package fasthttp
import (
"bytes"
"errors"
"io"
"iter"
"sort"
"sync"
)
const (
argsNoValue = true
argsHasValue = false
)
// AcquireArgs returns an empty Args object from the pool.
//
// The returned Args may be returned to the pool with ReleaseArgs
// when no longer needed. This allows reducing GC load.
func AcquireArgs() *Args {
return argsPool.Get().(*Args)
}
// ReleaseArgs returns the object acquired via AcquireArgs to the pool.
//
// Do not access the released Args object, otherwise data races may occur.
func ReleaseArgs(a *Args) {
a.Reset()
argsPool.Put(a)
}
var argsPool = &sync.Pool{
New: func() any {
return &Args{}
},
}
// Args represents query arguments.
//
// It is forbidden copying Args instances. Create new instances instead
// and use CopyTo().
//
// Args instance MUST NOT be used from concurrently running goroutines.
type Args struct {
noCopy noCopy
args []argsKV
buf []byte
}
type argsKV struct {
key []byte
value []byte
noValue bool
}
// Reset clears query args.
func (a *Args) Reset() {
a.args = a.args[:0]
}
// CopyTo copies all args to dst.
func (a *Args) CopyTo(dst *Args) {
dst.args = copyArgs(dst.args, a.args)
}
// All returns an iterator over key-value pairs from args.
//
// The key and value may invalid outside the iteration loop.
// Make copies if you need to use them after the loop ends.
func (a *Args) All() iter.Seq2[[]byte, []byte] {
return func(yield func([]byte, []byte) bool) {
for i := range a.args {
if !yield(a.args[i].key, a.args[i].value) {
break
}
}
}
}
// VisitAll calls f for each existing arg.
//
// f must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
//
// Deprecated: Use All instead.
func (a *Args) VisitAll(f func(key, value []byte)) {
a.All()(func(key, value []byte) bool {
f(key, value)
return true
})
}
// Len returns the number of query args.
func (a *Args) Len() int {
return len(a.args)
}
// Parse parses the given string containing query args.
func (a *Args) Parse(s string) {
a.buf = append(a.buf[:0], s...)
a.ParseBytes(a.buf)
}
// ParseBytes parses the given b containing query args.
func (a *Args) ParseBytes(b []byte) {
a.Reset()
var s argsScanner
s.b = b
var kv *argsKV
a.args, kv = allocArg(a.args)
for s.next(kv) {
if len(kv.key) > 0 || len(kv.value) > 0 {
a.args, kv = allocArg(a.args)
}
}
a.args = releaseArg(a.args)
}
// String returns string representation of query args.
func (a *Args) String() string {
return string(a.QueryString())
}
// QueryString returns query string for the args.
//
// The returned value is valid until the Args is reused or released (ReleaseArgs).
// Do not store references to the returned value. Make copies instead.
func (a *Args) QueryString() []byte {
a.buf = a.AppendBytes(a.buf[:0])
return a.buf
}
// Sort sorts Args by key and then value using 'f' as comparison function.
//
// For example args.Sort(bytes.Compare).
func (a *Args) Sort(f func(x, y []byte) int) {
sort.SliceStable(a.args, func(i, j int) bool {
n := f(a.args[i].key, a.args[j].key)
if n == 0 {
return f(a.args[i].value, a.args[j].value) == -1
}
return n == -1
})
}
// AppendBytes appends query string to dst and returns the extended dst.
func (a *Args) AppendBytes(dst []byte) []byte {
for i, n := 0, len(a.args); i < n; i++ {
kv := &a.args[i]
dst = AppendQuotedArg(dst, kv.key)
if !kv.noValue {
dst = append(dst, '=')
if len(kv.value) > 0 {
dst = AppendQuotedArg(dst, kv.value)
}
}
if i+1 < n {
dst = append(dst, '&')
}
}
return dst
}
// WriteTo writes query string to w.
//
// WriteTo implements io.WriterTo interface.
func (a *Args) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(a.QueryString())
return int64(n), err
}
// Del deletes argument with the given key from query args.
func (a *Args) Del(key string) {
a.args = delAllArgsStable(a.args, key)
}
// DelBytes deletes argument with the given key from query args.
func (a *Args) DelBytes(key []byte) {
a.args = delAllArgsStable(a.args, b2s(key))
}
// Add adds 'key=value' argument.
//
// Multiple values for the same key may be added.
func (a *Args) Add(key, value string) {
a.args = appendArg(a.args, key, value, argsHasValue)
}
// AddBytesK adds 'key=value' argument.
//
// Multiple values for the same key may be added.
func (a *Args) AddBytesK(key []byte, value string) {
a.args = appendArg(a.args, b2s(key), value, argsHasValue)
}
// AddBytesV adds 'key=value' argument.
//
// Multiple values for the same key may be added.
func (a *Args) AddBytesV(key string, value []byte) {
a.args = appendArg(a.args, key, b2s(value), argsHasValue)
}
// AddBytesKV adds 'key=value' argument.
//
// Multiple values for the same key may be added.
func (a *Args) AddBytesKV(key, value []byte) {
a.args = appendArg(a.args, b2s(key), b2s(value), argsHasValue)
}
// AddNoValue adds only 'key' as argument without the '='.
//
// Multiple values for the same key may be added.
func (a *Args) AddNoValue(key string) {
a.args = appendArg(a.args, key, "", argsNoValue)
}
// AddBytesKNoValue adds only 'key' as argument without the '='.
//
// Multiple values for the same key may be added.
func (a *Args) AddBytesKNoValue(key []byte) {
a.args = appendArg(a.args, b2s(key), "", argsNoValue)
}
// Set sets 'key=value' argument.
func (a *Args) Set(key, value string) {
a.args = setArg(a.args, key, value, argsHasValue)
}
// SetBytesK sets 'key=value' argument.
func (a *Args) SetBytesK(key []byte, value string) {
a.args = setArg(a.args, b2s(key), value, argsHasValue)
}
// SetBytesV sets 'key=value' argument.
func (a *Args) SetBytesV(key string, value []byte) {
a.args = setArg(a.args, key, b2s(value), argsHasValue)
}
// SetBytesKV sets 'key=value' argument.
func (a *Args) SetBytesKV(key, value []byte) {
a.args = setArgBytes(a.args, key, value, argsHasValue)
}
// SetNoValue sets only 'key' as argument without the '='.
//
// Only key in argument, like key1&key2.
func (a *Args) SetNoValue(key string) {
a.args = setArg(a.args, key, "", argsNoValue)
}
// SetBytesKNoValue sets 'key' argument.
func (a *Args) SetBytesKNoValue(key []byte) {
a.args = setArg(a.args, b2s(key), "", argsNoValue)
}
// Peek returns query arg value for the given key.
//
// The returned value is valid until the Args is reused or released (ReleaseArgs).
// Do not store references to the returned value. Make copies instead.
func (a *Args) Peek(key string) []byte {
return peekArgStr(a.args, key)
}
// PeekBytes returns query arg value for the given key.
//
// The returned value is valid until the Args is reused or released (ReleaseArgs).
// Do not store references to the returned value. Make copies instead.
func (a *Args) PeekBytes(key []byte) []byte {
return peekArgBytes(a.args, key)
}
// PeekMulti returns all the arg values for the given key.
func (a *Args) PeekMulti(key string) [][]byte {
var values [][]byte
for k, v := range a.All() {
if string(k) == key {
values = append(values, v)
}
}
return values
}
// PeekMultiBytes returns all the arg values for the given key.
func (a *Args) PeekMultiBytes(key []byte) [][]byte {
return a.PeekMulti(b2s(key))
}
// Has returns true if the given key exists in Args.
func (a *Args) Has(key string) bool {
return hasArg(a.args, key)
}
// HasBytes returns true if the given key exists in Args.
func (a *Args) HasBytes(key []byte) bool {
return hasArg(a.args, b2s(key))
}
// ErrNoArgValue is returned when Args value with the given key is missing.
var ErrNoArgValue = errors.New("no Args value for the given key")
// GetUint returns uint value for the given key.
func (a *Args) GetUint(key string) (int, error) {
value := a.Peek(key)
if len(value) == 0 {
return -1, ErrNoArgValue
}
return ParseUint(value)
}
// SetUint sets uint value for the given key.
func (a *Args) SetUint(key string, value int) {
a.buf = AppendUint(a.buf[:0], value)
a.SetBytesV(key, a.buf)
}
// SetUintBytes sets uint value for the given key.
func (a *Args) SetUintBytes(key []byte, value int) {
a.SetUint(b2s(key), value)
}
// GetUintOrZero returns uint value for the given key.
//
// Zero (0) is returned on error.
func (a *Args) GetUintOrZero(key string) int {
n, err := a.GetUint(key)
if err != nil {
n = 0
}
return n
}
// GetUfloat returns ufloat value for the given key.
func (a *Args) GetUfloat(key string) (float64, error) {
value := a.Peek(key)
if len(value) == 0 {
return -1, ErrNoArgValue
}
return ParseUfloat(value)
}
// GetUfloatOrZero returns ufloat value for the given key.
//
// Zero (0) is returned on error.
func (a *Args) GetUfloatOrZero(key string) float64 {
f, err := a.GetUfloat(key)
if err != nil {
f = 0
}
return f
}
// GetBool returns boolean value for the given key.
//
// true is returned for "1", "t", "T", "true", "TRUE", "True", "y", "yes", "Y", "YES", "Yes",
// otherwise false is returned.
func (a *Args) GetBool(key string) bool {
switch string(a.Peek(key)) {
// Support the same true cases as strconv.ParseBool
// See: https://github.com/golang/go/blob/4e1b11e2c9bdb0ddea1141eed487be1a626ff5be/src/strconv/atob.go#L12
// and Y and Yes versions.
case "1", "t", "T", "true", "TRUE", "True", "y", "yes", "Y", "YES", "Yes":
return true
default:
return false
}
}
func copyArgs(dst, src []argsKV) []argsKV {
if cap(dst) < len(src) {
tmp := make([]argsKV, len(src))
dstLen := len(dst)
dst = dst[:cap(dst)] // copy all of dst.
copy(tmp, dst)
for i := dstLen; i < len(tmp); i++ {
// Make sure nothing is nil.
tmp[i].key = []byte{}
tmp[i].value = []byte{}
}
dst = tmp
}
n := len(src)
dst = dst[:n]
for i := 0; i < n; i++ {
dstKV := &dst[i]
srcKV := &src[i]
dstKV.key = append(dstKV.key[:0], srcKV.key...)
if srcKV.noValue {
dstKV.value = dstKV.value[:0]
} else {
dstKV.value = append(dstKV.value[:0], srcKV.value...)
}
dstKV.noValue = srcKV.noValue
}
return dst
}
func delAllArgsStable(args []argsKV, key string) []argsKV {
for i, n := 0, len(args); i < n; i++ {
kv := &args[i]
if key == string(kv.key) {
tmp := *kv
copy(args[i:], args[i+1:])
n--
i--
args[n] = tmp
args = args[:n]
}
}
return args
}
func delAllArgs(args []argsKV, key string) []argsKV {
n := len(args)
for i := 0; i < n; i++ {
if key == string(args[i].key) {
args[i], args[n-1] = args[n-1], args[i]
n--
i--
}
}
return args[:n]
}
func setArgBytes(h []argsKV, key, value []byte, noValue bool) []argsKV {
return setArg(h, b2s(key), b2s(value), noValue)
}
func setArg(h []argsKV, key, value string, noValue bool) []argsKV {
n := len(h)
for i := 0; i < n; i++ {
kv := &h[i]
if key == string(kv.key) {
if noValue {
kv.value = kv.value[:0]
} else {
kv.value = append(kv.value[:0], value...)
}
kv.noValue = noValue
return h
}
}
return appendArg(h, key, value, noValue)
}
func appendArgBytes(h []argsKV, key, value []byte, noValue bool) []argsKV {
return appendArg(h, b2s(key), b2s(value), noValue)
}
func appendArg(args []argsKV, key, value string, noValue bool) []argsKV {
var kv *argsKV
args, kv = allocArg(args)
kv.key = append(kv.key[:0], key...)
if noValue {
kv.value = kv.value[:0]
} else {
kv.value = append(kv.value[:0], value...)
}
kv.noValue = noValue
return args
}
func allocArg(h []argsKV) ([]argsKV, *argsKV) {
n := len(h)
if cap(h) > n {
h = h[:n+1]
} else {
h = append(h, argsKV{
value: []byte{},
})
}
return h, &h[n]
}
func releaseArg(h []argsKV) []argsKV {
return h[:len(h)-1]
}
func hasArg(h []argsKV, key string) bool {
for i, n := 0, len(h); i < n; i++ {
kv := &h[i]
if key == string(kv.key) {
return true
}
}
return false
}
func peekArgBytes(h []argsKV, k []byte) []byte {
for i, n := 0, len(h); i < n; i++ {
kv := &h[i]
if bytes.Equal(kv.key, k) {
return kv.value
}
}
return nil
}
func peekArgStr(h []argsKV, k string) []byte {
for i, n := 0, len(h); i < n; i++ {
kv := &h[i]
if string(kv.key) == k {
return kv.value
}
}
return nil
}
type argsScanner struct {
b []byte
}
func (s *argsScanner) next(kv *argsKV) bool {
if len(s.b) == 0 {
return false
}
kv.noValue = argsHasValue
isKey := true
k := 0
for i, c := range s.b {
switch c {
case '=':
if isKey {
isKey = false
kv.key = decodeArgAppend(kv.key[:0], s.b[:i])
k = i + 1
}
case '&':
if isKey {
kv.key = decodeArgAppend(kv.key[:0], s.b[:i])
kv.value = kv.value[:0]
kv.noValue = argsNoValue
} else {
kv.value = decodeArgAppend(kv.value[:0], s.b[k:i])
}
s.b = s.b[i+1:]
return true
}
}
if isKey {
kv.key = decodeArgAppend(kv.key[:0], s.b)
kv.value = kv.value[:0]
kv.noValue = argsNoValue
} else {
kv.value = decodeArgAppend(kv.value[:0], s.b[k:])
}
s.b = s.b[len(s.b):]
return true
}
func decodeArgAppend(dst, src []byte) []byte {
idxPercent := bytes.IndexByte(src, '%')
idxPlus := bytes.IndexByte(src, '+')
if idxPercent == -1 && idxPlus == -1 {
// fast path: src doesn't contain encoded chars
return append(dst, src...)
}
var idx int
switch {
case idxPercent == -1:
idx = idxPlus
case idxPlus == -1:
idx = idxPercent
case idxPercent > idxPlus:
idx = idxPlus
default:
idx = idxPercent
}
dst = append(dst, src[:idx]...)
// slow path
for i := idx; i < len(src); i++ {
c := src[i]
switch c {
case '%':
if i+2 >= len(src) {
return append(dst, src[i:]...)
}
x2 := hex2intTable[src[i+2]]
x1 := hex2intTable[src[i+1]]
if x1 == 16 || x2 == 16 {
dst = append(dst, '%')
} else {
dst = append(dst, x1<<4|x2)
i += 2
}
case '+':
dst = append(dst, ' ')
default:
dst = append(dst, c)
}
}
return dst
}
// decodeArgAppendNoPlus is almost identical to decodeArgAppend, but it doesn't
// substitute '+' with ' '.
//
// The function is copy-pasted from decodeArgAppend due to the performance
// reasons only.
func decodeArgAppendNoPlus(dst, src []byte) []byte {
idx := bytes.IndexByte(src, '%')
if idx < 0 {
// fast path: src doesn't contain encoded chars
return append(dst, src...)
}
dst = append(dst, src[:idx]...)
// slow path
for i := idx; i < len(src); i++ {
c := src[i]
if c == '%' {
if i+2 >= len(src) {
return append(dst, src[i:]...)
}
x2 := hex2intTable[src[i+2]]
x1 := hex2intTable[src[i+1]]
if x1 == 16 || x2 == 16 {
dst = append(dst, '%')
} else {
dst = append(dst, x1<<4|x2)
i += 2
}
} else {
dst = append(dst, c)
}
}
return dst
}
func peekAllArgBytesToDst(dst [][]byte, h []argsKV, k []byte) [][]byte {
for i, n := 0, len(h); i < n; i++ {
kv := &h[i]
if bytes.Equal(kv.key, k) {
dst = append(dst, kv.value)
}
}
return dst
}
func peekArgsKeys(dst [][]byte, h []argsKV) [][]byte {
for i, n := 0, len(h); i < n; i++ {
kv := &h[i]
dst = append(dst, kv.key)
}
return dst
}
package fasthttp
import "unsafe"
// b2s converts byte slice to a string without memory allocation.
// See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ .
func b2s(b []byte) string {
return unsafe.String(unsafe.SliceData(b), len(b))
}
package fasthttp
import (
"bytes"
"fmt"
"io"
"sync"
"github.com/andybalholm/brotli"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp/stackless"
)
// Supported compression levels.
const (
CompressBrotliNoCompression = 0
CompressBrotliBestSpeed = brotli.BestSpeed
CompressBrotliBestCompression = brotli.BestCompression
// Choose a default brotli compression level comparable to
// CompressDefaultCompression (gzip 6)
// See: https://github.com/valyala/fasthttp/issues/798#issuecomment-626293806
CompressBrotliDefaultCompression = 4
)
func acquireBrotliReader(r io.Reader) (*brotli.Reader, error) {
v := brotliReaderPool.Get()
if v == nil {
return brotli.NewReader(r), nil
}
zr := v.(*brotli.Reader)
if err := zr.Reset(r); err != nil {
return nil, err
}
return zr, nil
}
func releaseBrotliReader(zr *brotli.Reader) {
brotliReaderPool.Put(zr)
}
var brotliReaderPool sync.Pool
func acquireStacklessBrotliWriter(w io.Writer, level int) stackless.Writer {
nLevel := normalizeBrotliCompressLevel(level)
p := stacklessBrotliWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
return stackless.NewWriter(w, func(w io.Writer) stackless.Writer {
return acquireRealBrotliWriter(w, level)
})
}
sw := v.(stackless.Writer)
sw.Reset(w)
return sw
}
func releaseStacklessBrotliWriter(sw stackless.Writer, level int) {
sw.Close()
nLevel := normalizeBrotliCompressLevel(level)
p := stacklessBrotliWriterPoolMap[nLevel]
p.Put(sw)
}
func acquireRealBrotliWriter(w io.Writer, level int) *brotli.Writer {
nLevel := normalizeBrotliCompressLevel(level)
p := realBrotliWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
zw := brotli.NewWriterLevel(w, level)
return zw
}
zw := v.(*brotli.Writer)
zw.Reset(w)
return zw
}
func releaseRealBrotliWriter(zw *brotli.Writer, level int) {
zw.Close()
nLevel := normalizeBrotliCompressLevel(level)
p := realBrotliWriterPoolMap[nLevel]
p.Put(zw)
}
var (
stacklessBrotliWriterPoolMap = newCompressWriterPoolMap()
realBrotliWriterPoolMap = newCompressWriterPoolMap()
)
// AppendBrotliBytesLevel appends brotlied src to dst using the given
// compression level and returns the resulting dst.
//
// Supported compression levels are:
//
// - CompressBrotliNoCompression
// - CompressBrotliBestSpeed
// - CompressBrotliBestCompression
// - CompressBrotliDefaultCompression
func AppendBrotliBytesLevel(dst, src []byte, level int) []byte {
w := &byteSliceWriter{b: dst}
WriteBrotliLevel(w, src, level) //nolint:errcheck
return w.b
}
// WriteBrotliLevel writes brotlied p to w using the given compression level
// and returns the number of compressed bytes written to w.
//
// Supported compression levels are:
//
// - CompressBrotliNoCompression
// - CompressBrotliBestSpeed
// - CompressBrotliBestCompression
// - CompressBrotliDefaultCompression
func WriteBrotliLevel(w io.Writer, p []byte, level int) (int, error) {
switch w.(type) {
case *byteSliceWriter,
*bytes.Buffer,
*bytebufferpool.ByteBuffer:
// These writers don't block, so we can just use stacklessWriteBrotli
ctx := &compressCtx{
w: w,
p: p,
level: level,
}
stacklessWriteBrotli(ctx)
return len(p), nil
default:
zw := acquireStacklessBrotliWriter(w, level)
n, err := zw.Write(p)
releaseStacklessBrotliWriter(zw, level)
return n, err
}
}
var (
stacklessWriteBrotliOnce sync.Once
stacklessWriteBrotliFunc func(ctx any) bool
)
func stacklessWriteBrotli(ctx any) {
stacklessWriteBrotliOnce.Do(func() {
stacklessWriteBrotliFunc = stackless.NewFunc(nonblockingWriteBrotli)
})
stacklessWriteBrotliFunc(ctx)
}
func nonblockingWriteBrotli(ctxv any) {
ctx := ctxv.(*compressCtx)
zw := acquireRealBrotliWriter(ctx.w, ctx.level)
zw.Write(ctx.p) //nolint:errcheck // no way to handle this error anyway
releaseRealBrotliWriter(zw, ctx.level)
}
// WriteBrotli writes brotlied p to w and returns the number of compressed
// bytes written to w.
func WriteBrotli(w io.Writer, p []byte) (int, error) {
return WriteBrotliLevel(w, p, CompressBrotliDefaultCompression)
}
// AppendBrotliBytes appends brotlied src to dst and returns the resulting dst.
func AppendBrotliBytes(dst, src []byte) []byte {
return AppendBrotliBytesLevel(dst, src, CompressBrotliDefaultCompression)
}
// WriteUnbrotli writes unbrotlied p to w and returns the number of uncompressed
// bytes written to w.
func WriteUnbrotli(w io.Writer, p []byte) (int, error) {
r := &byteSliceReader{b: p}
zr, err := acquireBrotliReader(r)
if err != nil {
return 0, err
}
n, err := copyZeroAlloc(w, zr)
releaseBrotliReader(zr)
nn := int(n)
if int64(nn) != n {
return 0, fmt.Errorf("too much data unbrotlied: %d", n)
}
return nn, err
}
// AppendUnbrotliBytes appends unbrotlied src to dst and returns the resulting dst.
func AppendUnbrotliBytes(dst, src []byte) ([]byte, error) {
w := &byteSliceWriter{b: dst}
_, err := WriteUnbrotli(w, src)
return w.b, err
}
// normalizes compression level into [0..11], so it could be used as an index
// in *PoolMap.
func normalizeBrotliCompressLevel(level int) int {
// -2 is the lowest compression level - CompressHuffmanOnly
// 9 is the highest compression level - CompressBestCompression
if level < 0 || level > 11 {
level = CompressBrotliDefaultCompression
}
return level
}
//go:generate go run bytesconv_table_gen.go
package fasthttp
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
)
// AppendHTMLEscape appends html-escaped s to dst and returns the extended dst.
func AppendHTMLEscape(dst []byte, s string) []byte {
var (
prev int
sub string
)
for i, n := 0, len(s); i < n; i++ {
sub = ""
switch s[i] {
case '&':
sub = "&"
case '<':
sub = "<"
case '>':
sub = ">"
case '"':
sub = """ // """ is shorter than """.
case '\'':
sub = "'" // "'" is shorter than "'" and apos was not in HTML until HTML5.
}
if sub != "" {
dst = append(dst, s[prev:i]...)
dst = append(dst, sub...)
prev = i + 1
}
}
return append(dst, s[prev:]...)
}
// AppendHTMLEscapeBytes appends html-escaped s to dst and returns
// the extended dst.
func AppendHTMLEscapeBytes(dst, s []byte) []byte {
return AppendHTMLEscape(dst, b2s(s))
}
// AppendIPv4 appends string representation of the given ip v4 to dst
// and returns the extended dst.
func AppendIPv4(dst []byte, ip net.IP) []byte {
ip = ip.To4()
if ip == nil {
return append(dst, "non-v4 ip passed to AppendIPv4"...)
}
dst = AppendUint(dst, int(ip[0]))
for i := 1; i < 4; i++ {
dst = append(dst, '.')
dst = AppendUint(dst, int(ip[i]))
}
return dst
}
var errEmptyIPStr = errors.New("empty ip address string")
// ParseIPv4 parses ip address from ipStr into dst and returns the extended dst.
func ParseIPv4(dst net.IP, ipStr []byte) (net.IP, error) {
if len(ipStr) == 0 {
return dst, errEmptyIPStr
}
if len(dst) < net.IPv4len || len(dst) > net.IPv4len {
dst = make([]byte, net.IPv4len)
}
copy(dst, net.IPv4zero)
dst = dst.To4() // dst is always non-nil here
b := ipStr
for i := 0; i < 3; i++ {
n := bytes.IndexByte(b, '.')
if n < 0 {
return dst, fmt.Errorf("cannot find dot in ipStr %q", ipStr)
}
v, err := ParseUint(b[:n])
if err != nil {
return dst, fmt.Errorf("cannot parse ipStr %q: %w", ipStr, err)
}
if v > 255 {
return dst, fmt.Errorf("cannot parse ipStr %q: ip part cannot exceed 255: parsed %d", ipStr, v)
}
dst[i] = byte(v)
b = b[n+1:]
}
v, err := ParseUint(b)
if err != nil {
return dst, fmt.Errorf("cannot parse ipStr %q: %w", ipStr, err)
}
if v > 255 {
return dst, fmt.Errorf("cannot parse ipStr %q: ip part cannot exceed 255: parsed %d", ipStr, v)
}
dst[3] = byte(v)
return dst, nil
}
// AppendHTTPDate appends HTTP-compliant (RFC1123) representation of date
// to dst and returns the extended dst.
func AppendHTTPDate(dst []byte, date time.Time) []byte {
dst = date.In(time.UTC).AppendFormat(dst, time.RFC1123)
copy(dst[len(dst)-3:], strGMT)
return dst
}
// ParseHTTPDate parses HTTP-compliant (RFC1123) date.
func ParseHTTPDate(date []byte) (time.Time, error) {
return time.Parse(time.RFC1123, b2s(date))
}
// AppendUint appends n to dst and returns the extended dst.
func AppendUint(dst []byte, n int) []byte {
if n < 0 {
// developer sanity-check
panic("BUG: int must be positive")
}
return strconv.AppendUint(dst, uint64(n), 10)
}
// ParseUint parses uint from buf.
func ParseUint(buf []byte) (int, error) {
v, n, err := parseUintBuf(buf)
if n != len(buf) {
return -1, errUnexpectedTrailingChar
}
return v, err
}
var (
errEmptyInt = errors.New("empty integer")
errUnexpectedFirstChar = errors.New("unexpected first char found. Expecting 0-9")
errUnexpectedTrailingChar = errors.New("unexpected trailing char found. Expecting 0-9")
errTooLongInt = errors.New("too long int")
)
func parseUintBuf(b []byte) (int, int, error) {
n := len(b)
if n == 0 {
return -1, 0, errEmptyInt
}
v := 0
for i := 0; i < n; i++ {
c := b[i]
k := c - '0'
if k > 9 {
if i == 0 {
return -1, i, errUnexpectedFirstChar
}
return v, i, nil
}
vNew := 10*v + int(k)
// Test for overflow.
if vNew < v {
return -1, i, errTooLongInt
}
v = vNew
}
return v, n, nil
}
// ParseUfloat parses unsigned float from buf.
func ParseUfloat(buf []byte) (float64, error) {
// The implementation of parsing a float string is not easy.
// We believe that the conservative approach is to call strconv.ParseFloat.
// https://github.com/valyala/fasthttp/pull/1865
res, err := strconv.ParseFloat(b2s(buf), 64)
if res < 0 {
return -1, errors.New("negative input is invalid")
}
if err != nil {
return -1, err
}
return res, err
}
var (
errEmptyHexNum = errors.New("empty hex number")
errTooLargeHexNum = errors.New("too large hex number")
)
func readHexInt(r *bufio.Reader) (int, error) {
var k, i, n int
for {
c, err := r.ReadByte()
if err != nil {
if err == io.EOF && i > 0 {
return n, nil
}
return -1, err
}
k = int(hex2intTable[c])
if k == 16 {
if i == 0 {
return -1, errEmptyHexNum
}
if err := r.UnreadByte(); err != nil {
return -1, err
}
return n, nil
}
if i >= maxHexIntChars {
return -1, errTooLargeHexNum
}
n = (n << 4) | k
i++
}
}
var hexIntBufPool sync.Pool
func writeHexInt(w *bufio.Writer, n int) error {
if n < 0 {
// developer sanity-check
panic("BUG: int must be positive")
}
v := hexIntBufPool.Get()
if v == nil {
v = make([]byte, maxHexIntChars+1)
}
buf := v.([]byte)
i := len(buf) - 1
for {
buf[i] = lowerhex[n&0xf]
n >>= 4
if n == 0 {
break
}
i--
}
_, err := w.Write(buf[i:])
hexIntBufPool.Put(v)
return err
}
const (
upperhex = "0123456789ABCDEF"
lowerhex = "0123456789abcdef"
)
func lowercaseBytes(b []byte) {
for i := 0; i < len(b); i++ {
p := &b[i]
*p = toLowerTable[*p]
}
}
// AppendUnquotedArg appends url-decoded src to dst and returns appended dst.
//
// dst may point to src. In this case src will be overwritten.
func AppendUnquotedArg(dst, src []byte) []byte {
return decodeArgAppend(dst, src)
}
// AppendQuotedArg appends url-encoded src to dst and returns appended dst.
func AppendQuotedArg(dst, src []byte) []byte {
for _, c := range src {
switch {
case c == ' ':
dst = append(dst, '+')
case quotedArgShouldEscapeTable[int(c)] != 0:
dst = append(dst, '%', upperhex[c>>4], upperhex[c&0xf])
default:
dst = append(dst, c)
}
}
return dst
}
func appendQuotedPath(dst, src []byte) []byte {
// Fix issue in https://github.com/golang/go/issues/11202
if len(src) == 1 && src[0] == '*' {
return append(dst, '*')
}
for _, c := range src {
if quotedPathShouldEscapeTable[int(c)] != 0 {
dst = append(dst, '%', upperhex[c>>4], upperhex[c&0xf])
} else {
dst = append(dst, c)
}
}
return dst
}
package fasthttp
import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"sync/atomic"
"time"
)
// Do performs the given http request and fills the given http response.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func Do(req *Request, resp *Response) error {
return defaultClient.Do(req, resp)
}
// DoTimeout performs the given request and waits for response during
// the given timeout duration.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned during
// the given timeout.
//
// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
return defaultClient.DoTimeout(req, resp, timeout)
}
// DoDeadline performs the given request and waits for response until
// the given deadline.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned until
// the given deadline.
//
// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func DoDeadline(req *Request, resp *Response, deadline time.Time) error {
return defaultClient.DoDeadline(req, resp, deadline)
}
// DoRedirects performs the given http request and fills the given http response,
// following up to maxRedirectsCount redirects. When the redirect count exceeds
// maxRedirectsCount, ErrTooManyRedirects is returned.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// Response is ignored if resp is nil.
//
// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error {
if defaultClient.DisablePathNormalizing {
req.URI().DisablePathNormalizing = true
}
_, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, &defaultClient)
return err
}
// Get returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
func Get(dst []byte, url string) (statusCode int, body []byte, err error) {
return defaultClient.Get(dst, url)
}
// GetTimeout returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// ErrTimeout error is returned if url contents couldn't be fetched
// during the given timeout.
func GetTimeout(dst []byte, url string, timeout time.Duration) (statusCode int, body []byte, err error) {
return defaultClient.GetTimeout(dst, url, timeout)
}
// GetDeadline returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// ErrTimeout error is returned if url contents couldn't be fetched
// until the given deadline.
func GetDeadline(dst []byte, url string, deadline time.Time) (statusCode int, body []byte, err error) {
return defaultClient.GetDeadline(dst, url, deadline)
}
// Post sends POST request to the given url with the given POST arguments.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// Empty POST body is sent if postArgs is nil.
func Post(dst []byte, url string, postArgs *Args) (statusCode int, body []byte, err error) {
return defaultClient.Post(dst, url, postArgs)
}
var defaultClient Client
// Client implements http client.
//
// Copying Client by value is prohibited. Create new instance instead.
//
// It is safe calling Client methods from concurrently running goroutines.
//
// The fields of a Client should not be changed while it is in use.
type Client struct {
noCopy noCopy
readerPool sync.Pool
writerPool sync.Pool
// Transport defines a transport-like mechanism that wraps every request/response.
Transport RoundTripper
// Callback for establishing new connections to hosts.
//
// Default DialTimeout is used if not set.
DialTimeout DialFuncWithTimeout
// Callback for establishing new connections to hosts.
//
// Note that if Dial is set instead of DialTimeout, Dial will ignore Request timeout.
// If you want the tcp dial process to account for request timeouts, use DialTimeout instead.
//
// If not set, DialTimeout is used.
Dial DialFunc
// TLS config for https connections.
//
// Default TLS config is used if not set.
TLSConfig *tls.Config
// RetryIf controls whether a retry should be attempted after an error.
//
// By default will use isIdempotent function.
//
// Deprecated: Use RetryIfErr instead.
// This field is only effective when the `RetryIfErr` field is not set.
RetryIf RetryIfFunc
// When the client encounters an error during a request, the behavior—whether to retry
// and whether to reset the request timeout—should be determined
// based on the return value of this field.
// This field is only effective within the range of MaxIdemponentCallAttempts.
RetryIfErr RetryIfErrFunc
// ConfigureClient configures the fasthttp.HostClient.
ConfigureClient func(hc *HostClient) error
m map[string]*HostClient
ms map[string]*HostClient
// Client name. Used in User-Agent request header.
//
// Default client name is used if not set.
Name string
// Maximum number of connections per each host which may be established.
//
// DefaultMaxConnsPerHost is used if not set.
MaxConnsPerHost int
// Idle keep-alive connections are closed after this duration.
//
// By default idle connections are closed
// after DefaultMaxIdleConnDuration.
MaxIdleConnDuration time.Duration
// Keep-alive connections are closed after this duration.
//
// By default connection duration is unlimited.
MaxConnDuration time.Duration
// Maximum number of attempts for idempotent calls.
//
// DefaultMaxIdemponentCallAttempts is used if not set.
MaxIdemponentCallAttempts int
// Per-connection buffer size for responses' reading.
// This also limits the maximum header size.
//
// Default buffer size is used if 0.
ReadBufferSize int
// Per-connection buffer size for requests' writing.
//
// Default buffer size is used if 0.
WriteBufferSize int
// Maximum duration for full response reading (including body).
//
// By default response read timeout is unlimited.
ReadTimeout time.Duration
// Maximum duration for full request writing (including body).
//
// By default request write timeout is unlimited.
WriteTimeout time.Duration
// Maximum response body size.
//
// The client returns ErrBodyTooLarge if this limit is greater than 0
// and response body is greater than the limit.
//
// By default response body size is unlimited.
MaxResponseBodySize int
// Maximum duration for waiting for a free connection.
//
// By default will not waiting, return ErrNoFreeConns immediately.
MaxConnWaitTimeout time.Duration
// Connection pool strategy. Can be either LIFO or FIFO (default).
ConnPoolStrategy ConnPoolStrategyType
mLock sync.RWMutex
mOnce sync.Once
// NoDefaultUserAgentHeader when set to true, causes the default
// User-Agent header to be excluded from the Request.
NoDefaultUserAgentHeader bool
// Attempt to connect to both ipv4 and ipv6 addresses if set to true.
//
// This option is used only if default TCP dialer is used,
// i.e. if Dial is blank.
//
// By default client connects only to ipv4 addresses,
// since unfortunately ipv6 remains broken in many networks worldwide :)
DialDualStack bool
// Header names are passed as-is without normalization
// if this option is set.
//
// Disabled header names' normalization may be useful only for proxying
// responses to other clients expecting case-sensitive
// header names. See https://github.com/valyala/fasthttp/issues/57
// for details.
//
// By default request and response header names are normalized, i.e.
// The first letter and the first letters following dashes
// are uppercased, while all the other letters are lowercased.
// Examples:
//
// * HOST -> Host
// * content-type -> Content-Type
// * cONTENT-lenGTH -> Content-Length
DisableHeaderNamesNormalizing bool
// Path values are sent as-is without normalization.
//
// Disabled path normalization may be useful for proxying incoming requests
// to servers that are expecting paths to be forwarded as-is.
//
// By default path values are normalized, i.e.
// extra slashes are removed, special characters are encoded.
DisablePathNormalizing bool
// StreamResponseBody enables response body streaming.
StreamResponseBody bool
}
// Get returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
func (c *Client) Get(dst []byte, url string) (statusCode int, body []byte, err error) {
return clientGetURL(dst, url, c)
}
// GetTimeout returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// ErrTimeout error is returned if url contents couldn't be fetched
// during the given timeout.
func (c *Client) GetTimeout(dst []byte, url string, timeout time.Duration) (statusCode int, body []byte, err error) {
return clientGetURLTimeout(dst, url, timeout, c)
}
// GetDeadline returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// ErrTimeout error is returned if url contents couldn't be fetched
// until the given deadline.
func (c *Client) GetDeadline(dst []byte, url string, deadline time.Time) (statusCode int, body []byte, err error) {
return clientGetURLDeadline(dst, url, deadline, c)
}
// Post sends POST request to the given url with the given POST arguments.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// Empty POST body is sent if postArgs is nil.
func (c *Client) Post(dst []byte, url string, postArgs *Args) (statusCode int, body []byte, err error) {
return clientPostURL(dst, url, postArgs, c)
}
// DoTimeout performs the given request and waits for response during
// the given timeout duration.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned during
// the given timeout.
// Immediately returns ErrTimeout if timeout value is negative.
//
// ErrNoFreeConns is returned if all Client.MaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
req.timeout = timeout
if req.timeout <= 0 {
return ErrTimeout
}
return c.Do(req, resp)
}
// DoDeadline performs the given request and waits for response until
// the given deadline.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned until
// the given deadline.
// Immediately returns ErrTimeout if the deadline has already been reached.
//
// ErrNoFreeConns is returned if all Client.MaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *Client) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
req.timeout = time.Until(deadline)
if req.timeout <= 0 {
return ErrTimeout
}
return c.Do(req, resp)
}
// DoRedirects performs the given http request and fills the given http response,
// following up to maxRedirectsCount redirects. When the redirect count exceeds
// maxRedirectsCount, ErrTooManyRedirects is returned.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// Response is ignored if resp is nil.
//
// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *Client) DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error {
if c.DisablePathNormalizing {
req.URI().DisablePathNormalizing = true
}
_, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c)
return err
}
// Do performs the given http request and fills the given http response.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// Response is ignored if resp is nil.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// ErrNoFreeConns is returned if all Client.MaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *Client) Do(req *Request, resp *Response) error {
uri := req.URI()
if uri == nil {
return ErrorInvalidURI
}
host := uri.Host()
if bytes.ContainsRune(host, ',') {
return fmt.Errorf("invalid host %q. Use HostClient for multiple hosts", host)
}
isTLS := false
if uri.isHTTPS() {
isTLS = true
} else if !uri.isHTTP() {
return fmt.Errorf("unsupported protocol %q. http and https are supported", uri.Scheme())
}
c.mOnce.Do(func() {
c.m = make(map[string]*HostClient)
c.ms = make(map[string]*HostClient)
})
hc, err := c.hostClient(host, isTLS)
if err != nil {
return err
}
atomic.AddInt32(&hc.pendingClientRequests, 1)
defer atomic.AddInt32(&hc.pendingClientRequests, -1)
return hc.Do(req, resp)
}
func (c *Client) hostClient(host []byte, isTLS bool) (*HostClient, error) {
m := c.m
if isTLS {
m = c.ms
}
c.mLock.RLock()
hc, exist := m[string(host)]
c.mLock.RUnlock()
if exist {
return hc, nil
}
c.mLock.Lock()
defer c.mLock.Unlock()
hc, exist = m[string(host)]
if exist {
return hc, nil
}
hc = &HostClient{
Addr: AddMissingPort(string(host), isTLS),
Transport: c.Transport,
Name: c.Name,
NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader,
Dial: c.Dial,
DialTimeout: c.DialTimeout,
DialDualStack: c.DialDualStack,
IsTLS: isTLS,
TLSConfig: c.TLSConfig,
MaxConns: c.MaxConnsPerHost,
MaxIdleConnDuration: c.MaxIdleConnDuration,
MaxConnDuration: c.MaxConnDuration,
MaxIdemponentCallAttempts: c.MaxIdemponentCallAttempts,
ReadBufferSize: c.ReadBufferSize,
WriteBufferSize: c.WriteBufferSize,
ReadTimeout: c.ReadTimeout,
WriteTimeout: c.WriteTimeout,
MaxResponseBodySize: c.MaxResponseBodySize,
DisableHeaderNamesNormalizing: c.DisableHeaderNamesNormalizing,
DisablePathNormalizing: c.DisablePathNormalizing,
MaxConnWaitTimeout: c.MaxConnWaitTimeout,
RetryIf: c.RetryIf,
RetryIfErr: c.RetryIfErr,
ConnPoolStrategy: c.ConnPoolStrategy,
StreamResponseBody: c.StreamResponseBody,
clientReaderPool: &c.readerPool,
clientWriterPool: &c.writerPool,
}
if c.ConfigureClient != nil {
if err := c.ConfigureClient(hc); err != nil {
return nil, err
}
}
m[string(host)] = hc
if len(m) == 1 {
go c.mCleaner(m)
}
return hc, nil
}
// CloseIdleConnections closes any connections which were previously
// connected from previous requests but are now sitting idle in a
// "keep-alive" state. It does not interrupt any connections currently
// in use.
func (c *Client) CloseIdleConnections() {
c.mLock.RLock()
for _, v := range c.m {
v.CloseIdleConnections()
}
for _, v := range c.ms {
v.CloseIdleConnections()
}
c.mLock.RUnlock()
}
func (c *Client) mCleaner(m map[string]*HostClient) {
mustStop := false
sleep := c.MaxIdleConnDuration
if sleep < time.Second {
sleep = time.Second
} else if sleep > 10*time.Second {
sleep = 10 * time.Second
}
for {
time.Sleep(sleep)
c.mLock.Lock()
for k, v := range m {
v.connsLock.Lock()
if v.connsCount == 0 && atomic.LoadInt32(&v.pendingClientRequests) == 0 {
delete(m, k)
}
v.connsLock.Unlock()
}
if len(m) == 0 {
mustStop = true
}
c.mLock.Unlock()
if mustStop {
break
}
}
}
// DefaultMaxConnsPerHost is the maximum number of concurrent connections
// http client may establish per host by default (i.e. if
// Client.MaxConnsPerHost isn't set).
const DefaultMaxConnsPerHost = 512
// DefaultMaxIdleConnDuration is the default duration before idle keep-alive
// connection is closed.
const DefaultMaxIdleConnDuration = 10 * time.Second
// DefaultMaxIdemponentCallAttempts is the default idempotent calls attempts count.
const DefaultMaxIdemponentCallAttempts = 5
// DialFunc must establish connection to addr.
//
// There is no need in establishing TLS (SSL) connection for https.
// The client automatically converts connection to TLS
// if HostClient.IsTLS is set.
//
// TCP address passed to DialFunc always contains host and port.
// Example TCP addr values:
//
// - foobar.com:80
// - foobar.com:443
// - foobar.com:8080
type DialFunc func(addr string) (net.Conn, error)
// DialFuncWithTimeout must establish connection to addr.
// Unlike DialFunc, it also accepts a timeout.
//
// There is no need in establishing TLS (SSL) connection for https.
// The client automatically converts connection to TLS
// if HostClient.IsTLS is set.
//
// TCP address passed to DialFuncWithTimeout always contains host and port.
// Example TCP addr values:
//
// - foobar.com:80
// - foobar.com:443
// - foobar.com:8080
type DialFuncWithTimeout func(addr string, timeout time.Duration) (net.Conn, error)
// RetryIfFunc defines the signature of the retry if function.
// Request argument passed to RetryIfFunc, if there are any request errors.
type RetryIfFunc func(request *Request) bool
// RetryIfErrFunc defines an interface used for implementing the following functionality:
// When the client encounters an error during a request, the behavior—whether to retry
// and whether to reset the request timeout—should be determined
// based on the return value of this interface.
//
// attempt indicates which attempt the current retry is due to a failure of.
// The first request counts as the first attempt.
//
// err represents the error encountered while attempting the `attempts`-th request.
//
// resetTimeout indicates whether to reuse the `Request`'s timeout as the timeout interval,
// rather than using the timeout after subtracting the time spent on previous failed requests.
// This return value is meaningful only when you use `Request.SetTimeout`, `DoTimeout`, or `DoDeadline`.
//
// retry indicates whether to retry the current request. If it is false,
// the request function will immediately return with the `err`.
type RetryIfErrFunc func(request *Request, attempts int, err error) (resetTimeout bool, retry bool)
// RoundTripper wraps every request/response.
type RoundTripper interface {
RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error)
}
// ConnPoolStrategyType define strategy of connection pool enqueue/dequeue.
type ConnPoolStrategyType int
const (
FIFO ConnPoolStrategyType = iota
LIFO
)
// HostClient balances http requests among hosts listed in Addr.
//
// HostClient may be used for balancing load among multiple upstream hosts.
// While multiple addresses passed to HostClient.Addr may be used for balancing
// load among them, it would be better using LBClient instead, since HostClient
// may unevenly balance load among upstream hosts.
//
// It is forbidden copying HostClient instances. Create new instances instead.
//
// It is safe calling HostClient methods from concurrently running goroutines.
type HostClient struct {
noCopy noCopy
readerPool sync.Pool
writerPool sync.Pool
// Transport defines a transport-like mechanism that wraps every request/response.
Transport RoundTripper
// Callback for establishing new connections to hosts.
//
// Default DialTimeout is used if not set.
DialTimeout DialFuncWithTimeout
// Callback for establishing new connections to hosts.
//
// Note that if Dial is set instead of DialTimeout, Dial will ignore Request timeout.
// If you want the tcp dial process to account for request timeouts, use DialTimeout instead.
//
// If not set, DialTimeout is used.
Dial DialFunc
// Optional TLS config.
TLSConfig *tls.Config
// RetryIf controls whether a retry should be attempted after an error.
// By default, it uses the isIdempotent function.
//
// Deprecated: Use RetryIfErr instead.
// This field is only effective when the `RetryIfErr` field is not set.
RetryIf RetryIfFunc
// When the client encounters an error during a request, the behavior—whether to retry
// and whether to reset the request timeout—should be determined
// based on the return value of this field.
// This field is only effective within the range of MaxIdemponentCallAttempts.
RetryIfErr RetryIfErrFunc
connsWait *wantConnQueue
tlsConfigMap map[string]*tls.Config
clientReaderPool *sync.Pool
clientWriterPool *sync.Pool
// Comma-separated list of upstream HTTP server host addresses,
// which are passed to Dial or DialTimeout in a round-robin manner.
//
// Each address may contain port if default dialer is used.
// For example,
//
// - foobar.com:80
// - foobar.com:443
// - foobar.com:8080
Addr string
// Client name. Used in User-Agent request header.
Name string
conns []*clientConn
addrs []string
// Maximum number of connections which may be established to all hosts
// listed in Addr.
//
// You can change this value while the HostClient is being used
// with HostClient.SetMaxConns(value)
//
// DefaultMaxConnsPerHost is used if not set.
MaxConns int
// Keep-alive connections are closed after this duration.
//
// By default connection duration is unlimited.
MaxConnDuration time.Duration
// Idle keep-alive connections are closed after this duration.
//
// By default idle connections are closed
// after DefaultMaxIdleConnDuration.
MaxIdleConnDuration time.Duration
// Maximum number of attempts for idempotent calls.
//
// A value of 0 or a negative value represents using DefaultMaxIdemponentCallAttempts.
// For example, a value of 1 means the request will be executed only once,
// while 2 means the request will be executed at most twice.
// The RetryIfErr and RetryIf fields can invalidate remaining attempts.
MaxIdemponentCallAttempts int
// Per-connection buffer size for responses' reading.
// This also limits the maximum header size.
//
// Default buffer size is used if 0.
ReadBufferSize int
// Per-connection buffer size for requests' writing.
//
// Default buffer size is used if 0.
WriteBufferSize int
// Maximum duration for full response reading (including body).
//
// By default response read timeout is unlimited.
ReadTimeout time.Duration
// Maximum duration for full request writing (including body).
//
// By default request write timeout is unlimited.
WriteTimeout time.Duration
// Maximum response body size.
//
// The client returns ErrBodyTooLarge if this limit is greater than 0
// and response body is greater than the limit.
//
// By default response body size is unlimited.
MaxResponseBodySize int
// Maximum duration for waiting for a free connection.
//
// By default will not waiting, return ErrNoFreeConns immediately
MaxConnWaitTimeout time.Duration
// Connection pool strategy. Can be either LIFO or FIFO (default).
ConnPoolStrategy ConnPoolStrategyType
connsCount int
connsLock sync.Mutex
addrsLock sync.Mutex
tlsConfigMapLock sync.Mutex
addrIdx uint32
lastUseTime uint32
pendingRequests int32
// pendingClientRequests counts the number of requests that a Client is currently running using this HostClient.
// It will be incremented earlier than pendingRequests and will be used by Client to see if the HostClient is still in use.
pendingClientRequests int32
// NoDefaultUserAgentHeader when set to true, causes the default
// User-Agent header to be excluded from the Request.
NoDefaultUserAgentHeader bool
// Attempt to connect to both ipv4 and ipv6 host addresses
// if set to true.
//
// This option is used only if default TCP dialer is used,
// i.e. if Dial and DialTimeout are blank.
//
// By default client connects only to ipv4 addresses,
// since unfortunately ipv6 remains broken in many networks worldwide :)
DialDualStack bool
// Whether to use TLS (aka SSL or HTTPS) for host connections.
IsTLS bool
// Header names are passed as-is without normalization
// if this option is set.
//
// Disabled header names' normalization may be useful only for proxying
// responses to other clients expecting case-sensitive
// header names. See https://github.com/valyala/fasthttp/issues/57
// for details.
//
// By default request and response header names are normalized, i.e.
// The first letter and the first letters following dashes
// are uppercased, while all the other letters are lowercased.
// Examples:
//
// * HOST -> Host
// * content-type -> Content-Type
// * cONTENT-lenGTH -> Content-Length
DisableHeaderNamesNormalizing bool
// Path values are sent as-is without normalization.
//
// Disabled path normalization may be useful for proxying incoming requests
// to servers that are expecting paths to be forwarded as-is.
//
// By default path values are normalized, i.e.
// extra slashes are removed, special characters are encoded.
DisablePathNormalizing bool
// Will not log potentially sensitive content in error logs.
//
// This option is useful for servers that handle sensitive data
// in the request/response.
//
// Client logs full errors by default.
SecureErrorLogMessage bool
// StreamResponseBody enables response body streaming.
StreamResponseBody bool
connsCleanerRun bool
}
type clientConn struct {
c net.Conn
createdTime time.Time
lastUseTime time.Time
}
// CreatedTime returns net.Conn the client.
func (cc *clientConn) Conn() net.Conn {
return cc.c
}
// CreatedTime returns time the client was created.
func (cc *clientConn) CreatedTime() time.Time {
return cc.createdTime
}
// LastUseTime returns time the client was last used.
func (cc *clientConn) LastUseTime() time.Time {
return cc.lastUseTime
}
var startTimeUnix = time.Now().Unix()
// LastUseTime returns time the client was last used.
func (c *HostClient) LastUseTime() time.Time {
n := atomic.LoadUint32(&c.lastUseTime)
return time.Unix(startTimeUnix+int64(n), 0)
}
// Get returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
func (c *HostClient) Get(dst []byte, url string) (statusCode int, body []byte, err error) {
return clientGetURL(dst, url, c)
}
// GetTimeout returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// ErrTimeout error is returned if url contents couldn't be fetched
// during the given timeout.
func (c *HostClient) GetTimeout(dst []byte, url string, timeout time.Duration) (statusCode int, body []byte, err error) {
return clientGetURLTimeout(dst, url, timeout, c)
}
// GetDeadline returns the status code and body of url.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// ErrTimeout error is returned if url contents couldn't be fetched
// until the given deadline.
func (c *HostClient) GetDeadline(dst []byte, url string, deadline time.Time) (statusCode int, body []byte, err error) {
return clientGetURLDeadline(dst, url, deadline, c)
}
// Post sends POST request to the given url with the given POST arguments.
//
// The contents of dst will be replaced by the body and returned, if the dst
// is too small a new slice will be allocated.
//
// The function follows redirects. Use Do* for manually handling redirects.
//
// Empty POST body is sent if postArgs is nil.
func (c *HostClient) Post(dst []byte, url string, postArgs *Args) (statusCode int, body []byte, err error) {
return clientPostURL(dst, url, postArgs, c)
}
type clientDoer interface {
Do(req *Request, resp *Response) error
}
func clientGetURL(dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) {
req := AcquireRequest()
statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c)
ReleaseRequest(req)
return statusCode, body, err
}
func clientGetURLTimeout(dst []byte, url string, timeout time.Duration, c clientDoer) (statusCode int, body []byte, err error) {
deadline := time.Now().Add(timeout)
return clientGetURLDeadline(dst, url, deadline, c)
}
type clientURLResponse struct {
err error
body []byte
statusCode int
}
func clientGetURLDeadline(dst []byte, url string, deadline time.Time, c clientDoer) (statusCode int, body []byte, err error) {
timeout := time.Until(deadline)
if timeout <= 0 {
return 0, dst, ErrTimeout
}
var ch chan clientURLResponse
chv := clientURLResponseChPool.Get()
if chv == nil {
chv = make(chan clientURLResponse, 1)
}
ch = chv.(chan clientURLResponse)
// Note that the request continues execution on ErrTimeout until
// client-specific ReadTimeout exceeds. This helps limiting load
// on slow hosts by MaxConns* concurrent requests.
//
// Without this 'hack' the load on slow host could exceed MaxConns*
// concurrent requests, since timed out requests on client side
// usually continue execution on the host.
var mu sync.Mutex
var timedout, responded bool
go func() {
req := AcquireRequest()
statusCodeCopy, bodyCopy, errCopy := doRequestFollowRedirectsBuffer(req, dst, url, c)
mu.Lock()
if !timedout {
ch <- clientURLResponse{
statusCode: statusCodeCopy,
body: bodyCopy,
err: errCopy,
}
responded = true
}
mu.Unlock()
ReleaseRequest(req)
}()
tc := AcquireTimer(timeout)
select {
case resp := <-ch:
statusCode = resp.statusCode
body = resp.body
err = resp.err
case <-tc.C:
mu.Lock()
if responded {
resp := <-ch
statusCode = resp.statusCode
body = resp.body
err = resp.err
} else {
timedout = true
err = ErrTimeout
body = dst
}
mu.Unlock()
}
ReleaseTimer(tc)
clientURLResponseChPool.Put(chv)
return statusCode, body, err
}
var clientURLResponseChPool sync.Pool
func clientPostURL(dst []byte, url string, postArgs *Args, c clientDoer) (statusCode int, body []byte, err error) {
req := AcquireRequest()
defer ReleaseRequest(req)
req.Header.SetMethod(MethodPost)
req.Header.SetContentTypeBytes(strPostArgsContentType)
if postArgs != nil {
if _, err := postArgs.WriteTo(req.BodyWriter()); err != nil {
return 0, nil, err
}
}
statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c)
return statusCode, body, err
}
var (
// ErrMissingLocation is returned by clients when the Location header is missing on
// an HTTP response with a redirect status code.
ErrMissingLocation = errors.New("missing Location header for http redirect")
// ErrTooManyRedirects is returned by clients when the number of redirects followed
// exceed the max count.
ErrTooManyRedirects = errors.New("too many redirects detected when doing the request")
// HostClients are only able to follow redirects to the same protocol.
ErrHostClientRedirectToDifferentScheme = errors.New("HostClient can't follow redirects to a different protocol," +
" please use Client instead")
)
const defaultMaxRedirectsCount = 16
func doRequestFollowRedirectsBuffer(req *Request, dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) {
resp := AcquireResponse()
bodyBuf := resp.bodyBuffer()
resp.keepBodyBuffer = true
oldBody := bodyBuf.B
bodyBuf.B = dst
statusCode, _, err = doRequestFollowRedirects(req, resp, url, defaultMaxRedirectsCount, c)
body = bodyBuf.B
bodyBuf.B = oldBody
resp.keepBodyBuffer = false
ReleaseResponse(resp)
return statusCode, body, err
}
func doRequestFollowRedirects(
req *Request, resp *Response, url string, maxRedirectsCount int, c clientDoer,
) (statusCode int, body []byte, err error) {
redirectsCount := 0
for {
req.SetRequestURI(url)
if err := req.parseURI(); err != nil {
return 0, nil, err
}
if err = c.Do(req, resp); err != nil {
break
}
statusCode = resp.Header.StatusCode()
if !StatusCodeIsRedirect(statusCode) {
break
}
redirectsCount++
if redirectsCount > maxRedirectsCount {
err = ErrTooManyRedirects
break
}
location := resp.Header.peek(strLocation)
if len(location) == 0 {
err = ErrMissingLocation
break
}
url = getRedirectURL(url, location, req.DisableRedirectPathNormalizing)
if string(req.Header.Method()) == "POST" && (statusCode == 301 || statusCode == 302) {
req.Header.SetMethod(MethodGet)
}
}
return statusCode, body, err
}
func getRedirectURL(baseURL string, location []byte, disablePathNormalizing bool) string {
u := AcquireURI()
u.Update(baseURL)
u.UpdateBytes(location)
u.DisablePathNormalizing = disablePathNormalizing
redirectURL := u.String()
ReleaseURI(u)
return redirectURL
}
// StatusCodeIsRedirect returns true if the status code indicates a redirect.
func StatusCodeIsRedirect(statusCode int) bool {
return statusCode == StatusMovedPermanently ||
statusCode == StatusFound ||
statusCode == StatusSeeOther ||
statusCode == StatusTemporaryRedirect ||
statusCode == StatusPermanentRedirect
}
var (
requestPool sync.Pool
responsePool sync.Pool
)
// AcquireRequest returns an empty Request instance from request pool.
//
// The returned Request instance may be passed to ReleaseRequest when it is
// no longer needed. This allows Request recycling, reduces GC pressure
// and usually improves performance.
func AcquireRequest() *Request {
v := requestPool.Get()
if v == nil {
return &Request{}
}
return v.(*Request)
}
// ReleaseRequest returns req acquired via AcquireRequest to request pool.
//
// It is forbidden accessing req and/or its' members after returning
// it to request pool.
func ReleaseRequest(req *Request) {
req.Reset()
requestPool.Put(req)
}
// AcquireResponse returns an empty Response instance from response pool.
//
// The returned Response instance may be passed to ReleaseResponse when it is
// no longer needed. This allows Response recycling, reduces GC pressure
// and usually improves performance.
func AcquireResponse() *Response {
v := responsePool.Get()
if v == nil {
return &Response{}
}
return v.(*Response)
}
// ReleaseResponse return resp acquired via AcquireResponse to response pool.
//
// It is forbidden accessing resp and/or its' members after returning
// it to response pool.
func ReleaseResponse(resp *Response) {
resp.Reset()
responsePool.Put(resp)
}
// DoTimeout performs the given request and waits for response during
// the given timeout duration.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned during
// the given timeout.
// Immediately returns ErrTimeout if timeout value is negative.
//
// ErrNoFreeConns is returned if all HostClient.MaxConns connections
// to the host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *HostClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
req.timeout = timeout
if req.timeout <= 0 {
return ErrTimeout
}
return c.Do(req, resp)
}
// DoDeadline performs the given request and waits for response until
// the given deadline.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned until
// the given deadline.
// Immediately returns ErrTimeout if the deadline has already been reached.
//
// ErrNoFreeConns is returned if all HostClient.MaxConns connections
// to the host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *HostClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
req.timeout = time.Until(deadline)
if req.timeout <= 0 {
return ErrTimeout
}
return c.Do(req, resp)
}
// DoRedirects performs the given http request and fills the given http response,
// following up to maxRedirectsCount redirects. When the redirect count exceeds
// maxRedirectsCount, ErrTooManyRedirects is returned.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// Client determines the server to be requested in the following order:
//
// - from RequestURI if it contains full url with scheme and host;
// - from Host header otherwise.
//
// Response is ignored if resp is nil.
//
// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections
// to the requested host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *HostClient) DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error {
if c.DisablePathNormalizing {
req.URI().DisablePathNormalizing = true
}
_, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c)
return err
}
// Do performs the given http request and sets the corresponding response.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// ErrNoFreeConns is returned if all HostClient.MaxConns connections
// to the host are busy.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *HostClient) Do(req *Request, resp *Response) error {
var (
err error
retry bool
resetTimeout bool
)
maxAttempts := c.MaxIdemponentCallAttempts
if maxAttempts <= 0 {
maxAttempts = DefaultMaxIdemponentCallAttempts
}
attempts := 0
hasBodyStream := req.IsBodyStream()
// If a request has a timeout we store the timeout
// and calculate a deadline so we can keep updating the
// timeout on each retry.
deadline := time.Time{}
timeout := req.timeout
if timeout > 0 {
deadline = time.Now().Add(timeout)
}
retryFunc := c.RetryIf
if retryFunc == nil {
retryFunc = isIdempotent
}
atomic.AddInt32(&c.pendingRequests, 1)
for {
// If the original timeout was set, we need to update
// the one set on the request to reflect the remaining time.
if timeout > 0 {
req.timeout = time.Until(deadline)
if req.timeout <= 0 {
err = ErrTimeout
break
}
}
retry, err = c.do(req, resp)
if err == nil || !retry {
break
}
if hasBodyStream {
break
}
// Path prioritization based on ease of computation
attempts++
if attempts >= maxAttempts {
break
}
if c.RetryIfErr != nil {
resetTimeout, retry = c.RetryIfErr(req, attempts, err)
} else {
retry = retryFunc(req)
}
if !retry {
break
}
if timeout > 0 && resetTimeout {
deadline = time.Now().Add(timeout)
}
}
atomic.AddInt32(&c.pendingRequests, -1)
// Restore the original timeout.
req.timeout = timeout
if err == io.EOF {
err = ErrConnectionClosed
}
return err
}
// PendingRequests returns the current number of requests the client
// is executing.
//
// This function may be used for balancing load among multiple HostClient
// instances.
func (c *HostClient) PendingRequests() int {
return int(atomic.LoadInt32(&c.pendingRequests))
}
func isIdempotent(req *Request) bool {
return req.Header.IsGet() || req.Header.IsHead() || req.Header.IsPut()
}
func (c *HostClient) do(req *Request, resp *Response) (bool, error) {
if resp == nil {
resp = AcquireResponse()
defer ReleaseResponse(resp)
}
return c.doNonNilReqResp(req, resp)
}
func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) {
if req == nil {
// for debugging purposes
panic("BUG: req cannot be nil")
}
if resp == nil {
// for debugging purposes
panic("BUG: resp cannot be nil")
}
// Secure header error logs configuration
resp.secureErrorLogMessage = c.SecureErrorLogMessage
resp.Header.secureErrorLogMessage = c.SecureErrorLogMessage
req.secureErrorLogMessage = c.SecureErrorLogMessage
req.Header.secureErrorLogMessage = c.SecureErrorLogMessage
if c.IsTLS != req.URI().isHTTPS() {
return false, ErrHostClientRedirectToDifferentScheme
}
atomic.StoreUint32(&c.lastUseTime, uint32(time.Now().Unix()-startTimeUnix)) // #nosec G115
// Free up resources occupied by response before sending the request,
// so the GC may reclaim these resources (e.g. response body).
// backing up SkipBody in case it was set explicitly
customSkipBody := resp.SkipBody
customStreamBody := resp.StreamBody || c.StreamResponseBody
resp.Reset()
resp.SkipBody = customSkipBody
resp.StreamBody = customStreamBody
req.URI().DisablePathNormalizing = c.DisablePathNormalizing
userAgentOld := req.Header.UserAgent()
if len(userAgentOld) == 0 {
userAgent := c.Name
if userAgent == "" && !c.NoDefaultUserAgentHeader {
userAgent = defaultUserAgent
}
if userAgent != "" {
req.Header.userAgent = append(req.Header.userAgent[:0], userAgent...)
}
}
return c.transport().RoundTrip(c, req, resp)
}
func (c *HostClient) transport() RoundTripper {
if c.Transport == nil {
return DefaultTransport
}
return c.Transport
}
var (
// ErrNoFreeConns is returned when no free connections available
// to the given host.
//
// Increase the allowed number of connections per host if you
// see this error.
ErrNoFreeConns = errors.New("no free connections available to host")
// ErrConnectionClosed may be returned from client methods if the server
// closes connection before returning the first response byte.
//
// If you see this error, then either fix the server by returning
// 'Connection: close' response header before closing the connection
// or add 'Connection: close' request header before sending requests
// to broken server.
ErrConnectionClosed = errors.New("the server closed connection before returning the first response byte. " +
"Make sure the server returns 'Connection: close' response header before closing the connection")
// ErrConnPoolStrategyNotImpl is returned when HostClient.ConnPoolStrategy is not implement yet.
// If you see this error, then you need to check your HostClient configuration.
ErrConnPoolStrategyNotImpl = errors.New("connection pool strategy is not implement")
)
type timeoutError struct{}
func (e *timeoutError) Error() string {
return "timeout"
}
// Only implement the Timeout() function of the net.Error interface.
// This allows for checks like:
//
// if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
func (e *timeoutError) Timeout() bool {
return true
}
// ErrTimeout is returned from timed out calls.
var ErrTimeout = &timeoutError{}
// SetMaxConns sets up the maximum number of connections which may be established to all hosts listed in Addr.
func (c *HostClient) SetMaxConns(newMaxConns int) {
c.connsLock.Lock()
c.MaxConns = newMaxConns
c.connsLock.Unlock()
}
func (c *HostClient) AcquireConn(reqTimeout time.Duration, connectionClose bool) (cc *clientConn, err error) {
createConn := false
startCleaner := false
var n int
c.connsLock.Lock()
n = len(c.conns)
if n == 0 {
maxConns := c.MaxConns
if maxConns <= 0 {
maxConns = DefaultMaxConnsPerHost
}
if c.connsCount < maxConns {
c.connsCount++
createConn = true
if !c.connsCleanerRun && !connectionClose {
startCleaner = true
c.connsCleanerRun = true
}
}
} else {
switch c.ConnPoolStrategy {
case LIFO:
n--
cc = c.conns[n]
c.conns[n] = nil
c.conns = c.conns[:n]
case FIFO:
cc = c.conns[0]
copy(c.conns, c.conns[1:])
c.conns[n-1] = nil
c.conns = c.conns[:n-1]
default:
c.connsLock.Unlock()
return nil, ErrConnPoolStrategyNotImpl
}
}
c.connsLock.Unlock()
if cc != nil {
return cc, nil
}
if !createConn {
if c.MaxConnWaitTimeout <= 0 {
return nil, ErrNoFreeConns
}
//nolint:dupword
// reqTimeout c.MaxConnWaitTimeout wait duration
// d1 d2 min(d1, d2)
// 0(not set) d2 d2
// d1 0(don't wait) 0(don't wait)
// 0(not set) d2 d2
timeout := c.MaxConnWaitTimeout
timeoutOverridden := false
// reqTimeout == 0 means not set
if reqTimeout > 0 && reqTimeout < timeout {
timeout = reqTimeout
timeoutOverridden = true
}
// wait for a free connection
tc := AcquireTimer(timeout)
defer ReleaseTimer(tc)
w := &wantConn{
ready: make(chan struct{}, 1),
}
defer func() {
if err != nil {
w.cancel(c, err)
}
}()
c.queueForIdle(w)
select {
case <-w.ready:
return w.conn, w.err
case <-tc.C:
c.connsWait.failedWaiters.Add(1)
if timeoutOverridden {
return nil, ErrTimeout
}
return nil, ErrNoFreeConns
}
}
if startCleaner {
go c.connsCleaner()
}
conn, err := c.dialHostHard(reqTimeout)
if err != nil {
c.decConnsCount()
return nil, err
}
cc = acquireClientConn(conn)
return cc, nil
}
func (c *HostClient) queueForIdle(w *wantConn) {
c.connsLock.Lock()
defer c.connsLock.Unlock()
if c.connsWait == nil {
c.connsWait = &wantConnQueue{}
}
c.connsWait.clearFront()
c.connsWait.pushBack(w)
}
func (c *HostClient) dialConnFor(w *wantConn) {
conn, err := c.dialHostHard(0)
if err != nil {
w.tryDeliver(nil, err)
c.decConnsCount()
return
}
cc := acquireClientConn(conn)
if !w.tryDeliver(cc, nil) {
// not delivered, return idle connection
c.ReleaseConn(cc)
}
}
// CloseIdleConnections closes any connections which were previously
// connected from previous requests but are now sitting idle in a
// "keep-alive" state. It does not interrupt any connections currently
// in use.
func (c *HostClient) CloseIdleConnections() {
c.connsLock.Lock()
scratch := append([]*clientConn{}, c.conns...)
for i := range c.conns {
c.conns[i] = nil
}
c.conns = c.conns[:0]
c.connsLock.Unlock()
for _, cc := range scratch {
c.CloseConn(cc)
}
}
func (c *HostClient) connsCleaner() {
var (
scratch []*clientConn
maxIdleConnDuration = c.MaxIdleConnDuration
)
if maxIdleConnDuration <= 0 {
maxIdleConnDuration = DefaultMaxIdleConnDuration
}
for {
currentTime := time.Now()
// Determine idle connections to be closed.
c.connsLock.Lock()
conns := c.conns
n := len(conns)
i := 0
for i < n && currentTime.Sub(conns[i].lastUseTime) > maxIdleConnDuration {
i++
}
sleepFor := maxIdleConnDuration
if i < n {
// + 1 so we actually sleep past the expiration time and not up to it.
// Otherwise the > check above would still fail.
sleepFor = maxIdleConnDuration - currentTime.Sub(conns[i].lastUseTime) + 1
}
scratch = append(scratch[:0], conns[:i]...)
if i > 0 {
m := copy(conns, conns[i:])
for i = m; i < n; i++ {
conns[i] = nil
}
c.conns = conns[:m]
}
c.connsLock.Unlock()
// Close idle connections.
for i, cc := range scratch {
c.CloseConn(cc)
scratch[i] = nil
}
// Determine whether to stop the connsCleaner.
c.connsLock.Lock()
mustStop := c.connsCount == 0
if mustStop {
c.connsCleanerRun = false
}
c.connsLock.Unlock()
if mustStop {
break
}
time.Sleep(sleepFor)
}
}
func (c *HostClient) CloseConn(cc *clientConn) {
c.decConnsCount()
cc.c.Close()
releaseClientConn(cc)
}
func (c *HostClient) decConnsCount() {
if c.MaxConnWaitTimeout <= 0 {
c.connsLock.Lock()
c.connsCount--
c.connsLock.Unlock()
return
}
c.connsLock.Lock()
defer c.connsLock.Unlock()
dialed := false
if q := c.connsWait; q != nil && q.len() > 0 {
for q.len() > 0 {
w := q.popFront()
if w.waiting() {
go c.dialConnFor(w)
dialed = true
break
}
c.connsWait.failedWaiters.Add(-1)
}
}
if !dialed {
c.connsCount--
}
}
// ConnsCount returns connection count of HostClient.
func (c *HostClient) ConnsCount() int {
c.connsLock.Lock()
defer c.connsLock.Unlock()
return c.connsCount
}
func acquireClientConn(conn net.Conn) *clientConn {
v := clientConnPool.Get()
if v == nil {
v = &clientConn{}
}
cc := v.(*clientConn)
cc.c = conn
cc.createdTime = time.Now()
return cc
}
func releaseClientConn(cc *clientConn) {
// Reset all fields.
*cc = clientConn{}
clientConnPool.Put(cc)
}
var clientConnPool sync.Pool
func (c *HostClient) ReleaseConn(cc *clientConn) {
cc.lastUseTime = time.Now()
if c.MaxConnWaitTimeout <= 0 {
c.connsLock.Lock()
c.conns = append(c.conns, cc)
c.connsLock.Unlock()
return
}
// try to deliver an idle connection to a *wantConn
c.connsLock.Lock()
defer c.connsLock.Unlock()
delivered := false
if q := c.connsWait; q != nil && q.len() > 0 {
for q.len() > 0 {
w := q.popFront()
if w.waiting() {
delivered = w.tryDeliver(cc, nil)
// This is the last resort to hand over conCount sema.
// We must ensure that there are no valid waiters in connsWait
// when we exit this loop.
//
// We did not apply the same looping pattern in the decConnsCount
// method because it needs to create a new time-spent connection,
// and the decConnsCount call chain will inevitably reach this point.
// When MaxConnWaitTimeout>0.
if delivered {
break
}
}
c.connsWait.failedWaiters.Add(-1)
}
}
if !delivered {
c.conns = append(c.conns, cc)
}
}
func (c *HostClient) AcquireWriter(conn net.Conn) *bufio.Writer {
var v any
if c.clientWriterPool != nil {
v = c.clientWriterPool.Get()
} else {
v = c.writerPool.Get()
}
if v == nil {
n := c.WriteBufferSize
if n <= 0 {
n = defaultWriteBufferSize
}
return bufio.NewWriterSize(conn, n)
}
bw := v.(*bufio.Writer)
bw.Reset(conn)
return bw
}
func (c *HostClient) ReleaseWriter(bw *bufio.Writer) {
if c.clientWriterPool != nil {
c.clientWriterPool.Put(bw)
} else {
c.writerPool.Put(bw)
}
}
func (c *HostClient) AcquireReader(conn net.Conn) *bufio.Reader {
var v any
if c.clientReaderPool != nil {
v = c.clientReaderPool.Get()
} else {
v = c.readerPool.Get()
}
if v == nil {
n := c.ReadBufferSize
if n <= 0 {
n = defaultReadBufferSize
}
return bufio.NewReaderSize(conn, n)
}
br := v.(*bufio.Reader)
br.Reset(conn)
return br
}
func (c *HostClient) ReleaseReader(br *bufio.Reader) {
if c.clientReaderPool != nil {
c.clientReaderPool.Put(br)
} else {
c.readerPool.Put(br)
}
}
func newClientTLSConfig(c *tls.Config, addr string) *tls.Config {
if c == nil {
c = &tls.Config{}
} else {
c = c.Clone()
}
if c.ServerName == "" {
serverName := tlsServerName(addr)
if serverName == "*" {
c.InsecureSkipVerify = true
} else {
c.ServerName = serverName
}
}
return c
}
func tlsServerName(addr string) string {
if !strings.Contains(addr, ":") {
return addr
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
return "*"
}
return host
}
func (c *HostClient) nextAddr() string {
c.addrsLock.Lock()
if c.addrs == nil {
c.addrs = strings.Split(c.Addr, ",")
}
addr := c.addrs[0]
if len(c.addrs) > 1 {
addr = c.addrs[c.addrIdx%uint32(len(c.addrs))] // #nosec G115
c.addrIdx++
}
c.addrsLock.Unlock()
return addr
}
func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err error) {
// use dialTimeout to control the timeout of each dial. It does not work if dialTimeout is 0 or if
// c.DialTimeout has not been set and c.Dial has been set.
// attempt to dial all the available hosts before giving up.
c.addrsLock.Lock()
n := len(c.addrs)
c.addrsLock.Unlock()
if n == 0 {
// It looks like c.addrs isn't initialized yet.
n = 1
}
timeout := c.ReadTimeout + c.WriteTimeout
if timeout <= 0 {
timeout = DefaultDialTimeout
}
deadline := time.Now().Add(timeout)
for n > 0 {
addr := c.nextAddr()
tlsConfig := c.cachedTLSConfig(addr)
conn, err = dialAddr(addr, c.Dial, c.DialTimeout, c.DialDualStack, c.IsTLS, tlsConfig, dialTimeout, c.WriteTimeout)
if err == nil {
return conn, nil
}
if time.Since(deadline) >= 0 {
break
}
n--
}
return nil, err
}
func (c *HostClient) cachedTLSConfig(addr string) *tls.Config {
if !c.IsTLS {
return nil
}
c.tlsConfigMapLock.Lock()
if c.tlsConfigMap == nil {
c.tlsConfigMap = make(map[string]*tls.Config)
}
cfg := c.tlsConfigMap[addr]
if cfg == nil {
cfg = newClientTLSConfig(c.TLSConfig, addr)
c.tlsConfigMap[addr] = cfg
}
c.tlsConfigMapLock.Unlock()
return cfg
}
// ErrTLSHandshakeTimeout indicates there is a timeout from tls handshake.
var ErrTLSHandshakeTimeout = errors.New("tls handshake timed out")
func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, deadline time.Time) (_ net.Conn, retErr error) {
defer func() {
if retErr != nil {
rawConn.Close()
}
}()
conn := tls.Client(rawConn, tlsConfig)
err := conn.SetDeadline(deadline)
if err != nil {
return nil, err
}
err = conn.Handshake()
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return nil, ErrTLSHandshakeTimeout
}
if err != nil {
return nil, err
}
err = conn.SetDeadline(time.Time{})
if err != nil {
return nil, err
}
return conn, nil
}
func dialAddr(
addr string, dial DialFunc, dialWithTimeout DialFuncWithTimeout, dialDualStack, isTLS bool,
tlsConfig *tls.Config, dialTimeout, writeTimeout time.Duration,
) (net.Conn, error) {
deadline := time.Now().Add(writeTimeout)
conn, err := callDialFunc(addr, dial, dialWithTimeout, dialDualStack, isTLS, dialTimeout)
if err != nil {
return nil, err
}
if conn == nil {
return nil, errors.New("dialling unsuccessful. Please report this bug")
}
// We assume that any conn that has the Handshake() method is a TLS conn already.
// This doesn't cover just tls.Conn but also other TLS implementations.
_, isTLSAlready := conn.(interface{ Handshake() error })
if isTLS && !isTLSAlready {
if writeTimeout == 0 {
return tls.Client(conn, tlsConfig), nil
}
return tlsClientHandshake(conn, tlsConfig, deadline)
}
return conn, nil
}
func callDialFunc(
addr string, dial DialFunc, dialWithTimeout DialFuncWithTimeout, dialDualStack, isTLS bool, timeout time.Duration,
) (net.Conn, error) {
if dialWithTimeout != nil {
return dialWithTimeout(addr, timeout)
}
if dial != nil {
return dial(addr)
}
addr = AddMissingPort(addr, isTLS)
if timeout > 0 {
if dialDualStack {
return DialDualStackTimeout(addr, timeout)
}
return DialTimeout(addr, timeout)
}
if dialDualStack {
return DialDualStack(addr)
}
return Dial(addr)
}
// AddMissingPort adds a port to a host if it is missing.
// A literal IPv6 address in hostport must be enclosed in square
// brackets, as in "[::1]:80", "[::1%lo0]:80".
func AddMissingPort(addr string, isTLS bool) string {
addrLen := len(addr)
if addrLen == 0 {
return addr
}
isIP6 := addr[0] == '['
if isIP6 {
// if the IPv6 has opening bracket but closing bracket is the last char then it doesn't have a port
isIP6WithoutPort := addr[addrLen-1] == ']'
if !isIP6WithoutPort {
return addr
}
} else { // IPv4
columnPos := strings.LastIndexByte(addr, ':')
if columnPos > 0 {
return addr
}
}
port := ":80"
if isTLS {
port = ":443"
}
return addr + port
}
// A wantConn records state about a wanted connection
// (that is, an active call to getConn).
// The conn may be gotten by dialing or by finding an idle connection,
// or a cancellation may make the conn no longer wanted.
// These three options are racing against each other and use
// wantConn to coordinate and agree about the winning outcome.
//
// Inspired by net/http/transport.go.
type wantConn struct {
err error
ready chan struct{}
conn *clientConn
mu sync.Mutex // protects conn, err, close(ready)
}
// waiting reports whether w is still waiting for an answer (connection or error).
func (w *wantConn) waiting() bool {
select {
case <-w.ready:
return false
default:
return true
}
}
// tryDeliver attempts to deliver conn, err to w and reports whether it succeeded.
func (w *wantConn) tryDeliver(conn *clientConn, err error) bool {
w.mu.Lock()
defer w.mu.Unlock()
if w.conn != nil || w.err != nil {
return false
}
w.conn = conn
w.err = err
if w.conn == nil && w.err == nil {
panic("fasthttp: internal error: misuse of tryDeliver")
}
close(w.ready)
return true
}
// cancel marks w as no longer wanting a result (for example, due to cancellation).
// If a connection has been delivered already, cancel returns it with c.releaseConn.
func (w *wantConn) cancel(c *HostClient, err error) {
w.mu.Lock()
if w.conn == nil && w.err == nil {
close(w.ready) // catch misbehavior in future delivery
}
conn := w.conn
w.conn = nil
w.err = err
w.mu.Unlock()
if conn != nil {
c.ReleaseConn(conn)
}
}
// A wantConnQueue is a queue of wantConns.
//
// Inspired by net/http/transport.go.
type wantConnQueue struct {
// This is a queue, not a dequeue.
// It is split into two stages - head[headPos:] and tail.
// popFront is trivial (headPos++) on the first stage, and
// pushBack is trivial (append) on the second stage.
// If the first stage is empty, popFront can swap the
// first and second stages to remedy the situation.
//
// This two-stage split is analogous to the use of two lists
// in Okasaki's purely functional queue but without the
// overhead of reversing the list when swapping stages.
head []*wantConn
tail []*wantConn
headPos int
// failedWaiters is the number of waiters in the head or tail queue,
// but is invalid.
// These state waiters cannot truly be considered as waiters; the current
// implementation does not immediately remove them when they become
// invalid but instead only marks them.
failedWaiters atomic.Int64
}
// len returns the number of items in the queue.
func (q *wantConnQueue) len() int {
return len(q.head) - q.headPos + len(q.tail) - int(q.failedWaiters.Load())
}
// pushBack adds w to the back of the queue.
func (q *wantConnQueue) pushBack(w *wantConn) {
q.tail = append(q.tail, w)
}
// popFront removes and returns the wantConn at the front of the queue.
func (q *wantConnQueue) popFront() *wantConn {
if q.headPos >= len(q.head) {
if len(q.tail) == 0 {
return nil
}
// Pick up tail as new head, clear tail.
q.head, q.headPos, q.tail = q.tail, 0, q.head[:0]
}
w := q.head[q.headPos]
q.head[q.headPos] = nil
q.headPos++
return w
}
// peekFront returns the wantConn at the front of the queue without removing it.
func (q *wantConnQueue) peekFront() *wantConn {
if q.headPos < len(q.head) {
return q.head[q.headPos]
}
if len(q.tail) > 0 {
return q.tail[0]
}
return nil
}
// clearFront pops any wantConns that are no longer waiting from the head of the
// queue, reporting whether any were popped.
func (q *wantConnQueue) clearFront() (cleaned bool) {
for {
w := q.peekFront()
if w == nil || w.waiting() {
return cleaned
}
q.popFront()
q.failedWaiters.Add(-1)
cleaned = true
}
}
// PipelineClient pipelines requests over a limited set of concurrent
// connections to the given Addr.
//
// This client may be used in highly loaded HTTP-based RPC systems for reducing
// context switches and network level overhead.
// See https://en.wikipedia.org/wiki/HTTP_pipelining for details.
//
// It is forbidden copying PipelineClient instances. Create new instances
// instead.
//
// It is safe calling PipelineClient methods from concurrently running
// goroutines.
type PipelineClient struct {
noCopy noCopy
// Logger for logging client errors.
//
// By default standard logger from log package is used.
Logger Logger
// Callback for connection establishing to the host.
//
// Default Dial is used if not set.
Dial DialFunc
// Optional TLS config.
TLSConfig *tls.Config
// Address of the host to connect to.
Addr string
// PipelineClient name. Used in User-Agent request header.
Name string
connClients []*pipelineConnClient
// The maximum number of concurrent connections to the Addr.
//
// A single connection is used by default.
MaxConns int
// The maximum number of pending pipelined requests over
// a single connection to Addr.
//
// DefaultMaxPendingRequests is used by default.
MaxPendingRequests int
// The maximum delay before sending pipelined requests as a batch
// to the server.
//
// By default requests are sent immediately to the server.
MaxBatchDelay time.Duration
// Idle connection to the host is closed after this duration.
//
// By default idle connection is closed after
// DefaultMaxIdleConnDuration.
MaxIdleConnDuration time.Duration
// Buffer size for responses' reading.
// This also limits the maximum header size.
//
// Default buffer size is used if 0.
ReadBufferSize int
// Buffer size for requests' writing.
//
// Default buffer size is used if 0.
WriteBufferSize int
// Maximum duration for full response reading (including body).
//
// By default response read timeout is unlimited.
ReadTimeout time.Duration
// Maximum duration for full request writing (including body).
//
// By default request write timeout is unlimited.
WriteTimeout time.Duration
connClientsLock sync.Mutex
// NoDefaultUserAgentHeader when set to true, causes the default
// User-Agent header to be excluded from the Request.
NoDefaultUserAgentHeader bool
// Attempt to connect to both ipv4 and ipv6 host addresses
// if set to true.
//
// This option is used only if default TCP dialer is used,
// i.e. if Dial is blank.
//
// By default client connects only to ipv4 addresses,
// since unfortunately ipv6 remains broken in many networks worldwide :)
DialDualStack bool
// Response header names are passed as-is without normalization
// if this option is set.
//
// Disabled header names' normalization may be useful only for proxying
// responses to other clients expecting case-sensitive
// header names. See https://github.com/valyala/fasthttp/issues/57
// for details.
//
// By default request and response header names are normalized, i.e.
// The first letter and the first letters following dashes
// are uppercased, while all the other letters are lowercased.
// Examples:
//
// * HOST -> Host
// * content-type -> Content-Type
// * cONTENT-lenGTH -> Content-Length
DisableHeaderNamesNormalizing bool
// Path values are sent as-is without normalization
//
// Disabled path normalization may be useful for proxying incoming requests
// to servers that are expecting paths to be forwarded as-is.
//
// By default path values are normalized, i.e.
// extra slashes are removed, special characters are encoded.
DisablePathNormalizing bool
// Whether to use TLS (aka SSL or HTTPS) for host connections.
IsTLS bool
}
type pipelineConnClient struct {
noCopy noCopy
workPool sync.Pool
Logger Logger
Dial DialFunc
TLSConfig *tls.Config
chW chan *pipelineWork
chR chan *pipelineWork
tlsConfig *tls.Config
Addr string
Name string
MaxPendingRequests int
MaxBatchDelay time.Duration
MaxIdleConnDuration time.Duration
ReadBufferSize int
WriteBufferSize int
ReadTimeout time.Duration
WriteTimeout time.Duration
chLock sync.Mutex
tlsConfigLock sync.Mutex
NoDefaultUserAgentHeader bool
DialDualStack bool
DisableHeaderNamesNormalizing bool
DisablePathNormalizing bool
IsTLS bool
}
type pipelineWork struct {
respCopy Response
deadline time.Time
err error
req *Request
resp *Response
t *time.Timer
done chan struct{}
reqCopy Request
}
// DoTimeout performs the given request and waits for response during
// the given timeout duration.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// The function doesn't follow redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned during
// the given timeout.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *PipelineClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
return c.DoDeadline(req, resp, time.Now().Add(timeout))
}
// DoDeadline performs the given request and waits for response until
// the given deadline.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// The function doesn't follow redirects.
//
// Response is ignored if resp is nil.
//
// ErrTimeout is returned if the response wasn't returned until
// the given deadline.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *PipelineClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
return c.getConnClient().DoDeadline(req, resp, deadline)
}
func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
c.init()
timeout := time.Until(deadline)
if timeout <= 0 {
return ErrTimeout
}
if c.DisablePathNormalizing {
req.URI().DisablePathNormalizing = true
}
userAgentOld := req.Header.UserAgent()
if len(userAgentOld) == 0 {
userAgent := c.Name
if userAgent == "" && !c.NoDefaultUserAgentHeader {
userAgent = defaultUserAgent
}
if userAgent != "" {
req.Header.userAgent = append(req.Header.userAgent[:0], userAgent...)
}
}
w := c.acquirePipelineWork(timeout)
w.respCopy.Header.disableNormalizing = c.DisableHeaderNamesNormalizing
w.req = &w.reqCopy
w.resp = &w.respCopy
// Make a copy of the request in order to avoid data races on timeouts
req.copyToSkipBody(&w.reqCopy)
swapRequestBody(req, &w.reqCopy)
// Put the request to outgoing queue
select {
case c.chW <- w:
// Fast path: len(c.ch) < cap(c.ch)
default:
// Slow path
select {
case c.chW <- w:
case <-w.t.C:
c.releasePipelineWork(w)
return ErrTimeout
}
}
// Wait for the response
var err error
select {
case <-w.done:
if resp != nil {
w.respCopy.copyToSkipBody(resp)
swapResponseBody(resp, &w.respCopy)
}
err = w.err
c.releasePipelineWork(w)
case <-w.t.C:
err = ErrTimeout
}
return err
}
func (c *pipelineConnClient) acquirePipelineWork(timeout time.Duration) (w *pipelineWork) {
v := c.workPool.Get()
if v != nil {
w = v.(*pipelineWork)
} else {
w = &pipelineWork{
done: make(chan struct{}, 1),
}
}
if timeout > 0 {
if w.t == nil {
w.t = time.NewTimer(timeout)
} else {
w.t.Reset(timeout)
}
w.deadline = time.Now().Add(timeout)
} else {
w.deadline = zeroTime
}
return w
}
func (c *pipelineConnClient) releasePipelineWork(w *pipelineWork) {
if w.t != nil {
w.t.Stop()
}
w.reqCopy.Reset()
w.respCopy.Reset()
w.req = nil
w.resp = nil
w.err = nil
c.workPool.Put(w)
}
// Do performs the given http request and sets the corresponding response.
//
// Request must contain at least non-zero RequestURI with full url (including
// scheme and host) or non-zero Host header + RequestURI.
//
// The function doesn't follow redirects. Use Get* for following redirects.
//
// Response is ignored if resp is nil.
//
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *PipelineClient) Do(req *Request, resp *Response) error {
return c.getConnClient().Do(req, resp)
}
func (c *pipelineConnClient) Do(req *Request, resp *Response) error {
c.init()
if c.DisablePathNormalizing {
req.URI().DisablePathNormalizing = true
}
userAgentOld := req.Header.UserAgent()
if len(userAgentOld) == 0 {
userAgent := c.Name
if userAgent == "" && !c.NoDefaultUserAgentHeader {
userAgent = defaultUserAgent
}
if userAgent != "" {
req.Header.userAgent = append(req.Header.userAgent[:0], userAgent...)
}
}
w := c.acquirePipelineWork(0)
w.req = req
if resp != nil {
resp.Header.disableNormalizing = c.DisableHeaderNamesNormalizing
w.resp = resp
} else {
w.resp = &w.respCopy
}
// Put the request to outgoing queue
select {
case c.chW <- w:
default:
// Try substituting the oldest w with the current one.
select {
case wOld := <-c.chW:
wOld.err = ErrPipelineOverflow
wOld.done <- struct{}{}
default:
}
select {
case c.chW <- w:
default:
c.releasePipelineWork(w)
return ErrPipelineOverflow
}
}
// Wait for the response
<-w.done
err := w.err
c.releasePipelineWork(w)
return err
}
func (c *PipelineClient) getConnClient() *pipelineConnClient {
c.connClientsLock.Lock()
cc := c.getConnClientUnlocked()
c.connClientsLock.Unlock()
return cc
}
func (c *PipelineClient) getConnClientUnlocked() *pipelineConnClient {
if len(c.connClients) == 0 {
return c.newConnClient()
}
// Return the client with the minimum number of pending requests.
minCC := c.connClients[0]
minReqs := minCC.PendingRequests()
if minReqs == 0 {
return minCC
}
for i := 1; i < len(c.connClients); i++ {
cc := c.connClients[i]
reqs := cc.PendingRequests()
if reqs == 0 {
return cc
}
if reqs < minReqs {
minCC = cc
minReqs = reqs
}
}
maxConns := c.MaxConns
if maxConns <= 0 {
maxConns = 1
}
if len(c.connClients) < maxConns {
return c.newConnClient()
}
return minCC
}
func (c *PipelineClient) newConnClient() *pipelineConnClient {
cc := &pipelineConnClient{
Addr: c.Addr,
Name: c.Name,
NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader,
MaxPendingRequests: c.MaxPendingRequests,
MaxBatchDelay: c.MaxBatchDelay,
Dial: c.Dial,
DialDualStack: c.DialDualStack,
DisableHeaderNamesNormalizing: c.DisableHeaderNamesNormalizing,
DisablePathNormalizing: c.DisablePathNormalizing,
IsTLS: c.IsTLS,
TLSConfig: c.TLSConfig,
MaxIdleConnDuration: c.MaxIdleConnDuration,
ReadBufferSize: c.ReadBufferSize,
WriteBufferSize: c.WriteBufferSize,
ReadTimeout: c.ReadTimeout,
WriteTimeout: c.WriteTimeout,
Logger: c.Logger,
}
c.connClients = append(c.connClients, cc)
return cc
}
// ErrPipelineOverflow may be returned from PipelineClient.Do*
// if the requests' queue is overflowed.
var ErrPipelineOverflow = errors.New("pipelined requests' queue has been overflowed. Increase MaxConns and/or MaxPendingRequests")
// DefaultMaxPendingRequests is the default value
// for PipelineClient.MaxPendingRequests.
const DefaultMaxPendingRequests = 1024
func (c *pipelineConnClient) init() {
c.chLock.Lock()
if c.chR == nil {
maxPendingRequests := c.MaxPendingRequests
if maxPendingRequests <= 0 {
maxPendingRequests = DefaultMaxPendingRequests
}
c.chR = make(chan *pipelineWork, maxPendingRequests)
if c.chW == nil {
c.chW = make(chan *pipelineWork, maxPendingRequests)
}
go func() {
// Keep restarting the worker if it fails (connection errors for example).
for {
if err := c.worker(); err != nil {
c.logger().Printf("error in PipelineClient(%q): %v", c.Addr, err)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Throttle client reconnections on timeout errors
time.Sleep(time.Second)
}
} else {
c.chLock.Lock()
stop := len(c.chR) == 0 && len(c.chW) == 0
if !stop {
c.chR = nil
c.chW = nil
}
c.chLock.Unlock()
if stop {
break
}
}
}
}()
}
c.chLock.Unlock()
}
func (c *pipelineConnClient) worker() error {
tlsConfig := c.cachedTLSConfig()
conn, err := dialAddr(c.Addr, c.Dial, nil, c.DialDualStack, c.IsTLS, tlsConfig, 0, c.WriteTimeout)
if err != nil {
return err
}
// Start reader and writer
stopW := make(chan struct{})
doneW := make(chan error)
go func() {
doneW <- c.writer(conn, stopW)
}()
stopR := make(chan struct{})
doneR := make(chan error)
go func() {
doneR <- c.reader(conn, stopR)
}()
// Wait until reader and writer are stopped
select {
case err = <-doneW:
conn.Close()
close(stopR)
<-doneR
case err = <-doneR:
conn.Close()
close(stopW)
<-doneW
}
// Notify pending readers
for len(c.chR) > 0 {
w := <-c.chR
w.err = errPipelineConnStopped
w.done <- struct{}{}
}
return err
}
func (c *pipelineConnClient) cachedTLSConfig() *tls.Config {
if !c.IsTLS {
return nil
}
c.tlsConfigLock.Lock()
cfg := c.tlsConfig
if cfg == nil {
cfg = newClientTLSConfig(c.TLSConfig, c.Addr)
c.tlsConfig = cfg
}
c.tlsConfigLock.Unlock()
return cfg
}
func (c *pipelineConnClient) writer(conn net.Conn, stopCh <-chan struct{}) error {
writeBufferSize := c.WriteBufferSize
if writeBufferSize <= 0 {
writeBufferSize = defaultWriteBufferSize
}
bw := bufio.NewWriterSize(conn, writeBufferSize)
defer bw.Flush()
chR := c.chR
chW := c.chW
writeTimeout := c.WriteTimeout
maxIdleConnDuration := c.MaxIdleConnDuration
if maxIdleConnDuration <= 0 {
maxIdleConnDuration = DefaultMaxIdleConnDuration
}
maxBatchDelay := c.MaxBatchDelay
var (
stopTimer = time.NewTimer(time.Hour)
flushTimer = time.NewTimer(time.Hour)
flushTimerCh <-chan time.Time
instantTimerCh = make(chan time.Time)
w *pipelineWork
err error
)
close(instantTimerCh)
for {
againChW:
select {
case w = <-chW:
// Fast path: len(chW) > 0
default:
// Slow path
stopTimer.Reset(maxIdleConnDuration)
select {
case w = <-chW:
case <-stopTimer.C:
return nil
case <-stopCh:
return nil
case <-flushTimerCh:
if err = bw.Flush(); err != nil {
return err
}
flushTimerCh = nil
goto againChW
}
}
if !w.deadline.IsZero() && time.Since(w.deadline) >= 0 {
w.err = ErrTimeout
w.done <- struct{}{}
continue
}
w.resp.ParseNetConn(conn)
if writeTimeout > 0 {
// Set Deadline every time, since golang has fixed the performance issue
// See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details
currentTime := time.Now()
if err = conn.SetWriteDeadline(currentTime.Add(writeTimeout)); err != nil {
w.err = err
w.done <- struct{}{}
return err
}
}
if err = w.req.Write(bw); err != nil {
w.err = err
w.done <- struct{}{}
return err
}
if flushTimerCh == nil && (len(chW) == 0 || len(chR) == cap(chR)) {
if maxBatchDelay > 0 {
flushTimer.Reset(maxBatchDelay)
flushTimerCh = flushTimer.C
} else {
flushTimerCh = instantTimerCh
}
}
againChR:
select {
case chR <- w:
// Fast path: len(chR) < cap(chR)
default:
// Slow path
select {
case chR <- w:
case <-stopCh:
w.err = errPipelineConnStopped
w.done <- struct{}{}
return nil
case <-flushTimerCh:
if err = bw.Flush(); err != nil {
w.err = err
w.done <- struct{}{}
return err
}
flushTimerCh = nil
goto againChR
}
}
}
}
func (c *pipelineConnClient) reader(conn net.Conn, stopCh <-chan struct{}) error {
readBufferSize := c.ReadBufferSize
if readBufferSize <= 0 {
readBufferSize = defaultReadBufferSize
}
br := bufio.NewReaderSize(conn, readBufferSize)
chR := c.chR
readTimeout := c.ReadTimeout
var (
w *pipelineWork
err error
)
for {
select {
case w = <-chR:
// Fast path: len(chR) > 0
default:
// Slow path
select {
case w = <-chR:
case <-stopCh:
return nil
}
}
if readTimeout > 0 {
// Set Deadline every time, since golang has fixed the performance issue
// See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details
currentTime := time.Now()
if err = conn.SetReadDeadline(currentTime.Add(readTimeout)); err != nil {
w.err = err
w.done <- struct{}{}
return err
}
}
if err = w.resp.Read(br); err != nil {
w.err = err
w.done <- struct{}{}
return err
}
w.done <- struct{}{}
}
}
func (c *pipelineConnClient) logger() Logger {
if c.Logger != nil {
return c.Logger
}
return defaultLogger
}
// PendingRequests returns the current number of pending requests pipelined
// to the server.
//
// This number may exceed MaxPendingRequests*MaxConns by up to two times, since
// each connection to the server may keep up to MaxPendingRequests requests
// in the queue before sending them to the server.
//
// This function may be used for balancing load among multiple PipelineClient
// instances.
func (c *PipelineClient) PendingRequests() int {
c.connClientsLock.Lock()
n := 0
for _, cc := range c.connClients {
n += cc.PendingRequests()
}
c.connClientsLock.Unlock()
return n
}
func (c *pipelineConnClient) PendingRequests() int {
c.init()
c.chLock.Lock()
n := len(c.chR) + len(c.chW)
c.chLock.Unlock()
return n
}
var errPipelineConnStopped = errors.New("pipeline connection has been stopped")
var DefaultTransport RoundTripper = &transport{}
type transport struct{}
func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error) {
customSkipBody := resp.SkipBody
customStreamBody := resp.StreamBody
var deadline time.Time
if req.timeout > 0 {
deadline = time.Now().Add(req.timeout)
}
cc, err := hc.AcquireConn(req.timeout, req.ConnectionClose())
if err != nil {
return false, err
}
conn := cc.c
resp.ParseNetConn(conn)
writeDeadline := deadline
if hc.WriteTimeout > 0 {
tmpWriteDeadline := time.Now().Add(hc.WriteTimeout)
if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
writeDeadline = tmpWriteDeadline
}
}
if err = conn.SetWriteDeadline(writeDeadline); err != nil {
hc.CloseConn(cc)
return true, err
}
resetConnection := false
if hc.MaxConnDuration > 0 && time.Since(cc.createdTime) > hc.MaxConnDuration && !req.ConnectionClose() {
req.SetConnectionClose()
resetConnection = true
}
bw := hc.AcquireWriter(conn)
err = req.Write(bw)
if resetConnection {
req.Header.ResetConnectionClose()
}
if err == nil {
err = bw.Flush()
}
hc.ReleaseWriter(bw)
// Return ErrTimeout on any timeout.
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
err = ErrTimeout
}
if err != nil {
hc.CloseConn(cc)
return true, err
}
readDeadline := deadline
if hc.ReadTimeout > 0 {
tmpReadDeadline := time.Now().Add(hc.ReadTimeout)
if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
readDeadline = tmpReadDeadline
}
}
if err = conn.SetReadDeadline(readDeadline); err != nil {
hc.CloseConn(cc)
return true, err
}
if customSkipBody || req.Header.IsHead() {
resp.SkipBody = true
}
if hc.DisableHeaderNamesNormalizing {
resp.Header.DisableNormalizing()
}
br := hc.AcquireReader(conn)
err = resp.ReadLimitBody(br, hc.MaxResponseBodySize)
if err != nil {
hc.ReleaseReader(br)
hc.CloseConn(cc)
// Don't retry in case of ErrBodyTooLarge since we will just get the same again.
needRetry := err != ErrBodyTooLarge
return needRetry, err
}
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose()
if customStreamBody && resp.bodyStream != nil {
rbs := resp.bodyStream
resp.bodyStream = newCloseReaderWithError(rbs, func(wErr error) error {
hc.ReleaseReader(br)
if r, ok := rbs.(*requestStream); ok {
releaseRequestStream(r)
}
if closeConn || resp.ConnectionClose() || wErr != nil {
hc.CloseConn(cc)
} else {
hc.ReleaseConn(cc)
}
return nil
})
return false, nil
}
hc.ReleaseReader(br)
if closeConn {
hc.CloseConn(cc)
} else {
hc.ReleaseConn(cc)
}
return false, nil
}
package fasthttp
import (
"time"
)
// CoarseTimeNow returns the current time truncated to the nearest second.
//
// Deprecated: This is slower than calling time.Now() directly.
// This is now time.Now().Truncate(time.Second) shortcut.
func CoarseTimeNow() time.Time {
return time.Now().Truncate(time.Second)
}
package fasthttp
import (
"bytes"
"fmt"
"io"
"io/fs"
"sync"
"github.com/klauspost/compress/flate"
"github.com/klauspost/compress/gzip"
"github.com/klauspost/compress/zlib"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp/stackless"
)
// Supported compression levels.
const (
CompressNoCompression = flate.NoCompression
CompressBestSpeed = flate.BestSpeed
CompressBestCompression = flate.BestCompression
CompressDefaultCompression = 6 // flate.DefaultCompression
CompressHuffmanOnly = -2 // flate.HuffmanOnly
)
func acquireGzipReader(r io.Reader) (*gzip.Reader, error) {
v := gzipReaderPool.Get()
if v == nil {
return gzip.NewReader(r)
}
zr := v.(*gzip.Reader)
if err := zr.Reset(r); err != nil {
return nil, err
}
return zr, nil
}
func releaseGzipReader(zr *gzip.Reader) {
zr.Close()
gzipReaderPool.Put(zr)
}
var gzipReaderPool sync.Pool
func acquireFlateReader(r io.Reader) (io.ReadCloser, error) {
v := flateReaderPool.Get()
if v == nil {
zr, err := zlib.NewReader(r)
if err != nil {
return nil, err
}
return zr, nil
}
zr := v.(io.ReadCloser)
if err := resetFlateReader(zr, r); err != nil {
return nil, err
}
return zr, nil
}
func releaseFlateReader(zr io.ReadCloser) {
zr.Close()
flateReaderPool.Put(zr)
}
func resetFlateReader(zr io.ReadCloser, r io.Reader) error {
zrr, ok := zr.(zlib.Resetter)
if !ok {
// sanity check. should only be called with a zlib.Reader
panic("BUG: zlib.Reader doesn't implement zlib.Resetter???")
}
return zrr.Reset(r, nil)
}
var flateReaderPool sync.Pool
func acquireStacklessGzipWriter(w io.Writer, level int) stackless.Writer {
nLevel := normalizeCompressLevel(level)
p := stacklessGzipWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
return stackless.NewWriter(w, func(w io.Writer) stackless.Writer {
return acquireRealGzipWriter(w, level)
})
}
sw := v.(stackless.Writer)
sw.Reset(w)
return sw
}
func releaseStacklessGzipWriter(sw stackless.Writer, level int) {
sw.Close()
nLevel := normalizeCompressLevel(level)
p := stacklessGzipWriterPoolMap[nLevel]
p.Put(sw)
}
func acquireRealGzipWriter(w io.Writer, level int) *gzip.Writer {
nLevel := normalizeCompressLevel(level)
p := realGzipWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
zw, err := gzip.NewWriterLevel(w, level)
if err != nil {
// gzip.NewWriterLevel only errors for invalid
// compression levels. Clamp it to be min or max.
if level < gzip.HuffmanOnly {
level = gzip.HuffmanOnly
} else {
level = gzip.BestCompression
}
zw, _ = gzip.NewWriterLevel(w, level)
}
return zw
}
zw := v.(*gzip.Writer)
zw.Reset(w)
return zw
}
func releaseRealGzipWriter(zw *gzip.Writer, level int) {
zw.Close()
nLevel := normalizeCompressLevel(level)
p := realGzipWriterPoolMap[nLevel]
p.Put(zw)
}
var (
stacklessGzipWriterPoolMap = newCompressWriterPoolMap()
realGzipWriterPoolMap = newCompressWriterPoolMap()
)
// AppendGzipBytesLevel appends gzipped src to dst using the given
// compression level and returns the resulting dst.
//
// Supported compression levels are:
//
// - CompressNoCompression
// - CompressBestSpeed
// - CompressBestCompression
// - CompressDefaultCompression
// - CompressHuffmanOnly
func AppendGzipBytesLevel(dst, src []byte, level int) []byte {
w := &byteSliceWriter{b: dst}
WriteGzipLevel(w, src, level) //nolint:errcheck
return w.b
}
// WriteGzipLevel writes gzipped p to w using the given compression level
// and returns the number of compressed bytes written to w.
//
// Supported compression levels are:
//
// - CompressNoCompression
// - CompressBestSpeed
// - CompressBestCompression
// - CompressDefaultCompression
// - CompressHuffmanOnly
func WriteGzipLevel(w io.Writer, p []byte, level int) (int, error) {
switch w.(type) {
case *byteSliceWriter,
*bytes.Buffer,
*bytebufferpool.ByteBuffer:
// These writers don't block, so we can just use stacklessWriteGzip
ctx := &compressCtx{
w: w,
p: p,
level: level,
}
stacklessWriteGzip(ctx)
return len(p), nil
default:
zw := acquireStacklessGzipWriter(w, level)
n, err := zw.Write(p)
releaseStacklessGzipWriter(zw, level)
return n, err
}
}
var (
stacklessWriteGzipOnce sync.Once
stacklessWriteGzipFunc func(ctx any) bool
)
func stacklessWriteGzip(ctx any) {
stacklessWriteGzipOnce.Do(func() {
stacklessWriteGzipFunc = stackless.NewFunc(nonblockingWriteGzip)
})
stacklessWriteGzipFunc(ctx)
}
func nonblockingWriteGzip(ctxv any) {
ctx := ctxv.(*compressCtx)
zw := acquireRealGzipWriter(ctx.w, ctx.level)
zw.Write(ctx.p) //nolint:errcheck // no way to handle this error anyway
releaseRealGzipWriter(zw, ctx.level)
}
// WriteGzip writes gzipped p to w and returns the number of compressed
// bytes written to w.
func WriteGzip(w io.Writer, p []byte) (int, error) {
return WriteGzipLevel(w, p, CompressDefaultCompression)
}
// AppendGzipBytes appends gzipped src to dst and returns the resulting dst.
func AppendGzipBytes(dst, src []byte) []byte {
return AppendGzipBytesLevel(dst, src, CompressDefaultCompression)
}
// WriteGunzip writes ungzipped p to w and returns the number of uncompressed
// bytes written to w.
func WriteGunzip(w io.Writer, p []byte) (int, error) {
r := &byteSliceReader{b: p}
zr, err := acquireGzipReader(r)
if err != nil {
return 0, err
}
n, err := copyZeroAlloc(w, zr)
releaseGzipReader(zr)
nn := int(n)
if int64(nn) != n {
return 0, fmt.Errorf("too much data gunzipped: %d", n)
}
return nn, err
}
// AppendGunzipBytes appends gunzipped src to dst and returns the resulting dst.
func AppendGunzipBytes(dst, src []byte) ([]byte, error) {
w := &byteSliceWriter{b: dst}
_, err := WriteGunzip(w, src)
return w.b, err
}
// AppendDeflateBytesLevel appends deflated src to dst using the given
// compression level and returns the resulting dst.
//
// Supported compression levels are:
//
// - CompressNoCompression
// - CompressBestSpeed
// - CompressBestCompression
// - CompressDefaultCompression
// - CompressHuffmanOnly
func AppendDeflateBytesLevel(dst, src []byte, level int) []byte {
w := &byteSliceWriter{b: dst}
WriteDeflateLevel(w, src, level) //nolint:errcheck
return w.b
}
// WriteDeflateLevel writes deflated p to w using the given compression level
// and returns the number of compressed bytes written to w.
//
// Supported compression levels are:
//
// - CompressNoCompression
// - CompressBestSpeed
// - CompressBestCompression
// - CompressDefaultCompression
// - CompressHuffmanOnly
func WriteDeflateLevel(w io.Writer, p []byte, level int) (int, error) {
switch w.(type) {
case *byteSliceWriter,
*bytes.Buffer,
*bytebufferpool.ByteBuffer:
// These writers don't block, so we can just use stacklessWriteDeflate
ctx := &compressCtx{
w: w,
p: p,
level: level,
}
stacklessWriteDeflate(ctx)
return len(p), nil
default:
zw := acquireStacklessDeflateWriter(w, level)
n, err := zw.Write(p)
releaseStacklessDeflateWriter(zw, level)
return n, err
}
}
var (
stacklessWriteDeflateOnce sync.Once
stacklessWriteDeflateFunc func(ctx any) bool
)
func stacklessWriteDeflate(ctx any) {
stacklessWriteDeflateOnce.Do(func() {
stacklessWriteDeflateFunc = stackless.NewFunc(nonblockingWriteDeflate)
})
stacklessWriteDeflateFunc(ctx)
}
func nonblockingWriteDeflate(ctxv any) {
ctx := ctxv.(*compressCtx)
zw := acquireRealDeflateWriter(ctx.w, ctx.level)
zw.Write(ctx.p) //nolint:errcheck // no way to handle this error anyway
releaseRealDeflateWriter(zw, ctx.level)
}
type compressCtx struct {
w io.Writer
p []byte
level int
}
// WriteDeflate writes deflated p to w and returns the number of compressed
// bytes written to w.
func WriteDeflate(w io.Writer, p []byte) (int, error) {
return WriteDeflateLevel(w, p, CompressDefaultCompression)
}
// AppendDeflateBytes appends deflated src to dst and returns the resulting dst.
func AppendDeflateBytes(dst, src []byte) []byte {
return AppendDeflateBytesLevel(dst, src, CompressDefaultCompression)
}
// WriteInflate writes inflated p to w and returns the number of uncompressed
// bytes written to w.
func WriteInflate(w io.Writer, p []byte) (int, error) {
r := &byteSliceReader{b: p}
zr, err := acquireFlateReader(r)
if err != nil {
return 0, err
}
n, err := copyZeroAlloc(w, zr)
releaseFlateReader(zr)
nn := int(n)
if int64(nn) != n {
return 0, fmt.Errorf("too much data inflated: %d", n)
}
return nn, err
}
// AppendInflateBytes appends inflated src to dst and returns the resulting dst.
func AppendInflateBytes(dst, src []byte) ([]byte, error) {
w := &byteSliceWriter{b: dst}
_, err := WriteInflate(w, src)
return w.b, err
}
type byteSliceWriter struct {
b []byte
}
func (w *byteSliceWriter) Write(p []byte) (int, error) {
w.b = append(w.b, p...)
return len(p), nil
}
func (w *byteSliceWriter) WriteString(s string) (int, error) {
w.b = append(w.b, s...)
return len(s), nil
}
type byteSliceReader struct {
b []byte
}
func (r *byteSliceReader) Read(p []byte) (int, error) {
if len(r.b) == 0 {
return 0, io.EOF
}
n := copy(p, r.b)
r.b = r.b[n:]
return n, nil
}
func (r *byteSliceReader) ReadByte() (byte, error) {
if len(r.b) == 0 {
return 0, io.EOF
}
n := r.b[0]
r.b = r.b[1:]
return n, nil
}
func acquireStacklessDeflateWriter(w io.Writer, level int) stackless.Writer {
nLevel := normalizeCompressLevel(level)
p := stacklessDeflateWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
return stackless.NewWriter(w, func(w io.Writer) stackless.Writer {
return acquireRealDeflateWriter(w, level)
})
}
sw := v.(stackless.Writer)
sw.Reset(w)
return sw
}
func releaseStacklessDeflateWriter(sw stackless.Writer, level int) {
sw.Close()
nLevel := normalizeCompressLevel(level)
p := stacklessDeflateWriterPoolMap[nLevel]
p.Put(sw)
}
func acquireRealDeflateWriter(w io.Writer, level int) *zlib.Writer {
nLevel := normalizeCompressLevel(level)
p := realDeflateWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
zw, err := zlib.NewWriterLevel(w, level)
if err != nil {
// zlib.NewWriterLevel only errors for invalid
// compression levels. Clamp it to be min or max.
if level < zlib.HuffmanOnly {
level = zlib.HuffmanOnly
} else {
level = zlib.BestCompression
}
zw, _ = zlib.NewWriterLevel(w, level)
}
return zw
}
zw := v.(*zlib.Writer)
zw.Reset(w)
return zw
}
func releaseRealDeflateWriter(zw *zlib.Writer, level int) {
zw.Close()
nLevel := normalizeCompressLevel(level)
p := realDeflateWriterPoolMap[nLevel]
p.Put(zw)
}
var (
stacklessDeflateWriterPoolMap = newCompressWriterPoolMap()
realDeflateWriterPoolMap = newCompressWriterPoolMap()
)
func newCompressWriterPoolMap() []*sync.Pool {
// Initialize pools for all the compression levels defined
// in https://pkg.go.dev/compress/flate#pkg-constants .
// Compression levels are normalized with normalizeCompressLevel,
// so the fit [0..11].
var m []*sync.Pool
for i := 0; i < 12; i++ {
m = append(m, &sync.Pool{})
}
return m
}
func isFileCompressible(f fs.File, minCompressRatio float64) bool {
// Try compressing the first 4kb of the file
// and see if it can be compressed by more than
// the given minCompressRatio.
b := bytebufferpool.Get()
zw := acquireStacklessGzipWriter(b, CompressDefaultCompression)
lr := &io.LimitedReader{
R: f,
N: 4096,
}
_, err := copyZeroAlloc(zw, lr)
releaseStacklessGzipWriter(zw, CompressDefaultCompression)
seeker, ok := f.(io.Seeker)
if !ok {
return false
}
seeker.Seek(0, io.SeekStart) //nolint:errcheck
if err != nil {
return false
}
n := 4096 - lr.N
zn := len(b.B)
bytebufferpool.Put(b)
return float64(zn) < float64(n)*minCompressRatio
}
// normalizes compression level into [0..11], so it could be used as an index
// in *PoolMap.
func normalizeCompressLevel(level int) int {
// -2 is the lowest compression level - CompressHuffmanOnly
// 9 is the highest compression level - CompressBestCompression
if level < -2 || level > 9 {
level = CompressDefaultCompression
}
return level + 2
}
package fasthttp
import (
"bytes"
"errors"
"io"
"sync"
"time"
)
var zeroTime time.Time
var (
// CookieExpireDelete may be set on Cookie.Expire for expiring the given cookie.
CookieExpireDelete = time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
// CookieExpireUnlimited indicates that the cookie doesn't expire.
CookieExpireUnlimited = zeroTime
)
// CookieSameSite is an enum for the mode in which the SameSite flag should be set for the given cookie.
// See https://tools.ietf.org/html/draft-ietf-httpbis-cookie-same-site-00 for details.
type CookieSameSite int
const (
// CookieSameSiteDisabled removes the SameSite flag.
CookieSameSiteDisabled CookieSameSite = iota
// CookieSameSiteDefaultMode sets the SameSite flag.
CookieSameSiteDefaultMode
// CookieSameSiteLaxMode sets the SameSite flag with the "Lax" parameter.
CookieSameSiteLaxMode
// CookieSameSiteStrictMode sets the SameSite flag with the "Strict" parameter.
CookieSameSiteStrictMode
// CookieSameSiteNoneMode sets the SameSite flag with the "None" parameter.
// See https://tools.ietf.org/html/draft-west-cookie-incrementalism-00
CookieSameSiteNoneMode // third-party cookies are phasing out, use Partitioned cookies instead
)
// AcquireCookie returns an empty Cookie object from the pool.
//
// The returned object may be returned back to the pool with ReleaseCookie.
// This allows reducing GC load.
func AcquireCookie() *Cookie {
return cookiePool.Get().(*Cookie)
}
// ReleaseCookie returns the Cookie object acquired with AcquireCookie back
// to the pool.
//
// Do not access released Cookie object, otherwise data races may occur.
func ReleaseCookie(c *Cookie) {
c.Reset()
cookiePool.Put(c)
}
var cookiePool = &sync.Pool{
New: func() any {
return &Cookie{}
},
}
// Cookie represents HTTP response cookie.
//
// Do not copy Cookie objects. Create new object and use CopyTo instead.
//
// Cookie instance MUST NOT be used from concurrently running goroutines.
type Cookie struct {
noCopy noCopy
expire time.Time
key []byte
value []byte
domain []byte
path []byte
bufK []byte
bufV []byte
// maxAge=0 means no 'max-age' attribute specified.
// maxAge<0 means delete cookie now, equivalently 'max-age=0'
// maxAge>0 means 'max-age' attribute present and given in seconds
maxAge int
sameSite CookieSameSite
httpOnly bool
secure bool
partitioned bool
}
// CopyTo copies src cookie to c.
func (c *Cookie) CopyTo(src *Cookie) {
c.Reset()
c.key = append(c.key, src.key...)
c.value = append(c.value, src.value...)
c.expire = src.expire
c.maxAge = src.maxAge
c.domain = append(c.domain, src.domain...)
c.path = append(c.path, src.path...)
c.httpOnly = src.httpOnly
c.secure = src.secure
c.sameSite = src.sameSite
c.partitioned = src.partitioned
}
// HTTPOnly returns true if the cookie is http only.
func (c *Cookie) HTTPOnly() bool {
return c.httpOnly
}
// SetHTTPOnly sets cookie's httpOnly flag to the given value.
func (c *Cookie) SetHTTPOnly(httpOnly bool) {
c.httpOnly = httpOnly
}
// Secure returns true if the cookie is secure.
func (c *Cookie) Secure() bool {
return c.secure
}
// SetSecure sets cookie's secure flag to the given value.
func (c *Cookie) SetSecure(secure bool) {
c.secure = secure
}
// SameSite returns the SameSite mode.
func (c *Cookie) SameSite() CookieSameSite {
return c.sameSite
}
// SetSameSite sets the cookie's SameSite flag to the given value.
// Set value CookieSameSiteNoneMode will set Secure to true also to avoid browser rejection.
func (c *Cookie) SetSameSite(mode CookieSameSite) {
c.sameSite = mode
if mode == CookieSameSiteNoneMode {
c.SetSecure(true)
}
}
// Partitioned returns true if the cookie is partitioned.
func (c *Cookie) Partitioned() bool {
return c.partitioned
}
// SetPartitioned sets the cookie's Partitioned flag to the given value.
// Set value Partitioned to true will set Secure to true and Path to / also to avoid browser rejection.
func (c *Cookie) SetPartitioned(partitioned bool) {
c.partitioned = partitioned
if partitioned {
c.SetSecure(true)
c.SetPath("/")
}
}
// Path returns cookie path.
func (c *Cookie) Path() []byte {
return c.path
}
// SetPath sets cookie path.
func (c *Cookie) SetPath(path string) {
c.bufK = append(c.bufK[:0], path...)
c.path = normalizePath(c.path, c.bufK)
}
// SetPathBytes sets cookie path.
func (c *Cookie) SetPathBytes(path []byte) {
c.bufK = append(c.bufK[:0], path...)
c.path = normalizePath(c.path, c.bufK)
}
// Domain returns cookie domain.
//
// The returned value is valid until the Cookie reused or released (ReleaseCookie).
// Do not store references to the returned value. Make copies instead.
func (c *Cookie) Domain() []byte {
return c.domain
}
// SetDomain sets cookie domain.
func (c *Cookie) SetDomain(domain string) {
c.domain = append(c.domain[:0], domain...)
}
// SetDomainBytes sets cookie domain.
func (c *Cookie) SetDomainBytes(domain []byte) {
c.domain = append(c.domain[:0], domain...)
}
// MaxAge returns the seconds until the cookie is meant to expire or 0
// if no max age.
func (c *Cookie) MaxAge() int {
return c.maxAge
}
// SetMaxAge sets cookie expiration time based on seconds. This takes precedence
// over any absolute expiry set on the cookie.
//
// 'max-age' is set when the maxAge is non-zero. That is, if maxAge = 0,
// the 'max-age' is unset. If maxAge < 0, it indicates that the cookie should
// be deleted immediately, equivalent to 'max-age=0'. This behavior is
// consistent with the Go standard library's net/http package.
func (c *Cookie) SetMaxAge(seconds int) {
c.maxAge = seconds
}
// Expire returns cookie expiration time.
//
// CookieExpireUnlimited is returned if cookie doesn't expire.
func (c *Cookie) Expire() time.Time {
expire := c.expire
if expire.IsZero() {
expire = CookieExpireUnlimited
}
return expire
}
// SetExpire sets cookie expiration time.
//
// Set expiration time to CookieExpireDelete for expiring (deleting)
// the cookie on the client.
//
// By default cookie lifetime is limited by browser session.
func (c *Cookie) SetExpire(expire time.Time) {
c.expire = expire
}
// Value returns cookie value.
//
// The returned value is valid until the Cookie reused or released (ReleaseCookie).
// Do not store references to the returned value. Make copies instead.
func (c *Cookie) Value() []byte {
return c.value
}
// SetValue sets cookie value.
func (c *Cookie) SetValue(value string) {
c.value = append(c.value[:0], value...)
}
// SetValueBytes sets cookie value.
func (c *Cookie) SetValueBytes(value []byte) {
c.value = append(c.value[:0], value...)
}
// Key returns cookie name.
//
// The returned value is valid until the Cookie reused or released (ReleaseCookie).
// Do not store references to the returned value. Make copies instead.
func (c *Cookie) Key() []byte {
return c.key
}
// SetKey sets cookie name.
func (c *Cookie) SetKey(key string) {
c.key = append(c.key[:0], key...)
}
// SetKeyBytes sets cookie name.
func (c *Cookie) SetKeyBytes(key []byte) {
c.key = append(c.key[:0], key...)
}
// Reset clears the cookie.
func (c *Cookie) Reset() {
c.key = c.key[:0]
c.value = c.value[:0]
c.expire = zeroTime
c.maxAge = 0
c.domain = c.domain[:0]
c.path = c.path[:0]
c.httpOnly = false
c.secure = false
c.sameSite = CookieSameSiteDisabled
c.partitioned = false
}
// AppendBytes appends cookie representation to dst and returns
// the extended dst.
func (c *Cookie) AppendBytes(dst []byte) []byte {
if len(c.key) > 0 {
dst = append(dst, c.key...)
dst = append(dst, '=')
}
dst = append(dst, c.value...)
if c.maxAge != 0 {
dst = append(dst, ';', ' ')
dst = append(dst, strCookieMaxAge...)
dst = append(dst, '=')
if c.maxAge < 0 {
// See https://github.com/valyala/fasthttp/issues/1900
dst = AppendUint(dst, 0)
} else {
dst = AppendUint(dst, c.maxAge)
}
} else if !c.expire.IsZero() {
c.bufV = AppendHTTPDate(c.bufV[:0], c.expire)
dst = append(dst, ';', ' ')
dst = append(dst, strCookieExpires...)
dst = append(dst, '=')
dst = append(dst, c.bufV...)
}
if len(c.domain) > 0 {
dst = appendCookiePart(dst, strCookieDomain, c.domain)
}
if len(c.path) > 0 {
dst = appendCookiePart(dst, strCookiePath, c.path)
}
if c.httpOnly {
dst = append(dst, ';', ' ')
dst = append(dst, strCookieHTTPOnly...)
}
if c.secure {
dst = append(dst, ';', ' ')
dst = append(dst, strCookieSecure...)
}
switch c.sameSite {
case CookieSameSiteDefaultMode:
dst = append(dst, ';', ' ')
dst = append(dst, strCookieSameSite...)
case CookieSameSiteLaxMode:
dst = append(dst, ';', ' ')
dst = append(dst, strCookieSameSite...)
dst = append(dst, '=')
dst = append(dst, strCookieSameSiteLax...)
case CookieSameSiteStrictMode:
dst = append(dst, ';', ' ')
dst = append(dst, strCookieSameSite...)
dst = append(dst, '=')
dst = append(dst, strCookieSameSiteStrict...)
case CookieSameSiteNoneMode:
dst = append(dst, ';', ' ')
dst = append(dst, strCookieSameSite...)
dst = append(dst, '=')
dst = append(dst, strCookieSameSiteNone...)
}
if c.partitioned {
dst = append(dst, ';', ' ')
dst = append(dst, strCookiePartitioned...)
}
return dst
}
// Cookie returns cookie representation.
//
// The returned value is valid until the Cookie reused or released (ReleaseCookie).
// Do not store references to the returned value. Make copies instead.
func (c *Cookie) Cookie() []byte {
c.bufK = c.AppendBytes(c.bufK[:0])
return c.bufK
}
// String returns cookie representation.
func (c *Cookie) String() string {
return string(c.Cookie())
}
// WriteTo writes cookie representation to w.
//
// WriteTo implements io.WriterTo interface.
func (c *Cookie) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(c.Cookie())
return int64(n), err
}
var errNoCookies = errors.New("no cookies found")
// Parse parses Set-Cookie header.
func (c *Cookie) Parse(src string) error {
c.bufK = append(c.bufK[:0], src...)
return c.ParseBytes(c.bufK)
}
// ParseBytes parses Set-Cookie header.
func (c *Cookie) ParseBytes(src []byte) error {
c.Reset()
var s cookieScanner
s.b = src
if !s.next(&c.bufK, &c.bufV) {
return errNoCookies
}
c.key = append(c.key, c.bufK...)
c.value = append(c.value, c.bufV...)
for s.next(&c.bufK, &c.bufV) {
if len(c.bufK) != 0 {
// Case insensitive switch on first char
switch c.bufK[0] | 0x20 {
case 'm':
if caseInsensitiveCompare(strCookieMaxAge, c.bufK) {
maxAge, err := ParseUint(c.bufV)
if err != nil {
return err
}
c.maxAge = maxAge
}
case 'e': // "expires"
if caseInsensitiveCompare(strCookieExpires, c.bufK) {
v := b2s(c.bufV)
// Try the same two formats as net/http
// See: https://github.com/golang/go/blob/00379be17e63a5b75b3237819392d2dc3b313a27/src/net/http/cookie.go#L133-L135
exptime, err := time.ParseInLocation(time.RFC1123, v, time.UTC)
if err != nil {
exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", v)
if err != nil {
return err
}
}
c.expire = exptime
}
case 'd': // "domain"
if caseInsensitiveCompare(strCookieDomain, c.bufK) {
c.domain = append(c.domain, c.bufV...)
}
case 'p': // "path"
if caseInsensitiveCompare(strCookiePath, c.bufK) {
c.path = append(c.path, c.bufV...)
}
case 's': // "samesite"
if caseInsensitiveCompare(strCookieSameSite, c.bufK) {
if len(c.bufV) > 0 {
// Case insensitive switch on first char
switch c.bufV[0] | 0x20 {
case 'l': // "lax"
if caseInsensitiveCompare(strCookieSameSiteLax, c.bufV) {
c.sameSite = CookieSameSiteLaxMode
}
case 's': // "strict"
if caseInsensitiveCompare(strCookieSameSiteStrict, c.bufV) {
c.sameSite = CookieSameSiteStrictMode
}
case 'n': // "none"
if caseInsensitiveCompare(strCookieSameSiteNone, c.bufV) {
c.sameSite = CookieSameSiteNoneMode
}
}
}
}
}
} else if len(c.bufV) != 0 {
// Case insensitive switch on first char
switch c.bufV[0] | 0x20 {
case 'h': // "httponly"
if caseInsensitiveCompare(strCookieHTTPOnly, c.bufV) {
c.httpOnly = true
}
case 's': // "secure"
if caseInsensitiveCompare(strCookieSecure, c.bufV) {
c.secure = true
} else if caseInsensitiveCompare(strCookieSameSite, c.bufV) {
c.sameSite = CookieSameSiteDefaultMode
}
case 'p': // "partitioned"
if caseInsensitiveCompare(strCookiePartitioned, c.bufV) {
c.partitioned = true
}
}
} // else empty or no match
}
return nil
}
func appendCookiePart(dst, key, value []byte) []byte {
dst = append(dst, ';', ' ')
dst = append(dst, key...)
dst = append(dst, '=')
return append(dst, value...)
}
func getCookieKey(dst, src []byte) []byte {
n := bytes.IndexByte(src, '=')
if n >= 0 {
src = src[:n]
}
return decodeCookieArg(dst, src, false)
}
func appendRequestCookieBytes(dst []byte, cookies []argsKV) []byte {
for i, n := 0, len(cookies); i < n; i++ {
kv := &cookies[i]
if len(kv.key) > 0 {
dst = append(dst, kv.key...)
dst = append(dst, '=')
}
dst = append(dst, kv.value...)
if i+1 < n {
dst = append(dst, ';', ' ')
}
}
return dst
}
// For Response we can not use the above function as response cookies
// already contain the key= in the value.
func appendResponseCookieBytes(dst []byte, cookies []argsKV) []byte {
for i, n := 0, len(cookies); i < n; i++ {
kv := &cookies[i]
dst = append(dst, kv.value...)
if i+1 < n {
dst = append(dst, ';', ' ')
}
}
return dst
}
func parseRequestCookies(cookies []argsKV, src []byte) []argsKV {
var s cookieScanner
s.b = src
var kv *argsKV
cookies, kv = allocArg(cookies)
for s.next(&kv.key, &kv.value) {
if len(kv.key) > 0 || len(kv.value) > 0 {
cookies, kv = allocArg(cookies)
}
}
return releaseArg(cookies)
}
type cookieScanner struct {
b []byte
}
func (s *cookieScanner) next(key, val *[]byte) bool {
b := s.b
if len(b) == 0 {
return false
}
isKey := true
k := 0
for i, c := range b {
switch c {
case '=':
if isKey {
isKey = false
*key = decodeCookieArg(*key, b[:i], false)
k = i + 1
}
case ';':
if isKey {
*key = (*key)[:0]
}
*val = decodeCookieArg(*val, b[k:i], true)
s.b = b[i+1:]
return true
}
}
if isKey {
*key = (*key)[:0]
}
*val = decodeCookieArg(*val, b[k:], true)
s.b = b[len(b):]
return true
}
func decodeCookieArg(dst, src []byte, skipQuotes bool) []byte {
for len(src) > 0 && src[0] == ' ' {
src = src[1:]
}
for len(src) > 0 && src[len(src)-1] == ' ' {
src = src[:len(src)-1]
}
if skipQuotes {
if len(src) > 1 && src[0] == '"' && src[len(src)-1] == '"' {
src = src[1 : len(src)-1]
}
}
return append(dst[:0], src...)
}
// caseInsensitiveCompare does a case insensitive equality comparison of
// two []byte. Assumes only letters need to be matched.
func caseInsensitiveCompare(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
if a[i]|0x20 != b[i]|0x20 {
return false
}
}
return true
}
package fasthttp
import (
"bytes"
"errors"
"fmt"
"html"
"io"
"io/fs"
"mime"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/gzip"
"github.com/klauspost/compress/zstd"
"github.com/valyala/bytebufferpool"
)
// ServeFileBytesUncompressed returns HTTP response containing file contents
// from the given path.
//
// Directory contents is returned if path points to directory.
//
// ServeFileBytes may be used for saving network traffic when serving files
// with good compression ratio.
//
// See also RequestCtx.SendFileBytes.
//
// WARNING: do not pass any user supplied paths to this function!
// WARNING: if path is based on user input users will be able to request
// any file on your filesystem! Use fasthttp.FS with a sane Root instead.
func ServeFileBytesUncompressed(ctx *RequestCtx, path []byte) {
ServeFileUncompressed(ctx, b2s(path))
}
// ServeFileUncompressed returns HTTP response containing file contents
// from the given path.
//
// Directory contents is returned if path points to directory.
//
// ServeFile may be used for saving network traffic when serving files
// with good compression ratio.
//
// See also RequestCtx.SendFile.
//
// WARNING: do not pass any user supplied paths to this function!
// WARNING: if path is based on user input users will be able to request
// any file on your filesystem! Use fasthttp.FS with a sane Root instead.
func ServeFileUncompressed(ctx *RequestCtx, path string) {
ctx.Request.Header.DelBytes(strAcceptEncoding)
ServeFile(ctx, path)
}
// ServeFileBytes returns HTTP response containing compressed file contents
// from the given path.
//
// HTTP response may contain uncompressed file contents in the following cases:
//
// - Missing 'Accept-Encoding: gzip' request header.
// - No write access to directory containing the file.
//
// Directory contents is returned if path points to directory.
//
// Use ServeFileBytesUncompressed is you don't need serving compressed
// file contents.
//
// See also RequestCtx.SendFileBytes.
//
// WARNING: do not pass any user supplied paths to this function!
// WARNING: if path is based on user input users will be able to request
// any file on your filesystem! Use fasthttp.FS with a sane Root instead.
func ServeFileBytes(ctx *RequestCtx, path []byte) {
ServeFile(ctx, b2s(path))
}
// ServeFile returns HTTP response containing compressed file contents
// from the given path.
//
// HTTP response may contain uncompressed file contents in the following cases:
//
// - Missing 'Accept-Encoding: gzip' request header.
// - No write access to directory containing the file.
//
// Directory contents is returned if path points to directory.
//
// Use ServeFileUncompressed is you don't need serving compressed file contents.
//
// See also RequestCtx.SendFile.
//
// WARNING: do not pass any user supplied paths to this function!
// WARNING: if path is based on user input users will be able to request
// any file on your filesystem! Use fasthttp.FS with a sane Root instead.
func ServeFile(ctx *RequestCtx, path string) {
rootFSOnce.Do(func() {
rootFSHandler = rootFS.NewRequestHandler()
})
if path == "" || !filepath.IsAbs(path) {
// extend relative path to absolute path
hasTrailingSlash := path != "" && (path[len(path)-1] == '/' || path[len(path)-1] == '\\')
var err error
path = filepath.FromSlash(path)
if path, err = filepath.Abs(path); err != nil {
ctx.Logger().Printf("cannot resolve path %q to absolute file path: %v", path, err)
ctx.Error("Internal Server Error", StatusInternalServerError)
return
}
if hasTrailingSlash {
path += "/"
}
}
// convert the path to forward slashes regardless the OS in order to set the URI properly
// the handler will convert back to OS path separator before opening the file
path = filepath.ToSlash(path)
ctx.Request.SetRequestURI(path)
rootFSHandler(ctx)
}
var (
rootFSOnce sync.Once
rootFS = &FS{
Root: "",
AllowEmptyRoot: true,
GenerateIndexPages: true,
Compress: true,
CompressBrotli: true,
CompressZstd: true,
AcceptByteRange: true,
}
rootFSHandler RequestHandler
)
// ServeFS returns HTTP response containing compressed file contents from the given fs.FS's path.
//
// HTTP response may contain uncompressed file contents in the following cases:
//
// - Missing 'Accept-Encoding: gzip' request header.
// - No write access to directory containing the file.
//
// Directory contents is returned if path points to directory.
//
// See also ServeFile.
func ServeFS(ctx *RequestCtx, filesystem fs.FS, path string) {
f := &FS{
FS: filesystem,
Root: "",
AllowEmptyRoot: true,
GenerateIndexPages: true,
Compress: true,
CompressBrotli: true,
CompressZstd: true,
AcceptByteRange: true,
}
handler := f.NewRequestHandler()
ctx.Request.SetRequestURI(path)
handler(ctx)
}
// PathRewriteFunc must return new request path based on arbitrary ctx
// info such as ctx.Path().
//
// Path rewriter is used in FS for translating the current request
// to the local filesystem path relative to FS.Root.
//
// The returned path must not contain '/../' substrings due to security reasons,
// since such paths may refer files outside FS.Root.
//
// The returned path may refer to ctx members. For example, ctx.Path().
type PathRewriteFunc func(ctx *RequestCtx) []byte
// NewVHostPathRewriter returns path rewriter, which strips slashesCount
// leading slashes from the path and prepends the path with request's host,
// thus simplifying virtual hosting for static files.
//
// Examples:
//
// - host=foobar.com, slashesCount=0, original path="/foo/bar".
// Resulting path: "/foobar.com/foo/bar"
//
// - host=img.aaa.com, slashesCount=1, original path="/images/123/456.jpg"
// Resulting path: "/img.aaa.com/123/456.jpg"
func NewVHostPathRewriter(slashesCount int) PathRewriteFunc {
return func(ctx *RequestCtx) []byte {
path := stripLeadingSlashes(ctx.Path(), slashesCount)
host := ctx.Host()
if n := bytes.IndexByte(host, '/'); n >= 0 {
host = nil
}
if len(host) == 0 {
host = strInvalidHost
}
b := bytebufferpool.Get()
b.B = append(b.B, '/')
b.B = append(b.B, host...)
b.B = append(b.B, path...)
ctx.URI().SetPathBytes(b.B)
bytebufferpool.Put(b)
return ctx.Path()
}
}
var strInvalidHost = []byte("invalid-host")
// NewPathSlashesStripper returns path rewriter, which strips slashesCount
// leading slashes from the path.
//
// Examples:
//
// - slashesCount = 0, original path: "/foo/bar", result: "/foo/bar"
// - slashesCount = 1, original path: "/foo/bar", result: "/bar"
// - slashesCount = 2, original path: "/foo/bar", result: ""
//
// The returned path rewriter may be used as FS.PathRewrite .
func NewPathSlashesStripper(slashesCount int) PathRewriteFunc {
return func(ctx *RequestCtx) []byte {
return stripLeadingSlashes(ctx.Path(), slashesCount)
}
}
// NewPathPrefixStripper returns path rewriter, which removes prefixSize bytes
// from the path prefix.
//
// Examples:
//
// - prefixSize = 0, original path: "/foo/bar", result: "/foo/bar"
// - prefixSize = 3, original path: "/foo/bar", result: "o/bar"
// - prefixSize = 7, original path: "/foo/bar", result: "r"
//
// The returned path rewriter may be used as FS.PathRewrite .
func NewPathPrefixStripper(prefixSize int) PathRewriteFunc {
return func(ctx *RequestCtx) []byte {
path := ctx.Path()
if len(path) >= prefixSize {
path = path[prefixSize:]
}
return path
}
}
// FS represents settings for request handler serving static files
// from the local filesystem.
//
// It is prohibited copying FS values. Create new values instead.
type FS struct {
noCopy noCopy
// FS is filesystem to serve files from. eg: embed.FS os.DirFS
FS fs.FS
// Path rewriting function.
//
// By default request path is not modified.
PathRewrite PathRewriteFunc
// PathNotFound fires when file is not found in filesystem
// this functions tries to replace "Cannot open requested path"
// server response giving to the programmer the control of server flow.
//
// By default PathNotFound returns
// "Cannot open requested path"
PathNotFound RequestHandler
// Suffixes list to add to compressedFileSuffix depending on encoding
//
// This value has sense only if Compress is set.
//
// FSCompressedFileSuffixes is used by default.
CompressedFileSuffixes map[string]string
// If CleanStop is set, the channel can be closed to stop the cleanup handlers
// for the FS RequestHandlers created with NewRequestHandler.
// NEVER close this channel while the handler is still being used!
CleanStop chan struct{}
h RequestHandler
// Path to the root directory to serve files from.
Root string
// Path to the compressed root directory to serve files from. If this value
// is empty, Root is used.
CompressRoot string
// Suffix to add to the name of cached compressed file.
//
// This value has sense only if Compress is set.
//
// FSCompressedFileSuffix is used by default.
CompressedFileSuffix string
// List of index file names to try opening during directory access.
//
// For example:
//
// * index.html
// * index.htm
// * my-super-index.xml
//
// By default the list is empty.
IndexNames []string
// Expiration duration for inactive file handlers.
//
// FSHandlerCacheDuration is used by default.
CacheDuration time.Duration
once sync.Once
// AllowEmptyRoot controls what happens when Root is empty. When false (default) it will default to the
// current working directory. An empty root is mostly useful when you want to use absolute paths
// on windows that are on different filesystems. On linux setting your Root to "/" already allows you to use
// absolute paths on any filesystem.
AllowEmptyRoot bool
// Uses brotli encoding and fallbacks to zstd or gzip in responses if set to true, uses zstd or gzip if set to false.
//
// This value has sense only if Compress is set.
//
// Brotli encoding is disabled by default.
CompressBrotli bool
// Uses zstd encoding and fallbacks to gzip in responses if set to true, uses gzip if set to false.
//
// This value has sense only if Compress is set.
//
// zstd encoding is disabled by default.
CompressZstd bool
// Index pages for directories without files matching IndexNames
// are automatically generated if set.
//
// Directory index generation may be quite slow for directories
// with many files (more than 1K), so it is discouraged enabling
// index pages' generation for such directories.
//
// By default index pages aren't generated.
GenerateIndexPages bool
// Transparently compresses responses if set to true.
//
// The server tries minimizing CPU usage by caching compressed files.
// It adds CompressedFileSuffix suffix to the original file name and
// tries saving the resulting compressed file under the new file name.
// So it is advisable to give the server write access to Root
// and to all inner folders in order to minimize CPU usage when serving
// compressed responses.
//
// Transparent compression is disabled by default.
Compress bool
// Enables byte range requests if set to true.
//
// Byte range requests are disabled by default.
AcceptByteRange bool
// SkipCache if true, will cache no file handler.
//
// By default is false.
SkipCache bool
}
// FSCompressedFileSuffix is the suffix FS adds to the original file names
// when trying to store compressed file under the new file name.
// See FS.Compress for details.
const FSCompressedFileSuffix = ".fasthttp.gz"
// FSCompressedFileSuffixes is the suffixes FS adds to the original file names depending on encoding
// when trying to store compressed file under the new file name.
// See FS.Compress for details.
var FSCompressedFileSuffixes = map[string]string{
"gzip": ".fasthttp.gz",
"br": ".fasthttp.br",
"zstd": ".fasthttp.zst",
}
// FSHandlerCacheDuration is the default expiration duration for inactive
// file handlers opened by FS.
const FSHandlerCacheDuration = 10 * time.Second
// FSHandler returns request handler serving static files from
// the given root folder.
//
// stripSlashes indicates how many leading slashes must be stripped
// from requested path before searching requested file in the root folder.
// Examples:
//
// - stripSlashes = 0, original path: "/foo/bar", result: "/foo/bar"
// - stripSlashes = 1, original path: "/foo/bar", result: "/bar"
// - stripSlashes = 2, original path: "/foo/bar", result: ""
//
// The returned request handler automatically generates index pages
// for directories without index.html.
//
// The returned handler caches requested file handles
// for FSHandlerCacheDuration.
// Make sure your program has enough 'max open files' limit aka
// 'ulimit -n' if root folder contains many files.
//
// Do not create multiple request handler instances for the same
// (root, stripSlashes) arguments - just reuse a single instance.
// Otherwise goroutine leak will occur.
func FSHandler(root string, stripSlashes int) RequestHandler {
fs := &FS{
Root: root,
IndexNames: []string{"index.html"},
GenerateIndexPages: true,
AcceptByteRange: true,
}
if stripSlashes > 0 {
fs.PathRewrite = NewPathSlashesStripper(stripSlashes)
}
return fs.NewRequestHandler()
}
// NewRequestHandler returns new request handler with the given FS settings.
//
// The returned handler caches requested file handles
// for FS.CacheDuration.
// Make sure your program has enough 'max open files' limit aka
// 'ulimit -n' if FS.Root folder contains many files.
//
// Do not create multiple request handlers from a single FS instance -
// just reuse a single request handler.
func (fs *FS) NewRequestHandler() RequestHandler {
fs.once.Do(fs.initRequestHandler)
return fs.h
}
func (fs *FS) normalizeRoot(root string) string {
// fs.FS uses relative paths, that paths are slash-separated on all systems, even Windows.
if fs.FS == nil {
// Serve files from the current working directory if Root is empty or if Root is a relative path.
if (!fs.AllowEmptyRoot && root == "") || (root != "" && !filepath.IsAbs(root)) {
path, err := os.Getwd()
if err != nil {
path = "."
}
root = path + "/" + root
}
// convert the root directory slashes to the native format
root = filepath.FromSlash(root)
}
// strip trailing slashes from the root path
for root != "" && root[len(root)-1] == os.PathSeparator {
root = root[:len(root)-1]
}
return root
}
func (fs *FS) initRequestHandler() {
root := fs.normalizeRoot(fs.Root)
compressRoot := fs.CompressRoot
if compressRoot == "" {
compressRoot = root
} else {
compressRoot = fs.normalizeRoot(compressRoot)
}
compressedFileSuffixes := fs.CompressedFileSuffixes
if compressedFileSuffixes["br"] == "" || compressedFileSuffixes["gzip"] == "" ||
compressedFileSuffixes["zstd"] == "" || compressedFileSuffixes["br"] == compressedFileSuffixes["gzip"] ||
compressedFileSuffixes["br"] == compressedFileSuffixes["zstd"] ||
compressedFileSuffixes["gzip"] == compressedFileSuffixes["zstd"] {
// Copy global map
compressedFileSuffixes = make(map[string]string, len(FSCompressedFileSuffixes))
for k, v := range FSCompressedFileSuffixes {
compressedFileSuffixes[k] = v
}
}
if fs.CompressedFileSuffix != "" {
compressedFileSuffixes["gzip"] = fs.CompressedFileSuffix
compressedFileSuffixes["br"] = FSCompressedFileSuffixes["br"]
compressedFileSuffixes["zstd"] = FSCompressedFileSuffixes["zstd"]
}
h := &fsHandler{
filesystem: fs.FS,
root: root,
indexNames: fs.IndexNames,
pathRewrite: fs.PathRewrite,
generateIndexPages: fs.GenerateIndexPages,
compress: fs.Compress,
compressBrotli: fs.CompressBrotli,
compressZstd: fs.CompressZstd,
compressRoot: compressRoot,
pathNotFound: fs.PathNotFound,
acceptByteRange: fs.AcceptByteRange,
compressedFileSuffixes: compressedFileSuffixes,
}
h.cacheManager = newCacheManager(fs)
if h.filesystem == nil {
h.filesystem = &osFS{} // It provides os.Open and os.Stat
}
fs.h = h.handleRequest
}
type fsHandler struct {
smallFileReaderPool sync.Pool
filesystem fs.FS
cacheManager cacheManager
pathRewrite PathRewriteFunc
pathNotFound RequestHandler
compressedFileSuffixes map[string]string
root string
compressRoot string
indexNames []string
generateIndexPages bool
compress bool
compressBrotli bool
compressZstd bool
acceptByteRange bool
}
type fsFile struct {
lastModified time.Time
t time.Time
f fs.File
h *fsHandler
filename string // fs.FileInfo.Name() return filename, isn't filepath.
contentType string
dirIndex []byte
lastModifiedStr []byte
bigFiles []*bigFileReader
contentLength int
readersCount int
bigFilesLock sync.Mutex
compressed bool
}
func (ff *fsFile) NewReader() (io.Reader, error) {
if ff.isBig() {
r, err := ff.bigFileReader()
if err != nil {
ff.decReadersCount()
}
return r, err
}
return ff.smallFileReader()
}
func (ff *fsFile) smallFileReader() (io.Reader, error) {
v := ff.h.smallFileReaderPool.Get()
if v == nil {
v = &fsSmallFileReader{}
}
r := v.(*fsSmallFileReader)
r.ff = ff
r.endPos = ff.contentLength
if r.startPos > 0 {
return nil, errors.New("bug: fsSmallFileReader with non-nil startPos found in the pool")
}
return r, nil
}
// Files bigger than this size are sent with sendfile.
const maxSmallFileSize = 2 * 4096
func (ff *fsFile) isBig() bool {
if _, ok := ff.h.filesystem.(*osFS); !ok { // fs.FS only uses bigFileReader, memory cache uses fsSmallFileReader
return ff.f != nil
}
return ff.contentLength > maxSmallFileSize && len(ff.dirIndex) == 0
}
func (ff *fsFile) bigFileReader() (io.Reader, error) {
if ff.f == nil {
return nil, errors.New("bug: ff.f must be non-nil in bigFileReader")
}
var r io.Reader
ff.bigFilesLock.Lock()
n := len(ff.bigFiles)
if n > 0 {
r = ff.bigFiles[n-1]
ff.bigFiles = ff.bigFiles[:n-1]
}
ff.bigFilesLock.Unlock()
if r != nil {
return r, nil
}
f, err := ff.h.filesystem.Open(ff.filename)
if err != nil {
return nil, fmt.Errorf("cannot open already opened file: %w", err)
}
return &bigFileReader{
f: f,
ff: ff,
r: f,
}, nil
}
func (ff *fsFile) Release() {
if ff.f != nil {
_ = ff.f.Close()
if ff.isBig() {
ff.bigFilesLock.Lock()
for _, r := range ff.bigFiles {
_ = r.f.Close()
}
ff.bigFilesLock.Unlock()
}
}
}
func (ff *fsFile) decReadersCount() {
ff.h.cacheManager.WithLock(func() {
ff.readersCount--
if ff.readersCount < 0 {
ff.readersCount = 0
}
})
}
// bigFileReader attempts to trigger sendfile
// for sending big files over the wire.
type bigFileReader struct {
f fs.File
ff *fsFile
r io.Reader
lr io.LimitedReader
}
func (r *bigFileReader) UpdateByteRange(startPos, endPos int) error {
seeker, ok := r.f.(io.Seeker)
if !ok {
return errors.New("must implement io.Seeker")
}
if _, err := seeker.Seek(int64(startPos), io.SeekStart); err != nil {
return err
}
r.r = &r.lr
r.lr.R = r.f
r.lr.N = int64(endPos - startPos + 1)
return nil
}
func (r *bigFileReader) Read(p []byte) (int, error) {
return r.r.Read(p)
}
func (r *bigFileReader) WriteTo(w io.Writer) (int64, error) {
if rf, ok := w.(io.ReaderFrom); ok {
// fast path. Send file must be triggered
return rf.ReadFrom(r.r)
}
// slow path
return copyZeroAlloc(w, r.r)
}
func (r *bigFileReader) Close() error {
r.r = r.f
seeker, ok := r.f.(io.Seeker)
if !ok {
_ = r.f.Close()
return errors.New("must implement io.Seeker")
}
n, err := seeker.Seek(0, io.SeekStart)
if err == nil {
if n == 0 {
ff := r.ff
ff.bigFilesLock.Lock()
ff.bigFiles = append(ff.bigFiles, r)
ff.bigFilesLock.Unlock()
} else {
_ = r.f.Close()
err = errors.New("bug: File.Seek(0, io.SeekStart) returned (non-zero, nil)")
}
} else {
_ = r.f.Close()
}
r.ff.decReadersCount()
return err
}
type fsSmallFileReader struct {
ff *fsFile
startPos int
endPos int
}
func (r *fsSmallFileReader) Close() error {
ff := r.ff
ff.decReadersCount()
r.ff = nil
r.startPos = 0
r.endPos = 0
ff.h.smallFileReaderPool.Put(r)
return nil
}
func (r *fsSmallFileReader) UpdateByteRange(startPos, endPos int) error {
r.startPos = startPos
r.endPos = endPos + 1
return nil
}
func (r *fsSmallFileReader) Read(p []byte) (int, error) {
tailLen := r.endPos - r.startPos
if tailLen <= 0 {
return 0, io.EOF
}
if len(p) > tailLen {
p = p[:tailLen]
}
ff := r.ff
if ff.f != nil {
ra, ok := ff.f.(io.ReaderAt)
if !ok {
return 0, errors.New("must implement io.ReaderAt")
}
n, err := ra.ReadAt(p, int64(r.startPos))
r.startPos += n
return n, err
}
n := copy(p, ff.dirIndex[r.startPos:])
r.startPos += n
return n, nil
}
func (r *fsSmallFileReader) WriteTo(w io.Writer) (int64, error) {
ff := r.ff
var n int
var err error
if ff.f == nil {
n, err = w.Write(ff.dirIndex[r.startPos:r.endPos])
return int64(n), err
}
if rf, ok := w.(io.ReaderFrom); ok {
return rf.ReadFrom(r)
}
curPos := r.startPos
bufv := copyBufPool.Get()
buf := bufv.([]byte)
for err == nil {
tailLen := r.endPos - curPos
if tailLen <= 0 {
break
}
if len(buf) > tailLen {
buf = buf[:tailLen]
}
ra, ok := ff.f.(io.ReaderAt)
if !ok {
return 0, errors.New("must implement io.ReaderAt")
}
n, err = ra.ReadAt(buf, int64(curPos))
nw, errw := w.Write(buf[:n])
curPos += nw
if errw == nil && nw != n {
errw = errors.New("bug: Write(p) returned (n, nil), where n != len(p)")
}
if err == nil {
err = errw
}
}
copyBufPool.Put(bufv)
if err == io.EOF {
err = nil
}
return int64(curPos - r.startPos), err
}
type cacheManager interface {
WithLock(work func())
GetFileFromCache(cacheKind CacheKind, path string) (*fsFile, bool)
SetFileToCache(cacheKind CacheKind, path string, ff *fsFile) *fsFile
}
var (
_ cacheManager = (*inMemoryCacheManager)(nil)
_ cacheManager = (*noopCacheManager)(nil)
)
type CacheKind uint8
const (
defaultCacheKind CacheKind = iota
brotliCacheKind
gzipCacheKind
zstdCacheKind
)
func newCacheManager(fs *FS) cacheManager {
if fs.SkipCache {
return &noopCacheManager{}
}
cacheDuration := fs.CacheDuration
if cacheDuration <= 0 {
cacheDuration = FSHandlerCacheDuration
}
instance := &inMemoryCacheManager{
cacheDuration: cacheDuration,
cache: make(map[string]*fsFile),
cacheBrotli: make(map[string]*fsFile),
cacheGzip: make(map[string]*fsFile),
cacheZstd: make(map[string]*fsFile),
}
go instance.handleCleanCache(fs.CleanStop)
return instance
}
type noopCacheManager struct {
cacheLock sync.Mutex
}
func (n *noopCacheManager) WithLock(work func()) {
n.cacheLock.Lock()
work()
n.cacheLock.Unlock()
}
func (*noopCacheManager) GetFileFromCache(cacheKind CacheKind, path string) (*fsFile, bool) {
return nil, false
}
func (*noopCacheManager) SetFileToCache(cacheKind CacheKind, path string, ff *fsFile) *fsFile {
return ff
}
type inMemoryCacheManager struct {
cache map[string]*fsFile
cacheBrotli map[string]*fsFile
cacheGzip map[string]*fsFile
cacheZstd map[string]*fsFile
cacheDuration time.Duration
cacheLock sync.Mutex
}
func (cm *inMemoryCacheManager) WithLock(work func()) {
cm.cacheLock.Lock()
work()
cm.cacheLock.Unlock()
}
func (cm *inMemoryCacheManager) getFsCache(cacheKind CacheKind) map[string]*fsFile {
fileCache := cm.cache
switch cacheKind {
case brotliCacheKind:
fileCache = cm.cacheBrotli
case gzipCacheKind:
fileCache = cm.cacheGzip
case zstdCacheKind:
fileCache = cm.cacheZstd
}
return fileCache
}
func (cm *inMemoryCacheManager) GetFileFromCache(cacheKind CacheKind, path string) (*fsFile, bool) {
fileCache := cm.getFsCache(cacheKind)
cm.cacheLock.Lock()
ff, ok := fileCache[path]
if ok {
ff.readersCount++
}
cm.cacheLock.Unlock()
return ff, ok
}
func (cm *inMemoryCacheManager) SetFileToCache(cacheKind CacheKind, path string, ff *fsFile) *fsFile {
fileCache := cm.getFsCache(cacheKind)
cm.cacheLock.Lock()
ff1, ok := fileCache[path]
if !ok {
fileCache[path] = ff
ff.readersCount++
} else {
ff1.readersCount++
}
cm.cacheLock.Unlock()
if ok {
// The file has been already opened by another
// goroutine, so close the current file and use
// the file opened by another goroutine instead.
ff.Release()
ff = ff1
}
return ff
}
func (cm *inMemoryCacheManager) handleCleanCache(cleanStop chan struct{}) {
var pendingFiles []*fsFile
clean := func() {
pendingFiles = cm.cleanCache(pendingFiles)
}
if cleanStop != nil {
t := time.NewTicker(cm.cacheDuration / 2)
for {
select {
case <-t.C:
clean()
case _, stillOpen := <-cleanStop:
// Ignore values send on the channel, only stop when it is closed.
if !stillOpen {
t.Stop()
return
}
}
}
}
for {
time.Sleep(cm.cacheDuration / 2)
clean()
}
}
func (cm *inMemoryCacheManager) cleanCache(pendingFiles []*fsFile) []*fsFile {
var filesToRelease []*fsFile
cm.cacheLock.Lock()
// Close files which couldn't be closed before due to non-zero
// readers count on the previous run.
var remainingFiles []*fsFile
for _, ff := range pendingFiles {
if ff.readersCount > 0 {
remainingFiles = append(remainingFiles, ff)
} else {
filesToRelease = append(filesToRelease, ff)
}
}
pendingFiles = remainingFiles
pendingFiles, filesToRelease = cleanCacheNolock(cm.cache, pendingFiles, filesToRelease, cm.cacheDuration)
pendingFiles, filesToRelease = cleanCacheNolock(cm.cacheBrotli, pendingFiles, filesToRelease, cm.cacheDuration)
pendingFiles, filesToRelease = cleanCacheNolock(cm.cacheGzip, pendingFiles, filesToRelease, cm.cacheDuration)
pendingFiles, filesToRelease = cleanCacheNolock(cm.cacheZstd, pendingFiles, filesToRelease, cm.cacheDuration)
cm.cacheLock.Unlock()
for _, ff := range filesToRelease {
ff.Release()
}
return pendingFiles
}
func cleanCacheNolock(
cache map[string]*fsFile, pendingFiles, filesToRelease []*fsFile, cacheDuration time.Duration,
) ([]*fsFile, []*fsFile) {
t := time.Now()
for k, ff := range cache {
if t.Sub(ff.t) > cacheDuration {
if ff.readersCount > 0 {
// There are pending readers on stale file handle,
// so we cannot close it. Put it into pendingFiles
// so it will be closed later.
pendingFiles = append(pendingFiles, ff)
} else {
filesToRelease = append(filesToRelease, ff)
}
delete(cache, k)
}
}
return pendingFiles, filesToRelease
}
func (h *fsHandler) pathToFilePath(path string) string {
if _, ok := h.filesystem.(*osFS); !ok {
if len(path) < 1 {
return path
}
return path[1:]
}
return filepath.FromSlash(h.root + path)
}
func (h *fsHandler) filePathToCompressed(filePath string) string {
if h.root == h.compressRoot {
return filePath
}
if !strings.HasPrefix(filePath, h.root) {
return filePath
}
return filepath.FromSlash(h.compressRoot + filePath[len(h.root):])
}
func (h *fsHandler) handleRequest(ctx *RequestCtx) {
var path []byte
if h.pathRewrite != nil {
path = h.pathRewrite(ctx)
} else {
path = ctx.Path()
}
hasTrailingSlash := len(path) > 0 && path[len(path)-1] == '/'
if n := bytes.IndexByte(path, 0); n >= 0 {
ctx.Logger().Printf("cannot serve path with nil byte at position %d: %q", n, path)
ctx.Error("Are you a hacker?", StatusBadRequest)
return
}
if h.pathRewrite != nil {
// There is no need to check for '/../' if path = ctx.Path(),
// since ctx.Path must normalize and sanitize the path.
if n := bytes.Index(path, strSlashDotDotSlash); n >= 0 {
ctx.Logger().Printf("cannot serve path with '/../' at position %d due to security reasons: %q", n, path)
ctx.Error("Internal Server Error", StatusInternalServerError)
return
}
}
mustCompress := false
fileCacheKind := defaultCacheKind
fileEncoding := ""
byteRange := ctx.Request.Header.peek(strRange)
if len(byteRange) == 0 && h.compress {
switch {
case h.compressBrotli && ctx.Request.Header.HasAcceptEncodingBytes(strBr):
mustCompress = true
fileCacheKind = brotliCacheKind
fileEncoding = "br"
case h.compressZstd && ctx.Request.Header.HasAcceptEncodingBytes(strZstd):
mustCompress = true
fileCacheKind = zstdCacheKind
fileEncoding = "zstd"
case ctx.Request.Header.HasAcceptEncodingBytes(strGzip):
mustCompress = true
fileCacheKind = gzipCacheKind
fileEncoding = "gzip"
}
}
originalPathStr := string(path)
pathStr := originalPathStr
if hasTrailingSlash {
pathStr = originalPathStr[:len(originalPathStr)-1]
}
ff, ok := h.cacheManager.GetFileFromCache(fileCacheKind, originalPathStr)
if !ok {
filePath := h.pathToFilePath(pathStr)
var err error
ff, err = h.openFSFile(filePath, mustCompress, fileEncoding)
if mustCompress && err == errNoCreatePermission {
ctx.Logger().Printf("insufficient permissions for saving compressed file for %q. Serving uncompressed file. "+
"Allow write access to the directory with this file in order to improve fasthttp performance", filePath)
mustCompress = false
ff, err = h.openFSFile(filePath, mustCompress, fileEncoding)
}
if errors.Is(err, errDirIndexRequired) {
if !hasTrailingSlash {
ctx.RedirectBytes(append(path, '/'), StatusFound)
return
}
ff, err = h.openIndexFile(ctx, filePath, mustCompress, fileEncoding)
if err != nil {
ctx.Logger().Printf("cannot open dir index %q: %v", filePath, err)
ctx.Error("Directory index is forbidden", StatusForbidden)
return
}
} else if err != nil {
ctx.Logger().Printf("cannot open file %q: %v", filePath, err)
if h.pathNotFound == nil {
ctx.Error("Cannot open requested path", StatusNotFound)
} else {
ctx.SetStatusCode(StatusNotFound)
h.pathNotFound(ctx)
}
return
}
ff = h.cacheManager.SetFileToCache(fileCacheKind, originalPathStr, ff)
}
if !ctx.IfModifiedSince(ff.lastModified) {
ff.decReadersCount()
ctx.NotModified()
return
}
r, err := ff.NewReader()
if err != nil {
ctx.Logger().Printf("cannot obtain file reader for path=%q: %v", path, err)
ctx.Error("Internal Server Error", StatusInternalServerError)
return
}
hdr := &ctx.Response.Header
if ff.compressed {
switch fileEncoding {
case "br":
hdr.SetContentEncodingBytes(strBr)
hdr.addVaryBytes(strAcceptEncoding)
case "gzip":
hdr.SetContentEncodingBytes(strGzip)
hdr.addVaryBytes(strAcceptEncoding)
case "zstd":
hdr.SetContentEncodingBytes(strZstd)
hdr.addVaryBytes(strAcceptEncoding)
}
}
statusCode := StatusOK
contentLength := ff.contentLength
if h.acceptByteRange {
hdr.setNonSpecial(strAcceptRanges, strBytes)
if len(byteRange) > 0 {
startPos, endPos, err := ParseByteRange(byteRange, contentLength)
if err != nil {
_ = r.(io.Closer).Close()
ctx.Logger().Printf("cannot parse byte range %q for path=%q: %v", byteRange, path, err)
ctx.Error("Range Not Satisfiable", StatusRequestedRangeNotSatisfiable)
return
}
if err = r.(byteRangeUpdater).UpdateByteRange(startPos, endPos); err != nil {
_ = r.(io.Closer).Close()
ctx.Logger().Printf("cannot seek byte range %q for path=%q: %v", byteRange, path, err)
ctx.Error("Internal Server Error", StatusInternalServerError)
return
}
hdr.SetContentRange(startPos, endPos, contentLength)
contentLength = endPos - startPos + 1
statusCode = StatusPartialContent
}
}
hdr.setNonSpecial(strLastModified, ff.lastModifiedStr)
if !ctx.IsHead() {
ctx.SetBodyStream(r, contentLength)
} else {
ctx.Response.ResetBody()
ctx.Response.SkipBody = true
ctx.Response.Header.SetContentLength(contentLength)
if rc, ok := r.(io.Closer); ok {
if err := rc.Close(); err != nil {
ctx.Logger().Printf("cannot close file reader: %v", err)
ctx.Error("Internal Server Error", StatusInternalServerError)
return
}
}
}
hdr.noDefaultContentType = true
if len(hdr.ContentType()) == 0 {
ctx.SetContentType(ff.contentType)
}
ctx.SetStatusCode(statusCode)
}
type byteRangeUpdater interface {
UpdateByteRange(startPos, endPos int) error
}
// ParseByteRange parses 'Range: bytes=...' header value.
//
// It follows https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 .
func ParseByteRange(byteRange []byte, contentLength int) (startPos, endPos int, err error) {
b := byteRange
if !bytes.HasPrefix(b, strBytes) {
return 0, 0, fmt.Errorf("unsupported range units: %q. Expecting %q", byteRange, strBytes)
}
b = b[len(strBytes):]
if len(b) == 0 || b[0] != '=' {
return 0, 0, fmt.Errorf("missing byte range in %q", byteRange)
}
b = b[1:]
n := bytes.IndexByte(b, '-')
if n < 0 {
return 0, 0, fmt.Errorf("missing the end position of byte range in %q", byteRange)
}
if n == 0 {
v, err := ParseUint(b[n+1:])
if err != nil {
return 0, 0, err
}
startPos := contentLength - v
if startPos < 0 {
startPos = 0
}
return startPos, contentLength - 1, nil
}
if startPos, err = ParseUint(b[:n]); err != nil {
return 0, 0, err
}
if startPos >= contentLength {
return 0, 0, fmt.Errorf("the start position of byte range cannot exceed %d. byte range %q", contentLength-1, byteRange)
}
b = b[n+1:]
if len(b) == 0 {
return startPos, contentLength - 1, nil
}
if endPos, err = ParseUint(b); err != nil {
return 0, 0, err
}
if endPos >= contentLength {
endPos = contentLength - 1
}
if endPos < startPos {
return 0, 0, fmt.Errorf("the start position of byte range cannot exceed the end position. byte range %q", byteRange)
}
return startPos, endPos, nil
}
func (h *fsHandler) openIndexFile(ctx *RequestCtx, dirPath string, mustCompress bool, fileEncoding string) (*fsFile, error) {
for _, indexName := range h.indexNames {
indexFilePath := indexName
if dirPath != "" {
indexFilePath = dirPath + "/" + indexName
}
ff, err := h.openFSFile(indexFilePath, mustCompress, fileEncoding)
if err == nil {
return ff, nil
}
if mustCompress && err == errNoCreatePermission {
ctx.Logger().Printf("insufficient permissions for saving compressed file for %q. Serving uncompressed file. "+
"Allow write access to the directory with this file in order to improve fasthttp performance", indexFilePath)
mustCompress = false
return h.openFSFile(indexFilePath, mustCompress, fileEncoding)
}
if !errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("cannot open file %q: %w", indexFilePath, err)
}
}
if !h.generateIndexPages {
return nil, fmt.Errorf("cannot access directory without index page. Directory %q", dirPath)
}
return h.createDirIndex(ctx, dirPath, mustCompress, fileEncoding)
}
var (
errDirIndexRequired = errors.New("directory index required")
errNoCreatePermission = errors.New("no 'create file' permissions")
)
func (h *fsHandler) createDirIndex(ctx *RequestCtx, dirPath string, mustCompress bool, fileEncoding string) (*fsFile, error) {
w := &bytebufferpool.ByteBuffer{}
base := ctx.URI()
// io/fs doesn't support ReadDir with empty path.
if dirPath == "" {
dirPath = "."
}
basePathEscaped := html.EscapeString(string(base.Path()))
_, _ = fmt.Fprintf(w, "<html><head><title>%s</title><style>.dir { font-weight: bold }</style></head><body>", basePathEscaped)
_, _ = fmt.Fprintf(w, "<h1>%s</h1>", basePathEscaped)
_, _ = fmt.Fprintf(w, "<ul>")
if len(basePathEscaped) > 1 {
var parentURI URI
base.CopyTo(&parentURI)
parentURI.Update(string(base.Path()) + "/..")
parentPathEscaped := html.EscapeString(string(parentURI.Path()))
_, _ = fmt.Fprintf(w, `<li><a href="%s" class="dir">..</a></li>`, parentPathEscaped)
}
dirEntries, err := fs.ReadDir(h.filesystem, dirPath)
if err != nil {
return nil, err
}
fm := make(map[string]fs.FileInfo, len(dirEntries))
filenames := make([]string, 0, len(dirEntries))
nestedContinue:
for _, de := range dirEntries {
name := de.Name()
for _, cfs := range h.compressedFileSuffixes {
if strings.HasSuffix(name, cfs) {
// Do not show compressed files on index page.
continue nestedContinue
}
}
fi, err := de.Info()
if err != nil {
ctx.Logger().Printf("cannot fetch information from dir entry %q: %v, skip", name, err)
continue nestedContinue
}
fm[name] = fi
filenames = append(filenames, name)
}
var u URI
base.CopyTo(&u)
u.Update(string(u.Path()) + "/")
sort.Strings(filenames)
for _, name := range filenames {
u.Update(name)
pathEscaped := html.EscapeString(string(u.Path()))
fi := fm[name]
auxStr := "dir"
className := "dir"
if !fi.IsDir() {
auxStr = fmt.Sprintf("file, %d bytes", fi.Size())
className = "file"
}
_, _ = fmt.Fprintf(w, `<li><a href="%s" class="%s">%s</a>, %s, last modified %s</li>`,
pathEscaped, className, html.EscapeString(name), auxStr, fsModTime(fi.ModTime()))
}
_, _ = fmt.Fprintf(w, "</ul></body></html>")
if mustCompress {
var zbuf bytebufferpool.ByteBuffer
switch fileEncoding {
case "br":
zbuf.B = AppendBrotliBytesLevel(zbuf.B, w.B, CompressDefaultCompression)
case "gzip":
zbuf.B = AppendGzipBytesLevel(zbuf.B, w.B, CompressDefaultCompression)
case "zstd":
zbuf.B = AppendZstdBytesLevel(zbuf.B, w.B, CompressZstdDefault)
}
w = &zbuf
}
dirIndex := w.B
lastModified := time.Now()
ff := &fsFile{
h: h,
dirIndex: dirIndex,
contentType: "text/html; charset=utf-8",
contentLength: len(dirIndex),
compressed: mustCompress,
lastModified: lastModified,
lastModifiedStr: AppendHTTPDate(nil, lastModified),
t: lastModified,
}
return ff, nil
}
const (
fsMinCompressRatio = 0.8
fsMaxCompressibleFileSize = 8 * 1024 * 1024
)
func (h *fsHandler) compressAndOpenFSFile(filePath, fileEncoding string) (*fsFile, error) {
f, err := h.filesystem.Open(filePath)
if err != nil {
return nil, err
}
fileInfo, err := f.Stat()
if err != nil {
_ = f.Close()
return nil, fmt.Errorf("cannot obtain info for file %q: %w", filePath, err)
}
if fileInfo.IsDir() {
_ = f.Close()
return nil, errDirIndexRequired
}
if strings.HasSuffix(filePath, h.compressedFileSuffixes[fileEncoding]) ||
fileInfo.Size() > fsMaxCompressibleFileSize ||
!isFileCompressible(f, fsMinCompressRatio) {
return h.newFSFile(f, fileInfo, false, filePath, "")
}
compressedFilePath := h.filePathToCompressed(filePath)
if _, ok := h.filesystem.(*osFS); !ok {
return h.newCompressedFSFileCache(f, fileInfo, compressedFilePath, fileEncoding)
}
if compressedFilePath != filePath {
if err := os.MkdirAll(filepath.Dir(compressedFilePath), 0o750); err != nil {
return nil, err
}
}
compressedFilePath += h.compressedFileSuffixes[fileEncoding]
absPath, err := filepath.Abs(compressedFilePath)
if err != nil {
_ = f.Close()
return nil, fmt.Errorf("cannot determine absolute path for %q: %v", compressedFilePath, err)
}
flock := getFileLock(absPath)
flock.Lock()
ff, err := h.compressFileNolock(f, fileInfo, filePath, compressedFilePath, fileEncoding)
flock.Unlock()
return ff, err
}
func (h *fsHandler) compressFileNolock(
f fs.File, fileInfo fs.FileInfo, filePath, compressedFilePath, fileEncoding string,
) (*fsFile, error) {
// Attempt to open compressed file created by another concurrent
// goroutine.
// It is safe opening such a file, since the file creation
// is guarded by file mutex - see getFileLock call.
if _, err := os.Stat(compressedFilePath); err == nil {
_ = f.Close()
return h.newCompressedFSFile(compressedFilePath, fileEncoding)
}
// Create temporary file, so concurrent goroutines don't use
// it until it is created.
tmpFilePath := compressedFilePath + ".tmp"
zf, err := os.Create(tmpFilePath)
if err != nil {
_ = f.Close()
if !errors.Is(err, fs.ErrPermission) {
return nil, fmt.Errorf("cannot create temporary file %q: %w", tmpFilePath, err)
}
return nil, errNoCreatePermission
}
switch fileEncoding {
case "br":
zw := acquireStacklessBrotliWriter(zf, CompressDefaultCompression)
_, err = copyZeroAlloc(zw, f)
if errf := zw.Flush(); err == nil {
err = errf
}
releaseStacklessBrotliWriter(zw, CompressDefaultCompression)
case "gzip":
zw := acquireStacklessGzipWriter(zf, CompressDefaultCompression)
_, err = copyZeroAlloc(zw, f)
if errf := zw.Flush(); err == nil {
err = errf
}
releaseStacklessGzipWriter(zw, CompressDefaultCompression)
case "zstd":
zw := acquireStacklessZstdWriter(zf, CompressZstdDefault)
_, err = copyZeroAlloc(zw, f)
if errf := zw.Flush(); err == nil {
err = errf
}
releaseStacklessZstdWriter(zw, CompressZstdDefault)
}
_ = zf.Close()
_ = f.Close()
if err != nil {
return nil, fmt.Errorf("error when compressing file %q to %q: %w", filePath, tmpFilePath, err)
}
if err = os.Chtimes(tmpFilePath, time.Now(), fileInfo.ModTime()); err != nil {
return nil, fmt.Errorf("cannot change modification time to %v for tmp file %q: %v",
fileInfo.ModTime(), tmpFilePath, err)
}
if err = os.Rename(tmpFilePath, compressedFilePath); err != nil {
return nil, fmt.Errorf("cannot move compressed file from %q to %q: %w", tmpFilePath, compressedFilePath, err)
}
return h.newCompressedFSFile(compressedFilePath, fileEncoding)
}
// newCompressedFSFileCache use memory cache compressed files.
func (h *fsHandler) newCompressedFSFileCache(f fs.File, fileInfo fs.FileInfo, filePath, fileEncoding string) (*fsFile, error) {
var (
w = &bytebufferpool.ByteBuffer{}
err error
)
switch fileEncoding {
case "br":
zw := acquireStacklessBrotliWriter(w, CompressDefaultCompression)
_, err = copyZeroAlloc(zw, f)
if errf := zw.Flush(); err == nil {
err = errf
}
releaseStacklessBrotliWriter(zw, CompressDefaultCompression)
case "gzip":
zw := acquireStacklessGzipWriter(w, CompressDefaultCompression)
_, err = copyZeroAlloc(zw, f)
if errf := zw.Flush(); err == nil {
err = errf
}
releaseStacklessGzipWriter(zw, CompressDefaultCompression)
case "zstd":
zw := acquireStacklessZstdWriter(w, CompressZstdDefault)
_, err = copyZeroAlloc(zw, f)
if errf := zw.Flush(); err == nil {
err = errf
}
releaseStacklessZstdWriter(zw, CompressZstdDefault)
}
defer func() { _ = f.Close() }()
if err != nil {
return nil, fmt.Errorf("error when compressing file %q: %w", filePath, err)
}
seeker, ok := f.(io.Seeker)
if !ok {
return nil, errors.New("not implemented io.Seeker")
}
if _, err = seeker.Seek(0, io.SeekStart); err != nil {
return nil, err
}
ext := fileExtension(fileInfo.Name(), false, h.compressedFileSuffixes[fileEncoding])
contentType := mime.TypeByExtension(ext)
if contentType == "" {
data, err := readFileHeader(f, false, fileEncoding)
if err != nil {
return nil, fmt.Errorf("cannot read header of the file %q: %w", fileInfo.Name(), err)
}
contentType = http.DetectContentType(data)
}
dirIndex := w.B
lastModified := fileInfo.ModTime()
ff := &fsFile{
h: h,
dirIndex: dirIndex,
contentType: contentType,
contentLength: len(dirIndex),
compressed: true,
lastModified: lastModified,
lastModifiedStr: AppendHTTPDate(nil, lastModified),
t: time.Now(),
}
return ff, nil
}
func (h *fsHandler) newCompressedFSFile(filePath, fileEncoding string) (*fsFile, error) {
f, err := h.filesystem.Open(filePath)
if err != nil {
return nil, fmt.Errorf("cannot open compressed file %q: %w", filePath, err)
}
fileInfo, err := f.Stat()
if err != nil {
_ = f.Close()
return nil, fmt.Errorf("cannot obtain info for compressed file %q: %w", filePath, err)
}
return h.newFSFile(f, fileInfo, true, filePath, fileEncoding)
}
func (h *fsHandler) openFSFile(filePath string, mustCompress bool, fileEncoding string) (*fsFile, error) {
filePathOriginal := filePath
if mustCompress {
filePath += h.compressedFileSuffixes[fileEncoding]
}
f, err := h.filesystem.Open(filePath)
if err != nil {
if mustCompress && errors.Is(err, fs.ErrNotExist) {
return h.compressAndOpenFSFile(filePathOriginal, fileEncoding)
}
// If the file is not found and the path is empty, let's return errDirIndexRequired error.
if filePath == "" && (errors.Is(err, fs.ErrNotExist) || errors.Is(err, fs.ErrInvalid)) {
return nil, errDirIndexRequired
}
return nil, err
}
fileInfo, err := f.Stat()
if err != nil {
_ = f.Close()
return nil, fmt.Errorf("cannot obtain info for file %q: %w", filePath, err)
}
if fileInfo.IsDir() {
_ = f.Close()
if mustCompress {
return nil, fmt.Errorf("directory with unexpected suffix found: %q. Suffix: %q",
filePath, h.compressedFileSuffixes[fileEncoding])
}
return nil, errDirIndexRequired
}
if mustCompress {
fileInfoOriginal, err := fs.Stat(h.filesystem, filePathOriginal)
if err != nil {
_ = f.Close()
return nil, fmt.Errorf("cannot obtain info for original file %q: %w", filePathOriginal, err)
}
// Only re-create the compressed file if there was more than a second between the mod times.
// On macOS the gzip seems to truncate the nanoseconds in the mod time causing the original file
// to look newer than the gzipped file.
if fileInfoOriginal.ModTime().Sub(fileInfo.ModTime()) >= time.Second {
// The compressed file became stale. Re-create it.
_ = f.Close()
_ = os.Remove(filePath)
return h.compressAndOpenFSFile(filePathOriginal, fileEncoding)
}
}
return h.newFSFile(f, fileInfo, mustCompress, filePath, fileEncoding)
}
func (h *fsHandler) newFSFile(f fs.File, fileInfo fs.FileInfo, compressed bool, filePath, fileEncoding string) (*fsFile, error) {
n := fileInfo.Size()
contentLength := int(n)
if n != int64(contentLength) {
_ = f.Close()
return nil, fmt.Errorf("too big file: %d bytes", n)
}
// detect content-type
ext := fileExtension(fileInfo.Name(), compressed, h.compressedFileSuffixes[fileEncoding])
contentType := mime.TypeByExtension(ext)
if contentType == "" {
data, err := readFileHeader(f, compressed, fileEncoding)
if err != nil {
return nil, fmt.Errorf("cannot read header of the file %q: %w", fileInfo.Name(), err)
}
contentType = http.DetectContentType(data)
}
lastModified := fileInfo.ModTime()
ff := &fsFile{
h: h,
f: f,
filename: filePath,
contentType: contentType,
contentLength: contentLength,
compressed: compressed,
lastModified: lastModified,
lastModifiedStr: AppendHTTPDate(nil, lastModified),
t: time.Now(),
}
return ff, nil
}
func readFileHeader(f io.Reader, compressed bool, fileEncoding string) ([]byte, error) {
r := f
var (
br *brotli.Reader
zr *gzip.Reader
zsr *zstd.Decoder
)
if compressed {
var err error
switch fileEncoding {
case "br":
if br, err = acquireBrotliReader(f); err != nil {
return nil, err
}
r = br
case "gzip":
if zr, err = acquireGzipReader(f); err != nil {
return nil, err
}
r = zr
case "zstd":
if zsr, err = acquireZstdReader(f); err != nil {
return nil, err
}
r = zsr
}
}
lr := &io.LimitedReader{
R: r,
N: 512,
}
data, err := io.ReadAll(lr)
seeker, ok := f.(io.Seeker)
if !ok {
return nil, errors.New("must implement io.Seeker")
}
if _, err := seeker.Seek(0, io.SeekStart); err != nil {
return nil, err
}
if br != nil {
releaseBrotliReader(br)
}
if zr != nil {
releaseGzipReader(zr)
}
if zsr != nil {
releaseZstdReader(zsr)
}
return data, err
}
func stripLeadingSlashes(path []byte, stripSlashes int) []byte {
for stripSlashes > 0 && len(path) > 0 {
if path[0] != '/' {
// developer sanity-check
panic("BUG: path must start with slash")
}
n := bytes.IndexByte(path[1:], '/')
if n < 0 {
path = path[:0]
break
}
path = path[n+1:]
stripSlashes--
}
return path
}
func stripTrailingSlashes(path []byte) []byte {
for len(path) > 0 && path[len(path)-1] == '/' {
path = path[:len(path)-1]
}
return path
}
func fileExtension(path string, compressed bool, compressedFileSuffix string) string {
if compressed && strings.HasSuffix(path, compressedFileSuffix) {
path = path[:len(path)-len(compressedFileSuffix)]
}
n := strings.LastIndexByte(path, '.')
if n < 0 {
return ""
}
return path[n:]
}
// FileLastModified returns last modified time for the file.
func FileLastModified(path string) (time.Time, error) {
f, err := os.Open(path)
if err != nil {
return zeroTime, err
}
fileInfo, err := f.Stat()
_ = f.Close()
if err != nil {
return zeroTime, err
}
return fsModTime(fileInfo.ModTime()), nil
}
func fsModTime(t time.Time) time.Time {
return t.In(time.UTC).Truncate(time.Second)
}
var filesLockMap sync.Map
func getFileLock(absPath string) *sync.Mutex {
v, _ := filesLockMap.LoadOrStore(absPath, &sync.Mutex{})
filelock := v.(*sync.Mutex)
return filelock
}
var _ fs.FS = (*osFS)(nil)
type osFS struct{}
func (o *osFS) Open(name string) (fs.File, error) { return os.Open(name) }
func (o *osFS) Stat(name string) (fs.FileInfo, error) { return os.Stat(name) }
package fasthttp
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"iter"
"log/slog"
"sync"
"sync/atomic"
"time"
)
const (
rChar = byte('\r')
nChar = byte('\n')
)
type header struct {
h []argsKV
cookies []argsKV
bufK []byte
bufV []byte
contentLengthBytes []byte
contentType []byte
protocol []byte
mulHeader [][]byte
trailer [][]byte
contentLength int
disableNormalizing bool
secureErrorLogMessage bool
noHTTP11 bool
connectionClose bool
noDefaultContentType bool
}
// ResponseHeader represents HTTP response header.
//
// It is forbidden copying ResponseHeader instances.
// Create new instances instead and use CopyTo.
//
// ResponseHeader instance MUST NOT be used from concurrently running
// goroutines.
type ResponseHeader struct {
noCopy noCopy
header
statusMessage []byte
contentEncoding []byte
server []byte
statusCode int
noDefaultDate bool
}
// RequestHeader represents HTTP request header.
//
// It is forbidden copying RequestHeader instances.
// Create new instances instead and use CopyTo.
//
// RequestHeader instance MUST NOT be used from concurrently running
// goroutines.
type RequestHeader struct {
noCopy noCopy
header
method []byte
requestURI []byte
host []byte
userAgent []byte
// stores an immutable copy of headers as they were received from the
// wire.
rawHeaders []byte
disableSpecialHeader bool
cookiesCollected bool
}
// SetContentRange sets 'Content-Range: bytes startPos-endPos/contentLength'
// header.
func (h *ResponseHeader) SetContentRange(startPos, endPos, contentLength int) {
b := h.bufV[:0]
b = append(b, strBytes...)
b = append(b, ' ')
b = AppendUint(b, startPos)
b = append(b, '-')
b = AppendUint(b, endPos)
b = append(b, '/')
b = AppendUint(b, contentLength)
h.bufV = b
h.setNonSpecial(strContentRange, h.bufV)
}
// SetByteRange sets 'Range: bytes=startPos-endPos' header.
//
// - If startPos is negative, then 'bytes=-startPos' value is set.
// - If endPos is negative, then 'bytes=startPos-' value is set.
func (h *RequestHeader) SetByteRange(startPos, endPos int) {
b := h.bufV[:0]
b = append(b, strBytes...)
b = append(b, '=')
if startPos >= 0 {
b = AppendUint(b, startPos)
} else {
endPos = -startPos
}
b = append(b, '-')
if endPos >= 0 {
b = AppendUint(b, endPos)
}
h.bufV = b
h.setNonSpecial(strRange, h.bufV)
}
// StatusCode returns response status code.
func (h *ResponseHeader) StatusCode() int {
if h.statusCode == 0 {
return StatusOK
}
return h.statusCode
}
// SetStatusCode sets response status code.
func (h *ResponseHeader) SetStatusCode(statusCode int) {
h.statusCode = statusCode
}
// StatusMessage returns response status message.
func (h *ResponseHeader) StatusMessage() []byte {
return h.statusMessage
}
// SetStatusMessage sets response status message bytes.
func (h *ResponseHeader) SetStatusMessage(statusMessage []byte) {
h.statusMessage = append(h.statusMessage[:0], statusMessage...)
}
// SetProtocol sets response protocol bytes.
func (h *ResponseHeader) SetProtocol(protocol []byte) {
h.protocol = append(h.protocol[:0], protocol...)
}
// SetLastModified sets 'Last-Modified' header to the given value.
func (h *ResponseHeader) SetLastModified(t time.Time) {
h.bufV = AppendHTTPDate(h.bufV[:0], t)
h.setNonSpecial(strLastModified, h.bufV)
}
// ConnectionClose returns true if 'Connection: close' header is set.
func (h *header) ConnectionClose() bool {
return h.connectionClose
}
// SetConnectionClose sets 'Connection: close' header.
func (h *header) SetConnectionClose() {
h.connectionClose = true
}
// ResetConnectionClose clears 'Connection: close' header if it exists.
func (h *header) ResetConnectionClose() {
if h.connectionClose {
h.connectionClose = false
h.h = delAllArgs(h.h, HeaderConnection)
}
}
// ConnectionUpgrade returns true if 'Connection: Upgrade' header is set.
func (h *ResponseHeader) ConnectionUpgrade() bool {
return hasHeaderValue(h.Peek(HeaderConnection), strUpgrade)
}
// ConnectionUpgrade returns true if 'Connection: Upgrade' header is set.
func (h *RequestHeader) ConnectionUpgrade() bool {
return hasHeaderValue(h.Peek(HeaderConnection), strUpgrade)
}
// PeekCookie is able to returns cookie by a given key from response.
func (h *ResponseHeader) PeekCookie(key string) []byte {
return peekArgStr(h.cookies, key)
}
// ContentLength returns Content-Length header value.
//
// It may be negative:
// -1 means Transfer-Encoding: chunked.
// -2 means Transfer-Encoding: identity.
func (h *header) ContentLength() int {
return h.contentLength
}
// SetContentLength sets Content-Length header value.
//
// Content-Length may be negative:
// -1 means Transfer-Encoding: chunked.
// -2 means Transfer-Encoding: identity.
func (h *ResponseHeader) SetContentLength(contentLength int) {
if h.mustSkipContentLength() {
return
}
h.contentLength = contentLength
if contentLength >= 0 {
h.contentLengthBytes = AppendUint(h.contentLengthBytes[:0], contentLength)
h.h = delAllArgs(h.h, HeaderTransferEncoding)
return
} else if contentLength == -1 {
h.contentLengthBytes = h.contentLengthBytes[:0]
h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue)
return
}
h.SetConnectionClose()
}
func (h *ResponseHeader) mustSkipContentLength() bool {
// From http/1.1 specs:
// All 1xx (informational), 204 (no content), and 304 (not modified) responses MUST NOT include a message-body
statusCode := h.StatusCode()
// Fast path.
if statusCode < 100 || statusCode == StatusOK {
return false
}
// Slow path.
return statusCode == StatusNotModified || statusCode == StatusNoContent || statusCode < 200
}
// SetContentLength sets Content-Length header value.
//
// Negative content-length sets 'Transfer-Encoding: chunked' header.
func (h *RequestHeader) SetContentLength(contentLength int) {
h.contentLength = contentLength
if contentLength >= 0 {
h.contentLengthBytes = AppendUint(h.contentLengthBytes[:0], contentLength)
h.h = delAllArgs(h.h, HeaderTransferEncoding)
} else {
h.contentLengthBytes = h.contentLengthBytes[:0]
h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue)
}
}
func (h *ResponseHeader) isCompressibleContentType() bool {
contentType := h.ContentType()
return bytes.HasPrefix(contentType, strTextSlash) ||
bytes.HasPrefix(contentType, strApplicationSlash) ||
bytes.HasPrefix(contentType, strImageSVG) ||
bytes.HasPrefix(contentType, strImageIcon) ||
bytes.HasPrefix(contentType, strFontSlash) ||
bytes.HasPrefix(contentType, strMultipartSlash)
}
// ContentType returns Content-Type header value.
func (h *ResponseHeader) ContentType() []byte {
contentType := h.contentType
if !h.noDefaultContentType && len(h.contentType) == 0 {
contentType = defaultContentType
}
return contentType
}
// SetContentType sets Content-Type header value.
func (h *header) SetContentType(contentType string) {
h.contentType = append(h.contentType[:0], contentType...)
}
// SetContentTypeBytes sets Content-Type header value.
func (h *header) SetContentTypeBytes(contentType []byte) {
h.contentType = append(h.contentType[:0], contentType...)
}
// ContentEncoding returns Content-Encoding header value.
func (h *ResponseHeader) ContentEncoding() []byte {
return h.contentEncoding
}
// SetContentEncoding sets Content-Encoding header value.
func (h *ResponseHeader) SetContentEncoding(contentEncoding string) {
h.contentEncoding = append(h.contentEncoding[:0], contentEncoding...)
}
// SetContentEncodingBytes sets Content-Encoding header value.
func (h *ResponseHeader) SetContentEncodingBytes(contentEncoding []byte) {
h.contentEncoding = append(h.contentEncoding[:0], contentEncoding...)
}
// addVaryBytes add value to the 'Vary' header if it's not included.
func (h *ResponseHeader) addVaryBytes(value []byte) {
v := h.peek(strVary)
if len(v) == 0 {
// 'Vary' is not set
h.SetBytesV(HeaderVary, value)
} else if !bytes.Contains(v, value) {
// 'Vary' is set and not contains target value
h.SetBytesV(HeaderVary, append(append(v, ','), value...))
} // else: 'Vary' is set and contains target value
}
// Server returns Server header value.
func (h *ResponseHeader) Server() []byte {
return h.server
}
// SetServer sets Server header value.
func (h *ResponseHeader) SetServer(server string) {
h.server = append(h.server[:0], server...)
}
// SetServerBytes sets Server header value.
func (h *ResponseHeader) SetServerBytes(server []byte) {
h.server = append(h.server[:0], server...)
}
// ContentType returns Content-Type header value.
func (h *RequestHeader) ContentType() []byte {
if h.disableSpecialHeader {
return peekArgBytes(h.h, []byte(HeaderContentType))
}
return h.contentType
}
// ContentEncoding returns Content-Encoding header value.
func (h *RequestHeader) ContentEncoding() []byte {
return peekArgBytes(h.h, strContentEncoding)
}
// SetContentEncoding sets Content-Encoding header value.
func (h *RequestHeader) SetContentEncoding(contentEncoding string) {
h.SetBytesK(strContentEncoding, contentEncoding)
}
// SetContentEncodingBytes sets Content-Encoding header value.
func (h *RequestHeader) SetContentEncodingBytes(contentEncoding []byte) {
h.setNonSpecial(strContentEncoding, contentEncoding)
}
// SetMultipartFormBoundary sets the following Content-Type:
// 'multipart/form-data; boundary=...'
// where ... is substituted by the given boundary.
func (h *RequestHeader) SetMultipartFormBoundary(boundary string) {
b := h.bufV[:0]
b = append(b, strMultipartFormData...)
b = append(b, ';', ' ')
b = append(b, strBoundary...)
b = append(b, '=')
b = append(b, boundary...)
h.bufV = b
h.SetContentTypeBytes(h.bufV)
}
// SetMultipartFormBoundaryBytes sets the following Content-Type:
// 'multipart/form-data; boundary=...'
// where ... is substituted by the given boundary.
func (h *RequestHeader) SetMultipartFormBoundaryBytes(boundary []byte) {
b := h.bufV[:0]
b = append(b, strMultipartFormData...)
b = append(b, ';', ' ')
b = append(b, strBoundary...)
b = append(b, '=')
b = append(b, boundary...)
h.bufV = b
h.SetContentTypeBytes(h.bufV)
}
// SetTrailer sets header Trailer value for chunked response
// to indicate which headers will be sent after the body.
//
// Use Set to set the trailer header later.
//
// Trailers are only supported with chunked transfer.
// Trailers allow the sender to include additional headers at the end of chunked messages.
//
// The following trailers are forbidden:
// 1. necessary for message framing (e.g., Transfer-Encoding and Content-Length),
// 2. routing (e.g., Host),
// 3. request modifiers (e.g., controls and conditionals in Section 5 of [RFC7231]),
// 4. authentication (e.g., see [RFC7235] and [RFC6265]),
// 5. response control data (e.g., see Section 7.1 of [RFC7231]),
// 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer)
//
// Return ErrBadTrailer if contain any forbidden trailers.
func (h *header) SetTrailer(trailer string) error {
return h.SetTrailerBytes(s2b(trailer))
}
// SetTrailerBytes sets Trailer header value for chunked response
// to indicate which headers will be sent after the body.
//
// Use Set to set the trailer header later.
//
// Trailers are only supported with chunked transfer.
// Trailers allow the sender to include additional headers at the end of chunked messages.
//
// The following trailers are forbidden:
// 1. necessary for message framing (e.g., Transfer-Encoding and Content-Length),
// 2. routing (e.g., Host),
// 3. request modifiers (e.g., controls and conditionals in Section 5 of [RFC7231]),
// 4. authentication (e.g., see [RFC7235] and [RFC6265]),
// 5. response control data (e.g., see Section 7.1 of [RFC7231]),
// 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer)
//
// Return ErrBadTrailer if contain any forbidden trailers.
func (h *header) SetTrailerBytes(trailer []byte) error {
h.trailer = h.trailer[:0]
return h.AddTrailerBytes(trailer)
}
// AddTrailer add Trailer header value for chunked response
// to indicate which headers will be sent after the body.
//
// Use Set to set the trailer header later.
//
// Trailers are only supported with chunked transfer.
// Trailers allow the sender to include additional headers at the end of chunked messages.
//
// The following trailers are forbidden:
// 1. necessary for message framing (e.g., Transfer-Encoding and Content-Length),
// 2. routing (e.g., Host),
// 3. request modifiers (e.g., controls and conditionals in Section 5 of [RFC7231]),
// 4. authentication (e.g., see [RFC7235] and [RFC6265]),
// 5. response control data (e.g., see Section 7.1 of [RFC7231]),
// 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer)
//
// Return ErrBadTrailer if contain any forbidden trailers.
func (h *header) AddTrailer(trailer string) error {
return h.AddTrailerBytes(s2b(trailer))
}
var ErrBadTrailer = errors.New("contain forbidden trailer")
// AddTrailerBytes add Trailer header value for chunked response
// to indicate which headers will be sent after the body.
//
// Use Set to set the trailer header later.
//
// Trailers are only supported with chunked transfer.
// Trailers allow the sender to include additional headers at the end of chunked messages.
//
// The following trailers are forbidden:
// 1. necessary for message framing (e.g., Transfer-Encoding and Content-Length),
// 2. routing (e.g., Host),
// 3. request modifiers (e.g., controls and conditionals in Section 5 of [RFC7231]),
// 4. authentication (e.g., see [RFC7235] and [RFC6265]),
// 5. response control data (e.g., see Section 7.1 of [RFC7231]),
// 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer)
//
// Return ErrBadTrailer if contain any forbidden trailers.
func (h *header) AddTrailerBytes(trailer []byte) (err error) {
for i := -1; i+1 < len(trailer); {
trailer = trailer[i+1:]
i = bytes.IndexByte(trailer, ',')
if i < 0 {
i = len(trailer)
}
key := trailer[:i]
for len(key) > 0 && key[0] == ' ' {
key = key[1:]
}
for len(key) > 0 && key[len(key)-1] == ' ' {
key = key[:len(key)-1]
}
// Forbidden by RFC 7230, section 4.1.2
if isBadTrailer(key) {
err = ErrBadTrailer
continue
}
h.bufK = append(h.bufK[:0], key...)
normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(h.bufK, ' ') != -1)
if cap(h.trailer) > len(h.trailer) {
h.trailer = h.trailer[:len(h.trailer)+1]
h.trailer[len(h.trailer)-1] = append(h.trailer[len(h.trailer)-1][:0], h.bufK...)
} else {
key = make([]byte, len(h.bufK))
copy(key, h.bufK)
h.trailer = append(h.trailer, key)
}
}
return err
}
// validHeaderFieldByte returns true if c valid header field byte
// as defined by RFC 7230.
func validHeaderFieldByte(c byte) bool {
return c < 128 && validHeaderFieldByteTable[c] == 1
}
// validHeaderValueByte returns true if c valid header value byte
// as defined by RFC 7230.
func validHeaderValueByte(c byte) bool {
return validHeaderValueByteTable[c] == 1
}
// VisitHeaderParams calls f for each parameter in the given header bytes.
// It stops processing when f returns false or an invalid parameter is found.
// Parameter values may be quoted, in which case \ is treated as an escape
// character, and the value is unquoted before being passed to value.
// See: https://www.rfc-editor.org/rfc/rfc9110#section-5.6.6
//
// f must not retain references to key and/or value after returning.
// Copy key and/or value contents before returning if you need retaining them.
func VisitHeaderParams(b []byte, f func(key, value []byte) bool) {
for len(b) > 0 {
idxSemi := 0
for idxSemi < len(b) && b[idxSemi] != ';' {
idxSemi++
}
if idxSemi >= len(b) {
return
}
b = b[idxSemi+1:]
for len(b) > 0 && b[0] == ' ' {
b = b[1:]
}
n := 0
if len(b) == 0 || !validHeaderFieldByte(b[n]) {
return
}
n++
for n < len(b) && validHeaderFieldByte(b[n]) {
n++
}
if n >= len(b)-1 || b[n] != '=' {
return
}
param := b[:n]
n++
switch {
case validHeaderFieldByte(b[n]):
m := n
n++
for n < len(b) && validHeaderFieldByte(b[n]) {
n++
}
if !f(param, b[m:n]) {
return
}
case b[n] == '"':
foundEndQuote := false
escaping := false
n++
m := n
for ; n < len(b); n++ {
if b[n] == '"' && !escaping {
foundEndQuote = true
break
}
escaping = (b[n] == '\\' && !escaping)
}
if !foundEndQuote {
return
}
if !f(param, b[m:n]) {
return
}
n++
default:
return
}
b = b[n:]
}
}
// MultipartFormBoundary returns boundary part
// from 'multipart/form-data; boundary=...' Content-Type.
func (h *RequestHeader) MultipartFormBoundary() []byte {
b := h.ContentType()
if !bytes.HasPrefix(b, strMultipartFormData) {
return nil
}
b = b[len(strMultipartFormData):]
if len(b) == 0 || b[0] != ';' {
return nil
}
var n int
for len(b) > 0 {
n++
for len(b) > n && b[n] == ' ' {
n++
}
b = b[n:]
if !bytes.HasPrefix(b, strBoundary) {
if n = bytes.IndexByte(b, ';'); n < 0 {
return nil
}
continue
}
b = b[len(strBoundary):]
if len(b) == 0 || b[0] != '=' {
return nil
}
b = b[1:]
if n = bytes.IndexByte(b, ';'); n >= 0 {
b = b[:n]
}
if len(b) > 1 && b[0] == '"' && b[len(b)-1] == '"' {
b = b[1 : len(b)-1]
}
return b
}
return nil
}
// Host returns Host header value.
func (h *RequestHeader) Host() []byte {
if h.disableSpecialHeader {
return peekArgBytes(h.h, []byte(HeaderHost))
}
return h.host
}
// SetHost sets Host header value.
func (h *RequestHeader) SetHost(host string) {
h.host = append(h.host[:0], host...)
}
// SetHostBytes sets Host header value.
func (h *RequestHeader) SetHostBytes(host []byte) {
h.host = append(h.host[:0], host...)
}
// UserAgent returns User-Agent header value.
func (h *RequestHeader) UserAgent() []byte {
if h.disableSpecialHeader {
return peekArgBytes(h.h, []byte(HeaderUserAgent))
}
return h.userAgent
}
// SetUserAgent sets User-Agent header value.
func (h *RequestHeader) SetUserAgent(userAgent string) {
h.userAgent = append(h.userAgent[:0], userAgent...)
}
// SetUserAgentBytes sets User-Agent header value.
func (h *RequestHeader) SetUserAgentBytes(userAgent []byte) {
h.userAgent = append(h.userAgent[:0], userAgent...)
}
// Referer returns Referer header value.
func (h *RequestHeader) Referer() []byte {
return peekArgBytes(h.h, strReferer)
}
// SetReferer sets Referer header value.
func (h *RequestHeader) SetReferer(referer string) {
h.SetBytesK(strReferer, referer)
}
// SetRefererBytes sets Referer header value.
func (h *RequestHeader) SetRefererBytes(referer []byte) {
h.setNonSpecial(strReferer, referer)
}
// Method returns HTTP request method.
func (h *RequestHeader) Method() []byte {
if len(h.method) == 0 {
return []byte(MethodGet)
}
return h.method
}
// SetMethod sets HTTP request method.
func (h *RequestHeader) SetMethod(method string) {
h.method = append(h.method[:0], method...)
}
// SetMethodBytes sets HTTP request method.
func (h *RequestHeader) SetMethodBytes(method []byte) {
h.method = append(h.method[:0], method...)
}
// Protocol returns HTTP protocol.
func (h *header) Protocol() []byte {
if len(h.protocol) == 0 {
return strHTTP11
}
return h.protocol
}
// SetProtocol sets HTTP request protocol.
func (h *RequestHeader) SetProtocol(protocol string) {
h.protocol = append(h.protocol[:0], protocol...)
h.noHTTP11 = !bytes.Equal(h.protocol, strHTTP11)
}
// SetProtocolBytes sets HTTP request protocol.
func (h *RequestHeader) SetProtocolBytes(protocol []byte) {
h.protocol = append(h.protocol[:0], protocol...)
h.noHTTP11 = !bytes.Equal(h.protocol, strHTTP11)
}
// RequestURI returns RequestURI from the first HTTP request line.
func (h *RequestHeader) RequestURI() []byte {
requestURI := h.requestURI
if len(requestURI) == 0 {
requestURI = strSlash
}
return requestURI
}
// SetRequestURI sets RequestURI for the first HTTP request line.
// RequestURI must be properly encoded.
// Use URI.RequestURI for constructing proper RequestURI if unsure.
func (h *RequestHeader) SetRequestURI(requestURI string) {
h.requestURI = append(h.requestURI[:0], requestURI...)
}
// SetRequestURIBytes sets RequestURI for the first HTTP request line.
// RequestURI must be properly encoded.
// Use URI.RequestURI for constructing proper RequestURI if unsure.
func (h *RequestHeader) SetRequestURIBytes(requestURI []byte) {
h.requestURI = append(h.requestURI[:0], requestURI...)
}
// IsGet returns true if request method is GET.
func (h *RequestHeader) IsGet() bool {
return string(h.Method()) == MethodGet
}
// IsPost returns true if request method is POST.
func (h *RequestHeader) IsPost() bool {
return string(h.Method()) == MethodPost
}
// IsPut returns true if request method is PUT.
func (h *RequestHeader) IsPut() bool {
return string(h.Method()) == MethodPut
}
// IsHead returns true if request method is HEAD.
func (h *RequestHeader) IsHead() bool {
return string(h.Method()) == MethodHead
}
// IsDelete returns true if request method is DELETE.
func (h *RequestHeader) IsDelete() bool {
return string(h.Method()) == MethodDelete
}
// IsConnect returns true if request method is CONNECT.
func (h *RequestHeader) IsConnect() bool {
return string(h.Method()) == MethodConnect
}
// IsOptions returns true if request method is OPTIONS.
func (h *RequestHeader) IsOptions() bool {
return string(h.Method()) == MethodOptions
}
// IsTrace returns true if request method is TRACE.
func (h *RequestHeader) IsTrace() bool {
return string(h.Method()) == MethodTrace
}
// IsPatch returns true if request method is PATCH.
func (h *RequestHeader) IsPatch() bool {
return string(h.Method()) == MethodPatch
}
// IsHTTP11 returns true if the header is HTTP/1.1.
func (h *header) IsHTTP11() bool {
return !h.noHTTP11
}
// HasAcceptEncoding returns true if the header contains
// the given Accept-Encoding value.
func (h *RequestHeader) HasAcceptEncoding(acceptEncoding string) bool {
h.bufV = append(h.bufV[:0], acceptEncoding...)
return h.HasAcceptEncodingBytes(h.bufV)
}
// HasAcceptEncodingBytes returns true if the header contains
// the given Accept-Encoding value.
func (h *RequestHeader) HasAcceptEncodingBytes(acceptEncoding []byte) bool {
ae := h.peek(strAcceptEncoding)
n := bytes.Index(ae, acceptEncoding)
if n < 0 {
return false
}
b := ae[n+len(acceptEncoding):]
if len(b) > 0 && b[0] != ',' {
return false
}
if n == 0 {
return true
}
return ae[n-1] == ' '
}
// Len returns the number of headers set,
// i.e. the number of times f is called in VisitAll.
func (h *ResponseHeader) Len() int {
n := 0
for range h.All() {
n++
}
return n
}
// Len returns the number of headers set,
// i.e. the number of times f is called in VisitAll.
func (h *RequestHeader) Len() int {
n := 0
for range h.All() {
n++
}
return n
}
// DisableSpecialHeader disables special header processing.
// fasthttp will not set any special headers for you, such as Host, Content-Type, User-Agent, etc.
// You must set everything yourself.
// If RequestHeader.Read() is called, special headers will be ignored.
// This can be used to control case and order of special headers.
// This is generally not recommended.
func (h *RequestHeader) DisableSpecialHeader() {
h.disableSpecialHeader = true
}
// EnableSpecialHeader enables special header processing.
// fasthttp will send Host, Content-Type, User-Agent, etc headers for you.
// This is suggested and enabled by default.
func (h *RequestHeader) EnableSpecialHeader() {
h.disableSpecialHeader = false
}
// DisableNormalizing disables header names' normalization.
//
// By default all the header names are normalized by uppercasing
// the first letter and all the first letters following dashes,
// while lowercasing all the other letters.
// Examples:
//
// - CONNECTION -> Connection
// - conteNT-tYPE -> Content-Type
// - foo-bar-baz -> Foo-Bar-Baz
//
// Disable header names' normalization only if know what are you doing.
func (h *header) DisableNormalizing() {
h.disableNormalizing = true
}
// EnableNormalizing enables header names' normalization.
//
// Header names are normalized by uppercasing the first letter and
// all the first letters following dashes, while lowercasing all
// the other letters.
// Examples:
//
// - CONNECTION -> Connection
// - conteNT-tYPE -> Content-Type
// - foo-bar-baz -> Foo-Bar-Baz
//
// This is enabled by default unless disabled using DisableNormalizing().
func (h *header) EnableNormalizing() {
h.disableNormalizing = false
}
// SetNoDefaultContentType allows you to control if a default Content-Type header will be set (false) or not (true).
func (h *header) SetNoDefaultContentType(noDefaultContentType bool) {
h.noDefaultContentType = noDefaultContentType
}
// Reset clears response header.
func (h *ResponseHeader) Reset() {
h.disableNormalizing = false
h.SetNoDefaultContentType(false)
h.noDefaultDate = false
h.resetSkipNormalize()
}
func (h *ResponseHeader) resetSkipNormalize() {
h.noHTTP11 = false
h.connectionClose = false
h.statusCode = 0
h.statusMessage = h.statusMessage[:0]
h.protocol = h.protocol[:0]
h.contentLength = 0
h.contentLengthBytes = h.contentLengthBytes[:0]
h.contentType = h.contentType[:0]
h.contentEncoding = h.contentEncoding[:0]
h.server = h.server[:0]
h.h = h.h[:0]
h.cookies = h.cookies[:0]
h.trailer = h.trailer[:0]
h.mulHeader = h.mulHeader[:0]
}
// Reset clears request header.
func (h *RequestHeader) Reset() {
h.disableSpecialHeader = false
h.disableNormalizing = false
h.SetNoDefaultContentType(false)
h.resetSkipNormalize()
}
func (h *RequestHeader) resetSkipNormalize() {
h.noHTTP11 = false
h.connectionClose = false
h.contentLength = 0
h.contentLengthBytes = h.contentLengthBytes[:0]
h.method = h.method[:0]
h.protocol = h.protocol[:0]
h.requestURI = h.requestURI[:0]
h.host = h.host[:0]
h.contentType = h.contentType[:0]
h.userAgent = h.userAgent[:0]
h.trailer = h.trailer[:0]
h.mulHeader = h.mulHeader[:0]
h.h = h.h[:0]
h.cookies = h.cookies[:0]
h.cookiesCollected = false
h.rawHeaders = h.rawHeaders[:0]
}
func (h *header) copyTo(dst *header) {
dst.disableNormalizing = h.disableNormalizing
dst.noHTTP11 = h.noHTTP11
dst.connectionClose = h.connectionClose
dst.noDefaultContentType = h.noDefaultContentType
dst.contentLength = h.contentLength
dst.contentLengthBytes = append(dst.contentLengthBytes, h.contentLengthBytes...)
dst.protocol = append(dst.protocol, h.protocol...)
dst.contentType = append(dst.contentType, h.contentType...)
dst.trailer = copyTrailer(dst.trailer, h.trailer)
dst.cookies = copyArgs(dst.cookies, h.cookies)
dst.h = copyArgs(dst.h, h.h)
}
// CopyTo copies all the headers to dst.
func (h *ResponseHeader) CopyTo(dst *ResponseHeader) {
dst.Reset()
h.copyTo(&dst.header)
dst.noDefaultDate = h.noDefaultDate
dst.statusCode = h.statusCode
dst.statusMessage = append(dst.statusMessage, h.statusMessage...)
dst.contentEncoding = append(dst.contentEncoding, h.contentEncoding...)
dst.server = append(dst.server, h.server...)
}
// CopyTo copies all the headers to dst.
func (h *RequestHeader) CopyTo(dst *RequestHeader) {
dst.Reset()
h.copyTo(&dst.header)
dst.method = append(dst.method, h.method...)
dst.requestURI = append(dst.requestURI, h.requestURI...)
dst.host = append(dst.host, h.host...)
dst.userAgent = append(dst.userAgent, h.userAgent...)
dst.cookiesCollected = h.cookiesCollected
dst.rawHeaders = append(dst.rawHeaders, h.rawHeaders...)
}
// All returns an iterator over key-value pairs in h.
//
// The key and value may invalid outside the iteration loop.
// Copy key and/or value contents for each iteration if you need retaining
// them.
func (h *ResponseHeader) All() iter.Seq2[[]byte, []byte] {
return func(yield func([]byte, []byte) bool) {
if len(h.contentLengthBytes) > 0 && !yield(strContentLength, h.contentLengthBytes) {
return
}
if contentType := h.ContentType(); len(contentType) > 0 && !yield(strContentType, contentType) {
return
}
if contentEncoding := h.ContentEncoding(); len(contentEncoding) > 0 && !yield(strContentEncoding, contentEncoding) {
return
}
if server := h.Server(); len(server) > 0 && !yield(strServer, server) {
return
}
for i := range h.cookies {
if !yield(strSetCookie, h.cookies[i].value) {
return
}
}
if len(h.trailer) > 0 && !yield(strTrailer, appendTrailerBytes(nil, h.trailer, strCommaSpace)) {
return
}
for i := range h.h {
if !yield(h.h[i].key, h.h[i].value) {
return
}
}
if h.ConnectionClose() && !yield(strConnection, strClose) {
return
}
}
}
// VisitAll calls f for each header.
//
// f must not retain references to key and/or value after returning.
// Copy key and/or value contents before returning if you need retaining them.
//
// Deprecated: Use All instead.
func (h *ResponseHeader) VisitAll(f func(key, value []byte)) {
h.All()(func(key, value []byte) bool {
f(key, value)
return true
})
}
// Trailers returns an iterator over trailers in h.
//
// The value of trailer may invalid outside the iteration loop.
func (h *header) Trailers() iter.Seq[[]byte] {
return func(yield func([]byte) bool) {
for i := range h.trailer {
if !yield(h.trailer[i]) {
break
}
}
}
}
// VisitAllTrailer calls f for each response Trailer.
//
// f must not retain references to value after returning.
//
// Deprecated: Use Trailers instead.
func (h *header) VisitAllTrailer(f func(value []byte)) {
h.Trailers()(func(v []byte) bool {
f(v)
return true
})
}
// Cookies returns an iterator over key-value paired response cookie in h.
//
// Cookie name is passed in key and the whole Set-Cookie header value
// is passed in value for each iteration. Value may be parsed with
// Cookie.ParseBytes().
//
// The key and value may invalid outside the iteration loop.
// Copy key and/or value contents for each iteration if you need retaining
// them.
func (h *ResponseHeader) Cookies() iter.Seq2[[]byte, []byte] {
return func(yield func([]byte, []byte) bool) {
for i := range h.cookies {
if !yield(h.cookies[i].key, h.cookies[i].value) {
break
}
}
}
}
// VisitAllCookie calls f for each response cookie.
//
// Cookie name is passed in key and the whole Set-Cookie header value
// is passed in value on each f invocation. Value may be parsed
// with Cookie.ParseBytes().
//
// f must not retain references to key and/or value after returning.
//
// Deprecated: Use Cookies instead.
func (h *ResponseHeader) VisitAllCookie(f func(key, value []byte)) {
h.Cookies()(func(key, value []byte) bool {
f(key, value)
return true
})
}
// Cookies returns an iterator over key-value pairs request cookie in h.
//
// The key and value may invalid outside the iteration loop.
// Copy key and/or value contents for each iteration if you need retaining
// them.
func (h *RequestHeader) Cookies() iter.Seq2[[]byte, []byte] {
return func(yield func([]byte, []byte) bool) {
h.collectCookies()
for i := range h.cookies {
if !yield(h.cookies[i].key, h.cookies[i].value) {
break
}
}
}
}
// VisitAllCookie calls f for each request cookie.
//
// f must not retain references to key and/or value after returning.
//
// Deprecated: Use Cookies instead.
func (h *RequestHeader) VisitAllCookie(f func(key, value []byte)) {
h.Cookies()(func(key, value []byte) bool {
f(key, value)
return true
})
}
// All returns an iterator over key-value pairs in h.
//
// The key and value may invalid outside the iteration loop.
// Copy key and/or value contents for each iteration if you need retaining
// them.
//
// To get the headers in order they were received use AllInOrder.
func (h *RequestHeader) All() iter.Seq2[[]byte, []byte] {
return func(yield func([]byte, []byte) bool) {
if host := h.Host(); len(host) > 0 && !yield(strHost, host) {
return
}
if len(h.contentLengthBytes) > 0 && !yield(strContentLength, h.contentLengthBytes) {
return
}
if contentType := h.ContentType(); len(contentType) > 0 && !yield(strContentType, contentType) {
return
}
if userAgent := h.UserAgent(); len(userAgent) > 0 && !yield(strUserAgent, userAgent) {
return
}
if len(h.trailer) > 0 && !yield(strTrailer, appendTrailerBytes(nil, h.trailer, strCommaSpace)) {
return
}
h.collectCookies()
if len(h.cookies) > 0 {
h.bufV = appendRequestCookieBytes(h.bufV[:0], h.cookies)
if !yield(strCookie, h.bufV) {
return
}
}
for i := range h.h {
if !yield(h.h[i].key, h.h[i].value) {
return
}
}
if h.ConnectionClose() && !yield(strConnection, strClose) {
return
}
}
}
// VisitAll calls f for each header.
//
// f must not retain references to key and/or value after returning.
// Copy key and/or value contents before returning if you need retaining them.
//
// To get the headers in order they were received use VisitAllInOrder.
//
// Deprecated: Use All instead.
func (h *RequestHeader) VisitAll(f func(key, value []byte)) {
h.All()(func(key, value []byte) bool {
f(key, value)
return true
})
}
// AllInOrder returns an iterator over key-value pairs in h in the order they
// were received.
//
// The key and value may invalid outside the iteration loop.
// Copy key and/or value contents for each iteration if you need retaining
// them.
//
// The returned iterator is slightly slower than All because it has to reparse
// the raw headers to get the order.
func (h *RequestHeader) AllInOrder() iter.Seq2[[]byte, []byte] {
return func(yield func([]byte, []byte) bool) {
var s headerScanner
s.b = h.rawHeaders
for s.next() {
normalizeHeaderKey(s.key, h.disableNormalizing || bytes.IndexByte(s.key, ' ') != -1)
if len(s.key) > 0 {
if !yield(s.key, s.value) {
break
}
}
}
}
}
// VisitAllInOrder calls f for each header in the order they were received.
//
// f must not retain references to key and/or value after returning.
// Copy key and/or value contents before returning if you need retaining them.
//
// This function is slightly slower than VisitAll because it has to reparse the
// raw headers to get the order.
//
// Deprecated: Use AllInOrder instead.
func (h *RequestHeader) VisitAllInOrder(f func(key, value []byte)) {
h.AllInOrder()(func(key, value []byte) bool {
f(key, value)
return true
})
}
// Del deletes header with the given key.
func (h *ResponseHeader) Del(key string) {
h.bufK = getHeaderKeyBytes(h.bufK, key, h.disableNormalizing)
h.del(h.bufK)
}
// DelBytes deletes header with the given key.
func (h *ResponseHeader) DelBytes(key []byte) {
h.bufK = append(h.bufK[:0], key...)
normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1)
h.del(h.bufK)
}
func (h *ResponseHeader) del(key []byte) {
switch string(key) {
case HeaderContentType:
h.contentType = h.contentType[:0]
case HeaderContentEncoding:
h.contentEncoding = h.contentEncoding[:0]
case HeaderServer:
h.server = h.server[:0]
case HeaderSetCookie:
h.cookies = h.cookies[:0]
case HeaderContentLength:
h.contentLength = 0
h.contentLengthBytes = h.contentLengthBytes[:0]
case HeaderConnection:
h.connectionClose = false
case HeaderTrailer:
h.trailer = h.trailer[:0]
}
h.h = delAllArgs(h.h, b2s(key))
}
// Del deletes header with the given key.
func (h *RequestHeader) Del(key string) {
h.bufK = getHeaderKeyBytes(h.bufK, key, h.disableNormalizing)
h.del(h.bufK)
}
// DelBytes deletes header with the given key.
func (h *RequestHeader) DelBytes(key []byte) {
h.bufK = append(h.bufK[:0], key...)
normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1)
h.del(h.bufK)
}
func (h *RequestHeader) del(key []byte) {
switch string(key) {
case HeaderHost:
h.host = h.host[:0]
case HeaderContentType:
h.contentType = h.contentType[:0]
case HeaderUserAgent:
h.userAgent = h.userAgent[:0]
case HeaderCookie:
h.cookies = h.cookies[:0]
case HeaderContentLength:
h.contentLength = 0
h.contentLengthBytes = h.contentLengthBytes[:0]
case HeaderConnection:
h.connectionClose = false
case HeaderTrailer:
h.trailer = h.trailer[:0]
}
h.h = delAllArgs(h.h, b2s(key))
}
// setSpecialHeader handles special headers and return true when a header is processed.
func (h *ResponseHeader) setSpecialHeader(key, value []byte) bool {
if len(key) == 0 {
return false
}
switch key[0] | 0x20 {
case 'c':
switch {
case caseInsensitiveCompare(strContentType, key):
h.SetContentTypeBytes(value)
return true
case caseInsensitiveCompare(strContentLength, key):
if contentLength, err := parseContentLength(value); err == nil {
h.contentLength = contentLength
h.contentLengthBytes = append(h.contentLengthBytes[:0], value...)
}
return true
case caseInsensitiveCompare(strContentEncoding, key):
h.SetContentEncodingBytes(value)
return true
case caseInsensitiveCompare(strConnection, key):
if bytes.Equal(strClose, value) {
h.SetConnectionClose()
} else {
h.ResetConnectionClose()
h.setNonSpecial(key, value)
}
return true
}
case 's':
if caseInsensitiveCompare(strServer, key) {
h.SetServerBytes(value)
return true
} else if caseInsensitiveCompare(strSetCookie, key) {
var kv *argsKV
h.cookies, kv = allocArg(h.cookies)
kv.key = getCookieKey(kv.key, value)
kv.value = append(kv.value[:0], value...)
return true
}
case 't':
if caseInsensitiveCompare(strTransferEncoding, key) {
// Transfer-Encoding is managed automatically.
return true
} else if caseInsensitiveCompare(strTrailer, key) {
_ = h.SetTrailerBytes(value)
return true
}
case 'd':
if caseInsensitiveCompare(strDate, key) {
// Date is managed automatically.
return true
}
}
return false
}
// setNonSpecial directly put into map i.e. not a basic header.
func (h *header) setNonSpecial(key, value []byte) {
h.h = setArgBytes(h.h, key, value, argsHasValue)
}
// setSpecialHeader handles special headers and return true when a header is processed.
func (h *RequestHeader) setSpecialHeader(key, value []byte) bool {
if len(key) == 0 || h.disableSpecialHeader {
return false
}
switch key[0] | 0x20 {
case 'c':
switch {
case caseInsensitiveCompare(strContentType, key):
h.SetContentTypeBytes(value)
return true
case caseInsensitiveCompare(strContentLength, key):
if contentLength, err := parseContentLength(value); err == nil {
h.contentLength = contentLength
h.contentLengthBytes = append(h.contentLengthBytes[:0], value...)
}
return true
case caseInsensitiveCompare(strConnection, key):
if bytes.Equal(strClose, value) {
h.SetConnectionClose()
} else {
h.ResetConnectionClose()
h.setNonSpecial(key, value)
}
return true
case caseInsensitiveCompare(strCookie, key):
h.collectCookies()
h.cookies = parseRequestCookies(h.cookies, value)
return true
}
case 't':
if caseInsensitiveCompare(strTransferEncoding, key) {
// Transfer-Encoding is managed automatically.
return true
} else if caseInsensitiveCompare(strTrailer, key) {
_ = h.SetTrailerBytes(value)
return true
}
case 'h':
if caseInsensitiveCompare(strHost, key) {
h.SetHostBytes(value)
return true
}
case 'u':
if caseInsensitiveCompare(strUserAgent, key) {
h.SetUserAgentBytes(value)
return true
}
}
return false
}
// Add adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use Set for setting a single header for the given key.
//
// the Content-Type, Content-Length, Connection, Server, Transfer-Encoding
// and Date headers can only be set once and will overwrite the previous value,
// while Set-Cookie will not clear previous cookies.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
// it will be sent after the chunked response body.
func (h *ResponseHeader) Add(key, value string) {
h.AddBytesKV(s2b(key), s2b(value))
}
// AddBytesK adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesK for setting a single header for the given key.
//
// the Content-Type, Content-Length, Connection, Server, Transfer-Encoding
// and Date headers can only be set once and will overwrite the previous value,
// while Set-Cookie will not clear previous cookies.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
// it will be sent after the chunked response body.
func (h *ResponseHeader) AddBytesK(key []byte, value string) {
h.AddBytesKV(key, s2b(value))
}
// AddBytesV adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesV for setting a single header for the given key.
//
// the Content-Type, Content-Length, Connection, Server, Transfer-Encoding
// and Date headers can only be set once and will overwrite the previous value,
// while Set-Cookie will not clear previous cookies.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
// it will be sent after the chunked response body.
func (h *ResponseHeader) AddBytesV(key string, value []byte) {
h.AddBytesKV(s2b(key), value)
}
// AddBytesKV adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesKV for setting a single header for the given key.
//
// the Content-Type, Content-Length, Connection, Server, Transfer-Encoding
// and Date headers can only be set once and will overwrite the previous value,
// while the Set-Cookie header will not clear previous cookies.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
// it will be sent after the chunked response body.
func (h *ResponseHeader) AddBytesKV(key, value []byte) {
if h.setSpecialHeader(key, value) {
return
}
h.bufK = getHeaderKeyBytes(h.bufK, b2s(key), h.disableNormalizing)
h.h = appendArgBytes(h.h, h.bufK, value, argsHasValue)
}
// Set sets the given 'key: value' header.
//
// Please note that the Set-Cookie header will not clear previous cookies,
// use SetCookie instead to reset cookies.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
// it will be sent after the chunked response body.
//
// Use Add for setting multiple header values under the same key.
func (h *ResponseHeader) Set(key, value string) {
h.bufK, h.bufV = initHeaderKV(h.bufK, h.bufV, key, value, h.disableNormalizing)
h.SetCanonical(h.bufK, h.bufV)
}
// SetBytesK sets the given 'key: value' header.
//
// Please note that the Set-Cookie header will not clear previous cookies,
// use SetCookie instead to reset cookies.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
// it will be sent after the chunked response body.
//
// Use AddBytesK for setting multiple header values under the same key.
func (h *ResponseHeader) SetBytesK(key []byte, value string) {
h.bufV = append(h.bufV[:0], value...)
h.SetBytesKV(key, h.bufV)
}
// SetBytesV sets the given 'key: value' header.
//
// Please note that the Set-Cookie header will not clear previous cookies,
// use SetCookie instead to reset cookies.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
// it will be sent after the chunked response body.
//
// Use AddBytesV for setting multiple header values under the same key.
func (h *ResponseHeader) SetBytesV(key string, value []byte) {
h.bufK = getHeaderKeyBytes(h.bufK, key, h.disableNormalizing)
h.SetCanonical(h.bufK, value)
}
// SetBytesKV sets the given 'key: value' header.
//
// Please note that the Set-Cookie header will not clear previous cookies,
// use SetCookie instead to reset cookies.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
// it will be sent after the chunked response body.
//
// Use AddBytesKV for setting multiple header values under the same key.
func (h *ResponseHeader) SetBytesKV(key, value []byte) {
h.bufK = append(h.bufK[:0], key...)
normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1)
h.SetCanonical(h.bufK, value)
}
// SetCanonical sets the given 'key: value' header assuming that
// key is in canonical form.
//
// Please note that the Set-Cookie header will not clear previous cookies,
// use SetCookie instead to reset cookies.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
// it will be sent after the chunked response body.
func (h *ResponseHeader) SetCanonical(key, value []byte) {
if h.setSpecialHeader(key, value) {
return
}
h.setNonSpecial(key, value)
}
// SetCookie sets the given response cookie.
//
// It is safe re-using the cookie after the function returns.
func (h *ResponseHeader) SetCookie(cookie *Cookie) {
h.cookies = setArgBytes(h.cookies, cookie.Key(), cookie.Cookie(), argsHasValue)
}
// SetCookie sets 'key: value' cookies.
func (h *RequestHeader) SetCookie(key, value string) {
h.collectCookies()
h.cookies = setArg(h.cookies, key, value, argsHasValue)
}
// SetCookieBytesK sets 'key: value' cookies.
func (h *RequestHeader) SetCookieBytesK(key []byte, value string) {
h.SetCookie(b2s(key), value)
}
// SetCookieBytesKV sets 'key: value' cookies.
func (h *RequestHeader) SetCookieBytesKV(key, value []byte) {
h.SetCookie(b2s(key), b2s(value))
}
// DelClientCookie instructs the client to remove the given cookie.
// This doesn't work for a cookie with specific domain or path,
// you should delete it manually like:
//
// c := AcquireCookie()
// c.SetKey(key)
// c.SetDomain("example.com")
// c.SetPath("/path")
// c.SetExpire(CookieExpireDelete)
// h.SetCookie(c)
// ReleaseCookie(c)
//
// Use DelCookie if you want just removing the cookie from response header.
func (h *ResponseHeader) DelClientCookie(key string) {
h.DelCookie(key)
c := AcquireCookie()
c.SetKey(key)
c.SetExpire(CookieExpireDelete)
h.SetCookie(c)
ReleaseCookie(c)
}
// DelClientCookieBytes instructs the client to remove the given cookie.
// This doesn't work for a cookie with specific domain or path,
// you should delete it manually like:
//
// c := AcquireCookie()
// c.SetKey(key)
// c.SetDomain("example.com")
// c.SetPath("/path")
// c.SetExpire(CookieExpireDelete)
// h.SetCookie(c)
// ReleaseCookie(c)
//
// Use DelCookieBytes if you want just removing the cookie from response header.
func (h *ResponseHeader) DelClientCookieBytes(key []byte) {
h.DelClientCookie(b2s(key))
}
// DelCookie removes cookie under the given key from response header.
//
// Note that DelCookie doesn't remove the cookie from the client.
// Use DelClientCookie instead.
func (h *ResponseHeader) DelCookie(key string) {
h.cookies = delAllArgs(h.cookies, key)
}
// DelCookieBytes removes cookie under the given key from response header.
//
// Note that DelCookieBytes doesn't remove the cookie from the client.
// Use DelClientCookieBytes instead.
func (h *ResponseHeader) DelCookieBytes(key []byte) {
h.DelCookie(b2s(key))
}
// DelCookie removes cookie under the given key.
func (h *RequestHeader) DelCookie(key string) {
h.collectCookies()
h.cookies = delAllArgs(h.cookies, key)
}
// DelCookieBytes removes cookie under the given key.
func (h *RequestHeader) DelCookieBytes(key []byte) {
h.DelCookie(b2s(key))
}
// DelAllCookies removes all the cookies from response headers.
func (h *ResponseHeader) DelAllCookies() {
h.cookies = h.cookies[:0]
}
// DelAllCookies removes all the cookies from request headers.
func (h *RequestHeader) DelAllCookies() {
h.collectCookies()
h.cookies = h.cookies[:0]
}
// Add adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use Set for setting a single header for the given key.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
// it will be sent after the chunked request body.
func (h *RequestHeader) Add(key, value string) {
h.AddBytesKV(s2b(key), s2b(value))
}
// AddBytesK adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesK for setting a single header for the given key.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
// it will be sent after the chunked request body.
func (h *RequestHeader) AddBytesK(key []byte, value string) {
h.AddBytesKV(key, s2b(value))
}
// AddBytesV adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesV for setting a single header for the given key.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
// it will be sent after the chunked request body.
func (h *RequestHeader) AddBytesV(key string, value []byte) {
h.AddBytesKV(s2b(key), value)
}
// AddBytesKV adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesKV for setting a single header for the given key.
//
// the Content-Type, Content-Length, Connection, Transfer-Encoding,
// Host and User-Agent headers can only be set once and will overwrite
// the previous value, while the Cookie header will not clear previous cookies.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
// it will be sent after the chunked request body.
func (h *RequestHeader) AddBytesKV(key, value []byte) {
if h.setSpecialHeader(key, value) {
return
}
h.bufK = getHeaderKeyBytes(h.bufK, b2s(key), h.disableNormalizing)
h.h = appendArgBytes(h.h, h.bufK, value, argsHasValue)
}
// Set sets the given 'key: value' header.
//
// Please note that the Cookie header will not clear previous cookies,
// delete cookies before calling in order to reset cookies.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
// it will be sent after the chunked request body.
//
// Use Add for setting multiple header values under the same key.
func (h *RequestHeader) Set(key, value string) {
h.bufK, h.bufV = initHeaderKV(h.bufK, h.bufV, key, value, h.disableNormalizing)
h.SetCanonical(h.bufK, h.bufV)
}
// SetBytesK sets the given 'key: value' header.
//
// Please note that the Cookie header will not clear previous cookies,
// delete cookies before calling in order to reset cookies.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
// it will be sent after the chunked request body.
//
// Use AddBytesK for setting multiple header values under the same key.
func (h *RequestHeader) SetBytesK(key []byte, value string) {
h.bufV = append(h.bufV[:0], value...)
h.SetBytesKV(key, h.bufV)
}
// SetBytesV sets the given 'key: value' header.
//
// Please note that the Cookie header will not clear previous cookies,
// delete cookies before calling in order to reset cookies.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
// it will be sent after the chunked request body.
//
// Use AddBytesV for setting multiple header values under the same key.
func (h *RequestHeader) SetBytesV(key string, value []byte) {
h.bufK = getHeaderKeyBytes(h.bufK, key, h.disableNormalizing)
h.SetCanonical(h.bufK, value)
}
// SetBytesKV sets the given 'key: value' header.
//
// Please note that the Cookie header will not clear previous cookies,
// delete cookies before calling in order to reset cookies.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
// it will be sent after the chunked request body.
//
// Use AddBytesKV for setting multiple header values under the same key.
func (h *RequestHeader) SetBytesKV(key, value []byte) {
h.bufK = append(h.bufK[:0], key...)
normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1)
h.SetCanonical(h.bufK, value)
}
// SetCanonical sets the given 'key: value' header assuming that
// key is in canonical form.
//
// Please note that the Cookie header will not clear previous cookies,
// delete cookies before calling in order to reset cookies.
//
// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
// it will be sent after the chunked request body.
func (h *RequestHeader) SetCanonical(key, value []byte) {
if h.setSpecialHeader(key, value) {
return
}
h.setNonSpecial(key, value)
}
// Peek returns header value for the given key.
//
// The returned value is valid until the response is released,
// either though ReleaseResponse or your request handler returning.
// Do not store references to the returned value. Make copies instead.
func (h *ResponseHeader) Peek(key string) []byte {
h.bufK = getHeaderKeyBytes(h.bufK, key, h.disableNormalizing)
return h.peek(h.bufK)
}
// PeekBytes returns header value for the given key.
//
// The returned value is valid until the response is released,
// either though ReleaseResponse or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (h *ResponseHeader) PeekBytes(key []byte) []byte {
h.bufK = append(h.bufK[:0], key...)
normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1)
return h.peek(h.bufK)
}
// Peek returns header value for the given key.
//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (h *RequestHeader) Peek(key string) []byte {
h.bufK = getHeaderKeyBytes(h.bufK, key, h.disableNormalizing)
return h.peek(h.bufK)
}
// PeekBytes returns header value for the given key.
//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (h *RequestHeader) PeekBytes(key []byte) []byte {
h.bufK = append(h.bufK[:0], key...)
normalizeHeaderKey(h.bufK, h.disableNormalizing || bytes.IndexByte(key, ' ') != -1)
return h.peek(h.bufK)
}
func (h *ResponseHeader) peek(key []byte) []byte {
switch string(key) {
case HeaderContentType:
return h.ContentType()
case HeaderContentEncoding:
return h.ContentEncoding()
case HeaderServer:
return h.Server()
case HeaderConnection:
if h.ConnectionClose() {
return strClose
}
return peekArgBytes(h.h, key)
case HeaderContentLength:
return h.contentLengthBytes
case HeaderSetCookie:
return appendResponseCookieBytes(nil, h.cookies)
case HeaderTrailer:
return appendTrailerBytes(nil, h.trailer, strCommaSpace)
default:
return peekArgBytes(h.h, key)
}
}
func (h *RequestHeader) peek(key []byte) []byte {
switch string(key) {
case HeaderHost:
return h.Host()
case HeaderContentType:
return h.ContentType()
case HeaderUserAgent:
return h.UserAgent()
case HeaderConnection:
if h.ConnectionClose() {
return strClose
}
return peekArgBytes(h.h, key)
case HeaderContentLength:
return h.contentLengthBytes
case HeaderCookie:
if h.cookiesCollected {
return appendRequestCookieBytes(nil, h.cookies)
}
return peekArgBytes(h.h, key)
case HeaderTrailer:
return appendTrailerBytes(nil, h.trailer, strCommaSpace)
default:
return peekArgBytes(h.h, key)
}
}
// PeekAll returns all header value for the given key.
//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Any future calls to the Peek* will modify the returned value.
// Do not store references to returned value. Make copies instead.
func (h *RequestHeader) PeekAll(key string) [][]byte {
h.bufK = getHeaderKeyBytes(h.bufK, key, h.disableNormalizing)
return h.peekAll(h.bufK)
}
func (h *RequestHeader) peekAll(key []byte) [][]byte {
h.mulHeader = h.mulHeader[:0]
switch string(key) {
case HeaderHost:
if host := h.Host(); len(host) > 0 {
h.mulHeader = append(h.mulHeader, host)
}
case HeaderContentType:
if contentType := h.ContentType(); len(contentType) > 0 {
h.mulHeader = append(h.mulHeader, contentType)
}
case HeaderUserAgent:
if ua := h.UserAgent(); len(ua) > 0 {
h.mulHeader = append(h.mulHeader, ua)
}
case HeaderConnection:
if h.ConnectionClose() {
h.mulHeader = append(h.mulHeader, strClose)
} else {
h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key)
}
case HeaderContentLength:
h.mulHeader = append(h.mulHeader, h.contentLengthBytes)
case HeaderCookie:
if h.cookiesCollected {
h.mulHeader = append(h.mulHeader, appendRequestCookieBytes(nil, h.cookies))
} else {
h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key)
}
case HeaderTrailer:
h.mulHeader = append(h.mulHeader, appendTrailerBytes(nil, h.trailer, strCommaSpace))
default:
h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key)
}
return h.mulHeader
}
// PeekAll returns all header value for the given key.
//
// The returned value is valid until the request is released,
// either though ReleaseResponse or your request handler returning.
// Any future calls to the Peek* will modify the returned value.
// Do not store references to returned value. Make copies instead.
func (h *ResponseHeader) PeekAll(key string) [][]byte {
h.bufK = getHeaderKeyBytes(h.bufK, key, h.disableNormalizing)
return h.peekAll(h.bufK)
}
func (h *ResponseHeader) peekAll(key []byte) [][]byte {
h.mulHeader = h.mulHeader[:0]
switch string(key) {
case HeaderContentType:
if contentType := h.ContentType(); len(contentType) > 0 {
h.mulHeader = append(h.mulHeader, contentType)
}
case HeaderContentEncoding:
if contentEncoding := h.ContentEncoding(); len(contentEncoding) > 0 {
h.mulHeader = append(h.mulHeader, contentEncoding)
}
case HeaderServer:
if server := h.Server(); len(server) > 0 {
h.mulHeader = append(h.mulHeader, server)
}
case HeaderConnection:
if h.ConnectionClose() {
h.mulHeader = append(h.mulHeader, strClose)
} else {
h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key)
}
case HeaderContentLength:
h.mulHeader = append(h.mulHeader, h.contentLengthBytes)
case HeaderSetCookie:
h.mulHeader = append(h.mulHeader, appendResponseCookieBytes(nil, h.cookies))
case HeaderTrailer:
h.mulHeader = append(h.mulHeader, appendTrailerBytes(nil, h.trailer, strCommaSpace))
default:
h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key)
}
return h.mulHeader
}
// PeekKeys return all header keys.
//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Any future calls to the Peek* will modify the returned value.
// Do not store references to returned value. Make copies instead.
func (h *header) PeekKeys() [][]byte {
h.mulHeader = h.mulHeader[:0]
h.mulHeader = peekArgsKeys(h.mulHeader, h.h)
return h.mulHeader
}
// PeekTrailerKeys return all trailer keys.
//
// The returned value is valid until the request is released,
// either though ReleaseResponse or your request handler returning.
// Any future calls to the Peek* will modify the returned value.
// Do not store references to returned value. Make copies instead.
func (h *header) PeekTrailerKeys() [][]byte {
h.mulHeader = copyTrailer(h.mulHeader, h.trailer)
return h.mulHeader
}
// Cookie returns cookie for the given key.
func (h *RequestHeader) Cookie(key string) []byte {
h.collectCookies()
return peekArgStr(h.cookies, key)
}
// CookieBytes returns cookie for the given key.
func (h *RequestHeader) CookieBytes(key []byte) []byte {
h.collectCookies()
return peekArgBytes(h.cookies, key)
}
// Cookie fills cookie for the given cookie.Key.
//
// Returns false if cookie with the given cookie.Key is missing.
func (h *ResponseHeader) Cookie(cookie *Cookie) bool {
v := peekArgBytes(h.cookies, cookie.Key())
if v == nil {
return false
}
cookie.ParseBytes(v) //nolint:errcheck
return true
}
// Read reads response header from r.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (h *ResponseHeader) Read(r *bufio.Reader) error {
n := 1
for {
err := h.tryRead(r, n)
if err == nil {
return nil
}
if err != errNeedMore {
h.resetSkipNormalize()
return err
}
n = r.Buffered() + 1
}
}
func (h *ResponseHeader) tryRead(r *bufio.Reader, n int) error {
h.resetSkipNormalize()
b, err := r.Peek(n)
if len(b) == 0 {
// Return ErrTimeout on any timeout.
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
return ErrTimeout
}
// treat all other errors on the first byte read as EOF
if n == 1 || err == io.EOF {
return io.EOF
}
// This is for go 1.6 bug. See https://github.com/golang/go/issues/14121 .
if err == bufio.ErrBufferFull {
if h.secureErrorLogMessage {
return &ErrSmallBuffer{
error: errors.New("error when reading response headers"),
}
}
return &ErrSmallBuffer{
error: fmt.Errorf("error when reading response headers: %w", errSmallBuffer),
}
}
return fmt.Errorf("error when reading response headers: %w", err)
}
b = mustPeekBuffered(r)
headersLen, errParse := h.parse(b)
if errParse != nil {
return headerError("response", err, errParse, b, h.secureErrorLogMessage)
}
mustDiscard(r, headersLen)
return nil
}
// ReadTrailer reads response trailer header from r.
//
// io.EOF is returned if r is closed before reading the first byte.
func (h *header) ReadTrailer(r *bufio.Reader) error {
n := 1
for {
err := h.tryReadTrailer(r, n)
if err == nil {
return nil
}
if err != errNeedMore {
return err
}
n = r.Buffered() + 1
}
}
func (h *header) tryReadTrailer(r *bufio.Reader, n int) error {
b, err := r.Peek(n)
if len(b) == 0 {
// Return ErrTimeout on any timeout.
if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
return ErrTimeout
}
if n == 1 || err == io.EOF {
return io.EOF
}
// This is for go 1.6 bug. See https://github.com/golang/go/issues/14121 .
if err == bufio.ErrBufferFull {
if h.secureErrorLogMessage {
return &ErrSmallBuffer{
error: errors.New("error when reading response trailer"),
}
}
return &ErrSmallBuffer{
error: fmt.Errorf("error when reading response trailer: %w", errSmallBuffer),
}
}
return fmt.Errorf("error when reading response trailer: %w", err)
}
b = mustPeekBuffered(r)
trailers, headersLen, errParse := parseTrailer(b, h.h, h.disableNormalizing)
h.h = trailers
if errParse != nil {
if err == io.EOF {
return err
}
return headerError("response", err, errParse, b, h.secureErrorLogMessage)
}
mustDiscard(r, headersLen)
return nil
}
func headerError(typ string, err, errParse error, b []byte, secureErrorLogMessage bool) error {
if errParse != errNeedMore {
return headerErrorMsg(typ, errParse, b, secureErrorLogMessage)
}
if err == nil {
return errNeedMore
}
// Buggy servers may leave trailing CRLFs after http body.
// Treat this case as EOF.
if isOnlyCRLF(b) {
return io.EOF
}
if err != bufio.ErrBufferFull {
return headerErrorMsg(typ, err, b, secureErrorLogMessage)
}
return &ErrSmallBuffer{
error: headerErrorMsg(typ, errSmallBuffer, b, secureErrorLogMessage),
}
}
func headerErrorMsg(typ string, err error, b []byte, secureErrorLogMessage bool) error {
if secureErrorLogMessage {
return fmt.Errorf("error when reading %s headers: %w. Buffer size=%d", typ, err, len(b))
}
return fmt.Errorf("error when reading %s headers: %w. Buffer size=%d, contents: %s", typ, err, len(b), bufferSnippet(b))
}
// Read reads request header from r.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (h *RequestHeader) Read(r *bufio.Reader) error {
return h.readLoop(r, true)
}
// readLoop reads request header from r optionally loops until it has enough data.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (h *RequestHeader) readLoop(r *bufio.Reader, waitForMore bool) error {
n := 1
for {
err := h.tryRead(r, n)
if err == nil {
return nil
}
if !waitForMore || err != errNeedMore {
h.resetSkipNormalize()
return err
}
n = r.Buffered() + 1
}
}
func (h *RequestHeader) tryRead(r *bufio.Reader, n int) error {
h.resetSkipNormalize()
b, err := r.Peek(n)
if len(b) == 0 {
if err == io.EOF {
return err
}
if err == nil {
panic("bufio.Reader.Peek() returned nil, nil")
}
// This is for go 1.6 bug. See https://github.com/golang/go/issues/14121 .
if err == bufio.ErrBufferFull {
return &ErrSmallBuffer{
error: fmt.Errorf("error when reading request headers: %w (n=%d, r.Buffered()=%d)", errSmallBuffer, n, r.Buffered()),
}
}
// n == 1 on the first read for the request.
if n == 1 {
// We didn't read a single byte.
return ErrNothingRead{error: err}
}
return fmt.Errorf("error when reading request headers: %w", err)
}
b = mustPeekBuffered(r)
headersLen, errParse := h.parse(b)
if errParse != nil {
return headerError("request", err, errParse, b, h.secureErrorLogMessage)
}
mustDiscard(r, headersLen)
return nil
}
func bufferSnippet(b []byte) string {
n := len(b)
start := 200
end := n - start
if start >= end {
start = n
end = n
}
bStart, bEnd := b[:start], b[end:]
if len(bEnd) == 0 {
return fmt.Sprintf("%q", b)
}
return fmt.Sprintf("%q...%q", bStart, bEnd)
}
func isOnlyCRLF(b []byte) bool {
for _, ch := range b {
if ch != rChar && ch != nChar {
return false
}
}
return true
}
func updateServerDate() {
refreshServerDate()
go func() {
for {
time.Sleep(time.Second)
refreshServerDate()
}
}()
}
var (
serverDate atomic.Value
serverDateOnce sync.Once // serverDateOnce.Do(updateServerDate)
)
func refreshServerDate() {
b := AppendHTTPDate(nil, time.Now())
serverDate.Store(b)
}
// Write writes response header to w.
func (h *ResponseHeader) Write(w *bufio.Writer) error {
_, err := w.Write(h.Header())
return err
}
// WriteTo writes response header to w.
//
// WriteTo implements io.WriterTo interface.
func (h *ResponseHeader) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(h.Header())
return int64(n), err
}
// Header returns response header representation.
//
// Headers that set as Trailer will not represent. Use TrailerHeader for trailers.
//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (h *ResponseHeader) Header() []byte {
h.bufV = h.AppendBytes(h.bufV[:0])
return h.bufV
}
// writeTrailer writes response trailer to w.
func (h *ResponseHeader) writeTrailer(w *bufio.Writer) error {
_, err := w.Write(h.TrailerHeader())
return err
}
// TrailerHeader returns response trailer header representation.
//
// Trailers will only be received with chunked transfer.
//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (h *ResponseHeader) TrailerHeader() []byte {
h.bufV = h.bufV[:0]
for _, t := range h.trailer {
value := h.peek(t)
h.bufV = appendHeaderLine(h.bufV, t, value)
}
h.bufV = append(h.bufV, strCRLF...)
return h.bufV
}
// String returns response header representation.
func (h *ResponseHeader) String() string {
return string(h.Header())
}
// appendStatusLine appends the response status line to dst and returns
// the extended dst.
func (h *ResponseHeader) appendStatusLine(dst []byte) []byte {
statusCode := h.StatusCode()
if statusCode < 0 {
statusCode = StatusOK
}
return formatStatusLine(dst, h.Protocol(), statusCode, h.StatusMessage())
}
// AppendBytes appends response header representation to dst and returns
// the extended dst.
func (h *ResponseHeader) AppendBytes(dst []byte) []byte {
dst = h.appendStatusLine(dst[:0])
server := h.Server()
if len(server) != 0 {
dst = appendHeaderLine(dst, strServer, server)
}
if !h.noDefaultDate {
serverDateOnce.Do(updateServerDate)
dst = appendHeaderLine(dst, strDate, serverDate.Load().([]byte))
}
// Append Content-Type only for non-zero responses
// or if it is explicitly set.
// See https://github.com/valyala/fasthttp/issues/28 .
if h.ContentLength() != 0 || len(h.contentType) > 0 {
contentType := h.ContentType()
if len(contentType) > 0 {
dst = appendHeaderLine(dst, strContentType, contentType)
}
}
contentEncoding := h.ContentEncoding()
if len(contentEncoding) > 0 {
dst = appendHeaderLine(dst, strContentEncoding, contentEncoding)
}
if len(h.contentLengthBytes) > 0 {
dst = appendHeaderLine(dst, strContentLength, h.contentLengthBytes)
}
for i, n := 0, len(h.h); i < n; i++ {
kv := &h.h[i]
// Exclude trailer from header
exclude := false
for _, t := range h.trailer {
if bytes.Equal(kv.key, t) {
exclude = true
break
}
}
if !exclude && (h.noDefaultDate || !bytes.Equal(kv.key, strDate)) {
dst = appendHeaderLine(dst, kv.key, kv.value)
}
}
if len(h.trailer) > 0 {
dst = appendHeaderLine(dst, strTrailer, appendTrailerBytes(nil, h.trailer, strCommaSpace))
}
n := len(h.cookies)
if n > 0 {
for i := 0; i < n; i++ {
kv := &h.cookies[i]
dst = appendHeaderLine(dst, strSetCookie, kv.value)
}
}
if h.ConnectionClose() {
dst = appendHeaderLine(dst, strConnection, strClose)
}
return append(dst, strCRLF...)
}
// Write writes request header to w.
func (h *RequestHeader) Write(w *bufio.Writer) error {
_, err := w.Write(h.Header())
return err
}
// WriteTo writes request header to w.
//
// WriteTo implements io.WriterTo interface.
func (h *RequestHeader) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(h.Header())
return int64(n), err
}
// Header returns request header representation.
//
// Headers that set as Trailer will not represent. Use TrailerHeader for trailers.
//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (h *RequestHeader) Header() []byte {
h.bufV = h.AppendBytes(h.bufV[:0])
return h.bufV
}
// writeTrailer writes request trailer to w.
func (h *RequestHeader) writeTrailer(w *bufio.Writer) error {
_, err := w.Write(h.TrailerHeader())
return err
}
// TrailerHeader returns request trailer header representation.
//
// Trailers will only be received with chunked transfer.
//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (h *RequestHeader) TrailerHeader() []byte {
h.bufV = h.bufV[:0]
for _, t := range h.trailer {
value := h.peek(t)
h.bufV = appendHeaderLine(h.bufV, t, value)
}
h.bufV = append(h.bufV, strCRLF...)
return h.bufV
}
// RawHeaders returns raw header key/value bytes.
//
// Depending on server configuration, header keys may be normalized to
// capital-case in place.
//
// This copy is set aside during parsing, so empty slice is returned for all
// cases where parsing did not happen. Similarly, request line is not stored
// during parsing and can not be returned.
//
// The slice is not safe to use after the handler returns.
func (h *RequestHeader) RawHeaders() []byte {
return h.rawHeaders
}
// String returns request header representation.
func (h *RequestHeader) String() string {
return string(h.Header())
}
// AppendBytes appends request header representation to dst and returns
// the extended dst.
func (h *RequestHeader) AppendBytes(dst []byte) []byte {
dst = append(dst, h.Method()...)
dst = append(dst, ' ')
dst = append(dst, h.RequestURI()...)
dst = append(dst, ' ')
dst = append(dst, h.Protocol()...)
dst = append(dst, strCRLF...)
userAgent := h.UserAgent()
if len(userAgent) > 0 && !h.disableSpecialHeader {
dst = appendHeaderLine(dst, strUserAgent, userAgent)
}
host := h.Host()
if len(host) > 0 && !h.disableSpecialHeader {
dst = appendHeaderLine(dst, strHost, host)
}
contentType := h.ContentType()
if !h.noDefaultContentType && len(contentType) == 0 && !h.ignoreBody() {
contentType = strDefaultContentType
}
if len(contentType) > 0 && !h.disableSpecialHeader {
dst = appendHeaderLine(dst, strContentType, contentType)
}
if len(h.contentLengthBytes) > 0 && !h.disableSpecialHeader {
dst = appendHeaderLine(dst, strContentLength, h.contentLengthBytes)
}
for i, n := 0, len(h.h); i < n; i++ {
kv := &h.h[i]
// Exclude trailer from header
exclude := false
for _, t := range h.trailer {
if bytes.Equal(kv.key, t) {
exclude = true
break
}
}
if !exclude {
dst = appendHeaderLine(dst, kv.key, kv.value)
}
}
if len(h.trailer) > 0 {
dst = appendHeaderLine(dst, strTrailer, appendTrailerBytes(nil, h.trailer, strCommaSpace))
}
// there is no need in h.collectCookies() here, since if cookies aren't collected yet,
// they all are located in h.h.
n := len(h.cookies)
if n > 0 && !h.disableSpecialHeader {
dst = append(dst, strCookie...)
dst = append(dst, strColonSpace...)
dst = appendRequestCookieBytes(dst, h.cookies)
dst = append(dst, strCRLF...)
}
if h.ConnectionClose() && !h.disableSpecialHeader {
dst = appendHeaderLine(dst, strConnection, strClose)
}
return append(dst, strCRLF...)
}
func appendHeaderLine(dst, key, value []byte) []byte {
dst = append(dst, key...)
dst = append(dst, strColonSpace...)
dst = append(dst, value...)
return append(dst, strCRLF...)
}
func (h *ResponseHeader) parse(buf []byte) (int, error) {
m, err := h.parseFirstLine(buf)
if err != nil {
return 0, err
}
n, err := h.parseHeaders(buf[m:])
if err != nil {
return 0, err
}
return m + n, nil
}
func (h *RequestHeader) ignoreBody() bool {
return h.IsGet() || h.IsHead()
}
func (h *RequestHeader) parse(buf []byte) (int, error) {
m, err := h.parseFirstLine(buf)
if err != nil {
return 0, err
}
h.rawHeaders, _, err = readRawHeaders(h.rawHeaders[:0], buf[m:])
if err != nil {
return 0, err
}
var n int
n, err = h.parseHeaders(buf[m:])
if err != nil {
return 0, err
}
return m + n, nil
}
func parseTrailer(src []byte, dest []argsKV, disableNormalizing bool) ([]argsKV, int, error) {
// Skip any 0 length chunk.
if src[0] == '0' {
skip := len(strCRLF) + 1
if len(src) < skip {
return dest, 0, io.EOF
}
src = src[skip:]
}
var s headerScanner
s.b = src
for s.next() {
if len(s.key) == 0 {
continue
}
disable := disableNormalizing
for _, ch := range s.key {
if !validHeaderFieldByte(ch) {
// We accept invalid headers with a space before the
// colon, but must not canonicalize them.
// See: https://github.com/valyala/fasthttp/issues/1917
if ch == ' ' {
disable = true
continue
}
return dest, 0, fmt.Errorf("invalid trailer key %q", s.key)
}
}
// Forbidden by RFC 7230, section 4.1.2
if isBadTrailer(s.key) {
return dest, 0, fmt.Errorf("forbidden trailer key %q", s.key)
}
normalizeHeaderKey(s.key, disable)
dest = appendArgBytes(dest, s.key, s.value, argsHasValue)
}
if s.err != nil {
return dest, 0, s.err
}
return dest, s.hLen, nil
}
func isBadTrailer(key []byte) bool {
if len(key) == 0 {
return true
}
switch key[0] | 0x20 {
case 'a':
return caseInsensitiveCompare(key, strAuthorization)
case 'c':
if len(key) > len(HeaderContentType) && caseInsensitiveCompare(key[:8], strContentType[:8]) {
// skip compare prefix 'Content-'
return caseInsensitiveCompare(key[8:], strContentEncoding[8:]) ||
caseInsensitiveCompare(key[8:], strContentLength[8:]) ||
caseInsensitiveCompare(key[8:], strContentType[8:]) ||
caseInsensitiveCompare(key[8:], strContentRange[8:])
}
return caseInsensitiveCompare(key, strConnection)
case 'e':
return caseInsensitiveCompare(key, strExpect)
case 'h':
return caseInsensitiveCompare(key, strHost)
case 'k':
return caseInsensitiveCompare(key, strKeepAlive)
case 'm':
return caseInsensitiveCompare(key, strMaxForwards)
case 'p':
if len(key) > len(HeaderProxyConnection) && caseInsensitiveCompare(key[:6], strProxyConnection[:6]) {
// skip compare prefix 'Proxy-'
return caseInsensitiveCompare(key[6:], strProxyConnection[6:]) ||
caseInsensitiveCompare(key[6:], strProxyAuthenticate[6:]) ||
caseInsensitiveCompare(key[6:], strProxyAuthorization[6:])
}
case 'r':
return caseInsensitiveCompare(key, strRange)
case 't':
return caseInsensitiveCompare(key, strTE) ||
caseInsensitiveCompare(key, strTrailer) ||
caseInsensitiveCompare(key, strTransferEncoding)
case 'w':
return caseInsensitiveCompare(key, strWWWAuthenticate)
}
return false
}
func (h *ResponseHeader) parseFirstLine(buf []byte) (int, error) {
bNext := buf
var b []byte
var err error
for len(b) == 0 {
if b, bNext, err = nextLine(bNext); err != nil {
return 0, err
}
}
// parse protocol
n := bytes.IndexByte(b, ' ')
if n < 0 {
if h.secureErrorLogMessage {
return 0, errors.New("cannot find whitespace in the first line of response")
}
return 0, fmt.Errorf("cannot find whitespace in the first line of response %q", buf)
}
h.noHTTP11 = !bytes.Equal(b[:n], strHTTP11)
b = b[n+1:]
// parse status code
h.statusCode, n, err = parseUintBuf(b)
if err != nil {
if h.secureErrorLogMessage {
return 0, fmt.Errorf("cannot parse response status code: %w", err)
}
return 0, fmt.Errorf("cannot parse response status code: %w. Response %q", err, buf)
}
if len(b) > n && b[n] != ' ' {
if h.secureErrorLogMessage {
return 0, errors.New("unexpected char at the end of status code")
}
return 0, fmt.Errorf("unexpected char at the end of status code. Response %q", buf)
}
if len(b) > n+1 {
h.SetStatusMessage(b[n+1:])
}
return len(buf) - len(bNext), nil
}
func isValidMethod(method []byte) bool {
for _, ch := range method {
if validMethodValueByteTable[ch] == 0 {
return false
}
}
return true
}
func (h *RequestHeader) parseFirstLine(buf []byte) (int, error) {
bNext := buf
var b []byte
var err error
for len(b) == 0 {
if b, bNext, err = nextLine(bNext); err != nil {
return 0, err
}
}
// parse method
n := bytes.IndexByte(b, ' ')
if n <= 0 {
if h.secureErrorLogMessage {
return 0, errors.New("cannot find http request method")
}
return 0, fmt.Errorf("cannot find http request method in %q", buf)
}
h.method = append(h.method[:0], b[:n]...)
if !isValidMethod(h.method) {
if h.secureErrorLogMessage {
return 0, errors.New("unsupported http request method")
}
return 0, fmt.Errorf("unsupported http request method %q in %q", h.method, buf)
}
b = b[n+1:]
// parse requestURI
n = bytes.LastIndexByte(b, ' ')
if n < 0 {
return 0, fmt.Errorf("cannot find whitespace in the first line of request %q", buf)
} else if n == 0 {
if h.secureErrorLogMessage {
return 0, errors.New("requestURI cannot be empty")
}
return 0, fmt.Errorf("requestURI cannot be empty in %q", buf)
}
protoStr := b[n+1:]
// Follow RFCs 7230 and 9112 and require that HTTP versions match the following pattern: HTTP/[0-9]\.[0-9]
if len(protoStr) != len(strHTTP11) {
if h.secureErrorLogMessage {
return 0, fmt.Errorf("unsupported HTTP version %q", protoStr)
}
return 0, fmt.Errorf("unsupported HTTP version %q in %q", protoStr, buf)
}
if !bytes.HasPrefix(protoStr, strHTTP11[:5]) {
if h.secureErrorLogMessage {
return 0, fmt.Errorf("unsupported HTTP version %q", protoStr)
}
return 0, fmt.Errorf("unsupported HTTP version %q in %q", protoStr, buf)
}
if protoStr[5] < '0' || protoStr[5] > '9' || protoStr[7] < '0' || protoStr[7] > '9' {
if h.secureErrorLogMessage {
return 0, fmt.Errorf("unsupported HTTP version %q", protoStr)
}
return 0, fmt.Errorf("unsupported HTTP version %q in %q", protoStr, buf)
}
h.noHTTP11 = !bytes.Equal(protoStr, strHTTP11)
h.protocol = append(h.protocol[:0], protoStr...)
h.requestURI = append(h.requestURI[:0], b[:n]...)
return len(buf) - len(bNext), nil
}
func readRawHeaders(dst, buf []byte) ([]byte, int, error) {
n := bytes.IndexByte(buf, nChar)
if n < 0 {
return dst[:0], 0, errNeedMore
}
if (n == 1 && buf[0] == rChar) || n == 0 {
// empty headers
return dst, n + 1, nil
}
n++
b := buf
m := n
for {
b = b[m:]
m = bytes.IndexByte(b, nChar)
if m < 0 {
return dst, 0, errNeedMore
}
m++
n += m
if (m == 2 && b[0] == rChar) || m == 1 {
dst = append(dst, buf[:n]...)
return dst, n, nil
}
}
}
func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) {
// 'identity' content-length by default
h.contentLength = -2
var s headerScanner
s.b = buf
var kv *argsKV
for s.next() {
if len(s.key) == 0 {
h.connectionClose = true
return 0, fmt.Errorf("invalid header key %q", s.key)
}
disableNormalizing := h.disableNormalizing
for _, ch := range s.key {
if !validHeaderFieldByte(ch) {
h.connectionClose = true
// We accept invalid headers with a space before the
// colon, but must not canonicalize them.
// See: https://github.com/valyala/fasthttp/issues/1917
if ch == ' ' {
disableNormalizing = true
continue
}
return 0, fmt.Errorf("invalid header key %q", s.key)
}
}
normalizeHeaderKey(s.key, disableNormalizing)
for _, ch := range s.value {
if !validHeaderValueByte(ch) {
h.connectionClose = true
return 0, fmt.Errorf("invalid header value %q", s.value)
}
}
switch s.key[0] | 0x20 {
case 'c':
if caseInsensitiveCompare(s.key, strContentType) {
h.contentType = append(h.contentType[:0], s.value...)
continue
}
if caseInsensitiveCompare(s.key, strContentEncoding) {
h.contentEncoding = append(h.contentEncoding[:0], s.value...)
continue
}
if caseInsensitiveCompare(s.key, strContentLength) {
if h.contentLength != -1 {
var err error
h.contentLength, err = parseContentLength(s.value)
if err != nil {
h.contentLength = -2
h.connectionClose = true
return 0, err
}
h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...)
}
continue
}
if caseInsensitiveCompare(s.key, strConnection) {
if bytes.Equal(s.value, strClose) {
h.connectionClose = true
} else {
h.connectionClose = false
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
continue
}
case 's':
if caseInsensitiveCompare(s.key, strServer) {
h.server = append(h.server[:0], s.value...)
continue
}
if caseInsensitiveCompare(s.key, strSetCookie) {
h.cookies, kv = allocArg(h.cookies)
kv.key = getCookieKey(kv.key, s.value)
kv.value = append(kv.value[:0], s.value...)
continue
}
case 't':
if caseInsensitiveCompare(s.key, strTransferEncoding) {
if len(s.value) > 0 && !bytes.Equal(s.value, strIdentity) {
h.contentLength = -1
h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue)
}
continue
}
if caseInsensitiveCompare(s.key, strTrailer) {
err := h.SetTrailerBytes(s.value)
if err != nil {
h.connectionClose = true
return 0, err
}
continue
}
}
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
if s.err != nil {
h.connectionClose = true
return 0, s.err
}
if h.contentLength < 0 {
h.contentLengthBytes = h.contentLengthBytes[:0]
}
if h.contentLength == -2 && !h.ConnectionUpgrade() && !h.mustSkipContentLength() {
// According to modern HTTP/1.1 specifications (RFC 7230):
// `identity` as a value for `Transfer-Encoding` was removed
// in the errata to RFC 2616.
// Therefore, we do not include `Transfer-Encoding: identity` in the header.
// See: https://github.com/valyala/fasthttp/issues/1909
h.connectionClose = true
}
if h.noHTTP11 && !h.connectionClose {
// close connection for non-http/1.1 response unless 'Connection: keep-alive' is set.
v := peekArgBytes(h.h, strConnection)
h.connectionClose = !hasHeaderValue(v, strKeepAlive)
}
return len(buf) - len(s.b), nil
}
func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
h.contentLength = -2
contentLengthSeen := false
var s headerScanner
s.b = buf
for s.next() {
if len(s.key) == 0 {
h.connectionClose = true
return 0, fmt.Errorf("invalid header key %q", s.key)
}
disableNormalizing := h.disableNormalizing
for _, ch := range s.key {
if !validHeaderFieldByte(ch) {
if ch == ' ' {
disableNormalizing = true
continue
}
h.connectionClose = true
return 0, fmt.Errorf("invalid header key %q", s.key)
}
}
normalizeHeaderKey(s.key, disableNormalizing)
for _, ch := range s.value {
if !validHeaderValueByte(ch) {
h.connectionClose = true
return 0, fmt.Errorf("invalid header value %q", s.value)
}
}
if h.disableSpecialHeader {
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
continue
}
switch s.key[0] | 0x20 {
case 'h':
if caseInsensitiveCompare(s.key, strHost) {
h.host = append(h.host[:0], s.value...)
continue
}
case 'u':
if caseInsensitiveCompare(s.key, strUserAgent) {
h.userAgent = append(h.userAgent[:0], s.value...)
continue
}
case 'c':
if caseInsensitiveCompare(s.key, strContentType) {
h.contentType = append(h.contentType[:0], s.value...)
continue
}
if caseInsensitiveCompare(s.key, strContentLength) {
if contentLengthSeen {
h.connectionClose = true
return 0, errors.New("duplicate Content-Length header")
}
contentLengthSeen = true
if h.contentLength != -1 {
var err error
h.contentLength, err = parseContentLength(s.value)
if err != nil {
h.contentLength = -2
h.connectionClose = true
return 0, err
}
h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...)
}
continue
}
if caseInsensitiveCompare(s.key, strConnection) {
if bytes.Equal(s.value, strClose) {
h.connectionClose = true
} else {
h.connectionClose = false
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
continue
}
case 't':
if caseInsensitiveCompare(s.key, strTransferEncoding) {
isIdentity := caseInsensitiveCompare(s.value, strIdentity)
isChunked := caseInsensitiveCompare(s.value, strChunked)
if !isIdentity && !isChunked {
h.connectionClose = true
if h.secureErrorLogMessage {
return 0, errors.New("unsupported Transfer-Encoding")
}
return 0, fmt.Errorf("unsupported Transfer-Encoding: %q", s.value)
}
if isChunked {
h.contentLength = -1
h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue)
}
continue
}
if caseInsensitiveCompare(s.key, strTrailer) {
err := h.SetTrailerBytes(s.value)
if err != nil {
h.connectionClose = true
return 0, err
}
continue
}
}
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
if s.err != nil {
h.connectionClose = true
return 0, s.err
}
if h.contentLength < 0 {
h.contentLengthBytes = h.contentLengthBytes[:0]
}
if h.noHTTP11 && !h.connectionClose {
// close connection for non-http/1.1 request unless 'Connection: keep-alive' is set.
v := peekArgBytes(h.h, strConnection)
h.connectionClose = !hasHeaderValue(v, strKeepAlive)
}
return s.hLen, nil
}
func (h *RequestHeader) collectCookies() {
if h.cookiesCollected {
return
}
for i, n := 0, len(h.h); i < n; i++ {
kv := &h.h[i]
if caseInsensitiveCompare(kv.key, strCookie) {
h.cookies = parseRequestCookies(h.cookies, kv.value)
tmp := *kv
copy(h.h[i:], h.h[i+1:])
n--
i--
h.h[n] = tmp
h.h = h.h[:n]
}
}
h.cookiesCollected = true
}
var errNonNumericChars = errors.New("non-numeric chars found")
func parseContentLength(b []byte) (int, error) {
v, n, err := parseUintBuf(b)
if err != nil {
return -1, fmt.Errorf("cannot parse Content-Length: %w", err)
}
if n != len(b) {
return -1, fmt.Errorf("cannot parse Content-Length: %w", errNonNumericChars)
}
return v, nil
}
type headerScanner struct {
err error
b []byte
key []byte
value []byte
// hLen stores header subslice len
hLen int
// by checking whether the next line contains a colon or not to tell
// it's a header entry or a multi line value of current header entry.
// the side effect of this operation is that we know the index of the
// next colon and new line, so this can be used during next iteration,
// instead of find them again.
nextColon int
nextNewLine int
initialized bool
// This is only used to print the deprecated newline separator warning.
// TODO: Remove this again once the newline separator is removed.
warned bool
}
// DeprecatedNewlineIncludeContext is used to control whether the context of the
// header is included in the warning message about the deprecated newline
// separator.
// Warning: this can potentially leak sensitive information such as auth headers.
var DeprecatedNewlineIncludeContext atomic.Bool
// TODO: Remove this again once the newline separator is removed.
var warnedAboutDeprecatedNewlineSeparatorLimiter atomic.Int64
func (s *headerScanner) next() bool {
if !s.initialized {
s.nextColon = -1
s.nextNewLine = -1
s.initialized = true
}
bLen := len(s.b)
if bLen >= 2 && s.b[0] == rChar && s.b[1] == nChar {
s.b = s.b[2:]
s.hLen += 2
return false
}
if bLen >= 1 && s.b[0] == nChar {
s.b = s.b[1:]
s.hLen++
return false
}
var n int
if s.nextColon >= 0 {
n = s.nextColon
s.nextColon = -1
} else {
n = bytes.IndexByte(s.b, ':')
// There can't be a \n inside the header name, check for this.
x := bytes.IndexByte(s.b, nChar)
if x < 0 {
// A header name should always at some point be followed by a \n
// even if it's the one that terminates the header block.
s.err = errNeedMore
return false
}
if x < n {
// There was a \n before the :
s.err = errInvalidName
return false
}
// If the character before '\n' isn't '\r', print a warning.
if !s.warned && x > 1 && s.b[x-1] != rChar {
// Only warn once per second.
now := time.Now().Unix()
if warnedAboutDeprecatedNewlineSeparatorLimiter.Load() < now {
if warnedAboutDeprecatedNewlineSeparatorLimiter.Swap(now) < now {
if DeprecatedNewlineIncludeContext.Load() {
// Include 20 characters after the '\n'.
xx := x + 20
if len(s.b) < xx {
xx = len(s.b)
}
slog.Error("Deprecated newline only separator found in header", "context", fmt.Sprintf("%q", s.b[:xx]))
} else {
slog.Error("Deprecated newline only separator found in header")
}
s.warned = true
}
}
}
}
if n < 0 {
s.err = errNeedMore
return false
}
s.key = s.b[:n]
n++
for len(s.b) > n && (s.b[n] == ' ' || s.b[n] == '\t') {
n++
// the newline index is a relative index, and lines below trimmed `s.b` by `n`,
// so the relative newline index also shifted forward. it's safe to decrease
// to a minus value, it means it's invalid, and will find the newline again.
s.nextNewLine--
}
s.hLen += n
s.b = s.b[n:]
if s.nextNewLine >= 0 {
n = s.nextNewLine
s.nextNewLine = -1
} else {
n = bytes.IndexByte(s.b, nChar)
}
if n < 0 {
s.err = errNeedMore
return false
}
for n+1 < len(s.b) {
if s.b[n+1] != ' ' && s.b[n+1] != '\t' {
break
}
d := bytes.IndexByte(s.b[n+1:], nChar)
if d <= 0 {
break
} else if d == 1 && s.b[n+1] == rChar {
break
}
e := n + d + 1
if c := bytes.IndexByte(s.b[n+1:e], ':'); c >= 0 {
s.nextColon = c
s.nextNewLine = d - c - 1
break
}
n = e
}
if n >= len(s.b) {
s.err = errNeedMore
return false
}
s.value = s.b[:n]
s.hLen += n + 1
s.b = s.b[n+1:]
if n > 0 && s.value[n-1] == rChar {
n--
}
for n > 0 && (s.value[n-1] == ' ' || s.value[n-1] == '\t') {
n--
}
s.value = s.value[:n]
if bytes.Contains(s.b, strCRLF) {
s.value = normalizeHeaderValue(s.value)
}
return true
}
type headerValueScanner struct {
b []byte
value []byte
}
func (s *headerValueScanner) next() bool {
b := s.b
if len(b) == 0 {
return false
}
n := bytes.IndexByte(b, ',')
if n < 0 {
s.value = stripSpace(b)
s.b = b[len(b):]
return true
}
s.value = stripSpace(b[:n])
s.b = b[n+1:]
return true
}
func stripSpace(b []byte) []byte {
for len(b) > 0 && b[0] == ' ' {
b = b[1:]
}
for len(b) > 0 && b[len(b)-1] == ' ' {
b = b[:len(b)-1]
}
return b
}
func hasHeaderValue(s, value []byte) bool {
var vs headerValueScanner
vs.b = s
for vs.next() {
if caseInsensitiveCompare(vs.value, value) {
return true
}
}
return false
}
func nextLine(b []byte) ([]byte, []byte, error) {
nNext := bytes.IndexByte(b, nChar)
if nNext < 0 {
return nil, nil, errNeedMore
}
n := nNext
if n > 0 && b[n-1] == rChar {
n--
}
return b[:n], b[nNext+1:], nil
}
func initHeaderKV(bufK, bufV []byte, key, value string, disableNormalizing bool) ([]byte, []byte) {
bufK = getHeaderKeyBytes(bufK, key, disableNormalizing)
// https://tools.ietf.org/html/rfc7230#section-3.2.4
bufV = append(bufV[:0], value...)
bufV = removeNewLines(bufV)
return bufK, bufV
}
func getHeaderKeyBytes(bufK []byte, key string, disableNormalizing bool) []byte {
bufK = append(bufK[:0], key...)
normalizeHeaderKey(bufK, disableNormalizing || bytes.IndexByte(bufK, ' ') != -1)
return bufK
}
func normalizeHeaderValue(ov []byte) (nv []byte) {
nv = ov
length := len(ov)
if length <= 0 {
return
}
write := 0
shrunk := 0
once := false
lineStart := false
for read := 0; read < length; read++ {
c := ov[read]
switch {
case c == rChar || c == nChar:
shrunk++
if c == nChar {
lineStart = true
once = false
}
continue
case lineStart && (c == '\t' || c == ' '):
if !once {
c = ' '
once = true
} else {
shrunk++
continue
}
default:
lineStart = false
}
nv[write] = c
write++
}
nv = nv[:length-shrunk]
return
}
func normalizeHeaderKey(b []byte, disableNormalizing bool) {
if disableNormalizing {
return
}
n := len(b)
if n == 0 {
return
}
b[0] = toUpperTable[b[0]]
for i := 1; i < n; i++ {
p := &b[i]
if *p == '-' {
i++
if i < n {
b[i] = toUpperTable[b[i]]
}
continue
}
*p = toLowerTable[*p]
}
}
// removeNewLines will replace `\r` and `\n` with an empty space.
func removeNewLines(raw []byte) []byte {
// check if a `\r` is present and save the position.
// if no `\r` is found, check if a `\n` is present.
foundR := bytes.IndexByte(raw, rChar)
foundN := bytes.IndexByte(raw, nChar)
start := 0
switch {
case foundN != -1:
if foundR > foundN {
start = foundN
} else if foundR != -1 {
start = foundR
}
case foundR != -1:
start = foundR
default:
return raw
}
for i := start; i < len(raw); i++ {
switch raw[i] {
case rChar, nChar:
raw[i] = ' '
default:
continue
}
}
return raw
}
// AppendNormalizedHeaderKey appends normalized header key (name) to dst
// and returns the resulting dst.
//
// Normalized header key starts with uppercase letter. The first letters
// after dashes are also uppercased. All the other letters are lowercased.
// Examples:
//
// - coNTENT-TYPe -> Content-Type
// - HOST -> Host
// - foo-bar-baz -> Foo-Bar-Baz
func AppendNormalizedHeaderKey(dst []byte, key string) []byte {
dst = append(dst, key...)
normalizeHeaderKey(dst[len(dst)-len(key):], false)
return dst
}
// AppendNormalizedHeaderKeyBytes appends normalized header key (name) to dst
// and returns the resulting dst.
//
// Normalized header key starts with uppercase letter. The first letters
// after dashes are also uppercased. All the other letters are lowercased.
// Examples:
//
// - coNTENT-TYPe -> Content-Type
// - HOST -> Host
// - foo-bar-baz -> Foo-Bar-Baz
func AppendNormalizedHeaderKeyBytes(dst, key []byte) []byte {
return AppendNormalizedHeaderKey(dst, b2s(key))
}
func appendTrailerBytes(dst []byte, trailer [][]byte, sep []byte) []byte {
for i, n := 0, len(trailer); i < n; i++ {
dst = append(dst, trailer[i]...)
if i+1 < n {
dst = append(dst, sep...)
}
}
return dst
}
func copyTrailer(dst, src [][]byte) [][]byte {
if cap(dst) > len(src) {
dst = dst[:len(src)]
} else {
dst = append(dst[:0], src...)
}
for i := range dst {
dst[i] = make([]byte, len(src[i]))
copy(dst[i], src[i])
}
return dst
}
var (
errNeedMore = errors.New("need more data: cannot find trailing lf")
errInvalidName = errors.New("invalid header name")
errSmallBuffer = errors.New("small read buffer. Increase ReadBufferSize")
)
// ErrNothingRead is returned when a keep-alive connection is closed,
// either because the remote closed it or because of a read timeout.
type ErrNothingRead struct {
error
}
// ErrSmallBuffer is returned when the provided buffer size is too small
// for reading request and/or response headers.
//
// ReadBufferSize value from Server or clients should reduce the number
// of such errors.
type ErrSmallBuffer struct {
error
}
func mustPeekBuffered(r *bufio.Reader) []byte {
buf, err := r.Peek(r.Buffered())
if len(buf) == 0 || err != nil {
panic(fmt.Sprintf("bufio.Reader.Peek() returned unexpected data (%q, %v)", buf, err))
}
return buf
}
func mustDiscard(r *bufio.Reader, n int) {
if _, err := r.Discard(n); err != nil {
panic(fmt.Sprintf("bufio.Reader.Discard(%d) failed: %v", n, err))
}
}
package fasthttp
import (
"bufio"
"bytes"
"compress/gzip"
"encoding/base64"
"errors"
"fmt"
"io"
"mime/multipart"
"net"
"os"
"sync"
"time"
"github.com/valyala/bytebufferpool"
)
var (
requestBodyPoolSizeLimit = -1
responseBodyPoolSizeLimit = -1
)
// SetBodySizePoolLimit set the max body size for bodies to be returned to the pool.
// If the body size is larger it will be released instead of put back into the pool for reuse.
func SetBodySizePoolLimit(reqBodyLimit, respBodyLimit int) {
requestBodyPoolSizeLimit = reqBodyLimit
responseBodyPoolSizeLimit = respBodyLimit
}
// Request represents HTTP request.
//
// It is forbidden copying Request instances. Create new instances
// and use CopyTo instead.
//
// Request instance MUST NOT be used from concurrently running goroutines.
type Request struct {
noCopy noCopy
bodyStream io.Reader
w requestBodyWriter
body *bytebufferpool.ByteBuffer
multipartForm *multipart.Form
multipartFormBoundary string
postArgs Args
userValues userData
bodyRaw []byte
uri URI
// Request header.
//
// Copying Header by value is forbidden. Use pointer to Header instead.
Header RequestHeader
// Request timeout. Usually set by DoDeadline or DoTimeout
// if <= 0, means not set
timeout time.Duration
secureErrorLogMessage bool
// Group bool members in order to reduce Request object size.
parsedURI bool
parsedPostArgs bool
keepBodyBuffer bool
// Used by Server to indicate the request was received on a HTTPS endpoint.
// Client/HostClient shouldn't use this field but should depend on the uri.scheme instead.
isTLS bool
// Use Host header (request.Header.SetHost) instead of the host from SetRequestURI, SetHost, or URI().SetHost
UseHostHeader bool
// DisableRedirectPathNormalizing disables redirect path normalization when used with DoRedirects.
//
// By default redirect path values are normalized, i.e.
// extra slashes are removed, special characters are encoded.
DisableRedirectPathNormalizing bool
}
// Response represents HTTP response.
//
// It is forbidden copying Response instances. Create new instances
// and use CopyTo instead.
//
// Response instance MUST NOT be used from concurrently running goroutines.
type Response struct {
noCopy noCopy
bodyStream io.Reader
// Remote TCPAddr from concurrently net.Conn.
raddr net.Addr
// Local TCPAddr from concurrently net.Conn.
laddr net.Addr
w responseBodyWriter
body *bytebufferpool.ByteBuffer
bodyRaw []byte
// Response header.
//
// Copying Header by value is forbidden. Use pointer to Header instead.
Header ResponseHeader
// Flush headers as soon as possible without waiting for first body bytes.
// Relevant for bodyStream only.
ImmediateHeaderFlush bool
// StreamBody enables response body streaming.
// Use SetBodyStream to set the body stream.
StreamBody bool
// Response.Read() skips reading body if set to true.
// Use it for reading HEAD responses.
//
// Response.Write() skips writing body if set to true.
// Use it for writing HEAD responses.
SkipBody bool
keepBodyBuffer bool
secureErrorLogMessage bool
}
// SetHost sets host for the request.
func (req *Request) SetHost(host string) {
req.URI().SetHost(host)
}
// SetHostBytes sets host for the request.
func (req *Request) SetHostBytes(host []byte) {
req.URI().SetHostBytes(host)
}
// Host returns the host for the given request.
func (req *Request) Host() []byte {
return req.URI().Host()
}
// SetRequestURI sets RequestURI.
func (req *Request) SetRequestURI(requestURI string) {
req.Header.SetRequestURI(requestURI)
req.parsedURI = false
}
// SetRequestURIBytes sets RequestURI.
func (req *Request) SetRequestURIBytes(requestURI []byte) {
req.Header.SetRequestURIBytes(requestURI)
req.parsedURI = false
}
// RequestURI returns request's URI.
func (req *Request) RequestURI() []byte {
if req.parsedURI {
requestURI := req.uri.RequestURI()
req.SetRequestURIBytes(requestURI)
}
return req.Header.RequestURI()
}
// StatusCode returns response status code.
func (resp *Response) StatusCode() int {
return resp.Header.StatusCode()
}
// SetStatusCode sets response status code.
func (resp *Response) SetStatusCode(statusCode int) {
resp.Header.SetStatusCode(statusCode)
}
// ConnectionClose returns true if 'Connection: close' header is set.
func (resp *Response) ConnectionClose() bool {
return resp.Header.ConnectionClose()
}
// SetConnectionClose sets 'Connection: close' header.
func (resp *Response) SetConnectionClose() {
resp.Header.SetConnectionClose()
}
// ConnectionClose returns true if 'Connection: close' header is set.
func (req *Request) ConnectionClose() bool {
return req.Header.ConnectionClose()
}
// SetConnectionClose sets 'Connection: close' header.
func (req *Request) SetConnectionClose() {
req.Header.SetConnectionClose()
}
// GetTimeOut retrieves the timeout duration set for the Request.
//
// This method returns a time.Duration that determines how long the request
// can wait before it times out. In the default use case, the timeout applies
// to the entire request lifecycle, including both receiving the response
// headers and the response body.
func (req *Request) GetTimeOut() time.Duration {
return req.timeout
}
// SendFile registers file on the given path to be used as response body
// when Write is called.
//
// Note that SendFile doesn't set Content-Type, so set it yourself
// with Header.SetContentType.
func (resp *Response) SendFile(path string) error {
f, err := os.Open(path)
if err != nil {
return err
}
fileInfo, err := f.Stat()
if err != nil {
f.Close()
return err
}
size64 := fileInfo.Size()
size := int(size64)
if int64(size) != size64 {
size = -1
}
resp.Header.SetLastModified(fileInfo.ModTime())
resp.SetBodyStream(f, size)
return nil
}
// SetBodyStream sets request body stream and, optionally body size.
//
// If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes
// before returning io.EOF.
//
// If bodySize < 0, then bodyStream is read until io.EOF.
//
// bodyStream.Close() is called after finishing reading all body data
// if it implements io.Closer.
//
// Note that GET and HEAD requests cannot have body.
//
// See also SetBodyStreamWriter.
func (req *Request) SetBodyStream(bodyStream io.Reader, bodySize int) {
req.ResetBody()
req.bodyStream = bodyStream
req.Header.SetContentLength(bodySize)
}
// SetBodyStream sets response body stream and, optionally body size.
//
// If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes
// before returning io.EOF.
//
// If bodySize < 0, then bodyStream is read until io.EOF.
//
// bodyStream.Close() is called after finishing reading all body data
// if it implements io.Closer.
//
// See also SetBodyStreamWriter.
func (resp *Response) SetBodyStream(bodyStream io.Reader, bodySize int) {
resp.ResetBody()
resp.bodyStream = bodyStream
resp.Header.SetContentLength(bodySize)
}
// IsBodyStream returns true if body is set via SetBodyStream*.
func (req *Request) IsBodyStream() bool {
return req.bodyStream != nil
}
// IsBodyStream returns true if body is set via SetBodyStream*.
func (resp *Response) IsBodyStream() bool {
return resp.bodyStream != nil
}
// SetBodyStreamWriter registers the given sw for populating request body.
//
// This function may be used in the following cases:
//
// - if request body is too big (more than 10MB).
// - if request body is streamed from slow external sources.
// - if request body must be streamed to the server in chunks
// (aka `http client push` or `chunked transfer-encoding`).
//
// Note that GET and HEAD requests cannot have body.
//
// See also SetBodyStream.
func (req *Request) SetBodyStreamWriter(sw StreamWriter) {
sr := NewStreamReader(sw)
req.SetBodyStream(sr, -1)
}
// SetBodyStreamWriter registers the given sw for populating response body.
//
// This function may be used in the following cases:
//
// - if response body is too big (more than 10MB).
// - if response body is streamed from slow external sources.
// - if response body must be streamed to the client in chunks
// (aka `http server push` or `chunked transfer-encoding`).
//
// See also SetBodyStream.
func (resp *Response) SetBodyStreamWriter(sw StreamWriter) {
sr := NewStreamReader(sw)
resp.SetBodyStream(sr, -1)
}
// BodyWriter returns writer for populating response body.
//
// If used inside RequestHandler, the returned writer must not be used
// after returning from RequestHandler. Use RequestCtx.Write
// or SetBodyStreamWriter in this case.
func (resp *Response) BodyWriter() io.Writer {
resp.w.r = resp
return &resp.w
}
// BodyStream returns io.Reader.
//
// You must CloseBodyStream or ReleaseRequest after you use it.
func (req *Request) BodyStream() io.Reader {
return req.bodyStream
}
func (req *Request) CloseBodyStream() error {
return req.closeBodyStream()
}
// BodyStream returns io.Reader.
//
// You must CloseBodyStream or ReleaseResponse after you use it.
func (resp *Response) BodyStream() io.Reader {
return resp.bodyStream
}
func (resp *Response) CloseBodyStream() error {
return resp.closeBodyStream(nil)
}
type ReadCloserWithError interface {
io.Reader
CloseWithError(err error) error
}
type closeReader struct {
io.Reader
closeFunc func(err error) error
}
func newCloseReaderWithError(r io.Reader, closeFunc func(err error) error) ReadCloserWithError {
if r == nil {
panic(`BUG: reader is nil`)
}
return &closeReader{Reader: r, closeFunc: closeFunc}
}
func (c *closeReader) CloseWithError(err error) error {
if c.closeFunc == nil {
return nil
}
return c.closeFunc(err)
}
// BodyWriter returns writer for populating request body.
func (req *Request) BodyWriter() io.Writer {
req.w.r = req
return &req.w
}
type responseBodyWriter struct {
r *Response
}
func (w *responseBodyWriter) Write(p []byte) (int, error) {
w.r.AppendBody(p)
return len(p), nil
}
func (w *responseBodyWriter) WriteString(s string) (int, error) {
w.r.AppendBodyString(s)
return len(s), nil
}
type requestBodyWriter struct {
r *Request
}
func (w *requestBodyWriter) Write(p []byte) (int, error) {
w.r.AppendBody(p)
return len(p), nil
}
func (w *requestBodyWriter) WriteString(s string) (int, error) {
w.r.AppendBodyString(s)
return len(s), nil
}
func (resp *Response) ParseNetConn(conn net.Conn) {
resp.raddr = conn.RemoteAddr()
resp.laddr = conn.LocalAddr()
}
// RemoteAddr returns the remote network address. The Addr returned is shared
// by all invocations of RemoteAddr, so do not modify it.
func (resp *Response) RemoteAddr() net.Addr {
return resp.raddr
}
// LocalAddr returns the local network address. The Addr returned is shared
// by all invocations of LocalAddr, so do not modify it.
func (resp *Response) LocalAddr() net.Addr {
return resp.laddr
}
// Body returns response body.
//
// The returned value is valid until the response is released,
// either though ReleaseResponse or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (resp *Response) Body() []byte {
if resp.bodyStream != nil {
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
_, err := copyZeroAlloc(bodyBuf, resp.bodyStream)
resp.closeBodyStream(err) //nolint:errcheck
if err != nil {
bodyBuf.SetString(err.Error())
}
}
return resp.bodyBytes()
}
func (resp *Response) bodyBytes() []byte {
if resp.bodyRaw != nil {
return resp.bodyRaw
}
if resp.body == nil {
return nil
}
return resp.body.B
}
func (req *Request) bodyBytes() []byte {
if req.bodyRaw != nil {
return req.bodyRaw
}
if req.bodyStream != nil {
bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
_, err := copyZeroAlloc(bodyBuf, req.bodyStream)
req.closeBodyStream() //nolint:errcheck
if err != nil {
bodyBuf.SetString(err.Error())
}
}
if req.body == nil {
return nil
}
return req.body.B
}
func (resp *Response) bodyBuffer() *bytebufferpool.ByteBuffer {
if resp.body == nil {
resp.body = responseBodyPool.Get()
}
resp.bodyRaw = nil
return resp.body
}
func (req *Request) bodyBuffer() *bytebufferpool.ByteBuffer {
if req.body == nil {
req.body = requestBodyPool.Get()
}
req.bodyRaw = nil
return req.body
}
var (
responseBodyPool bytebufferpool.Pool
requestBodyPool bytebufferpool.Pool
)
// BodyGunzip returns un-gzipped body data.
//
// This method may be used if the request header contains
// 'Content-Encoding: gzip' for reading un-gzipped body.
// Use Body for reading gzipped request body.
func (req *Request) BodyGunzip() ([]byte, error) {
return gunzipData(req.Body())
}
// BodyGunzip returns un-gzipped body data.
//
// This method may be used if the response header contains
// 'Content-Encoding: gzip' for reading un-gzipped body.
// Use Body for reading gzipped response body.
func (resp *Response) BodyGunzip() ([]byte, error) {
return gunzipData(resp.Body())
}
func gunzipData(p []byte) ([]byte, error) {
var bb bytebufferpool.ByteBuffer
_, err := WriteGunzip(&bb, p)
if err != nil {
return nil, err
}
return bb.B, nil
}
// BodyUnbrotli returns un-brotlied body data.
//
// This method may be used if the request header contains
// 'Content-Encoding: br' for reading un-brotlied body.
// Use Body for reading brotlied request body.
func (req *Request) BodyUnbrotli() ([]byte, error) {
return unBrotliData(req.Body())
}
// BodyUnbrotli returns un-brotlied body data.
//
// This method may be used if the response header contains
// 'Content-Encoding: br' for reading un-brotlied body.
// Use Body for reading brotlied response body.
func (resp *Response) BodyUnbrotli() ([]byte, error) {
return unBrotliData(resp.Body())
}
func unBrotliData(p []byte) ([]byte, error) {
var bb bytebufferpool.ByteBuffer
_, err := WriteUnbrotli(&bb, p)
if err != nil {
return nil, err
}
return bb.B, nil
}
// BodyInflate returns inflated body data.
//
// This method may be used if the response header contains
// 'Content-Encoding: deflate' for reading inflated request body.
// Use Body for reading deflated request body.
func (req *Request) BodyInflate() ([]byte, error) {
return inflateData(req.Body())
}
// BodyInflate returns inflated body data.
//
// This method may be used if the response header contains
// 'Content-Encoding: deflate' for reading inflated response body.
// Use Body for reading deflated response body.
func (resp *Response) BodyInflate() ([]byte, error) {
return inflateData(resp.Body())
}
func (ctx *RequestCtx) RequestBodyStream() io.Reader {
return ctx.Request.bodyStream
}
func (req *Request) BodyUnzstd() ([]byte, error) {
return unzstdData(req.Body())
}
func (resp *Response) BodyUnzstd() ([]byte, error) {
return unzstdData(resp.Body())
}
func unzstdData(p []byte) ([]byte, error) {
var bb bytebufferpool.ByteBuffer
_, err := WriteUnzstd(&bb, p)
if err != nil {
return nil, err
}
return bb.B, nil
}
func inflateData(p []byte) ([]byte, error) {
var bb bytebufferpool.ByteBuffer
_, err := WriteInflate(&bb, p)
if err != nil {
return nil, err
}
return bb.B, nil
}
var ErrContentEncodingUnsupported = errors.New("unsupported Content-Encoding")
// BodyUncompressed returns body data and if needed decompress it from gzip, deflate or Brotli.
//
// This method may be used if the response header contains
// 'Content-Encoding' for reading uncompressed request body.
// Use Body for reading the raw request body.
func (req *Request) BodyUncompressed() ([]byte, error) {
switch string(req.Header.ContentEncoding()) {
case "":
return req.Body(), nil
case "deflate":
return req.BodyInflate()
case "gzip":
return req.BodyGunzip()
case "br":
return req.BodyUnbrotli()
case "zstd":
return req.BodyUnzstd()
default:
return nil, ErrContentEncodingUnsupported
}
}
// BodyUncompressed returns body data and if needed decompress it from gzip, deflate or Brotli.
//
// This method may be used if the response header contains
// 'Content-Encoding' for reading uncompressed response body.
// Use Body for reading the raw response body.
func (resp *Response) BodyUncompressed() ([]byte, error) {
switch string(resp.Header.ContentEncoding()) {
case "":
return resp.Body(), nil
case "deflate":
return resp.BodyInflate()
case "gzip":
return resp.BodyGunzip()
case "br":
return resp.BodyUnbrotli()
case "zstd":
return resp.BodyUnzstd()
default:
return nil, ErrContentEncodingUnsupported
}
}
// BodyWriteTo writes request body to w.
func (req *Request) BodyWriteTo(w io.Writer) error {
if req.bodyStream != nil {
_, err := copyZeroAlloc(w, req.bodyStream)
req.closeBodyStream() //nolint:errcheck
return err
}
if req.onlyMultipartForm() {
return WriteMultipartForm(w, req.multipartForm, req.multipartFormBoundary)
}
_, err := w.Write(req.bodyBytes())
return err
}
// BodyWriteTo writes response body to w.
func (resp *Response) BodyWriteTo(w io.Writer) error {
if resp.bodyStream != nil {
_, err := copyZeroAlloc(w, resp.bodyStream)
resp.closeBodyStream(err) //nolint:errcheck
return err
}
_, err := w.Write(resp.bodyBytes())
return err
}
// AppendBody appends p to response body.
//
// It is safe re-using p after the function returns.
func (resp *Response) AppendBody(p []byte) {
resp.closeBodyStream(nil) //nolint:errcheck
resp.bodyBuffer().Write(p) //nolint:errcheck
}
// AppendBodyString appends s to response body.
func (resp *Response) AppendBodyString(s string) {
resp.closeBodyStream(nil) //nolint:errcheck
resp.bodyBuffer().WriteString(s) //nolint:errcheck
}
// SetBody sets response body.
//
// It is safe re-using body argument after the function returns.
func (resp *Response) SetBody(body []byte) {
resp.closeBodyStream(nil) //nolint:errcheck
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
bodyBuf.Write(body) //nolint:errcheck
}
// SetBodyString sets response body.
func (resp *Response) SetBodyString(body string) {
resp.closeBodyStream(nil) //nolint:errcheck
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
bodyBuf.WriteString(body) //nolint:errcheck
}
// ResetBody resets response body.
func (resp *Response) ResetBody() {
resp.bodyRaw = nil
resp.closeBodyStream(nil) //nolint:errcheck
if resp.body != nil {
if resp.keepBodyBuffer {
resp.body.Reset()
} else {
responseBodyPool.Put(resp.body)
resp.body = nil
}
}
}
// SetBodyRaw sets response body, but without copying it.
//
// From this point onward the body argument must not be changed.
func (resp *Response) SetBodyRaw(body []byte) {
resp.ResetBody()
resp.bodyRaw = body
}
// SetBodyRaw sets response body, but without copying it.
//
// From this point onward the body argument must not be changed.
func (req *Request) SetBodyRaw(body []byte) {
req.ResetBody()
req.bodyRaw = body
}
// ReleaseBody retires the response body if it is greater than "size" bytes.
//
// This permits GC to reclaim the large buffer. If used, must be before
// ReleaseResponse.
//
// Use this method only if you really understand how it works.
// The majority of workloads don't need this method.
func (resp *Response) ReleaseBody(size int) {
resp.bodyRaw = nil
if resp.body == nil {
return
}
if cap(resp.body.B) > size {
resp.closeBodyStream(nil) //nolint:errcheck
resp.body = nil
}
}
// ReleaseBody retires the request body if it is greater than "size" bytes.
//
// This permits GC to reclaim the large buffer. If used, must be before
// ReleaseRequest.
//
// Use this method only if you really understand how it works.
// The majority of workloads don't need this method.
func (req *Request) ReleaseBody(size int) {
req.bodyRaw = nil
if req.body == nil {
return
}
if cap(req.body.B) > size {
req.closeBodyStream() //nolint:errcheck
req.body = nil
}
}
// SwapBody swaps response body with the given body and returns
// the previous response body.
//
// It is forbidden to use the body passed to SwapBody after
// the function returns.
func (resp *Response) SwapBody(body []byte) []byte {
bb := resp.bodyBuffer()
if resp.bodyStream != nil {
bb.Reset()
_, err := copyZeroAlloc(bb, resp.bodyStream)
resp.closeBodyStream(err) //nolint:errcheck
if err != nil {
bb.Reset()
bb.SetString(err.Error())
}
}
resp.bodyRaw = nil
oldBody := bb.B
bb.B = body
return oldBody
}
// SwapBody swaps request body with the given body and returns
// the previous request body.
//
// It is forbidden to use the body passed to SwapBody after
// the function returns.
func (req *Request) SwapBody(body []byte) []byte {
bb := req.bodyBuffer()
if req.bodyStream != nil {
bb.Reset()
_, err := copyZeroAlloc(bb, req.bodyStream)
req.closeBodyStream() //nolint:errcheck
if err != nil {
bb.Reset()
bb.SetString(err.Error())
}
}
req.bodyRaw = nil
oldBody := bb.B
bb.B = body
return oldBody
}
// Body returns request body.
//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Do not store references to returned value. Make copies instead.
func (req *Request) Body() []byte {
if req.bodyRaw != nil {
return req.bodyRaw
} else if req.onlyMultipartForm() {
body, err := marshalMultipartForm(req.multipartForm, req.multipartFormBoundary)
if err != nil {
return []byte(err.Error())
}
return body
}
return req.bodyBytes()
}
// AppendBody appends p to request body.
//
// It is safe re-using p after the function returns.
func (req *Request) AppendBody(p []byte) {
req.RemoveMultipartFormFiles()
req.closeBodyStream() //nolint:errcheck
req.bodyBuffer().Write(p) //nolint:errcheck
}
// AppendBodyString appends s to request body.
func (req *Request) AppendBodyString(s string) {
req.RemoveMultipartFormFiles()
req.closeBodyStream() //nolint:errcheck
req.bodyBuffer().WriteString(s) //nolint:errcheck
}
// SetBody sets request body.
//
// It is safe re-using body argument after the function returns.
func (req *Request) SetBody(body []byte) {
req.RemoveMultipartFormFiles()
req.closeBodyStream() //nolint:errcheck
req.bodyBuffer().Set(body)
}
// SetBodyString sets request body.
func (req *Request) SetBodyString(body string) {
req.RemoveMultipartFormFiles()
req.closeBodyStream() //nolint:errcheck
req.bodyBuffer().SetString(body)
}
// ResetBody resets request body.
func (req *Request) ResetBody() {
req.bodyRaw = nil
req.RemoveMultipartFormFiles()
req.closeBodyStream() //nolint:errcheck
if req.body != nil {
if req.keepBodyBuffer {
req.body.Reset()
} else {
requestBodyPool.Put(req.body)
req.body = nil
}
}
}
// CopyTo copies req contents to dst except of body stream.
func (req *Request) CopyTo(dst *Request) {
req.copyToSkipBody(dst)
switch {
case req.bodyRaw != nil:
dst.bodyRaw = append(dst.bodyRaw[:0], req.bodyRaw...)
if dst.body != nil {
dst.body.Reset()
}
case req.body != nil:
dst.bodyBuffer().Set(req.body.B)
case dst.body != nil:
dst.body.Reset()
}
}
func (req *Request) copyToSkipBody(dst *Request) {
dst.Reset()
req.Header.CopyTo(&dst.Header)
req.uri.CopyTo(&dst.uri)
dst.parsedURI = req.parsedURI
req.postArgs.CopyTo(&dst.postArgs)
dst.parsedPostArgs = req.parsedPostArgs
dst.isTLS = req.isTLS
dst.UseHostHeader = req.UseHostHeader
// do not copy multipartForm - it will be automatically
// re-created on the first call to MultipartForm.
}
// CopyTo copies resp contents to dst except of body stream.
func (resp *Response) CopyTo(dst *Response) {
resp.copyToSkipBody(dst)
switch {
case resp.bodyRaw != nil:
dst.bodyRaw = append(dst.bodyRaw, resp.bodyRaw...)
if dst.body != nil {
dst.body.Reset()
}
case resp.body != nil:
dst.bodyBuffer().Set(resp.body.B)
case dst.body != nil:
dst.body.Reset()
}
}
func (resp *Response) copyToSkipBody(dst *Response) {
dst.Reset()
resp.Header.CopyTo(&dst.Header)
dst.SkipBody = resp.SkipBody
dst.raddr = resp.raddr
dst.laddr = resp.laddr
}
func swapRequestBody(a, b *Request) {
a.body, b.body = b.body, a.body
a.bodyRaw, b.bodyRaw = b.bodyRaw, a.bodyRaw
a.bodyStream, b.bodyStream = b.bodyStream, a.bodyStream
// This code assumes that if a requestStream was swapped the headers are also swapped or copied.
if rs, ok := a.bodyStream.(*requestStream); ok {
rs.header = &a.Header
}
if rs, ok := b.bodyStream.(*requestStream); ok {
rs.header = &b.Header
}
}
func swapResponseBody(a, b *Response) {
a.body, b.body = b.body, a.body
a.bodyRaw, b.bodyRaw = b.bodyRaw, a.bodyRaw
a.bodyStream, b.bodyStream = b.bodyStream, a.bodyStream
}
// URI returns request URI.
func (req *Request) URI() *URI {
req.parseURI() //nolint:errcheck
return &req.uri
}
// SetURI initializes request URI.
// Use this method if a single URI may be reused across multiple requests.
// Otherwise, you can just use SetRequestURI() and it will be parsed as new URI.
// The URI is copied and can be safely modified later.
func (req *Request) SetURI(newURI *URI) {
if newURI != nil {
newURI.CopyTo(&req.uri)
req.parsedURI = true
return
}
req.uri.Reset()
req.parsedURI = false
}
func (req *Request) parseURI() error {
if req.parsedURI {
return nil
}
req.parsedURI = true
return req.uri.parse(req.Header.Host(), req.Header.RequestURI(), req.isTLS)
}
// PostArgs returns POST arguments.
func (req *Request) PostArgs() *Args {
req.parsePostArgs()
return &req.postArgs
}
func (req *Request) parsePostArgs() {
if req.parsedPostArgs {
return
}
req.parsedPostArgs = true
if !bytes.HasPrefix(req.Header.ContentType(), strPostArgsContentType) {
return
}
req.postArgs.ParseBytes(req.bodyBytes())
}
// ErrNoMultipartForm means that the request's Content-Type
// isn't 'multipart/form-data'.
var ErrNoMultipartForm = errors.New("request Content-Type has bad boundary or is not multipart/form-data")
// MultipartForm returns request's multipart form.
//
// Returns ErrNoMultipartForm if request's Content-Type
// isn't 'multipart/form-data'.
//
// RemoveMultipartFormFiles must be called after returned multipart form
// is processed.
func (req *Request) MultipartForm() (*multipart.Form, error) {
if req.multipartForm != nil {
return req.multipartForm, nil
}
req.multipartFormBoundary = string(req.Header.MultipartFormBoundary())
if req.multipartFormBoundary == "" {
return nil, ErrNoMultipartForm
}
var err error
ce := req.Header.peek(strContentEncoding)
if req.bodyStream != nil {
bodyStream := req.bodyStream
if bytes.Equal(ce, strGzip) {
// Do not care about memory usage here.
if bodyStream, err = gzip.NewReader(bodyStream); err != nil {
return nil, fmt.Errorf("cannot gunzip request body: %w", err)
}
} else if len(ce) > 0 {
return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce)
}
mr := multipart.NewReader(bodyStream, req.multipartFormBoundary)
req.multipartForm, err = mr.ReadForm(8 * 1024)
if err != nil {
return nil, fmt.Errorf("cannot read multipart/form-data body: %w", err)
}
} else {
body := req.bodyBytes()
if bytes.Equal(ce, strGzip) {
// Do not care about memory usage here.
if body, err = AppendGunzipBytes(nil, body); err != nil {
return nil, fmt.Errorf("cannot gunzip request body: %w", err)
}
} else if len(ce) > 0 {
return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce)
}
req.multipartForm, err = readMultipartForm(bytes.NewReader(body), req.multipartFormBoundary, len(body), len(body))
if err != nil {
return nil, err
}
}
return req.multipartForm, nil
}
func marshalMultipartForm(f *multipart.Form, boundary string) ([]byte, error) {
var buf bytebufferpool.ByteBuffer
if err := WriteMultipartForm(&buf, f, boundary); err != nil {
return nil, err
}
return buf.B, nil
}
// WriteMultipartForm writes the given multipart form f with the given
// boundary to w.
func WriteMultipartForm(w io.Writer, f *multipart.Form, boundary string) error {
// Do not care about memory allocations here, since multipart
// form processing is slow.
if boundary == "" {
return errors.New("form boundary cannot be empty")
}
mw := multipart.NewWriter(w)
if err := mw.SetBoundary(boundary); err != nil {
return fmt.Errorf("cannot use form boundary %q: %w", boundary, err)
}
// marshal values
for k, vv := range f.Value {
for _, v := range vv {
if err := mw.WriteField(k, v); err != nil {
return fmt.Errorf("cannot write form field %q value %q: %w", k, v, err)
}
}
}
// marshal files
for k, fvv := range f.File {
for _, fv := range fvv {
vw, err := mw.CreatePart(fv.Header)
if err != nil {
return fmt.Errorf("cannot create form file %q (%q): %w", k, fv.Filename, err)
}
fh, err := fv.Open()
if err != nil {
return fmt.Errorf("cannot open form file %q (%q): %w", k, fv.Filename, err)
}
if _, err = copyZeroAlloc(vw, fh); err != nil {
_ = fh.Close()
return fmt.Errorf("error when copying form file %q (%q): %w", k, fv.Filename, err)
}
if err = fh.Close(); err != nil {
return fmt.Errorf("cannot close form file %q (%q): %w", k, fv.Filename, err)
}
}
}
if err := mw.Close(); err != nil {
return fmt.Errorf("error when closing multipart form writer: %w", err)
}
return nil
}
func readMultipartForm(r io.Reader, boundary string, size, maxInMemoryFileSize int) (*multipart.Form, error) {
// Do not care about memory allocations here, since they are tiny
// compared to multipart data (aka multi-MB files) usually sent
// in multipart/form-data requests.
if size <= 0 {
return nil, fmt.Errorf("form size must be greater than 0. Given %d", size)
}
lr := io.LimitReader(r, int64(size))
mr := multipart.NewReader(lr, boundary)
f, err := mr.ReadForm(int64(maxInMemoryFileSize))
if err != nil {
return nil, fmt.Errorf("cannot read multipart/form-data body: %w", err)
}
return f, nil
}
// Reset clears request contents.
func (req *Request) Reset() {
req.userValues.Reset() // it should be at the top, since some values might implement io.Closer interface
if requestBodyPoolSizeLimit >= 0 && req.body != nil {
req.ReleaseBody(requestBodyPoolSizeLimit)
}
req.Header.Reset()
req.resetSkipHeader()
req.timeout = 0
req.UseHostHeader = false
req.DisableRedirectPathNormalizing = false
}
func (req *Request) resetSkipHeader() {
req.ResetBody()
req.uri.Reset()
req.parsedURI = false
req.postArgs.Reset()
req.parsedPostArgs = false
req.isTLS = false
}
// RemoveMultipartFormFiles removes multipart/form-data temporary files
// associated with the request.
func (req *Request) RemoveMultipartFormFiles() {
if req.multipartForm != nil {
// Do not check for error, since these files may be deleted or moved
// to new places by user code.
req.multipartForm.RemoveAll() //nolint:errcheck
req.multipartForm = nil
}
req.multipartFormBoundary = ""
}
// Reset clears response contents.
func (resp *Response) Reset() {
if responseBodyPoolSizeLimit >= 0 && resp.body != nil {
resp.ReleaseBody(responseBodyPoolSizeLimit)
}
resp.resetSkipHeader()
resp.Header.Reset()
resp.SkipBody = false
resp.raddr = nil
resp.laddr = nil
resp.ImmediateHeaderFlush = false
resp.StreamBody = false
}
func (resp *Response) resetSkipHeader() {
resp.ResetBody()
}
// Read reads request (including body) from the given r.
//
// RemoveMultipartFormFiles or Reset must be called after
// reading multipart/form-data request in order to delete temporarily
// uploaded files.
//
// If MayContinue returns true, the caller must:
//
// - Either send StatusExpectationFailed response if request headers don't
// satisfy the caller.
// - Or send StatusContinue response before reading request body
// with ContinueReadBody.
// - Or close the connection.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (req *Request) Read(r *bufio.Reader) error {
return req.ReadLimitBody(r, 0)
}
const defaultMaxInMemoryFileSize = 16 * 1024 * 1024
// ErrGetOnly is returned when server expects only GET requests,
// but some other type of request came (Server.GetOnly option is true).
var ErrGetOnly = errors.New("non-GET request received")
// ReadLimitBody reads request from the given r, limiting the body size.
//
// If maxBodySize > 0 and the body size exceeds maxBodySize,
// then ErrBodyTooLarge is returned.
//
// RemoveMultipartFormFiles or Reset must be called after
// reading multipart/form-data request in order to delete temporarily
// uploaded files.
//
// If MayContinue returns true, the caller must:
//
// - Either send StatusExpectationFailed response if request headers don't
// satisfy the caller.
// - Or send StatusContinue response before reading request body
// with ContinueReadBody.
// - Or close the connection.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (req *Request) ReadLimitBody(r *bufio.Reader, maxBodySize int) error {
req.resetSkipHeader()
if err := req.Header.Read(r); err != nil {
return err
}
return req.readLimitBody(r, maxBodySize, false, true)
}
func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly, preParseMultipartForm bool) error {
// Do not reset the request here - the caller must reset it before
// calling this method.
if getOnly && !req.Header.IsGet() && !req.Header.IsHead() {
return ErrGetOnly
}
if req.MayContinue() {
// 'Expect: 100-continue' header found. Let the caller deciding
// whether to read request body or
// to return StatusExpectationFailed.
return nil
}
return req.ContinueReadBody(r, maxBodySize, preParseMultipartForm)
}
func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly, preParseMultipartForm bool) error {
// Do not reset the request here - the caller must reset it before
// calling this method.
if getOnly && !req.Header.IsGet() && !req.Header.IsHead() {
return ErrGetOnly
}
if req.MayContinue() {
// 'Expect: 100-continue' header found. Let the caller deciding
// whether to read request body or
// to return StatusExpectationFailed.
return nil
}
return req.ContinueReadBodyStream(r, maxBodySize, preParseMultipartForm)
}
// MayContinue returns true if the request contains
// 'Expect: 100-continue' header.
//
// The caller must do one of the following actions if MayContinue returns true:
//
// - Either send StatusExpectationFailed response if request headers don't
// satisfy the caller.
// - Or send StatusContinue response before reading request body
// with ContinueReadBody.
// - Or close the connection.
func (req *Request) MayContinue() bool {
return bytes.Equal(req.Header.peek(strExpect), str100Continue)
}
// ContinueReadBody reads request body if request header contains
// 'Expect: 100-continue'.
//
// The caller must send StatusContinue response before calling this method.
//
// If maxBodySize > 0 and the body size exceeds maxBodySize,
// then ErrBodyTooLarge is returned.
func (req *Request) ContinueReadBody(r *bufio.Reader, maxBodySize int, preParseMultipartForm ...bool) error {
var err error
contentLength := req.Header.ContentLength()
if contentLength > 0 {
if maxBodySize > 0 && contentLength > maxBodySize {
return ErrBodyTooLarge
}
if len(preParseMultipartForm) == 0 || preParseMultipartForm[0] {
// Pre-read multipart form data of known length.
// This way we limit memory usage for large file uploads, since their contents
// is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize.
req.multipartFormBoundary = string(req.Header.MultipartFormBoundary())
if req.multipartFormBoundary != "" && len(req.Header.peek(strContentEncoding)) == 0 {
req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize)
if err != nil {
req.Reset()
}
return err
}
}
}
if contentLength == -2 {
// identity body has no sense for http requests, since
// the end of body is determined by connection close.
// So just ignore request body for requests without
// 'Content-Length' and 'Transfer-Encoding' headers.
// refer to https://tools.ietf.org/html/rfc7230#section-3.3.2
if !req.Header.ignoreBody() {
req.Header.SetContentLength(0)
}
return nil
}
if err = req.ReadBody(r, contentLength, maxBodySize); err != nil {
return err
}
if contentLength == -1 {
err = req.Header.ReadTrailer(r)
if err != nil && err != io.EOF {
return err
}
}
return nil
}
// ReadBody reads request body from the given r, limiting the body size.
//
// If maxBodySize > 0 and the body size exceeds maxBodySize,
// then ErrBodyTooLarge is returned.
func (req *Request) ReadBody(r *bufio.Reader, contentLength, maxBodySize int) (err error) {
bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
switch {
case contentLength >= 0:
bodyBuf.B, err = readBody(r, contentLength, maxBodySize, bodyBuf.B)
case contentLength == -1:
bodyBuf.B, err = readBodyChunked(r, maxBodySize, bodyBuf.B)
if err == nil && len(bodyBuf.B) == 0 {
req.Header.SetContentLength(0)
}
default:
bodyBuf.B, err = readBodyIdentity(r, maxBodySize, bodyBuf.B)
req.Header.SetContentLength(len(bodyBuf.B))
}
if err != nil {
req.Reset()
return err
}
return nil
}
// ContinueReadBodyStream reads request body if request header contains
// 'Expect: 100-continue'.
//
// The caller must send StatusContinue response before calling this method.
//
// If maxBodySize > 0 and the body size exceeds maxBodySize,
// then ErrBodyTooLarge is returned.
func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, preParseMultipartForm ...bool) error {
var err error
contentLength := req.Header.ContentLength()
if contentLength > 0 {
if len(preParseMultipartForm) == 0 || preParseMultipartForm[0] {
// Pre-read multipart form data of known length.
// This way we limit memory usage for large file uploads, since their contents
// is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize.
req.multipartFormBoundary = b2s(req.Header.MultipartFormBoundary())
if req.multipartFormBoundary != "" && len(req.Header.peek(strContentEncoding)) == 0 {
req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize)
if err != nil {
req.Reset()
}
return err
}
}
}
if contentLength == -2 {
// identity body has no sense for http requests, since
// the end of body is determined by connection close.
// So just ignore request body for requests without
// 'Content-Length' and 'Transfer-Encoding' headers.
// refer to https://tools.ietf.org/html/rfc7230#section-3.3.2
if !req.Header.ignoreBody() {
req.Header.SetContentLength(0)
}
return nil
}
bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
bodyBuf.B, err = readBodyWithStreaming(r, contentLength, maxBodySize, bodyBuf.B)
if err != nil {
if err == ErrBodyTooLarge {
req.Header.SetContentLength(contentLength)
req.body = bodyBuf
req.bodyStream = acquireRequestStream(bodyBuf, r, &req.Header)
return nil
}
if err == errChunkedStream {
req.body = bodyBuf
req.bodyStream = acquireRequestStream(bodyBuf, r, &req.Header)
return nil
}
req.Reset()
return err
}
req.body = bodyBuf
req.bodyStream = acquireRequestStream(bodyBuf, r, &req.Header)
req.Header.SetContentLength(contentLength)
return nil
}
// Read reads response (including body) from the given r.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (resp *Response) Read(r *bufio.Reader) error {
return resp.ReadLimitBody(r, 0)
}
// ReadLimitBody reads response headers from the given r,
// then reads the body using the ReadBody function and limiting the body size.
//
// If resp.SkipBody is true then it skips reading the response body.
//
// If maxBodySize > 0 and the body size exceeds maxBodySize,
// then ErrBodyTooLarge is returned.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (resp *Response) ReadLimitBody(r *bufio.Reader, maxBodySize int) error {
resp.resetSkipHeader()
err := resp.Header.Read(r)
if err != nil {
return err
}
if resp.Header.StatusCode() == StatusContinue {
// Read the next response according to http://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html .
if err = resp.Header.Read(r); err != nil {
return err
}
}
if !resp.mustSkipBody() {
err = resp.ReadBody(r, maxBodySize)
if err != nil {
return err
}
}
// A response without a body can't have trailers.
if resp.Header.ContentLength() == -1 && !resp.StreamBody && !resp.mustSkipBody() {
err = resp.Header.ReadTrailer(r)
if err != nil && err != io.EOF {
return err
}
}
return nil
}
// ReadBody reads response body from the given r, limiting the body size.
//
// If maxBodySize > 0 and the body size exceeds maxBodySize,
// then ErrBodyTooLarge is returned.
func (resp *Response) ReadBody(r *bufio.Reader, maxBodySize int) (err error) {
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
contentLength := resp.Header.ContentLength()
switch {
case contentLength >= 0:
bodyBuf.B, err = readBody(r, contentLength, maxBodySize, bodyBuf.B)
if err == ErrBodyTooLarge && resp.StreamBody {
resp.bodyStream = acquireRequestStream(bodyBuf, r, &resp.Header)
err = nil
}
case contentLength == -1:
if resp.StreamBody {
resp.bodyStream = acquireRequestStream(bodyBuf, r, &resp.Header)
} else {
bodyBuf.B, err = readBodyChunked(r, maxBodySize, bodyBuf.B)
}
default:
if resp.StreamBody {
resp.bodyStream = acquireRequestStream(bodyBuf, r, &resp.Header)
} else {
bodyBuf.B, err = readBodyIdentity(r, maxBodySize, bodyBuf.B)
resp.Header.SetContentLength(len(bodyBuf.B))
}
}
if err == nil && resp.StreamBody && resp.bodyStream == nil {
resp.bodyStream = bytes.NewReader(bodyBuf.B)
}
return err
}
func (resp *Response) mustSkipBody() bool {
return resp.SkipBody || resp.Header.mustSkipContentLength()
}
var errRequestHostRequired = errors.New("missing required Host header in request")
// WriteTo writes request to w. It implements io.WriterTo.
func (req *Request) WriteTo(w io.Writer) (int64, error) {
return writeBufio(req, w)
}
// WriteTo writes response to w. It implements io.WriterTo.
func (resp *Response) WriteTo(w io.Writer) (int64, error) {
return writeBufio(resp, w)
}
func writeBufio(hw httpWriter, w io.Writer) (int64, error) {
sw := acquireStatsWriter(w)
bw := acquireBufioWriter(sw)
errw := hw.Write(bw)
errf := bw.Flush()
releaseBufioWriter(bw)
n := sw.bytesWritten
releaseStatsWriter(sw)
err := errw
if err == nil {
err = errf
}
return n, err
}
type statsWriter struct {
w io.Writer
bytesWritten int64
}
func (w *statsWriter) Write(p []byte) (int, error) {
n, err := w.w.Write(p)
w.bytesWritten += int64(n)
return n, err
}
func (w *statsWriter) WriteString(s string) (int, error) {
n, err := w.w.Write(s2b(s))
w.bytesWritten += int64(n)
return n, err
}
func acquireStatsWriter(w io.Writer) *statsWriter {
v := statsWriterPool.Get()
if v == nil {
return &statsWriter{
w: w,
}
}
sw := v.(*statsWriter)
sw.w = w
return sw
}
func releaseStatsWriter(sw *statsWriter) {
sw.w = nil
sw.bytesWritten = 0
statsWriterPool.Put(sw)
}
var statsWriterPool sync.Pool
func acquireBufioWriter(w io.Writer) *bufio.Writer {
v := bufioWriterPool.Get()
if v == nil {
return bufio.NewWriter(w)
}
bw := v.(*bufio.Writer)
bw.Reset(w)
return bw
}
func releaseBufioWriter(bw *bufio.Writer) {
bufioWriterPool.Put(bw)
}
var bufioWriterPool sync.Pool
func (req *Request) onlyMultipartForm() bool {
return req.multipartForm != nil && (req.body == nil || len(req.body.B) == 0)
}
// Write writes request to w.
//
// Write doesn't flush request to w for performance reasons.
//
// See also WriteTo.
func (req *Request) Write(w *bufio.Writer) error {
if len(req.Header.Host()) == 0 || req.parsedURI {
uri := req.URI()
host := uri.Host()
if len(req.Header.Host()) == 0 {
if len(host) == 0 {
return errRequestHostRequired
}
req.Header.SetHostBytes(host)
} else if !req.UseHostHeader {
req.Header.SetHostBytes(host)
}
req.Header.SetRequestURIBytes(uri.RequestURI())
if len(uri.username) > 0 {
// RequestHeader.SetBytesKV only uses RequestHeader.bufKV.key
// So we are free to use RequestHeader.bufKV.value as a scratch pad for
// the base64 encoding.
nl := len(uri.username) + len(uri.password) + 1
nb := nl + len(strBasicSpace)
tl := nb + base64.StdEncoding.EncodedLen(nl)
if tl > cap(req.Header.bufV) {
req.Header.bufV = make([]byte, 0, tl)
}
buf := req.Header.bufV[:0]
buf = append(buf, uri.username...)
buf = append(buf, strColon...)
buf = append(buf, uri.password...)
buf = append(buf, strBasicSpace...)
base64.StdEncoding.Encode(buf[nb:tl], buf[:nl])
req.Header.SetBytesKV(strAuthorization, buf[nl:tl])
}
}
if req.bodyStream != nil {
return req.writeBodyStream(w)
}
body := req.bodyBytes()
var err error
if req.onlyMultipartForm() {
body, err = marshalMultipartForm(req.multipartForm, req.multipartFormBoundary)
if err != nil {
return fmt.Errorf("error when marshaling multipart form: %w", err)
}
req.Header.SetMultipartFormBoundary(req.multipartFormBoundary)
}
hasBody := false
if len(body) == 0 {
body = req.postArgs.QueryString()
}
if len(body) != 0 || !req.Header.ignoreBody() {
hasBody = true
req.Header.SetContentLength(len(body))
}
if err = req.Header.Write(w); err != nil {
return err
}
if hasBody {
_, err = w.Write(body)
} else if len(body) > 0 {
if req.secureErrorLogMessage {
return errors.New("non-zero body for non-POST request")
}
return fmt.Errorf("non-zero body for non-POST request. body=%q", body)
}
return err
}
// WriteGzip writes response with gzipped body to w.
//
// The method gzips response body and sets 'Content-Encoding: gzip'
// header before writing response to w.
//
// WriteGzip doesn't flush response to w for performance reasons.
func (resp *Response) WriteGzip(w *bufio.Writer) error {
return resp.WriteGzipLevel(w, CompressDefaultCompression)
}
// WriteGzipLevel writes response with gzipped body to w.
//
// Level is the desired compression level:
//
// - CompressNoCompression
// - CompressBestSpeed
// - CompressBestCompression
// - CompressDefaultCompression
// - CompressHuffmanOnly
//
// The method gzips response body and sets 'Content-Encoding: gzip'
// header before writing response to w.
//
// WriteGzipLevel doesn't flush response to w for performance reasons.
func (resp *Response) WriteGzipLevel(w *bufio.Writer, level int) error {
resp.gzipBody(level)
return resp.Write(w)
}
// WriteDeflate writes response with deflated body to w.
//
// The method deflates response body and sets 'Content-Encoding: deflate'
// header before writing response to w.
//
// WriteDeflate doesn't flush response to w for performance reasons.
func (resp *Response) WriteDeflate(w *bufio.Writer) error {
return resp.WriteDeflateLevel(w, CompressDefaultCompression)
}
// WriteDeflateLevel writes response with deflated body to w.
//
// Level is the desired compression level:
//
// - CompressNoCompression
// - CompressBestSpeed
// - CompressBestCompression
// - CompressDefaultCompression
// - CompressHuffmanOnly
//
// The method deflates response body and sets 'Content-Encoding: deflate'
// header before writing response to w.
//
// WriteDeflateLevel doesn't flush response to w for performance reasons.
func (resp *Response) WriteDeflateLevel(w *bufio.Writer, level int) error {
resp.deflateBody(level)
return resp.Write(w)
}
func (resp *Response) brotliBody(level int) {
if len(resp.Header.ContentEncoding()) > 0 {
// It looks like the body is already compressed.
// Do not compress it again.
return
}
if !resp.Header.isCompressibleContentType() {
// The content-type cannot be compressed.
return
}
if resp.bodyStream != nil {
// Reset Content-Length to -1, since it is impossible
// to determine body size beforehand of streamed compression.
// For https://github.com/valyala/fasthttp/issues/176 .
resp.Header.SetContentLength(-1)
// Do not care about memory allocations here, since brotli is slow
// and allocates a lot of memory by itself.
bs := resp.bodyStream
resp.bodyStream = NewStreamReader(func(sw *bufio.Writer) {
zw := acquireStacklessBrotliWriter(sw, level)
fw := &flushWriter{
wf: zw,
bw: sw,
}
_, wErr := copyZeroAlloc(fw, bs)
releaseStacklessBrotliWriter(zw, level)
switch v := bs.(type) {
case io.Closer:
v.Close()
case ReadCloserWithError:
v.CloseWithError(wErr) //nolint:errcheck
}
})
} else {
bodyBytes := resp.bodyBytes()
if len(bodyBytes) < minCompressLen {
// There is no sense in spending CPU time on small body compression,
// since there is a very high probability that the compressed
// body size will be bigger than the original body size.
return
}
w := responseBodyPool.Get()
w.B = AppendBrotliBytesLevel(w.B, bodyBytes, level)
// Hack: swap resp.body with w.
if resp.body != nil {
responseBodyPool.Put(resp.body)
}
resp.body = w
resp.bodyRaw = nil
}
resp.Header.SetContentEncodingBytes(strBr)
resp.Header.addVaryBytes(strAcceptEncoding)
}
func (resp *Response) gzipBody(level int) {
if len(resp.Header.ContentEncoding()) > 0 {
// It looks like the body is already compressed.
// Do not compress it again.
return
}
if !resp.Header.isCompressibleContentType() {
// The content-type cannot be compressed.
return
}
if resp.bodyStream != nil {
// Reset Content-Length to -1, since it is impossible
// to determine body size beforehand of streamed compression.
// For https://github.com/valyala/fasthttp/issues/176 .
resp.Header.SetContentLength(-1)
// Do not care about memory allocations here, since gzip is slow
// and allocates a lot of memory by itself.
bs := resp.bodyStream
resp.bodyStream = NewStreamReader(func(sw *bufio.Writer) {
zw := acquireStacklessGzipWriter(sw, level)
fw := &flushWriter{
wf: zw,
bw: sw,
}
_, wErr := copyZeroAlloc(fw, bs)
releaseStacklessGzipWriter(zw, level)
switch v := bs.(type) {
case io.Closer:
v.Close()
case ReadCloserWithError:
v.CloseWithError(wErr) //nolint:errcheck
}
})
} else {
bodyBytes := resp.bodyBytes()
if len(bodyBytes) < minCompressLen {
// There is no sense in spending CPU time on small body compression,
// since there is a very high probability that the compressed
// body size will be bigger than the original body size.
return
}
w := responseBodyPool.Get()
w.B = AppendGzipBytesLevel(w.B, bodyBytes, level)
// Hack: swap resp.body with w.
if resp.body != nil {
responseBodyPool.Put(resp.body)
}
resp.body = w
resp.bodyRaw = nil
}
resp.Header.SetContentEncodingBytes(strGzip)
resp.Header.addVaryBytes(strAcceptEncoding)
}
func (resp *Response) deflateBody(level int) {
if len(resp.Header.ContentEncoding()) > 0 {
// It looks like the body is already compressed.
// Do not compress it again.
return
}
if !resp.Header.isCompressibleContentType() {
// The content-type cannot be compressed.
return
}
if resp.bodyStream != nil {
// Reset Content-Length to -1, since it is impossible
// to determine body size beforehand of streamed compression.
// For https://github.com/valyala/fasthttp/issues/176 .
resp.Header.SetContentLength(-1)
// Do not care about memory allocations here, since flate is slow
// and allocates a lot of memory by itself.
bs := resp.bodyStream
resp.bodyStream = NewStreamReader(func(sw *bufio.Writer) {
zw := acquireStacklessDeflateWriter(sw, level)
fw := &flushWriter{
wf: zw,
bw: sw,
}
_, wErr := copyZeroAlloc(fw, bs)
releaseStacklessDeflateWriter(zw, level)
switch v := bs.(type) {
case io.Closer:
v.Close()
case ReadCloserWithError:
v.CloseWithError(wErr) //nolint:errcheck
}
})
} else {
bodyBytes := resp.bodyBytes()
if len(bodyBytes) < minCompressLen {
// There is no sense in spending CPU time on small body compression,
// since there is a very high probability that the compressed
// body size will be bigger than the original body size.
return
}
w := responseBodyPool.Get()
w.B = AppendDeflateBytesLevel(w.B, bodyBytes, level)
// Hack: swap resp.body with w.
if resp.body != nil {
responseBodyPool.Put(resp.body)
}
resp.body = w
resp.bodyRaw = nil
}
resp.Header.SetContentEncodingBytes(strDeflate)
resp.Header.addVaryBytes(strAcceptEncoding)
}
func (resp *Response) zstdBody(level int) {
if len(resp.Header.ContentEncoding()) > 0 {
return
}
if !resp.Header.isCompressibleContentType() {
return
}
if resp.bodyStream != nil {
// Reset Content-Length to -1, since it is impossible
// to determine body size beforehand of streamed compression.
// For
resp.Header.SetContentLength(-1)
// Do not care about memory allocations here, since flate is slow
// and allocates a lot of memory by itself.
bs := resp.bodyStream
resp.bodyStream = NewStreamReader(func(sw *bufio.Writer) {
zw := acquireStacklessZstdWriter(sw, level)
fw := &flushWriter{
wf: zw,
bw: sw,
}
_, wErr := copyZeroAlloc(fw, bs)
releaseStacklessZstdWriter(zw, level)
switch v := bs.(type) {
case io.Closer:
v.Close()
case ReadCloserWithError:
v.CloseWithError(wErr) //nolint:errcheck
}
})
} else {
bodyBytes := resp.bodyBytes()
if len(bodyBytes) < minCompressLen {
return
}
w := responseBodyPool.Get()
w.B = AppendZstdBytesLevel(w.B, bodyBytes, level)
if resp.body != nil {
responseBodyPool.Put(resp.body)
}
resp.body = w
resp.bodyRaw = nil
}
resp.Header.SetContentEncodingBytes(strZstd)
resp.Header.addVaryBytes(strAcceptEncoding)
}
// Bodies with sizes smaller than minCompressLen aren't compressed at all.
const minCompressLen = 200
type writeFlusher interface {
io.Writer
Flush() error
}
type flushWriter struct {
wf writeFlusher
bw *bufio.Writer
}
func (w *flushWriter) Write(p []byte) (int, error) {
n, err := w.wf.Write(p)
if err != nil {
return 0, err
}
if err = w.wf.Flush(); err != nil {
return 0, err
}
if err = w.bw.Flush(); err != nil {
return 0, err
}
return n, nil
}
func (w *flushWriter) WriteString(s string) (int, error) {
return w.Write(s2b(s))
}
// Write writes response to w.
//
// Write doesn't flush response to w for performance reasons.
//
// See also WriteTo.
func (resp *Response) Write(w *bufio.Writer) error {
sendBody := !resp.mustSkipBody()
if resp.bodyStream != nil {
return resp.writeBodyStream(w, sendBody)
}
body := resp.bodyBytes()
bodyLen := len(body)
if sendBody || bodyLen > 0 {
resp.Header.SetContentLength(bodyLen)
}
if err := resp.Header.Write(w); err != nil {
return err
}
if sendBody {
if _, err := w.Write(body); err != nil {
return err
}
}
return nil
}
func (req *Request) writeBodyStream(w *bufio.Writer) error {
var err error
contentLength := req.Header.ContentLength()
if contentLength < 0 {
lrSize := limitedReaderSize(req.bodyStream)
if lrSize >= 0 {
contentLength = int(lrSize)
if int64(contentLength) != lrSize {
contentLength = -1
}
if contentLength >= 0 {
req.Header.SetContentLength(contentLength)
}
}
}
if contentLength >= 0 {
if err = req.Header.Write(w); err == nil {
err = writeBodyFixedSize(w, req.bodyStream, int64(contentLength))
}
} else {
req.Header.SetContentLength(-1)
err = req.Header.Write(w)
if err == nil {
err = writeBodyChunked(w, req.bodyStream)
}
if err == nil {
err = req.Header.writeTrailer(w)
}
}
errc := req.closeBodyStream()
if err == nil {
err = errc
}
return err
}
// ErrBodyStreamWritePanic is returned when panic happens during writing body stream.
type ErrBodyStreamWritePanic struct {
error
}
func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) (err error) {
defer func() {
if r := recover(); r != nil {
err = &ErrBodyStreamWritePanic{
error: fmt.Errorf("panic while writing body stream: %+v", r),
}
}
}()
contentLength := resp.Header.ContentLength()
if contentLength < 0 {
lrSize := limitedReaderSize(resp.bodyStream)
if lrSize >= 0 {
contentLength = int(lrSize)
if int64(contentLength) != lrSize {
contentLength = -1
}
if contentLength >= 0 {
resp.Header.SetContentLength(contentLength)
}
}
}
if contentLength >= 0 {
if err = resp.Header.Write(w); err == nil {
if resp.ImmediateHeaderFlush {
err = w.Flush()
}
if err == nil && sendBody {
err = writeBodyFixedSize(w, resp.bodyStream, int64(contentLength))
}
}
} else {
resp.Header.SetContentLength(-1)
if err = resp.Header.Write(w); err == nil {
if resp.ImmediateHeaderFlush {
err = w.Flush()
}
if err == nil && sendBody {
err = writeBodyChunked(w, resp.bodyStream)
}
if err == nil {
err = resp.Header.writeTrailer(w)
}
}
}
errc := resp.closeBodyStream(err)
if err == nil {
err = errc
}
return err
}
func (req *Request) closeBodyStream() error {
if req.bodyStream == nil {
return nil
}
var err error
if bsc, ok := req.bodyStream.(io.Closer); ok {
err = bsc.Close()
}
if rs, ok := req.bodyStream.(*requestStream); ok {
releaseRequestStream(rs)
}
req.bodyStream = nil
return err
}
func (resp *Response) closeBodyStream(wErr error) error {
if resp.bodyStream == nil {
return nil
}
var err error
if bsc, ok := resp.bodyStream.(io.Closer); ok {
err = bsc.Close()
}
if bsc, ok := resp.bodyStream.(ReadCloserWithError); ok {
err = bsc.CloseWithError(wErr)
}
if bsr, ok := resp.bodyStream.(*requestStream); ok {
releaseRequestStream(bsr)
}
resp.bodyStream = nil
return err
}
// String returns request representation.
//
// Returns error message instead of request representation on error.
//
// Use Write instead of String for performance-critical code.
func (req *Request) String() string {
return getHTTPString(req)
}
// String returns response representation.
//
// Returns error message instead of response representation on error.
//
// Use Write instead of String for performance-critical code.
func (resp *Response) String() string {
return getHTTPString(resp)
}
// SetUserValue stores the given value (arbitrary object)
// under the given key in Request.
//
// The value stored in Request may be obtained by UserValue*.
//
// This functionality may be useful for passing arbitrary values between
// functions involved in request processing.
//
// All the values are removed from Request after returning from the top
// RequestHandler. Additionally, Close method is called on each value
// implementing io.Closer before removing the value from Request.
func (req *Request) SetUserValue(key, value any) {
req.userValues.Set(key, value)
}
// SetUserValueBytes stores the given value (arbitrary object)
// under the given key in Request.
//
// The value stored in Request may be obtained by UserValue*.
//
// This functionality may be useful for passing arbitrary values between
// functions involved in request processing.
//
// All the values stored in Request are deleted after returning from RequestHandler.
func (req *Request) SetUserValueBytes(key []byte, value any) {
req.userValues.SetBytes(key, value)
}
// UserValue returns the value stored via SetUserValue* under the given key.
func (req *Request) UserValue(key any) any {
return req.userValues.Get(key)
}
// UserValueBytes returns the value stored via SetUserValue*
// under the given key.
func (req *Request) UserValueBytes(key []byte) any {
return req.userValues.GetBytes(key)
}
// VisitUserValues calls visitor for each existing userValue with a key that is a string or []byte.
//
// visitor must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (req *Request) VisitUserValues(visitor func([]byte, any)) {
for i, n := 0, len(req.userValues); i < n; i++ {
kv := &req.userValues[i]
if _, ok := kv.key.(string); ok {
visitor(s2b(kv.key.(string)), kv.value)
}
}
}
// VisitUserValuesAll calls visitor for each existing userValue.
//
// visitor must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (req *Request) VisitUserValuesAll(visitor func(any, any)) {
for i, n := 0, len(req.userValues); i < n; i++ {
kv := &req.userValues[i]
visitor(kv.key, kv.value)
}
}
// ResetUserValues allows to reset user values from Request Context.
func (req *Request) ResetUserValues() {
req.userValues.Reset()
}
// RemoveUserValue removes the given key and the value under it in Request.
func (req *Request) RemoveUserValue(key any) {
req.userValues.Remove(key)
}
// RemoveUserValueBytes removes the given key and the value under it in Request.
func (req *Request) RemoveUserValueBytes(key []byte) {
req.userValues.RemoveBytes(key)
}
func getHTTPString(hw httpWriter) string {
w := bytebufferpool.Get()
defer bytebufferpool.Put(w)
bw := bufio.NewWriter(w)
if err := hw.Write(bw); err != nil {
return err.Error()
}
if err := bw.Flush(); err != nil {
return err.Error()
}
s := string(w.B)
return s
}
type httpWriter interface {
Write(w *bufio.Writer) error
}
func writeBodyChunked(w *bufio.Writer, r io.Reader) error {
vbuf := copyBufPool.Get()
buf := vbuf.([]byte)
var err error
var n int
for {
n, err = r.Read(buf)
if n == 0 {
if err == nil {
continue
}
if err == io.EOF {
if err = writeChunk(w, buf[:0]); err != nil {
break
}
err = nil
}
break
}
if err = writeChunk(w, buf[:n]); err != nil {
break
}
}
copyBufPool.Put(vbuf)
return err
}
func limitedReaderSize(r io.Reader) int64 {
lr, ok := r.(*io.LimitedReader)
if !ok {
return -1
}
return lr.N
}
func writeBodyFixedSize(w *bufio.Writer, r io.Reader, size int64) error {
if size > maxSmallFileSize {
earlyFlush := false
switch r := r.(type) {
case *os.File:
earlyFlush = true
case *io.LimitedReader:
_, earlyFlush = r.R.(*os.File)
}
if earlyFlush {
// w buffer must be empty for triggering
// sendfile path in bufio.Writer.ReadFrom.
if err := w.Flush(); err != nil {
return err
}
}
}
n, err := copyZeroAlloc(w, r)
if n != size && err == nil {
err = fmt.Errorf("copied %d bytes from body stream instead of %d bytes", n, size)
}
return err
}
// copyZeroAlloc optimizes io.Copy by calling ReadFrom or WriteTo only when
// copying between os.File and net.TCPConn. If the reader has a WriteTo
// method, it uses WriteTo for copying; if the writer has a ReadFrom method,
// it uses ReadFrom for copying. If neither method is available, it gets a
// buffer from sync.Pool to perform the copy.
//
// io.CopyBuffer always uses the WriterTo or ReadFrom interface if it's
// available. however, os.File and net.TCPConn unfortunately have a
// fallback in their WriterTo that calls io.Copy if sendfile isn't possible.
//
// See issue: https://github.com/valyala/fasthttp/issues/1889
//
// sendfile can only be triggered when copying between os.File and net.TCPConn.
// Since the function confirming zero-copy is a private function, we use
// ReadFrom only in this specific scenario. For all other cases, we prioritize
// using our own copyBuffer method.
//
// o: our copyBuffer
// r: readFrom
// w: writeTo
//
// write\read *File *TCPConn writeTo other
// *File o r w o
// *TCPConn w,r o w o
// readFrom r r w r
// other o o w o
//
//nolint:dupword
func copyZeroAlloc(w io.Writer, r io.Reader) (int64, error) {
var readerIsFile, readerIsConn bool
switch r := r.(type) {
case *os.File:
readerIsFile = true
case *net.TCPConn:
readerIsConn = true
case io.WriterTo:
return r.WriteTo(w)
}
switch w := w.(type) {
case *os.File:
if readerIsConn {
return w.ReadFrom(r)
}
case *net.TCPConn:
if readerIsFile {
// net.WriteTo requires go1.22 or later
// Benchmark tests show that on Windows, WriteTo performs
// significantly better than ReadFrom. On Linux, however,
// ReadFrom slightly outperforms WriteTo. When possible,
// copyZeroAlloc aims to perform better than or as well
// as io.Copy, so we use WriteTo whenever possible for
// optimal performance.
if rt, ok := r.(io.WriterTo); ok {
return rt.WriteTo(w)
}
return w.ReadFrom(r)
}
case io.ReaderFrom:
return w.ReadFrom(r)
}
vbuf := copyBufPool.Get()
buf := vbuf.([]byte)
n, err := copyBuffer(w, r, buf)
copyBufPool.Put(vbuf)
return n, err
}
// copyBuffer is rewritten from io.copyBuffer. We do not check if src has a
// WriteTo method, if dst has a ReadFrom method, or if buf is empty.
func copyBuffer(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) {
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw < 0 || nr < nw {
nw = 0
if ew == nil {
ew = errors.New("invalid write result")
}
}
written += int64(nw)
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
}
var copyBufPool = sync.Pool{
New: func() any {
return make([]byte, 4096)
},
}
func writeChunk(w *bufio.Writer, b []byte) error {
n := len(b)
if err := writeHexInt(w, n); err != nil {
return err
}
if _, err := w.Write(strCRLF); err != nil {
return err
}
if _, err := w.Write(b); err != nil {
return err
}
// If is end chunk, write CRLF after writing trailer
if n > 0 {
if _, err := w.Write(strCRLF); err != nil {
return err
}
}
return w.Flush()
}
// ErrBodyTooLarge is returned if either request or response body exceeds
// the given limit.
var ErrBodyTooLarge = errors.New("body size exceeds the given limit")
func readBody(r *bufio.Reader, contentLength, maxBodySize int, dst []byte) ([]byte, error) {
if maxBodySize > 0 && contentLength > maxBodySize {
return dst, ErrBodyTooLarge
}
return appendBodyFixedSize(r, dst, contentLength)
}
var errChunkedStream = errors.New("chunked stream")
func readBodyWithStreaming(r *bufio.Reader, contentLength, maxBodySize int, dst []byte) (b []byte, err error) {
if contentLength == -1 {
// handled in requestStream.Read()
return b, errChunkedStream
}
dst = dst[:0]
readN := maxBodySize
if readN > contentLength {
readN = contentLength
}
if readN > 8*1024 {
readN = 8 * 1024
}
// A fixed-length pre-read function should be used here; otherwise,
// it may read content beyond the request body into areas outside
// the br buffer. This could affect the handling of the next request
// in the br buffer, if there is one. The original two branches can
// be handled with this single branch. by the way,
// fix issue: https://github.com/valyala/fasthttp/issues/1816
b, err = appendBodyFixedSize(r, dst, readN)
if err != nil {
return b, err
}
if contentLength > maxBodySize {
return b, ErrBodyTooLarge
}
return b, nil
}
func readBodyIdentity(r *bufio.Reader, maxBodySize int, dst []byte) ([]byte, error) {
dst = dst[:cap(dst)]
if len(dst) == 0 {
dst = make([]byte, 1024)
}
offset := 0
for {
nn, err := r.Read(dst[offset:])
if nn <= 0 {
switch {
case errors.Is(err, io.EOF):
return dst[:offset], nil
case err != nil:
return dst[:offset], err
default:
return dst[:offset], fmt.Errorf("bufio.Read() returned (%d, nil)", nn)
}
}
offset += nn
if maxBodySize > 0 && offset > maxBodySize {
return dst[:offset], ErrBodyTooLarge
}
if len(dst) == offset {
n := roundUpForSliceCap(2 * offset)
if maxBodySize > 0 && n > maxBodySize {
n = maxBodySize + 1
}
b := make([]byte, n)
copy(b, dst)
dst = b
}
}
}
func appendBodyFixedSize(r *bufio.Reader, dst []byte, n int) ([]byte, error) {
if n == 0 {
return dst, nil
}
offset := len(dst)
dstLen := offset + n
if cap(dst) < dstLen {
b := make([]byte, roundUpForSliceCap(dstLen))
copy(b, dst)
dst = b
}
dst = dst[:dstLen]
for {
nn, err := r.Read(dst[offset:])
if nn <= 0 {
switch {
case errors.Is(err, io.EOF):
return dst[:offset], io.ErrUnexpectedEOF
case err != nil:
return dst[:offset], err
default:
return dst[:offset], fmt.Errorf("bufio.Read() returned (%d, nil)", nn)
}
}
offset += nn
if offset == dstLen {
return dst, nil
}
}
}
// ErrBrokenChunk is returned when server receives a broken chunked body (Transfer-Encoding: chunked).
type ErrBrokenChunk struct {
error
}
func readBodyChunked(r *bufio.Reader, maxBodySize int, dst []byte) ([]byte, error) {
if len(dst) > 0 {
// data integrity might be in danger. No idea what we received,
// but nothing we should write to.
panic("BUG: expected zero-length buffer")
}
strCRLFLen := len(strCRLF)
for {
chunkSize, err := parseChunkSize(r)
if err != nil {
return dst, err
}
if chunkSize == 0 {
return dst, err
}
if maxBodySize > 0 && len(dst)+chunkSize > maxBodySize {
return dst, ErrBodyTooLarge
}
dst, err = appendBodyFixedSize(r, dst, chunkSize+strCRLFLen)
if err != nil {
return dst, err
}
if !bytes.Equal(dst[len(dst)-strCRLFLen:], strCRLF) {
return dst, ErrBrokenChunk{
error: errors.New("cannot find crlf at the end of chunk"),
}
}
dst = dst[:len(dst)-strCRLFLen]
}
}
func parseChunkSize(r *bufio.Reader) (int, error) {
n, err := readHexInt(r)
if err != nil {
return -1, err
}
for {
c, err := r.ReadByte()
if err != nil {
return -1, ErrBrokenChunk{
error: fmt.Errorf("cannot read '\\r' char at the end of chunk size: %w", err),
}
}
// Skip chunk extension after chunk size.
// Add support later if anyone needs it.
if c != '\r' {
// Security: Don't allow newlines in chunk extensions.
// This can lead to request smuggling issues with some reverse proxies.
if c == '\n' {
return -1, ErrBrokenChunk{
error: errors.New("invalid character '\\n' after chunk size"),
}
}
continue
}
if err := r.UnreadByte(); err != nil {
return -1, ErrBrokenChunk{
error: fmt.Errorf("cannot unread '\\r' char at the end of chunk size: %w", err),
}
}
break
}
err = readCrLf(r)
if err != nil {
return -1, err
}
return n, nil
}
func readCrLf(r *bufio.Reader) error {
for _, exp := range []byte{'\r', '\n'} {
c, err := r.ReadByte()
if err != nil {
return ErrBrokenChunk{
error: fmt.Errorf("cannot read %q char at the end of chunk size: %w", exp, err),
}
}
if c != exp {
return ErrBrokenChunk{
error: fmt.Errorf("unexpected char %q at the end of chunk size. Expected %q", c, exp),
}
}
}
return nil
}
// SetTimeout sets timeout for the request.
//
// The following code:
//
// req.SetTimeout(t)
// c.Do(&req, &resp)
//
// is equivalent to
//
// c.DoTimeout(&req, &resp, t)
func (req *Request) SetTimeout(t time.Duration) {
req.timeout = t
}
package fasthttp
import (
"sync"
"sync/atomic"
"time"
)
// BalancingClient is the interface for clients, which may be passed
// to LBClient.Clients.
type BalancingClient interface {
DoDeadline(req *Request, resp *Response, deadline time.Time) error
PendingRequests() int
}
// LBClient balances requests among available LBClient.Clients.
//
// It has the following features:
//
// - Balances load among available clients using 'least loaded' + 'least total'
// hybrid technique.
// - Dynamically decreases load on unhealthy clients.
//
// It is forbidden copying LBClient instances. Create new instances instead.
//
// It is safe calling LBClient methods from concurrently running goroutines.
type LBClient struct {
noCopy noCopy
// HealthCheck is a callback called after each request.
//
// The request, response and the error returned by the client
// is passed to HealthCheck, so the callback may determine whether
// the client is healthy.
//
// Load on the current client is decreased if HealthCheck returns false.
//
// By default HealthCheck returns false if err != nil.
HealthCheck func(req *Request, resp *Response, err error) bool
// Clients must contain non-zero clients list.
// Incoming requests are balanced among these clients.
Clients []BalancingClient
cs []*lbClient
// Timeout is the request timeout used when calling LBClient.Do.
//
// DefaultLBClientTimeout is used by default.
Timeout time.Duration
mu sync.RWMutex
once sync.Once
}
// DefaultLBClientTimeout is the default request timeout used by LBClient
// when calling LBClient.Do.
//
// The timeout may be overridden via LBClient.Timeout.
const DefaultLBClientTimeout = time.Second
// DoDeadline calls DoDeadline on the least loaded client.
func (cc *LBClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
return cc.get().DoDeadline(req, resp, deadline)
}
// DoTimeout calculates deadline and calls DoDeadline on the least loaded client.
func (cc *LBClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
return cc.get().DoDeadline(req, resp, deadline)
}
// Do calculates timeout using LBClient.Timeout and calls DoTimeout
// on the least loaded client.
func (cc *LBClient) Do(req *Request, resp *Response) error {
timeout := cc.Timeout
if timeout <= 0 {
timeout = DefaultLBClientTimeout
}
return cc.DoTimeout(req, resp, timeout)
}
func (cc *LBClient) init() {
cc.mu.Lock()
defer cc.mu.Unlock()
if len(cc.Clients) == 0 {
// developer sanity-check
panic("BUG: LBClient.Clients cannot be empty")
}
for _, c := range cc.Clients {
cc.cs = append(cc.cs, &lbClient{
c: c,
healthCheck: cc.HealthCheck,
})
}
}
// AddClient adds a new client to the balanced clients and
// returns the new total number of clients.
func (cc *LBClient) AddClient(c BalancingClient) int {
cc.mu.Lock()
cc.cs = append(cc.cs, &lbClient{
c: c,
healthCheck: cc.HealthCheck,
})
cc.mu.Unlock()
return len(cc.cs)
}
// RemoveClients removes clients using the provided callback.
// If rc returns true, the passed client will be removed.
// Returns the new total number of clients.
func (cc *LBClient) RemoveClients(rc func(BalancingClient) bool) int {
cc.mu.Lock()
n := 0
for idx, cs := range cc.cs {
cc.cs[idx] = nil
if rc(cs.c) {
continue
}
cc.cs[n] = cs
n++
}
cc.cs = cc.cs[:n]
cc.mu.Unlock()
return len(cc.cs)
}
func (cc *LBClient) get() *lbClient {
cc.once.Do(cc.init)
cc.mu.RLock()
cs := cc.cs
minC := cs[0]
minN := minC.PendingRequests()
minT := atomic.LoadUint64(&minC.total)
for _, c := range cs[1:] {
n := c.PendingRequests()
t := atomic.LoadUint64(&c.total)
if n < minN || (n == minN && t < minT) {
minC = c
minN = n
minT = t
}
}
cc.mu.RUnlock()
return minC
}
type lbClient struct {
c BalancingClient
healthCheck func(req *Request, resp *Response, err error) bool
penalty uint32
// total amount of requests handled.
total uint64
}
func (c *lbClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
err := c.c.DoDeadline(req, resp, deadline)
if !c.isHealthy(req, resp, err) && c.incPenalty() {
// Penalize the client returning error, so the next requests
// are routed to another clients.
time.AfterFunc(penaltyDuration, c.decPenalty)
} else {
atomic.AddUint64(&c.total, 1)
}
return err
}
func (c *lbClient) PendingRequests() int {
n := c.c.PendingRequests()
m := atomic.LoadUint32(&c.penalty)
return n + int(m)
}
func (c *lbClient) isHealthy(req *Request, resp *Response, err error) bool {
if c.healthCheck == nil {
return err == nil
}
return c.healthCheck(req, resp, err)
}
func (c *lbClient) incPenalty() bool {
m := atomic.AddUint32(&c.penalty, 1)
if m > maxPenalty {
c.decPenalty()
return false
}
return true
}
func (c *lbClient) decPenalty() {
atomic.AddUint32(&c.penalty, ^uint32(0))
}
const (
maxPenalty = 300
penaltyDuration = 3 * time.Second
)
package fasthttp
// Embed this type into a struct, which mustn't be copied,
// so `go vet` gives a warning if this struct is copied.
//
// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for details.
// and also: https://stackoverflow.com/questions/52494458/nocopy-minimal-example
type noCopy struct{}
func (*noCopy) Lock() {}
func (*noCopy) Unlock() {}
package fasthttp
import (
"crypto/tls"
"net"
"sync"
)
type perIPConnCounter struct {
perIPConnPool sync.Pool
perIPTLSConnPool sync.Pool
m map[uint32]int
lock sync.Mutex
}
func (cc *perIPConnCounter) Register(ip uint32) int {
cc.lock.Lock()
if cc.m == nil {
cc.m = make(map[uint32]int)
}
n := cc.m[ip] + 1
cc.m[ip] = n
cc.lock.Unlock()
return n
}
func (cc *perIPConnCounter) Unregister(ip uint32) {
cc.lock.Lock()
defer cc.lock.Unlock()
if cc.m == nil {
// developer safeguard
panic("BUG: perIPConnCounter.Register() wasn't called")
}
n := cc.m[ip] - 1
if n < 0 {
n = 0
}
cc.m[ip] = n
}
type perIPConn struct {
net.Conn
perIPConnCounter *perIPConnCounter
ip uint32
lock sync.Mutex
}
type perIPTLSConn struct {
*tls.Conn
perIPConnCounter *perIPConnCounter
ip uint32
lock sync.Mutex
}
func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) net.Conn {
if tlsConn, ok := conn.(*tls.Conn); ok {
v := counter.perIPTLSConnPool.Get()
if v == nil {
return &perIPTLSConn{
perIPConnCounter: counter,
Conn: tlsConn,
ip: ip,
}
}
c := v.(*perIPTLSConn)
c.Conn = tlsConn
c.ip = ip
return c
}
v := counter.perIPConnPool.Get()
if v == nil {
return &perIPConn{
perIPConnCounter: counter,
Conn: conn,
ip: ip,
}
}
c := v.(*perIPConn)
c.Conn = conn
c.ip = ip
return c
}
func (c *perIPConn) Close() error {
c.lock.Lock()
cc := c.Conn
c.Conn = nil
c.lock.Unlock()
if cc == nil {
return nil
}
err := cc.Close()
c.perIPConnCounter.Unregister(c.ip)
c.perIPConnCounter.perIPConnPool.Put(c)
return err
}
func (c *perIPTLSConn) Close() error {
c.lock.Lock()
cc := c.Conn
c.Conn = nil
c.lock.Unlock()
if cc == nil {
return nil
}
err := cc.Close()
c.perIPConnCounter.Unregister(c.ip)
c.perIPConnCounter.perIPTLSConnPool.Put(c)
return err
}
func getUint32IP(c net.Conn) uint32 {
return ip2uint32(getConnIP4(c))
}
func getConnIP4(c net.Conn) net.IP {
addr := c.RemoteAddr()
ipAddr, ok := addr.(*net.TCPAddr)
if !ok {
return net.IPv4zero
}
return ipAddr.IP.To4()
}
func ip2uint32(ip net.IP) uint32 {
if len(ip) != 4 {
return 0
}
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
}
func uint322ip(ip uint32) net.IP {
b := make([]byte, 4)
b[0] = byte(ip >> 24)
b[1] = byte(ip >> 16)
b[2] = byte(ip >> 8)
b[3] = byte(ip)
return b
}
//go:build amd64 || arm64 || ppc64 || ppc64le || riscv64 || s390x
package fasthttp
func roundUpForSliceCap(n int) int {
if n <= 0 {
return 0
}
// Above 100MB, we don't round up as the overhead is too large.
if n > 100*1024*1024 {
return n
}
x := uint64(n - 1) // #nosec G115
x |= x >> 1
x |= x >> 2
x |= x >> 4
x |= x >> 8
x |= x >> 16
return int(x + 1) // #nosec G115
}
package fasthttp
import "unsafe"
// s2b converts string to a byte slice without memory allocation.
func s2b(s string) []byte {
return unsafe.Slice(unsafe.StringData(s), len(s))
}
package fasthttp
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"mime/multipart"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
var errNoCertOrKeyProvided = errors.New("cert or key has not provided")
// ErrAlreadyServing is deprecated.
// Deprecated: ErrAlreadyServing is never returned from Serve. See issue #633.
var ErrAlreadyServing = errors.New("Server is already serving connections")
// ServeConn serves HTTP requests from the given connection
// using the given handler.
//
// ServeConn returns nil if all requests from the c are successfully served.
// It returns non-nil error otherwise.
//
// Connection c must immediately propagate all the data passed to Write()
// to the client. Otherwise requests' processing may hang.
//
// ServeConn closes c before returning.
func ServeConn(c net.Conn, handler RequestHandler) error {
v := serverPool.Get()
if v == nil {
v = &Server{}
}
s := v.(*Server)
s.Handler = handler
err := s.ServeConn(c)
s.Handler = nil
serverPool.Put(v)
return err
}
var serverPool sync.Pool
// Serve serves incoming connections from the given listener
// using the given handler.
//
// Serve blocks until the given listener returns permanent error.
func Serve(ln net.Listener, handler RequestHandler) error {
s := &Server{
Handler: handler,
}
return s.Serve(ln)
}
// ServeTLS serves HTTPS requests from the given net.Listener
// using the given handler.
//
// certFile and keyFile are paths to TLS certificate and key files.
func ServeTLS(ln net.Listener, certFile, keyFile string, handler RequestHandler) error {
s := &Server{
Handler: handler,
}
return s.ServeTLS(ln, certFile, keyFile)
}
// ServeTLSEmbed serves HTTPS requests from the given net.Listener
// using the given handler.
//
// certData and keyData must contain valid TLS certificate and key data.
func ServeTLSEmbed(ln net.Listener, certData, keyData []byte, handler RequestHandler) error {
s := &Server{
Handler: handler,
}
return s.ServeTLSEmbed(ln, certData, keyData)
}
// ListenAndServe serves HTTP requests from the given TCP addr
// using the given handler.
func ListenAndServe(addr string, handler RequestHandler) error {
s := &Server{
Handler: handler,
}
return s.ListenAndServe(addr)
}
// ListenAndServeUNIX serves HTTP requests from the given UNIX addr
// using the given handler.
//
// The function deletes existing file at addr before starting serving.
//
// The server sets the given file mode for the UNIX addr.
func ListenAndServeUNIX(addr string, mode os.FileMode, handler RequestHandler) error {
s := &Server{
Handler: handler,
}
return s.ListenAndServeUNIX(addr, mode)
}
// ListenAndServeTLS serves HTTPS requests from the given TCP addr
// using the given handler.
//
// certFile and keyFile are paths to TLS certificate and key files.
func ListenAndServeTLS(addr, certFile, keyFile string, handler RequestHandler) error {
s := &Server{
Handler: handler,
}
return s.ListenAndServeTLS(addr, certFile, keyFile)
}
// ListenAndServeTLSEmbed serves HTTPS requests from the given TCP addr
// using the given handler.
//
// certData and keyData must contain valid TLS certificate and key data.
func ListenAndServeTLSEmbed(addr string, certData, keyData []byte, handler RequestHandler) error {
s := &Server{
Handler: handler,
}
return s.ListenAndServeTLSEmbed(addr, certData, keyData)
}
// RequestHandler must process incoming requests.
//
// RequestHandler must call ctx.TimeoutError() before returning
// if it keeps references to ctx and/or its members after the return.
// Consider wrapping RequestHandler into TimeoutHandler if response time
// must be limited.
type RequestHandler func(ctx *RequestCtx)
// ServeHandler must process tls.Config.NextProto negotiated requests.
type ServeHandler func(c net.Conn) error
// Server implements HTTP server.
//
// Default Server settings should satisfy the majority of Server users.
// Adjust Server settings only if you really understand the consequences.
//
// It is forbidden copying Server instances. Create new Server instances
// instead.
//
// It is safe to call Server methods from concurrently running goroutines.
type Server struct {
noCopy noCopy
perIPConnCounter perIPConnCounter
ctxPool sync.Pool
readerPool sync.Pool
writerPool sync.Pool
hijackConnPool sync.Pool
// Logger, which is used by RequestCtx.Logger().
//
// By default standard logger from log package is used.
Logger Logger
// Handler for processing incoming requests.
//
// Take into account that no `panic` recovery is done by `fasthttp` (thus any `panic` will take down the entire server).
// Instead the user should use `recover` to handle these situations.
Handler RequestHandler
// ErrorHandler for returning a response in case of an error while receiving or parsing the request.
//
// The following is a non-exhaustive list of errors that can be expected as argument:
// * io.EOF
// * io.ErrUnexpectedEOF
// * ErrGetOnly
// * ErrSmallBuffer
// * ErrBodyTooLarge
// * ErrBrokenChunks
ErrorHandler func(ctx *RequestCtx, err error)
// HeaderReceived is called after receiving the header.
//
// Non zero RequestConfig field values will overwrite the default configs
HeaderReceived func(header *RequestHeader) RequestConfig
// ContinueHandler is called after receiving the Expect 100 Continue Header.
//
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.1.1
// Using ContinueHandler a server can make decisioning on whether or not
// to read a potentially large request body based on the headers.
//
// The default is to automatically read request bodies of Expect 100 Continue requests
// like they are normal requests.
ContinueHandler func(header *RequestHeader) bool
// ConnState specifies an optional callback function that is
// called when a client connection changes state. See the
// ConnState type and associated constants for details.
ConnState func(net.Conn, ConnState)
// TLSConfig optionally provides a TLS configuration for use
// by ServeTLS, ServeTLSEmbed, ListenAndServeTLS, ListenAndServeTLSEmbed,
// AppendCert, AppendCertEmbed and NextProto.
//
// Note that this value is cloned by ServeTLS, ServeTLSEmbed, ListenAndServeTLS
// and ListenAndServeTLSEmbed, so it's not possible to modify the configuration
// with methods like tls.Config.SetSessionTicketKeys.
// To use SetSessionTicketKeys, use Server.Serve with a TLS Listener
// instead.
TLSConfig *tls.Config
// FormValueFunc, which is used by RequestCtx.FormValue and support for customizing
// the behaviour of the RequestCtx.FormValue function.
//
// NetHttpFormValueFunc gives a FormValueFunc func implementation that is consistent with net/http.
FormValueFunc FormValueFunc
nextProtos map[string]ServeHandler
concurrencyCh chan struct{}
idleConns map[net.Conn]*atomic.Int64
done chan struct{}
// Server name for sending in response headers.
//
// Default server name is used if left blank.
Name string
// We need to know our listeners and idle connections so we can close them in Shutdown().
ln []net.Listener
// The maximum number of concurrent connections the server may serve.
//
// DefaultConcurrency is used if not set.
//
// Concurrency only works if you either call Serve once, or only ServeConn multiple times.
// It works with ListenAndServe as well.
Concurrency int
// Per-connection buffer size for requests' reading.
// This also limits the maximum header size.
//
// Increase this buffer if your clients send multi-KB RequestURIs
// and/or multi-KB headers (for example, BIG cookies).
//
// Default buffer size is used if not set.
ReadBufferSize int
// Per-connection buffer size for responses' writing.
//
// Default buffer size is used if not set.
WriteBufferSize int
// ReadTimeout is the amount of time allowed to read
// the full request including body. The connection's read
// deadline is reset when the connection opens, or for
// keep-alive connections after the first byte has been read.
//
// By default request read timeout is unlimited.
ReadTimeout time.Duration
// WriteTimeout is the maximum duration before timing out
// writes of the response. It is reset after the request handler
// has returned.
//
// By default response write timeout is unlimited.
WriteTimeout time.Duration
// IdleTimeout is the maximum amount of time to wait for the
// next request when keep-alive is enabled. If IdleTimeout
// is zero, the value of ReadTimeout is used.
IdleTimeout time.Duration
// Maximum number of concurrent client connections allowed per IP.
//
// By default unlimited number of concurrent connections
// may be established to the server from a single IP address.
MaxConnsPerIP int
// Maximum number of requests served per connection.
//
// The server closes connection after the last request.
// 'Connection: close' header is added to the last response.
//
// By default unlimited number of requests may be served per connection.
MaxRequestsPerConn int
// MaxKeepaliveDuration is a no-op and only left here for backwards compatibility.
// Deprecated: Use IdleTimeout instead.
MaxKeepaliveDuration time.Duration
// MaxIdleWorkerDuration is the maximum idle time of a single worker in the underlying
// worker pool of the Server. Idle workers beyond this time will be cleared.
MaxIdleWorkerDuration time.Duration
// Period between tcp keep-alive messages.
//
// TCP keep-alive period is determined by operation system by default.
TCPKeepalivePeriod time.Duration
// Maximum request body size.
//
// The server rejects requests with bodies exceeding this limit.
//
// Request body size is limited by DefaultMaxRequestBodySize by default.
MaxRequestBodySize int
// SleepWhenConcurrencyLimitsExceeded is a duration to be slept of if
// the concurrency limit in exceeded (default [when is 0]: don't sleep
// and accept new connections immediately).
SleepWhenConcurrencyLimitsExceeded time.Duration
idleConnsMu sync.Mutex
mu sync.Mutex
concurrency uint32
open int32
stop int32
rejectedRequestsCount uint32
// Whether to disable keep-alive connections.
//
// The server will close all the incoming connections after sending
// the first response to client if this option is set to true.
//
// By default keep-alive connections are enabled.
DisableKeepalive bool
// Whether to enable tcp keep-alive connections.
//
// Whether the operating system should send tcp keep-alive messages on the tcp connection.
//
// By default tcp keep-alive connections are disabled.
TCPKeepalive bool
// Aggressively reduces memory usage at the cost of higher CPU usage
// if set to true.
//
// Try enabling this option only if the server consumes too much memory
// serving mostly idle keep-alive connections. This may reduce memory
// usage by more than 50%.
//
// Aggressive memory usage reduction is disabled by default.
ReduceMemoryUsage bool
// Rejects all non-GET requests if set to true.
//
// This option is useful as anti-DoS protection for servers
// accepting only GET requests and HEAD requests. The request size is limited
// by ReadBufferSize if GetOnly is set.
//
// Server accepts all the requests by default.
GetOnly bool
// Will not pre parse Multipart Form data if set to true.
//
// This option is useful for servers that desire to treat
// multipart form data as a binary blob, or choose when to parse the data.
//
// Server pre parses multipart form data by default.
DisablePreParseMultipartForm bool
// Logs all errors, including the most frequent
// 'connection reset by peer', 'broken pipe' and 'connection timeout'
// errors. Such errors are common in production serving real-world
// clients.
//
// By default the most frequent errors such as
// 'connection reset by peer', 'broken pipe' and 'connection timeout'
// are suppressed in order to limit output log traffic.
LogAllErrors bool
// Will not log potentially sensitive content in error logs
//
// This option is useful for servers that handle sensitive data
// in the request/response.
//
// Server logs all full errors by default.
SecureErrorLogMessage bool
// Header names are passed as-is without normalization
// if this option is set.
//
// Disabled header names' normalization may be useful only for proxying
// incoming requests to other servers expecting case-sensitive
// header names. See https://github.com/valyala/fasthttp/issues/57
// for details.
//
// By default request and response header names are normalized, i.e.
// The first letter and the first letters following dashes
// are uppercased, while all the other letters are lowercased.
// Examples:
//
// * HOST -> Host
// * content-type -> Content-Type
// * cONTENT-lenGTH -> Content-Length
DisableHeaderNamesNormalizing bool
// NoDefaultServerHeader, when set to true, causes the default Server header
// to be excluded from the Response.
//
// The default Server header value is the value of the Name field or an
// internal default value in its absence. With this option set to true,
// the only time a Server header will be sent is if a non-zero length
// value is explicitly provided during a request.
NoDefaultServerHeader bool
// NoDefaultDate, when set to true, causes the default Date
// header to be excluded from the Response.
//
// The default Date header value is the current date value. When
// set to true, the Date will not be present.
NoDefaultDate bool
// NoDefaultContentType, when set to true, causes the default Content-Type
// header to be excluded from the Response.
//
// The default Content-Type header value is the internal default value. When
// set to true, the Content-Type will not be present.
NoDefaultContentType bool
// KeepHijackedConns is an opt-in disable of connection
// close by fasthttp after connections' HijackHandler returns.
// This allows to save goroutines, e.g. when fasthttp used to upgrade
// http connections to WS and connection goes to another handler,
// which will close it when needed.
KeepHijackedConns bool
// CloseOnShutdown when true adds a `Connection: close` header when the server is shutting down.
CloseOnShutdown bool
// StreamRequestBody enables request body streaming,
// and calls the handler sooner when given body is
// larger than the current limit.
StreamRequestBody bool
}
// TimeoutHandler creates RequestHandler, which returns StatusRequestTimeout
// error with the given msg to the client if h didn't return during
// the given duration.
//
// The returned handler may return StatusTooManyRequests error with the given
// msg to the client if there are more than Server.Concurrency concurrent
// handlers h are running at the moment.
func TimeoutHandler(h RequestHandler, timeout time.Duration, msg string) RequestHandler {
return TimeoutWithCodeHandler(h, timeout, msg, StatusRequestTimeout)
}
// TimeoutWithCodeHandler creates RequestHandler, which returns an error with
// the given msg and status code to the client if h didn't return during
// the given duration.
//
// The returned handler may return StatusTooManyRequests error with the given
// msg to the client if there are more than Server.Concurrency concurrent
// handlers h are running at the moment.
func TimeoutWithCodeHandler(h RequestHandler, timeout time.Duration, msg string, statusCode int) RequestHandler {
if timeout <= 0 {
return h
}
return func(ctx *RequestCtx) {
concurrencyCh := ctx.s.concurrencyCh
select {
case concurrencyCh <- struct{}{}:
default:
ctx.Error(msg, StatusTooManyRequests)
return
}
ch := ctx.timeoutCh
if ch == nil {
ch = make(chan struct{}, 1)
ctx.timeoutCh = ch
}
go func() {
h(ctx)
ch <- struct{}{}
<-concurrencyCh
}()
ctx.timeoutTimer = initTimer(ctx.timeoutTimer, timeout)
select {
case <-ch:
case <-ctx.timeoutTimer.C:
ctx.TimeoutErrorWithCode(msg, statusCode)
}
stopTimer(ctx.timeoutTimer)
}
}
// RequestConfig configure the per request deadline and body limits.
type RequestConfig struct {
// ReadTimeout is the maximum duration for reading the entire
// request body.
// A zero value means that default values will be honored.
ReadTimeout time.Duration
// WriteTimeout is the maximum duration before timing out
// writes of the response.
// A zero value means that default values will be honored.
WriteTimeout time.Duration
// Maximum request body size.
// A zero value means that default values will be honored.
MaxRequestBodySize int
}
// CompressHandler returns RequestHandler that transparently compresses
// response body generated by h if the request contains 'gzip' or 'deflate'
// 'Accept-Encoding' header.
func CompressHandler(h RequestHandler) RequestHandler {
return CompressHandlerLevel(h, CompressDefaultCompression)
}
// CompressHandlerLevel returns RequestHandler that transparently compresses
// response body generated by h if the request contains a 'gzip' or 'deflate'
// 'Accept-Encoding' header.
//
// Level is the desired compression level:
//
// - CompressNoCompression
// - CompressBestSpeed
// - CompressBestCompression
// - CompressDefaultCompression
// - CompressHuffmanOnly
func CompressHandlerLevel(h RequestHandler, level int) RequestHandler {
return func(ctx *RequestCtx) {
h(ctx)
switch {
case ctx.Request.Header.HasAcceptEncodingBytes(strGzip):
ctx.Response.gzipBody(level)
case ctx.Request.Header.HasAcceptEncodingBytes(strDeflate):
ctx.Response.deflateBody(level)
case ctx.Request.Header.HasAcceptEncodingBytes(strZstd):
ctx.Response.zstdBody(level)
}
}
}
// CompressHandlerBrotliLevel returns RequestHandler that transparently compresses
// response body generated by h if the request contains a 'br', 'gzip' or 'deflate'
// 'Accept-Encoding' header.
//
// brotliLevel is the desired compression level for brotli.
//
// - CompressBrotliNoCompression
// - CompressBrotliBestSpeed
// - CompressBrotliBestCompression
// - CompressBrotliDefaultCompression
//
// otherLevel is the desired compression level for gzip and deflate.
//
// - CompressNoCompression
// - CompressBestSpeed
// - CompressBestCompression
// - CompressDefaultCompression
// - CompressHuffmanOnly
func CompressHandlerBrotliLevel(h RequestHandler, brotliLevel, otherLevel int) RequestHandler {
return func(ctx *RequestCtx) {
h(ctx)
switch {
case ctx.Request.Header.HasAcceptEncodingBytes(strBr):
ctx.Response.brotliBody(brotliLevel)
case ctx.Request.Header.HasAcceptEncodingBytes(strGzip):
ctx.Response.gzipBody(otherLevel)
case ctx.Request.Header.HasAcceptEncodingBytes(strDeflate):
ctx.Response.deflateBody(otherLevel)
case ctx.Request.Header.HasAcceptEncodingBytes(strZstd):
ctx.Response.zstdBody(otherLevel)
}
}
}
// RequestCtx contains incoming request and manages outgoing response.
//
// It is forbidden copying RequestCtx instances.
//
// RequestHandler should avoid holding references to incoming RequestCtx and/or
// its members after the return.
// If holding RequestCtx references after the return is unavoidable
// (for instance, ctx is passed to a separate goroutine and ctx lifetime cannot
// be controlled), then the RequestHandler MUST call ctx.TimeoutError()
// before return.
//
// It is unsafe modifying/reading RequestCtx instance from concurrently
// running goroutines. The only exception is TimeoutError*, which may be called
// while other goroutines accessing RequestCtx.
type RequestCtx struct {
noCopy noCopy
// Outgoing response.
//
// Copying Response by value is forbidden. Use pointer to Response instead.
Response Response
connTime time.Time
time time.Time
logger ctxLogger
remoteAddr net.Addr
c net.Conn
s *Server
timeoutResponse *Response
timeoutCh chan struct{}
timeoutTimer *time.Timer
hijackHandler HijackHandler
formValueFunc FormValueFunc
fbr firstByteReader
// Incoming request.
//
// Copying Request by value is forbidden. Use pointer to Request instead.
Request Request
connID uint64
connRequestNum uint64
hijackNoResponse bool
}
// EarlyHints allows the server to hint to the browser what resources a page would need
// so the browser can preload them while waiting for the server's full response. Only Link
// headers already written to the response will be transmitted as Early Hints.
//
// This is a HTTP/2+ feature but all browsers will either understand it or safely ignore it.
//
// NOTE: Older HTTP/1.1 non-browser clients may face compatibility issues.
//
// See: https://developer.chrome.com/docs/web-platform/early-hints and
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Link#syntax
//
// Example:
//
// func(ctx *fasthttp.RequestCtx) {
// ctx.Response.Header.Add("Link", "<https://fonts.google.com>; rel=preconnect")
// ctx.EarlyHints()
// time.Sleep(5*time.Second) // some time-consuming task
// ctx.SetStatusCode(fasthttp.StatusOK)
// ctx.SetBody([]byte("<html><head></head><body><h1>Hello from Fasthttp</h1></body></html>"))
// }
func (ctx *RequestCtx) EarlyHints() error {
links := ctx.Response.Header.PeekAll(b2s(strLink))
if len(links) > 0 {
c := acquireWriter(ctx)
defer releaseWriter(ctx.s, c)
_, err := c.Write(strEarlyHints)
if err != nil {
return err
}
for _, l := range links {
if len(l) == 0 {
continue
}
_, err = c.Write(strLink)
if err != nil {
return err
}
_, err = c.Write(strColon)
if err != nil {
return err
}
_, err = c.Write(strSpace)
if err != nil {
return err
}
_, err = c.Write(l)
if err != nil {
return err
}
_, err = c.Write(strCRLF)
if err != nil {
return err
}
}
_, err = c.Write(strCRLF)
if err != nil {
return err
}
err = c.Flush()
if err != nil {
return err
}
}
return nil
}
// HijackHandler must process the hijacked connection c.
//
// If KeepHijackedConns is disabled, which is by default,
// the connection c is automatically closed after returning from HijackHandler.
//
// The connection c must not be used after returning from the handler, if KeepHijackedConns is disabled.
//
// When KeepHijackedConns enabled, fasthttp will not Close() the connection,
// you must do it when you need it. You must not use c in any way after calling Close().
type HijackHandler func(c net.Conn)
// Hijack registers the given handler for connection hijacking.
//
// The handler is called after returning from RequestHandler
// and sending http response. The current connection is passed
// to the handler. The connection is automatically closed after
// returning from the handler.
//
// The server skips calling the handler in the following cases:
//
// - 'Connection: close' header exists in either request or response.
// - Unexpected error during response writing to the connection.
//
// The server stops processing requests from hijacked connections.
//
// Server limits such as Concurrency, ReadTimeout, WriteTimeout, etc.
// aren't applied to hijacked connections.
//
// The handler must not retain references to ctx members.
//
// Arbitrary 'Connection: Upgrade' protocols may be implemented
// with HijackHandler. For instance,
//
// - WebSocket ( https://en.wikipedia.org/wiki/WebSocket )
// - HTTP/2.0 ( https://en.wikipedia.org/wiki/HTTP/2 )
func (ctx *RequestCtx) Hijack(handler HijackHandler) {
ctx.hijackHandler = handler
}
// HijackSetNoResponse changes the behavior of hijacking a request.
// If HijackSetNoResponse is called with false fasthttp will send a response
// to the client before calling the HijackHandler (default). If HijackSetNoResponse
// is called with true no response is send back before calling the
// HijackHandler supplied in the Hijack function.
func (ctx *RequestCtx) HijackSetNoResponse(noResponse bool) {
ctx.hijackNoResponse = noResponse
}
// Hijacked returns true after Hijack is called.
func (ctx *RequestCtx) Hijacked() bool {
return ctx.hijackHandler != nil
}
// SetUserValue stores the given value (arbitrary object)
// under the given key in Request.
//
// The value stored in Request may be obtained by UserValue*.
//
// This functionality may be useful for passing arbitrary values between
// functions involved in request processing.
//
// All the values are removed from Request after returning from the top
// RequestHandler. Additionally, Close method is called on each value
// implementing io.Closer before removing the value from Request.
func (ctx *RequestCtx) SetUserValue(key, value any) {
ctx.Request.SetUserValue(key, value)
}
// SetUserValueBytes stores the given value (arbitrary object)
// under the given key in Request.
//
// The value stored in Request may be obtained by UserValue*.
//
// This functionality may be useful for passing arbitrary values between
// functions involved in request processing.
//
// All the values stored in Request are deleted after returning from RequestHandler.
func (ctx *RequestCtx) SetUserValueBytes(key []byte, value any) {
ctx.Request.SetUserValueBytes(key, value)
}
// UserValue returns the value stored via SetUserValue* under the given key.
func (ctx *RequestCtx) UserValue(key any) any {
return ctx.Request.UserValue(key)
}
// UserValueBytes returns the value stored via SetUserValue*
// under the given key.
func (ctx *RequestCtx) UserValueBytes(key []byte) any {
return ctx.Request.UserValueBytes(key)
}
// VisitUserValues calls visitor for each existing userValue with a key that is a string or []byte.
//
// visitor must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (ctx *RequestCtx) VisitUserValues(visitor func([]byte, any)) {
ctx.Request.VisitUserValues(visitor)
}
// VisitUserValuesAll calls visitor for each existing userValue.
//
// visitor must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (ctx *RequestCtx) VisitUserValuesAll(visitor func(any, any)) {
ctx.Request.VisitUserValuesAll(visitor)
}
// ResetUserValues allows to reset user values from Request.
func (ctx *RequestCtx) ResetUserValues() {
ctx.Request.ResetUserValues()
}
// RemoveUserValue removes the given key and the value under it in Request.
func (ctx *RequestCtx) RemoveUserValue(key any) {
ctx.Request.RemoveUserValue(key)
}
// RemoveUserValueBytes removes the given key and the value under it in Request.
func (ctx *RequestCtx) RemoveUserValueBytes(key []byte) {
ctx.Request.RemoveUserValueBytes(key)
}
type connTLSer interface {
Handshake() error
ConnectionState() tls.ConnectionState
}
// IsTLS returns true if the underlying connection is tls.Conn.
//
// tls.Conn is an encrypted connection (aka SSL, HTTPS).
func (ctx *RequestCtx) IsTLS() bool {
// cast to (connTLSer) instead of (*tls.Conn), since it catches
// cases with overridden tls.Conn such as:
//
// type customConn struct {
// *tls.Conn
//
// // other custom fields here
// }
// perIPConn wraps the net.Conn in the Conn field
if pic, ok := ctx.c.(*perIPConn); ok {
_, ok := pic.Conn.(connTLSer)
return ok
}
_, ok := ctx.c.(connTLSer)
return ok
}
// TLSConnectionState returns TLS connection state.
//
// The function returns nil if the underlying connection isn't tls.Conn.
//
// The returned state may be used for verifying TLS version, client certificates,
// etc.
func (ctx *RequestCtx) TLSConnectionState() *tls.ConnectionState {
tlsConn, ok := ctx.c.(connTLSer)
if !ok {
return nil
}
state := tlsConn.ConnectionState()
return &state
}
// Conn returns a reference to the underlying net.Conn.
//
// WARNING: Only use this method if you know what you are doing!
//
// Reading from or writing to the returned connection will end badly!
func (ctx *RequestCtx) Conn() net.Conn {
return ctx.c
}
func (ctx *RequestCtx) reset() {
ctx.Request.Reset()
ctx.Response.Reset()
ctx.fbr.reset()
ctx.connID = 0
ctx.connRequestNum = 0
ctx.connTime = zeroTime
ctx.remoteAddr = nil
ctx.time = zeroTime
ctx.c = nil
// Don't reset ctx.s!
// We have a pool per server so the next time this ctx is used it
// will be assigned the same value again.
// ctx might still be in use for context.Done() and context.Err()
// which are safe to use as they only use ctx.s and no other value.
if ctx.timeoutResponse != nil {
ctx.timeoutResponse.Reset()
}
if ctx.timeoutTimer != nil {
stopTimer(ctx.timeoutTimer)
}
ctx.hijackHandler = nil
ctx.hijackNoResponse = false
}
type firstByteReader struct {
c net.Conn
ch byte
byteRead bool
}
func (r *firstByteReader) reset() {
r.c = nil
r.ch = 0
r.byteRead = false
}
func (r *firstByteReader) Read(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}
nn := 0
if !r.byteRead {
b[0] = r.ch
b = b[1:]
r.byteRead = true
nn = 1
}
n, err := r.c.Read(b)
return n + nn, err
}
// Logger is used for logging formatted messages.
type Logger interface {
// Printf must have the same semantics as log.Printf.
Printf(format string, args ...any)
}
var ctxLoggerLock sync.Mutex
type ctxLogger struct {
ctx *RequestCtx
logger Logger
}
func (cl *ctxLogger) Printf(format string, args ...any) {
msg := fmt.Sprintf(format, args...)
ctxLoggerLock.Lock()
cl.logger.Printf("%.3f %s - %s", time.Since(cl.ctx.ConnTime()).Seconds(), cl.ctx.String(), msg)
ctxLoggerLock.Unlock()
}
var zeroTCPAddr = &net.TCPAddr{
IP: net.IPv4zero,
}
// String returns unique string representation of the ctx.
//
// The returned value may be useful for logging.
func (ctx *RequestCtx) String() string {
return fmt.Sprintf("#%016X - %s<->%s - %s %s", ctx.ID(), ctx.LocalAddr(), ctx.RemoteAddr(),
ctx.Request.Header.Method(), ctx.URI().FullURI())
}
// ID returns unique ID of the request.
func (ctx *RequestCtx) ID() uint64 {
return (ctx.connID << 32) | ctx.connRequestNum
}
// ConnID returns unique connection ID.
//
// This ID may be used to match distinct requests to the same incoming
// connection.
func (ctx *RequestCtx) ConnID() uint64 {
return ctx.connID
}
// Time returns RequestHandler call time.
func (ctx *RequestCtx) Time() time.Time {
return ctx.time
}
// ConnTime returns the time the server started serving the connection
// the current request came from.
func (ctx *RequestCtx) ConnTime() time.Time {
return ctx.connTime
}
// ConnRequestNum returns request sequence number
// for the current connection.
//
// Sequence starts with 1.
func (ctx *RequestCtx) ConnRequestNum() uint64 {
return ctx.connRequestNum
}
// SetConnectionClose sets 'Connection: close' response header and closes
// connection after the RequestHandler returns.
func (ctx *RequestCtx) SetConnectionClose() {
ctx.Response.SetConnectionClose()
}
// SetStatusCode sets response status code.
func (ctx *RequestCtx) SetStatusCode(statusCode int) {
ctx.Response.SetStatusCode(statusCode)
}
// SetContentType sets response Content-Type.
func (ctx *RequestCtx) SetContentType(contentType string) {
ctx.Response.Header.SetContentType(contentType)
}
// SetContentTypeBytes sets response Content-Type.
//
// It is safe modifying contentType buffer after function return.
func (ctx *RequestCtx) SetContentTypeBytes(contentType []byte) {
ctx.Response.Header.SetContentTypeBytes(contentType)
}
// RequestURI returns RequestURI.
//
// The returned bytes are valid until your request handler returns.
func (ctx *RequestCtx) RequestURI() []byte {
return ctx.Request.Header.RequestURI()
}
// URI returns requested uri.
//
// This uri is valid until your request handler returns.
func (ctx *RequestCtx) URI() *URI {
return ctx.Request.URI()
}
// Referer returns request referer.
//
// The returned bytes are valid until your request handler returns.
func (ctx *RequestCtx) Referer() []byte {
return ctx.Request.Header.Referer()
}
// UserAgent returns User-Agent header value from the request.
//
// The returned bytes are valid until your request handler returns.
func (ctx *RequestCtx) UserAgent() []byte {
return ctx.Request.Header.UserAgent()
}
// Path returns requested path.
//
// The returned bytes are valid until your request handler returns.
func (ctx *RequestCtx) Path() []byte {
return ctx.URI().Path()
}
// Host returns requested host.
//
// The returned bytes are valid until your request handler returns.
func (ctx *RequestCtx) Host() []byte {
return ctx.URI().Host()
}
// QueryArgs returns query arguments from RequestURI.
//
// It doesn't return POST'ed arguments - use PostArgs() for this.
//
// See also PostArgs, FormValue and FormFile.
//
// These args are valid until your request handler returns.
func (ctx *RequestCtx) QueryArgs() *Args {
return ctx.URI().QueryArgs()
}
// PostArgs returns POST arguments.
//
// It doesn't return query arguments from RequestURI - use QueryArgs for this.
//
// See also QueryArgs, FormValue and FormFile.
//
// These args are valid until your request handler returns.
func (ctx *RequestCtx) PostArgs() *Args {
return ctx.Request.PostArgs()
}
// MultipartForm returns request's multipart form.
//
// Returns ErrNoMultipartForm if request's content-type
// isn't 'multipart/form-data'.
//
// All uploaded temporary files are automatically deleted after
// returning from RequestHandler. Either move or copy uploaded files
// into new place if you want retaining them.
//
// Use SaveMultipartFile function for permanently saving uploaded file.
//
// The returned form is valid until your request handler returns.
//
// See also FormFile and FormValue.
func (ctx *RequestCtx) MultipartForm() (*multipart.Form, error) {
return ctx.Request.MultipartForm()
}
// FormFile returns uploaded file associated with the given multipart form key.
//
// The file is automatically deleted after returning from RequestHandler,
// so either move or copy uploaded file into new place if you want retaining it.
//
// Use SaveMultipartFile function for permanently saving uploaded file.
//
// The returned file header is valid until your request handler returns.
func (ctx *RequestCtx) FormFile(key string) (*multipart.FileHeader, error) {
mf, err := ctx.MultipartForm()
if err != nil {
return nil, err
}
if mf.File == nil {
return nil, err
}
fhh := mf.File[key]
if fhh == nil {
return nil, ErrMissingFile
}
return fhh[0], nil
}
// ErrMissingFile may be returned from FormFile when the is no uploaded file
// associated with the given multipart form key.
var ErrMissingFile = errors.New("there is no uploaded file associated with the given key")
// SaveMultipartFile saves multipart file fh under the given filename path.
func SaveMultipartFile(fh *multipart.FileHeader, path string) (err error) {
var (
f multipart.File
ff *os.File
)
f, err = fh.Open()
if err != nil {
return
}
var ok bool
if ff, ok = f.(*os.File); ok {
// Windows can't rename files that are opened.
if err = f.Close(); err != nil {
return
}
// If renaming fails we try the normal copying method.
// Renaming could fail if the files are on different devices.
if os.Rename(ff.Name(), path) == nil {
return nil
}
// Reopen f for the code below.
if f, err = fh.Open(); err != nil {
return
}
}
defer func() {
e := f.Close()
if err == nil {
err = e
}
}()
if ff, err = os.Create(path); err != nil {
return
}
defer func() {
e := ff.Close()
if err == nil {
err = e
}
}()
_, err = copyZeroAlloc(ff, f)
return
}
// FormValue returns form value associated with the given key.
//
// The value is searched in the following places:
//
// - Query string.
// - POST or PUT body.
//
// There are more fine-grained methods for obtaining form values:
//
// - QueryArgs for obtaining values from query string.
// - PostArgs for obtaining values from POST or PUT body.
// - MultipartForm for obtaining values from multipart form.
// - FormFile for obtaining uploaded files.
//
// The returned value is valid until your request handler returns.
func (ctx *RequestCtx) FormValue(key string) []byte {
if ctx.formValueFunc != nil {
return ctx.formValueFunc(ctx, key)
}
return defaultFormValue(ctx, key)
}
type FormValueFunc func(*RequestCtx, string) []byte
var (
defaultFormValue = func(ctx *RequestCtx, key string) []byte {
v := ctx.QueryArgs().Peek(key)
if len(v) > 0 {
return v
}
v = ctx.PostArgs().Peek(key)
if len(v) > 0 {
return v
}
mf, err := ctx.MultipartForm()
if err == nil && mf.Value != nil {
vv := mf.Value[key]
if len(vv) > 0 {
return []byte(vv[0])
}
}
return nil
}
// NetHttpFormValueFunc gives consistent behavior with net/http.
// POST and PUT body parameters take precedence over URL query string values.
//
//nolint:staticcheck // backwards compatibility
NetHttpFormValueFunc = func(ctx *RequestCtx, key string) []byte {
v := ctx.PostArgs().Peek(key)
if len(v) > 0 {
return v
}
mf, err := ctx.MultipartForm()
if err == nil && mf.Value != nil {
vv := mf.Value[key]
if len(vv) > 0 {
return []byte(vv[0])
}
}
v = ctx.QueryArgs().Peek(key)
if len(v) > 0 {
return v
}
return nil
}
)
// IsGet returns true if request method is GET.
func (ctx *RequestCtx) IsGet() bool {
return ctx.Request.Header.IsGet()
}
// IsPost returns true if request method is POST.
func (ctx *RequestCtx) IsPost() bool {
return ctx.Request.Header.IsPost()
}
// IsPut returns true if request method is PUT.
func (ctx *RequestCtx) IsPut() bool {
return ctx.Request.Header.IsPut()
}
// IsDelete returns true if request method is DELETE.
func (ctx *RequestCtx) IsDelete() bool {
return ctx.Request.Header.IsDelete()
}
// IsConnect returns true if request method is CONNECT.
func (ctx *RequestCtx) IsConnect() bool {
return ctx.Request.Header.IsConnect()
}
// IsOptions returns true if request method is OPTIONS.
func (ctx *RequestCtx) IsOptions() bool {
return ctx.Request.Header.IsOptions()
}
// IsTrace returns true if request method is TRACE.
func (ctx *RequestCtx) IsTrace() bool {
return ctx.Request.Header.IsTrace()
}
// IsPatch returns true if request method is PATCH.
func (ctx *RequestCtx) IsPatch() bool {
return ctx.Request.Header.IsPatch()
}
// Method return request method.
//
// Returned value is valid until your request handler returns.
func (ctx *RequestCtx) Method() []byte {
return ctx.Request.Header.Method()
}
// IsHead returns true if request method is HEAD.
func (ctx *RequestCtx) IsHead() bool {
return ctx.Request.Header.IsHead()
}
// RemoteAddr returns client address for the given request.
//
// Always returns non-nil result.
func (ctx *RequestCtx) RemoteAddr() net.Addr {
if ctx.remoteAddr != nil {
return ctx.remoteAddr
}
if ctx.c == nil {
return zeroTCPAddr
}
addr := ctx.c.RemoteAddr()
if addr == nil {
return zeroTCPAddr
}
return addr
}
// SetRemoteAddr sets remote address to the given value.
//
// Set nil value to restore default behaviour for using
// connection remote address.
func (ctx *RequestCtx) SetRemoteAddr(remoteAddr net.Addr) {
ctx.remoteAddr = remoteAddr
}
// LocalAddr returns server address for the given request.
//
// Always returns non-nil result.
func (ctx *RequestCtx) LocalAddr() net.Addr {
if ctx.c == nil {
return zeroTCPAddr
}
addr := ctx.c.LocalAddr()
if addr == nil {
return zeroTCPAddr
}
return addr
}
// RemoteIP returns the client ip the request came from.
//
// Always returns non-nil result.
func (ctx *RequestCtx) RemoteIP() net.IP {
return addrToIP(ctx.RemoteAddr())
}
// LocalIP returns the server ip the request came to.
//
// Always returns non-nil result.
func (ctx *RequestCtx) LocalIP() net.IP {
return addrToIP(ctx.LocalAddr())
}
func addrToIP(addr net.Addr) net.IP {
x, ok := addr.(*net.TCPAddr)
if !ok {
return net.IPv4zero
}
return x.IP
}
// Error sets response status code to the given value and sets response body
// to the given message.
//
// Warning: this will reset the response headers and body already set!
func (ctx *RequestCtx) Error(msg string, statusCode int) {
ctx.Response.Reset()
ctx.SetStatusCode(statusCode)
ctx.SetContentTypeBytes(defaultContentType)
ctx.SetBodyString(msg)
}
// Success sets response Content-Type and body to the given values.
func (ctx *RequestCtx) Success(contentType string, body []byte) {
ctx.SetContentType(contentType)
ctx.SetBody(body)
}
// SuccessString sets response Content-Type and body to the given values.
func (ctx *RequestCtx) SuccessString(contentType, body string) {
ctx.SetContentType(contentType)
ctx.SetBodyString(body)
}
// Redirect sets 'Location: uri' response header and sets the given statusCode.
//
// statusCode must have one of the following values:
//
// - StatusMovedPermanently (301)
// - StatusFound (302)
// - StatusSeeOther (303)
// - StatusTemporaryRedirect (307)
// - StatusPermanentRedirect (308)
//
// All other statusCode values are replaced by StatusFound (302).
//
// The redirect uri may be either absolute or relative to the current
// request uri. Fasthttp will always send an absolute uri back to the client.
// To send a relative uri you can use the following code:
//
// strLocation = []byte("Location") // Put this with your top level var () declarations.
// ctx.Response.Header.SetCanonical(strLocation, "/relative?uri")
// ctx.Response.SetStatusCode(fasthttp.StatusMovedPermanently)
func (ctx *RequestCtx) Redirect(uri string, statusCode int) {
u := AcquireURI()
ctx.URI().CopyTo(u)
u.Update(uri)
ctx.redirect(u.FullURI(), statusCode)
ReleaseURI(u)
}
// RedirectBytes sets 'Location: uri' response header and sets
// the given statusCode.
//
// statusCode must have one of the following values:
//
// - StatusMovedPermanently (301)
// - StatusFound (302)
// - StatusSeeOther (303)
// - StatusTemporaryRedirect (307)
// - StatusPermanentRedirect (308)
//
// All other statusCode values are replaced by StatusFound (302).
//
// The redirect uri may be either absolute or relative to the current
// request uri. Fasthttp will always send an absolute uri back to the client.
// To send a relative uri you can use the following code:
//
// strLocation = []byte("Location") // Put this with your top level var () declarations.
// ctx.Response.Header.SetCanonical(strLocation, "/relative?uri")
// ctx.Response.SetStatusCode(fasthttp.StatusMovedPermanently)
func (ctx *RequestCtx) RedirectBytes(uri []byte, statusCode int) {
s := b2s(uri)
ctx.Redirect(s, statusCode)
}
func (ctx *RequestCtx) redirect(uri []byte, statusCode int) {
ctx.Response.Header.setNonSpecial(strLocation, uri)
statusCode = getRedirectStatusCode(statusCode)
ctx.Response.SetStatusCode(statusCode)
}
func getRedirectStatusCode(statusCode int) int {
if statusCode == StatusMovedPermanently || statusCode == StatusFound ||
statusCode == StatusSeeOther || statusCode == StatusTemporaryRedirect ||
statusCode == StatusPermanentRedirect {
return statusCode
}
return StatusFound
}
// SetBody sets response body to the given value.
//
// It is safe re-using body argument after the function returns.
func (ctx *RequestCtx) SetBody(body []byte) {
ctx.Response.SetBody(body)
}
// SetBodyString sets response body to the given value.
func (ctx *RequestCtx) SetBodyString(body string) {
ctx.Response.SetBodyString(body)
}
// ResetBody resets response body contents.
func (ctx *RequestCtx) ResetBody() {
ctx.Response.ResetBody()
}
// SendFile sends local file contents from the given path as response body.
//
// This is a shortcut to ServeFile(ctx, path).
//
// SendFile logs all the errors via ctx.Logger.
//
// See also ServeFile, FSHandler and FS.
//
// WARNING: do not pass any user supplied paths to this function!
// WARNING: if path is based on user input users will be able to request
// any file on your filesystem! Use fasthttp.FS with a sane Root instead.
func (ctx *RequestCtx) SendFile(path string) {
ServeFile(ctx, path)
}
// SendFileBytes sends local file contents from the given path as response body.
//
// This is a shortcut to ServeFileBytes(ctx, path).
//
// SendFileBytes logs all the errors via ctx.Logger.
//
// See also ServeFileBytes, FSHandler and FS.
//
// WARNING: do not pass any user supplied paths to this function!
// WARNING: if path is based on user input users will be able to request
// any file on your filesystem! Use fasthttp.FS with a sane Root instead.
func (ctx *RequestCtx) SendFileBytes(path []byte) {
ServeFileBytes(ctx, path)
}
// IfModifiedSince returns true if lastModified exceeds 'If-Modified-Since'
// value from the request header.
//
// The function returns true also 'If-Modified-Since' request header is missing.
func (ctx *RequestCtx) IfModifiedSince(lastModified time.Time) bool {
ifModStr := ctx.Request.Header.peek(strIfModifiedSince)
if len(ifModStr) == 0 {
return true
}
ifMod, err := ParseHTTPDate(ifModStr)
if err != nil {
return true
}
lastModified = lastModified.Truncate(time.Second)
return ifMod.Before(lastModified)
}
// NotModified resets response and sets '304 Not Modified' response status code.
func (ctx *RequestCtx) NotModified() {
ctx.Response.Reset()
ctx.SetStatusCode(StatusNotModified)
}
// NotFound resets response and sets '404 Not Found' response status code.
func (ctx *RequestCtx) NotFound() {
ctx.Response.Reset()
ctx.SetStatusCode(StatusNotFound)
ctx.SetBodyString("404 Page not found")
}
// Write writes p into response body.
func (ctx *RequestCtx) Write(p []byte) (int, error) {
ctx.Response.AppendBody(p)
return len(p), nil
}
// WriteString appends s to response body.
func (ctx *RequestCtx) WriteString(s string) (int, error) {
ctx.Response.AppendBodyString(s)
return len(s), nil
}
// PostBody returns POST request body.
//
// The returned bytes are valid until your request handler returns.
func (ctx *RequestCtx) PostBody() []byte {
return ctx.Request.Body()
}
// SetBodyStream sets response body stream and, optionally body size.
//
// bodyStream.Close() is called after finishing reading all body data
// if it implements io.Closer.
//
// If bodySize is >= 0, then bodySize bytes must be provided by bodyStream
// before returning io.EOF.
//
// If bodySize < 0, then bodyStream is read until io.EOF.
//
// See also SetBodyStreamWriter.
func (ctx *RequestCtx) SetBodyStream(bodyStream io.Reader, bodySize int) {
ctx.Response.SetBodyStream(bodyStream, bodySize)
}
// SetBodyStreamWriter registers the given stream writer for populating
// response body.
//
// Access to RequestCtx and/or its members is forbidden from sw.
//
// This function may be used in the following cases:
//
// - if response body is too big (more than 10MB).
// - if response body is streamed from slow external sources.
// - if response body must be streamed to the client in chunks.
// (aka `http server push`).
func (ctx *RequestCtx) SetBodyStreamWriter(sw StreamWriter) {
ctx.Response.SetBodyStreamWriter(sw)
}
// IsBodyStream returns true if response body is set via SetBodyStream*.
func (ctx *RequestCtx) IsBodyStream() bool {
return ctx.Response.IsBodyStream()
}
// Logger returns logger, which may be used for logging arbitrary
// request-specific messages inside RequestHandler.
//
// Each message logged via returned logger contains request-specific information
// such as request id, request duration, local address, remote address,
// request method and request url.
//
// It is safe re-using returned logger for logging multiple messages
// for the current request.
//
// The returned logger is valid until your request handler returns.
func (ctx *RequestCtx) Logger() Logger {
if ctx.logger.ctx == nil {
ctx.logger.ctx = ctx
}
if ctx.logger.logger == nil {
ctx.logger.logger = ctx.s.logger()
}
return &ctx.logger
}
// TimeoutError sets response status code to StatusRequestTimeout and sets
// body to the given msg.
//
// All response modifications after TimeoutError call are ignored.
//
// TimeoutError MUST be called before returning from RequestHandler if there are
// references to ctx and/or its members in other goroutines remain.
//
// Usage of this function is discouraged. Prefer eliminating ctx references
// from pending goroutines instead of using this function.
func (ctx *RequestCtx) TimeoutError(msg string) {
ctx.TimeoutErrorWithCode(msg, StatusRequestTimeout)
}
// TimeoutErrorWithCode sets response body to msg and response status
// code to statusCode.
//
// All response modifications after TimeoutErrorWithCode call are ignored.
//
// TimeoutErrorWithCode MUST be called before returning from RequestHandler
// if there are references to ctx and/or its members in other goroutines remain.
//
// Usage of this function is discouraged. Prefer eliminating ctx references
// from pending goroutines instead of using this function.
func (ctx *RequestCtx) TimeoutErrorWithCode(msg string, statusCode int) {
var resp Response
resp.SetStatusCode(statusCode)
resp.SetBodyString(msg)
ctx.TimeoutErrorWithResponse(&resp)
}
// TimeoutErrorWithResponse marks the ctx as timed out and sends the given
// response to the client.
//
// All ctx modifications after TimeoutErrorWithResponse call are ignored.
//
// TimeoutErrorWithResponse MUST be called before returning from RequestHandler
// if there are references to ctx and/or its members in other goroutines remain.
//
// Usage of this function is discouraged. Prefer eliminating ctx references
// from pending goroutines instead of using this function.
func (ctx *RequestCtx) TimeoutErrorWithResponse(resp *Response) {
respCopy := &Response{}
resp.CopyTo(respCopy)
ctx.timeoutResponse = respCopy
}
// NextProto adds nph to be processed when key is negotiated when TLS
// connection is established.
//
// This function can only be called before the server is started.
func (s *Server) NextProto(key string, nph ServeHandler) {
if s.nextProtos == nil {
s.nextProtos = make(map[string]ServeHandler)
}
s.configTLS()
s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, key)
s.nextProtos[key] = nph
}
func (s *Server) getNextProto(c net.Conn) (proto string, err error) {
if tlsConn, ok := c.(connTLSer); ok {
if s.ReadTimeout > 0 {
if err = c.SetReadDeadline(time.Now().Add(s.ReadTimeout)); err != nil {
return
}
}
if s.WriteTimeout > 0 {
if err = c.SetWriteDeadline(time.Now().Add(s.WriteTimeout)); err != nil {
return
}
}
err = tlsConn.Handshake()
if err == nil {
proto = tlsConn.ConnectionState().NegotiatedProtocol
}
}
return
}
// ListenAndServe serves HTTP requests from the given TCP4 addr.
//
// Pass custom listener to Serve if you need listening on non-TCP4 media
// such as IPv6.
//
// Accepted connections are configured to enable TCP keep-alives.
func (s *Server) ListenAndServe(addr string) error {
ln, err := net.Listen("tcp4", addr)
if err != nil {
return err
}
return s.Serve(ln)
}
// ListenAndServeUNIX serves HTTP requests from the given UNIX addr.
//
// The function deletes existing file at addr before starting serving.
//
// The server sets the given file mode for the UNIX addr.
func (s *Server) ListenAndServeUNIX(addr string, mode os.FileMode) error {
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("unexpected error when trying to remove unix socket file %q: %w", addr, err)
}
ln, err := net.Listen("unix", addr)
if err != nil {
return err
}
if err = os.Chmod(addr, mode); err != nil {
return fmt.Errorf("cannot chmod %#o for %q: %w", mode, addr, err)
}
return s.Serve(ln)
}
// ListenAndServeTLS serves HTTPS requests from the given TCP4 addr.
//
// certFile and keyFile are paths to TLS certificate and key files.
//
// Pass custom listener to Serve if you need listening on non-TCP4 media
// such as IPv6.
//
// If the certFile or keyFile has not been provided to the server structure,
// the function will use the previously added TLS configuration.
//
// Accepted connections are configured to enable TCP keep-alives.
func (s *Server) ListenAndServeTLS(addr, certFile, keyFile string) error {
ln, err := net.Listen("tcp4", addr)
if err != nil {
return err
}
return s.ServeTLS(ln, certFile, keyFile)
}
// ListenAndServeTLSEmbed serves HTTPS requests from the given TCP4 addr.
//
// certData and keyData must contain valid TLS certificate and key data.
//
// Pass custom listener to Serve if you need listening on arbitrary media
// such as IPv6.
//
// If the certFile or keyFile has not been provided the server structure,
// the function will use previously added TLS configuration.
//
// Accepted connections are configured to enable TCP keep-alives.
func (s *Server) ListenAndServeTLSEmbed(addr string, certData, keyData []byte) error {
ln, err := net.Listen("tcp4", addr)
if err != nil {
return err
}
return s.ServeTLSEmbed(ln, certData, keyData)
}
// ServeTLS serves HTTPS requests from the given listener.
//
// certFile and keyFile are paths to TLS certificate and key files.
//
// If the certFile or keyFile has not been provided the server structure,
// the function will use previously added TLS configuration.
func (s *Server) ServeTLS(ln net.Listener, certFile, keyFile string) error {
s.mu.Lock()
s.configTLS()
configHasCert := len(s.TLSConfig.Certificates) > 0 || s.TLSConfig.GetCertificate != nil
if !configHasCert || certFile != "" || keyFile != "" {
if err := s.AppendCert(certFile, keyFile); err != nil {
s.mu.Unlock()
return err
}
}
s.mu.Unlock()
return s.Serve(
tls.NewListener(ln, s.TLSConfig.Clone()),
)
}
// ServeTLSEmbed serves HTTPS requests from the given listener.
//
// certData and keyData must contain valid TLS certificate and key data.
//
// If the certFile or keyFile has not been provided the server structure,
// the function will use previously added TLS configuration.
func (s *Server) ServeTLSEmbed(ln net.Listener, certData, keyData []byte) error {
s.mu.Lock()
s.configTLS()
configHasCert := len(s.TLSConfig.Certificates) > 0 || s.TLSConfig.GetCertificate != nil
if !configHasCert || len(certData) != 0 || len(keyData) != 0 {
if err := s.AppendCertEmbed(certData, keyData); err != nil {
s.mu.Unlock()
return err
}
}
s.mu.Unlock()
return s.Serve(
tls.NewListener(ln, s.TLSConfig.Clone()),
)
}
// AppendCert appends certificate and keyfile to TLS Configuration.
//
// This function allows programmer to handle multiple domains
// in one server structure. See examples/multidomain.
func (s *Server) AppendCert(certFile, keyFile string) error {
if certFile == "" && keyFile == "" {
return errNoCertOrKeyProvided
}
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("cannot load TLS key pair from certFile=%q and keyFile=%q: %w", certFile, keyFile, err)
}
s.configTLS()
s.TLSConfig.Certificates = append(s.TLSConfig.Certificates, cert)
return nil
}
// AppendCertEmbed does the same as AppendCert but using in-memory data.
func (s *Server) AppendCertEmbed(certData, keyData []byte) error {
if len(certData) == 0 && len(keyData) == 0 {
return errNoCertOrKeyProvided
}
cert, err := tls.X509KeyPair(certData, keyData)
if err != nil {
return fmt.Errorf("cannot load TLS key pair from the provided certData(%d) and keyData(%d): %w",
len(certData), len(keyData), err)
}
s.configTLS()
s.TLSConfig.Certificates = append(s.TLSConfig.Certificates, cert)
return nil
}
func (s *Server) configTLS() {
if s.TLSConfig == nil {
s.TLSConfig = &tls.Config{}
}
}
// DefaultConcurrency is the maximum number of concurrent connections
// the Server may serve by default (i.e. if Server.Concurrency isn't set).
const DefaultConcurrency = 256 * 1024
// Serve serves incoming connections from the given listener.
//
// Serve blocks until the given listener returns permanent error.
func (s *Server) Serve(ln net.Listener) error {
var lastOverflowErrorTime time.Time
var lastPerIPErrorTime time.Time
maxWorkersCount := s.getConcurrency()
s.mu.Lock()
s.ln = append(s.ln, ln)
if s.done == nil {
s.done = make(chan struct{})
}
if s.concurrencyCh == nil {
s.concurrencyCh = make(chan struct{}, maxWorkersCount)
}
s.mu.Unlock()
wp := &workerPool{
WorkerFunc: s.serveConn,
MaxWorkersCount: maxWorkersCount,
LogAllErrors: s.LogAllErrors,
MaxIdleWorkerDuration: s.MaxIdleWorkerDuration,
Logger: s.logger(),
connState: s.setState,
}
wp.Start()
// Count our waiting to accept a connection as an open connection.
// This way we can't get into any weird state where just after accepting
// a connection Shutdown is called which reads open as 0 because it isn't
// incremented yet.
atomic.AddInt32(&s.open, 1)
defer atomic.AddInt32(&s.open, -1)
for {
c, err := acceptConn(s, ln, &lastPerIPErrorTime)
if err != nil {
wp.Stop()
if err == io.EOF {
return nil
}
return err
}
s.setState(c, StateNew)
atomic.AddInt32(&s.open, 1)
if !wp.Serve(c) {
atomic.AddInt32(&s.open, -1)
atomic.AddUint32(&s.rejectedRequestsCount, 1)
s.writeFastError(c, StatusServiceUnavailable,
"The connection cannot be served because Server.Concurrency limit exceeded")
c.Close()
s.setState(c, StateClosed)
if time.Since(lastOverflowErrorTime) > time.Minute {
s.logger().Printf("The incoming connection cannot be served, because %d concurrent connections are served. "+
"Try increasing Server.Concurrency", maxWorkersCount)
lastOverflowErrorTime = time.Now()
}
// The current server reached concurrency limit,
// so give other concurrently running servers a chance
// accepting incoming connections on the same address.
//
// There is a hope other servers didn't reach their
// concurrency limits yet :)
//
// See also: https://github.com/valyala/fasthttp/pull/485#discussion_r239994990
if s.SleepWhenConcurrencyLimitsExceeded > 0 {
time.Sleep(s.SleepWhenConcurrencyLimitsExceeded)
}
}
}
}
// Shutdown gracefully shuts down the server without interrupting any active connections.
// Shutdown works by first closing all open listeners and then waiting indefinitely for all connections
// to return to idle and then shut down.
//
// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS immediately return nil.
// Make sure the program doesn't exit and waits instead for Shutdown to return.
//
// Shutdown does not close keepalive connections so it's recommended to set ReadTimeout and IdleTimeout to something else than 0.
func (s *Server) Shutdown() error {
return s.ShutdownWithContext(context.Background())
}
// ShutdownWithContext gracefully shuts down the server without interrupting any active connections.
// ShutdownWithContext works by first closing all open listeners and then waiting for all connections to return to idle
// or context timeout and then shut down.
//
// When ShutdownWithContext is called, Serve, ListenAndServe, and ListenAndServeTLS immediately return nil.
// Make sure the program doesn't exit and waits instead for Shutdown to return.
//
// ShutdownWithContext does not close keepalive connections so it's recommended to set ReadTimeout and IdleTimeout
// to something else than 0.
//
// When ShutdownWithContext returns errors, any operation to the Server is unavailable.
func (s *Server) ShutdownWithContext(ctx context.Context) (err error) {
s.mu.Lock()
defer s.mu.Unlock()
atomic.StoreInt32(&s.stop, 1)
defer atomic.StoreInt32(&s.stop, 0)
if s.ln == nil {
return nil
}
lnerr := s.closeListenersLocked()
if s.done != nil {
close(s.done)
}
// Closing the listener will make Serve() call Stop on the worker pool.
// Setting .stop to 1 will make serveConn() break out of its loop.
// Now we just have to wait until all workers are done or timeout.
ticker := time.NewTicker(time.Millisecond * 100)
defer ticker.Stop()
for {
s.closeIdleConns()
if open := atomic.LoadInt32(&s.open); open == 0 {
// There may be a pending request to call ctx.Done(). Therefore, we only set it to nil when open == 0.
s.done = nil
return lnerr
}
// This is not an optimal solution but using a sync.WaitGroup
// here causes data races as it's hard to prevent Add() to be called
// while Wait() is waiting.
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
continue
}
}
}
type connKeepAliveer interface {
SetKeepAlive(keepalive bool) error
SetKeepAlivePeriod(d time.Duration) error
io.Closer
}
func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) {
for {
c, err := ln.Accept()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
s.logger().Printf("Timeout error when accepting new connections: %v", netErr)
time.Sleep(time.Second)
continue
}
if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
s.logger().Printf("Permanent error when accepting new connections: %v", err)
return nil, err
}
return nil, io.EOF
}
if tc, ok := c.(connKeepAliveer); ok && s.TCPKeepalive {
if err := tc.SetKeepAlive(s.TCPKeepalive); err != nil {
_ = tc.Close()
return nil, err
}
if s.TCPKeepalivePeriod > 0 {
if err := tc.SetKeepAlivePeriod(s.TCPKeepalivePeriod); err != nil {
_ = tc.Close()
return nil, err
}
}
}
if s.MaxConnsPerIP > 0 {
pic := wrapPerIPConn(s, c)
if pic == nil {
if time.Since(*lastPerIPErrorTime) > time.Minute {
s.logger().Printf("The number of connections from %s exceeds MaxConnsPerIP=%d",
getConnIP4(c), s.MaxConnsPerIP)
*lastPerIPErrorTime = time.Now()
}
continue
}
c = pic
}
return c, nil
}
}
func wrapPerIPConn(s *Server, c net.Conn) net.Conn {
ip := getUint32IP(c)
if ip == 0 {
return c
}
n := s.perIPConnCounter.Register(ip)
if n > s.MaxConnsPerIP {
s.perIPConnCounter.Unregister(ip)
s.writeFastError(c, StatusTooManyRequests, "The number of connections from your ip exceeds MaxConnsPerIP")
c.Close()
return nil
}
return acquirePerIPConn(c, ip, &s.perIPConnCounter)
}
var defaultLogger = Logger(log.New(os.Stderr, "", log.LstdFlags))
func (s *Server) logger() Logger {
if s.Logger != nil {
return s.Logger
}
return defaultLogger
}
var (
// ErrPerIPConnLimit may be returned from ServeConn if the number of connections
// per ip exceeds Server.MaxConnsPerIP.
ErrPerIPConnLimit = errors.New("too many connections per ip")
// ErrConcurrencyLimit may be returned from ServeConn if the number
// of concurrently served connections exceeds Server.Concurrency.
ErrConcurrencyLimit = errors.New("cannot serve the connection because Server.Concurrency concurrent connections are served")
)
// ServeConn serves HTTP requests from the given connection.
//
// ServeConn returns nil if all requests from the c are successfully served.
// It returns non-nil error otherwise.
//
// Connection c must immediately propagate all the data passed to Write()
// to the client. Otherwise requests' processing may hang.
//
// ServeConn closes c before returning.
func (s *Server) ServeConn(c net.Conn) error {
if s.MaxConnsPerIP > 0 {
pic := wrapPerIPConn(s, c)
if pic == nil {
return ErrPerIPConnLimit
}
c = pic
}
n := int(atomic.AddUint32(&s.concurrency, 1)) // #nosec G115
if n > s.getConcurrency() {
atomic.AddUint32(&s.concurrency, ^uint32(0))
s.writeFastError(c, StatusServiceUnavailable, "The connection cannot be served because Server.Concurrency limit exceeded")
c.Close()
return ErrConcurrencyLimit
}
atomic.AddInt32(&s.open, 1)
err := s.serveConn(c)
atomic.AddUint32(&s.concurrency, ^uint32(0))
if err != errHijacked {
errc := c.Close()
s.setState(c, StateClosed)
if err == nil {
err = errc
}
} else {
err = nil
s.setState(c, StateHijacked)
}
return err
}
var errHijacked = errors.New("connection has been hijacked")
// GetCurrentConcurrency returns a number of currently served
// connections.
//
// This function is intended be used by monitoring systems.
func (s *Server) GetCurrentConcurrency() uint32 {
return atomic.LoadUint32(&s.concurrency)
}
// GetOpenConnectionsCount returns a number of opened connections.
//
// This function is intended be used by monitoring systems.
func (s *Server) GetOpenConnectionsCount() int32 {
if atomic.LoadInt32(&s.stop) == 0 {
// Decrement by one to avoid reporting the extra open value that gets
// counted while the server is listening.
return atomic.LoadInt32(&s.open) - 1
}
// This is not perfect, because s.stop could have changed to zero
// before we load the value of s.open. However, in the common case
// this avoids underreporting open connections by 1 during server shutdown.
return atomic.LoadInt32(&s.open)
}
// GetRejectedConnectionsCount returns a number of rejected connections.
//
// This function is intended be used by monitoring systems.
func (s *Server) GetRejectedConnectionsCount() uint32 {
return atomic.LoadUint32(&s.rejectedRequestsCount)
}
func (s *Server) getConcurrency() int {
n := s.Concurrency
if n <= 0 {
n = DefaultConcurrency
}
return n
}
var globalConnID uint64
func nextConnID() uint64 {
return atomic.AddUint64(&globalConnID, 1)
}
// DefaultMaxRequestBodySize is the maximum request body size the server
// reads by default.
//
// See Server.MaxRequestBodySize for details.
const DefaultMaxRequestBodySize = 4 * 1024 * 1024
func (s *Server) idleTimeout() time.Duration {
if s.IdleTimeout != 0 {
return s.IdleTimeout
}
return s.ReadTimeout
}
func (s *Server) serveConnCleanup() {
atomic.AddInt32(&s.open, -1)
atomic.AddUint32(&s.concurrency, ^uint32(0))
}
func (s *Server) serveConn(c net.Conn) (err error) {
defer s.serveConnCleanup()
atomic.AddUint32(&s.concurrency, 1)
var proto string
if proto, err = s.getNextProto(c); err != nil {
return
}
if handler, ok := s.nextProtos[proto]; ok {
// Remove read or write deadlines that might have previously been set.
// The next handler is responsible for setting its own deadlines.
if s.ReadTimeout > 0 || s.WriteTimeout > 0 {
if err = c.SetDeadline(zeroTime); err != nil {
return
}
}
return handler(c)
}
s.idleConnsMu.Lock()
if s.idleConns == nil {
s.idleConns = make(map[net.Conn]*atomic.Int64)
}
idleConnTime, ok := s.idleConns[c]
if !ok {
v := idleConnTimePool.Get()
if v == nil {
v = &atomic.Int64{}
}
idleConnTime = v.(*atomic.Int64)
s.idleConns[c] = idleConnTime
}
// Count the connection as Idle after 5 seconds.
// Same as net/http.Server:
// https://github.com/golang/go/blob/85d7bab91d9a3ed1f76842e4328973ea75efef54/src/net/http/server.go#L2834-L2836
idleConnTime.Store(time.Now().Add(time.Second * 5).Unix())
s.idleConnsMu.Unlock()
serverName := s.getServerName()
connRequestNum := uint64(0)
connID := nextConnID()
connTime := time.Now()
maxRequestBodySize := s.MaxRequestBodySize
if maxRequestBodySize <= 0 {
maxRequestBodySize = DefaultMaxRequestBodySize
}
writeTimeout := s.WriteTimeout
previousWriteTimeout := time.Duration(0)
ctx := s.acquireCtx(c)
ctx.connTime = connTime
isTLS := ctx.IsTLS()
var (
br *bufio.Reader
bw *bufio.Writer
timeoutResponse *Response
hijackHandler HijackHandler
hijackNoResponse bool
connectionClose bool
continueReadingRequest = true
)
for {
connRequestNum++
// If this is a keep-alive connection set the idle timeout.
if connRequestNum > 1 {
if d := s.idleTimeout(); d > 0 {
if err = c.SetReadDeadline(time.Now().Add(d)); err != nil {
break
}
}
}
if !s.ReduceMemoryUsage || br != nil {
if br == nil {
br = acquireReader(ctx)
}
// If this is a keep-alive connection we want to try and read the first bytes
// within the idle time.
if connRequestNum > 1 {
var b []byte
b, err = br.Peek(1)
if len(b) == 0 {
// If reading from a keep-alive connection returns nothing it means
// the connection was closed (either timeout or from the other side).
if err != io.EOF {
err = ErrNothingRead{error: err}
}
}
}
} else {
// If this is a keep-alive connection acquireByteReader will try to peek
// a couple of bytes already so the idle timeout will already be used.
br, err = acquireByteReader(&ctx)
}
ctx.Request.isTLS = isTLS
ctx.Response.Header.noDefaultContentType = s.NoDefaultContentType
ctx.Response.Header.noDefaultDate = s.NoDefaultDate
// Secure header error logs configuration
ctx.Request.Header.secureErrorLogMessage = s.SecureErrorLogMessage
ctx.Response.Header.secureErrorLogMessage = s.SecureErrorLogMessage
ctx.Request.secureErrorLogMessage = s.SecureErrorLogMessage
ctx.Response.secureErrorLogMessage = s.SecureErrorLogMessage
if err == nil {
s.setState(c, StateActive)
idleConnTime.Store(0)
if s.ReadTimeout > 0 {
if err = c.SetReadDeadline(time.Now().Add(s.ReadTimeout)); err != nil {
break
}
} else if s.IdleTimeout > 0 && connRequestNum > 1 {
// If this was an idle connection and the server has an IdleTimeout but
// no ReadTimeout then we should remove the ReadTimeout.
if err = c.SetReadDeadline(zeroTime); err != nil {
break
}
}
if s.DisableHeaderNamesNormalizing {
ctx.Request.Header.DisableNormalizing()
ctx.Response.Header.DisableNormalizing()
}
// Reading Headers.
//
// If we have pipeline response in the outgoing buffer,
// we only want to try and read the next headers once.
// If we have to wait for the next request we flush the
// outgoing buffer first so it doesn't have to wait.
if bw != nil && bw.Buffered() > 0 {
err = ctx.Request.Header.readLoop(br, false)
if err == errNeedMore {
err = bw.Flush()
if err != nil {
break
}
err = ctx.Request.Header.Read(br)
}
} else {
err = ctx.Request.Header.Read(br)
}
if err == nil {
if onHdrRecv := s.HeaderReceived; onHdrRecv != nil {
reqConf := onHdrRecv(&ctx.Request.Header)
if reqConf.ReadTimeout > 0 {
deadline := time.Now().Add(reqConf.ReadTimeout)
if err = c.SetReadDeadline(deadline); err != nil {
break
}
}
switch {
case reqConf.MaxRequestBodySize > 0:
maxRequestBodySize = reqConf.MaxRequestBodySize
case s.MaxRequestBodySize > 0:
maxRequestBodySize = s.MaxRequestBodySize
default:
maxRequestBodySize = DefaultMaxRequestBodySize
}
if reqConf.WriteTimeout > 0 {
writeTimeout = reqConf.WriteTimeout
} else {
writeTimeout = s.WriteTimeout
}
}
// read body
if s.StreamRequestBody {
err = ctx.Request.readBodyStream(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
} else {
err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
}
}
// When StreamRequestBody is set to true, we cannot safely release br.
// For example, when using chunked encoding, it's possible that br has only read the request headers.
if (!s.StreamRequestBody && s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil {
releaseReader(s, br)
br = nil
}
}
if err != nil {
if err == io.EOF {
err = nil
} else if nr, ok := err.(ErrNothingRead); ok {
if connRequestNum > 1 {
// This is not the first request and we haven't read a single byte
// of a new request yet. This means it's just a keep-alive connection
// closing down either because the remote closed it or because
// or a read timeout on our side. Either way just close the connection
// and don't return any error response.
err = nil
} else {
err = nr.error
}
}
if err != nil {
bw = s.writeErrorResponse(bw, ctx, serverName, err)
}
break
}
// 'Expect: 100-continue' request handling.
// See https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 for details.
if ctx.Request.MayContinue() {
// Allow the ability to deny reading the incoming request body
if s.ContinueHandler != nil {
if continueReadingRequest = s.ContinueHandler(&ctx.Request.Header); !continueReadingRequest {
if br != nil {
br.Reset(ctx.c)
}
ctx.SetStatusCode(StatusExpectationFailed)
}
}
if continueReadingRequest {
if bw == nil {
bw = acquireWriter(ctx)
}
// Send 'HTTP/1.1 100 Continue' response.
_, err = bw.Write(strResponseContinue)
if err != nil {
break
}
err = bw.Flush()
if err != nil {
break
}
if s.ReduceMemoryUsage {
releaseWriter(s, bw)
bw = nil
}
// Read request body.
if br == nil {
br = acquireReader(ctx)
}
if s.StreamRequestBody {
err = ctx.Request.ContinueReadBodyStream(br, maxRequestBodySize, !s.DisablePreParseMultipartForm)
} else {
err = ctx.Request.ContinueReadBody(br, maxRequestBodySize, !s.DisablePreParseMultipartForm)
}
if (!s.StreamRequestBody && s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil {
releaseReader(s, br)
br = nil
}
if err != nil {
bw = s.writeErrorResponse(bw, ctx, serverName, err)
break
}
}
}
// store req.ConnectionClose so even if it was changed inside of handler
connectionClose = s.DisableKeepalive || ctx.Request.Header.ConnectionClose()
if serverName != "" {
ctx.Response.Header.SetServer(serverName)
}
ctx.connID = connID
ctx.connRequestNum = connRequestNum
ctx.time = time.Now()
// If a client denies a request the handler should not be called
if continueReadingRequest {
s.Handler(ctx)
}
timeoutResponse = ctx.timeoutResponse
if timeoutResponse != nil {
// Acquire a new ctx because the old one will still be in use by the timeout out handler.
ctx = s.acquireCtx(c)
timeoutResponse.CopyTo(&ctx.Response)
}
if ctx.IsHead() {
ctx.Response.SkipBody = true
}
hijackHandler = ctx.hijackHandler
ctx.hijackHandler = nil
hijackNoResponse = ctx.hijackNoResponse && hijackHandler != nil
ctx.hijackNoResponse = false
if writeTimeout > 0 {
if err = c.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
break
}
previousWriteTimeout = writeTimeout
} else if previousWriteTimeout > 0 {
// We don't want a write timeout but we previously set one, remove it.
if err = c.SetWriteDeadline(zeroTime); err != nil {
break
}
previousWriteTimeout = 0
}
connectionClose = connectionClose ||
(s.MaxRequestsPerConn > 0 && connRequestNum >= uint64(s.MaxRequestsPerConn)) || // #nosec G115
ctx.Response.Header.ConnectionClose() ||
(s.CloseOnShutdown && atomic.LoadInt32(&s.stop) == 1)
if connectionClose {
ctx.Response.Header.SetConnectionClose()
} else if !ctx.Request.Header.IsHTTP11() {
// Set 'Connection: keep-alive' response header for HTTP/1.0 request.
// There is no need in setting this header for http/1.1, since in http/1.1
// connections are keep-alive by default.
ctx.Response.Header.setNonSpecial(strConnection, strKeepAlive)
}
if serverName != "" && len(ctx.Response.Header.Server()) == 0 {
ctx.Response.Header.SetServer(serverName)
}
if !hijackNoResponse {
if bw == nil {
bw = acquireWriter(ctx)
}
if err = writeResponse(ctx, bw); err != nil {
break
}
// Only flush the writer if we don't have another request in the pipeline.
// This is a big of an ugly optimization for https://www.techempower.com/benchmarks/
// This benchmark will send 16 pipelined requests. It is faster to pack as many responses
// in a TCP packet and send it back at once than waiting for a flush every request.
// In real world circumstances this behaviour could be argued as being wrong.
if br == nil || br.Buffered() == 0 || connectionClose || (s.ReduceMemoryUsage && hijackHandler == nil) {
err = bw.Flush()
if err != nil {
break
}
}
if connectionClose {
break
}
if s.ReduceMemoryUsage && hijackHandler == nil {
releaseWriter(s, bw)
bw = nil
}
}
if hijackHandler != nil {
var hjr io.Reader = c
if br != nil {
hjr = br
br = nil
}
if bw != nil {
err = bw.Flush()
if err != nil {
break
}
releaseWriter(s, bw)
bw = nil
}
err = c.SetDeadline(zeroTime)
if err != nil {
break
}
go hijackConnHandler(ctx, hjr, c, s, hijackHandler)
err = errHijacked
break
}
if ctx.Request.bodyStream != nil {
if rs, ok := ctx.Request.bodyStream.(*requestStream); ok {
releaseRequestStream(rs)
}
ctx.Request.bodyStream = nil
}
s.setState(c, StateIdle)
ctx.Request.Reset()
ctx.Response.Reset()
if atomic.LoadInt32(&s.stop) == 1 {
err = nil
break
}
idleConnTime.Store(time.Now().Unix())
}
if br != nil {
releaseReader(s, br)
}
if bw != nil {
releaseWriter(s, bw)
}
if hijackHandler == nil {
s.releaseCtx(ctx)
}
s.idleConnsMu.Lock()
ic, ok := s.idleConns[c]
if ok {
idleConnTimePool.Put(ic)
delete(s.idleConns, c)
}
s.idleConnsMu.Unlock()
return
}
func (s *Server) setState(nc net.Conn, state ConnState) {
if hook := s.ConnState; hook != nil {
hook(nc, state)
}
}
func hijackConnHandler(ctx *RequestCtx, r io.Reader, c net.Conn, s *Server, h HijackHandler) {
hjc := s.acquireHijackConn(r, c)
h(hjc)
if br, ok := r.(*bufio.Reader); ok {
releaseReader(s, br)
}
if !s.KeepHijackedConns {
c.Close()
s.releaseHijackConn(hjc)
}
s.releaseCtx(ctx)
}
func (s *Server) acquireHijackConn(r io.Reader, c net.Conn) *hijackConn {
v := s.hijackConnPool.Get()
if v == nil {
hjc := &hijackConn{
Conn: c,
r: r,
s: s,
}
return hjc
}
hjc := v.(*hijackConn)
hjc.Conn = c
hjc.r = r
return hjc
}
func (s *Server) releaseHijackConn(hjc *hijackConn) {
hjc.Conn = nil
hjc.r = nil
s.hijackConnPool.Put(hjc)
}
type hijackConn struct {
net.Conn
r io.Reader
s *Server
}
func (c *hijackConn) UnsafeConn() net.Conn {
return c.Conn
}
func (c *hijackConn) Read(p []byte) (int, error) {
return c.r.Read(p)
}
func (c *hijackConn) Close() error {
if !c.s.KeepHijackedConns {
// when we do not keep hijacked connections,
// it is closed in hijackConnHandler.
return nil
}
return c.Conn.Close()
}
// LastTimeoutErrorResponse returns the last timeout response set
// via TimeoutError* call.
//
// This function is intended for custom server implementations.
func (ctx *RequestCtx) LastTimeoutErrorResponse() *Response {
return ctx.timeoutResponse
}
func writeResponse(ctx *RequestCtx, w *bufio.Writer) error {
if ctx.timeoutResponse != nil {
return errors.New("cannot write timed out response")
}
err := ctx.Response.Write(w)
return err
}
const (
defaultReadBufferSize = 4096
defaultWriteBufferSize = 4096
)
func acquireByteReader(ctxP **RequestCtx) (*bufio.Reader, error) {
ctx := *ctxP
s := ctx.s
c := ctx.c
s.releaseCtx(ctx)
//nolint:wastedassign // Make GC happy, so it could garbage collect ctx while we wait for the
// next request.
ctx = nil
*ctxP = nil
var b [1]byte
n, err := c.Read(b[:])
ctx = s.acquireCtx(c)
*ctxP = ctx
if err != nil {
// Treat all errors as EOF on unsuccessful read
// of the first request byte.
return nil, io.EOF
}
if n != 1 {
// developer sanity-check
panic("BUG: Reader must return at least one byte")
}
ctx.fbr.c = c
ctx.fbr.ch = b[0]
ctx.fbr.byteRead = false
r := acquireReader(ctx)
r.Reset(&ctx.fbr)
return r, nil
}
func acquireReader(ctx *RequestCtx) *bufio.Reader {
v := ctx.s.readerPool.Get()
if v == nil {
n := ctx.s.ReadBufferSize
if n <= 0 {
n = defaultReadBufferSize
}
return bufio.NewReaderSize(ctx.c, n)
}
r := v.(*bufio.Reader)
r.Reset(ctx.c)
return r
}
func releaseReader(s *Server, r *bufio.Reader) {
s.readerPool.Put(r)
}
func acquireWriter(ctx *RequestCtx) *bufio.Writer {
v := ctx.s.writerPool.Get()
if v == nil {
n := ctx.s.WriteBufferSize
if n <= 0 {
n = defaultWriteBufferSize
}
return bufio.NewWriterSize(ctx.c, n)
}
w := v.(*bufio.Writer)
w.Reset(ctx.c)
return w
}
func releaseWriter(s *Server, w *bufio.Writer) {
s.writerPool.Put(w)
}
func (s *Server) acquireCtx(c net.Conn) (ctx *RequestCtx) {
v := s.ctxPool.Get()
if v == nil {
keepBodyBuffer := !s.ReduceMemoryUsage
ctx = new(RequestCtx)
ctx.Request.keepBodyBuffer = keepBodyBuffer
ctx.Response.keepBodyBuffer = keepBodyBuffer
ctx.s = s
} else {
ctx = v.(*RequestCtx)
}
if s.FormValueFunc != nil {
ctx.formValueFunc = s.FormValueFunc
}
ctx.c = c
return ctx
}
// Init2 prepares ctx for passing to RequestHandler.
//
// conn is used only for determining local and remote addresses.
//
// This function is intended for custom Server implementations.
// See https://github.com/valyala/httpteleport for details.
func (ctx *RequestCtx) Init2(conn net.Conn, logger Logger, reduceMemoryUsage bool) {
ctx.c = conn
ctx.remoteAddr = nil
ctx.logger.logger = logger
ctx.connID = nextConnID()
ctx.s = fakeServer
ctx.connRequestNum = 0
ctx.connTime = time.Now()
keepBodyBuffer := !reduceMemoryUsage
ctx.Request.keepBodyBuffer = keepBodyBuffer
ctx.Response.keepBodyBuffer = keepBodyBuffer
}
// Init prepares ctx for passing to RequestHandler.
//
// remoteAddr and logger are optional. They are used by RequestCtx.Logger().
//
// This function is intended for custom Server implementations.
func (ctx *RequestCtx) Init(req *Request, remoteAddr net.Addr, logger Logger) {
if remoteAddr == nil {
remoteAddr = zeroTCPAddr
}
c := &fakeAddrer{
laddr: zeroTCPAddr,
raddr: remoteAddr,
}
if logger == nil {
logger = defaultLogger
}
ctx.Init2(c, logger, true)
req.CopyTo(&ctx.Request)
}
// Deadline returns the time when work done on behalf of this context
// should be canceled. Deadline returns ok==false when no deadline is
// set. Successive calls to Deadline return the same results.
//
// This method always returns 0, false and is only present to make
// RequestCtx implement the context interface.
func (ctx *RequestCtx) Deadline() (deadline time.Time, ok bool) {
return
}
// Done returns a channel that's closed when work done on behalf of this
// context should be canceled. Done may return nil if this context can
// never be canceled. Successive calls to Done return the same value.
//
// Note: Because creating a new channel for every request is just too expensive, so
// RequestCtx.s.done is only closed when the server is shutting down.
func (ctx *RequestCtx) Done() <-chan struct{} {
return ctx.s.done
}
// Err returns a non-nil error value after Done is closed,
// successive calls to Err return the same error.
// If Done is not yet closed, Err returns nil.
// If Done is closed, Err returns a non-nil error explaining why:
// Canceled if the context was canceled (via server Shutdown)
// or DeadlineExceeded if the context's deadline passed.
//
// Note: Because creating a new channel for every request is just too expensive, so
// RequestCtx.s.done is only closed when the server is shutting down.
func (ctx *RequestCtx) Err() error {
select {
case <-ctx.Done():
return context.Canceled
default:
return nil
}
}
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
//
// This method is present to make RequestCtx implement the context interface.
// This method is the same as calling ctx.UserValue(key).
func (ctx *RequestCtx) Value(key any) any {
return ctx.UserValue(key)
}
var fakeServer = &Server{
done: make(chan struct{}),
// Initialize concurrencyCh for TimeoutHandler
concurrencyCh: make(chan struct{}, DefaultConcurrency),
}
type fakeAddrer struct {
net.Conn
laddr net.Addr
raddr net.Addr
}
func (fa *fakeAddrer) RemoteAddr() net.Addr {
return fa.raddr
}
func (fa *fakeAddrer) LocalAddr() net.Addr {
return fa.laddr
}
func (fa *fakeAddrer) Read(p []byte) (int, error) {
// developer sanity-check
panic("BUG: unexpected Read call")
}
func (fa *fakeAddrer) Write(p []byte) (int, error) {
// developer sanity-check
panic("BUG: unexpected Write call")
}
func (fa *fakeAddrer) Close() error {
// developer sanity-check
panic("BUG: unexpected Close call")
}
func (s *Server) releaseCtx(ctx *RequestCtx) {
if ctx.timeoutResponse != nil {
// developer sanity-check
panic("BUG: cannot release timed out RequestCtx")
}
ctx.reset()
s.ctxPool.Put(ctx)
}
func (s *Server) getServerName() string {
serverName := s.Name
if serverName == "" {
if !s.NoDefaultServerHeader {
serverName = defaultServerName
}
}
return serverName
}
func (s *Server) writeFastError(w io.Writer, statusCode int, msg string) {
w.Write(formatStatusLine(nil, strHTTP11, statusCode, s2b(StatusMessage(statusCode)))) //nolint:errcheck
server := s.getServerName()
if server != "" {
server = fmt.Sprintf("Server: %s\r\n", server)
}
date := ""
if !s.NoDefaultDate {
serverDateOnce.Do(updateServerDate)
date = fmt.Sprintf("Date: %s\r\n", serverDate.Load())
}
fmt.Fprintf(w, "Connection: close\r\n"+
server+
date+
"Content-Type: text/plain\r\n"+
"Content-Length: %d\r\n"+
"\r\n"+
"%s",
len(msg), msg)
}
func defaultErrorHandler(ctx *RequestCtx, err error) {
if _, ok := err.(*ErrSmallBuffer); ok {
ctx.Error("Too big request header", StatusRequestHeaderFieldsTooLarge)
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
ctx.Error("Request timeout", StatusRequestTimeout)
} else {
ctx.Error("Error when parsing request", StatusBadRequest)
}
}
func (s *Server) writeErrorResponse(bw *bufio.Writer, ctx *RequestCtx, serverName string, err error) *bufio.Writer {
errorHandler := defaultErrorHandler
if s.ErrorHandler != nil {
errorHandler = s.ErrorHandler
}
errorHandler(ctx, err)
if serverName != "" {
ctx.Response.Header.SetServer(serverName)
}
ctx.SetConnectionClose()
if bw == nil {
bw = acquireWriter(ctx)
}
writeResponse(ctx, bw) //nolint:errcheck
ctx.Response.Reset()
bw.Flush()
return bw
}
var idleConnTimePool sync.Pool
func (s *Server) closeIdleConns() {
s.idleConnsMu.Lock()
now := time.Now().Unix()
for c, ict := range s.idleConns {
t := ict.Load()
if t != 0 && now-t >= 0 {
_ = c.Close()
delete(s.idleConns, c)
idleConnTimePool.Put(ict)
}
}
s.idleConnsMu.Unlock()
}
func (s *Server) closeListenersLocked() error {
var err error
for _, ln := range s.ln {
if cerr := ln.Close(); cerr != nil && err == nil {
err = cerr
}
}
s.ln = nil
return err
}
// A ConnState represents the state of a client connection to a server.
// It's used by the optional Server.ConnState hook.
type ConnState int
const (
// StateNew represents a new connection that is expected to
// send a request immediately. Connections begin at this
// state and then transition to either StateActive or
// StateClosed.
StateNew ConnState = iota
// StateActive represents a connection that has read 1 or more
// bytes of a request. The Server.ConnState hook for
// StateActive fires before the request has entered a handler
// and doesn't fire again until the request has been
// handled. After the request is handled, the state
// transitions to StateClosed, StateHijacked, or StateIdle.
// For HTTP/2, StateActive fires on the transition from zero
// to one active request, and only transitions away once all
// active requests are complete. That means that ConnState
// cannot be used to do per-request work; ConnState only notes
// the overall state of the connection.
StateActive
// StateIdle represents a connection that has finished
// handling a request and is in the keep-alive state, waiting
// for a new request. Connections transition from StateIdle
// to either StateActive or StateClosed.
StateIdle
// StateHijacked represents a hijacked connection.
// This is a terminal state. It does not transition to StateClosed.
StateHijacked
// StateClosed represents a closed connection.
// This is a terminal state. Hijacked connections do not
// transition to StateClosed.
StateClosed
)
var stateName = []string{
StateNew: "new",
StateActive: "active",
StateIdle: "idle",
StateHijacked: "hijacked",
StateClosed: "closed",
}
func (c ConnState) String() string {
return stateName[c]
}
package fasthttp
import (
"strconv"
)
const (
statusMessageMin = 100
statusMessageMax = 511
)
// HTTP status codes were stolen from net/http.
const (
StatusContinue = 100 // RFC 7231, 6.2.1
StatusSwitchingProtocols = 101 // RFC 7231, 6.2.2
StatusProcessing = 102 // RFC 2518, 10.1
StatusEarlyHints = 103 // RFC 8297
StatusOK = 200 // RFC 7231, 6.3.1
StatusCreated = 201 // RFC 7231, 6.3.2
StatusAccepted = 202 // RFC 7231, 6.3.3
StatusNonAuthoritativeInfo = 203 // RFC 7231, 6.3.4
StatusNoContent = 204 // RFC 7231, 6.3.5
StatusResetContent = 205 // RFC 7231, 6.3.6
StatusPartialContent = 206 // RFC 7233, 4.1
StatusMultiStatus = 207 // RFC 4918, 11.1
StatusAlreadyReported = 208 // RFC 5842, 7.1
StatusIMUsed = 226 // RFC 3229, 10.4.1
StatusMultipleChoices = 300 // RFC 7231, 6.4.1
StatusMovedPermanently = 301 // RFC 7231, 6.4.2
StatusFound = 302 // RFC 7231, 6.4.3
StatusSeeOther = 303 // RFC 7231, 6.4.4
StatusNotModified = 304 // RFC 7232, 4.1
StatusUseProxy = 305 // RFC 7231, 6.4.5
_ = 306 // RFC 7231, 6.4.6 (Unused)
StatusTemporaryRedirect = 307 // RFC 7231, 6.4.7
StatusPermanentRedirect = 308 // RFC 7538, 3
StatusBadRequest = 400 // RFC 7231, 6.5.1
StatusUnauthorized = 401 // RFC 7235, 3.1
StatusPaymentRequired = 402 // RFC 7231, 6.5.2
StatusForbidden = 403 // RFC 7231, 6.5.3
StatusNotFound = 404 // RFC 7231, 6.5.4
StatusMethodNotAllowed = 405 // RFC 7231, 6.5.5
StatusNotAcceptable = 406 // RFC 7231, 6.5.6
StatusProxyAuthRequired = 407 // RFC 7235, 3.2
StatusRequestTimeout = 408 // RFC 7231, 6.5.7
StatusConflict = 409 // RFC 7231, 6.5.8
StatusGone = 410 // RFC 7231, 6.5.9
StatusLengthRequired = 411 // RFC 7231, 6.5.10
StatusPreconditionFailed = 412 // RFC 7232, 4.2
StatusRequestEntityTooLarge = 413 // RFC 7231, 6.5.11
StatusRequestURITooLong = 414 // RFC 7231, 6.5.12
StatusUnsupportedMediaType = 415 // RFC 7231, 6.5.13
StatusRequestedRangeNotSatisfiable = 416 // RFC 7233, 4.4
StatusExpectationFailed = 417 // RFC 7231, 6.5.14
StatusTeapot = 418 // RFC 7168, 2.3.3
StatusMisdirectedRequest = 421 // RFC 7540, 9.1.2
StatusUnprocessableEntity = 422 // RFC 4918, 11.2
StatusLocked = 423 // RFC 4918, 11.3
StatusFailedDependency = 424 // RFC 4918, 11.4
StatusUpgradeRequired = 426 // RFC 7231, 6.5.15
StatusPreconditionRequired = 428 // RFC 6585, 3
StatusTooManyRequests = 429 // RFC 6585, 4
StatusRequestHeaderFieldsTooLarge = 431 // RFC 6585, 5
StatusUnavailableForLegalReasons = 451 // RFC 7725, 3
StatusInternalServerError = 500 // RFC 7231, 6.6.1
StatusNotImplemented = 501 // RFC 7231, 6.6.2
StatusBadGateway = 502 // RFC 7231, 6.6.3
StatusServiceUnavailable = 503 // RFC 7231, 6.6.4
StatusGatewayTimeout = 504 // RFC 7231, 6.6.5
StatusHTTPVersionNotSupported = 505 // RFC 7231, 6.6.6
StatusVariantAlsoNegotiates = 506 // RFC 2295, 8.1
StatusInsufficientStorage = 507 // RFC 4918, 11.5
StatusLoopDetected = 508 // RFC 5842, 7.2
StatusNotExtended = 510 // RFC 2774, 7
StatusNetworkAuthenticationRequired = 511 // RFC 6585, 6
)
var (
unknownStatusCode = "Unknown Status Code"
statusMessages = []string{
StatusContinue: "Continue",
StatusSwitchingProtocols: "Switching Protocols",
StatusProcessing: "Processing",
StatusEarlyHints: "Early Hints",
StatusOK: "OK",
StatusCreated: "Created",
StatusAccepted: "Accepted",
StatusNonAuthoritativeInfo: "Non-Authoritative Information",
StatusNoContent: "No Content",
StatusResetContent: "Reset Content",
StatusPartialContent: "Partial Content",
StatusMultiStatus: "Multi-Status",
StatusAlreadyReported: "Already Reported",
StatusIMUsed: "IM Used",
StatusMultipleChoices: "Multiple Choices",
StatusMovedPermanently: "Moved Permanently",
StatusFound: "Found",
StatusSeeOther: "See Other",
StatusNotModified: "Not Modified",
StatusUseProxy: "Use Proxy",
StatusTemporaryRedirect: "Temporary Redirect",
StatusPermanentRedirect: "Permanent Redirect",
StatusBadRequest: "Bad Request",
StatusUnauthorized: "Unauthorized",
StatusPaymentRequired: "Payment Required",
StatusForbidden: "Forbidden",
StatusNotFound: "Not Found",
StatusMethodNotAllowed: "Method Not Allowed",
StatusNotAcceptable: "Not Acceptable",
StatusProxyAuthRequired: "Proxy Authentication Required",
StatusRequestTimeout: "Request Timeout",
StatusConflict: "Conflict",
StatusGone: "Gone",
StatusLengthRequired: "Length Required",
StatusPreconditionFailed: "Precondition Failed",
StatusRequestEntityTooLarge: "Request Entity Too Large",
StatusRequestURITooLong: "Request URI Too Long",
StatusUnsupportedMediaType: "Unsupported Media Type",
StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable",
StatusExpectationFailed: "Expectation Failed",
StatusTeapot: "I'm a teapot",
StatusMisdirectedRequest: "Misdirected Request",
StatusUnprocessableEntity: "Unprocessable Entity",
StatusLocked: "Locked",
StatusFailedDependency: "Failed Dependency",
StatusUpgradeRequired: "Upgrade Required",
StatusPreconditionRequired: "Precondition Required",
StatusTooManyRequests: "Too Many Requests",
StatusRequestHeaderFieldsTooLarge: "Request Header Fields Too Large",
StatusUnavailableForLegalReasons: "Unavailable For Legal Reasons",
StatusInternalServerError: "Internal Server Error",
StatusNotImplemented: "Not Implemented",
StatusBadGateway: "Bad Gateway",
StatusServiceUnavailable: "Service Unavailable",
StatusGatewayTimeout: "Gateway Timeout",
StatusHTTPVersionNotSupported: "HTTP Version Not Supported",
StatusVariantAlsoNegotiates: "Variant Also Negotiates",
StatusInsufficientStorage: "Insufficient Storage",
StatusLoopDetected: "Loop Detected",
StatusNotExtended: "Not Extended",
StatusNetworkAuthenticationRequired: "Network Authentication Required",
}
)
// StatusMessage returns HTTP status message for the given status code.
func StatusMessage(statusCode int) string {
if statusCode < statusMessageMin || statusCode > statusMessageMax {
return unknownStatusCode
}
if s := statusMessages[statusCode]; s != "" {
return s
}
return unknownStatusCode
}
func formatStatusLine(dst, protocol []byte, statusCode int, statusText []byte) []byte {
dst = append(dst, protocol...)
dst = append(dst, ' ')
dst = strconv.AppendInt(dst, int64(statusCode), 10)
dst = append(dst, ' ')
if len(statusText) == 0 {
dst = append(dst, s2b(StatusMessage(statusCode))...)
} else {
dst = append(dst, statusText...)
}
return append(dst, strCRLF...)
}
package fasthttp
import (
"bufio"
"io"
"sync"
"github.com/valyala/fasthttp/fasthttputil"
)
// StreamWriter must write data to w.
//
// Usually StreamWriter writes data to w in a loop (aka 'data streaming').
//
// StreamWriter must return immediately if w returns error.
//
// Since the written data is buffered, do not forget calling w.Flush
// when the data must be propagated to reader.
type StreamWriter func(w *bufio.Writer)
// NewStreamReader returns a reader, which replays all the data generated by sw.
//
// The returned reader may be passed to Response.SetBodyStream.
//
// Close must be called on the returned reader after all the required data
// has been read. Otherwise goroutine leak may occur.
//
// See also Response.SetBodyStreamWriter.
func NewStreamReader(sw StreamWriter) io.ReadCloser {
pc := fasthttputil.NewPipeConns()
pw := pc.Conn1()
pr := pc.Conn2()
var bw *bufio.Writer
v := streamWriterBufPool.Get()
if v == nil {
bw = bufio.NewWriter(pw)
} else {
bw = v.(*bufio.Writer)
bw.Reset(pw)
}
go func() {
sw(bw)
bw.Flush()
pw.Close()
streamWriterBufPool.Put(bw)
}()
return pr
}
var streamWriterBufPool sync.Pool
package fasthttp
import (
"bufio"
"bytes"
"io"
"sync"
"github.com/valyala/bytebufferpool"
)
type headerInterface interface {
ContentLength() int
ReadTrailer(r *bufio.Reader) error
}
type requestStream struct {
header headerInterface
prefetchedBytes *bytes.Reader
reader *bufio.Reader
totalBytesRead int
chunkLeft int
}
func (rs *requestStream) Read(p []byte) (int, error) {
var (
n int
err error
)
if rs.header.ContentLength() == -1 {
if rs.chunkLeft == 0 {
chunkSize, err := parseChunkSize(rs.reader)
if err != nil {
return 0, err
}
if chunkSize == 0 {
err = rs.header.ReadTrailer(rs.reader)
if err != nil && err != io.EOF {
return 0, err
}
return 0, io.EOF
}
rs.chunkLeft = chunkSize
}
bytesToRead := len(p)
if rs.chunkLeft < len(p) {
bytesToRead = rs.chunkLeft
}
n, err = rs.reader.Read(p[:bytesToRead])
rs.totalBytesRead += n
rs.chunkLeft -= n
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
if err == nil && rs.chunkLeft == 0 {
err = readCrLf(rs.reader)
}
return n, err
}
if rs.totalBytesRead == rs.header.ContentLength() {
return 0, io.EOF
}
prefetchedSize := int(rs.prefetchedBytes.Size())
if prefetchedSize > rs.totalBytesRead {
left := prefetchedSize - rs.totalBytesRead
if len(p) > left {
p = p[:left]
}
n, err := rs.prefetchedBytes.Read(p)
rs.totalBytesRead += n
if n == rs.header.ContentLength() {
return n, io.EOF
}
return n, err
}
left := rs.header.ContentLength() - rs.totalBytesRead
if left > 0 && len(p) > left {
p = p[:left]
}
n, err = rs.reader.Read(p)
rs.totalBytesRead += n
if err != nil {
return n, err
}
if rs.totalBytesRead == rs.header.ContentLength() {
err = io.EOF
}
return n, err
}
func acquireRequestStream(b *bytebufferpool.ByteBuffer, r *bufio.Reader, h headerInterface) *requestStream {
rs := requestStreamPool.Get().(*requestStream)
rs.prefetchedBytes = bytes.NewReader(b.B)
rs.reader = r
rs.header = h
return rs
}
func releaseRequestStream(rs *requestStream) {
rs.prefetchedBytes = nil
rs.totalBytesRead = 0
rs.chunkLeft = 0
rs.reader = nil
rs.header = nil
requestStreamPool.Put(rs)
}
var requestStreamPool = sync.Pool{
New: func() any {
return &requestStream{}
},
}
package fasthttp
import (
"context"
"errors"
"fmt"
"net"
"strconv"
"sync"
"sync/atomic"
"time"
)
// Dial dials the given TCP addr using tcp4.
//
// This function has the following additional features comparing to net.Dial:
//
// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
// - It returns ErrDialTimeout if connection cannot be established during
// DefaultDialTimeout seconds. Use DialTimeout for customizing dial timeout.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// - foobar.baz:443
// - foo.bar:80
// - aaa.com:8080
func Dial(addr string) (net.Conn, error) {
return defaultDialer.Dial(addr)
}
// DialTimeout dials the given TCP addr using tcp4 using the given timeout.
//
// This function has the following additional features comparing to net.Dial:
//
// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// - foobar.baz:443
// - foo.bar:80
// - aaa.com:8080
func DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
return defaultDialer.DialTimeout(addr, timeout)
}
// DialDualStack dials the given TCP addr using both tcp4 and tcp6.
//
// This function has the following additional features comparing to net.Dial:
//
// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
// - It returns ErrDialTimeout if connection cannot be established during
// DefaultDialTimeout seconds. Use DialDualStackTimeout for custom dial
// timeout.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// - foobar.baz:443
// - foo.bar:80
// - aaa.com:8080
func DialDualStack(addr string) (net.Conn, error) {
return defaultDialer.DialDualStack(addr)
}
// DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6
// using the given timeout.
//
// This function has the following additional features comparing to net.Dial:
//
// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// - foobar.baz:443
// - foo.bar:80
// - aaa.com:8080
func DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) {
return defaultDialer.DialDualStackTimeout(addr, timeout)
}
var defaultDialer = &TCPDialer{Concurrency: 1000}
// Resolver represents interface of the tcp resolver.
type Resolver interface {
LookupIPAddr(context.Context, string) (names []net.IPAddr, err error)
}
// TCPDialer contains options to control a group of Dial calls.
type TCPDialer struct {
// This may be used to override DNS resolving policy, like this:
// var dialer = &fasthttp.TCPDialer{
// Resolver: &net.Resolver{
// PreferGo: true,
// StrictErrors: false,
// Dial: func (ctx context.Context, network, address string) (net.Conn, error) {
// d := net.Dialer{}
// return d.DialContext(ctx, "udp", "8.8.8.8:53")
// },
// },
// }
Resolver Resolver
// LocalAddr is the local address to use when dialing an
// address.
// If nil, a local address is automatically chosen.
LocalAddr *net.TCPAddr
concurrencyCh chan struct{}
tcpAddrsMap sync.Map
// Concurrency controls the maximum number of concurrent Dials
// that can be performed using this object.
// Setting this to 0 means unlimited.
//
// WARNING: This can only be changed before the first Dial.
// Changes made after the first Dial will not affect anything.
Concurrency int
// DNSCacheDuration may be used to override the default DNS cache duration (DefaultDNSCacheDuration)
DNSCacheDuration time.Duration
once sync.Once
// DisableDNSResolution may be used to disable DNS resolution
DisableDNSResolution bool
}
// Dial dials the given TCP addr using tcp4.
//
// This function has the following additional features comparing to net.Dial:
//
// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
// - It returns ErrDialTimeout if connection cannot be established during
// DefaultDialTimeout seconds. Use DialTimeout for customizing dial timeout.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// - foobar.baz:443
// - foo.bar:80
// - aaa.com:8080
func (d *TCPDialer) Dial(addr string) (net.Conn, error) {
return d.dial(addr, false, DefaultDialTimeout)
}
// DialTimeout dials the given TCP addr using tcp4 using the given timeout.
//
// This function has the following additional features comparing to net.Dial:
//
// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// - foobar.baz:443
// - foo.bar:80
// - aaa.com:8080
func (d *TCPDialer) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
return d.dial(addr, false, timeout)
}
// DialDualStack dials the given TCP addr using both tcp4 and tcp6.
//
// This function has the following additional features comparing to net.Dial:
//
// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
// - It returns ErrDialTimeout if connection cannot be established during
// DefaultDialTimeout seconds. Use DialDualStackTimeout for custom dial
// timeout.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// - foobar.baz:443
// - foo.bar:80
// - aaa.com:8080
func (d *TCPDialer) DialDualStack(addr string) (net.Conn, error) {
return d.dial(addr, true, DefaultDialTimeout)
}
// DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6
// using the given timeout.
//
// This function has the following additional features comparing to net.Dial:
//
// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
//
// The addr passed to the function must contain port. Example addr values:
//
// - foobar.baz:443
// - foo.bar:80
// - aaa.com:8080
func (d *TCPDialer) DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) {
return d.dial(addr, true, timeout)
}
func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (net.Conn, error) {
d.once.Do(func() {
if d.Concurrency > 0 {
d.concurrencyCh = make(chan struct{}, d.Concurrency)
}
if d.DNSCacheDuration == 0 {
d.DNSCacheDuration = DefaultDNSCacheDuration
}
if !d.DisableDNSResolution {
go d.tcpAddrsClean()
}
})
deadline := time.Now().Add(timeout)
network := "tcp4"
if dualStack {
network = "tcp"
}
if d.DisableDNSResolution {
return d.tryDial(network, addr, deadline, d.concurrencyCh)
}
addrs, idx, err := d.getTCPAddrs(addr, dualStack, deadline)
if err != nil {
return nil, err
}
var conn net.Conn
n := uint32(len(addrs)) // #nosec G115
for range n {
conn, err = d.tryDial(network, addrs[idx%n].String(), deadline, d.concurrencyCh)
if err == nil {
return conn, nil
}
if errors.Is(err, ErrDialTimeout) {
return nil, err
}
idx++
}
return nil, err
}
func (d *TCPDialer) tryDial(
network string, addr string, deadline time.Time, concurrencyCh chan struct{},
) (net.Conn, error) {
timeout := time.Until(deadline)
if timeout <= 0 {
return nil, wrapDialWithUpstream(ErrDialTimeout, addr)
}
if concurrencyCh != nil {
select {
case concurrencyCh <- struct{}{}:
default:
tc := AcquireTimer(timeout)
isTimeout := false
select {
case concurrencyCh <- struct{}{}:
case <-tc.C:
isTimeout = true
}
ReleaseTimer(tc)
if isTimeout {
return nil, wrapDialWithUpstream(ErrDialTimeout, addr)
}
}
defer func() { <-concurrencyCh }()
}
dialer := net.Dialer{}
if d.LocalAddr != nil {
dialer.LocalAddr = d.LocalAddr
}
ctx, cancelCtx := context.WithDeadline(context.Background(), deadline)
defer cancelCtx()
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return nil, wrapDialWithUpstream(ErrDialTimeout, addr)
}
return nil, wrapDialWithUpstream(err, addr)
}
return conn, nil
}
// ErrDialTimeout is returned when TCP dialing is timed out.
var ErrDialTimeout = errors.New("dialing to the given TCP address timed out")
// ErrDialWithUpstream wraps dial error with upstream info.
//
// Should use errors.As to get upstream information from error:
//
// hc := fasthttp.HostClient{Addr: "foo.com,bar.com"}
// err := hc.Do(req, res)
//
// var dialErr *fasthttp.ErrDialWithUpstream
// if errors.As(err, &dialErr) {
// upstream = dialErr.Upstream // 34.206.39.153:80
// }
type ErrDialWithUpstream struct {
wrapErr error
Upstream string
}
func (e *ErrDialWithUpstream) Error() string {
return fmt.Sprintf("error when dialing %s: %s", e.Upstream, e.wrapErr.Error())
}
func (e *ErrDialWithUpstream) Unwrap() error {
return e.wrapErr
}
func wrapDialWithUpstream(err error, upstream string) error {
return &ErrDialWithUpstream{
Upstream: upstream,
wrapErr: err,
}
}
// DefaultDialTimeout is timeout used by Dial and DialDualStack
// for establishing TCP connections.
const DefaultDialTimeout = 3 * time.Second
type tcpAddrEntry struct {
resolveTime time.Time
addrs []net.TCPAddr
addrsIdx uint32
pending int32
}
// DefaultDNSCacheDuration is the duration for caching resolved TCP addresses
// by Dial* functions.
const DefaultDNSCacheDuration = time.Minute
func (d *TCPDialer) tcpAddrsClean() {
expireDuration := 2 * d.DNSCacheDuration
for {
time.Sleep(time.Second)
t := time.Now()
d.tcpAddrsMap.Range(func(k, v any) bool {
if e, ok := v.(*tcpAddrEntry); ok && t.Sub(e.resolveTime) > expireDuration {
d.tcpAddrsMap.Delete(k)
}
return true
})
}
}
func (d *TCPDialer) getTCPAddrs(addr string, dualStack bool, deadline time.Time) ([]net.TCPAddr, uint32, error) {
item, exist := d.tcpAddrsMap.Load(addr)
e, ok := item.(*tcpAddrEntry)
if exist && ok && e != nil && time.Since(e.resolveTime) > d.DNSCacheDuration {
// Only let one goroutine re-resolve at a time.
if atomic.SwapInt32(&e.pending, 1) == 0 {
e = nil
}
}
if e == nil {
addrs, err := resolveTCPAddrs(addr, dualStack, d.Resolver, deadline)
if err != nil {
item, exist := d.tcpAddrsMap.Load(addr)
e, ok = item.(*tcpAddrEntry)
if exist && ok && e != nil {
// Set pending to 0 so another goroutine can retry.
atomic.StoreInt32(&e.pending, 0)
}
return nil, 0, err
}
e = &tcpAddrEntry{
addrs: addrs,
resolveTime: time.Now(),
}
d.tcpAddrsMap.Store(addr, e)
}
idx := atomic.AddUint32(&e.addrsIdx, 1)
return e.addrs, idx, nil
}
func resolveTCPAddrs(addr string, dualStack bool, resolver Resolver, deadline time.Time) ([]net.TCPAddr, error) {
host, portS, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
port, err := strconv.Atoi(portS)
if err != nil {
return nil, err
}
if resolver == nil {
resolver = net.DefaultResolver
}
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()
ipaddrs, err := resolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
n := len(ipaddrs)
addrs := make([]net.TCPAddr, 0, n)
for i := 0; i < n; i++ {
ip := ipaddrs[i]
if !dualStack && ip.IP.To4() == nil {
continue
}
addrs = append(addrs, net.TCPAddr{
IP: ip.IP,
Port: port,
Zone: ip.Zone,
})
}
if len(addrs) == 0 {
return nil, errNoDNSEntries
}
return addrs, nil
}
var errNoDNSEntries = errors.New("couldn't find DNS entries for the given domain. Try using DialDualStack")
package fasthttp
import (
"sync"
"time"
)
func initTimer(t *time.Timer, timeout time.Duration) *time.Timer {
if t == nil {
return time.NewTimer(timeout)
}
if t.Reset(timeout) {
// developer sanity-check
panic("BUG: active timer trapped into initTimer()")
}
return t
}
func stopTimer(t *time.Timer) {
if !t.Stop() {
// Collect possibly added time from the channel
// if timer has been stopped and nobody collected its value.
select {
case <-t.C:
default:
}
}
}
// AcquireTimer returns a time.Timer from the pool and updates it to
// send the current time on its channel after at least timeout.
//
// The returned Timer may be returned to the pool with ReleaseTimer
// when no longer needed. This allows reducing GC load.
func AcquireTimer(timeout time.Duration) *time.Timer {
v := timerPool.Get()
if v == nil {
return time.NewTimer(timeout)
}
t := v.(*time.Timer)
initTimer(t, timeout)
return t
}
// ReleaseTimer returns the time.Timer acquired via AcquireTimer to the pool
// and prevents the Timer from firing.
//
// Do not access the released time.Timer or read from its channel otherwise
// data races may occur.
func ReleaseTimer(t *time.Timer) {
stopTimer(t)
timerPool.Put(t)
}
var timerPool sync.Pool
package fasthttp
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"time"
)
// GenerateTestCertificate generates a test certificate and private key based on the given host.
func GenerateTestCertificate(host string) ([]byte, []byte, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, nil, err
}
cert := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"fasthttp test"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
SignatureAlgorithm: x509.SHA256WithRSA,
DNSNames: []string{host},
BasicConstraintsValid: true,
IsCA: true,
}
certBytes, err := x509.CreateCertificate(
rand.Reader, cert, cert, &priv.PublicKey, priv,
)
p := pem.EncodeToMemory(
&pem.Block{
Type: "PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
},
)
b := pem.EncodeToMemory(
&pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
},
)
return b, p, err
}
package fasthttp
import (
"bytes"
"errors"
"fmt"
"io"
"path/filepath"
"strconv"
"sync"
)
// AcquireURI returns an empty URI instance from the pool.
//
// Release the URI with ReleaseURI after the URI is no longer needed.
// This allows reducing GC load.
func AcquireURI() *URI {
return uriPool.Get().(*URI)
}
// ReleaseURI releases the URI acquired via AcquireURI.
//
// The released URI mustn't be used after releasing it, otherwise data races
// may occur.
func ReleaseURI(u *URI) {
u.Reset()
uriPool.Put(u)
}
var uriPool = &sync.Pool{
New: func() any {
return &URI{}
},
}
// URI represents URI :) .
//
// It is forbidden copying URI instances. Create new instance and use CopyTo
// instead.
//
// URI instance MUST NOT be used from concurrently running goroutines.
type URI struct {
noCopy noCopy
queryArgs Args
pathOriginal []byte
scheme []byte
path []byte
queryString []byte
hash []byte
host []byte
fullURI []byte
requestURI []byte
username []byte
password []byte
parsedQueryArgs bool
// Path values are sent as-is without normalization.
//
// Disabled path normalization may be useful for proxying incoming requests
// to servers that are expecting paths to be forwarded as-is.
//
// By default path values are normalized, i.e.
// extra slashes are removed, special characters are encoded.
DisablePathNormalizing bool
}
// CopyTo copies uri contents to dst.
func (u *URI) CopyTo(dst *URI) {
dst.Reset()
dst.pathOriginal = append(dst.pathOriginal, u.pathOriginal...)
dst.scheme = append(dst.scheme, u.scheme...)
dst.path = append(dst.path, u.path...)
dst.queryString = append(dst.queryString, u.queryString...)
dst.hash = append(dst.hash, u.hash...)
dst.host = append(dst.host, u.host...)
dst.username = append(dst.username, u.username...)
dst.password = append(dst.password, u.password...)
u.queryArgs.CopyTo(&dst.queryArgs)
dst.parsedQueryArgs = u.parsedQueryArgs
dst.DisablePathNormalizing = u.DisablePathNormalizing
// fullURI and requestURI shouldn't be copied, since they are created
// from scratch on each FullURI() and RequestURI() call.
}
// Hash returns URI hash, i.e. qwe of http://aaa.com/foo/bar?baz=123#qwe .
//
// The returned bytes are valid until the next URI method call.
func (u *URI) Hash() []byte {
return u.hash
}
// SetHash sets URI hash.
func (u *URI) SetHash(hash string) {
u.hash = append(u.hash[:0], hash...)
}
// SetHashBytes sets URI hash.
func (u *URI) SetHashBytes(hash []byte) {
u.hash = append(u.hash[:0], hash...)
}
// Username returns URI username
//
// The returned bytes are valid until the next URI method call.
func (u *URI) Username() []byte {
return u.username
}
// SetUsername sets URI username.
func (u *URI) SetUsername(username string) {
u.username = append(u.username[:0], username...)
}
// SetUsernameBytes sets URI username.
func (u *URI) SetUsernameBytes(username []byte) {
u.username = append(u.username[:0], username...)
}
// Password returns URI password.
//
// The returned bytes are valid until the next URI method call.
func (u *URI) Password() []byte {
return u.password
}
// SetPassword sets URI password.
func (u *URI) SetPassword(password string) {
u.password = append(u.password[:0], password...)
}
// SetPasswordBytes sets URI password.
func (u *URI) SetPasswordBytes(password []byte) {
u.password = append(u.password[:0], password...)
}
// QueryString returns URI query string,
// i.e. baz=123 of http://aaa.com/foo/bar?baz=123#qwe .
//
// The returned bytes are valid until the next URI method call.
func (u *URI) QueryString() []byte {
return u.queryString
}
// SetQueryString sets URI query string.
func (u *URI) SetQueryString(queryString string) {
u.queryString = append(u.queryString[:0], queryString...)
u.parsedQueryArgs = false
}
// SetQueryStringBytes sets URI query string.
func (u *URI) SetQueryStringBytes(queryString []byte) {
u.queryString = append(u.queryString[:0], queryString...)
u.parsedQueryArgs = false
}
// Path returns URI path, i.e. /foo/bar of http://aaa.com/foo/bar?baz=123#qwe .
//
// The returned path is always urldecoded and normalized,
// i.e. '//f%20obar/baz/../zzz' becomes '/f obar/zzz'.
//
// The returned bytes are valid until the next URI method call.
func (u *URI) Path() []byte {
path := u.path
if len(path) == 0 {
path = strSlash
}
return path
}
// SetPath sets URI path.
func (u *URI) SetPath(path string) {
u.pathOriginal = append(u.pathOriginal[:0], path...)
u.path = normalizePath(u.path, u.pathOriginal)
}
// SetPathBytes sets URI path.
func (u *URI) SetPathBytes(path []byte) {
u.pathOriginal = append(u.pathOriginal[:0], path...)
u.path = normalizePath(u.path, u.pathOriginal)
}
// PathOriginal returns the original path from requestURI passed to URI.Parse().
//
// The returned bytes are valid until the next URI method call.
func (u *URI) PathOriginal() []byte {
return u.pathOriginal
}
// Scheme returns URI scheme, i.e. http of http://aaa.com/foo/bar?baz=123#qwe .
//
// Returned scheme is always lowercased.
//
// The returned bytes are valid until the next URI method call.
func (u *URI) Scheme() []byte {
scheme := u.scheme
if len(scheme) == 0 {
scheme = strHTTP
}
return scheme
}
// SetScheme sets URI scheme, i.e. http, https, ftp, etc.
func (u *URI) SetScheme(scheme string) {
u.scheme = append(u.scheme[:0], scheme...)
lowercaseBytes(u.scheme)
}
// SetSchemeBytes sets URI scheme, i.e. http, https, ftp, etc.
func (u *URI) SetSchemeBytes(scheme []byte) {
u.scheme = append(u.scheme[:0], scheme...)
lowercaseBytes(u.scheme)
}
func (u *URI) isHTTPS() bool {
return bytes.Equal(u.scheme, strHTTPS)
}
func (u *URI) isHTTP() bool {
return len(u.scheme) == 0 || bytes.Equal(u.scheme, strHTTP)
}
// Reset clears uri.
func (u *URI) Reset() {
u.pathOriginal = u.pathOriginal[:0]
u.scheme = u.scheme[:0]
u.path = u.path[:0]
u.queryString = u.queryString[:0]
u.hash = u.hash[:0]
u.username = u.username[:0]
u.password = u.password[:0]
u.host = u.host[:0]
u.queryArgs.Reset()
u.parsedQueryArgs = false
u.DisablePathNormalizing = false
// There is no need in u.fullURI = u.fullURI[:0], since full uri
// is calculated on each call to FullURI().
// There is no need in u.requestURI = u.requestURI[:0], since requestURI
// is calculated on each call to RequestURI().
}
// Host returns host part, i.e. aaa.com of http://aaa.com/foo/bar?baz=123#qwe .
//
// Host is always lowercased.
//
// The returned bytes are valid until the next URI method call.
func (u *URI) Host() []byte {
return u.host
}
// SetHost sets host for the uri.
func (u *URI) SetHost(host string) {
u.host = append(u.host[:0], host...)
lowercaseBytes(u.host)
}
// SetHostBytes sets host for the uri.
func (u *URI) SetHostBytes(host []byte) {
u.host = append(u.host[:0], host...)
lowercaseBytes(u.host)
}
var ErrorInvalidURI = errors.New("invalid uri")
// Parse initializes URI from the given host and uri.
//
// host may be nil. In this case uri must contain fully qualified uri,
// i.e. with scheme and host. http is assumed if scheme is omitted.
//
// uri may contain e.g. RequestURI without scheme and host if host is non-empty.
func (u *URI) Parse(host, uri []byte) error {
return u.parse(host, uri, false)
}
func (u *URI) parse(host, uri []byte, isTLS bool) error {
u.Reset()
if stringContainsCTLByte(uri) {
return ErrorInvalidURI
}
if len(host) == 0 || bytes.Contains(uri, strColonSlashSlash) {
scheme, newHost, newURI := splitHostURI(host, uri)
u.SetSchemeBytes(scheme)
host = newHost
uri = newURI
}
if isTLS {
u.SetSchemeBytes(strHTTPS)
}
if n := bytes.IndexByte(host, '@'); n >= 0 {
auth := host[:n]
host = host[n+1:]
if n := bytes.IndexByte(auth, ':'); n >= 0 {
u.username = append(u.username[:0], auth[:n]...)
u.password = append(u.password[:0], auth[n+1:]...)
} else {
u.username = append(u.username[:0], auth...)
u.password = u.password[:0]
}
}
u.host = append(u.host, host...)
parsedHost, err := parseHost(u.host)
if err != nil {
return err
}
u.host = parsedHost
lowercaseBytes(u.host)
b := uri
queryIndex := bytes.IndexByte(b, '?')
fragmentIndex := bytes.IndexByte(b, '#')
// Ignore query in fragment part
if fragmentIndex >= 0 && queryIndex > fragmentIndex {
queryIndex = -1
}
if queryIndex < 0 && fragmentIndex < 0 {
u.pathOriginal = append(u.pathOriginal, b...)
u.path = normalizePath(u.path, u.pathOriginal)
return nil
}
if queryIndex >= 0 {
// Path is everything up to the start of the query
u.pathOriginal = append(u.pathOriginal, b[:queryIndex]...)
u.path = normalizePath(u.path, u.pathOriginal)
if fragmentIndex < 0 {
u.queryString = append(u.queryString, b[queryIndex+1:]...)
} else {
u.queryString = append(u.queryString, b[queryIndex+1:fragmentIndex]...)
u.hash = append(u.hash, b[fragmentIndex+1:]...)
}
return nil
}
// fragmentIndex >= 0 && queryIndex < 0
// Path is up to the start of fragment
u.pathOriginal = append(u.pathOriginal, b[:fragmentIndex]...)
u.path = normalizePath(u.path, u.pathOriginal)
u.hash = append(u.hash, b[fragmentIndex+1:]...)
return nil
}
// parseHost parses host as an authority without user
// information. That is, as host[:port].
//
// Based on https://github.com/golang/go/blob/8ac5cbe05d61df0a7a7c9a38ff33305d4dcfea32/src/net/url/url.go#L619
//
// The host is parsed and unescaped in place overwriting the contents of the host parameter.
func parseHost(host []byte) ([]byte, error) {
if len(host) > 0 && host[0] == '[' {
// Parse an IP-Literal in RFC 3986 and RFC 6874.
// E.g., "[fe80::1]", "[fe80::1%25en0]", "[fe80::1]:80".
i := bytes.LastIndexByte(host, ']')
if i < 0 {
return nil, errors.New("missing ']' in host")
}
colonPort := host[i+1:]
if !validOptionalPort(colonPort) {
return nil, fmt.Errorf("invalid port %q after host", colonPort)
}
// RFC 6874 defines that %25 (%-encoded percent) introduces
// the zone identifier, and the zone identifier can use basically
// any %-encoding it likes. That's different from the host, which
// can only %-encode non-ASCII bytes.
// We do impose some restrictions on the zone, to avoid stupidity
// like newlines.
zone := bytes.Index(host[:i], []byte("%25"))
if zone >= 0 {
host1, err := unescape(host[:zone], encodeHost)
if err != nil {
return nil, err
}
host2, err := unescape(host[zone:i], encodeZone)
if err != nil {
return nil, err
}
host3, err := unescape(host[i:], encodeHost)
if err != nil {
return nil, err
}
return append(host1, append(host2, host3...)...), nil
}
} else if i := bytes.LastIndexByte(host, ':'); i != -1 {
colonPort := host[i:]
if !validOptionalPort(colonPort) {
return nil, fmt.Errorf("invalid port %q after host", colonPort)
}
}
var err error
if host, err = unescape(host, encodeHost); err != nil {
return nil, err
}
return host, nil
}
type encoding int
const (
encodeHost encoding = 1 + iota
encodeZone
)
type EscapeError string
func (e EscapeError) Error() string {
return "invalid URL escape " + strconv.Quote(string(e))
}
type InvalidHostError string
func (e InvalidHostError) Error() string {
return "invalid character " + strconv.Quote(string(e)) + " in host name"
}
// unescape unescapes a string; the mode specifies
// which section of the URL string is being unescaped.
//
// Based on https://github.com/golang/go/blob/8ac5cbe05d61df0a7a7c9a38ff33305d4dcfea32/src/net/url/url.go#L199
//
// Unescapes in place overwriting the contents of s and returning it.
func unescape(s []byte, mode encoding) ([]byte, error) {
// Count %, check that they're well-formed.
n := 0
for i := 0; i < len(s); {
switch s[i] {
case '%':
n++
if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
s = s[i:]
if len(s) > 3 {
s = s[:3]
}
return nil, EscapeError(s)
}
// Per https://tools.ietf.org/html/rfc3986#page-21
// in the host component %-encoding can only be used
// for non-ASCII bytes.
// But https://tools.ietf.org/html/rfc6874#section-2
// introduces %25 being allowed to escape a percent sign
// in IPv6 scoped-address literals. Yay.
if mode == encodeHost && unhex(s[i+1]) < 8 && !bytes.Equal(s[i:i+3], []byte("%25")) {
return nil, EscapeError(s[i : i+3])
}
if mode == encodeZone {
// RFC 6874 says basically "anything goes" for zone identifiers
// and that even non-ASCII can be redundantly escaped,
// but it seems prudent to restrict %-escaped bytes here to those
// that are valid host name bytes in their unescaped form.
// That is, you can use escaping in the zone identifier but not
// to introduce bytes you couldn't just write directly.
// But Windows puts spaces here! Yay.
v := unhex(s[i+1])<<4 | unhex(s[i+2])
if !bytes.Equal(s[i:i+3], []byte("%25")) && v != ' ' && shouldEscape(v, encodeHost) {
return nil, EscapeError(s[i : i+3])
}
}
i += 3
default:
if (mode == encodeHost || mode == encodeZone) && s[i] < 0x80 && shouldEscape(s[i], mode) {
return nil, InvalidHostError(s[i : i+1])
}
i++
}
}
if n == 0 {
return s, nil
}
t := s[:0]
for i := 0; i < len(s); i++ {
switch s[i] {
case '%':
t = append(t, unhex(s[i+1])<<4|unhex(s[i+2]))
i += 2
default:
t = append(t, s[i])
}
}
return t, nil
}
// Return true if the specified character should be escaped when
// appearing in a URL string, according to RFC 3986.
//
// Please be informed that for now shouldEscape does not check all
// reserved characters correctly. See https://github.com/golang/go/issues/5684.
//
// Based on https://github.com/golang/go/blob/8ac5cbe05d61df0a7a7c9a38ff33305d4dcfea32/src/net/url/url.go#L100
func shouldEscape(c byte, mode encoding) bool {
// §2.3 Unreserved characters (alphanum)
if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' {
return false
}
if mode == encodeHost || mode == encodeZone {
// §3.2.2 Host allows
// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "="
// as part of reg-name.
// We add : because we include :port as part of host.
// We add [ ] because we include [ipv6]:port as part of host.
// We add < > because they're the only characters left that
// we could possibly allow, and Parse will reject them if we
// escape them (because hosts can't use %-encoding for
// ASCII bytes).
switch c {
case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '[', ']', '<', '>', '"':
return false
}
}
if c == '-' || c == '_' || c == '.' || c == '~' { // §2.3 Unreserved characters (mark)
return false
}
// Everything else must be escaped.
return true
}
func ishex(c byte) bool {
return hex2intTable[c] < 16
}
func unhex(c byte) byte {
return hex2intTable[c] & 15
}
// validOptionalPort reports whether port is either an empty string
// or matches /^:\d*$/.
func validOptionalPort(port []byte) bool {
if len(port) == 0 {
return true
}
if port[0] != ':' {
return false
}
for _, b := range port[1:] {
if b < '0' || b > '9' {
return false
}
}
return true
}
func normalizePath(dst, src []byte) []byte {
dst = dst[:0]
dst = addLeadingSlash(dst, src)
dst = decodeArgAppendNoPlus(dst, src)
// remove duplicate slashes
b := dst
bSize := len(b)
for {
n := bytes.Index(b, strSlashSlash)
if n < 0 {
break
}
b = b[n:]
copy(b, b[1:])
b = b[:len(b)-1]
bSize--
}
dst = dst[:bSize]
// remove /./ parts
b = dst
for {
n := bytes.Index(b, strSlashDotSlash)
if n < 0 {
break
}
nn := n + len(strSlashDotSlash) - 1
copy(b[n:], b[nn:])
b = b[:len(b)-nn+n]
}
// remove /foo/../ parts
for {
n := bytes.Index(b, strSlashDotDotSlash)
if n < 0 {
break
}
nn := bytes.LastIndexByte(b[:n], '/')
if nn < 0 {
nn = 0
}
n += len(strSlashDotDotSlash) - 1
copy(b[nn:], b[n:])
b = b[:len(b)-n+nn]
}
// remove trailing /foo/..
n := bytes.LastIndex(b, strSlashDotDot)
if n >= 0 && n+len(strSlashDotDot) == len(b) {
nn := bytes.LastIndexByte(b[:n], '/')
if nn < 0 {
return append(dst[:0], strSlash...)
}
b = b[:nn+1]
}
if filepath.Separator == '\\' {
// remove \.\ parts
for {
n := bytes.Index(b, strBackSlashDotBackSlash)
if n < 0 {
break
}
nn := n + len(strSlashDotSlash) - 1
copy(b[n:], b[nn:])
b = b[:len(b)-nn+n]
}
// remove /foo/..\ parts
for {
n := bytes.Index(b, strSlashDotDotBackSlash)
if n < 0 {
break
}
nn := bytes.LastIndexByte(b[:n], '/')
if nn < 0 {
nn = 0
}
nn++
n += len(strSlashDotDotBackSlash)
copy(b[nn:], b[n:])
b = b[:len(b)-n+nn]
}
// remove /foo\..\ parts
for {
n := bytes.Index(b, strBackSlashDotDotBackSlash)
if n < 0 {
break
}
nn := bytes.LastIndexByte(b[:n], '/')
if nn < 0 {
nn = 0
}
n += len(strBackSlashDotDotBackSlash) - 1
copy(b[nn:], b[n:])
b = b[:len(b)-n+nn]
}
// remove trailing \foo\..
n := bytes.LastIndex(b, strBackSlashDotDot)
if n >= 0 && n+len(strSlashDotDot) == len(b) {
nn := bytes.LastIndexByte(b[:n], '/')
if nn < 0 {
return append(dst[:0], strSlash...)
}
b = b[:nn+1]
}
}
return b
}
// RequestURI returns RequestURI - i.e. URI without Scheme and Host.
func (u *URI) RequestURI() []byte {
var dst []byte
if u.DisablePathNormalizing {
dst = u.requestURI[:0]
dst = append(dst, u.PathOriginal()...)
} else {
dst = appendQuotedPath(u.requestURI[:0], u.Path())
}
if u.parsedQueryArgs && u.queryArgs.Len() > 0 {
dst = append(dst, '?')
dst = u.queryArgs.AppendBytes(dst)
} else if len(u.queryString) > 0 {
dst = append(dst, '?')
dst = append(dst, u.queryString...)
}
u.requestURI = dst
return u.requestURI
}
// LastPathSegment returns the last part of uri path after '/'.
//
// Examples:
//
// - For /foo/bar/baz.html path returns baz.html.
// - For /foo/bar/ returns empty byte slice.
// - For /foobar.js returns foobar.js.
//
// The returned bytes are valid until the next URI method call.
func (u *URI) LastPathSegment() []byte {
path := u.Path()
n := bytes.LastIndexByte(path, '/')
if n < 0 {
return path
}
return path[n+1:]
}
// Update updates uri.
//
// The following newURI types are accepted:
//
// - Absolute, i.e. http://foobar.com/aaa/bb?cc . In this case the original
// uri is replaced by newURI.
// - Absolute without scheme, i.e. //foobar.com/aaa/bb?cc. In this case
// the original scheme is preserved.
// - Missing host, i.e. /aaa/bb?cc . In this case only RequestURI part
// of the original uri is replaced.
// - Relative path, i.e. xx?yy=abc . In this case the original RequestURI
// is updated according to the new relative path.
func (u *URI) Update(newURI string) {
u.UpdateBytes(s2b(newURI))
}
// UpdateBytes updates uri.
//
// The following newURI types are accepted:
//
// - Absolute, i.e. http://foobar.com/aaa/bb?cc . In this case the original
// uri is replaced by newURI.
// - Absolute without scheme, i.e. //foobar.com/aaa/bb?cc. In this case
// the original scheme is preserved.
// - Missing host, i.e. /aaa/bb?cc . In this case only RequestURI part
// of the original uri is replaced.
// - Relative path, i.e. xx?yy=abc . In this case the original RequestURI
// is updated according to the new relative path.
func (u *URI) UpdateBytes(newURI []byte) {
u.requestURI = u.updateBytes(newURI, u.requestURI)
}
func (u *URI) updateBytes(newURI, buf []byte) []byte {
if len(newURI) == 0 {
return buf
}
n := bytes.Index(newURI, strSlashSlash)
if n >= 0 {
// absolute uri
var b [32]byte
schemeOriginal := b[:0]
if len(u.scheme) > 0 {
schemeOriginal = append([]byte(nil), u.scheme...)
}
if err := u.Parse(nil, newURI); err != nil {
return nil
}
if len(schemeOriginal) > 0 && len(u.scheme) == 0 {
u.scheme = append(u.scheme[:0], schemeOriginal...)
}
return buf
}
if newURI[0] == '/' {
// uri without host
buf = u.appendSchemeHost(buf[:0])
buf = append(buf, newURI...)
if err := u.Parse(nil, buf); err != nil {
return nil
}
return buf
}
// relative path
switch newURI[0] {
case '?':
// query string only update
u.SetQueryStringBytes(newURI[1:])
return append(buf[:0], u.FullURI()...)
case '#':
// update only hash
u.SetHashBytes(newURI[1:])
return append(buf[:0], u.FullURI()...)
default:
// update the last path part after the slash
path := u.Path()
n = bytes.LastIndexByte(path, '/')
if n < 0 {
panic(fmt.Sprintf("BUG: path must contain at least one slash: %q %q", u.Path(), newURI))
}
buf = u.appendSchemeHost(buf[:0])
buf = appendQuotedPath(buf, path[:n+1])
buf = append(buf, newURI...)
if err := u.Parse(nil, buf); err != nil {
return nil
}
return buf
}
}
// FullURI returns full uri in the form {Scheme}://{Host}{RequestURI}#{Hash}.
//
// The returned bytes are valid until the next URI method call.
func (u *URI) FullURI() []byte {
u.fullURI = u.AppendBytes(u.fullURI[:0])
return u.fullURI
}
// AppendBytes appends full uri to dst and returns the extended dst.
func (u *URI) AppendBytes(dst []byte) []byte {
dst = u.appendSchemeHost(dst)
dst = append(dst, u.RequestURI()...)
if len(u.hash) > 0 {
dst = append(dst, '#')
dst = append(dst, u.hash...)
}
return dst
}
func (u *URI) appendSchemeHost(dst []byte) []byte {
dst = append(dst, u.Scheme()...)
dst = append(dst, strColonSlashSlash...)
return append(dst, u.Host()...)
}
// WriteTo writes full uri to w.
//
// WriteTo implements io.WriterTo interface.
func (u *URI) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(u.FullURI())
return int64(n), err
}
// String returns full uri.
func (u *URI) String() string {
return string(u.FullURI())
}
func splitHostURI(host, uri []byte) ([]byte, []byte, []byte) {
n := bytes.Index(uri, strSlashSlash)
if n < 0 {
return strHTTP, host, uri
}
scheme := uri[:n]
if bytes.IndexByte(scheme, '/') >= 0 {
return strHTTP, host, uri
}
if len(scheme) > 0 && scheme[len(scheme)-1] == ':' {
scheme = scheme[:len(scheme)-1]
}
n += len(strSlashSlash)
uri = uri[n:]
n = bytes.IndexByte(uri, '/')
nq := bytes.IndexByte(uri, '?')
if nq >= 0 && (n < 0 || nq < n) {
// A hack for urls like foobar.com?a=b/xyz
n = nq
}
nh := bytes.IndexByte(uri, '#')
if nh >= 0 && (n < 0 || nh < n) {
// A hack for urls like foobar.com#abc.com
n = nh
}
if n < 0 {
return scheme, uri, strSlash
}
return scheme, uri[:n], uri[n:]
}
// QueryArgs returns query args.
//
// The returned args are valid until the next URI method call.
func (u *URI) QueryArgs() *Args {
u.parseQueryArgs()
return &u.queryArgs
}
func (u *URI) parseQueryArgs() {
if u.parsedQueryArgs {
return
}
u.queryArgs.ParseBytes(u.queryString)
u.parsedQueryArgs = true
}
// stringContainsCTLByte reports whether s contains any ASCII control character.
func stringContainsCTLByte(s []byte) bool {
for i := 0; i < len(s); i++ {
b := s[i]
if b < ' ' || b == 0x7f {
return true
}
}
return false
}
//go:build !windows
package fasthttp
func addLeadingSlash(dst, src []byte) []byte {
// add leading slash for unix paths
if len(src) == 0 || src[0] != '/' {
dst = append(dst, '/')
}
return dst
}
package fasthttp
import (
"io"
)
type userDataKV struct {
key any
value any
}
type userData []userDataKV
func (d *userData) Set(key, value any) {
if b, ok := key.([]byte); ok {
key = string(b)
}
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if kv.key == key {
kv.value = value
return
}
}
if value == nil {
return
}
c := cap(args)
if c > n {
args = args[:n+1]
kv := &args[n]
kv.key = key
kv.value = value
*d = args
return
}
kv := userDataKV{}
kv.key = key
kv.value = value
args = append(args, kv)
*d = args
}
func (d *userData) SetBytes(key []byte, value any) {
d.Set(key, value)
}
func (d *userData) Get(key any) any {
if b, ok := key.([]byte); ok {
key = b2s(b)
}
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if kv.key == key {
return kv.value
}
}
return nil
}
func (d *userData) GetBytes(key []byte) any {
return d.Get(key)
}
func (d *userData) Reset() {
args := *d
n := len(args)
for i := 0; i < n; i++ {
v := args[i].value
if vc, ok := v.(io.Closer); ok {
vc.Close()
}
(*d)[i].value = nil
(*d)[i].key = nil
}
*d = (*d)[:0]
}
func (d *userData) Remove(key any) {
if b, ok := key.([]byte); ok {
key = b2s(b)
}
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if kv.key == key {
n--
args[i], args[n] = args[n], args[i]
args[n].key = nil
args[n].value = nil
args = args[:n]
*d = args
return
}
}
}
func (d *userData) RemoveBytes(key []byte) {
d.Remove(key)
}
package fasthttp
import (
"errors"
"net"
"runtime"
"strings"
"sync"
"time"
)
// workerPool serves incoming connections via a pool of workers
// in FILO order, i.e. the most recently stopped worker will serve the next
// incoming connection.
//
// Such a scheme keeps CPU caches hot (in theory).
type workerPool struct {
workerChanPool sync.Pool
Logger Logger
// Function for serving server connections.
// It must leave c unclosed.
WorkerFunc ServeHandler
stopCh chan struct{}
connState func(net.Conn, ConnState)
ready []*workerChan
MaxWorkersCount int
MaxIdleWorkerDuration time.Duration
workersCount int
lock sync.Mutex
LogAllErrors bool
mustStop bool
}
type workerChan struct {
lastUseTime time.Time
ch chan net.Conn
}
func (wp *workerPool) Start() {
if wp.stopCh != nil {
return
}
wp.stopCh = make(chan struct{})
stopCh := wp.stopCh
wp.workerChanPool.New = func() any {
return &workerChan{
ch: make(chan net.Conn, workerChanCap),
}
}
go func() {
var scratch []*workerChan
for {
wp.clean(&scratch)
select {
case <-stopCh:
return
default:
time.Sleep(wp.getMaxIdleWorkerDuration())
}
}
}()
}
func (wp *workerPool) Stop() {
if wp.stopCh == nil {
return
}
close(wp.stopCh)
wp.stopCh = nil
// Stop all the workers waiting for incoming connections.
// Do not wait for busy workers - they will stop after
// serving the connection and noticing wp.mustStop = true.
wp.lock.Lock()
ready := wp.ready
for i := range ready {
ready[i].ch <- nil
ready[i] = nil
}
wp.ready = ready[:0]
wp.mustStop = true
wp.lock.Unlock()
}
func (wp *workerPool) getMaxIdleWorkerDuration() time.Duration {
if wp.MaxIdleWorkerDuration <= 0 {
return 10 * time.Second
}
return wp.MaxIdleWorkerDuration
}
func (wp *workerPool) clean(scratch *[]*workerChan) {
maxIdleWorkerDuration := wp.getMaxIdleWorkerDuration()
// Clean least recently used workers if they didn't serve connections
// for more than maxIdleWorkerDuration.
criticalTime := time.Now().Add(-maxIdleWorkerDuration)
wp.lock.Lock()
ready := wp.ready
n := len(ready)
// Use binary-search algorithm to find out the index of the least recently worker which can be cleaned up.
l, r := 0, n-1
for l <= r {
mid := (l + r) / 2
if criticalTime.After(wp.ready[mid].lastUseTime) {
l = mid + 1
} else {
r = mid - 1
}
}
i := r
if i == -1 {
wp.lock.Unlock()
return
}
*scratch = append((*scratch)[:0], ready[:i+1]...)
m := copy(ready, ready[i+1:])
for i = m; i < n; i++ {
ready[i] = nil
}
wp.ready = ready[:m]
wp.lock.Unlock()
// Notify obsolete workers to stop.
// This notification must be outside the wp.lock, since ch.ch
// may be blocking and may consume a lot of time if many workers
// are located on non-local CPUs.
tmp := *scratch
for i := range tmp {
tmp[i].ch <- nil
tmp[i] = nil
}
}
func (wp *workerPool) Serve(c net.Conn) bool {
ch := wp.getCh()
if ch == nil {
return false
}
ch.ch <- c
return true
}
var workerChanCap = func() int {
// Use blocking workerChan if GOMAXPROCS=1.
// This immediately switches Serve to WorkerFunc, which results
// in higher performance (under go1.5 at least).
if runtime.GOMAXPROCS(0) == 1 {
return 0
}
// Use non-blocking workerChan if GOMAXPROCS>1,
// since otherwise the Serve caller (Acceptor) may lag accepting
// new connections if WorkerFunc is CPU-bound.
return 1
}()
func (wp *workerPool) getCh() *workerChan {
var ch *workerChan
createWorker := false
wp.lock.Lock()
ready := wp.ready
n := len(ready) - 1
if n < 0 {
if wp.workersCount < wp.MaxWorkersCount {
createWorker = true
wp.workersCount++
}
} else {
ch = ready[n]
ready[n] = nil
wp.ready = ready[:n]
}
wp.lock.Unlock()
if ch == nil {
if !createWorker {
return nil
}
vch := wp.workerChanPool.Get()
ch = vch.(*workerChan)
go func() {
wp.workerFunc(ch)
wp.workerChanPool.Put(vch)
}()
}
return ch
}
func (wp *workerPool) release(ch *workerChan) bool {
ch.lastUseTime = time.Now()
wp.lock.Lock()
if wp.mustStop {
wp.lock.Unlock()
return false
}
wp.ready = append(wp.ready, ch)
wp.lock.Unlock()
return true
}
func (wp *workerPool) workerFunc(ch *workerChan) {
var c net.Conn
var err error
for c = range ch.ch {
if c == nil {
break
}
if err = wp.WorkerFunc(c); err != nil && err != errHijacked {
errStr := err.Error()
shouldIgnore := strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "reset by peer") ||
strings.Contains(errStr, "request headers: small read buffer") ||
strings.Contains(errStr, "unexpected EOF") ||
strings.Contains(errStr, "i/o timeout") ||
errors.Is(err, ErrBadTrailer)
if wp.LogAllErrors || !shouldIgnore {
wp.Logger.Printf("error when serving connection %q<->%q: %v", c.LocalAddr(), c.RemoteAddr(), err)
}
}
if err == errHijacked {
wp.connState(c, StateHijacked)
} else {
_ = c.Close()
wp.connState(c, StateClosed)
}
if !wp.release(ch) {
break
}
}
wp.lock.Lock()
wp.workersCount--
wp.lock.Unlock()
}
package fasthttp
import (
"bytes"
"fmt"
"io"
"sync"
"github.com/klauspost/compress/zstd"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp/stackless"
)
const (
CompressZstdSpeedNotSet = iota
CompressZstdBestSpeed
CompressZstdDefault
CompressZstdSpeedBetter
CompressZstdBestCompression
)
var (
zstdDecoderPool sync.Pool
realZstdWriterPoolMap = newCompressWriterPoolMap()
stacklessZstdWriterPoolMap = newCompressWriterPoolMap()
)
func acquireZstdReader(r io.Reader) (*zstd.Decoder, error) {
v := zstdDecoderPool.Get()
if v == nil {
return zstd.NewReader(r)
}
zr := v.(*zstd.Decoder)
if err := zr.Reset(r); err != nil {
return nil, err
}
return zr, nil
}
func releaseZstdReader(zr *zstd.Decoder) {
zstdDecoderPool.Put(zr)
}
func acquireStacklessZstdWriter(w io.Writer, compressLevel int) stackless.Writer {
nLevel := normalizeZstdCompressLevel(compressLevel)
p := stacklessZstdWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
return stackless.NewWriter(w, func(w io.Writer) stackless.Writer {
return acquireRealZstdWriter(w, compressLevel)
})
}
sw := v.(stackless.Writer)
sw.Reset(w)
return sw
}
func releaseStacklessZstdWriter(zf stackless.Writer, level int) {
zf.Close()
nLevel := normalizeZstdCompressLevel(level)
p := stacklessZstdWriterPoolMap[nLevel]
p.Put(zf)
}
func acquireRealZstdWriter(w io.Writer, level int) *zstd.Encoder {
nLevel := normalizeZstdCompressLevel(level)
p := realZstdWriterPoolMap[nLevel]
v := p.Get()
if v == nil {
zw, err := zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.EncoderLevel(level)))
if err != nil {
panic(err)
}
return zw
}
zw := v.(*zstd.Encoder)
zw.Reset(w)
return zw
}
func releaseRealZstdWriter(zw *zstd.Encoder, level int) {
zw.Close()
nLevel := normalizeZstdCompressLevel(level)
p := realZstdWriterPoolMap[nLevel]
p.Put(zw)
}
func AppendZstdBytesLevel(dst, src []byte, level int) []byte {
w := &byteSliceWriter{b: dst}
WriteZstdLevel(w, src, level) //nolint:errcheck
return w.b
}
func WriteZstdLevel(w io.Writer, p []byte, level int) (int, error) {
level = normalizeZstdCompressLevel(level)
switch w.(type) {
case *byteSliceWriter,
*bytes.Buffer,
*bytebufferpool.ByteBuffer:
ctx := &compressCtx{
w: w,
p: p,
level: level,
}
stacklessWriteZstd(ctx)
return len(p), nil
default:
zw := acquireStacklessZstdWriter(w, level)
n, err := zw.Write(p)
releaseStacklessZstdWriter(zw, level)
return n, err
}
}
var (
stacklessWriteZstdOnce sync.Once
stacklessWriteZstdFunc func(ctx any) bool
)
func stacklessWriteZstd(ctx any) {
stacklessWriteZstdOnce.Do(func() {
stacklessWriteZstdFunc = stackless.NewFunc(nonblockingWriteZstd)
})
stacklessWriteZstdFunc(ctx)
}
func nonblockingWriteZstd(ctxv any) {
ctx := ctxv.(*compressCtx)
zw := acquireRealZstdWriter(ctx.w, ctx.level)
zw.Write(ctx.p) //nolint:errcheck
releaseRealZstdWriter(zw, ctx.level)
}
// AppendZstdBytes appends zstd src to dst and returns the resulting dst.
func AppendZstdBytes(dst, src []byte) []byte {
return AppendZstdBytesLevel(dst, src, CompressZstdDefault)
}
// WriteUnzstd writes unzstd p to w and returns the number of uncompressed
// bytes written to w.
func WriteUnzstd(w io.Writer, p []byte) (int, error) {
r := &byteSliceReader{b: p}
zr, err := acquireZstdReader(r)
if err != nil {
return 0, err
}
n, err := copyZeroAlloc(w, zr)
releaseZstdReader(zr)
nn := int(n)
if int64(nn) != n {
return 0, fmt.Errorf("too much data unzstd: %d", n)
}
return nn, err
}
// AppendUnzstdBytes appends unzstd src to dst and returns the resulting dst.
func AppendUnzstdBytes(dst, src []byte) ([]byte, error) {
w := &byteSliceWriter{b: dst}
_, err := WriteUnzstd(w, src)
return w.b, err
}
// normalizes compression level into [0..7], so it could be used as an index
// in *PoolMap.
func normalizeZstdCompressLevel(level int) int {
if level < CompressZstdSpeedNotSet || level > CompressZstdBestCompression {
level = CompressZstdDefault
}
return level
}