package qpack
import (
"bytes"
"errors"
"fmt"
"sync"
"golang.org/x/net/http2/hpack"
)
// A decodingError is something the spec defines as a decoding error.
type decodingError struct {
err error
}
func (de decodingError) Error() string {
return fmt.Sprintf("decoding error: %v", de.err)
}
// An invalidIndexError is returned when an encoder references a table
// entry before the static table or after the end of the dynamic table.
type invalidIndexError int
func (e invalidIndexError) Error() string {
return fmt.Sprintf("invalid indexed representation index %d", int(e))
}
var errNoDynamicTable = decodingError{errors.New("no dynamic table")}
// errNeedMore is an internal sentinel error value that means the
// buffer is truncated and we need to read more data before we can
// continue parsing.
var errNeedMore = errors.New("need more data")
// A Decoder is the decoding context for incremental processing of
// header blocks.
type Decoder struct {
mutex sync.Mutex
emitFunc func(f HeaderField)
readRequiredInsertCount bool
readDeltaBase bool
// buf is the unparsed buffer. It's only written to
// saveBuf if it was truncated in the middle of a header
// block. Because it's usually not owned, we can only
// process it under Write.
buf []byte // not owned; only valid during Write
// saveBuf is previous data passed to Write which we weren't able
// to fully parse before. Unlike buf, we own this data.
saveBuf bytes.Buffer
}
// NewDecoder returns a new decoder
// The emitFunc will be called for each valid field parsed,
// in the same goroutine as calls to Write, before Write returns.
func NewDecoder(emitFunc func(f HeaderField)) *Decoder {
return &Decoder{emitFunc: emitFunc}
}
func (d *Decoder) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
d.mutex.Lock()
n, err := d.writeLocked(p)
d.mutex.Unlock()
return n, err
}
func (d *Decoder) writeLocked(p []byte) (int, error) {
// Only copy the data if we have to. Optimistically assume
// that p will contain a complete header block.
if d.saveBuf.Len() == 0 {
d.buf = p
} else {
d.saveBuf.Write(p)
d.buf = d.saveBuf.Bytes()
d.saveBuf.Reset()
}
if err := d.decode(); err != nil {
if err != errNeedMore {
return 0, err
}
// TODO: limit the size of the buffer
d.saveBuf.Write(d.buf)
}
return len(p), nil
}
// DecodeFull decodes an entire block.
func (d *Decoder) DecodeFull(p []byte) ([]HeaderField, error) {
if len(p) == 0 {
return []HeaderField{}, nil
}
d.mutex.Lock()
defer d.mutex.Unlock()
saveFunc := d.emitFunc
defer func() { d.emitFunc = saveFunc }()
var hf []HeaderField
d.emitFunc = func(f HeaderField) { hf = append(hf, f) }
if _, err := d.writeLocked(p); err != nil {
return nil, err
}
if err := d.Close(); err != nil {
return nil, err
}
return hf, nil
}
// Close declares that the decoding is complete and resets the Decoder
// to be reused again for a new header block. If there is any remaining
// data in the decoder's buffer, Close returns an error.
func (d *Decoder) Close() error {
if d.saveBuf.Len() > 0 {
d.saveBuf.Reset()
return decodingError{errors.New("truncated headers")}
}
d.readRequiredInsertCount = false
d.readDeltaBase = false
return nil
}
func (d *Decoder) decode() error {
if !d.readRequiredInsertCount {
requiredInsertCount, rest, err := readVarInt(8, d.buf)
if err != nil {
return err
}
d.readRequiredInsertCount = true
if requiredInsertCount != 0 {
return decodingError{errors.New("expected Required Insert Count to be zero")}
}
d.buf = rest
}
if !d.readDeltaBase {
base, rest, err := readVarInt(7, d.buf)
if err != nil {
return err
}
d.readDeltaBase = true
if base != 0 {
return decodingError{errors.New("expected Base to be zero")}
}
d.buf = rest
}
if len(d.buf) == 0 {
return errNeedMore
}
for len(d.buf) > 0 {
b := d.buf[0]
var err error
switch {
case b&0x80 > 0: // 1xxxxxxx
err = d.parseIndexedHeaderField()
case b&0xc0 == 0x40: // 01xxxxxx
err = d.parseLiteralHeaderField()
case b&0xe0 == 0x20: // 001xxxxx
err = d.parseLiteralHeaderFieldWithoutNameReference()
default:
err = fmt.Errorf("unexpected type byte: %#x", b)
}
if err != nil {
return err
}
}
return nil
}
func (d *Decoder) parseIndexedHeaderField() error {
buf := d.buf
if buf[0]&0x40 == 0 {
return errNoDynamicTable
}
index, buf, err := readVarInt(6, buf)
if err != nil {
return err
}
hf, ok := d.at(index)
if !ok {
return decodingError{invalidIndexError(index)}
}
d.emitFunc(hf)
d.buf = buf
return nil
}
func (d *Decoder) parseLiteralHeaderField() error {
buf := d.buf
if buf[0]&0x10 == 0 {
return errNoDynamicTable
}
// We don't need to check the value of the N-bit here.
// It's only relevant when re-encoding header fields,
// and determines whether the header field can be added to the dynamic table.
// Since we don't support the dynamic table, we can ignore it.
index, buf, err := readVarInt(4, buf)
if err != nil {
return err
}
hf, ok := d.at(index)
if !ok {
return decodingError{invalidIndexError(index)}
}
if len(buf) == 0 {
return errNeedMore
}
usesHuffman := buf[0]&0x80 > 0
val, buf, err := d.readString(buf, 7, usesHuffman)
if err != nil {
return err
}
hf.Value = val
d.emitFunc(hf)
d.buf = buf
return nil
}
func (d *Decoder) parseLiteralHeaderFieldWithoutNameReference() error {
buf := d.buf
usesHuffmanForName := buf[0]&0x8 > 0
name, buf, err := d.readString(buf, 3, usesHuffmanForName)
if err != nil {
return err
}
if len(buf) == 0 {
return errNeedMore
}
usesHuffmanForVal := buf[0]&0x80 > 0
val, buf, err := d.readString(buf, 7, usesHuffmanForVal)
if err != nil {
return err
}
d.emitFunc(HeaderField{Name: name, Value: val})
d.buf = buf
return nil
}
func (d *Decoder) readString(buf []byte, n uint8, usesHuffman bool) (string, []byte, error) {
l, buf, err := readVarInt(n, buf)
if err != nil {
return "", nil, err
}
if uint64(len(buf)) < l {
return "", nil, errNeedMore
}
var val string
if usesHuffman {
var err error
val, err = hpack.HuffmanDecodeToString(buf[:l])
if err != nil {
return "", nil, err
}
} else {
val = string(buf[:l])
}
buf = buf[l:]
return val, buf, nil
}
func (d *Decoder) at(i uint64) (hf HeaderField, ok bool) {
if i >= uint64(len(staticTableEntries)) {
return
}
return staticTableEntries[i], true
}
package qpack
import (
"io"
"golang.org/x/net/http2/hpack"
)
// An Encoder performs QPACK encoding.
type Encoder struct {
wrotePrefix bool
w io.Writer
buf []byte
}
// NewEncoder returns a new Encoder which performs QPACK encoding. An
// encoded data is written to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w}
}
// WriteField encodes f into a single Write to e's underlying Writer.
// This function may also produce bytes for the Header Block Prefix
// if necessary. If produced, it is done before encoding f.
func (e *Encoder) WriteField(f HeaderField) error {
// write the Header Block Prefix
if !e.wrotePrefix {
e.buf = appendVarInt(e.buf, 8, 0)
e.buf = appendVarInt(e.buf, 7, 0)
e.wrotePrefix = true
}
idxAndVals, nameFound := encoderMap[f.Name]
if nameFound {
if idxAndVals.values == nil {
if len(f.Value) == 0 {
e.writeIndexedField(idxAndVals.idx)
} else {
e.writeLiteralFieldWithNameReference(&f, idxAndVals.idx)
}
} else {
valIdx, valueFound := idxAndVals.values[f.Value]
if valueFound {
e.writeIndexedField(valIdx)
} else {
e.writeLiteralFieldWithNameReference(&f, idxAndVals.idx)
}
}
} else {
e.writeLiteralFieldWithoutNameReference(f)
}
_, err := e.w.Write(e.buf)
e.buf = e.buf[:0]
return err
}
// Close declares that the encoding is complete and resets the Encoder
// to be reused again for a new header block.
func (e *Encoder) Close() error {
e.wrotePrefix = false
return nil
}
func (e *Encoder) writeLiteralFieldWithoutNameReference(f HeaderField) {
offset := len(e.buf)
e.buf = appendVarInt(e.buf, 3, hpack.HuffmanEncodeLength(f.Name))
e.buf[offset] ^= 0x20 ^ 0x8
e.buf = hpack.AppendHuffmanString(e.buf, f.Name)
offset = len(e.buf)
e.buf = appendVarInt(e.buf, 7, hpack.HuffmanEncodeLength(f.Value))
e.buf[offset] ^= 0x80
e.buf = hpack.AppendHuffmanString(e.buf, f.Value)
}
// Encodes a header field whose name is present in one of the tables.
func (e *Encoder) writeLiteralFieldWithNameReference(f *HeaderField, id uint8) {
offset := len(e.buf)
e.buf = appendVarInt(e.buf, 4, uint64(id))
// Set the 01NTxxxx pattern, forcing N to 0 and T to 1
e.buf[offset] ^= 0x50
offset = len(e.buf)
e.buf = appendVarInt(e.buf, 7, hpack.HuffmanEncodeLength(f.Value))
e.buf[offset] ^= 0x80
e.buf = hpack.AppendHuffmanString(e.buf, f.Value)
}
// Encodes an indexed field, meaning it's entirely defined in one of the tables.
func (e *Encoder) writeIndexedField(id uint8) {
offset := len(e.buf)
e.buf = appendVarInt(e.buf, 6, uint64(id))
// Set the 1Txxxxxx pattern, forcing T to 1
e.buf[offset] ^= 0xc0
}
package qpack
import (
"bytes"
"fmt"
"reflect"
"github.com/quic-go/qpack"
)
func Fuzz(data []byte) int {
if len(data) < 1 {
return 0
}
chunkLen := int(data[0]) + 1
data = data[1:]
fields, err := qpack.NewDecoder(nil).DecodeFull(data)
if err != nil {
return 0
}
if len(fields) == 0 {
return 0
}
var writtenFields []qpack.HeaderField
decoder := qpack.NewDecoder(func(hf qpack.HeaderField) {
writtenFields = append(writtenFields, hf)
})
for len(data) > 0 {
var chunk []byte
if chunkLen <= len(data) {
chunk = data[:chunkLen]
data = data[chunkLen:]
} else {
chunk = data
data = nil
}
n, err := decoder.Write(chunk)
if err != nil {
return 0
}
if n != len(chunk) {
panic("len error")
}
}
if !reflect.DeepEqual(fields, writtenFields) {
fmt.Printf("%#v vs %#v", fields, writtenFields)
panic("Write() and DecodeFull() produced different results")
}
buf := &bytes.Buffer{}
encoder := qpack.NewEncoder(buf)
for _, hf := range fields {
if err := encoder.WriteField(hf); err != nil {
panic(err)
}
}
if err := encoder.Close(); err != nil {
panic(err)
}
encodedFields, err := qpack.NewDecoder(nil).DecodeFull(buf.Bytes())
if err != nil {
fmt.Printf("Fields: %#v\n", fields)
panic(err)
}
if !reflect.DeepEqual(fields, encodedFields) {
fmt.Printf("%#v vs %#v", fields, encodedFields)
panic("unequal")
}
return 0
}
package qpack
// A HeaderField is a name-value pair. Both the name and value are
// treated as opaque sequences of octets.
type HeaderField struct {
Name string
Value string
}
// IsPseudo reports whether the header field is an HTTP3 pseudo header.
// That is, it reports whether it starts with a colon.
// It is not otherwise guaranteed to be a valid pseudo header field,
// though.
func (hf HeaderField) IsPseudo() bool {
return len(hf.Name) != 0 && hf.Name[0] == ':'
}
package qpack
var staticTableEntries = [...]HeaderField{
{Name: ":authority"},
{Name: ":path", Value: "/"},
{Name: "age", Value: "0"},
{Name: "content-disposition"},
{Name: "content-length", Value: "0"},
{Name: "cookie"},
{Name: "date"},
{Name: "etag"},
{Name: "if-modified-since"},
{Name: "if-none-match"},
{Name: "last-modified"},
{Name: "link"},
{Name: "location"},
{Name: "referer"},
{Name: "set-cookie"},
{Name: ":method", Value: "CONNECT"},
{Name: ":method", Value: "DELETE"},
{Name: ":method", Value: "GET"},
{Name: ":method", Value: "HEAD"},
{Name: ":method", Value: "OPTIONS"},
{Name: ":method", Value: "POST"},
{Name: ":method", Value: "PUT"},
{Name: ":scheme", Value: "http"},
{Name: ":scheme", Value: "https"},
{Name: ":status", Value: "103"},
{Name: ":status", Value: "200"},
{Name: ":status", Value: "304"},
{Name: ":status", Value: "404"},
{Name: ":status", Value: "503"},
{Name: "accept", Value: "*/*"},
{Name: "accept", Value: "application/dns-message"},
{Name: "accept-encoding", Value: "gzip, deflate, br"},
{Name: "accept-ranges", Value: "bytes"},
{Name: "access-control-allow-headers", Value: "cache-control"},
{Name: "access-control-allow-headers", Value: "content-type"},
{Name: "access-control-allow-origin", Value: "*"},
{Name: "cache-control", Value: "max-age=0"},
{Name: "cache-control", Value: "max-age=2592000"},
{Name: "cache-control", Value: "max-age=604800"},
{Name: "cache-control", Value: "no-cache"},
{Name: "cache-control", Value: "no-store"},
{Name: "cache-control", Value: "public, max-age=31536000"},
{Name: "content-encoding", Value: "br"},
{Name: "content-encoding", Value: "gzip"},
{Name: "content-type", Value: "application/dns-message"},
{Name: "content-type", Value: "application/javascript"},
{Name: "content-type", Value: "application/json"},
{Name: "content-type", Value: "application/x-www-form-urlencoded"},
{Name: "content-type", Value: "image/gif"},
{Name: "content-type", Value: "image/jpeg"},
{Name: "content-type", Value: "image/png"},
{Name: "content-type", Value: "text/css"},
{Name: "content-type", Value: "text/html; charset=utf-8"},
{Name: "content-type", Value: "text/plain"},
{Name: "content-type", Value: "text/plain;charset=utf-8"},
{Name: "range", Value: "bytes=0-"},
{Name: "strict-transport-security", Value: "max-age=31536000"},
{Name: "strict-transport-security", Value: "max-age=31536000; includesubdomains"},
{Name: "strict-transport-security", Value: "max-age=31536000; includesubdomains; preload"},
{Name: "vary", Value: "accept-encoding"},
{Name: "vary", Value: "origin"},
{Name: "x-content-type-options", Value: "nosniff"},
{Name: "x-xss-protection", Value: "1; mode=block"},
{Name: ":status", Value: "100"},
{Name: ":status", Value: "204"},
{Name: ":status", Value: "206"},
{Name: ":status", Value: "302"},
{Name: ":status", Value: "400"},
{Name: ":status", Value: "403"},
{Name: ":status", Value: "421"},
{Name: ":status", Value: "425"},
{Name: ":status", Value: "500"},
{Name: "accept-language"},
{Name: "access-control-allow-credentials", Value: "FALSE"},
{Name: "access-control-allow-credentials", Value: "TRUE"},
{Name: "access-control-allow-headers", Value: "*"},
{Name: "access-control-allow-methods", Value: "get"},
{Name: "access-control-allow-methods", Value: "get, post, options"},
{Name: "access-control-allow-methods", Value: "options"},
{Name: "access-control-expose-headers", Value: "content-length"},
{Name: "access-control-request-headers", Value: "content-type"},
{Name: "access-control-request-method", Value: "get"},
{Name: "access-control-request-method", Value: "post"},
{Name: "alt-svc", Value: "clear"},
{Name: "authorization"},
{Name: "content-security-policy", Value: "script-src 'none'; object-src 'none'; base-uri 'none'"},
{Name: "early-data", Value: "1"},
{Name: "expect-ct"},
{Name: "forwarded"},
{Name: "if-range"},
{Name: "origin"},
{Name: "purpose", Value: "prefetch"},
{Name: "server"},
{Name: "timing-allow-origin", Value: "*"},
{Name: "upgrade-insecure-requests", Value: "1"},
{Name: "user-agent"},
{Name: "x-forwarded-for"},
{Name: "x-frame-options", Value: "deny"},
{Name: "x-frame-options", Value: "sameorigin"},
}
// Only needed for tests.
// use go:linkname to retrieve the static table.
//
//nolint:unused
func getStaticTable() []HeaderField {
return staticTableEntries[:]
}
type indexAndValues struct {
idx uint8
values map[string]uint8
}
// A map of the header names from the static table to their index in the table.
// This is used by the encoder to quickly find if a header is in the static table
// and what value should be used to encode it.
// There's a second level of mapping for the headers that have some predefined
// values in the static table.
var encoderMap = map[string]indexAndValues{
":authority": {0, nil},
":path": {1, map[string]uint8{"/": 1}},
"age": {2, map[string]uint8{"0": 2}},
"content-disposition": {3, nil},
"content-length": {4, map[string]uint8{"0": 4}},
"cookie": {5, nil},
"date": {6, nil},
"etag": {7, nil},
"if-modified-since": {8, nil},
"if-none-match": {9, nil},
"last-modified": {10, nil},
"link": {11, nil},
"location": {12, nil},
"referer": {13, nil},
"set-cookie": {14, nil},
":method": {15, map[string]uint8{
"CONNECT": 15,
"DELETE": 16,
"GET": 17,
"HEAD": 18,
"OPTIONS": 19,
"POST": 20,
"PUT": 21,
}},
":scheme": {22, map[string]uint8{
"http": 22,
"https": 23,
}},
":status": {24, map[string]uint8{
"103": 24,
"200": 25,
"304": 26,
"404": 27,
"503": 28,
"100": 63,
"204": 64,
"206": 65,
"302": 66,
"400": 67,
"403": 68,
"421": 69,
"425": 70,
"500": 71,
}},
"accept": {29, map[string]uint8{
"*/*": 29,
"application/dns-message": 30,
}},
"accept-encoding": {31, map[string]uint8{"gzip, deflate, br": 31}},
"accept-ranges": {32, map[string]uint8{"bytes": 32}},
"access-control-allow-headers": {33, map[string]uint8{
"cache-control": 33,
"content-type": 34,
"*": 75,
}},
"access-control-allow-origin": {35, map[string]uint8{"*": 35}},
"cache-control": {36, map[string]uint8{
"max-age=0": 36,
"max-age=2592000": 37,
"max-age=604800": 38,
"no-cache": 39,
"no-store": 40,
"public, max-age=31536000": 41,
}},
"content-encoding": {42, map[string]uint8{
"br": 42,
"gzip": 43,
}},
"content-type": {44, map[string]uint8{
"application/dns-message": 44,
"application/javascript": 45,
"application/json": 46,
"application/x-www-form-urlencoded": 47,
"image/gif": 48,
"image/jpeg": 49,
"image/png": 50,
"text/css": 51,
"text/html; charset=utf-8": 52,
"text/plain": 53,
"text/plain;charset=utf-8": 54,
}},
"range": {55, map[string]uint8{"bytes=0-": 55}},
"strict-transport-security": {56, map[string]uint8{
"max-age=31536000": 56,
"max-age=31536000; includesubdomains": 57,
"max-age=31536000; includesubdomains; preload": 58,
}},
"vary": {59, map[string]uint8{
"accept-encoding": 59,
"origin": 60,
}},
"x-content-type-options": {61, map[string]uint8{"nosniff": 61}},
"x-xss-protection": {62, map[string]uint8{"1; mode=block": 62}},
// ":status" is duplicated and takes index 63 to 71
"accept-language": {72, nil},
"access-control-allow-credentials": {73, map[string]uint8{
"FALSE": 73,
"TRUE": 74,
}},
// "access-control-allow-headers" is duplicated and takes index 75
"access-control-allow-methods": {76, map[string]uint8{
"get": 76,
"get, post, options": 77,
"options": 78,
}},
"access-control-expose-headers": {79, map[string]uint8{"content-length": 79}},
"access-control-request-headers": {80, map[string]uint8{"content-type": 80}},
"access-control-request-method": {81, map[string]uint8{
"get": 81,
"post": 82,
}},
"alt-svc": {83, map[string]uint8{"clear": 83}},
"authorization": {84, nil},
"content-security-policy": {85, map[string]uint8{
"script-src 'none'; object-src 'none'; base-uri 'none'": 85,
}},
"early-data": {86, map[string]uint8{"1": 86}},
"expect-ct": {87, nil},
"forwarded": {88, nil},
"if-range": {89, nil},
"origin": {90, nil},
"purpose": {91, map[string]uint8{"prefetch": 91}},
"server": {92, nil},
"timing-allow-origin": {93, map[string]uint8{"*": 93}},
"upgrade-insecure-requests": {94, map[string]uint8{"1": 94}},
"user-agent": {95, nil},
"x-forwarded-for": {96, nil},
"x-frame-options": {97, map[string]uint8{
"deny": 97,
"sameorigin": 98,
}},
}
package qpack
// copied from the Go standard library HPACK implementation
import "errors"
var errVarintOverflow = errors.New("varint integer overflow")
// appendVarInt appends i, as encoded in variable integer form using n
// bit prefix, to dst and returns the extended buffer.
//
// See
// http://http2.github.io/http2-spec/compression.html#integer.representation
func appendVarInt(dst []byte, n byte, i uint64) []byte {
k := uint64((1 << n) - 1)
if i < k {
return append(dst, byte(i))
}
dst = append(dst, byte(k))
i -= k
for ; i >= 128; i >>= 7 {
dst = append(dst, byte(0x80|(i&0x7f)))
}
return append(dst, byte(i))
}
// readVarInt reads an unsigned variable length integer off the
// beginning of p. n is the parameter as described in
// http://http2.github.io/http2-spec/compression.html#rfc.section.5.1.
//
// n must always be between 1 and 8.
//
// The returned remain buffer is either a smaller suffix of p, or err != nil.
// The error is errNeedMore if p doesn't contain a complete integer.
func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) {
if n < 1 || n > 8 {
panic("bad n")
}
if len(p) == 0 {
return 0, p, errNeedMore
}
i = uint64(p[0])
if n < 8 {
i &= (1 << uint64(n)) - 1
}
if i < (1<<uint64(n))-1 {
return i, p[1:], nil
}
origP := p
p = p[1:]
var m uint64
for len(p) > 0 {
b := p[0]
p = p[1:]
i += uint64(b&127) << m
if b&128 == 0 {
return i, p, nil
}
m += 7
if m >= 63 { // TODO: proper overflow check. making this up.
return 0, origP, errVarintOverflow
}
}
return 0, origP, errNeedMore
}
package quic
import (
"sync"
"github.com/quic-go/quic-go/internal/protocol"
)
type packetBuffer struct {
Data []byte
// refCount counts how many packets Data is used in.
// It doesn't support concurrent use.
// It is > 1 when used for coalesced packet.
refCount int
}
// Split increases the refCount.
// It must be called when a packet buffer is used for more than one packet,
// e.g. when splitting coalesced packets.
func (b *packetBuffer) Split() {
b.refCount++
}
// Decrement decrements the reference counter.
// It doesn't put the buffer back into the pool.
func (b *packetBuffer) Decrement() {
b.refCount--
if b.refCount < 0 {
panic("negative packetBuffer refCount")
}
}
// MaybeRelease puts the packet buffer back into the pool,
// if the reference counter already reached 0.
func (b *packetBuffer) MaybeRelease() {
// only put the packetBuffer back if it's not used any more
if b.refCount == 0 {
b.putBack()
}
}
// Release puts back the packet buffer into the pool.
// It should be called when processing is definitely finished.
func (b *packetBuffer) Release() {
b.Decrement()
if b.refCount != 0 {
panic("packetBuffer refCount not zero")
}
b.putBack()
}
// Len returns the length of Data
func (b *packetBuffer) Len() protocol.ByteCount { return protocol.ByteCount(len(b.Data)) }
func (b *packetBuffer) Cap() protocol.ByteCount { return protocol.ByteCount(cap(b.Data)) }
func (b *packetBuffer) putBack() {
if cap(b.Data) == protocol.MaxPacketBufferSize {
bufferPool.Put(b)
return
}
if cap(b.Data) == protocol.MaxLargePacketBufferSize {
largeBufferPool.Put(b)
return
}
panic("putPacketBuffer called with packet of wrong size!")
}
var bufferPool, largeBufferPool sync.Pool
func getPacketBuffer() *packetBuffer {
buf := bufferPool.Get().(*packetBuffer)
buf.refCount = 1
buf.Data = buf.Data[:0]
return buf
}
func getLargePacketBuffer() *packetBuffer {
buf := largeBufferPool.Get().(*packetBuffer)
buf.refCount = 1
buf.Data = buf.Data[:0]
return buf
}
func init() {
bufferPool.New = func() any {
return &packetBuffer{Data: make([]byte, 0, protocol.MaxPacketBufferSize)}
}
largeBufferPool.New = func() any {
return &packetBuffer{Data: make([]byte, 0, protocol.MaxLargePacketBufferSize)}
}
}
package quic
import (
"context"
"crypto/tls"
"errors"
"net"
"github.com/quic-go/quic-go/internal/protocol"
)
// make it possible to mock connection ID for initial generation in the tests
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
// DialAddr establishes a new QUIC connection to a server.
// It resolves the address, and then creates a new UDP connection to dial the QUIC server.
// When the QUIC connection is closed, this UDP connection is closed.
// See [Dial] for more details.
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (*Conn, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
tr, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, false)
if err != nil {
tr.Close()
return nil, err
}
return conn, nil
}
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
// See [DialAddr] for more details.
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (*Conn, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
tr, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, true)
if err != nil {
tr.Close()
return nil, err
}
return conn, nil
}
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
// See [Dial] for more details.
func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (*Conn, error) {
dl, err := setupTransport(c, tlsConf, false)
if err != nil {
return nil, err
}
conn, err := dl.DialEarly(ctx, addr, tlsConf, conf)
if err != nil {
dl.Close()
return nil, err
}
return conn, nil
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// If the PacketConn satisfies the [OOBCapablePacketConn] interface (as a [net.UDPConn] does),
// ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
// will be used instead of ReadFrom and WriteTo to read/write packets.
// The [tls.Config] must define an application protocol (using tls.Config.NextProtos).
//
// This is a convenience function. More advanced use cases should instantiate a [Transport],
// which offers configuration options for a more fine-grained control of the connection establishment,
// including reusing the underlying UDP socket for multiple QUIC connections.
func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (*Conn, error) {
dl, err := setupTransport(c, tlsConf, false)
if err != nil {
return nil, err
}
conn, err := dl.Dial(ctx, addr, tlsConf, conf)
if err != nil {
dl.Close()
return nil, err
}
return conn, nil
}
func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
return &Transport{
Conn: c,
createdConn: createdPacketConn,
isSingleUse: true,
}, nil
}
package quic
import (
"math/bits"
"net"
"sync/atomic"
"github.com/quic-go/quic-go/internal/utils"
)
// A closedLocalConn is a connection that we closed locally.
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
// with an exponential backoff.
type closedLocalConn struct {
counter atomic.Uint32
logger utils.Logger
sendPacket func(net.Addr, packetInfo)
}
var _ packetHandler = &closedLocalConn{}
// newClosedLocalConn creates a new closedLocalConn and runs it.
func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), logger utils.Logger) packetHandler {
return &closedLocalConn{
sendPacket: sendPacket,
logger: logger,
}
}
func (c *closedLocalConn) handlePacket(p receivedPacket) {
n := c.counter.Add(1)
// exponential backoff
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
if bits.OnesCount32(n) != 1 {
return
}
c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", n)
c.sendPacket(p.remoteAddr, p.info)
}
func (c *closedLocalConn) destroy(error) {}
func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {}
// A closedRemoteConn is a connection that was closed remotely.
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
// We can just ignore those packets.
type closedRemoteConn struct{}
var _ packetHandler = &closedRemoteConn{}
func newClosedRemoteConn() packetHandler {
return &closedRemoteConn{}
}
func (c *closedRemoteConn) handlePacket(receivedPacket) {}
func (c *closedRemoteConn) destroy(error) {}
func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {}
package quic
import (
"fmt"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// Clone clones a Config.
func (c *Config) Clone() *Config {
copy := *c
return ©
}
func (c *Config) handshakeTimeout() time.Duration {
return 2 * c.HandshakeIdleTimeout
}
func (c *Config) maxRetryTokenAge() time.Duration {
return c.handshakeTimeout()
}
func validateConfig(config *Config) error {
if config == nil {
return nil
}
const maxStreams = 1 << 60
if config.MaxIncomingStreams > maxStreams {
config.MaxIncomingStreams = maxStreams
}
if config.MaxIncomingUniStreams > maxStreams {
config.MaxIncomingUniStreams = maxStreams
}
if config.MaxStreamReceiveWindow > quicvarint.Max {
config.MaxStreamReceiveWindow = quicvarint.Max
}
if config.MaxConnectionReceiveWindow > quicvarint.Max {
config.MaxConnectionReceiveWindow = quicvarint.Max
}
if config.InitialPacketSize > 0 && config.InitialPacketSize < protocol.MinInitialPacketSize {
config.InitialPacketSize = protocol.MinInitialPacketSize
}
if config.InitialPacketSize > protocol.MaxPacketBufferSize {
config.InitialPacketSize = protocol.MaxPacketBufferSize
}
// check that all QUIC versions are actually supported
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return fmt.Errorf("invalid QUIC version: %s", v)
}
}
return nil
}
// populateConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateConfig(config *Config) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions
if len(versions) == 0 {
versions = protocol.SupportedVersions
}
handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout
if config.HandshakeIdleTimeout != 0 {
handshakeIdleTimeout = config.HandshakeIdleTimeout
}
idleTimeout := protocol.DefaultIdleTimeout
if config.MaxIdleTimeout != 0 {
idleTimeout = config.MaxIdleTimeout
}
initialStreamReceiveWindow := config.InitialStreamReceiveWindow
if initialStreamReceiveWindow == 0 {
initialStreamReceiveWindow = protocol.DefaultInitialMaxStreamData
}
maxStreamReceiveWindow := config.MaxStreamReceiveWindow
if maxStreamReceiveWindow == 0 {
maxStreamReceiveWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
}
initialConnectionReceiveWindow := config.InitialConnectionReceiveWindow
if initialConnectionReceiveWindow == 0 {
initialConnectionReceiveWindow = protocol.DefaultInitialMaxData
}
maxConnectionReceiveWindow := config.MaxConnectionReceiveWindow
if maxConnectionReceiveWindow == 0 {
maxConnectionReceiveWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
}
maxIncomingStreams := config.MaxIncomingStreams
if maxIncomingStreams == 0 {
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
} else if maxIncomingStreams < 0 {
maxIncomingStreams = 0
}
maxIncomingUniStreams := config.MaxIncomingUniStreams
if maxIncomingUniStreams == 0 {
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
initialPacketSize := config.InitialPacketSize
if initialPacketSize == 0 {
initialPacketSize = protocol.InitialPacketSize
}
return &Config{
GetConfigForClient: config.GetConfigForClient,
Versions: versions,
HandshakeIdleTimeout: handshakeIdleTimeout,
MaxIdleTimeout: idleTimeout,
KeepAlivePeriod: config.KeepAlivePeriod,
InitialStreamReceiveWindow: initialStreamReceiveWindow,
MaxStreamReceiveWindow: maxStreamReceiveWindow,
InitialConnectionReceiveWindow: initialConnectionReceiveWindow,
MaxConnectionReceiveWindow: maxConnectionReceiveWindow,
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
TokenStore: config.TokenStore,
EnableDatagrams: config.EnableDatagrams,
InitialPacketSize: initialPacketSize,
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
EnableStreamResetPartialDelivery: config.EnableStreamResetPartialDelivery,
Allow0RTT: config.Allow0RTT,
Tracer: config.Tracer,
}
}
package quic
import (
"fmt"
"slices"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
)
type connRunnerCallbacks struct {
AddConnectionID func(protocol.ConnectionID)
RemoveConnectionID func(protocol.ConnectionID)
ReplaceWithClosed func([]protocol.ConnectionID, []byte, time.Duration)
}
// The memory address of the Transport is used as the key.
type connRunners map[connRunner]connRunnerCallbacks
func (cr connRunners) AddConnectionID(id protocol.ConnectionID) {
for _, c := range cr {
c.AddConnectionID(id)
}
}
func (cr connRunners) RemoveConnectionID(id protocol.ConnectionID) {
for _, c := range cr {
c.RemoveConnectionID(id)
}
}
func (cr connRunners) ReplaceWithClosed(ids []protocol.ConnectionID, b []byte, expiry time.Duration) {
for _, c := range cr {
c.ReplaceWithClosed(ids, b, expiry)
}
}
type connIDToRetire struct {
t time.Time
connID protocol.ConnectionID
}
type connIDGenerator struct {
generator ConnectionIDGenerator
highestSeq uint64
connRunners connRunners
activeSrcConnIDs map[uint64]protocol.ConnectionID
connIDsToRetire []connIDToRetire // sorted by t
initialClientDestConnID *protocol.ConnectionID // nil for the client
statelessResetter *statelessResetter
queueControlFrame func(wire.Frame)
}
func newConnIDGenerator(
runner connRunner,
initialConnectionID protocol.ConnectionID,
initialClientDestConnID *protocol.ConnectionID, // nil for the client
statelessResetter *statelessResetter,
callbacks connRunnerCallbacks,
queueControlFrame func(wire.Frame),
generator ConnectionIDGenerator,
) *connIDGenerator {
m := &connIDGenerator{
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
statelessResetter: statelessResetter,
connRunners: map[connRunner]connRunnerCallbacks{runner: callbacks},
queueControlFrame: queueControlFrame,
}
m.activeSrcConnIDs[0] = initialConnectionID
m.initialClientDestConnID = initialClientDestConnID
return m
}
func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
if m.generator.ConnectionIDLen() == 0 {
return nil
}
// The active_connection_id_limit transport parameter is the number of
// connection IDs the peer will store. This limit includes the connection ID
// used during the handshake, and the one sent in the preferred_address
// transport parameter.
// We currently don't send the preferred_address transport parameter,
// so we can issue (limit - 1) connection IDs.
for i := uint64(len(m.activeSrcConnIDs)); i < min(limit, protocol.MaxIssuedConnectionIDs); i++ {
if err := m.issueNewConnID(); err != nil {
return err
}
}
return nil
}
func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID, expiry time.Time) error {
if seq > m.highestSeq {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq),
}
}
connID, ok := m.activeSrcConnIDs[seq]
// We might already have deleted this connection ID, if this is a duplicate frame.
if !ok {
return nil
}
if connID == sentWithDestConnID {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
}
}
m.queueConnIDForRetiring(connID, expiry)
delete(m.activeSrcConnIDs, seq)
// Don't issue a replacement for the initial connection ID.
if seq == 0 {
return nil
}
return m.issueNewConnID()
}
func (m *connIDGenerator) queueConnIDForRetiring(connID protocol.ConnectionID, expiry time.Time) {
idx := slices.IndexFunc(m.connIDsToRetire, func(c connIDToRetire) bool {
return c.t.After(expiry)
})
if idx == -1 {
idx = len(m.connIDsToRetire)
}
m.connIDsToRetire = slices.Insert(m.connIDsToRetire, idx, connIDToRetire{t: expiry, connID: connID})
}
func (m *connIDGenerator) issueNewConnID() error {
connID, err := m.generator.GenerateConnectionID()
if err != nil {
return err
}
m.activeSrcConnIDs[m.highestSeq+1] = connID
m.connRunners.AddConnectionID(connID)
m.queueControlFrame(&wire.NewConnectionIDFrame{
SequenceNumber: m.highestSeq + 1,
ConnectionID: connID,
StatelessResetToken: m.statelessResetter.GetStatelessResetToken(connID),
})
m.highestSeq++
return nil
}
func (m *connIDGenerator) SetHandshakeComplete(connIDExpiry time.Time) {
if m.initialClientDestConnID != nil {
m.queueConnIDForRetiring(*m.initialClientDestConnID, connIDExpiry)
m.initialClientDestConnID = nil
}
}
func (m *connIDGenerator) NextRetireTime() time.Time {
if len(m.connIDsToRetire) == 0 {
return time.Time{}
}
return m.connIDsToRetire[0].t
}
func (m *connIDGenerator) RemoveRetiredConnIDs(now time.Time) {
if len(m.connIDsToRetire) == 0 {
return
}
for _, c := range m.connIDsToRetire {
if c.t.After(now) {
break
}
m.connRunners.RemoveConnectionID(c.connID)
m.connIDsToRetire = m.connIDsToRetire[1:]
}
}
func (m *connIDGenerator) RemoveAll() {
if m.initialClientDestConnID != nil {
m.connRunners.RemoveConnectionID(*m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
m.connRunners.RemoveConnectionID(connID)
}
for _, c := range m.connIDsToRetire {
m.connRunners.RemoveConnectionID(c.connID)
}
}
func (m *connIDGenerator) ReplaceWithClosed(connClose []byte, expiry time.Duration) {
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+len(m.connIDsToRetire)+1)
if m.initialClientDestConnID != nil {
connIDs = append(connIDs, *m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
connIDs = append(connIDs, connID)
}
for _, c := range m.connIDsToRetire {
connIDs = append(connIDs, c.connID)
}
m.connRunners.ReplaceWithClosed(connIDs, connClose, expiry)
}
func (m *connIDGenerator) AddConnRunner(runner connRunner, r connRunnerCallbacks) {
// The transport might have already been added earlier.
// This happens if the application migrates back to and old path.
if _, ok := m.connRunners[runner]; ok {
return
}
m.connRunners[runner] = r
if m.initialClientDestConnID != nil {
r.AddConnectionID(*m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
r.AddConnectionID(connID)
}
}
package quic
import (
"fmt"
"slices"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
)
type newConnID struct {
SequenceNumber uint64
ConnectionID protocol.ConnectionID
StatelessResetToken protocol.StatelessResetToken
}
type connIDManager struct {
queue []newConnID
highestProbingID uint64
pathProbing map[pathID]newConnID // initialized lazily
handshakeComplete bool
activeSequenceNumber uint64
highestRetired uint64
activeConnectionID protocol.ConnectionID
activeStatelessResetToken *protocol.StatelessResetToken
// We change the connection ID after sending on average
// protocol.PacketsPerConnectionID packets. The actual value is randomized
// hide the packet loss rate from on-path observers.
rand utils.Rand
packetsSinceLastChange uint32
packetsPerConnectionID uint32
addStatelessResetToken func(protocol.StatelessResetToken)
removeStatelessResetToken func(protocol.StatelessResetToken)
queueControlFrame func(wire.Frame)
closed bool
}
func newConnIDManager(
initialDestConnID protocol.ConnectionID,
addStatelessResetToken func(protocol.StatelessResetToken),
removeStatelessResetToken func(protocol.StatelessResetToken),
queueControlFrame func(wire.Frame),
) *connIDManager {
return &connIDManager{
activeConnectionID: initialDestConnID,
addStatelessResetToken: addStatelessResetToken,
removeStatelessResetToken: removeStatelessResetToken,
queueControlFrame: queueControlFrame,
queue: make([]newConnID, 0, protocol.MaxActiveConnectionIDs),
}
}
func (h *connIDManager) AddFromPreferredAddress(connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
return h.addConnectionID(1, connID, resetToken)
}
func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error {
if err := h.add(f); err != nil {
return err
}
if len(h.queue) >= protocol.MaxActiveConnectionIDs {
return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}
}
return nil
}
func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
if h.activeConnectionID.Len() == 0 {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "received NEW_CONNECTION_ID frame but zero-length connection IDs are in use",
}
}
// If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active
// connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately.
if f.SequenceNumber < max(h.activeSequenceNumber, h.highestProbingID) || f.SequenceNumber < h.highestRetired {
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: f.SequenceNumber,
})
return nil
}
if f.RetirePriorTo != 0 && h.pathProbing != nil {
for id, entry := range h.pathProbing {
if entry.SequenceNumber < f.RetirePriorTo {
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: entry.SequenceNumber,
})
h.removeStatelessResetToken(entry.StatelessResetToken)
delete(h.pathProbing, id)
}
}
}
// Retire elements in the queue.
// Doesn't retire the active connection ID.
if f.RetirePriorTo > h.highestRetired {
var newQueue []newConnID
for _, entry := range h.queue {
if entry.SequenceNumber >= f.RetirePriorTo {
newQueue = append(newQueue, entry)
} else {
h.queueControlFrame(&wire.RetireConnectionIDFrame{SequenceNumber: entry.SequenceNumber})
}
}
h.queue = newQueue
h.highestRetired = f.RetirePriorTo
}
if f.SequenceNumber == h.activeSequenceNumber {
return nil
}
if err := h.addConnectionID(f.SequenceNumber, f.ConnectionID, f.StatelessResetToken); err != nil {
return err
}
// Retire the active connection ID, if necessary.
if h.activeSequenceNumber < f.RetirePriorTo {
// The queue is guaranteed to have at least one element at this point.
h.updateConnectionID()
}
return nil
}
func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
// fast path: add to the end of the queue
if len(h.queue) == 0 || h.queue[len(h.queue)-1].SequenceNumber < seq {
h.queue = append(h.queue, newConnID{
SequenceNumber: seq,
ConnectionID: connID,
StatelessResetToken: resetToken,
})
return nil
}
// slow path: insert in the middle
for i, entry := range h.queue {
if entry.SequenceNumber == seq {
if entry.ConnectionID != connID {
return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq)
}
if entry.StatelessResetToken != resetToken {
return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", seq)
}
return nil
}
// insert at the correct position to maintain sorted order
if entry.SequenceNumber > seq {
h.queue = slices.Insert(h.queue, i, newConnID{
SequenceNumber: seq,
ConnectionID: connID,
StatelessResetToken: resetToken,
})
return nil
}
}
return nil // unreachable
}
func (h *connIDManager) updateConnectionID() {
h.assertNotClosed()
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: h.activeSequenceNumber,
})
h.highestRetired = max(h.highestRetired, h.activeSequenceNumber)
if h.activeStatelessResetToken != nil {
h.removeStatelessResetToken(*h.activeStatelessResetToken)
}
front := h.queue[0]
h.queue = h.queue[1:]
h.activeSequenceNumber = front.SequenceNumber
h.activeConnectionID = front.ConnectionID
h.activeStatelessResetToken = &front.StatelessResetToken
h.packetsSinceLastChange = 0
h.packetsPerConnectionID = protocol.PacketsPerConnectionID/2 + uint32(h.rand.Int31n(protocol.PacketsPerConnectionID))
h.addStatelessResetToken(*h.activeStatelessResetToken)
}
func (h *connIDManager) Close() {
h.closed = true
if h.activeStatelessResetToken != nil {
h.removeStatelessResetToken(*h.activeStatelessResetToken)
}
if h.pathProbing != nil {
for _, entry := range h.pathProbing {
h.removeStatelessResetToken(entry.StatelessResetToken)
}
}
}
// is called when the server performs a Retry
// and when the server changes the connection ID in the first Initial sent
func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) {
if h.activeSequenceNumber != 0 {
panic("expected first connection ID to have sequence number 0")
}
h.activeConnectionID = newConnID
}
// is called when the server provides a stateless reset token in the transport parameters
func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToken) {
h.assertNotClosed()
if h.activeSequenceNumber != 0 {
panic("expected first connection ID to have sequence number 0")
}
h.activeStatelessResetToken = &token
h.addStatelessResetToken(token)
}
func (h *connIDManager) SentPacket() {
h.packetsSinceLastChange++
}
func (h *connIDManager) shouldUpdateConnID() bool {
if !h.handshakeComplete {
return false
}
// initiate the first change as early as possible (after handshake completion)
if len(h.queue) > 0 && h.activeSequenceNumber == 0 {
return true
}
// For later changes, only change if
// 1. The queue of connection IDs is filled more than 50%.
// 2. We sent at least PacketsPerConnectionID packets
return 2*len(h.queue) >= protocol.MaxActiveConnectionIDs &&
h.packetsSinceLastChange >= h.packetsPerConnectionID
}
func (h *connIDManager) Get() protocol.ConnectionID {
h.assertNotClosed()
if h.shouldUpdateConnID() {
h.updateConnectionID()
}
return h.activeConnectionID
}
func (h *connIDManager) SetHandshakeComplete() {
h.handshakeComplete = true
}
// GetConnIDForPath retrieves a connection ID for a new path (i.e. not the active one).
// Once a connection ID is allocated for a path, it cannot be used for a different path.
// When called with the same pathID, it will return the same connection ID,
// unless the peer requested that this connection ID be retired.
func (h *connIDManager) GetConnIDForPath(id pathID) (protocol.ConnectionID, bool) {
h.assertNotClosed()
// if we're using zero-length connection IDs, we don't need to change the connection ID
if h.activeConnectionID.Len() == 0 {
return protocol.ConnectionID{}, true
}
if h.pathProbing == nil {
h.pathProbing = make(map[pathID]newConnID)
}
entry, ok := h.pathProbing[id]
if ok {
return entry.ConnectionID, true
}
if len(h.queue) == 0 {
return protocol.ConnectionID{}, false
}
front := h.queue[0]
h.queue = h.queue[1:]
h.pathProbing[id] = front
h.highestProbingID = front.SequenceNumber
h.addStatelessResetToken(front.StatelessResetToken)
return front.ConnectionID, true
}
func (h *connIDManager) RetireConnIDForPath(pathID pathID) {
h.assertNotClosed()
// if we're using zero-length connection IDs, we don't need to change the connection ID
if h.activeConnectionID.Len() == 0 {
return
}
entry, ok := h.pathProbing[pathID]
if !ok {
return
}
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: entry.SequenceNumber,
})
h.removeStatelessResetToken(entry.StatelessResetToken)
delete(h.pathProbing, pathID)
}
func (h *connIDManager) IsActiveStatelessResetToken(token protocol.StatelessResetToken) bool {
if h.activeStatelessResetToken != nil {
if *h.activeStatelessResetToken == token {
return true
}
}
if h.pathProbing != nil {
for _, entry := range h.pathProbing {
if entry.StatelessResetToken == token {
return true
}
}
}
return false
}
// Using the connIDManager after it has been closed can have disastrous effects:
// If the connection ID is rotated, a new entry would be inserted into the packet handler map,
// leading to a memory leak of the connection struct.
// See https://github.com/quic-go/quic-go/pull/4852 for more details.
func (h *connIDManager) assertNotClosed() {
if h.closed {
panic("connection ID manager is closed")
}
}
package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"reflect"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/utils/ringbuffer"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
type unpacker interface {
UnpackLongHeader(hdr *wire.Header, data []byte) (*unpackedPacket, error)
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
}
type cryptoStreamHandler interface {
StartHandshake(context.Context) error
ChangeConnectionID(protocol.ConnectionID)
SetLargest1RTTAcked(protocol.PacketNumber) error
SetHandshakeConfirmed()
GetSessionTicket() ([]byte, error)
NextEvent() handshake.Event
DiscardInitialKeys()
HandleMessage([]byte, protocol.EncryptionLevel) error
io.Closer
ConnectionState() handshake.ConnectionState
}
type receivedPacket struct {
buffer *packetBuffer
remoteAddr net.Addr
rcvTime time.Time
data []byte
ecn protocol.ECN
info packetInfo // only valid if the contained IP address is valid
}
func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) }
func (p *receivedPacket) Clone() *receivedPacket {
return &receivedPacket{
remoteAddr: p.remoteAddr,
rcvTime: p.rcvTime,
data: p.data,
buffer: p.buffer,
ecn: p.ecn,
info: p.info,
}
}
type connRunner interface {
Add(protocol.ConnectionID, packetHandler) bool
Remove(protocol.ConnectionID)
ReplaceWithClosed([]protocol.ConnectionID, []byte, time.Duration)
AddResetToken(protocol.StatelessResetToken, packetHandler)
RemoveResetToken(protocol.StatelessResetToken)
}
type closeError struct {
err error
immediate bool
}
type errCloseForRecreating struct {
nextPacketNumber protocol.PacketNumber
nextVersion protocol.Version
}
func (e *errCloseForRecreating) Error() string {
return "closing connection in order to recreate it"
}
var connTracingID atomic.Uint64 // to be accessed atomically
func nextConnTracingID() ConnectionTracingID { return ConnectionTracingID(connTracingID.Add(1)) }
// A Conn is a QUIC connection between two peers.
// Calls to the connection (and to streams) can return the following types of errors:
// - [ApplicationError]: for errors triggered by the application running on top of QUIC
// - [TransportError]: for errors triggered by the QUIC transport (in many cases a misbehaving peer)
// - [IdleTimeoutError]: when the peer goes away unexpectedly (this is a [net.Error] timeout error)
// - [HandshakeTimeoutError]: when the cryptographic handshake takes too long (this is a [net.Error] timeout error)
// - [StatelessResetError]: when we receive a stateless reset
// - [VersionNegotiationError]: returned by the client, when there's no version overlap between the peers
type Conn struct {
// Destination connection ID used during the handshake.
// Used to check source connection ID on incoming packets.
handshakeDestConnID protocol.ConnectionID
// Set for the client. Destination connection ID used on the first Initial sent.
origDestConnID protocol.ConnectionID
retrySrcConnID *protocol.ConnectionID // only set for the client (and if a Retry was performed)
srcConnIDLen int
perspective protocol.Perspective
version protocol.Version
config *Config
conn sendConn
sendQueue sender
// lazily initialzed: most connections never migrate
pathManager *pathManager
largestRcvdAppData protocol.PacketNumber
pathManagerOutgoing atomic.Pointer[pathManagerOutgoing]
streamsMap *streamsMap
connIDManager *connIDManager
connIDGenerator *connIDGenerator
rttStats *utils.RTTStats
cryptoStreamManager *cryptoStreamManager
sentPacketHandler ackhandler.SentPacketHandler
receivedPacketHandler ackhandler.ReceivedPacketHandler
retransmissionQueue *retransmissionQueue
framer *framer
connFlowController flowcontrol.ConnectionFlowController
tokenStoreKey string // only set for the client
tokenGenerator *handshake.TokenGenerator // only set for the server
unpacker unpacker
frameParser wire.FrameParser
packer packer
mtuDiscoverer mtuDiscoverer // initialized when the transport parameters are received
currentMTUEstimate atomic.Uint32
initialStream *initialCryptoStream
handshakeStream *cryptoStream
oneRTTStream *cryptoStream // only set for the server
cryptoStreamHandler cryptoStreamHandler
notifyReceivedPacket chan struct{}
sendingScheduled chan struct{}
receivedPacketMx sync.Mutex
receivedPackets ringbuffer.RingBuffer[receivedPacket]
// closeChan is used to notify the run loop that it should terminate
closeChan chan struct{}
closeErr atomic.Pointer[closeError]
ctx context.Context
ctxCancel context.CancelCauseFunc
handshakeCompleteChan chan struct{}
undecryptablePackets []receivedPacket // undecryptable packets, waiting for a change in encryption level
undecryptablePacketsToProcess []receivedPacket
earlyConnReadyChan chan struct{}
sentFirstPacket bool
droppedInitialKeys bool
handshakeComplete bool
handshakeConfirmed bool
receivedRetry bool
versionNegotiated bool
receivedFirstPacket bool
// the minimum of the max_idle_timeout values advertised by both endpoints
idleTimeout time.Duration
creationTime time.Time
// The idle timeout is set based on the max of the time we received the last packet...
lastPacketReceivedTime time.Time
// ... and the time we sent a new ack-eliciting packet after receiving a packet.
firstAckElicitingPacketAfterIdleSentTime time.Time
// pacingDeadline is the time when the next packet should be sent
pacingDeadline time.Time
peerParams *wire.TransportParameters
timer connectionTimer
// keepAlivePingSent stores whether a keep alive PING is in flight.
// It is reset as soon as we receive a packet from the peer.
keepAlivePingSent bool
keepAliveInterval time.Duration
datagramQueue *datagramQueue
connStateMutex sync.Mutex
connState ConnectionState
logID string
tracer *logging.ConnectionTracer
logger utils.Logger
}
var _ streamSender = &Conn{}
type connTestHooks struct {
run func() error
earlyConnReady func() <-chan struct{}
context func() context.Context
handshakeComplete func() <-chan struct{}
closeWithTransportError func(TransportErrorCode)
destroy func(error)
handlePacket func(receivedPacket)
}
type wrappedConn struct {
testHooks *connTestHooks
*Conn
}
var newConnection = func(
ctx context.Context,
ctxCancel context.CancelCauseFunc,
conn sendConn,
runner connRunner,
origDestConnID protocol.ConnectionID,
retrySrcConnID *protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
conf *Config,
tlsConf *tls.Config,
tokenGenerator *handshake.TokenGenerator,
clientAddressValidated bool,
rtt time.Duration,
tracer *logging.ConnectionTracer,
logger utils.Logger,
v protocol.Version,
) *wrappedConn {
s := &Conn{
ctx: ctx,
ctxCancel: ctxCancel,
conn: conn,
config: conf,
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
tokenGenerator: tokenGenerator,
oneRTTStream: newCryptoStream(),
perspective: protocol.PerspectiveServer,
tracer: tracer,
logger: logger,
version: v,
}
if origDestConnID.Len() > 0 {
s.logID = origDestConnID.String()
} else {
s.logID = destConnID.String()
}
s.connIDManager = newConnIDManager(
destConnID,
func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) },
runner.RemoveResetToken,
s.queueControlFrame,
)
s.connIDGenerator = newConnIDGenerator(
runner,
srcConnID,
&clientDestConnID,
statelessResetter,
connRunnerCallbacks{
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
RemoveConnectionID: runner.Remove,
ReplaceWithClosed: runner.ReplaceWithClosed,
},
s.queueControlFrame,
connIDGenerator,
)
s.preSetup()
s.rttStats.SetInitialRTT(rtt)
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
0,
protocol.ByteCount(s.config.InitialPacketSize),
s.rttStats,
clientAddressValidated,
s.conn.capabilities().ECN,
s.perspective,
s.tracer,
s.logger,
)
s.currentMTUEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
statelessResetToken := statelessResetter.GetStatelessResetToken(srcConnID)
params := &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
MaxIdleTimeout: s.config.MaxIdleTimeout,
MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams),
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
AckDelayExponent: protocol.AckDelayExponent,
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
StatelessResetToken: &statelessResetToken,
OriginalDestinationConnectionID: origDestConnID,
// For interoperability with quic-go versions before May 2023, this value must be set to a value
// different from protocol.DefaultActiveConnectionIDLimit.
// If set to the default value, it will be omitted from the transport parameters, which will make
// old quic-go versions interpret it as 0, instead of the default value of 2.
// See https://github.com/quic-go/quic-go/pull/3806.
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs,
InitialSourceConnectionID: srcConnID,
RetrySourceConnectionID: retrySrcConnID,
EnableResetStreamAt: conf.EnableStreamResetPartialDelivery,
}
if s.config.EnableDatagrams {
params.MaxDatagramFrameSize = wire.MaxDatagramSize
} else {
params.MaxDatagramFrameSize = protocol.InvalidByteCount
}
if s.tracer != nil && s.tracer.SentTransportParameters != nil {
s.tracer.SentTransportParameters(params)
}
cs := handshake.NewCryptoSetupServer(
clientDestConnID,
conn.LocalAddr(),
conn.RemoteAddr(),
params,
tlsConf,
conf.Allow0RTT,
s.rttStats,
tracer,
logger,
s.version,
)
s.cryptoStreamHandler = cs
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.cryptoStreamManager = newCryptoStreamManager(s.initialStream, s.handshakeStream, s.oneRTTStream)
return &wrappedConn{Conn: s}
}
// declare this as a variable, such that we can it mock it in the tests
var newClientConnection = func(
ctx context.Context,
conn sendConn,
runner connRunner,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
conf *Config,
tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber,
enable0RTT bool,
hasNegotiatedVersion bool,
tracer *logging.ConnectionTracer,
logger utils.Logger,
v protocol.Version,
) *wrappedConn {
s := &Conn{
conn: conn,
config: conf,
origDestConnID: destConnID,
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
perspective: protocol.PerspectiveClient,
logID: destConnID.String(),
logger: logger,
tracer: tracer,
versionNegotiated: hasNegotiatedVersion,
version: v,
}
s.connIDManager = newConnIDManager(
destConnID,
func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) },
runner.RemoveResetToken,
s.queueControlFrame,
)
s.connIDGenerator = newConnIDGenerator(
runner,
srcConnID,
nil,
statelessResetter,
connRunnerCallbacks{
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, s) },
RemoveConnectionID: runner.Remove,
ReplaceWithClosed: runner.ReplaceWithClosed,
},
s.queueControlFrame,
connIDGenerator,
)
s.ctx, s.ctxCancel = context.WithCancelCause(ctx)
s.preSetup()
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
initialPacketNumber,
protocol.ByteCount(s.config.InitialPacketSize),
s.rttStats,
false, // has no effect
s.conn.capabilities().ECN,
s.perspective,
s.tracer,
s.logger,
)
s.currentMTUEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
oneRTTStream := newCryptoStream()
params := &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
MaxIdleTimeout: s.config.MaxIdleTimeout,
MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams),
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
AckDelayExponent: protocol.AckDelayExponent,
// For interoperability with quic-go versions before May 2023, this value must be set to a value
// different from protocol.DefaultActiveConnectionIDLimit.
// If set to the default value, it will be omitted from the transport parameters, which will make
// old quic-go versions interpret it as 0, instead of the default value of 2.
// See https://github.com/quic-go/quic-go/pull/3806.
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs,
InitialSourceConnectionID: srcConnID,
EnableResetStreamAt: conf.EnableStreamResetPartialDelivery,
}
if s.config.EnableDatagrams {
params.MaxDatagramFrameSize = wire.MaxDatagramSize
} else {
params.MaxDatagramFrameSize = protocol.InvalidByteCount
}
if s.tracer != nil && s.tracer.SentTransportParameters != nil {
s.tracer.SentTransportParameters(params)
}
cs := handshake.NewCryptoSetupClient(
destConnID,
params,
tlsConf,
enable0RTT,
s.rttStats,
tracer,
logger,
s.version,
)
s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(s.initialStream, s.handshakeStream, oneRTTStream)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
if len(tlsConf.ServerName) > 0 {
s.tokenStoreKey = tlsConf.ServerName
} else {
s.tokenStoreKey = conn.RemoteAddr().String()
}
if s.config.TokenStore != nil {
if token := s.config.TokenStore.Pop(s.tokenStoreKey); token != nil {
s.packer.SetToken(token.data)
s.rttStats.SetInitialRTT(token.rtt)
}
}
return &wrappedConn{Conn: s}
}
func (c *Conn) preSetup() {
c.largestRcvdAppData = protocol.InvalidPacketNumber
c.initialStream = newInitialCryptoStream(c.perspective == protocol.PerspectiveClient)
c.handshakeStream = newCryptoStream()
c.sendQueue = newSendQueue(c.conn)
c.retransmissionQueue = newRetransmissionQueue()
c.frameParser = *wire.NewFrameParser(
c.config.EnableDatagrams,
c.config.EnableStreamResetPartialDelivery,
false, // ACK_FREQUENCY is not supported yet
)
c.rttStats = &utils.RTTStats{}
c.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.ByteCount(c.config.InitialConnectionReceiveWindow),
protocol.ByteCount(c.config.MaxConnectionReceiveWindow),
func(size protocol.ByteCount) bool {
if c.config.AllowConnectionWindowIncrease == nil {
return true
}
return c.config.AllowConnectionWindowIncrease(c, uint64(size))
},
c.rttStats,
c.logger,
)
c.earlyConnReadyChan = make(chan struct{})
c.streamsMap = newStreamsMap(
c.ctx,
c,
c.queueControlFrame,
c.newFlowController,
uint64(c.config.MaxIncomingStreams),
uint64(c.config.MaxIncomingUniStreams),
c.perspective,
)
c.framer = newFramer(c.connFlowController)
c.receivedPackets.Init(8)
c.notifyReceivedPacket = make(chan struct{}, 1)
c.closeChan = make(chan struct{}, 1)
c.sendingScheduled = make(chan struct{}, 1)
c.handshakeCompleteChan = make(chan struct{})
now := time.Now()
c.lastPacketReceivedTime = now
c.creationTime = now
c.datagramQueue = newDatagramQueue(c.scheduleSending, c.logger)
c.connState.Version = c.version
}
// run the connection main loop
func (c *Conn) run() (err error) {
defer func() { c.ctxCancel(err) }()
defer func() {
// drain queued packets that will never be processed
c.receivedPacketMx.Lock()
defer c.receivedPacketMx.Unlock()
for !c.receivedPackets.Empty() {
p := c.receivedPackets.PopFront()
p.buffer.Decrement()
p.buffer.MaybeRelease()
}
}()
c.timer = *newTimer()
if err := c.cryptoStreamHandler.StartHandshake(c.ctx); err != nil {
return err
}
if err := c.handleHandshakeEvents(time.Now()); err != nil {
return err
}
go func() {
if err := c.sendQueue.Run(); err != nil {
c.destroyImpl(err)
}
}()
if c.perspective == protocol.PerspectiveClient {
c.scheduleSending() // so the ClientHello actually gets sent
}
var sendQueueAvailable <-chan struct{}
runLoop:
for {
if c.framer.QueuedTooManyControlFrames() {
c.setCloseError(&closeError{err: &qerr.TransportError{ErrorCode: InternalError}})
break runLoop
}
// Close immediately if requested
select {
case <-c.closeChan:
break runLoop
default:
}
// no need to set a timer if we can send packets immediately
if c.pacingDeadline != deadlineSendImmediately {
c.maybeResetTimer()
}
// 1st: handle undecryptable packets, if any.
// This can only occur before completion of the handshake.
if len(c.undecryptablePacketsToProcess) > 0 {
var processedUndecryptablePacket bool
queue := c.undecryptablePacketsToProcess
c.undecryptablePacketsToProcess = nil
for _, p := range queue {
processed, err := c.handleOnePacket(p)
if err != nil {
c.setCloseError(&closeError{err: err})
break runLoop
}
if processed {
processedUndecryptablePacket = true
}
}
if processedUndecryptablePacket {
// if we processed any undecryptable packets, jump to the resetting of the timers directly
continue
}
}
// 2nd: receive packets.
processed, err := c.handlePackets() // don't check receivedPackets.Len() in the run loop to avoid locking the mutex
if err != nil {
c.setCloseError(&closeError{err: err})
break runLoop
}
// We don't need to wait for new events if:
// * we processed packets: we probably need to send an ACK, and potentially more data
// * the pacer allows us to send more packets immediately
shouldProceedImmediately := sendQueueAvailable == nil && (processed || c.pacingDeadline.Equal(deadlineSendImmediately))
if !shouldProceedImmediately {
// 3rd: wait for something to happen:
// * closing of the connection
// * timer firing
// * sending scheduled
// * send queue available
// * received packets
select {
case <-c.closeChan:
break runLoop
case <-c.timer.Chan():
c.timer.SetRead()
case <-c.sendingScheduled:
case <-sendQueueAvailable:
case <-c.notifyReceivedPacket:
wasProcessed, err := c.handlePackets()
if err != nil {
c.setCloseError(&closeError{err: err})
break runLoop
}
// if we processed any undecryptable packets, jump to the resetting of the timers directly
if !wasProcessed {
continue
}
}
}
// Check for loss detection timeout.
// This could cause packets to be declared lost, and retransmissions to be enqueued.
now := time.Now()
if timeout := c.sentPacketHandler.GetLossDetectionTimeout(); !timeout.IsZero() && timeout.Before(now) {
if err := c.sentPacketHandler.OnLossDetectionTimeout(now); err != nil {
c.setCloseError(&closeError{err: err})
break runLoop
}
}
if keepAliveTime := c.nextKeepAliveTime(); !keepAliveTime.IsZero() && !now.Before(keepAliveTime) {
// send a PING frame since there is no activity in the connection
c.logger.Debugf("Sending a keep-alive PING to keep the connection alive.")
c.framer.QueueControlFrame(&wire.PingFrame{})
c.keepAlivePingSent = true
} else if !c.handshakeComplete && now.Sub(c.creationTime) >= c.config.handshakeTimeout() {
c.destroyImpl(qerr.ErrHandshakeTimeout)
break runLoop
} else {
idleTimeoutStartTime := c.idleTimeoutStartTime()
if (!c.handshakeComplete && now.Sub(idleTimeoutStartTime) >= c.config.HandshakeIdleTimeout) ||
(c.handshakeComplete && now.After(c.nextIdleTimeoutTime())) {
c.destroyImpl(qerr.ErrIdleTimeout)
break runLoop
}
}
c.connIDGenerator.RemoveRetiredConnIDs(now)
if c.perspective == protocol.PerspectiveClient {
pm := c.pathManagerOutgoing.Load()
if pm != nil {
tr, ok := pm.ShouldSwitchPath()
if ok {
c.switchToNewPath(tr, now)
}
}
}
if c.sendQueue.WouldBlock() {
// The send queue is still busy sending out packets. Wait until there's space to enqueue new packets.
sendQueueAvailable = c.sendQueue.Available()
// Cancel the pacing timer, as we can't send any more packets until the send queue is available again.
c.pacingDeadline = time.Time{}
continue
}
if c.closeErr.Load() != nil {
break runLoop
}
if err := c.triggerSending(now); err != nil {
c.setCloseError(&closeError{err: err})
break runLoop
}
if c.sendQueue.WouldBlock() {
// The send queue is still busy sending out packets. Wait until there's space to enqueue new packets.
sendQueueAvailable = c.sendQueue.Available()
// Cancel the pacing timer, as we can't send any more packets until the send queue is available again.
c.pacingDeadline = time.Time{}
} else {
sendQueueAvailable = nil
}
}
closeErr := c.closeErr.Load()
c.cryptoStreamHandler.Close()
c.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE
c.handleCloseError(closeErr)
if c.tracer != nil && c.tracer.Close != nil {
if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) {
c.tracer.Close()
}
}
c.logger.Infof("Connection %s closed.", c.logID)
c.timer.Stop()
return closeErr.err
}
// blocks until the early connection can be used
func (c *Conn) earlyConnReady() <-chan struct{} {
return c.earlyConnReadyChan
}
// Context returns a context that is cancelled when the connection is closed.
// The cancellation cause is set to the error that caused the connection to close.
func (c *Conn) Context() context.Context {
return c.ctx
}
func (c *Conn) supportsDatagrams() bool {
return c.peerParams.MaxDatagramFrameSize > 0
}
// ConnectionState returns basic details about the QUIC connection.
func (c *Conn) ConnectionState() ConnectionState {
c.connStateMutex.Lock()
defer c.connStateMutex.Unlock()
cs := c.cryptoStreamHandler.ConnectionState()
c.connState.TLS = cs.ConnectionState
c.connState.Used0RTT = cs.Used0RTT
c.connState.SupportsStreamResetPartialDelivery = c.peerParams.EnableResetStreamAt
c.connState.GSO = c.conn.capabilities().GSO
return c.connState
}
// Time when the connection should time out
func (c *Conn) nextIdleTimeoutTime() time.Time {
idleTimeout := max(c.idleTimeout, c.rttStats.PTO(true)*3)
return c.idleTimeoutStartTime().Add(idleTimeout)
}
// Time when the next keep-alive packet should be sent.
// It returns a zero time if no keep-alive should be sent.
func (c *Conn) nextKeepAliveTime() time.Time {
if c.config.KeepAlivePeriod == 0 || c.keepAlivePingSent {
return time.Time{}
}
keepAliveInterval := max(c.keepAliveInterval, c.rttStats.PTO(true)*3/2)
return c.lastPacketReceivedTime.Add(keepAliveInterval)
}
func (c *Conn) maybeResetTimer() {
var deadline time.Time
if !c.handshakeComplete {
deadline = c.creationTime.Add(c.config.handshakeTimeout())
if t := c.idleTimeoutStartTime().Add(c.config.HandshakeIdleTimeout); t.Before(deadline) {
deadline = t
}
} else {
if keepAliveTime := c.nextKeepAliveTime(); !keepAliveTime.IsZero() {
deadline = keepAliveTime
} else {
deadline = c.nextIdleTimeoutTime()
}
}
c.timer.SetTimer(
deadline,
c.connIDGenerator.NextRetireTime(),
c.receivedPacketHandler.GetAlarmTimeout(),
c.sentPacketHandler.GetLossDetectionTimeout(),
c.pacingDeadline,
)
}
func (c *Conn) idleTimeoutStartTime() time.Time {
startTime := c.lastPacketReceivedTime
if t := c.firstAckElicitingPacketAfterIdleSentTime; t.After(startTime) {
startTime = t
}
return startTime
}
func (c *Conn) switchToNewPath(tr *Transport, now time.Time) {
initialPacketSize := protocol.ByteCount(c.config.InitialPacketSize)
c.sentPacketHandler.MigratedPath(now, initialPacketSize)
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
if c.peerParams.MaxUDPPayloadSize > 0 && c.peerParams.MaxUDPPayloadSize < maxPacketSize {
maxPacketSize = c.peerParams.MaxUDPPayloadSize
}
c.mtuDiscoverer.Reset(now, initialPacketSize, maxPacketSize)
c.conn = newSendConn(tr.conn, c.conn.RemoteAddr(), packetInfo{}, utils.DefaultLogger) // TODO: find a better way
c.sendQueue.Close()
c.sendQueue = newSendQueue(c.conn)
go func() {
if err := c.sendQueue.Run(); err != nil {
c.destroyImpl(err)
}
}()
}
func (c *Conn) handleHandshakeComplete(now time.Time) error {
defer close(c.handshakeCompleteChan)
// Once the handshake completes, we have derived 1-RTT keys.
// There's no point in queueing undecryptable packets for later decryption anymore.
c.undecryptablePackets = nil
c.connIDManager.SetHandshakeComplete()
c.connIDGenerator.SetHandshakeComplete(now.Add(3 * c.rttStats.PTO(false)))
if c.tracer != nil && c.tracer.ChoseALPN != nil {
c.tracer.ChoseALPN(c.cryptoStreamHandler.ConnectionState().NegotiatedProtocol)
}
// The server applies transport parameters right away, but the client side has to wait for handshake completion.
// During a 0-RTT connection, the client is only allowed to use the new transport parameters for 1-RTT packets.
if c.perspective == protocol.PerspectiveClient {
c.applyTransportParameters()
return nil
}
// All these only apply to the server side.
if err := c.handleHandshakeConfirmed(now); err != nil {
return err
}
ticket, err := c.cryptoStreamHandler.GetSessionTicket()
if err != nil {
return err
}
if ticket != nil { // may be nil if session tickets are disabled via tls.Config.SessionTicketsDisabled
c.oneRTTStream.Write(ticket)
for c.oneRTTStream.HasData() {
if cf := c.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize); cf != nil {
c.queueControlFrame(cf)
}
}
}
token, err := c.tokenGenerator.NewToken(c.conn.RemoteAddr(), c.rttStats.SmoothedRTT())
if err != nil {
return err
}
c.queueControlFrame(&wire.NewTokenFrame{Token: token})
c.queueControlFrame(&wire.HandshakeDoneFrame{})
return nil
}
func (c *Conn) handleHandshakeConfirmed(now time.Time) error {
if err := c.dropEncryptionLevel(protocol.EncryptionHandshake, now); err != nil {
return err
}
c.handshakeConfirmed = true
c.cryptoStreamHandler.SetHandshakeConfirmed()
if !c.config.DisablePathMTUDiscovery && c.conn.capabilities().DF {
c.mtuDiscoverer.Start(now)
}
return nil
}
func (c *Conn) handlePackets() (wasProcessed bool, _ error) {
// Now process all packets in the receivedPackets channel.
// Limit the number of packets to the length of the receivedPackets channel,
// so we eventually get a chance to send out an ACK when receiving a lot of packets.
c.receivedPacketMx.Lock()
numPackets := c.receivedPackets.Len()
if numPackets == 0 {
c.receivedPacketMx.Unlock()
return false, nil
}
var hasMorePackets bool
for i := 0; i < numPackets; i++ {
if i > 0 {
c.receivedPacketMx.Lock()
}
p := c.receivedPackets.PopFront()
hasMorePackets = !c.receivedPackets.Empty()
c.receivedPacketMx.Unlock()
processed, err := c.handleOnePacket(p)
if err != nil {
return false, err
}
if processed {
wasProcessed = true
}
if !hasMorePackets {
break
}
// only process a single packet at a time before handshake completion
if !c.handshakeComplete {
break
}
}
if hasMorePackets {
select {
case c.notifyReceivedPacket <- struct{}{}:
default:
}
}
return wasProcessed, nil
}
func (c *Conn) handleOnePacket(rp receivedPacket) (wasProcessed bool, _ error) {
c.sentPacketHandler.ReceivedBytes(rp.Size(), rp.rcvTime)
if wire.IsVersionNegotiationPacket(rp.data) {
c.handleVersionNegotiationPacket(rp)
return false, nil
}
var counter uint8
var lastConnID protocol.ConnectionID
data := rp.data
p := rp
for len(data) > 0 {
if counter > 0 {
p = *(p.Clone())
p.data = data
destConnID, err := wire.ParseConnectionID(p.data, c.srcConnIDLen)
if err != nil {
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError)
}
c.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err)
break
}
if destConnID != lastConnID {
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID)
}
c.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID)
break
}
}
if wire.IsLongHeaderPacket(p.data[0]) {
hdr, packetData, rest, err := wire.ParsePacket(p.data)
if err != nil {
if c.tracer != nil && c.tracer.DroppedPacket != nil {
dropReason := logging.PacketDropHeaderParseError
if err == wire.ErrUnsupportedVersion {
dropReason = logging.PacketDropUnsupportedVersion
}
c.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), dropReason)
}
c.logger.Debugf("error parsing packet: %s", err)
break
}
lastConnID = hdr.DestConnectionID
if hdr.Version != c.version {
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion)
}
c.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, c.version)
break
}
if counter > 0 {
p.buffer.Split()
}
counter++
// only log if this actually a coalesced packet
if c.logger.Debug() && (counter > 1 || len(rest) > 0) {
c.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest))
}
p.data = packetData
processed, err := c.handleLongHeaderPacket(p, hdr)
if err != nil {
return false, err
}
if processed {
wasProcessed = true
}
data = rest
} else {
if counter > 0 {
p.buffer.Split()
}
processed, err := c.handleShortHeaderPacket(p, counter > 0)
if err != nil {
return false, err
}
if processed {
wasProcessed = true
}
break
}
}
p.buffer.MaybeRelease()
return wasProcessed, nil
}
func (c *Conn) handleShortHeaderPacket(p receivedPacket, isCoalesced bool) (wasProcessed bool, _ error) {
var wasQueued bool
defer func() {
// Put back the packet buffer if the packet wasn't queued for later decryption.
if !wasQueued {
p.buffer.Decrement()
}
}()
destConnID, err := wire.ParseConnectionID(p.data, c.srcConnIDLen)
if err != nil {
c.tracer.DroppedPacket(logging.PacketType1RTT, protocol.InvalidPacketNumber, protocol.ByteCount(len(p.data)), logging.PacketDropHeaderParseError)
return false, nil
}
pn, pnLen, keyPhase, data, err := c.unpacker.UnpackShortHeader(p.rcvTime, p.data)
if err != nil {
// Stateless reset packets (see RFC 9000, section 10.3):
// * fill the entire UDP datagram (i.e. they cannot be part of a coalesced packet)
// * are short header packets (first bit is 0)
// * have the QUIC bit set (second bit is 1)
// * are at least 21 bytes long
if !isCoalesced && len(p.data) >= protocol.MinReceivedStatelessResetSize && p.data[0]&0b11000000 == 0b01000000 {
token := protocol.StatelessResetToken(p.data[len(p.data)-16:])
if c.connIDManager.IsActiveStatelessResetToken(token) {
return false, &StatelessResetError{}
}
}
wasQueued, err = c.handleUnpackError(err, p, logging.PacketType1RTT)
return false, err
}
c.largestRcvdAppData = max(c.largestRcvdAppData, pn)
if c.logger.Debug() {
c.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", pn, p.Size(), destConnID)
wire.LogShortHeader(c.logger, destConnID, pn, pnLen, keyPhase)
}
if c.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) {
c.logger.Debugf("Dropping (potentially) duplicate packet.")
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(logging.PacketType1RTT, pn, p.Size(), logging.PacketDropDuplicate)
}
return false, nil
}
var log func([]logging.Frame)
if c.tracer != nil && c.tracer.ReceivedShortHeaderPacket != nil {
log = func(frames []logging.Frame) {
c.tracer.ReceivedShortHeaderPacket(
&logging.ShortHeader{
DestConnectionID: destConnID,
PacketNumber: pn,
PacketNumberLen: pnLen,
KeyPhase: keyPhase,
},
p.Size(),
p.ecn,
frames,
)
}
}
isNonProbing, pathChallenge, err := c.handleUnpackedShortHeaderPacket(destConnID, pn, data, p.ecn, p.rcvTime, log)
if err != nil {
return false, err
}
// In RFC 9000, only the client can migrate between paths.
if c.perspective == protocol.PerspectiveClient {
return true, nil
}
if addrsEqual(p.remoteAddr, c.RemoteAddr()) {
return true, nil
}
var shouldSwitchPath bool
if c.pathManager == nil {
c.pathManager = newPathManager(
c.connIDManager.GetConnIDForPath,
c.connIDManager.RetireConnIDForPath,
c.logger,
)
}
destConnID, frames, shouldSwitchPath := c.pathManager.HandlePacket(p.remoteAddr, p.rcvTime, pathChallenge, isNonProbing)
if len(frames) > 0 {
probe, buf, err := c.packer.PackPathProbePacket(destConnID, frames, c.version)
if err != nil {
return true, err
}
c.logger.Debugf("sending path probe packet to %s", p.remoteAddr)
c.logShortHeaderPacket(probe.DestConnID, probe.Ack, probe.Frames, probe.StreamFrames, probe.PacketNumber, probe.PacketNumberLen, probe.KeyPhase, protocol.ECNNon, buf.Len(), false)
c.registerPackedShortHeaderPacket(probe, protocol.ECNNon, p.rcvTime)
c.sendQueue.SendProbe(buf, p.remoteAddr)
}
// We only switch paths in response to the highest-numbered non-probing packet,
// see section 9.3 of RFC 9000.
if !shouldSwitchPath || pn != c.largestRcvdAppData {
return true, nil
}
c.pathManager.SwitchToPath(p.remoteAddr)
c.sentPacketHandler.MigratedPath(p.rcvTime, protocol.ByteCount(c.config.InitialPacketSize))
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
if c.peerParams.MaxUDPPayloadSize > 0 && c.peerParams.MaxUDPPayloadSize < maxPacketSize {
maxPacketSize = c.peerParams.MaxUDPPayloadSize
}
c.mtuDiscoverer.Reset(
p.rcvTime,
protocol.ByteCount(c.config.InitialPacketSize),
maxPacketSize,
)
c.conn.ChangeRemoteAddr(p.remoteAddr, p.info)
return true, nil
}
func (c *Conn) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) (wasProcessed bool, _ error) {
var wasQueued bool
defer func() {
// Put back the packet buffer if the packet wasn't queued for later decryption.
if !wasQueued {
p.buffer.Decrement()
}
}()
if hdr.Type == protocol.PacketTypeRetry {
return c.handleRetryPacket(hdr, p.data, p.rcvTime), nil
}
// The server can change the source connection ID with the first Handshake packet.
// After this, all packets with a different source connection have to be ignored.
if c.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != c.handshakeDestConnID {
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(logging.PacketTypeInitial, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnknownConnectionID)
}
c.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, c.handshakeDestConnID)
return false, nil
}
// drop 0-RTT packets, if we are a client
if c.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT {
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false, nil
}
packet, err := c.unpacker.UnpackLongHeader(hdr, p.data)
if err != nil {
wasQueued, err = c.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr))
return false, err
}
if c.logger.Debug() {
c.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, %s", packet.hdr.PacketNumber, p.Size(), hdr.DestConnectionID, packet.encryptionLevel)
packet.hdr.Log(c.logger)
}
if pn := packet.hdr.PacketNumber; c.receivedPacketHandler.IsPotentiallyDuplicate(pn, packet.encryptionLevel) {
c.logger.Debugf("Dropping (potentially) duplicate packet.")
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), pn, p.Size(), logging.PacketDropDuplicate)
}
return false, nil
}
if err := c.handleUnpackedLongHeaderPacket(packet, p.ecn, p.rcvTime, p.Size()); err != nil {
return false, err
}
return true, nil
}
func (c *Conn) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool, _ error) {
switch err {
case handshake.ErrKeysDropped:
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropKeyUnavailable)
}
c.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size())
return false, nil
case handshake.ErrKeysNotYetAvailable:
// Sealer for this encryption level not yet available.
// Try again later.
c.tryQueueingUndecryptablePacket(p, pt)
return true, nil
case wire.ErrInvalidReservedBits:
return false, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: err.Error(),
}
case handshake.ErrDecryptionFailed:
// This might be a packet injected by an attacker. Drop it.
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropPayloadDecryptError)
}
c.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err)
return false, nil
default:
var headerErr *headerParseError
if errors.As(err, &headerErr) {
// This might be a packet injected by an attacker. Drop it.
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError)
}
c.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err)
return false, nil
}
// This is an error returned by the AEAD (other than ErrDecryptionFailed).
// For example, a PROTOCOL_VIOLATION due to key updates.
return false, err
}
}
func (c *Conn) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime time.Time) bool /* was this a valid Retry */ {
if c.perspective == protocol.PerspectiveServer {
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket)
}
c.logger.Debugf("Ignoring Retry.")
return false
}
if c.receivedFirstPacket {
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket)
}
c.logger.Debugf("Ignoring Retry, since we already received a packet.")
return false
}
destConnID := c.connIDManager.Get()
if hdr.SrcConnectionID == destConnID {
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket)
}
c.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.")
return false
}
// If a token is already set, this means that we already received a Retry from the server.
// Ignore this Retry packet.
if c.receivedRetry {
c.logger.Debugf("Ignoring Retry, since a Retry was already received.")
return false
}
tag := handshake.GetRetryIntegrityTag(data[:len(data)-16], destConnID, hdr.Version)
if !bytes.Equal(data[len(data)-16:], tag[:]) {
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError)
}
c.logger.Debugf("Ignoring spoofed Retry. Integrity Tag doesn't match.")
return false
}
newDestConnID := hdr.SrcConnectionID
c.receivedRetry = true
c.sentPacketHandler.ResetForRetry(rcvTime)
c.handshakeDestConnID = newDestConnID
c.retrySrcConnID = &newDestConnID
c.cryptoStreamHandler.ChangeConnectionID(newDestConnID)
c.packer.SetToken(hdr.Token)
c.connIDManager.ChangeInitialConnID(newDestConnID)
if c.logger.Debug() {
c.logger.Debugf("<- Received Retry:")
(&wire.ExtendedHeader{Header: *hdr}).Log(c.logger)
c.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID)
}
if c.tracer != nil && c.tracer.ReceivedRetry != nil {
c.tracer.ReceivedRetry(hdr)
}
c.scheduleSending()
return true
}
func (c *Conn) handleVersionNegotiationPacket(p receivedPacket) {
if c.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets
c.receivedFirstPacket || c.versionNegotiated { // ignore delayed / duplicated version negotiation packets
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
}
return
}
src, dest, supportedVersions, err := wire.ParseVersionNegotiationPacket(p.data)
if err != nil {
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError)
}
c.logger.Debugf("Error parsing Version Negotiation packet: %s", err)
return
}
if slices.Contains(supportedVersions, c.version) {
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedVersion)
}
// The Version Negotiation packet contains the version that we offered.
// This might be a packet sent by an attacker, or it was corrupted.
return
}
c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions)
if c.tracer != nil && c.tracer.ReceivedVersionNegotiationPacket != nil {
c.tracer.ReceivedVersionNegotiationPacket(dest, src, supportedVersions)
}
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, supportedVersions)
if !ok {
c.destroyImpl(&VersionNegotiationError{
Ours: c.config.Versions,
Theirs: supportedVersions,
})
c.logger.Infof("No compatible QUIC version found.")
return
}
if c.tracer != nil && c.tracer.NegotiatedVersion != nil {
c.tracer.NegotiatedVersion(newVersion, c.config.Versions, supportedVersions)
}
c.logger.Infof("Switching to QUIC version %s.", newVersion)
nextPN, _ := c.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial)
c.destroyImpl(&errCloseForRecreating{
nextPacketNumber: nextPN,
nextVersion: newVersion,
})
}
func (c *Conn) handleUnpackedLongHeaderPacket(
packet *unpackedPacket,
ecn protocol.ECN,
rcvTime time.Time,
packetSize protocol.ByteCount, // only for logging
) error {
if !c.receivedFirstPacket {
c.receivedFirstPacket = true
if !c.versionNegotiated && c.tracer != nil && c.tracer.NegotiatedVersion != nil {
var clientVersions, serverVersions []protocol.Version
switch c.perspective {
case protocol.PerspectiveClient:
clientVersions = c.config.Versions
case protocol.PerspectiveServer:
serverVersions = c.config.Versions
}
c.tracer.NegotiatedVersion(c.version, clientVersions, serverVersions)
}
// The server can change the source connection ID with the first Handshake packet.
if c.perspective == protocol.PerspectiveClient && packet.hdr.SrcConnectionID != c.handshakeDestConnID {
cid := packet.hdr.SrcConnectionID
c.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid)
c.handshakeDestConnID = cid
c.connIDManager.ChangeInitialConnID(cid)
}
// We create the connection as soon as we receive the first packet from the client.
// We do that before authenticating the packet.
// That means that if the source connection ID was corrupted,
// we might have created a connection with an incorrect source connection ID.
// Once we authenticate the first packet, we need to update it.
if c.perspective == protocol.PerspectiveServer {
if packet.hdr.SrcConnectionID != c.handshakeDestConnID {
c.handshakeDestConnID = packet.hdr.SrcConnectionID
c.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID)
}
if c.tracer != nil && c.tracer.StartedConnection != nil {
c.tracer.StartedConnection(
c.conn.LocalAddr(),
c.conn.RemoteAddr(),
packet.hdr.SrcConnectionID,
packet.hdr.DestConnectionID,
)
}
}
}
if c.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake &&
!c.droppedInitialKeys {
// On the server side, Initial keys are dropped as soon as the first Handshake packet is received.
// See Section 4.9.1 of RFC 9001.
if err := c.dropEncryptionLevel(protocol.EncryptionInitial, rcvTime); err != nil {
return err
}
}
c.lastPacketReceivedTime = rcvTime
c.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
c.keepAlivePingSent = false
if packet.hdr.Type == protocol.PacketType0RTT {
c.largestRcvdAppData = max(c.largestRcvdAppData, packet.hdr.PacketNumber)
}
var log func([]logging.Frame)
if c.tracer != nil && c.tracer.ReceivedLongHeaderPacket != nil {
log = func(frames []logging.Frame) {
c.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, ecn, frames)
}
}
isAckEliciting, _, _, err := c.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log, rcvTime)
if err != nil {
return err
}
return c.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting)
}
func (c *Conn) handleUnpackedShortHeaderPacket(
destConnID protocol.ConnectionID,
pn protocol.PacketNumber,
data []byte,
ecn protocol.ECN,
rcvTime time.Time,
log func([]logging.Frame),
) (isNonProbing bool, pathChallenge *wire.PathChallengeFrame, _ error) {
c.lastPacketReceivedTime = rcvTime
c.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
c.keepAlivePingSent = false
isAckEliciting, isNonProbing, pathChallenge, err := c.handleFrames(data, destConnID, protocol.Encryption1RTT, log, rcvTime)
if err != nil {
return false, nil, err
}
if err := c.receivedPacketHandler.ReceivedPacket(pn, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting); err != nil {
return false, nil, err
}
return isNonProbing, pathChallenge, nil
}
// handleFrames parses the frames, one after the other, and handles them.
// It returns the last PATH_CHALLENGE frame contained in the packet, if any.
func (c *Conn) handleFrames(
data []byte,
destConnID protocol.ConnectionID,
encLevel protocol.EncryptionLevel,
log func([]logging.Frame),
rcvTime time.Time,
) (isAckEliciting, isNonProbing bool, pathChallenge *wire.PathChallengeFrame, _ error) {
// Only used for tracing.
// If we're not tracing, this slice will always remain empty.
var frames []logging.Frame
if log != nil {
frames = make([]logging.Frame, 0, 4)
}
handshakeWasComplete := c.handshakeComplete
var handleErr error
var skipHandling bool
for len(data) > 0 {
frameType, l, err := c.frameParser.ParseType(data, encLevel)
if err != nil {
// The frame parser skips over PADDING frames, and returns an io.EOF if the PADDING
// frames were the last frames in this packet.
if err == io.EOF {
break
}
return false, false, nil, err
}
data = data[l:]
if ackhandler.IsFrameTypeAckEliciting(frameType) {
isAckEliciting = true
}
if !wire.IsProbingFrameType(frameType) {
isNonProbing = true
}
// We're inlining common cases, to avoid using interfaces
// Fast path: STREAM, DATAGRAM and ACK
if frameType.IsStreamFrameType() {
streamFrame, l, err := c.frameParser.ParseStreamFrame(frameType, data, c.version)
if err != nil {
return false, false, nil, err
}
data = data[l:]
if log != nil {
frames = append(frames, toLoggingFrame(streamFrame))
}
// an error occurred handling a previous frame, don't handle the current frame
if skipHandling {
continue
}
handleErr = c.streamsMap.HandleStreamFrame(streamFrame, rcvTime)
} else if frameType.IsAckFrameType() {
ackFrame, l, err := c.frameParser.ParseAckFrame(frameType, data, encLevel, c.version)
if err != nil {
return false, false, nil, err
}
data = data[l:]
if log != nil {
frames = append(frames, toLoggingFrame(ackFrame))
}
// an error occurred handling a previous frame, don't handle the current frame
if skipHandling {
continue
}
handleErr = c.handleAckFrame(ackFrame, encLevel, rcvTime)
} else if frameType.IsDatagramFrameType() {
datagramFrame, l, err := c.frameParser.ParseDatagramFrame(frameType, data, c.version)
if err != nil {
return false, false, nil, err
}
data = data[l:]
if log != nil {
frames = append(frames, toLoggingFrame(datagramFrame))
}
// an error occurred handling a previous frame, don't handle the current frame
if skipHandling {
continue
}
handleErr = c.handleDatagramFrame(datagramFrame)
} else {
frame, l, err := c.frameParser.ParseLessCommonFrame(frameType, data, c.version)
if err != nil {
return false, false, nil, err
}
data = data[l:]
if log != nil {
frames = append(frames, toLoggingFrame(frame))
}
// an error occurred handling a previous frame, don't handle the current frame
if skipHandling {
continue
}
pc, err := c.handleFrame(frame, encLevel, destConnID, rcvTime)
if pc != nil {
pathChallenge = pc
}
handleErr = err
}
if handleErr != nil {
// if we're logging, we need to keep parsing (but not handling) all frames
skipHandling = true
if log == nil {
return false, false, nil, handleErr
}
}
}
if log != nil {
log(frames)
if handleErr != nil {
return false, false, nil, handleErr
}
}
// Handle completion of the handshake after processing all the frames.
// This ensures that we correctly handle the following case on the server side:
// We receive a Handshake packet that contains the CRYPTO frame that allows us to complete the handshake,
// and an ACK serialized after that CRYPTO frame. In this case, we still want to process the ACK frame.
if !handshakeWasComplete && c.handshakeComplete {
if err := c.handleHandshakeComplete(rcvTime); err != nil {
return false, false, nil, err
}
}
return
}
func (c *Conn) handleFrame(
f wire.Frame,
encLevel protocol.EncryptionLevel,
destConnID protocol.ConnectionID,
rcvTime time.Time,
) (pathChallenge *wire.PathChallengeFrame, _ error) {
var err error
wire.LogFrame(c.logger, f, false)
switch frame := f.(type) {
case *wire.CryptoFrame:
err = c.handleCryptoFrame(frame, encLevel, rcvTime)
case *wire.ConnectionCloseFrame:
err = c.handleConnectionCloseFrame(frame)
case *wire.ResetStreamFrame:
err = c.streamsMap.HandleResetStreamFrame(frame, rcvTime)
case *wire.MaxDataFrame:
c.connFlowController.UpdateSendWindow(frame.MaximumData)
case *wire.MaxStreamDataFrame:
err = c.streamsMap.HandleMaxStreamDataFrame(frame)
case *wire.MaxStreamsFrame:
c.streamsMap.HandleMaxStreamsFrame(frame)
case *wire.DataBlockedFrame:
case *wire.StreamDataBlockedFrame:
err = c.streamsMap.HandleStreamDataBlockedFrame(frame)
case *wire.StreamsBlockedFrame:
case *wire.StopSendingFrame:
err = c.streamsMap.HandleStopSendingFrame(frame)
case *wire.PingFrame:
case *wire.PathChallengeFrame:
c.handlePathChallengeFrame(frame)
pathChallenge = frame
case *wire.PathResponseFrame:
err = c.handlePathResponseFrame(frame)
case *wire.NewTokenFrame:
err = c.handleNewTokenFrame(frame)
case *wire.NewConnectionIDFrame:
err = c.connIDManager.Add(frame)
case *wire.RetireConnectionIDFrame:
err = c.connIDGenerator.Retire(frame.SequenceNumber, destConnID, rcvTime.Add(3*c.rttStats.PTO(false)))
case *wire.HandshakeDoneFrame:
err = c.handleHandshakeDoneFrame(rcvTime)
default:
err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name())
}
return pathChallenge, err
}
// handlePacket is called by the server with a new packet
func (c *Conn) handlePacket(p receivedPacket) {
c.receivedPacketMx.Lock()
// Discard packets once the amount of queued packets is larger than
// the channel size, protocol.MaxConnUnprocessedPackets
if c.receivedPackets.Len() >= protocol.MaxConnUnprocessedPackets {
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropDOSPrevention)
}
c.receivedPacketMx.Unlock()
return
}
c.receivedPackets.PushBack(p)
c.receivedPacketMx.Unlock()
select {
case c.notifyReceivedPacket <- struct{}{}:
default:
}
}
func (c *Conn) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) error {
if frame.IsApplicationError {
return &qerr.ApplicationError{
Remote: true,
ErrorCode: qerr.ApplicationErrorCode(frame.ErrorCode),
ErrorMessage: frame.ReasonPhrase,
}
}
return &qerr.TransportError{
Remote: true,
ErrorCode: qerr.TransportErrorCode(frame.ErrorCode),
FrameType: frame.FrameType,
ErrorMessage: frame.ReasonPhrase,
}
}
func (c *Conn) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
if err := c.cryptoStreamManager.HandleCryptoFrame(frame, encLevel); err != nil {
return err
}
for {
data := c.cryptoStreamManager.GetCryptoData(encLevel)
if data == nil {
break
}
if err := c.cryptoStreamHandler.HandleMessage(data, encLevel); err != nil {
return err
}
}
return c.handleHandshakeEvents(rcvTime)
}
func (c *Conn) handleHandshakeEvents(now time.Time) error {
for {
ev := c.cryptoStreamHandler.NextEvent()
var err error
switch ev.Kind {
case handshake.EventNoEvent:
return nil
case handshake.EventHandshakeComplete:
// Don't call handleHandshakeComplete yet.
// It's advantageous to process ACK frames that might be serialized after the CRYPTO frame first.
c.handshakeComplete = true
case handshake.EventReceivedTransportParameters:
err = c.handleTransportParameters(ev.TransportParameters)
case handshake.EventRestoredTransportParameters:
c.restoreTransportParameters(ev.TransportParameters)
close(c.earlyConnReadyChan)
case handshake.EventReceivedReadKeys:
// queue all previously undecryptable packets
c.undecryptablePacketsToProcess = append(c.undecryptablePacketsToProcess, c.undecryptablePackets...)
c.undecryptablePackets = nil
case handshake.EventDiscard0RTTKeys:
err = c.dropEncryptionLevel(protocol.Encryption0RTT, now)
case handshake.EventWriteInitialData:
_, err = c.initialStream.Write(ev.Data)
case handshake.EventWriteHandshakeData:
_, err = c.handshakeStream.Write(ev.Data)
}
if err != nil {
return err
}
}
}
func (c *Conn) handlePathChallengeFrame(f *wire.PathChallengeFrame) {
if c.perspective == protocol.PerspectiveClient {
c.queueControlFrame(&wire.PathResponseFrame{Data: f.Data})
}
}
func (c *Conn) handlePathResponseFrame(f *wire.PathResponseFrame) error {
switch c.perspective {
case protocol.PerspectiveClient:
return c.handlePathResponseFrameClient(f)
case protocol.PerspectiveServer:
return c.handlePathResponseFrameServer(f)
default:
panic("unreachable")
}
}
func (c *Conn) handlePathResponseFrameClient(f *wire.PathResponseFrame) error {
pm := c.pathManagerOutgoing.Load()
if pm == nil {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "unexpected PATH_RESPONSE frame",
}
}
pm.HandlePathResponseFrame(f)
return nil
}
func (c *Conn) handlePathResponseFrameServer(f *wire.PathResponseFrame) error {
if c.pathManager == nil {
// since we didn't send PATH_CHALLENGEs yet, we don't expect PATH_RESPONSEs
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "unexpected PATH_RESPONSE frame",
}
}
c.pathManager.HandlePathResponseFrame(f)
return nil
}
func (c *Conn) handleNewTokenFrame(frame *wire.NewTokenFrame) error {
if c.perspective == protocol.PerspectiveServer {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "received NEW_TOKEN frame from the client",
}
}
if c.config.TokenStore != nil {
c.config.TokenStore.Put(c.tokenStoreKey, &ClientToken{data: frame.Token, rtt: c.rttStats.SmoothedRTT()})
}
return nil
}
func (c *Conn) handleHandshakeDoneFrame(rcvTime time.Time) error {
if c.perspective == protocol.PerspectiveServer {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "received a HANDSHAKE_DONE frame",
}
}
if !c.handshakeConfirmed {
return c.handleHandshakeConfirmed(rcvTime)
}
return nil
}
func (c *Conn) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
acked1RTTPacket, err := c.sentPacketHandler.ReceivedAck(frame, encLevel, c.lastPacketReceivedTime)
if err != nil {
return err
}
if !acked1RTTPacket {
return nil
}
// On the client side: If the packet acknowledged a 1-RTT packet, this confirms the handshake.
// This is only possible if the ACK was sent in a 1-RTT packet.
// This is an optimization over simply waiting for a HANDSHAKE_DONE frame, see section 4.1.2 of RFC 9001.
if c.perspective == protocol.PerspectiveClient && !c.handshakeConfirmed {
if err := c.handleHandshakeConfirmed(rcvTime); err != nil {
return err
}
}
// If one of the acknowledged packets was a Path MTU probe packet, this might have increased the Path MTU estimate.
if c.mtuDiscoverer != nil {
if mtu := c.mtuDiscoverer.CurrentSize(); mtu > protocol.ByteCount(c.currentMTUEstimate.Load()) {
c.currentMTUEstimate.Store(uint32(mtu))
c.sentPacketHandler.SetMaxDatagramSize(mtu)
}
}
return c.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked())
}
func (c *Conn) handleDatagramFrame(f *wire.DatagramFrame) error {
if f.Length(c.version) > wire.MaxDatagramSize {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "DATAGRAM frame too large",
}
}
c.datagramQueue.HandleDatagramFrame(f)
return nil
}
func (c *Conn) setCloseError(e *closeError) {
c.closeErr.CompareAndSwap(nil, e)
select {
case c.closeChan <- struct{}{}:
default:
}
}
// closeLocal closes the connection and send a CONNECTION_CLOSE containing the error
func (c *Conn) closeLocal(e error) {
c.setCloseError(&closeError{err: e, immediate: false})
}
// destroy closes the connection without sending the error on the wire
func (c *Conn) destroy(e error) {
c.destroyImpl(e)
<-c.ctx.Done()
}
func (c *Conn) destroyImpl(e error) {
c.setCloseError(&closeError{err: e, immediate: true})
}
// CloseWithError closes the connection with an error.
// The error string will be sent to the peer.
func (c *Conn) CloseWithError(code ApplicationErrorCode, desc string) error {
c.closeLocal(&qerr.ApplicationError{
ErrorCode: code,
ErrorMessage: desc,
})
<-c.ctx.Done()
return nil
}
func (c *Conn) closeWithTransportError(code TransportErrorCode) {
c.closeLocal(&qerr.TransportError{ErrorCode: code})
<-c.ctx.Done()
}
func (c *Conn) handleCloseError(closeErr *closeError) {
if closeErr.immediate {
if nerr, ok := closeErr.err.(net.Error); ok && nerr.Timeout() {
c.logger.Errorf("Destroying connection: %s", closeErr.err)
} else {
c.logger.Errorf("Destroying connection with error: %s", closeErr.err)
}
} else {
if closeErr.err == nil {
c.logger.Infof("Closing connection.")
} else {
c.logger.Errorf("Closing connection with error: %s", closeErr.err)
}
}
e := closeErr.err
if e == nil {
e = &qerr.ApplicationError{}
} else {
defer func() { closeErr.err = e }()
}
var (
statelessResetErr *StatelessResetError
versionNegotiationErr *VersionNegotiationError
recreateErr *errCloseForRecreating
applicationErr *ApplicationError
transportErr *TransportError
)
var isRemoteClose bool
switch {
case errors.Is(e, qerr.ErrIdleTimeout),
errors.Is(e, qerr.ErrHandshakeTimeout),
errors.As(e, &statelessResetErr),
errors.As(e, &versionNegotiationErr),
errors.As(e, &recreateErr):
case errors.As(e, &applicationErr):
isRemoteClose = applicationErr.Remote
case errors.As(e, &transportErr):
isRemoteClose = transportErr.Remote
case closeErr.immediate:
e = closeErr.err
default:
e = &qerr.TransportError{
ErrorCode: qerr.InternalError,
ErrorMessage: e.Error(),
}
}
c.streamsMap.CloseWithError(e)
if c.datagramQueue != nil {
c.datagramQueue.CloseWithError(e)
}
// In rare instances, the connection ID manager might switch to a new connection ID
// when sending the CONNECTION_CLOSE frame.
// The connection ID manager removes the active stateless reset token from the packet
// handler map when it is closed, so we need to make sure that this happens last.
defer c.connIDManager.Close()
if c.tracer != nil && c.tracer.ClosedConnection != nil && !errors.As(e, &recreateErr) {
c.tracer.ClosedConnection(e)
}
// If this is a remote close we're done here
if isRemoteClose {
c.connIDGenerator.ReplaceWithClosed(nil, 3*c.rttStats.PTO(false))
return
}
if closeErr.immediate {
c.connIDGenerator.RemoveAll()
return
}
// Don't send out any CONNECTION_CLOSE if this is an error that occurred
// before we even sent out the first packet.
if c.perspective == protocol.PerspectiveClient && !c.sentFirstPacket {
c.connIDGenerator.RemoveAll()
return
}
connClosePacket, err := c.sendConnectionClose(e)
if err != nil {
c.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err)
}
c.connIDGenerator.ReplaceWithClosed(connClosePacket, 3*c.rttStats.PTO(false))
}
func (c *Conn) dropEncryptionLevel(encLevel protocol.EncryptionLevel, now time.Time) error {
if c.tracer != nil && c.tracer.DroppedEncryptionLevel != nil {
c.tracer.DroppedEncryptionLevel(encLevel)
}
c.sentPacketHandler.DropPackets(encLevel, now)
c.receivedPacketHandler.DropPackets(encLevel)
//nolint:exhaustive // only Initial and 0-RTT need special treatment
switch encLevel {
case protocol.EncryptionInitial:
c.droppedInitialKeys = true
c.cryptoStreamHandler.DiscardInitialKeys()
case protocol.Encryption0RTT:
c.streamsMap.ResetFor0RTT()
c.framer.Handle0RTTRejection()
return c.connFlowController.Reset()
}
return c.cryptoStreamManager.Drop(encLevel)
}
// is called for the client, when restoring transport parameters saved for 0-RTT
func (c *Conn) restoreTransportParameters(params *wire.TransportParameters) {
if c.logger.Debug() {
c.logger.Debugf("Restoring Transport Parameters: %s", params)
}
c.peerParams = params
c.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit)
c.connFlowController.UpdateSendWindow(params.InitialMaxData)
c.streamsMap.HandleTransportParameters(params)
c.connStateMutex.Lock()
c.connState.SupportsDatagrams = c.supportsDatagrams()
c.connStateMutex.Unlock()
}
func (c *Conn) handleTransportParameters(params *wire.TransportParameters) error {
if c.tracer != nil && c.tracer.ReceivedTransportParameters != nil {
c.tracer.ReceivedTransportParameters(params)
}
if err := c.checkTransportParameters(params); err != nil {
return &qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: err.Error(),
}
}
if c.perspective == protocol.PerspectiveClient && c.peerParams != nil && c.ConnectionState().Used0RTT && !params.ValidForUpdate(c.peerParams) {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "server sent reduced limits after accepting 0-RTT data",
}
}
c.peerParams = params
// On the client side we have to wait for handshake completion.
// During a 0-RTT connection, we are only allowed to use the new transport parameters for 1-RTT packets.
if c.perspective == protocol.PerspectiveServer {
c.applyTransportParameters()
// On the server side, the early connection is ready as soon as we processed
// the client's transport parameters.
close(c.earlyConnReadyChan)
}
c.connStateMutex.Lock()
c.connState.SupportsDatagrams = c.supportsDatagrams()
c.connStateMutex.Unlock()
return nil
}
func (c *Conn) checkTransportParameters(params *wire.TransportParameters) error {
if c.logger.Debug() {
c.logger.Debugf("Processed Transport Parameters: %s", params)
}
// check the initial_source_connection_id
if params.InitialSourceConnectionID != c.handshakeDestConnID {
return fmt.Errorf("expected initial_source_connection_id to equal %s, is %s", c.handshakeDestConnID, params.InitialSourceConnectionID)
}
if c.perspective == protocol.PerspectiveServer {
return nil
}
// check the original_destination_connection_id
if params.OriginalDestinationConnectionID != c.origDestConnID {
return fmt.Errorf("expected original_destination_connection_id to equal %s, is %s", c.origDestConnID, params.OriginalDestinationConnectionID)
}
if c.retrySrcConnID != nil { // a Retry was performed
if params.RetrySourceConnectionID == nil {
return errors.New("missing retry_source_connection_id")
}
if *params.RetrySourceConnectionID != *c.retrySrcConnID {
return fmt.Errorf("expected retry_source_connection_id to equal %s, is %s", c.retrySrcConnID, *params.RetrySourceConnectionID)
}
} else if params.RetrySourceConnectionID != nil {
return errors.New("received retry_source_connection_id, although no Retry was performed")
}
return nil
}
func (c *Conn) applyTransportParameters() {
params := c.peerParams
// Our local idle timeout will always be > 0.
c.idleTimeout = c.config.MaxIdleTimeout
// If the peer advertised an idle timeout, take the minimum of the values.
if params.MaxIdleTimeout > 0 {
c.idleTimeout = min(c.idleTimeout, params.MaxIdleTimeout)
}
c.keepAliveInterval = min(c.config.KeepAlivePeriod, c.idleTimeout/2)
c.streamsMap.HandleTransportParameters(params)
c.frameParser.SetAckDelayExponent(params.AckDelayExponent)
c.connFlowController.UpdateSendWindow(params.InitialMaxData)
c.rttStats.SetMaxAckDelay(params.MaxAckDelay)
c.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit)
if params.StatelessResetToken != nil {
c.connIDManager.SetStatelessResetToken(*params.StatelessResetToken)
}
// We don't support connection migration yet, so we don't have any use for the preferred_address.
if params.PreferredAddress != nil {
// Retire the connection ID.
c.connIDManager.AddFromPreferredAddress(params.PreferredAddress.ConnectionID, params.PreferredAddress.StatelessResetToken)
}
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
if params.MaxUDPPayloadSize > 0 && params.MaxUDPPayloadSize < maxPacketSize {
maxPacketSize = params.MaxUDPPayloadSize
}
c.mtuDiscoverer = newMTUDiscoverer(
c.rttStats,
protocol.ByteCount(c.config.InitialPacketSize),
maxPacketSize,
c.tracer,
)
}
func (c *Conn) triggerSending(now time.Time) error {
c.pacingDeadline = time.Time{}
sendMode := c.sentPacketHandler.SendMode(now)
switch sendMode {
case ackhandler.SendAny:
return c.sendPackets(now)
case ackhandler.SendNone:
return nil
case ackhandler.SendPacingLimited:
deadline := c.sentPacketHandler.TimeUntilSend()
if deadline.IsZero() {
deadline = deadlineSendImmediately
}
c.pacingDeadline = deadline
// Allow sending of an ACK if we're pacing limit.
// This makes sure that a peer that is mostly receiving data (and thus has an inaccurate cwnd estimate)
// sends enough ACKs to allow its peer to utilize the bandwidth.
fallthrough
case ackhandler.SendAck:
// We can at most send a single ACK only packet.
// There will only be a new ACK after receiving new packets.
// SendAck is only returned when we're congestion limited, so we don't need to set the pacing timer.
return c.maybeSendAckOnlyPacket(now)
case ackhandler.SendPTOInitial, ackhandler.SendPTOHandshake, ackhandler.SendPTOAppData:
if err := c.sendProbePacket(sendMode, now); err != nil {
return err
}
if c.sendQueue.WouldBlock() {
c.scheduleSending()
return nil
}
return c.triggerSending(now)
default:
return fmt.Errorf("BUG: invalid send mode %d", sendMode)
}
}
func (c *Conn) sendPackets(now time.Time) error {
if c.perspective == protocol.PerspectiveClient && c.handshakeConfirmed {
if pm := c.pathManagerOutgoing.Load(); pm != nil {
connID, frame, tr, ok := pm.NextPathToProbe()
if ok {
probe, buf, err := c.packer.PackPathProbePacket(connID, []ackhandler.Frame{frame}, c.version)
if err != nil {
return err
}
c.logger.Debugf("sending path probe packet from %s", c.LocalAddr())
c.logShortHeaderPacket(probe.DestConnID, probe.Ack, probe.Frames, probe.StreamFrames, probe.PacketNumber, probe.PacketNumberLen, probe.KeyPhase, protocol.ECNNon, buf.Len(), false)
c.registerPackedShortHeaderPacket(probe, protocol.ECNNon, now)
tr.WriteTo(buf.Data, c.conn.RemoteAddr())
// There's (likely) more data to send. Loop around again.
c.scheduleSending()
return nil
}
}
}
// Path MTU Discovery
// Can't use GSO, since we need to send a single packet that's larger than our current maximum size.
// Performance-wise, this doesn't matter, since we only send a very small (<10) number of
// MTU probe packets per connection.
if c.handshakeConfirmed && c.mtuDiscoverer != nil && c.mtuDiscoverer.ShouldSendProbe(now) {
ping, size := c.mtuDiscoverer.GetPing(now)
p, buf, err := c.packer.PackMTUProbePacket(ping, size, c.version)
if err != nil {
return err
}
ecn := c.sentPacketHandler.ECNMode(true)
c.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false)
c.registerPackedShortHeaderPacket(p, ecn, now)
c.sendQueue.Send(buf, 0, ecn)
// There's (likely) more data to send. Loop around again.
c.scheduleSending()
return nil
}
if offset := c.connFlowController.GetWindowUpdate(now); offset > 0 {
c.framer.QueueControlFrame(&wire.MaxDataFrame{MaximumData: offset})
}
if cf := c.cryptoStreamManager.GetPostHandshakeData(protocol.MaxPostHandshakeCryptoFrameSize); cf != nil {
c.queueControlFrame(cf)
}
if !c.handshakeConfirmed {
packet, err := c.packer.PackCoalescedPacket(false, c.maxPacketSize(), now, c.version)
if err != nil || packet == nil {
return err
}
c.sentFirstPacket = true
if err := c.sendPackedCoalescedPacket(packet, c.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now); err != nil {
return err
}
//nolint:exhaustive // only need to handle pacing-related events here
switch c.sentPacketHandler.SendMode(now) {
case ackhandler.SendPacingLimited:
c.resetPacingDeadline()
case ackhandler.SendAny:
c.pacingDeadline = deadlineSendImmediately
}
return nil
}
if c.conn.capabilities().GSO {
return c.sendPacketsWithGSO(now)
}
return c.sendPacketsWithoutGSO(now)
}
func (c *Conn) sendPacketsWithoutGSO(now time.Time) error {
for {
buf := getPacketBuffer()
ecn := c.sentPacketHandler.ECNMode(true)
if _, err := c.appendOneShortHeaderPacket(buf, c.maxPacketSize(), ecn, now); err != nil {
if err == errNothingToPack {
buf.Release()
return nil
}
return err
}
c.sendQueue.Send(buf, 0, ecn)
if c.sendQueue.WouldBlock() {
return nil
}
sendMode := c.sentPacketHandler.SendMode(now)
if sendMode == ackhandler.SendPacingLimited {
c.resetPacingDeadline()
return nil
}
if sendMode != ackhandler.SendAny {
return nil
}
// Prioritize receiving of packets over sending out more packets.
c.receivedPacketMx.Lock()
hasPackets := !c.receivedPackets.Empty()
c.receivedPacketMx.Unlock()
if hasPackets {
c.pacingDeadline = deadlineSendImmediately
return nil
}
}
}
func (c *Conn) sendPacketsWithGSO(now time.Time) error {
buf := getLargePacketBuffer()
maxSize := c.maxPacketSize()
ecn := c.sentPacketHandler.ECNMode(true)
for {
var dontSendMore bool
size, err := c.appendOneShortHeaderPacket(buf, maxSize, ecn, now)
if err != nil {
if err != errNothingToPack {
return err
}
if buf.Len() == 0 {
buf.Release()
return nil
}
dontSendMore = true
}
if !dontSendMore {
sendMode := c.sentPacketHandler.SendMode(now)
if sendMode == ackhandler.SendPacingLimited {
c.resetPacingDeadline()
}
if sendMode != ackhandler.SendAny {
dontSendMore = true
}
}
// Don't send more packets in this batch if they require a different ECN marking than the previous ones.
nextECN := c.sentPacketHandler.ECNMode(true)
// Append another packet if
// 1. The congestion controller and pacer allow sending more
// 2. The last packet appended was a full-size packet
// 3. The next packet will have the same ECN marking
// 4. We still have enough space for another full-size packet in the buffer
if !dontSendMore && size == maxSize && nextECN == ecn && buf.Len()+maxSize <= buf.Cap() {
continue
}
c.sendQueue.Send(buf, uint16(maxSize), ecn)
if dontSendMore {
return nil
}
if c.sendQueue.WouldBlock() {
return nil
}
// Prioritize receiving of packets over sending out more packets.
c.receivedPacketMx.Lock()
hasPackets := !c.receivedPackets.Empty()
c.receivedPacketMx.Unlock()
if hasPackets {
c.pacingDeadline = deadlineSendImmediately
return nil
}
ecn = nextECN
buf = getLargePacketBuffer()
}
}
func (c *Conn) resetPacingDeadline() {
deadline := c.sentPacketHandler.TimeUntilSend()
if deadline.IsZero() {
deadline = deadlineSendImmediately
}
c.pacingDeadline = deadline
}
func (c *Conn) maybeSendAckOnlyPacket(now time.Time) error {
if !c.handshakeConfirmed {
ecn := c.sentPacketHandler.ECNMode(false)
packet, err := c.packer.PackCoalescedPacket(true, c.maxPacketSize(), now, c.version)
if err != nil {
return err
}
if packet == nil {
return nil
}
return c.sendPackedCoalescedPacket(packet, ecn, now)
}
ecn := c.sentPacketHandler.ECNMode(true)
p, buf, err := c.packer.PackAckOnlyPacket(c.maxPacketSize(), now, c.version)
if err != nil {
if err == errNothingToPack {
return nil
}
return err
}
c.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false)
c.registerPackedShortHeaderPacket(p, ecn, now)
c.sendQueue.Send(buf, 0, ecn)
return nil
}
func (c *Conn) sendProbePacket(sendMode ackhandler.SendMode, now time.Time) error {
var encLevel protocol.EncryptionLevel
//nolint:exhaustive // We only need to handle the PTO send modes here.
switch sendMode {
case ackhandler.SendPTOInitial:
encLevel = protocol.EncryptionInitial
case ackhandler.SendPTOHandshake:
encLevel = protocol.EncryptionHandshake
case ackhandler.SendPTOAppData:
encLevel = protocol.Encryption1RTT
default:
return fmt.Errorf("connection BUG: unexpected send mode: %d", sendMode)
}
// Queue probe packets until we actually send out a packet,
// or until there are no more packets to queue.
var packet *coalescedPacket
for packet == nil {
if wasQueued := c.sentPacketHandler.QueueProbePacket(encLevel); !wasQueued {
break
}
var err error
packet, err = c.packer.PackPTOProbePacket(encLevel, c.maxPacketSize(), false, now, c.version)
if err != nil {
return err
}
}
if packet == nil {
var err error
packet, err = c.packer.PackPTOProbePacket(encLevel, c.maxPacketSize(), true, now, c.version)
if err != nil {
return err
}
}
if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) {
return fmt.Errorf("connection BUG: couldn't pack %s probe packet: %v", encLevel, packet)
}
return c.sendPackedCoalescedPacket(packet, c.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now)
}
// appendOneShortHeaderPacket appends a new packet to the given packetBuffer.
// If there was nothing to pack, the returned size is 0.
func (c *Conn) appendOneShortHeaderPacket(buf *packetBuffer, maxSize protocol.ByteCount, ecn protocol.ECN, now time.Time) (protocol.ByteCount, error) {
startLen := buf.Len()
p, err := c.packer.AppendPacket(buf, maxSize, now, c.version)
if err != nil {
return 0, err
}
size := buf.Len() - startLen
c.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, size, false)
c.registerPackedShortHeaderPacket(p, ecn, now)
return size, nil
}
func (c *Conn) registerPackedShortHeaderPacket(p shortHeaderPacket, ecn protocol.ECN, now time.Time) {
if p.IsPathProbePacket {
c.sentPacketHandler.SentPacket(
now,
p.PacketNumber,
protocol.InvalidPacketNumber,
p.StreamFrames,
p.Frames,
protocol.Encryption1RTT,
ecn,
p.Length,
p.IsPathMTUProbePacket,
true,
)
return
}
if c.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) {
c.firstAckElicitingPacketAfterIdleSentTime = now
}
largestAcked := protocol.InvalidPacketNumber
if p.Ack != nil {
largestAcked = p.Ack.LargestAcked()
}
c.sentPacketHandler.SentPacket(
now,
p.PacketNumber,
largestAcked,
p.StreamFrames,
p.Frames,
protocol.Encryption1RTT,
ecn,
p.Length,
p.IsPathMTUProbePacket,
false,
)
c.connIDManager.SentPacket()
}
func (c *Conn) sendPackedCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN, now time.Time) error {
c.logCoalescedPacket(packet, ecn)
for _, p := range packet.longHdrPackets {
if c.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
c.firstAckElicitingPacketAfterIdleSentTime = now
}
largestAcked := protocol.InvalidPacketNumber
if p.ack != nil {
largestAcked = p.ack.LargestAcked()
}
c.sentPacketHandler.SentPacket(
now,
p.header.PacketNumber,
largestAcked,
p.streamFrames,
p.frames,
p.EncryptionLevel(),
ecn,
p.length,
false,
false,
)
if c.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake &&
!c.droppedInitialKeys {
// On the client side, Initial keys are dropped as soon as the first Handshake packet is sent.
// See Section 4.9.1 of RFC 9001.
if err := c.dropEncryptionLevel(protocol.EncryptionInitial, now); err != nil {
return err
}
}
}
if p := packet.shortHdrPacket; p != nil {
if c.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
c.firstAckElicitingPacketAfterIdleSentTime = now
}
largestAcked := protocol.InvalidPacketNumber
if p.Ack != nil {
largestAcked = p.Ack.LargestAcked()
}
c.sentPacketHandler.SentPacket(
now,
p.PacketNumber,
largestAcked,
p.StreamFrames,
p.Frames,
protocol.Encryption1RTT,
ecn,
p.Length,
p.IsPathMTUProbePacket,
false,
)
}
c.connIDManager.SentPacket()
c.sendQueue.Send(packet.buffer, 0, ecn)
return nil
}
func (c *Conn) sendConnectionClose(e error) ([]byte, error) {
var packet *coalescedPacket
var err error
var transportErr *qerr.TransportError
var applicationErr *qerr.ApplicationError
if errors.As(e, &transportErr) {
packet, err = c.packer.PackConnectionClose(transportErr, c.maxPacketSize(), c.version)
} else if errors.As(e, &applicationErr) {
packet, err = c.packer.PackApplicationClose(applicationErr, c.maxPacketSize(), c.version)
} else {
packet, err = c.packer.PackConnectionClose(&qerr.TransportError{
ErrorCode: qerr.InternalError,
ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()),
}, c.maxPacketSize(), c.version)
}
if err != nil {
return nil, err
}
ecn := c.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket())
c.logCoalescedPacket(packet, ecn)
return packet.buffer.Data, c.conn.Write(packet.buffer.Data, 0, ecn)
}
func (c *Conn) maxPacketSize() protocol.ByteCount {
if c.mtuDiscoverer == nil {
// Use the configured packet size on the client side.
// If the server sends a max_udp_payload_size that's smaller than this size, we can ignore this:
// Apparently the server still processed the (fully padded) Initial packet anyway.
if c.perspective == protocol.PerspectiveClient {
return protocol.ByteCount(c.config.InitialPacketSize)
}
// On the server side, there's no downside to using 1200 bytes until we received the client's transport
// parameters:
// * If the first packet didn't contain the entire ClientHello, all we can do is ACK that packet. We don't
// need a lot of bytes for that.
// * If it did, we will have processed the transport parameters and initialized the MTU discoverer.
return protocol.MinInitialPacketSize
}
return c.mtuDiscoverer.CurrentSize()
}
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
func (c *Conn) AcceptStream(ctx context.Context) (*Stream, error) {
return c.streamsMap.AcceptStream(ctx)
}
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
func (c *Conn) AcceptUniStream(ctx context.Context) (*ReceiveStream, error) {
return c.streamsMap.AcceptUniStream(ctx)
}
// OpenStream opens a new bidirectional QUIC stream.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream,
// or the stream has been reset or closed.
// When reaching the peer's stream limit, it is not possible to open a new stream until the
// peer raises the stream limit. In that case, a [StreamLimitReachedError] is returned.
func (c *Conn) OpenStream() (*Stream, error) {
return c.streamsMap.OpenStream()
}
// OpenStreamSync opens a new bidirectional QUIC stream.
// It blocks until a new stream can be opened.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream,
// or the stream has been reset or closed.
func (c *Conn) OpenStreamSync(ctx context.Context) (*Stream, error) {
return c.streamsMap.OpenStreamSync(ctx)
}
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream,
// or the stream has been reset or closed.
// When reaching the peer's stream limit, it is not possible to open a new stream until the
// peer raises the stream limit. In that case, a [StreamLimitReachedError] is returned.
func (c *Conn) OpenUniStream() (*SendStream, error) {
return c.streamsMap.OpenUniStream()
}
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
// It blocks until a new stream can be opened.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream,
// or the stream has been reset or closed.
func (c *Conn) OpenUniStreamSync(ctx context.Context) (*SendStream, error) {
return c.streamsMap.OpenUniStreamSync(ctx)
}
func (c *Conn) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController {
initialSendWindow := c.peerParams.InitialMaxStreamDataUni
if id.Type() == protocol.StreamTypeBidi {
if id.InitiatedBy() == c.perspective {
initialSendWindow = c.peerParams.InitialMaxStreamDataBidiRemote
} else {
initialSendWindow = c.peerParams.InitialMaxStreamDataBidiLocal
}
}
return flowcontrol.NewStreamFlowController(
id,
c.connFlowController,
protocol.ByteCount(c.config.InitialStreamReceiveWindow),
protocol.ByteCount(c.config.MaxStreamReceiveWindow),
initialSendWindow,
c.rttStats,
c.logger,
)
}
// scheduleSending signals that we have data for sending
func (c *Conn) scheduleSending() {
select {
case c.sendingScheduled <- struct{}{}:
default:
}
}
// tryQueueingUndecryptablePacket queues a packet for which we're missing the decryption keys.
// The logging.PacketType is only used for logging purposes.
func (c *Conn) tryQueueingUndecryptablePacket(p receivedPacket, pt logging.PacketType) {
if c.handshakeComplete {
panic("shouldn't queue undecryptable packets after handshake completion")
}
if len(c.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets {
if c.tracer != nil && c.tracer.DroppedPacket != nil {
c.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropDOSPrevention)
}
c.logger.Infof("Dropping undecryptable packet (%d bytes). Undecryptable packet queue full.", p.Size())
return
}
c.logger.Infof("Queueing packet (%d bytes) for later decryption", p.Size())
if c.tracer != nil && c.tracer.BufferedPacket != nil {
c.tracer.BufferedPacket(pt, p.Size())
}
c.undecryptablePackets = append(c.undecryptablePackets, p)
}
func (c *Conn) queueControlFrame(f wire.Frame) {
c.framer.QueueControlFrame(f)
c.scheduleSending()
}
func (c *Conn) onHasConnectionData() { c.scheduleSending() }
func (c *Conn) onHasStreamData(id protocol.StreamID, str *SendStream) {
c.framer.AddActiveStream(id, str)
c.scheduleSending()
}
func (c *Conn) onHasStreamControlFrame(id protocol.StreamID, str streamControlFrameGetter) {
c.framer.AddStreamWithControlFrames(id, str)
c.scheduleSending()
}
func (c *Conn) onStreamCompleted(id protocol.StreamID) {
if err := c.streamsMap.DeleteStream(id); err != nil {
c.closeLocal(err)
}
c.framer.RemoveActiveStream(id)
}
// SendDatagram sends a message using a QUIC datagram, as specified in RFC 9221,
// if the peer enabled datagram support.
// There is no delivery guarantee for DATAGRAM frames, they are not retransmitted if lost.
// The payload of the datagram needs to fit into a single QUIC packet.
// In addition, a datagram may be dropped before being sent out if the available packet size suddenly decreases.
// If the payload is too large to be sent at the current time, a DatagramTooLargeError is returned.
func (c *Conn) SendDatagram(p []byte) error {
if !c.supportsDatagrams() {
return errors.New("datagram support disabled")
}
f := &wire.DatagramFrame{DataLenPresent: true}
// The payload size estimate is conservative.
// Under many circumstances we could send a few more bytes.
maxDataLen := min(
f.MaxDataLen(c.peerParams.MaxDatagramFrameSize, c.version),
protocol.ByteCount(c.currentMTUEstimate.Load()),
)
if protocol.ByteCount(len(p)) > maxDataLen {
return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)}
}
f.Data = make([]byte, len(p))
copy(f.Data, p)
return c.datagramQueue.Add(f)
}
// ReceiveDatagram gets a message received in a QUIC datagram, as specified in RFC 9221.
func (c *Conn) ReceiveDatagram(ctx context.Context) ([]byte, error) {
if !c.config.EnableDatagrams {
return nil, errors.New("datagram support disabled")
}
return c.datagramQueue.Receive(ctx)
}
// LocalAddr returns the local address of the QUIC connection.
func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
// RemoteAddr returns the remote address of the QUIC connection.
func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
func (c *Conn) getPathManager() *pathManagerOutgoing {
c.pathManagerOutgoing.CompareAndSwap(nil,
func() *pathManagerOutgoing { // this function is only called if a swap is performed
return newPathManagerOutgoing(
c.connIDManager.GetConnIDForPath,
c.connIDManager.RetireConnIDForPath,
c.scheduleSending,
)
}(),
)
return c.pathManagerOutgoing.Load()
}
func (c *Conn) AddPath(t *Transport) (*Path, error) {
if c.perspective == protocol.PerspectiveServer {
return nil, errors.New("server cannot initiate connection migration")
}
if c.peerParams.DisableActiveMigration {
return nil, errors.New("server disabled connection migration")
}
if err := t.init(false); err != nil {
return nil, err
}
return c.getPathManager().NewPath(
t,
200*time.Millisecond, // initial RTT estimate
func() {
runner := (*packetHandlerMap)(t)
c.connIDGenerator.AddConnRunner(
runner,
connRunnerCallbacks{
AddConnectionID: func(connID protocol.ConnectionID) { runner.Add(connID, c) },
RemoveConnectionID: runner.Remove,
ReplaceWithClosed: runner.ReplaceWithClosed,
},
)
},
), nil
}
// HandshakeComplete blocks until the handshake completes (or fails).
// For the client, data sent before completion of the handshake is encrypted with 0-RTT keys.
// For the server, data sent before completion of the handshake is encrypted with 1-RTT keys,
// however the client's identity is only verified once the handshake completes.
func (c *Conn) HandshakeComplete() <-chan struct{} {
return c.handshakeCompleteChan
}
func (c *Conn) NextConnection(ctx context.Context) (*Conn, error) {
// The handshake might fail after the server rejected 0-RTT.
// This could happen if the Finished message is malformed or never received.
select {
case <-ctx.Done():
return nil, context.Cause(ctx)
case <-c.Context().Done():
case <-c.HandshakeComplete():
c.streamsMap.UseResetMaps()
}
return c, nil
}
// estimateMaxPayloadSize estimates the maximum payload size for short header packets.
// It is not very sophisticated: it just subtracts the size of header (assuming the maximum
// connection ID length), and the size of the encryption tag.
func estimateMaxPayloadSize(mtu protocol.ByteCount) protocol.ByteCount {
return mtu - 1 /* type byte */ - 20 /* maximum connection ID length */ - 16 /* tag size */
}
package quic
import (
"slices"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
// ConvertFrame converts a wire.Frame into a logging.Frame.
// This makes it possible for external packages to access the frames.
// Furthermore, it removes the data slices from CRYPTO and STREAM frames.
func toLoggingFrame(frame wire.Frame) logging.Frame {
switch f := frame.(type) {
case *wire.AckFrame:
// We use a pool for ACK frames.
// Implementations of the tracer interface may hold on to frames, so we need to make a copy here.
return toLoggingAckFrame(f)
case *wire.CryptoFrame:
return &logging.CryptoFrame{
Offset: f.Offset,
Length: protocol.ByteCount(len(f.Data)),
}
case *wire.StreamFrame:
return &logging.StreamFrame{
StreamID: f.StreamID,
Offset: f.Offset,
Length: f.DataLen(),
Fin: f.Fin,
}
case *wire.DatagramFrame:
return &logging.DatagramFrame{
Length: logging.ByteCount(len(f.Data)),
}
default:
return logging.Frame(frame)
}
}
func toLoggingAckFrame(f *wire.AckFrame) *logging.AckFrame {
ack := &logging.AckFrame{
AckRanges: slices.Clone(f.AckRanges),
DelayTime: f.DelayTime,
ECNCE: f.ECNCE,
ECT0: f.ECT0,
ECT1: f.ECT1,
}
return ack
}
func (c *Conn) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) {
// quic-go logging
if c.logger.Debug() {
p.header.Log(c.logger)
if p.ack != nil {
wire.LogFrame(c.logger, p.ack, true)
}
for _, frame := range p.frames {
wire.LogFrame(c.logger, frame.Frame, true)
}
for _, frame := range p.streamFrames {
wire.LogFrame(c.logger, frame.Frame, true)
}
}
// tracing
if c.tracer != nil && c.tracer.SentLongHeaderPacket != nil {
frames := make([]logging.Frame, 0, len(p.frames))
for _, f := range p.frames {
frames = append(frames, toLoggingFrame(f.Frame))
}
for _, f := range p.streamFrames {
frames = append(frames, toLoggingFrame(f.Frame))
}
var ack *logging.AckFrame
if p.ack != nil {
ack = toLoggingAckFrame(p.ack)
}
c.tracer.SentLongHeaderPacket(p.header, p.length, ecn, ack, frames)
}
}
func (c *Conn) logShortHeaderPacket(
destConnID protocol.ConnectionID,
ackFrame *wire.AckFrame,
frames []ackhandler.Frame,
streamFrames []ackhandler.StreamFrame,
pn protocol.PacketNumber,
pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit,
ecn protocol.ECN,
size protocol.ByteCount,
isCoalesced bool,
) {
if c.logger.Debug() && !isCoalesced {
c.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT (ECN: %s)", pn, size, c.logID, ecn)
}
// quic-go logging
if c.logger.Debug() {
wire.LogShortHeader(c.logger, destConnID, pn, pnLen, kp)
if ackFrame != nil {
wire.LogFrame(c.logger, ackFrame, true)
}
for _, f := range frames {
wire.LogFrame(c.logger, f.Frame, true)
}
for _, f := range streamFrames {
wire.LogFrame(c.logger, f.Frame, true)
}
}
// tracing
if c.tracer != nil && c.tracer.SentShortHeaderPacket != nil {
fs := make([]logging.Frame, 0, len(frames)+len(streamFrames))
for _, f := range frames {
fs = append(fs, toLoggingFrame(f.Frame))
}
for _, f := range streamFrames {
fs = append(fs, toLoggingFrame(f.Frame))
}
var ack *logging.AckFrame
if ackFrame != nil {
ack = toLoggingAckFrame(ackFrame)
}
c.tracer.SentShortHeaderPacket(
&logging.ShortHeader{DestConnectionID: destConnID, PacketNumber: pn, PacketNumberLen: pnLen, KeyPhase: kp},
size,
ecn,
ack,
fs,
)
}
}
func (c *Conn) logCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN) {
if c.logger.Debug() {
// There's a short period between dropping both Initial and Handshake keys and completion of the handshake,
// during which we might call PackCoalescedPacket but just pack a short header packet.
if len(packet.longHdrPackets) == 0 && packet.shortHdrPacket != nil {
c.logShortHeaderPacket(
packet.shortHdrPacket.DestConnID,
packet.shortHdrPacket.Ack,
packet.shortHdrPacket.Frames,
packet.shortHdrPacket.StreamFrames,
packet.shortHdrPacket.PacketNumber,
packet.shortHdrPacket.PacketNumberLen,
packet.shortHdrPacket.KeyPhase,
ecn,
packet.shortHdrPacket.Length,
false,
)
return
}
if len(packet.longHdrPackets) > 1 {
c.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.longHdrPackets), packet.buffer.Len(), c.logID)
} else {
c.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.longHdrPackets[0].header.PacketNumber, packet.buffer.Len(), c.logID, packet.longHdrPackets[0].EncryptionLevel())
}
}
for _, p := range packet.longHdrPackets {
c.logLongHeaderPacket(p, ecn)
}
if p := packet.shortHdrPacket; p != nil {
c.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, p.Length, true)
}
}
package quic
import (
"time"
"github.com/quic-go/quic-go/internal/utils"
)
var deadlineSendImmediately = time.Time{}.Add(42 * time.Millisecond) // any value > time.Time{} and before time.Now() is fine
type connectionTimer struct {
timer *utils.Timer
last time.Time
}
func newTimer() *connectionTimer {
return &connectionTimer{timer: utils.NewTimer()}
}
func (t *connectionTimer) SetRead() {
if deadline := t.timer.Deadline(); deadline != deadlineSendImmediately {
t.last = deadline
}
t.timer.SetRead()
}
func (t *connectionTimer) Chan() <-chan time.Time {
return t.timer.Chan()
}
// SetTimer resets the timer.
// It makes sure that the deadline is strictly increasing.
// This prevents busy-looping in cases where the timer fires, but we can't actually send out a packet.
// This doesn't apply to the pacing deadline, which can be set multiple times to deadlineSendImmediately.
func (t *connectionTimer) SetTimer(idleTimeoutOrKeepAlive, connIDRetirement, ackAlarm, lossTime, pacing time.Time) {
deadline := idleTimeoutOrKeepAlive
if !connIDRetirement.IsZero() && connIDRetirement.Before(deadline) && connIDRetirement.After(t.last) {
deadline = connIDRetirement
}
if !ackAlarm.IsZero() && ackAlarm.Before(deadline) && ackAlarm.After(t.last) {
deadline = ackAlarm
}
if !lossTime.IsZero() && lossTime.Before(deadline) && lossTime.After(t.last) {
deadline = lossTime
}
if !pacing.IsZero() && pacing.Before(deadline) {
deadline = pacing
}
t.timer.Reset(deadline)
}
func (t *connectionTimer) Stop() {
t.timer.Stop()
}
package quic
import (
"errors"
"fmt"
"io"
"os"
"slices"
"strconv"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
)
const disableClientHelloScramblingEnv = "QUIC_GO_DISABLE_CLIENTHELLO_SCRAMBLING"
// The baseCryptoStream is used by the cryptoStream and the initialCryptoStream.
// This allows us to implement different logic for PopCryptoFrame for the two streams.
type baseCryptoStream struct {
queue frameSorter
highestOffset protocol.ByteCount
finished bool
writeOffset protocol.ByteCount
writeBuf []byte
}
func newCryptoStream() *cryptoStream {
return &cryptoStream{baseCryptoStream{queue: *newFrameSorter()}}
}
func (s *baseCryptoStream) HandleCryptoFrame(f *wire.CryptoFrame) error {
highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
return &qerr.TransportError{
ErrorCode: qerr.CryptoBufferExceeded,
ErrorMessage: fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset),
}
}
if s.finished {
if highestOffset > s.highestOffset {
// reject crypto data received after this stream was already finished
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "received crypto data after change of encryption level",
}
}
// ignore data with a smaller offset than the highest received
// could e.g. be a retransmission
return nil
}
s.highestOffset = max(s.highestOffset, highestOffset)
return s.queue.Push(f.Data, f.Offset, nil)
}
// GetCryptoData retrieves data that was received in CRYPTO frames
func (s *baseCryptoStream) GetCryptoData() []byte {
_, data, _ := s.queue.Pop()
return data
}
func (s *baseCryptoStream) Finish() error {
if s.queue.HasMoreData() {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "encryption level changed, but crypto stream has more data to read",
}
}
s.finished = true
return nil
}
// Writes writes data that should be sent out in CRYPTO frames
func (s *baseCryptoStream) Write(p []byte) (int, error) {
s.writeBuf = append(s.writeBuf, p...)
return len(p), nil
}
func (s *baseCryptoStream) HasData() bool {
return len(s.writeBuf) > 0
}
func (s *baseCryptoStream) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
f := &wire.CryptoFrame{Offset: s.writeOffset}
n := min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
if n <= 0 {
return nil
}
f.Data = s.writeBuf[:n]
s.writeBuf = s.writeBuf[n:]
s.writeOffset += n
return f
}
type cryptoStream struct {
baseCryptoStream
}
type clientHelloCut struct {
start protocol.ByteCount
end protocol.ByteCount
}
type initialCryptoStream struct {
baseCryptoStream
scramble bool
end protocol.ByteCount
cuts [2]clientHelloCut
}
func newInitialCryptoStream(isClient bool) *initialCryptoStream {
var scramble bool
if isClient {
disabled, err := strconv.ParseBool(os.Getenv(disableClientHelloScramblingEnv))
scramble = err != nil || !disabled
}
s := &initialCryptoStream{
baseCryptoStream: baseCryptoStream{queue: *newFrameSorter()},
scramble: scramble,
}
for i := range len(s.cuts) {
s.cuts[i].start = protocol.InvalidByteCount
s.cuts[i].end = protocol.InvalidByteCount
}
return s
}
func (s *initialCryptoStream) HasData() bool {
// The ClientHello might be written in multiple parts.
// In order to correctly split the ClientHello, we need the entire ClientHello has been queued.
if s.scramble && s.writeOffset == 0 && s.cuts[0].start == protocol.InvalidByteCount {
return false
}
return s.baseCryptoStream.HasData()
}
func (s *initialCryptoStream) Write(p []byte) (int, error) {
s.writeBuf = append(s.writeBuf, p...)
if !s.scramble {
return len(p), nil
}
if s.cuts[0].start == protocol.InvalidByteCount {
sniPos, sniLen, echPos, err := findSNIAndECH(s.writeBuf)
if errors.Is(err, io.ErrUnexpectedEOF) {
return len(p), nil
}
if err != nil {
return len(p), err
}
if sniPos == -1 && echPos == -1 {
// Neither SNI nor ECH found.
// There's nothing to scramble.
s.scramble = false
return len(p), nil
}
s.end = protocol.ByteCount(len(s.writeBuf))
s.cuts[0].start = protocol.ByteCount(sniPos + sniLen/2) // right in the middle
s.cuts[0].end = protocol.ByteCount(sniPos + sniLen)
if echPos > 0 {
// ECH extension found, cut the ECH extension type value (a uint16) in half
start := protocol.ByteCount(echPos + 1)
s.cuts[1].start = start
// cut somewhere (16 bytes), most likely in the ECH extension value
s.cuts[1].end = min(start+16, s.end)
}
slices.SortFunc(s.cuts[:], func(a, b clientHelloCut) int {
if a.start == protocol.InvalidByteCount {
return 1
}
if a.start > b.start {
return 1
}
return -1
})
}
return len(p), nil
}
func (s *initialCryptoStream) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
if !s.scramble {
return s.baseCryptoStream.PopCryptoFrame(maxLen)
}
// send out the skipped parts
if s.writeOffset == s.end {
var foundCuts bool
var f *wire.CryptoFrame
for i, c := range s.cuts {
if c.start == protocol.InvalidByteCount {
continue
}
foundCuts = true
if f != nil {
break
}
f = &wire.CryptoFrame{Offset: c.start}
n := min(f.MaxDataLen(maxLen), c.end-c.start)
if n <= 0 {
return nil
}
f.Data = s.writeBuf[c.start : c.start+n]
s.cuts[i].start += n
if s.cuts[i].start == c.end {
s.cuts[i].start = protocol.InvalidByteCount
s.cuts[i].end = protocol.InvalidByteCount
foundCuts = false
}
}
if !foundCuts {
// no more cuts found, we're done sending out everything up until s.end
s.writeBuf = s.writeBuf[s.end:]
s.end = protocol.InvalidByteCount
s.scramble = false
}
return f
}
nextCut := clientHelloCut{start: protocol.InvalidByteCount, end: protocol.InvalidByteCount}
for _, c := range s.cuts {
if c.start == protocol.InvalidByteCount {
continue
}
if c.start > s.writeOffset {
nextCut = c
break
}
}
f := &wire.CryptoFrame{Offset: s.writeOffset}
maxOffset := nextCut.start
if maxOffset == protocol.InvalidByteCount {
maxOffset = s.end
}
n := min(f.MaxDataLen(maxLen), maxOffset-s.writeOffset)
if n <= 0 {
return nil
}
f.Data = s.writeBuf[s.writeOffset : s.writeOffset+n]
// Don't reslice the writeBuf yet.
// This is done once all parts have been sent out.
s.writeOffset += n
if s.writeOffset == nextCut.start {
s.writeOffset = nextCut.end
}
return f
}
package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
type cryptoStreamManager struct {
initialStream *initialCryptoStream
handshakeStream *cryptoStream
oneRTTStream *cryptoStream
}
func newCryptoStreamManager(
initialStream *initialCryptoStream,
handshakeStream *cryptoStream,
oneRTTStream *cryptoStream,
) *cryptoStreamManager {
return &cryptoStreamManager{
initialStream: initialStream,
handshakeStream: handshakeStream,
oneRTTStream: oneRTTStream,
}
}
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
return m.initialStream.HandleCryptoFrame(frame)
case protocol.EncryptionHandshake:
return m.handshakeStream.HandleCryptoFrame(frame)
case protocol.Encryption1RTT:
return m.oneRTTStream.HandleCryptoFrame(frame)
default:
return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
}
}
func (m *cryptoStreamManager) GetCryptoData(encLevel protocol.EncryptionLevel) []byte {
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
return m.initialStream.GetCryptoData()
case protocol.EncryptionHandshake:
return m.handshakeStream.GetCryptoData()
case protocol.Encryption1RTT:
return m.oneRTTStream.GetCryptoData()
default:
panic(fmt.Sprintf("received CRYPTO frame with unexpected encryption level: %s", encLevel))
}
}
func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame {
if !m.oneRTTStream.HasData() {
return nil
}
return m.oneRTTStream.PopCryptoFrame(maxSize)
}
func (m *cryptoStreamManager) Drop(encLevel protocol.EncryptionLevel) error {
//nolint:exhaustive // 1-RTT keys should never get dropped.
switch encLevel {
case protocol.EncryptionInitial:
return m.initialStream.Finish()
case protocol.EncryptionHandshake:
return m.handshakeStream.Finish()
default:
panic(fmt.Sprintf("dropped unexpected encryption level: %s", encLevel))
}
}
package quic
import (
"context"
"sync"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/utils/ringbuffer"
"github.com/quic-go/quic-go/internal/wire"
)
const (
maxDatagramSendQueueLen = 32
maxDatagramRcvQueueLen = 128
)
type datagramQueue struct {
sendMx sync.Mutex
sendQueue ringbuffer.RingBuffer[*wire.DatagramFrame]
sent chan struct{} // used to notify Add that a datagram was dequeued
rcvMx sync.Mutex
rcvQueue [][]byte
rcvd chan struct{} // used to notify Receive that a new datagram was received
closeErr error
closed chan struct{}
hasData func()
logger utils.Logger
}
func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue {
return &datagramQueue{
hasData: hasData,
rcvd: make(chan struct{}, 1),
sent: make(chan struct{}, 1),
closed: make(chan struct{}),
logger: logger,
}
}
// Add queues a new DATAGRAM frame for sending.
// Up to 32 DATAGRAM frames will be queued.
// Once that limit is reached, Add blocks until the queue size has reduced.
func (h *datagramQueue) Add(f *wire.DatagramFrame) error {
h.sendMx.Lock()
for {
if h.sendQueue.Len() < maxDatagramSendQueueLen {
h.sendQueue.PushBack(f)
h.sendMx.Unlock()
h.hasData()
return nil
}
select {
case <-h.sent: // drain the queue so we don't loop immediately
default:
}
h.sendMx.Unlock()
select {
case <-h.closed:
return h.closeErr
case <-h.sent:
}
h.sendMx.Lock()
}
}
// Peek gets the next DATAGRAM frame for sending.
// If actually sent out, Pop needs to be called before the next call to Peek.
func (h *datagramQueue) Peek() *wire.DatagramFrame {
h.sendMx.Lock()
defer h.sendMx.Unlock()
if h.sendQueue.Empty() {
return nil
}
return h.sendQueue.PeekFront()
}
func (h *datagramQueue) Pop() {
h.sendMx.Lock()
defer h.sendMx.Unlock()
_ = h.sendQueue.PopFront()
select {
case h.sent <- struct{}{}:
default:
}
}
// HandleDatagramFrame handles a received DATAGRAM frame.
func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) {
data := make([]byte, len(f.Data))
copy(data, f.Data)
var queued bool
h.rcvMx.Lock()
if len(h.rcvQueue) < maxDatagramRcvQueueLen {
h.rcvQueue = append(h.rcvQueue, data)
queued = true
select {
case h.rcvd <- struct{}{}:
default:
}
}
h.rcvMx.Unlock()
if !queued && h.logger.Debug() {
h.logger.Debugf("Discarding received DATAGRAM frame (%d bytes payload)", len(f.Data))
}
}
// Receive gets a received DATAGRAM frame.
func (h *datagramQueue) Receive(ctx context.Context) ([]byte, error) {
for {
h.rcvMx.Lock()
if len(h.rcvQueue) > 0 {
data := h.rcvQueue[0]
h.rcvQueue = h.rcvQueue[1:]
h.rcvMx.Unlock()
return data, nil
}
h.rcvMx.Unlock()
select {
case <-h.rcvd:
continue
case <-h.closed:
return nil, h.closeErr
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
func (h *datagramQueue) CloseWithError(e error) {
h.closeErr = e
close(h.closed)
}
package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/qerr"
)
type (
// TransportError indicates an error that occurred on the QUIC transport layer.
// Every transport error other than CONNECTION_REFUSED and APPLICATION_ERROR is
// likely a bug in the implementation.
TransportError = qerr.TransportError
// ApplicationError is an application-defined error.
ApplicationError = qerr.ApplicationError
// VersionNegotiationError indicates a failure to negotiate a QUIC version.
VersionNegotiationError = qerr.VersionNegotiationError
// StatelessResetError indicates a stateless reset was received.
// This can happen when the peer reboots, or when packets are misrouted.
// See section 10.3 of RFC 9000 for details.
StatelessResetError = qerr.StatelessResetError
// IdleTimeoutError indicates that the connection timed out because it was inactive for too long.
IdleTimeoutError = qerr.IdleTimeoutError
// HandshakeTimeoutError indicates that the connection timed out before completing the handshake.
HandshakeTimeoutError = qerr.HandshakeTimeoutError
)
type (
// TransportErrorCode is a QUIC transport error code, see section 20 of RFC 9000.
TransportErrorCode = qerr.TransportErrorCode
// ApplicationErrorCode is an QUIC application error code.
ApplicationErrorCode = qerr.ApplicationErrorCode
// StreamErrorCode is a QUIC stream error code. The meaning of the value is defined by the application.
StreamErrorCode = qerr.StreamErrorCode
)
const (
// NoError is the NO_ERROR transport error code.
NoError = qerr.NoError
// InternalError is the INTERNAL_ERROR transport error code.
InternalError = qerr.InternalError
// ConnectionRefused is the CONNECTION_REFUSED transport error code.
ConnectionRefused = qerr.ConnectionRefused
// FlowControlError is the FLOW_CONTROL_ERROR transport error code.
FlowControlError = qerr.FlowControlError
// StreamLimitError is the STREAM_LIMIT_ERROR transport error code.
StreamLimitError = qerr.StreamLimitError
// StreamStateError is the STREAM_STATE_ERROR transport error code.
StreamStateError = qerr.StreamStateError
// FinalSizeError is the FINAL_SIZE_ERROR transport error code.
FinalSizeError = qerr.FinalSizeError
// FrameEncodingError is the FRAME_ENCODING_ERROR transport error code.
FrameEncodingError = qerr.FrameEncodingError
// TransportParameterError is the TRANSPORT_PARAMETER_ERROR transport error code.
TransportParameterError = qerr.TransportParameterError
// ConnectionIDLimitError is the CONNECTION_ID_LIMIT_ERROR transport error code.
ConnectionIDLimitError = qerr.ConnectionIDLimitError
// ProtocolViolation is the PROTOCOL_VIOLATION transport error code.
ProtocolViolation = qerr.ProtocolViolation
// InvalidToken is the INVALID_TOKEN transport error code.
InvalidToken = qerr.InvalidToken
// ApplicationErrorErrorCode is the APPLICATION_ERROR transport error code.
ApplicationErrorErrorCode = qerr.ApplicationErrorErrorCode
// CryptoBufferExceeded is the CRYPTO_BUFFER_EXCEEDED transport error code.
CryptoBufferExceeded = qerr.CryptoBufferExceeded
// KeyUpdateError is the KEY_UPDATE_ERROR transport error code.
KeyUpdateError = qerr.KeyUpdateError
// AEADLimitReached is the AEAD_LIMIT_REACHED transport error code.
AEADLimitReached = qerr.AEADLimitReached
// NoViablePathError is the NO_VIABLE_PATH_ERROR transport error code.
NoViablePathError = qerr.NoViablePathError
)
// A StreamError is used to signal stream cancellations.
// It is returned from the Read and Write methods of the [ReceiveStream], [SendStream] and [Stream].
type StreamError struct {
StreamID StreamID
ErrorCode StreamErrorCode
Remote bool
}
func (e *StreamError) Is(target error) bool {
t, ok := target.(*StreamError)
return ok && e.StreamID == t.StreamID && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote
}
func (e *StreamError) Error() string {
pers := "local"
if e.Remote {
pers = "remote"
}
return fmt.Sprintf("stream %d canceled by %s with error code %d", e.StreamID, pers, e.ErrorCode)
}
// DatagramTooLargeError is returned from Conn.SendDatagram if the payload is too large to be sent.
type DatagramTooLargeError struct {
MaxDatagramPayloadSize int64
}
func (e *DatagramTooLargeError) Is(target error) bool {
t, ok := target.(*DatagramTooLargeError)
return ok && e.MaxDatagramPayloadSize == t.MaxDatagramPayloadSize
}
func (e *DatagramTooLargeError) Error() string { return "DATAGRAM frame too large" }
package quic
import (
"errors"
"sync"
"github.com/quic-go/quic-go/internal/protocol"
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
)
// byteInterval is an interval from one ByteCount to the other
type byteInterval struct {
Start protocol.ByteCount
End protocol.ByteCount
}
var byteIntervalElementPool sync.Pool
func init() {
byteIntervalElementPool = *list.NewPool[byteInterval]()
}
type frameSorterEntry struct {
Data []byte
DoneCb func()
}
type frameSorter struct {
queue map[protocol.ByteCount]frameSorterEntry
readPos protocol.ByteCount
gaps *list.List[byteInterval]
}
var errDuplicateStreamData = errors.New("duplicate stream data")
func newFrameSorter() *frameSorter {
s := frameSorter{
gaps: list.NewWithPool[byteInterval](&byteIntervalElementPool),
queue: make(map[protocol.ByteCount]frameSorterEntry),
}
s.gaps.PushFront(byteInterval{Start: 0, End: protocol.MaxByteCount})
return &s
}
func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, doneCb func()) error {
err := s.push(data, offset, doneCb)
if err == errDuplicateStreamData {
if doneCb != nil {
doneCb()
}
return nil
}
return err
}
func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func()) error {
if len(data) == 0 {
return errDuplicateStreamData
}
start := offset
end := offset + protocol.ByteCount(len(data))
if end <= s.gaps.Front().Value.Start {
return errDuplicateStreamData
}
startGap, startsInGap := s.findStartGap(start)
endGap, endsInGap := s.findEndGap(startGap, end)
startGapEqualsEndGap := startGap == endGap
if (startGapEqualsEndGap && end <= startGap.Value.Start) ||
(!startGapEqualsEndGap && startGap.Value.End >= endGap.Value.Start && end <= startGap.Value.Start) {
return errDuplicateStreamData
}
startGapNext := startGap.Next()
startGapEnd := startGap.Value.End // save it, in case startGap is modified
endGapStart := endGap.Value.Start // save it, in case endGap is modified
endGapEnd := endGap.Value.End // save it, in case endGap is modified
var adjustedStartGapEnd bool
var wasCut bool
pos := start
var hasReplacedAtLeastOne bool
for {
oldEntry, ok := s.queue[pos]
if !ok {
break
}
oldEntryLen := protocol.ByteCount(len(oldEntry.Data))
if end-pos > oldEntryLen || (hasReplacedAtLeastOne && end-pos == oldEntryLen) {
// The existing frame is shorter than the new frame. Replace it.
delete(s.queue, pos)
pos += oldEntryLen
hasReplacedAtLeastOne = true
if oldEntry.DoneCb != nil {
oldEntry.DoneCb()
}
} else {
if !hasReplacedAtLeastOne {
return errDuplicateStreamData
}
// The existing frame is longer than the new frame.
// Cut the new frame such that the end aligns with the start of the existing frame.
data = data[:pos-start]
end = pos
wasCut = true
break
}
}
if !startsInGap && !hasReplacedAtLeastOne {
// cut the frame, such that it starts at the start of the gap
data = data[startGap.Value.Start-start:]
start = startGap.Value.Start
wasCut = true
}
if start <= startGap.Value.Start {
if end >= startGap.Value.End {
// The frame covers the whole startGap. Delete the gap.
s.gaps.Remove(startGap)
} else {
startGap.Value.Start = end
}
} else if !hasReplacedAtLeastOne {
startGap.Value.End = start
adjustedStartGapEnd = true
}
if !startGapEqualsEndGap {
s.deleteConsecutive(startGapEnd)
var nextGap *list.Element[byteInterval]
for gap := startGapNext; gap.Value.End < endGapStart; gap = nextGap {
nextGap = gap.Next()
s.deleteConsecutive(gap.Value.End)
s.gaps.Remove(gap)
}
}
if !endsInGap && start != endGapEnd && end > endGapEnd {
// cut the frame, such that it ends at the end of the gap
data = data[:endGapEnd-start]
end = endGapEnd
wasCut = true
}
if end == endGapEnd {
if !startGapEqualsEndGap {
// The frame covers the whole endGap. Delete the gap.
s.gaps.Remove(endGap)
}
} else {
if startGapEqualsEndGap && adjustedStartGapEnd {
// The frame split the existing gap into two.
s.gaps.InsertAfter(byteInterval{Start: end, End: startGapEnd}, startGap)
} else if !startGapEqualsEndGap {
endGap.Value.Start = end
}
}
if wasCut && len(data) < protocol.MinStreamFrameBufferSize {
newData := make([]byte, len(data))
copy(newData, data)
data = newData
if doneCb != nil {
doneCb()
doneCb = nil
}
}
if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps {
return errors.New("too many gaps in received data")
}
s.queue[start] = frameSorterEntry{Data: data, DoneCb: doneCb}
return nil
}
func (s *frameSorter) findStartGap(offset protocol.ByteCount) (*list.Element[byteInterval], bool) {
for gap := s.gaps.Front(); gap != nil; gap = gap.Next() {
if offset >= gap.Value.Start && offset <= gap.Value.End {
return gap, true
}
if offset < gap.Value.Start {
return gap, false
}
}
panic("no gap found")
}
func (s *frameSorter) findEndGap(startGap *list.Element[byteInterval], offset protocol.ByteCount) (*list.Element[byteInterval], bool) {
for gap := startGap; gap != nil; gap = gap.Next() {
if offset >= gap.Value.Start && offset < gap.Value.End {
return gap, true
}
if offset < gap.Value.Start {
return gap.Prev(), false
}
}
panic("no gap found")
}
// deleteConsecutive deletes consecutive frames from the queue, starting at pos
func (s *frameSorter) deleteConsecutive(pos protocol.ByteCount) {
for {
oldEntry, ok := s.queue[pos]
if !ok {
break
}
oldEntryLen := protocol.ByteCount(len(oldEntry.Data))
delete(s.queue, pos)
if oldEntry.DoneCb != nil {
oldEntry.DoneCb()
}
pos += oldEntryLen
}
}
func (s *frameSorter) Pop() (protocol.ByteCount, []byte, func()) {
entry, ok := s.queue[s.readPos]
if !ok {
return s.readPos, nil, nil
}
delete(s.queue, s.readPos)
offset := s.readPos
s.readPos += protocol.ByteCount(len(entry.Data))
if s.gaps.Front().Value.End <= s.readPos {
panic("frame sorter BUG: read position higher than a gap")
}
return offset, entry.Data, entry.DoneCb
}
// HasMoreData says if there is any more data queued at *any* offset.
func (s *frameSorter) HasMoreData() bool {
return len(s.queue) > 0
}
package quic
import (
"slices"
"sync"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils/ringbuffer"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/quicvarint"
)
const (
maxPathResponses = 256
maxControlFrames = 16 << 10
)
// This is the largest possible size of a stream-related control frame
// (which is the RESET_STREAM frame).
const maxStreamControlFrameSize = 25
type streamFrameGetter interface {
popStreamFrame(protocol.ByteCount, protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame, bool)
}
type streamControlFrameGetter interface {
getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool)
}
type framer struct {
mutex sync.Mutex
activeStreams map[protocol.StreamID]streamFrameGetter
streamQueue ringbuffer.RingBuffer[protocol.StreamID]
streamsWithControlFrames map[protocol.StreamID]streamControlFrameGetter
controlFrameMutex sync.Mutex
controlFrames []wire.Frame
pathResponses []*wire.PathResponseFrame
connFlowController flowcontrol.ConnectionFlowController
queuedTooManyControlFrames bool
}
func newFramer(connFlowController flowcontrol.ConnectionFlowController) *framer {
return &framer{
activeStreams: make(map[protocol.StreamID]streamFrameGetter),
streamsWithControlFrames: make(map[protocol.StreamID]streamControlFrameGetter),
connFlowController: connFlowController,
}
}
func (f *framer) HasData() bool {
f.mutex.Lock()
hasData := !f.streamQueue.Empty()
f.mutex.Unlock()
if hasData {
return true
}
f.controlFrameMutex.Lock()
defer f.controlFrameMutex.Unlock()
return len(f.streamsWithControlFrames) > 0 || len(f.controlFrames) > 0 || len(f.pathResponses) > 0
}
func (f *framer) QueueControlFrame(frame wire.Frame) {
f.controlFrameMutex.Lock()
defer f.controlFrameMutex.Unlock()
if pr, ok := frame.(*wire.PathResponseFrame); ok {
// Only queue up to maxPathResponses PATH_RESPONSE frames.
// This limit should be high enough to never be hit in practice,
// unless the peer is doing something malicious.
if len(f.pathResponses) >= maxPathResponses {
return
}
f.pathResponses = append(f.pathResponses, pr)
return
}
// This is a hack.
if len(f.controlFrames) >= maxControlFrames {
f.queuedTooManyControlFrames = true
return
}
f.controlFrames = append(f.controlFrames, frame)
}
func (f *framer) Append(
frames []ackhandler.Frame,
streamFrames []ackhandler.StreamFrame,
maxLen protocol.ByteCount,
now time.Time,
v protocol.Version,
) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount) {
f.controlFrameMutex.Lock()
frames, controlFrameLen := f.appendControlFrames(frames, maxLen, now, v)
maxLen -= controlFrameLen
var lastFrame ackhandler.StreamFrame
var streamFrameLen protocol.ByteCount
f.mutex.Lock()
// pop STREAM frames, until less than 128 bytes are left in the packet
numActiveStreams := f.streamQueue.Len()
for i := 0; i < numActiveStreams; i++ {
if protocol.MinStreamFrameSize > maxLen {
break
}
sf, blocked := f.getNextStreamFrame(maxLen, v)
if sf.Frame != nil {
streamFrames = append(streamFrames, sf)
maxLen -= sf.Frame.Length(v)
lastFrame = sf
streamFrameLen += sf.Frame.Length(v)
}
// If the stream just became blocked on stream flow control, attempt to pack the
// STREAM_DATA_BLOCKED into the same packet.
if blocked != nil {
l := blocked.Length(v)
// In case it doesn't fit, queue it for the next packet.
if maxLen < l {
f.controlFrames = append(f.controlFrames, blocked)
break
}
frames = append(frames, ackhandler.Frame{Frame: blocked})
maxLen -= l
controlFrameLen += l
}
}
// The only way to become blocked on connection-level flow control is by sending STREAM frames.
if isBlocked, offset := f.connFlowController.IsNewlyBlocked(); isBlocked {
blocked := &wire.DataBlockedFrame{MaximumData: offset}
l := blocked.Length(v)
// In case it doesn't fit, queue it for the next packet.
if maxLen >= l {
frames = append(frames, ackhandler.Frame{Frame: blocked})
controlFrameLen += l
} else {
f.controlFrames = append(f.controlFrames, blocked)
}
}
f.mutex.Unlock()
f.controlFrameMutex.Unlock()
if lastFrame.Frame != nil {
// account for the smaller size of the last STREAM frame
streamFrameLen -= lastFrame.Frame.Length(v)
lastFrame.Frame.DataLenPresent = false
streamFrameLen += lastFrame.Frame.Length(v)
}
return frames, streamFrames, controlFrameLen + streamFrameLen
}
func (f *framer) appendControlFrames(
frames []ackhandler.Frame,
maxLen protocol.ByteCount,
now time.Time,
v protocol.Version,
) ([]ackhandler.Frame, protocol.ByteCount) {
var length protocol.ByteCount
// add a PATH_RESPONSE first, but only pack a single PATH_RESPONSE per packet
if len(f.pathResponses) > 0 {
frame := f.pathResponses[0]
frameLen := frame.Length(v)
if frameLen <= maxLen {
frames = append(frames, ackhandler.Frame{Frame: frame})
length += frameLen
f.pathResponses = f.pathResponses[1:]
}
}
// add stream-related control frames
for id, str := range f.streamsWithControlFrames {
start:
remainingLen := maxLen - length
if remainingLen <= maxStreamControlFrameSize {
break
}
fr, ok, hasMore := str.getControlFrame(now)
if !hasMore {
delete(f.streamsWithControlFrames, id)
}
if !ok {
continue
}
frames = append(frames, fr)
length += fr.Frame.Length(v)
if hasMore {
// It is rare that a stream has more than one control frame to queue.
// We don't want to spawn another loop for just to cover that case.
goto start
}
}
for len(f.controlFrames) > 0 {
frame := f.controlFrames[len(f.controlFrames)-1]
frameLen := frame.Length(v)
if length+frameLen > maxLen {
break
}
frames = append(frames, ackhandler.Frame{Frame: frame})
length += frameLen
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
}
return frames, length
}
// QueuedTooManyControlFrames says if the control frame queue exceeded its maximum queue length.
// This is a hack.
// It is easier to implement than propagating an error return value in QueueControlFrame.
// The correct solution would be to queue frames with their respective structs.
// See https://github.com/quic-go/quic-go/issues/4271 for the queueing of stream-related control frames.
func (f *framer) QueuedTooManyControlFrames() bool {
return f.queuedTooManyControlFrames
}
func (f *framer) AddActiveStream(id protocol.StreamID, str streamFrameGetter) {
f.mutex.Lock()
if _, ok := f.activeStreams[id]; !ok {
f.streamQueue.PushBack(id)
f.activeStreams[id] = str
}
f.mutex.Unlock()
}
func (f *framer) AddStreamWithControlFrames(id protocol.StreamID, str streamControlFrameGetter) {
f.controlFrameMutex.Lock()
if _, ok := f.streamsWithControlFrames[id]; !ok {
f.streamsWithControlFrames[id] = str
}
f.controlFrameMutex.Unlock()
}
// RemoveActiveStream is called when a stream completes.
func (f *framer) RemoveActiveStream(id protocol.StreamID) {
f.mutex.Lock()
delete(f.activeStreams, id)
// We don't delete the stream from the streamQueue,
// since we'd have to iterate over the ringbuffer.
// Instead, we check if the stream is still in activeStreams when appending STREAM frames.
f.mutex.Unlock()
}
func (f *framer) getNextStreamFrame(maxLen protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, *wire.StreamDataBlockedFrame) {
id := f.streamQueue.PopFront()
// This should never return an error. Better check it anyway.
// The stream will only be in the streamQueue, if it enqueued itself there.
str, ok := f.activeStreams[id]
// The stream might have been removed after being enqueued.
if !ok {
return ackhandler.StreamFrame{}, nil
}
// For the last STREAM frame, we'll remove the DataLen field later.
// Therefore, we can pretend to have more bytes available when popping
// the STREAM frame (which will always have the DataLen set).
maxLen += protocol.ByteCount(quicvarint.Len(uint64(maxLen)))
frame, blocked, hasMoreData := str.popStreamFrame(maxLen, v)
if hasMoreData { // put the stream back in the queue (at the end)
f.streamQueue.PushBack(id)
} else { // no more data to send. Stream is not active
delete(f.activeStreams, id)
}
// Note that the frame.Frame can be nil:
// * if the stream was canceled after it said it had data
// * the remaining size doesn't allow us to add another STREAM frame
return frame, blocked
}
func (f *framer) Handle0RTTRejection() {
f.mutex.Lock()
defer f.mutex.Unlock()
f.controlFrameMutex.Lock()
defer f.controlFrameMutex.Unlock()
f.streamQueue.Clear()
for id := range f.activeStreams {
delete(f.activeStreams, id)
}
var j int
for i, frame := range f.controlFrames {
switch frame.(type) {
case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame,
*wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame:
continue
default:
f.controlFrames[j] = f.controlFrames[i]
j++
}
}
f.controlFrames = slices.Delete(f.controlFrames, j, len(f.controlFrames))
}
package frames
import (
"fmt"
"io"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
const version = protocol.Version1
// PrefixLen is the number of bytes used for configuration
const PrefixLen = 1
func toEncLevel(v uint8) protocol.EncryptionLevel {
switch v % 3 {
default:
return protocol.EncryptionInitial
case 1:
return protocol.EncryptionHandshake
case 2:
return protocol.Encryption1RTT
}
}
// Fuzz fuzzes the QUIC frames.
//
//go:generate go run ./cmd/corpus.go
func Fuzz(data []byte) int {
if len(data) < PrefixLen {
return 0
}
encLevel := toEncLevel(data[0])
data = data[PrefixLen:]
parser := wire.NewFrameParser(true, true, true)
parser.SetAckDelayExponent(protocol.DefaultAckDelayExponent)
var numFrames int
var b []byte
for len(data) > 0 {
initialLen := len(data)
frameType, l, err := parser.ParseType(data, encLevel)
if err != nil {
if err == io.EOF { // the last frame was a PADDING frame
break
}
break
}
data = data[l:]
numFrames++
var f wire.Frame
switch {
case frameType.IsStreamFrameType():
f, l, err = parser.ParseStreamFrame(frameType, data, version)
case frameType == wire.FrameTypeAck || frameType == wire.FrameTypeAckECN:
f, l, err = parser.ParseAckFrame(frameType, data, encLevel, version)
case frameType == wire.FrameTypeDatagramNoLength || frameType == wire.FrameTypeDatagramWithLength:
f, l, err = parser.ParseDatagramFrame(frameType, data, version)
default:
f, l, err = parser.ParseLessCommonFrame(frameType, data, version)
}
if err != nil {
break
}
data = data[l:]
wire.IsProbingFrame(f)
ackhandler.IsFrameAckEliciting(f)
// We accept empty STREAM frames, but we don't write them.
if sf, ok := f.(*wire.StreamFrame); ok {
if sf.DataLen() == 0 {
sf.PutBack()
continue
}
}
validateFrame(f)
startLen := len(b)
parsedLen := initialLen - len(data)
b, err = f.Append(b, version)
if err != nil {
panic(fmt.Sprintf("error writing frame %#v: %s", f, err))
}
frameLen := protocol.ByteCount(len(b) - startLen)
if f.Length(version) != frameLen {
panic(fmt.Sprintf("inconsistent frame length for %#v: expected %d, got %d", f, frameLen, f.Length(version)))
}
if sf, ok := f.(*wire.StreamFrame); ok {
sf.PutBack()
}
if frameLen > protocol.ByteCount(parsedLen) {
panic(fmt.Sprintf("serialized length (%d) is longer than parsed length (%d)", len(b), parsedLen))
}
}
if numFrames == 0 {
return 0
}
return 1
}
func validateFrame(frame wire.Frame) {
switch f := frame.(type) {
case *wire.StreamFrame:
if protocol.ByteCount(len(f.Data)) != f.DataLen() {
panic("STREAM frame: inconsistent data length")
}
case *wire.AckFrame:
if f.DelayTime < 0 {
panic(fmt.Sprintf("invalid ACK delay_time: %s", f.DelayTime))
}
if f.LargestAcked() < f.LowestAcked() {
panic("ACK: largest acknowledged is smaller than lowest acknowledged")
}
for _, r := range f.AckRanges {
if r.Largest < 0 || r.Smallest < 0 {
panic("ACK range contains a negative packet number")
}
}
if !f.AcksPacket(f.LargestAcked()) {
panic("ACK frame claims that largest acknowledged is not acknowledged")
}
if !f.AcksPacket(f.LowestAcked()) {
panic("ACK frame claims that lowest acknowledged is not acknowledged")
}
_ = f.AcksPacket(100)
_ = f.AcksPacket((f.LargestAcked() + f.LowestAcked()) / 2)
case *wire.NewConnectionIDFrame:
if f.ConnectionID.Len() < 1 || f.ConnectionID.Len() > 20 {
panic(fmt.Sprintf("invalid NEW_CONNECTION_ID frame length: %s", f.ConnectionID))
}
case *wire.NewTokenFrame:
if len(f.Token) == 0 {
panic("NEW_TOKEN frame with an empty token")
}
case *wire.MaxStreamsFrame:
if f.MaxStreamNum > protocol.MaxStreamCount {
panic("MAX_STREAMS frame with an invalid Maximum Streams value")
}
case *wire.StreamsBlockedFrame:
if f.StreamLimit > protocol.MaxStreamCount {
panic("STREAMS_BLOCKED frame with an invalid Maximum Streams value")
}
case *wire.ConnectionCloseFrame:
if f.IsApplicationError && f.FrameType != 0 {
panic("CONNECTION_CLOSE for an application error containing a frame type")
}
case *wire.ResetStreamFrame:
if f.FinalSize < f.ReliableSize {
panic("RESET_STREAM frame with a FinalSize smaller than the ReliableSize")
}
case *wire.AckFrequencyFrame:
if f.RequestMaxAckDelay < 0 {
panic("ACK_FREQUENCY frame with a negative RequestMaxAckDelay")
}
}
}
package handshake
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"log"
"math"
mrand "math/rand/v2"
"net"
"time"
"github.com/quic-go/quic-go/fuzzing/internal/helper"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
)
var (
cert, clientCert *tls.Certificate
certPool, clientCertPool *x509.CertPool
sessionTicketKey = [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
)
func init() {
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
log.Fatal(err)
}
cert, certPool, err = helper.GenerateCertificate(priv)
if err != nil {
log.Fatal(err)
}
_, privClient, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
log.Fatal(err)
}
clientCert, clientCertPool, err = helper.GenerateCertificate(privClient)
if err != nil {
log.Fatal(err)
}
}
type messageType uint8
// TLS handshake message types.
const (
typeClientHello messageType = 1
typeServerHello messageType = 2
typeNewSessionTicket messageType = 4
typeEncryptedExtensions messageType = 8
typeCertificate messageType = 11
typeCertificateRequest messageType = 13
typeCertificateVerify messageType = 15
typeFinished messageType = 20
)
func (m messageType) String() string {
switch m {
case typeClientHello:
return "ClientHello"
case typeServerHello:
return "ServerHello"
case typeNewSessionTicket:
return "NewSessionTicket"
case typeEncryptedExtensions:
return "EncryptedExtensions"
case typeCertificate:
return "Certificate"
case typeCertificateRequest:
return "CertificateRequest"
case typeCertificateVerify:
return "CertificateVerify"
case typeFinished:
return "Finished"
default:
return fmt.Sprintf("unknown message type: %d", m)
}
}
// consumes 3 bits
func getClientAuth(rand uint8) tls.ClientAuthType {
switch rand {
default:
return tls.NoClientCert
case 0:
return tls.RequestClientCert
case 1:
return tls.RequireAnyClientCert
case 2:
return tls.VerifyClientCertIfGiven
case 3:
return tls.RequireAndVerifyClientCert
}
}
const (
alpn = "fuzzing"
alpnWrong = "wrong"
)
func toEncryptionLevel(n uint8) protocol.EncryptionLevel {
switch n % 3 {
default:
return protocol.EncryptionInitial
case 1:
return protocol.EncryptionHandshake
case 2:
return protocol.Encryption1RTT
}
}
func getTransportParameters(seed uint8) *wire.TransportParameters {
const maxVarInt = math.MaxUint64 / 4
r := mrand.New(mrand.NewPCG(uint64(seed), uint64(seed)))
return &wire.TransportParameters{
ActiveConnectionIDLimit: 2,
InitialMaxData: protocol.ByteCount(r.Uint64() % maxVarInt),
InitialMaxStreamDataBidiLocal: protocol.ByteCount(r.Uint64() % maxVarInt),
InitialMaxStreamDataBidiRemote: protocol.ByteCount(r.Uint64() % maxVarInt),
InitialMaxStreamDataUni: protocol.ByteCount(r.Uint64() % maxVarInt),
}
}
// PrefixLen is the number of bytes used for configuration
const (
PrefixLen = 12
confLen = 5
)
// Fuzz fuzzes the TLS 1.3 handshake used by QUIC.
//
//go:generate go run ./cmd/corpus.go
func Fuzz(data []byte) int {
if len(data) < PrefixLen {
return -1
}
dataLen := len(data)
var runConfig1, runConfig2 [confLen]byte
copy(runConfig1[:], data)
data = data[confLen:]
messageConfig1 := data[0]
data = data[1:]
copy(runConfig2[:], data)
data = data[confLen:]
messageConfig2 := data[0]
data = data[1:]
if dataLen != len(data)+PrefixLen {
panic("incorrect configuration")
}
clientConf := &tls.Config{
MinVersion: tls.VersionTLS13,
ServerName: "localhost",
NextProtos: []string{alpn},
RootCAs: certPool,
}
useSessionTicketCache := helper.NthBit(runConfig1[0], 2)
if useSessionTicketCache {
clientConf.ClientSessionCache = tls.NewLRUClientSessionCache(5)
}
if val := runHandshake(runConfig1, messageConfig1, clientConf, data); val != 1 {
return val
}
return runHandshake(runConfig2, messageConfig2, clientConf, data)
}
func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.Config, data []byte) int {
serverConf := &tls.Config{
MinVersion: tls.VersionTLS13,
Certificates: []tls.Certificate{*cert},
NextProtos: []string{alpn},
SessionTicketKey: sessionTicketKey,
}
// This sets the cipher suite for both client and server.
// The way crypto/tls is designed doesn't allow us to set different cipher suites for client and server.
resetCipherSuite := func() {}
switch (runConfig[0] >> 6) % 4 {
case 0:
resetCipherSuite = qtls.SetCipherSuite(tls.TLS_AES_128_GCM_SHA256)
case 1:
resetCipherSuite = qtls.SetCipherSuite(tls.TLS_AES_256_GCM_SHA384)
case 3:
resetCipherSuite = qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256)
default:
}
defer resetCipherSuite()
enable0RTTClient := helper.NthBit(runConfig[0], 0)
enable0RTTServer := helper.NthBit(runConfig[0], 1)
sendPostHandshakeMessageToClient := helper.NthBit(runConfig[0], 3)
sendPostHandshakeMessageToServer := helper.NthBit(runConfig[0], 4)
sendSessionTicket := helper.NthBit(runConfig[0], 5)
serverConf.ClientAuth = getClientAuth(runConfig[1] & 0b00000111)
serverConf.SessionTicketsDisabled = helper.NthBit(runConfig[1], 3)
if helper.NthBit(runConfig[2], 0) {
clientConf.RootCAs = x509.NewCertPool()
}
if helper.NthBit(runConfig[2], 1) {
serverConf.ClientCAs = clientCertPool
} else {
serverConf.ClientCAs = x509.NewCertPool()
}
if helper.NthBit(runConfig[2], 2) {
serverConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
if helper.NthBit(runConfig[2], 3) {
return nil, errors.New("getting client config failed")
}
if helper.NthBit(runConfig[2], 4) {
return nil, nil
}
return serverConf, nil
}
}
if helper.NthBit(runConfig[2], 5) {
serverConf.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
if helper.NthBit(runConfig[2], 6) {
return nil, errors.New("getting certificate failed")
}
if helper.NthBit(runConfig[2], 7) {
return nil, nil
}
return clientCert, nil // this certificate will be invalid
}
}
if helper.NthBit(runConfig[3], 0) {
serverConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
if helper.NthBit(runConfig[3], 1) {
return errors.New("certificate verification failed")
}
return nil
}
}
if helper.NthBit(runConfig[3], 2) {
clientConf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
if helper.NthBit(runConfig[3], 3) {
return errors.New("certificate verification failed")
}
return nil
}
}
if helper.NthBit(runConfig[3], 4) {
serverConf.NextProtos = []string{alpnWrong}
}
if helper.NthBit(runConfig[3], 5) {
serverConf.NextProtos = []string{alpnWrong, alpn}
}
if helper.NthBit(runConfig[3], 6) {
serverConf.KeyLogWriter = io.Discard
}
if helper.NthBit(runConfig[3], 7) {
clientConf.KeyLogWriter = io.Discard
}
clientTP := getTransportParameters(runConfig[4] & 0x3)
if helper.NthBit(runConfig[4], 3) {
clientTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5
}
serverTP := getTransportParameters(runConfig[4] & 0b00011000)
if helper.NthBit(runConfig[4], 3) {
serverTP.MaxAckDelay = protocol.MaxMaxAckDelay + 5
}
messageToReplace := messageConfig % 32
messageToReplaceEncLevel := toEncryptionLevel(messageConfig >> 6)
if len(data) == 0 {
return -1
}
client := handshake.NewCryptoSetupClient(
protocol.ConnectionID{},
clientTP,
clientConf,
enable0RTTClient,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("client"),
protocol.Version1,
)
if err := client.StartHandshake(context.Background()); err != nil {
log.Fatal(err)
}
defer client.Close()
server := handshake.NewCryptoSetupServer(
protocol.ConnectionID{},
&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
serverTP,
serverConf,
enable0RTTServer,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
protocol.Version1,
)
if err := server.StartHandshake(context.Background()); err != nil {
log.Fatal(err)
}
defer server.Close()
var clientHandshakeComplete, serverHandshakeComplete bool
for {
var processedEvent bool
clientLoop:
for {
ev := client.NextEvent()
//nolint:exhaustive // only need to process a few events
switch ev.Kind {
case handshake.EventNoEvent:
if !processedEvent && !clientHandshakeComplete { // handshake stuck
return 1
}
break clientLoop
case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData:
msg := ev.Data
encLevel := protocol.EncryptionInitial
if ev.Kind == handshake.EventWriteHandshakeData {
encLevel = protocol.EncryptionHandshake
}
if msg[0] == messageToReplace {
fmt.Printf("replacing %s message to the server with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel)
msg = data
encLevel = messageToReplaceEncLevel
}
if err := server.HandleMessage(msg, encLevel); err != nil {
return 1
}
case handshake.EventHandshakeComplete:
clientHandshakeComplete = true
}
processedEvent = true
}
processedEvent = false
serverLoop:
for {
ev := server.NextEvent()
//nolint:exhaustive // only need to process a few events
switch ev.Kind {
case handshake.EventNoEvent:
if !processedEvent && !serverHandshakeComplete { // handshake stuck
return 1
}
break serverLoop
case handshake.EventWriteInitialData, handshake.EventWriteHandshakeData:
encLevel := protocol.EncryptionInitial
if ev.Kind == handshake.EventWriteHandshakeData {
encLevel = protocol.EncryptionHandshake
}
msg := ev.Data
if msg[0] == messageToReplace {
fmt.Printf("replacing %s message to the client with %s at %s\n", messageType(msg[0]), messageType(data[0]), messageToReplaceEncLevel)
msg = data
encLevel = messageToReplaceEncLevel
}
if err := client.HandleMessage(msg, encLevel); err != nil {
return 1
}
case handshake.EventHandshakeComplete:
serverHandshakeComplete = true
}
processedEvent = true
}
if serverHandshakeComplete && clientHandshakeComplete {
break
}
}
_ = client.ConnectionState()
_ = server.ConnectionState()
sealer, err := client.Get1RTTSealer()
if err != nil {
panic("expected to get a 1-RTT sealer")
}
opener, err := server.Get1RTTOpener()
if err != nil {
panic("expected to get a 1-RTT opener")
}
const msg = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
encrypted := sealer.Seal(nil, []byte(msg), 1337, []byte("foobar"))
decrypted, err := opener.Open(nil, encrypted, time.Time{}, 1337, protocol.KeyPhaseZero, []byte("foobar"))
if err != nil {
panic(fmt.Sprintf("Decrypting message failed: %s", err.Error()))
}
if string(decrypted) != msg {
panic("wrong message")
}
if sendSessionTicket && !serverConf.SessionTicketsDisabled {
ticket, err := server.GetSessionTicket()
if err != nil {
panic(err)
}
if ticket == nil {
panic("empty ticket")
}
client.HandleMessage(ticket, protocol.Encryption1RTT)
}
if sendPostHandshakeMessageToClient {
fmt.Println("sending post handshake message to the client at", messageToReplaceEncLevel)
client.HandleMessage(data, messageToReplaceEncLevel)
}
if sendPostHandshakeMessageToServer {
fmt.Println("sending post handshake message to the server at", messageToReplaceEncLevel)
server.HandleMessage(data, messageToReplaceEncLevel)
}
return 1
}
package header
import (
"bytes"
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
const version = protocol.Version1
// PrefixLen is the number of bytes used for configuration
const PrefixLen = 1
// Fuzz fuzzes the QUIC header.
//
//go:generate go run ./cmd/corpus.go
func Fuzz(data []byte) int {
if len(data) < PrefixLen {
return 0
}
connIDLen := int(data[0] % 21)
data = data[PrefixLen:]
if wire.IsVersionNegotiationPacket(data) {
return fuzzVNP(data)
}
connID, err := wire.ParseConnectionID(data, connIDLen)
if err != nil {
return 0
}
if !wire.IsLongHeaderPacket(data[0]) {
wire.ParseShortHeader(data, connIDLen)
return 1
}
is0RTTPacket := wire.Is0RTTPacket(data)
hdr, _, _, err := wire.ParsePacket(data)
if err != nil {
return 0
}
if hdr.DestConnectionID != connID {
panic(fmt.Sprintf("Expected connection IDs to match: %s vs %s", hdr.DestConnectionID, connID))
}
if (hdr.Type == protocol.PacketType0RTT) != is0RTTPacket {
panic("inconsistent 0-RTT packet detection")
}
var extHdr *wire.ExtendedHeader
// Parse the extended header, if this is not a Retry packet.
if hdr.Type == protocol.PacketTypeRetry {
extHdr = &wire.ExtendedHeader{Header: *hdr}
} else {
var err error
extHdr, err = hdr.ParseExtended(data)
if err != nil {
return 0
}
}
// We always use a 2-byte encoding for the Length field in Long Header packets.
// Serializing the header will fail when using a higher value.
if hdr.Length > 16383 {
return 1
}
b, err := extHdr.Append(nil, version)
if err != nil {
// We are able to parse packets with connection IDs longer than 20 bytes,
// but in QUIC version 1, we don't write headers with longer connection IDs.
if hdr.DestConnectionID.Len() <= protocol.MaxConnIDLen &&
hdr.SrcConnectionID.Len() <= protocol.MaxConnIDLen {
panic(err)
}
return 0
}
// GetLength is not implemented for Retry packets
if hdr.Type != protocol.PacketTypeRetry {
if expLen := extHdr.GetLength(version); expLen != protocol.ByteCount(len(b)) {
panic(fmt.Sprintf("inconsistent header length: %#v. Expected %d, got %d", extHdr, expLen, len(b)))
}
}
return 1
}
func fuzzVNP(data []byte) int {
connID, err := wire.ParseConnectionID(data, 0)
if err != nil {
return 0
}
dest, src, versions, err := wire.ParseVersionNegotiationPacket(data)
if err != nil {
return 0
}
if !bytes.Equal(dest, connID.Bytes()) {
panic("connection IDs don't match")
}
if len(versions) == 0 {
panic("no versions")
}
wire.ComposeVersionNegotiation(src, dest, versions)
return 1
}
package helper
import (
"crypto"
"crypto/rand"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/hex"
"math/big"
"os"
"path/filepath"
"time"
)
// NthBit gets the n-th bit of a byte (counting starts at 0).
func NthBit(val uint8, n int) bool {
if n < 0 || n > 7 {
panic("invalid value for n")
}
return val>>n&0x1 == 1
}
// WriteCorpusFile writes data to a corpus file in directory path.
// The filename is calculated from the SHA1 sum of the file contents.
func WriteCorpusFile(path string, data []byte) error {
// create the directory, if it doesn't exist yet
if _, err := os.Stat(path); os.IsNotExist(err) {
if err := os.MkdirAll(path, os.ModePerm); err != nil {
return err
}
}
hash := sha1.Sum(data)
return os.WriteFile(filepath.Join(path, hex.EncodeToString(hash[:])), data, 0o644)
}
// WriteCorpusFileWithPrefix writes data to a corpus file in directory path.
// In many fuzzers, the first n bytes are used to control.
// This function prepends n zero-bytes to the data.
func WriteCorpusFileWithPrefix(path string, data []byte, n int) error {
return WriteCorpusFile(path, append(make([]byte, n), data...))
}
// GenerateCertificate generates a self-signed certificate.
// It returns the certificate and a x509.CertPool containing that certificate.
func GenerateCertificate(priv crypto.Signer) (*tls.Certificate, *x509.CertPool, error) {
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{Organization: []string{"quic-go fuzzer"}},
NotBefore: time.Now().Add(-24 * time.Hour),
NotAfter: time.Now().Add(30 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{"localhost"},
BasicConstraintsValid: true,
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, priv.Public(), priv)
if err != nil {
return nil, nil, err
}
cert, err := x509.ParseCertificate(derBytes)
if err != nil {
return nil, nil, err
}
certPool := x509.NewCertPool()
certPool.AddCert(cert)
return &tls.Certificate{
Certificate: [][]byte{derBytes},
PrivateKey: priv,
}, certPool, nil
}
package tokens
import (
"encoding/binary"
"net"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
)
func Fuzz(data []byte) int {
if len(data) < 32 {
return -1
}
var key quic.TokenGeneratorKey
copy(key[:], data[:32])
data = data[32:]
tg := handshake.NewTokenGenerator(key)
if len(data) < 1 {
return -1
}
s := data[0] % 3
data = data[1:]
switch s {
case 0:
tg.DecodeToken(data)
return 1
case 1:
return newToken(tg, data)
case 2:
return newRetryToken(tg, data)
}
return -1
}
func newToken(tg *handshake.TokenGenerator, data []byte) int {
if len(data) < 1 {
return -1
}
usesUDPAddr := data[0]%2 == 0
data = data[1:]
if len(data) < 18 {
return -1
}
var addr net.Addr
if usesUDPAddr {
addr = &net.UDPAddr{
Port: int(binary.BigEndian.Uint16(data[:2])),
IP: net.IP(data[2:18]),
}
} else {
addr = &net.TCPAddr{
Port: int(binary.BigEndian.Uint16(data[:2])),
IP: net.IP(data[2:18]),
}
}
data = data[18:]
if len(data) < 1 {
return -1
}
start := time.Now()
encrypted, err := tg.NewToken(addr, time.Duration(data[0])*time.Millisecond)
if err != nil {
panic(err)
}
token, err := tg.DecodeToken(encrypted)
if err != nil {
panic(err)
}
if token.IsRetryToken {
panic("didn't encode a Retry token")
}
if token.SentTime.Before(start) || token.SentTime.After(time.Now()) {
panic("incorrect send time")
}
if token.OriginalDestConnectionID.Len() > 0 || token.RetrySrcConnectionID.Len() > 0 {
panic("didn't expect connection IDs")
}
return 1
}
func newRetryToken(tg *handshake.TokenGenerator, data []byte) int {
if len(data) < 2 {
return -1
}
origDestConnIDLen := int(data[0] % 21)
retrySrcConnIDLen := int(data[1] % 21)
data = data[2:]
if len(data) < origDestConnIDLen {
return -1
}
origDestConnID := protocol.ParseConnectionID(data[:origDestConnIDLen])
data = data[origDestConnIDLen:]
if len(data) < retrySrcConnIDLen {
return -1
}
retrySrcConnID := protocol.ParseConnectionID(data[:retrySrcConnIDLen])
data = data[retrySrcConnIDLen:]
if len(data) < 1 {
return -1
}
usesUDPAddr := data[0]%2 == 0
data = data[1:]
if len(data) != 18 {
return -1
}
start := time.Now()
var addr net.Addr
if usesUDPAddr {
addr = &net.UDPAddr{
Port: int(binary.BigEndian.Uint16(data[:2])),
IP: net.IP(data[2:]),
}
} else {
addr = &net.TCPAddr{
Port: int(binary.BigEndian.Uint16(data[:2])),
IP: net.IP(data[2:]),
}
}
encrypted, err := tg.NewRetryToken(addr, origDestConnID, retrySrcConnID)
if err != nil {
panic(err)
}
token, err := tg.DecodeToken(encrypted)
if err != nil {
panic(err)
}
if !token.IsRetryToken {
panic("expected a Retry token")
}
if token.SentTime.Before(start) || token.SentTime.After(time.Now()) {
panic("incorrect send time")
}
if token.OriginalDestConnectionID != origDestConnID {
panic("orig dest conn ID doesn't match")
}
if token.RetrySrcConnectionID != retrySrcConnID {
panic("retry src conn ID doesn't match")
}
return 1
}
package transportparameters
import (
"errors"
"fmt"
"github.com/quic-go/quic-go/fuzzing/internal/helper"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
// PrefixLen is the number of bytes used for configuration
const PrefixLen = 1
// Fuzz fuzzes the QUIC transport parameters.
//
//go:generate go run ./cmd/corpus.go
func Fuzz(data []byte) int {
if len(data) <= PrefixLen {
return 0
}
if helper.NthBit(data[0], 0) {
return fuzzTransportParametersForSessionTicket(data[PrefixLen:])
}
return fuzzTransportParameters(data[PrefixLen:], helper.NthBit(data[0], 1))
}
func fuzzTransportParameters(data []byte, sentByServer bool) int {
sentBy := protocol.PerspectiveClient
if sentByServer {
sentBy = protocol.PerspectiveServer
}
tp := &wire.TransportParameters{}
if err := tp.Unmarshal(data, sentBy); err != nil {
return 0
}
_ = tp.String()
if err := validateTransportParameters(tp, sentBy); err != nil {
panic(err)
}
tp2 := &wire.TransportParameters{}
if err := tp2.Unmarshal(tp.Marshal(sentBy), sentBy); err != nil {
fmt.Printf("%#v\n", tp)
panic(err)
}
if err := validateTransportParameters(tp2, sentBy); err != nil {
panic(err)
}
return 1
}
func fuzzTransportParametersForSessionTicket(data []byte) int {
tp := &wire.TransportParameters{}
if err := tp.UnmarshalFromSessionTicket(data); err != nil {
return 0
}
b := tp.MarshalForSessionTicket(nil)
tp2 := &wire.TransportParameters{}
if err := tp2.UnmarshalFromSessionTicket(b); err != nil {
panic(err)
}
return 1
}
func validateTransportParameters(tp *wire.TransportParameters, sentBy protocol.Perspective) error {
if sentBy == protocol.PerspectiveClient && tp.StatelessResetToken != nil {
return errors.New("client's transport parameters contained stateless reset token")
}
if tp.MaxIdleTimeout < 0 {
return fmt.Errorf("negative max_idle_timeout: %s", tp.MaxIdleTimeout)
}
if tp.AckDelayExponent > 20 {
return fmt.Errorf("invalid ack_delay_exponent: %d", tp.AckDelayExponent)
}
if tp.MaxUDPPayloadSize < 1200 {
return fmt.Errorf("invalid max_udp_payload_size: %d", tp.MaxUDPPayloadSize)
}
if tp.ActiveConnectionIDLimit < 2 {
return fmt.Errorf("invalid active_connection_id_limit: %d", tp.ActiveConnectionIDLimit)
}
if tp.OriginalDestinationConnectionID.Len() > 20 {
return fmt.Errorf("invalid original_destination_connection_id length: %s", tp.InitialSourceConnectionID)
}
if tp.InitialSourceConnectionID.Len() > 20 {
return fmt.Errorf("invalid initial_source_connection_id length: %s", tp.InitialSourceConnectionID)
}
if tp.RetrySourceConnectionID != nil && tp.RetrySourceConnectionID.Len() > 20 {
return fmt.Errorf("invalid retry_source_connection_id length: %s", tp.RetrySourceConnectionID)
}
if tp.PreferredAddress != nil && tp.PreferredAddress.ConnectionID.Len() > 20 {
return fmt.Errorf("invalid preferred_address connection ID length: %s", tp.PreferredAddress.ConnectionID)
}
return nil
}
package quic
import (
"context"
"crypto/tls"
"errors"
"net"
"slices"
"time"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/logging"
)
// The StreamID is the ID of a QUIC stream.
type StreamID = protocol.StreamID
// A Version is a QUIC version number.
type Version = protocol.Version
const (
// Version1 is RFC 9000
Version1 = protocol.Version1
// Version2 is RFC 9369
Version2 = protocol.Version2
)
// SupportedVersions returns the support versions, sorted in descending order of preference.
func SupportedVersions() []Version {
// clone the slice to prevent the caller from modifying the slice
return slices.Clone(protocol.SupportedVersions)
}
// A ClientToken is a token received by the client.
// It can be used to skip address validation on future connection attempts.
type ClientToken struct {
data []byte
rtt time.Duration
}
type TokenStore interface {
// Pop searches for a ClientToken associated with the given key.
// Since tokens are not supposed to be reused, it must remove the token from the cache.
// It returns nil when no token is found.
Pop(key string) (token *ClientToken)
// Put adds a token to the cache with the given key. It might get called
// multiple times in a connection.
Put(key string, token *ClientToken)
}
// Err0RTTRejected is the returned from:
// - Open{Uni}Stream{Sync}
// - Accept{Uni}Stream
// - Stream.Read and Stream.Write
//
// when the server rejects a 0-RTT connection attempt.
var Err0RTTRejected = errors.New("0-RTT rejected")
// ConnectionTracingKey can be used to associate a [logging.ConnectionTracer] with a [Conn].
// It is set on the Conn.Context() context,
// as well as on the context passed to logging.Tracer.NewConnectionTracer.
//
// Deprecated: Applications can set their own tracing key using Transport.ConnContext.
var ConnectionTracingKey = connTracingCtxKey{}
// ConnectionTracingID is the type of the context value saved under the ConnectionTracingKey.
//
// Deprecated: Applications can set their own tracing key using Transport.ConnContext.
type ConnectionTracingID uint64
type connTracingCtxKey struct{}
// QUICVersionContextKey can be used to find out the QUIC version of a TLS handshake from the
// context returned by tls.Config.ClientInfo.Context.
var QUICVersionContextKey = handshake.QUICVersionContextKey
// StatelessResetKey is a key used to derive stateless reset tokens.
type StatelessResetKey [32]byte
// TokenGeneratorKey is a key used to encrypt session resumption tokens.
type TokenGeneratorKey = handshake.TokenProtectorKey
// A ConnectionID is a QUIC Connection ID, as defined in RFC 9000.
// It is not able to handle QUIC Connection IDs longer than 20 bytes,
// as they are allowed by RFC 8999.
type ConnectionID = protocol.ConnectionID
// ConnectionIDFromBytes interprets b as a [ConnectionID]. It panics if b is
// longer than 20 bytes.
func ConnectionIDFromBytes(b []byte) ConnectionID {
return protocol.ParseConnectionID(b)
}
// A ConnectionIDGenerator allows the application to take control over the generation of Connection IDs.
// Connection IDs generated by an implementation must be of constant length.
type ConnectionIDGenerator interface {
// GenerateConnectionID generates a new Connection ID.
// Generated Connection IDs must be unique and observers should not be able to correlate two Connection IDs.
GenerateConnectionID() (ConnectionID, error)
// ConnectionIDLen returns the length of Connection IDs generated by this implementation.
// Implementations must return constant-length Connection IDs with lengths between 0 and 20 bytes.
// A length of 0 can only be used when an endpoint doesn't need to multiplex connections during migration.
ConnectionIDLen() int
}
// Config contains all configuration data needed for a QUIC server or client.
type Config struct {
// GetConfigForClient is called for incoming connections.
// If the error is not nil, the connection attempt is refused.
GetConfigForClient func(info *ClientInfo) (*Config, error)
// The QUIC versions that can be negotiated.
// If not set, it uses all versions available.
Versions []Version
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
// If we don't receive any packet from the peer within this time, the connection attempt is aborted.
// Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted.
// If this value is zero, the timeout is set to 5 seconds.
HandshakeIdleTimeout time.Duration
// MaxIdleTimeout is the maximum duration that may pass without any incoming network activity.
// The actual value for the idle timeout is the minimum of this value and the peer's.
// This value only applies after the handshake has completed.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 30 seconds.
MaxIdleTimeout time.Duration
// The TokenStore stores tokens received from the server.
// Tokens are used to skip address validation on future connection attempts.
// The key used to store tokens is the ServerName from the tls.Config, if set
// otherwise the token is associated with the server's IP address.
TokenStore TokenStore
// InitialStreamReceiveWindow is the initial size of the stream-level flow control window for receiving data.
// If the application is consuming data quickly enough, the flow control auto-tuning algorithm
// will increase the window up to MaxStreamReceiveWindow.
// If this value is zero, it will default to 512 KB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
InitialStreamReceiveWindow uint64
// MaxStreamReceiveWindow is the maximum stream-level flow control window for receiving data.
// If this value is zero, it will default to 6 MB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
MaxStreamReceiveWindow uint64
// InitialConnectionReceiveWindow is the initial size of the stream-level flow control window for receiving data.
// If the application is consuming data quickly enough, the flow control auto-tuning algorithm
// will increase the window up to MaxConnectionReceiveWindow.
// If this value is zero, it will default to 512 KB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
InitialConnectionReceiveWindow uint64
// MaxConnectionReceiveWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 15 MB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
MaxConnectionReceiveWindow uint64
// AllowConnectionWindowIncrease is called every time the connection flow controller attempts
// to increase the connection flow control window.
// If set, the caller can prevent an increase of the window. Typically, it would do so to
// limit the memory usage.
// To avoid deadlocks, it is not valid to call other functions on the connection or on streams
// in this callback.
AllowConnectionWindowIncrease func(conn *Conn, delta uint64) bool
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any bidirectional streams.
// Values larger than 2^60 will be clipped to that value.
MaxIncomingStreams int64
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any unidirectional streams.
// Values larger than 2^60 will be clipped to that value.
MaxIncomingUniStreams int64
// KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive.
// If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most
// every half of MaxIdleTimeout, whichever is smaller).
KeepAlivePeriod time.Duration
// InitialPacketSize is the initial size (and the lower limit) for packets sent.
// Under most circumstances, it is not necessary to manually set this value,
// since path MTU discovery quickly finds the path's MTU.
// If set too high, the path might not support packets of that size, leading to a timeout of the QUIC handshake.
// Values below 1200 are invalid.
InitialPacketSize uint16
// DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899).
// This allows the sending of QUIC packets that fully utilize the available MTU of the path.
// Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit.
DisablePathMTUDiscovery bool
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
// Only valid for the server.
Allow0RTT bool
// Enable QUIC datagram support (RFC 9221).
EnableDatagrams bool
// Enable QUIC Stream Resets with Partial Delivery.
// See https://datatracker.ietf.org/doc/html/draft-ietf-quic-reliable-stream-reset-07.
EnableStreamResetPartialDelivery bool
Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer
}
// ClientHelloInfo contains information about an incoming connection attempt.
//
// Deprecated: Use ClientInfo instead.
type ClientHelloInfo = ClientInfo
// ClientInfo contains information about an incoming connection attempt.
type ClientInfo struct {
// RemoteAddr is the remote address on the Initial packet.
// Unless AddrVerified is set, the address is not yet verified, and could be a spoofed IP address.
RemoteAddr net.Addr
// AddrVerified says if the remote address was verified using QUIC's Retry mechanism.
// Note that the Retry mechanism costs one network roundtrip,
// and is not performed unless Transport.MaxUnvalidatedHandshakes is surpassed.
AddrVerified bool
}
// ConnectionState records basic details about a QUIC connection.
type ConnectionState struct {
// TLS contains information about the TLS connection state, incl. the tls.ConnectionState.
TLS tls.ConnectionState
// SupportsDatagrams indicates whether the peer advertised support for QUIC datagrams (RFC 9221).
// When true, datagrams can be sent using the Conn's SendDatagram method.
// This is a unilateral declaration by the peer - receiving datagrams is only possible if
// datagram support was enabled locally via Config.EnableDatagrams.
SupportsDatagrams bool
// SupportsStreamResetPartialDelivery indicates whether the peer advertised support for QUIC Stream Resets with Partial Delivery.
SupportsStreamResetPartialDelivery bool
// Used0RTT says if 0-RTT resumption was used.
Used0RTT bool
// Version is the QUIC version of the QUIC connection.
Version Version
// GSO says if generic segmentation offload is used.
GSO bool
}
package ackhandler
import "github.com/quic-go/quic-go/internal/wire"
// IsFrameTypeAckEliciting returns true if the frame is ack-eliciting.
func IsFrameTypeAckEliciting(t wire.FrameType) bool {
//nolint:exhaustive // The default case catches the rest.
switch t {
case wire.FrameTypeAck, wire.FrameTypeAckECN:
return false
case wire.FrameTypeConnectionClose, wire.FrameTypeApplicationClose:
return false
default:
return true
}
}
// IsFrameAckEliciting returns true if the frame is ack-eliciting.
func IsFrameAckEliciting(f wire.Frame) bool {
_, isAck := f.(*wire.AckFrame)
_, isConnectionClose := f.(*wire.ConnectionCloseFrame)
return !isAck && !isConnectionClose
}
// HasAckElicitingFrames returns true if at least one frame is ack-eliciting.
func HasAckElicitingFrames(fs []Frame) bool {
for _, f := range fs {
if IsFrameAckEliciting(f.Frame) {
return true
}
}
return false
}
package ackhandler
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler.
// clientAddressValidated indicates whether the address was validated beforehand by an address validation token.
// clientAddressValidated has no effect for a client.
func NewAckHandler(
initialPacketNumber protocol.PacketNumber,
initialMaxDatagramSize protocol.ByteCount,
rttStats *utils.RTTStats,
clientAddressValidated bool,
enableECN bool,
pers protocol.Perspective,
tracer *logging.ConnectionTracer,
logger utils.Logger,
) (SentPacketHandler, ReceivedPacketHandler) {
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger)
return sph, newReceivedPacketHandler(sph, logger)
}
package ackhandler
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
type ecnState uint8
const (
ecnStateInitial ecnState = iota
ecnStateTesting
ecnStateUnknown
ecnStateCapable
ecnStateFailed
)
// must fit into an uint8, otherwise numSentTesting and numLostTesting must have a larger type
const numECNTestingPackets = 10
type ecnHandler interface {
SentPacket(protocol.PacketNumber, protocol.ECN)
Mode() protocol.ECN
HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool)
LostPacket(protocol.PacketNumber)
}
// The ecnTracker performs ECN validation of a path.
// Once failed, it doesn't do any re-validation of the path.
// It is designed only work for 1-RTT packets, it doesn't handle multiple packet number spaces.
// In order to avoid revealing any internal state to on-path observers,
// callers should make sure to start using ECN (i.e. calling Mode) for the very first 1-RTT packet sent.
// The validation logic implemented here strictly follows the algorithm described in RFC 9000 section 13.4.2 and A.4.
type ecnTracker struct {
state ecnState
numSentTesting, numLostTesting uint8
firstTestingPacket protocol.PacketNumber
lastTestingPacket protocol.PacketNumber
firstCapablePacket protocol.PacketNumber
numSentECT0, numSentECT1 int64
numAckedECT0, numAckedECT1, numAckedECNCE int64
tracer *logging.ConnectionTracer
logger utils.Logger
}
var _ ecnHandler = &ecnTracker{}
func newECNTracker(logger utils.Logger, tracer *logging.ConnectionTracer) *ecnTracker {
return &ecnTracker{
firstTestingPacket: protocol.InvalidPacketNumber,
lastTestingPacket: protocol.InvalidPacketNumber,
firstCapablePacket: protocol.InvalidPacketNumber,
state: ecnStateInitial,
logger: logger,
tracer: tracer,
}
}
func (e *ecnTracker) SentPacket(pn protocol.PacketNumber, ecn protocol.ECN) {
//nolint:exhaustive // These are the only ones we need to take care of.
switch ecn {
case protocol.ECNNon:
return
case protocol.ECT0:
e.numSentECT0++
case protocol.ECT1:
e.numSentECT1++
case protocol.ECNUnsupported:
if e.state != ecnStateFailed {
panic("didn't expect ECN to be unsupported")
}
default:
panic(fmt.Sprintf("sent packet with unexpected ECN marking: %s", ecn))
}
if e.state == ecnStateCapable && e.firstCapablePacket == protocol.InvalidPacketNumber {
e.firstCapablePacket = pn
}
if e.state != ecnStateTesting {
return
}
e.numSentTesting++
if e.firstTestingPacket == protocol.InvalidPacketNumber {
e.firstTestingPacket = pn
}
if e.numSentECT0+e.numSentECT1 >= numECNTestingPackets {
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger)
}
e.state = ecnStateUnknown
e.lastTestingPacket = pn
}
}
func (e *ecnTracker) Mode() protocol.ECN {
switch e.state {
case ecnStateInitial:
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger)
}
e.state = ecnStateTesting
return e.Mode()
case ecnStateTesting, ecnStateCapable:
return protocol.ECT0
case ecnStateUnknown, ecnStateFailed:
return protocol.ECNNon
default:
panic(fmt.Sprintf("unknown ECN state: %d", e.state))
}
}
func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) {
if e.state != ecnStateTesting && e.state != ecnStateUnknown {
return
}
if !e.isTestingPacket(pn) {
return
}
e.numLostTesting++
// Only proceed if we have sent all 10 testing packets.
if e.state != ecnStateUnknown {
return
}
if e.numLostTesting >= e.numSentTesting {
e.logger.Debugf("Disabling ECN. All testing packets were lost.")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets)
}
e.state = ecnStateFailed
return
}
// Path validation also fails if some testing packets are lost, and all other testing packets where CE-marked
e.failIfMangled()
}
// HandleNewlyAcked handles the ECN counts on an ACK frame.
// It must only be called for ACK frames that increase the largest acknowledged packet number,
// see section 13.4.2.1 of RFC 9000.
func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool) {
if e.state == ecnStateFailed {
return false
}
// ECN validation can fail if the received total count for either ECT(0) or ECT(1) exceeds
// the total number of packets sent with each corresponding ECT codepoint.
if ect0 > e.numSentECT0 || ect1 > e.numSentECT1 {
e.logger.Debugf("Disabling ECN. Received more ECT(0) / ECT(1) acknowledgements than packets sent.")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent)
}
e.state = ecnStateFailed
return false
}
// Count ECT0 and ECT1 marks that we used when sending the packets that are now being acknowledged.
var ackedECT0, ackedECT1 int64
for _, p := range packets {
//nolint:exhaustive // We only ever send ECT(0) and ECT(1).
switch e.ecnMarking(p.PacketNumber) {
case protocol.ECT0:
ackedECT0++
case protocol.ECT1:
ackedECT1++
}
}
// If an ACK frame newly acknowledges a packet that the endpoint sent with either the ECT(0) or ECT(1)
// codepoint set, ECN validation fails if the corresponding ECN counts are not present in the ACK frame.
// This check detects:
// * paths that bleach all ECN marks, and
// * peers that don't report any ECN counts
if (ackedECT0 > 0 || ackedECT1 > 0) && ect0 == 0 && ect1 == 0 && ecnce == 0 {
e.logger.Debugf("Disabling ECN. ECN-marked packet acknowledged, but no ECN counts on ACK frame.")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts)
}
e.state = ecnStateFailed
return false
}
// Determine the increase in ECT0, ECT1 and ECNCE marks
newECT0 := ect0 - e.numAckedECT0
newECT1 := ect1 - e.numAckedECT1
newECNCE := ecnce - e.numAckedECNCE
// We're only processing ACKs that increase the Largest Acked.
// Therefore, the ECN counters should only ever increase.
// Any decrease means that the peer's counting logic is broken.
if newECT0 < 0 || newECT1 < 0 || newECNCE < 0 {
e.logger.Debugf("Disabling ECN. ECN counts decreased unexpectedly.")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedDecreasedECNCounts)
}
e.state = ecnStateFailed
return false
}
// ECN validation also fails if the sum of the increase in ECT(0) and ECN-CE counts is less than the number
// of newly acknowledged packets that were originally sent with an ECT(0) marking.
// This could be the result of (partial) bleaching.
if newECT0+newECNCE < ackedECT0 {
e.logger.Debugf("Disabling ECN. Received less ECT(0) + ECN-CE than packets sent with ECT(0).")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts)
}
e.state = ecnStateFailed
return false
}
// Similarly, ECN validation fails if the sum of the increases to ECT(1) and ECN-CE counts is less than
// the number of newly acknowledged packets sent with an ECT(1) marking.
if newECT1+newECNCE < ackedECT1 {
e.logger.Debugf("Disabling ECN. Received less ECT(1) + ECN-CE than packets sent with ECT(1).")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts)
}
e.state = ecnStateFailed
return false
}
// update our counters
e.numAckedECT0 = ect0
e.numAckedECT1 = ect1
e.numAckedECNCE = ecnce
// Detect mangling (a path remarking all ECN-marked testing packets as CE),
// once all 10 testing packets have been sent out.
if e.state == ecnStateUnknown {
e.failIfMangled()
if e.state == ecnStateFailed {
return false
}
}
if e.state == ecnStateTesting || e.state == ecnStateUnknown {
var ackedTestingPacket bool
for _, p := range packets {
if e.isTestingPacket(p.PacketNumber) {
ackedTestingPacket = true
break
}
}
// This check won't succeed if the path is mangling ECN-marks (i.e. rewrites all ECN-marked packets to CE).
if ackedTestingPacket && (newECT0 > 0 || newECT1 > 0) {
e.logger.Debugf("ECN capability confirmed.")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger)
}
e.state = ecnStateCapable
}
}
// Don't trust CE marks before having confirmed ECN capability of the path.
// Otherwise, mangling would be misinterpreted as actual congestion.
return e.state == ecnStateCapable && newECNCE > 0
}
// failIfMangled fails ECN validation if all testing packets are lost or CE-marked.
func (e *ecnTracker) failIfMangled() {
numAckedECNCE := e.numAckedECNCE + int64(e.numLostTesting)
if e.numSentECT0+e.numSentECT1 > numAckedECNCE {
return
}
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected)
}
e.state = ecnStateFailed
}
func (e *ecnTracker) ecnMarking(pn protocol.PacketNumber) protocol.ECN {
if pn < e.firstTestingPacket || e.firstTestingPacket == protocol.InvalidPacketNumber {
return protocol.ECNNon
}
if pn < e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber {
return protocol.ECT0
}
if pn < e.firstCapablePacket || e.firstCapablePacket == protocol.InvalidPacketNumber {
return protocol.ECNNon
}
// We don't need to deal with the case when ECN validation fails,
// since we're ignoring any ECN counts reported in ACK frames in that case.
return protocol.ECT0
}
func (e *ecnTracker) isTestingPacket(pn protocol.PacketNumber) bool {
if e.firstTestingPacket == protocol.InvalidPacketNumber {
return false
}
return pn >= e.firstTestingPacket && (pn <= e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber)
}
package ackhandler
import (
"sync"
"time"
"github.com/quic-go/quic-go/internal/protocol"
)
// A Packet is a packet
type packet struct {
SendTime time.Time
PacketNumber protocol.PacketNumber
StreamFrames []StreamFrame
Frames []Frame
LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK
Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel
IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller.
includedInBytesInFlight bool
declaredLost bool
skippedPacket bool
isPathProbePacket bool
}
func (p *packet) outstanding() bool {
return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket && !p.isPathProbePacket
}
var packetPool = sync.Pool{New: func() any { return &packet{} }}
func getPacket() *packet {
p := packetPool.Get().(*packet)
p.PacketNumber = 0
p.StreamFrames = nil
p.Frames = nil
p.LargestAcked = 0
p.Length = 0
p.EncryptionLevel = protocol.EncryptionLevel(0)
p.SendTime = time.Time{}
p.IsPathMTUProbePacket = false
p.includedInBytesInFlight = false
p.declaredLost = false
p.skippedPacket = false
return p
}
// We currently only return Packets back into the pool when they're acknowledged (not when they're lost).
// This simplifies the code, and gives the vast majority of the performance benefit we can gain from using the pool.
func putPacket(p *packet) {
p.Frames = nil
p.StreamFrames = nil
packetPool.Put(p)
}
package ackhandler
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
type packetNumberGenerator interface {
Peek() protocol.PacketNumber
// Pop pops the packet number.
// It reports if the packet number (before the one just popped) was skipped.
// It never skips more than one packet number in a row.
Pop() (skipped bool, _ protocol.PacketNumber)
}
type sequentialPacketNumberGenerator struct {
next protocol.PacketNumber
}
var _ packetNumberGenerator = &sequentialPacketNumberGenerator{}
func newSequentialPacketNumberGenerator(initial protocol.PacketNumber) packetNumberGenerator {
return &sequentialPacketNumberGenerator{next: initial}
}
func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber {
return p.next
}
func (p *sequentialPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) {
next := p.next
p.next++
return false, next
}
// The skippingPacketNumberGenerator generates the packet number for the next packet
// it randomly skips a packet number every averagePeriod packets (on average).
// It is guaranteed to never skip two consecutive packet numbers.
type skippingPacketNumberGenerator struct {
period protocol.PacketNumber
maxPeriod protocol.PacketNumber
next protocol.PacketNumber
nextToSkip protocol.PacketNumber
rng utils.Rand
}
var _ packetNumberGenerator = &skippingPacketNumberGenerator{}
func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol.PacketNumber) packetNumberGenerator {
g := &skippingPacketNumberGenerator{
next: initial,
period: initialPeriod,
maxPeriod: maxPeriod,
}
g.generateNewSkip()
return g
}
func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber {
if p.next == p.nextToSkip {
return p.next + 1
}
return p.next
}
func (p *skippingPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) {
next := p.next
if p.next == p.nextToSkip {
next++
p.next += 2
p.generateNewSkip()
return true, next
}
p.next++ // generate a new packet number for the next packet
return false, next
}
func (p *skippingPacketNumberGenerator) generateNewSkip() {
// make sure that there are never two consecutive packet numbers that are skipped
p.nextToSkip = p.next + 3 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period)))
p.period = min(2*p.period, p.maxPeriod)
}
package ackhandler
import (
"fmt"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
)
type receivedPacketHandler struct {
sentPackets sentPacketTracker
initialPackets *receivedPacketTracker
handshakePackets *receivedPacketTracker
appDataPackets appDataReceivedPacketTracker
lowest1RTTPacket protocol.PacketNumber
}
var _ ReceivedPacketHandler = &receivedPacketHandler{}
func newReceivedPacketHandler(sentPackets sentPacketTracker, logger utils.Logger) ReceivedPacketHandler {
return &receivedPacketHandler{
sentPackets: sentPackets,
initialPackets: newReceivedPacketTracker(),
handshakePackets: newReceivedPacketTracker(),
appDataPackets: *newAppDataReceivedPacketTracker(logger),
lowest1RTTPacket: protocol.InvalidPacketNumber,
}
}
func (h *receivedPacketHandler) ReceivedPacket(
pn protocol.PacketNumber,
ecn protocol.ECN,
encLevel protocol.EncryptionLevel,
rcvTime time.Time,
ackEliciting bool,
) error {
h.sentPackets.ReceivedPacket(encLevel, rcvTime)
switch encLevel {
case protocol.EncryptionInitial:
return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
case protocol.EncryptionHandshake:
// The Handshake packet number space might already have been dropped as a result
// of processing the CRYPTO frame that was contained in this packet.
if h.handshakePackets == nil {
return nil
}
return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
case protocol.Encryption0RTT:
if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket {
return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket)
}
return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
case protocol.Encryption1RTT:
if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket {
h.lowest1RTTPacket = pn
}
if err := h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil {
return err
}
h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked())
return nil
default:
panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel))
}
}
func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
//nolint:exhaustive // 1-RTT packet number space is never dropped.
switch encLevel {
case protocol.EncryptionInitial:
h.initialPackets = nil
case protocol.EncryptionHandshake:
h.handshakePackets = nil
case protocol.Encryption0RTT:
// Nothing to do here.
// If we are rejecting 0-RTT, no 0-RTT packets will have been decrypted.
default:
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
}
}
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
return h.appDataPackets.GetAlarmTimeout()
}
func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, now time.Time, onlyIfQueued bool) *wire.AckFrame {
//nolint:exhaustive // 0-RTT packets can't contain ACK frames.
switch encLevel {
case protocol.EncryptionInitial:
if h.initialPackets != nil {
return h.initialPackets.GetAckFrame()
}
return nil
case protocol.EncryptionHandshake:
if h.handshakePackets != nil {
return h.handshakePackets.GetAckFrame()
}
return nil
case protocol.Encryption1RTT:
return h.appDataPackets.GetAckFrame(now, onlyIfQueued)
default:
// 0-RTT packets can't contain ACK frames
return nil
}
}
func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool {
switch encLevel {
case protocol.EncryptionInitial:
if h.initialPackets != nil {
return h.initialPackets.IsPotentiallyDuplicate(pn)
}
case protocol.EncryptionHandshake:
if h.handshakePackets != nil {
return h.handshakePackets.IsPotentiallyDuplicate(pn)
}
case protocol.Encryption0RTT, protocol.Encryption1RTT:
return h.appDataPackets.IsPotentiallyDuplicate(pn)
}
panic("unexpected encryption level")
}
package ackhandler
import (
"slices"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
// interval is an interval from one PacketNumber to the other
type interval struct {
Start protocol.PacketNumber
End protocol.PacketNumber
}
// The receivedPacketHistory stores if a packet number has already been received.
// It generates ACK ranges which can be used to assemble an ACK frame.
// It does not store packet contents.
type receivedPacketHistory struct {
ranges []interval // maximum length: protocol.MaxNumAckRanges
deletedBelow protocol.PacketNumber
}
func newReceivedPacketHistory() *receivedPacketHistory {
return &receivedPacketHistory{}
}
// ReceivedPacket registers a packet with PacketNumber p and updates the ranges
func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ {
// ignore delayed packets, if we already deleted the range
if p < h.deletedBelow {
return false
}
isNew := h.addToRanges(p)
// Delete old ranges, if we're tracking too many of them.
// This is a DoS defense against a peer that sends us too many gaps.
if len(h.ranges) > protocol.MaxNumAckRanges {
h.ranges = slices.Delete(h.ranges, 0, len(h.ranges)-protocol.MaxNumAckRanges)
}
return isNew
}
func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ {
if len(h.ranges) == 0 {
h.ranges = append(h.ranges, interval{Start: p, End: p})
return true
}
for i := len(h.ranges) - 1; i >= 0; i-- {
// p already included in an existing range. Nothing to do here
if p >= h.ranges[i].Start && p <= h.ranges[i].End {
return false
}
if h.ranges[i].End == p-1 { // extend a range at the end
h.ranges[i].End = p
return true
}
if h.ranges[i].Start == p+1 { // extend a range at the beginning
h.ranges[i].Start = p
if i > 0 && h.ranges[i-1].End+1 == h.ranges[i].Start { // merge two ranges
h.ranges[i-1].End = h.ranges[i].End
h.ranges = slices.Delete(h.ranges, i, i+1)
}
return true
}
// create a new range after the current one
if p > h.ranges[i].End {
h.ranges = slices.Insert(h.ranges, i+1, interval{Start: p, End: p})
return true
}
}
// create a new range at the beginning
h.ranges = slices.Insert(h.ranges, 0, interval{Start: p, End: p})
return true
}
// DeleteBelow deletes all entries below (but not including) p
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
if p < h.deletedBelow {
return
}
h.deletedBelow = p
if len(h.ranges) == 0 {
return
}
idx := -1
for i := 0; i < len(h.ranges); i++ {
if h.ranges[i].End < p { // delete a whole range
idx = i
} else if p > h.ranges[i].Start && p <= h.ranges[i].End {
h.ranges[i].Start = p
break
} else { // no ranges affected. Nothing to do
break
}
}
if idx >= 0 {
h.ranges = slices.Delete(h.ranges, 0, idx+1)
}
}
// AppendAckRanges appends to a slice of all AckRanges that can be used in an AckFrame
func (h *receivedPacketHistory) AppendAckRanges(ackRanges []wire.AckRange) []wire.AckRange {
for i := len(h.ranges) - 1; i >= 0; i-- {
ackRanges = append(ackRanges, wire.AckRange{Smallest: h.ranges[i].Start, Largest: h.ranges[i].End})
}
return ackRanges
}
func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
ackRange := wire.AckRange{}
if len(h.ranges) > 0 {
ackRange.Smallest = h.ranges[len(h.ranges)-1].Start
ackRange.Largest = h.ranges[len(h.ranges)-1].End
}
return ackRange
}
func (h *receivedPacketHistory) IsPotentiallyDuplicate(p protocol.PacketNumber) bool {
if p < h.deletedBelow {
return true
}
// Iterating over the slices is faster than using a binary search (using slices.BinarySearchFunc).
for i := len(h.ranges) - 1; i >= 0; i-- {
if p > h.ranges[i].End {
return false
}
if p <= h.ranges[i].End && p >= h.ranges[i].Start {
return true
}
}
return false
}
package ackhandler
import (
"fmt"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
)
// The receivedPacketTracker tracks packets for the Initial and Handshake packet number space.
// Every received packet is acknowledged immediately.
type receivedPacketTracker struct {
ect0, ect1, ecnce uint64
packetHistory receivedPacketHistory
lastAck *wire.AckFrame
hasNewAck bool // true as soon as we received an ack-eliciting new packet
}
func newReceivedPacketTracker() *receivedPacketTracker {
return &receivedPacketTracker{packetHistory: *newReceivedPacketHistory()}
}
func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
if isNew := h.packetHistory.ReceivedPacket(pn); !isNew {
return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn)
}
//nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECN-CE.
switch ecn {
case protocol.ECT0:
h.ect0++
case protocol.ECT1:
h.ect1++
case protocol.ECNCE:
h.ecnce++
}
if !ackEliciting {
return nil
}
h.hasNewAck = true
return nil
}
func (h *receivedPacketTracker) GetAckFrame() *wire.AckFrame {
if !h.hasNewAck {
return nil
}
// This function always returns the same ACK frame struct, filled with the most recent values.
ack := h.lastAck
if ack == nil {
ack = &wire.AckFrame{}
}
ack.Reset()
ack.ECT0 = h.ect0
ack.ECT1 = h.ect1
ack.ECNCE = h.ecnce
ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
h.lastAck = ack
h.hasNewAck = false
return ack
}
func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
return h.packetHistory.IsPotentiallyDuplicate(pn)
}
// number of ack-eliciting packets received before sending an ACK
const packetsBeforeAck = 2
// The appDataReceivedPacketTracker tracks packets received in the Application Data packet number space.
// It waits until at least 2 packets were received before queueing an ACK, or until the max_ack_delay was reached.
type appDataReceivedPacketTracker struct {
receivedPacketTracker
largestObservedRcvdTime time.Time
largestObserved protocol.PacketNumber
ignoreBelow protocol.PacketNumber
maxAckDelay time.Duration
ackQueued bool // true if we need send a new ACK
ackElicitingPacketsReceivedSinceLastAck int
ackAlarm time.Time
logger utils.Logger
}
func newAppDataReceivedPacketTracker(logger utils.Logger) *appDataReceivedPacketTracker {
h := &appDataReceivedPacketTracker{
receivedPacketTracker: *newReceivedPacketTracker(),
maxAckDelay: protocol.MaxAckDelay,
logger: logger,
}
return h
}
func (h *appDataReceivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
if err := h.receivedPacketTracker.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil {
return err
}
if pn >= h.largestObserved {
h.largestObserved = pn
h.largestObservedRcvdTime = rcvTime
}
if !ackEliciting {
return nil
}
h.ackElicitingPacketsReceivedSinceLastAck++
isMissing := h.isMissing(pn)
if !h.ackQueued && h.shouldQueueACK(pn, ecn, isMissing) {
h.ackQueued = true
h.ackAlarm = time.Time{} // cancel the ack alarm
}
if !h.ackQueued {
// No ACK queued, but we'll need to acknowledge the packet after max_ack_delay.
h.ackAlarm = rcvTime.Add(h.maxAckDelay)
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay)
}
}
return nil
}
// IgnoreBelow sets a lower limit for acknowledging packets.
// Packets with packet numbers smaller than p will not be acked.
func (h *appDataReceivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
if pn <= h.ignoreBelow {
return
}
h.ignoreBelow = pn
h.packetHistory.DeleteBelow(pn)
if h.logger.Debug() {
h.logger.Debugf("\tIgnoring all packets below %d.", pn)
}
}
// isMissing says if a packet was reported missing in the last ACK.
func (h *appDataReceivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
if h.lastAck == nil || p < h.ignoreBelow {
return false
}
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
}
func (h *appDataReceivedPacketTracker) hasNewMissingPackets() bool {
if h.lastAck == nil {
return false
}
highestRange := h.packetHistory.GetHighestAckRange()
return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1
}
func (h *appDataReceivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn protocol.ECN, wasMissing bool) bool {
// always acknowledge the first packet
if h.lastAck == nil {
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
return true
}
// Send an ACK if this packet was reported missing in an ACK sent before.
// Ack decimation with reordering relies on the timer to send an ACK, but if
// missing packets we reported in the previous ACK, send an ACK immediately.
if wasMissing {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn)
}
return true
}
// send an ACK every 2 ack-eliciting packets
if h.ackElicitingPacketsReceivedSinceLastAck >= packetsBeforeAck {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck)
}
return true
}
// queue an ACK if there are new missing packets to report
if h.hasNewMissingPackets() {
h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.")
return true
}
// queue an ACK if the packet was ECN-CE marked
if ecn == protocol.ECNCE {
h.logger.Debugf("\tQueuing ACK because the packet was ECN-CE marked.")
return true
}
return false
}
func (h *appDataReceivedPacketTracker) GetAckFrame(now time.Time, onlyIfQueued bool) *wire.AckFrame {
if onlyIfQueued && !h.ackQueued {
if h.ackAlarm.IsZero() || h.ackAlarm.After(now) {
return nil
}
if h.logger.Debug() && !h.ackAlarm.IsZero() {
h.logger.Debugf("Sending ACK because the ACK timer expired.")
}
}
ack := h.receivedPacketTracker.GetAckFrame()
if ack == nil {
return nil
}
ack.DelayTime = max(0, now.Sub(h.largestObservedRcvdTime))
h.ackQueued = false
h.ackAlarm = time.Time{}
h.ackElicitingPacketsReceivedSinceLastAck = 0
return ack
}
func (h *appDataReceivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }
package ackhandler
import "fmt"
// The SendMode says what kind of packets can be sent.
type SendMode uint8
const (
// SendNone means that no packets should be sent
SendNone SendMode = iota
// SendAck means an ACK-only packet should be sent
SendAck
// SendPTOInitial means that an Initial probe packet should be sent
SendPTOInitial
// SendPTOHandshake means that a Handshake probe packet should be sent
SendPTOHandshake
// SendPTOAppData means that an Application data probe packet should be sent
SendPTOAppData
// SendPacingLimited means that the pacer doesn't allow sending of a packet right now,
// but will do in a little while.
// The timestamp when sending is allowed again can be obtained via the SentPacketHandler.TimeUntilSend.
SendPacingLimited
// SendAny means that any packet should be sent
SendAny
)
func (s SendMode) String() string {
switch s {
case SendNone:
return "none"
case SendAck:
return "ack"
case SendPTOInitial:
return "pto (Initial)"
case SendPTOHandshake:
return "pto (Handshake)"
case SendPTOAppData:
return "pto (Application Data)"
case SendAny:
return "any"
case SendPacingLimited:
return "pacing limited"
default:
return fmt.Sprintf("invalid send mode: %d", s)
}
}
package ackhandler
import (
"errors"
"fmt"
"time"
"github.com/quic-go/quic-go/internal/congestion"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
const (
// Maximum reordering in time space before time based loss detection considers a packet lost.
// Specified as an RTT multiplier.
timeThreshold = 9.0 / 8
// Maximum reordering in packets before packet threshold loss detection considers a packet lost.
packetThreshold = 3
// Before validating the client's address, the server won't send more than 3x bytes than it received.
amplificationFactor = 3
// We use Retry packets to derive an RTT estimate. Make sure we don't set the RTT to a super low value yet.
minRTTAfterRetry = 5 * time.Millisecond
// The PTO duration uses exponential backoff, but is truncated to a maximum value, as allowed by RFC 8961, section 4.4.
maxPTODuration = 60 * time.Second
)
// Path probe packets are declared lost after this time.
const pathProbePacketLossTimeout = time.Second
type packetNumberSpace struct {
history sentPacketHistory
pns packetNumberGenerator
lossTime time.Time
lastAckElicitingPacketTime time.Time
largestAcked protocol.PacketNumber
largestSent protocol.PacketNumber
}
func newPacketNumberSpace(initialPN protocol.PacketNumber, isAppData bool) *packetNumberSpace {
var pns packetNumberGenerator
if isAppData {
pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod)
} else {
pns = newSequentialPacketNumberGenerator(initialPN)
}
return &packetNumberSpace{
history: *newSentPacketHistory(isAppData),
pns: pns,
largestSent: protocol.InvalidPacketNumber,
largestAcked: protocol.InvalidPacketNumber,
}
}
type alarmTimer struct {
Time time.Time
TimerType logging.TimerType
EncryptionLevel protocol.EncryptionLevel
}
type sentPacketHandler struct {
initialPackets *packetNumberSpace
handshakePackets *packetNumberSpace
appDataPackets *packetNumberSpace
// Do we know that the peer completed address validation yet?
// Always true for the server.
peerCompletedAddressValidation bool
bytesReceived protocol.ByteCount
bytesSent protocol.ByteCount
// Have we validated the peer's address yet?
// Always true for the client.
peerAddressValidated bool
handshakeConfirmed bool
// lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
// example: we send an ACK for packets 90-100 with packet number 20
// once we receive an ACK from the peer for packet 20, the lowestNotConfirmedAcked is 101
// Only applies to the application-data packet number space.
lowestNotConfirmedAcked protocol.PacketNumber
ackedPackets []*packet // to avoid allocations in detectAndRemoveAckedPackets
bytesInFlight protocol.ByteCount
congestion congestion.SendAlgorithmWithDebugInfos
rttStats *utils.RTTStats
// The number of times a PTO has been sent without receiving an ack.
ptoCount uint32
ptoMode SendMode
// The number of PTO probe packets that should be sent.
// Only applies to the application-data packet number space.
numProbesToSend int
// The alarm timeout
alarm alarmTimer
enableECN bool
ecnTracker ecnHandler
perspective protocol.Perspective
tracer *logging.ConnectionTracer
logger utils.Logger
}
var (
_ SentPacketHandler = &sentPacketHandler{}
_ sentPacketTracker = &sentPacketHandler{}
)
// clientAddressValidated indicates whether the address was validated beforehand by an address validation token.
// If the address was validated, the amplification limit doesn't apply. It has no effect for a client.
func newSentPacketHandler(
initialPN protocol.PacketNumber,
initialMaxDatagramSize protocol.ByteCount,
rttStats *utils.RTTStats,
clientAddressValidated bool,
enableECN bool,
pers protocol.Perspective,
tracer *logging.ConnectionTracer,
logger utils.Logger,
) *sentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
initialMaxDatagramSize,
true, // use Reno
tracer,
)
h := &sentPacketHandler{
peerCompletedAddressValidation: pers == protocol.PerspectiveServer,
peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated,
initialPackets: newPacketNumberSpace(initialPN, false),
handshakePackets: newPacketNumberSpace(0, false),
appDataPackets: newPacketNumberSpace(0, true),
rttStats: rttStats,
congestion: congestion,
perspective: pers,
tracer: tracer,
logger: logger,
}
if enableECN {
h.enableECN = true
h.ecnTracker = newECNTracker(logger, tracer)
}
return h
}
func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
if p.includedInBytesInFlight {
if p.Length > h.bytesInFlight {
panic("negative bytes_in_flight")
}
h.bytesInFlight -= p.Length
p.includedInBytesInFlight = false
}
}
func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel, now time.Time) {
// The server won't await address validation after the handshake is confirmed.
// This applies even if we didn't receive an ACK for a Handshake packet.
if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake {
h.peerCompletedAddressValidation = true
}
// remove outstanding packets from bytes_in_flight
if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake {
pnSpace := h.getPacketNumberSpace(encLevel)
// We might already have dropped this packet number space.
if pnSpace == nil {
return
}
for p := range pnSpace.history.Packets() {
h.removeFromBytesInFlight(p)
}
}
// drop the packet history
//nolint:exhaustive // Not every packet number space can be dropped.
switch encLevel {
case protocol.EncryptionInitial:
h.initialPackets = nil
case protocol.EncryptionHandshake:
// Dropping the handshake packet number space means that the handshake is confirmed,
// see section 4.9.2 of RFC 9001.
h.handshakeConfirmed = true
h.handshakePackets = nil
case protocol.Encryption0RTT:
// This function is only called when 0-RTT is rejected,
// and not when the client drops 0-RTT keys when the handshake completes.
// When 0-RTT is rejected, all application data sent so far becomes invalid.
// Delete the packets from the history and remove them from bytes_in_flight.
for p := range h.appDataPackets.history.Packets() {
if p.EncryptionLevel != protocol.Encryption0RTT && !p.skippedPacket {
break
}
h.removeFromBytesInFlight(p)
h.appDataPackets.history.Remove(p.PacketNumber)
}
default:
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
}
if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 {
h.tracer.UpdatedPTOCount(0)
}
h.ptoCount = 0
h.numProbesToSend = 0
h.ptoMode = SendNone
h.setLossDetectionTimer(now)
}
func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount, t time.Time) {
wasAmplificationLimit := h.isAmplificationLimited()
h.bytesReceived += n
if wasAmplificationLimit && !h.isAmplificationLimited() {
h.setLossDetectionTimer(t)
}
}
func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel, t time.Time) {
if h.perspective == protocol.PerspectiveServer && l == protocol.EncryptionHandshake && !h.peerAddressValidated {
h.peerAddressValidated = true
h.setLossDetectionTimer(t)
}
}
func (h *sentPacketHandler) packetsInFlight() int {
packetsInFlight := h.appDataPackets.history.Len()
if h.handshakePackets != nil {
packetsInFlight += h.handshakePackets.history.Len()
}
if h.initialPackets != nil {
packetsInFlight += h.initialPackets.history.Len()
}
return packetsInFlight
}
func (h *sentPacketHandler) SentPacket(
t time.Time,
pn, largestAcked protocol.PacketNumber,
streamFrames []StreamFrame,
frames []Frame,
encLevel protocol.EncryptionLevel,
ecn protocol.ECN,
size protocol.ByteCount,
isPathMTUProbePacket bool,
isPathProbePacket bool,
) {
h.bytesSent += size
pnSpace := h.getPacketNumberSpace(encLevel)
if h.logger.Debug() && (pnSpace.history.HasOutstandingPackets() || pnSpace.history.HasOutstandingPathProbes()) {
for p := max(0, pnSpace.largestSent+1); p < pn; p++ {
h.logger.Debugf("Skipping packet number %d", p)
}
}
pnSpace.largestSent = pn
isAckEliciting := len(streamFrames) > 0 || len(frames) > 0
if isPathProbePacket {
p := getPacket()
p.SendTime = t
p.PacketNumber = pn
p.EncryptionLevel = encLevel
p.Length = size
p.Frames = frames
p.isPathProbePacket = true
pnSpace.history.SentPathProbePacket(p)
h.setLossDetectionTimer(t)
return
}
if isAckEliciting {
pnSpace.lastAckElicitingPacketTime = t
h.bytesInFlight += size
if h.numProbesToSend > 0 {
h.numProbesToSend--
}
}
h.congestion.OnPacketSent(t, h.bytesInFlight, pn, size, isAckEliciting)
if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil {
h.ecnTracker.SentPacket(pn, ecn)
}
if !isAckEliciting {
pnSpace.history.SentNonAckElicitingPacket(pn)
if !h.peerCompletedAddressValidation {
h.setLossDetectionTimer(t)
}
return
}
p := getPacket()
p.SendTime = t
p.PacketNumber = pn
p.EncryptionLevel = encLevel
p.Length = size
p.LargestAcked = largestAcked
p.StreamFrames = streamFrames
p.Frames = frames
p.IsPathMTUProbePacket = isPathMTUProbePacket
p.includedInBytesInFlight = true
pnSpace.history.SentAckElicitingPacket(p)
if h.tracer != nil && h.tracer.UpdatedMetrics != nil {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
}
h.setLossDetectionTimer(t)
}
func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace {
switch encLevel {
case protocol.EncryptionInitial:
return h.initialPackets
case protocol.EncryptionHandshake:
return h.handshakePackets
case protocol.Encryption0RTT, protocol.Encryption1RTT:
return h.appDataPackets
default:
panic("invalid packet number space")
}
}
func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* contained 1-RTT packet */, error) {
pnSpace := h.getPacketNumberSpace(encLevel)
largestAcked := ack.LargestAcked()
if largestAcked > pnSpace.largestSent {
return false, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "received ACK for an unsent packet",
}
}
// Servers complete address validation when a protected packet is received.
if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation &&
(encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) {
h.peerCompletedAddressValidation = true
h.logger.Debugf("Peer doesn't await address validation any longer.")
// Make sure that the timer is reset, even if this ACK doesn't acknowledge any (ack-eliciting) packets.
h.setLossDetectionTimer(rcvTime)
}
priorInFlight := h.bytesInFlight
ackedPackets, err := h.detectAndRemoveAckedPackets(ack, encLevel)
if err != nil || len(ackedPackets) == 0 {
return false, err
}
// update the RTT, if the largest acked is newly acknowledged
if len(ackedPackets) > 0 {
if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() && !p.isPathProbePacket {
// don't use the ack delay for Initial and Handshake packets
var ackDelay time.Duration
if encLevel == protocol.Encryption1RTT {
ackDelay = min(ack.DelayTime, h.rttStats.MaxAckDelay())
}
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay)
if h.logger.Debug() {
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
}
h.congestion.MaybeExitSlowStart()
}
}
// Only inform the ECN tracker about new 1-RTT ACKs if the ACK increases the largest acked.
if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil && largestAcked > pnSpace.largestAcked {
congested := h.ecnTracker.HandleNewlyAcked(ackedPackets, int64(ack.ECT0), int64(ack.ECT1), int64(ack.ECNCE))
if congested {
h.congestion.OnCongestionEvent(largestAcked, 0, priorInFlight)
}
}
pnSpace.largestAcked = max(pnSpace.largestAcked, largestAcked)
h.detectLostPackets(rcvTime, encLevel)
if encLevel == protocol.Encryption1RTT {
h.detectLostPathProbes(rcvTime)
}
var acked1RTTPacket bool
for _, p := range ackedPackets {
if p.includedInBytesInFlight && !p.declaredLost {
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
}
if p.EncryptionLevel == protocol.Encryption1RTT {
acked1RTTPacket = true
}
h.removeFromBytesInFlight(p)
if !p.isPathProbePacket {
putPacket(p)
}
}
// After this point, we must not use ackedPackets any longer!
// We've already returned the buffers.
ackedPackets = nil //nolint:ineffassign // This is just to be on the safe side.
// Reset the pto_count unless the client is unsure if the server has validated the client's address.
if h.peerCompletedAddressValidation {
if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 {
h.tracer.UpdatedPTOCount(0)
}
h.ptoCount = 0
}
h.numProbesToSend = 0
if h.tracer != nil && h.tracer.UpdatedMetrics != nil {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
}
h.setLossDetectionTimer(rcvTime)
return acked1RTTPacket, nil
}
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
return h.lowestNotConfirmedAcked
}
// Packets are returned in ascending packet number order.
func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*packet, error) {
pnSpace := h.getPacketNumberSpace(encLevel)
h.ackedPackets = h.ackedPackets[:0]
ackRangeIndex := 0
lowestAcked := ack.LowestAcked()
largestAcked := ack.LargestAcked()
for p := range pnSpace.history.Packets() {
// ignore packets below the lowest acked
if p.PacketNumber < lowestAcked {
continue
}
if p.PacketNumber > largestAcked {
break
}
if ack.HasMissingRanges() {
ackRange := ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ack.AckRanges)-1 {
ackRangeIndex++
ackRange = ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
}
if p.PacketNumber < ackRange.Smallest { // packet not contained in ACK range
continue
}
if p.PacketNumber > ackRange.Largest {
return nil, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
}
}
if p.skippedPacket {
return nil, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel),
}
}
if p.isPathProbePacket {
probePacket := pnSpace.history.RemovePathProbe(p.PacketNumber)
// the probe packet might already have been declared lost
if probePacket != nil {
h.ackedPackets = append(h.ackedPackets, probePacket)
}
continue
}
h.ackedPackets = append(h.ackedPackets, p)
}
if h.logger.Debug() && len(h.ackedPackets) > 0 {
pns := make([]protocol.PacketNumber, len(h.ackedPackets))
for i, p := range h.ackedPackets {
pns[i] = p.PacketNumber
}
h.logger.Debugf("\tnewly acked packets (%d): %d", len(pns), pns)
}
for _, p := range h.ackedPackets {
if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT {
h.lowestNotConfirmedAcked = max(h.lowestNotConfirmedAcked, p.LargestAcked+1)
}
for _, f := range p.Frames {
if f.Handler != nil {
f.Handler.OnAcked(f.Frame)
}
}
for _, f := range p.StreamFrames {
if f.Handler != nil {
f.Handler.OnAcked(f.Frame)
}
}
if err := pnSpace.history.Remove(p.PacketNumber); err != nil {
return nil, err
}
if h.tracer != nil && h.tracer.AcknowledgedPacket != nil {
h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber)
}
}
return h.ackedPackets, nil
}
func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.EncryptionLevel) {
var encLevel protocol.EncryptionLevel
var lossTime time.Time
if h.initialPackets != nil {
lossTime = h.initialPackets.lossTime
encLevel = protocol.EncryptionInitial
}
if h.handshakePackets != nil && (lossTime.IsZero() || (!h.handshakePackets.lossTime.IsZero() && h.handshakePackets.lossTime.Before(lossTime))) {
lossTime = h.handshakePackets.lossTime
encLevel = protocol.EncryptionHandshake
}
if lossTime.IsZero() || (!h.appDataPackets.lossTime.IsZero() && h.appDataPackets.lossTime.Before(lossTime)) {
lossTime = h.appDataPackets.lossTime
encLevel = protocol.Encryption1RTT
}
return lossTime, encLevel
}
func (h *sentPacketHandler) getScaledPTO(includeMaxAckDelay bool) time.Duration {
pto := h.rttStats.PTO(includeMaxAckDelay) << h.ptoCount
if pto > maxPTODuration || pto <= 0 {
return maxPTODuration
}
return pto
}
// same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime
func (h *sentPacketHandler) getPTOTimeAndSpace(now time.Time) (pto time.Time, encLevel protocol.EncryptionLevel) {
// We only send application data probe packets once the handshake is confirmed,
// because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets.
if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() {
if h.peerCompletedAddressValidation {
return
}
t := now.Add(h.getScaledPTO(false))
if h.initialPackets != nil {
return t, protocol.EncryptionInitial
}
return t, protocol.EncryptionHandshake
}
if h.initialPackets != nil && h.initialPackets.history.HasOutstandingPackets() &&
!h.initialPackets.lastAckElicitingPacketTime.IsZero() {
encLevel = protocol.EncryptionInitial
if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() {
pto = t.Add(h.getScaledPTO(false))
}
}
if h.handshakePackets != nil && h.handshakePackets.history.HasOutstandingPackets() &&
!h.handshakePackets.lastAckElicitingPacketTime.IsZero() {
t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(false))
if pto.IsZero() || (!t.IsZero() && t.Before(pto)) {
pto = t
encLevel = protocol.EncryptionHandshake
}
}
if h.handshakeConfirmed && h.appDataPackets.history.HasOutstandingPackets() &&
!h.appDataPackets.lastAckElicitingPacketTime.IsZero() {
t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(true))
if pto.IsZero() || (!t.IsZero() && t.Before(pto)) {
pto = t
encLevel = protocol.Encryption1RTT
}
}
return pto, encLevel
}
func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool {
if h.initialPackets != nil && h.initialPackets.history.HasOutstandingPackets() {
return true
}
if h.handshakePackets != nil && h.handshakePackets.history.HasOutstandingPackets() {
return true
}
return false
}
func (h *sentPacketHandler) setLossDetectionTimer(now time.Time) {
oldAlarm := h.alarm // only needed in case tracing is enabled
newAlarm := h.lossDetectionTime(now)
h.alarm = newAlarm
hasAlarm := !newAlarm.Time.IsZero()
if !hasAlarm && !oldAlarm.Time.IsZero() {
h.logger.Debugf("Canceling loss detection timer.")
if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
h.tracer.LossTimerCanceled()
}
}
if hasAlarm && h.tracer != nil && h.tracer.SetLossTimer != nil && newAlarm != oldAlarm {
h.tracer.SetLossTimer(newAlarm.TimerType, newAlarm.EncryptionLevel, newAlarm.Time)
}
}
func (h *sentPacketHandler) lossDetectionTime(now time.Time) alarmTimer {
// cancel the alarm if no packets are outstanding
if h.peerCompletedAddressValidation && !h.hasOutstandingCryptoPackets() &&
!h.appDataPackets.history.HasOutstandingPackets() && !h.appDataPackets.history.HasOutstandingPathProbes() {
return alarmTimer{}
}
// cancel the alarm if amplification limited
if h.isAmplificationLimited() {
return alarmTimer{}
}
var pathProbeLossTime time.Time
if h.appDataPackets.history.HasOutstandingPathProbes() {
if p := h.appDataPackets.history.FirstOutstandingPathProbe(); p != nil {
pathProbeLossTime = p.SendTime.Add(pathProbePacketLossTimeout)
}
}
// early retransmit timer or time loss detection
lossTime, encLevel := h.getLossTimeAndSpace()
if !lossTime.IsZero() && (pathProbeLossTime.IsZero() || lossTime.Before(pathProbeLossTime)) {
return alarmTimer{
Time: lossTime,
TimerType: logging.TimerTypeACK,
EncryptionLevel: encLevel,
}
}
ptoTime, encLevel := h.getPTOTimeAndSpace(now)
if !ptoTime.IsZero() && (pathProbeLossTime.IsZero() || ptoTime.Before(pathProbeLossTime)) {
return alarmTimer{
Time: ptoTime,
TimerType: logging.TimerTypePTO,
EncryptionLevel: encLevel,
}
}
if !pathProbeLossTime.IsZero() {
return alarmTimer{
Time: pathProbeLossTime,
TimerType: logging.TimerTypePathProbe,
EncryptionLevel: protocol.Encryption1RTT,
}
}
return alarmTimer{}
}
func (h *sentPacketHandler) detectLostPathProbes(now time.Time) {
if !h.appDataPackets.history.HasOutstandingPathProbes() {
return
}
lossTime := now.Add(-pathProbePacketLossTimeout)
// RemovePathProbe cannot be called while iterating.
var lostPathProbes []*packet
for p := range h.appDataPackets.history.PathProbes() {
if !p.SendTime.After(lossTime) {
lostPathProbes = append(lostPathProbes, p)
}
}
for _, p := range lostPathProbes {
for _, f := range p.Frames {
f.Handler.OnLost(f.Frame)
}
h.appDataPackets.history.RemovePathProbe(p.PacketNumber)
}
}
func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) {
pnSpace := h.getPacketNumberSpace(encLevel)
pnSpace.lossTime = time.Time{}
maxRTT := float64(max(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
lossDelay := time.Duration(timeThreshold * maxRTT)
// Minimum time of granularity before packets are deemed lost.
lossDelay = max(lossDelay, protocol.TimerGranularity)
// Packets sent before this time are deemed lost.
lostSendTime := now.Add(-lossDelay)
priorInFlight := h.bytesInFlight
for p := range pnSpace.history.Packets() {
if p.PacketNumber > pnSpace.largestAcked {
break
}
isRegularPacket := !p.skippedPacket && !p.isPathProbePacket
var packetLost bool
if !p.SendTime.After(lostSendTime) {
packetLost = true
if isRegularPacket {
if h.logger.Debug() {
h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber)
}
if h.tracer != nil && h.tracer.LostPacket != nil {
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold)
}
}
} else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold {
packetLost = true
if isRegularPacket {
if h.logger.Debug() {
h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber)
}
if h.tracer != nil && h.tracer.LostPacket != nil {
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold)
}
}
} else if pnSpace.lossTime.IsZero() {
// Note: This conditional is only entered once per call
lossTime := p.SendTime.Add(lossDelay)
if h.logger.Debug() {
h.logger.Debugf("\tsetting loss timer for packet %d (%s) to %s (in %s)", p.PacketNumber, encLevel, lossDelay, lossTime)
}
pnSpace.lossTime = lossTime
}
if packetLost {
pnSpace.history.DeclareLost(p.PacketNumber)
if isRegularPacket {
// the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted
h.removeFromBytesInFlight(p)
h.queueFramesForRetransmission(p)
if !p.IsPathMTUProbePacket {
h.congestion.OnCongestionEvent(p.PacketNumber, p.Length, priorInFlight)
}
if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil {
h.ecnTracker.LostPacket(p.PacketNumber)
}
}
}
}
}
func (h *sentPacketHandler) OnLossDetectionTimeout(now time.Time) error {
defer h.setLossDetectionTimer(now)
if h.handshakeConfirmed {
h.detectLostPathProbes(now)
}
earliestLossTime, encLevel := h.getLossTimeAndSpace()
if !earliestLossTime.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime)
}
if h.tracer != nil && h.tracer.LossTimerExpired != nil {
h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel)
}
// Early retransmit or time loss detection
h.detectLostPackets(now, encLevel)
return nil
}
// PTO
// When all outstanding are acknowledged, the alarm is canceled in setLossDetectionTimer.
// However, there's no way to reset the timer in the connection.
// When OnLossDetectionTimeout is called, we therefore need to make sure that there are
// actually packets outstanding.
if h.bytesInFlight == 0 && !h.peerCompletedAddressValidation {
h.ptoCount++
h.numProbesToSend++
if h.initialPackets != nil {
h.ptoMode = SendPTOInitial
} else if h.handshakePackets != nil {
h.ptoMode = SendPTOHandshake
} else {
return errors.New("sentPacketHandler BUG: PTO fired, but bytes_in_flight is 0 and Initial and Handshake already dropped")
}
return nil
}
ptoTime, encLevel := h.getPTOTimeAndSpace(now)
if ptoTime.IsZero() {
return nil
}
ps := h.getPacketNumberSpace(encLevel)
if !ps.history.HasOutstandingPackets() && !ps.history.HasOutstandingPathProbes() && !h.peerCompletedAddressValidation {
return nil
}
h.ptoCount++
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount)
}
if h.tracer != nil {
if h.tracer.LossTimerExpired != nil {
h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel)
}
if h.tracer.UpdatedPTOCount != nil {
h.tracer.UpdatedPTOCount(h.ptoCount)
}
}
h.numProbesToSend += 2
//nolint:exhaustive // We never arm a PTO timer for 0-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
h.ptoMode = SendPTOInitial
case protocol.EncryptionHandshake:
h.ptoMode = SendPTOHandshake
case protocol.Encryption1RTT:
// skip a packet number in order to elicit an immediate ACK
pn := h.PopPacketNumber(protocol.Encryption1RTT)
h.getPacketNumberSpace(protocol.Encryption1RTT).history.SkippedPacket(pn)
h.ptoMode = SendPTOAppData
default:
return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel)
}
return nil
}
func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
return h.alarm.Time
}
func (h *sentPacketHandler) ECNMode(isShortHeaderPacket bool) protocol.ECN {
if !h.enableECN {
return protocol.ECNUnsupported
}
if !isShortHeaderPacket {
return protocol.ECNNon
}
return h.ecnTracker.Mode()
}
func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) {
pnSpace := h.getPacketNumberSpace(encLevel)
pn := pnSpace.pns.Peek()
// See section 17.1 of RFC 9000.
return pn, protocol.PacketNumberLengthForHeader(pn, pnSpace.largestAcked)
}
func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber {
pnSpace := h.getPacketNumberSpace(encLevel)
skipped, pn := pnSpace.pns.Pop()
if skipped {
skippedPN := pn - 1
pnSpace.history.SkippedPacket(skippedPN)
if h.logger.Debug() {
h.logger.Debugf("Skipping packet number %d", skippedPN)
}
}
return pn
}
func (h *sentPacketHandler) SendMode(now time.Time) SendMode {
numTrackedPackets := h.appDataPackets.history.Len()
if h.initialPackets != nil {
numTrackedPackets += h.initialPackets.history.Len()
}
if h.handshakePackets != nil {
numTrackedPackets += h.handshakePackets.history.Len()
}
if h.isAmplificationLimited() {
h.logger.Debugf("Amplification window limited. Received %d bytes, already sent out %d bytes", h.bytesReceived, h.bytesSent)
return SendNone
}
// Don't send any packets if we're keeping track of the maximum number of packets.
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
// but still allow sending of retransmissions and ACKs.
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
if h.logger.Debug() {
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
}
return SendNone
}
if h.numProbesToSend > 0 {
return h.ptoMode
}
// Only send ACKs if we're congestion limited.
if !h.congestion.CanSend(h.bytesInFlight) {
if h.logger.Debug() {
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, h.congestion.GetCongestionWindow())
}
return SendAck
}
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
if h.logger.Debug() {
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
}
return SendAck
}
if !h.congestion.HasPacingBudget(now) {
return SendPacingLimited
}
return SendAny
}
func (h *sentPacketHandler) TimeUntilSend() time.Time {
return h.congestion.TimeUntilSend(h.bytesInFlight)
}
func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) {
h.congestion.SetMaxDatagramSize(s)
}
func (h *sentPacketHandler) isAmplificationLimited() bool {
if h.peerAddressValidated {
return false
}
return h.bytesSent >= amplificationFactor*h.bytesReceived
}
func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool {
pnSpace := h.getPacketNumberSpace(encLevel)
p := pnSpace.history.FirstOutstanding()
if p == nil {
return false
}
h.queueFramesForRetransmission(p)
// TODO: don't declare the packet lost here.
// Keep track of acknowledged frames instead.
h.removeFromBytesInFlight(p)
pnSpace.history.DeclareLost(p.PacketNumber)
return true
}
func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) {
if len(p.Frames) == 0 && len(p.StreamFrames) == 0 {
panic("no frames")
}
for _, f := range p.Frames {
if f.Handler != nil {
f.Handler.OnLost(f.Frame)
}
}
for _, f := range p.StreamFrames {
if f.Handler != nil {
f.Handler.OnLost(f.Frame)
}
}
p.StreamFrames = nil
p.Frames = nil
}
func (h *sentPacketHandler) ResetForRetry(now time.Time) {
h.bytesInFlight = 0
var firstPacketSendTime time.Time
for p := range h.initialPackets.history.Packets() {
if firstPacketSendTime.IsZero() {
firstPacketSendTime = p.SendTime
}
if !p.declaredLost && !p.skippedPacket {
h.queueFramesForRetransmission(p)
}
}
// All application data packets sent at this point are 0-RTT packets.
// In the case of a Retry, we can assume that the server dropped all of them.
for p := range h.appDataPackets.history.Packets() {
if !p.declaredLost && !p.skippedPacket {
h.queueFramesForRetransmission(p)
}
}
// Only use the Retry to estimate the RTT if we didn't send any retransmission for the Initial.
// Otherwise, we don't know which Initial the Retry was sent in response to.
if h.ptoCount == 0 {
// Don't set the RTT to a value lower than 5ms here.
h.rttStats.UpdateRTT(max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0)
if h.logger.Debug() {
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
}
if h.tracer != nil && h.tracer.UpdatedMetrics != nil {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
}
}
h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Peek(), false)
h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Peek(), true)
oldAlarm := h.alarm
h.alarm = alarmTimer{}
if h.tracer != nil {
if h.tracer.UpdatedPTOCount != nil {
h.tracer.UpdatedPTOCount(0)
}
if !oldAlarm.Time.IsZero() && h.tracer.LossTimerCanceled != nil {
h.tracer.LossTimerCanceled()
}
}
h.ptoCount = 0
}
func (h *sentPacketHandler) MigratedPath(now time.Time, initialMaxDatagramSize protocol.ByteCount) {
h.rttStats.ResetForPathMigration()
for p := range h.appDataPackets.history.Packets() {
h.appDataPackets.history.DeclareLost(p.PacketNumber)
if !p.skippedPacket && !p.isPathProbePacket {
h.removeFromBytesInFlight(p)
h.queueFramesForRetransmission(p)
}
}
for p := range h.appDataPackets.history.PathProbes() {
h.appDataPackets.history.RemovePathProbe(p.PacketNumber)
}
h.congestion = congestion.NewCubicSender(
congestion.DefaultClock{},
h.rttStats,
initialMaxDatagramSize,
true, // use Reno
h.tracer,
)
h.setLossDetectionTimer(now)
}
package ackhandler
import (
"fmt"
"iter"
"github.com/quic-go/quic-go/internal/protocol"
)
type sentPacketHistory struct {
packets []*packet
pathProbePackets []*packet
numOutstanding int
highestPacketNumber protocol.PacketNumber
}
func newSentPacketHistory(isAppData bool) *sentPacketHistory {
h := &sentPacketHistory{
highestPacketNumber: protocol.InvalidPacketNumber,
}
if isAppData {
h.packets = make([]*packet, 0, 32)
} else {
h.packets = make([]*packet, 0, 6)
}
return h
}
func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNumber) {
if h.highestPacketNumber != protocol.InvalidPacketNumber {
if pn != h.highestPacketNumber+1 {
panic("non-sequential packet number use")
}
}
h.highestPacketNumber = pn
}
func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) {
h.checkSequentialPacketNumberUse(pn)
h.packets = append(h.packets, &packet{
PacketNumber: pn,
skippedPacket: true,
})
}
func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber) {
h.checkSequentialPacketNumberUse(pn)
if len(h.packets) > 0 {
h.packets = append(h.packets, nil)
}
}
func (h *sentPacketHistory) SentAckElicitingPacket(p *packet) {
h.checkSequentialPacketNumberUse(p.PacketNumber)
h.packets = append(h.packets, p)
if p.outstanding() {
h.numOutstanding++
}
}
func (h *sentPacketHistory) SentPathProbePacket(p *packet) {
h.checkSequentialPacketNumberUse(p.PacketNumber)
h.packets = append(h.packets, &packet{
PacketNumber: p.PacketNumber,
isPathProbePacket: true,
})
h.pathProbePackets = append(h.pathProbePackets, p)
}
func (h *sentPacketHistory) Packets() iter.Seq[*packet] {
return func(yield func(*packet) bool) {
for _, p := range h.packets {
if p == nil {
continue
}
if !yield(p) {
return
}
}
}
}
func (h *sentPacketHistory) PathProbes() iter.Seq[*packet] {
return func(yield func(*packet) bool) {
for _, p := range h.pathProbePackets {
if !yield(p) {
return
}
}
}
}
// FirstOutstanding returns the first outstanding packet.
func (h *sentPacketHistory) FirstOutstanding() *packet {
if !h.HasOutstandingPackets() {
return nil
}
for _, p := range h.packets {
if p != nil && p.outstanding() {
return p
}
}
return nil
}
// FirstOutstandingPathProbe returns the first outstanding path probe packet
func (h *sentPacketHistory) FirstOutstandingPathProbe() *packet {
if len(h.pathProbePackets) == 0 {
return nil
}
return h.pathProbePackets[0]
}
func (h *sentPacketHistory) Len() int {
return len(h.packets)
}
func (h *sentPacketHistory) Remove(pn protocol.PacketNumber) error {
idx, ok := h.getIndex(pn)
if !ok {
return fmt.Errorf("packet %d not found in sent packet history", pn)
}
p := h.packets[idx]
if p.outstanding() {
h.numOutstanding--
if h.numOutstanding < 0 {
panic("negative number of outstanding packets")
}
}
h.packets[idx] = nil
// clean up all skipped packets directly before this packet number
for idx > 0 {
idx--
p := h.packets[idx]
if p == nil || !p.skippedPacket {
break
}
h.packets[idx] = nil
}
if idx == 0 {
h.cleanupStart()
}
if len(h.packets) > 0 && h.packets[0] == nil {
panic("remove failed")
}
return nil
}
// RemovePathProbe removes a path probe packet.
// It scales O(N), but that's ok, since we don't expect to send many path probe packets.
// It is not valid to call this function in IteratePathProbes.
func (h *sentPacketHistory) RemovePathProbe(pn protocol.PacketNumber) *packet {
var packetToDelete *packet
idx := -1
for i, p := range h.pathProbePackets {
if p.PacketNumber == pn {
packetToDelete = p
idx = i
break
}
}
if idx != -1 {
// don't use slices.Delete, because it zeros the deleted element
copy(h.pathProbePackets[idx:], h.pathProbePackets[idx+1:])
h.pathProbePackets = h.pathProbePackets[:len(h.pathProbePackets)-1]
}
return packetToDelete
}
// getIndex gets the index of packet p in the packets slice.
func (h *sentPacketHistory) getIndex(p protocol.PacketNumber) (int, bool) {
if len(h.packets) == 0 {
return 0, false
}
first := h.packets[0].PacketNumber
if p < first {
return 0, false
}
index := int(p - first)
if index > len(h.packets)-1 {
return 0, false
}
return index, true
}
func (h *sentPacketHistory) HasOutstandingPackets() bool {
return h.numOutstanding > 0
}
func (h *sentPacketHistory) HasOutstandingPathProbes() bool {
return len(h.pathProbePackets) > 0
}
// delete all nil entries at the beginning of the packets slice
func (h *sentPacketHistory) cleanupStart() {
for i, p := range h.packets {
if p != nil {
h.packets = h.packets[i:]
return
}
}
h.packets = h.packets[:0]
}
func (h *sentPacketHistory) LowestPacketNumber() protocol.PacketNumber {
if len(h.packets) == 0 {
return protocol.InvalidPacketNumber
}
return h.packets[0].PacketNumber
}
func (h *sentPacketHistory) DeclareLost(pn protocol.PacketNumber) {
idx, ok := h.getIndex(pn)
if !ok {
return
}
p := h.packets[idx]
if p.outstanding() {
h.numOutstanding--
if h.numOutstanding < 0 {
panic("negative number of outstanding packets")
}
}
h.packets[idx] = nil
if idx == 0 {
h.cleanupStart()
}
}
package congestion
import (
"math"
"time"
"github.com/quic-go/quic-go/internal/protocol"
)
// Bandwidth of a connection
type Bandwidth uint64
const infBandwidth Bandwidth = math.MaxUint64
const (
// BitsPerSecond is 1 bit per second
BitsPerSecond Bandwidth = 1
// BytesPerSecond is 1 byte per second
BytesPerSecond = 8 * BitsPerSecond
)
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
func BandwidthFromDelta(bytes protocol.ByteCount, delta time.Duration) Bandwidth {
return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
}
package congestion
import "time"
// A Clock returns the current time
type Clock interface {
Now() time.Time
}
// DefaultClock implements the Clock interface using the Go stdlib clock.
type DefaultClock struct{}
var _ Clock = DefaultClock{}
// Now gets the current time
func (DefaultClock) Now() time.Time {
return time.Now()
}
package congestion
import (
"math"
"time"
"github.com/quic-go/quic-go/internal/protocol"
)
// This cubic implementation is based on the one found in Chromiums's QUIC
// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}.
// Constants based on TCP defaults.
// The following constants are in 2^10 fractions of a second instead of ms to
// allow a 10 shift right to divide.
// 1024*1024^3 (first 1024 is from 0.100^3)
// where 0.100 is 100 ms which is the scaling round trip time.
const (
cubeScale = 40
cubeCongestionWindowScale = 410
cubeFactor = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
// TODO: when re-enabling cubic, make sure to use the actual packet size here
maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSize)
)
const defaultNumConnections = 1
// Default Cubic backoff factor
const beta float32 = 0.7
// Additional backoff factor when loss occurs in the concave part of the Cubic
// curve. This additional backoff factor is expected to give up bandwidth to
// new concurrent flows and speed up convergence.
const betaLastMax float32 = 0.85
// Cubic implements the cubic algorithm from TCP
type Cubic struct {
clock Clock
// Number of connections to simulate.
numConnections int
// Time when this cycle started, after last loss event.
epoch time.Time
// Max congestion window used just before last loss event.
// Note: to improve fairness to other streams an additional back off is
// applied to this value if the new value is below our latest value.
lastMaxCongestionWindow protocol.ByteCount
// Number of acked bytes since the cycle started (epoch).
ackedBytesCount protocol.ByteCount
// TCP Reno equivalent congestion window in packets.
estimatedTCPcongestionWindow protocol.ByteCount
// Origin point of cubic function.
originPointCongestionWindow protocol.ByteCount
// Time to origin point of cubic function in 2^10 fractions of a second.
timeToOriginPoint uint32
// Last congestion window in packets computed by cubic function.
lastTargetCongestionWindow protocol.ByteCount
}
// NewCubic returns a new Cubic instance
func NewCubic(clock Clock) *Cubic {
c := &Cubic{
clock: clock,
numConnections: defaultNumConnections,
}
c.Reset()
return c
}
// Reset is called after a timeout to reset the cubic state
func (c *Cubic) Reset() {
c.epoch = time.Time{}
c.lastMaxCongestionWindow = 0
c.ackedBytesCount = 0
c.estimatedTCPcongestionWindow = 0
c.originPointCongestionWindow = 0
c.timeToOriginPoint = 0
c.lastTargetCongestionWindow = 0
}
func (c *Cubic) alpha() float32 {
// TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that
// beta here is a cwnd multiplier, and is equal to 1-beta from the paper.
// We derive the equivalent alpha for an N-connection emulation as:
b := c.beta()
return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b)
}
func (c *Cubic) beta() float32 {
// kNConnectionBeta is the backoff factor after loss for our N-connection
// emulation, which emulates the effective backoff of an ensemble of N
// TCP-Reno connections on a single loss event. The effective multiplier is
// computed as:
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
}
func (c *Cubic) betaLastMax() float32 {
// betaLastMax is the additional backoff factor after loss for our
// N-connection emulation, which emulates the additional backoff of
// an ensemble of N TCP-Reno connections on a single loss event. The
// effective multiplier is computed as:
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
}
// OnApplicationLimited is called on ack arrival when sender is unable to use
// the available congestion window. Resets Cubic state during quiescence.
func (c *Cubic) OnApplicationLimited() {
// When sender is not using the available congestion window, the window does
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
// using the entire window during the time since the beginning of the current
// "epoch" (the end of the last loss recovery period). Since
// application-limited periods break this assumption, we reset the epoch when
// in such a period. This reset effectively freezes congestion window growth
// through application-limited periods and allows Cubic growth to continue
// when the entire window is being used.
c.epoch = time.Time{}
}
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
// a loss event. Returns the new congestion window in packets. The new
// congestion window is a multiplicative decrease of our current window.
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount {
if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow {
// We never reached the old max, so assume we are competing with another
// flow. Use our extra back off factor to allow the other flow to go up.
c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
} else {
c.lastMaxCongestionWindow = currentCongestionWindow
}
c.epoch = time.Time{} // Reset time.
return protocol.ByteCount(float32(currentCongestionWindow) * c.beta())
}
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
// Returns the new congestion window in packets. The new congestion window
// follows a cubic function that depends on the time passed since last
// packet loss.
func (c *Cubic) CongestionWindowAfterAck(
ackedBytes protocol.ByteCount,
currentCongestionWindow protocol.ByteCount,
delayMin time.Duration,
eventTime time.Time,
) protocol.ByteCount {
c.ackedBytesCount += ackedBytes
if c.epoch.IsZero() {
// First ACK after a loss event.
c.epoch = eventTime // Start of epoch.
c.ackedBytesCount = ackedBytes // Reset count.
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
c.estimatedTCPcongestionWindow = currentCongestionWindow
if c.lastMaxCongestionWindow <= currentCongestionWindow {
c.timeToOriginPoint = 0
c.originPointCongestionWindow = currentCongestionWindow
} else {
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
c.originPointCongestionWindow = c.lastMaxCongestionWindow
}
}
// Change the time unit from microseconds to 2^10 fractions per second. Take
// the round trip time in account. This is done to allow us to use shift as a
// divide operator.
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
// Right-shifts of negative, signed numbers have implementation-dependent
// behavior, so force the offset to be positive, as is done in the kernel.
offset := int64(c.timeToOriginPoint) - elapsedTime
if offset < 0 {
offset = -offset
}
deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale
var targetCongestionWindow protocol.ByteCount
if elapsedTime > int64(c.timeToOriginPoint) {
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
} else {
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
}
// Limit the CWND increase to half the acked bytes.
targetCongestionWindow = min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
// Increase the window by approximately Alpha * 1 MSS of bytes every
// time we ack an estimated tcp window of bytes. For small
// congestion windows (less than 25), the formula below will
// increase slightly slower than linearly per estimated tcp window
// of bytes.
c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow))
c.ackedBytesCount = 0
// We have a new cubic congestion window.
c.lastTargetCongestionWindow = targetCongestionWindow
// Compute target congestion_window based on cubic target and estimated TCP
// congestion_window, use highest (fastest).
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
targetCongestionWindow = c.estimatedTCPcongestionWindow
}
return targetCongestionWindow
}
// SetNumConnections sets the number of emulated connections
func (c *Cubic) SetNumConnections(n int) {
c.numConnections = n
}
package congestion
import (
"fmt"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
const (
// maxDatagramSize is the default maximum packet size used in the Linux TCP implementation.
// Used in QUIC for congestion window computations in bytes.
initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSize)
maxBurstPackets = 3
renoBeta = 0.7 // Reno backoff factor.
minCongestionWindowPackets = 2
initialCongestionWindow = 32
)
type cubicSender struct {
hybridSlowStart HybridSlowStart
rttStats *utils.RTTStats
cubic *Cubic
pacer *pacer
clock Clock
reno bool
// Track the largest packet that has been sent.
largestSentPacketNumber protocol.PacketNumber
// Track the largest packet that has been acked.
largestAckedPacketNumber protocol.PacketNumber
// Track the largest packet number outstanding when a CWND cutback occurs.
largestSentAtLastCutback protocol.PacketNumber
// Whether the last loss event caused us to exit slowstart.
// Used for stats collection of slowstartPacketsLost
lastCutbackExitedSlowstart bool
// Congestion window in bytes.
congestionWindow protocol.ByteCount
// Slow start congestion window in bytes, aka ssthresh.
slowStartThreshold protocol.ByteCount
// ACK counter for the Reno implementation.
numAckedPackets uint64
initialCongestionWindow protocol.ByteCount
initialMaxCongestionWindow protocol.ByteCount
maxDatagramSize protocol.ByteCount
lastState logging.CongestionState
tracer *logging.ConnectionTracer
}
var (
_ SendAlgorithm = &cubicSender{}
_ SendAlgorithmWithDebugInfos = &cubicSender{}
)
// NewCubicSender makes a new cubic sender
func NewCubicSender(
clock Clock,
rttStats *utils.RTTStats,
initialMaxDatagramSize protocol.ByteCount,
reno bool,
tracer *logging.ConnectionTracer,
) *cubicSender {
return newCubicSender(
clock,
rttStats,
reno,
initialMaxDatagramSize,
initialCongestionWindow*initialMaxDatagramSize,
protocol.MaxCongestionWindowPackets*initialMaxDatagramSize,
tracer,
)
}
func newCubicSender(
clock Clock,
rttStats *utils.RTTStats,
reno bool,
initialMaxDatagramSize,
initialCongestionWindow,
initialMaxCongestionWindow protocol.ByteCount,
tracer *logging.ConnectionTracer,
) *cubicSender {
c := &cubicSender{
rttStats: rttStats,
largestSentPacketNumber: protocol.InvalidPacketNumber,
largestAckedPacketNumber: protocol.InvalidPacketNumber,
largestSentAtLastCutback: protocol.InvalidPacketNumber,
initialCongestionWindow: initialCongestionWindow,
initialMaxCongestionWindow: initialMaxCongestionWindow,
congestionWindow: initialCongestionWindow,
slowStartThreshold: protocol.MaxByteCount,
cubic: NewCubic(clock),
clock: clock,
reno: reno,
tracer: tracer,
maxDatagramSize: initialMaxDatagramSize,
}
c.pacer = newPacer(c.BandwidthEstimate)
if c.tracer != nil && c.tracer.UpdatedCongestionState != nil {
c.lastState = logging.CongestionStateSlowStart
c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart)
}
return c
}
// TimeUntilSend returns when the next packet should be sent.
func (c *cubicSender) TimeUntilSend(_ protocol.ByteCount) time.Time {
return c.pacer.TimeUntilSend()
}
func (c *cubicSender) HasPacingBudget(now time.Time) bool {
return c.pacer.Budget(now) >= c.maxDatagramSize
}
func (c *cubicSender) maxCongestionWindow() protocol.ByteCount {
return c.maxDatagramSize * protocol.MaxCongestionWindowPackets
}
func (c *cubicSender) minCongestionWindow() protocol.ByteCount {
return c.maxDatagramSize * minCongestionWindowPackets
}
func (c *cubicSender) OnPacketSent(
sentTime time.Time,
_ protocol.ByteCount,
packetNumber protocol.PacketNumber,
bytes protocol.ByteCount,
isRetransmittable bool,
) {
c.pacer.SentPacket(sentTime, bytes)
if !isRetransmittable {
return
}
c.largestSentPacketNumber = packetNumber
c.hybridSlowStart.OnPacketSent(packetNumber)
}
func (c *cubicSender) CanSend(bytesInFlight protocol.ByteCount) bool {
return bytesInFlight < c.GetCongestionWindow()
}
func (c *cubicSender) InRecovery() bool {
return c.largestAckedPacketNumber != protocol.InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
}
func (c *cubicSender) InSlowStart() bool {
return c.GetCongestionWindow() < c.slowStartThreshold
}
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
return c.congestionWindow
}
func (c *cubicSender) MaybeExitSlowStart() {
if c.InSlowStart() &&
c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
// exit slow start
c.slowStartThreshold = c.congestionWindow
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
}
}
func (c *cubicSender) OnPacketAcked(
ackedPacketNumber protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
c.largestAckedPacketNumber = max(ackedPacketNumber, c.largestAckedPacketNumber)
if c.InRecovery() {
return
}
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
if c.InSlowStart() {
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
}
}
func (c *cubicSender) OnCongestionEvent(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) {
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
// already sent should be treated as a single loss event, since it's expected.
if packetNumber <= c.largestSentAtLastCutback {
return
}
c.lastCutbackExitedSlowstart = c.InSlowStart()
c.maybeTraceStateChange(logging.CongestionStateRecovery)
if c.reno {
c.congestionWindow = protocol.ByteCount(float64(c.congestionWindow) * renoBeta)
} else {
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
}
if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
c.congestionWindow = minCwnd
}
c.slowStartThreshold = c.congestionWindow
c.largestSentAtLastCutback = c.largestSentPacketNumber
// reset packet count from congestion avoidance mode. We start
// counting again when we're out of recovery.
c.numAckedPackets = 0
}
// Called when we receive an ack. Normal TCP tracks how many packets one ack
// represents, but quic has a separate ack for each packet.
func (c *cubicSender) maybeIncreaseCwnd(
_ protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
// Do not increase the congestion window unless the sender is close to using
// the current window.
if !c.isCwndLimited(priorInFlight) {
c.cubic.OnApplicationLimited()
c.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
return
}
if c.congestionWindow >= c.maxCongestionWindow() {
return
}
if c.InSlowStart() {
// TCP slow start, exponential growth, increase by one for each ACK.
c.congestionWindow += c.maxDatagramSize
c.maybeTraceStateChange(logging.CongestionStateSlowStart)
return
}
// Congestion avoidance
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
if c.reno {
// Classic Reno congestion avoidance.
c.numAckedPackets++
if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
c.congestionWindow += c.maxDatagramSize
c.numAckedPackets = 0
}
} else {
c.congestionWindow = min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
}
}
func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool {
congestionWindow := c.GetCongestionWindow()
if bytesInFlight >= congestionWindow {
return true
}
availableBytes := congestionWindow - bytesInFlight
slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
}
// BandwidthEstimate returns the current bandwidth estimate
func (c *cubicSender) BandwidthEstimate() Bandwidth {
srtt := c.rttStats.SmoothedRTT()
if srtt == 0 {
// If we haven't measured an rtt, the bandwidth estimate is unknown.
return infBandwidth
}
return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
}
// OnRetransmissionTimeout is called on an retransmission timeout
func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
c.largestSentAtLastCutback = protocol.InvalidPacketNumber
if !packetsRetransmitted {
return
}
c.hybridSlowStart.Restart()
c.cubic.Reset()
c.slowStartThreshold = c.congestionWindow / 2
c.congestionWindow = c.minCongestionWindow()
}
// OnConnectionMigration is called when the connection is migrated (?)
func (c *cubicSender) OnConnectionMigration() {
c.hybridSlowStart.Restart()
c.largestSentPacketNumber = protocol.InvalidPacketNumber
c.largestAckedPacketNumber = protocol.InvalidPacketNumber
c.largestSentAtLastCutback = protocol.InvalidPacketNumber
c.lastCutbackExitedSlowstart = false
c.cubic.Reset()
c.numAckedPackets = 0
c.congestionWindow = c.initialCongestionWindow
c.slowStartThreshold = c.initialMaxCongestionWindow
}
func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) {
if c.tracer == nil || c.tracer.UpdatedCongestionState == nil || new == c.lastState {
return
}
c.tracer.UpdatedCongestionState(new)
c.lastState = new
}
func (c *cubicSender) SetMaxDatagramSize(s protocol.ByteCount) {
if s < c.maxDatagramSize {
panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
}
cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
c.maxDatagramSize = s
if cwndIsMinCwnd {
c.congestionWindow = c.minCongestionWindow()
}
c.pacer.SetMaxDatagramSize(s)
}
package congestion
import (
"time"
"github.com/quic-go/quic-go/internal/protocol"
)
// Note(pwestin): the magic clamping numbers come from the original code in
// tcp_cubic.c.
const hybridStartLowWindow = protocol.ByteCount(16)
// Number of delay samples for detecting the increase of delay.
const hybridStartMinSamples = uint32(8)
// Exit slow start if the min rtt has increased by more than 1/8th.
const hybridStartDelayFactorExp = 3 // 2^3 = 8
// The original paper specifies 2 and 8ms, but those have changed over time.
const (
hybridStartDelayMinThresholdUs = int64(4000)
hybridStartDelayMaxThresholdUs = int64(16000)
)
// HybridSlowStart implements the TCP hybrid slow start algorithm
type HybridSlowStart struct {
endPacketNumber protocol.PacketNumber
lastSentPacketNumber protocol.PacketNumber
started bool
currentMinRTT time.Duration
rttSampleCount uint32
hystartFound bool
}
// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase.
func (s *HybridSlowStart) StartReceiveRound(lastSent protocol.PacketNumber) {
s.endPacketNumber = lastSent
s.currentMinRTT = 0
s.rttSampleCount = 0
s.started = true
}
// IsEndOfRound returns true if this ack is the last packet number of our current slow start round.
func (s *HybridSlowStart) IsEndOfRound(ack protocol.PacketNumber) bool {
return s.endPacketNumber < ack
}
// ShouldExitSlowStart should be called on every new ack frame, since a new
// RTT measurement can be made then.
// rtt: the RTT for this ack packet.
// minRTT: is the lowest delay (RTT) we have seen during the session.
// congestionWindow: the congestion window in packets.
func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow protocol.ByteCount) bool {
if !s.started {
// Time to start the hybrid slow start.
s.StartReceiveRound(s.lastSentPacketNumber)
}
if s.hystartFound {
return true
}
// Second detection parameter - delay increase detection.
// Compare the minimum delay (s.currentMinRTT) of the current
// burst of packets relative to the minimum delay during the session.
// Note: we only look at the first few(8) packets in each burst, since we
// only want to compare the lowest RTT of the burst relative to previous
// bursts.
s.rttSampleCount++
if s.rttSampleCount <= hybridStartMinSamples {
if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT {
s.currentMinRTT = latestRTT
}
}
// We only need to check this once per round.
if s.rttSampleCount == hybridStartMinSamples {
// Divide minRTT by 8 to get a rtt increase threshold for exiting.
minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
// Ensure the rtt threshold is never less than 2ms or more than 16ms.
minRTTincreaseThresholdUs = min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
minRTTincreaseThreshold := time.Duration(max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
s.hystartFound = true
}
}
// Exit from slow start if the cwnd is greater than 16 and
// increasing delay is found.
return congestionWindow >= hybridStartLowWindow && s.hystartFound
}
// OnPacketSent is called when a packet was sent
func (s *HybridSlowStart) OnPacketSent(packetNumber protocol.PacketNumber) {
s.lastSentPacketNumber = packetNumber
}
// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end
// the round when the final packet of the burst is received and start it on
// the next incoming ack.
func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber protocol.PacketNumber) {
if s.IsEndOfRound(ackedPacketNumber) {
s.started = false
}
}
// Started returns true if started
func (s *HybridSlowStart) Started() bool {
return s.started
}
// Restart the slow start phase
func (s *HybridSlowStart) Restart() {
s.started = false
s.hystartFound = false
}
package congestion
import (
"time"
"github.com/quic-go/quic-go/internal/protocol"
)
const maxBurstSizePackets = 10
// The pacer implements a token bucket pacing algorithm.
type pacer struct {
budgetAtLastSent protocol.ByteCount
maxDatagramSize protocol.ByteCount
lastSentTime time.Time
adjustedBandwidth func() uint64 // in bytes/s
}
func newPacer(getBandwidth func() Bandwidth) *pacer {
p := &pacer{
maxDatagramSize: initialMaxDatagramSize,
adjustedBandwidth: func() uint64 {
// Bandwidth is in bits/s. We need the value in bytes/s.
bw := uint64(getBandwidth() / BytesPerSecond)
// Use a slightly higher value than the actual measured bandwidth.
// RTT variations then won't result in under-utilization of the congestion window.
// Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire,
// provided the congestion window is fully utilized and acknowledgments arrive at regular intervals.
return bw * 5 / 4
},
}
p.budgetAtLastSent = p.maxBurstSize()
return p
}
func (p *pacer) SentPacket(sendTime time.Time, size protocol.ByteCount) {
budget := p.Budget(sendTime)
if size >= budget {
p.budgetAtLastSent = 0
} else {
p.budgetAtLastSent = budget - size
}
p.lastSentTime = sendTime
}
func (p *pacer) Budget(now time.Time) protocol.ByteCount {
if p.lastSentTime.IsZero() {
return p.maxBurstSize()
}
budget := p.budgetAtLastSent + (protocol.ByteCount(p.adjustedBandwidth())*protocol.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
if budget < 0 { // protect against overflows
budget = protocol.MaxByteCount
}
return min(p.maxBurstSize(), budget)
}
func (p *pacer) maxBurstSize() protocol.ByteCount {
return max(
protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.adjustedBandwidth())/1e9,
maxBurstSizePackets*p.maxDatagramSize,
)
}
// TimeUntilSend returns when the next packet should be sent.
// It returns the zero value of time.Time if a packet can be sent immediately.
func (p *pacer) TimeUntilSend() time.Time {
if p.budgetAtLastSent >= p.maxDatagramSize {
return time.Time{}
}
diff := 1e9 * uint64(p.maxDatagramSize-p.budgetAtLastSent)
bw := p.adjustedBandwidth()
// We might need to round up this value.
// Otherwise, we might have a budget (slightly) smaller than the datagram size when the timer expires.
d := diff / bw
// this is effectively a math.Ceil, but using only integer math
if diff%bw > 0 {
d++
}
return p.lastSentTime.Add(max(protocol.MinPacingDelay, time.Duration(d)*time.Nanosecond))
}
func (p *pacer) SetMaxDatagramSize(s protocol.ByteCount) {
p.maxDatagramSize = s
}
package flowcontrol
import (
"sync"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
type baseFlowController struct {
// for sending data
bytesSent protocol.ByteCount
sendWindow protocol.ByteCount
lastBlockedAt protocol.ByteCount
// for receiving data
//nolint:structcheck // The mutex is used both by the stream and the connection flow controller
mutex sync.Mutex
bytesRead protocol.ByteCount
highestReceived protocol.ByteCount
receiveWindow protocol.ByteCount
receiveWindowSize protocol.ByteCount
maxReceiveWindowSize protocol.ByteCount
allowWindowIncrease func(size protocol.ByteCount) bool
epochStartTime time.Time
epochStartOffset protocol.ByteCount
rttStats *utils.RTTStats
logger utils.Logger
}
// IsNewlyBlocked says if it is newly blocked by flow control.
// For every offset, it only returns true once.
// If it is blocked, the offset is returned.
func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
if c.SendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
return false, 0
}
c.lastBlockedAt = c.sendWindow
return true, c.sendWindow
}
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
c.bytesSent += n
}
// UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame.
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) (updated bool) {
if offset > c.sendWindow {
c.sendWindow = offset
return true
}
return false
}
func (c *baseFlowController) SendWindowSize() protocol.ByteCount {
// this only happens during connection establishment, when data is sent before we receive the peer's transport parameters
if c.bytesSent > c.sendWindow {
return 0
}
return c.sendWindow - c.bytesSent
}
// needs to be called with locked mutex
func (c *baseFlowController) addBytesRead(n protocol.ByteCount) {
c.bytesRead += n
}
func (c *baseFlowController) hasWindowUpdate() bool {
bytesRemaining := c.receiveWindow - c.bytesRead
// update the window when more than the threshold was consumed
return bytesRemaining <= protocol.ByteCount(float64(c.receiveWindowSize)*(1-protocol.WindowUpdateThreshold))
}
// getWindowUpdate updates the receive window, if necessary
// it returns the new offset
func (c *baseFlowController) getWindowUpdate(now time.Time) protocol.ByteCount {
if !c.hasWindowUpdate() {
return 0
}
c.maybeAdjustWindowSize(now)
c.receiveWindow = c.bytesRead + c.receiveWindowSize
return c.receiveWindow
}
// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
func (c *baseFlowController) maybeAdjustWindowSize(now time.Time) {
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
// don't do anything if less than half the window has been consumed
if bytesReadInEpoch <= c.receiveWindowSize/2 {
return
}
rtt := c.rttStats.SmoothedRTT()
if rtt == 0 {
return
}
fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
// window is consumed too fast, try to increase the window size
newSize := min(2*c.receiveWindowSize, c.maxReceiveWindowSize)
if newSize > c.receiveWindowSize && (c.allowWindowIncrease == nil || c.allowWindowIncrease(newSize-c.receiveWindowSize)) {
c.receiveWindowSize = newSize
}
}
c.startNewAutoTuningEpoch(now)
}
func (c *baseFlowController) startNewAutoTuningEpoch(now time.Time) {
c.epochStartTime = now
c.epochStartOffset = c.bytesRead
}
func (c *baseFlowController) checkFlowControlViolation() bool {
return c.highestReceived > c.receiveWindow
}
package flowcontrol
import (
"errors"
"fmt"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
)
type connectionFlowController struct {
baseFlowController
}
var _ ConnectionFlowController = &connectionFlowController{}
// NewConnectionFlowController gets a new flow controller for the connection
// It is created before we receive the peer's transport parameters, thus it starts with a sendWindow of 0.
func NewConnectionFlowController(
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
allowWindowIncrease func(size protocol.ByteCount) bool,
rttStats *utils.RTTStats,
logger utils.Logger,
) *connectionFlowController {
return &connectionFlowController{
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow,
allowWindowIncrease: allowWindowIncrease,
logger: logger,
},
}
}
// IncrementHighestReceived adds an increment to the highestReceived value
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount, now time.Time) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// If this is the first frame received on this connection, start flow-control auto-tuning.
if c.highestReceived == 0 {
c.startNewAutoTuningEpoch(now)
}
c.highestReceived += increment
if c.checkFlowControlViolation() {
return &qerr.TransportError{
ErrorCode: qerr.FlowControlError,
ErrorMessage: fmt.Sprintf("received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow),
}
}
return nil
}
func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) (hasWindowUpdate bool) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.addBytesRead(n)
return c.hasWindowUpdate()
}
func (c *connectionFlowController) GetWindowUpdate(now time.Time) protocol.ByteCount {
c.mutex.Lock()
defer c.mutex.Unlock()
oldWindowSize := c.receiveWindowSize
offset := c.getWindowUpdate(now)
if c.logger.Debug() && oldWindowSize < c.receiveWindowSize {
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
}
return offset
}
// EnsureMinimumWindowSize sets a minimum window size
// it should make sure that the connection-level window is increased when a stream-level window grows
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount, now time.Time) {
c.mutex.Lock()
defer c.mutex.Unlock()
if inc <= c.receiveWindowSize {
return
}
newSize := min(inc, c.maxReceiveWindowSize)
if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) {
c.receiveWindowSize = newSize
if c.logger.Debug() {
c.logger.Debugf("Increasing receive flow control window for the connection to %d, in response to stream flow control window increase", newSize)
}
}
c.startNewAutoTuningEpoch(now)
}
// Reset rests the flow controller. This happens when 0-RTT is rejected.
// All stream data is invalidated, it's as if we had never opened a stream and never sent any data.
// At that point, we only have sent stream data, but we didn't have the keys to open 1-RTT keys yet.
func (c *connectionFlowController) Reset() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.bytesRead > 0 || c.highestReceived > 0 || !c.epochStartTime.IsZero() {
return errors.New("flow controller reset after reading data")
}
c.bytesSent = 0
c.lastBlockedAt = 0
c.sendWindow = 0
return nil
}
package flowcontrol
import (
"fmt"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
)
type streamFlowController struct {
baseFlowController
streamID protocol.StreamID
connection connectionFlowControllerI
receivedFinalOffset bool
}
var _ StreamFlowController = &streamFlowController{}
// NewStreamFlowController gets a new flow controller for a stream
func NewStreamFlowController(
streamID protocol.StreamID,
cfc ConnectionFlowController,
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
initialSendWindow protocol.ByteCount,
rttStats *utils.RTTStats,
logger utils.Logger,
) StreamFlowController {
return &streamFlowController{
streamID: streamID,
connection: cfc.(connectionFlowControllerI),
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow,
sendWindow: initialSendWindow,
logger: logger,
},
}
}
// UpdateHighestReceived updates the highestReceived value, if the offset is higher.
func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool, now time.Time) error {
// If the final offset for this stream is already known, check for consistency.
if c.receivedFinalOffset {
// If we receive another final offset, check that it's the same.
if final && offset != c.highestReceived {
return &qerr.TransportError{
ErrorCode: qerr.FinalSizeError,
ErrorMessage: fmt.Sprintf("received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, offset),
}
}
// Check that the offset is below the final offset.
if offset > c.highestReceived {
return &qerr.TransportError{
ErrorCode: qerr.FinalSizeError,
ErrorMessage: fmt.Sprintf("received offset %d for stream %d, but final offset was already received at %d", offset, c.streamID, c.highestReceived),
}
}
}
if final {
c.receivedFinalOffset = true
}
if offset == c.highestReceived {
return nil
}
// A higher offset was received before. This can happen due to reordering.
if offset < c.highestReceived {
if final {
return &qerr.TransportError{
ErrorCode: qerr.FinalSizeError,
ErrorMessage: fmt.Sprintf("received final offset %d for stream %d, but already received offset %d before", offset, c.streamID, c.highestReceived),
}
}
return nil
}
// If this is the first frame received for this stream, start flow-control auto-tuning.
if c.highestReceived == 0 {
c.startNewAutoTuningEpoch(now)
}
increment := offset - c.highestReceived
c.highestReceived = offset
if c.checkFlowControlViolation() {
return &qerr.TransportError{
ErrorCode: qerr.FlowControlError,
ErrorMessage: fmt.Sprintf("received %d bytes on stream %d, allowed %d bytes", offset, c.streamID, c.receiveWindow),
}
}
return c.connection.IncrementHighestReceived(increment, now)
}
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) (hasStreamWindowUpdate, hasConnWindowUpdate bool) {
c.mutex.Lock()
c.addBytesRead(n)
hasStreamWindowUpdate = c.shouldQueueWindowUpdate()
c.mutex.Unlock()
hasConnWindowUpdate = c.connection.AddBytesRead(n)
return
}
func (c *streamFlowController) Abandon() {
c.mutex.Lock()
unread := c.highestReceived - c.bytesRead
c.bytesRead = c.highestReceived
c.mutex.Unlock()
if unread > 0 {
c.connection.AddBytesRead(unread)
}
}
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
c.baseFlowController.AddBytesSent(n)
c.connection.AddBytesSent(n)
}
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
return min(c.baseFlowController.SendWindowSize(), c.connection.SendWindowSize())
}
func (c *streamFlowController) IsNewlyBlocked() bool {
blocked, _ := c.baseFlowController.IsNewlyBlocked()
return blocked
}
func (c *streamFlowController) shouldQueueWindowUpdate() bool {
return !c.receivedFinalOffset && c.hasWindowUpdate()
}
func (c *streamFlowController) GetWindowUpdate(now time.Time) protocol.ByteCount {
// If we already received the final offset for this stream, the peer won't need any additional flow control credit.
if c.receivedFinalOffset {
return 0
}
c.mutex.Lock()
defer c.mutex.Unlock()
oldWindowSize := c.receiveWindowSize
offset := c.getWindowUpdate(now)
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
c.logger.Debugf("Increasing receive flow control window for stream %d to %d", c.streamID, c.receiveWindowSize)
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize)*protocol.ConnectionFlowControlMultiplier), now)
}
return offset
}
package handshake
import (
"encoding/binary"
"github.com/quic-go/quic-go/internal/protocol"
)
func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.Version) *xorNonceAEAD {
keyLabel := hkdfLabelKeyV1
ivLabel := hkdfLabelIVV1
if v == protocol.Version2 {
keyLabel = hkdfLabelKeyV2
ivLabel = hkdfLabelIVV2
}
key := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, keyLabel, suite.KeyLen)
iv := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, ivLabel, suite.IVLen())
return suite.AEAD(key, iv)
}
type longHeaderSealer struct {
aead *xorNonceAEAD
headerProtector headerProtector
nonceBuf [8]byte
}
var _ LongHeaderSealer = &longHeaderSealer{}
func newLongHeaderSealer(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderSealer {
if aead.NonceSize() != 8 {
panic("unexpected nonce size")
}
return &longHeaderSealer{
aead: aead,
headerProtector: headerProtector,
}
}
func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
binary.BigEndian.PutUint64(s.nonceBuf[:], uint64(pn))
return s.aead.Seal(dst, s.nonceBuf[:], src, ad)
}
func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
s.headerProtector.EncryptHeader(sample, firstByte, pnBytes)
}
func (s *longHeaderSealer) Overhead() int {
return s.aead.Overhead()
}
type longHeaderOpener struct {
aead *xorNonceAEAD
headerProtector headerProtector
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
// use a single array to avoid allocations
nonceBuf [8]byte
}
var _ LongHeaderOpener = &longHeaderOpener{}
func newLongHeaderOpener(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderOpener {
if aead.NonceSize() != 8 {
panic("unexpected nonce size")
}
return &longHeaderOpener{
aead: aead,
headerProtector: headerProtector,
}
}
func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN)
}
func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[:], uint64(pn))
dec, err := o.aead.Open(dst, o.nonceBuf[:], src, ad)
if err == nil {
o.highestRcvdPN = max(o.highestRcvdPN, pn)
} else {
err = ErrDecryptionFailed
}
return dec, err
}
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
o.headerProtector.DecryptHeader(sample, firstByte, pnBytes)
}
package handshake
import (
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/tls"
"fmt"
"golang.org/x/crypto/chacha20poly1305"
)
// These cipher suite implementations are copied from the standard library crypto/tls package.
const aeadNonceLength = 12
type cipherSuite struct {
ID uint16
Hash crypto.Hash
KeyLen int
AEAD func(key, nonceMask []byte) *xorNonceAEAD
}
func (s cipherSuite) IVLen() int { return aeadNonceLength }
func getCipherSuite(id uint16) *cipherSuite {
switch id {
case tls.TLS_AES_128_GCM_SHA256:
return &cipherSuite{ID: tls.TLS_AES_128_GCM_SHA256, Hash: crypto.SHA256, KeyLen: 16, AEAD: aeadAESGCMTLS13}
case tls.TLS_CHACHA20_POLY1305_SHA256:
return &cipherSuite{ID: tls.TLS_CHACHA20_POLY1305_SHA256, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadChaCha20Poly1305}
case tls.TLS_AES_256_GCM_SHA384:
return &cipherSuite{ID: tls.TLS_AES_256_GCM_SHA384, Hash: crypto.SHA384, KeyLen: 32, AEAD: aeadAESGCMTLS13}
default:
panic(fmt.Sprintf("unknown cypher suite: %d", id))
}
}
func aeadAESGCMTLS13(key, nonceMask []byte) *xorNonceAEAD {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aes, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(aes)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
func aeadChaCha20Poly1305(key, nonceMask []byte) *xorNonceAEAD {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aead, err := chacha20poly1305.New(key)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce
// before each call.
type xorNonceAEAD struct {
nonceMask [aeadNonceLength]byte
aead cipher.AEAD
}
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result
}
func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result, err
}
package handshake
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"strings"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/quicvarint"
)
type quicVersionContextKey struct{}
var QUICVersionContextKey = &quicVersionContextKey{}
const clientSessionStateRevision = 5
type cryptoSetup struct {
tlsConf *tls.Config
conn *tls.QUICConn
events []Event
version protocol.Version
ourParams *wire.TransportParameters
peerParams *wire.TransportParameters
zeroRTTParameters *wire.TransportParameters
allow0RTT bool
rttStats *utils.RTTStats
tracer *logging.ConnectionTracer
logger utils.Logger
perspective protocol.Perspective
handshakeCompleteTime time.Time
zeroRTTOpener LongHeaderOpener // only set for the server
zeroRTTSealer LongHeaderSealer // only set for the client
initialOpener LongHeaderOpener
initialSealer LongHeaderSealer
handshakeOpener LongHeaderOpener
handshakeSealer LongHeaderSealer
used0RTT atomic.Bool
aead *updatableAEAD
has1RTTSealer bool
has1RTTOpener bool
}
var _ CryptoSetup = &cryptoSetup{}
// NewCryptoSetupClient creates a new crypto setup for the client
func NewCryptoSetupClient(
connID protocol.ConnectionID,
tp *wire.TransportParameters,
tlsConf *tls.Config,
enable0RTT bool,
rttStats *utils.RTTStats,
tracer *logging.ConnectionTracer,
logger utils.Logger,
version protocol.Version,
) CryptoSetup {
cs := newCryptoSetup(
connID,
tp,
rttStats,
tracer,
logger,
protocol.PerspectiveClient,
version,
)
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
cs.tlsConf = tlsConf
cs.allow0RTT = enable0RTT
cs.conn = tls.QUICClient(&tls.QUICConfig{
TLSConfig: tlsConf,
EnableSessionEvents: true,
})
cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient))
return cs
}
// NewCryptoSetupServer creates a new crypto setup for the server
func NewCryptoSetupServer(
connID protocol.ConnectionID,
localAddr, remoteAddr net.Addr,
tp *wire.TransportParameters,
tlsConf *tls.Config,
allow0RTT bool,
rttStats *utils.RTTStats,
tracer *logging.ConnectionTracer,
logger utils.Logger,
version protocol.Version,
) CryptoSetup {
cs := newCryptoSetup(
connID,
tp,
rttStats,
tracer,
logger,
protocol.PerspectiveServer,
version,
)
cs.allow0RTT = allow0RTT
tlsConf = setupConfigForServer(tlsConf, localAddr, remoteAddr)
cs.tlsConf = tlsConf
cs.conn = tls.QUICServer(&tls.QUICConfig{
TLSConfig: tlsConf,
EnableSessionEvents: true,
})
return cs
}
func newCryptoSetup(
connID protocol.ConnectionID,
tp *wire.TransportParameters,
rttStats *utils.RTTStats,
tracer *logging.ConnectionTracer,
logger utils.Logger,
perspective protocol.Perspective,
version protocol.Version,
) *cryptoSetup {
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
if tracer != nil && tracer.UpdatedKeyFromTLS != nil {
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
}
return &cryptoSetup{
initialSealer: initialSealer,
initialOpener: initialOpener,
aead: newUpdatableAEAD(rttStats, tracer, logger, version),
events: make([]Event, 0, 16),
ourParams: tp,
rttStats: rttStats,
tracer: tracer,
logger: logger,
perspective: perspective,
version: version,
}
}
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version)
h.initialSealer = initialSealer
h.initialOpener = initialOpener
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
}
}
func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
return h.aead.SetLargestAcked(pn)
}
func (h *cryptoSetup) StartHandshake(ctx context.Context) error {
err := h.conn.Start(context.WithValue(ctx, QUICVersionContextKey, h.version))
if err != nil {
return wrapError(err)
}
for {
ev := h.conn.NextEvent()
if err := h.handleEvent(ev); err != nil {
return wrapError(err)
}
if ev.Kind == tls.QUICNoEvent {
break
}
}
if h.perspective == protocol.PerspectiveClient {
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
h.logger.Debugf("Doing 0-RTT.")
h.events = append(h.events, Event{Kind: EventRestoredTransportParameters, TransportParameters: h.zeroRTTParameters})
} else {
h.logger.Debugf("Not doing 0-RTT. Has sealer: %t, has params: %t", h.zeroRTTSealer != nil, h.zeroRTTParameters != nil)
}
}
return nil
}
// Close closes the crypto setup.
// It aborts the handshake, if it is still running.
func (h *cryptoSetup) Close() error {
return h.conn.Close()
}
// HandleMessage handles a TLS handshake message.
// It is called by the crypto streams when a new message is available.
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
if err := h.handleMessage(data, encLevel); err != nil {
return wrapError(err)
}
return nil
}
func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
if err := h.conn.HandleData(encLevel.ToTLSEncryptionLevel(), data); err != nil {
return err
}
for {
ev := h.conn.NextEvent()
if err := h.handleEvent(ev); err != nil {
return err
}
if ev.Kind == tls.QUICNoEvent {
return nil
}
}
}
func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (err error) {
switch ev.Kind {
case tls.QUICNoEvent:
return nil
case tls.QUICSetReadSecret:
h.setReadKey(ev.Level, ev.Suite, ev.Data)
return nil
case tls.QUICSetWriteSecret:
h.setWriteKey(ev.Level, ev.Suite, ev.Data)
return nil
case tls.QUICTransportParameters:
return h.handleTransportParameters(ev.Data)
case tls.QUICTransportParametersRequired:
h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective))
return nil
case tls.QUICRejectedEarlyData:
h.rejected0RTT()
return nil
case tls.QUICWriteData:
h.writeRecord(ev.Level, ev.Data)
return nil
case tls.QUICHandshakeDone:
h.handshakeComplete()
return nil
case tls.QUICStoreSession:
if h.perspective == protocol.PerspectiveServer {
panic("cryptoSetup BUG: unexpected QUICStoreSession event for the server")
}
ev.SessionState.Extra = append(
ev.SessionState.Extra,
addSessionStateExtraPrefix(h.marshalDataForSessionState(ev.SessionState.EarlyData)),
)
return h.conn.StoreSession(ev.SessionState)
case tls.QUICResumeSession:
var allowEarlyData bool
switch h.perspective {
case protocol.PerspectiveClient:
// for clients, this event occurs when a session ticket is selected
allowEarlyData = h.handleDataFromSessionState(
findSessionStateExtraData(ev.SessionState.Extra),
ev.SessionState.EarlyData,
)
case protocol.PerspectiveServer:
// for servers, this event occurs when receiving the client's session ticket
allowEarlyData = h.handleSessionTicket(
findSessionStateExtraData(ev.SessionState.Extra),
ev.SessionState.EarlyData,
)
}
if ev.SessionState.EarlyData {
ev.SessionState.EarlyData = allowEarlyData
}
return nil
default:
// Unknown events should be ignored.
// crypto/tls will ensure that this is safe to do.
// See the discussion following https://github.com/golang/go/issues/68124#issuecomment-2187042510 for details.
return nil
}
}
func (h *cryptoSetup) NextEvent() Event {
if len(h.events) == 0 {
return Event{Kind: EventNoEvent}
}
ev := h.events[0]
h.events = h.events[1:]
return ev
}
func (h *cryptoSetup) handleTransportParameters(data []byte) error {
var tp wire.TransportParameters
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
return err
}
h.peerParams = &tp
h.events = append(h.events, Event{Kind: EventReceivedTransportParameters, TransportParameters: h.peerParams})
return nil
}
// must be called after receiving the transport parameters
func (h *cryptoSetup) marshalDataForSessionState(earlyData bool) []byte {
b := make([]byte, 0, 256)
b = quicvarint.Append(b, clientSessionStateRevision)
if earlyData {
// only save the transport parameters for 0-RTT enabled session tickets
return h.peerParams.MarshalForSessionTicket(b)
}
return b
}
func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (allowEarlyData bool) {
tp, err := decodeDataFromSessionState(data, earlyData)
if err != nil {
h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error())
return
}
// The session ticket might have been saved from a connection that allowed 0-RTT,
// and therefore contain transport parameters.
// Only use them if 0-RTT is actually used on the new connection.
if tp != nil && h.allow0RTT {
h.zeroRTTParameters = tp
return true
}
return false
}
func decodeDataFromSessionState(b []byte, earlyData bool) (*wire.TransportParameters, error) {
ver, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
}
b = b[l:]
if ver != clientSessionStateRevision {
return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
}
if !earlyData {
return nil, nil
}
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
return nil, err
}
return &tp, nil
}
func (h *cryptoSetup) getDataForSessionTicket() []byte {
return (&sessionTicket{
Parameters: h.ourParams,
}).Marshal()
}
// GetSessionTicket generates a new session ticket.
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
// It is only valid for the server.
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{
EarlyData: h.allow0RTT,
Extra: [][]byte{addSessionStateExtraPrefix(h.getDataForSessionTicket())},
}); err != nil {
// Session tickets might be disabled by tls.Config.SessionTicketsDisabled.
// We can't check h.tlsConfig here, since the actual config might have been obtained from
// the GetConfigForClient callback.
// See https://github.com/golang/go/issues/62032.
// Once that issue is resolved, this error assertion can be removed.
if strings.Contains(err.Error(), "session ticket keys unavailable") {
return nil, nil
}
return nil, err
}
ev := h.conn.NextEvent()
if ev.Kind != tls.QUICWriteData || ev.Level != tls.QUICEncryptionLevelApplication {
panic("crypto/tls bug: where's my session ticket?")
}
ticket := ev.Data
if ev := h.conn.NextEvent(); ev.Kind != tls.QUICNoEvent {
panic("crypto/tls bug: why more than one ticket?")
}
return ticket, nil
}
// handleSessionTicket is called for the server when receiving the client's session ticket.
// It reads parameters from the session ticket and checks whether to accept 0-RTT if the session ticket enabled 0-RTT.
// Note that the fact that the session ticket allows 0-RTT doesn't mean that the actual TLS handshake enables 0-RTT:
// A client may use a 0-RTT enabled session to resume a TLS session without using 0-RTT.
func (h *cryptoSetup) handleSessionTicket(data []byte, using0RTT bool) (allowEarlyData bool) {
var t sessionTicket
if err := t.Unmarshal(data); err != nil {
h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error())
return false
}
if !using0RTT {
return false
}
valid := h.ourParams.ValidFor0RTT(t.Parameters)
if !valid {
h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.")
return false
}
if !h.allow0RTT {
h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.")
return false
}
return true
}
// rejected0RTT is called for the client when the server rejects 0-RTT.
func (h *cryptoSetup) rejected0RTT() {
h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.")
had0RTTKeys := h.zeroRTTSealer != nil
h.zeroRTTSealer = nil
if had0RTTKeys {
h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys})
}
}
func (h *cryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
suite := getCipherSuite(suiteID)
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
switch el {
case tls.QUICEncryptionLevelEarly:
if h.perspective == protocol.PerspectiveClient {
panic("Received 0-RTT read key for the client")
}
h.zeroRTTOpener = newLongHeaderOpener(
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
)
h.used0RTT.Store(true)
if h.logger.Debug() {
h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
}
case tls.QUICEncryptionLevelHandshake:
h.handshakeOpener = newLongHeaderOpener(
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
)
if h.logger.Debug() {
h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID))
}
case tls.QUICEncryptionLevelApplication:
h.aead.SetReadKey(suite, trafficSecret)
h.has1RTTOpener = true
if h.logger.Debug() {
h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
}
default:
panic("unexpected read encryption level")
}
h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(protocol.FromTLSEncryptionLevel(el), h.perspective.Opposite())
}
}
func (h *cryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
suite := getCipherSuite(suiteID)
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
switch el {
case tls.QUICEncryptionLevelEarly:
if h.perspective == protocol.PerspectiveServer {
panic("Received 0-RTT write key for the server")
}
h.zeroRTTSealer = newLongHeaderSealer(
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
)
if h.logger.Debug() {
h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
}
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective)
}
// don't set used0RTT here. 0-RTT might still get rejected.
return
case tls.QUICEncryptionLevelHandshake:
h.handshakeSealer = newLongHeaderSealer(
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
)
if h.logger.Debug() {
h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID))
}
case tls.QUICEncryptionLevelApplication:
h.aead.SetWriteKey(suite, trafficSecret)
h.has1RTTSealer = true
if h.logger.Debug() {
h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
}
if h.zeroRTTSealer != nil {
// Once we receive handshake keys, we know that 0-RTT was not rejected.
h.used0RTT.Store(true)
h.zeroRTTSealer = nil
h.logger.Debugf("Dropping 0-RTT keys.")
if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil {
h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
}
}
default:
panic("unexpected write encryption level")
}
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(protocol.FromTLSEncryptionLevel(el), h.perspective)
}
}
// writeRecord is called when TLS writes data
func (h *cryptoSetup) writeRecord(encLevel tls.QUICEncryptionLevel, p []byte) {
//nolint:exhaustive // handshake records can only be written for Initial and Handshake.
switch encLevel {
case tls.QUICEncryptionLevelInitial:
h.events = append(h.events, Event{Kind: EventWriteInitialData, Data: p})
case tls.QUICEncryptionLevelHandshake:
h.events = append(h.events, Event{Kind: EventWriteHandshakeData, Data: p})
case tls.QUICEncryptionLevelApplication:
panic("unexpected write")
default:
panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel))
}
}
func (h *cryptoSetup) DiscardInitialKeys() {
dropped := h.initialOpener != nil
h.initialOpener = nil
h.initialSealer = nil
if dropped {
h.logger.Debugf("Dropping Initial keys.")
}
}
func (h *cryptoSetup) handshakeComplete() {
h.handshakeCompleteTime = time.Now()
h.events = append(h.events, Event{Kind: EventHandshakeComplete})
}
func (h *cryptoSetup) SetHandshakeConfirmed() {
h.aead.SetHandshakeConfirmed()
// drop Handshake keys
var dropped bool
if h.handshakeOpener != nil {
h.handshakeOpener = nil
h.handshakeSealer = nil
dropped = true
}
if dropped {
h.logger.Debugf("Dropping Handshake keys.")
}
}
func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
if h.initialSealer == nil {
return nil, ErrKeysDropped
}
return h.initialSealer, nil
}
func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
if h.zeroRTTSealer == nil {
return nil, ErrKeysDropped
}
return h.zeroRTTSealer, nil
}
func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
if h.handshakeSealer == nil {
if h.initialSealer == nil {
return nil, ErrKeysDropped
}
return nil, ErrKeysNotYetAvailable
}
return h.handshakeSealer, nil
}
func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
if !h.has1RTTSealer {
return nil, ErrKeysNotYetAvailable
}
return h.aead, nil
}
func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
if h.initialOpener == nil {
return nil, ErrKeysDropped
}
return h.initialOpener, nil
}
func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
if h.zeroRTTOpener == nil {
if h.initialOpener != nil {
return nil, ErrKeysNotYetAvailable
}
// if the initial opener is also not available, the keys were already dropped
return nil, ErrKeysDropped
}
return h.zeroRTTOpener, nil
}
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
if h.handshakeOpener == nil {
if h.initialOpener != nil {
return nil, ErrKeysNotYetAvailable
}
// if the initial opener is also not available, the keys were already dropped
return nil, ErrKeysDropped
}
return h.handshakeOpener, nil
}
func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) {
h.zeroRTTOpener = nil
h.logger.Debugf("Dropping 0-RTT keys.")
if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil {
h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
}
}
if !h.has1RTTOpener {
return nil, ErrKeysNotYetAvailable
}
return h.aead, nil
}
func (h *cryptoSetup) ConnectionState() ConnectionState {
return ConnectionState{
ConnectionState: h.conn.ConnectionState(),
Used0RTT: h.used0RTT.Load(),
}
}
func wrapError(err error) error {
if alertErr := tls.AlertError(0); errors.As(err, &alertErr) {
return qerr.NewLocalCryptoError(uint8(alertErr), err)
}
return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()}
}
package handshake
import (
"net"
"time"
)
type conn struct {
localAddr, remoteAddr net.Addr
}
var _ net.Conn = &conn{}
func (c *conn) Read([]byte) (int, error) { return 0, nil }
func (c *conn) Write([]byte) (int, error) { return 0, nil }
func (c *conn) Close() error { return nil }
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
func (c *conn) SetReadDeadline(time.Time) error { return nil }
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
func (c *conn) SetDeadline(time.Time) error { return nil }
package handshake
import (
"crypto/aes"
"crypto/cipher"
"crypto/tls"
"encoding/binary"
"fmt"
"golang.org/x/crypto/chacha20"
"github.com/quic-go/quic-go/internal/protocol"
)
type headerProtector interface {
EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
}
func hkdfHeaderProtectionLabel(v protocol.Version) string {
if v == protocol.Version2 {
return "quicv2 hp"
}
return "quic hp"
}
func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.Version) headerProtector {
hkdfLabel := hkdfHeaderProtectionLabel(v)
switch suite.ID {
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
return newAESHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel)
case tls.TLS_CHACHA20_POLY1305_SHA256:
return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel)
default:
panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID))
}
}
type aesHeaderProtector struct {
mask [16]byte // AES always has a 16 byte block size
block cipher.Block
isLongHeader bool
}
var _ headerProtector = &aesHeaderProtector{}
func newAESHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
block, err := aes.NewCipher(hpKey)
if err != nil {
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
return &aesHeaderProtector{
block: block,
isLongHeader: isLongHeader,
}
}
func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
if len(sample) != len(p.mask) {
panic("invalid sample size")
}
p.block.Encrypt(p.mask[:], sample)
if p.isLongHeader {
*firstByte ^= p.mask[0] & 0xf
} else {
*firstByte ^= p.mask[0] & 0x1f
}
for i := range hdrBytes {
hdrBytes[i] ^= p.mask[i+1]
}
}
type chachaHeaderProtector struct {
mask [5]byte
key [32]byte
isLongHeader bool
}
var _ headerProtector = &chachaHeaderProtector{}
func newChaChaHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
p := &chachaHeaderProtector{
isLongHeader: isLongHeader,
}
copy(p.key[:], hpKey)
return p
}
func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
if len(sample) != 16 {
panic("invalid sample size")
}
for i := 0; i < 5; i++ {
p.mask[i] = 0
}
cipher, err := chacha20.NewUnauthenticatedCipher(p.key[:], sample[4:])
if err != nil {
panic(err)
}
cipher.SetCounter(binary.LittleEndian.Uint32(sample[:4]))
cipher.XORKeyStream(p.mask[:], p.mask[:])
p.applyMask(firstByte, hdrBytes)
}
func (p *chachaHeaderProtector) applyMask(firstByte *byte, hdrBytes []byte) {
if p.isLongHeader {
*firstByte ^= p.mask[0] & 0xf
} else {
*firstByte ^= p.mask[0] & 0x1f
}
for i := range hdrBytes {
hdrBytes[i] ^= p.mask[i+1]
}
}
package handshake
import (
"crypto"
"encoding/binary"
"golang.org/x/crypto/hkdf"
)
// hkdfExpandLabel HKDF expands a label as defined in RFC 8446, section 7.1.
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
b := make([]byte, 3, 3+6+len(label)+1+len(context))
binary.BigEndian.PutUint16(b, uint16(length))
b[2] = uint8(6 + len(label))
b = append(b, []byte("tls13 ")...)
b = append(b, []byte(label)...)
b = b[:3+6+len(label)+1]
b[3+6+len(label)] = uint8(len(context))
b = append(b, context...)
out := make([]byte, length)
n, err := hkdf.Expand(hash.New, secret, b).Read(out)
if err != nil || n != length {
panic("quic: HKDF-Expand-Label invocation failed unexpectedly")
}
return out
}
package handshake
import (
"crypto"
"crypto/tls"
"golang.org/x/crypto/hkdf"
"github.com/quic-go/quic-go/internal/protocol"
)
var (
quicSaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
quicSaltV2 = []byte{0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9}
)
const (
hkdfLabelKeyV1 = "quic key"
hkdfLabelKeyV2 = "quicv2 key"
hkdfLabelIVV1 = "quic iv"
hkdfLabelIVV2 = "quicv2 iv"
)
func getSalt(v protocol.Version) []byte {
if v == protocol.Version2 {
return quicSaltV2
}
return quicSaltV1
}
var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256)
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.Version) (LongHeaderSealer, LongHeaderOpener) {
clientSecret, serverSecret := computeSecrets(connID, v)
var mySecret, otherSecret []byte
if pers == protocol.PerspectiveClient {
mySecret = clientSecret
otherSecret = serverSecret
} else {
mySecret = serverSecret
otherSecret = clientSecret
}
myKey, myIV := computeInitialKeyAndIV(mySecret, v)
otherKey, otherIV := computeInitialKeyAndIV(otherSecret, v)
encrypter := initialSuite.AEAD(myKey, myIV)
decrypter := initialSuite.AEAD(otherKey, otherIV)
return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true, v)),
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v)))
}
func computeSecrets(connID protocol.ConnectionID, v protocol.Version) (clientSecret, serverSecret []byte) {
initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v))
clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size())
return
}
func computeInitialKeyAndIV(secret []byte, v protocol.Version) (key, iv []byte) {
keyLabel := hkdfLabelKeyV1
ivLabel := hkdfLabelIVV1
if v == protocol.Version2 {
keyLabel = hkdfLabelKeyV2
ivLabel = hkdfLabelIVV2
}
key = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16)
iv = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12)
return
}
package handshake
import (
"context"
"crypto/tls"
"errors"
"io"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
var (
// ErrKeysNotYetAvailable is returned when an opener or a sealer is requested for an encryption level,
// but the corresponding opener has not yet been initialized
// This can happen when packets arrive out of order.
ErrKeysNotYetAvailable = errors.New("CryptoSetup: keys at this encryption level not yet available")
// ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level,
// but the corresponding keys have already been dropped.
ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped")
// ErrDecryptionFailed is returned when the AEAD fails to open the packet.
ErrDecryptionFailed = errors.New("decryption failed")
)
type headerDecryptor interface {
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
}
// LongHeaderOpener opens a long header packet
type LongHeaderOpener interface {
headerDecryptor
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error)
}
// ShortHeaderOpener opens a short header packet
type ShortHeaderOpener interface {
headerDecryptor
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error)
}
// LongHeaderSealer seals a long header packet
type LongHeaderSealer interface {
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
Overhead() int
}
// ShortHeaderSealer seals a short header packet
type ShortHeaderSealer interface {
LongHeaderSealer
KeyPhase() protocol.KeyPhaseBit
}
type ConnectionState struct {
tls.ConnectionState
Used0RTT bool
}
// EventKind is the kind of handshake event.
type EventKind uint8
const (
// EventNoEvent signals that there are no new handshake events
EventNoEvent EventKind = iota + 1
// EventWriteInitialData contains new CRYPTO data to send at the Initial encryption level
EventWriteInitialData
// EventWriteHandshakeData contains new CRYPTO data to send at the Handshake encryption level
EventWriteHandshakeData
// EventReceivedReadKeys signals that new decryption keys are available.
// It doesn't say which encryption level those keys are for.
EventReceivedReadKeys
// EventDiscard0RTTKeys signals that the Handshake keys were discarded.
EventDiscard0RTTKeys
// EventReceivedTransportParameters contains the transport parameters sent by the peer.
EventReceivedTransportParameters
// EventRestoredTransportParameters contains the transport parameters restored from the session ticket.
// It is only used for the client.
EventRestoredTransportParameters
// EventHandshakeComplete signals that the TLS handshake was completed.
EventHandshakeComplete
)
func (k EventKind) String() string {
switch k {
case EventNoEvent:
return "EventNoEvent"
case EventWriteInitialData:
return "EventWriteInitialData"
case EventWriteHandshakeData:
return "EventWriteHandshakeData"
case EventReceivedReadKeys:
return "EventReceivedReadKeys"
case EventDiscard0RTTKeys:
return "EventDiscard0RTTKeys"
case EventReceivedTransportParameters:
return "EventReceivedTransportParameters"
case EventRestoredTransportParameters:
return "EventRestoredTransportParameters"
case EventHandshakeComplete:
return "EventHandshakeComplete"
default:
return "Unknown EventKind"
}
}
// Event is a handshake event.
type Event struct {
Kind EventKind
Data []byte
TransportParameters *wire.TransportParameters
}
// CryptoSetup handles the handshake and protecting / unprotecting packets
type CryptoSetup interface {
StartHandshake(context.Context) error
io.Closer
ChangeConnectionID(protocol.ConnectionID)
GetSessionTicket() ([]byte, error)
HandleMessage([]byte, protocol.EncryptionLevel) error
NextEvent() Event
SetLargest1RTTAcked(protocol.PacketNumber) error
DiscardInitialKeys()
SetHandshakeConfirmed()
ConnectionState() ConnectionState
GetInitialOpener() (LongHeaderOpener, error)
GetHandshakeOpener() (LongHeaderOpener, error)
Get0RTTOpener() (LongHeaderOpener, error)
Get1RTTOpener() (ShortHeaderOpener, error)
GetInitialSealer() (LongHeaderSealer, error)
GetHandshakeSealer() (LongHeaderSealer, error)
Get0RTTSealer() (LongHeaderSealer, error)
Get1RTTSealer() (ShortHeaderSealer, error)
}
package handshake
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"fmt"
"sync"
"github.com/quic-go/quic-go/internal/protocol"
)
// Instead of using an init function, the AEADs are created lazily.
// For more details see https://github.com/quic-go/quic-go/issues/4894.
var (
retryAEADv1 cipher.AEAD // used for QUIC v1 (RFC 9000)
retryAEADv2 cipher.AEAD // used for QUIC v2 (RFC 9369)
)
func initAEAD(key [16]byte) cipher.AEAD {
aes, err := aes.NewCipher(key[:])
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(aes)
if err != nil {
panic(err)
}
return aead
}
var (
retryBuf bytes.Buffer
retryMutex sync.Mutex
retryNonceV1 = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb}
retryNonceV2 = [12]byte{0xd8, 0x69, 0x69, 0xbc, 0x2d, 0x7c, 0x6d, 0x99, 0x90, 0xef, 0xb0, 0x4a}
)
// GetRetryIntegrityTag calculates the integrity tag on a Retry packet
func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.Version) *[16]byte {
retryMutex.Lock()
defer retryMutex.Unlock()
retryBuf.WriteByte(uint8(origDestConnID.Len()))
retryBuf.Write(origDestConnID.Bytes())
retryBuf.Write(retry)
defer retryBuf.Reset()
var tag [16]byte
var sealed []byte
if version == protocol.Version2 {
if retryAEADv2 == nil {
retryAEADv2 = initAEAD([16]byte{0x8f, 0xb4, 0xb0, 0x1b, 0x56, 0xac, 0x48, 0xe2, 0x60, 0xfb, 0xcb, 0xce, 0xad, 0x7c, 0xcc, 0x92})
}
sealed = retryAEADv2.Seal(tag[:0], retryNonceV2[:], nil, retryBuf.Bytes())
} else {
if retryAEADv1 == nil {
retryAEADv1 = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e})
}
sealed = retryAEADv1.Seal(tag[:0], retryNonceV1[:], nil, retryBuf.Bytes())
}
if len(sealed) != 16 {
panic(fmt.Sprintf("unexpected Retry integrity tag length: %d", len(sealed)))
}
return &tag
}
package handshake
import (
"bytes"
"errors"
"fmt"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/quicvarint"
)
const sessionTicketRevision = 5
type sessionTicket struct {
Parameters *wire.TransportParameters
}
func (t *sessionTicket) Marshal() []byte {
b := make([]byte, 0, 256)
b = quicvarint.Append(b, sessionTicketRevision)
return t.Parameters.MarshalForSessionTicket(b)
}
func (t *sessionTicket) Unmarshal(b []byte) error {
rev, l, err := quicvarint.Parse(b)
if err != nil {
return errors.New("failed to read session ticket revision")
}
b = b[l:]
if rev != sessionTicketRevision {
return fmt.Errorf("unknown session ticket revision: %d", rev)
}
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
}
t.Parameters = &tp
return nil
}
const extraPrefix = "quic-go1"
func addSessionStateExtraPrefix(b []byte) []byte {
return append([]byte(extraPrefix), b...)
}
func findSessionStateExtraData(extras [][]byte) []byte {
prefix := []byte(extraPrefix)
for _, extra := range extras {
if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) {
continue
}
return extra[len(prefix):]
}
return nil
}
package handshake
import (
"crypto/tls"
"net"
)
func setupConfigForServer(conf *tls.Config, localAddr, remoteAddr net.Addr) *tls.Config {
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
conf = conf.Clone()
conf.MinVersion = tls.VersionTLS13
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
// we're returning a tls.Config here, so we need to apply this recursively
c = setupConfigForServer(c, localAddr, remoteAddr)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
return conf
}
package handshake
import (
"bytes"
"encoding/asn1"
"fmt"
"net"
"time"
"github.com/quic-go/quic-go/internal/protocol"
)
const (
tokenPrefixIP byte = iota
tokenPrefixString
)
// A Token is derived from the client address and can be used to verify the ownership of this address.
type Token struct {
IsRetryToken bool
SentTime time.Time
encodedRemoteAddr []byte
// only set for tokens sent in NEW_TOKEN frames
RTT time.Duration
// only set for retry tokens
OriginalDestConnectionID protocol.ConnectionID
RetrySrcConnectionID protocol.ConnectionID
}
// ValidateRemoteAddr validates the address, but does not check expiration
func (t *Token) ValidateRemoteAddr(addr net.Addr) bool {
return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr)
}
// token is the struct that is used for ASN1 serialization and deserialization
type token struct {
IsRetryToken bool
RemoteAddr []byte
Timestamp int64
RTT int64 // in mus
OriginalDestConnectionID []byte
RetrySrcConnectionID []byte
}
// A TokenGenerator generates tokens
type TokenGenerator struct {
tokenProtector tokenProtector
}
// NewTokenGenerator initializes a new TokenGenerator
func NewTokenGenerator(key TokenProtectorKey) *TokenGenerator {
return &TokenGenerator{tokenProtector: *newTokenProtector(key)}
}
// NewRetryToken generates a new token for a Retry for a given source address
func (g *TokenGenerator) NewRetryToken(
raddr net.Addr,
origDestConnID protocol.ConnectionID,
retrySrcConnID protocol.ConnectionID,
) ([]byte, error) {
data, err := asn1.Marshal(token{
IsRetryToken: true,
RemoteAddr: encodeRemoteAddr(raddr),
OriginalDestConnectionID: origDestConnID.Bytes(),
RetrySrcConnectionID: retrySrcConnID.Bytes(),
Timestamp: time.Now().UnixNano(),
})
if err != nil {
return nil, err
}
return g.tokenProtector.NewToken(data)
}
// NewToken generates a new token to be sent in a NEW_TOKEN frame
func (g *TokenGenerator) NewToken(raddr net.Addr, rtt time.Duration) ([]byte, error) {
data, err := asn1.Marshal(token{
RemoteAddr: encodeRemoteAddr(raddr),
Timestamp: time.Now().UnixNano(),
RTT: rtt.Microseconds(),
})
if err != nil {
return nil, err
}
return g.tokenProtector.NewToken(data)
}
// DecodeToken decodes a token
func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) {
// if the client didn't send any token, DecodeToken will be called with a nil-slice
if len(encrypted) == 0 {
return nil, nil
}
data, err := g.tokenProtector.DecodeToken(encrypted)
if err != nil {
return nil, err
}
t := &token{}
rest, err := asn1.Unmarshal(data, t)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
}
token := &Token{
IsRetryToken: t.IsRetryToken,
SentTime: time.Unix(0, t.Timestamp),
encodedRemoteAddr: t.RemoteAddr,
}
if t.IsRetryToken {
token.OriginalDestConnectionID = protocol.ParseConnectionID(t.OriginalDestConnectionID)
token.RetrySrcConnectionID = protocol.ParseConnectionID(t.RetrySrcConnectionID)
} else {
token.RTT = time.Duration(t.RTT) * time.Microsecond
}
return token, nil
}
// encodeRemoteAddr encodes a remote address such that it can be saved in the token
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
return append([]byte{tokenPrefixIP}, udpAddr.IP...)
}
return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...)
}
package handshake
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
)
// TokenProtectorKey is the key used to encrypt both Retry and session resumption tokens.
type TokenProtectorKey [32]byte
const tokenNonceSize = 32
// tokenProtector is used to create and verify a token
type tokenProtector struct {
key TokenProtectorKey
}
// newTokenProtector creates a source for source address tokens
func newTokenProtector(key TokenProtectorKey) *tokenProtector {
return &tokenProtector{key: key}
}
// NewToken encodes data into a new token.
func (s *tokenProtector) NewToken(data []byte) ([]byte, error) {
var nonce [tokenNonceSize]byte
if _, err := rand.Read(nonce[:]); err != nil {
return nil, err
}
aead, aeadNonce, err := s.createAEAD(nonce[:])
if err != nil {
return nil, err
}
return append(nonce[:], aead.Seal(nil, aeadNonce, data, nil)...), nil
}
// DecodeToken decodes a token.
func (s *tokenProtector) DecodeToken(p []byte) ([]byte, error) {
if len(p) < tokenNonceSize {
return nil, fmt.Errorf("token too short: %d", len(p))
}
nonce := p[:tokenNonceSize]
aead, aeadNonce, err := s.createAEAD(nonce)
if err != nil {
return nil, err
}
return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil)
}
func (s *tokenProtector) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
h := hkdf.New(sha256.New, s.key[:], nonce, []byte("quic-go token source"))
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
if _, err := io.ReadFull(h, key); err != nil {
return nil, nil, err
}
aeadNonce := make([]byte, 12)
if _, err := io.ReadFull(h, aeadNonce); err != nil {
return nil, nil, err
}
c, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}
aead, err := cipher.NewGCM(c)
if err != nil {
return nil, nil, err
}
return aead, aeadNonce, nil
}
package handshake
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"encoding/binary"
"fmt"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
var keyUpdateInterval atomic.Uint64
func init() {
keyUpdateInterval.Store(protocol.KeyUpdateInterval)
}
func SetKeyUpdateInterval(v uint64) (reset func()) {
old := keyUpdateInterval.Swap(v)
return func() { keyUpdateInterval.Store(old) }
}
// FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update.
// It's a package-level variable to allow modifying it for testing purposes.
var FirstKeyUpdateInterval uint64 = 100
type updatableAEAD struct {
suite *cipherSuite
keyPhase protocol.KeyPhase
largestAcked protocol.PacketNumber
firstPacketNumber protocol.PacketNumber
handshakeConfirmed bool
invalidPacketLimit uint64
invalidPacketCount uint64
// Time when the keys should be dropped. Keys are dropped on the next call to Open().
prevRcvAEADExpiry time.Time
prevRcvAEAD cipher.AEAD
firstRcvdWithCurrentKey protocol.PacketNumber
firstSentWithCurrentKey protocol.PacketNumber
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
numRcvdWithCurrentKey uint64
numSentWithCurrentKey uint64
rcvAEAD cipher.AEAD
sendAEAD cipher.AEAD
// caches cipher.AEAD.Overhead(). This speeds up calls to Overhead().
aeadOverhead int
nextRcvAEAD cipher.AEAD
nextSendAEAD cipher.AEAD
nextRcvTrafficSecret []byte
nextSendTrafficSecret []byte
headerDecrypter headerProtector
headerEncrypter headerProtector
rttStats *utils.RTTStats
tracer *logging.ConnectionTracer
logger utils.Logger
version protocol.Version
// use a single slice to avoid allocations
nonceBuf []byte
}
var (
_ ShortHeaderOpener = &updatableAEAD{}
_ ShortHeaderSealer = &updatableAEAD{}
)
func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.Version) *updatableAEAD {
return &updatableAEAD{
firstPacketNumber: protocol.InvalidPacketNumber,
largestAcked: protocol.InvalidPacketNumber,
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
firstSentWithCurrentKey: protocol.InvalidPacketNumber,
rttStats: rttStats,
tracer: tracer,
logger: logger,
version: version,
}
}
func (a *updatableAEAD) rollKeys() {
if a.prevRcvAEAD != nil {
a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry)
if a.tracer != nil && a.tracer.DroppedKey != nil {
a.tracer.DroppedKey(a.keyPhase - 1)
}
a.prevRcvAEADExpiry = time.Time{}
}
a.keyPhase++
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
a.numRcvdWithCurrentKey = 0
a.numSentWithCurrentKey = 0
a.prevRcvAEAD = a.rcvAEAD
a.rcvAEAD = a.nextRcvAEAD
a.sendAEAD = a.nextSendAEAD
a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret)
a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret)
a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret, a.version)
a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret, a.version)
}
func (a *updatableAEAD) startKeyDropTimer(now time.Time) {
d := 3 * a.rttStats.PTO(true)
a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d)
a.prevRcvAEADExpiry = now.Add(d)
}
func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte {
return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size())
}
// SetReadKey sets the read key.
// For the client, this function is called before SetWriteKey.
// For the server, this function is called after SetWriteKey.
func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) {
a.rcvAEAD = createAEAD(suite, trafficSecret, a.version)
a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
if a.suite == nil {
a.setAEADParameters(a.rcvAEAD, suite)
}
a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version)
}
// SetWriteKey sets the write key.
// For the client, this function is called after SetReadKey.
// For the server, this function is called before SetReadKey.
func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) {
a.sendAEAD = createAEAD(suite, trafficSecret, a.version)
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
if a.suite == nil {
a.setAEADParameters(a.sendAEAD, suite)
}
a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version)
}
func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *cipherSuite) {
a.nonceBuf = make([]byte, aead.NonceSize())
a.aeadOverhead = aead.Overhead()
a.suite = suite
switch suite.ID {
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
a.invalidPacketLimit = protocol.InvalidPacketLimitAES
case tls.TLS_CHACHA20_POLY1305_SHA256:
a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha
default:
panic(fmt.Sprintf("unknown cipher suite %d", suite.ID))
}
}
func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN)
}
func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
dec, err := a.open(dst, src, rcvTime, pn, kp, ad)
if err == ErrDecryptionFailed {
a.invalidPacketCount++
if a.invalidPacketCount >= a.invalidPacketLimit {
return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached}
}
}
if err == nil {
a.highestRcvdPN = max(a.highestRcvdPN, pn)
}
return dec, err
}
func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) {
a.prevRcvAEAD = nil
a.logger.Debugf("Dropping key phase %d", a.keyPhase-1)
a.prevRcvAEADExpiry = time.Time{}
if a.tracer != nil && a.tracer.DroppedKey != nil {
a.tracer.DroppedKey(a.keyPhase - 1)
}
}
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
if kp != a.keyPhase.Bit() {
if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
if a.prevRcvAEAD == nil {
return nil, ErrKeysDropped
}
// we updated the key, but the peer hasn't updated yet
dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {
err = ErrDecryptionFailed
}
return dec, err
}
// try opening the packet with the next key phase
dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {
return nil, ErrDecryptionFailed
}
// Opening succeeded. Check if the peer was allowed to update.
if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
return nil, &qerr.TransportError{
ErrorCode: qerr.KeyUpdateError,
ErrorMessage: "keys updated too quickly",
}
}
a.rollKeys()
a.logger.Debugf("Peer updated keys to %d", a.keyPhase)
// The peer initiated this key update. It's safe to drop the keys for the previous generation now.
// Start a timer to drop the previous key generation.
a.startKeyDropTimer(rcvTime)
if a.tracer != nil && a.tracer.UpdatedKey != nil {
a.tracer.UpdatedKey(a.keyPhase, true)
}
a.firstRcvdWithCurrentKey = pn
return dec, err
}
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {
return dec, ErrDecryptionFailed
}
a.numRcvdWithCurrentKey++
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
// We initiated the key updated, and now we received the first packet protected with the new key phase.
// Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys.
if a.keyPhase > 0 {
a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase)
a.startKeyDropTimer(rcvTime)
}
a.firstRcvdWithCurrentKey = pn
}
return dec, err
}
func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
a.firstSentWithCurrentKey = pn
}
if a.firstPacketNumber == protocol.InvalidPacketNumber {
a.firstPacketNumber = pn
}
a.numSentWithCurrentKey++
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad)
}
func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error {
if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 {
return &qerr.TransportError{
ErrorCode: qerr.KeyUpdateError,
ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase),
}
}
a.largestAcked = pn
return nil
}
func (a *updatableAEAD) SetHandshakeConfirmed() {
a.handshakeConfirmed = true
}
func (a *updatableAEAD) updateAllowed() bool {
if !a.handshakeConfirmed {
return false
}
// the first key update is allowed as soon as the handshake is confirmed
return a.keyPhase == 0 ||
// subsequent key updates as soon as a packet sent with that key phase has been acknowledged
(a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
a.largestAcked != protocol.InvalidPacketNumber &&
a.largestAcked >= a.firstSentWithCurrentKey)
}
func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
if !a.updateAllowed() {
return false
}
// Initiate the first key update shortly after the handshake, in order to exercise the key update mechanism.
if a.keyPhase == 0 {
if a.numRcvdWithCurrentKey >= FirstKeyUpdateInterval || a.numSentWithCurrentKey >= FirstKeyUpdateInterval {
return true
}
}
if a.numRcvdWithCurrentKey >= keyUpdateInterval.Load() {
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
return true
}
if a.numSentWithCurrentKey >= keyUpdateInterval.Load() {
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
return true
}
return false
}
func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
if a.shouldInitiateKeyUpdate() {
a.rollKeys()
a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase)
if a.tracer != nil && a.tracer.UpdatedKey != nil {
a.tracer.UpdatedKey(a.keyPhase, false)
}
}
return a.keyPhase.Bit()
}
func (a *updatableAEAD) Overhead() int {
return a.aeadOverhead
}
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes)
}
func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes)
}
func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber {
return a.firstPacketNumber
}
package protocol
import (
"crypto/rand"
"errors"
"fmt"
"io"
)
var ErrInvalidConnectionIDLen = errors.New("invalid Connection ID length")
// An ArbitraryLenConnectionID is a QUIC Connection ID able to represent Connection IDs according to RFC 8999.
// Future QUIC versions might allow connection ID lengths up to 255 bytes, while QUIC v1
// restricts the length to 20 bytes.
type ArbitraryLenConnectionID []byte
func (c ArbitraryLenConnectionID) Len() int {
return len(c)
}
func (c ArbitraryLenConnectionID) Bytes() []byte {
return c
}
func (c ArbitraryLenConnectionID) String() string {
if c.Len() == 0 {
return "(empty)"
}
return fmt.Sprintf("%x", c.Bytes())
}
const maxConnectionIDLen = 20
// A ConnectionID in QUIC
type ConnectionID struct {
b [20]byte
l uint8
}
// GenerateConnectionID generates a connection ID using cryptographic random
func GenerateConnectionID(l int) (ConnectionID, error) {
var c ConnectionID
c.l = uint8(l)
_, err := rand.Read(c.b[:l])
return c, err
}
// ParseConnectionID interprets b as a Connection ID.
// It panics if b is longer than 20 bytes.
func ParseConnectionID(b []byte) ConnectionID {
if len(b) > maxConnectionIDLen {
panic("invalid conn id length")
}
var c ConnectionID
c.l = uint8(len(b))
copy(c.b[:c.l], b)
return c
}
// GenerateConnectionIDForInitial generates a connection ID for the Initial packet.
// It uses a length randomly chosen between 8 and 20 bytes.
func GenerateConnectionIDForInitial() (ConnectionID, error) {
r := make([]byte, 1)
if _, err := rand.Read(r); err != nil {
return ConnectionID{}, err
}
l := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
return GenerateConnectionID(l)
}
// ReadConnectionID reads a connection ID of length len from the given io.Reader.
// It returns io.EOF if there are not enough bytes to read.
func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) {
var c ConnectionID
if l == 0 {
return c, nil
}
if l > maxConnectionIDLen {
return c, ErrInvalidConnectionIDLen
}
c.l = uint8(l)
_, err := io.ReadFull(r, c.b[:l])
if err == io.ErrUnexpectedEOF {
return c, io.EOF
}
return c, err
}
// Len returns the length of the connection ID in bytes
func (c ConnectionID) Len() int {
return int(c.l)
}
// Bytes returns the byte representation
func (c ConnectionID) Bytes() []byte {
return c.b[:c.l]
}
func (c ConnectionID) String() string {
if c.Len() == 0 {
return "(empty)"
}
return fmt.Sprintf("%x", c.Bytes())
}
type DefaultConnectionIDGenerator struct {
ConnLen int
}
func (d *DefaultConnectionIDGenerator) GenerateConnectionID() (ConnectionID, error) {
return GenerateConnectionID(d.ConnLen)
}
func (d *DefaultConnectionIDGenerator) ConnectionIDLen() int {
return d.ConnLen
}
package protocol
import (
"crypto/tls"
"fmt"
)
// EncryptionLevel is the encryption level
// Default value is Unencrypted
type EncryptionLevel uint8
const (
// EncryptionInitial is the Initial encryption level
EncryptionInitial EncryptionLevel = 1 + iota
// EncryptionHandshake is the Handshake encryption level
EncryptionHandshake
// Encryption0RTT is the 0-RTT encryption level
Encryption0RTT
// Encryption1RTT is the 1-RTT encryption level
Encryption1RTT
)
func (e EncryptionLevel) String() string {
switch e {
case EncryptionInitial:
return "Initial"
case EncryptionHandshake:
return "Handshake"
case Encryption0RTT:
return "0-RTT"
case Encryption1RTT:
return "1-RTT"
}
return "unknown"
}
func (e EncryptionLevel) ToTLSEncryptionLevel() tls.QUICEncryptionLevel {
switch e {
case EncryptionInitial:
return tls.QUICEncryptionLevelInitial
case EncryptionHandshake:
return tls.QUICEncryptionLevelHandshake
case Encryption1RTT:
return tls.QUICEncryptionLevelApplication
case Encryption0RTT:
return tls.QUICEncryptionLevelEarly
default:
panic(fmt.Sprintf("unexpected encryption level: %s", e))
}
}
func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) EncryptionLevel {
switch e {
case tls.QUICEncryptionLevelInitial:
return EncryptionInitial
case tls.QUICEncryptionLevelHandshake:
return EncryptionHandshake
case tls.QUICEncryptionLevelApplication:
return Encryption1RTT
case tls.QUICEncryptionLevelEarly:
return Encryption0RTT
default:
panic(fmt.Sprintf("unexpect encryption level: %s", e))
}
}
package protocol
// KeyPhase is the key phase
type KeyPhase uint64
// Bit determines the key phase bit
func (p KeyPhase) Bit() KeyPhaseBit {
if p%2 == 0 {
return KeyPhaseZero
}
return KeyPhaseOne
}
// KeyPhaseBit is the key phase bit
type KeyPhaseBit uint8
const (
// KeyPhaseUndefined is an undefined key phase
KeyPhaseUndefined KeyPhaseBit = iota
// KeyPhaseZero is key phase 0
KeyPhaseZero
// KeyPhaseOne is key phase 1
KeyPhaseOne
)
func (p KeyPhaseBit) String() string {
//nolint:exhaustive
switch p {
case KeyPhaseZero:
return "0"
case KeyPhaseOne:
return "1"
default:
return "undefined"
}
}
package protocol
// A PacketNumber in QUIC
type PacketNumber int64
// InvalidPacketNumber is a packet number that is never sent.
// In QUIC, 0 is a valid packet number.
const InvalidPacketNumber PacketNumber = -1
// PacketNumberLen is the length of the packet number in bytes
type PacketNumberLen uint8
const (
// PacketNumberLen1 is a packet number length of 1 byte
PacketNumberLen1 PacketNumberLen = 1
// PacketNumberLen2 is a packet number length of 2 bytes
PacketNumberLen2 PacketNumberLen = 2
// PacketNumberLen3 is a packet number length of 3 bytes
PacketNumberLen3 PacketNumberLen = 3
// PacketNumberLen4 is a packet number length of 4 bytes
PacketNumberLen4 PacketNumberLen = 4
)
// DecodePacketNumber calculates the packet number based its length and the last seen packet number
// This function is taken from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.3.
func DecodePacketNumber(length PacketNumberLen, largest PacketNumber, truncated PacketNumber) PacketNumber {
expected := largest + 1
win := PacketNumber(1 << (length * 8))
hwin := win / 2
mask := win - 1
candidate := (expected & ^mask) | truncated
if candidate <= expected-hwin && candidate < 1<<62-win {
return candidate + win
}
if candidate > expected+hwin && candidate >= win {
return candidate - win
}
return candidate
}
// PacketNumberLengthForHeader gets the length of the packet number for the public header
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
func PacketNumberLengthForHeader(pn, largestAcked PacketNumber) PacketNumberLen {
var numUnacked PacketNumber
if largestAcked == InvalidPacketNumber {
numUnacked = pn + 1
} else {
numUnacked = pn - largestAcked
}
if numUnacked < 1<<(16-1) {
return PacketNumberLen2
}
if numUnacked < 1<<(24-1) {
return PacketNumberLen3
}
return PacketNumberLen4
}
package protocol
// Perspective determines if we're acting as a server or a client
type Perspective int
// the perspectives
const (
PerspectiveServer Perspective = 1
PerspectiveClient Perspective = 2
)
// Opposite returns the perspective of the peer
func (p Perspective) Opposite() Perspective {
return 3 - p
}
func (p Perspective) String() string {
switch p {
case PerspectiveServer:
return "server"
case PerspectiveClient:
return "client"
default:
return "invalid perspective"
}
}
package protocol
import (
"fmt"
"time"
)
// The PacketType is the Long Header Type
type PacketType uint8
const (
// PacketTypeInitial is the packet type of an Initial packet
PacketTypeInitial PacketType = 1 + iota
// PacketTypeRetry is the packet type of a Retry packet
PacketTypeRetry
// PacketTypeHandshake is the packet type of a Handshake packet
PacketTypeHandshake
// PacketType0RTT is the packet type of a 0-RTT packet
PacketType0RTT
)
func (t PacketType) String() string {
switch t {
case PacketTypeInitial:
return "Initial"
case PacketTypeRetry:
return "Retry"
case PacketTypeHandshake:
return "Handshake"
case PacketType0RTT:
return "0-RTT Protected"
default:
return fmt.Sprintf("unknown packet type: %d", t)
}
}
type ECN uint8
const (
ECNUnsupported ECN = iota
ECNNon // 00
ECT1 // 01
ECT0 // 10
ECNCE // 11
)
func ParseECNHeaderBits(bits byte) ECN {
switch bits {
case 0:
return ECNNon
case 0b00000010:
return ECT0
case 0b00000001:
return ECT1
case 0b00000011:
return ECNCE
default:
panic("invalid ECN bits")
}
}
func (e ECN) ToHeaderBits() byte {
//nolint:exhaustive // There are only 4 values.
switch e {
case ECNNon:
return 0
case ECT0:
return 0b00000010
case ECT1:
return 0b00000001
case ECNCE:
return 0b00000011
default:
panic("ECN unsupported")
}
}
func (e ECN) String() string {
switch e {
case ECNUnsupported:
return "ECN unsupported"
case ECNNon:
return "Not-ECT"
case ECT1:
return "ECT(1)"
case ECT0:
return "ECT(0)"
case ECNCE:
return "CE"
default:
return fmt.Sprintf("invalid ECN value: %d", e)
}
}
// A ByteCount in QUIC
type ByteCount int64
// MaxByteCount is the maximum value of a ByteCount
const MaxByteCount = ByteCount(1<<62 - 1)
// InvalidByteCount is an invalid byte count
const InvalidByteCount ByteCount = -1
// A StatelessResetToken is a stateless reset token.
type StatelessResetToken [16]byte
// MaxPacketBufferSize maximum packet size of any QUIC packet, based on
// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header,
// UDP adds an additional 8 bytes. This is a total overhead of 48 bytes.
// Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452.
const MaxPacketBufferSize = 1452
// MaxLargePacketBufferSize is used when using GSO
const MaxLargePacketBufferSize = 20 * 1024
// MinInitialPacketSize is the minimum size an Initial packet is required to have.
const MinInitialPacketSize = 1200
// MinUnknownVersionPacketSize is the minimum size a packet with an unknown version
// needs to have in order to trigger a Version Negotiation packet.
const MinUnknownVersionPacketSize = MinInitialPacketSize
// MinStatelessResetSize is the minimum size of a stateless reset packet that we send
const MinStatelessResetSize = 1 /* first byte */ + 20 /* max. conn ID length */ + 4 /* max. packet number length */ + 1 /* min. payload length */ + 16 /* token */
// MinReceivedStatelessResetSize is the minimum size of a received stateless reset,
// as specified in section 10.3 of RFC 9000.
const MinReceivedStatelessResetSize = 5 + 16
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
const MinConnectionIDLenInitial = 8
// DefaultAckDelayExponent is the default ack delay exponent
const DefaultAckDelayExponent = 3
// DefaultActiveConnectionIDLimit is the default active connection ID limit
const DefaultActiveConnectionIDLimit = 2
// MaxAckDelayExponent is the maximum ack delay exponent
const MaxAckDelayExponent = 20
// DefaultMaxAckDelay is the default max_ack_delay
const DefaultMaxAckDelay = 25 * time.Millisecond
// MaxMaxAckDelay is the maximum max_ack_delay
const MaxMaxAckDelay = (1<<14 - 1) * time.Millisecond
// MaxConnIDLen is the maximum length of the connection ID
const MaxConnIDLen = 20
// InvalidPacketLimitAES is the maximum number of packets that we can fail to decrypt when using
// AEAD_AES_128_GCM or AEAD_AES_265_GCM.
const InvalidPacketLimitAES = 1 << 52
// InvalidPacketLimitChaCha is the maximum number of packets that we can fail to decrypt when using AEAD_CHACHA20_POLY1305.
const InvalidPacketLimitChaCha = 1 << 36
package protocol
import "github.com/quic-go/quic-go/quicvarint"
// StreamType encodes if this is a unidirectional or bidirectional stream
type StreamType uint8
const (
// StreamTypeUni is a unidirectional stream
StreamTypeUni StreamType = iota
// StreamTypeBidi is a bidirectional stream
StreamTypeBidi
)
// InvalidPacketNumber is a stream ID that is invalid.
// The first valid stream ID in QUIC is 0.
const InvalidStreamID StreamID = -1
// StreamNum is the stream number
type StreamNum int64
const (
// InvalidStreamNum is an invalid stream number.
InvalidStreamNum = -1
// MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames
// and as the stream count in the transport parameters
MaxStreamCount StreamNum = 1 << 60
// MaxStreamID is the maximum stream ID
MaxStreamID StreamID = quicvarint.Max
)
const (
// FirstOutgoingBidiStreamClient is the first bidirectional stream opened by the client
FirstOutgoingBidiStreamClient StreamID = 0
// FirstOutgoingUniStreamClient is the first unidirectional stream opened by the client
FirstOutgoingUniStreamClient StreamID = 2
// FirstOutgoingBidiStreamServer is the first bidirectional stream opened by the server
FirstOutgoingBidiStreamServer StreamID = 1
// FirstOutgoingUniStreamServer is the first unidirectional stream opened by the server
FirstOutgoingUniStreamServer StreamID = 3
)
const (
// FirstIncomingBidiStreamServer is the first bidirectional stream accepted by the server
FirstIncomingBidiStreamServer = FirstOutgoingBidiStreamClient
// FirstIncomingUniStreamServer is the first unidirectional stream accepted by the server
FirstIncomingUniStreamServer = FirstOutgoingUniStreamClient
// FirstIncomingBidiStreamClient is the first bidirectional stream accepted by the client
FirstIncomingBidiStreamClient = FirstOutgoingBidiStreamServer
// FirstIncomingUniStreamClient is the first unidirectional stream accepted by the client
FirstIncomingUniStreamClient = FirstOutgoingUniStreamServer
)
// StreamID calculates the stream ID.
func (s StreamNum) StreamID(stype StreamType, pers Perspective) StreamID {
if s == 0 {
return InvalidStreamID
}
var first StreamID
switch stype {
case StreamTypeBidi:
switch pers {
case PerspectiveClient:
first = 0
case PerspectiveServer:
first = 1
}
case StreamTypeUni:
switch pers {
case PerspectiveClient:
first = 2
case PerspectiveServer:
first = 3
}
}
return first + 4*StreamID(s-1)
}
// A StreamID in QUIC
type StreamID int64
// InitiatedBy says if the stream was initiated by the client or by the server
func (s StreamID) InitiatedBy() Perspective {
if s%2 == 0 {
return PerspectiveClient
}
return PerspectiveServer
}
// Type says if this is a unidirectional or bidirectional stream
func (s StreamID) Type() StreamType {
if s%4 >= 2 {
return StreamTypeUni
}
return StreamTypeBidi
}
// StreamNum returns how many streams in total are below this
// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9)
func (s StreamID) StreamNum() StreamNum {
return StreamNum(s/4) + 1
}
package protocol
import (
"crypto/rand"
"encoding/binary"
"fmt"
"math"
mrand "math/rand/v2"
"slices"
"sync"
)
// Version is a version number as int
type Version uint32
// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions
const (
gquicVersion0 = 0x51303030
maxGquicVersion = 0x51303439
)
// The version numbers, making grepping easier
const (
VersionUnknown Version = math.MaxUint32
versionDraft29 Version = 0xff00001d // draft-29 used to be a widely deployed version
Version1 Version = 0x1
Version2 Version = 0x6b3343cf
)
// SupportedVersions lists the versions that the server supports
// must be in sorted descending order
var SupportedVersions = []Version{Version1, Version2}
// IsValidVersion says if the version is known to quic-go
func IsValidVersion(v Version) bool {
return v == Version1 || IsSupportedVersion(SupportedVersions, v)
}
func (vn Version) String() string {
switch vn {
case VersionUnknown:
return "unknown"
case versionDraft29:
return "draft-29"
case Version1:
return "v1"
case Version2:
return "v2"
default:
if vn.isGQUIC() {
return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion())
}
return fmt.Sprintf("%#x", uint32(vn))
}
}
func (vn Version) isGQUIC() bool {
return vn > gquicVersion0 && vn <= maxGquicVersion
}
func (vn Version) toGQUICVersion() int {
return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10)
}
// IsSupportedVersion returns true if the server supports this version
func IsSupportedVersion(supported []Version, v Version) bool {
return slices.Contains(supported, v)
}
// ChooseSupportedVersion finds the best version in the overlap of ours and theirs
// ours is a slice of versions that we support, sorted by our preference (descending)
// theirs is a slice of versions offered by the peer. The order does not matter.
// The bool returned indicates if a matching version was found.
func ChooseSupportedVersion(ours, theirs []Version) (Version, bool) {
for _, ourVer := range ours {
if slices.Contains(theirs, ourVer) {
return ourVer, true
}
}
return 0, false
}
var (
versionNegotiationMx sync.Mutex
versionNegotiationRand mrand.Rand
)
func init() {
var seed [16]byte
rand.Read(seed[:])
versionNegotiationRand = *mrand.New(mrand.NewPCG(
binary.BigEndian.Uint64(seed[:8]),
binary.BigEndian.Uint64(seed[8:]),
))
}
// generateReservedVersion generates a reserved version (v & 0x0f0f0f0f == 0x0a0a0a0a)
func generateReservedVersion() Version {
var b [4]byte
binary.BigEndian.PutUint32(b[:], versionNegotiationRand.Uint32())
return Version((binary.BigEndian.Uint32(b[:]) | 0x0a0a0a0a) & 0xfafafafa)
}
// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position.
// It doesn't modify the supported slice.
func GetGreasedVersions(supported []Version) []Version {
versionNegotiationMx.Lock()
defer versionNegotiationMx.Unlock()
randPos := versionNegotiationRand.IntN(len(supported) + 1)
greased := make([]Version, len(supported)+1)
copy(greased, supported[:randPos])
greased[randPos] = generateReservedVersion()
copy(greased[randPos+1:], supported[randPos:])
return greased
}
package qerr
import (
"crypto/tls"
"fmt"
)
// TransportErrorCode is a QUIC transport error.
type TransportErrorCode uint64
// The error codes defined by QUIC
const (
NoError TransportErrorCode = 0x0
InternalError TransportErrorCode = 0x1
ConnectionRefused TransportErrorCode = 0x2
FlowControlError TransportErrorCode = 0x3
StreamLimitError TransportErrorCode = 0x4
StreamStateError TransportErrorCode = 0x5
FinalSizeError TransportErrorCode = 0x6
FrameEncodingError TransportErrorCode = 0x7
TransportParameterError TransportErrorCode = 0x8
ConnectionIDLimitError TransportErrorCode = 0x9
ProtocolViolation TransportErrorCode = 0xa
InvalidToken TransportErrorCode = 0xb
ApplicationErrorErrorCode TransportErrorCode = 0xc
CryptoBufferExceeded TransportErrorCode = 0xd
KeyUpdateError TransportErrorCode = 0xe
AEADLimitReached TransportErrorCode = 0xf
NoViablePathError TransportErrorCode = 0x10
)
func (e TransportErrorCode) IsCryptoError() bool {
return e >= 0x100 && e < 0x200
}
// Message is a description of the error.
// It only returns a non-empty string for crypto errors.
func (e TransportErrorCode) Message() string {
if !e.IsCryptoError() {
return ""
}
return tls.AlertError(e - 0x100).Error()
}
func (e TransportErrorCode) String() string {
switch e {
case NoError:
return "NO_ERROR"
case InternalError:
return "INTERNAL_ERROR"
case ConnectionRefused:
return "CONNECTION_REFUSED"
case FlowControlError:
return "FLOW_CONTROL_ERROR"
case StreamLimitError:
return "STREAM_LIMIT_ERROR"
case StreamStateError:
return "STREAM_STATE_ERROR"
case FinalSizeError:
return "FINAL_SIZE_ERROR"
case FrameEncodingError:
return "FRAME_ENCODING_ERROR"
case TransportParameterError:
return "TRANSPORT_PARAMETER_ERROR"
case ConnectionIDLimitError:
return "CONNECTION_ID_LIMIT_ERROR"
case ProtocolViolation:
return "PROTOCOL_VIOLATION"
case InvalidToken:
return "INVALID_TOKEN"
case ApplicationErrorErrorCode:
return "APPLICATION_ERROR"
case CryptoBufferExceeded:
return "CRYPTO_BUFFER_EXCEEDED"
case KeyUpdateError:
return "KEY_UPDATE_ERROR"
case AEADLimitReached:
return "AEAD_LIMIT_REACHED"
case NoViablePathError:
return "NO_VIABLE_PATH"
default:
if e.IsCryptoError() {
return fmt.Sprintf("CRYPTO_ERROR %#x", uint16(e))
}
return fmt.Sprintf("unknown error code: %#x", uint16(e))
}
}
package qerr
import (
"fmt"
"net"
"github.com/quic-go/quic-go/internal/protocol"
)
var (
ErrHandshakeTimeout = &HandshakeTimeoutError{}
ErrIdleTimeout = &IdleTimeoutError{}
)
type TransportError struct {
Remote bool
FrameType uint64
ErrorCode TransportErrorCode
ErrorMessage string
error error // only set for local errors, sometimes
}
var _ error = &TransportError{}
// NewLocalCryptoError create a new TransportError instance for a crypto error
func NewLocalCryptoError(tlsAlert uint8, err error) *TransportError {
return &TransportError{
ErrorCode: 0x100 + TransportErrorCode(tlsAlert),
error: err,
}
}
func (e *TransportError) Error() string {
str := fmt.Sprintf("%s (%s)", e.ErrorCode.String(), getRole(e.Remote))
if e.FrameType != 0 {
str += fmt.Sprintf(" (frame type: %#x)", e.FrameType)
}
msg := e.ErrorMessage
if len(msg) == 0 && e.error != nil {
msg = e.error.Error()
}
if len(msg) == 0 {
msg = e.ErrorCode.Message()
}
if len(msg) == 0 {
return str
}
return str + ": " + msg
}
func (e *TransportError) Unwrap() []error { return []error{net.ErrClosed, e.error} }
func (e *TransportError) Is(target error) bool {
t, ok := target.(*TransportError)
return ok && e.ErrorCode == t.ErrorCode && e.FrameType == t.FrameType && e.Remote == t.Remote
}
// An ApplicationErrorCode is an application-defined error code.
type ApplicationErrorCode uint64
// A StreamErrorCode is an error code used to cancel streams.
type StreamErrorCode uint64
type ApplicationError struct {
Remote bool
ErrorCode ApplicationErrorCode
ErrorMessage string
}
var _ error = &ApplicationError{}
func (e *ApplicationError) Error() string {
if len(e.ErrorMessage) == 0 {
return fmt.Sprintf("Application error %#x (%s)", e.ErrorCode, getRole(e.Remote))
}
return fmt.Sprintf("Application error %#x (%s): %s", e.ErrorCode, getRole(e.Remote), e.ErrorMessage)
}
func (e *ApplicationError) Unwrap() error { return net.ErrClosed }
func (e *ApplicationError) Is(target error) bool {
t, ok := target.(*ApplicationError)
return ok && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote
}
type IdleTimeoutError struct{}
var _ error = &IdleTimeoutError{}
func (e *IdleTimeoutError) Timeout() bool { return true }
func (e *IdleTimeoutError) Temporary() bool { return false }
func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" }
func (e *IdleTimeoutError) Unwrap() error { return net.ErrClosed }
type HandshakeTimeoutError struct{}
var _ error = &HandshakeTimeoutError{}
func (e *HandshakeTimeoutError) Timeout() bool { return true }
func (e *HandshakeTimeoutError) Temporary() bool { return false }
func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" }
func (e *HandshakeTimeoutError) Unwrap() error { return net.ErrClosed }
// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version.
type VersionNegotiationError struct {
Ours []protocol.Version
Theirs []protocol.Version
}
func (e *VersionNegotiationError) Error() string {
return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.Ours, e.Theirs)
}
func (e *VersionNegotiationError) Unwrap() error { return net.ErrClosed }
// A StatelessResetError occurs when we receive a stateless reset.
type StatelessResetError struct{}
var _ net.Error = &StatelessResetError{}
func (e *StatelessResetError) Error() string {
return "received a stateless reset"
}
func (e *StatelessResetError) Unwrap() error { return net.ErrClosed }
func (e *StatelessResetError) Timeout() bool { return false }
func (e *StatelessResetError) Temporary() bool { return true }
func getRole(remote bool) string {
if remote {
return "remote"
}
return "local"
}
package qtls
import (
"crypto/tls"
"fmt"
"unsafe"
)
//go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13
var cipherSuitesTLS13 []unsafe.Pointer
//go:linkname defaultCipherSuitesTLS13 crypto/tls.defaultCipherSuitesTLS13
var defaultCipherSuitesTLS13 []uint16
//go:linkname defaultCipherSuitesTLS13NoAES crypto/tls.defaultCipherSuitesTLS13NoAES
var defaultCipherSuitesTLS13NoAES []uint16
var cipherSuitesModified bool
// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls
// such that it only contains the cipher suite with the chosen id.
// The reset function returned resets them back to the original value.
func SetCipherSuite(id uint16) (reset func()) {
if cipherSuitesModified {
panic("cipher suites modified multiple times without resetting")
}
cipherSuitesModified = true
origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...)
origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...)
origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...)
// The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls.
switch id {
case tls.TLS_AES_128_GCM_SHA256:
cipherSuitesTLS13 = cipherSuitesTLS13[:1]
case tls.TLS_CHACHA20_POLY1305_SHA256:
cipherSuitesTLS13 = cipherSuitesTLS13[1:2]
case tls.TLS_AES_256_GCM_SHA384:
cipherSuitesTLS13 = cipherSuitesTLS13[2:]
default:
panic(fmt.Sprintf("unexpected cipher suite: %d", id))
}
defaultCipherSuitesTLS13 = []uint16{id}
defaultCipherSuitesTLS13NoAES = []uint16{id}
return func() {
cipherSuitesTLS13 = origCipherSuitesTLS13
defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13
defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES
cipherSuitesModified = false
}
}
package utils
import (
"bufio"
"io"
)
type bufferedWriteCloser struct {
*bufio.Writer
io.Closer
}
// NewBufferedWriteCloser creates an io.WriteCloser from a bufio.Writer and an io.Closer
func NewBufferedWriteCloser(writer *bufio.Writer, closer io.Closer) io.WriteCloser {
return &bufferedWriteCloser{
Writer: writer,
Closer: closer,
}
}
func (h bufferedWriteCloser) Close() error {
if err := h.Flush(); err != nil {
return err
}
return h.Closer.Close()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package list implements a doubly linked list.
//
// To iterate over a list (where l is a *List[T]):
//
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.Value
// }
package list
import "sync"
func NewPool[T any]() *sync.Pool {
return &sync.Pool{New: func() any { return &Element[T]{} }}
}
// Element is an element of a linked list.
type Element[T any] struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *Element[T]
// The list to which this element belongs.
list *List[T]
// The value stored with this element.
Value T
}
// Next returns the next list element or nil.
func (e *Element[T]) Next() *Element[T] {
if p := e.next; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// Prev returns the previous list element or nil.
func (e *Element[T]) Prev() *Element[T] {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
func (e *Element[T]) List() *List[T] {
return e.list
}
// List represents a doubly linked list.
// The zero value for List is an empty list ready to use.
type List[T any] struct {
root Element[T] // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element
pool *sync.Pool
}
// Init initializes or clears list l.
func (l *List[T]) Init() *List[T] {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// New returns an initialized list.
func New[T any]() *List[T] { return new(List[T]).Init() }
// NewWithPool returns an initialized list, using a sync.Pool for list elements.
func NewWithPool[T any](pool *sync.Pool) *List[T] {
l := &List[T]{pool: pool}
return l.Init()
}
// Len returns the number of elements of list l.
// The complexity is O(1).
func (l *List[T]) Len() int { return l.len }
// Front returns the first element of list l or nil if the list is empty.
func (l *List[T]) Front() *Element[T] {
if l.len == 0 {
return nil
}
return l.root.next
}
// Back returns the last element of list l or nil if the list is empty.
func (l *List[T]) Back() *Element[T] {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List value.
func (l *List[T]) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *List[T]) insert(e, at *Element[T]) *Element[T] {
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] {
var e *Element[T]
if l.pool != nil {
e = l.pool.Get().(*Element[T])
} else {
e = &Element[T]{}
}
e.Value = v
return l.insert(e, at)
}
// remove removes e from its list, decrements l.len
func (l *List[T]) remove(e *Element[T]) {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
if l.pool != nil {
l.pool.Put(e)
}
l.len--
}
// move moves e to next to at.
func (l *List[T]) move(e, at *Element[T]) {
if e == at {
return
}
e.prev.next = e.next
e.next.prev = e.prev
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
}
// Remove removes e from l if e is an element of list l.
// It returns the element value e.Value.
// The element must not be nil.
func (l *List[T]) Remove(e *Element[T]) T {
v := e.Value
if e.list == l {
// if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e)
}
return v
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *List[T]) PushFront(v T) *Element[T] {
l.lazyInit()
return l.insertValue(v, &l.root)
}
// PushBack inserts a new element e with value v at the back of list l and returns e.
func (l *List[T]) PushBack(v T) *Element[T] {
l.lazyInit()
return l.insertValue(v, l.root.prev)
}
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev)
}
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *List[T]) MoveToFront(e *Element[T]) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.move(e, &l.root)
}
// MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *List[T]) MoveToBack(e *Element[T]) {
if e.list != l || l.root.prev == e {
return
}
// see comment in List.Remove about initialization of l
l.move(e, l.root.prev)
}
// MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *List[T]) MoveBefore(e, mark *Element[T]) {
if e.list != l || e == mark || mark.list != l {
return
}
l.move(e, mark.prev)
}
// MoveAfter moves element e to its new position after mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *List[T]) MoveAfter(e, mark *Element[T]) {
if e.list != l || e == mark || mark.list != l {
return
}
l.move(e, mark)
}
// PushBackList inserts a copy of another list at the back of list l.
// The lists l and other may be the same. They must not be nil.
func (l *List[T]) PushBackList(other *List[T]) {
l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
l.insertValue(e.Value, l.root.prev)
}
}
// PushFrontList inserts a copy of another list at the front of list l.
// The lists l and other may be the same. They must not be nil.
func (l *List[T]) PushFrontList(other *List[T]) {
l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
l.insertValue(e.Value, &l.root)
}
}
package utils
import (
"fmt"
"log"
"os"
"strings"
"time"
)
// LogLevel of quic-go
type LogLevel uint8
const (
// LogLevelNothing disables
LogLevelNothing LogLevel = iota
// LogLevelError enables err logs
LogLevelError
// LogLevelInfo enables info logs (e.g. packets)
LogLevelInfo
// LogLevelDebug enables debug logs (e.g. packet contents)
LogLevelDebug
)
const logEnv = "QUIC_GO_LOG_LEVEL"
// A Logger logs.
type Logger interface {
SetLogLevel(LogLevel)
SetLogTimeFormat(format string)
WithPrefix(prefix string) Logger
Debug() bool
Errorf(format string, args ...interface{})
Infof(format string, args ...interface{})
Debugf(format string, args ...interface{})
}
// DefaultLogger is used by quic-go for logging.
var DefaultLogger Logger
type defaultLogger struct {
prefix string
logLevel LogLevel
timeFormat string
}
var _ Logger = &defaultLogger{}
// SetLogLevel sets the log level
func (l *defaultLogger) SetLogLevel(level LogLevel) {
l.logLevel = level
}
// SetLogTimeFormat sets the format of the timestamp
// an empty string disables the logging of timestamps
func (l *defaultLogger) SetLogTimeFormat(format string) {
log.SetFlags(0) // disable timestamp logging done by the log package
l.timeFormat = format
}
// Debugf logs something
func (l *defaultLogger) Debugf(format string, args ...interface{}) {
if l.logLevel == LogLevelDebug {
l.logMessage(format, args...)
}
}
// Infof logs something
func (l *defaultLogger) Infof(format string, args ...interface{}) {
if l.logLevel >= LogLevelInfo {
l.logMessage(format, args...)
}
}
// Errorf logs something
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
if l.logLevel >= LogLevelError {
l.logMessage(format, args...)
}
}
func (l *defaultLogger) logMessage(format string, args ...interface{}) {
var pre string
if len(l.timeFormat) > 0 {
pre = time.Now().Format(l.timeFormat) + " "
}
if len(l.prefix) > 0 {
pre += l.prefix + " "
}
log.Printf(pre+format, args...)
}
func (l *defaultLogger) WithPrefix(prefix string) Logger {
if len(l.prefix) > 0 {
prefix = l.prefix + " " + prefix
}
return &defaultLogger{
logLevel: l.logLevel,
timeFormat: l.timeFormat,
prefix: prefix,
}
}
// Debug returns true if the log level is LogLevelDebug
func (l *defaultLogger) Debug() bool {
return l.logLevel == LogLevelDebug
}
func init() {
DefaultLogger = &defaultLogger{}
DefaultLogger.SetLogLevel(readLoggingEnv())
}
func readLoggingEnv() LogLevel {
switch strings.ToLower(os.Getenv(logEnv)) {
case "":
return LogLevelNothing
case "debug":
return LogLevelDebug
case "info":
return LogLevelInfo
case "error":
return LogLevelError
default:
fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/quic-go/quic-go/wiki/Logging")
return LogLevelNothing
}
}
package utils
import (
"crypto/rand"
"encoding/binary"
)
// Rand is a wrapper around crypto/rand that adds some convenience functions known from math/rand.
type Rand struct {
buf [4]byte
}
func (r *Rand) Int31() int32 {
rand.Read(r.buf[:])
return int32(binary.BigEndian.Uint32(r.buf[:]) & ^uint32(1<<31))
}
// copied from the standard library math/rand implementation of Int63n
func (r *Rand) Int31n(n int32) int32 {
if n&(n-1) == 0 { // n is power of two, can mask
return r.Int31() & (n - 1)
}
max := int32((1 << 31) - 1 - (1<<31)%uint32(n))
v := r.Int31()
for v > max {
v = r.Int31()
}
return v % n
}
package ringbuffer
// A RingBuffer is a ring buffer.
// It acts as a heap that doesn't cause any allocations.
type RingBuffer[T any] struct {
ring []T
headPos, tailPos int
full bool
}
// Init preallocates a buffer with a certain size.
func (r *RingBuffer[T]) Init(size int) {
r.ring = make([]T, size)
}
// Len returns the number of elements in the ring buffer.
func (r *RingBuffer[T]) Len() int {
if r.full {
return len(r.ring)
}
if r.tailPos >= r.headPos {
return r.tailPos - r.headPos
}
return r.tailPos - r.headPos + len(r.ring)
}
// Empty says if the ring buffer is empty.
func (r *RingBuffer[T]) Empty() bool {
return !r.full && r.headPos == r.tailPos
}
// PushBack adds a new element.
// If the ring buffer is full, its capacity is increased first.
func (r *RingBuffer[T]) PushBack(t T) {
if r.full || len(r.ring) == 0 {
r.grow()
}
r.ring[r.tailPos] = t
r.tailPos++
if r.tailPos == len(r.ring) {
r.tailPos = 0
}
if r.tailPos == r.headPos {
r.full = true
}
}
// PopFront returns the next element.
// It must not be called when the buffer is empty, that means that
// callers might need to check if there are elements in the buffer first.
func (r *RingBuffer[T]) PopFront() T {
if r.Empty() {
panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: pop from an empty queue")
}
r.full = false
t := r.ring[r.headPos]
r.ring[r.headPos] = *new(T)
r.headPos++
if r.headPos == len(r.ring) {
r.headPos = 0
}
return t
}
// PeekFront returns the next element.
// It must not be called when the buffer is empty, that means that
// callers might need to check if there are elements in the buffer first.
func (r *RingBuffer[T]) PeekFront() T {
if r.Empty() {
panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: peek from an empty queue")
}
return r.ring[r.headPos]
}
// Grow the maximum size of the queue.
// This method assume the queue is full.
func (r *RingBuffer[T]) grow() {
oldRing := r.ring
newSize := len(oldRing) * 2
if newSize == 0 {
newSize = 1
}
r.ring = make([]T, newSize)
headLen := copy(r.ring, oldRing[r.headPos:])
copy(r.ring[headLen:], oldRing[:r.headPos])
r.headPos, r.tailPos, r.full = 0, len(oldRing), false
}
// Clear removes all elements.
func (r *RingBuffer[T]) Clear() {
var zeroValue T
for i := range r.ring {
r.ring[i] = zeroValue
}
r.headPos, r.tailPos, r.full = 0, 0, false
}
package utils
import (
"time"
"github.com/quic-go/quic-go/internal/protocol"
)
const (
rttAlpha = 0.125
oneMinusAlpha = 1 - rttAlpha
rttBeta = 0.25
oneMinusBeta = 1 - rttBeta
// The default RTT used before an RTT sample is taken.
defaultInitialRTT = 100 * time.Millisecond
)
// RTTStats provides round-trip statistics
type RTTStats struct {
hasMeasurement bool
minRTT time.Duration
latestRTT time.Duration
smoothedRTT time.Duration
meanDeviation time.Duration
maxAckDelay time.Duration
}
// MinRTT Returns the minRTT for the entire connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
// LatestRTT returns the most recent rtt measurement.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT }
// SmoothedRTT returns the smoothed RTT for the connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT }
// MeanDeviation gets the mean deviation
func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation }
// MaxAckDelay gets the max_ack_delay advertised by the peer
func (r *RTTStats) MaxAckDelay() time.Duration { return r.maxAckDelay }
// PTO gets the probe timeout duration.
func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration {
if r.SmoothedRTT() == 0 {
return 2 * defaultInitialRTT
}
pto := r.SmoothedRTT() + max(4*r.MeanDeviation(), protocol.TimerGranularity)
if includeMaxAckDelay {
pto += r.MaxAckDelay()
}
return pto
}
// UpdateRTT updates the RTT based on a new sample.
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration) {
if sendDelta <= 0 {
return
}
// Update r.minRTT first. r.minRTT does not use an rttSample corrected for
// ackDelay but the raw observed sendDelta, since poor clock granularity at
// the client may cause a high ackDelay to result in underestimation of the
// r.minRTT.
if r.minRTT == 0 || r.minRTT > sendDelta {
r.minRTT = sendDelta
}
// Correct for ackDelay if information received from the peer results in a
// an RTT sample at least as large as minRTT. Otherwise, only use the
// sendDelta.
sample := sendDelta
if sample-r.minRTT >= ackDelay {
sample -= ackDelay
}
r.latestRTT = sample
// First time call.
if !r.hasMeasurement {
r.hasMeasurement = true
r.smoothedRTT = sample
r.meanDeviation = sample / 2
} else {
r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32((r.smoothedRTT-sample).Abs()/time.Microsecond)) * time.Microsecond
r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond
}
}
// SetMaxAckDelay sets the max_ack_delay
func (r *RTTStats) SetMaxAckDelay(mad time.Duration) {
r.maxAckDelay = mad
}
// SetInitialRTT sets the initial RTT.
// It is used during handshake when restoring the RTT stats from the token.
func (r *RTTStats) SetInitialRTT(t time.Duration) {
// On the server side, by the time we get to process the session ticket,
// we might already have obtained an RTT measurement.
// This can happen if we received the ClientHello in multiple pieces, and one of those pieces was lost.
// Discard the restored value. A fresh measurement is always better.
if r.hasMeasurement {
return
}
r.smoothedRTT = t
r.latestRTT = t
}
func (r *RTTStats) ResetForPathMigration() {
r.hasMeasurement = false
r.minRTT = 0
r.latestRTT = 0
r.smoothedRTT = 0
r.meanDeviation = 0
// max_ack_delay remains valid
}
package utils
import (
"math"
"time"
)
// A Timer wrapper that behaves correctly when resetting
type Timer struct {
t *time.Timer
read bool
deadline time.Time
}
// NewTimer creates a new timer that is not set
func NewTimer() *Timer {
return &Timer{t: time.NewTimer(time.Duration(math.MaxInt64))}
}
// Chan returns the channel of the wrapped timer
func (t *Timer) Chan() <-chan time.Time {
return t.t.C
}
// Reset the timer, no matter whether the value was read or not
func (t *Timer) Reset(deadline time.Time) {
if deadline.Equal(t.deadline) && !t.read {
// No need to reset the timer
return
}
// We need to drain the timer if the value from its channel was not read yet.
// See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU
if !t.t.Stop() && !t.read {
<-t.t.C
}
if !deadline.IsZero() {
t.t.Reset(time.Until(deadline))
}
t.read = false
t.deadline = deadline
}
// SetRead should be called after the value from the chan was read
func (t *Timer) SetRead() {
t.read = true
}
func (t *Timer) Deadline() time.Time {
return t.deadline
}
// Stop stops the timer
func (t *Timer) Stop() {
t.t.Stop()
}
package wire
import (
"errors"
"math"
"sort"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges")
// An AckFrame is an ACK frame
type AckFrame struct {
AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last
DelayTime time.Duration
ECT0, ECT1, ECNCE uint64
}
// parseAckFrame reads an ACK frame
func parseAckFrame(frame *AckFrame, b []byte, typ FrameType, ackDelayExponent uint8, _ protocol.Version) (int, error) {
startLen := len(b)
ecn := typ == FrameTypeAckECN
la, l, err := quicvarint.Parse(b)
if err != nil {
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
largestAcked := protocol.PacketNumber(la)
delay, l, err := quicvarint.Parse(b)
if err != nil {
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
if delayTime < 0 {
// If the delay time overflows, set it to the maximum encode-able value.
delayTime = time.Duration(math.MaxInt64)
}
frame.DelayTime = delayTime
numBlocks, l, err := quicvarint.Parse(b)
if err != nil {
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
// read the first ACK range
ab, l, err := quicvarint.Parse(b)
if err != nil {
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
ackBlock := protocol.PacketNumber(ab)
if ackBlock > largestAcked {
return 0, errors.New("invalid first ACK range")
}
smallest := largestAcked - ackBlock
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked})
// read all the other ACK ranges
for i := uint64(0); i < numBlocks; i++ {
g, l, err := quicvarint.Parse(b)
if err != nil {
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
gap := protocol.PacketNumber(g)
if smallest < gap+2 {
return 0, errInvalidAckRanges
}
largest := smallest - gap - 2
ab, l, err := quicvarint.Parse(b)
if err != nil {
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
ackBlock := protocol.PacketNumber(ab)
if ackBlock > largest {
return 0, errInvalidAckRanges
}
smallest = largest - ackBlock
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest})
}
if !frame.validateAckRanges() {
return 0, errInvalidAckRanges
}
if ecn {
ect0, l, err := quicvarint.Parse(b)
if err != nil {
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
frame.ECT0 = ect0
ect1, l, err := quicvarint.Parse(b)
if err != nil {
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
frame.ECT1 = ect1
ecnce, l, err := quicvarint.Parse(b)
if err != nil {
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
frame.ECNCE = ecnce
}
return startLen - len(b), nil
}
// Append appends an ACK frame.
func (f *AckFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0
if hasECN {
b = append(b, byte(FrameTypeAckECN))
} else {
b = append(b, byte(FrameTypeAck))
}
b = quicvarint.Append(b, uint64(f.LargestAcked()))
b = quicvarint.Append(b, encodeAckDelay(f.DelayTime))
numRanges := f.numEncodableAckRanges()
b = quicvarint.Append(b, uint64(numRanges-1))
// write the first range
_, firstRange := f.encodeAckRange(0)
b = quicvarint.Append(b, firstRange)
// write all the other range
for i := 1; i < numRanges; i++ {
gap, len := f.encodeAckRange(i)
b = quicvarint.Append(b, gap)
b = quicvarint.Append(b, len)
}
if hasECN {
b = quicvarint.Append(b, f.ECT0)
b = quicvarint.Append(b, f.ECT1)
b = quicvarint.Append(b, f.ECNCE)
}
return b, nil
}
// Length of a written frame
func (f *AckFrame) Length(_ protocol.Version) protocol.ByteCount {
largestAcked := f.AckRanges[0].Largest
numRanges := f.numEncodableAckRanges()
length := 1 + quicvarint.Len(uint64(largestAcked)) + quicvarint.Len(encodeAckDelay(f.DelayTime))
length += quicvarint.Len(uint64(numRanges - 1))
lowestInFirstRange := f.AckRanges[0].Smallest
length += quicvarint.Len(uint64(largestAcked - lowestInFirstRange))
for i := 1; i < numRanges; i++ {
gap, len := f.encodeAckRange(i)
length += quicvarint.Len(gap)
length += quicvarint.Len(len)
}
if f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 {
length += quicvarint.Len(f.ECT0)
length += quicvarint.Len(f.ECT1)
length += quicvarint.Len(f.ECNCE)
}
return protocol.ByteCount(length)
}
// gets the number of ACK ranges that can be encoded
// such that the resulting frame is smaller than the maximum ACK frame size
func (f *AckFrame) numEncodableAckRanges() int {
length := 1 + quicvarint.Len(uint64(f.LargestAcked())) + quicvarint.Len(encodeAckDelay(f.DelayTime))
length += 2 // assume that the number of ranges will consume 2 bytes
for i := 1; i < len(f.AckRanges); i++ {
gap, len := f.encodeAckRange(i)
rangeLen := quicvarint.Len(gap) + quicvarint.Len(len)
if protocol.ByteCount(length+rangeLen) > protocol.MaxAckFrameSize {
// Writing range i would exceed the MaxAckFrameSize.
// So encode one range less than that.
return i - 1
}
length += rangeLen
}
return len(f.AckRanges)
}
func (f *AckFrame) encodeAckRange(i int) (uint64 /* gap */, uint64 /* length */) {
if i == 0 {
return 0, uint64(f.AckRanges[0].Largest - f.AckRanges[0].Smallest)
}
return uint64(f.AckRanges[i-1].Smallest - f.AckRanges[i].Largest - 2),
uint64(f.AckRanges[i].Largest - f.AckRanges[i].Smallest)
}
// HasMissingRanges returns if this frame reports any missing packets
func (f *AckFrame) HasMissingRanges() bool {
return len(f.AckRanges) > 1
}
func (f *AckFrame) validateAckRanges() bool {
if len(f.AckRanges) == 0 {
return false
}
// check the validity of every single ACK range
for _, ackRange := range f.AckRanges {
if ackRange.Smallest > ackRange.Largest {
return false
}
}
// check the consistency for ACK with multiple NACK ranges
for i, ackRange := range f.AckRanges {
if i == 0 {
continue
}
lastAckRange := f.AckRanges[i-1]
if lastAckRange.Smallest <= ackRange.Smallest {
return false
}
if lastAckRange.Smallest <= ackRange.Largest+1 {
return false
}
}
return true
}
// LargestAcked is the largest acked packet number
func (f *AckFrame) LargestAcked() protocol.PacketNumber {
return f.AckRanges[0].Largest
}
// LowestAcked is the lowest acked packet number
func (f *AckFrame) LowestAcked() protocol.PacketNumber {
return f.AckRanges[len(f.AckRanges)-1].Smallest
}
// AcksPacket determines if this ACK frame acks a certain packet number
func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool {
if p < f.LowestAcked() || p > f.LargestAcked() {
return false
}
i := sort.Search(len(f.AckRanges), func(i int) bool {
return p >= f.AckRanges[i].Smallest
})
// i will always be < len(f.AckRanges), since we checked above that p is not bigger than the largest acked
return p <= f.AckRanges[i].Largest
}
func (f *AckFrame) Reset() {
f.DelayTime = 0
f.ECT0 = 0
f.ECT1 = 0
f.ECNCE = 0
for _, r := range f.AckRanges {
r.Largest = 0
r.Smallest = 0
}
f.AckRanges = f.AckRanges[:0]
}
func encodeAckDelay(delay time.Duration) uint64 {
return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent)))
}
package wire
import (
"math"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
type AckFrequencyFrame struct {
SequenceNumber uint64
AckElicitingThreshold uint64
RequestMaxAckDelay time.Duration
ReorderingThreshold protocol.PacketNumber
}
func parseAckFrequencyFrame(b []byte, _ protocol.Version) (*AckFrequencyFrame, int, error) {
startLen := len(b)
seq, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
aeth, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
mad, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
// prevents overflows if the peer sends a very large value
maxAckDelay := time.Duration(mad) * time.Microsecond
if maxAckDelay < 0 {
maxAckDelay = math.MaxInt64
}
b = b[l:]
rth, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
return &AckFrequencyFrame{
SequenceNumber: seq,
AckElicitingThreshold: aeth,
RequestMaxAckDelay: maxAckDelay,
ReorderingThreshold: protocol.PacketNumber(rth),
}, startLen - len(b), nil
}
func (f *AckFrequencyFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = quicvarint.Append(b, uint64(FrameTypeAckFrequency))
b = quicvarint.Append(b, f.SequenceNumber)
b = quicvarint.Append(b, f.AckElicitingThreshold)
b = quicvarint.Append(b, uint64(f.RequestMaxAckDelay/time.Microsecond))
return quicvarint.Append(b, uint64(f.ReorderingThreshold)), nil
}
func (f *AckFrequencyFrame) Length(_ protocol.Version) protocol.ByteCount {
return protocol.ByteCount(2 + quicvarint.Len(f.SequenceNumber) + quicvarint.Len(f.AckElicitingThreshold) +
quicvarint.Len(uint64(f.RequestMaxAckDelay/time.Microsecond)) + quicvarint.Len(uint64(f.ReorderingThreshold)))
}
package wire
import "github.com/quic-go/quic-go/internal/protocol"
// AckRange is an ACK range
type AckRange struct {
Smallest protocol.PacketNumber
Largest protocol.PacketNumber
}
// Len returns the number of packets contained in this ACK range
func (r AckRange) Len() protocol.PacketNumber {
return r.Largest - r.Smallest + 1
}
package wire
import (
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// A ConnectionCloseFrame is a CONNECTION_CLOSE frame
type ConnectionCloseFrame struct {
IsApplicationError bool
ErrorCode uint64
FrameType uint64
ReasonPhrase string
}
func parseConnectionCloseFrame(b []byte, typ FrameType, _ protocol.Version) (*ConnectionCloseFrame, int, error) {
startLen := len(b)
f := &ConnectionCloseFrame{IsApplicationError: typ == FrameTypeApplicationClose}
ec, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
f.ErrorCode = ec
// read the Frame Type, if this is not an application error
if !f.IsApplicationError {
ft, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
f.FrameType = ft
}
var reasonPhraseLen uint64
reasonPhraseLen, l, err = quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
if int(reasonPhraseLen) > len(b) {
return nil, 0, io.EOF
}
reasonPhrase := make([]byte, reasonPhraseLen)
copy(reasonPhrase, b)
f.ReasonPhrase = string(reasonPhrase)
return f, startLen - len(b) + int(reasonPhraseLen), nil
}
// Length of a written frame
func (f *ConnectionCloseFrame) Length(protocol.Version) protocol.ByteCount {
length := 1 + protocol.ByteCount(quicvarint.Len(f.ErrorCode)+quicvarint.Len(uint64(len(f.ReasonPhrase)))) + protocol.ByteCount(len(f.ReasonPhrase))
if !f.IsApplicationError {
length += protocol.ByteCount(quicvarint.Len(f.FrameType)) // for the frame type
}
return length
}
func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
if f.IsApplicationError {
b = append(b, byte(FrameTypeApplicationClose))
} else {
b = append(b, byte(FrameTypeConnectionClose))
}
b = quicvarint.Append(b, f.ErrorCode)
if !f.IsApplicationError {
b = quicvarint.Append(b, f.FrameType)
}
b = quicvarint.Append(b, uint64(len(f.ReasonPhrase)))
b = append(b, []byte(f.ReasonPhrase)...)
return b, nil
}
package wire
import (
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// A CryptoFrame is a CRYPTO frame
type CryptoFrame struct {
Offset protocol.ByteCount
Data []byte
}
func parseCryptoFrame(b []byte, _ protocol.Version) (*CryptoFrame, int, error) {
startLen := len(b)
frame := &CryptoFrame{}
offset, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
frame.Offset = protocol.ByteCount(offset)
dataLen, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
if dataLen > uint64(len(b)) {
return nil, 0, io.EOF
}
if dataLen != 0 {
frame.Data = make([]byte, dataLen)
copy(frame.Data, b)
}
return frame, startLen - len(b) + int(dataLen), nil
}
func (f *CryptoFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, byte(FrameTypeCrypto))
b = quicvarint.Append(b, uint64(f.Offset))
b = quicvarint.Append(b, uint64(len(f.Data)))
b = append(b, f.Data...)
return b, nil
}
// Length of a written frame
func (f *CryptoFrame) Length(_ protocol.Version) protocol.ByteCount {
return protocol.ByteCount(1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + len(f.Data))
}
// MaxDataLen returns the maximum data length
func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount {
// pretend that the data size will be 1 bytes
// if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards
headerLen := protocol.ByteCount(1 + quicvarint.Len(uint64(f.Offset)) + 1)
if headerLen > maxSize {
return 0
}
maxDataLen := maxSize - headerLen
if quicvarint.Len(uint64(maxDataLen)) != 1 {
maxDataLen--
}
return maxDataLen
}
// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes.
// It returns if the frame was actually split.
// The frame might not be split if:
// * the size is large enough to fit the whole frame
// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil.
func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*CryptoFrame, bool /* was splitting required */) {
if f.Length(version) <= maxSize {
return nil, false
}
n := f.MaxDataLen(maxSize)
if n == 0 {
return nil, true
}
newLen := protocol.ByteCount(len(f.Data)) - n
new := &CryptoFrame{}
new.Offset = f.Offset
new.Data = make([]byte, newLen)
// swap the data slices
new.Data, f.Data = f.Data, new.Data
copy(f.Data, new.Data[n:])
new.Data = new.Data[:n]
f.Offset += n
return new, true
}
package wire
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// A DataBlockedFrame is a DATA_BLOCKED frame
type DataBlockedFrame struct {
MaximumData protocol.ByteCount
}
func parseDataBlockedFrame(b []byte, _ protocol.Version) (*DataBlockedFrame, int, error) {
offset, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, l, nil
}
func (f *DataBlockedFrame) Append(b []byte, version protocol.Version) ([]byte, error) {
b = append(b, byte(FrameTypeDataBlocked))
return quicvarint.Append(b, uint64(f.MaximumData)), nil
}
// Length of a written frame
func (f *DataBlockedFrame) Length(version protocol.Version) protocol.ByteCount {
return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.MaximumData)))
}
package wire
import (
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// MaxDatagramSize is the maximum size of a DATAGRAM frame (RFC 9221).
// By setting it to a large value, we allow all datagrams that fit into a QUIC packet.
// The value is chosen such that it can still be encoded as a 2 byte varint.
// This is a var and not a const so it can be set in tests.
var MaxDatagramSize protocol.ByteCount = 16383
// A DatagramFrame is a DATAGRAM frame
type DatagramFrame struct {
DataLenPresent bool
Data []byte
}
func parseDatagramFrame(b []byte, typ FrameType, _ protocol.Version) (*DatagramFrame, int, error) {
startLen := len(b)
f := &DatagramFrame{}
f.DataLenPresent = uint64(typ)&0x1 > 0
var length uint64
if f.DataLenPresent {
var err error
var l int
length, l, err = quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
if length > uint64(len(b)) {
return nil, 0, io.EOF
}
} else {
length = uint64(len(b))
}
f.Data = make([]byte, length)
copy(f.Data, b)
return f, startLen - len(b) + int(length), nil
}
func (f *DatagramFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
typ := uint8(0x30)
if f.DataLenPresent {
typ ^= 0b1
}
b = append(b, typ)
if f.DataLenPresent {
b = quicvarint.Append(b, uint64(len(f.Data)))
}
b = append(b, f.Data...)
return b, nil
}
// MaxDataLen returns the maximum data length
func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Version) protocol.ByteCount {
headerLen := protocol.ByteCount(1)
if f.DataLenPresent {
// pretend that the data size will be 1 bytes
// if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards
headerLen++
}
if headerLen > maxSize {
return 0
}
maxDataLen := maxSize - headerLen
if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 {
maxDataLen--
}
return maxDataLen
}
// Length of a written frame
func (f *DatagramFrame) Length(_ protocol.Version) protocol.ByteCount {
length := 1 + protocol.ByteCount(len(f.Data))
if f.DataLenPresent {
length += protocol.ByteCount(quicvarint.Len(uint64(len(f.Data))))
}
return length
}
package wire
import (
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
)
// ErrInvalidReservedBits is returned when the reserved bits are incorrect.
// When this error is returned, parsing continues, and an ExtendedHeader is returned.
// This is necessary because we need to decrypt the packet in that case,
// in order to avoid a timing side-channel.
var ErrInvalidReservedBits = errors.New("invalid reserved bits")
// ExtendedHeader is the header of a QUIC packet.
type ExtendedHeader struct {
Header
typeByte byte
KeyPhase protocol.KeyPhaseBit
PacketNumberLen protocol.PacketNumberLen
PacketNumber protocol.PacketNumber
parsedLen protocol.ByteCount
}
func (h *ExtendedHeader) parse(data []byte) (bool /* reserved bits valid */, error) {
// read the (now unencrypted) first byte
h.typeByte = data[0]
h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
if protocol.ByteCount(len(data)) < h.Header.ParsedLen()+protocol.ByteCount(h.PacketNumberLen) {
return false, io.EOF
}
pn, err := readPacketNumber(data[h.Header.ParsedLen():], h.PacketNumberLen)
if err != nil {
return true, nil
}
h.PacketNumber = pn
reservedBitsValid := h.typeByte&0xc == 0
h.parsedLen = h.Header.ParsedLen() + protocol.ByteCount(h.PacketNumberLen)
return reservedBitsValid, err
}
// Append appends the Header.
func (h *ExtendedHeader) Append(b []byte, v protocol.Version) ([]byte, error) {
if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
}
if h.SrcConnectionID.Len() > protocol.MaxConnIDLen {
return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len())
}
var packetType uint8
if v == protocol.Version2 {
switch h.Type {
case protocol.PacketTypeInitial:
packetType = 0b01
case protocol.PacketType0RTT:
packetType = 0b10
case protocol.PacketTypeHandshake:
packetType = 0b11
case protocol.PacketTypeRetry:
packetType = 0b00
}
} else {
switch h.Type {
case protocol.PacketTypeInitial:
packetType = 0b00
case protocol.PacketType0RTT:
packetType = 0b01
case protocol.PacketTypeHandshake:
packetType = 0b10
case protocol.PacketTypeRetry:
packetType = 0b11
}
}
firstByte := 0xc0 | packetType<<4
if h.Type != protocol.PacketTypeRetry {
// Retry packets don't have a packet number
firstByte |= uint8(h.PacketNumberLen - 1)
}
b = append(b, firstByte)
b = append(b, make([]byte, 4)...)
binary.BigEndian.PutUint32(b[len(b)-4:], uint32(h.Version))
b = append(b, uint8(h.DestConnectionID.Len()))
b = append(b, h.DestConnectionID.Bytes()...)
b = append(b, uint8(h.SrcConnectionID.Len()))
b = append(b, h.SrcConnectionID.Bytes()...)
//nolint:exhaustive
switch h.Type {
case protocol.PacketTypeRetry:
b = append(b, h.Token...)
return b, nil
case protocol.PacketTypeInitial:
b = quicvarint.Append(b, uint64(len(h.Token)))
b = append(b, h.Token...)
}
b = quicvarint.AppendWithLen(b, uint64(h.Length), 2)
return appendPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
// ParsedLen returns the number of bytes that were consumed when parsing the header
func (h *ExtendedHeader) ParsedLen() protocol.ByteCount {
return h.parsedLen
}
// GetLength determines the length of the Header.
func (h *ExtendedHeader) GetLength(_ protocol.Version) protocol.ByteCount {
length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
if h.Type == protocol.PacketTypeInitial {
length += protocol.ByteCount(quicvarint.Len(uint64(len(h.Token))) + len(h.Token))
}
return length
}
// Log logs the Header
func (h *ExtendedHeader) Log(logger utils.Logger) {
var token string
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
if len(h.Token) == 0 {
token = "Token: (empty), "
} else {
token = fmt.Sprintf("Token: %#x, ", h.Token)
}
if h.Type == protocol.PacketTypeRetry {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version)
return
}
}
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
}
func appendPacketNumber(b []byte, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) ([]byte, error) {
switch pnLen {
case protocol.PacketNumberLen1:
b = append(b, uint8(pn))
case protocol.PacketNumberLen2:
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(pn))
b = append(b, buf...)
case protocol.PacketNumberLen3:
buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, uint32(pn))
b = append(b, buf[1:]...)
case protocol.PacketNumberLen4:
buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, uint32(pn))
b = append(b, buf...)
default:
return nil, fmt.Errorf("invalid packet number length: %d", pnLen)
}
return b, nil
}
package wire
import (
"github.com/quic-go/quic-go/internal/protocol"
)
// A Frame in QUIC
type Frame interface {
Append(b []byte, version protocol.Version) ([]byte, error)
Length(version protocol.Version) protocol.ByteCount
}
// IsProbingFrame returns true if the frame is a probing frame.
// See section 9.1 of RFC 9000.
func IsProbingFrame(f Frame) bool {
switch f.(type) {
case *PathChallengeFrame, *PathResponseFrame, *NewConnectionIDFrame:
return true
}
return false
}
// IsProbingFrameType returns true if the FrameType is a probing frame.
// See section 9.1 of RFC 9000.
func IsProbingFrameType(f FrameType) bool {
//nolint:exhaustive // PATH_CHALLENGE, PATH_RESPONSE and NEW_CONNECTION_ID are the only probing frames
switch f {
case FrameTypePathChallenge, FrameTypePathResponse, FrameTypeNewConnectionID:
return true
default:
return false
}
}
package wire
import (
"errors"
"fmt"
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/quicvarint"
)
var errUnknownFrameType = errors.New("unknown frame type")
// The FrameParser parses QUIC frames, one by one.
type FrameParser struct {
ackDelayExponent uint8
supportsDatagrams bool
supportsResetStreamAt bool
supportsAckFrequency bool
// To avoid allocating when parsing, keep a single ACK frame struct.
// It is used over and over again.
ackFrame *AckFrame
}
// NewFrameParser creates a new frame parser.
func NewFrameParser(supportsDatagrams, supportsResetStreamAt, supportsAckFrequency bool) *FrameParser {
return &FrameParser{
supportsDatagrams: supportsDatagrams,
supportsResetStreamAt: supportsResetStreamAt,
supportsAckFrequency: supportsAckFrequency,
ackFrame: &AckFrame{},
}
}
// ParseType parses the frame type of the next frame.
// It skips over PADDING frames.
func (p *FrameParser) ParseType(b []byte, encLevel protocol.EncryptionLevel) (FrameType, int, error) {
var parsed int
for len(b) != 0 {
typ, l, err := quicvarint.Parse(b)
parsed += l
if err != nil {
return 0, parsed, &qerr.TransportError{
ErrorCode: qerr.FrameEncodingError,
ErrorMessage: err.Error(),
}
}
b = b[l:]
if typ == 0x0 { // skip PADDING frames
continue
}
ft := FrameType(typ)
valid := ft.isValidRFC9000() ||
(p.supportsDatagrams && ft.IsDatagramFrameType()) ||
(p.supportsResetStreamAt && ft == FrameTypeResetStreamAt) ||
(p.supportsAckFrequency && (ft == FrameTypeAckFrequency || ft == FrameTypeImmediateAck))
if !valid {
return 0, parsed, &qerr.TransportError{
ErrorCode: qerr.FrameEncodingError,
FrameType: typ,
ErrorMessage: errUnknownFrameType.Error(),
}
}
if !ft.isAllowedAtEncLevel(encLevel) {
return 0, parsed, &qerr.TransportError{
ErrorCode: qerr.FrameEncodingError,
FrameType: typ,
ErrorMessage: fmt.Sprintf("%d not allowed at encryption level %s", ft, encLevel),
}
}
return ft, parsed, nil
}
return 0, parsed, io.EOF
}
func (p *FrameParser) ParseStreamFrame(frameType FrameType, data []byte, v protocol.Version) (*StreamFrame, int, error) {
frame, n, err := ParseStreamFrame(data, frameType, v)
if err != nil {
return nil, n, &qerr.TransportError{
ErrorCode: qerr.FrameEncodingError,
FrameType: uint64(frameType),
ErrorMessage: err.Error(),
}
}
return frame, n, nil
}
func (p *FrameParser) ParseAckFrame(frameType FrameType, data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (*AckFrame, int, error) {
ackDelayExponent := p.ackDelayExponent
if encLevel != protocol.Encryption1RTT {
ackDelayExponent = protocol.DefaultAckDelayExponent
}
p.ackFrame.Reset()
l, err := parseAckFrame(p.ackFrame, data, frameType, ackDelayExponent, v)
if err != nil {
return nil, l, &qerr.TransportError{
ErrorCode: qerr.FrameEncodingError,
FrameType: uint64(frameType),
ErrorMessage: err.Error(),
}
}
return p.ackFrame, l, nil
}
func (p *FrameParser) ParseDatagramFrame(frameType FrameType, data []byte, v protocol.Version) (*DatagramFrame, int, error) {
f, l, err := parseDatagramFrame(data, frameType, v)
if err != nil {
return nil, 0, &qerr.TransportError{
ErrorCode: qerr.FrameEncodingError,
FrameType: uint64(frameType),
ErrorMessage: err.Error(),
}
}
return f, l, nil
}
// ParseLessCommonFrame parses everything except STREAM, ACK or DATAGRAM.
// These cases should be handled separately for performance reasons.
func (p *FrameParser) ParseLessCommonFrame(frameType FrameType, data []byte, v protocol.Version) (Frame, int, error) {
var frame Frame
var l int
var err error
//nolint:exhaustive // Common frames should already be handled.
switch frameType {
case FrameTypePing:
frame = &PingFrame{}
case FrameTypeResetStream:
frame, l, err = parseResetStreamFrame(data, false, v)
case FrameTypeStopSending:
frame, l, err = parseStopSendingFrame(data, v)
case FrameTypeCrypto:
frame, l, err = parseCryptoFrame(data, v)
case FrameTypeNewToken:
frame, l, err = parseNewTokenFrame(data, v)
case FrameTypeMaxData:
frame, l, err = parseMaxDataFrame(data, v)
case FrameTypeMaxStreamData:
frame, l, err = parseMaxStreamDataFrame(data, v)
case FrameTypeBidiMaxStreams, FrameTypeUniMaxStreams:
frame, l, err = parseMaxStreamsFrame(data, frameType, v)
case FrameTypeDataBlocked:
frame, l, err = parseDataBlockedFrame(data, v)
case FrameTypeStreamDataBlocked:
frame, l, err = parseStreamDataBlockedFrame(data, v)
case FrameTypeBidiStreamBlocked, FrameTypeUniStreamBlocked:
frame, l, err = parseStreamsBlockedFrame(data, frameType, v)
case FrameTypeNewConnectionID:
frame, l, err = parseNewConnectionIDFrame(data, v)
case FrameTypeRetireConnectionID:
frame, l, err = parseRetireConnectionIDFrame(data, v)
case FrameTypePathChallenge:
frame, l, err = parsePathChallengeFrame(data, v)
case FrameTypePathResponse:
frame, l, err = parsePathResponseFrame(data, v)
case FrameTypeConnectionClose, FrameTypeApplicationClose:
frame, l, err = parseConnectionCloseFrame(data, frameType, v)
case FrameTypeHandshakeDone:
frame = &HandshakeDoneFrame{}
case FrameTypeResetStreamAt:
frame, l, err = parseResetStreamFrame(data, true, v)
case FrameTypeAckFrequency:
frame, l, err = parseAckFrequencyFrame(data, v)
case FrameTypeImmediateAck:
frame = &ImmediateAckFrame{}
default:
err = errUnknownFrameType
}
if err != nil {
return frame, l, &qerr.TransportError{
ErrorCode: qerr.FrameEncodingError,
FrameType: uint64(frameType),
ErrorMessage: err.Error(),
}
}
return frame, l, err
}
// SetAckDelayExponent sets the acknowledgment delay exponent (sent in the transport parameters).
// This value is used to scale the ACK Delay field in the ACK frame.
func (p *FrameParser) SetAckDelayExponent(exp uint8) {
p.ackDelayExponent = exp
}
func replaceUnexpectedEOF(e error) error {
if e == io.ErrUnexpectedEOF {
return io.EOF
}
return e
}
package wire
import "github.com/quic-go/quic-go/internal/protocol"
type FrameType uint64
// These constants correspond to those defined in RFC 9000.
// Stream frame types are not listed explicitly here; use FrameType.IsStreamFrameType() to identify them.
const (
FrameTypePing FrameType = 0x1
FrameTypeAck FrameType = 0x2
FrameTypeAckECN FrameType = 0x3
FrameTypeResetStream FrameType = 0x4
FrameTypeStopSending FrameType = 0x5
FrameTypeCrypto FrameType = 0x6
FrameTypeNewToken FrameType = 0x7
FrameTypeMaxData FrameType = 0x10
FrameTypeMaxStreamData FrameType = 0x11
FrameTypeBidiMaxStreams FrameType = 0x12
FrameTypeUniMaxStreams FrameType = 0x13
FrameTypeDataBlocked FrameType = 0x14
FrameTypeStreamDataBlocked FrameType = 0x15
FrameTypeBidiStreamBlocked FrameType = 0x16
FrameTypeUniStreamBlocked FrameType = 0x17
FrameTypeNewConnectionID FrameType = 0x18
FrameTypeRetireConnectionID FrameType = 0x19
FrameTypePathChallenge FrameType = 0x1a
FrameTypePathResponse FrameType = 0x1b
FrameTypeConnectionClose FrameType = 0x1c
FrameTypeApplicationClose FrameType = 0x1d
FrameTypeHandshakeDone FrameType = 0x1e
// https://datatracker.ietf.org/doc/draft-ietf-quic-reliable-stream-reset/07/
FrameTypeResetStreamAt FrameType = 0x24
// https://datatracker.ietf.org/doc/draft-ietf-quic-ack-frequency/11/
FrameTypeAckFrequency FrameType = 0xaf
FrameTypeImmediateAck FrameType = 0x1f
FrameTypeDatagramNoLength FrameType = 0x30
FrameTypeDatagramWithLength FrameType = 0x31
)
func (t FrameType) IsStreamFrameType() bool {
return t >= 0x8 && t <= 0xf
}
func (t FrameType) isValidRFC9000() bool {
return t <= 0x1e
}
func (t FrameType) IsAckFrameType() bool {
return t == FrameTypeAck || t == FrameTypeAckECN
}
func (t FrameType) IsDatagramFrameType() bool {
return t == FrameTypeDatagramNoLength || t == FrameTypeDatagramWithLength
}
func (t FrameType) isAllowedAtEncLevel(encLevel protocol.EncryptionLevel) bool {
//nolint:exhaustive
switch encLevel {
case protocol.EncryptionInitial, protocol.EncryptionHandshake:
switch t {
case FrameTypeCrypto, FrameTypeAck, FrameTypeAckECN, FrameTypeConnectionClose, FrameTypePing:
return true
default:
return false
}
case protocol.Encryption0RTT:
switch t {
case FrameTypeCrypto, FrameTypeAck, FrameTypeAckECN, FrameTypeConnectionClose, FrameTypeNewToken, FrameTypePathResponse, FrameTypeRetireConnectionID:
return false
default:
return true
}
case protocol.Encryption1RTT:
return true
default:
panic("unknown encryption level")
}
}
package wire
import (
"github.com/quic-go/quic-go/internal/protocol"
)
// A HandshakeDoneFrame is a HANDSHAKE_DONE frame
type HandshakeDoneFrame struct{}
func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
return append(b, byte(FrameTypeHandshakeDone)), nil
}
// Length of a written frame
func (f *HandshakeDoneFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1
}
package wire
import (
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// ParseConnectionID parses the destination connection ID of a packet.
func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) {
if len(data) == 0 {
return protocol.ConnectionID{}, io.EOF
}
if !IsLongHeaderPacket(data[0]) {
if len(data) < shortHeaderConnIDLen+1 {
return protocol.ConnectionID{}, io.EOF
}
return protocol.ParseConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil
}
if len(data) < 6 {
return protocol.ConnectionID{}, io.EOF
}
destConnIDLen := int(data[5])
if destConnIDLen > protocol.MaxConnIDLen {
return protocol.ConnectionID{}, protocol.ErrInvalidConnectionIDLen
}
if len(data) < 6+destConnIDLen {
return protocol.ConnectionID{}, io.EOF
}
return protocol.ParseConnectionID(data[6 : 6+destConnIDLen]), nil
}
// ParseArbitraryLenConnectionIDs parses the most general form of a Long Header packet,
// using only the version-independent packet format as described in Section 5.1 of RFC 8999:
// https://datatracker.ietf.org/doc/html/rfc8999#section-5.1.
// This function should only be called on Long Header packets for which we don't support the version.
func ParseArbitraryLenConnectionIDs(data []byte) (bytesParsed int, dest, src protocol.ArbitraryLenConnectionID, _ error) {
startLen := len(data)
if len(data) < 6 {
return 0, nil, nil, io.EOF
}
data = data[5:] // skip first byte and version field
destConnIDLen := data[0]
data = data[1:]
destConnID := make(protocol.ArbitraryLenConnectionID, destConnIDLen)
if len(data) < int(destConnIDLen)+1 {
return 0, nil, nil, io.EOF
}
copy(destConnID, data)
data = data[destConnIDLen:]
srcConnIDLen := data[0]
data = data[1:]
if len(data) < int(srcConnIDLen) {
return 0, nil, nil, io.EOF
}
srcConnID := make(protocol.ArbitraryLenConnectionID, srcConnIDLen)
copy(srcConnID, data)
return startLen - len(data) + int(srcConnIDLen), destConnID, srcConnID, nil
}
func IsPotentialQUICPacket(firstByte byte) bool {
return firstByte&0x40 > 0
}
// IsLongHeaderPacket says if this is a Long Header packet
func IsLongHeaderPacket(firstByte byte) bool {
return firstByte&0x80 > 0
}
// ParseVersion parses the QUIC version.
// It should only be called for Long Header packets (Short Header packets don't contain a version number).
func ParseVersion(data []byte) (protocol.Version, error) {
if len(data) < 5 {
return 0, io.EOF
}
return protocol.Version(binary.BigEndian.Uint32(data[1:5])), nil
}
// IsVersionNegotiationPacket says if this is a version negotiation packet
func IsVersionNegotiationPacket(b []byte) bool {
if len(b) < 5 {
return false
}
return IsLongHeaderPacket(b[0]) && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0
}
// Is0RTTPacket says if this is a 0-RTT packet.
// A packet sent with a version we don't understand can never be a 0-RTT packet.
func Is0RTTPacket(b []byte) bool {
if len(b) < 5 {
return false
}
if !IsLongHeaderPacket(b[0]) {
return false
}
version := protocol.Version(binary.BigEndian.Uint32(b[1:5]))
//nolint:exhaustive // We only need to test QUIC versions that we support.
switch version {
case protocol.Version1:
return b[0]>>4&0b11 == 0b01
case protocol.Version2:
return b[0]>>4&0b11 == 0b10
default:
return false
}
}
var ErrUnsupportedVersion = errors.New("unsupported version")
// The Header is the version independent part of the header
type Header struct {
typeByte byte
Type protocol.PacketType
Version protocol.Version
SrcConnectionID protocol.ConnectionID
DestConnectionID protocol.ConnectionID
Length protocol.ByteCount
Token []byte
parsedLen protocol.ByteCount // how many bytes were read while parsing this header
}
// ParsePacket parses a long header packet.
// The packet is cut according to the length field.
// If we understand the version, the packet is parsed up unto the packet number.
// Otherwise, only the invariant part of the header is parsed.
func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
if len(data) == 0 || !IsLongHeaderPacket(data[0]) {
return nil, nil, nil, errors.New("not a long header packet")
}
hdr, err := parseHeader(data)
if err != nil {
if errors.Is(err, ErrUnsupportedVersion) {
return hdr, nil, nil, err
}
return nil, nil, nil, err
}
if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length {
return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
}
packetLen := int(hdr.ParsedLen() + hdr.Length)
return hdr, data[:packetLen], data[packetLen:], nil
}
// ParseHeader parses the header:
// * if we understand the version: up to the packet number
// * if not, only the invariant part of the header
func parseHeader(b []byte) (*Header, error) {
if len(b) == 0 {
return nil, io.EOF
}
typeByte := b[0]
h := &Header{typeByte: typeByte}
l, err := h.parseLongHeader(b[1:])
h.parsedLen = protocol.ByteCount(l) + 1
return h, err
}
func (h *Header) parseLongHeader(b []byte) (int, error) {
startLen := len(b)
if len(b) < 5 {
return 0, io.EOF
}
h.Version = protocol.Version(binary.BigEndian.Uint32(b[:4]))
if h.Version != 0 && h.typeByte&0x40 == 0 {
return startLen - len(b), errors.New("not a QUIC packet")
}
destConnIDLen := int(b[4])
if destConnIDLen > protocol.MaxConnIDLen {
return startLen - len(b), protocol.ErrInvalidConnectionIDLen
}
b = b[5:]
if len(b) < destConnIDLen+1 {
return startLen - len(b), io.EOF
}
h.DestConnectionID = protocol.ParseConnectionID(b[:destConnIDLen])
srcConnIDLen := int(b[destConnIDLen])
if srcConnIDLen > protocol.MaxConnIDLen {
return startLen - len(b), protocol.ErrInvalidConnectionIDLen
}
b = b[destConnIDLen+1:]
if len(b) < srcConnIDLen {
return startLen - len(b), io.EOF
}
h.SrcConnectionID = protocol.ParseConnectionID(b[:srcConnIDLen])
b = b[srcConnIDLen:]
if h.Version == 0 { // version negotiation packet
return startLen - len(b), nil
}
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
return startLen - len(b), ErrUnsupportedVersion
}
if h.Version == protocol.Version2 {
switch h.typeByte >> 4 & 0b11 {
case 0b00:
h.Type = protocol.PacketTypeRetry
case 0b01:
h.Type = protocol.PacketTypeInitial
case 0b10:
h.Type = protocol.PacketType0RTT
case 0b11:
h.Type = protocol.PacketTypeHandshake
}
} else {
switch h.typeByte >> 4 & 0b11 {
case 0b00:
h.Type = protocol.PacketTypeInitial
case 0b01:
h.Type = protocol.PacketType0RTT
case 0b10:
h.Type = protocol.PacketTypeHandshake
case 0b11:
h.Type = protocol.PacketTypeRetry
}
}
if h.Type == protocol.PacketTypeRetry {
tokenLen := len(b) - 16
if tokenLen <= 0 {
return startLen - len(b), io.EOF
}
h.Token = make([]byte, tokenLen)
copy(h.Token, b[:tokenLen])
return startLen - len(b) + tokenLen + 16, nil
}
if h.Type == protocol.PacketTypeInitial {
tokenLen, n, err := quicvarint.Parse(b)
if err != nil {
return startLen - len(b), err
}
b = b[n:]
if tokenLen > uint64(len(b)) {
return startLen - len(b), io.EOF
}
h.Token = make([]byte, tokenLen)
copy(h.Token, b[:tokenLen])
b = b[tokenLen:]
}
pl, n, err := quicvarint.Parse(b)
if err != nil {
return 0, err
}
h.Length = protocol.ByteCount(pl)
return startLen - len(b) + n, nil
}
// ParsedLen returns the number of bytes that were consumed when parsing the header
func (h *Header) ParsedLen() protocol.ByteCount {
return h.parsedLen
}
// ParseExtended parses the version dependent part of the header.
// The Reader has to be set such that it points to the first byte of the header.
func (h *Header) ParseExtended(data []byte) (*ExtendedHeader, error) {
extHdr := h.toExtendedHeader()
reservedBitsValid, err := extHdr.parse(data)
if err != nil {
return nil, err
}
if !reservedBitsValid {
return extHdr, ErrInvalidReservedBits
}
return extHdr, nil
}
func (h *Header) toExtendedHeader() *ExtendedHeader {
return &ExtendedHeader{Header: *h}
}
// PacketType is the type of the packet, for logging purposes
func (h *Header) PacketType() string {
return h.Type.String()
}
func readPacketNumber(data []byte, pnLen protocol.PacketNumberLen) (protocol.PacketNumber, error) {
var pn protocol.PacketNumber
switch pnLen {
case protocol.PacketNumberLen1:
pn = protocol.PacketNumber(data[0])
case protocol.PacketNumberLen2:
pn = protocol.PacketNumber(binary.BigEndian.Uint16(data[:2]))
case protocol.PacketNumberLen3:
pn = protocol.PacketNumber(uint32(data[2]) + uint32(data[1])<<8 + uint32(data[0])<<16)
case protocol.PacketNumberLen4:
pn = protocol.PacketNumber(binary.BigEndian.Uint32(data[:4]))
default:
return 0, fmt.Errorf("invalid packet number length: %d", pnLen)
}
return pn, nil
}
package wire
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// An ImmediateAckFrame is an IMMEDIATE_ACK frame
type ImmediateAckFrame struct{}
func (f *ImmediateAckFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
return quicvarint.Append(b, uint64(FrameTypeImmediateAck)), nil
}
// Length of a written frame
func (f *ImmediateAckFrame) Length(_ protocol.Version) protocol.ByteCount {
return protocol.ByteCount(quicvarint.Len(uint64(FrameTypeImmediateAck)))
}
package wire
import (
"fmt"
"strings"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
// LogFrame logs a frame, either sent or received
func LogFrame(logger utils.Logger, frame Frame, sent bool) {
if !logger.Debug() {
return
}
dir := "<-"
if sent {
dir = "->"
}
switch f := frame.(type) {
case *CryptoFrame:
dataLen := protocol.ByteCount(len(f.Data))
logger.Debugf("\t%s &wire.CryptoFrame{Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.Offset, dataLen, f.Offset+dataLen)
case *StreamFrame:
logger.Debugf("\t%s &wire.StreamFrame{StreamID: %d, Fin: %t, Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.StreamID, f.Fin, f.Offset, f.DataLen(), f.Offset+f.DataLen())
case *ResetStreamFrame:
logger.Debugf("\t%s &wire.ResetStreamFrame{StreamID: %d, ErrorCode: %#x, FinalSize: %d}", dir, f.StreamID, f.ErrorCode, f.FinalSize)
case *AckFrame:
hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0
var ecn string
if hasECN {
ecn = fmt.Sprintf(", ECT0: %d, ECT1: %d, CE: %d", f.ECT0, f.ECT1, f.ECNCE)
}
if len(f.AckRanges) > 1 {
ackRanges := make([]string, len(f.AckRanges))
for i, r := range f.AckRanges {
ackRanges[i] = fmt.Sprintf("{Largest: %d, Smallest: %d}", r.Largest, r.Smallest)
}
logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, AckRanges: {%s}, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), strings.Join(ackRanges, ", "), f.DelayTime.String(), ecn)
} else {
logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), f.DelayTime.String(), ecn)
}
case *MaxDataFrame:
logger.Debugf("\t%s &wire.MaxDataFrame{MaximumData: %d}", dir, f.MaximumData)
case *MaxStreamDataFrame:
logger.Debugf("\t%s &wire.MaxStreamDataFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData)
case *DataBlockedFrame:
logger.Debugf("\t%s &wire.DataBlockedFrame{MaximumData: %d}", dir, f.MaximumData)
case *StreamDataBlockedFrame:
logger.Debugf("\t%s &wire.StreamDataBlockedFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData)
case *MaxStreamsFrame:
switch f.Type {
case protocol.StreamTypeUni:
logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: uni, MaxStreamNum: %d}", dir, f.MaxStreamNum)
case protocol.StreamTypeBidi:
logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: bidi, MaxStreamNum: %d}", dir, f.MaxStreamNum)
}
case *StreamsBlockedFrame:
switch f.Type {
case protocol.StreamTypeUni:
logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: uni, MaxStreams: %d}", dir, f.StreamLimit)
case protocol.StreamTypeBidi:
logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit)
}
case *NewConnectionIDFrame:
logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, RetirePriorTo: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.RetirePriorTo, f.ConnectionID, f.StatelessResetToken)
case *RetireConnectionIDFrame:
logger.Debugf("\t%s &wire.RetireConnectionIDFrame{SequenceNumber: %d}", dir, f.SequenceNumber)
case *NewTokenFrame:
logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token)
default:
logger.Debugf("\t%s %#v", dir, frame)
}
}
package wire
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// A MaxDataFrame carries flow control information for the connection
type MaxDataFrame struct {
MaximumData protocol.ByteCount
}
// parseMaxDataFrame parses a MAX_DATA frame
func parseMaxDataFrame(b []byte, _ protocol.Version) (*MaxDataFrame, int, error) {
frame := &MaxDataFrame{}
byteOffset, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
frame.MaximumData = protocol.ByteCount(byteOffset)
return frame, l, nil
}
func (f *MaxDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, byte(FrameTypeMaxData))
b = quicvarint.Append(b, uint64(f.MaximumData))
return b, nil
}
// Length of a written frame
func (f *MaxDataFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.MaximumData)))
}
package wire
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// A MaxStreamDataFrame is a MAX_STREAM_DATA frame
type MaxStreamDataFrame struct {
StreamID protocol.StreamID
MaximumStreamData protocol.ByteCount
}
func parseMaxStreamDataFrame(b []byte, _ protocol.Version) (*MaxStreamDataFrame, int, error) {
startLen := len(b)
sid, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
offset, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
return &MaxStreamDataFrame{
StreamID: protocol.StreamID(sid),
MaximumStreamData: protocol.ByteCount(offset),
}, startLen - len(b), nil
}
func (f *MaxStreamDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, byte(FrameTypeMaxStreamData))
b = quicvarint.Append(b, uint64(f.StreamID))
b = quicvarint.Append(b, uint64(f.MaximumStreamData))
return b, nil
}
// Length of a written frame
func (f *MaxStreamDataFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID))+quicvarint.Len(uint64(f.MaximumStreamData)))
}
package wire
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// A MaxStreamsFrame is a MAX_STREAMS frame
type MaxStreamsFrame struct {
Type protocol.StreamType
MaxStreamNum protocol.StreamNum
}
func parseMaxStreamsFrame(b []byte, typ FrameType, _ protocol.Version) (*MaxStreamsFrame, int, error) {
f := &MaxStreamsFrame{}
//nolint:exhaustive // Function will only be called with BidiMaxStreamsFrameType or UniMaxStreamsFrameType
switch typ {
case FrameTypeBidiMaxStreams:
f.Type = protocol.StreamTypeBidi
case FrameTypeUniMaxStreams:
f.Type = protocol.StreamTypeUni
}
streamID, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
f.MaxStreamNum = protocol.StreamNum(streamID)
if f.MaxStreamNum > protocol.MaxStreamCount {
return nil, 0, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum)
}
return f, l, nil
}
func (f *MaxStreamsFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
switch f.Type {
case protocol.StreamTypeBidi:
b = append(b, byte(FrameTypeBidiMaxStreams))
case protocol.StreamTypeUni:
b = append(b, byte(FrameTypeUniMaxStreams))
}
b = quicvarint.Append(b, uint64(f.MaxStreamNum))
return b, nil
}
// Length of a written frame
func (f *MaxStreamsFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.MaxStreamNum)))
}
package wire
import (
"errors"
"fmt"
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// A NewConnectionIDFrame is a NEW_CONNECTION_ID frame
type NewConnectionIDFrame struct {
SequenceNumber uint64
RetirePriorTo uint64
ConnectionID protocol.ConnectionID
StatelessResetToken protocol.StatelessResetToken
}
func parseNewConnectionIDFrame(b []byte, _ protocol.Version) (*NewConnectionIDFrame, int, error) {
startLen := len(b)
seq, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
ret, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
if ret > seq {
//nolint:staticcheck // SA1021: Retire Prior To is the name of the field
return nil, 0, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq)
}
if len(b) == 0 {
return nil, 0, io.EOF
}
connIDLen := int(b[0])
b = b[1:]
if connIDLen == 0 {
return nil, 0, errors.New("invalid zero-length connection ID")
}
if connIDLen > protocol.MaxConnIDLen {
return nil, 0, protocol.ErrInvalidConnectionIDLen
}
if len(b) < connIDLen {
return nil, 0, io.EOF
}
frame := &NewConnectionIDFrame{
SequenceNumber: seq,
RetirePriorTo: ret,
ConnectionID: protocol.ParseConnectionID(b[:connIDLen]),
}
b = b[connIDLen:]
if len(b) < len(frame.StatelessResetToken) {
return nil, 0, io.EOF
}
copy(frame.StatelessResetToken[:], b)
return frame, startLen - len(b) + len(frame.StatelessResetToken), nil
}
func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, byte(FrameTypeNewConnectionID))
b = quicvarint.Append(b, f.SequenceNumber)
b = quicvarint.Append(b, f.RetirePriorTo)
connIDLen := f.ConnectionID.Len()
if connIDLen > protocol.MaxConnIDLen {
return nil, fmt.Errorf("invalid connection ID length: %d", connIDLen)
}
b = append(b, uint8(connIDLen))
b = append(b, f.ConnectionID.Bytes()...)
b = append(b, f.StatelessResetToken[:]...)
return b, nil
}
// Length of a written frame
func (f *NewConnectionIDFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + protocol.ByteCount(quicvarint.Len(f.SequenceNumber)+quicvarint.Len(f.RetirePriorTo)+1 /* connection ID length */ +f.ConnectionID.Len()) + 16
}
package wire
import (
"errors"
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// A NewTokenFrame is a NEW_TOKEN frame
type NewTokenFrame struct {
Token []byte
}
func parseNewTokenFrame(b []byte, _ protocol.Version) (*NewTokenFrame, int, error) {
tokenLen, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
if tokenLen == 0 {
return nil, 0, errors.New("token must not be empty")
}
if uint64(len(b)) < tokenLen {
return nil, 0, io.EOF
}
token := make([]byte, int(tokenLen))
copy(token, b)
return &NewTokenFrame{Token: token}, l + int(tokenLen), nil
}
func (f *NewTokenFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, byte(FrameTypeNewToken))
b = quicvarint.Append(b, uint64(len(f.Token)))
b = append(b, f.Token...)
return b, nil
}
// Length of a written frame
func (f *NewTokenFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + protocol.ByteCount(quicvarint.Len(uint64(len(f.Token)))+len(f.Token))
}
package wire
import (
"io"
"github.com/quic-go/quic-go/internal/protocol"
)
// A PathChallengeFrame is a PATH_CHALLENGE frame
type PathChallengeFrame struct {
Data [8]byte
}
func parsePathChallengeFrame(b []byte, _ protocol.Version) (*PathChallengeFrame, int, error) {
f := &PathChallengeFrame{}
if len(b) < 8 {
return nil, 0, io.EOF
}
copy(f.Data[:], b)
return f, 8, nil
}
func (f *PathChallengeFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, byte(FrameTypePathChallenge))
b = append(b, f.Data[:]...)
return b, nil
}
// Length of a written frame
func (f *PathChallengeFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + 8
}
package wire
import (
"io"
"github.com/quic-go/quic-go/internal/protocol"
)
// A PathResponseFrame is a PATH_RESPONSE frame
type PathResponseFrame struct {
Data [8]byte
}
func parsePathResponseFrame(b []byte, _ protocol.Version) (*PathResponseFrame, int, error) {
f := &PathResponseFrame{}
if len(b) < 8 {
return nil, 0, io.EOF
}
copy(f.Data[:], b)
return f, 8, nil
}
func (f *PathResponseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, byte(FrameTypePathResponse))
b = append(b, f.Data[:]...)
return b, nil
}
// Length of a written frame
func (f *PathResponseFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + 8
}
package wire
import (
"github.com/quic-go/quic-go/internal/protocol"
)
// A PingFrame is a PING frame
type PingFrame struct{}
func (f *PingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
return append(b, byte(FrameTypePing)), nil
}
// Length of a written frame
func (f *PingFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1
}
package wire
import (
"sync"
"github.com/quic-go/quic-go/internal/protocol"
)
var pool sync.Pool
func init() {
pool.New = func() interface{} {
return &StreamFrame{
Data: make([]byte, 0, protocol.MaxPacketBufferSize),
fromPool: true,
}
}
}
func GetStreamFrame() *StreamFrame {
f := pool.Get().(*StreamFrame)
return f
}
func putStreamFrame(f *StreamFrame) {
if !f.fromPool {
return
}
if protocol.ByteCount(cap(f.Data)) != protocol.MaxPacketBufferSize {
panic("wire.PutStreamFrame called with packet of wrong size!")
}
pool.Put(f)
}
package wire
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/quicvarint"
)
// A ResetStreamFrame is a RESET_STREAM or RESET_STREAM_AT frame in QUIC
type ResetStreamFrame struct {
StreamID protocol.StreamID
ErrorCode qerr.StreamErrorCode
FinalSize protocol.ByteCount
ReliableSize protocol.ByteCount
}
func parseResetStreamFrame(b []byte, isResetStreamAt bool, _ protocol.Version) (*ResetStreamFrame, int, error) {
startLen := len(b)
streamID, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
errorCode, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
finalSize, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
var reliableSize uint64
if isResetStreamAt {
reliableSize, l, err = quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
}
if reliableSize > finalSize {
return nil, 0, fmt.Errorf("RESET_STREAM_AT: reliable size can't be larger than final size (%d vs %d)", reliableSize, finalSize)
}
return &ResetStreamFrame{
StreamID: protocol.StreamID(streamID),
ErrorCode: qerr.StreamErrorCode(errorCode),
FinalSize: protocol.ByteCount(finalSize),
ReliableSize: protocol.ByteCount(reliableSize),
}, startLen - len(b), nil
}
func (f *ResetStreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
if f.ReliableSize == 0 {
b = quicvarint.Append(b, uint64(FrameTypeResetStream))
} else {
b = quicvarint.Append(b, uint64(FrameTypeResetStreamAt))
}
b = quicvarint.Append(b, uint64(f.StreamID))
b = quicvarint.Append(b, uint64(f.ErrorCode))
b = quicvarint.Append(b, uint64(f.FinalSize))
if f.ReliableSize > 0 {
b = quicvarint.Append(b, uint64(f.ReliableSize))
}
return b, nil
}
// Length of a written frame
func (f *ResetStreamFrame) Length(protocol.Version) protocol.ByteCount {
size := 1 // the frame type for both RESET_STREAM and RESET_STREAM_AT fits into 1 byte
if f.ReliableSize > 0 {
size += quicvarint.Len(uint64(f.ReliableSize))
}
return protocol.ByteCount(size + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize)))
}
package wire
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// A RetireConnectionIDFrame is a RETIRE_CONNECTION_ID frame
type RetireConnectionIDFrame struct {
SequenceNumber uint64
}
func parseRetireConnectionIDFrame(b []byte, _ protocol.Version) (*RetireConnectionIDFrame, int, error) {
seq, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
return &RetireConnectionIDFrame{SequenceNumber: seq}, l, nil
}
func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, byte(FrameTypeRetireConnectionID))
b = quicvarint.Append(b, f.SequenceNumber)
return b, nil
}
// Length of a written frame
func (f *RetireConnectionIDFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + protocol.ByteCount(quicvarint.Len(f.SequenceNumber))
}
package wire
import (
"errors"
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
// ParseShortHeader parses a short header packet.
// It must be called after header protection was removed.
// Otherwise, the check for the reserved bits will (most likely) fail.
func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.PacketNumber, _ protocol.PacketNumberLen, _ protocol.KeyPhaseBit, _ error) {
if len(data) == 0 {
return 0, 0, 0, 0, io.EOF
}
if data[0]&0x80 > 0 {
return 0, 0, 0, 0, errors.New("not a short header packet")
}
if data[0]&0x40 == 0 {
return 0, 0, 0, 0, errors.New("not a QUIC packet")
}
pnLen := protocol.PacketNumberLen(data[0]&0b11) + 1
if len(data) < 1+int(pnLen)+connIDLen {
return 0, 0, 0, 0, io.EOF
}
pos := 1 + connIDLen
pn, err := readPacketNumber(data[pos:], pnLen)
if err != nil {
return 0, 0, 0, 0, err
}
kp := protocol.KeyPhaseZero
if data[0]&0b100 > 0 {
kp = protocol.KeyPhaseOne
}
if data[0]&0x18 != 0 {
err = ErrInvalidReservedBits
}
return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err
}
// AppendShortHeader writes a short header.
func AppendShortHeader(b []byte, connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) ([]byte, error) {
typeByte := 0x40 | uint8(pnLen-1)
if kp == protocol.KeyPhaseOne {
typeByte |= byte(1 << 2)
}
b = append(b, typeByte)
b = append(b, connID.Bytes()...)
return appendPacketNumber(b, pn, pnLen)
}
func ShortHeaderLen(dest protocol.ConnectionID, pnLen protocol.PacketNumberLen) protocol.ByteCount {
return 1 + protocol.ByteCount(dest.Len()) + protocol.ByteCount(pnLen)
}
func LogShortHeader(logger utils.Logger, dest protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", dest, pn, pnLen, kp)
}
package wire
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/quicvarint"
)
// A StopSendingFrame is a STOP_SENDING frame
type StopSendingFrame struct {
StreamID protocol.StreamID
ErrorCode qerr.StreamErrorCode
}
// parseStopSendingFrame parses a STOP_SENDING frame
func parseStopSendingFrame(b []byte, _ protocol.Version) (*StopSendingFrame, int, error) {
startLen := len(b)
streamID, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
errorCode, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
return &StopSendingFrame{
StreamID: protocol.StreamID(streamID),
ErrorCode: qerr.StreamErrorCode(errorCode),
}, startLen - len(b), nil
}
// Length of a written frame
func (f *StopSendingFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID))+quicvarint.Len(uint64(f.ErrorCode)))
}
func (f *StopSendingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, byte(FrameTypeStopSending))
b = quicvarint.Append(b, uint64(f.StreamID))
b = quicvarint.Append(b, uint64(f.ErrorCode))
return b, nil
}
package wire
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame
type StreamDataBlockedFrame struct {
StreamID protocol.StreamID
MaximumStreamData protocol.ByteCount
}
func parseStreamDataBlockedFrame(b []byte, _ protocol.Version) (*StreamDataBlockedFrame, int, error) {
startLen := len(b)
sid, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
offset, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
return &StreamDataBlockedFrame{
StreamID: protocol.StreamID(sid),
MaximumStreamData: protocol.ByteCount(offset),
}, startLen - len(b) + l, nil
}
func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, 0x15)
b = quicvarint.Append(b, uint64(f.StreamID))
b = quicvarint.Append(b, uint64(f.MaximumStreamData))
return b, nil
}
// Length of a written frame
func (f *StreamDataBlockedFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID))+quicvarint.Len(uint64(f.MaximumStreamData)))
}
package wire
import (
"errors"
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// A StreamFrame of QUIC
type StreamFrame struct {
StreamID protocol.StreamID
Offset protocol.ByteCount
Data []byte
Fin bool
DataLenPresent bool
fromPool bool
}
func ParseStreamFrame(b []byte, typ FrameType, _ protocol.Version) (*StreamFrame, int, error) {
startLen := len(b)
hasOffset := typ&0b100 > 0
fin := typ&0b1 > 0
hasDataLen := typ&0b10 > 0
streamID, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
var offset uint64
if hasOffset {
offset, l, err = quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
}
var dataLen uint64
if hasDataLen {
var err error
var l int
dataLen, l, err = quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
if dataLen > uint64(len(b)) {
return nil, 0, io.EOF
}
} else {
// The rest of the packet is data
dataLen = uint64(len(b))
}
var frame *StreamFrame
if dataLen < protocol.MinStreamFrameBufferSize {
frame = &StreamFrame{}
if dataLen > 0 {
frame.Data = make([]byte, dataLen)
}
} else {
frame = GetStreamFrame()
// The STREAM frame can't be larger than the StreamFrame we obtained from the buffer,
// since those StreamFrames have a buffer length of the maximum packet size.
if dataLen > uint64(cap(frame.Data)) {
return nil, 0, io.EOF
}
frame.Data = frame.Data[:dataLen]
}
frame.StreamID = protocol.StreamID(streamID)
frame.Offset = protocol.ByteCount(offset)
frame.Fin = fin
frame.DataLenPresent = hasDataLen
if dataLen > 0 {
copy(frame.Data, b)
}
if frame.Offset+frame.DataLen() > protocol.MaxByteCount {
return nil, 0, errors.New("stream data overflows maximum offset")
}
return frame, startLen - len(b) + int(dataLen), nil
}
func (f *StreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
if len(f.Data) == 0 && !f.Fin {
return nil, errors.New("StreamFrame: attempting to write empty frame without FIN")
}
typ := byte(0x8)
if f.Fin {
typ ^= 0b1
}
hasOffset := f.Offset != 0
if f.DataLenPresent {
typ ^= 0b10
}
if hasOffset {
typ ^= 0b100
}
b = append(b, typ)
b = quicvarint.Append(b, uint64(f.StreamID))
if hasOffset {
b = quicvarint.Append(b, uint64(f.Offset))
}
if f.DataLenPresent {
b = quicvarint.Append(b, uint64(f.DataLen()))
}
b = append(b, f.Data...)
return b, nil
}
// Length returns the total length of the STREAM frame
func (f *StreamFrame) Length(protocol.Version) protocol.ByteCount {
length := 1 + quicvarint.Len(uint64(f.StreamID))
if f.Offset != 0 {
length += quicvarint.Len(uint64(f.Offset))
}
if f.DataLenPresent {
length += quicvarint.Len(uint64(f.DataLen()))
}
return protocol.ByteCount(length) + f.DataLen()
}
// DataLen gives the length of data in bytes
func (f *StreamFrame) DataLen() protocol.ByteCount {
return protocol.ByteCount(len(f.Data))
}
// MaxDataLen returns the maximum data length
// If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data).
func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, _ protocol.Version) protocol.ByteCount {
headerLen := 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID)))
if f.Offset != 0 {
headerLen += protocol.ByteCount(quicvarint.Len(uint64(f.Offset)))
}
if f.DataLenPresent {
// Pretend that the data size will be 1 byte.
// If it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterward
headerLen++
}
if headerLen > maxSize {
return 0
}
maxDataLen := maxSize - headerLen
if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 {
maxDataLen--
}
return maxDataLen
}
// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes.
// It returns if the frame was actually split.
// The frame might not be split if:
// * the size is large enough to fit the whole frame
// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil.
func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*StreamFrame, bool /* was splitting required */) {
if maxSize >= f.Length(version) {
return nil, false
}
n := f.MaxDataLen(maxSize, version)
if n == 0 {
return nil, true
}
new := GetStreamFrame()
new.StreamID = f.StreamID
new.Offset = f.Offset
new.Fin = false
new.DataLenPresent = f.DataLenPresent
// swap the data slices
new.Data, f.Data = f.Data, new.Data
new.fromPool, f.fromPool = f.fromPool, new.fromPool
f.Data = f.Data[:protocol.ByteCount(len(new.Data))-n]
copy(f.Data, new.Data[n:])
new.Data = new.Data[:n]
f.Offset += n
return new, true
}
func (f *StreamFrame) PutBack() {
putStreamFrame(f)
}
package wire
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// A StreamsBlockedFrame is a STREAMS_BLOCKED frame
type StreamsBlockedFrame struct {
Type protocol.StreamType
StreamLimit protocol.StreamNum
}
func parseStreamsBlockedFrame(b []byte, typ FrameType, _ protocol.Version) (*StreamsBlockedFrame, int, error) {
f := &StreamsBlockedFrame{}
//nolint:exhaustive // This will only be called with a BidiStreamBlockedFrameType or a UniStreamBlockedFrameType.
switch typ {
case FrameTypeBidiStreamBlocked:
f.Type = protocol.StreamTypeBidi
case FrameTypeUniStreamBlocked:
f.Type = protocol.StreamTypeUni
}
streamLimit, l, err := quicvarint.Parse(b)
if err != nil {
return nil, 0, replaceUnexpectedEOF(err)
}
f.StreamLimit = protocol.StreamNum(streamLimit)
if f.StreamLimit > protocol.MaxStreamCount {
return nil, 0, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit)
}
return f, l, nil
}
func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
switch f.Type {
case protocol.StreamTypeBidi:
b = append(b, byte(FrameTypeBidiStreamBlocked))
case protocol.StreamTypeUni:
b = append(b, byte(FrameTypeUniStreamBlocked))
}
b = quicvarint.Append(b, uint64(f.StreamLimit))
return b, nil
}
// Length of a written frame
func (f *StreamsBlockedFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamLimit)))
}
package wire
import (
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"net/netip"
"slices"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/quicvarint"
)
// AdditionalTransportParametersClient are additional transport parameters that will be added
// to the client's transport parameters.
// This is not intended for production use, but _only_ to increase the size of the ClientHello beyond
// the usual size of less than 1 MTU.
var AdditionalTransportParametersClient map[uint64][]byte
const transportParameterMarshalingVersion = 1
type transportParameterID uint64
const (
originalDestinationConnectionIDParameterID transportParameterID = 0x0
maxIdleTimeoutParameterID transportParameterID = 0x1
statelessResetTokenParameterID transportParameterID = 0x2
maxUDPPayloadSizeParameterID transportParameterID = 0x3
initialMaxDataParameterID transportParameterID = 0x4
initialMaxStreamDataBidiLocalParameterID transportParameterID = 0x5
initialMaxStreamDataBidiRemoteParameterID transportParameterID = 0x6
initialMaxStreamDataUniParameterID transportParameterID = 0x7
initialMaxStreamsBidiParameterID transportParameterID = 0x8
initialMaxStreamsUniParameterID transportParameterID = 0x9
ackDelayExponentParameterID transportParameterID = 0xa
maxAckDelayParameterID transportParameterID = 0xb
disableActiveMigrationParameterID transportParameterID = 0xc
preferredAddressParameterID transportParameterID = 0xd
activeConnectionIDLimitParameterID transportParameterID = 0xe
initialSourceConnectionIDParameterID transportParameterID = 0xf
retrySourceConnectionIDParameterID transportParameterID = 0x10
// RFC 9221
maxDatagramFrameSizeParameterID transportParameterID = 0x20
// https://datatracker.ietf.org/doc/draft-ietf-quic-reliable-stream-reset/06/
resetStreamAtParameterID transportParameterID = 0x17f7586d2cb571
)
// PreferredAddress is the value encoding in the preferred_address transport parameter
type PreferredAddress struct {
IPv4, IPv6 netip.AddrPort
ConnectionID protocol.ConnectionID
StatelessResetToken protocol.StatelessResetToken
}
// TransportParameters are parameters sent to the peer during the handshake
type TransportParameters struct {
InitialMaxStreamDataBidiLocal protocol.ByteCount
InitialMaxStreamDataBidiRemote protocol.ByteCount
InitialMaxStreamDataUni protocol.ByteCount
InitialMaxData protocol.ByteCount
MaxAckDelay time.Duration
AckDelayExponent uint8
DisableActiveMigration bool
MaxUDPPayloadSize protocol.ByteCount
MaxUniStreamNum protocol.StreamNum
MaxBidiStreamNum protocol.StreamNum
MaxIdleTimeout time.Duration
PreferredAddress *PreferredAddress
OriginalDestinationConnectionID protocol.ConnectionID
InitialSourceConnectionID protocol.ConnectionID
RetrySourceConnectionID *protocol.ConnectionID // use a pointer here to distinguish zero-length connection IDs from missing transport parameters
StatelessResetToken *protocol.StatelessResetToken
ActiveConnectionIDLimit uint64
MaxDatagramFrameSize protocol.ByteCount // RFC 9221
EnableResetStreamAt bool // https://datatracker.ietf.org/doc/draft-ietf-quic-reliable-stream-reset/06/
}
// Unmarshal the transport parameters
func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error {
if err := p.unmarshal(data, sentBy, false); err != nil {
return &qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: err.Error(),
}
}
return nil
}
func (p *TransportParameters) unmarshal(b []byte, sentBy protocol.Perspective, fromSessionTicket bool) error {
// needed to check that every parameter is only sent at most once
parameterIDs := make([]transportParameterID, 0, 32)
var (
readOriginalDestinationConnectionID bool
readInitialSourceConnectionID bool
readActiveConnectionIDLimit bool
)
p.AckDelayExponent = protocol.DefaultAckDelayExponent
p.MaxAckDelay = protocol.DefaultMaxAckDelay
p.MaxDatagramFrameSize = protocol.InvalidByteCount
for len(b) > 0 {
paramIDInt, l, err := quicvarint.Parse(b)
if err != nil {
return err
}
paramID := transportParameterID(paramIDInt)
b = b[l:]
paramLen, l, err := quicvarint.Parse(b)
if err != nil {
return err
}
b = b[l:]
if uint64(len(b)) < paramLen {
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(b), paramLen)
}
parameterIDs = append(parameterIDs, paramID)
switch paramID {
case activeConnectionIDLimitParameterID:
readActiveConnectionIDLimit = true
fallthrough
case maxIdleTimeoutParameterID,
maxUDPPayloadSizeParameterID,
initialMaxDataParameterID,
initialMaxStreamDataBidiLocalParameterID,
initialMaxStreamDataBidiRemoteParameterID,
initialMaxStreamDataUniParameterID,
initialMaxStreamsBidiParameterID,
initialMaxStreamsUniParameterID,
maxAckDelayParameterID,
maxDatagramFrameSizeParameterID,
ackDelayExponentParameterID:
if err := p.readNumericTransportParameter(b, paramID, int(paramLen)); err != nil {
return err
}
b = b[paramLen:]
case preferredAddressParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent a preferred_address")
}
if err := p.readPreferredAddress(b, int(paramLen)); err != nil {
return err
}
b = b[paramLen:]
case disableActiveMigrationParameterID:
if paramLen != 0 {
return fmt.Errorf("wrong length for disable_active_migration: %d (expected empty)", paramLen)
}
p.DisableActiveMigration = true
case statelessResetTokenParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent a stateless_reset_token")
}
if paramLen != 16 {
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
}
var token protocol.StatelessResetToken
if len(b) < len(token) {
return io.EOF
}
copy(token[:], b)
b = b[len(token):]
p.StatelessResetToken = &token
case originalDestinationConnectionIDParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent an original_destination_connection_id")
}
if paramLen > protocol.MaxConnIDLen {
return protocol.ErrInvalidConnectionIDLen
}
p.OriginalDestinationConnectionID = protocol.ParseConnectionID(b[:paramLen])
b = b[paramLen:]
readOriginalDestinationConnectionID = true
case initialSourceConnectionIDParameterID:
if paramLen > protocol.MaxConnIDLen {
return protocol.ErrInvalidConnectionIDLen
}
p.InitialSourceConnectionID = protocol.ParseConnectionID(b[:paramLen])
b = b[paramLen:]
readInitialSourceConnectionID = true
case retrySourceConnectionIDParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent a retry_source_connection_id")
}
if paramLen > protocol.MaxConnIDLen {
return protocol.ErrInvalidConnectionIDLen
}
connID := protocol.ParseConnectionID(b[:paramLen])
b = b[paramLen:]
p.RetrySourceConnectionID = &connID
case resetStreamAtParameterID:
if paramLen != 0 {
return fmt.Errorf("wrong length for reset_stream_at: %d (expected empty)", paramLen)
}
p.EnableResetStreamAt = true
default:
b = b[paramLen:]
}
}
if !readActiveConnectionIDLimit {
p.ActiveConnectionIDLimit = protocol.DefaultActiveConnectionIDLimit
}
if !fromSessionTicket {
if sentBy == protocol.PerspectiveServer && !readOriginalDestinationConnectionID {
return errors.New("missing original_destination_connection_id")
}
if p.MaxUDPPayloadSize == 0 {
p.MaxUDPPayloadSize = protocol.MaxByteCount
}
if !readInitialSourceConnectionID {
return errors.New("missing initial_source_connection_id")
}
}
// check that every transport parameter was sent at most once
slices.SortFunc(parameterIDs, func(a, b transportParameterID) int {
if a < b {
return -1
}
return 1
})
for i := 0; i < len(parameterIDs)-1; i++ {
if parameterIDs[i] == parameterIDs[i+1] {
return fmt.Errorf("received duplicate transport parameter %#x", parameterIDs[i])
}
}
return nil
}
func (p *TransportParameters) readPreferredAddress(b []byte, expectedLen int) error {
remainingLen := len(b)
pa := &PreferredAddress{}
if len(b) < 4+2+16+2+1 {
return io.EOF
}
var ipv4 [4]byte
copy(ipv4[:], b[:4])
port4 := binary.BigEndian.Uint16(b[4:])
b = b[4+2:]
if port4 != 0 && ipv4 != [4]byte{} {
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port4)
}
var ipv6 [16]byte
copy(ipv6[:], b[:16])
port6 := binary.BigEndian.Uint16(b[16:])
if port6 != 0 && ipv6 != [16]byte{} {
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port6)
}
b = b[16+2:]
connIDLen := int(b[0])
b = b[1:]
if connIDLen == 0 || connIDLen > protocol.MaxConnIDLen {
return fmt.Errorf("invalid connection ID length: %d", connIDLen)
}
if len(b) < connIDLen+len(pa.StatelessResetToken) {
return io.EOF
}
pa.ConnectionID = protocol.ParseConnectionID(b[:connIDLen])
b = b[connIDLen:]
copy(pa.StatelessResetToken[:], b)
b = b[len(pa.StatelessResetToken):]
if bytesRead := remainingLen - len(b); bytesRead != expectedLen {
return fmt.Errorf("expected preferred_address to be %d long, read %d bytes", expectedLen, bytesRead)
}
p.PreferredAddress = pa
return nil
}
func (p *TransportParameters) readNumericTransportParameter(b []byte, paramID transportParameterID, expectedLen int) error {
val, l, err := quicvarint.Parse(b)
if err != nil {
return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err)
}
if l != expectedLen {
return fmt.Errorf("inconsistent transport parameter length for transport parameter %#x", paramID)
}
//nolint:exhaustive // This only covers the numeric transport parameters.
switch paramID {
case initialMaxStreamDataBidiLocalParameterID:
p.InitialMaxStreamDataBidiLocal = protocol.ByteCount(val)
case initialMaxStreamDataBidiRemoteParameterID:
p.InitialMaxStreamDataBidiRemote = protocol.ByteCount(val)
case initialMaxStreamDataUniParameterID:
p.InitialMaxStreamDataUni = protocol.ByteCount(val)
case initialMaxDataParameterID:
p.InitialMaxData = protocol.ByteCount(val)
case initialMaxStreamsBidiParameterID:
p.MaxBidiStreamNum = protocol.StreamNum(val)
if p.MaxBidiStreamNum > protocol.MaxStreamCount {
return fmt.Errorf("initial_max_streams_bidi too large: %d (maximum %d)", p.MaxBidiStreamNum, protocol.MaxStreamCount)
}
case initialMaxStreamsUniParameterID:
p.MaxUniStreamNum = protocol.StreamNum(val)
if p.MaxUniStreamNum > protocol.MaxStreamCount {
return fmt.Errorf("initial_max_streams_uni too large: %d (maximum %d)", p.MaxUniStreamNum, protocol.MaxStreamCount)
}
case maxIdleTimeoutParameterID:
p.MaxIdleTimeout = max(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond)
case maxUDPPayloadSizeParameterID:
if val < 1200 {
return fmt.Errorf("invalid value for max_udp_payload_size: %d (minimum 1200)", val)
}
p.MaxUDPPayloadSize = protocol.ByteCount(val)
case ackDelayExponentParameterID:
if val > protocol.MaxAckDelayExponent {
return fmt.Errorf("invalid value for ack_delay_exponent: %d (maximum %d)", val, protocol.MaxAckDelayExponent)
}
p.AckDelayExponent = uint8(val)
case maxAckDelayParameterID:
if val > uint64(protocol.MaxMaxAckDelay/time.Millisecond) {
return fmt.Errorf("invalid value for max_ack_delay: %dms (maximum %dms)", val, protocol.MaxMaxAckDelay/time.Millisecond)
}
p.MaxAckDelay = time.Duration(val) * time.Millisecond
case activeConnectionIDLimitParameterID:
if val < 2 {
return fmt.Errorf("invalid value for active_connection_id_limit: %d (minimum 2)", val)
}
p.ActiveConnectionIDLimit = val
case maxDatagramFrameSizeParameterID:
p.MaxDatagramFrameSize = protocol.ByteCount(val)
default:
return fmt.Errorf("TransportParameter BUG: transport parameter %d not found", paramID)
}
return nil
}
// Marshal the transport parameters
func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
// Typical Transport Parameters consume around 110 bytes, depending on the exact values,
// especially the lengths of the Connection IDs.
// Allocate 256 bytes, so we won't have to grow the slice in any case.
b := make([]byte, 0, 256)
// add a greased value
random := make([]byte, 18)
rand.Read(random)
b = quicvarint.Append(b, 27+31*uint64(random[0]))
length := random[1] % 16
b = quicvarint.Append(b, uint64(length))
b = append(b, random[2:2+length]...)
// initial_max_stream_data_bidi_local
b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal))
// initial_max_stream_data_bidi_remote
b = p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote))
// initial_max_stream_data_uni
b = p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni))
// initial_max_data
b = p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData))
// initial_max_bidi_streams
b = p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum))
// initial_max_uni_streams
b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum))
// idle_timeout
b = p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond))
// max_udp_payload_size
if p.MaxUDPPayloadSize > 0 {
b = p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(p.MaxUDPPayloadSize))
}
// max_ack_delay
// Only send it if is different from the default value.
if p.MaxAckDelay != protocol.DefaultMaxAckDelay {
b = p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond))
}
// ack_delay_exponent
// Only send it if is different from the default value.
if p.AckDelayExponent != protocol.DefaultAckDelayExponent {
b = p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent))
}
// disable_active_migration
if p.DisableActiveMigration {
b = quicvarint.Append(b, uint64(disableActiveMigrationParameterID))
b = quicvarint.Append(b, 0)
}
if pers == protocol.PerspectiveServer {
// stateless_reset_token
if p.StatelessResetToken != nil {
b = quicvarint.Append(b, uint64(statelessResetTokenParameterID))
b = quicvarint.Append(b, 16)
b = append(b, p.StatelessResetToken[:]...)
}
// original_destination_connection_id
b = quicvarint.Append(b, uint64(originalDestinationConnectionIDParameterID))
b = quicvarint.Append(b, uint64(p.OriginalDestinationConnectionID.Len()))
b = append(b, p.OriginalDestinationConnectionID.Bytes()...)
// preferred_address
if p.PreferredAddress != nil {
b = quicvarint.Append(b, uint64(preferredAddressParameterID))
b = quicvarint.Append(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16)
if p.PreferredAddress.IPv4.IsValid() {
ipv4 := p.PreferredAddress.IPv4.Addr().As4()
b = append(b, ipv4[:]...)
b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv4.Port())
} else {
b = append(b, make([]byte, 6)...)
}
if p.PreferredAddress.IPv6.IsValid() {
ipv6 := p.PreferredAddress.IPv6.Addr().As16()
b = append(b, ipv6[:]...)
b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv6.Port())
} else {
b = append(b, make([]byte, 18)...)
}
b = append(b, uint8(p.PreferredAddress.ConnectionID.Len()))
b = append(b, p.PreferredAddress.ConnectionID.Bytes()...)
b = append(b, p.PreferredAddress.StatelessResetToken[:]...)
}
}
// active_connection_id_limit
if p.ActiveConnectionIDLimit != protocol.DefaultActiveConnectionIDLimit {
b = p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit)
}
// initial_source_connection_id
b = quicvarint.Append(b, uint64(initialSourceConnectionIDParameterID))
b = quicvarint.Append(b, uint64(p.InitialSourceConnectionID.Len()))
b = append(b, p.InitialSourceConnectionID.Bytes()...)
// retry_source_connection_id
if pers == protocol.PerspectiveServer && p.RetrySourceConnectionID != nil {
b = quicvarint.Append(b, uint64(retrySourceConnectionIDParameterID))
b = quicvarint.Append(b, uint64(p.RetrySourceConnectionID.Len()))
b = append(b, p.RetrySourceConnectionID.Bytes()...)
}
// QUIC datagrams
if p.MaxDatagramFrameSize != protocol.InvalidByteCount {
b = p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize))
}
// QUIC Stream Resets with Partial Delivery
if p.EnableResetStreamAt {
b = quicvarint.Append(b, uint64(resetStreamAtParameterID))
b = quicvarint.Append(b, 0)
}
if pers == protocol.PerspectiveClient && len(AdditionalTransportParametersClient) > 0 {
for k, v := range AdditionalTransportParametersClient {
b = quicvarint.Append(b, k)
b = quicvarint.Append(b, uint64(len(v)))
b = append(b, v...)
}
}
return b
}
func (p *TransportParameters) marshalVarintParam(b []byte, id transportParameterID, val uint64) []byte {
b = quicvarint.Append(b, uint64(id))
b = quicvarint.Append(b, uint64(quicvarint.Len(val)))
return quicvarint.Append(b, val)
}
// MarshalForSessionTicket marshals the transport parameters we save in the session ticket.
// When sending a 0-RTT enabled TLS session tickets, we need to save the transport parameters.
// The client will remember the transport parameters used in the last session,
// and apply those to the 0-RTT data it sends.
// Saving the transport parameters in the ticket gives the server the option to reject 0-RTT
// if the transport parameters changed.
// Since the session ticket is encrypted, the serialization format is defined by the server.
// For convenience, we use the same format that we also use for sending the transport parameters.
func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte {
b = quicvarint.Append(b, transportParameterMarshalingVersion)
// initial_max_stream_data_bidi_local
b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal))
// initial_max_stream_data_bidi_remote
b = p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote))
// initial_max_stream_data_uni
b = p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni))
// initial_max_data
b = p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData))
// initial_max_bidi_streams
b = p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum))
// initial_max_uni_streams
b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum))
// active_connection_id_limit
b = p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit)
// max_datagram_frame_size
if p.MaxDatagramFrameSize != protocol.InvalidByteCount {
b = p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize))
}
// reset_stream_at
if p.EnableResetStreamAt {
b = quicvarint.Append(b, uint64(resetStreamAtParameterID))
b = quicvarint.Append(b, 0)
}
return b
}
// UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket.
func (p *TransportParameters) UnmarshalFromSessionTicket(b []byte) error {
version, l, err := quicvarint.Parse(b)
if err != nil {
return err
}
if version != transportParameterMarshalingVersion {
return fmt.Errorf("unknown transport parameter marshaling version: %d", version)
}
return p.unmarshal(b[l:], protocol.PerspectiveServer, true)
}
// ValidFor0RTT checks if the transport parameters match those saved in the session ticket.
func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool {
if saved.MaxDatagramFrameSize != protocol.InvalidByteCount && (p.MaxDatagramFrameSize == protocol.InvalidByteCount || p.MaxDatagramFrameSize < saved.MaxDatagramFrameSize) {
return false
}
return p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal &&
p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote &&
p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni &&
p.InitialMaxData >= saved.InitialMaxData &&
p.MaxBidiStreamNum >= saved.MaxBidiStreamNum &&
p.MaxUniStreamNum >= saved.MaxUniStreamNum &&
p.ActiveConnectionIDLimit == saved.ActiveConnectionIDLimit
}
// ValidForUpdate checks that the new transport parameters don't reduce limits after resuming a 0-RTT connection.
// It is only used on the client side.
func (p *TransportParameters) ValidForUpdate(saved *TransportParameters) bool {
if saved.MaxDatagramFrameSize != protocol.InvalidByteCount && (p.MaxDatagramFrameSize == protocol.InvalidByteCount || p.MaxDatagramFrameSize < saved.MaxDatagramFrameSize) {
return false
}
return p.ActiveConnectionIDLimit >= saved.ActiveConnectionIDLimit &&
p.InitialMaxData >= saved.InitialMaxData &&
p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal &&
p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote &&
p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni &&
p.MaxBidiStreamNum >= saved.MaxBidiStreamNum &&
p.MaxUniStreamNum >= saved.MaxUniStreamNum
}
// String returns a string representation, intended for logging.
func (p *TransportParameters) String() string {
logString := "&wire.TransportParameters{OriginalDestinationConnectionID: %s, InitialSourceConnectionID: %s, "
logParams := []any{p.OriginalDestinationConnectionID, p.InitialSourceConnectionID}
if p.RetrySourceConnectionID != nil {
logString += "RetrySourceConnectionID: %s, "
logParams = append(logParams, p.RetrySourceConnectionID)
}
logString += "InitialMaxStreamDataBidiLocal: %d, InitialMaxStreamDataBidiRemote: %d, InitialMaxStreamDataUni: %d, InitialMaxData: %d, MaxBidiStreamNum: %d, MaxUniStreamNum: %d, MaxIdleTimeout: %s, AckDelayExponent: %d, MaxAckDelay: %s, ActiveConnectionIDLimit: %d"
logParams = append(logParams, []any{p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreamNum, p.MaxUniStreamNum, p.MaxIdleTimeout, p.AckDelayExponent, p.MaxAckDelay, p.ActiveConnectionIDLimit}...)
if p.StatelessResetToken != nil { // the client never sends a stateless reset token
logString += ", StatelessResetToken: %#x"
logParams = append(logParams, *p.StatelessResetToken)
}
if p.MaxDatagramFrameSize != protocol.InvalidByteCount {
logString += ", MaxDatagramFrameSize: %d"
logParams = append(logParams, p.MaxDatagramFrameSize)
}
logString += ", EnableResetStreamAt: %t"
logParams = append(logParams, p.EnableResetStreamAt)
logString += "}"
return fmt.Sprintf(logString, logParams...)
}
package wire
import (
"crypto/rand"
"encoding/binary"
"errors"
"github.com/quic-go/quic-go/internal/protocol"
)
// ParseVersionNegotiationPacket parses a Version Negotiation packet.
func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.Version, _ error) {
n, dest, src, err := ParseArbitraryLenConnectionIDs(b)
if err != nil {
return nil, nil, nil, err
}
b = b[n:]
if len(b) == 0 {
//nolint:staticcheck // SA1021: the packet is called Version Negotiation packet
return nil, nil, nil, errors.New("Version Negotiation packet has empty version list")
}
if len(b)%4 != 0 {
//nolint:staticcheck // SA1021: the packet is called Version Negotiation packet
return nil, nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length")
}
versions := make([]protocol.Version, len(b)/4)
for i := 0; len(b) > 0; i++ {
versions[i] = protocol.Version(binary.BigEndian.Uint32(b[:4]))
b = b[4:]
}
return dest, src, versions, nil
}
// ComposeVersionNegotiation composes a Version Negotiation
func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.Version) []byte {
greasedVersions := protocol.GetGreasedVersions(versions)
expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4
buf := make([]byte, 1+4 /* type byte and version field */, expectedLen)
_, _ = rand.Read(buf[:1]) // ignore the error here. It is not critical to have perfect random here.
// Setting the "QUIC bit" (0x40) is not required by the RFC,
// but it allows clients to demultiplex QUIC with a long list of other protocols.
// See RFC 9443 and https://mailarchive.ietf.org/arch/msg/quic/oR4kxGKY6mjtPC1CZegY1ED4beg/ for details.
buf[0] |= 0xc0
// The next 4 bytes are left at 0 (version number).
buf = append(buf, uint8(destConnID.Len()))
buf = append(buf, destConnID.Bytes()...)
buf = append(buf, uint8(srcConnID.Len()))
buf = append(buf, srcConnID.Bytes()...)
for _, v := range greasedVersions {
buf = binary.BigEndian.AppendUint32(buf, uint32(v))
}
return buf
}
// Code generated by generate_multiplexer.go; DO NOT EDIT.
package logging
import (
"net"
"time"
)
func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTracer {
if len(tracers) == 0 {
return nil
}
if len(tracers) == 1 {
return tracers[0]
}
return &ConnectionTracer{
StartedConnection: func(local net.Addr, remote net.Addr, srcConnID ConnectionID, destConnID ConnectionID) {
for _, t := range tracers {
if t.StartedConnection != nil {
t.StartedConnection(local, remote, srcConnID, destConnID)
}
}
},
NegotiatedVersion: func(chosen Version, clientVersions []Version, serverVersions []Version) {
for _, t := range tracers {
if t.NegotiatedVersion != nil {
t.NegotiatedVersion(chosen, clientVersions, serverVersions)
}
}
},
ClosedConnection: func(err error) {
for _, t := range tracers {
if t.ClosedConnection != nil {
t.ClosedConnection(err)
}
}
},
SentTransportParameters: func(parameters *TransportParameters) {
for _, t := range tracers {
if t.SentTransportParameters != nil {
t.SentTransportParameters(parameters)
}
}
},
ReceivedTransportParameters: func(parameters *TransportParameters) {
for _, t := range tracers {
if t.ReceivedTransportParameters != nil {
t.ReceivedTransportParameters(parameters)
}
}
},
RestoredTransportParameters: func(parameters *TransportParameters) {
for _, t := range tracers {
if t.RestoredTransportParameters != nil {
t.RestoredTransportParameters(parameters)
}
}
},
SentLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) {
for _, t := range tracers {
if t.SentLongHeaderPacket != nil {
t.SentLongHeaderPacket(hdr, size, ecn, ack, frames)
}
}
},
SentShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) {
for _, t := range tracers {
if t.SentShortHeaderPacket != nil {
t.SentShortHeaderPacket(hdr, size, ecn, ack, frames)
}
}
},
ReceivedVersionNegotiationPacket: func(dest ArbitraryLenConnectionID, src ArbitraryLenConnectionID, versions []Version) {
for _, t := range tracers {
if t.ReceivedVersionNegotiationPacket != nil {
t.ReceivedVersionNegotiationPacket(dest, src, versions)
}
}
},
ReceivedRetry: func(hdr *Header) {
for _, t := range tracers {
if t.ReceivedRetry != nil {
t.ReceivedRetry(hdr)
}
}
},
ReceivedLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) {
for _, t := range tracers {
if t.ReceivedLongHeaderPacket != nil {
t.ReceivedLongHeaderPacket(hdr, size, ecn, frames)
}
}
},
ReceivedShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) {
for _, t := range tracers {
if t.ReceivedShortHeaderPacket != nil {
t.ReceivedShortHeaderPacket(hdr, size, ecn, frames)
}
}
},
BufferedPacket: func(packetType PacketType, size ByteCount) {
for _, t := range tracers {
if t.BufferedPacket != nil {
t.BufferedPacket(packetType, size)
}
}
},
DroppedPacket: func(packetType PacketType, pn PacketNumber, size ByteCount, reason PacketDropReason) {
for _, t := range tracers {
if t.DroppedPacket != nil {
t.DroppedPacket(packetType, pn, size, reason)
}
}
},
UpdatedMetrics: func(rttStats *RTTStats, cwnd ByteCount, bytesInFlight ByteCount, packetsInFlight int) {
for _, t := range tracers {
if t.UpdatedMetrics != nil {
t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight)
}
}
},
AcknowledgedPacket: func(encLevel EncryptionLevel, pn PacketNumber) {
for _, t := range tracers {
if t.AcknowledgedPacket != nil {
t.AcknowledgedPacket(encLevel, pn)
}
}
},
LostPacket: func(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) {
for _, t := range tracers {
if t.LostPacket != nil {
t.LostPacket(encLevel, pn, reason)
}
}
},
UpdatedMTU: func(mtu ByteCount, done bool) {
for _, t := range tracers {
if t.UpdatedMTU != nil {
t.UpdatedMTU(mtu, done)
}
}
},
UpdatedCongestionState: func(state CongestionState) {
for _, t := range tracers {
if t.UpdatedCongestionState != nil {
t.UpdatedCongestionState(state)
}
}
},
UpdatedPTOCount: func(value uint32) {
for _, t := range tracers {
if t.UpdatedPTOCount != nil {
t.UpdatedPTOCount(value)
}
}
},
UpdatedKeyFromTLS: func(encLevel EncryptionLevel, p Perspective) {
for _, t := range tracers {
if t.UpdatedKeyFromTLS != nil {
t.UpdatedKeyFromTLS(encLevel, p)
}
}
},
UpdatedKey: func(keyPhase KeyPhase, remote bool) {
for _, t := range tracers {
if t.UpdatedKey != nil {
t.UpdatedKey(keyPhase, remote)
}
}
},
DroppedEncryptionLevel: func(encLevel EncryptionLevel) {
for _, t := range tracers {
if t.DroppedEncryptionLevel != nil {
t.DroppedEncryptionLevel(encLevel)
}
}
},
DroppedKey: func(keyPhase KeyPhase) {
for _, t := range tracers {
if t.DroppedKey != nil {
t.DroppedKey(keyPhase)
}
}
},
SetLossTimer: func(timerType TimerType, encLevel EncryptionLevel, time time.Time) {
for _, t := range tracers {
if t.SetLossTimer != nil {
t.SetLossTimer(timerType, encLevel, time)
}
}
},
LossTimerExpired: func(timerType TimerType, encLevel EncryptionLevel) {
for _, t := range tracers {
if t.LossTimerExpired != nil {
t.LossTimerExpired(timerType, encLevel)
}
}
},
LossTimerCanceled: func() {
for _, t := range tracers {
if t.LossTimerCanceled != nil {
t.LossTimerCanceled()
}
}
},
ECNStateUpdated: func(state ECNState, trigger ECNStateTrigger) {
for _, t := range tracers {
if t.ECNStateUpdated != nil {
t.ECNStateUpdated(state, trigger)
}
}
},
ChoseALPN: func(protocol string) {
for _, t := range tracers {
if t.ChoseALPN != nil {
t.ChoseALPN(protocol)
}
}
},
Close: func() {
for _, t := range tracers {
if t.Close != nil {
t.Close()
}
}
},
Debug: func(name string, msg string) {
for _, t := range tracers {
if t.Debug != nil {
t.Debug(name, msg)
}
}
},
}
}
package logging
import (
"github.com/quic-go/quic-go/internal/protocol"
)
// PacketTypeFromHeader determines the packet type from a *wire.Header.
func PacketTypeFromHeader(hdr *Header) PacketType {
if hdr.Version == 0 {
return PacketTypeVersionNegotiation
}
switch hdr.Type {
case protocol.PacketTypeInitial:
return PacketTypeInitial
case protocol.PacketTypeHandshake:
return PacketTypeHandshake
case protocol.PacketType0RTT:
return PacketType0RTT
case protocol.PacketTypeRetry:
return PacketTypeRetry
default:
return PacketTypeNotDetermined
}
}
// Code generated by generate_multiplexer.go; DO NOT EDIT.
package logging
import "net"
func NewMultiplexedTracer(tracers ...*Tracer) *Tracer {
if len(tracers) == 0 {
return nil
}
if len(tracers) == 1 {
return tracers[0]
}
return &Tracer{
SentPacket: func(dest net.Addr, hdr *Header, size ByteCount, frames []Frame) {
for _, t := range tracers {
if t.SentPacket != nil {
t.SentPacket(dest, hdr, size, frames)
}
}
},
SentVersionNegotiationPacket: func(dest net.Addr, destConnID ArbitraryLenConnectionID, srcConnID ArbitraryLenConnectionID, versions []Version) {
for _, t := range tracers {
if t.SentVersionNegotiationPacket != nil {
t.SentVersionNegotiationPacket(dest, destConnID, srcConnID, versions)
}
}
},
DroppedPacket: func(addr net.Addr, packetType PacketType, size ByteCount, reason PacketDropReason) {
for _, t := range tracers {
if t.DroppedPacket != nil {
t.DroppedPacket(addr, packetType, size, reason)
}
}
},
Debug: func(name string, msg string) {
for _, t := range tracers {
if t.Debug != nil {
t.Debug(name, msg)
}
}
},
Close: func() {
for _, t := range tracers {
if t.Close != nil {
t.Close()
}
}
},
}
}
package quic
import (
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
type mtuDiscoverer interface {
// Start starts the MTU discovery process.
// It's unnecessary to call ShouldSendProbe before that.
Start(now time.Time)
ShouldSendProbe(now time.Time) bool
CurrentSize() protocol.ByteCount
GetPing(now time.Time) (ping ackhandler.Frame, datagramSize protocol.ByteCount)
Reset(now time.Time, start, max protocol.ByteCount)
}
const (
// At some point, we have to stop searching for a higher MTU.
// We're happy to send a packet that's 10 bytes smaller than the actual MTU.
maxMTUDiff protocol.ByteCount = 20
// send a probe packet every mtuProbeDelay RTTs
mtuProbeDelay = 5
// Once maxLostMTUProbes MTU probe packets larger than a certain size are lost,
// MTU discovery won't probe for larger MTUs than this size.
// The algorithm used here is resilient to packet loss of (maxLostMTUProbes - 1) packets.
maxLostMTUProbes = 3
)
// The Path MTU is found by sending a larger packet every now and then.
// If the packet is acknowledged, we conclude that the path supports this larger packet size.
// If the packet is lost, this can mean one of two things:
// 1. The path doesn't support this larger packet size, or
// 2. The packet was lost due to packet loss, independent of its size.
// The algorithm used here is resilient to packet loss of (maxLostMTUProbes - 1) packets.
// For simplicty, the following example use maxLostMTUProbes = 2.
//
// Initialization:
// |------------------------------------------------------------------------------|
// min max
//
// The first MTU probe packet will have size (min+max)/2.
// Assume that this packet is acknowledged. We can now move the min marker,
// and continue the search in the resulting interval.
//
// If 1st probe packet acknowledged:
// |---------------------------------------|--------------------------------------|
// min max
//
// If 1st probe packet lost:
// |---------------------------------------|--------------------------------------|
// min lost[0] max
//
// We can't conclude that the path doesn't support this packet size, since the loss of the probe
// packet could have been unrelated to the packet size. A larger probe packet will be sent later on.
// After a loss, the next probe packet has size (min+lost[0])/2.
// Now assume this probe packet is acknowledged:
//
// 2nd probe packet acknowledged:
// |------------------|--------------------|--------------------------------------|
// min lost[0] max
//
// First of all, we conclude that the path supports at least this MTU. That's progress!
// Second, we probe a bit more aggressively with the next probe packet:
// After an acknowledgement, the next probe packet has size (min+max)/2.
// This means we'll send a packet larger than the first probe packet (which was lost).
//
// If 3rd probe packet acknowledged:
// |-------------------------------------------------|----------------------------|
// min max
//
// We can conclude that the loss of the 1st probe packet was not due to its size, and
// continue searching in a much smaller interval now.
//
// If 3rd probe packet lost:
// |------------------|--------------------|---------|----------------------------|
// min lost[0] max
//
// Since in our example numPTOProbes = 2, and we lost 2 packets smaller than max, we
// conclude that this packet size is not supported on the path, and reduce the maximum
// value of the search interval.
//
// MTU discovery concludes once the interval min and max has been narrowed down to maxMTUDiff.
type mtuFinder struct {
lastProbeTime time.Time
rttStats *utils.RTTStats
inFlight protocol.ByteCount // the size of the probe packet currently in flight. InvalidByteCount if none is in flight
min protocol.ByteCount
// on initialization, we treat the maximum size as the first "lost" packet
lost [maxLostMTUProbes]protocol.ByteCount
lastProbeWasLost bool
// The generation is used to ignore ACKs / losses for probe packets sent before a reset.
// Resets happen when the connection is migrated to a new path.
// We're therefore not concerned about overflows of this counter.
generation uint8
tracer *logging.ConnectionTracer
}
var _ mtuDiscoverer = &mtuFinder{}
func newMTUDiscoverer(
rttStats *utils.RTTStats,
start, max protocol.ByteCount,
tracer *logging.ConnectionTracer,
) *mtuFinder {
f := &mtuFinder{
inFlight: protocol.InvalidByteCount,
rttStats: rttStats,
tracer: tracer,
}
f.init(start, max)
return f
}
func (f *mtuFinder) init(start, max protocol.ByteCount) {
f.min = start
for i := range f.lost {
if i == 0 {
f.lost[i] = max
continue
}
f.lost[i] = protocol.InvalidByteCount
}
}
func (f *mtuFinder) done() bool {
return f.max()-f.min <= maxMTUDiff+1
}
func (f *mtuFinder) max() protocol.ByteCount {
for i, v := range f.lost {
if v == protocol.InvalidByteCount {
return f.lost[i-1]
}
}
return f.lost[len(f.lost)-1]
}
func (f *mtuFinder) Start(now time.Time) {
f.lastProbeTime = now // makes sure the first probe packet is not sent immediately
}
func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
if f.lastProbeTime.IsZero() {
return false
}
if f.inFlight != protocol.InvalidByteCount || f.done() {
return false
}
return !now.Before(f.lastProbeTime.Add(mtuProbeDelay * f.rttStats.SmoothedRTT()))
}
func (f *mtuFinder) GetPing(now time.Time) (ackhandler.Frame, protocol.ByteCount) {
var size protocol.ByteCount
if f.lastProbeWasLost {
size = (f.min + f.lost[0]) / 2
} else {
size = (f.min + f.max()) / 2
}
f.lastProbeTime = now
f.inFlight = size
return ackhandler.Frame{
Frame: &wire.PingFrame{},
Handler: &mtuFinderAckHandler{mtuFinder: f, generation: f.generation},
}, size
}
func (f *mtuFinder) CurrentSize() protocol.ByteCount {
return f.min
}
func (f *mtuFinder) Reset(now time.Time, start, max protocol.ByteCount) {
f.generation++
f.lastProbeTime = now
f.lastProbeWasLost = false
f.inFlight = protocol.InvalidByteCount
f.init(start, max)
}
type mtuFinderAckHandler struct {
*mtuFinder
generation uint8
}
var _ ackhandler.FrameHandler = &mtuFinderAckHandler{}
func (h *mtuFinderAckHandler) OnAcked(wire.Frame) {
if h.generation != h.mtuFinder.generation {
// ACK for probe sent before reset
return
}
size := h.inFlight
if size == protocol.InvalidByteCount {
panic("OnAcked callback called although there's no MTU probe packet in flight")
}
h.inFlight = protocol.InvalidByteCount
h.min = size
h.lastProbeWasLost = false
// remove all values smaller than size from the lost array
var j int
for i, v := range h.lost {
if size < v {
j = i
break
}
}
if j > 0 {
for i := 0; i < len(h.lost); i++ {
if i+j < len(h.lost) {
h.lost[i] = h.lost[i+j]
} else {
h.lost[i] = protocol.InvalidByteCount
}
}
}
if h.tracer != nil && h.tracer.UpdatedMTU != nil {
h.tracer.UpdatedMTU(size, h.done())
}
}
func (h *mtuFinderAckHandler) OnLost(wire.Frame) {
if h.generation != h.mtuFinder.generation {
// probe sent before reset received
return
}
size := h.inFlight
if size == protocol.InvalidByteCount {
panic("OnLost callback called although there's no MTU probe packet in flight")
}
h.lastProbeWasLost = true
h.inFlight = protocol.InvalidByteCount
for i, v := range h.lost {
if size < v {
copy(h.lost[i+1:], h.lost[i:])
h.lost[i] = size
break
}
}
}
package quic
import (
crand "crypto/rand"
"encoding/binary"
"errors"
"fmt"
"math/rand/v2"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
)
var errNothingToPack = errors.New("nothing to pack")
type packer interface {
PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (*coalescedPacket, error)
PackAckOnlyPacket(maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
AppendPacket(_ *packetBuffer, maxPacketSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, error)
PackPTOProbePacket(_ protocol.EncryptionLevel, _ protocol.ByteCount, addPingIfEmpty bool, now time.Time, v protocol.Version) (*coalescedPacket, error)
PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)
PackPathProbePacket(protocol.ConnectionID, []ackhandler.Frame, protocol.Version) (shortHeaderPacket, *packetBuffer, error)
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error)
SetToken([]byte)
}
type sealer interface {
handshake.LongHeaderSealer
}
type payload struct {
streamFrames []ackhandler.StreamFrame
frames []ackhandler.Frame
ack *wire.AckFrame
length protocol.ByteCount
}
type longHeaderPacket struct {
header *wire.ExtendedHeader
ack *wire.AckFrame
frames []ackhandler.Frame
streamFrames []ackhandler.StreamFrame // only used for 0-RTT packets
length protocol.ByteCount
}
type shortHeaderPacket struct {
PacketNumber protocol.PacketNumber
Frames []ackhandler.Frame
StreamFrames []ackhandler.StreamFrame
Ack *wire.AckFrame
Length protocol.ByteCount
IsPathMTUProbePacket bool
IsPathProbePacket bool
// used for logging
DestConnID protocol.ConnectionID
PacketNumberLen protocol.PacketNumberLen
KeyPhase protocol.KeyPhaseBit
}
func (p *shortHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.Frames) }
type coalescedPacket struct {
buffer *packetBuffer
longHdrPackets []*longHeaderPacket
shortHdrPacket *shortHeaderPacket
}
// IsOnlyShortHeaderPacket says if this packet only contains a short header packet (and no long header packets).
func (p *coalescedPacket) IsOnlyShortHeaderPacket() bool {
return len(p.longHdrPackets) == 0 && p.shortHdrPacket != nil
}
func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel {
//nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data).
switch p.header.Type {
case protocol.PacketTypeInitial:
return protocol.EncryptionInitial
case protocol.PacketTypeHandshake:
return protocol.EncryptionHandshake
case protocol.PacketType0RTT:
return protocol.Encryption0RTT
default:
panic("can't determine encryption level")
}
}
func (p *longHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.frames) }
type packetNumberManager interface {
PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)
PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber
}
type sealingManager interface {
GetInitialSealer() (handshake.LongHeaderSealer, error)
GetHandshakeSealer() (handshake.LongHeaderSealer, error)
Get0RTTSealer() (handshake.LongHeaderSealer, error)
Get1RTTSealer() (handshake.ShortHeaderSealer, error)
}
type frameSource interface {
HasData() bool
Append([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount, time.Time, protocol.Version) ([]ackhandler.Frame, []ackhandler.StreamFrame, protocol.ByteCount)
}
type ackFrameSource interface {
GetAckFrame(_ protocol.EncryptionLevel, now time.Time, onlyIfQueued bool) *wire.AckFrame
}
type packetPacker struct {
srcConnID protocol.ConnectionID
getDestConnID func() protocol.ConnectionID
perspective protocol.Perspective
cryptoSetup sealingManager
initialStream *initialCryptoStream
handshakeStream *cryptoStream
token []byte
pnManager packetNumberManager
framer frameSource
acks ackFrameSource
datagramQueue *datagramQueue
retransmissionQueue *retransmissionQueue
rand rand.Rand
numNonAckElicitingAcks int
}
var _ packer = &packetPacker{}
func newPacketPacker(
srcConnID protocol.ConnectionID,
getDestConnID func() protocol.ConnectionID,
initialStream *initialCryptoStream,
handshakeStream *cryptoStream,
packetNumberManager packetNumberManager,
retransmissionQueue *retransmissionQueue,
cryptoSetup sealingManager,
framer frameSource,
acks ackFrameSource,
datagramQueue *datagramQueue,
perspective protocol.Perspective,
) *packetPacker {
var b [16]byte
_, _ = crand.Read(b[:])
return &packetPacker{
cryptoSetup: cryptoSetup,
getDestConnID: getDestConnID,
srcConnID: srcConnID,
initialStream: initialStream,
handshakeStream: handshakeStream,
retransmissionQueue: retransmissionQueue,
datagramQueue: datagramQueue,
perspective: perspective,
framer: framer,
acks: acks,
rand: *rand.New(rand.NewPCG(binary.BigEndian.Uint64(b[:8]), binary.BigEndian.Uint64(b[8:]))),
pnManager: packetNumberManager,
}
}
// PackConnectionClose packs a packet that closes the connection with a transport error.
func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
var reason string
// don't send details of crypto errors
if !e.ErrorCode.IsCryptoError() {
reason = e.ErrorMessage
}
return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason, maxPacketSize, v)
}
// PackApplicationClose packs a packet that closes the connection with an application error.
func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) {
return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, maxPacketSize, v)
}
func (p *packetPacker) packConnectionClose(
isApplicationError bool,
errorCode uint64,
frameType uint64,
reason string,
maxPacketSize protocol.ByteCount,
v protocol.Version,
) (*coalescedPacket, error) {
var sealers [4]sealer
var hdrs [3]*wire.ExtendedHeader
var payloads [4]payload
var size protocol.ByteCount
var connID protocol.ConnectionID
var oneRTTPacketNumber protocol.PacketNumber
var oneRTTPacketNumberLen protocol.PacketNumberLen
var keyPhase protocol.KeyPhaseBit // only set for 1-RTT
var numLongHdrPackets uint8
encLevels := [4]protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, protocol.Encryption1RTT}
for i, encLevel := range encLevels {
if p.perspective == protocol.PerspectiveServer && encLevel == protocol.Encryption0RTT {
continue
}
ccf := &wire.ConnectionCloseFrame{
IsApplicationError: isApplicationError,
ErrorCode: errorCode,
FrameType: frameType,
ReasonPhrase: reason,
}
// don't send application errors in Initial or Handshake packets
if isApplicationError && (encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake) {
ccf.IsApplicationError = false
ccf.ErrorCode = uint64(qerr.ApplicationErrorErrorCode)
ccf.ReasonPhrase = ""
}
pl := payload{
frames: []ackhandler.Frame{{Frame: ccf}},
length: ccf.Length(v),
}
var sealer sealer
var err error
switch encLevel {
case protocol.EncryptionInitial:
sealer, err = p.cryptoSetup.GetInitialSealer()
case protocol.EncryptionHandshake:
sealer, err = p.cryptoSetup.GetHandshakeSealer()
case protocol.Encryption0RTT:
sealer, err = p.cryptoSetup.Get0RTTSealer()
case protocol.Encryption1RTT:
var s handshake.ShortHeaderSealer
s, err = p.cryptoSetup.Get1RTTSealer()
if err == nil {
keyPhase = s.KeyPhase()
}
sealer = s
}
if err == handshake.ErrKeysNotYetAvailable || err == handshake.ErrKeysDropped {
continue
}
if err != nil {
return nil, err
}
sealers[i] = sealer
var hdr *wire.ExtendedHeader
if encLevel == protocol.Encryption1RTT {
connID = p.getDestConnID()
oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, pl)
} else {
hdr = p.getLongHeader(encLevel, v)
hdrs[i] = hdr
size += p.longHeaderPacketLength(hdr, pl, v) + protocol.ByteCount(sealer.Overhead())
numLongHdrPackets++
}
payloads[i] = pl
}
buffer := getPacketBuffer()
packet := &coalescedPacket{
buffer: buffer,
longHdrPackets: make([]*longHeaderPacket, 0, numLongHdrPackets),
}
for i, encLevel := range encLevels {
if sealers[i] == nil {
continue
}
if encLevel == protocol.Encryption1RTT {
shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], 0, maxPacketSize, sealers[i], false, v)
if err != nil {
return nil, err
}
packet.shortHdrPacket = &shp
} else {
var paddingLen protocol.ByteCount
if encLevel == protocol.EncryptionInitial {
paddingLen = p.initialPaddingLen(payloads[i].frames, size, maxPacketSize)
}
longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], v)
if err != nil {
return nil, err
}
packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket)
}
}
return packet, nil
}
// longHeaderPacketLength calculates the length of a serialized long header packet.
// It takes into account that packets that have a tiny payload need to be padded,
// such that len(payload) + packet number len >= 4 + AEAD overhead
func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.Version) protocol.ByteCount {
var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(hdr.PacketNumberLen)
if pl.length < 4-pnLen {
paddingLen = 4 - pnLen - pl.length
}
return hdr.GetLength(v) + pl.length + paddingLen
}
// shortHeaderPacketLength calculates the length of a serialized short header packet.
// It takes into account that packets that have a tiny payload need to be padded,
// such that len(payload) + packet number len >= 4 + AEAD overhead
func (p *packetPacker) shortHeaderPacketLength(connID protocol.ConnectionID, pnLen protocol.PacketNumberLen, pl payload) protocol.ByteCount {
var paddingLen protocol.ByteCount
if pl.length < 4-protocol.ByteCount(pnLen) {
paddingLen = 4 - protocol.ByteCount(pnLen) - pl.length
}
return wire.ShortHeaderLen(connID, pnLen) + pl.length + paddingLen
}
// size is the expected size of the packet, if no padding was applied.
func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, currentSize, maxPacketSize protocol.ByteCount) protocol.ByteCount {
// For the server, only ack-eliciting Initial packets need to be padded.
if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) {
return 0
}
if currentSize >= maxPacketSize {
return 0
}
return maxPacketSize - currentSize
}
// PackCoalescedPacket packs a new packet.
// It packs an Initial / Handshake if there is data to send in these packet number spaces.
// It should only be called before the handshake is confirmed.
func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxSize protocol.ByteCount, now time.Time, v protocol.Version) (*coalescedPacket, error) {
var (
initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader
initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload
oneRTTPacketNumber protocol.PacketNumber
oneRTTPacketNumberLen protocol.PacketNumberLen
)
// Try packing an Initial packet.
initialSealer, err := p.cryptoSetup.GetInitialSealer()
if err != nil && err != handshake.ErrKeysDropped {
return nil, err
}
var size protocol.ByteCount
if initialSealer != nil {
initialHdr, initialPayload = p.maybeGetCryptoPacket(
maxSize-protocol.ByteCount(initialSealer.Overhead()),
protocol.EncryptionInitial,
now,
false,
onlyAck,
true,
v,
)
if initialPayload.length > 0 {
size += p.longHeaderPacketLength(initialHdr, initialPayload, v) + protocol.ByteCount(initialSealer.Overhead())
}
}
// Add a Handshake packet.
var handshakeSealer sealer
if (onlyAck && size == 0) || (!onlyAck && size < maxSize-protocol.MinCoalescedPacketSize) {
var err error
handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
return nil, err
}
if handshakeSealer != nil {
handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(
maxSize-size-protocol.ByteCount(handshakeSealer.Overhead()),
protocol.EncryptionHandshake,
now,
false,
onlyAck,
size == 0,
v,
)
if handshakePayload.length > 0 {
s := p.longHeaderPacketLength(handshakeHdr, handshakePayload, v) + protocol.ByteCount(handshakeSealer.Overhead())
size += s
}
}
}
// Add a 0-RTT / 1-RTT packet.
var zeroRTTSealer sealer
var oneRTTSealer handshake.ShortHeaderSealer
var connID protocol.ConnectionID
var kp protocol.KeyPhaseBit
if (onlyAck && size == 0) || (!onlyAck && size < maxSize-protocol.MinCoalescedPacketSize) {
var err error
oneRTTSealer, err = p.cryptoSetup.Get1RTTSealer()
if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
return nil, err
}
if err == nil { // 1-RTT
kp = oneRTTSealer.KeyPhase()
connID = p.getDestConnID()
oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdrLen := wire.ShortHeaderLen(connID, oneRTTPacketNumberLen)
oneRTTPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, hdrLen, maxSize-size, onlyAck, size == 0, now, v)
if oneRTTPayload.length > 0 {
size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead())
}
} else if p.perspective == protocol.PerspectiveClient && !onlyAck { // 0-RTT packets can't contain ACK frames
var err error
zeroRTTSealer, err = p.cryptoSetup.Get0RTTSealer()
if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
return nil, err
}
if zeroRTTSealer != nil {
zeroRTTHdr, zeroRTTPayload = p.maybeGetAppDataPacketFor0RTT(zeroRTTSealer, maxSize-size, now, v)
if zeroRTTPayload.length > 0 {
size += p.longHeaderPacketLength(zeroRTTHdr, zeroRTTPayload, v) + protocol.ByteCount(zeroRTTSealer.Overhead())
}
}
}
}
if initialPayload.length == 0 && handshakePayload.length == 0 && zeroRTTPayload.length == 0 && oneRTTPayload.length == 0 {
return nil, nil
}
buffer := getPacketBuffer()
packet := &coalescedPacket{
buffer: buffer,
longHdrPackets: make([]*longHeaderPacket, 0, 3),
}
if initialPayload.length > 0 {
padding := p.initialPaddingLen(initialPayload.frames, size, maxSize)
cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v)
if err != nil {
return nil, err
}
packet.longHdrPackets = append(packet.longHdrPackets, cont)
}
if handshakePayload.length > 0 {
cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, v)
if err != nil {
return nil, err
}
packet.longHdrPackets = append(packet.longHdrPackets, cont)
}
if zeroRTTPayload.length > 0 {
longHdrPacket, err := p.appendLongHeaderPacket(buffer, zeroRTTHdr, zeroRTTPayload, 0, protocol.Encryption0RTT, zeroRTTSealer, v)
if err != nil {
return nil, err
}
packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket)
} else if oneRTTPayload.length > 0 {
shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, maxSize, oneRTTSealer, false, v)
if err != nil {
return nil, err
}
packet.shortHdrPacket = &shp
}
return packet, nil
}
// PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space.
// It should be called after the handshake is confirmed.
func (p *packetPacker) PackAckOnlyPacket(maxSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
buf := getPacketBuffer()
packet, err := p.appendPacket(buf, true, maxSize, now, v)
return packet, buf, err
}
// AppendPacket packs a packet in the application data packet number space.
// It should be called after the handshake is confirmed.
func (p *packetPacker) AppendPacket(buf *packetBuffer, maxSize protocol.ByteCount, now time.Time, v protocol.Version) (shortHeaderPacket, error) {
return p.appendPacket(buf, false, maxSize, now, v)
}
func (p *packetPacker) appendPacket(
buf *packetBuffer,
onlyAck bool,
maxPacketSize protocol.ByteCount,
now time.Time,
v protocol.Version,
) (shortHeaderPacket, error) {
sealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return shortHeaderPacket{}, err
}
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
connID := p.getDestConnID()
hdrLen := wire.ShortHeaderLen(connID, pnLen)
pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, maxPacketSize, onlyAck, true, now, v)
if pl.length == 0 {
return shortHeaderPacket{}, errNothingToPack
}
kp := sealer.KeyPhase()
return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v)
}
func (p *packetPacker) maybeGetCryptoPacket(
maxPacketSize protocol.ByteCount,
encLevel protocol.EncryptionLevel,
now time.Time,
addPingIfEmpty bool,
onlyAck, ackAllowed bool,
v protocol.Version,
) (*wire.ExtendedHeader, payload) {
if onlyAck {
if ack := p.acks.GetAckFrame(encLevel, now, true); ack != nil {
return p.getLongHeader(encLevel, v), payload{
ack: ack,
length: ack.Length(v),
}
}
return nil, payload{}
}
var hasCryptoData func() bool
var popCryptoFrame func(maxLen protocol.ByteCount) *wire.CryptoFrame
//nolint:exhaustive // Initial and Handshake are the only two encryption levels here.
switch encLevel {
case protocol.EncryptionInitial:
hasCryptoData = p.initialStream.HasData
popCryptoFrame = p.initialStream.PopCryptoFrame
case protocol.EncryptionHandshake:
hasCryptoData = p.handshakeStream.HasData
popCryptoFrame = p.handshakeStream.PopCryptoFrame
}
handler := p.retransmissionQueue.AckHandler(encLevel)
hasRetransmission := p.retransmissionQueue.HasData(encLevel)
var ack *wire.AckFrame
if ackAllowed {
ack = p.acks.GetAckFrame(encLevel, now, !hasRetransmission && !hasCryptoData())
}
var pl payload
if !hasCryptoData() && !hasRetransmission && ack == nil {
if !addPingIfEmpty {
// nothing to send
return nil, payload{}
}
ping := &wire.PingFrame{}
pl.frames = append(pl.frames, ackhandler.Frame{Frame: ping, Handler: emptyHandler{}})
pl.length += ping.Length(v)
}
if ack != nil {
pl.ack = ack
pl.length = ack.Length(v)
maxPacketSize -= pl.length
}
hdr := p.getLongHeader(encLevel, v)
maxPacketSize -= hdr.GetLength(v)
if hasRetransmission {
for {
frame := p.retransmissionQueue.GetFrame(encLevel, maxPacketSize, v)
if frame == nil {
break
}
pl.frames = append(pl.frames, ackhandler.Frame{
Frame: frame,
Handler: p.retransmissionQueue.AckHandler(encLevel),
})
frameLen := frame.Length(v)
pl.length += frameLen
maxPacketSize -= frameLen
}
return hdr, pl
} else {
for hasCryptoData() {
cf := popCryptoFrame(maxPacketSize)
if cf == nil {
break
}
pl.frames = append(pl.frames, ackhandler.Frame{Frame: cf, Handler: handler})
pl.length += cf.Length(v)
maxPacketSize -= cf.Length(v)
}
}
return hdr, pl
}
func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxSize protocol.ByteCount, now time.Time, v protocol.Version) (*wire.ExtendedHeader, payload) {
if p.perspective != protocol.PerspectiveClient {
return nil, payload{}
}
hdr := p.getLongHeader(protocol.Encryption0RTT, v)
maxPayloadSize := maxSize - hdr.GetLength(v) - protocol.ByteCount(sealer.Overhead())
return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, now, v)
}
func (p *packetPacker) maybeGetShortHeaderPacket(
sealer handshake.ShortHeaderSealer,
hdrLen, maxPacketSize protocol.ByteCount,
onlyAck, ackAllowed bool,
now time.Time,
v protocol.Version,
) payload {
maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead())
return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed, now, v)
}
func (p *packetPacker) maybeGetAppDataPacket(
maxPayloadSize protocol.ByteCount,
onlyAck, ackAllowed bool,
now time.Time,
v protocol.Version,
) payload {
pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, now, v)
// check if we have anything to send
if len(pl.frames) == 0 && len(pl.streamFrames) == 0 {
if pl.ack == nil {
return payload{}
}
// the packet only contains an ACK
if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks {
ping := &wire.PingFrame{}
pl.frames = append(pl.frames, ackhandler.Frame{Frame: ping})
pl.length += ping.Length(v)
p.numNonAckElicitingAcks = 0
} else {
p.numNonAckElicitingAcks++
}
} else {
p.numNonAckElicitingAcks = 0
}
return pl
}
func (p *packetPacker) composeNextPacket(
maxPayloadSize protocol.ByteCount,
onlyAck, ackAllowed bool,
now time.Time,
v protocol.Version,
) payload {
if onlyAck {
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, now, true); ack != nil {
return payload{ack: ack, length: ack.Length(v)}
}
return payload{}
}
hasData := p.framer.HasData()
hasRetransmission := p.retransmissionQueue.HasData(protocol.Encryption1RTT)
var hasAck bool
var pl payload
if ackAllowed {
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, now, !hasRetransmission && !hasData); ack != nil {
pl.ack = ack
pl.length += ack.Length(v)
hasAck = true
}
}
if p.datagramQueue != nil {
if f := p.datagramQueue.Peek(); f != nil {
size := f.Length(v)
if size <= maxPayloadSize-pl.length { // DATAGRAM frame fits
pl.frames = append(pl.frames, ackhandler.Frame{Frame: f})
pl.length += size
p.datagramQueue.Pop()
} else if !hasAck {
// The DATAGRAM frame doesn't fit, and the packet doesn't contain an ACK.
// Discard this frame. There's no point in retrying this in the next packet,
// as it's unlikely that the available packet size will increase.
p.datagramQueue.Pop()
}
// If the DATAGRAM frame was too large and the packet contained an ACK, we'll try to send it out later.
}
}
if hasAck && !hasData && !hasRetransmission {
return pl
}
if hasRetransmission {
for {
remainingLen := maxPayloadSize - pl.length
if remainingLen < protocol.MinStreamFrameSize {
break
}
f := p.retransmissionQueue.GetFrame(protocol.Encryption1RTT, remainingLen, v)
if f == nil {
break
}
pl.frames = append(pl.frames, ackhandler.Frame{Frame: f, Handler: p.retransmissionQueue.AckHandler(protocol.Encryption1RTT)})
pl.length += f.Length(v)
}
}
if hasData {
var lengthAdded protocol.ByteCount
startLen := len(pl.frames)
pl.frames, pl.streamFrames, lengthAdded = p.framer.Append(pl.frames, pl.streamFrames, maxPayloadSize-pl.length, now, v)
pl.length += lengthAdded
// add handlers for the control frames that were added
for i := startLen; i < len(pl.frames); i++ {
if pl.frames[i].Handler != nil {
continue
}
switch pl.frames[i].Frame.(type) {
case *wire.PathChallengeFrame, *wire.PathResponseFrame:
// Path probing is currently not supported, therefore we don't need to set the OnAcked callback yet.
// PATH_CHALLENGE and PATH_RESPONSE are never retransmitted.
default:
// we might be packing a 0-RTT packet, but we need to use the 1-RTT ack handler anyway
pl.frames[i].Handler = p.retransmissionQueue.AckHandler(protocol.Encryption1RTT)
}
}
}
return pl
}
func (p *packetPacker) PackPTOProbePacket(
encLevel protocol.EncryptionLevel,
maxPacketSize protocol.ByteCount,
addPingIfEmpty bool,
now time.Time,
v protocol.Version,
) (*coalescedPacket, error) {
if encLevel == protocol.Encryption1RTT {
return p.packPTOProbePacket1RTT(maxPacketSize, addPingIfEmpty, now, v)
}
var sealer handshake.LongHeaderSealer
//nolint:exhaustive // Probe packets are never sent for 0-RTT.
switch encLevel {
case protocol.EncryptionInitial:
var err error
sealer, err = p.cryptoSetup.GetInitialSealer()
if err != nil {
return nil, err
}
case protocol.EncryptionHandshake:
var err error
sealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil {
return nil, err
}
default:
panic("unknown encryption level")
}
hdr, pl := p.maybeGetCryptoPacket(
maxPacketSize-protocol.ByteCount(sealer.Overhead()),
encLevel,
now,
addPingIfEmpty,
false,
true,
v,
)
if pl.length == 0 {
return nil, nil
}
buffer := getPacketBuffer()
packet := &coalescedPacket{buffer: buffer}
size := p.longHeaderPacketLength(hdr, pl, v) + protocol.ByteCount(sealer.Overhead())
var padding protocol.ByteCount
if encLevel == protocol.EncryptionInitial {
padding = p.initialPaddingLen(pl.frames, size, maxPacketSize)
}
longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdr, pl, padding, encLevel, sealer, v)
if err != nil {
return nil, err
}
packet.longHdrPackets = []*longHeaderPacket{longHdrPacket}
return packet, nil
}
func (p *packetPacker) packPTOProbePacket1RTT(maxPacketSize protocol.ByteCount, addPingIfEmpty bool, now time.Time, v protocol.Version) (*coalescedPacket, error) {
s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, err
}
kp := s.KeyPhase()
connID := p.getDestConnID()
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdrLen := wire.ShortHeaderLen(connID, pnLen)
pl := p.maybeGetAppDataPacket(maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, now, v)
if pl.length == 0 {
if !addPingIfEmpty {
return nil, nil
}
ping := &wire.PingFrame{}
pl.frames = append(pl.frames, ackhandler.Frame{Frame: ping, Handler: emptyHandler{}})
pl.length += ping.Length(v)
}
buffer := getPacketBuffer()
packet := &coalescedPacket{buffer: buffer}
shp, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, maxPacketSize, s, false, v)
if err != nil {
return nil, err
}
packet.shortHdrPacket = &shp
return packet, nil
}
func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
pl := payload{
frames: []ackhandler.Frame{ping},
length: ping.Frame.Length(v),
}
buffer := getPacketBuffer()
s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return shortHeaderPacket{}, nil, err
}
connID := p.getDestConnID()
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
padding := size - p.shortHeaderPacketLength(connID, pnLen, pl) - protocol.ByteCount(s.Overhead())
kp := s.KeyPhase()
packet, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, padding, size, s, true, v)
return packet, buffer, err
}
func (p *packetPacker) PackPathProbePacket(connID protocol.ConnectionID, frames []ackhandler.Frame, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
buf := getPacketBuffer()
s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return shortHeaderPacket{}, nil, err
}
var l protocol.ByteCount
for _, f := range frames {
l += f.Frame.Length(v)
}
payload := payload{
frames: frames,
length: l,
}
padding := protocol.MinInitialPacketSize - p.shortHeaderPacketLength(connID, pnLen, payload) - protocol.ByteCount(s.Overhead())
packet, err := p.appendShortHeaderPacket(buf, connID, pn, pnLen, s.KeyPhase(), payload, padding, protocol.MinInitialPacketSize, s, false, v)
if err != nil {
return shortHeaderPacket{}, nil, err
}
packet.IsPathProbePacket = true
return packet, buf, err
}
func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.Version) *wire.ExtendedHeader {
pn, pnLen := p.pnManager.PeekPacketNumber(encLevel)
hdr := &wire.ExtendedHeader{
PacketNumber: pn,
PacketNumberLen: pnLen,
}
hdr.Version = v
hdr.SrcConnectionID = p.srcConnID
hdr.DestConnectionID = p.getDestConnID()
//nolint:exhaustive // 1-RTT packets are not long header packets.
switch encLevel {
case protocol.EncryptionInitial:
hdr.Type = protocol.PacketTypeInitial
hdr.Token = p.token
case protocol.EncryptionHandshake:
hdr.Type = protocol.PacketTypeHandshake
case protocol.Encryption0RTT:
hdr.Type = protocol.PacketType0RTT
}
return hdr
}
func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.Version) (*longHeaderPacket, error) {
var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(header.PacketNumberLen)
if pl.length < 4-pnLen {
paddingLen = 4 - pnLen - pl.length
}
paddingLen += padding
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length + paddingLen
startLen := len(buffer.Data)
raw := buffer.Data[startLen:]
raw, err := header.Append(raw, v)
if err != nil {
return nil, err
}
payloadOffset := protocol.ByteCount(len(raw))
raw, err = p.appendPacketPayload(raw, pl, paddingLen, v)
if err != nil {
return nil, err
}
raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen)
buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)]
if pn := p.pnManager.PopPacketNumber(encLevel); pn != header.PacketNumber {
return nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, header.PacketNumber)
}
return &longHeaderPacket{
header: header,
ack: pl.ack,
frames: pl.frames,
streamFrames: pl.streamFrames,
length: protocol.ByteCount(len(raw)),
}, nil
}
func (p *packetPacker) appendShortHeaderPacket(
buffer *packetBuffer,
connID protocol.ConnectionID,
pn protocol.PacketNumber,
pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit,
pl payload,
padding, maxPacketSize protocol.ByteCount,
sealer sealer,
isMTUProbePacket bool,
v protocol.Version,
) (shortHeaderPacket, error) {
var paddingLen protocol.ByteCount
if pl.length < 4-protocol.ByteCount(pnLen) {
paddingLen = 4 - protocol.ByteCount(pnLen) - pl.length
}
paddingLen += padding
startLen := len(buffer.Data)
raw := buffer.Data[startLen:]
raw, err := wire.AppendShortHeader(raw, connID, pn, pnLen, kp)
if err != nil {
return shortHeaderPacket{}, err
}
payloadOffset := protocol.ByteCount(len(raw))
raw, err = p.appendPacketPayload(raw, pl, paddingLen, v)
if err != nil {
return shortHeaderPacket{}, err
}
if !isMTUProbePacket {
if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > maxPacketSize {
return shortHeaderPacket{}, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, maxPacketSize)
}
}
raw = p.encryptPacket(raw, sealer, pn, payloadOffset, protocol.ByteCount(pnLen))
buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)]
if newPN := p.pnManager.PopPacketNumber(protocol.Encryption1RTT); newPN != pn {
return shortHeaderPacket{}, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, newPN)
}
return shortHeaderPacket{
PacketNumber: pn,
PacketNumberLen: pnLen,
KeyPhase: kp,
StreamFrames: pl.streamFrames,
Frames: pl.frames,
Ack: pl.ack,
Length: protocol.ByteCount(len(raw)),
DestConnID: connID,
IsPathMTUProbePacket: isMTUProbePacket,
}, nil
}
// appendPacketPayload serializes the payload of a packet into the raw byte slice.
// It modifies the order of payload.frames.
func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.Version) ([]byte, error) {
payloadOffset := len(raw)
if pl.ack != nil {
var err error
raw, err = pl.ack.Append(raw, v)
if err != nil {
return nil, err
}
}
if paddingLen > 0 {
raw = append(raw, make([]byte, paddingLen)...)
}
// Randomize the order of the control frames.
// This makes sure that the receiver doesn't rely on the order in which frames are packed.
if len(pl.frames) > 1 {
p.rand.Shuffle(len(pl.frames), func(i, j int) { pl.frames[i], pl.frames[j] = pl.frames[j], pl.frames[i] })
}
for _, f := range pl.frames {
var err error
raw, err = f.Frame.Append(raw, v)
if err != nil {
return nil, err
}
}
for _, f := range pl.streamFrames {
var err error
raw, err = f.Frame.Append(raw, v)
if err != nil {
return nil, err
}
}
if payloadSize := protocol.ByteCount(len(raw)-payloadOffset) - paddingLen; payloadSize != pl.length {
return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", pl.length, payloadSize)
}
return raw, nil
}
func (p *packetPacker) encryptPacket(raw []byte, sealer sealer, pn protocol.PacketNumber, payloadOffset, pnLen protocol.ByteCount) []byte {
_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], pn, raw[:payloadOffset])
raw = raw[:len(raw)+sealer.Overhead()]
// apply header protection
pnOffset := payloadOffset - pnLen
sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[0], raw[pnOffset:payloadOffset])
return raw
}
func (p *packetPacker) SetToken(token []byte) {
p.token = token
}
type emptyHandler struct{}
var _ ackhandler.FrameHandler = emptyHandler{}
func (emptyHandler) OnAcked(wire.Frame) {}
func (emptyHandler) OnLost(wire.Frame) {}
package quic
import (
"fmt"
"time"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
)
type headerDecryptor interface {
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
}
type headerParseError struct {
err error
}
func (e *headerParseError) Unwrap() error {
return e.err
}
func (e *headerParseError) Error() string {
return e.err.Error()
}
type unpackedPacket struct {
hdr *wire.ExtendedHeader
encryptionLevel protocol.EncryptionLevel
data []byte
}
// The packetUnpacker unpacks QUIC packets.
type packetUnpacker struct {
cs handshake.CryptoSetup
shortHdrConnIDLen int
}
var _ unpacker = &packetUnpacker{}
func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetUnpacker {
return &packetUnpacker{
cs: cs,
shortHdrConnIDLen: shortHdrConnIDLen,
}
}
// UnpackLongHeader unpacks a Long Header packet.
// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
// If any other error occurred when parsing the header, the error is of type headerParseError.
// If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
var encLevel protocol.EncryptionLevel
var extHdr *wire.ExtendedHeader
var decrypted []byte
//nolint:exhaustive // Retry packets can't be unpacked.
switch hdr.Type {
case protocol.PacketTypeInitial:
encLevel = protocol.EncryptionInitial
opener, err := u.cs.GetInitialOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
case protocol.PacketTypeHandshake:
encLevel = protocol.EncryptionHandshake
opener, err := u.cs.GetHandshakeOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
case protocol.PacketType0RTT:
encLevel = protocol.Encryption0RTT
opener, err := u.cs.Get0RTTOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
}
if len(decrypted) == 0 {
return nil, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "empty packet",
}
}
return &unpackedPacket{
hdr: extHdr,
encryptionLevel: encLevel,
data: decrypted,
}, nil
}
func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
opener, err := u.cs.Get1RTTOpener()
if err != nil {
return 0, 0, 0, nil, err
}
pn, pnLen, kp, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data)
if err != nil {
return 0, 0, 0, nil, err
}
if len(decrypted) == 0 {
return 0, 0, 0, nil, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "empty packet",
}
}
return pn, pnLen, kp, decrypted, nil
}
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker
// to gain information about the header encryption.
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, nil, parseErr
}
extHdrLen := extHdr.ParsedLen()
extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen)
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen])
if err != nil {
return nil, nil, err
}
if parseErr != nil {
return nil, nil, parseErr
}
return extHdr, decrypted, nil
}
func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
l, pn, pnLen, kp, parseErr := u.unpackShortHeader(opener, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker
// to gain information about the header encryption.
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return 0, 0, 0, nil, &headerParseError{parseErr}
}
pn = opener.DecodePacketNumber(pn, pnLen)
decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, pn, kp, data[:l])
if err != nil {
return 0, 0, 0, nil, err
}
return pn, pnLen, kp, decrypted, parseErr
}
func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int, protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, error) {
hdrLen := 1 /* first header byte */ + u.shortHdrConnIDLen
if len(data) < hdrLen+4+16 {
return 0, 0, 0, 0, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen)
}
origPNBytes := make([]byte, 4)
copy(origPNBytes, data[hdrLen:hdrLen+4])
// 2. decrypt the header, assuming a 4 byte packet number
hd.DecryptHeader(
data[hdrLen+4:hdrLen+4+16],
&data[0],
data[hdrLen:hdrLen+4],
)
// 3. parse the header (and learn the actual length of the packet number)
l, pn, pnLen, kp, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen)
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return l, pn, pnLen, kp, parseErr
}
// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
if pnLen != protocol.PacketNumberLen4 {
copy(data[hdrLen+int(pnLen):hdrLen+4], origPNBytes[int(pnLen):])
}
return l, pn, pnLen, kp, parseErr
}
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
extHdr, err := unpackLongHeader(hd, hdr, data)
if err != nil && err != wire.ErrInvalidReservedBits {
return nil, &headerParseError{err: err}
}
return extHdr, err
}
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
hdrLen := hdr.ParsedLen()
if protocol.ByteCount(len(data)) < hdrLen+4+16 {
return nil, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", protocol.ByteCount(len(data))-hdrLen)
}
// The packet number can be up to 4 bytes long, but we won't know the length until we decrypt it.
// 1. save a copy of the 4 bytes
origPNBytes := make([]byte, 4)
copy(origPNBytes, data[hdrLen:hdrLen+4])
// 2. decrypt the header, assuming a 4 byte packet number
hd.DecryptHeader(
data[hdrLen+4:hdrLen+4+16],
&data[0],
data[hdrLen:hdrLen+4],
)
// 3. parse the header (and learn the actual length of the packet number)
extHdr, parseErr := hdr.ParseExtended(data)
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, parseErr
}
// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
if extHdr.PacketNumberLen != protocol.PacketNumberLen4 {
copy(data[extHdr.ParsedLen():hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
}
return extHdr, parseErr
}
package quic
import (
"crypto/rand"
"net"
"slices"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
)
type pathID int64
const invalidPathID pathID = -1
// Maximum number of paths to keep track of.
// If the peer probes another path (before the pathTimeout of an existing path expires),
// this probing attempt is ignored.
const maxPaths = 3
// If no packet is received for a path for pathTimeout,
// the path can be evicted when the peer probes another path.
// This prevents an attacker from churning through paths by duplicating packets and
// sending them with spoofed source addresses.
const pathTimeout = 5 * time.Second
type path struct {
id pathID
addr net.Addr
lastPacketTime time.Time
pathChallenge [8]byte
validated bool
rcvdNonProbing bool
}
type pathManager struct {
nextPathID pathID
// ordered by lastPacketTime, with the most recently used path at the end
paths []*path
getConnID func(pathID) (_ protocol.ConnectionID, ok bool)
retireConnID func(pathID)
logger utils.Logger
}
func newPathManager(
getConnID func(pathID) (_ protocol.ConnectionID, ok bool),
retireConnID func(pathID),
logger utils.Logger,
) *pathManager {
return &pathManager{
paths: make([]*path, 0, maxPaths+1),
getConnID: getConnID,
retireConnID: retireConnID,
logger: logger,
}
}
// Returns a path challenge frame if one should be sent.
// May return nil.
func (pm *pathManager) HandlePacket(
remoteAddr net.Addr,
t time.Time,
pathChallenge *wire.PathChallengeFrame, // may be nil if the packet didn't contain a PATH_CHALLENGE
isNonProbing bool,
) (_ protocol.ConnectionID, _ []ackhandler.Frame, shouldSwitch bool) {
var p *path
for i, path := range pm.paths {
if addrsEqual(path.addr, remoteAddr) {
p = path
p.lastPacketTime = t
// already sent a PATH_CHALLENGE for this path
if isNonProbing {
path.rcvdNonProbing = true
}
if pm.logger.Debug() {
pm.logger.Debugf("received packet for path %s that was already probed, validated: %t", remoteAddr, path.validated)
}
shouldSwitch = path.validated && path.rcvdNonProbing
if i != len(pm.paths)-1 {
// move the path to the end of the list
pm.paths = slices.Delete(pm.paths, i, i+1)
pm.paths = append(pm.paths, p)
}
if pathChallenge == nil {
return protocol.ConnectionID{}, nil, shouldSwitch
}
}
}
if len(pm.paths) >= maxPaths {
if pm.paths[0].lastPacketTime.Add(pathTimeout).After(t) {
if pm.logger.Debug() {
pm.logger.Debugf("received packet for previously unseen path %s, but already have %d paths", remoteAddr, len(pm.paths))
}
return protocol.ConnectionID{}, nil, shouldSwitch
}
// evict the oldest path, if the last packet was received more than pathTimeout ago
pm.retireConnID(pm.paths[0].id)
pm.paths = pm.paths[1:]
}
var pathID pathID
if p != nil {
pathID = p.id
} else {
pathID = pm.nextPathID
}
// previously unseen path, initiate path validation by sending a PATH_CHALLENGE
connID, ok := pm.getConnID(pathID)
if !ok {
pm.logger.Debugf("skipping validation of new path %s since no connection ID is available", remoteAddr)
return protocol.ConnectionID{}, nil, shouldSwitch
}
frames := make([]ackhandler.Frame, 0, 2)
if p == nil {
var pathChallengeData [8]byte
rand.Read(pathChallengeData[:])
p = &path{
id: pm.nextPathID,
addr: remoteAddr,
lastPacketTime: t,
rcvdNonProbing: isNonProbing,
pathChallenge: pathChallengeData,
}
pm.nextPathID++
pm.paths = append(pm.paths, p)
frames = append(frames, ackhandler.Frame{
Frame: &wire.PathChallengeFrame{Data: p.pathChallenge},
Handler: (*pathManagerAckHandler)(pm),
})
pm.logger.Debugf("enqueueing PATH_CHALLENGE for new path %s", remoteAddr)
}
if pathChallenge != nil {
frames = append(frames, ackhandler.Frame{
Frame: &wire.PathResponseFrame{Data: pathChallenge.Data},
Handler: (*pathManagerAckHandler)(pm),
})
}
return connID, frames, shouldSwitch
}
func (pm *pathManager) HandlePathResponseFrame(f *wire.PathResponseFrame) {
for _, p := range pm.paths {
if f.Data == p.pathChallenge {
// path validated
p.validated = true
pm.logger.Debugf("path %s validated", p.addr)
break
}
}
}
// SwitchToPath is called when the connection switches to a new path
func (pm *pathManager) SwitchToPath(addr net.Addr) {
// retire all other paths
for _, path := range pm.paths {
if addrsEqual(path.addr, addr) {
pm.logger.Debugf("switching to path %d (%s)", path.id, addr)
continue
}
pm.retireConnID(path.id)
}
clear(pm.paths)
pm.paths = pm.paths[:0]
}
type pathManagerAckHandler pathManager
var _ ackhandler.FrameHandler = &pathManagerAckHandler{}
// Acknowledging the frame doesn't validate the path, only receiving the PATH_RESPONSE does.
func (pm *pathManagerAckHandler) OnAcked(f wire.Frame) {}
func (pm *pathManagerAckHandler) OnLost(f wire.Frame) {
pc, ok := f.(*wire.PathChallengeFrame)
if !ok {
return
}
for i, path := range pm.paths {
if path.pathChallenge == pc.Data {
pm.paths = slices.Delete(pm.paths, i, i+1)
pm.retireConnID(path.id)
break
}
}
}
func addrsEqual(addr1, addr2 net.Addr) bool {
if addr1 == nil || addr2 == nil {
return false
}
a1, ok1 := addr1.(*net.UDPAddr)
a2, ok2 := addr2.(*net.UDPAddr)
if ok1 && ok2 {
return a1.IP.Equal(a2.IP) && a1.Port == a2.Port
}
return addr1.String() == addr2.String()
}
package quic
import (
"context"
"crypto/rand"
"errors"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
var (
// ErrPathClosed is returned when trying to switch to a path that has been closed.
ErrPathClosed = errors.New("path closed")
// ErrPathNotValidated is returned when trying to use a path before path probing has completed.
ErrPathNotValidated = errors.New("path not yet validated")
)
var errPathDoesNotExist = errors.New("path does not exist")
// Path is a network path.
type Path struct {
id pathID
pathManager *pathManagerOutgoing
tr *Transport
initialRTT time.Duration
enablePath func()
validated atomic.Bool
abandon chan struct{}
}
func (p *Path) Probe(ctx context.Context) error {
path := p.pathManager.addPath(p, p.enablePath)
p.pathManager.enqueueProbe(p)
nextProbeDur := p.initialRTT
var timer *time.Timer
var timerChan <-chan time.Time
for {
select {
case <-ctx.Done():
return context.Cause(ctx)
case <-path.Validated():
p.validated.Store(true)
return nil
case <-timerChan:
nextProbeDur *= 2 // exponential backoff
p.pathManager.enqueueProbe(p)
case <-path.ProbeSent():
case <-p.abandon:
return ErrPathClosed
}
if timer != nil {
timer.Stop()
}
timer = time.NewTimer(nextProbeDur)
timerChan = timer.C
}
}
// Switch switches the QUIC connection to this path.
// It immediately stops sending on the old path, and sends on this new path.
func (p *Path) Switch() error {
if err := p.pathManager.switchToPath(p.id); err != nil {
switch {
case errors.Is(err, ErrPathNotValidated):
return err
case errors.Is(err, errPathDoesNotExist) && !p.validated.Load():
select {
case <-p.abandon:
return ErrPathClosed
default:
return ErrPathNotValidated
}
default:
return ErrPathClosed
}
}
return nil
}
// Close abandons a path.
// It is not possible to close the path that’s currently active.
// After closing, it is not possible to probe this path again.
func (p *Path) Close() error {
select {
case <-p.abandon:
return nil
default:
}
if err := p.pathManager.removePath(p.id); err != nil {
return err
}
close(p.abandon)
return nil
}
type pathOutgoing struct {
pathChallenges [][8]byte // length is implicitly limited by exponential backoff
tr *Transport
isValidated bool
probeSent chan struct{} // receives when a PATH_CHALLENGE is sent
validated chan struct{} // closed when the path the corresponding PATH_RESPONSE is received
enablePath func()
}
func (p *pathOutgoing) ProbeSent() <-chan struct{} { return p.probeSent }
func (p *pathOutgoing) Validated() <-chan struct{} { return p.validated }
type pathManagerOutgoing struct {
getConnID func(pathID) (_ protocol.ConnectionID, ok bool)
retireConnID func(pathID)
scheduleSending func()
mx sync.Mutex
activePath pathID
pathsToProbe []pathID
paths map[pathID]*pathOutgoing
nextPathID pathID
pathToSwitchTo *pathOutgoing
}
func newPathManagerOutgoing(
getConnID func(pathID) (_ protocol.ConnectionID, ok bool),
retireConnID func(pathID),
scheduleSending func(),
) *pathManagerOutgoing {
return &pathManagerOutgoing{
activePath: 0, // at initialization time, we're guaranteed to be using the handshake path
nextPathID: 1,
getConnID: getConnID,
retireConnID: retireConnID,
scheduleSending: scheduleSending,
paths: make(map[pathID]*pathOutgoing, 4),
}
}
func (pm *pathManagerOutgoing) addPath(p *Path, enablePath func()) *pathOutgoing {
pm.mx.Lock()
defer pm.mx.Unlock()
// path might already exist, and just being re-probed
if existingPath, ok := pm.paths[p.id]; ok {
existingPath.validated = make(chan struct{})
return existingPath
}
path := &pathOutgoing{
tr: p.tr,
probeSent: make(chan struct{}, 1),
validated: make(chan struct{}),
enablePath: enablePath,
}
pm.paths[p.id] = path
return path
}
func (pm *pathManagerOutgoing) enqueueProbe(p *Path) {
pm.mx.Lock()
pm.pathsToProbe = append(pm.pathsToProbe, p.id)
pm.mx.Unlock()
pm.scheduleSending()
}
func (pm *pathManagerOutgoing) removePath(id pathID) error {
if err := pm.removePathImpl(id); err != nil {
return err
}
pm.scheduleSending()
return nil
}
func (pm *pathManagerOutgoing) removePathImpl(id pathID) error {
pm.mx.Lock()
defer pm.mx.Unlock()
if id == pm.activePath {
return errors.New("cannot close active path")
}
p, ok := pm.paths[id]
if !ok {
return nil
}
if len(p.pathChallenges) > 0 {
pm.retireConnID(id)
}
delete(pm.paths, id)
return nil
}
func (pm *pathManagerOutgoing) switchToPath(id pathID) error {
pm.mx.Lock()
defer pm.mx.Unlock()
p, ok := pm.paths[id]
if !ok {
return errPathDoesNotExist
}
if !p.isValidated {
return ErrPathNotValidated
}
pm.pathToSwitchTo = p
pm.activePath = id
return nil
}
func (pm *pathManagerOutgoing) NewPath(t *Transport, initialRTT time.Duration, enablePath func()) *Path {
pm.mx.Lock()
defer pm.mx.Unlock()
id := pm.nextPathID
pm.nextPathID++
return &Path{
pathManager: pm,
id: id,
tr: t,
enablePath: enablePath,
initialRTT: initialRTT,
abandon: make(chan struct{}),
}
}
func (pm *pathManagerOutgoing) NextPathToProbe() (_ protocol.ConnectionID, _ ackhandler.Frame, _ *Transport, hasPath bool) {
pm.mx.Lock()
defer pm.mx.Unlock()
var p *pathOutgoing
id := invalidPathID
for _, pID := range pm.pathsToProbe {
var ok bool
p, ok = pm.paths[pID]
if ok {
id = pID
break
}
// if the path doesn't exist in the map, it might have been abandoned
pm.pathsToProbe = pm.pathsToProbe[1:]
}
if id == invalidPathID {
return protocol.ConnectionID{}, ackhandler.Frame{}, nil, false
}
connID, ok := pm.getConnID(id)
if !ok {
return protocol.ConnectionID{}, ackhandler.Frame{}, nil, false
}
var b [8]byte
_, _ = rand.Read(b[:])
p.pathChallenges = append(p.pathChallenges, b)
pm.pathsToProbe = pm.pathsToProbe[1:]
p.enablePath()
select {
case p.probeSent <- struct{}{}:
default:
}
frame := ackhandler.Frame{
Frame: &wire.PathChallengeFrame{Data: b},
Handler: (*pathManagerOutgoingAckHandler)(pm),
}
return connID, frame, p.tr, true
}
func (pm *pathManagerOutgoing) HandlePathResponseFrame(f *wire.PathResponseFrame) {
pm.mx.Lock()
defer pm.mx.Unlock()
for _, p := range pm.paths {
if slices.Contains(p.pathChallenges, f.Data) {
// path validated
if !p.isValidated {
// make sure that duplicate PATH_RESPONSE frames are ignored
p.isValidated = true
p.pathChallenges = nil
close(p.validated)
}
break
}
}
}
func (pm *pathManagerOutgoing) ShouldSwitchPath() (*Transport, bool) {
pm.mx.Lock()
defer pm.mx.Unlock()
if pm.pathToSwitchTo == nil {
return nil, false
}
p := pm.pathToSwitchTo
pm.pathToSwitchTo = nil
return p.tr, true
}
type pathManagerOutgoingAckHandler pathManagerOutgoing
var _ ackhandler.FrameHandler = &pathManagerOutgoingAckHandler{}
// OnAcked is called when the PATH_CHALLENGE is acked.
// This doesn't validate the path, only receiving the PATH_RESPONSE does.
func (pm *pathManagerOutgoingAckHandler) OnAcked(wire.Frame) {}
func (pm *pathManagerOutgoingAckHandler) OnLost(wire.Frame) {}
package quicvarint
import (
"bytes"
"io"
)
// Reader implements both the io.ByteReader and io.Reader interfaces.
type Reader interface {
io.ByteReader
io.Reader
}
var _ Reader = &bytes.Reader{}
type byteReader struct {
io.Reader
}
var _ Reader = &byteReader{}
// NewReader returns a Reader for r.
// If r already implements both io.ByteReader and io.Reader, NewReader returns r.
// Otherwise, r is wrapped to add the missing interfaces.
func NewReader(r io.Reader) Reader {
if r, ok := r.(Reader); ok {
return r
}
return &byteReader{r}
}
func (r *byteReader) ReadByte() (byte, error) {
var b [1]byte
n, err := r.Read(b[:])
if n == 1 && err == io.EOF {
err = nil
}
return b[0], err
}
// Writer implements both the io.ByteWriter and io.Writer interfaces.
type Writer interface {
io.ByteWriter
io.Writer
}
var _ Writer = &bytes.Buffer{}
type byteWriter struct {
io.Writer
}
var _ Writer = &byteWriter{}
// NewWriter returns a Writer for w.
// If r already implements both io.ByteWriter and io.Writer, NewWriter returns w.
// Otherwise, w is wrapped to add the missing interfaces.
func NewWriter(w io.Writer) Writer {
if w, ok := w.(Writer); ok {
return w
}
return &byteWriter{w}
}
func (w *byteWriter) WriteByte(c byte) error {
_, err := w.Write([]byte{c})
return err
}
package quicvarint
import (
"encoding/binary"
"fmt"
"io"
)
// taken from the QUIC draft
const (
// Min is the minimum value allowed for a QUIC varint.
Min = 0
// Max is the maximum allowed value for a QUIC varint (2^62-1).
Max = maxVarInt8
maxVarInt1 = 63
maxVarInt2 = 16383
maxVarInt4 = 1073741823
maxVarInt8 = 4611686018427387903
)
// Read reads a number in the QUIC varint format from r.
func Read(r io.ByteReader) (uint64, error) {
firstByte, err := r.ReadByte()
if err != nil {
return 0, err
}
// the first two bits of the first byte encode the length
l := 1 << ((firstByte & 0xc0) >> 6)
b1 := firstByte & (0xff - 0xc0)
if l == 1 {
return uint64(b1), nil
}
b2, err := r.ReadByte()
if err != nil {
return 0, err
}
if l == 2 {
return uint64(b2) + uint64(b1)<<8, nil
}
b3, err := r.ReadByte()
if err != nil {
return 0, err
}
b4, err := r.ReadByte()
if err != nil {
return 0, err
}
if l == 4 {
return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil
}
b5, err := r.ReadByte()
if err != nil {
return 0, err
}
b6, err := r.ReadByte()
if err != nil {
return 0, err
}
b7, err := r.ReadByte()
if err != nil {
return 0, err
}
b8, err := r.ReadByte()
if err != nil {
return 0, err
}
return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil
}
// Parse reads a number in the QUIC varint format.
// It returns the number of bytes consumed.
func Parse(b []byte) (uint64 /* value */, int /* bytes consumed */, error) {
if len(b) == 0 {
return 0, 0, io.EOF
}
first := b[0]
switch first >> 6 {
case 0: // 1-byte encoding: 00xxxxxx
return uint64(first & 0b00111111), 1, nil
case 1: // 2-byte encoding: 01xxxxxx
if len(b) < 2 {
return 0, 0, io.ErrUnexpectedEOF
}
return uint64(b[1]) | uint64(first&0b00111111)<<8, 2, nil
case 2: // 4-byte encoding: 10xxxxxx
if len(b) < 4 {
return 0, 0, io.ErrUnexpectedEOF
}
return uint64(b[3]) | uint64(b[2])<<8 | uint64(b[1])<<16 | uint64(first&0b00111111)<<24, 4, nil
case 3: // 8-byte encoding: 00xxxxxx
if len(b) < 8 {
return 0, 0, io.ErrUnexpectedEOF
}
// binary.BigEndian.Uint64 only reads the first 8 bytes. Passing the full slice avoids slicing overhead.
return binary.BigEndian.Uint64(b) & 0x3fffffffffffffff, 8, nil
}
panic("unreachable")
}
// Append appends i in the QUIC varint format.
func Append(b []byte, i uint64) []byte {
if i <= maxVarInt1 {
return append(b, uint8(i))
}
if i <= maxVarInt2 {
return append(b, []byte{uint8(i>>8) | 0x40, uint8(i)}...)
}
if i <= maxVarInt4 {
return append(b, []byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)}...)
}
if i <= maxVarInt8 {
return append(b, []byte{
uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
}...)
}
panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
}
// AppendWithLen append i in the QUIC varint format with the desired length.
func AppendWithLen(b []byte, i uint64, length int) []byte {
if length != 1 && length != 2 && length != 4 && length != 8 {
panic("invalid varint length")
}
l := Len(i)
if l == length {
return Append(b, i)
}
if l > length {
panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length))
}
switch length {
case 2:
b = append(b, 0b01000000)
case 4:
b = append(b, 0b10000000)
case 8:
b = append(b, 0b11000000)
}
for range length - l - 1 {
b = append(b, 0)
}
for j := range l {
b = append(b, uint8(i>>(8*(l-1-j))))
}
return b
}
// Len determines the number of bytes that will be needed to write the number i.
func Len(i uint64) int {
if i <= maxVarInt1 {
return 1
}
if i <= maxVarInt2 {
return 2
}
if i <= maxVarInt4 {
return 4
}
if i <= maxVarInt8 {
return 8
}
// Don't use a fmt.Sprintf here to format the error message.
// The function would then exceed the inlining budget.
panic(struct {
message string
num uint64
}{"value doesn't fit into 62 bits: ", i})
}
package quic
import (
"fmt"
"io"
"sync"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
)
// A ReceiveStream is a unidirectional Receive Stream.
type ReceiveStream struct {
mutex sync.Mutex
streamID protocol.StreamID
sender streamSender
frameQueue *frameSorter
finalOffset protocol.ByteCount
currentFrame []byte
currentFrameDone func()
readPosInFrame int
currentFrameIsLast bool // is the currentFrame the last frame on this stream
queuedStopSending bool
queuedMaxStreamData bool
// Set once we read the io.EOF or the cancellation error.
// Note that for local cancellations, this doesn't necessarily mean that we know the final offset yet.
errorRead bool
completed bool // set once we've called streamSender.onStreamCompleted
cancelledRemotely bool
cancelledLocally bool
cancelErr *StreamError
closeForShutdownErr error
readPos protocol.ByteCount
reliableSize protocol.ByteCount
readChan chan struct{}
readOnce chan struct{} // cap: 1, to protect against concurrent use of Read
deadline time.Time
flowController flowcontrol.StreamFlowController
}
var (
_ streamControlFrameGetter = &ReceiveStream{}
_ receiveStreamFrameHandler = &ReceiveStream{}
)
func newReceiveStream(
streamID protocol.StreamID,
sender streamSender,
flowController flowcontrol.StreamFlowController,
) *ReceiveStream {
return &ReceiveStream{
streamID: streamID,
sender: sender,
flowController: flowController,
frameQueue: newFrameSorter(),
readChan: make(chan struct{}, 1),
readOnce: make(chan struct{}, 1),
finalOffset: protocol.MaxByteCount,
}
}
// StreamID returns the stream ID.
func (s *ReceiveStream) StreamID() protocol.StreamID {
return s.streamID
}
// Read reads data from the stream.
// Read can be made to time out using [ReceiveStream.SetReadDeadline].
// If the stream was canceled, the error is a [StreamError].
func (s *ReceiveStream) Read(p []byte) (int, error) {
// Concurrent use of Read is not permitted (and doesn't make any sense),
// but sometimes people do it anyway.
// Make sure that we only execute one call at any given time to avoid hard to debug failures.
s.readOnce <- struct{}{}
defer func() { <-s.readOnce }()
s.mutex.Lock()
queuedStreamWindowUpdate, queuedConnWindowUpdate, n, err := s.readImpl(p)
completed := s.isNewlyCompleted()
s.mutex.Unlock()
if completed {
s.sender.onStreamCompleted(s.streamID)
}
if queuedStreamWindowUpdate {
s.sender.onHasStreamControlFrame(s.streamID, s)
}
if queuedConnWindowUpdate {
s.sender.onHasConnectionData()
}
return n, err
}
func (s *ReceiveStream) isNewlyCompleted() bool {
if s.completed {
return false
}
// We need to know the final offset (either via FIN or RESET_STREAM) for flow control accounting.
if s.finalOffset == protocol.MaxByteCount {
return false
}
// We're done with the stream if it was cancelled locally...
if s.cancelledLocally {
s.completed = true
return true
}
// ... or if the error (either io.EOF or the reset error) was read
if s.errorRead {
s.completed = true
return true
}
return false
}
func (s *ReceiveStream) readImpl(p []byte) (hasStreamWindowUpdate bool, hasConnWindowUpdate bool, _ int, _ error) {
if s.currentFrameIsLast && s.currentFrame == nil {
s.errorRead = true
return false, false, 0, io.EOF
}
if s.cancelledLocally || (s.cancelledRemotely && s.readPos >= s.reliableSize) {
s.errorRead = true
return false, false, 0, s.cancelErr
}
if s.closeForShutdownErr != nil {
return false, false, 0, s.closeForShutdownErr
}
var bytesRead int
var deadlineTimer *utils.Timer
for bytesRead < len(p) {
if s.currentFrame == nil || s.readPosInFrame >= len(s.currentFrame) {
s.dequeueNextFrame()
}
if s.currentFrame == nil && bytesRead > 0 {
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.closeForShutdownErr
}
for {
// Stop waiting on errors
if s.closeForShutdownErr != nil {
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.closeForShutdownErr
}
if s.cancelledLocally || (s.cancelledRemotely && s.readPos >= s.reliableSize) {
s.errorRead = true
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.cancelErr
}
deadline := s.deadline
if !deadline.IsZero() {
if !time.Now().Before(deadline) {
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, errDeadline
}
if deadlineTimer == nil {
deadlineTimer = utils.NewTimer()
defer deadlineTimer.Stop()
}
deadlineTimer.Reset(deadline)
}
if s.currentFrame != nil || s.currentFrameIsLast {
break
}
s.mutex.Unlock()
if deadline.IsZero() {
<-s.readChan
} else {
select {
case <-s.readChan:
case <-deadlineTimer.Chan():
deadlineTimer.SetRead()
}
}
s.mutex.Lock()
if s.currentFrame == nil {
s.dequeueNextFrame()
}
}
if bytesRead > len(p) {
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
}
if s.readPosInFrame > len(s.currentFrame) {
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
}
m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:])
// when a RESET_STREAM was received, the flow controller was already
// informed about the final offset for this stream
if !s.cancelledRemotely || s.readPos < s.reliableSize {
hasStream, hasConn := s.flowController.AddBytesRead(protocol.ByteCount(m))
if hasStream {
s.queuedMaxStreamData = true
hasStreamWindowUpdate = true
}
if hasConn {
hasConnWindowUpdate = true
}
}
s.readPosInFrame += m
s.readPos += protocol.ByteCount(m)
bytesRead += m
if s.cancelledRemotely && s.readPos >= s.reliableSize {
s.flowController.Abandon()
}
if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast {
s.currentFrame = nil
if s.currentFrameDone != nil {
s.currentFrameDone()
}
s.errorRead = true
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, io.EOF
}
}
if s.cancelledRemotely && s.readPos >= s.reliableSize {
s.errorRead = true
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, s.cancelErr
}
return hasStreamWindowUpdate, hasConnWindowUpdate, bytesRead, nil
}
func (s *ReceiveStream) dequeueNextFrame() {
var offset protocol.ByteCount
// We're done with the last frame. Release the buffer.
if s.currentFrameDone != nil {
s.currentFrameDone()
}
offset, s.currentFrame, s.currentFrameDone = s.frameQueue.Pop()
s.currentFrameIsLast = offset+protocol.ByteCount(len(s.currentFrame)) >= s.finalOffset && !s.cancelledRemotely
s.readPosInFrame = 0
}
// CancelRead aborts receiving on this stream.
// It instructs the peer to stop transmitting stream data.
// Read will unblock immediately, and future Read calls will fail.
// When called multiple times or after reading the io.EOF it is a no-op.
func (s *ReceiveStream) CancelRead(errorCode StreamErrorCode) {
s.mutex.Lock()
queuedNewControlFrame := s.cancelReadImpl(errorCode)
completed := s.isNewlyCompleted()
s.mutex.Unlock()
if queuedNewControlFrame {
s.sender.onHasStreamControlFrame(s.streamID, s)
}
if completed {
s.flowController.Abandon()
s.sender.onStreamCompleted(s.streamID)
}
}
func (s *ReceiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) (queuedNewControlFrame bool) {
if s.cancelledLocally { // duplicate call to CancelRead
return false
}
if s.closeForShutdownErr != nil {
return false
}
s.cancelledLocally = true
if s.errorRead || s.cancelledRemotely {
return false
}
s.queuedStopSending = true
s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.signalRead()
return true
}
func (s *ReceiveStream) handleStreamFrame(frame *wire.StreamFrame, now time.Time) error {
s.mutex.Lock()
err := s.handleStreamFrameImpl(frame, now)
completed := s.isNewlyCompleted()
s.mutex.Unlock()
if completed {
s.flowController.Abandon()
s.sender.onStreamCompleted(s.streamID)
}
return err
}
func (s *ReceiveStream) handleStreamFrameImpl(frame *wire.StreamFrame, now time.Time) error {
maxOffset := frame.Offset + frame.DataLen()
if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin, now); err != nil {
return err
}
if frame.Fin {
s.finalOffset = maxOffset
}
if s.cancelledLocally {
return nil
}
if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.PutBack); err != nil {
return err
}
s.signalRead()
return nil
}
func (s *ReceiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame, now time.Time) error {
s.mutex.Lock()
err := s.handleResetStreamFrameImpl(frame, now)
completed := s.isNewlyCompleted()
s.mutex.Unlock()
if completed {
s.sender.onStreamCompleted(s.streamID)
}
return err
}
func (s *ReceiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame, now time.Time) error {
if s.closeForShutdownErr != nil {
return nil
}
if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true, now); err != nil {
return err
}
s.finalOffset = frame.FinalSize
// senders are allowed to reduce the reliable size, but frames might have been reordered
if (!s.cancelledRemotely && s.reliableSize == 0) || frame.ReliableSize < s.reliableSize {
s.reliableSize = frame.ReliableSize
}
if s.readPos >= s.reliableSize {
// calling Abandon multiple times is a no-op
s.flowController.Abandon()
}
// ignore duplicate RESET_STREAM frames for this stream (after checking their final offset)
if s.cancelledRemotely {
return nil
}
// don't save the error if the RESET_STREAM frames was received after CancelRead was called
if s.cancelledLocally {
return nil
}
s.cancelledRemotely = true
s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: frame.ErrorCode, Remote: true}
s.signalRead()
return nil
}
func (s *ReceiveStream) getControlFrame(now time.Time) (_ ackhandler.Frame, ok, hasMore bool) {
s.mutex.Lock()
defer s.mutex.Unlock()
if !s.queuedStopSending && !s.queuedMaxStreamData {
return ackhandler.Frame{}, false, false
}
if s.queuedStopSending {
s.queuedStopSending = false
return ackhandler.Frame{
Frame: &wire.StopSendingFrame{StreamID: s.streamID, ErrorCode: s.cancelErr.ErrorCode},
}, true, s.queuedMaxStreamData
}
s.queuedMaxStreamData = false
return ackhandler.Frame{
Frame: &wire.MaxStreamDataFrame{
StreamID: s.streamID,
MaximumStreamData: s.flowController.GetWindowUpdate(now),
},
}, true, false
}
// SetReadDeadline sets the deadline for future Read calls and
// any currently-blocked Read call.
// A zero value for t means Read will not time out.
func (s *ReceiveStream) SetReadDeadline(t time.Time) error {
s.mutex.Lock()
s.deadline = t
s.mutex.Unlock()
s.signalRead()
return nil
}
// CloseForShutdown closes a stream abruptly.
// It makes Read unblock (and return the error) immediately.
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RESET.
func (s *ReceiveStream) closeForShutdown(err error) {
s.mutex.Lock()
s.closeForShutdownErr = err
s.mutex.Unlock()
s.signalRead()
}
// signalRead performs a non-blocking send on the readChan
func (s *ReceiveStream) signalRead() {
select {
case s.readChan <- struct{}{}:
default:
}
}
package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
type framesToRetransmit struct {
crypto []*wire.CryptoFrame
other []wire.Frame
}
type retransmissionQueue struct {
initial *framesToRetransmit
handshake *framesToRetransmit
appData framesToRetransmit
}
func newRetransmissionQueue() *retransmissionQueue {
return &retransmissionQueue{
initial: &framesToRetransmit{},
handshake: &framesToRetransmit{},
}
}
func (q *retransmissionQueue) addInitial(f wire.Frame) {
if q.initial == nil {
return
}
if cf, ok := f.(*wire.CryptoFrame); ok {
q.initial.crypto = append(q.initial.crypto, cf)
return
}
q.initial.other = append(q.initial.other, f)
}
func (q *retransmissionQueue) addHandshake(f wire.Frame) {
if q.handshake == nil {
return
}
if cf, ok := f.(*wire.CryptoFrame); ok {
q.handshake.crypto = append(q.handshake.crypto, cf)
return
}
q.handshake.other = append(q.handshake.other, f)
}
func (q *retransmissionQueue) addAppData(f wire.Frame) {
switch f := f.(type) {
case *wire.StreamFrame:
panic("STREAM frames are handled with their respective streams.")
case *wire.CryptoFrame:
q.appData.crypto = append(q.appData.crypto, f)
default:
q.appData.other = append(q.appData.other, f)
}
}
func (q *retransmissionQueue) HasData(encLevel protocol.EncryptionLevel) bool {
//nolint:exhaustive // 0-RTT data is retransmitted in 1-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
return q.initial != nil &&
(len(q.initial.crypto) > 0 || len(q.initial.other) > 0)
case protocol.EncryptionHandshake:
return q.handshake != nil &&
(len(q.handshake.crypto) > 0 || len(q.handshake.other) > 0)
case protocol.Encryption1RTT:
return len(q.appData.crypto) > 0 || len(q.appData.other) > 0
}
return false
}
func (q *retransmissionQueue) GetFrame(encLevel protocol.EncryptionLevel, maxLen protocol.ByteCount, v protocol.Version) wire.Frame {
var r *framesToRetransmit
//nolint:exhaustive // 0-RTT data is retransmitted in 1-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
r = q.initial
case protocol.EncryptionHandshake:
r = q.handshake
case protocol.Encryption1RTT:
r = &q.appData
}
if r == nil {
return nil
}
if len(r.crypto) > 0 {
f := r.crypto[0]
newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v)
if newFrame == nil && !needsSplit { // the whole frame fits
r.crypto = r.crypto[1:]
return f
}
if newFrame != nil { // frame was split. Leave the original frame in the queue.
return newFrame
}
}
if len(r.other) == 0 {
return nil
}
f := r.other[0]
if f.Length(v) > maxLen {
return nil
}
r.other = r.other[1:]
return f
}
func (q *retransmissionQueue) DropPackets(encLevel protocol.EncryptionLevel) {
//nolint:exhaustive // Can only drop Initial and Handshake packet number space.
switch encLevel {
case protocol.EncryptionInitial:
q.initial = nil
case protocol.EncryptionHandshake:
q.handshake = nil
default:
panic(fmt.Sprintf("unexpected encryption level: %s", encLevel))
}
}
func (q *retransmissionQueue) AckHandler(encLevel protocol.EncryptionLevel) ackhandler.FrameHandler {
switch encLevel {
case protocol.EncryptionInitial:
return (*retransmissionQueueInitialAckHandler)(q)
case protocol.EncryptionHandshake:
return (*retransmissionQueueHandshakeAckHandler)(q)
case protocol.Encryption0RTT, protocol.Encryption1RTT:
return (*retransmissionQueueAppDataAckHandler)(q)
}
return nil
}
type retransmissionQueueInitialAckHandler retransmissionQueue
func (q *retransmissionQueueInitialAckHandler) OnAcked(wire.Frame) {}
func (q *retransmissionQueueInitialAckHandler) OnLost(f wire.Frame) {
(*retransmissionQueue)(q).addInitial(f)
}
type retransmissionQueueHandshakeAckHandler retransmissionQueue
func (q *retransmissionQueueHandshakeAckHandler) OnAcked(wire.Frame) {}
func (q *retransmissionQueueHandshakeAckHandler) OnLost(f wire.Frame) {
(*retransmissionQueue)(q).addHandshake(f)
}
type retransmissionQueueAppDataAckHandler retransmissionQueue
func (q *retransmissionQueueAppDataAckHandler) OnAcked(wire.Frame) {}
func (q *retransmissionQueueAppDataAckHandler) OnLost(f wire.Frame) {
(*retransmissionQueue)(q).addAppData(f)
}
package quic
import (
"net"
"sync/atomic"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
// A sendConn allows sending using a simple Write() on a non-connected packet conn.
type sendConn interface {
Write(b []byte, gsoSize uint16, ecn protocol.ECN) error
WriteTo([]byte, net.Addr) error
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
ChangeRemoteAddr(addr net.Addr, info packetInfo)
capabilities() connCapabilities
}
type remoteAddrInfo struct {
addr net.Addr
oob []byte
}
type sconn struct {
rawConn
localAddr net.Addr
remoteAddrInfo atomic.Pointer[remoteAddrInfo]
logger utils.Logger
// If GSO enabled, and we receive a GSO error for this remote address, GSO is disabled.
gotGSOError bool
// Used to catch the error sometimes returned by the first sendmsg call on Linux,
// see https://github.com/golang/go/issues/63322.
wroteFirstPacket bool
}
var _ sendConn = &sconn{}
func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logger) *sconn {
localAddr := c.LocalAddr()
if info.addr.IsValid() {
if udpAddr, ok := localAddr.(*net.UDPAddr); ok {
addrCopy := *udpAddr
addrCopy.IP = info.addr.AsSlice()
localAddr = &addrCopy
}
}
oob := info.OOB()
// increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating
l := len(oob)
oob = append(oob, make([]byte, 64)...)[:l]
sc := &sconn{
rawConn: c,
localAddr: localAddr,
logger: logger,
}
sc.remoteAddrInfo.Store(&remoteAddrInfo{
addr: remote,
oob: oob,
})
return sc
}
func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error {
ai := c.remoteAddrInfo.Load()
err := c.writePacket(p, ai.addr, ai.oob, gsoSize, ecn)
if err != nil && isGSOError(err) {
// disable GSO for future calls
c.gotGSOError = true
if c.logger.Debug() {
c.logger.Debugf("GSO failed when sending to %s", ai.addr)
}
// send out the packets one by one
for len(p) > 0 {
l := len(p)
if l > int(gsoSize) {
l = int(gsoSize)
}
if err := c.writePacket(p[:l], ai.addr, ai.oob, 0, ecn); err != nil {
return err
}
p = p[l:]
}
return nil
}
return err
}
func (c *sconn) writePacket(p []byte, addr net.Addr, oob []byte, gsoSize uint16, ecn protocol.ECN) error {
_, err := c.WritePacket(p, addr, oob, gsoSize, ecn)
if err != nil && !c.wroteFirstPacket && isPermissionError(err) {
_, err = c.WritePacket(p, addr, oob, gsoSize, ecn)
}
c.wroteFirstPacket = true
return err
}
func (c *sconn) WriteTo(b []byte, addr net.Addr) error {
_, err := c.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported)
return err
}
func (c *sconn) capabilities() connCapabilities {
capabilities := c.rawConn.capabilities()
if capabilities.GSO {
capabilities.GSO = !c.gotGSOError
}
return capabilities
}
func (c *sconn) ChangeRemoteAddr(addr net.Addr, info packetInfo) {
c.remoteAddrInfo.Store(&remoteAddrInfo{
addr: addr,
oob: info.OOB(),
})
}
func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddrInfo.Load().addr }
func (c *sconn) LocalAddr() net.Addr { return c.localAddr }
package quic
import (
"net"
"github.com/quic-go/quic-go/internal/protocol"
)
type sender interface {
Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN)
SendProbe(*packetBuffer, net.Addr)
Run() error
WouldBlock() bool
Available() <-chan struct{}
Close()
}
type queueEntry struct {
buf *packetBuffer
gsoSize uint16
ecn protocol.ECN
}
type sendQueue struct {
queue chan queueEntry
closeCalled chan struct{} // runStopped when Close() is called
runStopped chan struct{} // runStopped when the run loop returns
available chan struct{}
conn sendConn
}
var _ sender = &sendQueue{}
const sendQueueCapacity = 8
func newSendQueue(conn sendConn) sender {
return &sendQueue{
conn: conn,
runStopped: make(chan struct{}),
closeCalled: make(chan struct{}),
available: make(chan struct{}, 1),
queue: make(chan queueEntry, sendQueueCapacity),
}
}
// Send sends out a packet. It's guaranteed to not block.
// Callers need to make sure that there's actually space in the send queue by calling WouldBlock.
// Otherwise Send will panic.
func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) {
select {
case h.queue <- queueEntry{buf: p, gsoSize: gsoSize, ecn: ecn}:
// clear available channel if we've reached capacity
if len(h.queue) == sendQueueCapacity {
select {
case <-h.available:
default:
}
}
case <-h.runStopped:
default:
panic("sendQueue.Send would have blocked")
}
}
func (h *sendQueue) SendProbe(p *packetBuffer, addr net.Addr) {
h.conn.WriteTo(p.Data, addr)
}
func (h *sendQueue) WouldBlock() bool {
return len(h.queue) == sendQueueCapacity
}
func (h *sendQueue) Available() <-chan struct{} {
return h.available
}
func (h *sendQueue) Run() error {
defer close(h.runStopped)
var shouldClose bool
for {
if shouldClose && len(h.queue) == 0 {
return nil
}
select {
case <-h.closeCalled:
h.closeCalled = nil // prevent this case from being selected again
// make sure that all queued packets are actually sent out
shouldClose = true
case e := <-h.queue:
if err := h.conn.Write(e.buf.Data, e.gsoSize, e.ecn); err != nil {
// This additional check enables:
// 1. Checking for "datagram too large" message from the kernel, as such,
// 2. Path MTU discovery,and
// 3. Eventual detection of loss PingFrame.
if !isSendMsgSizeErr(err) {
return err
}
}
e.buf.Release()
select {
case h.available <- struct{}{}:
default:
}
}
}
}
func (h *sendQueue) Close() {
close(h.closeCalled)
// wait until the run loop returned
<-h.runStopped
}
package quic
import (
"context"
"fmt"
"sync"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
)
// A SendStream is a unidirectional Send Stream.
type SendStream struct {
mutex sync.Mutex
numOutstandingFrames int64 // outstanding STREAM and RESET_STREAM frames
retransmissionQueue []*wire.StreamFrame
ctx context.Context
ctxCancel context.CancelCauseFunc
streamID protocol.StreamID
sender streamSender
// reliableSize is the portion of the stream that needs to be transmitted reliably,
// even if the stream is cancelled.
// This requires the peer to support RESET_STREAM_AT.
// This value should not be accessed directly, but only through the reliableOffset method.
// This method returns 0 if the peer doesn't support the RESET_STREAM_AT extension.
reliableSize protocol.ByteCount
writeOffset protocol.ByteCount
shutdownErr error
resetErr *StreamError
queuedResetStreamFrame *wire.ResetStreamFrame
supportsResetStreamAt bool
finishedWriting bool // set once Close() is called
finSent bool // set when a STREAM_FRAME with FIN bit has been sent
// Set when the application knows about the cancellation.
// This can happen because the application called CancelWrite,
// or because Write returned the error (for remote cancellations).
cancellationFlagged bool
completed bool // set when this stream has been reported to the streamSender as completed
dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
nextFrame *wire.StreamFrame
writeChan chan struct{}
writeOnce chan struct{}
deadline time.Time
flowController flowcontrol.StreamFlowController
}
var (
_ streamControlFrameGetter = &SendStream{}
_ outgoingStream = &SendStream{}
_ sendStreamFrameHandler = &SendStream{}
)
func newSendStream(
ctx context.Context,
streamID protocol.StreamID,
sender streamSender,
flowController flowcontrol.StreamFlowController,
supportsResetStreamAt bool,
) *SendStream {
s := &SendStream{
streamID: streamID,
sender: sender,
flowController: flowController,
writeChan: make(chan struct{}, 1),
writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write
supportsResetStreamAt: supportsResetStreamAt,
}
s.ctx, s.ctxCancel = context.WithCancelCause(ctx)
return s
}
// StreamID returns the stream ID.
func (s *SendStream) StreamID() StreamID {
return s.streamID // same for receiveStream and sendStream
}
// Write writes data to the stream.
// Write can be made to time out using [SendStream.SetWriteDeadline].
// If the stream was canceled, the error is a [StreamError].
func (s *SendStream) Write(p []byte) (int, error) {
// Concurrent use of Write is not permitted (and doesn't make any sense),
// but sometimes people do it anyway.
// Make sure that we only execute one call at any given time to avoid hard to debug failures.
s.writeOnce <- struct{}{}
defer func() { <-s.writeOnce }()
isNewlyCompleted, n, err := s.write(p)
if isNewlyCompleted {
s.sender.onStreamCompleted(s.streamID)
}
return n, err
}
func (s *SendStream) write(p []byte) (bool /* is newly completed */, int, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.resetErr != nil {
s.cancellationFlagged = true
return s.isNewlyCompleted(), 0, s.resetErr
}
if s.shutdownErr != nil {
return false, 0, s.shutdownErr
}
if s.finishedWriting {
return false, 0, fmt.Errorf("write on closed stream %d", s.streamID)
}
if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
return false, 0, errDeadline
}
if len(p) == 0 {
return false, 0, nil
}
s.dataForWriting = p
var (
deadlineTimer *utils.Timer
bytesWritten int
notifiedSender bool
)
for {
var copied bool
var deadline time.Time
// As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame),
// which can then be popped the next time we assemble a packet.
// This allows us to return Write() when all data but x bytes have been sent out.
// When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame,
// allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN).
if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 {
if s.nextFrame == nil {
f := wire.GetStreamFrame()
f.Offset = s.writeOffset
f.StreamID = s.streamID
f.DataLenPresent = true
f.Data = f.Data[:len(s.dataForWriting)]
copy(f.Data, s.dataForWriting)
s.nextFrame = f
} else {
l := len(s.nextFrame.Data)
s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)]
copy(s.nextFrame.Data[l:], s.dataForWriting)
}
s.dataForWriting = nil
bytesWritten = len(p)
copied = true
} else {
bytesWritten = len(p) - len(s.dataForWriting)
deadline = s.deadline
if !deadline.IsZero() {
if !time.Now().Before(deadline) {
s.dataForWriting = nil
return false, bytesWritten, errDeadline
}
if deadlineTimer == nil {
deadlineTimer = utils.NewTimer()
defer deadlineTimer.Stop()
}
deadlineTimer.Reset(deadline)
}
if s.dataForWriting == nil || s.shutdownErr != nil || s.resetErr != nil {
break
}
}
s.mutex.Unlock()
if !notifiedSender {
s.sender.onHasStreamData(s.streamID, s) // must be called without holding the mutex
notifiedSender = true
}
if copied {
s.mutex.Lock()
break
}
if deadline.IsZero() {
<-s.writeChan
} else {
select {
case <-s.writeChan:
case <-deadlineTimer.Chan():
deadlineTimer.SetRead()
}
}
s.mutex.Lock()
}
if bytesWritten == len(p) {
return false, bytesWritten, nil
}
if s.shutdownErr != nil {
return false, bytesWritten, s.shutdownErr
}
if s.resetErr != nil {
s.cancellationFlagged = true
return s.isNewlyCompleted(), bytesWritten, s.resetErr
}
return false, bytesWritten, nil
}
func (s *SendStream) canBufferStreamFrame() bool {
var l protocol.ByteCount
if s.nextFrame != nil {
l = s.nextFrame.DataLen()
}
return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxPacketBufferSize
}
// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
// maxBytes is the maximum length this frame (including frame header) will have.
func (s *SendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ ackhandler.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMore bool) {
s.mutex.Lock()
f, blocked, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
if f != nil {
s.numOutstandingFrames++
}
s.mutex.Unlock()
if f == nil {
return ackhandler.StreamFrame{}, blocked, hasMoreData
}
return ackhandler.StreamFrame{
Frame: f,
Handler: (*sendStreamAckHandler)(s),
}, blocked, hasMoreData
}
func (s *SendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMoreData bool) {
if s.shutdownErr != nil {
return nil, nil, false
}
if s.resetErr != nil {
reliableOffset := s.reliableOffset()
if reliableOffset == 0 || (s.writeOffset >= reliableOffset && len(s.retransmissionQueue) == 0) {
return nil, nil, false
}
}
if len(s.retransmissionQueue) > 0 {
f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes, v)
if f != nil || hasMoreRetransmissions {
if f == nil {
return nil, nil, true
}
// We always claim that we have more data to send.
// This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future.
return f, nil, true
}
}
if len(s.dataForWriting) == 0 && s.nextFrame == nil {
if s.finishedWriting && !s.finSent {
s.finSent = true
return &wire.StreamFrame{
StreamID: s.streamID,
Offset: s.writeOffset,
DataLenPresent: true,
Fin: true,
}, nil, false
}
return nil, nil, false
}
maxDataLen := s.flowController.SendWindowSize()
if maxDataLen == 0 {
return nil, nil, true
}
// if the stream is canceled, only data up to the reliable size needs to be sent
reliableOffset := s.reliableOffset()
if s.resetErr != nil && reliableOffset > 0 {
maxDataLen = min(maxDataLen, reliableOffset-s.writeOffset)
}
f, hasMoreData := s.popNewStreamFrame(maxBytes, maxDataLen, v)
if f == nil {
return nil, nil, hasMoreData
}
if f.DataLen() > 0 {
s.writeOffset += f.DataLen()
s.flowController.AddBytesSent(f.DataLen())
}
if s.resetErr != nil && s.writeOffset >= reliableOffset {
hasMoreData = false
}
var blocked *wire.StreamDataBlockedFrame
// If the entire send window is used, the stream might have become blocked on stream-level flow control.
// This is not guaranteed though, because the stream might also have been blocked on connection-level flow control.
if f.DataLen() == maxDataLen && s.flowController.IsNewlyBlocked() {
blocked = &wire.StreamDataBlockedFrame{StreamID: s.streamID, MaximumStreamData: s.writeOffset}
}
f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent
if f.Fin {
s.finSent = true
}
return f, blocked, hasMoreData
}
// popNewStreamFrame returns a new STREAM frame to send for this stream
// hasMoreData says if there's more data to send, *not* taking into account the reliable size
func (s *SendStream) popNewStreamFrame(maxBytes, maxDataLen protocol.ByteCount, v protocol.Version) (_ *wire.StreamFrame, hasMoreData bool) {
if s.nextFrame != nil {
maxDataLen := min(maxDataLen, s.nextFrame.MaxDataLen(maxBytes, v))
if maxDataLen == 0 {
return nil, true
}
nextFrame := s.nextFrame
s.nextFrame = nil
if nextFrame.DataLen() > maxDataLen {
s.nextFrame = wire.GetStreamFrame()
s.nextFrame.StreamID = s.streamID
s.nextFrame.Offset = s.writeOffset + maxDataLen
s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen]
s.nextFrame.DataLenPresent = true
copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:])
nextFrame.Data = nextFrame.Data[:maxDataLen]
} else {
s.signalWrite()
}
return nextFrame, s.nextFrame != nil || s.dataForWriting != nil
}
f := wire.GetStreamFrame()
f.Fin = false
f.StreamID = s.streamID
f.Offset = s.writeOffset
f.DataLenPresent = true
f.Data = f.Data[:0]
hasMoreData = s.popNewStreamFrameWithoutBuffer(f, maxBytes, maxDataLen, v)
if len(f.Data) == 0 && !f.Fin {
f.PutBack()
return nil, hasMoreData
}
return f, hasMoreData
}
func (s *SendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.Version) bool {
maxDataLen := f.MaxDataLen(maxBytes, v)
if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
}
s.getDataForWriting(f, min(maxDataLen, sendWindow))
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
}
func (s *SendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more retransmissions */) {
f := s.retransmissionQueue[0]
newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v)
if needsSplit {
return newFrame, true
}
s.retransmissionQueue = s.retransmissionQueue[1:]
return f, len(s.retransmissionQueue) > 0
}
func (s *SendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) {
if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes {
f.Data = f.Data[:len(s.dataForWriting)]
copy(f.Data, s.dataForWriting)
s.dataForWriting = nil
s.signalWrite()
return
}
f.Data = f.Data[:maxBytes]
copy(f.Data, s.dataForWriting)
s.dataForWriting = s.dataForWriting[maxBytes:]
if s.canBufferStreamFrame() {
s.signalWrite()
}
}
func (s *SendStream) isNewlyCompleted() bool {
if s.completed {
return false
}
if s.nextFrame != nil && s.nextFrame.DataLen() > 0 {
return false
}
// We need to keep the stream around until all frames have been sent and acknowledged.
if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 || s.queuedResetStreamFrame != nil {
return false
}
// The stream is completed if we sent the FIN.
if s.finSent {
s.completed = true
return true
}
// The stream is also completed if:
// 1. the application called CancelWrite, or
// 2. we received a STOP_SENDING, and
// * the application consumed the error via Write, or
// * the application called Close
if s.resetErr != nil && (s.cancellationFlagged || s.finishedWriting) {
s.completed = true
return true
}
return false
}
// Close closes the write-direction of the stream.
// Future calls to Write are not permitted after calling Close.
// It must not be called concurrently with Write.
// It must not be called after calling CancelWrite.
func (s *SendStream) Close() error {
s.mutex.Lock()
if s.shutdownErr != nil || s.finishedWriting {
s.mutex.Unlock()
return nil
}
s.finishedWriting = true
cancelled := s.resetErr != nil
if cancelled {
s.cancellationFlagged = true
}
completed := s.isNewlyCompleted()
s.mutex.Unlock()
if completed {
s.sender.onStreamCompleted(s.streamID)
}
if cancelled {
return fmt.Errorf("close called for canceled stream %d", s.streamID)
}
s.sender.onHasStreamData(s.streamID, s) // need to send the FIN, must be called without holding the mutex
s.ctxCancel(nil)
return nil
}
// SetReliableBoundary marks the data written to this stream so far as reliable.
// It is valid to call this function multiple times, thereby increasing the reliable size.
// It only has an effect if the peer enabled support for the RESET_STREAM_AT extension,
// otherwise, it is a no-op.
func (s *SendStream) SetReliableBoundary() {
s.mutex.Lock()
defer s.mutex.Unlock()
s.reliableSize = s.writeOffset
if s.nextFrame != nil {
s.reliableSize += s.nextFrame.DataLen()
}
}
// CancelWrite aborts sending on this stream.
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
// Write will unblock immediately, and future calls to Write will fail.
// When called multiple times it is a no-op.
// When called after Close, it aborts reliable delivery of outstanding stream data.
// Note that there is no guarantee if the peer will receive the FIN or the cancellation error first.
func (s *SendStream) CancelWrite(errorCode StreamErrorCode) {
s.mutex.Lock()
if s.shutdownErr != nil {
s.mutex.Unlock()
return
}
s.cancellationFlagged = true
if s.resetErr != nil {
completed := s.isNewlyCompleted()
s.mutex.Unlock()
// The user has called CancelWrite. If the previous cancellation was because of a
// STOP_SENDING, we don't need to flag the error to the user anymore.
if completed {
s.sender.onStreamCompleted(s.streamID)
}
return
}
s.resetErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.ctxCancel(s.resetErr)
reliableOffset := s.reliableOffset()
if reliableOffset == 0 {
s.numOutstandingFrames = 0
s.retransmissionQueue = nil
}
s.queuedResetStreamFrame = &wire.ResetStreamFrame{
StreamID: s.streamID,
FinalSize: max(s.writeOffset, reliableOffset),
ErrorCode: errorCode,
// if the peer doesn't support the extension, the reliable offset will always be 0
ReliableSize: reliableOffset,
}
if reliableOffset > 0 {
if s.nextFrame != nil {
if s.nextFrame.Offset >= reliableOffset {
s.nextFrame.PutBack()
s.nextFrame = nil
} else if s.nextFrame.Offset+s.nextFrame.DataLen() > reliableOffset {
s.nextFrame.Data = s.nextFrame.Data[:reliableOffset-s.nextFrame.Offset]
}
}
if len(s.retransmissionQueue) > 0 {
retransmissionQueue := make([]*wire.StreamFrame, 0, len(s.retransmissionQueue))
for _, f := range s.retransmissionQueue {
if f.Offset >= reliableOffset {
f.PutBack()
continue
}
if f.Offset+f.DataLen() <= reliableOffset {
retransmissionQueue = append(retransmissionQueue, f)
} else {
f.Data = f.Data[:reliableOffset-f.Offset]
retransmissionQueue = append(retransmissionQueue, f)
}
}
s.retransmissionQueue = retransmissionQueue
}
}
s.mutex.Unlock()
s.signalWrite()
s.sender.onHasStreamControlFrame(s.streamID, s)
}
func (s *SendStream) enableResetStreamAt() {
s.mutex.Lock()
s.supportsResetStreamAt = true
s.mutex.Unlock()
}
func (s *SendStream) updateSendWindow(limit protocol.ByteCount) {
updated := s.flowController.UpdateSendWindow(limit)
if !updated { // duplicate or reordered MAX_STREAM_DATA frame
return
}
s.mutex.Lock()
hasStreamData := s.dataForWriting != nil || s.nextFrame != nil
s.mutex.Unlock()
if hasStreamData {
s.sender.onHasStreamData(s.streamID, s)
}
}
func (s *SendStream) handleStopSendingFrame(f *wire.StopSendingFrame) {
s.mutex.Lock()
if s.shutdownErr != nil {
s.mutex.Unlock()
return
}
// If the stream was already cancelled (either locally, or due to a previous STOP_SENDING frame),
// there's nothing else to do.
if s.resetErr != nil && s.reliableOffset() == 0 {
s.mutex.Unlock()
return
}
// if the peer stopped reading from the stream, there's no need to transmit any data reliably
s.reliableSize = 0
s.numOutstandingFrames = 0
s.retransmissionQueue = nil
if s.resetErr == nil {
s.resetErr = &StreamError{StreamID: s.streamID, ErrorCode: f.ErrorCode, Remote: true}
s.ctxCancel(s.resetErr)
}
s.queuedResetStreamFrame = &wire.ResetStreamFrame{
StreamID: s.streamID,
FinalSize: s.writeOffset,
ErrorCode: s.resetErr.ErrorCode,
}
s.mutex.Unlock()
s.signalWrite()
s.sender.onHasStreamControlFrame(s.streamID, s)
}
func (s *SendStream) getControlFrame(time.Time) (_ ackhandler.Frame, ok, hasMore bool) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.queuedResetStreamFrame == nil {
return ackhandler.Frame{}, false, false
}
s.numOutstandingFrames++
f := ackhandler.Frame{
Frame: s.queuedResetStreamFrame,
Handler: (*sendStreamResetStreamHandler)(s),
}
s.queuedResetStreamFrame = nil
return f, true, false
}
func (s *SendStream) reliableOffset() protocol.ByteCount {
if !s.supportsResetStreamAt {
return 0
}
return s.reliableSize
}
// The Context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() or CancelWrite() is called, or when the peer
// cancels the read-side of their stream.
// The cancellation cause is set to the error that caused the stream to
// close, or `context.Canceled` in case the stream is closed without error.
func (s *SendStream) Context() context.Context {
return s.ctx
}
// SetWriteDeadline sets the deadline for future Write calls
// and any currently-blocked Write call.
// Even if write times out, it may return n > 0, indicating that
// some data was successfully written.
// A zero value for t means Write will not time out.
func (s *SendStream) SetWriteDeadline(t time.Time) error {
s.mutex.Lock()
s.deadline = t
s.mutex.Unlock()
s.signalWrite()
return nil
}
// CloseForShutdown closes a stream abruptly.
// It makes Write unblock (and return the error) immediately.
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
func (s *SendStream) closeForShutdown(err error) {
s.mutex.Lock()
if s.shutdownErr == nil && !s.finishedWriting {
s.shutdownErr = err
}
s.mutex.Unlock()
s.signalWrite()
}
// signalWrite performs a non-blocking send on the writeChan
func (s *SendStream) signalWrite() {
select {
case s.writeChan <- struct{}{}:
default:
}
}
type sendStreamAckHandler SendStream
var _ ackhandler.FrameHandler = &sendStreamAckHandler{}
func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
sf := f.(*wire.StreamFrame)
sf.PutBack()
s.mutex.Lock()
if s.resetErr != nil && (*SendStream)(s).reliableOffset() == 0 {
s.mutex.Unlock()
return
}
s.numOutstandingFrames--
if s.numOutstandingFrames < 0 {
panic("numOutStandingFrames negative")
}
completed := (*SendStream)(s).isNewlyCompleted()
s.mutex.Unlock()
if completed {
s.sender.onStreamCompleted(s.streamID)
}
}
func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
sf := f.(*wire.StreamFrame)
s.mutex.Lock()
// If the reliable size was 0 when the stream was cancelled,
// the number of outstanding frames was immediately set to 0, and the retransmission queue was dropped.
if s.resetErr != nil && (*SendStream)(s).reliableOffset() == 0 {
s.mutex.Unlock()
return
}
s.numOutstandingFrames--
if s.numOutstandingFrames < 0 {
panic("numOutStandingFrames negative")
}
if s.resetErr != nil && (*SendStream)(s).reliableOffset() > 0 {
// If the stream was reset, and this frame is beyond the reliable offset,
// it doesn't need to be retransmitted.
if sf.Offset >= (*SendStream)(s).reliableOffset() {
sf.PutBack()
// If this frame was the last one tracked, losing it might cause the stream to be completed.
completed := (*SendStream)(s).isNewlyCompleted()
s.mutex.Unlock()
if completed {
s.sender.onStreamCompleted(s.streamID)
}
return
}
// If the payload of the frame extends beyond the reliable size,
// truncate the frame to the reliable size.
if sf.Offset+sf.DataLen() > (*SendStream)(s).reliableOffset() {
sf.Data = sf.Data[:(*SendStream)(s).reliableOffset()-sf.Offset]
}
}
sf.DataLenPresent = true
s.retransmissionQueue = append(s.retransmissionQueue, sf)
s.mutex.Unlock()
s.sender.onHasStreamData(s.streamID, (*SendStream)(s))
}
type sendStreamResetStreamHandler SendStream
var _ ackhandler.FrameHandler = &sendStreamResetStreamHandler{}
func (s *sendStreamResetStreamHandler) OnAcked(f wire.Frame) {
rsf := f.(*wire.ResetStreamFrame)
s.mutex.Lock()
// If the peer sent a STOP_SENDING after we sent a RESET_STREAM_AT frame,
// we sent 1. reduced the reliable size to 0 and 2. sent a RESET_STREAM frame.
// In this case, we don't care about the acknowledgment of this frame.
if rsf.ReliableSize != (*SendStream)(s).reliableOffset() {
s.mutex.Unlock()
return
}
s.numOutstandingFrames--
if s.numOutstandingFrames < 0 {
panic("numOutStandingFrames negative")
}
completed := (*SendStream)(s).isNewlyCompleted()
s.mutex.Unlock()
if completed {
s.sender.onStreamCompleted(s.streamID)
}
}
func (s *sendStreamResetStreamHandler) OnLost(f wire.Frame) {
rsf := f.(*wire.ResetStreamFrame)
s.mutex.Lock()
// If the peer sent a STOP_SENDING after we sent a RESET_STREAM_AT frame,
// we sent 1. reduced the reliable size to 0 and 2. sent a RESET_STREAM frame.
// In this case, the loss of the RESET_STREAM_AT frame can be ignored.
if rsf.ReliableSize != (*SendStream)(s).reliableOffset() {
s.mutex.Unlock()
return
}
s.queuedResetStreamFrame = rsf
s.numOutstandingFrames--
s.mutex.Unlock()
s.sender.onHasStreamControlFrame(s.streamID, (*SendStream)(s))
}
package quic
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
// ErrServerClosed is returned by the [Listener] or [EarlyListener]'s Accept method after a call to Close.
var ErrServerClosed = errServerClosed{}
type errServerClosed struct{}
func (errServerClosed) Error() string { return "quic: server closed" }
func (errServerClosed) Unwrap() error { return net.ErrClosed }
// packetHandler handles packets
type packetHandler interface {
handlePacket(receivedPacket)
destroy(error)
closeWithTransportError(qerr.TransportErrorCode)
}
type zeroRTTQueue struct {
packets []receivedPacket
expiration time.Time
}
type rejectedPacket struct {
receivedPacket
hdr *wire.Header
}
// A Listener of QUIC
type baseServer struct {
tr *packetHandlerMap
disableVersionNegotiation bool
acceptEarlyConns bool
tlsConf *tls.Config
config *Config
conn rawConn
tokenGenerator *handshake.TokenGenerator
maxTokenAge time.Duration
connIDGenerator ConnectionIDGenerator
statelessResetter *statelessResetter
onClose func()
receivedPackets chan receivedPacket
nextZeroRTTCleanup time.Time
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
connContext func(context.Context, *ClientInfo) (context.Context, error)
// set as a member, so they can be set in the tests
newConn func(
context.Context,
context.CancelCauseFunc,
sendConn,
connRunner,
protocol.ConnectionID, /* original dest connection ID */
*protocol.ConnectionID, /* retry src connection ID */
protocol.ConnectionID, /* client dest connection ID */
protocol.ConnectionID, /* destination connection ID */
protocol.ConnectionID, /* source connection ID */
ConnectionIDGenerator,
*statelessResetter,
*Config,
*tls.Config,
*handshake.TokenGenerator,
bool, /* client address validated by an address validation token */
time.Duration,
*logging.ConnectionTracer,
utils.Logger,
protocol.Version,
) *wrappedConn
closeMx sync.Mutex
// errorChan is closed when Close is called. This has two effects:
// 1. it cancels handshakes that are still in flight (using CONNECTION_REFUSED) errors
// 2. it stops handling of packets passed to this server
errorChan chan struct{}
// acceptChan is closed when Close returns.
// This only happens once all handshake in flight have either completed and canceled.
// Calls to Accept will first drain the queue of connections that have completed the handshake,
// and then return ErrServerClosed.
stopAccepting chan struct{}
closeErr error
running chan struct{} // closed as soon as run() returns
versionNegotiationQueue chan receivedPacket
invalidTokenQueue chan rejectedPacket
connectionRefusedQueue chan rejectedPacket
retryQueue chan rejectedPacket
handshakingCount sync.WaitGroup
verifySourceAddress func(net.Addr) bool
connQueue chan *Conn
tracer *logging.Tracer
logger utils.Logger
}
// A Listener listens for incoming QUIC connections.
// It returns connections once the handshake has completed.
type Listener struct {
baseServer *baseServer
}
// Accept returns new connections. It should be called in a loop.
func (l *Listener) Accept(ctx context.Context) (*Conn, error) {
return l.baseServer.Accept(ctx)
}
// Close closes the listener.
// Accept will return [ErrServerClosed] as soon as all connections in the accept queue have been accepted.
// QUIC handshakes that are still in flight will be rejected with a CONNECTION_REFUSED error.
// Already established (accepted)connections will be unaffected.
func (l *Listener) Close() error {
return l.baseServer.Close()
}
// Addr returns the local network address that the server is listening on.
func (l *Listener) Addr() net.Addr {
return l.baseServer.Addr()
}
// An EarlyListener listens for incoming QUIC connections, and returns them before the handshake completes.
// For connections that don't use 0-RTT, this allows the server to send 0.5-RTT data.
// This data is encrypted with forward-secure keys, however, the client's identity has not yet been verified.
// For connection using 0-RTT, this allows the server to accept and respond to streams that the client opened in the
// 0-RTT data it sent. Note that at this point during the handshake, the live-ness of the
// client has not yet been confirmed, and the 0-RTT data could have been replayed by an attacker.
type EarlyListener struct {
baseServer *baseServer
}
// Accept returns a new connections. It should be called in a loop.
func (l *EarlyListener) Accept(ctx context.Context) (*Conn, error) {
conn, err := l.baseServer.accept(ctx)
if err != nil {
return nil, err
}
return conn, nil
}
// Close the server. All active connections will be closed.
func (l *EarlyListener) Close() error {
return l.baseServer.Close()
}
// Addr returns the local network addr that the server is listening on.
func (l *EarlyListener) Addr() net.Addr {
return l.baseServer.Addr()
}
// ListenAddr creates a QUIC server listening on a given address.
// See [Listen] for more details.
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) {
conn, err := listenUDP(addr)
if err != nil {
return nil, err
}
return (&Transport{
Conn: conn,
createdConn: true,
isSingleUse: true,
}).Listen(tlsConf, config)
}
// ListenAddrEarly works like [ListenAddr], but it returns connections before the handshake completes.
func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
conn, err := listenUDP(addr)
if err != nil {
return nil, err
}
return (&Transport{
Conn: conn,
createdConn: true,
isSingleUse: true,
}).ListenEarly(tlsConf, config)
}
func listenUDP(addr string) (*net.UDPConn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
return net.ListenUDP("udp", udpAddr)
}
// Listen listens for QUIC connections on a given net.PacketConn.
// If the PacketConn satisfies the [OOBCapablePacketConn] interface (as a [net.UDPConn] does),
// ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
// will be used instead of ReadFrom and WriteTo to read/write packets.
// A single net.PacketConn can only be used for a single call to Listen.
//
// The tls.Config must not be nil and must contain a certificate configuration.
// Furthermore, it must define an application control (using [NextProtos]).
// The quic.Config may be nil, in that case the default values will be used.
//
// This is a convenience function. More advanced use cases should instantiate a [Transport],
// which offers configuration options for a more fine-grained control of the connection establishment,
// including reusing the underlying UDP socket for outgoing QUIC connections.
// When closing a listener created with Listen, all established QUIC connections will be closed immediately.
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) {
tr := &Transport{Conn: conn, isSingleUse: true}
return tr.Listen(tlsConf, config)
}
// ListenEarly works like [Listen], but it returns connections before the handshake completes.
func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
tr := &Transport{Conn: conn, isSingleUse: true}
return tr.ListenEarly(tlsConf, config)
}
func newServer(
conn rawConn,
tr *packetHandlerMap,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
connContext func(context.Context, *ClientInfo) (context.Context, error),
tlsConf *tls.Config,
config *Config,
tracer *logging.Tracer,
onClose func(),
tokenGeneratorKey TokenGeneratorKey,
maxTokenAge time.Duration,
verifySourceAddress func(net.Addr) bool,
disableVersionNegotiation bool,
acceptEarly bool,
) *baseServer {
s := &baseServer{
conn: conn,
connContext: connContext,
tr: tr,
tlsConf: tlsConf,
config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
maxTokenAge: maxTokenAge,
verifySourceAddress: verifySourceAddress,
connIDGenerator: connIDGenerator,
statelessResetter: statelessResetter,
connQueue: make(chan *Conn, protocol.MaxAcceptQueueSize),
errorChan: make(chan struct{}),
stopAccepting: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
versionNegotiationQueue: make(chan receivedPacket, 4),
invalidTokenQueue: make(chan rejectedPacket, 4),
connectionRefusedQueue: make(chan rejectedPacket, 4),
retryQueue: make(chan rejectedPacket, 8),
newConn: newConnection,
tracer: tracer,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
disableVersionNegotiation: disableVersionNegotiation,
onClose: onClose,
}
if acceptEarly {
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
}
go s.run()
go s.runSendQueue()
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
return s
}
func (s *baseServer) run() {
defer close(s.running)
for {
select {
case <-s.errorChan:
return
default:
}
select {
case <-s.errorChan:
return
case p := <-s.receivedPackets:
if bufferStillInUse := s.handlePacketImpl(p); !bufferStillInUse {
p.buffer.Release()
}
}
}
}
func (s *baseServer) runSendQueue() {
for {
select {
case <-s.running:
return
case p := <-s.versionNegotiationQueue:
s.maybeSendVersionNegotiationPacket(p)
case p := <-s.invalidTokenQueue:
s.maybeSendInvalidToken(p)
case p := <-s.connectionRefusedQueue:
s.sendConnectionRefused(p)
case p := <-s.retryQueue:
s.sendRetry(p)
}
}
}
// Accept returns connections that already completed the handshake.
// It is only valid if acceptEarlyConns is false.
func (s *baseServer) Accept(ctx context.Context) (*Conn, error) {
return s.accept(ctx)
}
func (s *baseServer) accept(ctx context.Context) (*Conn, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case conn := <-s.connQueue:
return conn, nil
case <-s.stopAccepting:
// first drain the queue
select {
case conn := <-s.connQueue:
return conn, nil
default:
}
return nil, s.closeErr
}
}
func (s *baseServer) Close() error {
s.close(ErrServerClosed, false)
return nil
}
// close closes the server. The Transport mutex must not be held while calling this method.
// This method closes any handshaking connections which requires the tranpsort mutex.
func (s *baseServer) close(e error, transportClose bool) {
s.closeMx.Lock()
if s.closeErr != nil {
s.closeMx.Unlock()
return
}
s.closeErr = e
close(s.errorChan)
<-s.running
s.closeMx.Unlock()
if !transportClose {
s.onClose()
}
// wait until all handshakes in flight have terminated
s.handshakingCount.Wait()
close(s.stopAccepting)
if transportClose {
// if the transport is closing, drain the connQueue. All connections in the queue
// will be closed by the transport.
for {
select {
case <-s.connQueue:
default:
return
}
}
}
}
// Addr returns the server's network address
func (s *baseServer) Addr() net.Addr {
return s.conn.LocalAddr()
}
func (s *baseServer) handlePacket(p receivedPacket) {
select {
case s.receivedPackets <- p:
case <-s.errorChan:
return
default:
s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size())
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
}
}
}
func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer still in use? */ {
if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) {
defer s.cleanupZeroRTTQueues(p.rcvTime)
}
if wire.IsVersionNegotiationPacket(p.data) {
s.logger.Debugf("Dropping Version Negotiation packet.")
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
// Short header packets should never end up here in the first place
if !wire.IsLongHeaderPacket(p.data[0]) {
panic(fmt.Sprintf("misrouted packet: %#v", p.data))
}
v, err := wire.ParseVersion(p.data)
// drop the packet if we failed to parse the protocol version
if err != nil {
s.logger.Debugf("Dropping a packet with an unknown version")
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
// send a Version Negotiation Packet if the client is speaking a different protocol version
if !protocol.IsSupportedVersion(s.config.Versions, v) {
if s.disableVersionNegotiation {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedVersion)
}
return false
}
if p.Size() < protocol.MinUnknownVersionPacketSize {
s.logger.Debugf("Dropping a packet with an unsupported version number %d that is too small (%d bytes)", v, p.Size())
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
return s.enqueueVersionNegotiationPacket(p)
}
if wire.Is0RTTPacket(p.data) {
if !s.acceptEarlyConns {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
return s.handle0RTTPacket(p)
}
// If we're creating a new connection, the packet will be passed to the connection.
// The header will then be parsed again.
hdr, _, _, err := wire.ParsePacket(p.data)
if err != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
}
s.logger.Debugf("Error parsing packet: %s", err)
return false
}
if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize {
s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size())
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
if hdr.Type != protocol.PacketTypeInitial {
// Drop long header packets.
// There's little point in sending a Stateless Reset, since the client
// might not have received the token yet.
s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data))
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
s.logger.Debugf("<- Received Initial packet.")
if err := s.handleInitialImpl(p, hdr); err != nil {
s.logger.Errorf("Error occurred handling initial packet: %s", err)
}
// Don't put the packet buffer back.
// handleInitialImpl deals with the buffer.
return true
}
func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
connID, err := wire.ParseConnectionID(p.data, 0)
if err != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
}
return false
}
// check again if we might have a connection now
if handler, ok := s.tr.Get(connID); ok {
handler.handlePacket(p)
return true
}
if q, ok := s.zeroRTTQueues[connID]; ok {
if len(q.packets) >= protocol.Max0RTTQueueLen {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return false
}
q.packets = append(q.packets, p)
return true
}
if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return false
}
queue := &zeroRTTQueue{packets: make([]receivedPacket, 1, 8)}
queue.packets[0] = p
expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration)
queue.expiration = expiration
if s.nextZeroRTTCleanup.IsZero() || s.nextZeroRTTCleanup.After(expiration) {
s.nextZeroRTTCleanup = expiration
}
s.zeroRTTQueues[connID] = queue
return true
}
func (s *baseServer) cleanupZeroRTTQueues(now time.Time) {
// Iterate over all queues to find those that are expired.
// This is ok since we're placing a pretty low limit on the number of queues.
var nextCleanup time.Time
for connID, q := range s.zeroRTTQueues {
if q.expiration.After(now) {
if nextCleanup.IsZero() || nextCleanup.After(q.expiration) {
nextCleanup = q.expiration
}
continue
}
for _, p := range q.packets {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
p.buffer.Release()
}
delete(s.zeroRTTQueues, connID)
if s.logger.Debug() {
s.logger.Debugf("Removing 0-RTT queue for %s.", connID)
}
}
s.nextZeroRTTCleanup = nextCleanup
}
// validateToken returns false if:
// - address is invalid
// - token is expired
// - token is null
func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
if token == nil {
return false
}
if !token.ValidateRemoteAddr(addr) {
return false
}
if !token.IsRetryToken && time.Since(token.SentTime) > s.maxTokenAge {
return false
}
if token.IsRetryToken && time.Since(token.SentTime) > s.config.maxRetryTokenAge() {
return false
}
return true
}
func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error {
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
}
p.buffer.Release()
return errors.New("too short connection ID")
}
// The server queues packets for a while, and we might already have established a connection by now.
// This results in a second check in the connection map.
// That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets).
if handler, ok := s.tr.Get(hdr.DestConnectionID); ok {
handler.handlePacket(p)
return nil
}
var (
token *handshake.Token
retrySrcConnID *protocol.ConnectionID
clientAddrVerified bool
)
origDestConnID := hdr.DestConnectionID
if len(hdr.Token) > 0 {
tok, err := s.tokenGenerator.DecodeToken(hdr.Token)
if err == nil {
if tok.IsRetryToken {
origDestConnID = tok.OriginalDestConnectionID
retrySrcConnID = &tok.RetrySrcConnectionID
}
token = tok
}
}
if token != nil {
clientAddrVerified = s.validateToken(token, p.remoteAddr)
if !clientAddrVerified {
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
// We just ignore them, and act as if there was no token on this packet at all.
// This also means we might send a Retry later.
if !token.IsRetryToken {
token = nil
} else {
// For Retry tokens, we send an INVALID_ERROR if
// * the token is too old, or
// * the token is invalid, in case of a retry token.
select {
case s.invalidTokenQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the INVALID_TOKEN packets fast enough
p.buffer.Release()
}
return nil
}
}
}
if token == nil && s.verifySourceAddress != nil && s.verifySourceAddress(p.remoteAddr) {
// Retry invalidates all 0-RTT packets sent.
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
case s.retryQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out Retry packets fast enough
p.buffer.Release()
}
return nil
}
// restore RTT from token
var rtt time.Duration
if token != nil && !token.IsRetryToken {
rtt = token.RTT
}
config := s.config
clientInfo := &ClientInfo{
RemoteAddr: p.remoteAddr,
AddrVerified: clientAddrVerified,
}
if s.config.GetConfigForClient != nil {
conf, err := s.config.GetConfigForClient(clientInfo)
if err != nil {
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
s.refuseNewConn(p, hdr)
return nil
}
config = populateConfig(conf)
}
var conn *wrappedConn
var cancel context.CancelCauseFunc
ctx, cancel1 := context.WithCancelCause(context.Background())
if s.connContext != nil {
var err error
ctx, err = s.connContext(ctx, clientInfo)
if err != nil {
cancel1(err)
s.logger.Debugf("Rejecting new connection due to ConnContext callback: %s", err)
s.refuseNewConn(p, hdr)
return nil
}
if ctx == nil {
panic("quic: ConnContext returned nil")
}
// There's no guarantee that the application returns a context
// that's derived from the context we passed into ConnContext.
// We need to make sure that both contexts are cancelled.
var cancel2 context.CancelCauseFunc
ctx, cancel2 = context.WithCancelCause(ctx)
cancel = func(cause error) {
cancel1(cause)
cancel2(cause)
}
} else {
cancel = cancel1
}
ctx = context.WithValue(ctx, ConnectionTracingKey, nextConnTracingID())
var tracer *logging.ConnectionTracer
if config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
connID := hdr.DestConnectionID
if origDestConnID.Len() > 0 {
connID = origDestConnID
}
tracer = config.Tracer(ctx, protocol.PerspectiveServer, connID)
}
connID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil {
return err
}
s.logger.Debugf("Changing connection ID to %s.", connID)
conn = s.newConn(
ctx,
cancel,
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
s.tr,
origDestConnID,
retrySrcConnID,
hdr.DestConnectionID,
hdr.SrcConnectionID,
connID,
s.connIDGenerator,
s.statelessResetter,
config,
s.tlsConf,
s.tokenGenerator,
clientAddrVerified,
rtt,
tracer,
s.logger,
hdr.Version,
)
conn.handlePacket(p)
// Adding the connection will fail if the client's chosen Destination Connection ID is already in use.
// This is very unlikely: Even if an attacker chooses a connection ID that's already in use,
// under normal circumstances the packet would just be routed to that connection.
// The only time this collision will occur if we receive the two Initial packets at the same time.
if added := s.tr.AddWithConnID(hdr.DestConnectionID, connID, conn); !added {
delete(s.zeroRTTQueues, hdr.DestConnectionID)
conn.closeWithTransportError(ConnectionRefused)
return nil
}
// Pass queued 0-RTT to the newly established connection.
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
for _, p := range q.packets {
conn.handlePacket(p)
}
delete(s.zeroRTTQueues, hdr.DestConnectionID)
}
s.handshakingCount.Add(1)
go func() {
defer s.handshakingCount.Done()
s.handleNewConn(conn)
}()
go conn.run()
return nil
}
func (s *baseServer) refuseNewConn(p receivedPacket, hdr *wire.Header) {
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
p.buffer.Release()
}
}
func (s *baseServer) handleNewConn(conn *wrappedConn) {
if s.acceptEarlyConns {
// wait until the early connection is ready, the handshake fails, or the server is closed
select {
case <-s.errorChan:
conn.closeWithTransportError(ConnectionRefused)
return
case <-conn.Context().Done():
return
case <-conn.earlyConnReady():
}
} else {
// wait until the handshake completes, fails, or the server is closed
select {
case <-s.errorChan:
conn.closeWithTransportError(ConnectionRefused)
return
case <-conn.Context().Done():
return
case <-conn.HandshakeComplete():
}
}
select {
case s.connQueue <- conn.Conn:
default:
conn.closeWithTransportError(ConnectionRefused)
}
}
func (s *baseServer) sendRetry(p rejectedPacket) {
if err := s.sendRetryPacket(p); err != nil {
s.logger.Debugf("Error sending Retry packet: %s", err)
}
}
func (s *baseServer) sendRetryPacket(p rejectedPacket) error {
hdr := p.hdr
// Log the Initial packet now.
// If no Retry is sent, the packet will be logged by the connection.
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
srcConnID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil {
return err
}
token, err := s.tokenGenerator.NewRetryToken(p.remoteAddr, hdr.DestConnectionID, srcConnID)
if err != nil {
return err
}
replyHdr := &wire.ExtendedHeader{}
replyHdr.Type = protocol.PacketTypeRetry
replyHdr.Version = hdr.Version
replyHdr.SrcConnectionID = srcConnID
replyHdr.DestConnectionID = hdr.SrcConnectionID
replyHdr.Token = token
if s.logger.Debug() {
s.logger.Debugf("Changing connection ID to %s.", srcConnID)
s.logger.Debugf("-> Sending Retry")
replyHdr.Log(s.logger)
}
buf := getPacketBuffer()
defer buf.Release()
buf.Data, err = replyHdr.Append(buf.Data, hdr.Version)
if err != nil {
return err
}
// append the Retry integrity tag
tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version)
buf.Data = append(buf.Data, tag[:]...)
if s.tracer != nil && s.tracer.SentPacket != nil {
s.tracer.SentPacket(p.remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
}
_, err = s.conn.WritePacket(buf.Data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported)
return err
}
func (s *baseServer) maybeSendInvalidToken(p rejectedPacket) {
defer p.buffer.Release()
// Only send INVALID_TOKEN if we can unprotect the packet.
// This makes sure that we won't send it for packets that were corrupted.
hdr := p.hdr
sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version)
data := p.data[:hdr.ParsedLen()+hdr.Length]
extHdr, err := unpackLongHeader(opener, hdr, data)
// Only send INVALID_TOKEN if we can unprotect the packet.
// This makes sure that we won't send it for packets that were corrupted.
if err != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
}
return
}
hdrLen := extHdr.ParsedLen()
if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError)
}
return
}
if s.logger.Debug() {
s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr)
}
if err := s.sendError(p.remoteAddr, hdr, sealer, InvalidToken, p.info); err != nil {
s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err)
}
}
func (s *baseServer) sendConnectionRefused(p rejectedPacket) {
defer p.buffer.Release()
sealer, _ := handshake.NewInitialAEAD(p.hdr.DestConnectionID, protocol.PerspectiveServer, p.hdr.Version)
if err := s.sendError(p.remoteAddr, p.hdr, sealer, ConnectionRefused, p.info); err != nil {
s.logger.Debugf("Error sending CONNECTION_REFUSED error: %s", err)
}
}
// sendError sends the error as a response to the packet received with header hdr
func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info packetInfo) error {
b := getPacketBuffer()
defer b.Release()
ccf := &wire.ConnectionCloseFrame{ErrorCode: uint64(errorCode)}
replyHdr := &wire.ExtendedHeader{}
replyHdr.Type = protocol.PacketTypeInitial
replyHdr.Version = hdr.Version
replyHdr.SrcConnectionID = hdr.DestConnectionID
replyHdr.DestConnectionID = hdr.SrcConnectionID
replyHdr.PacketNumberLen = protocol.PacketNumberLen4
replyHdr.Length = 4 /* packet number len */ + ccf.Length(hdr.Version) + protocol.ByteCount(sealer.Overhead())
var err error
b.Data, err = replyHdr.Append(b.Data, hdr.Version)
if err != nil {
return err
}
payloadOffset := len(b.Data)
b.Data, err = ccf.Append(b.Data, hdr.Version)
if err != nil {
return err
}
_ = sealer.Seal(b.Data[payloadOffset:payloadOffset], b.Data[payloadOffset:], replyHdr.PacketNumber, b.Data[:payloadOffset])
b.Data = b.Data[0 : len(b.Data)+sealer.Overhead()]
pnOffset := payloadOffset - int(replyHdr.PacketNumberLen)
sealer.EncryptHeader(
b.Data[pnOffset+4:pnOffset+4+16],
&b.Data[0],
b.Data[pnOffset:payloadOffset],
)
replyHdr.Log(s.logger)
wire.LogFrame(s.logger, ccf, true)
if s.tracer != nil && s.tracer.SentPacket != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
}
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported)
return err
}
func (s *baseServer) enqueueVersionNegotiationPacket(p receivedPacket) (bufferInUse bool) {
select {
case s.versionNegotiationQueue <- p:
return true
default:
// it's fine to not send version negotiation packets when we are busy
}
return false
}
func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) {
defer p.buffer.Release()
v, err := wire.ParseVersion(p.data)
if err != nil {
s.logger.Debugf("failed to parse version for sending version negotiation packet: %s", err)
return
}
_, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data)
if err != nil { // should never happen
s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs")
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
}
return
}
s.logger.Debugf("Client offered version %s, sending Version Negotiation", v)
data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions)
if s.tracer != nil && s.tracer.SentVersionNegotiationPacket != nil {
s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions)
}
if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err)
}
}
package quic
import (
"encoding/binary"
"errors"
"io"
)
const (
extTypeSNI = 0
extTypeECH = 0xfe0d
)
// findSNIAndECH parses the given byte slice as a ClientHello, and locates:
// - the position and length of the Server Name Indication (SNI) extension,
// - the position of the Encrypted Client Hello (ECH) extension.
// If no SNI extension is found, it returns -1 for the SNI position.
// If no ECH extension is found, it returns -1 for the ECH position.
func findSNIAndECH(data []byte) (sniPos, sniLen, echPos int, err error) {
if len(data) < 4 {
return 0, 0, 0, io.ErrUnexpectedEOF
}
if data[0] != 1 {
return 0, 0, 0, errors.New("not a ClientHello")
}
handshakeLen := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if len(data) != 4+handshakeLen {
return 0, 0, 0, io.ErrUnexpectedEOF
}
parsePos := 4
// Skip protocol version (2 bytes)
if parsePos+2 > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
parsePos += 2
// skip random (32 bytes)
if parsePos+32 > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
parsePos += 32
// session ID
if parsePos+1 > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
sessionIDLen := int(data[parsePos])
parsePos++
if parsePos+sessionIDLen > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
parsePos += sessionIDLen
// cipher suites
if parsePos+2 > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
cipherSuitesLen := int(binary.BigEndian.Uint16(data[parsePos:]))
parsePos += 2
if parsePos+cipherSuitesLen > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
parsePos += cipherSuitesLen
// compression methods
if parsePos+1 > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
compressionMethodsLen := int(data[parsePos])
parsePos++
if parsePos+compressionMethodsLen > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
parsePos += compressionMethodsLen
// extensions
if parsePos+2 > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
extensionsLen := int(binary.BigEndian.Uint16(data[parsePos:]))
parsePos += 2
if parsePos+extensionsLen > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
extensionsStart := parsePos
extensions := data[extensionsStart : extensionsStart+extensionsLen]
// parse extensions
var extPos int
sniPos = -1
echPos = -1
for extPos+4 <= extensionsLen {
extType := binary.BigEndian.Uint16(extensions[extPos:])
extLen := int(binary.BigEndian.Uint16(extensions[extPos+2:]))
if extPos+4+extLen > extensionsLen {
return 0, 0, 0, io.ErrUnexpectedEOF
}
switch extType {
case extTypeSNI:
if sniPos != -1 {
return 0, 0, 0, errors.New("multiple SNI extensions")
}
sniData := extensions[extPos+4 : extPos+4+extLen]
if len(sniData) < 2 {
return 0, 0, 0, io.ErrUnexpectedEOF
}
nameListLen := int(binary.BigEndian.Uint16(sniData))
if len(sniData) != 2+nameListLen {
return 0, 0, 0, io.ErrUnexpectedEOF
}
listPos := 2
for listPos+3 <= nameListLen+2 {
nameType := sniData[listPos]
sniLen = int(binary.BigEndian.Uint16(sniData[listPos+1:]))
if listPos+3+sniLen > len(sniData) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
if nameType == 0 { // host_name
sniPos = extensionsStart + extPos + 4 + listPos + 3
break // stop after first host_name
}
listPos += 3 + sniLen
}
if sniPos == 0 {
return 0, 0, 0, errors.New("SNI host_name not found")
}
case extTypeECH:
if echPos != -1 {
return 0, 0, 0, errors.New("multiple ECH extensions")
}
echPos = extensionsStart + extPos
}
extPos += 4 + extLen
if sniPos != -1 && echPos != -1 {
break
}
}
return sniPos, sniLen, echPos, nil
}
package quic
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"hash"
"sync"
"github.com/quic-go/quic-go/internal/protocol"
)
type statelessResetter struct {
mx sync.Mutex
h hash.Hash
}
// newStatelessRetter creates a new stateless reset generator.
// It is valid to use a nil key. In that case, a random key will be used.
// This makes is impossible for on-path attackers to shut down established connections.
func newStatelessResetter(key *StatelessResetKey) *statelessResetter {
var h hash.Hash
if key != nil {
h = hmac.New(sha256.New, key[:])
} else {
b := make([]byte, 32)
_, _ = rand.Read(b)
h = hmac.New(sha256.New, b)
}
return &statelessResetter{h: h}
}
func (r *statelessResetter) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
r.mx.Lock()
defer r.mx.Unlock()
var token protocol.StatelessResetToken
r.h.Write(connID.Bytes())
copy(token[:], r.h.Sum(nil))
r.h.Reset()
return token
}
package quic
import (
"context"
"net"
"os"
"sync"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
type deadlineError struct{}
func (deadlineError) Error() string { return "deadline exceeded" }
func (deadlineError) Temporary() bool { return true }
func (deadlineError) Timeout() bool { return true }
func (deadlineError) Unwrap() error { return os.ErrDeadlineExceeded }
var errDeadline net.Error = &deadlineError{}
// The streamSender is notified by the stream about various events.
type streamSender interface {
onHasConnectionData()
onHasStreamData(protocol.StreamID, *SendStream)
onHasStreamControlFrame(protocol.StreamID, streamControlFrameGetter)
// must be called without holding the mutex that is acquired by closeForShutdown
onStreamCompleted(protocol.StreamID)
}
// Each of the both stream halves gets its own uniStreamSender.
// This is necessary in order to keep track when both halves have been completed.
type uniStreamSender struct {
streamSender
onStreamCompletedImpl func()
onHasStreamControlFrameImpl func(protocol.StreamID, streamControlFrameGetter)
}
func (s *uniStreamSender) onHasStreamData(id protocol.StreamID, str *SendStream) {
s.streamSender.onHasStreamData(id, str)
}
func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) { s.onStreamCompletedImpl() }
func (s *uniStreamSender) onHasStreamControlFrame(id protocol.StreamID, str streamControlFrameGetter) {
s.onHasStreamControlFrameImpl(id, str)
}
var _ streamSender = &uniStreamSender{}
type Stream struct {
receiveStr *ReceiveStream
sendStr *SendStream
completedMutex sync.Mutex
sender streamSender
receiveStreamCompleted bool
sendStreamCompleted bool
}
var (
_ outgoingStream = &Stream{}
_ sendStreamFrameHandler = &Stream{}
_ receiveStreamFrameHandler = &Stream{}
)
// newStream creates a new Stream
func newStream(
ctx context.Context,
streamID protocol.StreamID,
sender streamSender,
flowController flowcontrol.StreamFlowController,
supportsResetStreamAt bool,
) *Stream {
s := &Stream{sender: sender}
senderForSendStream := &uniStreamSender{
streamSender: sender,
onStreamCompletedImpl: func() {
s.completedMutex.Lock()
s.sendStreamCompleted = true
s.checkIfCompleted()
s.completedMutex.Unlock()
},
onHasStreamControlFrameImpl: func(id protocol.StreamID, str streamControlFrameGetter) {
sender.onHasStreamControlFrame(streamID, s)
},
}
s.sendStr = newSendStream(ctx, streamID, senderForSendStream, flowController, supportsResetStreamAt)
senderForReceiveStream := &uniStreamSender{
streamSender: sender,
onStreamCompletedImpl: func() {
s.completedMutex.Lock()
s.receiveStreamCompleted = true
s.checkIfCompleted()
s.completedMutex.Unlock()
},
onHasStreamControlFrameImpl: func(id protocol.StreamID, str streamControlFrameGetter) {
sender.onHasStreamControlFrame(streamID, s)
},
}
s.receiveStr = newReceiveStream(streamID, senderForReceiveStream, flowController)
return s
}
// StreamID returns the stream ID.
func (s *Stream) StreamID() protocol.StreamID {
// the result is same for receiveStream and sendStream
return s.sendStr.StreamID()
}
// Read reads data from the stream.
// Read can be made to time out using [Stream.SetReadDeadline] and [Stream.SetDeadline].
// If the stream was canceled, the error is a [StreamError].
func (s *Stream) Read(p []byte) (int, error) {
return s.receiveStr.Read(p)
}
// Write writes data to the stream.
// Write can be made to time out using [Stream.SetWriteDeadline] or [Stream.SetDeadline].
// If the stream was canceled, the error is a [StreamError].
func (s *Stream) Write(p []byte) (int, error) {
return s.sendStr.Write(p)
}
// CancelWrite aborts sending on this stream.
// See [SendStream.CancelWrite] for more details.
func (s *Stream) CancelWrite(errorCode StreamErrorCode) {
s.sendStr.CancelWrite(errorCode)
}
// CancelRead aborts receiving on this stream.
// See [ReceiveStream.CancelRead] for more details.
func (s *Stream) CancelRead(errorCode StreamErrorCode) {
s.receiveStr.CancelRead(errorCode)
}
// The Context is canceled as soon as the write-side of the stream is closed.
// See [SendStream.Context] for more details.
func (s *Stream) Context() context.Context {
return s.sendStr.Context()
}
// Close closes the send-direction of the stream.
// It does not close the receive-direction of the stream.
func (s *Stream) Close() error {
return s.sendStr.Close()
}
func (s *Stream) handleResetStreamFrame(frame *wire.ResetStreamFrame, rcvTime time.Time) error {
return s.receiveStr.handleResetStreamFrame(frame, rcvTime)
}
func (s *Stream) handleStreamFrame(frame *wire.StreamFrame, rcvTime time.Time) error {
return s.receiveStr.handleStreamFrame(frame, rcvTime)
}
func (s *Stream) handleStopSendingFrame(frame *wire.StopSendingFrame) {
s.sendStr.handleStopSendingFrame(frame)
}
func (s *Stream) updateSendWindow(limit protocol.ByteCount) {
s.sendStr.updateSendWindow(limit)
}
func (s *Stream) enableResetStreamAt() {
s.sendStr.enableResetStreamAt()
}
func (s *Stream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (_ ackhandler.StreamFrame, _ *wire.StreamDataBlockedFrame, hasMore bool) {
return s.sendStr.popStreamFrame(maxBytes, v)
}
func (s *Stream) getControlFrame(now time.Time) (_ ackhandler.Frame, ok, hasMore bool) {
f, ok, _ := s.sendStr.getControlFrame(now)
if ok {
return f, true, true
}
return s.receiveStr.getControlFrame(now)
}
// SetReadDeadline sets the deadline for future Read calls.
// See [ReceiveStream.SetReadDeadline] for more details.
func (s *Stream) SetReadDeadline(t time.Time) error {
return s.receiveStr.SetReadDeadline(t)
}
// SetWriteDeadline sets the deadline for future Write calls.
// See [SendStream.SetWriteDeadline] for more details.
func (s *Stream) SetWriteDeadline(t time.Time) error {
return s.sendStr.SetWriteDeadline(t)
}
// SetDeadline sets the read and write deadlines associated with the stream.
// It is equivalent to calling both SetReadDeadline and SetWriteDeadline.
func (s *Stream) SetDeadline(t time.Time) error {
_ = s.receiveStr.SetReadDeadline(t) // SetReadDeadline never errors
_ = s.sendStr.SetWriteDeadline(t) // SetWriteDeadline never errors
return nil
}
// CloseForShutdown closes a stream abruptly.
// It makes Read and Write unblock (and return the error) immediately.
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
func (s *Stream) closeForShutdown(err error) {
s.sendStr.closeForShutdown(err)
s.receiveStr.closeForShutdown(err)
}
// checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed.
// It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed.
func (s *Stream) checkIfCompleted() {
if s.sendStreamCompleted && s.receiveStreamCompleted {
s.sender.onStreamCompleted(s.StreamID())
}
}
package quic
import (
"context"
"fmt"
"sync"
"time"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
)
// StreamLimitReachedError is returned from Conn.OpenStream and Conn.OpenUniStream
// when it is not possible to open a new stream because the number of opens streams reached
// the peer's stream limit.
type StreamLimitReachedError struct{}
func (e StreamLimitReachedError) Error() string { return "too many open streams" }
type streamsMap struct {
ctx context.Context // not used for cancellations, but carries the values associated with the connection
perspective protocol.Perspective
maxIncomingBidiStreams uint64
maxIncomingUniStreams uint64
sender streamSender
queueControlFrame func(wire.Frame)
newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
mutex sync.Mutex
outgoingBidiStreams *outgoingStreamsMap[*Stream]
outgoingUniStreams *outgoingStreamsMap[*SendStream]
incomingBidiStreams *incomingStreamsMap[*Stream]
incomingUniStreams *incomingStreamsMap[*ReceiveStream]
reset bool
supportsResetStreamAt bool
}
func newStreamsMap(
ctx context.Context,
sender streamSender,
queueControlFrame func(wire.Frame),
newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
maxIncomingBidiStreams uint64,
maxIncomingUniStreams uint64,
perspective protocol.Perspective,
) *streamsMap {
m := &streamsMap{
ctx: ctx,
perspective: perspective,
queueControlFrame: queueControlFrame,
newFlowController: newFlowController,
maxIncomingBidiStreams: maxIncomingBidiStreams,
maxIncomingUniStreams: maxIncomingUniStreams,
sender: sender,
}
m.initMaps()
return m
}
func (m *streamsMap) initMaps() {
m.outgoingBidiStreams = newOutgoingStreamsMap(
protocol.StreamTypeBidi,
func(id protocol.StreamID) *Stream {
return newStream(m.ctx, id, m.sender, m.newFlowController(id), m.supportsResetStreamAt)
},
m.queueControlFrame,
m.perspective,
)
m.incomingBidiStreams = newIncomingStreamsMap(
protocol.StreamTypeBidi,
func(id protocol.StreamID) *Stream {
return newStream(m.ctx, id, m.sender, m.newFlowController(id), m.supportsResetStreamAt)
},
m.maxIncomingBidiStreams,
m.queueControlFrame,
m.perspective,
)
m.outgoingUniStreams = newOutgoingStreamsMap(
protocol.StreamTypeUni,
func(id protocol.StreamID) *SendStream {
return newSendStream(m.ctx, id, m.sender, m.newFlowController(id), m.supportsResetStreamAt)
},
m.queueControlFrame,
m.perspective,
)
m.incomingUniStreams = newIncomingStreamsMap(
protocol.StreamTypeUni,
func(id protocol.StreamID) *ReceiveStream {
return newReceiveStream(id, m.sender, m.newFlowController(id))
},
m.maxIncomingUniStreams,
m.queueControlFrame,
m.perspective,
)
}
func (m *streamsMap) OpenStream() (*Stream, error) {
m.mutex.Lock()
reset := m.reset
mm := m.outgoingBidiStreams
m.mutex.Unlock()
if reset {
return nil, Err0RTTRejected
}
return mm.OpenStream()
}
func (m *streamsMap) OpenStreamSync(ctx context.Context) (*Stream, error) {
m.mutex.Lock()
reset := m.reset
mm := m.outgoingBidiStreams
m.mutex.Unlock()
if reset {
return nil, Err0RTTRejected
}
return mm.OpenStreamSync(ctx)
}
func (m *streamsMap) OpenUniStream() (*SendStream, error) {
m.mutex.Lock()
reset := m.reset
mm := m.outgoingUniStreams
m.mutex.Unlock()
if reset {
return nil, Err0RTTRejected
}
return mm.OpenStream()
}
func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (*SendStream, error) {
m.mutex.Lock()
reset := m.reset
mm := m.outgoingUniStreams
m.mutex.Unlock()
if reset {
return nil, Err0RTTRejected
}
return mm.OpenStreamSync(ctx)
}
func (m *streamsMap) AcceptStream(ctx context.Context) (*Stream, error) {
m.mutex.Lock()
reset := m.reset
mm := m.incomingBidiStreams
m.mutex.Unlock()
if reset {
return nil, Err0RTTRejected
}
return mm.AcceptStream(ctx)
}
func (m *streamsMap) AcceptUniStream(ctx context.Context) (*ReceiveStream, error) {
m.mutex.Lock()
reset := m.reset
mm := m.incomingUniStreams
m.mutex.Unlock()
if reset {
return nil, Err0RTTRejected
}
return mm.AcceptStream(ctx)
}
func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
switch id.Type() {
case protocol.StreamTypeUni:
if id.InitiatedBy() == m.perspective {
return m.outgoingUniStreams.DeleteStream(id)
}
return m.incomingUniStreams.DeleteStream(id)
case protocol.StreamTypeBidi:
if id.InitiatedBy() == m.perspective {
return m.outgoingBidiStreams.DeleteStream(id)
}
return m.incomingBidiStreams.DeleteStream(id)
}
panic("")
}
func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
switch f.Type {
case protocol.StreamTypeUni:
m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum.StreamID(protocol.StreamTypeUni, m.perspective))
case protocol.StreamTypeBidi:
m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum.StreamID(protocol.StreamTypeBidi, m.perspective))
}
}
type sendStreamFrameHandler interface {
updateSendWindow(protocol.ByteCount)
handleStopSendingFrame(*wire.StopSendingFrame)
}
func (m *streamsMap) getSendStream(id protocol.StreamID) (sendStreamFrameHandler, error) {
switch id.Type() {
case protocol.StreamTypeUni:
if id.InitiatedBy() != m.perspective {
// an outgoing unidirectional stream is a send stream, not a receive stream
return nil, &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: fmt.Sprintf("invalid frame for send stream %d", id),
}
}
str, err := m.outgoingUniStreams.GetStream(id)
if str == nil || err != nil {
return nil, err
}
return str, nil
case protocol.StreamTypeBidi:
if id.InitiatedBy() == m.perspective {
str, err := m.outgoingBidiStreams.GetStream(id)
if str == nil || err != nil {
return nil, err
}
return str, nil
}
str, err := m.incomingBidiStreams.GetOrOpenStream(id)
if str == nil || err != nil {
return nil, err
}
return str, nil
}
panic("unreachable")
}
func (m *streamsMap) HandleMaxStreamDataFrame(f *wire.MaxStreamDataFrame) error {
str, err := m.getSendStream(f.StreamID)
if err != nil {
return err
}
if str == nil { // stream already deleted
return nil
}
str.updateSendWindow(f.MaximumStreamData)
return nil
}
func (m *streamsMap) HandleStopSendingFrame(f *wire.StopSendingFrame) error {
str, err := m.getSendStream(f.StreamID)
if err != nil {
return err
}
if str == nil { // stream already deleted
return nil
}
str.handleStopSendingFrame(f)
return nil
}
type receiveStreamFrameHandler interface {
handleResetStreamFrame(*wire.ResetStreamFrame, time.Time) error
handleStreamFrame(*wire.StreamFrame, time.Time) error
}
func (m *streamsMap) getReceiveStream(id protocol.StreamID) (receiveStreamFrameHandler, error) {
switch id.Type() {
case protocol.StreamTypeUni:
// an outgoing unidirectional stream is a send stream, not a receive stream
if id.InitiatedBy() == m.perspective {
return nil, &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: fmt.Sprintf("invalid frame for receive stream %d", id),
}
}
str, err := m.incomingUniStreams.GetOrOpenStream(id)
if err != nil || str == nil {
return nil, err
}
return str, nil
case protocol.StreamTypeBidi:
var str *Stream
var err error
if id.InitiatedBy() == m.perspective {
str, err = m.outgoingBidiStreams.GetStream(id)
} else {
str, err = m.incomingBidiStreams.GetOrOpenStream(id)
}
if str == nil || err != nil {
return nil, err
}
return str, nil
}
panic("unreachable")
}
func (m *streamsMap) HandleStreamDataBlockedFrame(f *wire.StreamDataBlockedFrame) error {
if _, err := m.getReceiveStream(f.StreamID); err != nil {
return err
}
// We don't need to do anything in response to a STREAM_DATA_BLOCKED frame,
// but we need to make sure that the stream ID is valid.
return nil // we don't need to do anything in response to a STREAM_DATA_BLOCKED frame
}
func (m *streamsMap) HandleResetStreamFrame(f *wire.ResetStreamFrame, rcvTime time.Time) error {
str, err := m.getReceiveStream(f.StreamID)
if err != nil {
return err
}
if str == nil { // stream already deleted
return nil
}
return str.handleResetStreamFrame(f, rcvTime)
}
func (m *streamsMap) HandleStreamFrame(f *wire.StreamFrame, rcvTime time.Time) error {
str, err := m.getReceiveStream(f.StreamID)
if err != nil {
return err
}
if str == nil { // stream already deleted
return nil
}
return str.handleStreamFrame(f, rcvTime)
}
func (m *streamsMap) HandleTransportParameters(p *wire.TransportParameters) {
m.supportsResetStreamAt = p.EnableResetStreamAt
m.outgoingBidiStreams.EnableResetStreamAt()
m.outgoingUniStreams.EnableResetStreamAt()
m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote)
m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum.StreamID(protocol.StreamTypeBidi, m.perspective))
m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni)
m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum.StreamID(protocol.StreamTypeUni, m.perspective))
}
func (m *streamsMap) CloseWithError(err error) {
m.outgoingBidiStreams.CloseWithError(err)
m.outgoingUniStreams.CloseWithError(err)
m.incomingBidiStreams.CloseWithError(err)
m.incomingUniStreams.CloseWithError(err)
}
// ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are
// 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error.
// 2. reset to their initial state, such that we can immediately process new incoming stream data.
// Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error,
// until UseResetMaps() has been called.
func (m *streamsMap) ResetFor0RTT() {
m.mutex.Lock()
defer m.mutex.Unlock()
m.reset = true
m.CloseWithError(Err0RTTRejected)
m.initMaps()
}
func (m *streamsMap) UseResetMaps() {
m.mutex.Lock()
m.reset = false
m.mutex.Unlock()
}
package quic
import (
"context"
"fmt"
"sync"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
)
type incomingStream interface {
closeForShutdown(error)
}
// When a stream is deleted before it was accepted, we can't delete it from the map immediately.
// We need to wait until the application accepts it, and delete it then.
type incomingStreamEntry[T incomingStream] struct {
stream T
shouldDelete bool
}
type incomingStreamsMap[T incomingStream] struct {
mutex sync.RWMutex
newStreamChan chan struct{}
streamType protocol.StreamType
streams map[protocol.StreamID]incomingStreamEntry[T]
nextStreamToAccept protocol.StreamID // the next stream that will be returned by AcceptStream()
nextStreamToOpen protocol.StreamID // the highest stream that the peer opened
maxStream protocol.StreamID // the highest stream that the peer is allowed to open
maxNumStreams uint64 // maximum number of streams
newStream func(protocol.StreamID) T
queueMaxStreamID func(*wire.MaxStreamsFrame)
closeErr error
}
func newIncomingStreamsMap[T incomingStream](
streamType protocol.StreamType,
newStream func(protocol.StreamID) T,
maxStreams uint64,
queueControlFrame func(wire.Frame),
pers protocol.Perspective,
) *incomingStreamsMap[T] {
var nextStreamToAccept protocol.StreamID
switch {
case streamType == protocol.StreamTypeBidi && pers == protocol.PerspectiveServer:
nextStreamToAccept = protocol.FirstIncomingBidiStreamServer
case streamType == protocol.StreamTypeBidi && pers == protocol.PerspectiveClient:
nextStreamToAccept = protocol.FirstIncomingBidiStreamClient
case streamType == protocol.StreamTypeUni && pers == protocol.PerspectiveServer:
nextStreamToAccept = protocol.FirstIncomingUniStreamServer
case streamType == protocol.StreamTypeUni && pers == protocol.PerspectiveClient:
nextStreamToAccept = protocol.FirstIncomingUniStreamClient
}
return &incomingStreamsMap[T]{
newStreamChan: make(chan struct{}, 1),
streamType: streamType,
streams: make(map[protocol.StreamID]incomingStreamEntry[T]),
maxStream: protocol.StreamNum(maxStreams).StreamID(streamType, pers.Opposite()),
maxNumStreams: maxStreams,
newStream: newStream,
nextStreamToOpen: nextStreamToAccept,
nextStreamToAccept: nextStreamToAccept,
queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
}
}
func (m *incomingStreamsMap[T]) AcceptStream(ctx context.Context) (T, error) {
// drain the newStreamChan, so we don't check the map twice if the stream doesn't exist
select {
case <-m.newStreamChan:
default:
}
m.mutex.Lock()
var id protocol.StreamID
var entry incomingStreamEntry[T]
for {
id = m.nextStreamToAccept
if m.closeErr != nil {
m.mutex.Unlock()
return *new(T), m.closeErr
}
var ok bool
entry, ok = m.streams[id]
if ok {
break
}
m.mutex.Unlock()
select {
case <-ctx.Done():
return *new(T), ctx.Err()
case <-m.newStreamChan:
}
m.mutex.Lock()
}
m.nextStreamToAccept += 4
// If this stream was completed before being accepted, we can delete it now.
if entry.shouldDelete {
if err := m.deleteStream(id); err != nil {
m.mutex.Unlock()
return *new(T), err
}
}
m.mutex.Unlock()
return entry.stream, nil
}
func (m *incomingStreamsMap[T]) GetOrOpenStream(id protocol.StreamID) (T, error) {
m.mutex.RLock()
if id > m.maxStream {
m.mutex.RUnlock()
return *new(T), &qerr.TransportError{
ErrorCode: qerr.StreamLimitError,
ErrorMessage: fmt.Sprintf("peer tried to open stream %d (current limit: %d)", id, m.maxStream),
}
}
// if the num is smaller than the highest we accepted
// * this stream exists in the map, and we can return it, or
// * this stream was already closed, then we can return the nil
if id < m.nextStreamToOpen {
var s T
// If the stream was already queued for deletion, and is just waiting to be accepted, don't return it.
if entry, ok := m.streams[id]; ok && !entry.shouldDelete {
s = entry.stream
}
m.mutex.RUnlock()
return s, nil
}
m.mutex.RUnlock()
m.mutex.Lock()
// no need to check the two error conditions from above again
// * maxStream can only increase, so if the id was valid before, it definitely is valid now
// * highestStream is only modified by this function
for newNum := m.nextStreamToOpen; newNum <= id; newNum += 4 {
m.streams[newNum] = incomingStreamEntry[T]{stream: m.newStream(newNum)}
select {
case m.newStreamChan <- struct{}{}:
default:
}
}
m.nextStreamToOpen = id + 4
entry := m.streams[id]
m.mutex.Unlock()
return entry.stream, nil
}
func (m *incomingStreamsMap[T]) DeleteStream(id protocol.StreamID) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.deleteStream(id); err != nil {
return &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: err.Error(),
}
}
return nil
}
func (m *incomingStreamsMap[T]) deleteStream(id protocol.StreamID) error {
if _, ok := m.streams[id]; !ok {
return fmt.Errorf("tried to delete unknown incoming stream %d", id)
}
// Don't delete this stream yet, if it was not yet accepted.
// Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted.
if id >= m.nextStreamToAccept {
entry, ok := m.streams[id]
if ok && entry.shouldDelete {
return fmt.Errorf("tried to delete incoming stream %d multiple times", id)
}
entry.shouldDelete = true
m.streams[id] = entry // can't assign to struct in map, so we need to reassign
return nil
}
delete(m.streams, id)
// queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream
if m.maxNumStreams > uint64(len(m.streams)) {
maxStream := m.nextStreamToOpen + 4*protocol.StreamID(m.maxNumStreams-uint64(len(m.streams))-1)
// never send a value larger than the maximum value for a stream number
if maxStream <= protocol.MaxStreamID {
m.maxStream = maxStream
m.queueMaxStreamID(&wire.MaxStreamsFrame{
Type: m.streamType,
MaxStreamNum: m.maxStream.StreamNum(),
})
}
}
return nil
}
func (m *incomingStreamsMap[T]) CloseWithError(err error) {
m.mutex.Lock()
m.closeErr = err
for _, entry := range m.streams {
entry.stream.closeForShutdown(err)
}
m.mutex.Unlock()
close(m.newStreamChan)
}
package quic
import (
"context"
"fmt"
"slices"
"sync"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
)
type outgoingStream interface {
updateSendWindow(protocol.ByteCount)
enableResetStreamAt()
closeForShutdown(error)
}
type outgoingStreamsMap[T outgoingStream] struct {
mutex sync.RWMutex
streamType protocol.StreamType
streams map[protocol.StreamID]T
openQueue []chan struct{}
nextStream protocol.StreamID // stream ID of the stream returned by OpenStream(Sync)
maxStream protocol.StreamID // the maximum stream ID we're allowed to open
blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream
newStream func(protocol.StreamID) T
queueStreamIDBlocked func(*wire.StreamsBlockedFrame)
closeErr error
}
func newOutgoingStreamsMap[T outgoingStream](
streamType protocol.StreamType,
newStream func(protocol.StreamID) T,
queueControlFrame func(wire.Frame),
pers protocol.Perspective,
) *outgoingStreamsMap[T] {
var nextStream protocol.StreamID
switch {
case streamType == protocol.StreamTypeBidi && pers == protocol.PerspectiveServer:
nextStream = protocol.FirstOutgoingBidiStreamServer
case streamType == protocol.StreamTypeBidi && pers == protocol.PerspectiveClient:
nextStream = protocol.FirstOutgoingBidiStreamClient
case streamType == protocol.StreamTypeUni && pers == protocol.PerspectiveServer:
nextStream = protocol.FirstOutgoingUniStreamServer
case streamType == protocol.StreamTypeUni && pers == protocol.PerspectiveClient:
nextStream = protocol.FirstOutgoingUniStreamClient
}
return &outgoingStreamsMap[T]{
streamType: streamType,
streams: make(map[protocol.StreamID]T),
maxStream: protocol.InvalidStreamNum,
nextStream: nextStream,
newStream: newStream,
queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) },
}
}
func (m *outgoingStreamsMap[T]) OpenStream() (T, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.closeErr != nil {
return *new(T), m.closeErr
}
// if there are OpenStreamSync calls waiting, return an error here
if len(m.openQueue) > 0 || m.nextStream > m.maxStream {
m.maybeSendBlockedFrame()
return *new(T), &StreamLimitReachedError{}
}
return m.openStream(), nil
}
func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.closeErr != nil {
return *new(T), m.closeErr
}
if err := ctx.Err(); err != nil {
return *new(T), err
}
if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
return m.openStream(), nil
}
waitChan := make(chan struct{}, 1)
m.openQueue = append(m.openQueue, waitChan)
m.maybeSendBlockedFrame()
for {
m.mutex.Unlock()
select {
case <-ctx.Done():
m.mutex.Lock()
m.openQueue = slices.DeleteFunc(m.openQueue, func(c chan struct{}) bool {
return c == waitChan
})
// If we just received a MAX_STREAMS frame, this might have been the next stream
// that could be opened. Make sure we unblock the next OpenStreamSync call.
m.maybeUnblockOpenSync()
return *new(T), ctx.Err()
case <-waitChan:
}
m.mutex.Lock()
if m.closeErr != nil {
return *new(T), m.closeErr
}
if m.nextStream > m.maxStream {
// no stream available. Continue waiting
continue
}
str := m.openStream()
m.openQueue = m.openQueue[1:]
m.maybeUnblockOpenSync()
return str, nil
}
}
func (m *outgoingStreamsMap[T]) openStream() T {
s := m.newStream(m.nextStream)
m.streams[m.nextStream] = s
m.nextStream += 4
return s
}
// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset,
// if we haven't sent one for this offset yet
func (m *outgoingStreamsMap[T]) maybeSendBlockedFrame() {
if m.blockedSent {
return
}
var streamLimit protocol.StreamNum
if m.maxStream != protocol.InvalidStreamID {
streamLimit = m.maxStream.StreamNum()
}
m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{
Type: m.streamType,
StreamLimit: streamLimit,
})
m.blockedSent = true
}
func (m *outgoingStreamsMap[T]) GetStream(id protocol.StreamID) (T, error) {
m.mutex.RLock()
if id >= m.nextStream {
m.mutex.RUnlock()
return *new(T), &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: fmt.Sprintf("peer attempted to open stream %d", id),
}
}
s := m.streams[id]
m.mutex.RUnlock()
return s, nil
}
func (m *outgoingStreamsMap[T]) DeleteStream(id protocol.StreamID) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, ok := m.streams[id]; !ok {
return &qerr.TransportError{
ErrorCode: qerr.StreamStateError,
ErrorMessage: fmt.Sprintf("tried to delete unknown outgoing stream %d", id),
}
}
delete(m.streams, id)
return nil
}
func (m *outgoingStreamsMap[T]) SetMaxStream(id protocol.StreamID) {
m.mutex.Lock()
defer m.mutex.Unlock()
if id <= m.maxStream {
return
}
m.maxStream = id
m.blockedSent = false
if m.maxStream < m.nextStream-4+4*protocol.StreamID(len(m.openQueue)) {
m.maybeSendBlockedFrame()
}
m.maybeUnblockOpenSync()
}
// UpdateSendWindow is called when the peer's transport parameters are received.
// Only in the case of a 0-RTT handshake will we have open streams at this point.
// We might need to update the send window, in case the server increased it.
func (m *outgoingStreamsMap[T]) UpdateSendWindow(limit protocol.ByteCount) {
m.mutex.Lock()
for _, str := range m.streams {
str.updateSendWindow(limit)
}
m.mutex.Unlock()
}
func (m *outgoingStreamsMap[T]) EnableResetStreamAt() {
m.mutex.Lock()
for _, str := range m.streams {
str.enableResetStreamAt()
}
m.mutex.Unlock()
}
// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream
func (m *outgoingStreamsMap[T]) maybeUnblockOpenSync() {
if len(m.openQueue) == 0 {
return
}
if m.nextStream > m.maxStream {
return
}
// unblockOpenSync is called both from OpenStreamSync and from SetMaxStream.
// It's sufficient to only unblock OpenStreamSync once.
select {
case m.openQueue[0] <- struct{}{}:
default:
}
}
func (m *outgoingStreamsMap[T]) CloseWithError(err error) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.closeErr = err
for _, str := range m.streams {
str.closeForShutdown(err)
}
for _, c := range m.openQueue {
if c != nil {
close(c)
}
}
m.openQueue = nil
}
package quic
import (
"io"
"log"
"net"
"os"
"strconv"
"strings"
"syscall"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
type connCapabilities struct {
// This connection has the Don't Fragment (DF) bit set.
// This means it makes to run DPLPMTUD.
DF bool
// GSO (Generic Segmentation Offload) supported
GSO bool
// ECN (Explicit Congestion Notifications) supported
ECN bool
}
// rawConn is a connection that allow reading of a receivedPackeh.
type rawConn interface {
ReadPacket() (receivedPacket, error)
// WritePacket writes a packet on the wire.
// gsoSize is the size of a single packet, or 0 to disable GSO.
// It is invalid to set gsoSize if capabilities.GSO is not set.
WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error)
LocalAddr() net.Addr
SetReadDeadline(time.Time) error
io.Closer
capabilities() connCapabilities
}
// OOBCapablePacketConn is a connection that allows the reading of ECN bits from the IP header.
// If the PacketConn passed to the [Transport] satisfies this interface, quic-go will use it.
// In this case, ReadMsgUDP() will be used instead of ReadFrom() to read packets.
type OOBCapablePacketConn interface {
net.PacketConn
SyscallConn() (syscall.RawConn, error)
SetReadBuffer(int) error
ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error)
}
var _ OOBCapablePacketConn = &net.UDPConn{}
func wrapConn(pc net.PacketConn) (rawConn, error) {
if err := setReceiveBuffer(pc); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
setBufferWarningOnce.Do(func() {
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
return
}
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.", err)
})
}
}
if err := setSendBuffer(pc); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
setBufferWarningOnce.Do(func() {
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
return
}
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.", err)
})
}
}
conn, ok := pc.(interface {
SyscallConn() (syscall.RawConn, error)
})
var supportsDF bool
if ok {
rawConn, err := conn.SyscallConn()
if err != nil {
return nil, err
}
// only set DF on UDP sockets
if _, ok := pc.LocalAddr().(*net.UDPAddr); ok {
var err error
supportsDF, err = setDF(rawConn)
if err != nil {
return nil, err
}
}
}
c, ok := pc.(OOBCapablePacketConn)
if !ok {
utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.")
return &basicConn{PacketConn: pc, supportsDF: supportsDF}, nil
}
return newConn(c, supportsDF)
}
// The basicConn is the most trivial implementation of a rawConn.
// It reads a single packet from the underlying net.PacketConn.
// It is used when
// * the net.PacketConn is not a OOBCapablePacketConn, and
// * when the OS doesn't support OOB.
type basicConn struct {
net.PacketConn
supportsDF bool
}
var _ rawConn = &basicConn{}
func (c *basicConn) ReadPacket() (receivedPacket, error) {
buffer := getPacketBuffer()
// The packet size should not exceed protocol.MaxPacketBufferSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable
buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize]
n, addr, err := c.ReadFrom(buffer.Data)
if err != nil {
return receivedPacket{}, err
}
return receivedPacket{
remoteAddr: addr,
rcvTime: time.Now(),
data: buffer.Data[:n],
buffer: buffer,
}, nil
}
func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16, ecn protocol.ECN) (n int, err error) {
if gsoSize != 0 {
panic("cannot use GSO with a basicConn")
}
if ecn != protocol.ECNUnsupported {
panic("cannot use ECN with a basicConn")
}
return c.WriteTo(b, addr)
}
func (c *basicConn) capabilities() connCapabilities { return connCapabilities{DF: c.supportsDF} }
package quic
import (
"errors"
"fmt"
"net"
"syscall"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
//go:generate sh -c "echo '// Code generated by go generate. DO NOT EDIT.\n// Source: sys_conn_buffers.go\n' > sys_conn_buffers_write.go && sed -e 's/SetReadBuffer/SetWriteBuffer/g' -e 's/setReceiveBuffer/setSendBuffer/g' -e 's/inspectReadBuffer/inspectWriteBuffer/g' -e 's/protocol\\.DesiredReceiveBufferSize/protocol\\.DesiredSendBufferSize/g' -e 's/forceSetReceiveBuffer/forceSetSendBuffer/g' -e 's/receive buffer/send buffer/g' sys_conn_buffers.go | sed '/^\\/\\/go:generate/d' >> sys_conn_buffers_write.go"
func setReceiveBuffer(c net.PacketConn) error {
conn, ok := c.(interface{ SetReadBuffer(int) error })
if !ok {
return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?")
}
var syscallConn syscall.RawConn
if sc, ok := c.(interface {
SyscallConn() (syscall.RawConn, error)
}); ok {
var err error
syscallConn, err = sc.SyscallConn()
if err != nil {
syscallConn = nil
}
}
// The connection has a SetReadBuffer method, but we couldn't obtain a syscall.RawConn.
// This shouldn't happen for a net.UDPConn, but is possible if the connection just implements the
// net.PacketConn interface and the SetReadBuffer method.
// We have no way of checking if increasing the buffer size actually worked.
if syscallConn == nil {
return conn.SetReadBuffer(protocol.DesiredReceiveBufferSize)
}
size, err := inspectReadBuffer(syscallConn)
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
if size >= protocol.DesiredReceiveBufferSize {
utils.DefaultLogger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
return nil
}
// Ignore the error. We check if we succeeded by querying the buffer size afterward.
_ = conn.SetReadBuffer(protocol.DesiredReceiveBufferSize)
newSize, err := inspectReadBuffer(syscallConn)
if newSize < protocol.DesiredReceiveBufferSize {
// Try again with RCVBUFFORCE on Linux
_ = forceSetReceiveBuffer(syscallConn, protocol.DesiredReceiveBufferSize)
newSize, err = inspectReadBuffer(syscallConn)
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
}
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
if newSize == size {
return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024)
}
if newSize < protocol.DesiredReceiveBufferSize {
return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024)
}
utils.DefaultLogger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
return nil
}
// Code generated by go generate. DO NOT EDIT.
// Source: sys_conn_buffers.go
package quic
import (
"errors"
"fmt"
"net"
"syscall"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
func setSendBuffer(c net.PacketConn) error {
conn, ok := c.(interface{ SetWriteBuffer(int) error })
if !ok {
return errors.New("connection doesn't allow setting of send buffer size. Not a *net.UDPConn?")
}
var syscallConn syscall.RawConn
if sc, ok := c.(interface {
SyscallConn() (syscall.RawConn, error)
}); ok {
var err error
syscallConn, err = sc.SyscallConn()
if err != nil {
syscallConn = nil
}
}
// The connection has a SetWriteBuffer method, but we couldn't obtain a syscall.RawConn.
// This shouldn't happen for a net.UDPConn, but is possible if the connection just implements the
// net.PacketConn interface and the SetWriteBuffer method.
// We have no way of checking if increasing the buffer size actually worked.
if syscallConn == nil {
return conn.SetWriteBuffer(protocol.DesiredSendBufferSize)
}
size, err := inspectWriteBuffer(syscallConn)
if err != nil {
return fmt.Errorf("failed to determine send buffer size: %w", err)
}
if size >= protocol.DesiredSendBufferSize {
utils.DefaultLogger.Debugf("Conn has send buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredSendBufferSize/1024)
return nil
}
// Ignore the error. We check if we succeeded by querying the buffer size afterward.
_ = conn.SetWriteBuffer(protocol.DesiredSendBufferSize)
newSize, err := inspectWriteBuffer(syscallConn)
if newSize < protocol.DesiredSendBufferSize {
// Try again with RCVBUFFORCE on Linux
_ = forceSetSendBuffer(syscallConn, protocol.DesiredSendBufferSize)
newSize, err = inspectWriteBuffer(syscallConn)
if err != nil {
return fmt.Errorf("failed to determine send buffer size: %w", err)
}
}
if err != nil {
return fmt.Errorf("failed to determine send buffer size: %w", err)
}
if newSize == size {
return fmt.Errorf("failed to increase send buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredSendBufferSize/1024, newSize/1024)
}
if newSize < protocol.DesiredSendBufferSize {
return fmt.Errorf("failed to sufficiently increase send buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredSendBufferSize/1024, newSize/1024)
}
utils.DefaultLogger.Debugf("Increased send buffer size to %d kiB", newSize/1024)
return nil
}
//go:build linux
package quic
import (
"errors"
"syscall"
"golang.org/x/sys/unix"
"github.com/quic-go/quic-go/internal/utils"
)
func setDF(rawConn syscall.RawConn) (bool, error) {
// Enabling IP_MTU_DISCOVER will force the kernel to return "sendto: message too long"
// and the datagram will not be fragmented
var errDFIPv4, errDFIPv6 error
if err := rawConn.Control(func(fd uintptr) {
errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_PROBE)
errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_PROBE)
}); err != nil {
return false, err
}
switch {
case errDFIPv4 == nil && errDFIPv6 == nil:
utils.DefaultLogger.Debugf("Setting DF for IPv4 and IPv6.")
case errDFIPv4 == nil && errDFIPv6 != nil:
utils.DefaultLogger.Debugf("Setting DF for IPv4.")
case errDFIPv4 != nil && errDFIPv6 == nil:
utils.DefaultLogger.Debugf("Setting DF for IPv6.")
case errDFIPv4 != nil && errDFIPv6 != nil:
return false, errors.New("setting DF failed for both IPv4 and IPv6")
}
return true, nil
}
func isSendMsgSizeErr(err error) bool {
// https://man7.org/linux/man-pages/man7/udp.7.html
return errors.Is(err, unix.EMSGSIZE)
}
func isRecvMsgSizeErr(error) bool { return false }
//go:build linux
package quic
import (
"encoding/binary"
"errors"
"net/netip"
"os"
"strconv"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
const (
msgTypeIPTOS = unix.IP_TOS
ipv4PKTINFO = unix.IP_PKTINFO
)
const ecnIPv4DataLen = 1
const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed)
var kernelVersionMajor int
func init() {
kernelVersionMajor, _ = kernelVersion()
}
func forceSetReceiveBuffer(c syscall.RawConn, bytes int) error {
var serr error
if err := c.Control(func(fd uintptr) {
serr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, bytes)
}); err != nil {
return err
}
return serr
}
func forceSetSendBuffer(c syscall.RawConn, bytes int) error {
var serr error
if err := c.Control(func(fd uintptr) {
serr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, bytes)
}); err != nil {
return err
}
return serr
}
func parseIPv4PktInfo(body []byte) (ip netip.Addr, ifIndex uint32, ok bool) {
// struct in_pktinfo {
// unsigned int ipi_ifindex; /* Interface index */
// struct in_addr ipi_spec_dst; /* Local address */
// struct in_addr ipi_addr; /* Header Destination address */
// };
if len(body) != 12 {
return netip.Addr{}, 0, false
}
return netip.AddrFrom4(*(*[4]byte)(body[8:12])), binary.NativeEndian.Uint32(body), true
}
// isGSOEnabled tests if the kernel supports GSO.
// Sending with GSO might still fail later on, if the interface doesn't support it (see isGSOError).
func isGSOEnabled(conn syscall.RawConn) bool {
if kernelVersionMajor < 5 {
return false
}
disabled, err := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_GSO"))
if err == nil && disabled {
return false
}
var serr error
if err := conn.Control(func(fd uintptr) {
_, serr = unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
}); err != nil {
return false
}
return serr == nil
}
func appendUDPSegmentSizeMsg(b []byte, size uint16) []byte {
startLen := len(b)
const dataLen = 2 // payload is a uint16
b = append(b, make([]byte, unix.CmsgSpace(dataLen))...)
h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen]))
h.Level = syscall.IPPROTO_UDP
h.Type = unix.UDP_SEGMENT
h.SetLen(unix.CmsgLen(dataLen))
// UnixRights uses the private `data` method, but I *think* this achieves the same goal.
offset := startLen + unix.CmsgSpace(0)
*(*uint16)(unsafe.Pointer(&b[offset])) = size
return b
}
func isGSOError(err error) bool {
var serr *os.SyscallError
if errors.As(err, &serr) {
// EIO is returned by udp_send_skb() if the device driver does not have tx checksums enabled,
// which is a hard requirement of UDP_SEGMENT. See:
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
return serr.Err == unix.EIO
}
return false
}
// The first sendmsg call on a new UDP socket sometimes errors on Linux.
// It's not clear why this happens.
// See https://github.com/golang/go/issues/63322.
func isPermissionError(err error) bool {
var serr *os.SyscallError
if errors.As(err, &serr) {
return serr.Syscall == "sendmsg" && serr.Err == unix.EPERM
}
return false
}
func isECNEnabled() bool {
return kernelVersionMajor >= 5 && !isECNDisabledUsingEnv()
}
// kernelVersion returns major and minor kernel version numbers, parsed from
// the syscall.Uname's Release field, or 0, 0 if the version can't be obtained
// or parsed.
//
// copied from the standard library's internal/syscall/unix/kernel_version_linux.go
func kernelVersion() (major, minor int) {
var uname syscall.Utsname
if err := syscall.Uname(&uname); err != nil {
return
}
var (
values [2]int
value, vi int
)
for _, c := range uname.Release {
if '0' <= c && c <= '9' {
value = (value * 10) + int(c-'0')
} else {
// Note that we're assuming N.N.N here.
// If we see anything else, we are likely to mis-parse it.
values[vi] = value
vi++
if vi >= len(values) {
break
}
value = 0
}
}
return values[0], values[1]
}
//go:build darwin || linux || freebsd
package quic
import (
"encoding/binary"
"errors"
"log"
"net"
"net/netip"
"os"
"strconv"
"sync"
"syscall"
"time"
"unsafe"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
const (
ecnMask = 0x3
oobBufferSize = 128
)
// Contrary to what the naming suggests, the ipv{4,6}.Message is not dependent on the IP version.
// They're both just aliases for x/net/internal/socket.Message.
// This means we can use this struct to read from a socket that receives both IPv4 and IPv6 messages.
var _ ipv4.Message = ipv6.Message{}
type batchConn interface {
ReadBatch(ms []ipv4.Message, flags int) (int, error)
}
func inspectReadBuffer(c syscall.RawConn) (int, error) {
var size int
var serr error
if err := c.Control(func(fd uintptr) {
size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
}); err != nil {
return 0, err
}
return size, serr
}
func inspectWriteBuffer(c syscall.RawConn) (int, error) {
var size int
var serr error
if err := c.Control(func(fd uintptr) {
size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF)
}); err != nil {
return 0, err
}
return size, serr
}
func isECNDisabledUsingEnv() bool {
disabled, err := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_ECN"))
return err == nil && disabled
}
type oobConn struct {
OOBCapablePacketConn
batchConn batchConn
readPos uint8
// Packets received from the kernel, but not yet returned by ReadPacket().
messages []ipv4.Message
buffers [batchSize]*packetBuffer
cap connCapabilities
}
var _ rawConn = &oobConn{}
func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) {
rawConn, err := c.SyscallConn()
if err != nil {
return nil, err
}
var needsPacketInfo bool
if udpAddr, ok := c.LocalAddr().(*net.UDPAddr); ok && udpAddr.IP.IsUnspecified() {
needsPacketInfo = true
}
// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
// Try enabling receiving of ECN and packet info for both IP versions.
// We expect at least one of those syscalls to succeed.
var errECNIPv4, errECNIPv6, errPIIPv4, errPIIPv6 error
if err := rawConn.Control(func(fd uintptr) {
errECNIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
errECNIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
if needsPacketInfo {
errPIIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, ipv4PKTINFO, 1)
errPIIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
}
}); err != nil {
return nil, err
}
switch {
case errECNIPv4 == nil && errECNIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.")
case errECNIPv4 == nil && errECNIPv6 != nil:
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.")
case errECNIPv4 != nil && errECNIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.")
case errECNIPv4 != nil && errECNIPv6 != nil:
return nil, errors.New("activating ECN failed for both IPv4 and IPv6")
}
if needsPacketInfo {
switch {
case errPIIPv4 == nil && errPIIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of packet info for IPv4 and IPv6.")
case errPIIPv4 == nil && errPIIPv6 != nil:
utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv4.")
case errPIIPv4 != nil && errPIIPv6 == nil:
utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv6.")
case errPIIPv4 != nil && errPIIPv6 != nil:
return nil, errors.New("activating packet info failed for both IPv4 and IPv6")
}
}
// Allows callers to pass in a connection that already satisfies batchConn interface
// to make use of the optimisation. Otherwise, ipv4.NewPacketConn would unwrap the file descriptor
// via SyscallConn(), and read it that way, which might not be what the caller wants.
var bc batchConn
if ibc, ok := c.(batchConn); ok {
bc = ibc
} else {
bc = ipv4.NewPacketConn(c)
}
msgs := make([]ipv4.Message, batchSize)
for i := range msgs {
// preallocate the [][]byte
msgs[i].Buffers = make([][]byte, 1)
}
oobConn := &oobConn{
OOBCapablePacketConn: c,
batchConn: bc,
messages: msgs,
readPos: batchSize,
cap: connCapabilities{
DF: supportsDF,
GSO: isGSOEnabled(rawConn),
ECN: isECNEnabled(),
},
}
for i := 0; i < batchSize; i++ {
oobConn.messages[i].OOB = make([]byte, oobBufferSize)
}
return oobConn, nil
}
var invalidCmsgOnceV4, invalidCmsgOnceV6 sync.Once
func (c *oobConn) ReadPacket() (receivedPacket, error) {
if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages.
c.messages = c.messages[:batchSize]
// replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call
for i := uint8(0); i < c.readPos; i++ {
buffer := getPacketBuffer()
buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize]
c.buffers[i] = buffer
c.messages[i].Buffers[0] = c.buffers[i].Data
}
c.readPos = 0
n, err := c.batchConn.ReadBatch(c.messages, 0)
if n == 0 || err != nil {
return receivedPacket{}, err
}
c.messages = c.messages[:n]
}
msg := c.messages[c.readPos]
buffer := c.buffers[c.readPos]
c.readPos++
data := msg.OOB[:msg.NN]
p := receivedPacket{
remoteAddr: msg.Addr,
rcvTime: time.Now(),
data: msg.Buffers[0][:msg.N],
buffer: buffer,
}
for len(data) > 0 {
hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data)
if err != nil {
return receivedPacket{}, err
}
if hdr.Level == unix.IPPROTO_IP {
switch hdr.Type {
case msgTypeIPTOS:
if len(body) != 1 {
return receivedPacket{}, errors.New("invalid IPTOS size")
}
p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask)
case ipv4PKTINFO:
ip, ifIndex, ok := parseIPv4PktInfo(body)
if ok {
p.info.addr = ip
p.info.ifIndex = ifIndex
} else {
invalidCmsgOnceV4.Do(func() {
log.Printf("Received invalid IPv4 packet info control message: %+x. "+
"This should never occur, please open a new issue and include details about the architecture.", body)
})
}
}
}
if hdr.Level == unix.IPPROTO_IPV6 {
switch hdr.Type {
case unix.IPV6_TCLASS:
if len(body) != 4 {
return receivedPacket{}, errors.New("invalid IPV6_TCLASS size")
}
bits := uint8(binary.NativeEndian.Uint32(body)) & ecnMask
p.ecn = protocol.ParseECNHeaderBits(bits)
case unix.IPV6_PKTINFO:
// struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */
// unsigned int ipi6_ifindex; /* send/recv interface index */
// };
if len(body) == 20 {
p.info.addr = netip.AddrFrom16(*(*[16]byte)(body[:16])).Unmap()
p.info.ifIndex = binary.NativeEndian.Uint32(body[16:])
} else {
invalidCmsgOnceV6.Do(func() {
log.Printf("Received invalid IPv6 packet info control message: %+x. "+
"This should never occur, please open a new issue and include details about the architecture.", body)
})
}
}
}
data = remainder
}
return p, nil
}
// WritePacket writes a new packet.
func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) {
oob := packetInfoOOB
if gsoSize > 0 {
if !c.capabilities().GSO {
panic("GSO disabled")
}
oob = appendUDPSegmentSizeMsg(oob, gsoSize)
}
if ecn != protocol.ECNUnsupported {
if !c.capabilities().ECN {
panic("tried to send an ECN-marked packet although ECN is disabled")
}
if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok {
if remoteUDPAddr.IP.To4() != nil {
oob = appendIPv4ECNMsg(oob, ecn)
} else {
oob = appendIPv6ECNMsg(oob, ecn)
}
}
}
n, _, err := c.WriteMsgUDP(b, oob, addr.(*net.UDPAddr))
return n, err
}
func (c *oobConn) capabilities() connCapabilities {
return c.cap
}
type packetInfo struct {
addr netip.Addr
ifIndex uint32
}
func (info *packetInfo) OOB() []byte {
if info == nil {
return nil
}
if info.addr.Is4() {
ip := info.addr.As4()
// struct in_pktinfo {
// unsigned int ipi_ifindex; /* Interface index */
// struct in_addr ipi_spec_dst; /* Local address */
// struct in_addr ipi_addr; /* Header Destination address */
// };
cm := ipv4.ControlMessage{
Src: ip[:],
IfIndex: int(info.ifIndex),
}
return cm.Marshal()
} else if info.addr.Is6() {
ip := info.addr.As16()
// struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */
// unsigned int ipi6_ifindex; /* send/recv interface index */
// };
cm := ipv6.ControlMessage{
Src: ip[:],
IfIndex: int(info.ifIndex),
}
return cm.Marshal()
}
return nil
}
func appendIPv4ECNMsg(b []byte, val protocol.ECN) []byte {
startLen := len(b)
b = append(b, make([]byte, unix.CmsgSpace(ecnIPv4DataLen))...)
h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen]))
h.Level = syscall.IPPROTO_IP
h.Type = unix.IP_TOS
h.SetLen(unix.CmsgLen(ecnIPv4DataLen))
// UnixRights uses the private `data` method, but I *think* this achieves the same goal.
offset := startLen + unix.CmsgSpace(0)
b[offset] = val.ToHeaderBits()
return b
}
func appendIPv6ECNMsg(b []byte, val protocol.ECN) []byte {
startLen := len(b)
const dataLen = 4
b = append(b, make([]byte, unix.CmsgSpace(dataLen))...)
h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen]))
h.Level = syscall.IPPROTO_IPV6
h.Type = unix.IPV6_TCLASS
h.SetLen(unix.CmsgLen(dataLen))
// UnixRights uses the private `data` method, but I *think* this achieves the same goal.
offset := startLen + unix.CmsgSpace(0)
binary.NativeEndian.PutUint32(b[offset:offset+dataLen], uint32(val.ToHeaderBits()))
return b
}
package quic
import (
"sync"
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
)
type singleOriginTokenStore struct {
tokens []*ClientToken
len int
p int
}
func newSingleOriginTokenStore(size int) *singleOriginTokenStore {
return &singleOriginTokenStore{tokens: make([]*ClientToken, size)}
}
func (s *singleOriginTokenStore) Add(token *ClientToken) {
s.tokens[s.p] = token
s.p = s.index(s.p + 1)
s.len = min(s.len+1, len(s.tokens))
}
func (s *singleOriginTokenStore) Pop() *ClientToken {
s.p = s.index(s.p - 1)
token := s.tokens[s.p]
s.tokens[s.p] = nil
s.len = max(s.len-1, 0)
return token
}
func (s *singleOriginTokenStore) Len() int {
return s.len
}
func (s *singleOriginTokenStore) index(i int) int {
mod := len(s.tokens)
return (i + mod) % mod
}
type lruTokenStoreEntry struct {
key string
cache *singleOriginTokenStore
}
type lruTokenStore struct {
mutex sync.Mutex
m map[string]*list.Element[*lruTokenStoreEntry]
q *list.List[*lruTokenStoreEntry]
capacity int
singleOriginSize int
}
var _ TokenStore = &lruTokenStore{}
// NewLRUTokenStore creates a new LRU cache for tokens received by the client.
// maxOrigins specifies how many origins this cache is saving tokens for.
// tokensPerOrigin specifies the maximum number of tokens per origin.
func NewLRUTokenStore(maxOrigins, tokensPerOrigin int) TokenStore {
return &lruTokenStore{
m: make(map[string]*list.Element[*lruTokenStoreEntry]),
q: list.New[*lruTokenStoreEntry](),
capacity: maxOrigins,
singleOriginSize: tokensPerOrigin,
}
}
func (s *lruTokenStore) Put(key string, token *ClientToken) {
s.mutex.Lock()
defer s.mutex.Unlock()
if el, ok := s.m[key]; ok {
entry := el.Value
entry.cache.Add(token)
s.q.MoveToFront(el)
return
}
if s.q.Len() < s.capacity {
entry := &lruTokenStoreEntry{
key: key,
cache: newSingleOriginTokenStore(s.singleOriginSize),
}
entry.cache.Add(token)
s.m[key] = s.q.PushFront(entry)
return
}
elem := s.q.Back()
entry := elem.Value
delete(s.m, entry.key)
entry.key = key
entry.cache = newSingleOriginTokenStore(s.singleOriginSize)
entry.cache.Add(token)
s.q.MoveToFront(elem)
s.m[key] = elem
}
func (s *lruTokenStore) Pop(key string) *ClientToken {
s.mutex.Lock()
defer s.mutex.Unlock()
var token *ClientToken
if el, ok := s.m[key]; ok {
s.q.MoveToFront(el)
cache := el.Value.cache
token = cache.Pop()
if cache.Len() == 0 {
s.q.Remove(el)
delete(s.m, key)
}
}
return token
}
package quic
import (
"context"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
// ErrTransportClosed is returned by the [Transport]'s Listen or Dial method after it was closed.
var ErrTransportClosed = &errTransportClosed{}
type errTransportClosed struct {
err error
}
func (e *errTransportClosed) Unwrap() []error { return []error{net.ErrClosed, e.err} }
func (e *errTransportClosed) Error() string {
if e.err == nil {
return "quic: transport closed"
}
return fmt.Sprintf("quic: transport closed: %s", e.err)
}
func (e *errTransportClosed) Is(target error) bool {
_, ok := target.(*errTransportClosed)
return ok
}
var errListenerAlreadySet = errors.New("listener already set")
type closePacket struct {
payload []byte
addr net.Addr
info packetInfo
}
// The Transport is the central point to manage incoming and outgoing QUIC connections.
// QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple.
// This means that a single UDP socket can be used for listening for incoming connections, as well as
// for dialing an arbitrary number of outgoing connections.
// A Transport handles a single net.PacketConn, and offers a range of configuration options
// compared to the simple helper functions like [Listen] and [Dial] that this package provides.
type Transport struct {
// A single net.PacketConn can only be handled by one Transport.
// Bad things will happen if passed to multiple Transports.
//
// A number of optimizations will be enabled if the connections implements the OOBCapablePacketConn interface,
// as a *net.UDPConn does.
// 1. It enables the Don't Fragment (DF) bit on the IP header.
// This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899).
// 2. It enables reading of the ECN bits from the IP header.
// This allows the remote node to speed up its loss detection and recovery.
// 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket.
// 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux).
//
// After passing the connection to the Transport, it's invalid to call ReadFrom or WriteTo on the connection.
Conn net.PacketConn
// The length of the connection ID in bytes.
// It can be any value between 1 and 20.
// Due to the increased risk of collisions, it is not recommended to use connection IDs shorter than 4 bytes.
// If unset, a 4 byte connection ID will be used.
ConnectionIDLength int
// Use for generating new connection IDs.
// This allows the application to control of the connection IDs used,
// which allows routing / load balancing based on connection IDs.
// All Connection IDs returned by the ConnectionIDGenerator MUST
// have the same length.
ConnectionIDGenerator ConnectionIDGenerator
// The StatelessResetKey is used to generate stateless reset tokens.
// If no key is configured, sending of stateless resets is disabled.
// It is highly recommended to configure a stateless reset key, as stateless resets
// allow the peer to quickly recover from crashes and reboots of this node.
// See section 10.3 of RFC 9000 for details.
StatelessResetKey *StatelessResetKey
// The TokenGeneratorKey is used to encrypt session resumption tokens.
// If no key is configured, a random key will be generated.
// If multiple servers are authoritative for the same domain, they should use the same key,
// see section 8.1.3 of RFC 9000 for details.
TokenGeneratorKey *TokenGeneratorKey
// MaxTokenAge is the maximum age of the resumption token presented during the handshake.
// These tokens allow skipping address resumption when resuming a QUIC connection,
// and are especially useful when using 0-RTT.
// If not set, it defaults to 24 hours.
// See section 8.1.3 of RFC 9000 for details.
MaxTokenAge time.Duration
// DisableVersionNegotiationPackets disables the sending of Version Negotiation packets.
// This can be useful if version information is exchanged out-of-band.
// It has no effect for clients.
DisableVersionNegotiationPackets bool
// VerifySourceAddress decides if a connection attempt originating from unvalidated source
// addresses first needs to go through source address validation using QUIC's Retry mechanism,
// as described in RFC 9000 section 8.1.2.
// Note that the address passed to this callback is unvalidated, and might be spoofed in case
// of an attack.
// Validating the source address adds one additional network roundtrip to the handshake,
// and should therefore only be used if a suspiciously high number of incoming connection is recorded.
// For most use cases, wrapping the Allow function of a rate.Limiter will be a reasonable
// implementation of this callback (negating its return value).
VerifySourceAddress func(net.Addr) bool
// ConnContext is called when the server accepts a new connection. To reject a connection return
// a non-nil error.
// The context is closed when the connection is closed, or when the handshake fails for any reason.
// The context returned from the callback is used to derive every other context used during the
// lifetime of the connection:
// * the context passed to crypto/tls (and used on the tls.ClientHelloInfo)
// * the context used in Config.Tracer
// * the context returned from Conn.Context
// * the context returned from SendStream.Context
// It is not used for dialed connections.
ConnContext func(context.Context, *ClientInfo) (context.Context, error)
// A Tracer traces events that don't belong to a single QUIC connection.
// Tracer.Close is called when the transport is closed.
Tracer *logging.Tracer
mutex sync.Mutex
handlers map[protocol.ConnectionID]packetHandler
resetTokens map[protocol.StatelessResetToken]packetHandler
initOnce sync.Once
initErr error
// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
connIDLen int
// Set in init.
// If no ConnectionIDGenerator is set, this is set to a default.
connIDGenerator ConnectionIDGenerator
statelessResetter *statelessResetter
server *baseServer
conn rawConn
closeQueue chan closePacket
statelessResetQueue chan receivedPacket
listening chan struct{} // is closed when listen returns
closeErr error
createdConn bool
isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
readingNonQUICPackets atomic.Bool
nonQUICPackets chan receivedPacket
logger utils.Logger
}
// Listen starts listening for incoming QUIC connections.
// There can only be a single listener on any net.PacketConn.
// Listen may only be called again after the current listener was closed.
func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) {
s, err := t.createServer(tlsConf, conf, false)
if err != nil {
return nil, err
}
return &Listener{baseServer: s}, nil
}
// ListenEarly starts listening for incoming QUIC connections.
// There can only be a single listener on any net.PacketConn.
// ListenEarly may only be called again after the current listener was closed.
func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) {
s, err := t.createServer(tlsConf, conf, true)
if err != nil {
return nil, err
}
return &EarlyListener{baseServer: s}, nil
}
func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bool) (*baseServer, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
if err := validateConfig(conf); err != nil {
return nil, err
}
t.mutex.Lock()
defer t.mutex.Unlock()
if t.closeErr != nil {
return nil, t.closeErr
}
if t.server != nil {
return nil, errListenerAlreadySet
}
conf = populateConfig(conf)
if err := t.init(false); err != nil {
return nil, err
}
maxTokenAge := t.MaxTokenAge
if maxTokenAge == 0 {
maxTokenAge = 24 * time.Hour
}
s := newServer(
t.conn,
(*packetHandlerMap)(t),
t.connIDGenerator,
t.statelessResetter,
t.ConnContext,
tlsConf,
conf,
t.Tracer,
t.closeServer,
*t.TokenGeneratorKey,
maxTokenAge,
t.VerifySourceAddress,
t.DisableVersionNegotiationPackets,
allow0RTT,
)
t.server = s
return s, nil
}
// Dial dials a new connection to a remote host (not using 0-RTT).
func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (*Conn, error) {
return t.dial(ctx, addr, "", tlsConf, conf, false)
}
// DialEarly dials a new connection, attempting to use 0-RTT if possible.
func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (*Conn, error) {
return t.dial(ctx, addr, "", tlsConf, conf, true)
}
func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (*Conn, error) {
if err := t.init(t.isSingleUse); err != nil {
return nil, err
}
if err := validateConfig(conf); err != nil {
return nil, err
}
conf = populateConfig(conf)
tlsConf = tlsConf.Clone()
setTLSConfigServerName(tlsConf, addr, host)
return t.doDial(ctx,
newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger),
tlsConf,
conf,
0,
false,
use0RTT,
conf.Versions[0],
)
}
func (t *Transport) doDial(
ctx context.Context,
sendConn sendConn,
tlsConf *tls.Config,
config *Config,
initialPacketNumber protocol.PacketNumber,
hasNegotiatedVersion bool,
use0RTT bool,
version protocol.Version,
) (*Conn, error) {
srcConnID, err := t.connIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err
}
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return nil, err
}
tracingID := nextConnTracingID()
ctx = context.WithValue(ctx, ConnectionTracingKey, tracingID)
t.mutex.Lock()
if t.closeErr != nil {
t.mutex.Unlock()
return nil, t.closeErr
}
var tracer *logging.ConnectionTracer
if config.Tracer != nil {
tracer = config.Tracer(ctx, protocol.PerspectiveClient, destConnID)
}
if tracer != nil && tracer.StartedConnection != nil {
tracer.StartedConnection(sendConn.LocalAddr(), sendConn.RemoteAddr(), srcConnID, destConnID)
}
logger := utils.DefaultLogger.WithPrefix("client")
logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", tlsConf.ServerName, sendConn.LocalAddr(), sendConn.RemoteAddr(), srcConnID, destConnID, version)
conn := newClientConnection(
context.WithoutCancel(ctx),
sendConn,
(*packetHandlerMap)(t),
destConnID,
srcConnID,
t.connIDGenerator,
t.statelessResetter,
config,
tlsConf,
initialPacketNumber,
use0RTT,
hasNegotiatedVersion,
tracer,
logger,
version,
)
t.handlers[srcConnID] = conn
t.mutex.Unlock()
// The error channel needs to be buffered, as the run loop will continue running
// after doDial returns (if the handshake is successful).
// Similarly, the recreateChan needs to be buffered; in case a different case is selected.
errChan := make(chan error, 1)
recreateChan := make(chan errCloseForRecreating, 1)
go func() {
err := conn.run()
var recreateErr *errCloseForRecreating
if errors.As(err, &recreateErr) {
recreateChan <- *recreateErr
return
}
if t.isSingleUse {
t.Close()
}
errChan <- err
}()
// Only set when we're using 0-RTT.
// Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
var earlyConnChan <-chan struct{}
if use0RTT {
earlyConnChan = conn.earlyConnReady()
}
select {
case <-ctx.Done():
conn.destroy(nil)
// wait until the Go routine that called Conn.run() returns
select {
case <-errChan:
case <-recreateChan:
}
return nil, context.Cause(ctx)
case params := <-recreateChan:
return t.doDial(ctx,
sendConn,
tlsConf,
config,
params.nextPacketNumber,
true,
use0RTT,
params.nextVersion,
)
case err := <-errChan:
return nil, err
case <-earlyConnChan:
// ready to send 0-RTT data
return conn.Conn, nil
case <-conn.HandshakeComplete():
// handshake successfully completed
return conn.Conn, nil
}
}
func (t *Transport) init(allowZeroLengthConnIDs bool) error {
t.initOnce.Do(func() {
var conn rawConn
if c, ok := t.Conn.(rawConn); ok {
conn = c
} else {
var err error
conn, err = wrapConn(t.Conn)
if err != nil {
t.initErr = err
return
}
}
t.logger = utils.DefaultLogger // TODO: make this configurable
t.conn = conn
t.handlers = make(map[protocol.ConnectionID]packetHandler)
t.resetTokens = make(map[protocol.StatelessResetToken]packetHandler)
t.listening = make(chan struct{})
t.closeQueue = make(chan closePacket, 4)
t.statelessResetQueue = make(chan receivedPacket, 4)
if t.TokenGeneratorKey == nil {
var key TokenGeneratorKey
if _, err := rand.Read(key[:]); err != nil {
t.initErr = err
return
}
t.TokenGeneratorKey = &key
}
if t.ConnectionIDGenerator != nil {
t.connIDGenerator = t.ConnectionIDGenerator
t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen()
} else {
connIDLen := t.ConnectionIDLength
if t.ConnectionIDLength == 0 && !allowZeroLengthConnIDs {
connIDLen = protocol.DefaultConnectionIDLength
}
t.connIDLen = connIDLen
t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen}
}
t.statelessResetter = newStatelessResetter(t.StatelessResetKey)
go func() {
defer close(t.listening)
t.listen(conn)
if t.createdConn {
conn.Close()
}
}()
go t.runSendQueue()
})
return t.initErr
}
// WriteTo sends a packet on the underlying connection.
func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
if err := t.init(false); err != nil {
return 0, err
}
return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported)
}
func (t *Transport) runSendQueue() {
for {
select {
case <-t.listening:
return
case p := <-t.closeQueue:
t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNUnsupported)
case p := <-t.statelessResetQueue:
t.sendStatelessReset(p)
}
}
}
// Close stops listening for UDP datagrams on the Transport.Conn.
// It abruptly terminates all existing connections, without sending a CONNECTION_CLOSE
// to the peers. It is the application's responsibility to cleanly terminate existing
// connections prior to calling Close.
//
// If a server was started, it will be closed as well.
// It is not possible to start any new server or dial new connections after that.
func (t *Transport) Close() error {
// avoid race condition if the transport is currently being initialized
t.init(false)
t.close(nil)
if t.createdConn {
if err := t.Conn.Close(); err != nil {
return err
}
} else if t.conn != nil {
t.conn.SetReadDeadline(time.Now())
defer func() { t.conn.SetReadDeadline(time.Time{}) }()
}
if t.listening != nil {
<-t.listening // wait until listening returns
}
return nil
}
func (t *Transport) closeServer() {
t.mutex.Lock()
defer t.mutex.Unlock()
t.server = nil
if t.isSingleUse {
t.closeErr = ErrServerClosed
}
if len(t.handlers) == 0 {
t.maybeStopListening()
}
}
func (t *Transport) close(e error) {
t.mutex.Lock()
if t.closeErr != nil {
t.mutex.Unlock()
return
}
e = &errTransportClosed{err: e}
t.closeErr = e
server := t.server
t.server = nil
if server != nil {
t.mutex.Unlock()
server.close(e, true)
t.mutex.Lock()
}
// Close existing connections
var wg sync.WaitGroup
for _, handler := range t.handlers {
wg.Add(1)
go func(handler packetHandler) {
handler.destroy(e)
wg.Done()
}(handler)
}
t.mutex.Unlock() // closing connections requires releasing transport mutex
wg.Wait()
if t.Tracer != nil && t.Tracer.Close != nil {
t.Tracer.Close()
}
}
// only print warnings about the UDP receive buffer size once
var setBufferWarningOnce sync.Once
func (t *Transport) listen(conn rawConn) {
for {
p, err := conn.ReadPacket()
//nolint:staticcheck // SA1019 ignore this!
// TODO: This code is used to ignore wsa errors on Windows.
// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
// See https://github.com/quic-go/quic-go/issues/1737 for details.
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
t.mutex.Lock()
closed := t.closeErr != nil
t.mutex.Unlock()
if closed {
return
}
t.logger.Debugf("Temporary error reading from conn: %w", err)
continue
}
if err != nil {
// Windows returns an error when receiving a UDP datagram that doesn't fit into the provided buffer.
if isRecvMsgSizeErr(err) {
continue
}
t.close(err)
return
}
t.handlePacket(p)
}
}
func (t *Transport) maybeStopListening() {
if t.isSingleUse && t.closeErr != nil {
t.conn.SetReadDeadline(time.Now())
}
}
func (t *Transport) handlePacket(p receivedPacket) {
if len(p.data) == 0 {
return
}
if !wire.IsPotentialQUICPacket(p.data[0]) && !wire.IsLongHeaderPacket(p.data[0]) {
t.handleNonQUICPacket(p)
return
}
connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
if err != nil {
t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
if t.Tracer != nil && t.Tracer.DroppedPacket != nil {
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
}
p.buffer.MaybeRelease()
return
}
// If there's a connection associated with the connection ID, pass the packet there.
if handler, ok := (*packetHandlerMap)(t).Get(connID); ok {
handler.handlePacket(p)
return
}
// RFC 9000 section 10.3.1 requires that the stateless reset detection logic is run for both
// packets that cannot be associated with any connections, and for packets that can't be decrypted.
// We deviate from the RFC and ignore the latter: If a packet's connection ID is associated with an
// existing connection, it is dropped there if if it can't be decrypted.
// Stateless resets use random connection IDs, and at reasonable connection ID lengths collisions are
// exceedingly rare. In the unlikely event that a stateless reset is misrouted to an existing connection,
// it is to be expected that the next stateless reset will be correctly detected.
if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
return
}
if !wire.IsLongHeaderPacket(p.data[0]) {
if statelessResetQueued := t.maybeSendStatelessReset(p); !statelessResetQueued {
if t.Tracer != nil && t.Tracer.DroppedPacket != nil {
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnknownConnectionID)
}
p.buffer.Release()
}
return
}
t.mutex.Lock()
defer t.mutex.Unlock()
if t.server == nil { // no server set
t.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
if t.Tracer != nil && t.Tracer.DroppedPacket != nil {
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnknownConnectionID)
}
p.buffer.MaybeRelease()
return
}
t.server.handlePacket(p)
}
func (t *Transport) maybeSendStatelessReset(p receivedPacket) (statelessResetQueued bool) {
if t.StatelessResetKey == nil {
return false
}
// Don't send a stateless reset in response to very small packets.
// This includes packets that could be stateless resets.
if len(p.data) <= protocol.MinStatelessResetSize {
return false
}
select {
case t.statelessResetQueue <- p:
return true
default:
// it's fine to not send a stateless reset when we're busy
return false
}
}
func (t *Transport) sendStatelessReset(p receivedPacket) {
defer p.buffer.Release()
connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
if err != nil {
t.logger.Errorf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
return
}
token := t.statelessResetter.GetStatelessResetToken(connID)
t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
rand.Read(data)
data[0] = (data[0] & 0x7f) | 0x40
data = append(data, token[:]...)
if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil {
t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err)
}
}
func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
// stateless resets are always short header packets
if wire.IsLongHeaderPacket(data[0]) {
return false
}
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
return false
}
token := protocol.StatelessResetToken(data[len(data)-16:])
t.mutex.Lock()
conn, ok := t.resetTokens[token]
t.mutex.Unlock()
if ok {
t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
go conn.destroy(&StatelessResetError{})
return true
}
return false
}
func (t *Transport) handleNonQUICPacket(p receivedPacket) {
// Strictly speaking, this is racy,
// but we only care about receiving packets at some point after ReadNonQUICPacket has been called.
if !t.readingNonQUICPackets.Load() {
return
}
select {
case t.nonQUICPackets <- p:
default:
if t.Tracer != nil && t.Tracer.DroppedPacket != nil {
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
}
}
}
const maxQueuedNonQUICPackets = 32
// ReadNonQUICPacket reads non-QUIC packets received on the underlying connection.
// The detection logic is very simple: Any packet that has the first and second bit of the packet set to 0.
// Note that this is stricter than the detection logic defined in RFC 9443.
func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) {
if err := t.init(false); err != nil {
return 0, nil, err
}
if !t.readingNonQUICPackets.Load() {
t.nonQUICPackets = make(chan receivedPacket, maxQueuedNonQUICPackets)
t.readingNonQUICPackets.Store(true)
}
select {
case <-ctx.Done():
return 0, nil, ctx.Err()
case p := <-t.nonQUICPackets:
n := copy(b, p.data)
return n, p.remoteAddr, nil
case <-t.listening:
return 0, nil, errors.New("closed")
}
}
func setTLSConfigServerName(tlsConf *tls.Config, addr net.Addr, host string) {
// If no ServerName is set, infer the ServerName from the host we're connecting to.
if tlsConf.ServerName != "" {
return
}
if host == "" {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
tlsConf.ServerName = udpAddr.IP.String()
return
}
}
h, _, err := net.SplitHostPort(host)
if err != nil { // This happens if the host doesn't contain a port number.
tlsConf.ServerName = host
return
}
tlsConf.ServerName = h
}
type packetHandlerMap Transport
var _ connRunner = &packetHandlerMap{}
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
h.mutex.Lock()
defer h.mutex.Unlock()
if _, ok := h.handlers[id]; ok {
h.logger.Debugf("Not adding connection ID %s, as it already exists.", id)
return false
}
h.handlers[id] = handler
h.logger.Debugf("Adding connection ID %s.", id)
return true
}
func (h *packetHandlerMap) Get(connID protocol.ConnectionID) (packetHandler, bool) {
h.mutex.Lock()
defer h.mutex.Unlock()
handler, ok := h.handlers[connID]
return handler, ok
}
func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
h.mutex.Lock()
h.resetTokens[token] = handler
h.mutex.Unlock()
}
func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) {
h.mutex.Lock()
delete(h.resetTokens, token)
h.mutex.Unlock()
}
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool {
h.mutex.Lock()
defer h.mutex.Unlock()
if _, ok := h.handlers[clientDestConnID]; ok {
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
return false
}
h.handlers[clientDestConnID] = handler
h.handlers[newConnID] = handler
h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID)
return true
}
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
h.mutex.Lock()
delete(h.handlers, id)
h.mutex.Unlock()
h.logger.Debugf("Removing connection ID %s.", id)
}
// ReplaceWithClosed is called when a connection is closed.
// Depending on which side closed the connection, we need to:
// * remote close: absorb delayed packets
// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte, expiry time.Duration) {
var handler packetHandler
if connClosePacket != nil {
handler = newClosedLocalConn(
func(addr net.Addr, info packetInfo) {
select {
case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}:
default:
// We're backlogged.
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
}
},
h.logger,
)
} else {
handler = newClosedRemoteConn()
}
h.mutex.Lock()
for _, id := range ids {
h.handlers[id] = handler
}
h.mutex.Unlock()
h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids)
time.AfterFunc(expiry, func() {
h.mutex.Lock()
for _, id := range ids {
delete(h.handlers, id)
}
if len(h.handlers) == 0 {
t := (*Transport)(h)
t.maybeStopListening()
}
h.mutex.Unlock()
h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids)
})
}