// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tar implements access to tar archives.
//
// Tape archives (tar) are a file format for storing a sequence of files that
// can be read and written in a streaming manner.
// This package aims to cover most variations of the format,
// including those produced by GNU and BSD tar tools.
package tar
import (
"errors"
"fmt"
"internal/godebug"
"io/fs"
"maps"
"math"
"path"
"reflect"
"strconv"
"strings"
"time"
)
// BUG: Use of the Uid and Gid fields in Header could overflow on 32-bit
// architectures. If a large value is encountered when decoding, the result
// stored in Header will be the truncated version.
var tarinsecurepath = godebug.New("tarinsecurepath")
var (
ErrHeader = errors.New("archive/tar: invalid tar header")
ErrWriteTooLong = errors.New("archive/tar: write too long")
ErrFieldTooLong = errors.New("archive/tar: header field too long")
ErrWriteAfterClose = errors.New("archive/tar: write after close")
ErrInsecurePath = errors.New("archive/tar: insecure file path")
errMissData = errors.New("archive/tar: sparse file references non-existent data")
errUnrefData = errors.New("archive/tar: sparse file contains unreferenced data")
errWriteHole = errors.New("archive/tar: write non-NUL byte in sparse hole")
)
type headerError []string
func (he headerError) Error() string {
const prefix = "archive/tar: cannot encode header"
var ss []string
for _, s := range he {
if s != "" {
ss = append(ss, s)
}
}
if len(ss) == 0 {
return prefix
}
return fmt.Sprintf("%s: %v", prefix, strings.Join(ss, "; and "))
}
// Type flags for Header.Typeflag.
const (
// Type '0' indicates a regular file.
TypeReg = '0'
// Deprecated: Use TypeReg instead.
TypeRegA = '\x00'
// Type '1' to '6' are header-only flags and may not have a data body.
TypeLink = '1' // Hard link
TypeSymlink = '2' // Symbolic link
TypeChar = '3' // Character device node
TypeBlock = '4' // Block device node
TypeDir = '5' // Directory
TypeFifo = '6' // FIFO node
// Type '7' is reserved.
TypeCont = '7'
// Type 'x' is used by the PAX format to store key-value records that
// are only relevant to the next file.
// This package transparently handles these types.
TypeXHeader = 'x'
// Type 'g' is used by the PAX format to store key-value records that
// are relevant to all subsequent files.
// This package only supports parsing and composing such headers,
// but does not currently support persisting the global state across files.
TypeXGlobalHeader = 'g'
// Type 'S' indicates a sparse file in the GNU format.
TypeGNUSparse = 'S'
// Types 'L' and 'K' are used by the GNU format for a meta file
// used to store the path or link name for the next file.
// This package transparently handles these types.
TypeGNULongName = 'L'
TypeGNULongLink = 'K'
)
// Keywords for PAX extended header records.
const (
paxNone = "" // Indicates that no PAX key is suitable
paxPath = "path"
paxLinkpath = "linkpath"
paxSize = "size"
paxUid = "uid"
paxGid = "gid"
paxUname = "uname"
paxGname = "gname"
paxMtime = "mtime"
paxAtime = "atime"
paxCtime = "ctime" // Removed from later revision of PAX spec, but was valid
paxCharset = "charset" // Currently unused
paxComment = "comment" // Currently unused
paxSchilyXattr = "SCHILY.xattr."
// Keywords for GNU sparse files in a PAX extended header.
paxGNUSparse = "GNU.sparse."
paxGNUSparseNumBlocks = "GNU.sparse.numblocks"
paxGNUSparseOffset = "GNU.sparse.offset"
paxGNUSparseNumBytes = "GNU.sparse.numbytes"
paxGNUSparseMap = "GNU.sparse.map"
paxGNUSparseName = "GNU.sparse.name"
paxGNUSparseMajor = "GNU.sparse.major"
paxGNUSparseMinor = "GNU.sparse.minor"
paxGNUSparseSize = "GNU.sparse.size"
paxGNUSparseRealSize = "GNU.sparse.realsize"
)
// basicKeys is a set of the PAX keys for which we have built-in support.
// This does not contain "charset" or "comment", which are both PAX-specific,
// so adding them as first-class features of Header is unlikely.
// Users can use the PAXRecords field to set it themselves.
var basicKeys = map[string]bool{
paxPath: true, paxLinkpath: true, paxSize: true, paxUid: true, paxGid: true,
paxUname: true, paxGname: true, paxMtime: true, paxAtime: true, paxCtime: true,
}
// A Header represents a single header in a tar archive.
// Some fields may not be populated.
//
// For forward compatibility, users that retrieve a Header from Reader.Next,
// mutate it in some ways, and then pass it back to Writer.WriteHeader
// should do so by creating a new Header and copying the fields
// that they are interested in preserving.
type Header struct {
// Typeflag is the type of header entry.
// The zero value is automatically promoted to either TypeReg or TypeDir
// depending on the presence of a trailing slash in Name.
Typeflag byte
Name string // Name of file entry
Linkname string // Target name of link (valid for TypeLink or TypeSymlink)
Size int64 // Logical file size in bytes
Mode int64 // Permission and mode bits
Uid int // User ID of owner
Gid int // Group ID of owner
Uname string // User name of owner
Gname string // Group name of owner
// If the Format is unspecified, then Writer.WriteHeader rounds ModTime
// to the nearest second and ignores the AccessTime and ChangeTime fields.
//
// To use AccessTime or ChangeTime, specify the Format as PAX or GNU.
// To use sub-second resolution, specify the Format as PAX.
ModTime time.Time // Modification time
AccessTime time.Time // Access time (requires either PAX or GNU support)
ChangeTime time.Time // Change time (requires either PAX or GNU support)
Devmajor int64 // Major device number (valid for TypeChar or TypeBlock)
Devminor int64 // Minor device number (valid for TypeChar or TypeBlock)
// Xattrs stores extended attributes as PAX records under the
// "SCHILY.xattr." namespace.
//
// The following are semantically equivalent:
// h.Xattrs[key] = value
// h.PAXRecords["SCHILY.xattr."+key] = value
//
// When Writer.WriteHeader is called, the contents of Xattrs will take
// precedence over those in PAXRecords.
//
// Deprecated: Use PAXRecords instead.
Xattrs map[string]string
// PAXRecords is a map of PAX extended header records.
//
// User-defined records should have keys of the following form:
// VENDOR.keyword
// Where VENDOR is some namespace in all uppercase, and keyword may
// not contain the '=' character (e.g., "GOLANG.pkg.version").
// The key and value should be non-empty UTF-8 strings.
//
// When Writer.WriteHeader is called, PAX records derived from the
// other fields in Header take precedence over PAXRecords.
PAXRecords map[string]string
// Format specifies the format of the tar header.
//
// This is set by Reader.Next as a best-effort guess at the format.
// Since the Reader liberally reads some non-compliant files,
// it is possible for this to be FormatUnknown.
//
// If the format is unspecified when Writer.WriteHeader is called,
// then it uses the first format (in the order of USTAR, PAX, GNU)
// capable of encoding this Header (see Format).
Format Format
}
// sparseEntry represents a Length-sized fragment at Offset in the file.
type sparseEntry struct{ Offset, Length int64 }
func (s sparseEntry) endOffset() int64 { return s.Offset + s.Length }
// A sparse file can be represented as either a sparseDatas or a sparseHoles.
// As long as the total size is known, they are equivalent and one can be
// converted to the other form and back. The various tar formats with sparse
// file support represent sparse files in the sparseDatas form. That is, they
// specify the fragments in the file that has data, and treat everything else as
// having zero bytes. As such, the encoding and decoding logic in this package
// deals with sparseDatas.
//
// However, the external API uses sparseHoles instead of sparseDatas because the
// zero value of sparseHoles logically represents a normal file (i.e., there are
// no holes in it). On the other hand, the zero value of sparseDatas implies
// that the file has no data in it, which is rather odd.
//
// As an example, if the underlying raw file contains the 10-byte data:
//
// var compactFile = "abcdefgh"
//
// And the sparse map has the following entries:
//
// var spd sparseDatas = []sparseEntry{
// {Offset: 2, Length: 5}, // Data fragment for 2..6
// {Offset: 18, Length: 3}, // Data fragment for 18..20
// }
// var sph sparseHoles = []sparseEntry{
// {Offset: 0, Length: 2}, // Hole fragment for 0..1
// {Offset: 7, Length: 11}, // Hole fragment for 7..17
// {Offset: 21, Length: 4}, // Hole fragment for 21..24
// }
//
// Then the content of the resulting sparse file with a Header.Size of 25 is:
//
// var sparseFile = "\x00"*2 + "abcde" + "\x00"*11 + "fgh" + "\x00"*4
type (
sparseDatas []sparseEntry
sparseHoles []sparseEntry
)
// validateSparseEntries reports whether sp is a valid sparse map.
// It does not matter whether sp represents data fragments or hole fragments.
func validateSparseEntries(sp []sparseEntry, size int64) bool {
// Validate all sparse entries. These are the same checks as performed by
// the BSD tar utility.
if size < 0 {
return false
}
var pre sparseEntry
for _, cur := range sp {
switch {
case cur.Offset < 0 || cur.Length < 0:
return false // Negative values are never okay
case cur.Offset > math.MaxInt64-cur.Length:
return false // Integer overflow with large length
case cur.endOffset() > size:
return false // Region extends beyond the actual size
case pre.endOffset() > cur.Offset:
return false // Regions cannot overlap and must be in order
}
pre = cur
}
return true
}
// alignSparseEntries mutates src and returns dst where each fragment's
// starting offset is aligned up to the nearest block edge, and each
// ending offset is aligned down to the nearest block edge.
//
// Even though the Go tar Reader and the BSD tar utility can handle entries
// with arbitrary offsets and lengths, the GNU tar utility can only handle
// offsets and lengths that are multiples of blockSize.
func alignSparseEntries(src []sparseEntry, size int64) []sparseEntry {
dst := src[:0]
for _, s := range src {
pos, end := s.Offset, s.endOffset()
pos += blockPadding(+pos) // Round-up to nearest blockSize
if end != size {
end -= blockPadding(-end) // Round-down to nearest blockSize
}
if pos < end {
dst = append(dst, sparseEntry{Offset: pos, Length: end - pos})
}
}
return dst
}
// invertSparseEntries converts a sparse map from one form to the other.
// If the input is sparseHoles, then it will output sparseDatas and vice-versa.
// The input must have been already validated.
//
// This function mutates src and returns a normalized map where:
// - adjacent fragments are coalesced together
// - only the last fragment may be empty
// - the endOffset of the last fragment is the total size
func invertSparseEntries(src []sparseEntry, size int64) []sparseEntry {
dst := src[:0]
var pre sparseEntry
for _, cur := range src {
if cur.Length == 0 {
continue // Skip empty fragments
}
pre.Length = cur.Offset - pre.Offset
if pre.Length > 0 {
dst = append(dst, pre) // Only add non-empty fragments
}
pre.Offset = cur.endOffset()
}
pre.Length = size - pre.Offset // Possibly the only empty fragment
return append(dst, pre)
}
// fileState tracks the number of logical (includes sparse holes) and physical
// (actual in tar archive) bytes remaining for the current file.
//
// Invariant: logicalRemaining >= physicalRemaining
type fileState interface {
logicalRemaining() int64
physicalRemaining() int64
}
// allowedFormats determines which formats can be used.
// The value returned is the logical OR of multiple possible formats.
// If the value is FormatUnknown, then the input Header cannot be encoded
// and an error is returned explaining why.
//
// As a by-product of checking the fields, this function returns paxHdrs, which
// contain all fields that could not be directly encoded.
// A value receiver ensures that this method does not mutate the source Header.
func (h Header) allowedFormats() (format Format, paxHdrs map[string]string, err error) {
format = FormatUSTAR | FormatPAX | FormatGNU
paxHdrs = make(map[string]string)
var whyNoUSTAR, whyNoPAX, whyNoGNU string
var preferPAX bool // Prefer PAX over USTAR
verifyString := func(s string, size int, name, paxKey string) {
// NUL-terminator is optional for path and linkpath.
// Technically, it is required for uname and gname,
// but neither GNU nor BSD tar checks for it.
tooLong := len(s) > size
allowLongGNU := paxKey == paxPath || paxKey == paxLinkpath
if hasNUL(s) || (tooLong && !allowLongGNU) {
whyNoGNU = fmt.Sprintf("GNU cannot encode %s=%q", name, s)
format.mustNotBe(FormatGNU)
}
if !isASCII(s) || tooLong {
canSplitUSTAR := paxKey == paxPath
if _, _, ok := splitUSTARPath(s); !canSplitUSTAR || !ok {
whyNoUSTAR = fmt.Sprintf("USTAR cannot encode %s=%q", name, s)
format.mustNotBe(FormatUSTAR)
}
if paxKey == paxNone {
whyNoPAX = fmt.Sprintf("PAX cannot encode %s=%q", name, s)
format.mustNotBe(FormatPAX)
} else {
paxHdrs[paxKey] = s
}
}
if v, ok := h.PAXRecords[paxKey]; ok && v == s {
paxHdrs[paxKey] = v
}
}
verifyNumeric := func(n int64, size int, name, paxKey string) {
if !fitsInBase256(size, n) {
whyNoGNU = fmt.Sprintf("GNU cannot encode %s=%d", name, n)
format.mustNotBe(FormatGNU)
}
if !fitsInOctal(size, n) {
whyNoUSTAR = fmt.Sprintf("USTAR cannot encode %s=%d", name, n)
format.mustNotBe(FormatUSTAR)
if paxKey == paxNone {
whyNoPAX = fmt.Sprintf("PAX cannot encode %s=%d", name, n)
format.mustNotBe(FormatPAX)
} else {
paxHdrs[paxKey] = strconv.FormatInt(n, 10)
}
}
if v, ok := h.PAXRecords[paxKey]; ok && v == strconv.FormatInt(n, 10) {
paxHdrs[paxKey] = v
}
}
verifyTime := func(ts time.Time, size int, name, paxKey string) {
if ts.IsZero() {
return // Always okay
}
if !fitsInBase256(size, ts.Unix()) {
whyNoGNU = fmt.Sprintf("GNU cannot encode %s=%v", name, ts)
format.mustNotBe(FormatGNU)
}
isMtime := paxKey == paxMtime
fitsOctal := fitsInOctal(size, ts.Unix())
if (isMtime && !fitsOctal) || !isMtime {
whyNoUSTAR = fmt.Sprintf("USTAR cannot encode %s=%v", name, ts)
format.mustNotBe(FormatUSTAR)
}
needsNano := ts.Nanosecond() != 0
if !isMtime || !fitsOctal || needsNano {
preferPAX = true // USTAR may truncate sub-second measurements
if paxKey == paxNone {
whyNoPAX = fmt.Sprintf("PAX cannot encode %s=%v", name, ts)
format.mustNotBe(FormatPAX)
} else {
paxHdrs[paxKey] = formatPAXTime(ts)
}
}
if v, ok := h.PAXRecords[paxKey]; ok && v == formatPAXTime(ts) {
paxHdrs[paxKey] = v
}
}
// Check basic fields.
var blk block
v7 := blk.toV7()
ustar := blk.toUSTAR()
gnu := blk.toGNU()
verifyString(h.Name, len(v7.name()), "Name", paxPath)
verifyString(h.Linkname, len(v7.linkName()), "Linkname", paxLinkpath)
verifyString(h.Uname, len(ustar.userName()), "Uname", paxUname)
verifyString(h.Gname, len(ustar.groupName()), "Gname", paxGname)
verifyNumeric(h.Mode, len(v7.mode()), "Mode", paxNone)
verifyNumeric(int64(h.Uid), len(v7.uid()), "Uid", paxUid)
verifyNumeric(int64(h.Gid), len(v7.gid()), "Gid", paxGid)
verifyNumeric(h.Size, len(v7.size()), "Size", paxSize)
verifyNumeric(h.Devmajor, len(ustar.devMajor()), "Devmajor", paxNone)
verifyNumeric(h.Devminor, len(ustar.devMinor()), "Devminor", paxNone)
verifyTime(h.ModTime, len(v7.modTime()), "ModTime", paxMtime)
verifyTime(h.AccessTime, len(gnu.accessTime()), "AccessTime", paxAtime)
verifyTime(h.ChangeTime, len(gnu.changeTime()), "ChangeTime", paxCtime)
// Check for header-only types.
var whyOnlyPAX, whyOnlyGNU string
switch h.Typeflag {
case TypeReg, TypeChar, TypeBlock, TypeFifo, TypeGNUSparse:
// Exclude TypeLink and TypeSymlink, since they may reference directories.
if strings.HasSuffix(h.Name, "/") {
return FormatUnknown, nil, headerError{"filename may not have trailing slash"}
}
case TypeXHeader, TypeGNULongName, TypeGNULongLink:
return FormatUnknown, nil, headerError{"cannot manually encode TypeXHeader, TypeGNULongName, or TypeGNULongLink headers"}
case TypeXGlobalHeader:
h2 := Header{Name: h.Name, Typeflag: h.Typeflag, Xattrs: h.Xattrs, PAXRecords: h.PAXRecords, Format: h.Format}
if !reflect.DeepEqual(h, h2) {
return FormatUnknown, nil, headerError{"only PAXRecords should be set for TypeXGlobalHeader"}
}
whyOnlyPAX = "only PAX supports TypeXGlobalHeader"
format.mayOnlyBe(FormatPAX)
}
if !isHeaderOnlyType(h.Typeflag) && h.Size < 0 {
return FormatUnknown, nil, headerError{"negative size on header-only type"}
}
// Check PAX records.
if len(h.Xattrs) > 0 {
for k, v := range h.Xattrs {
paxHdrs[paxSchilyXattr+k] = v
}
whyOnlyPAX = "only PAX supports Xattrs"
format.mayOnlyBe(FormatPAX)
}
if len(h.PAXRecords) > 0 {
for k, v := range h.PAXRecords {
switch _, exists := paxHdrs[k]; {
case exists:
continue // Do not overwrite existing records
case h.Typeflag == TypeXGlobalHeader:
paxHdrs[k] = v // Copy all records
case !basicKeys[k] && !strings.HasPrefix(k, paxGNUSparse):
paxHdrs[k] = v // Ignore local records that may conflict
}
}
whyOnlyPAX = "only PAX supports PAXRecords"
format.mayOnlyBe(FormatPAX)
}
for k, v := range paxHdrs {
if !validPAXRecord(k, v) {
return FormatUnknown, nil, headerError{fmt.Sprintf("invalid PAX record: %q", k+" = "+v)}
}
}
// TODO(dsnet): Re-enable this when adding sparse support.
// See https://golang.org/issue/22735
/*
// Check sparse files.
if len(h.SparseHoles) > 0 || h.Typeflag == TypeGNUSparse {
if isHeaderOnlyType(h.Typeflag) {
return FormatUnknown, nil, headerError{"header-only type cannot be sparse"}
}
if !validateSparseEntries(h.SparseHoles, h.Size) {
return FormatUnknown, nil, headerError{"invalid sparse holes"}
}
if h.Typeflag == TypeGNUSparse {
whyOnlyGNU = "only GNU supports TypeGNUSparse"
format.mayOnlyBe(FormatGNU)
} else {
whyNoGNU = "GNU supports sparse files only with TypeGNUSparse"
format.mustNotBe(FormatGNU)
}
whyNoUSTAR = "USTAR does not support sparse files"
format.mustNotBe(FormatUSTAR)
}
*/
// Check desired format.
if wantFormat := h.Format; wantFormat != FormatUnknown {
if wantFormat.has(FormatPAX) && !preferPAX {
wantFormat.mayBe(FormatUSTAR) // PAX implies USTAR allowed too
}
format.mayOnlyBe(wantFormat) // Set union of formats allowed and format wanted
}
if format == FormatUnknown {
switch h.Format {
case FormatUSTAR:
err = headerError{"Format specifies USTAR", whyNoUSTAR, whyOnlyPAX, whyOnlyGNU}
case FormatPAX:
err = headerError{"Format specifies PAX", whyNoPAX, whyOnlyGNU}
case FormatGNU:
err = headerError{"Format specifies GNU", whyNoGNU, whyOnlyPAX}
default:
err = headerError{whyNoUSTAR, whyNoPAX, whyNoGNU, whyOnlyPAX, whyOnlyGNU}
}
}
return format, paxHdrs, err
}
// FileInfo returns an fs.FileInfo for the Header.
func (h *Header) FileInfo() fs.FileInfo {
return headerFileInfo{h}
}
// headerFileInfo implements fs.FileInfo.
type headerFileInfo struct {
h *Header
}
func (fi headerFileInfo) Size() int64 { return fi.h.Size }
func (fi headerFileInfo) IsDir() bool { return fi.Mode().IsDir() }
func (fi headerFileInfo) ModTime() time.Time { return fi.h.ModTime }
func (fi headerFileInfo) Sys() any { return fi.h }
// Name returns the base name of the file.
func (fi headerFileInfo) Name() string {
if fi.IsDir() {
return path.Base(path.Clean(fi.h.Name))
}
return path.Base(fi.h.Name)
}
// Mode returns the permission and mode bits for the headerFileInfo.
func (fi headerFileInfo) Mode() (mode fs.FileMode) {
// Set file permission bits.
mode = fs.FileMode(fi.h.Mode).Perm()
// Set setuid, setgid and sticky bits.
if fi.h.Mode&c_ISUID != 0 {
mode |= fs.ModeSetuid
}
if fi.h.Mode&c_ISGID != 0 {
mode |= fs.ModeSetgid
}
if fi.h.Mode&c_ISVTX != 0 {
mode |= fs.ModeSticky
}
// Set file mode bits; clear perm, setuid, setgid, and sticky bits.
switch m := fs.FileMode(fi.h.Mode) &^ 07777; m {
case c_ISDIR:
mode |= fs.ModeDir
case c_ISFIFO:
mode |= fs.ModeNamedPipe
case c_ISLNK:
mode |= fs.ModeSymlink
case c_ISBLK:
mode |= fs.ModeDevice
case c_ISCHR:
mode |= fs.ModeDevice
mode |= fs.ModeCharDevice
case c_ISSOCK:
mode |= fs.ModeSocket
}
switch fi.h.Typeflag {
case TypeSymlink:
mode |= fs.ModeSymlink
case TypeChar:
mode |= fs.ModeDevice
mode |= fs.ModeCharDevice
case TypeBlock:
mode |= fs.ModeDevice
case TypeDir:
mode |= fs.ModeDir
case TypeFifo:
mode |= fs.ModeNamedPipe
}
return mode
}
func (fi headerFileInfo) String() string {
return fs.FormatFileInfo(fi)
}
// sysStat, if non-nil, populates h from system-dependent fields of fi.
var sysStat func(fi fs.FileInfo, h *Header, doNameLookups bool) error
const (
// Mode constants from the USTAR spec:
// See http://pubs.opengroup.org/onlinepubs/9699919799/utilities/pax.html#tag_20_92_13_06
c_ISUID = 04000 // Set uid
c_ISGID = 02000 // Set gid
c_ISVTX = 01000 // Save text (sticky bit)
// Common Unix mode constants; these are not defined in any common tar standard.
// Header.FileInfo understands these, but FileInfoHeader will never produce these.
c_ISDIR = 040000 // Directory
c_ISFIFO = 010000 // FIFO
c_ISREG = 0100000 // Regular file
c_ISLNK = 0120000 // Symbolic link
c_ISBLK = 060000 // Block special file
c_ISCHR = 020000 // Character special file
c_ISSOCK = 0140000 // Socket
)
// FileInfoHeader creates a partially-populated [Header] from fi.
// If fi describes a symlink, FileInfoHeader records link as the link target.
// If fi describes a directory, a slash is appended to the name.
//
// Since fs.FileInfo's Name method only returns the base name of
// the file it describes, it may be necessary to modify Header.Name
// to provide the full path name of the file.
//
// If fi implements [FileInfoNames]
// Header.Gname and Header.Uname
// are provided by the methods of the interface.
func FileInfoHeader(fi fs.FileInfo, link string) (*Header, error) {
if fi == nil {
return nil, errors.New("archive/tar: FileInfo is nil")
}
fm := fi.Mode()
h := &Header{
Name: fi.Name(),
ModTime: fi.ModTime(),
Mode: int64(fm.Perm()), // or'd with c_IS* constants later
}
switch {
case fm.IsRegular():
h.Typeflag = TypeReg
h.Size = fi.Size()
case fi.IsDir():
h.Typeflag = TypeDir
h.Name += "/"
case fm&fs.ModeSymlink != 0:
h.Typeflag = TypeSymlink
h.Linkname = link
case fm&fs.ModeDevice != 0:
if fm&fs.ModeCharDevice != 0 {
h.Typeflag = TypeChar
} else {
h.Typeflag = TypeBlock
}
case fm&fs.ModeNamedPipe != 0:
h.Typeflag = TypeFifo
case fm&fs.ModeSocket != 0:
return nil, fmt.Errorf("archive/tar: sockets not supported")
default:
return nil, fmt.Errorf("archive/tar: unknown file mode %v", fm)
}
if fm&fs.ModeSetuid != 0 {
h.Mode |= c_ISUID
}
if fm&fs.ModeSetgid != 0 {
h.Mode |= c_ISGID
}
if fm&fs.ModeSticky != 0 {
h.Mode |= c_ISVTX
}
// If possible, populate additional fields from OS-specific
// FileInfo fields.
if sys, ok := fi.Sys().(*Header); ok {
// This FileInfo came from a Header (not the OS). Use the
// original Header to populate all remaining fields.
h.Uid = sys.Uid
h.Gid = sys.Gid
h.Uname = sys.Uname
h.Gname = sys.Gname
h.AccessTime = sys.AccessTime
h.ChangeTime = sys.ChangeTime
h.Xattrs = maps.Clone(sys.Xattrs)
if sys.Typeflag == TypeLink {
// hard link
h.Typeflag = TypeLink
h.Size = 0
h.Linkname = sys.Linkname
}
h.PAXRecords = maps.Clone(sys.PAXRecords)
}
var doNameLookups = true
if iface, ok := fi.(FileInfoNames); ok {
doNameLookups = false
var err error
h.Gname, err = iface.Gname()
if err != nil {
return nil, err
}
h.Uname, err = iface.Uname()
if err != nil {
return nil, err
}
}
if sysStat != nil {
return h, sysStat(fi, h, doNameLookups)
}
return h, nil
}
// FileInfoNames extends [fs.FileInfo].
// Passing an instance of this to [FileInfoHeader] permits the caller
// to avoid a system-dependent name lookup by specifying the Uname and Gname directly.
type FileInfoNames interface {
fs.FileInfo
// Uname should give a user name.
Uname() (string, error)
// Gname should give a group name.
Gname() (string, error)
}
// isHeaderOnlyType checks if the given type flag is of the type that has no
// data section even if a size is specified.
func isHeaderOnlyType(flag byte) bool {
switch flag {
case TypeLink, TypeSymlink, TypeChar, TypeBlock, TypeDir, TypeFifo:
return true
default:
return false
}
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tar
import "strings"
// Format represents the tar archive format.
//
// The original tar format was introduced in Unix V7.
// Since then, there have been multiple competing formats attempting to
// standardize or extend the V7 format to overcome its limitations.
// The most common formats are the USTAR, PAX, and GNU formats,
// each with their own advantages and limitations.
//
// The following table captures the capabilities of each format:
//
// | USTAR | PAX | GNU
// ------------------+--------+-----------+----------
// Name | 256B | unlimited | unlimited
// Linkname | 100B | unlimited | unlimited
// Size | uint33 | unlimited | uint89
// Mode | uint21 | uint21 | uint57
// Uid/Gid | uint21 | unlimited | uint57
// Uname/Gname | 32B | unlimited | 32B
// ModTime | uint33 | unlimited | int89
// AccessTime | n/a | unlimited | int89
// ChangeTime | n/a | unlimited | int89
// Devmajor/Devminor | uint21 | uint21 | uint57
// ------------------+--------+-----------+----------
// string encoding | ASCII | UTF-8 | binary
// sub-second times | no | yes | no
// sparse files | no | yes | yes
//
// The table's upper portion shows the [Header] fields, where each format reports
// the maximum number of bytes allowed for each string field and
// the integer type used to store each numeric field
// (where timestamps are stored as the number of seconds since the Unix epoch).
//
// The table's lower portion shows specialized features of each format,
// such as supported string encodings, support for sub-second timestamps,
// or support for sparse files.
//
// The Writer currently provides no support for sparse files.
type Format int
// Constants to identify various tar formats.
const (
// Deliberately hide the meaning of constants from public API.
_ Format = (1 << iota) / 4 // Sequence of 0, 0, 1, 2, 4, 8, etc...
// FormatUnknown indicates that the format is unknown.
FormatUnknown
// The format of the original Unix V7 tar tool prior to standardization.
formatV7
// FormatUSTAR represents the USTAR header format defined in POSIX.1-1988.
//
// While this format is compatible with most tar readers,
// the format has several limitations making it unsuitable for some usages.
// Most notably, it cannot support sparse files, files larger than 8GiB,
// filenames larger than 256 characters, and non-ASCII filenames.
//
// Reference:
// http://pubs.opengroup.org/onlinepubs/9699919799/utilities/pax.html#tag_20_92_13_06
FormatUSTAR
// FormatPAX represents the PAX header format defined in POSIX.1-2001.
//
// PAX extends USTAR by writing a special file with Typeflag TypeXHeader
// preceding the original header. This file contains a set of key-value
// records, which are used to overcome USTAR's shortcomings, in addition to
// providing the ability to have sub-second resolution for timestamps.
//
// Some newer formats add their own extensions to PAX by defining their
// own keys and assigning certain semantic meaning to the associated values.
// For example, sparse file support in PAX is implemented using keys
// defined by the GNU manual (e.g., "GNU.sparse.map").
//
// Reference:
// http://pubs.opengroup.org/onlinepubs/009695399/utilities/pax.html
FormatPAX
// FormatGNU represents the GNU header format.
//
// The GNU header format is older than the USTAR and PAX standards and
// is not compatible with them. The GNU format supports
// arbitrary file sizes, filenames of arbitrary encoding and length,
// sparse files, and other features.
//
// It is recommended that PAX be chosen over GNU unless the target
// application can only parse GNU formatted archives.
//
// Reference:
// https://www.gnu.org/software/tar/manual/html_node/Standard.html
FormatGNU
// Schily's tar format, which is incompatible with USTAR.
// This does not cover STAR extensions to the PAX format; these fall under
// the PAX format.
formatSTAR
formatMax
)
func (f Format) has(f2 Format) bool { return f&f2 != 0 }
func (f *Format) mayBe(f2 Format) { *f |= f2 }
func (f *Format) mayOnlyBe(f2 Format) { *f &= f2 }
func (f *Format) mustNotBe(f2 Format) { *f &^= f2 }
var formatNames = map[Format]string{
formatV7: "V7", FormatUSTAR: "USTAR", FormatPAX: "PAX", FormatGNU: "GNU", formatSTAR: "STAR",
}
func (f Format) String() string {
var ss []string
for f2 := Format(1); f2 < formatMax; f2 <<= 1 {
if f.has(f2) {
ss = append(ss, formatNames[f2])
}
}
switch len(ss) {
case 0:
return "<unknown>"
case 1:
return ss[0]
default:
return "(" + strings.Join(ss, " | ") + ")"
}
}
// Magics used to identify various formats.
const (
magicGNU, versionGNU = "ustar ", " \x00"
magicUSTAR, versionUSTAR = "ustar\x00", "00"
trailerSTAR = "tar\x00"
)
// Size constants from various tar specifications.
const (
blockSize = 512 // Size of each block in a tar stream
nameSize = 100 // Max length of the name field in USTAR format
prefixSize = 155 // Max length of the prefix field in USTAR format
// Max length of a special file (PAX header, GNU long name or link).
// This matches the limit used by libarchive.
maxSpecialFileSize = 1 << 20
)
// blockPadding computes the number of bytes needed to pad offset up to the
// nearest block edge where 0 <= n < blockSize.
func blockPadding(offset int64) (n int64) {
return -offset & (blockSize - 1)
}
var zeroBlock block
type block [blockSize]byte
// Convert block to any number of formats.
func (b *block) toV7() *headerV7 { return (*headerV7)(b) }
func (b *block) toGNU() *headerGNU { return (*headerGNU)(b) }
func (b *block) toSTAR() *headerSTAR { return (*headerSTAR)(b) }
func (b *block) toUSTAR() *headerUSTAR { return (*headerUSTAR)(b) }
func (b *block) toSparse() sparseArray { return sparseArray(b[:]) }
// getFormat checks that the block is a valid tar header based on the checksum.
// It then attempts to guess the specific format based on magic values.
// If the checksum fails, then FormatUnknown is returned.
func (b *block) getFormat() Format {
// Verify checksum.
var p parser
value := p.parseOctal(b.toV7().chksum())
chksum1, chksum2 := b.computeChecksum()
if p.err != nil || (value != chksum1 && value != chksum2) {
return FormatUnknown
}
// Guess the magic values.
magic := string(b.toUSTAR().magic())
version := string(b.toUSTAR().version())
trailer := string(b.toSTAR().trailer())
switch {
case magic == magicUSTAR && trailer == trailerSTAR:
return formatSTAR
case magic == magicUSTAR:
return FormatUSTAR | FormatPAX
case magic == magicGNU && version == versionGNU:
return FormatGNU
default:
return formatV7
}
}
// setFormat writes the magic values necessary for specified format
// and then updates the checksum accordingly.
func (b *block) setFormat(format Format) {
// Set the magic values.
switch {
case format.has(formatV7):
// Do nothing.
case format.has(FormatGNU):
copy(b.toGNU().magic(), magicGNU)
copy(b.toGNU().version(), versionGNU)
case format.has(formatSTAR):
copy(b.toSTAR().magic(), magicUSTAR)
copy(b.toSTAR().version(), versionUSTAR)
copy(b.toSTAR().trailer(), trailerSTAR)
case format.has(FormatUSTAR | FormatPAX):
copy(b.toUSTAR().magic(), magicUSTAR)
copy(b.toUSTAR().version(), versionUSTAR)
default:
panic("invalid format")
}
// Update checksum.
// This field is special in that it is terminated by a NULL then space.
var f formatter
field := b.toV7().chksum()
chksum, _ := b.computeChecksum() // Possible values are 256..128776
f.formatOctal(field[:7], chksum) // Never fails since 128776 < 262143
field[7] = ' '
}
// computeChecksum computes the checksum for the header block.
// POSIX specifies a sum of the unsigned byte values, but the Sun tar used
// signed byte values.
// We compute and return both.
func (b *block) computeChecksum() (unsigned, signed int64) {
for i, c := range b {
if 148 <= i && i < 156 {
c = ' ' // Treat the checksum field itself as all spaces.
}
unsigned += int64(c)
signed += int64(int8(c))
}
return unsigned, signed
}
// reset clears the block with all zeros.
func (b *block) reset() {
*b = block{}
}
type headerV7 [blockSize]byte
func (h *headerV7) name() []byte { return h[000:][:100] }
func (h *headerV7) mode() []byte { return h[100:][:8] }
func (h *headerV7) uid() []byte { return h[108:][:8] }
func (h *headerV7) gid() []byte { return h[116:][:8] }
func (h *headerV7) size() []byte { return h[124:][:12] }
func (h *headerV7) modTime() []byte { return h[136:][:12] }
func (h *headerV7) chksum() []byte { return h[148:][:8] }
func (h *headerV7) typeFlag() []byte { return h[156:][:1] }
func (h *headerV7) linkName() []byte { return h[157:][:100] }
type headerGNU [blockSize]byte
func (h *headerGNU) v7() *headerV7 { return (*headerV7)(h) }
func (h *headerGNU) magic() []byte { return h[257:][:6] }
func (h *headerGNU) version() []byte { return h[263:][:2] }
func (h *headerGNU) userName() []byte { return h[265:][:32] }
func (h *headerGNU) groupName() []byte { return h[297:][:32] }
func (h *headerGNU) devMajor() []byte { return h[329:][:8] }
func (h *headerGNU) devMinor() []byte { return h[337:][:8] }
func (h *headerGNU) accessTime() []byte { return h[345:][:12] }
func (h *headerGNU) changeTime() []byte { return h[357:][:12] }
func (h *headerGNU) sparse() sparseArray { return sparseArray(h[386:][:24*4+1]) }
func (h *headerGNU) realSize() []byte { return h[483:][:12] }
type headerSTAR [blockSize]byte
func (h *headerSTAR) v7() *headerV7 { return (*headerV7)(h) }
func (h *headerSTAR) magic() []byte { return h[257:][:6] }
func (h *headerSTAR) version() []byte { return h[263:][:2] }
func (h *headerSTAR) userName() []byte { return h[265:][:32] }
func (h *headerSTAR) groupName() []byte { return h[297:][:32] }
func (h *headerSTAR) devMajor() []byte { return h[329:][:8] }
func (h *headerSTAR) devMinor() []byte { return h[337:][:8] }
func (h *headerSTAR) prefix() []byte { return h[345:][:131] }
func (h *headerSTAR) accessTime() []byte { return h[476:][:12] }
func (h *headerSTAR) changeTime() []byte { return h[488:][:12] }
func (h *headerSTAR) trailer() []byte { return h[508:][:4] }
type headerUSTAR [blockSize]byte
func (h *headerUSTAR) v7() *headerV7 { return (*headerV7)(h) }
func (h *headerUSTAR) magic() []byte { return h[257:][:6] }
func (h *headerUSTAR) version() []byte { return h[263:][:2] }
func (h *headerUSTAR) userName() []byte { return h[265:][:32] }
func (h *headerUSTAR) groupName() []byte { return h[297:][:32] }
func (h *headerUSTAR) devMajor() []byte { return h[329:][:8] }
func (h *headerUSTAR) devMinor() []byte { return h[337:][:8] }
func (h *headerUSTAR) prefix() []byte { return h[345:][:155] }
type sparseArray []byte
func (s sparseArray) entry(i int) sparseElem { return sparseElem(s[i*24:]) }
func (s sparseArray) isExtended() []byte { return s[24*s.maxEntries():][:1] }
func (s sparseArray) maxEntries() int { return len(s) / 24 }
type sparseElem []byte
func (s sparseElem) offset() []byte { return s[00:][:12] }
func (s sparseElem) length() []byte { return s[12:][:12] }
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tar
import (
"bytes"
"io"
"path/filepath"
"strconv"
"strings"
"time"
)
// Reader provides sequential access to the contents of a tar archive.
// Reader.Next advances to the next file in the archive (including the first),
// and then Reader can be treated as an io.Reader to access the file's data.
type Reader struct {
r io.Reader
pad int64 // Amount of padding (ignored) after current file entry
curr fileReader // Reader for current file entry
blk block // Buffer to use as temporary local storage
// err is a persistent error.
// It is only the responsibility of every exported method of Reader to
// ensure that this error is sticky.
err error
}
type fileReader interface {
io.Reader
fileState
WriteTo(io.Writer) (int64, error)
}
// NewReader creates a new [Reader] reading from r.
func NewReader(r io.Reader) *Reader {
return &Reader{r: r, curr: ®FileReader{r, 0}}
}
// Next advances to the next entry in the tar archive.
// The Header.Size determines how many bytes can be read for the next file.
// Any remaining data in the current file is automatically discarded.
// At the end of the archive, Next returns the error io.EOF.
//
// If Next encounters a non-local name (as defined by [filepath.IsLocal])
// and the GODEBUG environment variable contains `tarinsecurepath=0`,
// Next returns the header with an [ErrInsecurePath] error.
// A future version of Go may introduce this behavior by default.
// Programs that want to accept non-local names can ignore
// the [ErrInsecurePath] error and use the returned header.
func (tr *Reader) Next() (*Header, error) {
if tr.err != nil {
return nil, tr.err
}
hdr, err := tr.next()
tr.err = err
if err == nil && !filepath.IsLocal(hdr.Name) {
if tarinsecurepath.Value() == "0" {
tarinsecurepath.IncNonDefault()
err = ErrInsecurePath
}
}
return hdr, err
}
func (tr *Reader) next() (*Header, error) {
var paxHdrs map[string]string
var gnuLongName, gnuLongLink string
// Externally, Next iterates through the tar archive as if it is a series of
// files. Internally, the tar format often uses fake "files" to add meta
// data that describes the next file. These meta data "files" should not
// normally be visible to the outside. As such, this loop iterates through
// one or more "header files" until it finds a "normal file".
format := FormatUSTAR | FormatPAX | FormatGNU
for {
// Discard the remainder of the file and any padding.
if err := discard(tr.r, tr.curr.physicalRemaining()); err != nil {
return nil, err
}
if _, err := tryReadFull(tr.r, tr.blk[:tr.pad]); err != nil {
return nil, err
}
tr.pad = 0
hdr, rawHdr, err := tr.readHeader()
if err != nil {
return nil, err
}
if err := tr.handleRegularFile(hdr); err != nil {
return nil, err
}
format.mayOnlyBe(hdr.Format)
// Check for PAX/GNU special headers and files.
switch hdr.Typeflag {
case TypeXHeader, TypeXGlobalHeader:
format.mayOnlyBe(FormatPAX)
paxHdrs, err = parsePAX(tr)
if err != nil {
return nil, err
}
if hdr.Typeflag == TypeXGlobalHeader {
mergePAX(hdr, paxHdrs)
return &Header{
Name: hdr.Name,
Typeflag: hdr.Typeflag,
Xattrs: hdr.Xattrs,
PAXRecords: hdr.PAXRecords,
Format: format,
}, nil
}
continue // This is a meta header affecting the next header
case TypeGNULongName, TypeGNULongLink:
format.mayOnlyBe(FormatGNU)
realname, err := readSpecialFile(tr)
if err != nil {
return nil, err
}
var p parser
switch hdr.Typeflag {
case TypeGNULongName:
gnuLongName = p.parseString(realname)
case TypeGNULongLink:
gnuLongLink = p.parseString(realname)
}
continue // This is a meta header affecting the next header
default:
// The old GNU sparse format is handled here since it is technically
// just a regular file with additional attributes.
if err := mergePAX(hdr, paxHdrs); err != nil {
return nil, err
}
if gnuLongName != "" {
hdr.Name = gnuLongName
}
if gnuLongLink != "" {
hdr.Linkname = gnuLongLink
}
if hdr.Typeflag == TypeRegA {
if strings.HasSuffix(hdr.Name, "/") {
hdr.Typeflag = TypeDir // Legacy archives use trailing slash for directories
} else {
hdr.Typeflag = TypeReg
}
}
// The extended headers may have updated the size.
// Thus, setup the regFileReader again after merging PAX headers.
if err := tr.handleRegularFile(hdr); err != nil {
return nil, err
}
// Sparse formats rely on being able to read from the logical data
// section; there must be a preceding call to handleRegularFile.
if err := tr.handleSparseFile(hdr, rawHdr); err != nil {
return nil, err
}
// Set the final guess at the format.
if format.has(FormatUSTAR) && format.has(FormatPAX) {
format.mayOnlyBe(FormatUSTAR)
}
hdr.Format = format
return hdr, nil // This is a file, so stop
}
}
}
// handleRegularFile sets up the current file reader and padding such that it
// can only read the following logical data section. It will properly handle
// special headers that contain no data section.
func (tr *Reader) handleRegularFile(hdr *Header) error {
nb := hdr.Size
if isHeaderOnlyType(hdr.Typeflag) {
nb = 0
}
if nb < 0 {
return ErrHeader
}
tr.pad = blockPadding(nb)
tr.curr = ®FileReader{r: tr.r, nb: nb}
return nil
}
// handleSparseFile checks if the current file is a sparse format of any type
// and sets the curr reader appropriately.
func (tr *Reader) handleSparseFile(hdr *Header, rawHdr *block) error {
var spd sparseDatas
var err error
if hdr.Typeflag == TypeGNUSparse {
spd, err = tr.readOldGNUSparseMap(hdr, rawHdr)
} else {
spd, err = tr.readGNUSparsePAXHeaders(hdr)
}
// If sp is non-nil, then this is a sparse file.
// Note that it is possible for len(sp) == 0.
if err == nil && spd != nil {
if isHeaderOnlyType(hdr.Typeflag) || !validateSparseEntries(spd, hdr.Size) {
return ErrHeader
}
sph := invertSparseEntries(spd, hdr.Size)
tr.curr = &sparseFileReader{tr.curr, sph, 0}
}
return err
}
// readGNUSparsePAXHeaders checks the PAX headers for GNU sparse headers.
// If they are found, then this function reads the sparse map and returns it.
// This assumes that 0.0 headers have already been converted to 0.1 headers
// by the PAX header parsing logic.
func (tr *Reader) readGNUSparsePAXHeaders(hdr *Header) (sparseDatas, error) {
// Identify the version of GNU headers.
var is1x0 bool
major, minor := hdr.PAXRecords[paxGNUSparseMajor], hdr.PAXRecords[paxGNUSparseMinor]
switch {
case major == "0" && (minor == "0" || minor == "1"):
is1x0 = false
case major == "1" && minor == "0":
is1x0 = true
case major != "" || minor != "":
return nil, nil // Unknown GNU sparse PAX version
case hdr.PAXRecords[paxGNUSparseMap] != "":
is1x0 = false // 0.0 and 0.1 did not have explicit version records, so guess
default:
return nil, nil // Not a PAX format GNU sparse file.
}
hdr.Format.mayOnlyBe(FormatPAX)
// Update hdr from GNU sparse PAX headers.
if name := hdr.PAXRecords[paxGNUSparseName]; name != "" {
hdr.Name = name
}
size := hdr.PAXRecords[paxGNUSparseSize]
if size == "" {
size = hdr.PAXRecords[paxGNUSparseRealSize]
}
if size != "" {
n, err := strconv.ParseInt(size, 10, 64)
if err != nil {
return nil, ErrHeader
}
hdr.Size = n
}
// Read the sparse map according to the appropriate format.
if is1x0 {
return readGNUSparseMap1x0(tr.curr)
}
return readGNUSparseMap0x1(hdr.PAXRecords)
}
// mergePAX merges paxHdrs into hdr for all relevant fields of Header.
func mergePAX(hdr *Header, paxHdrs map[string]string) (err error) {
for k, v := range paxHdrs {
if v == "" {
continue // Keep the original USTAR value
}
var id64 int64
switch k {
case paxPath:
hdr.Name = v
case paxLinkpath:
hdr.Linkname = v
case paxUname:
hdr.Uname = v
case paxGname:
hdr.Gname = v
case paxUid:
id64, err = strconv.ParseInt(v, 10, 64)
hdr.Uid = int(id64) // Integer overflow possible
case paxGid:
id64, err = strconv.ParseInt(v, 10, 64)
hdr.Gid = int(id64) // Integer overflow possible
case paxAtime:
hdr.AccessTime, err = parsePAXTime(v)
case paxMtime:
hdr.ModTime, err = parsePAXTime(v)
case paxCtime:
hdr.ChangeTime, err = parsePAXTime(v)
case paxSize:
hdr.Size, err = strconv.ParseInt(v, 10, 64)
default:
if strings.HasPrefix(k, paxSchilyXattr) {
if hdr.Xattrs == nil {
hdr.Xattrs = make(map[string]string)
}
hdr.Xattrs[k[len(paxSchilyXattr):]] = v
}
}
if err != nil {
return ErrHeader
}
}
hdr.PAXRecords = paxHdrs
return nil
}
// parsePAX parses PAX headers.
// If an extended header (type 'x') is invalid, ErrHeader is returned.
func parsePAX(r io.Reader) (map[string]string, error) {
buf, err := readSpecialFile(r)
if err != nil {
return nil, err
}
sbuf := string(buf)
// For GNU PAX sparse format 0.0 support.
// This function transforms the sparse format 0.0 headers into format 0.1
// headers since 0.0 headers were not PAX compliant.
var sparseMap []string
paxHdrs := make(map[string]string)
for len(sbuf) > 0 {
key, value, residual, err := parsePAXRecord(sbuf)
if err != nil {
return nil, ErrHeader
}
sbuf = residual
switch key {
case paxGNUSparseOffset, paxGNUSparseNumBytes:
// Validate sparse header order and value.
if (len(sparseMap)%2 == 0 && key != paxGNUSparseOffset) ||
(len(sparseMap)%2 == 1 && key != paxGNUSparseNumBytes) ||
strings.Contains(value, ",") {
return nil, ErrHeader
}
sparseMap = append(sparseMap, value)
default:
paxHdrs[key] = value
}
}
if len(sparseMap) > 0 {
paxHdrs[paxGNUSparseMap] = strings.Join(sparseMap, ",")
}
return paxHdrs, nil
}
// readHeader reads the next block header and assumes that the underlying reader
// is already aligned to a block boundary. It returns the raw block of the
// header in case further processing is required.
//
// The err will be set to io.EOF only when one of the following occurs:
// - Exactly 0 bytes are read and EOF is hit.
// - Exactly 1 block of zeros is read and EOF is hit.
// - At least 2 blocks of zeros are read.
func (tr *Reader) readHeader() (*Header, *block, error) {
// Two blocks of zero bytes marks the end of the archive.
if _, err := io.ReadFull(tr.r, tr.blk[:]); err != nil {
return nil, nil, err // EOF is okay here; exactly 0 bytes read
}
if bytes.Equal(tr.blk[:], zeroBlock[:]) {
if _, err := io.ReadFull(tr.r, tr.blk[:]); err != nil {
return nil, nil, err // EOF is okay here; exactly 1 block of zeros read
}
if bytes.Equal(tr.blk[:], zeroBlock[:]) {
return nil, nil, io.EOF // normal EOF; exactly 2 block of zeros read
}
return nil, nil, ErrHeader // Zero block and then non-zero block
}
// Verify the header matches a known format.
format := tr.blk.getFormat()
if format == FormatUnknown {
return nil, nil, ErrHeader
}
var p parser
hdr := new(Header)
// Unpack the V7 header.
v7 := tr.blk.toV7()
hdr.Typeflag = v7.typeFlag()[0]
hdr.Name = p.parseString(v7.name())
hdr.Linkname = p.parseString(v7.linkName())
hdr.Size = p.parseNumeric(v7.size())
hdr.Mode = p.parseNumeric(v7.mode())
hdr.Uid = int(p.parseNumeric(v7.uid()))
hdr.Gid = int(p.parseNumeric(v7.gid()))
hdr.ModTime = time.Unix(p.parseNumeric(v7.modTime()), 0)
// Unpack format specific fields.
if format > formatV7 {
ustar := tr.blk.toUSTAR()
hdr.Uname = p.parseString(ustar.userName())
hdr.Gname = p.parseString(ustar.groupName())
hdr.Devmajor = p.parseNumeric(ustar.devMajor())
hdr.Devminor = p.parseNumeric(ustar.devMinor())
var prefix string
switch {
case format.has(FormatUSTAR | FormatPAX):
hdr.Format = format
ustar := tr.blk.toUSTAR()
prefix = p.parseString(ustar.prefix())
// For Format detection, check if block is properly formatted since
// the parser is more liberal than what USTAR actually permits.
notASCII := func(r rune) bool { return r >= 0x80 }
if bytes.IndexFunc(tr.blk[:], notASCII) >= 0 {
hdr.Format = FormatUnknown // Non-ASCII characters in block.
}
nul := func(b []byte) bool { return int(b[len(b)-1]) == 0 }
if !(nul(v7.size()) && nul(v7.mode()) && nul(v7.uid()) && nul(v7.gid()) &&
nul(v7.modTime()) && nul(ustar.devMajor()) && nul(ustar.devMinor())) {
hdr.Format = FormatUnknown // Numeric fields must end in NUL
}
case format.has(formatSTAR):
star := tr.blk.toSTAR()
prefix = p.parseString(star.prefix())
hdr.AccessTime = time.Unix(p.parseNumeric(star.accessTime()), 0)
hdr.ChangeTime = time.Unix(p.parseNumeric(star.changeTime()), 0)
case format.has(FormatGNU):
hdr.Format = format
var p2 parser
gnu := tr.blk.toGNU()
if b := gnu.accessTime(); b[0] != 0 {
hdr.AccessTime = time.Unix(p2.parseNumeric(b), 0)
}
if b := gnu.changeTime(); b[0] != 0 {
hdr.ChangeTime = time.Unix(p2.parseNumeric(b), 0)
}
// Prior to Go1.8, the Writer had a bug where it would output
// an invalid tar file in certain rare situations because the logic
// incorrectly believed that the old GNU format had a prefix field.
// This is wrong and leads to an output file that mangles the
// atime and ctime fields, which are often left unused.
//
// In order to continue reading tar files created by former, buggy
// versions of Go, we skeptically parse the atime and ctime fields.
// If we are unable to parse them and the prefix field looks like
// an ASCII string, then we fallback on the pre-Go1.8 behavior
// of treating these fields as the USTAR prefix field.
//
// Note that this will not use the fallback logic for all possible
// files generated by a pre-Go1.8 toolchain. If the generated file
// happened to have a prefix field that parses as valid
// atime and ctime fields (e.g., when they are valid octal strings),
// then it is impossible to distinguish between a valid GNU file
// and an invalid pre-Go1.8 file.
//
// See https://golang.org/issues/12594
// See https://golang.org/issues/21005
if p2.err != nil {
hdr.AccessTime, hdr.ChangeTime = time.Time{}, time.Time{}
ustar := tr.blk.toUSTAR()
if s := p.parseString(ustar.prefix()); isASCII(s) {
prefix = s
}
hdr.Format = FormatUnknown // Buggy file is not GNU
}
}
if len(prefix) > 0 {
hdr.Name = prefix + "/" + hdr.Name
}
}
return hdr, &tr.blk, p.err
}
// readOldGNUSparseMap reads the sparse map from the old GNU sparse format.
// The sparse map is stored in the tar header if it's small enough.
// If it's larger than four entries, then one or more extension headers are used
// to store the rest of the sparse map.
//
// The Header.Size does not reflect the size of any extended headers used.
// Thus, this function will read from the raw io.Reader to fetch extra headers.
// This method mutates blk in the process.
func (tr *Reader) readOldGNUSparseMap(hdr *Header, blk *block) (sparseDatas, error) {
// Make sure that the input format is GNU.
// Unfortunately, the STAR format also has a sparse header format that uses
// the same type flag but has a completely different layout.
if blk.getFormat() != FormatGNU {
return nil, ErrHeader
}
hdr.Format.mayOnlyBe(FormatGNU)
var p parser
hdr.Size = p.parseNumeric(blk.toGNU().realSize())
if p.err != nil {
return nil, p.err
}
s := blk.toGNU().sparse()
spd := make(sparseDatas, 0, s.maxEntries())
for {
for i := 0; i < s.maxEntries(); i++ {
// This termination condition is identical to GNU and BSD tar.
if s.entry(i).offset()[0] == 0x00 {
break // Don't return, need to process extended headers (even if empty)
}
offset := p.parseNumeric(s.entry(i).offset())
length := p.parseNumeric(s.entry(i).length())
if p.err != nil {
return nil, p.err
}
spd = append(spd, sparseEntry{Offset: offset, Length: length})
}
if s.isExtended()[0] > 0 {
// There are more entries. Read an extension header and parse its entries.
if _, err := mustReadFull(tr.r, blk[:]); err != nil {
return nil, err
}
s = blk.toSparse()
continue
}
return spd, nil // Done
}
}
// readGNUSparseMap1x0 reads the sparse map as stored in GNU's PAX sparse format
// version 1.0. The format of the sparse map consists of a series of
// newline-terminated numeric fields. The first field is the number of entries
// and is always present. Following this are the entries, consisting of two
// fields (offset, length). This function must stop reading at the end
// boundary of the block containing the last newline.
//
// Note that the GNU manual says that numeric values should be encoded in octal
// format. However, the GNU tar utility itself outputs these values in decimal.
// As such, this library treats values as being encoded in decimal.
func readGNUSparseMap1x0(r io.Reader) (sparseDatas, error) {
var (
cntNewline int64
buf bytes.Buffer
blk block
)
// feedTokens copies data in blocks from r into buf until there are
// at least cnt newlines in buf. It will not read more blocks than needed.
feedTokens := func(n int64) error {
for cntNewline < n {
if _, err := mustReadFull(r, blk[:]); err != nil {
return err
}
buf.Write(blk[:])
for _, c := range blk {
if c == '\n' {
cntNewline++
}
}
}
return nil
}
// nextToken gets the next token delimited by a newline. This assumes that
// at least one newline exists in the buffer.
nextToken := func() string {
cntNewline--
tok, _ := buf.ReadString('\n')
return strings.TrimRight(tok, "\n")
}
// Parse for the number of entries.
// Use integer overflow resistant math to check this.
if err := feedTokens(1); err != nil {
return nil, err
}
numEntries, err := strconv.ParseInt(nextToken(), 10, 0) // Intentionally parse as native int
if err != nil || numEntries < 0 || int(2*numEntries) < int(numEntries) {
return nil, ErrHeader
}
// Parse for all member entries.
// numEntries is trusted after this since a potential attacker must have
// committed resources proportional to what this library used.
if err := feedTokens(2 * numEntries); err != nil {
return nil, err
}
spd := make(sparseDatas, 0, numEntries)
for i := int64(0); i < numEntries; i++ {
offset, err1 := strconv.ParseInt(nextToken(), 10, 64)
length, err2 := strconv.ParseInt(nextToken(), 10, 64)
if err1 != nil || err2 != nil {
return nil, ErrHeader
}
spd = append(spd, sparseEntry{Offset: offset, Length: length})
}
return spd, nil
}
// readGNUSparseMap0x1 reads the sparse map as stored in GNU's PAX sparse format
// version 0.1. The sparse map is stored in the PAX headers.
func readGNUSparseMap0x1(paxHdrs map[string]string) (sparseDatas, error) {
// Get number of entries.
// Use integer overflow resistant math to check this.
numEntriesStr := paxHdrs[paxGNUSparseNumBlocks]
numEntries, err := strconv.ParseInt(numEntriesStr, 10, 0) // Intentionally parse as native int
if err != nil || numEntries < 0 || int(2*numEntries) < int(numEntries) {
return nil, ErrHeader
}
// There should be two numbers in sparseMap for each entry.
sparseMap := strings.Split(paxHdrs[paxGNUSparseMap], ",")
if len(sparseMap) == 1 && sparseMap[0] == "" {
sparseMap = sparseMap[:0]
}
if int64(len(sparseMap)) != 2*numEntries {
return nil, ErrHeader
}
// Loop through the entries in the sparse map.
// numEntries is trusted now.
spd := make(sparseDatas, 0, numEntries)
for len(sparseMap) >= 2 {
offset, err1 := strconv.ParseInt(sparseMap[0], 10, 64)
length, err2 := strconv.ParseInt(sparseMap[1], 10, 64)
if err1 != nil || err2 != nil {
return nil, ErrHeader
}
spd = append(spd, sparseEntry{Offset: offset, Length: length})
sparseMap = sparseMap[2:]
}
return spd, nil
}
// Read reads from the current file in the tar archive.
// It returns (0, io.EOF) when it reaches the end of that file,
// until [Next] is called to advance to the next file.
//
// If the current file is sparse, then the regions marked as a hole
// are read back as NUL-bytes.
//
// Calling Read on special types like [TypeLink], [TypeSymlink], [TypeChar],
// [TypeBlock], [TypeDir], and [TypeFifo] returns (0, [io.EOF]) regardless of what
// the [Header.Size] claims.
func (tr *Reader) Read(b []byte) (int, error) {
if tr.err != nil {
return 0, tr.err
}
n, err := tr.curr.Read(b)
if err != nil && err != io.EOF {
tr.err = err
}
return n, err
}
// writeTo writes the content of the current file to w.
// The bytes written matches the number of remaining bytes in the current file.
//
// If the current file is sparse and w is an io.WriteSeeker,
// then writeTo uses Seek to skip past holes defined in Header.SparseHoles,
// assuming that skipped regions are filled with NULs.
// This always writes the last byte to ensure w is the right size.
//
// TODO(dsnet): Re-export this when adding sparse file support.
// See https://golang.org/issue/22735
func (tr *Reader) writeTo(w io.Writer) (int64, error) {
if tr.err != nil {
return 0, tr.err
}
n, err := tr.curr.WriteTo(w)
if err != nil {
tr.err = err
}
return n, err
}
// regFileReader is a fileReader for reading data from a regular file entry.
type regFileReader struct {
r io.Reader // Underlying Reader
nb int64 // Number of remaining bytes to read
}
func (fr *regFileReader) Read(b []byte) (n int, err error) {
if int64(len(b)) > fr.nb {
b = b[:fr.nb]
}
if len(b) > 0 {
n, err = fr.r.Read(b)
fr.nb -= int64(n)
}
switch {
case err == io.EOF && fr.nb > 0:
return n, io.ErrUnexpectedEOF
case err == nil && fr.nb == 0:
return n, io.EOF
default:
return n, err
}
}
func (fr *regFileReader) WriteTo(w io.Writer) (int64, error) {
return io.Copy(w, struct{ io.Reader }{fr})
}
// logicalRemaining implements fileState.logicalRemaining.
func (fr regFileReader) logicalRemaining() int64 {
return fr.nb
}
// physicalRemaining implements fileState.physicalRemaining.
func (fr regFileReader) physicalRemaining() int64 {
return fr.nb
}
// sparseFileReader is a fileReader for reading data from a sparse file entry.
type sparseFileReader struct {
fr fileReader // Underlying fileReader
sp sparseHoles // Normalized list of sparse holes
pos int64 // Current position in sparse file
}
func (sr *sparseFileReader) Read(b []byte) (n int, err error) {
finished := int64(len(b)) >= sr.logicalRemaining()
if finished {
b = b[:sr.logicalRemaining()]
}
b0 := b
endPos := sr.pos + int64(len(b))
for endPos > sr.pos && err == nil {
var nf int // Bytes read in fragment
holeStart, holeEnd := sr.sp[0].Offset, sr.sp[0].endOffset()
if sr.pos < holeStart { // In a data fragment
bf := b[:min(int64(len(b)), holeStart-sr.pos)]
nf, err = tryReadFull(sr.fr, bf)
} else { // In a hole fragment
bf := b[:min(int64(len(b)), holeEnd-sr.pos)]
nf, err = tryReadFull(zeroReader{}, bf)
}
b = b[nf:]
sr.pos += int64(nf)
if sr.pos >= holeEnd && len(sr.sp) > 1 {
sr.sp = sr.sp[1:] // Ensure last fragment always remains
}
}
n = len(b0) - len(b)
switch {
case err == io.EOF:
return n, errMissData // Less data in dense file than sparse file
case err != nil:
return n, err
case sr.logicalRemaining() == 0 && sr.physicalRemaining() > 0:
return n, errUnrefData // More data in dense file than sparse file
case finished:
return n, io.EOF
default:
return n, nil
}
}
func (sr *sparseFileReader) WriteTo(w io.Writer) (n int64, err error) {
ws, ok := w.(io.WriteSeeker)
if ok {
if _, err := ws.Seek(0, io.SeekCurrent); err != nil {
ok = false // Not all io.Seeker can really seek
}
}
if !ok {
return io.Copy(w, struct{ io.Reader }{sr})
}
var writeLastByte bool
pos0 := sr.pos
for sr.logicalRemaining() > 0 && !writeLastByte && err == nil {
var nf int64 // Size of fragment
holeStart, holeEnd := sr.sp[0].Offset, sr.sp[0].endOffset()
if sr.pos < holeStart { // In a data fragment
nf = holeStart - sr.pos
nf, err = io.CopyN(ws, sr.fr, nf)
} else { // In a hole fragment
nf = holeEnd - sr.pos
if sr.physicalRemaining() == 0 {
writeLastByte = true
nf--
}
_, err = ws.Seek(nf, io.SeekCurrent)
}
sr.pos += nf
if sr.pos >= holeEnd && len(sr.sp) > 1 {
sr.sp = sr.sp[1:] // Ensure last fragment always remains
}
}
// If the last fragment is a hole, then seek to 1-byte before EOF, and
// write a single byte to ensure the file is the right size.
if writeLastByte && err == nil {
_, err = ws.Write([]byte{0})
sr.pos++
}
n = sr.pos - pos0
switch {
case err == io.EOF:
return n, errMissData // Less data in dense file than sparse file
case err != nil:
return n, err
case sr.logicalRemaining() == 0 && sr.physicalRemaining() > 0:
return n, errUnrefData // More data in dense file than sparse file
default:
return n, nil
}
}
func (sr sparseFileReader) logicalRemaining() int64 {
return sr.sp[len(sr.sp)-1].endOffset() - sr.pos
}
func (sr sparseFileReader) physicalRemaining() int64 {
return sr.fr.physicalRemaining()
}
type zeroReader struct{}
func (zeroReader) Read(b []byte) (int, error) {
clear(b)
return len(b), nil
}
// mustReadFull is like io.ReadFull except it returns
// io.ErrUnexpectedEOF when io.EOF is hit before len(b) bytes are read.
func mustReadFull(r io.Reader, b []byte) (int, error) {
n, err := tryReadFull(r, b)
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return n, err
}
// tryReadFull is like io.ReadFull except it returns
// io.EOF when it is hit before len(b) bytes are read.
func tryReadFull(r io.Reader, b []byte) (n int, err error) {
for len(b) > n && err == nil {
var nn int
nn, err = r.Read(b[n:])
n += nn
}
if len(b) == n && err == io.EOF {
err = nil
}
return n, err
}
// readSpecialFile is like io.ReadAll except it returns
// ErrFieldTooLong if more than maxSpecialFileSize is read.
func readSpecialFile(r io.Reader) ([]byte, error) {
buf, err := io.ReadAll(io.LimitReader(r, maxSpecialFileSize+1))
if len(buf) > maxSpecialFileSize {
return nil, ErrFieldTooLong
}
return buf, err
}
// discard skips n bytes in r, reporting an error if unable to do so.
func discard(r io.Reader, n int64) error {
// If possible, Seek to the last byte before the end of the data section.
// Do this because Seek is often lazy about reporting errors; this will mask
// the fact that the stream may be truncated. We can rely on the
// io.CopyN done shortly afterwards to trigger any IO errors.
var seekSkipped int64 // Number of bytes skipped via Seek
if sr, ok := r.(io.Seeker); ok && n > 1 {
// Not all io.Seeker can actually Seek. For example, os.Stdin implements
// io.Seeker, but calling Seek always returns an error and performs
// no action. Thus, we try an innocent seek to the current position
// to see if Seek is really supported.
pos1, err := sr.Seek(0, io.SeekCurrent)
if pos1 >= 0 && err == nil {
// Seek seems supported, so perform the real Seek.
pos2, err := sr.Seek(n-1, io.SeekCurrent)
if pos2 < 0 || err != nil {
return err
}
seekSkipped = pos2 - pos1
}
}
copySkipped, err := io.CopyN(io.Discard, r, n-seekSkipped)
if err == io.EOF && seekSkipped+copySkipped < n {
err = io.ErrUnexpectedEOF
}
return err
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || linux || dragonfly || openbsd || solaris
package tar
import (
"syscall"
"time"
)
func statAtime(st *syscall.Stat_t) time.Time {
return time.Unix(st.Atim.Unix())
}
func statCtime(st *syscall.Stat_t) time.Time {
return time.Unix(st.Ctim.Unix())
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package tar
import (
"io/fs"
"os/user"
"runtime"
"strconv"
"sync"
"syscall"
)
func init() {
sysStat = statUnix
}
// userMap and groupMap caches UID and GID lookups for performance reasons.
// The downside is that renaming uname or gname by the OS never takes effect.
var userMap, groupMap sync.Map // map[int]string
func statUnix(fi fs.FileInfo, h *Header, doNameLookups bool) error {
sys, ok := fi.Sys().(*syscall.Stat_t)
if !ok {
return nil
}
h.Uid = int(sys.Uid)
h.Gid = int(sys.Gid)
if doNameLookups {
// Best effort at populating Uname and Gname.
// The os/user functions may fail for any number of reasons
// (not implemented on that platform, cgo not enabled, etc).
if u, ok := userMap.Load(h.Uid); ok {
h.Uname = u.(string)
} else if u, err := user.LookupId(strconv.Itoa(h.Uid)); err == nil {
h.Uname = u.Username
userMap.Store(h.Uid, h.Uname)
}
if g, ok := groupMap.Load(h.Gid); ok {
h.Gname = g.(string)
} else if g, err := user.LookupGroupId(strconv.Itoa(h.Gid)); err == nil {
h.Gname = g.Name
groupMap.Store(h.Gid, h.Gname)
}
}
h.AccessTime = statAtime(sys)
h.ChangeTime = statCtime(sys)
// Best effort at populating Devmajor and Devminor.
if h.Typeflag == TypeChar || h.Typeflag == TypeBlock {
dev := uint64(sys.Rdev) // May be int32 or uint32
switch runtime.GOOS {
case "aix":
var major, minor uint32
major = uint32((dev & 0x3fffffff00000000) >> 32)
minor = uint32((dev & 0x00000000ffffffff) >> 0)
h.Devmajor, h.Devminor = int64(major), int64(minor)
case "linux":
// Copied from golang.org/x/sys/unix/dev_linux.go.
major := uint32((dev & 0x00000000000fff00) >> 8)
major |= uint32((dev & 0xfffff00000000000) >> 32)
minor := uint32((dev & 0x00000000000000ff) >> 0)
minor |= uint32((dev & 0x00000ffffff00000) >> 12)
h.Devmajor, h.Devminor = int64(major), int64(minor)
case "darwin", "ios":
// Copied from golang.org/x/sys/unix/dev_darwin.go.
major := uint32((dev >> 24) & 0xff)
minor := uint32(dev & 0xffffff)
h.Devmajor, h.Devminor = int64(major), int64(minor)
case "dragonfly":
// Copied from golang.org/x/sys/unix/dev_dragonfly.go.
major := uint32((dev >> 8) & 0xff)
minor := uint32(dev & 0xffff00ff)
h.Devmajor, h.Devminor = int64(major), int64(minor)
case "freebsd":
// Copied from golang.org/x/sys/unix/dev_freebsd.go.
major := uint32((dev >> 8) & 0xff)
minor := uint32(dev & 0xffff00ff)
h.Devmajor, h.Devminor = int64(major), int64(minor)
case "netbsd":
// Copied from golang.org/x/sys/unix/dev_netbsd.go.
major := uint32((dev & 0x000fff00) >> 8)
minor := uint32((dev & 0x000000ff) >> 0)
minor |= uint32((dev & 0xfff00000) >> 12)
h.Devmajor, h.Devminor = int64(major), int64(minor)
case "openbsd":
// Copied from golang.org/x/sys/unix/dev_openbsd.go.
major := uint32((dev & 0x0000ff00) >> 8)
minor := uint32((dev & 0x000000ff) >> 0)
minor |= uint32((dev & 0xffff0000) >> 8)
h.Devmajor, h.Devminor = int64(major), int64(minor)
default:
// TODO: Implement solaris (see https://golang.org/issue/8106)
}
}
return nil
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tar
import (
"bytes"
"fmt"
"strconv"
"strings"
"time"
)
// hasNUL reports whether the NUL character exists within s.
func hasNUL(s string) bool {
return strings.Contains(s, "\x00")
}
// isASCII reports whether the input is an ASCII C-style string.
func isASCII(s string) bool {
for _, c := range s {
if c >= 0x80 || c == 0x00 {
return false
}
}
return true
}
// toASCII converts the input to an ASCII C-style string.
// This is a best effort conversion, so invalid characters are dropped.
func toASCII(s string) string {
if isASCII(s) {
return s
}
b := make([]byte, 0, len(s))
for _, c := range s {
if c < 0x80 && c != 0x00 {
b = append(b, byte(c))
}
}
return string(b)
}
type parser struct {
err error // Last error seen
}
type formatter struct {
err error // Last error seen
}
// parseString parses bytes as a NUL-terminated C-style string.
// If a NUL byte is not found then the whole slice is returned as a string.
func (*parser) parseString(b []byte) string {
if i := bytes.IndexByte(b, 0); i >= 0 {
return string(b[:i])
}
return string(b)
}
// formatString copies s into b, NUL-terminating if possible.
func (f *formatter) formatString(b []byte, s string) {
if len(s) > len(b) {
f.err = ErrFieldTooLong
}
copy(b, s)
if len(s) < len(b) {
b[len(s)] = 0
}
// Some buggy readers treat regular files with a trailing slash
// in the V7 path field as a directory even though the full path
// recorded elsewhere (e.g., via PAX record) contains no trailing slash.
if len(s) > len(b) && b[len(b)-1] == '/' {
n := len(strings.TrimRight(s[:len(b)-1], "/"))
b[n] = 0 // Replace trailing slash with NUL terminator
}
}
// fitsInBase256 reports whether x can be encoded into n bytes using base-256
// encoding. Unlike octal encoding, base-256 encoding does not require that the
// string ends with a NUL character. Thus, all n bytes are available for output.
//
// If operating in binary mode, this assumes strict GNU binary mode; which means
// that the first byte can only be either 0x80 or 0xff. Thus, the first byte is
// equivalent to the sign bit in two's complement form.
func fitsInBase256(n int, x int64) bool {
binBits := uint(n-1) * 8
return n >= 9 || (x >= -1<<binBits && x < 1<<binBits)
}
// parseNumeric parses the input as being encoded in either base-256 or octal.
// This function may return negative numbers.
// If parsing fails or an integer overflow occurs, err will be set.
func (p *parser) parseNumeric(b []byte) int64 {
// Check for base-256 (binary) format first.
// If the first bit is set, then all following bits constitute a two's
// complement encoded number in big-endian byte order.
if len(b) > 0 && b[0]&0x80 != 0 {
// Handling negative numbers relies on the following identity:
// -a-1 == ^a
//
// If the number is negative, we use an inversion mask to invert the
// data bytes and treat the value as an unsigned number.
var inv byte // 0x00 if positive or zero, 0xff if negative
if b[0]&0x40 != 0 {
inv = 0xff
}
var x uint64
for i, c := range b {
c ^= inv // Inverts c only if inv is 0xff, otherwise does nothing
if i == 0 {
c &= 0x7f // Ignore signal bit in first byte
}
if (x >> 56) > 0 {
p.err = ErrHeader // Integer overflow
return 0
}
x = x<<8 | uint64(c)
}
if (x >> 63) > 0 {
p.err = ErrHeader // Integer overflow
return 0
}
if inv == 0xff {
return ^int64(x)
}
return int64(x)
}
// Normal case is base-8 (octal) format.
return p.parseOctal(b)
}
// formatNumeric encodes x into b using base-8 (octal) encoding if possible.
// Otherwise it will attempt to use base-256 (binary) encoding.
func (f *formatter) formatNumeric(b []byte, x int64) {
if fitsInOctal(len(b), x) {
f.formatOctal(b, x)
return
}
if fitsInBase256(len(b), x) {
for i := len(b) - 1; i >= 0; i-- {
b[i] = byte(x)
x >>= 8
}
b[0] |= 0x80 // Highest bit indicates binary format
return
}
f.formatOctal(b, 0) // Last resort, just write zero
f.err = ErrFieldTooLong
}
func (p *parser) parseOctal(b []byte) int64 {
// Because unused fields are filled with NULs, we need
// to skip leading NULs. Fields may also be padded with
// spaces or NULs.
// So we remove leading and trailing NULs and spaces to
// be sure.
b = bytes.Trim(b, " \x00")
if len(b) == 0 {
return 0
}
x, perr := strconv.ParseUint(p.parseString(b), 8, 64)
if perr != nil {
p.err = ErrHeader
}
return int64(x)
}
func (f *formatter) formatOctal(b []byte, x int64) {
if !fitsInOctal(len(b), x) {
x = 0 // Last resort, just write zero
f.err = ErrFieldTooLong
}
s := strconv.FormatInt(x, 8)
// Add leading zeros, but leave room for a NUL.
if n := len(b) - len(s) - 1; n > 0 {
s = strings.Repeat("0", n) + s
}
f.formatString(b, s)
}
// fitsInOctal reports whether the integer x fits in a field n-bytes long
// using octal encoding with the appropriate NUL terminator.
func fitsInOctal(n int, x int64) bool {
octBits := uint(n-1) * 3
return x >= 0 && (n >= 22 || x < 1<<octBits)
}
// parsePAXTime takes a string of the form %d.%d as described in the PAX
// specification. Note that this implementation allows for negative timestamps,
// which is allowed for by the PAX specification, but not always portable.
func parsePAXTime(s string) (time.Time, error) {
const maxNanoSecondDigits = 9
// Split string into seconds and sub-seconds parts.
ss, sn, _ := strings.Cut(s, ".")
// Parse the seconds.
secs, err := strconv.ParseInt(ss, 10, 64)
if err != nil {
return time.Time{}, ErrHeader
}
if len(sn) == 0 {
return time.Unix(secs, 0), nil // No sub-second values
}
// Parse the nanoseconds.
if strings.Trim(sn, "0123456789") != "" {
return time.Time{}, ErrHeader
}
if len(sn) < maxNanoSecondDigits {
sn += strings.Repeat("0", maxNanoSecondDigits-len(sn)) // Right pad
} else {
sn = sn[:maxNanoSecondDigits] // Right truncate
}
nsecs, _ := strconv.ParseInt(sn, 10, 64) // Must succeed
if len(ss) > 0 && ss[0] == '-' {
return time.Unix(secs, -1*nsecs), nil // Negative correction
}
return time.Unix(secs, nsecs), nil
}
// formatPAXTime converts ts into a time of the form %d.%d as described in the
// PAX specification. This function is capable of negative timestamps.
func formatPAXTime(ts time.Time) (s string) {
secs, nsecs := ts.Unix(), ts.Nanosecond()
if nsecs == 0 {
return strconv.FormatInt(secs, 10)
}
// If seconds is negative, then perform correction.
sign := ""
if secs < 0 {
sign = "-" // Remember sign
secs = -(secs + 1) // Add a second to secs
nsecs = -(nsecs - 1e9) // Take that second away from nsecs
}
return strings.TrimRight(fmt.Sprintf("%s%d.%09d", sign, secs, nsecs), "0")
}
// parsePAXRecord parses the input PAX record string into a key-value pair.
// If parsing is successful, it will slice off the currently read record and
// return the remainder as r.
func parsePAXRecord(s string) (k, v, r string, err error) {
// The size field ends at the first space.
nStr, rest, ok := strings.Cut(s, " ")
if !ok {
return "", "", s, ErrHeader
}
// Parse the first token as a decimal integer.
n, perr := strconv.ParseInt(nStr, 10, 0) // Intentionally parse as native int
if perr != nil || n < 5 || n > int64(len(s)) {
return "", "", s, ErrHeader
}
n -= int64(len(nStr) + 1) // convert from index in s to index in rest
if n <= 0 {
return "", "", s, ErrHeader
}
// Extract everything between the space and the final newline.
rec, nl, rem := rest[:n-1], rest[n-1:n], rest[n:]
if nl != "\n" {
return "", "", s, ErrHeader
}
// The first equals separates the key from the value.
k, v, ok = strings.Cut(rec, "=")
if !ok {
return "", "", s, ErrHeader
}
if !validPAXRecord(k, v) {
return "", "", s, ErrHeader
}
return k, v, rem, nil
}
// formatPAXRecord formats a single PAX record, prefixing it with the
// appropriate length.
func formatPAXRecord(k, v string) (string, error) {
if !validPAXRecord(k, v) {
return "", ErrHeader
}
const padding = 3 // Extra padding for ' ', '=', and '\n'
size := len(k) + len(v) + padding
size += len(strconv.Itoa(size))
record := strconv.Itoa(size) + " " + k + "=" + v + "\n"
// Final adjustment if adding size field increased the record size.
if len(record) != size {
size = len(record)
record = strconv.Itoa(size) + " " + k + "=" + v + "\n"
}
return record, nil
}
// validPAXRecord reports whether the key-value pair is valid where each
// record is formatted as:
//
// "%d %s=%s\n" % (size, key, value)
//
// Keys and values should be UTF-8, but the number of bad writers out there
// forces us to be a more liberal.
// Thus, we only reject all keys with NUL, and only reject NULs in values
// for the PAX version of the USTAR string fields.
// The key must not contain an '=' character.
func validPAXRecord(k, v string) bool {
if k == "" || strings.Contains(k, "=") {
return false
}
switch k {
case paxPath, paxLinkpath, paxUname, paxGname:
return !hasNUL(v)
default:
return !hasNUL(k)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tar
import (
"errors"
"fmt"
"io"
"io/fs"
"maps"
"path"
"slices"
"strings"
"time"
)
// Writer provides sequential writing of a tar archive.
// [Writer.WriteHeader] begins a new file with the provided [Header],
// and then Writer can be treated as an io.Writer to supply that file's data.
type Writer struct {
w io.Writer
pad int64 // Amount of padding to write after current file entry
curr fileWriter // Writer for current file entry
hdr Header // Shallow copy of Header that is safe for mutations
blk block // Buffer to use as temporary local storage
// err is a persistent error.
// It is only the responsibility of every exported method of Writer to
// ensure that this error is sticky.
err error
}
// NewWriter creates a new Writer writing to w.
func NewWriter(w io.Writer) *Writer {
return &Writer{w: w, curr: ®FileWriter{w, 0}}
}
type fileWriter interface {
io.Writer
fileState
ReadFrom(io.Reader) (int64, error)
}
// Flush finishes writing the current file's block padding.
// The current file must be fully written before Flush can be called.
//
// This is unnecessary as the next call to [Writer.WriteHeader] or [Writer.Close]
// will implicitly flush out the file's padding.
func (tw *Writer) Flush() error {
if tw.err != nil {
return tw.err
}
if nb := tw.curr.logicalRemaining(); nb > 0 {
return fmt.Errorf("archive/tar: missed writing %d bytes", nb)
}
if _, tw.err = tw.w.Write(zeroBlock[:tw.pad]); tw.err != nil {
return tw.err
}
tw.pad = 0
return nil
}
// WriteHeader writes hdr and prepares to accept the file's contents.
// The Header.Size determines how many bytes can be written for the next file.
// If the current file is not fully written, then this returns an error.
// This implicitly flushes any padding necessary before writing the header.
func (tw *Writer) WriteHeader(hdr *Header) error {
if err := tw.Flush(); err != nil {
return err
}
tw.hdr = *hdr // Shallow copy of Header
// Avoid usage of the legacy TypeRegA flag, and automatically promote
// it to use TypeReg or TypeDir.
if tw.hdr.Typeflag == TypeRegA {
if strings.HasSuffix(tw.hdr.Name, "/") {
tw.hdr.Typeflag = TypeDir
} else {
tw.hdr.Typeflag = TypeReg
}
}
// Round ModTime and ignore AccessTime and ChangeTime unless
// the format is explicitly chosen.
// This ensures nominal usage of WriteHeader (without specifying the format)
// does not always result in the PAX format being chosen, which
// causes a 1KiB increase to every header.
if tw.hdr.Format == FormatUnknown {
tw.hdr.ModTime = tw.hdr.ModTime.Round(time.Second)
tw.hdr.AccessTime = time.Time{}
tw.hdr.ChangeTime = time.Time{}
}
allowedFormats, paxHdrs, err := tw.hdr.allowedFormats()
switch {
case allowedFormats.has(FormatUSTAR):
tw.err = tw.writeUSTARHeader(&tw.hdr)
return tw.err
case allowedFormats.has(FormatPAX):
tw.err = tw.writePAXHeader(&tw.hdr, paxHdrs)
return tw.err
case allowedFormats.has(FormatGNU):
tw.err = tw.writeGNUHeader(&tw.hdr)
return tw.err
default:
return err // Non-fatal error
}
}
func (tw *Writer) writeUSTARHeader(hdr *Header) error {
// Check if we can use USTAR prefix/suffix splitting.
var namePrefix string
if prefix, suffix, ok := splitUSTARPath(hdr.Name); ok {
namePrefix, hdr.Name = prefix, suffix
}
// Pack the main header.
var f formatter
blk := tw.templateV7Plus(hdr, f.formatString, f.formatOctal)
f.formatString(blk.toUSTAR().prefix(), namePrefix)
blk.setFormat(FormatUSTAR)
if f.err != nil {
return f.err // Should never happen since header is validated
}
return tw.writeRawHeader(blk, hdr.Size, hdr.Typeflag)
}
func (tw *Writer) writePAXHeader(hdr *Header, paxHdrs map[string]string) error {
realName, realSize := hdr.Name, hdr.Size
// TODO(dsnet): Re-enable this when adding sparse support.
// See https://golang.org/issue/22735
/*
// Handle sparse files.
var spd sparseDatas
var spb []byte
if len(hdr.SparseHoles) > 0 {
sph := append([]sparseEntry{}, hdr.SparseHoles...) // Copy sparse map
sph = alignSparseEntries(sph, hdr.Size)
spd = invertSparseEntries(sph, hdr.Size)
// Format the sparse map.
hdr.Size = 0 // Replace with encoded size
spb = append(strconv.AppendInt(spb, int64(len(spd)), 10), '\n')
for _, s := range spd {
hdr.Size += s.Length
spb = append(strconv.AppendInt(spb, s.Offset, 10), '\n')
spb = append(strconv.AppendInt(spb, s.Length, 10), '\n')
}
pad := blockPadding(int64(len(spb)))
spb = append(spb, zeroBlock[:pad]...)
hdr.Size += int64(len(spb)) // Accounts for encoded sparse map
// Add and modify appropriate PAX records.
dir, file := path.Split(realName)
hdr.Name = path.Join(dir, "GNUSparseFile.0", file)
paxHdrs[paxGNUSparseMajor] = "1"
paxHdrs[paxGNUSparseMinor] = "0"
paxHdrs[paxGNUSparseName] = realName
paxHdrs[paxGNUSparseRealSize] = strconv.FormatInt(realSize, 10)
paxHdrs[paxSize] = strconv.FormatInt(hdr.Size, 10)
delete(paxHdrs, paxPath) // Recorded by paxGNUSparseName
}
*/
_ = realSize
// Write PAX records to the output.
isGlobal := hdr.Typeflag == TypeXGlobalHeader
if len(paxHdrs) > 0 || isGlobal {
// Write each record to a buffer.
var buf strings.Builder
// Sort keys for deterministic ordering.
for _, k := range slices.Sorted(maps.Keys(paxHdrs)) {
rec, err := formatPAXRecord(k, paxHdrs[k])
if err != nil {
return err
}
buf.WriteString(rec)
}
// Write the extended header file.
var name string
var flag byte
if isGlobal {
name = realName
if name == "" {
name = "GlobalHead.0.0"
}
flag = TypeXGlobalHeader
} else {
dir, file := path.Split(realName)
name = path.Join(dir, "PaxHeaders.0", file)
flag = TypeXHeader
}
data := buf.String()
if len(data) > maxSpecialFileSize {
return ErrFieldTooLong
}
if err := tw.writeRawFile(name, data, flag, FormatPAX); err != nil || isGlobal {
return err // Global headers return here
}
}
// Pack the main header.
var f formatter // Ignore errors since they are expected
fmtStr := func(b []byte, s string) { f.formatString(b, toASCII(s)) }
blk := tw.templateV7Plus(hdr, fmtStr, f.formatOctal)
blk.setFormat(FormatPAX)
if err := tw.writeRawHeader(blk, hdr.Size, hdr.Typeflag); err != nil {
return err
}
// TODO(dsnet): Re-enable this when adding sparse support.
// See https://golang.org/issue/22735
/*
// Write the sparse map and setup the sparse writer if necessary.
if len(spd) > 0 {
// Use tw.curr since the sparse map is accounted for in hdr.Size.
if _, err := tw.curr.Write(spb); err != nil {
return err
}
tw.curr = &sparseFileWriter{tw.curr, spd, 0}
}
*/
return nil
}
func (tw *Writer) writeGNUHeader(hdr *Header) error {
// Use long-link files if Name or Linkname exceeds the field size.
const longName = "././@LongLink"
if len(hdr.Name) > nameSize {
data := hdr.Name + "\x00"
if err := tw.writeRawFile(longName, data, TypeGNULongName, FormatGNU); err != nil {
return err
}
}
if len(hdr.Linkname) > nameSize {
data := hdr.Linkname + "\x00"
if err := tw.writeRawFile(longName, data, TypeGNULongLink, FormatGNU); err != nil {
return err
}
}
// Pack the main header.
var f formatter // Ignore errors since they are expected
var spd sparseDatas
var spb []byte
blk := tw.templateV7Plus(hdr, f.formatString, f.formatNumeric)
if !hdr.AccessTime.IsZero() {
f.formatNumeric(blk.toGNU().accessTime(), hdr.AccessTime.Unix())
}
if !hdr.ChangeTime.IsZero() {
f.formatNumeric(blk.toGNU().changeTime(), hdr.ChangeTime.Unix())
}
// TODO(dsnet): Re-enable this when adding sparse support.
// See https://golang.org/issue/22735
/*
if hdr.Typeflag == TypeGNUSparse {
sph := append([]sparseEntry{}, hdr.SparseHoles...) // Copy sparse map
sph = alignSparseEntries(sph, hdr.Size)
spd = invertSparseEntries(sph, hdr.Size)
// Format the sparse map.
formatSPD := func(sp sparseDatas, sa sparseArray) sparseDatas {
for i := 0; len(sp) > 0 && i < sa.MaxEntries(); i++ {
f.formatNumeric(sa.Entry(i).Offset(), sp[0].Offset)
f.formatNumeric(sa.Entry(i).Length(), sp[0].Length)
sp = sp[1:]
}
if len(sp) > 0 {
sa.IsExtended()[0] = 1
}
return sp
}
sp2 := formatSPD(spd, blk.GNU().Sparse())
for len(sp2) > 0 {
var spHdr block
sp2 = formatSPD(sp2, spHdr.Sparse())
spb = append(spb, spHdr[:]...)
}
// Update size fields in the header block.
realSize := hdr.Size
hdr.Size = 0 // Encoded size; does not account for encoded sparse map
for _, s := range spd {
hdr.Size += s.Length
}
copy(blk.V7().Size(), zeroBlock[:]) // Reset field
f.formatNumeric(blk.V7().Size(), hdr.Size)
f.formatNumeric(blk.GNU().RealSize(), realSize)
}
*/
blk.setFormat(FormatGNU)
if err := tw.writeRawHeader(blk, hdr.Size, hdr.Typeflag); err != nil {
return err
}
// Write the extended sparse map and setup the sparse writer if necessary.
if len(spd) > 0 {
// Use tw.w since the sparse map is not accounted for in hdr.Size.
if _, err := tw.w.Write(spb); err != nil {
return err
}
tw.curr = &sparseFileWriter{tw.curr, spd, 0}
}
return nil
}
type (
stringFormatter func([]byte, string)
numberFormatter func([]byte, int64)
)
// templateV7Plus fills out the V7 fields of a block using values from hdr.
// It also fills out fields (uname, gname, devmajor, devminor) that are
// shared in the USTAR, PAX, and GNU formats using the provided formatters.
//
// The block returned is only valid until the next call to
// templateV7Plus or writeRawFile.
func (tw *Writer) templateV7Plus(hdr *Header, fmtStr stringFormatter, fmtNum numberFormatter) *block {
tw.blk.reset()
modTime := hdr.ModTime
if modTime.IsZero() {
modTime = time.Unix(0, 0)
}
v7 := tw.blk.toV7()
v7.typeFlag()[0] = hdr.Typeflag
fmtStr(v7.name(), hdr.Name)
fmtStr(v7.linkName(), hdr.Linkname)
fmtNum(v7.mode(), hdr.Mode)
fmtNum(v7.uid(), int64(hdr.Uid))
fmtNum(v7.gid(), int64(hdr.Gid))
fmtNum(v7.size(), hdr.Size)
fmtNum(v7.modTime(), modTime.Unix())
ustar := tw.blk.toUSTAR()
fmtStr(ustar.userName(), hdr.Uname)
fmtStr(ustar.groupName(), hdr.Gname)
fmtNum(ustar.devMajor(), hdr.Devmajor)
fmtNum(ustar.devMinor(), hdr.Devminor)
return &tw.blk
}
// writeRawFile writes a minimal file with the given name and flag type.
// It uses format to encode the header format and will write data as the body.
// It uses default values for all of the other fields (as BSD and GNU tar does).
func (tw *Writer) writeRawFile(name, data string, flag byte, format Format) error {
tw.blk.reset()
// Best effort for the filename.
name = toASCII(name)
if len(name) > nameSize {
name = name[:nameSize]
}
name = strings.TrimRight(name, "/")
var f formatter
v7 := tw.blk.toV7()
v7.typeFlag()[0] = flag
f.formatString(v7.name(), name)
f.formatOctal(v7.mode(), 0)
f.formatOctal(v7.uid(), 0)
f.formatOctal(v7.gid(), 0)
f.formatOctal(v7.size(), int64(len(data))) // Must be < 8GiB
f.formatOctal(v7.modTime(), 0)
tw.blk.setFormat(format)
if f.err != nil {
return f.err // Only occurs if size condition is violated
}
// Write the header and data.
if err := tw.writeRawHeader(&tw.blk, int64(len(data)), flag); err != nil {
return err
}
_, err := io.WriteString(tw, data)
return err
}
// writeRawHeader writes the value of blk, regardless of its value.
// It sets up the Writer such that it can accept a file of the given size.
// If the flag is a special header-only flag, then the size is treated as zero.
func (tw *Writer) writeRawHeader(blk *block, size int64, flag byte) error {
if err := tw.Flush(); err != nil {
return err
}
if _, err := tw.w.Write(blk[:]); err != nil {
return err
}
if isHeaderOnlyType(flag) {
size = 0
}
tw.curr = ®FileWriter{tw.w, size}
tw.pad = blockPadding(size)
return nil
}
// AddFS adds the files from fs.FS to the archive.
// It walks the directory tree starting at the root of the filesystem
// adding each file to the tar archive while maintaining the directory structure.
func (tw *Writer) AddFS(fsys fs.FS) error {
return fs.WalkDir(fsys, ".", func(name string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if name == "." {
return nil
}
info, err := d.Info()
if err != nil {
return err
}
linkTarget := ""
if typ := d.Type(); typ == fs.ModeSymlink {
var err error
linkTarget, err = fs.ReadLink(fsys, name)
if err != nil {
return err
}
} else if !typ.IsRegular() && typ != fs.ModeDir {
return errors.New("tar: cannot add non-regular file")
}
h, err := FileInfoHeader(info, linkTarget)
if err != nil {
return err
}
h.Name = name
if d.IsDir() {
h.Name += "/"
}
if err := tw.WriteHeader(h); err != nil {
return err
}
if !d.Type().IsRegular() {
return nil
}
f, err := fsys.Open(name)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(tw, f)
return err
})
}
// splitUSTARPath splits a path according to USTAR prefix and suffix rules.
// If the path is not splittable, then it will return ("", "", false).
func splitUSTARPath(name string) (prefix, suffix string, ok bool) {
length := len(name)
if length <= nameSize || !isASCII(name) {
return "", "", false
} else if length > prefixSize+1 {
length = prefixSize + 1
} else if name[length-1] == '/' {
length--
}
i := strings.LastIndex(name[:length], "/")
nlen := len(name) - i - 1 // nlen is length of suffix
plen := i // plen is length of prefix
if i <= 0 || nlen > nameSize || nlen == 0 || plen > prefixSize {
return "", "", false
}
return name[:i], name[i+1:], true
}
// Write writes to the current file in the tar archive.
// Write returns the error [ErrWriteTooLong] if more than
// Header.Size bytes are written after [Writer.WriteHeader].
//
// Calling Write on special types like [TypeLink], [TypeSymlink], [TypeChar],
// [TypeBlock], [TypeDir], and [TypeFifo] returns (0, [ErrWriteTooLong]) regardless
// of what the [Header.Size] claims.
func (tw *Writer) Write(b []byte) (int, error) {
if tw.err != nil {
return 0, tw.err
}
n, err := tw.curr.Write(b)
if err != nil && err != ErrWriteTooLong {
tw.err = err
}
return n, err
}
// readFrom populates the content of the current file by reading from r.
// The bytes read must match the number of remaining bytes in the current file.
//
// If the current file is sparse and r is an io.ReadSeeker,
// then readFrom uses Seek to skip past holes defined in Header.SparseHoles,
// assuming that skipped regions are all NULs.
// This always reads the last byte to ensure r is the right size.
//
// TODO(dsnet): Re-export this when adding sparse file support.
// See https://golang.org/issue/22735
func (tw *Writer) readFrom(r io.Reader) (int64, error) {
if tw.err != nil {
return 0, tw.err
}
n, err := tw.curr.ReadFrom(r)
if err != nil && err != ErrWriteTooLong {
tw.err = err
}
return n, err
}
// Close closes the tar archive by flushing the padding, and writing the footer.
// If the current file (from a prior call to [Writer.WriteHeader]) is not fully written,
// then this returns an error.
func (tw *Writer) Close() error {
if tw.err == ErrWriteAfterClose {
return nil
}
if tw.err != nil {
return tw.err
}
// Trailer: two zero blocks.
err := tw.Flush()
for i := 0; i < 2 && err == nil; i++ {
_, err = tw.w.Write(zeroBlock[:])
}
// Ensure all future actions are invalid.
tw.err = ErrWriteAfterClose
return err // Report IO errors
}
// regFileWriter is a fileWriter for writing data to a regular file entry.
type regFileWriter struct {
w io.Writer // Underlying Writer
nb int64 // Number of remaining bytes to write
}
func (fw *regFileWriter) Write(b []byte) (n int, err error) {
overwrite := int64(len(b)) > fw.nb
if overwrite {
b = b[:fw.nb]
}
if len(b) > 0 {
n, err = fw.w.Write(b)
fw.nb -= int64(n)
}
switch {
case err != nil:
return n, err
case overwrite:
return n, ErrWriteTooLong
default:
return n, nil
}
}
func (fw *regFileWriter) ReadFrom(r io.Reader) (int64, error) {
return io.Copy(struct{ io.Writer }{fw}, r)
}
// logicalRemaining implements fileState.logicalRemaining.
func (fw regFileWriter) logicalRemaining() int64 {
return fw.nb
}
// physicalRemaining implements fileState.physicalRemaining.
func (fw regFileWriter) physicalRemaining() int64 {
return fw.nb
}
// sparseFileWriter is a fileWriter for writing data to a sparse file entry.
type sparseFileWriter struct {
fw fileWriter // Underlying fileWriter
sp sparseDatas // Normalized list of data fragments
pos int64 // Current position in sparse file
}
func (sw *sparseFileWriter) Write(b []byte) (n int, err error) {
overwrite := int64(len(b)) > sw.logicalRemaining()
if overwrite {
b = b[:sw.logicalRemaining()]
}
b0 := b
endPos := sw.pos + int64(len(b))
for endPos > sw.pos && err == nil {
var nf int // Bytes written in fragment
dataStart, dataEnd := sw.sp[0].Offset, sw.sp[0].endOffset()
if sw.pos < dataStart { // In a hole fragment
bf := b[:min(int64(len(b)), dataStart-sw.pos)]
nf, err = zeroWriter{}.Write(bf)
} else { // In a data fragment
bf := b[:min(int64(len(b)), dataEnd-sw.pos)]
nf, err = sw.fw.Write(bf)
}
b = b[nf:]
sw.pos += int64(nf)
if sw.pos >= dataEnd && len(sw.sp) > 1 {
sw.sp = sw.sp[1:] // Ensure last fragment always remains
}
}
n = len(b0) - len(b)
switch {
case err == ErrWriteTooLong:
return n, errMissData // Not possible; implies bug in validation logic
case err != nil:
return n, err
case sw.logicalRemaining() == 0 && sw.physicalRemaining() > 0:
return n, errUnrefData // Not possible; implies bug in validation logic
case overwrite:
return n, ErrWriteTooLong
default:
return n, nil
}
}
func (sw *sparseFileWriter) ReadFrom(r io.Reader) (n int64, err error) {
rs, ok := r.(io.ReadSeeker)
if ok {
if _, err := rs.Seek(0, io.SeekCurrent); err != nil {
ok = false // Not all io.Seeker can really seek
}
}
if !ok {
return io.Copy(struct{ io.Writer }{sw}, r)
}
var readLastByte bool
pos0 := sw.pos
for sw.logicalRemaining() > 0 && !readLastByte && err == nil {
var nf int64 // Size of fragment
dataStart, dataEnd := sw.sp[0].Offset, sw.sp[0].endOffset()
if sw.pos < dataStart { // In a hole fragment
nf = dataStart - sw.pos
if sw.physicalRemaining() == 0 {
readLastByte = true
nf--
}
_, err = rs.Seek(nf, io.SeekCurrent)
} else { // In a data fragment
nf = dataEnd - sw.pos
nf, err = io.CopyN(sw.fw, rs, nf)
}
sw.pos += nf
if sw.pos >= dataEnd && len(sw.sp) > 1 {
sw.sp = sw.sp[1:] // Ensure last fragment always remains
}
}
// If the last fragment is a hole, then seek to 1-byte before EOF, and
// read a single byte to ensure the file is the right size.
if readLastByte && err == nil {
_, err = mustReadFull(rs, []byte{0})
sw.pos++
}
n = sw.pos - pos0
switch {
case err == io.EOF:
return n, io.ErrUnexpectedEOF
case err == ErrWriteTooLong:
return n, errMissData // Not possible; implies bug in validation logic
case err != nil:
return n, err
case sw.logicalRemaining() == 0 && sw.physicalRemaining() > 0:
return n, errUnrefData // Not possible; implies bug in validation logic
default:
return n, ensureEOF(rs)
}
}
func (sw sparseFileWriter) logicalRemaining() int64 {
return sw.sp[len(sw.sp)-1].endOffset() - sw.pos
}
func (sw sparseFileWriter) physicalRemaining() int64 {
return sw.fw.physicalRemaining()
}
// zeroWriter may only be written with NULs, otherwise it returns errWriteHole.
type zeroWriter struct{}
func (zeroWriter) Write(b []byte) (int, error) {
for i, c := range b {
if c != 0 {
return i, errWriteHole
}
}
return len(b), nil
}
// ensureEOF checks whether r is at EOF, reporting ErrWriteTooLong if not so.
func ensureEOF(r io.Reader) error {
n, err := tryReadFull(r, []byte{0})
switch {
case n > 0:
return ErrWriteTooLong
case err == io.EOF:
return nil
default:
return err
}
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zip
import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"hash"
"hash/crc32"
"internal/godebug"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"slices"
"strings"
"sync"
"time"
)
var zipinsecurepath = godebug.New("zipinsecurepath")
var (
ErrFormat = errors.New("zip: not a valid zip file")
ErrAlgorithm = errors.New("zip: unsupported compression algorithm")
ErrChecksum = errors.New("zip: checksum error")
ErrInsecurePath = errors.New("zip: insecure file path")
)
// A Reader serves content from a ZIP archive.
type Reader struct {
r io.ReaderAt
File []*File
Comment string
decompressors map[uint16]Decompressor
// Some JAR files are zip files with a prefix that is a bash script.
// The baseOffset field is the start of the zip file proper.
baseOffset int64
// fileList is a list of files sorted by ename,
// for use by the Open method.
fileListOnce sync.Once
fileList []fileListEntry
}
// A ReadCloser is a [Reader] that must be closed when no longer needed.
type ReadCloser struct {
f *os.File
Reader
}
// A File is a single file in a ZIP archive.
// The file information is in the embedded [FileHeader].
// The file content can be accessed by calling [File.Open].
type File struct {
FileHeader
zip *Reader
zipr io.ReaderAt
headerOffset int64 // includes overall ZIP archive baseOffset
zip64 bool // zip64 extended information extra field presence
}
// OpenReader will open the Zip file specified by name and return a ReadCloser.
//
// If any file inside the archive uses a non-local name
// (as defined by [filepath.IsLocal]) or a name containing backslashes
// and the GODEBUG environment variable contains `zipinsecurepath=0`,
// OpenReader returns the reader with an ErrInsecurePath error.
// A future version of Go may introduce this behavior by default.
// Programs that want to accept non-local names can ignore
// the ErrInsecurePath error and use the returned reader.
func OpenReader(name string) (*ReadCloser, error) {
f, err := os.Open(name)
if err != nil {
return nil, err
}
fi, err := f.Stat()
if err != nil {
f.Close()
return nil, err
}
r := new(ReadCloser)
if err = r.init(f, fi.Size()); err != nil && err != ErrInsecurePath {
f.Close()
return nil, err
}
r.f = f
return r, err
}
// NewReader returns a new [Reader] reading from r, which is assumed to
// have the given size in bytes.
//
// If any file inside the archive uses a non-local name
// (as defined by [filepath.IsLocal]) or a name containing backslashes
// and the GODEBUG environment variable contains `zipinsecurepath=0`,
// NewReader returns the reader with an [ErrInsecurePath] error.
// A future version of Go may introduce this behavior by default.
// Programs that want to accept non-local names can ignore
// the [ErrInsecurePath] error and use the returned reader.
func NewReader(r io.ReaderAt, size int64) (*Reader, error) {
if size < 0 {
return nil, errors.New("zip: size cannot be negative")
}
zr := new(Reader)
var err error
if err = zr.init(r, size); err != nil && err != ErrInsecurePath {
return nil, err
}
return zr, err
}
func (r *Reader) init(rdr io.ReaderAt, size int64) error {
end, baseOffset, err := readDirectoryEnd(rdr, size)
if err != nil {
return err
}
r.r = rdr
r.baseOffset = baseOffset
// Since the number of directory records is not validated, it is not
// safe to preallocate r.File without first checking that the specified
// number of files is reasonable, since a malformed archive may
// indicate it contains up to 1 << 128 - 1 files. Since each file has a
// header which will be _at least_ 30 bytes we can safely preallocate
// if (data size / 30) >= end.directoryRecords.
if end.directorySize < uint64(size) && (uint64(size)-end.directorySize)/30 >= end.directoryRecords {
r.File = make([]*File, 0, end.directoryRecords)
}
r.Comment = end.comment
rs := io.NewSectionReader(rdr, 0, size)
if _, err = rs.Seek(r.baseOffset+int64(end.directoryOffset), io.SeekStart); err != nil {
return err
}
buf := bufio.NewReader(rs)
// The count of files inside a zip is truncated to fit in a uint16.
// Gloss over this by reading headers until we encounter
// a bad one, and then only report an ErrFormat or UnexpectedEOF if
// the file count modulo 65536 is incorrect.
for {
f := &File{zip: r, zipr: rdr}
err = readDirectoryHeader(f, buf)
if err == ErrFormat || err == io.ErrUnexpectedEOF {
break
}
if err != nil {
return err
}
f.headerOffset += r.baseOffset
r.File = append(r.File, f)
}
if uint16(len(r.File)) != uint16(end.directoryRecords) { // only compare 16 bits here
// Return the readDirectoryHeader error if we read
// the wrong number of directory entries.
return err
}
if zipinsecurepath.Value() == "0" {
for _, f := range r.File {
if f.Name == "" {
// Zip permits an empty file name field.
continue
}
// The zip specification states that names must use forward slashes,
// so consider any backslashes in the name insecure.
if !filepath.IsLocal(f.Name) || strings.Contains(f.Name, `\`) {
zipinsecurepath.IncNonDefault()
return ErrInsecurePath
}
}
}
return nil
}
// RegisterDecompressor registers or overrides a custom decompressor for a
// specific method ID. If a decompressor for a given method is not found,
// [Reader] will default to looking up the decompressor at the package level.
func (r *Reader) RegisterDecompressor(method uint16, dcomp Decompressor) {
if r.decompressors == nil {
r.decompressors = make(map[uint16]Decompressor)
}
r.decompressors[method] = dcomp
}
func (r *Reader) decompressor(method uint16) Decompressor {
dcomp := r.decompressors[method]
if dcomp == nil {
dcomp = decompressor(method)
}
return dcomp
}
// Close closes the Zip file, rendering it unusable for I/O.
func (rc *ReadCloser) Close() error {
return rc.f.Close()
}
// DataOffset returns the offset of the file's possibly-compressed
// data, relative to the beginning of the zip file.
//
// Most callers should instead use [File.Open], which transparently
// decompresses data and verifies checksums.
func (f *File) DataOffset() (offset int64, err error) {
bodyOffset, err := f.findBodyOffset()
if err != nil {
return
}
return f.headerOffset + bodyOffset, nil
}
// Open returns a [ReadCloser] that provides access to the [File]'s contents.
// Multiple files may be read concurrently.
func (f *File) Open() (io.ReadCloser, error) {
bodyOffset, err := f.findBodyOffset()
if err != nil {
return nil, err
}
if strings.HasSuffix(f.Name, "/") {
// The ZIP specification (APPNOTE.TXT) specifies that directories, which
// are technically zero-byte files, must not have any associated file
// data. We previously tried failing here if f.CompressedSize64 != 0,
// but it turns out that a number of implementations (namely, the Java
// jar tool) don't properly set the storage method on directories
// resulting in a file with compressed size > 0 but uncompressed size ==
// 0. We still want to fail when a directory has associated uncompressed
// data, but we are tolerant of cases where the uncompressed size is
// zero but compressed size is not.
if f.UncompressedSize64 != 0 {
return &dirReader{ErrFormat}, nil
} else {
return &dirReader{io.EOF}, nil
}
}
size := int64(f.CompressedSize64)
r := io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset, size)
dcomp := f.zip.decompressor(f.Method)
if dcomp == nil {
return nil, ErrAlgorithm
}
var rc io.ReadCloser = dcomp(r)
var desr io.Reader
if f.hasDataDescriptor() {
desr = io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset+size, dataDescriptorLen)
}
rc = &checksumReader{
rc: rc,
hash: crc32.NewIEEE(),
f: f,
desr: desr,
}
return rc, nil
}
// OpenRaw returns a [Reader] that provides access to the [File]'s contents without
// decompression.
func (f *File) OpenRaw() (io.Reader, error) {
bodyOffset, err := f.findBodyOffset()
if err != nil {
return nil, err
}
r := io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset, int64(f.CompressedSize64))
return r, nil
}
type dirReader struct {
err error
}
func (r *dirReader) Read([]byte) (int, error) {
return 0, r.err
}
func (r *dirReader) Close() error {
return nil
}
type checksumReader struct {
rc io.ReadCloser
hash hash.Hash32
nread uint64 // number of bytes read so far
f *File
desr io.Reader // if non-nil, where to read the data descriptor
err error // sticky error
}
func (r *checksumReader) Stat() (fs.FileInfo, error) {
return headerFileInfo{&r.f.FileHeader}, nil
}
func (r *checksumReader) Read(b []byte) (n int, err error) {
if r.err != nil {
return 0, r.err
}
n, err = r.rc.Read(b)
r.hash.Write(b[:n])
r.nread += uint64(n)
if r.nread > r.f.UncompressedSize64 {
return 0, ErrFormat
}
if err == nil {
return
}
if err == io.EOF {
if r.nread != r.f.UncompressedSize64 {
return 0, io.ErrUnexpectedEOF
}
if r.desr != nil {
if err1 := readDataDescriptor(r.desr, r.f); err1 != nil {
if err1 == io.EOF {
err = io.ErrUnexpectedEOF
} else {
err = err1
}
} else if r.hash.Sum32() != r.f.CRC32 {
err = ErrChecksum
}
} else {
// If there's not a data descriptor, we still compare
// the CRC32 of what we've read against the file header
// or TOC's CRC32, if it seems like it was set.
if r.f.CRC32 != 0 && r.hash.Sum32() != r.f.CRC32 {
err = ErrChecksum
}
}
}
r.err = err
return
}
func (r *checksumReader) Close() error { return r.rc.Close() }
// findBodyOffset does the minimum work to verify the file has a header
// and returns the file body offset.
func (f *File) findBodyOffset() (int64, error) {
var buf [fileHeaderLen]byte
if _, err := f.zipr.ReadAt(buf[:], f.headerOffset); err != nil {
return 0, err
}
b := readBuf(buf[:])
if sig := b.uint32(); sig != fileHeaderSignature {
return 0, ErrFormat
}
b = b[22:] // skip over most of the header
filenameLen := int(b.uint16())
extraLen := int(b.uint16())
return int64(fileHeaderLen + filenameLen + extraLen), nil
}
// readDirectoryHeader attempts to read a directory header from r.
// It returns io.ErrUnexpectedEOF if it cannot read a complete header,
// and ErrFormat if it doesn't find a valid header signature.
func readDirectoryHeader(f *File, r io.Reader) error {
var buf [directoryHeaderLen]byte
if _, err := io.ReadFull(r, buf[:]); err != nil {
return err
}
b := readBuf(buf[:])
if sig := b.uint32(); sig != directoryHeaderSignature {
return ErrFormat
}
f.CreatorVersion = b.uint16()
f.ReaderVersion = b.uint16()
f.Flags = b.uint16()
f.Method = b.uint16()
f.ModifiedTime = b.uint16()
f.ModifiedDate = b.uint16()
f.CRC32 = b.uint32()
f.CompressedSize = b.uint32()
f.UncompressedSize = b.uint32()
f.CompressedSize64 = uint64(f.CompressedSize)
f.UncompressedSize64 = uint64(f.UncompressedSize)
filenameLen := int(b.uint16())
extraLen := int(b.uint16())
commentLen := int(b.uint16())
b = b[4:] // skipped start disk number and internal attributes (2x uint16)
f.ExternalAttrs = b.uint32()
f.headerOffset = int64(b.uint32())
d := make([]byte, filenameLen+extraLen+commentLen)
if _, err := io.ReadFull(r, d); err != nil {
return err
}
f.Name = string(d[:filenameLen])
f.Extra = d[filenameLen : filenameLen+extraLen]
f.Comment = string(d[filenameLen+extraLen:])
// Determine the character encoding.
utf8Valid1, utf8Require1 := detectUTF8(f.Name)
utf8Valid2, utf8Require2 := detectUTF8(f.Comment)
switch {
case !utf8Valid1 || !utf8Valid2:
// Name and Comment definitely not UTF-8.
f.NonUTF8 = true
case !utf8Require1 && !utf8Require2:
// Name and Comment use only single-byte runes that overlap with UTF-8.
f.NonUTF8 = false
default:
// Might be UTF-8, might be some other encoding; preserve existing flag.
// Some ZIP writers use UTF-8 encoding without setting the UTF-8 flag.
// Since it is impossible to always distinguish valid UTF-8 from some
// other encoding (e.g., GBK or Shift-JIS), we trust the flag.
f.NonUTF8 = f.Flags&0x800 == 0
}
needUSize := f.UncompressedSize == ^uint32(0)
needCSize := f.CompressedSize == ^uint32(0)
needHeaderOffset := f.headerOffset == int64(^uint32(0))
// Best effort to find what we need.
// Other zip authors might not even follow the basic format,
// and we'll just ignore the Extra content in that case.
var modified time.Time
parseExtras:
for extra := readBuf(f.Extra); len(extra) >= 4; { // need at least tag and size
fieldTag := extra.uint16()
fieldSize := int(extra.uint16())
if len(extra) < fieldSize {
break
}
fieldBuf := extra.sub(fieldSize)
switch fieldTag {
case zip64ExtraID:
f.zip64 = true
// update directory values from the zip64 extra block.
// They should only be consulted if the sizes read earlier
// are maxed out.
// See golang.org/issue/13367.
if needUSize {
needUSize = false
if len(fieldBuf) < 8 {
return ErrFormat
}
f.UncompressedSize64 = fieldBuf.uint64()
}
if needCSize {
needCSize = false
if len(fieldBuf) < 8 {
return ErrFormat
}
f.CompressedSize64 = fieldBuf.uint64()
}
if needHeaderOffset {
needHeaderOffset = false
if len(fieldBuf) < 8 {
return ErrFormat
}
f.headerOffset = int64(fieldBuf.uint64())
}
case ntfsExtraID:
if len(fieldBuf) < 4 {
continue parseExtras
}
fieldBuf.uint32() // reserved (ignored)
for len(fieldBuf) >= 4 { // need at least tag and size
attrTag := fieldBuf.uint16()
attrSize := int(fieldBuf.uint16())
if len(fieldBuf) < attrSize {
continue parseExtras
}
attrBuf := fieldBuf.sub(attrSize)
if attrTag != 1 || attrSize != 24 {
continue // Ignore irrelevant attributes
}
const ticksPerSecond = 1e7 // Windows timestamp resolution
ts := int64(attrBuf.uint64()) // ModTime since Windows epoch
secs := ts / ticksPerSecond
nsecs := (1e9 / ticksPerSecond) * (ts % ticksPerSecond)
epoch := time.Date(1601, time.January, 1, 0, 0, 0, 0, time.UTC)
modified = time.Unix(epoch.Unix()+secs, nsecs)
}
case unixExtraID, infoZipUnixExtraID:
if len(fieldBuf) < 8 {
continue parseExtras
}
fieldBuf.uint32() // AcTime (ignored)
ts := int64(fieldBuf.uint32()) // ModTime since Unix epoch
modified = time.Unix(ts, 0)
case extTimeExtraID:
if len(fieldBuf) < 5 || fieldBuf.uint8()&1 == 0 {
continue parseExtras
}
ts := int64(fieldBuf.uint32()) // ModTime since Unix epoch
modified = time.Unix(ts, 0)
}
}
msdosModified := msDosTimeToTime(f.ModifiedDate, f.ModifiedTime)
f.Modified = msdosModified
if !modified.IsZero() {
f.Modified = modified.UTC()
// If legacy MS-DOS timestamps are set, we can use the delta between
// the legacy and extended versions to estimate timezone offset.
//
// A non-UTC timezone is always used (even if offset is zero).
// Thus, FileHeader.Modified.Location() == time.UTC is useful for
// determining whether extended timestamps are present.
// This is necessary for users that need to do additional time
// calculations when dealing with legacy ZIP formats.
if f.ModifiedTime != 0 || f.ModifiedDate != 0 {
f.Modified = modified.In(timeZone(msdosModified.Sub(modified)))
}
}
// Assume that uncompressed size 2³²-1 could plausibly happen in
// an old zip32 file that was sharding inputs into the largest chunks
// possible (or is just malicious; search the web for 42.zip).
// If needUSize is true still, it means we didn't see a zip64 extension.
// As long as the compressed size is not also 2³²-1 (implausible)
// and the header is not also 2³²-1 (equally implausible),
// accept the uncompressed size 2³²-1 as valid.
// If nothing else, this keeps archive/zip working with 42.zip.
_ = needUSize
if needCSize || needHeaderOffset {
return ErrFormat
}
return nil
}
func readDataDescriptor(r io.Reader, f *File) error {
var buf [dataDescriptorLen]byte
// The spec says: "Although not originally assigned a
// signature, the value 0x08074b50 has commonly been adopted
// as a signature value for the data descriptor record.
// Implementers should be aware that ZIP files may be
// encountered with or without this signature marking data
// descriptors and should account for either case when reading
// ZIP files to ensure compatibility."
//
// dataDescriptorLen includes the size of the signature but
// first read just those 4 bytes to see if it exists.
if _, err := io.ReadFull(r, buf[:4]); err != nil {
return err
}
off := 0
maybeSig := readBuf(buf[:4])
if maybeSig.uint32() != dataDescriptorSignature {
// No data descriptor signature. Keep these four
// bytes.
off += 4
}
if _, err := io.ReadFull(r, buf[off:12]); err != nil {
return err
}
b := readBuf(buf[:12])
if b.uint32() != f.CRC32 {
return ErrChecksum
}
// The two sizes that follow here can be either 32 bits or 64 bits
// but the spec is not very clear on this and different
// interpretations has been made causing incompatibilities. We
// already have the sizes from the central directory so we can
// just ignore these.
return nil
}
func readDirectoryEnd(r io.ReaderAt, size int64) (dir *directoryEnd, baseOffset int64, err error) {
// look for directoryEndSignature in the last 1k, then in the last 65k
var buf []byte
var directoryEndOffset int64
for i, bLen := range []int64{1024, 65 * 1024} {
if bLen > size {
bLen = size
}
buf = make([]byte, int(bLen))
if _, err := r.ReadAt(buf, size-bLen); err != nil && err != io.EOF {
return nil, 0, err
}
if p := findSignatureInBlock(buf); p >= 0 {
buf = buf[p:]
directoryEndOffset = size - bLen + int64(p)
break
}
if i == 1 || bLen == size {
return nil, 0, ErrFormat
}
}
// read header into struct
b := readBuf(buf[4:]) // skip signature
d := &directoryEnd{
diskNbr: uint32(b.uint16()),
dirDiskNbr: uint32(b.uint16()),
dirRecordsThisDisk: uint64(b.uint16()),
directoryRecords: uint64(b.uint16()),
directorySize: uint64(b.uint32()),
directoryOffset: uint64(b.uint32()),
commentLen: b.uint16(),
}
l := int(d.commentLen)
if l > len(b) {
return nil, 0, errors.New("zip: invalid comment length")
}
d.comment = string(b[:l])
// These values mean that the file can be a zip64 file
if d.directoryRecords == 0xffff || d.directorySize == 0xffff || d.directoryOffset == 0xffffffff {
p, err := findDirectory64End(r, directoryEndOffset)
if err == nil && p >= 0 {
directoryEndOffset = p
err = readDirectory64End(r, p, d)
}
if err != nil {
return nil, 0, err
}
}
maxInt64 := uint64(1<<63 - 1)
if d.directorySize > maxInt64 || d.directoryOffset > maxInt64 {
return nil, 0, ErrFormat
}
baseOffset = directoryEndOffset - int64(d.directorySize) - int64(d.directoryOffset)
// Make sure directoryOffset points to somewhere in our file.
if o := baseOffset + int64(d.directoryOffset); o < 0 || o >= size {
return nil, 0, ErrFormat
}
// If the directory end data tells us to use a non-zero baseOffset,
// but we would find a valid directory entry if we assume that the
// baseOffset is 0, then just use a baseOffset of 0.
// We've seen files in which the directory end data gives us
// an incorrect baseOffset.
if baseOffset > 0 {
off := int64(d.directoryOffset)
rs := io.NewSectionReader(r, off, size-off)
if readDirectoryHeader(&File{}, rs) == nil {
baseOffset = 0
}
}
return d, baseOffset, nil
}
// findDirectory64End tries to read the zip64 locator just before the
// directory end and returns the offset of the zip64 directory end if
// found.
func findDirectory64End(r io.ReaderAt, directoryEndOffset int64) (int64, error) {
locOffset := directoryEndOffset - directory64LocLen
if locOffset < 0 {
return -1, nil // no need to look for a header outside the file
}
buf := make([]byte, directory64LocLen)
if _, err := r.ReadAt(buf, locOffset); err != nil {
return -1, err
}
b := readBuf(buf)
if sig := b.uint32(); sig != directory64LocSignature {
return -1, nil
}
if b.uint32() != 0 { // number of the disk with the start of the zip64 end of central directory
return -1, nil // the file is not a valid zip64-file
}
p := b.uint64() // relative offset of the zip64 end of central directory record
if b.uint32() != 1 { // total number of disks
return -1, nil // the file is not a valid zip64-file
}
return int64(p), nil
}
// readDirectory64End reads the zip64 directory end and updates the
// directory end with the zip64 directory end values.
func readDirectory64End(r io.ReaderAt, offset int64, d *directoryEnd) (err error) {
buf := make([]byte, directory64EndLen)
if _, err := r.ReadAt(buf, offset); err != nil {
return err
}
b := readBuf(buf)
if sig := b.uint32(); sig != directory64EndSignature {
return ErrFormat
}
b = b[12:] // skip dir size, version and version needed (uint64 + 2x uint16)
d.diskNbr = b.uint32() // number of this disk
d.dirDiskNbr = b.uint32() // number of the disk with the start of the central directory
d.dirRecordsThisDisk = b.uint64() // total number of entries in the central directory on this disk
d.directoryRecords = b.uint64() // total number of entries in the central directory
d.directorySize = b.uint64() // size of the central directory
d.directoryOffset = b.uint64() // offset of start of central directory with respect to the starting disk number
return nil
}
func findSignatureInBlock(b []byte) int {
for i := len(b) - directoryEndLen; i >= 0; i-- {
// defined from directoryEndSignature in struct.go
if b[i] == 'P' && b[i+1] == 'K' && b[i+2] == 0x05 && b[i+3] == 0x06 {
// n is length of comment
n := int(b[i+directoryEndLen-2]) | int(b[i+directoryEndLen-1])<<8
if n+directoryEndLen+i > len(b) {
// Truncated comment.
// Some parsers (such as Info-ZIP) ignore the truncated comment
// rather than treating it as a hard error.
return -1
}
return i
}
}
return -1
}
type readBuf []byte
func (b *readBuf) uint8() uint8 {
v := (*b)[0]
*b = (*b)[1:]
return v
}
func (b *readBuf) uint16() uint16 {
v := binary.LittleEndian.Uint16(*b)
*b = (*b)[2:]
return v
}
func (b *readBuf) uint32() uint32 {
v := binary.LittleEndian.Uint32(*b)
*b = (*b)[4:]
return v
}
func (b *readBuf) uint64() uint64 {
v := binary.LittleEndian.Uint64(*b)
*b = (*b)[8:]
return v
}
func (b *readBuf) sub(n int) readBuf {
b2 := (*b)[:n]
*b = (*b)[n:]
return b2
}
// A fileListEntry is a File and its ename.
// If file == nil, the fileListEntry describes a directory without metadata.
type fileListEntry struct {
name string
file *File
isDir bool
isDup bool
}
type fileInfoDirEntry interface {
fs.FileInfo
fs.DirEntry
}
func (f *fileListEntry) stat() (fileInfoDirEntry, error) {
if f.isDup {
return nil, errors.New(f.name + ": duplicate entries in zip file")
}
if !f.isDir {
return headerFileInfo{&f.file.FileHeader}, nil
}
return f, nil
}
// Only used for directories.
func (f *fileListEntry) Name() string { _, elem, _ := split(f.name); return elem }
func (f *fileListEntry) Size() int64 { return 0 }
func (f *fileListEntry) Mode() fs.FileMode { return fs.ModeDir | 0555 }
func (f *fileListEntry) Type() fs.FileMode { return fs.ModeDir }
func (f *fileListEntry) IsDir() bool { return true }
func (f *fileListEntry) Sys() any { return nil }
func (f *fileListEntry) ModTime() time.Time {
if f.file == nil {
return time.Time{}
}
return f.file.FileHeader.Modified.UTC()
}
func (f *fileListEntry) Info() (fs.FileInfo, error) { return f, nil }
func (f *fileListEntry) String() string {
return fs.FormatDirEntry(f)
}
// toValidName coerces name to be a valid name for fs.FS.Open.
func toValidName(name string) string {
name = strings.ReplaceAll(name, `\`, `/`)
p := path.Clean(name)
p = strings.TrimPrefix(p, "/")
for strings.HasPrefix(p, "../") {
p = p[len("../"):]
}
return p
}
func (r *Reader) initFileList() {
r.fileListOnce.Do(func() {
// Preallocate the minimum size of the index.
// We may also synthesize additional directory entries.
r.fileList = make([]fileListEntry, 0, len(r.File))
// files and knownDirs map from a file/directory name
// to an index into the r.fileList entry that we are
// building. They are used to mark duplicate entries.
files := make(map[string]int)
knownDirs := make(map[string]int)
// dirs[name] is true if name is known to be a directory,
// because it appears as a prefix in a path.
dirs := make(map[string]bool)
for _, file := range r.File {
isDir := len(file.Name) > 0 && file.Name[len(file.Name)-1] == '/'
name := toValidName(file.Name)
if name == "" {
continue
}
if idx, ok := files[name]; ok {
r.fileList[idx].isDup = true
continue
}
if idx, ok := knownDirs[name]; ok {
r.fileList[idx].isDup = true
continue
}
for dir := path.Dir(name); dir != "."; dir = path.Dir(dir) {
dirs[dir] = true
}
idx := len(r.fileList)
entry := fileListEntry{
name: name,
file: file,
isDir: isDir,
}
r.fileList = append(r.fileList, entry)
if isDir {
knownDirs[name] = idx
} else {
files[name] = idx
}
}
for dir := range dirs {
if _, ok := knownDirs[dir]; !ok {
if idx, ok := files[dir]; ok {
r.fileList[idx].isDup = true
} else {
entry := fileListEntry{
name: dir,
file: nil,
isDir: true,
}
r.fileList = append(r.fileList, entry)
}
}
}
slices.SortFunc(r.fileList, func(a, b fileListEntry) int {
return fileEntryCompare(a.name, b.name)
})
})
}
func fileEntryCompare(x, y string) int {
xdir, xelem, _ := split(x)
ydir, yelem, _ := split(y)
if xdir != ydir {
return strings.Compare(xdir, ydir)
}
return strings.Compare(xelem, yelem)
}
// Open opens the named file in the ZIP archive,
// using the semantics of fs.FS.Open:
// paths are always slash separated, with no
// leading / or ../ elements.
func (r *Reader) Open(name string) (fs.File, error) {
r.initFileList()
if !fs.ValidPath(name) {
return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrInvalid}
}
e := r.openLookup(name)
if e == nil {
return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist}
}
if e.isDir {
return &openDir{e, r.openReadDir(name), 0}, nil
}
rc, err := e.file.Open()
if err != nil {
return nil, err
}
return rc.(fs.File), nil
}
func split(name string) (dir, elem string, isDir bool) {
name, isDir = strings.CutSuffix(name, "/")
i := strings.LastIndexByte(name, '/')
if i < 0 {
return ".", name, isDir
}
return name[:i], name[i+1:], isDir
}
var dotFile = &fileListEntry{name: "./", isDir: true}
func (r *Reader) openLookup(name string) *fileListEntry {
if name == "." {
return dotFile
}
dir, elem, _ := split(name)
files := r.fileList
i, _ := slices.BinarySearchFunc(files, dir, func(a fileListEntry, dir string) (ret int) {
idir, ielem, _ := split(a.name)
if dir != idir {
return strings.Compare(idir, dir)
}
return strings.Compare(ielem, elem)
})
if i < len(files) {
fname := files[i].name
if fname == name || len(fname) == len(name)+1 && fname[len(name)] == '/' && fname[:len(name)] == name {
return &files[i]
}
}
return nil
}
func (r *Reader) openReadDir(dir string) []fileListEntry {
files := r.fileList
i, _ := slices.BinarySearchFunc(files, dir, func(a fileListEntry, dir string) int {
idir, _, _ := split(a.name)
if dir != idir {
return strings.Compare(idir, dir)
}
// find the first entry with dir
return +1
})
j, _ := slices.BinarySearchFunc(files, dir, func(a fileListEntry, dir string) int {
jdir, _, _ := split(a.name)
if dir != jdir {
return strings.Compare(jdir, dir)
}
// find the last entry with dir
return -1
})
return files[i:j]
}
type openDir struct {
e *fileListEntry
files []fileListEntry
offset int
}
func (d *openDir) Close() error { return nil }
func (d *openDir) Stat() (fs.FileInfo, error) { return d.e.stat() }
func (d *openDir) Read([]byte) (int, error) {
return 0, &fs.PathError{Op: "read", Path: d.e.name, Err: errors.New("is a directory")}
}
func (d *openDir) ReadDir(count int) ([]fs.DirEntry, error) {
n := len(d.files) - d.offset
if count > 0 && n > count {
n = count
}
if n == 0 {
if count <= 0 {
return nil, nil
}
return nil, io.EOF
}
list := make([]fs.DirEntry, n)
for i := range list {
s, err := d.files[d.offset+i].stat()
if err != nil {
return nil, err
} else if s.Name() == "." || !fs.ValidPath(s.Name()) {
return nil, &fs.PathError{
Op: "readdir",
Path: d.e.name,
Err: fmt.Errorf("invalid file name: %v", d.files[d.offset+i].name),
}
}
list[i] = s
}
d.offset += n
return list, nil
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zip
import (
"compress/flate"
"errors"
"io"
"sync"
)
// A Compressor returns a new compressing writer, writing to w.
// The WriteCloser's Close method must be used to flush pending data to w.
// The Compressor itself must be safe to invoke from multiple goroutines
// simultaneously, but each returned writer will be used only by
// one goroutine at a time.
type Compressor func(w io.Writer) (io.WriteCloser, error)
// A Decompressor returns a new decompressing reader, reading from r.
// The [io.ReadCloser]'s Close method must be used to release associated resources.
// The Decompressor itself must be safe to invoke from multiple goroutines
// simultaneously, but each returned reader will be used only by
// one goroutine at a time.
type Decompressor func(r io.Reader) io.ReadCloser
var flateWriterPool sync.Pool
func newFlateWriter(w io.Writer) io.WriteCloser {
fw, ok := flateWriterPool.Get().(*flate.Writer)
if ok {
fw.Reset(w)
} else {
fw, _ = flate.NewWriter(w, 5)
}
return &pooledFlateWriter{fw: fw}
}
type pooledFlateWriter struct {
mu sync.Mutex // guards Close and Write
fw *flate.Writer
}
func (w *pooledFlateWriter) Write(p []byte) (n int, err error) {
w.mu.Lock()
defer w.mu.Unlock()
if w.fw == nil {
return 0, errors.New("Write after Close")
}
return w.fw.Write(p)
}
func (w *pooledFlateWriter) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
var err error
if w.fw != nil {
err = w.fw.Close()
flateWriterPool.Put(w.fw)
w.fw = nil
}
return err
}
var flateReaderPool sync.Pool
func newFlateReader(r io.Reader) io.ReadCloser {
fr, ok := flateReaderPool.Get().(io.ReadCloser)
if ok {
fr.(flate.Resetter).Reset(r, nil)
} else {
fr = flate.NewReader(r)
}
return &pooledFlateReader{fr: fr}
}
type pooledFlateReader struct {
mu sync.Mutex // guards Close and Read
fr io.ReadCloser
}
func (r *pooledFlateReader) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.fr == nil {
return 0, errors.New("Read after Close")
}
return r.fr.Read(p)
}
func (r *pooledFlateReader) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
var err error
if r.fr != nil {
err = r.fr.Close()
flateReaderPool.Put(r.fr)
r.fr = nil
}
return err
}
var (
compressors sync.Map // map[uint16]Compressor
decompressors sync.Map // map[uint16]Decompressor
)
func init() {
compressors.Store(Store, Compressor(func(w io.Writer) (io.WriteCloser, error) { return &nopCloser{w}, nil }))
compressors.Store(Deflate, Compressor(func(w io.Writer) (io.WriteCloser, error) { return newFlateWriter(w), nil }))
decompressors.Store(Store, Decompressor(io.NopCloser))
decompressors.Store(Deflate, Decompressor(newFlateReader))
}
// RegisterDecompressor allows custom decompressors for a specified method ID.
// The common methods [Store] and [Deflate] are built in.
func RegisterDecompressor(method uint16, dcomp Decompressor) {
if _, dup := decompressors.LoadOrStore(method, dcomp); dup {
panic("decompressor already registered")
}
}
// RegisterCompressor registers custom compressors for a specified method ID.
// The common methods [Store] and [Deflate] are built in.
func RegisterCompressor(method uint16, comp Compressor) {
if _, dup := compressors.LoadOrStore(method, comp); dup {
panic("compressor already registered")
}
}
func compressor(method uint16) Compressor {
ci, ok := compressors.Load(method)
if !ok {
return nil
}
return ci.(Compressor)
}
func decompressor(method uint16) Decompressor {
di, ok := decompressors.Load(method)
if !ok {
return nil
}
return di.(Decompressor)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package zip provides support for reading and writing ZIP archives.
See the [ZIP specification] for details.
This package does not support disk spanning.
A note about ZIP64:
To be backwards compatible the FileHeader has both 32 and 64 bit Size
fields. The 64 bit fields will always contain the correct value and
for normal archives both fields will be the same. For files requiring
the ZIP64 format the 32 bit fields will be 0xffffffff and the 64 bit
fields must be used instead.
[ZIP specification]: https://support.pkware.com/pkzip/appnote
*/
package zip
import (
"io/fs"
"path"
"time"
)
// Compression methods.
const (
Store uint16 = 0 // no compression
Deflate uint16 = 8 // DEFLATE compressed
)
const (
fileHeaderSignature = 0x04034b50
directoryHeaderSignature = 0x02014b50
directoryEndSignature = 0x06054b50
directory64LocSignature = 0x07064b50
directory64EndSignature = 0x06064b50
dataDescriptorSignature = 0x08074b50 // de-facto standard; required by OS X Finder
fileHeaderLen = 30 // + filename + extra
directoryHeaderLen = 46 // + filename + extra + comment
directoryEndLen = 22 // + comment
dataDescriptorLen = 16 // four uint32: descriptor signature, crc32, compressed size, size
dataDescriptor64Len = 24 // two uint32: signature, crc32 | two uint64: compressed size, size
directory64LocLen = 20 //
directory64EndLen = 56 // + extra
// Constants for the first byte in CreatorVersion.
creatorFAT = 0
creatorUnix = 3
creatorNTFS = 11
creatorVFAT = 14
creatorMacOSX = 19
// Version numbers.
zipVersion20 = 20 // 2.0
zipVersion45 = 45 // 4.5 (reads and writes zip64 archives)
// Limits for non zip64 files.
uint16max = (1 << 16) - 1
uint32max = (1 << 32) - 1
// Extra header IDs.
//
// IDs 0..31 are reserved for official use by PKWARE.
// IDs above that range are defined by third-party vendors.
// Since ZIP lacked high precision timestamps (nor an official specification
// of the timezone used for the date fields), many competing extra fields
// have been invented. Pervasive use effectively makes them "official".
//
// See http://mdfs.net/Docs/Comp/Archiving/Zip/ExtraField
zip64ExtraID = 0x0001 // Zip64 extended information
ntfsExtraID = 0x000a // NTFS
unixExtraID = 0x000d // UNIX
extTimeExtraID = 0x5455 // Extended timestamp
infoZipUnixExtraID = 0x5855 // Info-ZIP Unix extension
)
// FileHeader describes a file within a ZIP file.
// See the [ZIP specification] for details.
//
// [ZIP specification]: https://support.pkware.com/pkzip/appnote
type FileHeader struct {
// Name is the name of the file.
//
// It must be a relative path, not start with a drive letter (such as "C:"),
// and must use forward slashes instead of back slashes. A trailing slash
// indicates that this file is a directory and should have no data.
Name string
// Comment is any arbitrary user-defined string shorter than 64KiB.
Comment string
// NonUTF8 indicates that Name and Comment are not encoded in UTF-8.
//
// By specification, the only other encoding permitted should be CP-437,
// but historically many ZIP readers interpret Name and Comment as whatever
// the system's local character encoding happens to be.
//
// This flag should only be set if the user intends to encode a non-portable
// ZIP file for a specific localized region. Otherwise, the Writer
// automatically sets the ZIP format's UTF-8 flag for valid UTF-8 strings.
NonUTF8 bool
CreatorVersion uint16
ReaderVersion uint16
Flags uint16
// Method is the compression method. If zero, Store is used.
Method uint16
// Modified is the modified time of the file.
//
// When reading, an extended timestamp is preferred over the legacy MS-DOS
// date field, and the offset between the times is used as the timezone.
// If only the MS-DOS date is present, the timezone is assumed to be UTC.
//
// When writing, an extended timestamp (which is timezone-agnostic) is
// always emitted. The legacy MS-DOS date field is encoded according to the
// location of the Modified time.
Modified time.Time
// ModifiedTime is an MS-DOS-encoded time.
//
// Deprecated: Use Modified instead.
ModifiedTime uint16
// ModifiedDate is an MS-DOS-encoded date.
//
// Deprecated: Use Modified instead.
ModifiedDate uint16
// CRC32 is the CRC32 checksum of the file content.
CRC32 uint32
// CompressedSize is the compressed size of the file in bytes.
// If either the uncompressed or compressed size of the file
// does not fit in 32 bits, CompressedSize is set to ^uint32(0).
//
// Deprecated: Use CompressedSize64 instead.
CompressedSize uint32
// UncompressedSize is the uncompressed size of the file in bytes.
// If either the uncompressed or compressed size of the file
// does not fit in 32 bits, UncompressedSize is set to ^uint32(0).
//
// Deprecated: Use UncompressedSize64 instead.
UncompressedSize uint32
// CompressedSize64 is the compressed size of the file in bytes.
CompressedSize64 uint64
// UncompressedSize64 is the uncompressed size of the file in bytes.
UncompressedSize64 uint64
Extra []byte
ExternalAttrs uint32 // Meaning depends on CreatorVersion
}
// FileInfo returns an fs.FileInfo for the [FileHeader].
func (h *FileHeader) FileInfo() fs.FileInfo {
return headerFileInfo{h}
}
// headerFileInfo implements [fs.FileInfo].
type headerFileInfo struct {
fh *FileHeader
}
func (fi headerFileInfo) Name() string { return path.Base(fi.fh.Name) }
func (fi headerFileInfo) Size() int64 {
if fi.fh.UncompressedSize64 > 0 {
return int64(fi.fh.UncompressedSize64)
}
return int64(fi.fh.UncompressedSize)
}
func (fi headerFileInfo) IsDir() bool { return fi.Mode().IsDir() }
func (fi headerFileInfo) ModTime() time.Time {
if fi.fh.Modified.IsZero() {
return fi.fh.ModTime()
}
return fi.fh.Modified.UTC()
}
func (fi headerFileInfo) Mode() fs.FileMode { return fi.fh.Mode() }
func (fi headerFileInfo) Type() fs.FileMode { return fi.fh.Mode().Type() }
func (fi headerFileInfo) Sys() any { return fi.fh }
func (fi headerFileInfo) Info() (fs.FileInfo, error) { return fi, nil }
func (fi headerFileInfo) String() string {
return fs.FormatFileInfo(fi)
}
// FileInfoHeader creates a partially-populated [FileHeader] from an
// fs.FileInfo.
// Because fs.FileInfo's Name method returns only the base name of
// the file it describes, it may be necessary to modify the Name field
// of the returned header to provide the full path name of the file.
// If compression is desired, callers should set the FileHeader.Method
// field; it is unset by default.
func FileInfoHeader(fi fs.FileInfo) (*FileHeader, error) {
size := fi.Size()
fh := &FileHeader{
Name: fi.Name(),
UncompressedSize64: uint64(size),
}
fh.SetModTime(fi.ModTime())
fh.SetMode(fi.Mode())
if fh.UncompressedSize64 > uint32max {
fh.UncompressedSize = uint32max
} else {
fh.UncompressedSize = uint32(fh.UncompressedSize64)
}
return fh, nil
}
type directoryEnd struct {
diskNbr uint32 // unused
dirDiskNbr uint32 // unused
dirRecordsThisDisk uint64 // unused
directoryRecords uint64
directorySize uint64
directoryOffset uint64 // relative to file
commentLen uint16
comment string
}
// timeZone returns a *time.Location based on the provided offset.
// If the offset is non-sensible, then this uses an offset of zero.
func timeZone(offset time.Duration) *time.Location {
const (
minOffset = -12 * time.Hour // E.g., Baker island at -12:00
maxOffset = +14 * time.Hour // E.g., Line island at +14:00
offsetAlias = 15 * time.Minute // E.g., Nepal at +5:45
)
offset = offset.Round(offsetAlias)
if offset < minOffset || maxOffset < offset {
offset = 0
}
return time.FixedZone("", int(offset/time.Second))
}
// msDosTimeToTime converts an MS-DOS date and time into a time.Time.
// The resolution is 2s.
// See: https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-dosdatetimetofiletime
func msDosTimeToTime(dosDate, dosTime uint16) time.Time {
return time.Date(
// date bits 0-4: day of month; 5-8: month; 9-15: years since 1980
int(dosDate>>9+1980),
time.Month(dosDate>>5&0xf),
int(dosDate&0x1f),
// time bits 0-4: second/2; 5-10: minute; 11-15: hour
int(dosTime>>11),
int(dosTime>>5&0x3f),
int(dosTime&0x1f*2),
0, // nanoseconds
time.UTC,
)
}
// timeToMsDosTime converts a time.Time to an MS-DOS date and time.
// The resolution is 2s.
// See: https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-filetimetodosdatetime
func timeToMsDosTime(t time.Time) (fDate uint16, fTime uint16) {
fDate = uint16(t.Day() + int(t.Month())<<5 + (t.Year()-1980)<<9)
fTime = uint16(t.Second()/2 + t.Minute()<<5 + t.Hour()<<11)
return
}
// ModTime returns the modification time in UTC using the legacy
// [ModifiedDate] and [ModifiedTime] fields.
//
// Deprecated: Use [Modified] instead.
func (h *FileHeader) ModTime() time.Time {
return msDosTimeToTime(h.ModifiedDate, h.ModifiedTime)
}
// SetModTime sets the [Modified], [ModifiedTime], and [ModifiedDate] fields
// to the given time in UTC.
//
// Deprecated: Use [Modified] instead.
func (h *FileHeader) SetModTime(t time.Time) {
t = t.UTC() // Convert to UTC for compatibility
h.Modified = t
h.ModifiedDate, h.ModifiedTime = timeToMsDosTime(t)
}
const (
// Unix constants. The specification doesn't mention them,
// but these seem to be the values agreed on by tools.
s_IFMT = 0xf000
s_IFSOCK = 0xc000
s_IFLNK = 0xa000
s_IFREG = 0x8000
s_IFBLK = 0x6000
s_IFDIR = 0x4000
s_IFCHR = 0x2000
s_IFIFO = 0x1000
s_ISUID = 0x800
s_ISGID = 0x400
s_ISVTX = 0x200
msdosDir = 0x10
msdosReadOnly = 0x01
)
// Mode returns the permission and mode bits for the [FileHeader].
func (h *FileHeader) Mode() (mode fs.FileMode) {
switch h.CreatorVersion >> 8 {
case creatorUnix, creatorMacOSX:
mode = unixModeToFileMode(h.ExternalAttrs >> 16)
case creatorNTFS, creatorVFAT, creatorFAT:
mode = msdosModeToFileMode(h.ExternalAttrs)
}
if len(h.Name) > 0 && h.Name[len(h.Name)-1] == '/' {
mode |= fs.ModeDir
}
return mode
}
// SetMode changes the permission and mode bits for the [FileHeader].
func (h *FileHeader) SetMode(mode fs.FileMode) {
h.CreatorVersion = h.CreatorVersion&0xff | creatorUnix<<8
h.ExternalAttrs = fileModeToUnixMode(mode) << 16
// set MSDOS attributes too, as the original zip does.
if mode&fs.ModeDir != 0 {
h.ExternalAttrs |= msdosDir
}
if mode&0200 == 0 {
h.ExternalAttrs |= msdosReadOnly
}
}
// isZip64 reports whether the file size exceeds the 32 bit limit
func (h *FileHeader) isZip64() bool {
return h.CompressedSize64 >= uint32max || h.UncompressedSize64 >= uint32max
}
func (h *FileHeader) hasDataDescriptor() bool {
return h.Flags&0x8 != 0
}
func msdosModeToFileMode(m uint32) (mode fs.FileMode) {
if m&msdosDir != 0 {
mode = fs.ModeDir | 0777
} else {
mode = 0666
}
if m&msdosReadOnly != 0 {
mode &^= 0222
}
return mode
}
func fileModeToUnixMode(mode fs.FileMode) uint32 {
var m uint32
switch mode & fs.ModeType {
default:
m = s_IFREG
case fs.ModeDir:
m = s_IFDIR
case fs.ModeSymlink:
m = s_IFLNK
case fs.ModeNamedPipe:
m = s_IFIFO
case fs.ModeSocket:
m = s_IFSOCK
case fs.ModeDevice:
m = s_IFBLK
case fs.ModeDevice | fs.ModeCharDevice:
m = s_IFCHR
}
if mode&fs.ModeSetuid != 0 {
m |= s_ISUID
}
if mode&fs.ModeSetgid != 0 {
m |= s_ISGID
}
if mode&fs.ModeSticky != 0 {
m |= s_ISVTX
}
return m | uint32(mode&0777)
}
func unixModeToFileMode(m uint32) fs.FileMode {
mode := fs.FileMode(m & 0777)
switch m & s_IFMT {
case s_IFBLK:
mode |= fs.ModeDevice
case s_IFCHR:
mode |= fs.ModeDevice | fs.ModeCharDevice
case s_IFDIR:
mode |= fs.ModeDir
case s_IFIFO:
mode |= fs.ModeNamedPipe
case s_IFLNK:
mode |= fs.ModeSymlink
case s_IFREG:
// nothing to do
case s_IFSOCK:
mode |= fs.ModeSocket
}
if m&s_ISGID != 0 {
mode |= fs.ModeSetgid
}
if m&s_ISUID != 0 {
mode |= fs.ModeSetuid
}
if m&s_ISVTX != 0 {
mode |= fs.ModeSticky
}
return mode
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zip
import (
"bufio"
"encoding/binary"
"errors"
"hash"
"hash/crc32"
"io"
"io/fs"
"strings"
"unicode/utf8"
)
var (
errLongName = errors.New("zip: FileHeader.Name too long")
errLongExtra = errors.New("zip: FileHeader.Extra too long")
)
// Writer implements a zip file writer.
type Writer struct {
cw *countWriter
dir []*header
last *fileWriter
closed bool
compressors map[uint16]Compressor
comment string
// testHookCloseSizeOffset if non-nil is called with the size
// of offset of the central directory at Close.
testHookCloseSizeOffset func(size, offset uint64)
}
type header struct {
*FileHeader
offset uint64
raw bool
}
// NewWriter returns a new [Writer] writing a zip file to w.
func NewWriter(w io.Writer) *Writer {
return &Writer{cw: &countWriter{w: bufio.NewWriter(w)}}
}
// SetOffset sets the offset of the beginning of the zip data within the
// underlying writer. It should be used when the zip data is appended to an
// existing file, such as a binary executable.
// It must be called before any data is written.
func (w *Writer) SetOffset(n int64) {
if w.cw.count != 0 {
panic("zip: SetOffset called after data was written")
}
w.cw.count = n
}
// Flush flushes any buffered data to the underlying writer.
// Calling Flush is not normally necessary; calling Close is sufficient.
func (w *Writer) Flush() error {
return w.cw.w.(*bufio.Writer).Flush()
}
// SetComment sets the end-of-central-directory comment field.
// It can only be called before [Writer.Close].
func (w *Writer) SetComment(comment string) error {
if len(comment) > uint16max {
return errors.New("zip: Writer.Comment too long")
}
w.comment = comment
return nil
}
// Close finishes writing the zip file by writing the central directory.
// It does not close the underlying writer.
func (w *Writer) Close() error {
if w.last != nil && !w.last.closed {
if err := w.last.close(); err != nil {
return err
}
w.last = nil
}
if w.closed {
return errors.New("zip: writer closed twice")
}
w.closed = true
// write central directory
start := w.cw.count
for _, h := range w.dir {
var buf [directoryHeaderLen]byte
b := writeBuf(buf[:])
b.uint32(uint32(directoryHeaderSignature))
b.uint16(h.CreatorVersion)
b.uint16(h.ReaderVersion)
b.uint16(h.Flags)
b.uint16(h.Method)
b.uint16(h.ModifiedTime)
b.uint16(h.ModifiedDate)
b.uint32(h.CRC32)
if h.isZip64() || h.offset >= uint32max {
// the file needs a zip64 header. store maxint in both
// 32 bit size fields (and offset later) to signal that the
// zip64 extra header should be used.
b.uint32(uint32max) // compressed size
b.uint32(uint32max) // uncompressed size
// append a zip64 extra block to Extra
var buf [28]byte // 2x uint16 + 3x uint64
eb := writeBuf(buf[:])
eb.uint16(zip64ExtraID)
eb.uint16(24) // size = 3x uint64
eb.uint64(h.UncompressedSize64)
eb.uint64(h.CompressedSize64)
eb.uint64(h.offset)
h.Extra = append(h.Extra, buf[:]...)
} else {
b.uint32(h.CompressedSize)
b.uint32(h.UncompressedSize)
}
b.uint16(uint16(len(h.Name)))
b.uint16(uint16(len(h.Extra)))
b.uint16(uint16(len(h.Comment)))
b = b[4:] // skip disk number start and internal file attr (2x uint16)
b.uint32(h.ExternalAttrs)
if h.offset > uint32max {
b.uint32(uint32max)
} else {
b.uint32(uint32(h.offset))
}
if _, err := w.cw.Write(buf[:]); err != nil {
return err
}
if _, err := io.WriteString(w.cw, h.Name); err != nil {
return err
}
if _, err := w.cw.Write(h.Extra); err != nil {
return err
}
if _, err := io.WriteString(w.cw, h.Comment); err != nil {
return err
}
}
end := w.cw.count
records := uint64(len(w.dir))
size := uint64(end - start)
offset := uint64(start)
if f := w.testHookCloseSizeOffset; f != nil {
f(size, offset)
}
if records >= uint16max || size >= uint32max || offset >= uint32max {
var buf [directory64EndLen + directory64LocLen]byte
b := writeBuf(buf[:])
// zip64 end of central directory record
b.uint32(directory64EndSignature)
b.uint64(directory64EndLen - 12) // length minus signature (uint32) and length fields (uint64)
b.uint16(zipVersion45) // version made by
b.uint16(zipVersion45) // version needed to extract
b.uint32(0) // number of this disk
b.uint32(0) // number of the disk with the start of the central directory
b.uint64(records) // total number of entries in the central directory on this disk
b.uint64(records) // total number of entries in the central directory
b.uint64(size) // size of the central directory
b.uint64(offset) // offset of start of central directory with respect to the starting disk number
// zip64 end of central directory locator
b.uint32(directory64LocSignature)
b.uint32(0) // number of the disk with the start of the zip64 end of central directory
b.uint64(uint64(end)) // relative offset of the zip64 end of central directory record
b.uint32(1) // total number of disks
if _, err := w.cw.Write(buf[:]); err != nil {
return err
}
// store max values in the regular end record to signal
// that the zip64 values should be used instead
records = uint16max
size = uint32max
offset = uint32max
}
// write end record
var buf [directoryEndLen]byte
b := writeBuf(buf[:])
b.uint32(uint32(directoryEndSignature))
b = b[4:] // skip over disk number and first disk number (2x uint16)
b.uint16(uint16(records)) // number of entries this disk
b.uint16(uint16(records)) // number of entries total
b.uint32(uint32(size)) // size of directory
b.uint32(uint32(offset)) // start of directory
b.uint16(uint16(len(w.comment))) // byte size of EOCD comment
if _, err := w.cw.Write(buf[:]); err != nil {
return err
}
if _, err := io.WriteString(w.cw, w.comment); err != nil {
return err
}
return w.cw.w.(*bufio.Writer).Flush()
}
// Create adds a file to the zip file using the provided name.
// It returns a [Writer] to which the file contents should be written.
// The file contents will be compressed using the [Deflate] method.
// The name must be a relative path: it must not start with a drive
// letter (e.g. C:) or leading slash, and only forward slashes are
// allowed. To create a directory instead of a file, add a trailing
// slash to the name. Duplicate names will not overwrite previous entries
// and are appended to the zip file.
// The file's contents must be written to the [io.Writer] before the next
// call to [Writer.Create], [Writer.CreateHeader], or [Writer.Close].
func (w *Writer) Create(name string) (io.Writer, error) {
header := &FileHeader{
Name: name,
Method: Deflate,
}
return w.CreateHeader(header)
}
// detectUTF8 reports whether s is a valid UTF-8 string, and whether the string
// must be considered UTF-8 encoding (i.e., not compatible with CP-437, ASCII,
// or any other common encoding).
func detectUTF8(s string) (valid, require bool) {
for i := 0; i < len(s); {
r, size := utf8.DecodeRuneInString(s[i:])
i += size
// Officially, ZIP uses CP-437, but many readers use the system's
// local character encoding. Most encoding are compatible with a large
// subset of CP-437, which itself is ASCII-like.
//
// Forbid 0x7e and 0x5c since EUC-KR and Shift-JIS replace those
// characters with localized currency and overline characters.
if r < 0x20 || r > 0x7d || r == 0x5c {
if !utf8.ValidRune(r) || (r == utf8.RuneError && size == 1) {
return false, false
}
require = true
}
}
return true, require
}
// prepare performs the bookkeeping operations required at the start of
// CreateHeader and CreateRaw.
func (w *Writer) prepare(fh *FileHeader) error {
if w.last != nil && !w.last.closed {
if err := w.last.close(); err != nil {
return err
}
}
if len(w.dir) > 0 && w.dir[len(w.dir)-1].FileHeader == fh {
// See https://golang.org/issue/11144 confusion.
return errors.New("archive/zip: invalid duplicate FileHeader")
}
return nil
}
// CreateHeader adds a file to the zip archive using the provided [FileHeader]
// for the file metadata. [Writer] takes ownership of fh and may mutate
// its fields. The caller must not modify fh after calling [Writer.CreateHeader].
//
// This returns a [Writer] to which the file contents should be written.
// The file's contents must be written to the io.Writer before the next
// call to [Writer.Create], [Writer.CreateHeader], [Writer.CreateRaw], or [Writer.Close].
func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, error) {
if err := w.prepare(fh); err != nil {
return nil, err
}
// The ZIP format has a sad state of affairs regarding character encoding.
// Officially, the name and comment fields are supposed to be encoded
// in CP-437 (which is mostly compatible with ASCII), unless the UTF-8
// flag bit is set. However, there are several problems:
//
// * Many ZIP readers still do not support UTF-8.
// * If the UTF-8 flag is cleared, several readers simply interpret the
// name and comment fields as whatever the local system encoding is.
//
// In order to avoid breaking readers without UTF-8 support,
// we avoid setting the UTF-8 flag if the strings are CP-437 compatible.
// However, if the strings require multibyte UTF-8 encoding and is a
// valid UTF-8 string, then we set the UTF-8 bit.
//
// For the case, where the user explicitly wants to specify the encoding
// as UTF-8, they will need to set the flag bit themselves.
utf8Valid1, utf8Require1 := detectUTF8(fh.Name)
utf8Valid2, utf8Require2 := detectUTF8(fh.Comment)
switch {
case fh.NonUTF8:
fh.Flags &^= 0x800
case (utf8Require1 || utf8Require2) && (utf8Valid1 && utf8Valid2):
fh.Flags |= 0x800
}
fh.CreatorVersion = fh.CreatorVersion&0xff00 | zipVersion20 // preserve compatibility byte
fh.ReaderVersion = zipVersion20
// If Modified is set, this takes precedence over MS-DOS timestamp fields.
if !fh.Modified.IsZero() {
// Contrary to the FileHeader.SetModTime method, we intentionally
// do not convert to UTC, because we assume the user intends to encode
// the date using the specified timezone. A user may want this control
// because many legacy ZIP readers interpret the timestamp according
// to the local timezone.
//
// The timezone is only non-UTC if a user directly sets the Modified
// field directly themselves. All other approaches sets UTC.
fh.ModifiedDate, fh.ModifiedTime = timeToMsDosTime(fh.Modified)
// Use "extended timestamp" format since this is what Info-ZIP uses.
// Nearly every major ZIP implementation uses a different format,
// but at least most seem to be able to understand the other formats.
//
// This format happens to be identical for both local and central header
// if modification time is the only timestamp being encoded.
var mbuf [9]byte // 2*SizeOf(uint16) + SizeOf(uint8) + SizeOf(uint32)
mt := uint32(fh.Modified.Unix())
eb := writeBuf(mbuf[:])
eb.uint16(extTimeExtraID)
eb.uint16(5) // Size: SizeOf(uint8) + SizeOf(uint32)
eb.uint8(1) // Flags: ModTime
eb.uint32(mt) // ModTime
fh.Extra = append(fh.Extra, mbuf[:]...)
}
var (
ow io.Writer
fw *fileWriter
)
h := &header{
FileHeader: fh,
offset: uint64(w.cw.count),
}
if strings.HasSuffix(fh.Name, "/") {
// Set the compression method to Store to ensure data length is truly zero,
// which the writeHeader method always encodes for the size fields.
// This is necessary as most compression formats have non-zero lengths
// even when compressing an empty string.
fh.Method = Store
fh.Flags &^= 0x8 // we will not write a data descriptor
// Explicitly clear sizes as they have no meaning for directories.
fh.CompressedSize = 0
fh.CompressedSize64 = 0
fh.UncompressedSize = 0
fh.UncompressedSize64 = 0
ow = dirWriter{}
} else {
fh.Flags |= 0x8 // we will write a data descriptor
fw = &fileWriter{
zipw: w.cw,
compCount: &countWriter{w: w.cw},
crc32: crc32.NewIEEE(),
}
comp := w.compressor(fh.Method)
if comp == nil {
return nil, ErrAlgorithm
}
var err error
fw.comp, err = comp(fw.compCount)
if err != nil {
return nil, err
}
fw.rawCount = &countWriter{w: fw.comp}
fw.header = h
ow = fw
}
w.dir = append(w.dir, h)
if err := writeHeader(w.cw, h); err != nil {
return nil, err
}
// If we're creating a directory, fw is nil.
w.last = fw
return ow, nil
}
func writeHeader(w io.Writer, h *header) error {
const maxUint16 = 1<<16 - 1
if len(h.Name) > maxUint16 {
return errLongName
}
if len(h.Extra) > maxUint16 {
return errLongExtra
}
var buf [fileHeaderLen]byte
b := writeBuf(buf[:])
b.uint32(uint32(fileHeaderSignature))
b.uint16(h.ReaderVersion)
b.uint16(h.Flags)
b.uint16(h.Method)
b.uint16(h.ModifiedTime)
b.uint16(h.ModifiedDate)
// In raw mode (caller does the compression), the values are either
// written here or in the trailing data descriptor based on the header
// flags.
if h.raw && !h.hasDataDescriptor() {
b.uint32(h.CRC32)
b.uint32(uint32(min(h.CompressedSize64, uint32max)))
b.uint32(uint32(min(h.UncompressedSize64, uint32max)))
} else {
// When this package handle the compression, these values are
// always written to the trailing data descriptor.
b.uint32(0) // crc32
b.uint32(0) // compressed size
b.uint32(0) // uncompressed size
}
b.uint16(uint16(len(h.Name)))
b.uint16(uint16(len(h.Extra)))
if _, err := w.Write(buf[:]); err != nil {
return err
}
if _, err := io.WriteString(w, h.Name); err != nil {
return err
}
_, err := w.Write(h.Extra)
return err
}
// CreateRaw adds a file to the zip archive using the provided [FileHeader] and
// returns a [Writer] to which the file contents should be written. The file's
// contents must be written to the io.Writer before the next call to [Writer.Create],
// [Writer.CreateHeader], [Writer.CreateRaw], or [Writer.Close].
//
// In contrast to [Writer.CreateHeader], the bytes passed to Writer are not compressed.
//
// CreateRaw's argument is stored in w. If the argument is a pointer to the embedded
// [FileHeader] in a [File] obtained from a [Reader] created from in-memory data,
// then w will refer to all of that memory.
func (w *Writer) CreateRaw(fh *FileHeader) (io.Writer, error) {
if err := w.prepare(fh); err != nil {
return nil, err
}
fh.CompressedSize = uint32(min(fh.CompressedSize64, uint32max))
fh.UncompressedSize = uint32(min(fh.UncompressedSize64, uint32max))
h := &header{
FileHeader: fh,
offset: uint64(w.cw.count),
raw: true,
}
w.dir = append(w.dir, h)
if err := writeHeader(w.cw, h); err != nil {
return nil, err
}
if strings.HasSuffix(fh.Name, "/") {
w.last = nil
return dirWriter{}, nil
}
fw := &fileWriter{
header: h,
zipw: w.cw,
}
w.last = fw
return fw, nil
}
// Copy copies the file f (obtained from a [Reader]) into w. It copies the raw
// form directly bypassing decompression, compression, and validation.
func (w *Writer) Copy(f *File) error {
r, err := f.OpenRaw()
if err != nil {
return err
}
// Copy the FileHeader so w doesn't store a pointer to the data
// of f's entire archive. See #65499.
fh := f.FileHeader
fw, err := w.CreateRaw(&fh)
if err != nil {
return err
}
_, err = io.Copy(fw, r)
return err
}
// RegisterCompressor registers or overrides a custom compressor for a specific
// method ID. If a compressor for a given method is not found, [Writer] will
// default to looking up the compressor at the package level.
func (w *Writer) RegisterCompressor(method uint16, comp Compressor) {
if w.compressors == nil {
w.compressors = make(map[uint16]Compressor)
}
w.compressors[method] = comp
}
// AddFS adds the files from fs.FS to the archive.
// It walks the directory tree starting at the root of the filesystem
// adding each file to the zip using deflate while maintaining the directory structure.
func (w *Writer) AddFS(fsys fs.FS) error {
return fs.WalkDir(fsys, ".", func(name string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if name == "." {
return nil
}
info, err := d.Info()
if err != nil {
return err
}
if !d.IsDir() && !info.Mode().IsRegular() {
return errors.New("zip: cannot add non-regular file")
}
h, err := FileInfoHeader(info)
if err != nil {
return err
}
h.Name = name
if d.IsDir() {
h.Name += "/"
}
h.Method = Deflate
fw, err := w.CreateHeader(h)
if err != nil {
return err
}
if d.IsDir() {
return nil
}
f, err := fsys.Open(name)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(fw, f)
return err
})
}
func (w *Writer) compressor(method uint16) Compressor {
comp := w.compressors[method]
if comp == nil {
comp = compressor(method)
}
return comp
}
type dirWriter struct{}
func (dirWriter) Write(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}
return 0, errors.New("zip: write to directory")
}
type fileWriter struct {
*header
zipw io.Writer
rawCount *countWriter
comp io.WriteCloser
compCount *countWriter
crc32 hash.Hash32
closed bool
}
func (w *fileWriter) Write(p []byte) (int, error) {
if w.closed {
return 0, errors.New("zip: write to closed file")
}
if w.raw {
return w.zipw.Write(p)
}
w.crc32.Write(p)
return w.rawCount.Write(p)
}
func (w *fileWriter) close() error {
if w.closed {
return errors.New("zip: file closed twice")
}
w.closed = true
if w.raw {
return w.writeDataDescriptor()
}
if err := w.comp.Close(); err != nil {
return err
}
// update FileHeader
fh := w.header.FileHeader
fh.CRC32 = w.crc32.Sum32()
fh.CompressedSize64 = uint64(w.compCount.count)
fh.UncompressedSize64 = uint64(w.rawCount.count)
if fh.isZip64() {
fh.CompressedSize = uint32max
fh.UncompressedSize = uint32max
fh.ReaderVersion = zipVersion45 // requires 4.5 - File uses ZIP64 format extensions
} else {
fh.CompressedSize = uint32(fh.CompressedSize64)
fh.UncompressedSize = uint32(fh.UncompressedSize64)
}
return w.writeDataDescriptor()
}
func (w *fileWriter) writeDataDescriptor() error {
if !w.hasDataDescriptor() {
return nil
}
// Write data descriptor. This is more complicated than one would
// think, see e.g. comments in zipfile.c:putextended() and
// https://bugs.openjdk.org/browse/JDK-7073588.
// The approach here is to write 8 byte sizes if needed without
// adding a zip64 extra in the local header (too late anyway).
var buf []byte
if w.isZip64() {
buf = make([]byte, dataDescriptor64Len)
} else {
buf = make([]byte, dataDescriptorLen)
}
b := writeBuf(buf)
b.uint32(dataDescriptorSignature) // de-facto standard, required by OS X
b.uint32(w.CRC32)
if w.isZip64() {
b.uint64(w.CompressedSize64)
b.uint64(w.UncompressedSize64)
} else {
b.uint32(w.CompressedSize)
b.uint32(w.UncompressedSize)
}
_, err := w.zipw.Write(buf)
return err
}
type countWriter struct {
w io.Writer
count int64
}
func (w *countWriter) Write(p []byte) (int, error) {
n, err := w.w.Write(p)
w.count += int64(n)
return n, err
}
type nopCloser struct {
io.Writer
}
func (w nopCloser) Close() error {
return nil
}
type writeBuf []byte
func (b *writeBuf) uint8(v uint8) {
(*b)[0] = v
*b = (*b)[1:]
}
func (b *writeBuf) uint16(v uint16) {
binary.LittleEndian.PutUint16(*b, v)
*b = (*b)[2:]
}
func (b *writeBuf) uint32(v uint32) {
binary.LittleEndian.PutUint32(*b, v)
*b = (*b)[4:]
}
func (b *writeBuf) uint64(v uint64) {
binary.LittleEndian.PutUint64(*b, v)
*b = (*b)[8:]
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package bufio implements buffered I/O. It wraps an io.Reader or io.Writer
// object, creating another object (Reader or Writer) that also implements
// the interface but provides buffering and some help for textual I/O.
package bufio
import (
"bytes"
"errors"
"io"
"strings"
"unicode/utf8"
)
const (
defaultBufSize = 4096
)
var (
ErrInvalidUnreadByte = errors.New("bufio: invalid use of UnreadByte")
ErrInvalidUnreadRune = errors.New("bufio: invalid use of UnreadRune")
ErrBufferFull = errors.New("bufio: buffer full")
ErrNegativeCount = errors.New("bufio: negative count")
)
// Buffered input.
// Reader implements buffering for an io.Reader object.
// A new Reader is created by calling [NewReader] or [NewReaderSize];
// alternatively the zero value of a Reader may be used after calling [Reset]
// on it.
type Reader struct {
buf []byte
rd io.Reader // reader provided by the client
r, w int // buf read and write positions
err error
lastByte int // last byte read for UnreadByte; -1 means invalid
lastRuneSize int // size of last rune read for UnreadRune; -1 means invalid
}
const minReadBufferSize = 16
const maxConsecutiveEmptyReads = 100
// NewReaderSize returns a new [Reader] whose buffer has at least the specified
// size. If the argument io.Reader is already a [Reader] with large enough
// size, it returns the underlying [Reader].
func NewReaderSize(rd io.Reader, size int) *Reader {
// Is it already a Reader?
b, ok := rd.(*Reader)
if ok && len(b.buf) >= size {
return b
}
r := new(Reader)
r.reset(make([]byte, max(size, minReadBufferSize)), rd)
return r
}
// NewReader returns a new [Reader] whose buffer has the default size.
func NewReader(rd io.Reader) *Reader {
return NewReaderSize(rd, defaultBufSize)
}
// Size returns the size of the underlying buffer in bytes.
func (b *Reader) Size() int { return len(b.buf) }
// Reset discards any buffered data, resets all state, and switches
// the buffered reader to read from r.
// Calling Reset on the zero value of [Reader] initializes the internal buffer
// to the default size.
// Calling b.Reset(b) (that is, resetting a [Reader] to itself) does nothing.
func (b *Reader) Reset(r io.Reader) {
// If a Reader r is passed to NewReader, NewReader will return r.
// Different layers of code may do that, and then later pass r
// to Reset. Avoid infinite recursion in that case.
if b == r {
return
}
if b.buf == nil {
b.buf = make([]byte, defaultBufSize)
}
b.reset(b.buf, r)
}
func (b *Reader) reset(buf []byte, r io.Reader) {
*b = Reader{
buf: buf,
rd: r,
lastByte: -1,
lastRuneSize: -1,
}
}
var errNegativeRead = errors.New("bufio: reader returned negative count from Read")
// fill reads a new chunk into the buffer.
func (b *Reader) fill() {
// Slide existing data to beginning.
if b.r > 0 {
copy(b.buf, b.buf[b.r:b.w])
b.w -= b.r
b.r = 0
}
if b.w >= len(b.buf) {
panic("bufio: tried to fill full buffer")
}
// Read new data: try a limited number of times.
for i := maxConsecutiveEmptyReads; i > 0; i-- {
n, err := b.rd.Read(b.buf[b.w:])
if n < 0 {
panic(errNegativeRead)
}
b.w += n
if err != nil {
b.err = err
return
}
if n > 0 {
return
}
}
b.err = io.ErrNoProgress
}
func (b *Reader) readErr() error {
err := b.err
b.err = nil
return err
}
// Peek returns the next n bytes without advancing the reader. The bytes stop
// being valid at the next read call. If necessary, Peek will read more bytes
// into the buffer in order to make n bytes available. If Peek returns fewer
// than n bytes, it also returns an error explaining why the read is short.
// The error is [ErrBufferFull] if n is larger than b's buffer size.
//
// Calling Peek prevents a [Reader.UnreadByte] or [Reader.UnreadRune] call from succeeding
// until the next read operation.
func (b *Reader) Peek(n int) ([]byte, error) {
if n < 0 {
return nil, ErrNegativeCount
}
b.lastByte = -1
b.lastRuneSize = -1
for b.w-b.r < n && b.w-b.r < len(b.buf) && b.err == nil {
b.fill() // b.w-b.r < len(b.buf) => buffer is not full
}
if n > len(b.buf) {
return b.buf[b.r:b.w], ErrBufferFull
}
// 0 <= n <= len(b.buf)
var err error
if avail := b.w - b.r; avail < n {
// not enough data in buffer
n = avail
err = b.readErr()
if err == nil {
err = ErrBufferFull
}
}
return b.buf[b.r : b.r+n], err
}
// Discard skips the next n bytes, returning the number of bytes discarded.
//
// If Discard skips fewer than n bytes, it also returns an error.
// If 0 <= n <= b.Buffered(), Discard is guaranteed to succeed without
// reading from the underlying io.Reader.
func (b *Reader) Discard(n int) (discarded int, err error) {
if n < 0 {
return 0, ErrNegativeCount
}
if n == 0 {
return
}
b.lastByte = -1
b.lastRuneSize = -1
remain := n
for {
skip := b.Buffered()
if skip == 0 {
b.fill()
skip = b.Buffered()
}
if skip > remain {
skip = remain
}
b.r += skip
remain -= skip
if remain == 0 {
return n, nil
}
if b.err != nil {
return n - remain, b.readErr()
}
}
}
// Read reads data into p.
// It returns the number of bytes read into p.
// The bytes are taken from at most one Read on the underlying [Reader],
// hence n may be less than len(p).
// To read exactly len(p) bytes, use io.ReadFull(b, p).
// If the underlying [Reader] can return a non-zero count with io.EOF,
// then this Read method can do so as well; see the [io.Reader] docs.
func (b *Reader) Read(p []byte) (n int, err error) {
n = len(p)
if n == 0 {
if b.Buffered() > 0 {
return 0, nil
}
return 0, b.readErr()
}
if b.r == b.w {
if b.err != nil {
return 0, b.readErr()
}
if len(p) >= len(b.buf) {
// Large read, empty buffer.
// Read directly into p to avoid copy.
n, b.err = b.rd.Read(p)
if n < 0 {
panic(errNegativeRead)
}
if n > 0 {
b.lastByte = int(p[n-1])
b.lastRuneSize = -1
}
return n, b.readErr()
}
// One read.
// Do not use b.fill, which will loop.
b.r = 0
b.w = 0
n, b.err = b.rd.Read(b.buf)
if n < 0 {
panic(errNegativeRead)
}
if n == 0 {
return 0, b.readErr()
}
b.w += n
}
// copy as much as we can
// Note: if the slice panics here, it is probably because
// the underlying reader returned a bad count. See issue 49795.
n = copy(p, b.buf[b.r:b.w])
b.r += n
b.lastByte = int(b.buf[b.r-1])
b.lastRuneSize = -1
return n, nil
}
// ReadByte reads and returns a single byte.
// If no byte is available, returns an error.
func (b *Reader) ReadByte() (byte, error) {
b.lastRuneSize = -1
for b.r == b.w {
if b.err != nil {
return 0, b.readErr()
}
b.fill() // buffer is empty
}
c := b.buf[b.r]
b.r++
b.lastByte = int(c)
return c, nil
}
// UnreadByte unreads the last byte. Only the most recently read byte can be unread.
//
// UnreadByte returns an error if the most recent method called on the
// [Reader] was not a read operation. Notably, [Reader.Peek], [Reader.Discard], and [Reader.WriteTo] are not
// considered read operations.
func (b *Reader) UnreadByte() error {
if b.lastByte < 0 || b.r == 0 && b.w > 0 {
return ErrInvalidUnreadByte
}
// b.r > 0 || b.w == 0
if b.r > 0 {
b.r--
} else {
// b.r == 0 && b.w == 0
b.w = 1
}
b.buf[b.r] = byte(b.lastByte)
b.lastByte = -1
b.lastRuneSize = -1
return nil
}
// ReadRune reads a single UTF-8 encoded Unicode character and returns the
// rune and its size in bytes. If the encoded rune is invalid, it consumes one byte
// and returns unicode.ReplacementChar (U+FFFD) with a size of 1.
func (b *Reader) ReadRune() (r rune, size int, err error) {
for b.r+utf8.UTFMax > b.w && !utf8.FullRune(b.buf[b.r:b.w]) && b.err == nil && b.w-b.r < len(b.buf) {
b.fill() // b.w-b.r < len(buf) => buffer is not full
}
b.lastRuneSize = -1
if b.r == b.w {
return 0, 0, b.readErr()
}
r, size = rune(b.buf[b.r]), 1
if r >= utf8.RuneSelf {
r, size = utf8.DecodeRune(b.buf[b.r:b.w])
}
b.r += size
b.lastByte = int(b.buf[b.r-1])
b.lastRuneSize = size
return r, size, nil
}
// UnreadRune unreads the last rune. If the most recent method called on
// the [Reader] was not a [Reader.ReadRune], [Reader.UnreadRune] returns an error. (In this
// regard it is stricter than [Reader.UnreadByte], which will unread the last byte
// from any read operation.)
func (b *Reader) UnreadRune() error {
if b.lastRuneSize < 0 || b.r < b.lastRuneSize {
return ErrInvalidUnreadRune
}
b.r -= b.lastRuneSize
b.lastByte = -1
b.lastRuneSize = -1
return nil
}
// Buffered returns the number of bytes that can be read from the current buffer.
func (b *Reader) Buffered() int { return b.w - b.r }
// ReadSlice reads until the first occurrence of delim in the input,
// returning a slice pointing at the bytes in the buffer.
// The bytes stop being valid at the next read.
// If ReadSlice encounters an error before finding a delimiter,
// it returns all the data in the buffer and the error itself (often io.EOF).
// ReadSlice fails with error [ErrBufferFull] if the buffer fills without a delim.
// Because the data returned from ReadSlice will be overwritten
// by the next I/O operation, most clients should use
// [Reader.ReadBytes] or ReadString instead.
// ReadSlice returns err != nil if and only if line does not end in delim.
func (b *Reader) ReadSlice(delim byte) (line []byte, err error) {
s := 0 // search start index
for {
// Search buffer.
if i := bytes.IndexByte(b.buf[b.r+s:b.w], delim); i >= 0 {
i += s
line = b.buf[b.r : b.r+i+1]
b.r += i + 1
break
}
// Pending error?
if b.err != nil {
line = b.buf[b.r:b.w]
b.r = b.w
err = b.readErr()
break
}
// Buffer full?
if b.Buffered() >= len(b.buf) {
b.r = b.w
line = b.buf
err = ErrBufferFull
break
}
s = b.w - b.r // do not rescan area we scanned before
b.fill() // buffer is not full
}
// Handle last byte, if any.
if i := len(line) - 1; i >= 0 {
b.lastByte = int(line[i])
b.lastRuneSize = -1
}
return
}
// ReadLine is a low-level line-reading primitive. Most callers should use
// [Reader.ReadBytes]('\n') or [Reader.ReadString]('\n') instead or use a [Scanner].
//
// ReadLine tries to return a single line, not including the end-of-line bytes.
// If the line was too long for the buffer then isPrefix is set and the
// beginning of the line is returned. The rest of the line will be returned
// from future calls. isPrefix will be false when returning the last fragment
// of the line. The returned buffer is only valid until the next call to
// ReadLine. ReadLine either returns a non-nil line or it returns an error,
// never both.
//
// The text returned from ReadLine does not include the line end ("\r\n" or "\n").
// No indication or error is given if the input ends without a final line end.
// Calling [Reader.UnreadByte] after ReadLine will always unread the last byte read
// (possibly a character belonging to the line end) even if that byte is not
// part of the line returned by ReadLine.
func (b *Reader) ReadLine() (line []byte, isPrefix bool, err error) {
line, err = b.ReadSlice('\n')
if err == ErrBufferFull {
// Handle the case where "\r\n" straddles the buffer.
if len(line) > 0 && line[len(line)-1] == '\r' {
// Put the '\r' back on buf and drop it from line.
// Let the next call to ReadLine check for "\r\n".
if b.r == 0 {
// should be unreachable
panic("bufio: tried to rewind past start of buffer")
}
b.r--
line = line[:len(line)-1]
}
return line, true, nil
}
if len(line) == 0 {
if err != nil {
line = nil
}
return
}
err = nil
if line[len(line)-1] == '\n' {
drop := 1
if len(line) > 1 && line[len(line)-2] == '\r' {
drop = 2
}
line = line[:len(line)-drop]
}
return
}
// collectFragments reads until the first occurrence of delim in the input. It
// returns (slice of full buffers, remaining bytes before delim, total number
// of bytes in the combined first two elements, error).
// The complete result is equal to
// `bytes.Join(append(fullBuffers, finalFragment), nil)`, which has a
// length of `totalLen`. The result is structured in this way to allow callers
// to minimize allocations and copies.
func (b *Reader) collectFragments(delim byte) (fullBuffers [][]byte, finalFragment []byte, totalLen int, err error) {
var frag []byte
// Use ReadSlice to look for delim, accumulating full buffers.
for {
var e error
frag, e = b.ReadSlice(delim)
if e == nil { // got final fragment
break
}
if e != ErrBufferFull { // unexpected error
err = e
break
}
// Make a copy of the buffer.
buf := bytes.Clone(frag)
fullBuffers = append(fullBuffers, buf)
totalLen += len(buf)
}
totalLen += len(frag)
return fullBuffers, frag, totalLen, err
}
// ReadBytes reads until the first occurrence of delim in the input,
// returning a slice containing the data up to and including the delimiter.
// If ReadBytes encounters an error before finding a delimiter,
// it returns the data read before the error and the error itself (often io.EOF).
// ReadBytes returns err != nil if and only if the returned data does not end in
// delim.
// For simple uses, a Scanner may be more convenient.
func (b *Reader) ReadBytes(delim byte) ([]byte, error) {
full, frag, n, err := b.collectFragments(delim)
// Allocate new buffer to hold the full pieces and the fragment.
buf := make([]byte, n)
n = 0
// Copy full pieces and fragment in.
for i := range full {
n += copy(buf[n:], full[i])
}
copy(buf[n:], frag)
return buf, err
}
// ReadString reads until the first occurrence of delim in the input,
// returning a string containing the data up to and including the delimiter.
// If ReadString encounters an error before finding a delimiter,
// it returns the data read before the error and the error itself (often io.EOF).
// ReadString returns err != nil if and only if the returned data does not end in
// delim.
// For simple uses, a Scanner may be more convenient.
func (b *Reader) ReadString(delim byte) (string, error) {
full, frag, n, err := b.collectFragments(delim)
// Allocate new buffer to hold the full pieces and the fragment.
var buf strings.Builder
buf.Grow(n)
// Copy full pieces and fragment in.
for _, fb := range full {
buf.Write(fb)
}
buf.Write(frag)
return buf.String(), err
}
// WriteTo implements io.WriterTo.
// This may make multiple calls to the [Reader.Read] method of the underlying [Reader].
// If the underlying reader supports the [Reader.WriteTo] method,
// this calls the underlying [Reader.WriteTo] without buffering.
func (b *Reader) WriteTo(w io.Writer) (n int64, err error) {
b.lastByte = -1
b.lastRuneSize = -1
if b.r < b.w {
n, err = b.writeBuf(w)
if err != nil {
return
}
}
if r, ok := b.rd.(io.WriterTo); ok {
m, err := r.WriteTo(w)
n += m
return n, err
}
if w, ok := w.(io.ReaderFrom); ok {
m, err := w.ReadFrom(b.rd)
n += m
return n, err
}
if b.w-b.r < len(b.buf) {
b.fill() // buffer not full
}
for b.r < b.w {
// b.r < b.w => buffer is not empty
m, err := b.writeBuf(w)
n += m
if err != nil {
return n, err
}
b.fill() // buffer is empty
}
if b.err == io.EOF {
b.err = nil
}
return n, b.readErr()
}
var errNegativeWrite = errors.New("bufio: writer returned negative count from Write")
// writeBuf writes the [Reader]'s buffer to the writer.
func (b *Reader) writeBuf(w io.Writer) (int64, error) {
n, err := w.Write(b.buf[b.r:b.w])
if n < 0 {
panic(errNegativeWrite)
}
b.r += n
return int64(n), err
}
// buffered output
// Writer implements buffering for an [io.Writer] object.
// If an error occurs writing to a [Writer], no more data will be
// accepted and all subsequent writes, and [Writer.Flush], will return the error.
// After all data has been written, the client should call the
// [Writer.Flush] method to guarantee all data has been forwarded to
// the underlying [io.Writer].
type Writer struct {
err error
buf []byte
n int
wr io.Writer
}
// NewWriterSize returns a new [Writer] whose buffer has at least the specified
// size. If the argument io.Writer is already a [Writer] with large enough
// size, it returns the underlying [Writer].
func NewWriterSize(w io.Writer, size int) *Writer {
// Is it already a Writer?
b, ok := w.(*Writer)
if ok && len(b.buf) >= size {
return b
}
if size <= 0 {
size = defaultBufSize
}
return &Writer{
buf: make([]byte, size),
wr: w,
}
}
// NewWriter returns a new [Writer] whose buffer has the default size.
// If the argument io.Writer is already a [Writer] with large enough buffer size,
// it returns the underlying [Writer].
func NewWriter(w io.Writer) *Writer {
return NewWriterSize(w, defaultBufSize)
}
// Size returns the size of the underlying buffer in bytes.
func (b *Writer) Size() int { return len(b.buf) }
// Reset discards any unflushed buffered data, clears any error, and
// resets b to write its output to w.
// Calling Reset on the zero value of [Writer] initializes the internal buffer
// to the default size.
// Calling w.Reset(w) (that is, resetting a [Writer] to itself) does nothing.
func (b *Writer) Reset(w io.Writer) {
// If a Writer w is passed to NewWriter, NewWriter will return w.
// Different layers of code may do that, and then later pass w
// to Reset. Avoid infinite recursion in that case.
if b == w {
return
}
if b.buf == nil {
b.buf = make([]byte, defaultBufSize)
}
b.err = nil
b.n = 0
b.wr = w
}
// Flush writes any buffered data to the underlying [io.Writer].
func (b *Writer) Flush() error {
if b.err != nil {
return b.err
}
if b.n == 0 {
return nil
}
n, err := b.wr.Write(b.buf[0:b.n])
if n < b.n && err == nil {
err = io.ErrShortWrite
}
if err != nil {
if n > 0 && n < b.n {
copy(b.buf[0:b.n-n], b.buf[n:b.n])
}
b.n -= n
b.err = err
return err
}
b.n = 0
return nil
}
// Available returns how many bytes are unused in the buffer.
func (b *Writer) Available() int { return len(b.buf) - b.n }
// AvailableBuffer returns an empty buffer with b.Available() capacity.
// This buffer is intended to be appended to and
// passed to an immediately succeeding [Writer.Write] call.
// The buffer is only valid until the next write operation on b.
func (b *Writer) AvailableBuffer() []byte {
return b.buf[b.n:][:0]
}
// Buffered returns the number of bytes that have been written into the current buffer.
func (b *Writer) Buffered() int { return b.n }
// Write writes the contents of p into the buffer.
// It returns the number of bytes written.
// If nn < len(p), it also returns an error explaining
// why the write is short.
func (b *Writer) Write(p []byte) (nn int, err error) {
for len(p) > b.Available() && b.err == nil {
var n int
if b.Buffered() == 0 {
// Large write, empty buffer.
// Write directly from p to avoid copy.
n, b.err = b.wr.Write(p)
} else {
n = copy(b.buf[b.n:], p)
b.n += n
b.Flush()
}
nn += n
p = p[n:]
}
if b.err != nil {
return nn, b.err
}
n := copy(b.buf[b.n:], p)
b.n += n
nn += n
return nn, nil
}
// WriteByte writes a single byte.
func (b *Writer) WriteByte(c byte) error {
if b.err != nil {
return b.err
}
if b.Available() <= 0 && b.Flush() != nil {
return b.err
}
b.buf[b.n] = c
b.n++
return nil
}
// WriteRune writes a single Unicode code point, returning
// the number of bytes written and any error.
func (b *Writer) WriteRune(r rune) (size int, err error) {
// Compare as uint32 to correctly handle negative runes.
if uint32(r) < utf8.RuneSelf {
err = b.WriteByte(byte(r))
if err != nil {
return 0, err
}
return 1, nil
}
if b.err != nil {
return 0, b.err
}
n := b.Available()
if n < utf8.UTFMax {
if b.Flush(); b.err != nil {
return 0, b.err
}
n = b.Available()
if n < utf8.UTFMax {
// Can only happen if buffer is silly small.
return b.WriteString(string(r))
}
}
size = utf8.EncodeRune(b.buf[b.n:], r)
b.n += size
return size, nil
}
// WriteString writes a string.
// It returns the number of bytes written.
// If the count is less than len(s), it also returns an error explaining
// why the write is short.
func (b *Writer) WriteString(s string) (int, error) {
var sw io.StringWriter
tryStringWriter := true
nn := 0
for len(s) > b.Available() && b.err == nil {
var n int
if b.Buffered() == 0 && sw == nil && tryStringWriter {
// Check at most once whether b.wr is a StringWriter.
sw, tryStringWriter = b.wr.(io.StringWriter)
}
if b.Buffered() == 0 && tryStringWriter {
// Large write, empty buffer, and the underlying writer supports
// WriteString: forward the write to the underlying StringWriter.
// This avoids an extra copy.
n, b.err = sw.WriteString(s)
} else {
n = copy(b.buf[b.n:], s)
b.n += n
b.Flush()
}
nn += n
s = s[n:]
}
if b.err != nil {
return nn, b.err
}
n := copy(b.buf[b.n:], s)
b.n += n
nn += n
return nn, nil
}
// ReadFrom implements [io.ReaderFrom]. If the underlying writer
// supports the ReadFrom method, this calls the underlying ReadFrom.
// If there is buffered data and an underlying ReadFrom, this fills
// the buffer and writes it before calling ReadFrom.
func (b *Writer) ReadFrom(r io.Reader) (n int64, err error) {
if b.err != nil {
return 0, b.err
}
readerFrom, readerFromOK := b.wr.(io.ReaderFrom)
var m int
for {
if b.Available() == 0 {
if err1 := b.Flush(); err1 != nil {
return n, err1
}
}
if readerFromOK && b.Buffered() == 0 {
nn, err := readerFrom.ReadFrom(r)
b.err = err
n += nn
return n, err
}
nr := 0
for nr < maxConsecutiveEmptyReads {
m, err = r.Read(b.buf[b.n:])
if m != 0 || err != nil {
break
}
nr++
}
if nr == maxConsecutiveEmptyReads {
return n, io.ErrNoProgress
}
b.n += m
n += int64(m)
if err != nil {
break
}
}
if err == io.EOF {
// If we filled the buffer exactly, flush preemptively.
if b.Available() == 0 {
err = b.Flush()
} else {
err = nil
}
}
return n, err
}
// buffered input and output
// ReadWriter stores pointers to a [Reader] and a [Writer].
// It implements [io.ReadWriter].
type ReadWriter struct {
*Reader
*Writer
}
// NewReadWriter allocates a new [ReadWriter] that dispatches to r and w.
func NewReadWriter(r *Reader, w *Writer) *ReadWriter {
return &ReadWriter{r, w}
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bufio
import (
"bytes"
"errors"
"io"
"unicode/utf8"
)
// Scanner provides a convenient interface for reading data such as
// a file of newline-delimited lines of text. Successive calls to
// the [Scanner.Scan] method will step through the 'tokens' of a file, skipping
// the bytes between the tokens. The specification of a token is
// defined by a split function of type [SplitFunc]; the default split
// function breaks the input into lines with line termination stripped. [Scanner.Split]
// functions are defined in this package for scanning a file into
// lines, bytes, UTF-8-encoded runes, and space-delimited words. The
// client may instead provide a custom split function.
//
// Scanning stops unrecoverably at EOF, the first I/O error, or a token too
// large to fit in the [Scanner.Buffer]. When a scan stops, the reader may have
// advanced arbitrarily far past the last token. Programs that need more
// control over error handling or large tokens, or must run sequential scans
// on a reader, should use [bufio.Reader] instead.
type Scanner struct {
r io.Reader // The reader provided by the client.
split SplitFunc // The function to split the tokens.
maxTokenSize int // Maximum size of a token; modified by tests.
token []byte // Last token returned by split.
buf []byte // Buffer used as argument to split.
start int // First non-processed byte in buf.
end int // End of data in buf.
err error // Sticky error.
empties int // Count of successive empty tokens.
scanCalled bool // Scan has been called; buffer is in use.
done bool // Scan has finished.
}
// SplitFunc is the signature of the split function used to tokenize the
// input. The arguments are an initial substring of the remaining unprocessed
// data and a flag, atEOF, that reports whether the [Reader] has no more data
// to give. The return values are the number of bytes to advance the input
// and the next token to return to the user, if any, plus an error, if any.
//
// Scanning stops if the function returns an error, in which case some of
// the input may be discarded. If that error is [ErrFinalToken], scanning
// stops with no error. A non-nil token delivered with [ErrFinalToken]
// will be the last token, and a nil token with [ErrFinalToken]
// immediately stops the scanning.
//
// Otherwise, the [Scanner] advances the input. If the token is not nil,
// the [Scanner] returns it to the user. If the token is nil, the
// Scanner reads more data and continues scanning; if there is no more
// data--if atEOF was true--the [Scanner] returns. If the data does not
// yet hold a complete token, for instance if it has no newline while
// scanning lines, a [SplitFunc] can return (0, nil, nil) to signal the
// [Scanner] to read more data into the slice and try again with a
// longer slice starting at the same point in the input.
//
// The function is never called with an empty data slice unless atEOF
// is true. If atEOF is true, however, data may be non-empty and,
// as always, holds unprocessed text.
type SplitFunc func(data []byte, atEOF bool) (advance int, token []byte, err error)
// Errors returned by Scanner.
var (
ErrTooLong = errors.New("bufio.Scanner: token too long")
ErrNegativeAdvance = errors.New("bufio.Scanner: SplitFunc returns negative advance count")
ErrAdvanceTooFar = errors.New("bufio.Scanner: SplitFunc returns advance count beyond input")
ErrBadReadCount = errors.New("bufio.Scanner: Read returned impossible count")
)
const (
// MaxScanTokenSize is the maximum size used to buffer a token
// unless the user provides an explicit buffer with [Scanner.Buffer].
// The actual maximum token size may be smaller as the buffer
// may need to include, for instance, a newline.
MaxScanTokenSize = 64 * 1024
startBufSize = 4096 // Size of initial allocation for buffer.
)
// NewScanner returns a new [Scanner] to read from r.
// The split function defaults to [ScanLines].
func NewScanner(r io.Reader) *Scanner {
return &Scanner{
r: r,
split: ScanLines,
maxTokenSize: MaxScanTokenSize,
}
}
// Err returns the first non-EOF error that was encountered by the [Scanner].
func (s *Scanner) Err() error {
if s.err == io.EOF {
return nil
}
return s.err
}
// Bytes returns the most recent token generated by a call to [Scanner.Scan].
// The underlying array may point to data that will be overwritten
// by a subsequent call to Scan. It does no allocation.
func (s *Scanner) Bytes() []byte {
return s.token
}
// Text returns the most recent token generated by a call to [Scanner.Scan]
// as a newly allocated string holding its bytes.
func (s *Scanner) Text() string {
return string(s.token)
}
// ErrFinalToken is a special sentinel error value. It is intended to be
// returned by a Split function to indicate that the scanning should stop
// with no error. If the token being delivered with this error is not nil,
// the token is the last token.
//
// The value is useful to stop processing early or when it is necessary to
// deliver a final empty token (which is different from a nil token).
// One could achieve the same behavior with a custom error value but
// providing one here is tidier.
// See the emptyFinalToken example for a use of this value.
var ErrFinalToken = errors.New("final token")
// Scan advances the [Scanner] to the next token, which will then be
// available through the [Scanner.Bytes] or [Scanner.Text] method. It returns false when
// there are no more tokens, either by reaching the end of the input or an error.
// After Scan returns false, the [Scanner.Err] method will return any error that
// occurred during scanning, except that if it was [io.EOF], [Scanner.Err]
// will return nil.
// Scan panics if the split function returns too many empty
// tokens without advancing the input. This is a common error mode for
// scanners.
func (s *Scanner) Scan() bool {
if s.done {
return false
}
s.scanCalled = true
// Loop until we have a token.
for {
// See if we can get a token with what we already have.
// If we've run out of data but have an error, give the split function
// a chance to recover any remaining, possibly empty token.
if s.end > s.start || s.err != nil {
advance, token, err := s.split(s.buf[s.start:s.end], s.err != nil)
if err != nil {
if err == ErrFinalToken {
s.token = token
s.done = true
// When token is not nil, it means the scanning stops
// with a trailing token, and thus the return value
// should be true to indicate the existence of the token.
return token != nil
}
s.setErr(err)
return false
}
if !s.advance(advance) {
return false
}
s.token = token
if token != nil {
if s.err == nil || advance > 0 {
s.empties = 0
} else {
// Returning tokens not advancing input at EOF.
s.empties++
if s.empties > maxConsecutiveEmptyReads {
panic("bufio.Scan: too many empty tokens without progressing")
}
}
return true
}
}
// We cannot generate a token with what we are holding.
// If we've already hit EOF or an I/O error, we are done.
if s.err != nil {
// Shut it down.
s.start = 0
s.end = 0
return false
}
// Must read more data.
// First, shift data to beginning of buffer if there's lots of empty space
// or space is needed.
if s.start > 0 && (s.end == len(s.buf) || s.start > len(s.buf)/2) {
copy(s.buf, s.buf[s.start:s.end])
s.end -= s.start
s.start = 0
}
// Is the buffer full? If so, resize.
if s.end == len(s.buf) {
// Guarantee no overflow in the multiplication below.
const maxInt = int(^uint(0) >> 1)
if len(s.buf) >= s.maxTokenSize || len(s.buf) > maxInt/2 {
s.setErr(ErrTooLong)
return false
}
newSize := len(s.buf) * 2
if newSize == 0 {
newSize = startBufSize
}
newSize = min(newSize, s.maxTokenSize)
newBuf := make([]byte, newSize)
copy(newBuf, s.buf[s.start:s.end])
s.buf = newBuf
s.end -= s.start
s.start = 0
}
// Finally we can read some input. Make sure we don't get stuck with
// a misbehaving Reader. Officially we don't need to do this, but let's
// be extra careful: Scanner is for safe, simple jobs.
for loop := 0; ; {
n, err := s.r.Read(s.buf[s.end:len(s.buf)])
if n < 0 || len(s.buf)-s.end < n {
s.setErr(ErrBadReadCount)
break
}
s.end += n
if err != nil {
s.setErr(err)
break
}
if n > 0 {
s.empties = 0
break
}
loop++
if loop > maxConsecutiveEmptyReads {
s.setErr(io.ErrNoProgress)
break
}
}
}
}
// advance consumes n bytes of the buffer. It reports whether the advance was legal.
func (s *Scanner) advance(n int) bool {
if n < 0 {
s.setErr(ErrNegativeAdvance)
return false
}
if n > s.end-s.start {
s.setErr(ErrAdvanceTooFar)
return false
}
s.start += n
return true
}
// setErr records the first error encountered.
func (s *Scanner) setErr(err error) {
if s.err == nil || s.err == io.EOF {
s.err = err
}
}
// Buffer controls memory allocation by the Scanner.
// It sets the initial buffer to use when scanning
// and the maximum size of buffer that may be allocated during scanning.
// The contents of the buffer are ignored.
//
// The maximum token size must be less than the larger of max and cap(buf).
// If max <= cap(buf), [Scanner.Scan] will use this buffer only and do no allocation.
//
// By default, [Scanner.Scan] uses an internal buffer and sets the
// maximum token size to [MaxScanTokenSize].
//
// Buffer panics if it is called after scanning has started.
func (s *Scanner) Buffer(buf []byte, max int) {
if s.scanCalled {
panic("Buffer called after Scan")
}
s.buf = buf[0:cap(buf)]
s.maxTokenSize = max
}
// Split sets the split function for the [Scanner].
// The default split function is [ScanLines].
//
// Split panics if it is called after scanning has started.
func (s *Scanner) Split(split SplitFunc) {
if s.scanCalled {
panic("Split called after Scan")
}
s.split = split
}
// Split functions
// ScanBytes is a split function for a [Scanner] that returns each byte as a token.
func ScanBytes(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
return 1, data[0:1], nil
}
var errorRune = []byte(string(utf8.RuneError))
// ScanRunes is a split function for a [Scanner] that returns each
// UTF-8-encoded rune as a token. The sequence of runes returned is
// equivalent to that from a range loop over the input as a string, which
// means that erroneous UTF-8 encodings translate to U+FFFD = "\xef\xbf\xbd".
// Because of the Scan interface, this makes it impossible for the client to
// distinguish correctly encoded replacement runes from encoding errors.
func ScanRunes(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
// Fast path 1: ASCII.
if data[0] < utf8.RuneSelf {
return 1, data[0:1], nil
}
// Fast path 2: Correct UTF-8 decode without error.
_, width := utf8.DecodeRune(data)
if width > 1 {
// It's a valid encoding. Width cannot be one for a correctly encoded
// non-ASCII rune.
return width, data[0:width], nil
}
// We know it's an error: we have width==1 and implicitly r==utf8.RuneError.
// Is the error because there wasn't a full rune to be decoded?
// FullRune distinguishes correctly between erroneous and incomplete encodings.
if !atEOF && !utf8.FullRune(data) {
// Incomplete; get more bytes.
return 0, nil, nil
}
// We have a real UTF-8 encoding error. Return a properly encoded error rune
// but advance only one byte. This matches the behavior of a range loop over
// an incorrectly encoded string.
return 1, errorRune, nil
}
// dropCR drops a terminal \r from the data.
func dropCR(data []byte) []byte {
if len(data) > 0 && data[len(data)-1] == '\r' {
return data[0 : len(data)-1]
}
return data
}
// ScanLines is a split function for a [Scanner] that returns each line of
// text, stripped of any trailing end-of-line marker. The returned line may
// be empty. The end-of-line marker is one optional carriage return followed
// by one mandatory newline. In regular expression notation, it is `\r?\n`.
// The last non-empty line of input will be returned even if it has no
// newline.
func ScanLines(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\n'); i >= 0 {
// We have a full newline-terminated line.
return i + 1, dropCR(data[0:i]), nil
}
// If we're at EOF, we have a final, non-terminated line. Return it.
if atEOF {
return len(data), dropCR(data), nil
}
// Request more data.
return 0, nil, nil
}
// isSpace reports whether the character is a Unicode white space character.
// We avoid dependency on the unicode package, but check validity of the implementation
// in the tests.
func isSpace(r rune) bool {
if r <= '\u00FF' {
// Obvious ASCII ones: \t through \r plus space. Plus two Latin-1 oddballs.
switch r {
case ' ', '\t', '\n', '\v', '\f', '\r':
return true
case '\u0085', '\u00A0':
return true
}
return false
}
// High-valued ones.
if '\u2000' <= r && r <= '\u200a' {
return true
}
switch r {
case '\u1680', '\u2028', '\u2029', '\u202f', '\u205f', '\u3000':
return true
}
return false
}
// ScanWords is a split function for a [Scanner] that returns each
// space-separated word of text, with surrounding spaces deleted. It will
// never return an empty string. The definition of space is set by
// unicode.IsSpace.
func ScanWords(data []byte, atEOF bool) (advance int, token []byte, err error) {
// Skip leading spaces.
start := 0
for width := 0; start < len(data); start += width {
var r rune
r, width = utf8.DecodeRune(data[start:])
if !isSpace(r) {
break
}
}
// Scan until space, marking end of word.
for width, i := 0, start; i < len(data); i += width {
var r rune
r, width = utf8.DecodeRune(data[i:])
if isSpace(r) {
return i + width, data[start:i], nil
}
}
// If we're at EOF, we have a final, non-empty, non-terminated word. Return it.
if atEOF && len(data) > start {
return len(data), data[start:], nil
}
// Request more data.
return start, nil, nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bzip2
import (
"bufio"
"io"
)
// bitReader wraps an io.Reader and provides the ability to read values,
// bit-by-bit, from it. Its Read* methods don't return the usual error
// because the error handling was verbose. Instead, any error is kept and can
// be checked afterwards.
type bitReader struct {
r io.ByteReader
n uint64
bits uint
err error
}
// newBitReader returns a new bitReader reading from r. If r is not
// already an io.ByteReader, it will be converted via a bufio.Reader.
func newBitReader(r io.Reader) bitReader {
byter, ok := r.(io.ByteReader)
if !ok {
byter = bufio.NewReader(r)
}
return bitReader{r: byter}
}
// ReadBits64 reads the given number of bits and returns them in the
// least-significant part of a uint64. In the event of an error, it returns 0
// and the error can be obtained by calling bitReader.Err().
func (br *bitReader) ReadBits64(bits uint) (n uint64) {
for bits > br.bits {
b, err := br.r.ReadByte()
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
if err != nil {
br.err = err
return 0
}
br.n <<= 8
br.n |= uint64(b)
br.bits += 8
}
// br.n looks like this (assuming that br.bits = 14 and bits = 6):
// Bit: 111111
// 5432109876543210
//
// (6 bits, the desired output)
// |-----|
// V V
// 0101101101001110
// ^ ^
// |------------|
// br.bits (num valid bits)
//
// The next line right shifts the desired bits into the
// least-significant places and masks off anything above.
n = (br.n >> (br.bits - bits)) & ((1 << bits) - 1)
br.bits -= bits
return
}
func (br *bitReader) ReadBits(bits uint) (n int) {
n64 := br.ReadBits64(bits)
return int(n64)
}
func (br *bitReader) ReadBit() bool {
n := br.ReadBits(1)
return n != 0
}
func (br *bitReader) Err() error {
return br.err
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package bzip2 implements bzip2 decompression.
package bzip2
import "io"
// There's no RFC for bzip2. I used the Wikipedia page for reference and a lot
// of guessing: https://en.wikipedia.org/wiki/Bzip2
// The source code to pyflate was useful for debugging:
// http://www.paul.sladen.org/projects/pyflate
// A StructuralError is returned when the bzip2 data is found to be
// syntactically invalid.
type StructuralError string
func (s StructuralError) Error() string {
return "bzip2 data invalid: " + string(s)
}
// A reader decompresses bzip2 compressed data.
type reader struct {
br bitReader
fileCRC uint32
blockCRC uint32
wantBlockCRC uint32
setupDone bool // true if we have parsed the bzip2 header.
eof bool
blockSize int // blockSize in bytes, i.e. 900 * 1000.
c [256]uint // the ``C'' array for the inverse BWT.
tt []uint32 // mirrors the ``tt'' array in the bzip2 source and contains the P array in the upper 24 bits.
tPos uint32 // Index of the next output byte in tt.
preRLE []uint32 // contains the RLE data still to be processed.
preRLEUsed int // number of entries of preRLE used.
lastByte int // the last byte value seen.
byteRepeats uint // the number of repeats of lastByte seen.
repeats uint // the number of copies of lastByte to output.
}
// NewReader returns an [io.Reader] which decompresses bzip2 data from r.
// If r does not also implement [io.ByteReader],
// the decompressor may read more data than necessary from r.
func NewReader(r io.Reader) io.Reader {
bz2 := new(reader)
bz2.br = newBitReader(r)
return bz2
}
const bzip2FileMagic = 0x425a // "BZ"
const bzip2BlockMagic = 0x314159265359
const bzip2FinalMagic = 0x177245385090
// setup parses the bzip2 header.
func (bz2 *reader) setup(needMagic bool) error {
br := &bz2.br
if needMagic {
magic := br.ReadBits(16)
if magic != bzip2FileMagic {
return StructuralError("bad magic value")
}
}
t := br.ReadBits(8)
if t != 'h' {
return StructuralError("non-Huffman entropy encoding")
}
level := br.ReadBits(8)
if level < '1' || level > '9' {
return StructuralError("invalid compression level")
}
bz2.fileCRC = 0
bz2.blockSize = 100 * 1000 * (level - '0')
if bz2.blockSize > len(bz2.tt) {
bz2.tt = make([]uint32, bz2.blockSize)
}
return nil
}
func (bz2 *reader) Read(buf []byte) (n int, err error) {
if bz2.eof {
return 0, io.EOF
}
if !bz2.setupDone {
err = bz2.setup(true)
brErr := bz2.br.Err()
if brErr != nil {
err = brErr
}
if err != nil {
return 0, err
}
bz2.setupDone = true
}
n, err = bz2.read(buf)
brErr := bz2.br.Err()
if brErr != nil {
err = brErr
}
return
}
func (bz2 *reader) readFromBlock(buf []byte) int {
// bzip2 is a block based compressor, except that it has a run-length
// preprocessing step. The block based nature means that we can
// preallocate fixed-size buffers and reuse them. However, the RLE
// preprocessing would require allocating huge buffers to store the
// maximum expansion. Thus we process blocks all at once, except for
// the RLE which we decompress as required.
n := 0
for (bz2.repeats > 0 || bz2.preRLEUsed < len(bz2.preRLE)) && n < len(buf) {
// We have RLE data pending.
// The run-length encoding works like this:
// Any sequence of four equal bytes is followed by a length
// byte which contains the number of repeats of that byte to
// include. (The number of repeats can be zero.) Because we are
// decompressing on-demand our state is kept in the reader
// object.
if bz2.repeats > 0 {
buf[n] = byte(bz2.lastByte)
n++
bz2.repeats--
if bz2.repeats == 0 {
bz2.lastByte = -1
}
continue
}
bz2.tPos = bz2.preRLE[bz2.tPos]
b := byte(bz2.tPos)
bz2.tPos >>= 8
bz2.preRLEUsed++
if bz2.byteRepeats == 3 {
bz2.repeats = uint(b)
bz2.byteRepeats = 0
continue
}
if bz2.lastByte == int(b) {
bz2.byteRepeats++
} else {
bz2.byteRepeats = 0
}
bz2.lastByte = int(b)
buf[n] = b
n++
}
return n
}
func (bz2 *reader) read(buf []byte) (int, error) {
for {
n := bz2.readFromBlock(buf)
if n > 0 || len(buf) == 0 {
bz2.blockCRC = updateCRC(bz2.blockCRC, buf[:n])
return n, nil
}
// End of block. Check CRC.
if bz2.blockCRC != bz2.wantBlockCRC {
bz2.br.err = StructuralError("block checksum mismatch")
return 0, bz2.br.err
}
// Find next block.
br := &bz2.br
switch br.ReadBits64(48) {
default:
return 0, StructuralError("bad magic value found")
case bzip2BlockMagic:
// Start of block.
err := bz2.readBlock()
if err != nil {
return 0, err
}
case bzip2FinalMagic:
// Check end-of-file CRC.
wantFileCRC := uint32(br.ReadBits64(32))
if br.err != nil {
return 0, br.err
}
if bz2.fileCRC != wantFileCRC {
br.err = StructuralError("file checksum mismatch")
return 0, br.err
}
// Skip ahead to byte boundary.
// Is there a file concatenated to this one?
// It would start with BZ.
if br.bits%8 != 0 {
br.ReadBits(br.bits % 8)
}
b, err := br.r.ReadByte()
if err == io.EOF {
br.err = io.EOF
bz2.eof = true
return 0, io.EOF
}
if err != nil {
br.err = err
return 0, err
}
z, err := br.r.ReadByte()
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
br.err = err
return 0, err
}
if b != 'B' || z != 'Z' {
return 0, StructuralError("bad magic value in continuation file")
}
if err := bz2.setup(false); err != nil {
return 0, err
}
}
}
}
// readBlock reads a bzip2 block. The magic number should already have been consumed.
func (bz2 *reader) readBlock() (err error) {
br := &bz2.br
bz2.wantBlockCRC = uint32(br.ReadBits64(32)) // skip checksum. TODO: check it if we can figure out what it is.
bz2.blockCRC = 0
bz2.fileCRC = (bz2.fileCRC<<1 | bz2.fileCRC>>31) ^ bz2.wantBlockCRC
randomized := br.ReadBits(1)
if randomized != 0 {
return StructuralError("deprecated randomized files")
}
origPtr := uint(br.ReadBits(24))
// If not every byte value is used in the block (i.e., it's text) then
// the symbol set is reduced. The symbols used are stored as a
// two-level, 16x16 bitmap.
symbolRangeUsedBitmap := br.ReadBits(16)
symbolPresent := make([]bool, 256)
numSymbols := 0
for symRange := uint(0); symRange < 16; symRange++ {
if symbolRangeUsedBitmap&(1<<(15-symRange)) != 0 {
bits := br.ReadBits(16)
for symbol := uint(0); symbol < 16; symbol++ {
if bits&(1<<(15-symbol)) != 0 {
symbolPresent[16*symRange+symbol] = true
numSymbols++
}
}
}
}
if numSymbols == 0 {
// There must be an EOF symbol.
return StructuralError("no symbols in input")
}
// A block uses between two and six different Huffman trees.
numHuffmanTrees := br.ReadBits(3)
if numHuffmanTrees < 2 || numHuffmanTrees > 6 {
return StructuralError("invalid number of Huffman trees")
}
// The Huffman tree can switch every 50 symbols so there's a list of
// tree indexes telling us which tree to use for each 50 symbol block.
numSelectors := br.ReadBits(15)
treeIndexes := make([]uint8, numSelectors)
// The tree indexes are move-to-front transformed and stored as unary
// numbers.
mtfTreeDecoder := newMTFDecoderWithRange(numHuffmanTrees)
for i := range treeIndexes {
c := 0
for {
inc := br.ReadBits(1)
if inc == 0 {
break
}
c++
}
if c >= numHuffmanTrees {
return StructuralError("tree index too large")
}
treeIndexes[i] = mtfTreeDecoder.Decode(c)
}
// The list of symbols for the move-to-front transform is taken from
// the previously decoded symbol bitmap.
symbols := make([]byte, numSymbols)
nextSymbol := 0
for i := 0; i < 256; i++ {
if symbolPresent[i] {
symbols[nextSymbol] = byte(i)
nextSymbol++
}
}
mtf := newMTFDecoder(symbols)
numSymbols += 2 // to account for RUNA and RUNB symbols
huffmanTrees := make([]huffmanTree, numHuffmanTrees)
// Now we decode the arrays of code-lengths for each tree.
lengths := make([]uint8, numSymbols)
for i := range huffmanTrees {
// The code lengths are delta encoded from a 5-bit base value.
length := br.ReadBits(5)
for j := range lengths {
for {
if length < 1 || length > 20 {
return StructuralError("Huffman length out of range")
}
if !br.ReadBit() {
break
}
if br.ReadBit() {
length--
} else {
length++
}
}
lengths[j] = uint8(length)
}
huffmanTrees[i], err = newHuffmanTree(lengths)
if err != nil {
return err
}
}
selectorIndex := 1 // the next tree index to use
if len(treeIndexes) == 0 {
return StructuralError("no tree selectors given")
}
if int(treeIndexes[0]) >= len(huffmanTrees) {
return StructuralError("tree selector out of range")
}
currentHuffmanTree := huffmanTrees[treeIndexes[0]]
bufIndex := 0 // indexes bz2.buf, the output buffer.
// The output of the move-to-front transform is run-length encoded and
// we merge the decoding into the Huffman parsing loop. These two
// variables accumulate the repeat count. See the Wikipedia page for
// details.
repeat := 0
repeatPower := 0
// The `C' array (used by the inverse BWT) needs to be zero initialized.
clear(bz2.c[:])
decoded := 0 // counts the number of symbols decoded by the current tree.
for {
if decoded == 50 {
if selectorIndex >= numSelectors {
return StructuralError("insufficient selector indices for number of symbols")
}
if int(treeIndexes[selectorIndex]) >= len(huffmanTrees) {
return StructuralError("tree selector out of range")
}
currentHuffmanTree = huffmanTrees[treeIndexes[selectorIndex]]
selectorIndex++
decoded = 0
}
v := currentHuffmanTree.Decode(br)
decoded++
if v < 2 {
// This is either the RUNA or RUNB symbol.
if repeat == 0 {
repeatPower = 1
}
repeat += repeatPower << v
repeatPower <<= 1
// This limit of 2 million comes from the bzip2 source
// code. It prevents repeat from overflowing.
if repeat > 2*1024*1024 {
return StructuralError("repeat count too large")
}
continue
}
if repeat > 0 {
// We have decoded a complete run-length so we need to
// replicate the last output symbol.
if repeat > bz2.blockSize-bufIndex {
return StructuralError("repeats past end of block")
}
for i := 0; i < repeat; i++ {
b := mtf.First()
bz2.tt[bufIndex] = uint32(b)
bz2.c[b]++
bufIndex++
}
repeat = 0
}
if int(v) == numSymbols-1 {
// This is the EOF symbol. Because it's always at the
// end of the move-to-front list, and never gets moved
// to the front, it has this unique value.
break
}
// Since two metasymbols (RUNA and RUNB) have values 0 and 1,
// one would expect |v-2| to be passed to the MTF decoder.
// However, the front of the MTF list is never referenced as 0,
// it's always referenced with a run-length of 1. Thus 0
// doesn't need to be encoded and we have |v-1| in the next
// line.
b := mtf.Decode(int(v - 1))
if bufIndex >= bz2.blockSize {
return StructuralError("data exceeds block size")
}
bz2.tt[bufIndex] = uint32(b)
bz2.c[b]++
bufIndex++
}
if origPtr >= uint(bufIndex) {
return StructuralError("origPtr out of bounds")
}
// We have completed the entropy decoding. Now we can perform the
// inverse BWT and setup the RLE buffer.
bz2.preRLE = bz2.tt[:bufIndex]
bz2.preRLEUsed = 0
bz2.tPos = inverseBWT(bz2.preRLE, origPtr, bz2.c[:])
bz2.lastByte = -1
bz2.byteRepeats = 0
bz2.repeats = 0
return nil
}
// inverseBWT implements the inverse Burrows-Wheeler transform as described in
// http://www.hpl.hp.com/techreports/Compaq-DEC/SRC-RR-124.pdf, section 4.2.
// In that document, origPtr is called “I” and c is the “C” array after the
// first pass over the data. It's an argument here because we merge the first
// pass with the Huffman decoding.
//
// This also implements the “single array” method from the bzip2 source code
// which leaves the output, still shuffled, in the bottom 8 bits of tt with the
// index of the next byte in the top 24-bits. The index of the first byte is
// returned.
func inverseBWT(tt []uint32, origPtr uint, c []uint) uint32 {
sum := uint(0)
for i := 0; i < 256; i++ {
sum += c[i]
c[i] = sum - c[i]
}
for i := range tt {
b := tt[i] & 0xff
tt[c[b]] |= uint32(i) << 8
c[b]++
}
return tt[origPtr] >> 8
}
// This is a standard CRC32 like in hash/crc32 except that all the shifts are reversed,
// causing the bits in the input to be processed in the reverse of the usual order.
var crctab [256]uint32
func init() {
const poly = 0x04C11DB7
for i := range crctab {
crc := uint32(i) << 24
for j := 0; j < 8; j++ {
if crc&0x80000000 != 0 {
crc = (crc << 1) ^ poly
} else {
crc <<= 1
}
}
crctab[i] = crc
}
}
// updateCRC updates the crc value to incorporate the data in b.
// The initial value is 0.
func updateCRC(val uint32, b []byte) uint32 {
crc := ^val
for _, v := range b {
crc = crctab[byte(crc>>24)^v] ^ (crc << 8)
}
return ^crc
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bzip2
import (
"cmp"
"slices"
)
// A huffmanTree is a binary tree which is navigated, bit-by-bit to reach a
// symbol.
type huffmanTree struct {
// nodes contains all the non-leaf nodes in the tree. nodes[0] is the
// root of the tree and nextNode contains the index of the next element
// of nodes to use when the tree is being constructed.
nodes []huffmanNode
nextNode int
}
// A huffmanNode is a node in the tree. left and right contain indexes into the
// nodes slice of the tree. If left or right is invalidNodeValue then the child
// is a left node and its value is in leftValue/rightValue.
//
// The symbols are uint16s because bzip2 encodes not only MTF indexes in the
// tree, but also two magic values for run-length encoding and an EOF symbol.
// Thus there are more than 256 possible symbols.
type huffmanNode struct {
left, right uint16
leftValue, rightValue uint16
}
// invalidNodeValue is an invalid index which marks a leaf node in the tree.
const invalidNodeValue = 0xffff
// Decode reads bits from the given bitReader and navigates the tree until a
// symbol is found.
func (t *huffmanTree) Decode(br *bitReader) (v uint16) {
nodeIndex := uint16(0) // node 0 is the root of the tree.
for {
node := &t.nodes[nodeIndex]
var bit uint16
if br.bits > 0 {
// Get next bit - fast path.
br.bits--
bit = uint16(br.n>>(br.bits&63)) & 1
} else {
// Get next bit - slow path.
// Use ReadBits to retrieve a single bit
// from the underling io.ByteReader.
bit = uint16(br.ReadBits(1))
}
// Trick a compiler into generating conditional move instead of branch,
// by making both loads unconditional.
l, r := node.left, node.right
if bit == 1 {
nodeIndex = l
} else {
nodeIndex = r
}
if nodeIndex == invalidNodeValue {
// We found a leaf. Use the value of bit to decide
// whether is a left or a right value.
l, r := node.leftValue, node.rightValue
if bit == 1 {
v = l
} else {
v = r
}
return
}
}
}
// newHuffmanTree builds a Huffman tree from a slice containing the code
// lengths of each symbol. The maximum code length is 32 bits.
func newHuffmanTree(lengths []uint8) (huffmanTree, error) {
// There are many possible trees that assign the same code length to
// each symbol (consider reflecting a tree down the middle, for
// example). Since the code length assignments determine the
// efficiency of the tree, each of these trees is equally good. In
// order to minimize the amount of information needed to build a tree
// bzip2 uses a canonical tree so that it can be reconstructed given
// only the code length assignments.
if len(lengths) < 2 {
panic("newHuffmanTree: too few symbols")
}
var t huffmanTree
// First we sort the code length assignments by ascending code length,
// using the symbol value to break ties.
pairs := make([]huffmanSymbolLengthPair, len(lengths))
for i, length := range lengths {
pairs[i].value = uint16(i)
pairs[i].length = length
}
slices.SortFunc(pairs, func(a, b huffmanSymbolLengthPair) int {
if c := cmp.Compare(a.length, b.length); c != 0 {
return c
}
return cmp.Compare(a.value, b.value)
})
// Now we assign codes to the symbols, starting with the longest code.
// We keep the codes packed into a uint32, at the most-significant end.
// So branches are taken from the MSB downwards. This makes it easy to
// sort them later.
code := uint32(0)
length := uint8(32)
codes := make([]huffmanCode, len(lengths))
for i := len(pairs) - 1; i >= 0; i-- {
if length > pairs[i].length {
length = pairs[i].length
}
codes[i].code = code
codes[i].codeLen = length
codes[i].value = pairs[i].value
// We need to 'increment' the code, which means treating |code|
// like a |length| bit number.
code += 1 << (32 - length)
}
// Now we can sort by the code so that the left half of each branch are
// grouped together, recursively.
slices.SortFunc(codes, func(a, b huffmanCode) int {
return cmp.Compare(a.code, b.code)
})
t.nodes = make([]huffmanNode, len(codes))
_, err := buildHuffmanNode(&t, codes, 0)
return t, err
}
// huffmanSymbolLengthPair contains a symbol and its code length.
type huffmanSymbolLengthPair struct {
value uint16
length uint8
}
// huffmanCode contains a symbol, its code and code length.
type huffmanCode struct {
code uint32
codeLen uint8
value uint16
}
// buildHuffmanNode takes a slice of sorted huffmanCodes and builds a node in
// the Huffman tree at the given level. It returns the index of the newly
// constructed node.
func buildHuffmanNode(t *huffmanTree, codes []huffmanCode, level uint32) (nodeIndex uint16, err error) {
test := uint32(1) << (31 - level)
// We have to search the list of codes to find the divide between the left and right sides.
firstRightIndex := len(codes)
for i, code := range codes {
if code.code&test != 0 {
firstRightIndex = i
break
}
}
left := codes[:firstRightIndex]
right := codes[firstRightIndex:]
if len(left) == 0 || len(right) == 0 {
// There is a superfluous level in the Huffman tree indicating
// a bug in the encoder. However, this bug has been observed in
// the wild so we handle it.
// If this function was called recursively then we know that
// len(codes) >= 2 because, otherwise, we would have hit the
// "leaf node" case, below, and not recurred.
//
// However, for the initial call it's possible that len(codes)
// is zero or one. Both cases are invalid because a zero length
// tree cannot encode anything and a length-1 tree can only
// encode EOF and so is superfluous. We reject both.
if len(codes) < 2 {
return 0, StructuralError("empty Huffman tree")
}
// In this case the recursion doesn't always reduce the length
// of codes so we need to ensure termination via another
// mechanism.
if level == 31 {
// Since len(codes) >= 2 the only way that the values
// can match at all 32 bits is if they are equal, which
// is invalid. This ensures that we never enter
// infinite recursion.
return 0, StructuralError("equal symbols in Huffman tree")
}
if len(left) == 0 {
return buildHuffmanNode(t, right, level+1)
}
return buildHuffmanNode(t, left, level+1)
}
nodeIndex = uint16(t.nextNode)
node := &t.nodes[t.nextNode]
t.nextNode++
if len(left) == 1 {
// leaf node
node.left = invalidNodeValue
node.leftValue = left[0].value
} else {
node.left, err = buildHuffmanNode(t, left, level+1)
}
if err != nil {
return
}
if len(right) == 1 {
// leaf node
node.right = invalidNodeValue
node.rightValue = right[0].value
} else {
node.right, err = buildHuffmanNode(t, right, level+1)
}
return
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bzip2
// moveToFrontDecoder implements a move-to-front list. Such a list is an
// efficient way to transform a string with repeating elements into one with
// many small valued numbers, which is suitable for entropy encoding. It works
// by starting with an initial list of symbols and references symbols by their
// index into that list. When a symbol is referenced, it's moved to the front
// of the list. Thus, a repeated symbol ends up being encoded with many zeros,
// as the symbol will be at the front of the list after the first access.
type moveToFrontDecoder []byte
// newMTFDecoder creates a move-to-front decoder with an explicit initial list
// of symbols.
func newMTFDecoder(symbols []byte) moveToFrontDecoder {
if len(symbols) > 256 {
panic("too many symbols")
}
return moveToFrontDecoder(symbols)
}
// newMTFDecoderWithRange creates a move-to-front decoder with an initial
// symbol list of 0...n-1.
func newMTFDecoderWithRange(n int) moveToFrontDecoder {
if n > 256 {
panic("newMTFDecoderWithRange: cannot have > 256 symbols")
}
m := make([]byte, n)
for i := 0; i < n; i++ {
m[i] = byte(i)
}
return moveToFrontDecoder(m)
}
func (m moveToFrontDecoder) Decode(n int) (b byte) {
// Implement move-to-front with a simple copy. This approach
// beats more sophisticated approaches in benchmarking, probably
// because it has high locality of reference inside of a
// single cache line (most move-to-front operations have n < 64).
b = m[n]
copy(m[1:], m[:n])
m[0] = b
return
}
// First returns the symbol at the front of the list.
func (m moveToFrontDecoder) First() byte {
return m[0]
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flate
import (
"errors"
"fmt"
"io"
"math"
)
const (
NoCompression = 0
BestSpeed = 1
BestCompression = 9
DefaultCompression = -1
// HuffmanOnly disables Lempel-Ziv match searching and only performs Huffman
// entropy encoding. This mode is useful in compressing data that has
// already been compressed with an LZ style algorithm (e.g. Snappy or LZ4)
// that lacks an entropy encoder. Compression gains are achieved when
// certain bytes in the input stream occur more frequently than others.
//
// Note that HuffmanOnly produces a compressed output that is
// RFC 1951 compliant. That is, any valid DEFLATE decompressor will
// continue to be able to decompress this output.
HuffmanOnly = -2
)
const (
logWindowSize = 15
windowSize = 1 << logWindowSize
windowMask = windowSize - 1
// The LZ77 step produces a sequence of literal tokens and <length, offset>
// pair tokens. The offset is also known as distance. The underlying wire
// format limits the range of lengths and offsets. For example, there are
// 256 legitimate lengths: those in the range [3, 258]. This package's
// compressor uses a higher minimum match length, enabling optimizations
// such as finding matches via 32-bit loads and compares.
baseMatchLength = 3 // The smallest match length per the RFC section 3.2.5
minMatchLength = 4 // The smallest match length that the compressor actually emits
maxMatchLength = 258 // The largest match length
baseMatchOffset = 1 // The smallest match offset
maxMatchOffset = 1 << 15 // The largest match offset
// The maximum number of tokens we put into a single flate block, just to
// stop things from getting too large.
maxFlateBlockTokens = 1 << 14
maxStoreBlockSize = 65535
hashBits = 17 // After 17 performance degrades
hashSize = 1 << hashBits
hashMask = (1 << hashBits) - 1
maxHashOffset = 1 << 24
skipNever = math.MaxInt32
)
type compressionLevel struct {
level, good, lazy, nice, chain, fastSkipHashing int
}
var levels = []compressionLevel{
{0, 0, 0, 0, 0, 0}, // NoCompression.
{1, 0, 0, 0, 0, 0}, // BestSpeed uses a custom algorithm; see deflatefast.go.
// For levels 2-3 we don't bother trying with lazy matches.
{2, 4, 0, 16, 8, 5},
{3, 4, 0, 32, 32, 6},
// Levels 4-9 use increasingly more lazy matching
// and increasingly stringent conditions for "good enough".
{4, 4, 4, 16, 16, skipNever},
{5, 8, 16, 32, 32, skipNever},
{6, 8, 16, 128, 128, skipNever},
{7, 8, 32, 128, 256, skipNever},
{8, 32, 128, 258, 1024, skipNever},
{9, 32, 258, 258, 4096, skipNever},
}
type compressor struct {
compressionLevel
w *huffmanBitWriter
bulkHasher func([]byte, []uint32)
// compression algorithm
fill func(*compressor, []byte) int // copy data to window
step func(*compressor) // process window
bestSpeed *deflateFast // Encoder for BestSpeed
// Input hash chains
// hashHead[hashValue] contains the largest inputIndex with the specified hash value
// If hashHead[hashValue] is within the current window, then
// hashPrev[hashHead[hashValue] & windowMask] contains the previous index
// with the same hash value.
chainHead int
hashHead [hashSize]uint32
hashPrev [windowSize]uint32
hashOffset int
// input window: unprocessed data is window[index:windowEnd]
index int
window []byte
windowEnd int
blockStart int // window index where current tokens start
byteAvailable bool // if true, still need to process window[index-1].
sync bool // requesting flush
// queued output tokens
tokens []token
// deflate state
length int
offset int
maxInsertIndex int
err error
// hashMatch must be able to contain hashes for the maximum match length.
hashMatch [maxMatchLength - 1]uint32
}
func (d *compressor) fillDeflate(b []byte) int {
if d.index >= 2*windowSize-(minMatchLength+maxMatchLength) {
// shift the window by windowSize
copy(d.window, d.window[windowSize:2*windowSize])
d.index -= windowSize
d.windowEnd -= windowSize
if d.blockStart >= windowSize {
d.blockStart -= windowSize
} else {
d.blockStart = math.MaxInt32
}
d.hashOffset += windowSize
if d.hashOffset > maxHashOffset {
delta := d.hashOffset - 1
d.hashOffset -= delta
d.chainHead -= delta
// Iterate over slices instead of arrays to avoid copying
// the entire table onto the stack (Issue #18625).
for i, v := range d.hashPrev[:] {
if int(v) > delta {
d.hashPrev[i] = uint32(int(v) - delta)
} else {
d.hashPrev[i] = 0
}
}
for i, v := range d.hashHead[:] {
if int(v) > delta {
d.hashHead[i] = uint32(int(v) - delta)
} else {
d.hashHead[i] = 0
}
}
}
}
n := copy(d.window[d.windowEnd:], b)
d.windowEnd += n
return n
}
func (d *compressor) writeBlock(tokens []token, index int) error {
if index > 0 {
var window []byte
if d.blockStart <= index {
window = d.window[d.blockStart:index]
}
d.blockStart = index
d.w.writeBlock(tokens, false, window)
return d.w.err
}
return nil
}
// fillWindow will fill the current window with the supplied
// dictionary and calculate all hashes.
// This is much faster than doing a full encode.
// Should only be used after a reset.
func (d *compressor) fillWindow(b []byte) {
// Do not fill window if we are in store-only mode.
if d.compressionLevel.level < 2 {
return
}
if d.index != 0 || d.windowEnd != 0 {
panic("internal error: fillWindow called with stale data")
}
// If we are given too much, cut it.
if len(b) > windowSize {
b = b[len(b)-windowSize:]
}
// Add all to window.
n := copy(d.window, b)
// Calculate 256 hashes at the time (more L1 cache hits)
loops := (n + 256 - minMatchLength) / 256
for j := 0; j < loops; j++ {
index := j * 256
end := index + 256 + minMatchLength - 1
if end > n {
end = n
}
toCheck := d.window[index:end]
dstSize := len(toCheck) - minMatchLength + 1
if dstSize <= 0 {
continue
}
dst := d.hashMatch[:dstSize]
d.bulkHasher(toCheck, dst)
for i, val := range dst {
di := i + index
hh := &d.hashHead[val&hashMask]
// Get previous value with the same hash.
// Our chain should point to the previous value.
d.hashPrev[di&windowMask] = *hh
// Set the head of the hash chain to us.
*hh = uint32(di + d.hashOffset)
}
}
// Update window information.
d.windowEnd = n
d.index = n
}
// Try to find a match starting at index whose length is greater than prevSize.
// We only look at chainCount possibilities before giving up.
func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead int) (length, offset int, ok bool) {
minMatchLook := maxMatchLength
if lookahead < minMatchLook {
minMatchLook = lookahead
}
win := d.window[0 : pos+minMatchLook]
// We quit when we get a match that's at least nice long
nice := len(win) - pos
if d.nice < nice {
nice = d.nice
}
// If we've got a match that's good enough, only look in 1/4 the chain.
tries := d.chain
length = prevLength
if length >= d.good {
tries >>= 2
}
wEnd := win[pos+length]
wPos := win[pos:]
minIndex := pos - windowSize
for i := prevHead; tries > 0; tries-- {
if wEnd == win[i+length] {
n := matchLen(win[i:], wPos, minMatchLook)
if n > length && (n > minMatchLength || pos-i <= 4096) {
length = n
offset = pos - i
ok = true
if n >= nice {
// The match is good enough that we don't try to find a better one.
break
}
wEnd = win[pos+n]
}
}
if i == minIndex {
// hashPrev[i & windowMask] has already been overwritten, so stop now.
break
}
i = int(d.hashPrev[i&windowMask]) - d.hashOffset
if i < minIndex || i < 0 {
break
}
}
return
}
func (d *compressor) writeStoredBlock(buf []byte) error {
if d.w.writeStoredHeader(len(buf), false); d.w.err != nil {
return d.w.err
}
d.w.writeBytes(buf)
return d.w.err
}
const hashmul = 0x1e35a7bd
// hash4 returns a hash representation of the first 4 bytes
// of the supplied slice.
// The caller must ensure that len(b) >= 4.
func hash4(b []byte) uint32 {
return ((uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24) * hashmul) >> (32 - hashBits)
}
// bulkHash4 will compute hashes using the same
// algorithm as hash4.
func bulkHash4(b []byte, dst []uint32) {
if len(b) < minMatchLength {
return
}
hb := uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
dst[0] = (hb * hashmul) >> (32 - hashBits)
end := len(b) - minMatchLength + 1
for i := 1; i < end; i++ {
hb = (hb << 8) | uint32(b[i+3])
dst[i] = (hb * hashmul) >> (32 - hashBits)
}
}
// matchLen returns the number of matching bytes in a and b
// up to length 'max'. Both slices must be at least 'max'
// bytes in size.
func matchLen(a, b []byte, max int) int {
a = a[:max]
b = b[:len(a)]
for i, av := range a {
if b[i] != av {
return i
}
}
return max
}
// encSpeed will compress and store the currently added data,
// if enough has been accumulated or we at the end of the stream.
// Any error that occurred will be in d.err
func (d *compressor) encSpeed() {
// We only compress if we have maxStoreBlockSize.
if d.windowEnd < maxStoreBlockSize {
if !d.sync {
return
}
// Handle small sizes.
if d.windowEnd < 128 {
switch {
case d.windowEnd == 0:
return
case d.windowEnd <= 16:
d.err = d.writeStoredBlock(d.window[:d.windowEnd])
default:
d.w.writeBlockHuff(false, d.window[:d.windowEnd])
d.err = d.w.err
}
d.windowEnd = 0
d.bestSpeed.reset()
return
}
}
// Encode the block.
d.tokens = d.bestSpeed.encode(d.tokens[:0], d.window[:d.windowEnd])
// If we removed less than 1/16th, Huffman compress the block.
if len(d.tokens) > d.windowEnd-(d.windowEnd>>4) {
d.w.writeBlockHuff(false, d.window[:d.windowEnd])
} else {
d.w.writeBlockDynamic(d.tokens, false, d.window[:d.windowEnd])
}
d.err = d.w.err
d.windowEnd = 0
}
func (d *compressor) initDeflate() {
d.window = make([]byte, 2*windowSize)
d.hashOffset = 1
d.tokens = make([]token, 0, maxFlateBlockTokens+1)
d.length = minMatchLength - 1
d.offset = 0
d.byteAvailable = false
d.index = 0
d.chainHead = -1
d.bulkHasher = bulkHash4
}
func (d *compressor) deflate() {
if d.windowEnd-d.index < minMatchLength+maxMatchLength && !d.sync {
return
}
d.maxInsertIndex = d.windowEnd - (minMatchLength - 1)
Loop:
for {
if d.index > d.windowEnd {
panic("index > windowEnd")
}
lookahead := d.windowEnd - d.index
if lookahead < minMatchLength+maxMatchLength {
if !d.sync {
break Loop
}
if d.index > d.windowEnd {
panic("index > windowEnd")
}
if lookahead == 0 {
// Flush current output block if any.
if d.byteAvailable {
// There is still one pending token that needs to be flushed
d.tokens = append(d.tokens, literalToken(uint32(d.window[d.index-1])))
d.byteAvailable = false
}
if len(d.tokens) > 0 {
if d.err = d.writeBlock(d.tokens, d.index); d.err != nil {
return
}
d.tokens = d.tokens[:0]
}
break Loop
}
}
if d.index < d.maxInsertIndex {
// Update the hash
hash := hash4(d.window[d.index : d.index+minMatchLength])
hh := &d.hashHead[hash&hashMask]
d.chainHead = int(*hh)
d.hashPrev[d.index&windowMask] = uint32(d.chainHead)
*hh = uint32(d.index + d.hashOffset)
}
prevLength := d.length
prevOffset := d.offset
d.length = minMatchLength - 1
d.offset = 0
minIndex := d.index - windowSize
if minIndex < 0 {
minIndex = 0
}
if d.chainHead-d.hashOffset >= minIndex &&
(d.fastSkipHashing != skipNever && lookahead > minMatchLength-1 ||
d.fastSkipHashing == skipNever && lookahead > prevLength && prevLength < d.lazy) {
if newLength, newOffset, ok := d.findMatch(d.index, d.chainHead-d.hashOffset, minMatchLength-1, lookahead); ok {
d.length = newLength
d.offset = newOffset
}
}
if d.fastSkipHashing != skipNever && d.length >= minMatchLength ||
d.fastSkipHashing == skipNever && prevLength >= minMatchLength && d.length <= prevLength {
// There was a match at the previous step, and the current match is
// not better. Output the previous match.
if d.fastSkipHashing != skipNever {
d.tokens = append(d.tokens, matchToken(uint32(d.length-baseMatchLength), uint32(d.offset-baseMatchOffset)))
} else {
d.tokens = append(d.tokens, matchToken(uint32(prevLength-baseMatchLength), uint32(prevOffset-baseMatchOffset)))
}
// Insert in the hash table all strings up to the end of the match.
// index and index-1 are already inserted. If there is not enough
// lookahead, the last two strings are not inserted into the hash
// table.
if d.length <= d.fastSkipHashing {
var newIndex int
if d.fastSkipHashing != skipNever {
newIndex = d.index + d.length
} else {
newIndex = d.index + prevLength - 1
}
index := d.index
for index++; index < newIndex; index++ {
if index < d.maxInsertIndex {
hash := hash4(d.window[index : index+minMatchLength])
// Get previous value with the same hash.
// Our chain should point to the previous value.
hh := &d.hashHead[hash&hashMask]
d.hashPrev[index&windowMask] = *hh
// Set the head of the hash chain to us.
*hh = uint32(index + d.hashOffset)
}
}
d.index = index
if d.fastSkipHashing == skipNever {
d.byteAvailable = false
d.length = minMatchLength - 1
}
} else {
// For matches this long, we don't bother inserting each individual
// item into the table.
d.index += d.length
}
if len(d.tokens) == maxFlateBlockTokens {
// The block includes the current character
if d.err = d.writeBlock(d.tokens, d.index); d.err != nil {
return
}
d.tokens = d.tokens[:0]
}
} else {
if d.fastSkipHashing != skipNever || d.byteAvailable {
i := d.index - 1
if d.fastSkipHashing != skipNever {
i = d.index
}
d.tokens = append(d.tokens, literalToken(uint32(d.window[i])))
if len(d.tokens) == maxFlateBlockTokens {
if d.err = d.writeBlock(d.tokens, i+1); d.err != nil {
return
}
d.tokens = d.tokens[:0]
}
}
d.index++
if d.fastSkipHashing == skipNever {
d.byteAvailable = true
}
}
}
}
func (d *compressor) fillStore(b []byte) int {
n := copy(d.window[d.windowEnd:], b)
d.windowEnd += n
return n
}
func (d *compressor) store() {
if d.windowEnd > 0 && (d.windowEnd == maxStoreBlockSize || d.sync) {
d.err = d.writeStoredBlock(d.window[:d.windowEnd])
d.windowEnd = 0
}
}
// storeHuff compresses and stores the currently added data
// when the d.window is full or we are at the end of the stream.
// Any error that occurred will be in d.err
func (d *compressor) storeHuff() {
if d.windowEnd < len(d.window) && !d.sync || d.windowEnd == 0 {
return
}
d.w.writeBlockHuff(false, d.window[:d.windowEnd])
d.err = d.w.err
d.windowEnd = 0
}
func (d *compressor) write(b []byte) (n int, err error) {
if d.err != nil {
return 0, d.err
}
n = len(b)
for len(b) > 0 {
d.step(d)
b = b[d.fill(d, b):]
if d.err != nil {
return 0, d.err
}
}
return n, nil
}
func (d *compressor) syncFlush() error {
if d.err != nil {
return d.err
}
d.sync = true
d.step(d)
if d.err == nil {
d.w.writeStoredHeader(0, false)
d.w.flush()
d.err = d.w.err
}
d.sync = false
return d.err
}
func (d *compressor) init(w io.Writer, level int) (err error) {
d.w = newHuffmanBitWriter(w)
switch {
case level == NoCompression:
d.window = make([]byte, maxStoreBlockSize)
d.fill = (*compressor).fillStore
d.step = (*compressor).store
case level == HuffmanOnly:
d.window = make([]byte, maxStoreBlockSize)
d.fill = (*compressor).fillStore
d.step = (*compressor).storeHuff
case level == BestSpeed:
d.compressionLevel = levels[level]
d.window = make([]byte, maxStoreBlockSize)
d.fill = (*compressor).fillStore
d.step = (*compressor).encSpeed
d.bestSpeed = newDeflateFast()
d.tokens = make([]token, maxStoreBlockSize)
case level == DefaultCompression:
level = 6
fallthrough
case 2 <= level && level <= 9:
d.compressionLevel = levels[level]
d.initDeflate()
d.fill = (*compressor).fillDeflate
d.step = (*compressor).deflate
default:
return fmt.Errorf("flate: invalid compression level %d: want value in range [-2, 9]", level)
}
return nil
}
func (d *compressor) reset(w io.Writer) {
d.w.reset(w)
d.sync = false
d.err = nil
switch d.compressionLevel.level {
case NoCompression:
d.windowEnd = 0
case BestSpeed:
d.windowEnd = 0
d.tokens = d.tokens[:0]
d.bestSpeed.reset()
default:
d.chainHead = -1
clear(d.hashHead[:])
clear(d.hashPrev[:])
d.hashOffset = 1
d.index, d.windowEnd = 0, 0
d.blockStart, d.byteAvailable = 0, false
d.tokens = d.tokens[:0]
d.length = minMatchLength - 1
d.offset = 0
d.maxInsertIndex = 0
}
}
func (d *compressor) close() error {
if d.err == errWriterClosed {
return nil
}
if d.err != nil {
return d.err
}
d.sync = true
d.step(d)
if d.err != nil {
return d.err
}
if d.w.writeStoredHeader(0, true); d.w.err != nil {
return d.w.err
}
d.w.flush()
if d.w.err != nil {
return d.w.err
}
d.err = errWriterClosed
return nil
}
// NewWriter returns a new [Writer] compressing data at the given level.
// Following zlib, levels range from 1 ([BestSpeed]) to 9 ([BestCompression]);
// higher levels typically run slower but compress more. Level 0
// ([NoCompression]) does not attempt any compression; it only adds the
// necessary DEFLATE framing.
// Level -1 ([DefaultCompression]) uses the default compression level.
// Level -2 ([HuffmanOnly]) will use Huffman compression only, giving
// a very fast compression for all types of input, but sacrificing considerable
// compression efficiency.
//
// If level is in the range [-2, 9] then the error returned will be nil.
// Otherwise the error returned will be non-nil.
func NewWriter(w io.Writer, level int) (*Writer, error) {
var dw Writer
if err := dw.d.init(w, level); err != nil {
return nil, err
}
return &dw, nil
}
// NewWriterDict is like [NewWriter] but initializes the new
// [Writer] with a preset dictionary. The returned [Writer] behaves
// as if the dictionary had been written to it without producing
// any compressed output. The compressed data written to w
// can only be decompressed by a reader initialized with the
// same dictionary (see [NewReaderDict]).
func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
dw := &dictWriter{w}
zw, err := NewWriter(dw, level)
if err != nil {
return nil, err
}
zw.d.fillWindow(dict)
zw.dict = append(zw.dict, dict...) // duplicate dictionary for Reset method.
return zw, nil
}
type dictWriter struct {
w io.Writer
}
func (w *dictWriter) Write(b []byte) (n int, err error) {
return w.w.Write(b)
}
var errWriterClosed = errors.New("flate: closed writer")
// A Writer takes data written to it and writes the compressed
// form of that data to an underlying writer (see [NewWriter]).
type Writer struct {
d compressor
dict []byte
}
// Write writes data to w, which will eventually write the
// compressed form of data to its underlying writer.
func (w *Writer) Write(data []byte) (n int, err error) {
return w.d.write(data)
}
// Flush flushes any pending data to the underlying writer.
// It is useful mainly in compressed network protocols, to ensure that
// a remote reader has enough data to reconstruct a packet.
// Flush does not return until the data has been written.
// Calling Flush when there is no pending data still causes the [Writer]
// to emit a sync marker of at least 4 bytes.
// If the underlying writer returns an error, Flush returns that error.
//
// In the terminology of the zlib library, Flush is equivalent to Z_SYNC_FLUSH.
func (w *Writer) Flush() error {
// For more about flushing:
// https://www.bolet.org/~pornin/deflate-flush.html
return w.d.syncFlush()
}
// Close flushes and closes the writer.
func (w *Writer) Close() error {
return w.d.close()
}
// Reset discards the writer's state and makes it equivalent to
// the result of [NewWriter] or [NewWriterDict] called with dst
// and w's level and dictionary.
func (w *Writer) Reset(dst io.Writer) {
if dw, ok := w.d.w.writer.(*dictWriter); ok {
// w was created with NewWriterDict
dw.w = dst
w.d.reset(dw)
w.d.fillWindow(w.dict)
} else {
// w was created with NewWriter
w.d.reset(dst)
}
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flate
import "math"
// This encoding algorithm, which prioritizes speed over output size, is
// based on Snappy's LZ77-style encoder: github.com/golang/snappy
const (
tableBits = 14 // Bits used in the table.
tableSize = 1 << tableBits // Size of the table.
tableMask = tableSize - 1 // Mask for table indices. Redundant, but can eliminate bounds checks.
tableShift = 32 - tableBits // Right-shift to get the tableBits most significant bits of a uint32.
// Reset the buffer offset when reaching this.
// Offsets are stored between blocks as int32 values.
// Since the offset we are checking against is at the beginning
// of the buffer, we need to subtract the current and input
// buffer to not risk overflowing the int32.
bufferReset = math.MaxInt32 - maxStoreBlockSize*2
)
func load32(b []byte, i int32) uint32 {
b = b[i : i+4 : len(b)] // Help the compiler eliminate bounds checks on the next line.
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
}
func load64(b []byte, i int32) uint64 {
b = b[i : i+8 : len(b)] // Help the compiler eliminate bounds checks on the next line.
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
}
func hash(u uint32) uint32 {
return (u * 0x1e35a7bd) >> tableShift
}
// These constants are defined by the Snappy implementation so that its
// assembly implementation can fast-path some 16-bytes-at-a-time copies. They
// aren't necessary in the pure Go implementation, as we don't use those same
// optimizations, but using the same thresholds doesn't really hurt.
const (
inputMargin = 16 - 1
minNonLiteralBlockSize = 1 + 1 + inputMargin
)
type tableEntry struct {
val uint32 // Value at destination
offset int32
}
// deflateFast maintains the table for matches,
// and the previous byte block for cross block matching.
type deflateFast struct {
table [tableSize]tableEntry
prev []byte // Previous block, zero length if unknown.
cur int32 // Current match offset.
}
func newDeflateFast() *deflateFast {
return &deflateFast{cur: maxStoreBlockSize, prev: make([]byte, 0, maxStoreBlockSize)}
}
// encode encodes a block given in src and appends tokens
// to dst and returns the result.
func (e *deflateFast) encode(dst []token, src []byte) []token {
// Ensure that e.cur doesn't wrap.
if e.cur >= bufferReset {
e.shiftOffsets()
}
// This check isn't in the Snappy implementation, but there, the caller
// instead of the callee handles this case.
if len(src) < minNonLiteralBlockSize {
e.cur += maxStoreBlockSize
e.prev = e.prev[:0]
return emitLiteral(dst, src)
}
// sLimit is when to stop looking for offset/length copies. The inputMargin
// lets us use a fast path for emitLiteral in the main loop, while we are
// looking for copies.
sLimit := int32(len(src) - inputMargin)
// nextEmit is where in src the next emitLiteral should start from.
nextEmit := int32(0)
s := int32(0)
cv := load32(src, s)
nextHash := hash(cv)
for {
// Copied from the C++ snappy implementation:
//
// Heuristic match skipping: If 32 bytes are scanned with no matches
// found, start looking only at every other byte. If 32 more bytes are
// scanned (or skipped), look at every third byte, etc.. When a match
// is found, immediately go back to looking at every byte. This is a
// small loss (~5% performance, ~0.1% density) for compressible data
// due to more bookkeeping, but for non-compressible data (such as
// JPEG) it's a huge win since the compressor quickly "realizes" the
// data is incompressible and doesn't bother looking for matches
// everywhere.
//
// The "skip" variable keeps track of how many bytes there are since
// the last match; dividing it by 32 (ie. right-shifting by five) gives
// the number of bytes to move ahead for each iteration.
skip := int32(32)
nextS := s
var candidate tableEntry
for {
s = nextS
bytesBetweenHashLookups := skip >> 5
nextS = s + bytesBetweenHashLookups
skip += bytesBetweenHashLookups
if nextS > sLimit {
goto emitRemainder
}
candidate = e.table[nextHash&tableMask]
now := load32(src, nextS)
e.table[nextHash&tableMask] = tableEntry{offset: s + e.cur, val: cv}
nextHash = hash(now)
offset := s - (candidate.offset - e.cur)
if offset > maxMatchOffset || cv != candidate.val {
// Out of range or not matched.
cv = now
continue
}
break
}
// A 4-byte match has been found. We'll later see if more than 4 bytes
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
// them as literal bytes.
dst = emitLiteral(dst, src[nextEmit:s])
// Call emitCopy, and then see if another emitCopy could be our next
// move. Repeat until we find no match for the input immediately after
// what was consumed by the last emitCopy call.
//
// If we exit this loop normally then we need to call emitLiteral next,
// though we don't yet know how big the literal will be. We handle that
// by proceeding to the next iteration of the main loop. We also can
// exit this loop via goto if we get close to exhausting the input.
for {
// Invariant: we have a 4-byte match at s, and no need to emit any
// literal bytes prior to s.
// Extend the 4-byte match as long as possible.
//
s += 4
t := candidate.offset - e.cur + 4
l := e.matchLen(s, t, src)
// matchToken is flate's equivalent of Snappy's emitCopy. (length,offset)
dst = append(dst, matchToken(uint32(l+4-baseMatchLength), uint32(s-t-baseMatchOffset)))
s += l
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
// We could immediately start working at s now, but to improve
// compression we first update the hash table at s-1 and at s. If
// another emitCopy is not our next move, also calculate nextHash
// at s+1. At least on GOARCH=amd64, these three hash calculations
// are faster as one load64 call (with some shifts) instead of
// three load32 calls.
x := load64(src, s-1)
prevHash := hash(uint32(x))
e.table[prevHash&tableMask] = tableEntry{offset: e.cur + s - 1, val: uint32(x)}
x >>= 8
currHash := hash(uint32(x))
candidate = e.table[currHash&tableMask]
e.table[currHash&tableMask] = tableEntry{offset: e.cur + s, val: uint32(x)}
offset := s - (candidate.offset - e.cur)
if offset > maxMatchOffset || uint32(x) != candidate.val {
cv = uint32(x >> 8)
nextHash = hash(cv)
s++
break
}
}
}
emitRemainder:
if int(nextEmit) < len(src) {
dst = emitLiteral(dst, src[nextEmit:])
}
e.cur += int32(len(src))
e.prev = e.prev[:len(src)]
copy(e.prev, src)
return dst
}
func emitLiteral(dst []token, lit []byte) []token {
for _, v := range lit {
dst = append(dst, literalToken(uint32(v)))
}
return dst
}
// matchLen returns the match length between src[s:] and src[t:].
// t can be negative to indicate the match is starting in e.prev.
// We assume that src[s-4:s] and src[t-4:t] already match.
func (e *deflateFast) matchLen(s, t int32, src []byte) int32 {
s1 := int(s) + maxMatchLength - 4
if s1 > len(src) {
s1 = len(src)
}
// If we are inside the current block
if t >= 0 {
b := src[t:]
a := src[s:s1]
b = b[:len(a)]
// Extend the match to be as long as possible.
for i := range a {
if a[i] != b[i] {
return int32(i)
}
}
return int32(len(a))
}
// We found a match in the previous block.
tp := int32(len(e.prev)) + t
if tp < 0 {
return 0
}
// Extend the match to be as long as possible.
a := src[s:s1]
b := e.prev[tp:]
if len(b) > len(a) {
b = b[:len(a)]
}
a = a[:len(b)]
for i := range b {
if a[i] != b[i] {
return int32(i)
}
}
// If we reached our limit, we matched everything we are
// allowed to in the previous block and we return.
n := int32(len(b))
if int(s+n) == s1 {
return n
}
// Continue looking for more matches in the current block.
a = src[s+n : s1]
b = src[:len(a)]
for i := range a {
if a[i] != b[i] {
return int32(i) + n
}
}
return int32(len(a)) + n
}
// Reset resets the encoding history.
// This ensures that no matches are made to the previous block.
func (e *deflateFast) reset() {
e.prev = e.prev[:0]
// Bump the offset, so all matches will fail distance check.
// Nothing should be >= e.cur in the table.
e.cur += maxMatchOffset
// Protect against e.cur wraparound.
if e.cur >= bufferReset {
e.shiftOffsets()
}
}
// shiftOffsets will shift down all match offset.
// This is only called in rare situations to prevent integer overflow.
//
// See https://golang.org/issue/18636 and https://github.com/golang/go/issues/34121.
func (e *deflateFast) shiftOffsets() {
if len(e.prev) == 0 {
// We have no history; just clear the table.
clear(e.table[:])
e.cur = maxMatchOffset + 1
return
}
// Shift down everything in the table that isn't already too far away.
for i := range e.table[:] {
v := e.table[i].offset - e.cur + maxMatchOffset + 1
if v < 0 {
// We want to reset e.cur to maxMatchOffset + 1, so we need to shift
// all table entries down by (e.cur - (maxMatchOffset + 1)).
// Because we ignore matches > maxMatchOffset, we can cap
// any negative offsets at 0.
v = 0
}
e.table[i].offset = v
}
e.cur = maxMatchOffset + 1
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flate
// dictDecoder implements the LZ77 sliding dictionary as used in decompression.
// LZ77 decompresses data through sequences of two forms of commands:
//
// - Literal insertions: Runs of one or more symbols are inserted into the data
// stream as is. This is accomplished through the writeByte method for a
// single symbol, or combinations of writeSlice/writeMark for multiple symbols.
// Any valid stream must start with a literal insertion if no preset dictionary
// is used.
//
// - Backward copies: Runs of one or more symbols are copied from previously
// emitted data. Backward copies come as the tuple (dist, length) where dist
// determines how far back in the stream to copy from and length determines how
// many bytes to copy. Note that it is valid for the length to be greater than
// the distance. Since LZ77 uses forward copies, that situation is used to
// perform a form of run-length encoding on repeated runs of symbols.
// The writeCopy and tryWriteCopy are used to implement this command.
//
// For performance reasons, this implementation performs little to no sanity
// checks about the arguments. As such, the invariants documented for each
// method call must be respected.
type dictDecoder struct {
hist []byte // Sliding window history
// Invariant: 0 <= rdPos <= wrPos <= len(hist)
wrPos int // Current output position in buffer
rdPos int // Have emitted hist[:rdPos] already
full bool // Has a full window length been written yet?
}
// init initializes dictDecoder to have a sliding window dictionary of the given
// size. If a preset dict is provided, it will initialize the dictionary with
// the contents of dict.
func (dd *dictDecoder) init(size int, dict []byte) {
*dd = dictDecoder{hist: dd.hist}
if cap(dd.hist) < size {
dd.hist = make([]byte, size)
}
dd.hist = dd.hist[:size]
if len(dict) > len(dd.hist) {
dict = dict[len(dict)-len(dd.hist):]
}
dd.wrPos = copy(dd.hist, dict)
if dd.wrPos == len(dd.hist) {
dd.wrPos = 0
dd.full = true
}
dd.rdPos = dd.wrPos
}
// histSize reports the total amount of historical data in the dictionary.
func (dd *dictDecoder) histSize() int {
if dd.full {
return len(dd.hist)
}
return dd.wrPos
}
// availRead reports the number of bytes that can be flushed by readFlush.
func (dd *dictDecoder) availRead() int {
return dd.wrPos - dd.rdPos
}
// availWrite reports the available amount of output buffer space.
func (dd *dictDecoder) availWrite() int {
return len(dd.hist) - dd.wrPos
}
// writeSlice returns a slice of the available buffer to write data to.
//
// This invariant will be kept: len(s) <= availWrite()
func (dd *dictDecoder) writeSlice() []byte {
return dd.hist[dd.wrPos:]
}
// writeMark advances the writer pointer by cnt.
//
// This invariant must be kept: 0 <= cnt <= availWrite()
func (dd *dictDecoder) writeMark(cnt int) {
dd.wrPos += cnt
}
// writeByte writes a single byte to the dictionary.
//
// This invariant must be kept: 0 < availWrite()
func (dd *dictDecoder) writeByte(c byte) {
dd.hist[dd.wrPos] = c
dd.wrPos++
}
// writeCopy copies a string at a given (dist, length) to the output.
// This returns the number of bytes copied and may be less than the requested
// length if the available space in the output buffer is too small.
//
// This invariant must be kept: 0 < dist <= histSize()
func (dd *dictDecoder) writeCopy(dist, length int) int {
dstBase := dd.wrPos
dstPos := dstBase
srcPos := dstPos - dist
endPos := dstPos + length
if endPos > len(dd.hist) {
endPos = len(dd.hist)
}
// Copy non-overlapping section after destination position.
//
// This section is non-overlapping in that the copy length for this section
// is always less than or equal to the backwards distance. This can occur
// if a distance refers to data that wraps-around in the buffer.
// Thus, a backwards copy is performed here; that is, the exact bytes in
// the source prior to the copy is placed in the destination.
if srcPos < 0 {
srcPos += len(dd.hist)
dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:])
srcPos = 0
}
// Copy possibly overlapping section before destination position.
//
// This section can overlap if the copy length for this section is larger
// than the backwards distance. This is allowed by LZ77 so that repeated
// strings can be succinctly represented using (dist, length) pairs.
// Thus, a forwards copy is performed here; that is, the bytes copied is
// possibly dependent on the resulting bytes in the destination as the copy
// progresses along. This is functionally equivalent to the following:
//
// for i := 0; i < endPos-dstPos; i++ {
// dd.hist[dstPos+i] = dd.hist[srcPos+i]
// }
// dstPos = endPos
//
for dstPos < endPos {
dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:dstPos])
}
dd.wrPos = dstPos
return dstPos - dstBase
}
// tryWriteCopy tries to copy a string at a given (distance, length) to the
// output. This specialized version is optimized for short distances.
//
// This method is designed to be inlined for performance reasons.
//
// This invariant must be kept: 0 < dist <= histSize()
func (dd *dictDecoder) tryWriteCopy(dist, length int) int {
dstPos := dd.wrPos
endPos := dstPos + length
if dstPos < dist || endPos > len(dd.hist) {
return 0
}
dstBase := dstPos
srcPos := dstPos - dist
// Copy possibly overlapping section before destination position.
for dstPos < endPos {
dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:dstPos])
}
dd.wrPos = dstPos
return dstPos - dstBase
}
// readFlush returns a slice of the historical buffer that is ready to be
// emitted to the user. The data returned by readFlush must be fully consumed
// before calling any other dictDecoder methods.
func (dd *dictDecoder) readFlush() []byte {
toRead := dd.hist[dd.rdPos:dd.wrPos]
dd.rdPos = dd.wrPos
if dd.wrPos == len(dd.hist) {
dd.wrPos, dd.rdPos = 0, 0
dd.full = true
}
return toRead
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flate
import (
"io"
)
const (
// The largest offset code.
offsetCodeCount = 30
// The special code used to mark the end of a block.
endBlockMarker = 256
// The first length code.
lengthCodesStart = 257
// The number of codegen codes.
codegenCodeCount = 19
badCode = 255
// bufferFlushSize indicates the buffer size
// after which bytes are flushed to the writer.
// Should preferably be a multiple of 6, since
// we accumulate 6 bytes between writes to the buffer.
bufferFlushSize = 240
// bufferSize is the actual output byte buffer size.
// It must have additional headroom for a flush
// which can contain up to 8 bytes.
bufferSize = bufferFlushSize + 8
)
// The number of extra bits needed by length code X - LENGTH_CODES_START.
var lengthExtraBits = []int8{
/* 257 */ 0, 0, 0,
/* 260 */ 0, 0, 0, 0, 0, 1, 1, 1, 1, 2,
/* 270 */ 2, 2, 2, 3, 3, 3, 3, 4, 4, 4,
/* 280 */ 4, 5, 5, 5, 5, 0,
}
// The length indicated by length code X - LENGTH_CODES_START.
var lengthBase = []uint32{
0, 1, 2, 3, 4, 5, 6, 7, 8, 10,
12, 14, 16, 20, 24, 28, 32, 40, 48, 56,
64, 80, 96, 112, 128, 160, 192, 224, 255,
}
// offset code word extra bits.
var offsetExtraBits = []int8{
0, 0, 0, 0, 1, 1, 2, 2, 3, 3,
4, 4, 5, 5, 6, 6, 7, 7, 8, 8,
9, 9, 10, 10, 11, 11, 12, 12, 13, 13,
}
var offsetBase = []uint32{
0x000000, 0x000001, 0x000002, 0x000003, 0x000004,
0x000006, 0x000008, 0x00000c, 0x000010, 0x000018,
0x000020, 0x000030, 0x000040, 0x000060, 0x000080,
0x0000c0, 0x000100, 0x000180, 0x000200, 0x000300,
0x000400, 0x000600, 0x000800, 0x000c00, 0x001000,
0x001800, 0x002000, 0x003000, 0x004000, 0x006000,
}
// The odd order in which the codegen code sizes are written.
var codegenOrder = []uint32{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}
type huffmanBitWriter struct {
// writer is the underlying writer.
// Do not use it directly; use the write method, which ensures
// that Write errors are sticky.
writer io.Writer
// Data waiting to be written is bytes[0:nbytes]
// and then the low nbits of bits. Data is always written
// sequentially into the bytes array.
bits uint64
nbits uint
bytes [bufferSize]byte
codegenFreq [codegenCodeCount]int32
nbytes int
literalFreq []int32
offsetFreq []int32
codegen []uint8
literalEncoding *huffmanEncoder
offsetEncoding *huffmanEncoder
codegenEncoding *huffmanEncoder
err error
}
func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
return &huffmanBitWriter{
writer: w,
literalFreq: make([]int32, maxNumLit),
offsetFreq: make([]int32, offsetCodeCount),
codegen: make([]uint8, maxNumLit+offsetCodeCount+1),
literalEncoding: newHuffmanEncoder(maxNumLit),
codegenEncoding: newHuffmanEncoder(codegenCodeCount),
offsetEncoding: newHuffmanEncoder(offsetCodeCount),
}
}
func (w *huffmanBitWriter) reset(writer io.Writer) {
w.writer = writer
w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil
}
func (w *huffmanBitWriter) flush() {
if w.err != nil {
w.nbits = 0
return
}
n := w.nbytes
for w.nbits != 0 {
w.bytes[n] = byte(w.bits)
w.bits >>= 8
if w.nbits > 8 { // Avoid underflow
w.nbits -= 8
} else {
w.nbits = 0
}
n++
}
w.bits = 0
w.write(w.bytes[:n])
w.nbytes = 0
}
func (w *huffmanBitWriter) write(b []byte) {
if w.err != nil {
return
}
_, w.err = w.writer.Write(b)
}
func (w *huffmanBitWriter) writeBits(b int32, nb uint) {
if w.err != nil {
return
}
w.bits |= uint64(b) << w.nbits
w.nbits += nb
if w.nbits >= 48 {
bits := w.bits
w.bits >>= 48
w.nbits -= 48
n := w.nbytes
bytes := w.bytes[n : n+6]
bytes[0] = byte(bits)
bytes[1] = byte(bits >> 8)
bytes[2] = byte(bits >> 16)
bytes[3] = byte(bits >> 24)
bytes[4] = byte(bits >> 32)
bytes[5] = byte(bits >> 40)
n += 6
if n >= bufferFlushSize {
w.write(w.bytes[:n])
n = 0
}
w.nbytes = n
}
}
func (w *huffmanBitWriter) writeBytes(bytes []byte) {
if w.err != nil {
return
}
n := w.nbytes
if w.nbits&7 != 0 {
w.err = InternalError("writeBytes with unfinished bits")
return
}
for w.nbits != 0 {
w.bytes[n] = byte(w.bits)
w.bits >>= 8
w.nbits -= 8
n++
}
if n != 0 {
w.write(w.bytes[:n])
}
w.nbytes = 0
w.write(bytes)
}
// RFC 1951 3.2.7 specifies a special run-length encoding for specifying
// the literal and offset lengths arrays (which are concatenated into a single
// array). This method generates that run-length encoding.
//
// The result is written into the codegen array, and the frequencies
// of each code is written into the codegenFreq array.
// Codes 0-15 are single byte codes. Codes 16-18 are followed by additional
// information. Code badCode is an end marker
//
// numLiterals The number of literals in literalEncoding
// numOffsets The number of offsets in offsetEncoding
// litenc, offenc The literal and offset encoder to use
func (w *huffmanBitWriter) generateCodegen(numLiterals int, numOffsets int, litEnc, offEnc *huffmanEncoder) {
clear(w.codegenFreq[:])
// Note that we are using codegen both as a temporary variable for holding
// a copy of the frequencies, and as the place where we put the result.
// This is fine because the output is always shorter than the input used
// so far.
codegen := w.codegen // cache
// Copy the concatenated code sizes to codegen. Put a marker at the end.
cgnl := codegen[:numLiterals]
for i := range cgnl {
cgnl[i] = uint8(litEnc.codes[i].len)
}
cgnl = codegen[numLiterals : numLiterals+numOffsets]
for i := range cgnl {
cgnl[i] = uint8(offEnc.codes[i].len)
}
codegen[numLiterals+numOffsets] = badCode
size := codegen[0]
count := 1
outIndex := 0
for inIndex := 1; size != badCode; inIndex++ {
// INVARIANT: We have seen "count" copies of size that have not yet
// had output generated for them.
nextSize := codegen[inIndex]
if nextSize == size {
count++
continue
}
// We need to generate codegen indicating "count" of size.
if size != 0 {
codegen[outIndex] = size
outIndex++
w.codegenFreq[size]++
count--
for count >= 3 {
n := 6
if n > count {
n = count
}
codegen[outIndex] = 16
outIndex++
codegen[outIndex] = uint8(n - 3)
outIndex++
w.codegenFreq[16]++
count -= n
}
} else {
for count >= 11 {
n := 138
if n > count {
n = count
}
codegen[outIndex] = 18
outIndex++
codegen[outIndex] = uint8(n - 11)
outIndex++
w.codegenFreq[18]++
count -= n
}
if count >= 3 {
// count >= 3 && count <= 10
codegen[outIndex] = 17
outIndex++
codegen[outIndex] = uint8(count - 3)
outIndex++
w.codegenFreq[17]++
count = 0
}
}
count--
for ; count >= 0; count-- {
codegen[outIndex] = size
outIndex++
w.codegenFreq[size]++
}
// Set up invariant for next time through the loop.
size = nextSize
count = 1
}
// Marker indicating the end of the codegen.
codegen[outIndex] = badCode
}
// dynamicSize returns the size of dynamically encoded data in bits.
func (w *huffmanBitWriter) dynamicSize(litEnc, offEnc *huffmanEncoder, extraBits int) (size, numCodegens int) {
numCodegens = len(w.codegenFreq)
for numCodegens > 4 && w.codegenFreq[codegenOrder[numCodegens-1]] == 0 {
numCodegens--
}
header := 3 + 5 + 5 + 4 + (3 * numCodegens) +
w.codegenEncoding.bitLength(w.codegenFreq[:]) +
int(w.codegenFreq[16])*2 +
int(w.codegenFreq[17])*3 +
int(w.codegenFreq[18])*7
size = header +
litEnc.bitLength(w.literalFreq) +
offEnc.bitLength(w.offsetFreq) +
extraBits
return size, numCodegens
}
// fixedSize returns the size of dynamically encoded data in bits.
func (w *huffmanBitWriter) fixedSize(extraBits int) int {
return 3 +
fixedLiteralEncoding.bitLength(w.literalFreq) +
fixedOffsetEncoding.bitLength(w.offsetFreq) +
extraBits
}
// storedSize calculates the stored size, including header.
// The function returns the size in bits and whether the block
// fits inside a single block.
func (w *huffmanBitWriter) storedSize(in []byte) (int, bool) {
if in == nil {
return 0, false
}
if len(in) <= maxStoreBlockSize {
return (len(in) + 5) * 8, true
}
return 0, false
}
func (w *huffmanBitWriter) writeCode(c hcode) {
if w.err != nil {
return
}
w.bits |= uint64(c.code) << w.nbits
w.nbits += uint(c.len)
if w.nbits >= 48 {
bits := w.bits
w.bits >>= 48
w.nbits -= 48
n := w.nbytes
bytes := w.bytes[n : n+6]
bytes[0] = byte(bits)
bytes[1] = byte(bits >> 8)
bytes[2] = byte(bits >> 16)
bytes[3] = byte(bits >> 24)
bytes[4] = byte(bits >> 32)
bytes[5] = byte(bits >> 40)
n += 6
if n >= bufferFlushSize {
w.write(w.bytes[:n])
n = 0
}
w.nbytes = n
}
}
// Write the header of a dynamic Huffman block to the output stream.
//
// numLiterals The number of literals specified in codegen
// numOffsets The number of offsets specified in codegen
// numCodegens The number of codegens used in codegen
func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, numCodegens int, isEof bool) {
if w.err != nil {
return
}
var firstBits int32 = 4
if isEof {
firstBits = 5
}
w.writeBits(firstBits, 3)
w.writeBits(int32(numLiterals-257), 5)
w.writeBits(int32(numOffsets-1), 5)
w.writeBits(int32(numCodegens-4), 4)
for i := 0; i < numCodegens; i++ {
value := uint(w.codegenEncoding.codes[codegenOrder[i]].len)
w.writeBits(int32(value), 3)
}
i := 0
for {
var codeWord int = int(w.codegen[i])
i++
if codeWord == badCode {
break
}
w.writeCode(w.codegenEncoding.codes[uint32(codeWord)])
switch codeWord {
case 16:
w.writeBits(int32(w.codegen[i]), 2)
i++
case 17:
w.writeBits(int32(w.codegen[i]), 3)
i++
case 18:
w.writeBits(int32(w.codegen[i]), 7)
i++
}
}
}
func (w *huffmanBitWriter) writeStoredHeader(length int, isEof bool) {
if w.err != nil {
return
}
var flag int32
if isEof {
flag = 1
}
w.writeBits(flag, 3)
w.flush()
w.writeBits(int32(length), 16)
w.writeBits(int32(^uint16(length)), 16)
}
func (w *huffmanBitWriter) writeFixedHeader(isEof bool) {
if w.err != nil {
return
}
// Indicate that we are a fixed Huffman block
var value int32 = 2
if isEof {
value = 3
}
w.writeBits(value, 3)
}
// writeBlock will write a block of tokens with the smallest encoding.
// The original input can be supplied, and if the huffman encoded data
// is larger than the original bytes, the data will be written as a
// stored block.
// If the input is nil, the tokens will always be Huffman encoded.
func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
if w.err != nil {
return
}
tokens = append(tokens, endBlockMarker)
numLiterals, numOffsets := w.indexTokens(tokens)
var extraBits int
storedSize, storable := w.storedSize(input)
if storable {
// We only bother calculating the costs of the extra bits required by
// the length of offset fields (which will be the same for both fixed
// and dynamic encoding), if we need to compare those two encodings
// against stored encoding.
for lengthCode := lengthCodesStart + 8; lengthCode < numLiterals; lengthCode++ {
// First eight length codes have extra size = 0.
extraBits += int(w.literalFreq[lengthCode]) * int(lengthExtraBits[lengthCode-lengthCodesStart])
}
for offsetCode := 4; offsetCode < numOffsets; offsetCode++ {
// First four offset codes have extra size = 0.
extraBits += int(w.offsetFreq[offsetCode]) * int(offsetExtraBits[offsetCode])
}
}
// Figure out smallest code.
// Fixed Huffman baseline.
var literalEncoding = fixedLiteralEncoding
var offsetEncoding = fixedOffsetEncoding
var size = w.fixedSize(extraBits)
// Dynamic Huffman?
var numCodegens int
// Generate codegen and codegenFrequencies, which indicates how to encode
// the literalEncoding and the offsetEncoding.
w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, w.offsetEncoding)
w.codegenEncoding.generate(w.codegenFreq[:], 7)
dynamicSize, numCodegens := w.dynamicSize(w.literalEncoding, w.offsetEncoding, extraBits)
if dynamicSize < size {
size = dynamicSize
literalEncoding = w.literalEncoding
offsetEncoding = w.offsetEncoding
}
// Stored bytes?
if storable && storedSize < size {
w.writeStoredHeader(len(input), eof)
w.writeBytes(input)
return
}
// Huffman.
if literalEncoding == fixedLiteralEncoding {
w.writeFixedHeader(eof)
} else {
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
}
// Write the tokens.
w.writeTokens(tokens, literalEncoding.codes, offsetEncoding.codes)
}
// writeBlockDynamic encodes a block using a dynamic Huffman table.
// This should be used if the symbols used have a disproportionate
// histogram distribution.
// If input is supplied and the compression savings are below 1/16th of the
// input size the block is stored.
func (w *huffmanBitWriter) writeBlockDynamic(tokens []token, eof bool, input []byte) {
if w.err != nil {
return
}
tokens = append(tokens, endBlockMarker)
numLiterals, numOffsets := w.indexTokens(tokens)
// Generate codegen and codegenFrequencies, which indicates how to encode
// the literalEncoding and the offsetEncoding.
w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, w.offsetEncoding)
w.codegenEncoding.generate(w.codegenFreq[:], 7)
size, numCodegens := w.dynamicSize(w.literalEncoding, w.offsetEncoding, 0)
// Store bytes, if we don't get a reasonable improvement.
if ssize, storable := w.storedSize(input); storable && ssize < (size+size>>4) {
w.writeStoredHeader(len(input), eof)
w.writeBytes(input)
return
}
// Write Huffman table.
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
// Write the tokens.
w.writeTokens(tokens, w.literalEncoding.codes, w.offsetEncoding.codes)
}
// indexTokens indexes a slice of tokens, and updates
// literalFreq and offsetFreq, and generates literalEncoding
// and offsetEncoding.
// The number of literal and offset tokens is returned.
func (w *huffmanBitWriter) indexTokens(tokens []token) (numLiterals, numOffsets int) {
clear(w.literalFreq)
clear(w.offsetFreq)
for _, t := range tokens {
if t < matchType {
w.literalFreq[t.literal()]++
continue
}
length := t.length()
offset := t.offset()
w.literalFreq[lengthCodesStart+lengthCode(length)]++
w.offsetFreq[offsetCode(offset)]++
}
// get the number of literals
numLiterals = len(w.literalFreq)
for w.literalFreq[numLiterals-1] == 0 {
numLiterals--
}
// get the number of offsets
numOffsets = len(w.offsetFreq)
for numOffsets > 0 && w.offsetFreq[numOffsets-1] == 0 {
numOffsets--
}
if numOffsets == 0 {
// We haven't found a single match. If we want to go with the dynamic encoding,
// we should count at least one offset to be sure that the offset huffman tree could be encoded.
w.offsetFreq[0] = 1
numOffsets = 1
}
w.literalEncoding.generate(w.literalFreq, 15)
w.offsetEncoding.generate(w.offsetFreq, 15)
return
}
// writeTokens writes a slice of tokens to the output.
// codes for literal and offset encoding must be supplied.
func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode) {
if w.err != nil {
return
}
for _, t := range tokens {
if t < matchType {
w.writeCode(leCodes[t.literal()])
continue
}
// Write the length
length := t.length()
lengthCode := lengthCode(length)
w.writeCode(leCodes[lengthCode+lengthCodesStart])
extraLengthBits := uint(lengthExtraBits[lengthCode])
if extraLengthBits > 0 {
extraLength := int32(length - lengthBase[lengthCode])
w.writeBits(extraLength, extraLengthBits)
}
// Write the offset
offset := t.offset()
offsetCode := offsetCode(offset)
w.writeCode(oeCodes[offsetCode])
extraOffsetBits := uint(offsetExtraBits[offsetCode])
if extraOffsetBits > 0 {
extraOffset := int32(offset - offsetBase[offsetCode])
w.writeBits(extraOffset, extraOffsetBits)
}
}
}
// huffOffset is a static offset encoder used for huffman only encoding.
// It can be reused since we will not be encoding offset values.
var huffOffset *huffmanEncoder
func init() {
offsetFreq := make([]int32, offsetCodeCount)
offsetFreq[0] = 1
huffOffset = newHuffmanEncoder(offsetCodeCount)
huffOffset.generate(offsetFreq, 15)
}
// writeBlockHuff encodes a block of bytes as either
// Huffman encoded literals or uncompressed bytes if the
// results only gains very little from compression.
func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte) {
if w.err != nil {
return
}
// Clear histogram
clear(w.literalFreq)
// Add everything as literals
histogram(input, w.literalFreq)
w.literalFreq[endBlockMarker] = 1
const numLiterals = endBlockMarker + 1
w.offsetFreq[0] = 1
const numOffsets = 1
w.literalEncoding.generate(w.literalFreq, 15)
// Figure out smallest code.
// Always use dynamic Huffman or Store
var numCodegens int
// Generate codegen and codegenFrequencies, which indicates how to encode
// the literalEncoding and the offsetEncoding.
w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, huffOffset)
w.codegenEncoding.generate(w.codegenFreq[:], 7)
size, numCodegens := w.dynamicSize(w.literalEncoding, huffOffset, 0)
// Store bytes, if we don't get a reasonable improvement.
if ssize, storable := w.storedSize(input); storable && ssize < (size+size>>4) {
w.writeStoredHeader(len(input), eof)
w.writeBytes(input)
return
}
// Huffman.
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
encoding := w.literalEncoding.codes[:257]
n := w.nbytes
for _, t := range input {
// Bitwriting inlined, ~30% speedup
c := encoding[t]
w.bits |= uint64(c.code) << w.nbits
w.nbits += uint(c.len)
if w.nbits < 48 {
continue
}
// Store 6 bytes
bits := w.bits
w.bits >>= 48
w.nbits -= 48
bytes := w.bytes[n : n+6]
bytes[0] = byte(bits)
bytes[1] = byte(bits >> 8)
bytes[2] = byte(bits >> 16)
bytes[3] = byte(bits >> 24)
bytes[4] = byte(bits >> 32)
bytes[5] = byte(bits >> 40)
n += 6
if n < bufferFlushSize {
continue
}
w.write(w.bytes[:n])
if w.err != nil {
return // Return early in the event of write failures
}
n = 0
}
w.nbytes = n
w.writeCode(encoding[endBlockMarker])
}
// histogram accumulates a histogram of b in h.
//
// len(h) must be >= 256, and h's elements must be all zeroes.
func histogram(b []byte, h []int32) {
h = h[:256]
for _, t := range b {
h[t]++
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flate
import (
"math"
"math/bits"
"sort"
)
// hcode is a huffman code with a bit code and bit length.
type hcode struct {
code, len uint16
}
type huffmanEncoder struct {
codes []hcode
freqcache []literalNode
bitCount [17]int32
lns byLiteral // stored to avoid repeated allocation in generate
lfs byFreq // stored to avoid repeated allocation in generate
}
type literalNode struct {
literal uint16
freq int32
}
// A levelInfo describes the state of the constructed tree for a given depth.
type levelInfo struct {
// Our level. for better printing
level int32
// The frequency of the last node at this level
lastFreq int32
// The frequency of the next character to add to this level
nextCharFreq int32
// The frequency of the next pair (from level below) to add to this level.
// Only valid if the "needed" value of the next lower level is 0.
nextPairFreq int32
// The number of chains remaining to generate for this level before moving
// up to the next level
needed int32
}
// set sets the code and length of an hcode.
func (h *hcode) set(code uint16, length uint16) {
h.len = length
h.code = code
}
func maxNode() literalNode { return literalNode{math.MaxUint16, math.MaxInt32} }
func newHuffmanEncoder(size int) *huffmanEncoder {
return &huffmanEncoder{codes: make([]hcode, size)}
}
// Generates a HuffmanCode corresponding to the fixed literal table.
func generateFixedLiteralEncoding() *huffmanEncoder {
h := newHuffmanEncoder(maxNumLit)
codes := h.codes
var ch uint16
for ch = 0; ch < maxNumLit; ch++ {
var bits uint16
var size uint16
switch {
case ch < 144:
// size 8, 000110000 .. 10111111
bits = ch + 48
size = 8
case ch < 256:
// size 9, 110010000 .. 111111111
bits = ch + 400 - 144
size = 9
case ch < 280:
// size 7, 0000000 .. 0010111
bits = ch - 256
size = 7
default:
// size 8, 11000000 .. 11000111
bits = ch + 192 - 280
size = 8
}
codes[ch] = hcode{code: reverseBits(bits, byte(size)), len: size}
}
return h
}
func generateFixedOffsetEncoding() *huffmanEncoder {
h := newHuffmanEncoder(30)
codes := h.codes
for ch := range codes {
codes[ch] = hcode{code: reverseBits(uint16(ch), 5), len: 5}
}
return h
}
var fixedLiteralEncoding *huffmanEncoder = generateFixedLiteralEncoding()
var fixedOffsetEncoding *huffmanEncoder = generateFixedOffsetEncoding()
func (h *huffmanEncoder) bitLength(freq []int32) int {
var total int
for i, f := range freq {
if f != 0 {
total += int(f) * int(h.codes[i].len)
}
}
return total
}
const maxBitsLimit = 16
// bitCounts computes the number of literals assigned to each bit size in the Huffman encoding.
// It is only called when list.length >= 3.
// The cases of 0, 1, and 2 literals are handled by special case code.
//
// list is an array of the literals with non-zero frequencies
// and their associated frequencies. The array is in order of increasing
// frequency and has as its last element a special element with frequency
// MaxInt32.
//
// maxBits is the maximum number of bits that should be used to encode any literal.
// It must be less than 16.
//
// bitCounts returns an integer slice in which slice[i] indicates the number of literals
// that should be encoded in i bits.
func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 {
if maxBits >= maxBitsLimit {
panic("flate: maxBits too large")
}
n := int32(len(list))
list = list[0 : n+1]
list[n] = maxNode()
// The tree can't have greater depth than n - 1, no matter what. This
// saves a little bit of work in some small cases
if maxBits > n-1 {
maxBits = n - 1
}
// Create information about each of the levels.
// A bogus "Level 0" whose sole purpose is so that
// level1.prev.needed==0. This makes level1.nextPairFreq
// be a legitimate value that never gets chosen.
var levels [maxBitsLimit]levelInfo
// leafCounts[i] counts the number of literals at the left
// of ancestors of the rightmost node at level i.
// leafCounts[i][j] is the number of literals at the left
// of the level j ancestor.
var leafCounts [maxBitsLimit][maxBitsLimit]int32
for level := int32(1); level <= maxBits; level++ {
// For every level, the first two items are the first two characters.
// We initialize the levels as if we had already figured this out.
levels[level] = levelInfo{
level: level,
lastFreq: list[1].freq,
nextCharFreq: list[2].freq,
nextPairFreq: list[0].freq + list[1].freq,
}
leafCounts[level][level] = 2
if level == 1 {
levels[level].nextPairFreq = math.MaxInt32
}
}
// We need a total of 2*n - 2 items at top level and have already generated 2.
levels[maxBits].needed = 2*n - 4
level := maxBits
for {
l := &levels[level]
if l.nextPairFreq == math.MaxInt32 && l.nextCharFreq == math.MaxInt32 {
// We've run out of both leaves and pairs.
// End all calculations for this level.
// To make sure we never come back to this level or any lower level,
// set nextPairFreq impossibly large.
l.needed = 0
levels[level+1].nextPairFreq = math.MaxInt32
level++
continue
}
prevFreq := l.lastFreq
if l.nextCharFreq < l.nextPairFreq {
// The next item on this row is a leaf node.
n := leafCounts[level][level] + 1
l.lastFreq = l.nextCharFreq
// Lower leafCounts are the same of the previous node.
leafCounts[level][level] = n
l.nextCharFreq = list[n].freq
} else {
// The next item on this row is a pair from the previous row.
// nextPairFreq isn't valid until we generate two
// more values in the level below
l.lastFreq = l.nextPairFreq
// Take leaf counts from the lower level, except counts[level] remains the same.
copy(leafCounts[level][:level], leafCounts[level-1][:level])
levels[l.level-1].needed = 2
}
if l.needed--; l.needed == 0 {
// We've done everything we need to do for this level.
// Continue calculating one level up. Fill in nextPairFreq
// of that level with the sum of the two nodes we've just calculated on
// this level.
if l.level == maxBits {
// All done!
break
}
levels[l.level+1].nextPairFreq = prevFreq + l.lastFreq
level++
} else {
// If we stole from below, move down temporarily to replenish it.
for levels[level-1].needed > 0 {
level--
}
}
}
// Somethings is wrong if at the end, the top level is null or hasn't used
// all of the leaves.
if leafCounts[maxBits][maxBits] != n {
panic("leafCounts[maxBits][maxBits] != n")
}
bitCount := h.bitCount[:maxBits+1]
bits := 1
counts := &leafCounts[maxBits]
for level := maxBits; level > 0; level-- {
// chain.leafCount gives the number of literals requiring at least "bits"
// bits to encode.
bitCount[bits] = counts[level] - counts[level-1]
bits++
}
return bitCount
}
// Look at the leaves and assign them a bit count and an encoding as specified
// in RFC 1951 3.2.2
func (h *huffmanEncoder) assignEncodingAndSize(bitCount []int32, list []literalNode) {
code := uint16(0)
for n, bits := range bitCount {
code <<= 1
if n == 0 || bits == 0 {
continue
}
// The literals list[len(list)-bits] .. list[len(list)-bits]
// are encoded using "bits" bits, and get the values
// code, code + 1, .... The code values are
// assigned in literal order (not frequency order).
chunk := list[len(list)-int(bits):]
h.lns.sort(chunk)
for _, node := range chunk {
h.codes[node.literal] = hcode{code: reverseBits(code, uint8(n)), len: uint16(n)}
code++
}
list = list[0 : len(list)-int(bits)]
}
}
// Update this Huffman Code object to be the minimum code for the specified frequency count.
//
// freq is an array of frequencies, in which freq[i] gives the frequency of literal i.
// maxBits The maximum number of bits to use for any literal.
func (h *huffmanEncoder) generate(freq []int32, maxBits int32) {
if h.freqcache == nil {
// Allocate a reusable buffer with the longest possible frequency table.
// Possible lengths are codegenCodeCount, offsetCodeCount and maxNumLit.
// The largest of these is maxNumLit, so we allocate for that case.
h.freqcache = make([]literalNode, maxNumLit+1)
}
list := h.freqcache[:len(freq)+1]
// Number of non-zero literals
count := 0
// Set list to be the set of all non-zero literals and their frequencies
for i, f := range freq {
if f != 0 {
list[count] = literalNode{uint16(i), f}
count++
} else {
h.codes[i].len = 0
}
}
list = list[:count]
if count <= 2 {
// Handle the small cases here, because they are awkward for the general case code. With
// two or fewer literals, everything has bit length 1.
for i, node := range list {
// "list" is in order of increasing literal value.
h.codes[node.literal].set(uint16(i), 1)
}
return
}
h.lfs.sort(list)
// Get the number of literals for each bit count
bitCount := h.bitCounts(list, maxBits)
// And do the assignment
h.assignEncodingAndSize(bitCount, list)
}
type byLiteral []literalNode
func (s *byLiteral) sort(a []literalNode) {
*s = byLiteral(a)
sort.Sort(s)
}
func (s byLiteral) Len() int { return len(s) }
func (s byLiteral) Less(i, j int) bool {
return s[i].literal < s[j].literal
}
func (s byLiteral) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type byFreq []literalNode
func (s *byFreq) sort(a []literalNode) {
*s = byFreq(a)
sort.Sort(s)
}
func (s byFreq) Len() int { return len(s) }
func (s byFreq) Less(i, j int) bool {
if s[i].freq == s[j].freq {
return s[i].literal < s[j].literal
}
return s[i].freq < s[j].freq
}
func (s byFreq) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func reverseBits(number uint16, bitLength byte) uint16 {
return bits.Reverse16(number << (16 - bitLength))
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package flate implements the DEFLATE compressed data format, described in
// RFC 1951. The [compress/gzip] and [compress/zlib] packages implement access
// to DEFLATE-based file formats.
package flate
import (
"bufio"
"io"
"math/bits"
"strconv"
"sync"
)
const (
maxCodeLen = 16 // max length of Huffman code
// The next three numbers come from the RFC section 3.2.7, with the
// additional proviso in section 3.2.5 which implies that distance codes
// 30 and 31 should never occur in compressed data.
maxNumLit = 286
maxNumDist = 30
numCodes = 19 // number of codes in Huffman meta-code
)
// Initialize the fixedHuffmanDecoder only once upon first use.
var fixedOnce sync.Once
var fixedHuffmanDecoder huffmanDecoder
// A CorruptInputError reports the presence of corrupt input at a given offset.
type CorruptInputError int64
func (e CorruptInputError) Error() string {
return "flate: corrupt input before offset " + strconv.FormatInt(int64(e), 10)
}
// An InternalError reports an error in the flate code itself.
type InternalError string
func (e InternalError) Error() string { return "flate: internal error: " + string(e) }
// A ReadError reports an error encountered while reading input.
//
// Deprecated: No longer returned.
type ReadError struct {
Offset int64 // byte offset where error occurred
Err error // error returned by underlying Read
}
func (e *ReadError) Error() string {
return "flate: read error at offset " + strconv.FormatInt(e.Offset, 10) + ": " + e.Err.Error()
}
// A WriteError reports an error encountered while writing output.
//
// Deprecated: No longer returned.
type WriteError struct {
Offset int64 // byte offset where error occurred
Err error // error returned by underlying Write
}
func (e *WriteError) Error() string {
return "flate: write error at offset " + strconv.FormatInt(e.Offset, 10) + ": " + e.Err.Error()
}
// Resetter resets a ReadCloser returned by [NewReader] or [NewReaderDict]
// to switch to a new underlying [Reader]. This permits reusing a ReadCloser
// instead of allocating a new one.
type Resetter interface {
// Reset discards any buffered data and resets the Resetter as if it was
// newly initialized with the given reader.
Reset(r io.Reader, dict []byte) error
}
// The data structure for decoding Huffman tables is based on that of
// zlib. There is a lookup table of a fixed bit width (huffmanChunkBits),
// For codes smaller than the table width, there are multiple entries
// (each combination of trailing bits has the same value). For codes
// larger than the table width, the table contains a link to an overflow
// table. The width of each entry in the link table is the maximum code
// size minus the chunk width.
//
// Note that you can do a lookup in the table even without all bits
// filled. Since the extra bits are zero, and the DEFLATE Huffman codes
// have the property that shorter codes come before longer ones, the
// bit length estimate in the result is a lower bound on the actual
// number of bits.
//
// See the following:
// https://github.com/madler/zlib/raw/master/doc/algorithm.txt
// chunk & 15 is number of bits
// chunk >> 4 is value, including table link
const (
huffmanChunkBits = 9
huffmanNumChunks = 1 << huffmanChunkBits
huffmanCountMask = 15
huffmanValueShift = 4
)
type huffmanDecoder struct {
min int // the minimum code length
chunks [huffmanNumChunks]uint32 // chunks as described above
links [][]uint32 // overflow links
linkMask uint32 // mask the width of the link table
}
// Initialize Huffman decoding tables from array of code lengths.
// Following this function, h is guaranteed to be initialized into a complete
// tree (i.e., neither over-subscribed nor under-subscribed). The exception is a
// degenerate case where the tree has only a single symbol with length 1. Empty
// trees are permitted.
func (h *huffmanDecoder) init(lengths []int) bool {
// Sanity enables additional runtime tests during Huffman
// table construction. It's intended to be used during
// development to supplement the currently ad-hoc unit tests.
const sanity = false
if h.min != 0 {
*h = huffmanDecoder{}
}
// Count number of codes of each length,
// compute min and max length.
var count [maxCodeLen]int
var min, max int
for _, n := range lengths {
if n == 0 {
continue
}
if min == 0 || n < min {
min = n
}
if n > max {
max = n
}
count[n]++
}
// Empty tree. The decompressor.huffSym function will fail later if the tree
// is used. Technically, an empty tree is only valid for the HDIST tree and
// not the HCLEN and HLIT tree. However, a stream with an empty HCLEN tree
// is guaranteed to fail since it will attempt to use the tree to decode the
// codes for the HLIT and HDIST trees. Similarly, an empty HLIT tree is
// guaranteed to fail later since the compressed data section must be
// composed of at least one symbol (the end-of-block marker).
if max == 0 {
return true
}
code := 0
var nextcode [maxCodeLen]int
for i := min; i <= max; i++ {
code <<= 1
nextcode[i] = code
code += count[i]
}
// Check that the coding is complete (i.e., that we've
// assigned all 2-to-the-max possible bit sequences).
// Exception: To be compatible with zlib, we also need to
// accept degenerate single-code codings. See also
// TestDegenerateHuffmanCoding.
if code != 1<<uint(max) && !(code == 1 && max == 1) {
return false
}
h.min = min
if max > huffmanChunkBits {
numLinks := 1 << (uint(max) - huffmanChunkBits)
h.linkMask = uint32(numLinks - 1)
// create link tables
link := nextcode[huffmanChunkBits+1] >> 1
h.links = make([][]uint32, huffmanNumChunks-link)
for j := uint(link); j < huffmanNumChunks; j++ {
reverse := int(bits.Reverse16(uint16(j)))
reverse >>= uint(16 - huffmanChunkBits)
off := j - uint(link)
if sanity && h.chunks[reverse] != 0 {
panic("impossible: overwriting existing chunk")
}
h.chunks[reverse] = uint32(off<<huffmanValueShift | (huffmanChunkBits + 1))
h.links[off] = make([]uint32, numLinks)
}
}
for i, n := range lengths {
if n == 0 {
continue
}
code := nextcode[n]
nextcode[n]++
chunk := uint32(i<<huffmanValueShift | n)
reverse := int(bits.Reverse16(uint16(code)))
reverse >>= uint(16 - n)
if n <= huffmanChunkBits {
for off := reverse; off < len(h.chunks); off += 1 << uint(n) {
// We should never need to overwrite
// an existing chunk. Also, 0 is
// never a valid chunk, because the
// lower 4 "count" bits should be
// between 1 and 15.
if sanity && h.chunks[off] != 0 {
panic("impossible: overwriting existing chunk")
}
h.chunks[off] = chunk
}
} else {
j := reverse & (huffmanNumChunks - 1)
if sanity && h.chunks[j]&huffmanCountMask != huffmanChunkBits+1 {
// Longer codes should have been
// associated with a link table above.
panic("impossible: not an indirect chunk")
}
value := h.chunks[j] >> huffmanValueShift
linktab := h.links[value]
reverse >>= huffmanChunkBits
for off := reverse; off < len(linktab); off += 1 << uint(n-huffmanChunkBits) {
if sanity && linktab[off] != 0 {
panic("impossible: overwriting existing chunk")
}
linktab[off] = chunk
}
}
}
if sanity {
// Above we've sanity checked that we never overwrote
// an existing entry. Here we additionally check that
// we filled the tables completely.
for i, chunk := range h.chunks {
if chunk == 0 {
// As an exception, in the degenerate
// single-code case, we allow odd
// chunks to be missing.
if code == 1 && i%2 == 1 {
continue
}
panic("impossible: missing chunk")
}
}
for _, linktab := range h.links {
for _, chunk := range linktab {
if chunk == 0 {
panic("impossible: missing chunk")
}
}
}
}
return true
}
// The actual read interface needed by [NewReader].
// If the passed in [io.Reader] does not also have ReadByte,
// the [NewReader] will introduce its own buffering.
type Reader interface {
io.Reader
io.ByteReader
}
// Decompress state.
type decompressor struct {
// Input source.
r Reader
rBuf *bufio.Reader // created if provided io.Reader does not implement io.ByteReader
roffset int64
// Input bits, in top of b.
b uint32
nb uint
// Huffman decoders for literal/length, distance.
h1, h2 huffmanDecoder
// Length arrays used to define Huffman codes.
bits *[maxNumLit + maxNumDist]int
codebits *[numCodes]int
// Output history, buffer.
dict dictDecoder
// Temporary buffer (avoids repeated allocation).
buf [4]byte
// Next step in the decompression,
// and decompression state.
step func(*decompressor)
stepState int
final bool
err error
toRead []byte
hl, hd *huffmanDecoder
copyLen int
copyDist int
}
func (f *decompressor) nextBlock() {
for f.nb < 1+2 {
if f.err = f.moreBits(); f.err != nil {
return
}
}
f.final = f.b&1 == 1
f.b >>= 1
typ := f.b & 3
f.b >>= 2
f.nb -= 1 + 2
switch typ {
case 0:
f.dataBlock()
case 1:
// compressed, fixed Huffman tables
f.hl = &fixedHuffmanDecoder
f.hd = nil
f.huffmanBlock()
case 2:
// compressed, dynamic Huffman tables
if f.err = f.readHuffman(); f.err != nil {
break
}
f.hl = &f.h1
f.hd = &f.h2
f.huffmanBlock()
default:
// 3 is reserved.
f.err = CorruptInputError(f.roffset)
}
}
func (f *decompressor) Read(b []byte) (int, error) {
for {
if len(f.toRead) > 0 {
n := copy(b, f.toRead)
f.toRead = f.toRead[n:]
if len(f.toRead) == 0 {
return n, f.err
}
return n, nil
}
if f.err != nil {
return 0, f.err
}
f.step(f)
if f.err != nil && len(f.toRead) == 0 {
f.toRead = f.dict.readFlush() // Flush what's left in case of error
}
}
}
func (f *decompressor) Close() error {
if f.err == io.EOF {
return nil
}
return f.err
}
// RFC 1951 section 3.2.7.
// Compression with dynamic Huffman codes
var codeOrder = [...]int{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}
func (f *decompressor) readHuffman() error {
// HLIT[5], HDIST[5], HCLEN[4].
for f.nb < 5+5+4 {
if err := f.moreBits(); err != nil {
return err
}
}
nlit := int(f.b&0x1F) + 257
if nlit > maxNumLit {
return CorruptInputError(f.roffset)
}
f.b >>= 5
ndist := int(f.b&0x1F) + 1
if ndist > maxNumDist {
return CorruptInputError(f.roffset)
}
f.b >>= 5
nclen := int(f.b&0xF) + 4
// numCodes is 19, so nclen is always valid.
f.b >>= 4
f.nb -= 5 + 5 + 4
// (HCLEN+4)*3 bits: code lengths in the magic codeOrder order.
for i := 0; i < nclen; i++ {
for f.nb < 3 {
if err := f.moreBits(); err != nil {
return err
}
}
f.codebits[codeOrder[i]] = int(f.b & 0x7)
f.b >>= 3
f.nb -= 3
}
for i := nclen; i < len(codeOrder); i++ {
f.codebits[codeOrder[i]] = 0
}
if !f.h1.init(f.codebits[0:]) {
return CorruptInputError(f.roffset)
}
// HLIT + 257 code lengths, HDIST + 1 code lengths,
// using the code length Huffman code.
for i, n := 0, nlit+ndist; i < n; {
x, err := f.huffSym(&f.h1)
if err != nil {
return err
}
if x < 16 {
// Actual length.
f.bits[i] = x
i++
continue
}
// Repeat previous length or zero.
var rep int
var nb uint
var b int
switch x {
default:
return InternalError("unexpected length code")
case 16:
rep = 3
nb = 2
if i == 0 {
return CorruptInputError(f.roffset)
}
b = f.bits[i-1]
case 17:
rep = 3
nb = 3
b = 0
case 18:
rep = 11
nb = 7
b = 0
}
for f.nb < nb {
if err := f.moreBits(); err != nil {
return err
}
}
rep += int(f.b & uint32(1<<nb-1))
f.b >>= nb
f.nb -= nb
if i+rep > n {
return CorruptInputError(f.roffset)
}
for j := 0; j < rep; j++ {
f.bits[i] = b
i++
}
}
if !f.h1.init(f.bits[0:nlit]) || !f.h2.init(f.bits[nlit:nlit+ndist]) {
return CorruptInputError(f.roffset)
}
// As an optimization, we can initialize the min bits to read at a time
// for the HLIT tree to the length of the EOB marker since we know that
// every block must terminate with one. This preserves the property that
// we never read any extra bytes after the end of the DEFLATE stream.
if f.h1.min < f.bits[endBlockMarker] {
f.h1.min = f.bits[endBlockMarker]
}
return nil
}
// Decode a single Huffman block from f.
// hl and hd are the Huffman states for the lit/length values
// and the distance values, respectively. If hd == nil, using the
// fixed distance encoding associated with fixed Huffman blocks.
func (f *decompressor) huffmanBlock() {
const (
stateInit = iota // Zero value must be stateInit
stateDict
)
switch f.stepState {
case stateInit:
goto readLiteral
case stateDict:
goto copyHistory
}
readLiteral:
// Read literal and/or (length, distance) according to RFC section 3.2.3.
{
v, err := f.huffSym(f.hl)
if err != nil {
f.err = err
return
}
var n uint // number of bits extra
var length int
switch {
case v < 256:
f.dict.writeByte(byte(v))
if f.dict.availWrite() == 0 {
f.toRead = f.dict.readFlush()
f.step = (*decompressor).huffmanBlock
f.stepState = stateInit
return
}
goto readLiteral
case v == 256:
f.finishBlock()
return
// otherwise, reference to older data
case v < 265:
length = v - (257 - 3)
n = 0
case v < 269:
length = v*2 - (265*2 - 11)
n = 1
case v < 273:
length = v*4 - (269*4 - 19)
n = 2
case v < 277:
length = v*8 - (273*8 - 35)
n = 3
case v < 281:
length = v*16 - (277*16 - 67)
n = 4
case v < 285:
length = v*32 - (281*32 - 131)
n = 5
case v < maxNumLit:
length = 258
n = 0
default:
f.err = CorruptInputError(f.roffset)
return
}
if n > 0 {
for f.nb < n {
if err = f.moreBits(); err != nil {
f.err = err
return
}
}
length += int(f.b & uint32(1<<n-1))
f.b >>= n
f.nb -= n
}
var dist int
if f.hd == nil {
for f.nb < 5 {
if err = f.moreBits(); err != nil {
f.err = err
return
}
}
dist = int(bits.Reverse8(uint8(f.b & 0x1F << 3)))
f.b >>= 5
f.nb -= 5
} else {
if dist, err = f.huffSym(f.hd); err != nil {
f.err = err
return
}
}
switch {
case dist < 4:
dist++
case dist < maxNumDist:
nb := uint(dist-2) >> 1
// have 1 bit in bottom of dist, need nb more.
extra := (dist & 1) << nb
for f.nb < nb {
if err = f.moreBits(); err != nil {
f.err = err
return
}
}
extra |= int(f.b & uint32(1<<nb-1))
f.b >>= nb
f.nb -= nb
dist = 1<<(nb+1) + 1 + extra
default:
f.err = CorruptInputError(f.roffset)
return
}
// No check on length; encoding can be prescient.
if dist > f.dict.histSize() {
f.err = CorruptInputError(f.roffset)
return
}
f.copyLen, f.copyDist = length, dist
goto copyHistory
}
copyHistory:
// Perform a backwards copy according to RFC section 3.2.3.
{
cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen)
if cnt == 0 {
cnt = f.dict.writeCopy(f.copyDist, f.copyLen)
}
f.copyLen -= cnt
if f.dict.availWrite() == 0 || f.copyLen > 0 {
f.toRead = f.dict.readFlush()
f.step = (*decompressor).huffmanBlock // We need to continue this work
f.stepState = stateDict
return
}
goto readLiteral
}
}
// Copy a single uncompressed data block from input to output.
func (f *decompressor) dataBlock() {
// Uncompressed.
// Discard current half-byte.
f.nb = 0
f.b = 0
// Length then ones-complement of length.
nr, err := io.ReadFull(f.r, f.buf[0:4])
f.roffset += int64(nr)
if err != nil {
f.err = noEOF(err)
return
}
n := int(f.buf[0]) | int(f.buf[1])<<8
nn := int(f.buf[2]) | int(f.buf[3])<<8
if uint16(nn) != uint16(^n) {
f.err = CorruptInputError(f.roffset)
return
}
if n == 0 {
f.toRead = f.dict.readFlush()
f.finishBlock()
return
}
f.copyLen = n
f.copyData()
}
// copyData copies f.copyLen bytes from the underlying reader into f.hist.
// It pauses for reads when f.hist is full.
func (f *decompressor) copyData() {
buf := f.dict.writeSlice()
if len(buf) > f.copyLen {
buf = buf[:f.copyLen]
}
cnt, err := io.ReadFull(f.r, buf)
f.roffset += int64(cnt)
f.copyLen -= cnt
f.dict.writeMark(cnt)
if err != nil {
f.err = noEOF(err)
return
}
if f.dict.availWrite() == 0 || f.copyLen > 0 {
f.toRead = f.dict.readFlush()
f.step = (*decompressor).copyData
return
}
f.finishBlock()
}
func (f *decompressor) finishBlock() {
if f.final {
if f.dict.availRead() > 0 {
f.toRead = f.dict.readFlush()
}
f.err = io.EOF
}
f.step = (*decompressor).nextBlock
}
// noEOF returns err, unless err == io.EOF, in which case it returns io.ErrUnexpectedEOF.
func noEOF(e error) error {
if e == io.EOF {
return io.ErrUnexpectedEOF
}
return e
}
func (f *decompressor) moreBits() error {
c, err := f.r.ReadByte()
if err != nil {
return noEOF(err)
}
f.roffset++
f.b |= uint32(c) << f.nb
f.nb += 8
return nil
}
// Read the next Huffman-encoded symbol from f according to h.
func (f *decompressor) huffSym(h *huffmanDecoder) (int, error) {
// Since a huffmanDecoder can be empty or be composed of a degenerate tree
// with single element, huffSym must error on these two edge cases. In both
// cases, the chunks slice will be 0 for the invalid sequence, leading it
// satisfy the n == 0 check below.
n := uint(h.min)
// Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers,
// but is smart enough to keep local variables in registers, so use nb and b,
// inline call to moreBits and reassign b,nb back to f on return.
nb, b := f.nb, f.b
for {
for nb < n {
c, err := f.r.ReadByte()
if err != nil {
f.b = b
f.nb = nb
return 0, noEOF(err)
}
f.roffset++
b |= uint32(c) << (nb & 31)
nb += 8
}
chunk := h.chunks[b&(huffmanNumChunks-1)]
n = uint(chunk & huffmanCountMask)
if n > huffmanChunkBits {
chunk = h.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&h.linkMask]
n = uint(chunk & huffmanCountMask)
}
if n <= nb {
if n == 0 {
f.b = b
f.nb = nb
f.err = CorruptInputError(f.roffset)
return 0, f.err
}
f.b = b >> (n & 31)
f.nb = nb - n
return int(chunk >> huffmanValueShift), nil
}
}
}
func (f *decompressor) makeReader(r io.Reader) {
if rr, ok := r.(Reader); ok {
f.rBuf = nil
f.r = rr
return
}
// Reuse rBuf if possible. Invariant: rBuf is always created (and owned) by decompressor.
if f.rBuf != nil {
f.rBuf.Reset(r)
} else {
// bufio.NewReader will not return r, as r does not implement flate.Reader, so it is not bufio.Reader.
f.rBuf = bufio.NewReader(r)
}
f.r = f.rBuf
}
func fixedHuffmanDecoderInit() {
fixedOnce.Do(func() {
// These come from the RFC section 3.2.6.
var bits [288]int
for i := 0; i < 144; i++ {
bits[i] = 8
}
for i := 144; i < 256; i++ {
bits[i] = 9
}
for i := 256; i < 280; i++ {
bits[i] = 7
}
for i := 280; i < 288; i++ {
bits[i] = 8
}
fixedHuffmanDecoder.init(bits[:])
})
}
func (f *decompressor) Reset(r io.Reader, dict []byte) error {
*f = decompressor{
rBuf: f.rBuf,
bits: f.bits,
codebits: f.codebits,
dict: f.dict,
step: (*decompressor).nextBlock,
}
f.makeReader(r)
f.dict.init(maxMatchOffset, dict)
return nil
}
// NewReader returns a new ReadCloser that can be used
// to read the uncompressed version of r.
// If r does not also implement [io.ByteReader],
// the decompressor may read more data than necessary from r.
// The reader returns [io.EOF] after the final block in the DEFLATE stream has
// been encountered. Any trailing data after the final block is ignored.
//
// The [io.ReadCloser] returned by NewReader also implements [Resetter].
func NewReader(r io.Reader) io.ReadCloser {
fixedHuffmanDecoderInit()
var f decompressor
f.makeReader(r)
f.bits = new([maxNumLit + maxNumDist]int)
f.codebits = new([numCodes]int)
f.step = (*decompressor).nextBlock
f.dict.init(maxMatchOffset, nil)
return &f
}
// NewReaderDict is like [NewReader] but initializes the reader
// with a preset dictionary. The returned reader behaves as if
// the uncompressed data stream started with the given dictionary,
// which has already been read. NewReaderDict is typically used
// to read data compressed by [NewWriterDict].
//
// The ReadCloser returned by NewReaderDict also implements [Resetter].
func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
fixedHuffmanDecoderInit()
var f decompressor
f.makeReader(r)
f.bits = new([maxNumLit + maxNumDist]int)
f.codebits = new([numCodes]int)
f.step = (*decompressor).nextBlock
f.dict.init(maxMatchOffset, dict)
return &f
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flate
const (
// 2 bits: type 0 = literal 1=EOF 2=Match 3=Unused
// 8 bits: xlength = length - MIN_MATCH_LENGTH
// 22 bits xoffset = offset - MIN_OFFSET_SIZE, or literal
lengthShift = 22
offsetMask = 1<<lengthShift - 1
typeMask = 3 << 30
literalType = 0 << 30
matchType = 1 << 30
)
// The length code for length X (MIN_MATCH_LENGTH <= X <= MAX_MATCH_LENGTH)
// is lengthCodes[length - MIN_MATCH_LENGTH]
var lengthCodes = [...]uint32{
0, 1, 2, 3, 4, 5, 6, 7, 8, 8,
9, 9, 10, 10, 11, 11, 12, 12, 12, 12,
13, 13, 13, 13, 14, 14, 14, 14, 15, 15,
15, 15, 16, 16, 16, 16, 16, 16, 16, 16,
17, 17, 17, 17, 17, 17, 17, 17, 18, 18,
18, 18, 18, 18, 18, 18, 19, 19, 19, 19,
19, 19, 19, 19, 20, 20, 20, 20, 20, 20,
20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
21, 21, 21, 21, 21, 21, 22, 22, 22, 22,
22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
22, 22, 23, 23, 23, 23, 23, 23, 23, 23,
23, 23, 23, 23, 23, 23, 23, 23, 24, 24,
24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
25, 25, 26, 26, 26, 26, 26, 26, 26, 26,
26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
26, 26, 26, 26, 27, 27, 27, 27, 27, 27,
27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
27, 27, 27, 27, 27, 28,
}
var offsetCodes = [...]uint32{
0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7,
8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
}
type token uint32
// Convert a literal into a literal token.
func literalToken(literal uint32) token { return token(literalType + literal) }
// Convert a < xlength, xoffset > pair into a match token.
func matchToken(xlength uint32, xoffset uint32) token {
return token(matchType + xlength<<lengthShift + xoffset)
}
// Returns the literal of a literal token.
func (t token) literal() uint32 { return uint32(t - literalType) }
// Returns the extra offset of a match token.
func (t token) offset() uint32 { return uint32(t) & offsetMask }
func (t token) length() uint32 { return uint32((t - matchType) >> lengthShift) }
func lengthCode(len uint32) uint32 { return lengthCodes[len] }
// Returns the offset code corresponding to a specific offset.
func offsetCode(off uint32) uint32 {
if off < uint32(len(offsetCodes)) {
return offsetCodes[off]
}
if off>>7 < uint32(len(offsetCodes)) {
return offsetCodes[off>>7] + 14
}
return offsetCodes[off>>14] + 28
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package gzip implements reading and writing of gzip format compressed files,
// as specified in RFC 1952.
package gzip
import (
"bufio"
"compress/flate"
"encoding/binary"
"errors"
"hash/crc32"
"io"
"time"
)
const (
gzipID1 = 0x1f
gzipID2 = 0x8b
gzipDeflate = 8
flagText = 1 << 0
flagHdrCrc = 1 << 1
flagExtra = 1 << 2
flagName = 1 << 3
flagComment = 1 << 4
)
var (
// ErrChecksum is returned when reading GZIP data that has an invalid checksum.
ErrChecksum = errors.New("gzip: invalid checksum")
// ErrHeader is returned when reading GZIP data that has an invalid header.
ErrHeader = errors.New("gzip: invalid header")
)
var le = binary.LittleEndian
// noEOF converts io.EOF to io.ErrUnexpectedEOF.
func noEOF(err error) error {
if err == io.EOF {
return io.ErrUnexpectedEOF
}
return err
}
// The gzip file stores a header giving metadata about the compressed file.
// That header is exposed as the fields of the [Writer] and [Reader] structs.
//
// Strings must be UTF-8 encoded and may only contain Unicode code points
// U+0001 through U+00FF, due to limitations of the GZIP file format.
type Header struct {
Comment string // comment
Extra []byte // "extra data"
ModTime time.Time // modification time
Name string // file name
OS byte // operating system type
}
// A Reader is an [io.Reader] that can be read to retrieve
// uncompressed data from a gzip-format compressed file.
//
// In general, a gzip file can be a concatenation of gzip files,
// each with its own header. Reads from the Reader
// return the concatenation of the uncompressed data of each.
// Only the first header is recorded in the Reader fields.
//
// Gzip files store a length and checksum of the uncompressed data.
// The Reader will return an [ErrChecksum] when [Reader.Read]
// reaches the end of the uncompressed data if it does not
// have the expected length or checksum. Clients should treat data
// returned by [Reader.Read] as tentative until they receive the [io.EOF]
// marking the end of the data.
type Reader struct {
Header // valid after NewReader or Reader.Reset
r flate.Reader
decompressor io.ReadCloser
digest uint32 // CRC-32, IEEE polynomial (section 8)
size uint32 // Uncompressed size (section 2.3.1)
buf [512]byte
err error
multistream bool
}
// NewReader creates a new [Reader] reading the given reader.
// If r does not also implement [io.ByteReader],
// the decompressor may read more data than necessary from r.
//
// It is the caller's responsibility to call [Reader.Close] when done.
//
// The Reader.[Header] fields will be valid in the [Reader] returned.
func NewReader(r io.Reader) (*Reader, error) {
z := new(Reader)
if err := z.Reset(r); err != nil {
return nil, err
}
return z, nil
}
// Reset discards the [Reader] z's state and makes it equivalent to the
// result of its original state from [NewReader], but reading from r instead.
// This permits reusing a [Reader] rather than allocating a new one.
func (z *Reader) Reset(r io.Reader) error {
*z = Reader{
decompressor: z.decompressor,
multistream: true,
}
if rr, ok := r.(flate.Reader); ok {
z.r = rr
} else {
z.r = bufio.NewReader(r)
}
z.Header, z.err = z.readHeader()
return z.err
}
// Multistream controls whether the reader supports multistream files.
//
// If enabled (the default), the [Reader] expects the input to be a sequence
// of individually gzipped data streams, each with its own header and
// trailer, ending at EOF. The effect is that the concatenation of a sequence
// of gzipped files is treated as equivalent to the gzip of the concatenation
// of the sequence. This is standard behavior for gzip readers.
//
// Calling Multistream(false) disables this behavior; disabling the behavior
// can be useful when reading file formats that distinguish individual gzip
// data streams or mix gzip data streams with other data streams.
// In this mode, when the [Reader] reaches the end of the data stream,
// [Reader.Read] returns [io.EOF]. The underlying reader must implement [io.ByteReader]
// in order to be left positioned just after the gzip stream.
// To start the next stream, call z.Reset(r) followed by z.Multistream(false).
// If there is no next stream, z.Reset(r) will return [io.EOF].
func (z *Reader) Multistream(ok bool) {
z.multistream = ok
}
// readString reads a NUL-terminated string from z.r.
// It treats the bytes read as being encoded as ISO 8859-1 (Latin-1) and
// will output a string encoded using UTF-8.
// This method always updates z.digest with the data read.
func (z *Reader) readString() (string, error) {
var err error
needConv := false
for i := 0; ; i++ {
if i >= len(z.buf) {
return "", ErrHeader
}
z.buf[i], err = z.r.ReadByte()
if err != nil {
return "", err
}
if z.buf[i] > 0x7f {
needConv = true
}
if z.buf[i] == 0 {
// Digest covers the NUL terminator.
z.digest = crc32.Update(z.digest, crc32.IEEETable, z.buf[:i+1])
// Strings are ISO 8859-1, Latin-1 (RFC 1952, section 2.3.1).
if needConv {
s := make([]rune, 0, i)
for _, v := range z.buf[:i] {
s = append(s, rune(v))
}
return string(s), nil
}
return string(z.buf[:i]), nil
}
}
}
// readHeader reads the GZIP header according to section 2.3.1.
// This method does not set z.err.
func (z *Reader) readHeader() (hdr Header, err error) {
if _, err = io.ReadFull(z.r, z.buf[:10]); err != nil {
// RFC 1952, section 2.2, says the following:
// A gzip file consists of a series of "members" (compressed data sets).
//
// Other than this, the specification does not clarify whether a
// "series" is defined as "one or more" or "zero or more". To err on the
// side of caution, Go interprets this to mean "zero or more".
// Thus, it is okay to return io.EOF here.
return hdr, err
}
if z.buf[0] != gzipID1 || z.buf[1] != gzipID2 || z.buf[2] != gzipDeflate {
return hdr, ErrHeader
}
flg := z.buf[3]
if t := int64(le.Uint32(z.buf[4:8])); t > 0 {
// Section 2.3.1, the zero value for MTIME means that the
// modified time is not set.
hdr.ModTime = time.Unix(t, 0)
}
// z.buf[8] is XFL and is currently ignored.
hdr.OS = z.buf[9]
z.digest = crc32.ChecksumIEEE(z.buf[:10])
if flg&flagExtra != 0 {
if _, err = io.ReadFull(z.r, z.buf[:2]); err != nil {
return hdr, noEOF(err)
}
z.digest = crc32.Update(z.digest, crc32.IEEETable, z.buf[:2])
data := make([]byte, le.Uint16(z.buf[:2]))
if _, err = io.ReadFull(z.r, data); err != nil {
return hdr, noEOF(err)
}
z.digest = crc32.Update(z.digest, crc32.IEEETable, data)
hdr.Extra = data
}
var s string
if flg&flagName != 0 {
if s, err = z.readString(); err != nil {
return hdr, noEOF(err)
}
hdr.Name = s
}
if flg&flagComment != 0 {
if s, err = z.readString(); err != nil {
return hdr, noEOF(err)
}
hdr.Comment = s
}
if flg&flagHdrCrc != 0 {
if _, err = io.ReadFull(z.r, z.buf[:2]); err != nil {
return hdr, noEOF(err)
}
digest := le.Uint16(z.buf[:2])
if digest != uint16(z.digest) {
return hdr, ErrHeader
}
}
z.digest = 0
if z.decompressor == nil {
z.decompressor = flate.NewReader(z.r)
} else {
z.decompressor.(flate.Resetter).Reset(z.r, nil)
}
return hdr, nil
}
// Read implements [io.Reader], reading uncompressed bytes from its underlying reader.
func (z *Reader) Read(p []byte) (n int, err error) {
if z.err != nil {
return 0, z.err
}
for n == 0 {
n, z.err = z.decompressor.Read(p)
z.digest = crc32.Update(z.digest, crc32.IEEETable, p[:n])
z.size += uint32(n)
if z.err != io.EOF {
// In the normal case we return here.
return n, z.err
}
// Finished file; check checksum and size.
if _, err := io.ReadFull(z.r, z.buf[:8]); err != nil {
z.err = noEOF(err)
return n, z.err
}
digest := le.Uint32(z.buf[:4])
size := le.Uint32(z.buf[4:8])
if digest != z.digest || size != z.size {
z.err = ErrChecksum
return n, z.err
}
z.digest, z.size = 0, 0
// File is ok; check if there is another.
if !z.multistream {
return n, io.EOF
}
z.err = nil // Remove io.EOF
if _, z.err = z.readHeader(); z.err != nil {
return n, z.err
}
}
return n, nil
}
// Close closes the [Reader]. It does not close the underlying reader.
// In order for the GZIP checksum to be verified, the reader must be
// fully consumed until the [io.EOF].
func (z *Reader) Close() error { return z.decompressor.Close() }
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gzip
import (
"compress/flate"
"errors"
"fmt"
"hash/crc32"
"io"
"time"
)
// These constants are copied from the [flate] package, so that code that imports
// [compress/gzip] does not also have to import [compress/flate].
const (
NoCompression = flate.NoCompression
BestSpeed = flate.BestSpeed
BestCompression = flate.BestCompression
DefaultCompression = flate.DefaultCompression
HuffmanOnly = flate.HuffmanOnly
)
// A Writer is an [io.WriteCloser].
// Writes to a Writer are compressed and written to w.
type Writer struct {
Header // written at first call to Write, Flush, or Close
w io.Writer
level int
wroteHeader bool
closed bool
buf [10]byte
compressor *flate.Writer
digest uint32 // CRC-32, IEEE polynomial (section 8)
size uint32 // Uncompressed size (section 2.3.1)
err error
}
// NewWriter returns a new [Writer].
// Writes to the returned writer are compressed and written to w.
//
// It is the caller's responsibility to call Close on the [Writer] when done.
// Writes may be buffered and not flushed until Close.
//
// Callers that wish to set the fields in Writer.[Header] must do so before
// the first call to Write, Flush, or Close.
func NewWriter(w io.Writer) *Writer {
z, _ := NewWriterLevel(w, DefaultCompression)
return z
}
// NewWriterLevel is like [NewWriter] but specifies the compression level instead
// of assuming [DefaultCompression].
//
// The compression level can be [DefaultCompression], [NoCompression], [HuffmanOnly]
// or any integer value between [BestSpeed] and [BestCompression] inclusive.
// The error returned will be nil if the level is valid.
func NewWriterLevel(w io.Writer, level int) (*Writer, error) {
if level < HuffmanOnly || level > BestCompression {
return nil, fmt.Errorf("gzip: invalid compression level: %d", level)
}
z := new(Writer)
z.init(w, level)
return z, nil
}
func (z *Writer) init(w io.Writer, level int) {
compressor := z.compressor
if compressor != nil {
compressor.Reset(w)
}
*z = Writer{
Header: Header{
OS: 255, // unknown
},
w: w,
level: level,
compressor: compressor,
}
}
// Reset discards the [Writer] z's state and makes it equivalent to the
// result of its original state from [NewWriter] or [NewWriterLevel], but
// writing to w instead. This permits reusing a [Writer] rather than
// allocating a new one.
func (z *Writer) Reset(w io.Writer) {
z.init(w, z.level)
}
// writeBytes writes a length-prefixed byte slice to z.w.
func (z *Writer) writeBytes(b []byte) error {
if len(b) > 0xffff {
return errors.New("gzip.Write: Extra data is too large")
}
le.PutUint16(z.buf[:2], uint16(len(b)))
_, err := z.w.Write(z.buf[:2])
if err != nil {
return err
}
_, err = z.w.Write(b)
return err
}
// writeString writes a UTF-8 string s in GZIP's format to z.w.
// GZIP (RFC 1952) specifies that strings are NUL-terminated ISO 8859-1 (Latin-1).
func (z *Writer) writeString(s string) (err error) {
// GZIP stores Latin-1 strings; error if non-Latin-1; convert if non-ASCII.
needconv := false
for _, v := range s {
if v == 0 || v > 0xff {
return errors.New("gzip.Write: non-Latin-1 header string")
}
if v > 0x7f {
needconv = true
}
}
if needconv {
b := make([]byte, 0, len(s))
for _, v := range s {
b = append(b, byte(v))
}
_, err = z.w.Write(b)
} else {
_, err = io.WriteString(z.w, s)
}
if err != nil {
return err
}
// GZIP strings are NUL-terminated.
z.buf[0] = 0
_, err = z.w.Write(z.buf[:1])
return err
}
// Write writes a compressed form of p to the underlying [io.Writer]. The
// compressed bytes are not necessarily flushed until the [Writer] is closed.
func (z *Writer) Write(p []byte) (int, error) {
if z.err != nil {
return 0, z.err
}
var n int
// Write the GZIP header lazily.
if !z.wroteHeader {
z.wroteHeader = true
z.buf = [10]byte{0: gzipID1, 1: gzipID2, 2: gzipDeflate}
if z.Extra != nil {
z.buf[3] |= 0x04
}
if z.Name != "" {
z.buf[3] |= 0x08
}
if z.Comment != "" {
z.buf[3] |= 0x10
}
if z.ModTime.After(time.Unix(0, 0)) {
// Section 2.3.1, the zero value for MTIME means that the
// modified time is not set.
le.PutUint32(z.buf[4:8], uint32(z.ModTime.Unix()))
}
if z.level == BestCompression {
z.buf[8] = 2
} else if z.level == BestSpeed {
z.buf[8] = 4
}
z.buf[9] = z.OS
_, z.err = z.w.Write(z.buf[:10])
if z.err != nil {
return 0, z.err
}
if z.Extra != nil {
z.err = z.writeBytes(z.Extra)
if z.err != nil {
return 0, z.err
}
}
if z.Name != "" {
z.err = z.writeString(z.Name)
if z.err != nil {
return 0, z.err
}
}
if z.Comment != "" {
z.err = z.writeString(z.Comment)
if z.err != nil {
return 0, z.err
}
}
if z.compressor == nil {
z.compressor, _ = flate.NewWriter(z.w, z.level)
}
}
z.size += uint32(len(p))
z.digest = crc32.Update(z.digest, crc32.IEEETable, p)
n, z.err = z.compressor.Write(p)
return n, z.err
}
// Flush flushes any pending compressed data to the underlying writer.
//
// It is useful mainly in compressed network protocols, to ensure that
// a remote reader has enough data to reconstruct a packet. Flush does
// not return until the data has been written. If the underlying
// writer returns an error, Flush returns that error.
//
// In the terminology of the zlib library, Flush is equivalent to Z_SYNC_FLUSH.
func (z *Writer) Flush() error {
if z.err != nil {
return z.err
}
if z.closed {
return nil
}
if !z.wroteHeader {
z.Write(nil)
if z.err != nil {
return z.err
}
}
z.err = z.compressor.Flush()
return z.err
}
// Close closes the [Writer] by flushing any unwritten data to the underlying
// [io.Writer] and writing the GZIP footer.
// It does not close the underlying [io.Writer].
func (z *Writer) Close() error {
if z.err != nil {
return z.err
}
if z.closed {
return nil
}
z.closed = true
if !z.wroteHeader {
z.Write(nil)
if z.err != nil {
return z.err
}
}
z.err = z.compressor.Close()
if z.err != nil {
return z.err
}
le.PutUint32(z.buf[:4], z.digest)
le.PutUint32(z.buf[4:8], z.size)
_, z.err = z.w.Write(z.buf[:8])
return z.err
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package lzw implements the Lempel-Ziv-Welch compressed data format,
// described in T. A. Welch, “A Technique for High-Performance Data
// Compression”, Computer, 17(6) (June 1984), pp 8-19.
//
// In particular, it implements LZW as used by the GIF and PDF file
// formats, which means variable-width codes up to 12 bits and the first
// two non-literal codes are a clear code and an EOF code.
//
// The TIFF file format uses a similar but incompatible version of the LZW
// algorithm. See the [golang.org/x/image/tiff/lzw] package for an
// implementation.
package lzw
// TODO(nigeltao): check that PDF uses LZW in the same way as GIF,
// modulo LSB/MSB packing order.
import (
"bufio"
"errors"
"fmt"
"io"
)
// Order specifies the bit ordering in an LZW data stream.
type Order int
const (
// LSB means Least Significant Bits first, as used in the GIF file format.
LSB Order = iota
// MSB means Most Significant Bits first, as used in the TIFF and PDF
// file formats.
MSB
)
const (
maxWidth = 12
decoderInvalidCode = 0xffff
flushBuffer = 1 << maxWidth
)
// Reader is an [io.Reader] which can be used to read compressed data in the
// LZW format.
type Reader struct {
r io.ByteReader
bits uint32
nBits uint
width uint
read func(*Reader) (uint16, error) // readLSB or readMSB
litWidth int // width in bits of literal codes
err error
// The first 1<<litWidth codes are literal codes.
// The next two codes mean clear and EOF.
// Other valid codes are in the range [lo, hi] where lo := clear + 2,
// with the upper bound incrementing on each code seen.
//
// overflow is the code at which hi overflows the code width. It always
// equals 1 << width.
//
// last is the most recently seen code, or decoderInvalidCode.
//
// An invariant is that hi < overflow.
clear, eof, hi, overflow, last uint16
// Each code c in [lo, hi] expands to two or more bytes. For c != hi:
// suffix[c] is the last of these bytes.
// prefix[c] is the code for all but the last byte.
// This code can either be a literal code or another code in [lo, c).
// The c == hi case is a special case.
suffix [1 << maxWidth]uint8
prefix [1 << maxWidth]uint16
// output is the temporary output buffer.
// Literal codes are accumulated from the start of the buffer.
// Non-literal codes decode to a sequence of suffixes that are first
// written right-to-left from the end of the buffer before being copied
// to the start of the buffer.
// It is flushed when it contains >= 1<<maxWidth bytes,
// so that there is always room to decode an entire code.
output [2 * 1 << maxWidth]byte
o int // write index into output
toRead []byte // bytes to return from Read
}
// readLSB returns the next code for "Least Significant Bits first" data.
func (r *Reader) readLSB() (uint16, error) {
for r.nBits < r.width {
x, err := r.r.ReadByte()
if err != nil {
return 0, err
}
r.bits |= uint32(x) << r.nBits
r.nBits += 8
}
code := uint16(r.bits & (1<<r.width - 1))
r.bits >>= r.width
r.nBits -= r.width
return code, nil
}
// readMSB returns the next code for "Most Significant Bits first" data.
func (r *Reader) readMSB() (uint16, error) {
for r.nBits < r.width {
x, err := r.r.ReadByte()
if err != nil {
return 0, err
}
r.bits |= uint32(x) << (24 - r.nBits)
r.nBits += 8
}
code := uint16(r.bits >> (32 - r.width))
r.bits <<= r.width
r.nBits -= r.width
return code, nil
}
// Read implements io.Reader, reading uncompressed bytes from its underlying reader.
func (r *Reader) Read(b []byte) (int, error) {
for {
if len(r.toRead) > 0 {
n := copy(b, r.toRead)
r.toRead = r.toRead[n:]
return n, nil
}
if r.err != nil {
return 0, r.err
}
r.decode()
}
}
// decode decompresses bytes from r and leaves them in d.toRead.
// read specifies how to decode bytes into codes.
// litWidth is the width in bits of literal codes.
func (r *Reader) decode() {
// Loop over the code stream, converting codes into decompressed bytes.
loop:
for {
code, err := r.read(r)
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
r.err = err
break
}
switch {
case code < r.clear:
// We have a literal code.
r.output[r.o] = uint8(code)
r.o++
if r.last != decoderInvalidCode {
// Save what the hi code expands to.
r.suffix[r.hi] = uint8(code)
r.prefix[r.hi] = r.last
}
case code == r.clear:
r.width = 1 + uint(r.litWidth)
r.hi = r.eof
r.overflow = 1 << r.width
r.last = decoderInvalidCode
continue
case code == r.eof:
r.err = io.EOF
break loop
case code <= r.hi:
c, i := code, len(r.output)-1
if code == r.hi && r.last != decoderInvalidCode {
// code == hi is a special case which expands to the last expansion
// followed by the head of the last expansion. To find the head, we walk
// the prefix chain until we find a literal code.
c = r.last
for c >= r.clear {
c = r.prefix[c]
}
r.output[i] = uint8(c)
i--
c = r.last
}
// Copy the suffix chain into output and then write that to w.
for c >= r.clear {
r.output[i] = r.suffix[c]
i--
c = r.prefix[c]
}
r.output[i] = uint8(c)
r.o += copy(r.output[r.o:], r.output[i:])
if r.last != decoderInvalidCode {
// Save what the hi code expands to.
r.suffix[r.hi] = uint8(c)
r.prefix[r.hi] = r.last
}
default:
r.err = errors.New("lzw: invalid code")
break loop
}
r.last, r.hi = code, r.hi+1
if r.hi >= r.overflow {
if r.hi > r.overflow {
panic("unreachable")
}
if r.width == maxWidth {
r.last = decoderInvalidCode
// Undo the d.hi++ a few lines above, so that (1) we maintain
// the invariant that d.hi < d.overflow, and (2) d.hi does not
// eventually overflow a uint16.
r.hi--
} else {
r.width++
r.overflow = 1 << r.width
}
}
if r.o >= flushBuffer {
break
}
}
// Flush pending output.
r.toRead = r.output[:r.o]
r.o = 0
}
var errClosed = errors.New("lzw: reader/writer is closed")
// Close closes the [Reader] and returns an error for any future read operation.
// It does not close the underlying [io.Reader].
func (r *Reader) Close() error {
r.err = errClosed // in case any Reads come along
return nil
}
// Reset clears the [Reader]'s state and allows it to be reused again
// as a new [Reader].
func (r *Reader) Reset(src io.Reader, order Order, litWidth int) {
*r = Reader{}
r.init(src, order, litWidth)
}
// NewReader creates a new [io.ReadCloser].
// Reads from the returned [io.ReadCloser] read and decompress data from r.
// If r does not also implement [io.ByteReader],
// the decompressor may read more data than necessary from r.
// It is the caller's responsibility to call Close on the ReadCloser when
// finished reading.
// The number of bits to use for literal codes, litWidth, must be in the
// range [2,8] and is typically 8. It must equal the litWidth
// used during compression.
//
// It is guaranteed that the underlying type of the returned [io.ReadCloser]
// is a *[Reader].
func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser {
return newReader(r, order, litWidth)
}
func newReader(src io.Reader, order Order, litWidth int) *Reader {
r := new(Reader)
r.init(src, order, litWidth)
return r
}
func (r *Reader) init(src io.Reader, order Order, litWidth int) {
switch order {
case LSB:
r.read = (*Reader).readLSB
case MSB:
r.read = (*Reader).readMSB
default:
r.err = errors.New("lzw: unknown order")
return
}
if litWidth < 2 || 8 < litWidth {
r.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
return
}
br, ok := src.(io.ByteReader)
if !ok && src != nil {
br = bufio.NewReader(src)
}
r.r = br
r.litWidth = litWidth
r.width = 1 + uint(litWidth)
r.clear = uint16(1) << uint(litWidth)
r.eof, r.hi = r.clear+1, r.clear+1
r.overflow = uint16(1) << r.width
r.last = decoderInvalidCode
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package lzw
import (
"bufio"
"errors"
"fmt"
"io"
)
// A writer is a buffered, flushable writer.
type writer interface {
io.ByteWriter
Flush() error
}
const (
// A code is a 12 bit value, stored as a uint32 when encoding to avoid
// type conversions when shifting bits.
maxCode = 1<<12 - 1
invalidCode = 1<<32 - 1
// There are 1<<12 possible codes, which is an upper bound on the number of
// valid hash table entries at any given point in time. tableSize is 4x that.
tableSize = 4 * 1 << 12
tableMask = tableSize - 1
// A hash table entry is a uint32. Zero is an invalid entry since the
// lower 12 bits of a valid entry must be a non-literal code.
invalidEntry = 0
)
// Writer is an LZW compressor. It writes the compressed form of the data
// to an underlying writer (see [NewWriter]).
type Writer struct {
// w is the writer that compressed bytes are written to.
w writer
// litWidth is the width in bits of literal codes.
litWidth uint
// order, write, bits, nBits and width are the state for
// converting a code stream into a byte stream.
order Order
write func(*Writer, uint32) error
nBits uint
width uint
bits uint32
// hi is the code implied by the next code emission.
// overflow is the code at which hi overflows the code width.
hi, overflow uint32
// savedCode is the accumulated code at the end of the most recent Write
// call. It is equal to invalidCode if there was no such call.
savedCode uint32
// err is the first error encountered during writing. Closing the writer
// will make any future Write calls return errClosed
err error
// table is the hash table from 20-bit keys to 12-bit values. Each table
// entry contains key<<12|val and collisions resolve by linear probing.
// The keys consist of a 12-bit code prefix and an 8-bit byte suffix.
// The values are a 12-bit code.
table [tableSize]uint32
}
// writeLSB writes the code c for "Least Significant Bits first" data.
func (w *Writer) writeLSB(c uint32) error {
w.bits |= c << w.nBits
w.nBits += w.width
for w.nBits >= 8 {
if err := w.w.WriteByte(uint8(w.bits)); err != nil {
return err
}
w.bits >>= 8
w.nBits -= 8
}
return nil
}
// writeMSB writes the code c for "Most Significant Bits first" data.
func (w *Writer) writeMSB(c uint32) error {
w.bits |= c << (32 - w.width - w.nBits)
w.nBits += w.width
for w.nBits >= 8 {
if err := w.w.WriteByte(uint8(w.bits >> 24)); err != nil {
return err
}
w.bits <<= 8
w.nBits -= 8
}
return nil
}
// errOutOfCodes is an internal error that means that the writer has run out
// of unused codes and a clear code needs to be sent next.
var errOutOfCodes = errors.New("lzw: out of codes")
// incHi increments e.hi and checks for both overflow and running out of
// unused codes. In the latter case, incHi sends a clear code, resets the
// writer state and returns errOutOfCodes.
func (w *Writer) incHi() error {
w.hi++
if w.hi == w.overflow {
w.width++
w.overflow <<= 1
}
if w.hi == maxCode {
clear := uint32(1) << w.litWidth
if err := w.write(w, clear); err != nil {
return err
}
w.width = w.litWidth + 1
w.hi = clear + 1
w.overflow = clear << 1
for i := range w.table {
w.table[i] = invalidEntry
}
return errOutOfCodes
}
return nil
}
// Write writes a compressed representation of p to w's underlying writer.
func (w *Writer) Write(p []byte) (n int, err error) {
if w.err != nil {
return 0, w.err
}
if len(p) == 0 {
return 0, nil
}
if maxLit := uint8(1<<w.litWidth - 1); maxLit != 0xff {
for _, x := range p {
if x > maxLit {
w.err = errors.New("lzw: input byte too large for the litWidth")
return 0, w.err
}
}
}
n = len(p)
code := w.savedCode
if code == invalidCode {
// This is the first write; send a clear code.
// https://www.w3.org/Graphics/GIF/spec-gif89a.txt Appendix F
// "Variable-Length-Code LZW Compression" says that "Encoders should
// output a Clear code as the first code of each image data stream".
//
// LZW compression isn't only used by GIF, but it's cheap to follow
// that directive unconditionally.
clear := uint32(1) << w.litWidth
if err := w.write(w, clear); err != nil {
return 0, err
}
// After the starting clear code, the next code sent (for non-empty
// input) is always a literal code.
code, p = uint32(p[0]), p[1:]
}
loop:
for _, x := range p {
literal := uint32(x)
key := code<<8 | literal
// If there is a hash table hit for this key then we continue the loop
// and do not emit a code yet.
hash := (key>>12 ^ key) & tableMask
for h, t := hash, w.table[hash]; t != invalidEntry; {
if key == t>>12 {
code = t & maxCode
continue loop
}
h = (h + 1) & tableMask
t = w.table[h]
}
// Otherwise, write the current code, and literal becomes the start of
// the next emitted code.
if w.err = w.write(w, code); w.err != nil {
return 0, w.err
}
code = literal
// Increment e.hi, the next implied code. If we run out of codes, reset
// the writer state (including clearing the hash table) and continue.
if err1 := w.incHi(); err1 != nil {
if err1 == errOutOfCodes {
continue
}
w.err = err1
return 0, w.err
}
// Otherwise, insert key -> e.hi into the map that e.table represents.
for {
if w.table[hash] == invalidEntry {
w.table[hash] = (key << 12) | w.hi
break
}
hash = (hash + 1) & tableMask
}
}
w.savedCode = code
return n, nil
}
// Close closes the [Writer], flushing any pending output. It does not close
// w's underlying writer.
func (w *Writer) Close() error {
if w.err != nil {
if w.err == errClosed {
return nil
}
return w.err
}
// Make any future calls to Write return errClosed.
w.err = errClosed
// Write the savedCode if valid.
if w.savedCode != invalidCode {
if err := w.write(w, w.savedCode); err != nil {
return err
}
if err := w.incHi(); err != nil && err != errOutOfCodes {
return err
}
} else {
// Write the starting clear code, as w.Write did not.
clear := uint32(1) << w.litWidth
if err := w.write(w, clear); err != nil {
return err
}
}
// Write the eof code.
eof := uint32(1)<<w.litWidth + 1
if err := w.write(w, eof); err != nil {
return err
}
// Write the final bits.
if w.nBits > 0 {
if w.order == MSB {
w.bits >>= 24
}
if err := w.w.WriteByte(uint8(w.bits)); err != nil {
return err
}
}
return w.w.Flush()
}
// Reset clears the [Writer]'s state and allows it to be reused again
// as a new [Writer].
func (w *Writer) Reset(dst io.Writer, order Order, litWidth int) {
*w = Writer{}
w.init(dst, order, litWidth)
}
// NewWriter creates a new [io.WriteCloser].
// Writes to the returned [io.WriteCloser] are compressed and written to w.
// It is the caller's responsibility to call Close on the WriteCloser when
// finished writing.
// The number of bits to use for literal codes, litWidth, must be in the
// range [2,8] and is typically 8. Input bytes must be less than 1<<litWidth.
//
// It is guaranteed that the underlying type of the returned [io.WriteCloser]
// is a *[Writer].
func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
return newWriter(w, order, litWidth)
}
func newWriter(dst io.Writer, order Order, litWidth int) *Writer {
w := new(Writer)
w.init(dst, order, litWidth)
return w
}
func (w *Writer) init(dst io.Writer, order Order, litWidth int) {
switch order {
case LSB:
w.write = (*Writer).writeLSB
case MSB:
w.write = (*Writer).writeMSB
default:
w.err = errors.New("lzw: unknown order")
return
}
if litWidth < 2 || 8 < litWidth {
w.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
return
}
bw, ok := dst.(writer)
if !ok && dst != nil {
bw = bufio.NewWriter(dst)
}
w.w = bw
lw := uint(litWidth)
w.order = order
w.width = 1 + lw
w.litWidth = lw
w.hi = 1<<lw + 1
w.overflow = 1 << (lw + 1)
w.savedCode = invalidCode
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package zlib implements reading and writing of zlib format compressed data,
as specified in RFC 1950.
The implementation provides filters that uncompress during reading
and compress during writing. For example, to write compressed data
to a buffer:
var b bytes.Buffer
w := zlib.NewWriter(&b)
w.Write([]byte("hello, world\n"))
w.Close()
and to read that data back:
r, err := zlib.NewReader(&b)
io.Copy(os.Stdout, r)
r.Close()
*/
package zlib
import (
"bufio"
"compress/flate"
"encoding/binary"
"errors"
"hash"
"hash/adler32"
"io"
)
const (
zlibDeflate = 8
zlibMaxWindow = 7
)
var (
// ErrChecksum is returned when reading ZLIB data that has an invalid checksum.
ErrChecksum = errors.New("zlib: invalid checksum")
// ErrDictionary is returned when reading ZLIB data that has an invalid dictionary.
ErrDictionary = errors.New("zlib: invalid dictionary")
// ErrHeader is returned when reading ZLIB data that has an invalid header.
ErrHeader = errors.New("zlib: invalid header")
)
type reader struct {
r flate.Reader
decompressor io.ReadCloser
digest hash.Hash32
err error
scratch [4]byte
}
// Resetter resets a ReadCloser returned by [NewReader] or [NewReaderDict]
// to switch to a new underlying Reader. This permits reusing a ReadCloser
// instead of allocating a new one.
type Resetter interface {
// Reset discards any buffered data and resets the Resetter as if it was
// newly initialized with the given reader.
Reset(r io.Reader, dict []byte) error
}
// NewReader creates a new ReadCloser.
// Reads from the returned ReadCloser read and decompress data from r.
// If r does not implement [io.ByteReader], the decompressor may read more
// data than necessary from r.
// It is the caller's responsibility to call Close on the ReadCloser when done.
//
// The [io.ReadCloser] returned by NewReader also implements [Resetter].
func NewReader(r io.Reader) (io.ReadCloser, error) {
return NewReaderDict(r, nil)
}
// NewReaderDict is like [NewReader] but uses a preset dictionary.
// NewReaderDict ignores the dictionary if the compressed data does not refer to it.
// If the compressed data refers to a different dictionary, NewReaderDict returns [ErrDictionary].
//
// The ReadCloser returned by NewReaderDict also implements [Resetter].
func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, error) {
z := new(reader)
err := z.Reset(r, dict)
if err != nil {
return nil, err
}
return z, nil
}
func (z *reader) Read(p []byte) (int, error) {
if z.err != nil {
return 0, z.err
}
var n int
n, z.err = z.decompressor.Read(p)
z.digest.Write(p[0:n])
if z.err != io.EOF {
// In the normal case we return here.
return n, z.err
}
// Finished file; check checksum.
if _, err := io.ReadFull(z.r, z.scratch[0:4]); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
z.err = err
return n, z.err
}
// ZLIB (RFC 1950) is big-endian, unlike GZIP (RFC 1952).
checksum := binary.BigEndian.Uint32(z.scratch[:4])
if checksum != z.digest.Sum32() {
z.err = ErrChecksum
return n, z.err
}
return n, io.EOF
}
// Calling Close does not close the wrapped [io.Reader] originally passed to [NewReader].
// In order for the ZLIB checksum to be verified, the reader must be
// fully consumed until the [io.EOF].
func (z *reader) Close() error {
if z.err != nil && z.err != io.EOF {
return z.err
}
z.err = z.decompressor.Close()
return z.err
}
func (z *reader) Reset(r io.Reader, dict []byte) error {
*z = reader{decompressor: z.decompressor}
if fr, ok := r.(flate.Reader); ok {
z.r = fr
} else {
z.r = bufio.NewReader(r)
}
// Read the header (RFC 1950 section 2.2.).
_, z.err = io.ReadFull(z.r, z.scratch[0:2])
if z.err != nil {
if z.err == io.EOF {
z.err = io.ErrUnexpectedEOF
}
return z.err
}
h := binary.BigEndian.Uint16(z.scratch[:2])
if (z.scratch[0]&0x0f != zlibDeflate) || (z.scratch[0]>>4 > zlibMaxWindow) || (h%31 != 0) {
z.err = ErrHeader
return z.err
}
haveDict := z.scratch[1]&0x20 != 0
if haveDict {
_, z.err = io.ReadFull(z.r, z.scratch[0:4])
if z.err != nil {
if z.err == io.EOF {
z.err = io.ErrUnexpectedEOF
}
return z.err
}
checksum := binary.BigEndian.Uint32(z.scratch[:4])
if checksum != adler32.Checksum(dict) {
z.err = ErrDictionary
return z.err
}
}
if z.decompressor == nil {
if haveDict {
z.decompressor = flate.NewReaderDict(z.r, dict)
} else {
z.decompressor = flate.NewReader(z.r)
}
} else {
z.decompressor.(flate.Resetter).Reset(z.r, dict)
}
z.digest = adler32.New()
return nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zlib
import (
"compress/flate"
"encoding/binary"
"fmt"
"hash"
"hash/adler32"
"io"
)
// These constants are copied from the [flate] package, so that code that imports
// [compress/zlib] does not also have to import [compress/flate].
const (
NoCompression = flate.NoCompression
BestSpeed = flate.BestSpeed
BestCompression = flate.BestCompression
DefaultCompression = flate.DefaultCompression
HuffmanOnly = flate.HuffmanOnly
)
// A Writer takes data written to it and writes the compressed
// form of that data to an underlying writer (see [NewWriter]).
type Writer struct {
w io.Writer
level int
dict []byte
compressor *flate.Writer
digest hash.Hash32
err error
scratch [4]byte
wroteHeader bool
}
// NewWriter creates a new [Writer].
// Writes to the returned Writer are compressed and written to w.
//
// It is the caller's responsibility to call Close on the Writer when done.
// Writes may be buffered and not flushed until Close.
func NewWriter(w io.Writer) *Writer {
z, _ := NewWriterLevelDict(w, DefaultCompression, nil)
return z
}
// NewWriterLevel is like [NewWriter] but specifies the compression level instead
// of assuming [DefaultCompression].
//
// The compression level can be [DefaultCompression], [NoCompression], [HuffmanOnly]
// or any integer value between [BestSpeed] and [BestCompression] inclusive.
// The error returned will be nil if the level is valid.
func NewWriterLevel(w io.Writer, level int) (*Writer, error) {
return NewWriterLevelDict(w, level, nil)
}
// NewWriterLevelDict is like [NewWriterLevel] but specifies a dictionary to
// compress with.
//
// The dictionary may be nil. If not, its contents should not be modified until
// the Writer is closed.
func NewWriterLevelDict(w io.Writer, level int, dict []byte) (*Writer, error) {
if level < HuffmanOnly || level > BestCompression {
return nil, fmt.Errorf("zlib: invalid compression level: %d", level)
}
return &Writer{
w: w,
level: level,
dict: dict,
}, nil
}
// Reset clears the state of the [Writer] z such that it is equivalent to its
// initial state from [NewWriterLevel] or [NewWriterLevelDict], but instead writing
// to w.
func (z *Writer) Reset(w io.Writer) {
z.w = w
// z.level and z.dict left unchanged.
if z.compressor != nil {
z.compressor.Reset(w)
}
if z.digest != nil {
z.digest.Reset()
}
z.err = nil
z.scratch = [4]byte{}
z.wroteHeader = false
}
// writeHeader writes the ZLIB header.
func (z *Writer) writeHeader() (err error) {
z.wroteHeader = true
// ZLIB has a two-byte header (as documented in RFC 1950).
// The first four bits is the CINFO (compression info), which is 7 for the default deflate window size.
// The next four bits is the CM (compression method), which is 8 for deflate.
z.scratch[0] = 0x78
// The next two bits is the FLEVEL (compression level). The four values are:
// 0=fastest, 1=fast, 2=default, 3=best.
// The next bit, FDICT, is set if a dictionary is given.
// The final five FCHECK bits form a mod-31 checksum.
switch z.level {
case -2, 0, 1:
z.scratch[1] = 0 << 6
case 2, 3, 4, 5:
z.scratch[1] = 1 << 6
case 6, -1:
z.scratch[1] = 2 << 6
case 7, 8, 9:
z.scratch[1] = 3 << 6
default:
panic("unreachable")
}
if z.dict != nil {
z.scratch[1] |= 1 << 5
}
z.scratch[1] += uint8(31 - binary.BigEndian.Uint16(z.scratch[:2])%31)
if _, err = z.w.Write(z.scratch[0:2]); err != nil {
return err
}
if z.dict != nil {
// The next four bytes are the Adler-32 checksum of the dictionary.
binary.BigEndian.PutUint32(z.scratch[:], adler32.Checksum(z.dict))
if _, err = z.w.Write(z.scratch[0:4]); err != nil {
return err
}
}
if z.compressor == nil {
// Initialize deflater unless the Writer is being reused
// after a Reset call.
z.compressor, err = flate.NewWriterDict(z.w, z.level, z.dict)
if err != nil {
return err
}
z.digest = adler32.New()
}
return nil
}
// Write writes a compressed form of p to the underlying [io.Writer]. The
// compressed bytes are not necessarily flushed until the [Writer] is closed or
// explicitly flushed.
func (z *Writer) Write(p []byte) (n int, err error) {
if !z.wroteHeader {
z.err = z.writeHeader()
}
if z.err != nil {
return 0, z.err
}
if len(p) == 0 {
return 0, nil
}
n, err = z.compressor.Write(p)
if err != nil {
z.err = err
return
}
z.digest.Write(p)
return
}
// Flush flushes the Writer to its underlying [io.Writer].
func (z *Writer) Flush() error {
if !z.wroteHeader {
z.err = z.writeHeader()
}
if z.err != nil {
return z.err
}
z.err = z.compressor.Flush()
return z.err
}
// Close closes the Writer, flushing any unwritten data to the underlying
// [io.Writer], but does not close the underlying io.Writer.
func (z *Writer) Close() error {
if !z.wroteHeader {
z.err = z.writeHeader()
}
if z.err != nil {
return z.err
}
z.err = z.compressor.Close()
if z.err != nil {
return z.err
}
checksum := z.digest.Sum32()
// ZLIB (RFC 1950) is big-endian, unlike GZIP (RFC 1952).
binary.BigEndian.PutUint32(z.scratch[:], checksum)
_, z.err = z.w.Write(z.scratch[0:4])
return z.err
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package list implements a doubly linked list.
//
// To iterate over a list (where l is a *List):
//
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.Value
// }
package list
// Element is an element of a linked list.
type Element struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *Element
// The list to which this element belongs.
list *List
// The value stored with this element.
Value any
}
// Next returns the next list element or nil.
func (e *Element) Next() *Element {
if p := e.next; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// Prev returns the previous list element or nil.
func (e *Element) Prev() *Element {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// List represents a doubly linked list.
// The zero value for List is an empty list ready to use.
type List struct {
root Element // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *List) Init() *List {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// New returns an initialized list.
func New() *List { return new(List).Init() }
// Len returns the number of elements of list l.
// The complexity is O(1).
func (l *List) Len() int { return l.len }
// Front returns the first element of list l or nil if the list is empty.
func (l *List) Front() *Element {
if l.len == 0 {
return nil
}
return l.root.next
}
// Back returns the last element of list l or nil if the list is empty.
func (l *List) Back() *Element {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List value.
func (l *List) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *List) insert(e, at *Element) *Element {
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *List) insertValue(v any, at *Element) *Element {
return l.insert(&Element{Value: v}, at)
}
// remove removes e from its list, decrements l.len
func (l *List) remove(e *Element) {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
}
// move moves e to next to at.
func (l *List) move(e, at *Element) {
if e == at {
return
}
e.prev.next = e.next
e.next.prev = e.prev
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
}
// Remove removes e from l if e is an element of list l.
// It returns the element value e.Value.
// The element must not be nil.
func (l *List) Remove(e *Element) any {
if e.list == l {
// if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e)
}
return e.Value
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *List) PushFront(v any) *Element {
l.lazyInit()
return l.insertValue(v, &l.root)
}
// PushBack inserts a new element e with value v at the back of list l and returns e.
func (l *List) PushBack(v any) *Element {
l.lazyInit()
return l.insertValue(v, l.root.prev)
}
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *List) InsertBefore(v any, mark *Element) *Element {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev)
}
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *List) InsertAfter(v any, mark *Element) *Element {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *List) MoveToFront(e *Element) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.move(e, &l.root)
}
// MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *List) MoveToBack(e *Element) {
if e.list != l || l.root.prev == e {
return
}
// see comment in List.Remove about initialization of l
l.move(e, l.root.prev)
}
// MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *List) MoveBefore(e, mark *Element) {
if e.list != l || e == mark || mark.list != l {
return
}
l.move(e, mark.prev)
}
// MoveAfter moves element e to its new position after mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *List) MoveAfter(e, mark *Element) {
if e.list != l || e == mark || mark.list != l {
return
}
l.move(e, mark)
}
// PushBackList inserts a copy of another list at the back of list l.
// The lists l and other may be the same. They must not be nil.
func (l *List) PushBackList(other *List) {
l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
l.insertValue(e.Value, l.root.prev)
}
}
// PushFrontList inserts a copy of another list at the front of list l.
// The lists l and other may be the same. They must not be nil.
func (l *List) PushFrontList(other *List) {
l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
l.insertValue(e.Value, &l.root)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ring implements operations on circular lists.
package ring
// A Ring is an element of a circular list, or ring.
// Rings do not have a beginning or end; a pointer to any ring element
// serves as reference to the entire ring. Empty rings are represented
// as nil Ring pointers. The zero value for a Ring is a one-element
// ring with a nil Value.
type Ring struct {
next, prev *Ring
Value any // for use by client; untouched by this library
}
func (r *Ring) init() *Ring {
r.next = r
r.prev = r
return r
}
// Next returns the next ring element. r must not be empty.
func (r *Ring) Next() *Ring {
if r.next == nil {
return r.init()
}
return r.next
}
// Prev returns the previous ring element. r must not be empty.
func (r *Ring) Prev() *Ring {
if r.next == nil {
return r.init()
}
return r.prev
}
// Move moves n % r.Len() elements backward (n < 0) or forward (n >= 0)
// in the ring and returns that ring element. r must not be empty.
func (r *Ring) Move(n int) *Ring {
if r.next == nil {
return r.init()
}
switch {
case n < 0:
for ; n < 0; n++ {
r = r.prev
}
case n > 0:
for ; n > 0; n-- {
r = r.next
}
}
return r
}
// New creates a ring of n elements.
func New(n int) *Ring {
if n <= 0 {
return nil
}
r := new(Ring)
p := r
for i := 1; i < n; i++ {
p.next = &Ring{prev: p}
p = p.next
}
p.next = r
r.prev = p
return r
}
// Link connects ring r with ring s such that r.Next()
// becomes s and returns the original value for r.Next().
// r must not be empty.
//
// If r and s point to the same ring, linking
// them removes the elements between r and s from the ring.
// The removed elements form a subring and the result is a
// reference to that subring (if no elements were removed,
// the result is still the original value for r.Next(),
// and not nil).
//
// If r and s point to different rings, linking
// them creates a single ring with the elements of s inserted
// after r. The result points to the element following the
// last element of s after insertion.
func (r *Ring) Link(s *Ring) *Ring {
n := r.Next()
if s != nil {
p := s.Prev()
// Note: Cannot use multiple assignment because
// evaluation order of LHS is not specified.
r.next = s
s.prev = r
n.prev = p
p.next = n
}
return n
}
// Unlink removes n % r.Len() elements from the ring r, starting
// at r.Next(). If n % r.Len() == 0, r remains unchanged.
// The result is the removed subring. r must not be empty.
func (r *Ring) Unlink(n int) *Ring {
if n <= 0 {
return nil
}
return r.Link(r.Move(n + 1))
}
// Len computes the number of elements in ring r.
// It executes in time proportional to the number of elements.
func (r *Ring) Len() int {
n := 0
if r != nil {
n = 1
for p := r.Next(); p != r; p = p.next {
n++
}
}
return n
}
// Do calls function f on each element of the ring, in forward order.
// The behavior of Do is undefined if f changes *r.
func (r *Ring) Do(f func(any)) {
if r != nil {
f(r.Value)
for p := r.Next(); p != r; p = p.next {
f(p.Value)
}
}
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package context defines the Context type, which carries deadlines,
// cancellation signals, and other request-scoped values across API boundaries
// and between processes.
//
// Incoming requests to a server should create a [Context], and outgoing
// calls to servers should accept a Context. The chain of function
// calls between them must propagate the Context, optionally replacing
// it with a derived Context created using [WithCancel], [WithDeadline],
// [WithTimeout], or [WithValue].
//
// A Context may be canceled to indicate that work done on its behalf should stop.
// A Context with a deadline is canceled after the deadline passes.
// When a Context is canceled, all Contexts derived from it are also canceled.
//
// The [WithCancel], [WithDeadline], and [WithTimeout] functions take a
// Context (the parent) and return a derived Context (the child) and a
// [CancelFunc]. Calling the CancelFunc directly cancels the child and its
// children, removes the parent's reference to the child, and stops
// any associated timers. Failing to call the CancelFunc leaks the
// child and its children until the parent is canceled. The go vet tool
// checks that CancelFuncs are used on all control-flow paths.
//
// The [WithCancelCause], [WithDeadlineCause], and [WithTimeoutCause] functions
// return a [CancelCauseFunc], which takes an error and records it as
// the cancellation cause. Calling [Cause] on the canceled context
// or any of its children retrieves the cause. If no cause is specified,
// Cause(ctx) returns the same value as ctx.Err().
//
// Programs that use Contexts should follow these rules to keep interfaces
// consistent across packages and enable static analysis tools to check context
// propagation:
//
// Do not store Contexts inside a struct type; instead, pass a Context
// explicitly to each function that needs it. This is discussed further in
// https://go.dev/blog/context-and-structs. The Context should be the first
// parameter, typically named ctx:
//
// func DoSomething(ctx context.Context, arg Arg) error {
// // ... use ctx ...
// }
//
// Do not pass a nil [Context], even if a function permits it. Pass [context.TODO]
// if you are unsure about which Context to use.
//
// Use context Values only for request-scoped data that transits processes and
// APIs, not for passing optional parameters to functions.
//
// The same Context may be passed to functions running in different goroutines;
// Contexts are safe for simultaneous use by multiple goroutines.
//
// See https://go.dev/blog/context for example code for a server that uses
// Contexts.
package context
import (
"errors"
"internal/reflectlite"
"sync"
"sync/atomic"
"time"
)
// A Context carries a deadline, a cancellation signal, and other values across
// API boundaries.
//
// Context's methods may be called by multiple goroutines simultaneously.
type Context interface {
// Deadline returns the time when work done on behalf of this context
// should be canceled. Deadline returns ok==false when no deadline is
// set. Successive calls to Deadline return the same results.
Deadline() (deadline time.Time, ok bool)
// Done returns a channel that's closed when work done on behalf of this
// context should be canceled. Done may return nil if this context can
// never be canceled. Successive calls to Done return the same value.
// The close of the Done channel may happen asynchronously,
// after the cancel function returns.
//
// WithCancel arranges for Done to be closed when cancel is called;
// WithDeadline arranges for Done to be closed when the deadline
// expires; WithTimeout arranges for Done to be closed when the timeout
// elapses.
//
// Done is provided for use in select statements:
//
// // Stream generates values with DoSomething and sends them to out
// // until DoSomething returns an error or ctx.Done is closed.
// func Stream(ctx context.Context, out chan<- Value) error {
// for {
// v, err := DoSomething(ctx)
// if err != nil {
// return err
// }
// select {
// case <-ctx.Done():
// return ctx.Err()
// case out <- v:
// }
// }
// }
//
// See https://go.dev/blog/pipelines for more examples of how to use
// a Done channel for cancellation.
Done() <-chan struct{}
// If Done is not yet closed, Err returns nil.
// If Done is closed, Err returns a non-nil error explaining why:
// DeadlineExceeded if the context's deadline passed,
// or Canceled if the context was canceled for some other reason.
// After Err returns a non-nil error, successive calls to Err return the same error.
Err() error
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
//
// Use context values only for request-scoped data that transits
// processes and API boundaries, not for passing optional parameters to
// functions.
//
// A key identifies a specific value in a Context. Functions that wish
// to store values in Context typically allocate a key in a global
// variable then use that key as the argument to context.WithValue and
// Context.Value. A key can be any type that supports equality;
// packages should define keys as an unexported type to avoid
// collisions.
//
// Packages that define a Context key should provide type-safe accessors
// for the values stored using that key:
//
// // Package user defines a User type that's stored in Contexts.
// package user
//
// import "context"
//
// // User is the type of value stored in the Contexts.
// type User struct {...}
//
// // key is an unexported type for keys defined in this package.
// // This prevents collisions with keys defined in other packages.
// type key int
//
// // userKey is the key for user.User values in Contexts. It is
// // unexported; clients use user.NewContext and user.FromContext
// // instead of using this key directly.
// var userKey key
//
// // NewContext returns a new Context that carries value u.
// func NewContext(ctx context.Context, u *User) context.Context {
// return context.WithValue(ctx, userKey, u)
// }
//
// // FromContext returns the User value stored in ctx, if any.
// func FromContext(ctx context.Context) (*User, bool) {
// u, ok := ctx.Value(userKey).(*User)
// return u, ok
// }
Value(key any) any
}
// Canceled is the error returned by [Context.Err] when the context is canceled
// for some reason other than its deadline passing.
var Canceled = errors.New("context canceled")
// DeadlineExceeded is the error returned by [Context.Err] when the context is canceled
// due to its deadline passing.
var DeadlineExceeded error = deadlineExceededError{}
type deadlineExceededError struct{}
func (deadlineExceededError) Error() string { return "context deadline exceeded" }
func (deadlineExceededError) Timeout() bool { return true }
func (deadlineExceededError) Temporary() bool { return true }
// An emptyCtx is never canceled, has no values, and has no deadline.
// It is the common base of backgroundCtx and todoCtx.
type emptyCtx struct{}
func (emptyCtx) Deadline() (deadline time.Time, ok bool) {
return
}
func (emptyCtx) Done() <-chan struct{} {
return nil
}
func (emptyCtx) Err() error {
return nil
}
func (emptyCtx) Value(key any) any {
return nil
}
type backgroundCtx struct{ emptyCtx }
func (backgroundCtx) String() string {
return "context.Background"
}
type todoCtx struct{ emptyCtx }
func (todoCtx) String() string {
return "context.TODO"
}
// Background returns a non-nil, empty [Context]. It is never canceled, has no
// values, and has no deadline. It is typically used by the main function,
// initialization, and tests, and as the top-level Context for incoming
// requests.
func Background() Context {
return backgroundCtx{}
}
// TODO returns a non-nil, empty [Context]. Code should use context.TODO when
// it's unclear which Context to use or it is not yet available (because the
// surrounding function has not yet been extended to accept a Context
// parameter).
func TODO() Context {
return todoCtx{}
}
// A CancelFunc tells an operation to abandon its work.
// A CancelFunc does not wait for the work to stop.
// A CancelFunc may be called by multiple goroutines simultaneously.
// After the first call, subsequent calls to a CancelFunc do nothing.
type CancelFunc func()
// WithCancel returns a derived context that points to the parent context
// but has a new Done channel. The returned context's Done channel is closed
// when the returned cancel function is called or when the parent context's
// Done channel is closed, whichever happens first.
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this [Context] complete.
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
c := withCancel(parent)
return c, func() { c.cancel(true, Canceled, nil) }
}
// A CancelCauseFunc behaves like a [CancelFunc] but additionally sets the cancellation cause.
// This cause can be retrieved by calling [Cause] on the canceled Context or on
// any of its derived Contexts.
//
// If the context has already been canceled, CancelCauseFunc does not set the cause.
// For example, if childContext is derived from parentContext:
// - if parentContext is canceled with cause1 before childContext is canceled with cause2,
// then Cause(parentContext) == Cause(childContext) == cause1
// - if childContext is canceled with cause2 before parentContext is canceled with cause1,
// then Cause(parentContext) == cause1 and Cause(childContext) == cause2
type CancelCauseFunc func(cause error)
// WithCancelCause behaves like [WithCancel] but returns a [CancelCauseFunc] instead of a [CancelFunc].
// Calling cancel with a non-nil error (the "cause") records that error in ctx;
// it can then be retrieved using Cause(ctx).
// Calling cancel with nil sets the cause to Canceled.
//
// Example use:
//
// ctx, cancel := context.WithCancelCause(parent)
// cancel(myError)
// ctx.Err() // returns context.Canceled
// context.Cause(ctx) // returns myError
func WithCancelCause(parent Context) (ctx Context, cancel CancelCauseFunc) {
c := withCancel(parent)
return c, func(cause error) { c.cancel(true, Canceled, cause) }
}
func withCancel(parent Context) *cancelCtx {
if parent == nil {
panic("cannot create context from nil parent")
}
c := &cancelCtx{}
c.propagateCancel(parent, c)
return c
}
// Cause returns a non-nil error explaining why c was canceled.
// The first cancellation of c or one of its parents sets the cause.
// If that cancellation happened via a call to CancelCauseFunc(err),
// then [Cause] returns err.
// Otherwise Cause(c) returns the same value as c.Err().
// Cause returns nil if c has not been canceled yet.
func Cause(c Context) error {
if cc, ok := c.Value(&cancelCtxKey).(*cancelCtx); ok {
cc.mu.Lock()
cause := cc.cause
cc.mu.Unlock()
if cause != nil {
return cause
}
// Either this context is not canceled,
// or it is canceled and the cancellation happened in a
// custom context implementation rather than a *cancelCtx.
}
// There is no cancelCtxKey value with a cause, so we know that c is
// not a descendant of some canceled Context created by WithCancelCause.
// Therefore, there is no specific cause to return.
// If this is not one of the standard Context types,
// it might still have an error even though it won't have a cause.
return c.Err()
}
// AfterFunc arranges to call f in its own goroutine after ctx is canceled.
// If ctx is already canceled, AfterFunc calls f immediately in its own goroutine.
//
// Multiple calls to AfterFunc on a context operate independently;
// one does not replace another.
//
// Calling the returned stop function stops the association of ctx with f.
// It returns true if the call stopped f from being run.
// If stop returns false,
// either the context is canceled and f has been started in its own goroutine;
// or f was already stopped.
// The stop function does not wait for f to complete before returning.
// If the caller needs to know whether f is completed,
// it must coordinate with f explicitly.
//
// If ctx has a "AfterFunc(func()) func() bool" method,
// AfterFunc will use it to schedule the call.
func AfterFunc(ctx Context, f func()) (stop func() bool) {
a := &afterFuncCtx{
f: f,
}
a.cancelCtx.propagateCancel(ctx, a)
return func() bool {
stopped := false
a.once.Do(func() {
stopped = true
})
if stopped {
a.cancel(true, Canceled, nil)
}
return stopped
}
}
type afterFuncer interface {
AfterFunc(func()) func() bool
}
type afterFuncCtx struct {
cancelCtx
once sync.Once // either starts running f or stops f from running
f func()
}
func (a *afterFuncCtx) cancel(removeFromParent bool, err, cause error) {
a.cancelCtx.cancel(false, err, cause)
if removeFromParent {
removeChild(a.Context, a)
}
a.once.Do(func() {
go a.f()
})
}
// A stopCtx is used as the parent context of a cancelCtx when
// an AfterFunc has been registered with the parent.
// It holds the stop function used to unregister the AfterFunc.
type stopCtx struct {
Context
stop func() bool
}
// goroutines counts the number of goroutines ever created; for testing.
var goroutines atomic.Int32
// &cancelCtxKey is the key that a cancelCtx returns itself for.
var cancelCtxKey int
// parentCancelCtx returns the underlying *cancelCtx for parent.
// It does this by looking up parent.Value(&cancelCtxKey) to find
// the innermost enclosing *cancelCtx and then checking whether
// parent.Done() matches that *cancelCtx. (If not, the *cancelCtx
// has been wrapped in a custom implementation providing a
// different done channel, in which case we should not bypass it.)
func parentCancelCtx(parent Context) (*cancelCtx, bool) {
done := parent.Done()
if done == closedchan || done == nil {
return nil, false
}
p, ok := parent.Value(&cancelCtxKey).(*cancelCtx)
if !ok {
return nil, false
}
pdone, _ := p.done.Load().(chan struct{})
if pdone != done {
return nil, false
}
return p, true
}
// removeChild removes a context from its parent.
func removeChild(parent Context, child canceler) {
if s, ok := parent.(stopCtx); ok {
s.stop()
return
}
p, ok := parentCancelCtx(parent)
if !ok {
return
}
p.mu.Lock()
if p.children != nil {
delete(p.children, child)
}
p.mu.Unlock()
}
// A canceler is a context type that can be canceled directly. The
// implementations are *cancelCtx and *timerCtx.
type canceler interface {
cancel(removeFromParent bool, err, cause error)
Done() <-chan struct{}
}
// closedchan is a reusable closed channel.
var closedchan = make(chan struct{})
func init() {
close(closedchan)
}
// A cancelCtx can be canceled. When canceled, it also cancels any children
// that implement canceler.
type cancelCtx struct {
Context
mu sync.Mutex // protects following fields
done atomic.Value // of chan struct{}, created lazily, closed by first cancel call
children map[canceler]struct{} // set to nil by the first cancel call
err atomic.Value // set to non-nil by the first cancel call
cause error // set to non-nil by the first cancel call
}
func (c *cancelCtx) Value(key any) any {
if key == &cancelCtxKey {
return c
}
return value(c.Context, key)
}
func (c *cancelCtx) Done() <-chan struct{} {
d := c.done.Load()
if d != nil {
return d.(chan struct{})
}
c.mu.Lock()
defer c.mu.Unlock()
d = c.done.Load()
if d == nil {
d = make(chan struct{})
c.done.Store(d)
}
return d.(chan struct{})
}
func (c *cancelCtx) Err() error {
// An atomic load is ~5x faster than a mutex, which can matter in tight loops.
if err := c.err.Load(); err != nil {
return err.(error)
}
return nil
}
// propagateCancel arranges for child to be canceled when parent is.
// It sets the parent context of cancelCtx.
func (c *cancelCtx) propagateCancel(parent Context, child canceler) {
c.Context = parent
done := parent.Done()
if done == nil {
return // parent is never canceled
}
select {
case <-done:
// parent is already canceled
child.cancel(false, parent.Err(), Cause(parent))
return
default:
}
if p, ok := parentCancelCtx(parent); ok {
// parent is a *cancelCtx, or derives from one.
p.mu.Lock()
if err := p.err.Load(); err != nil {
// parent has already been canceled
child.cancel(false, err.(error), p.cause)
} else {
if p.children == nil {
p.children = make(map[canceler]struct{})
}
p.children[child] = struct{}{}
}
p.mu.Unlock()
return
}
if a, ok := parent.(afterFuncer); ok {
// parent implements an AfterFunc method.
c.mu.Lock()
stop := a.AfterFunc(func() {
child.cancel(false, parent.Err(), Cause(parent))
})
c.Context = stopCtx{
Context: parent,
stop: stop,
}
c.mu.Unlock()
return
}
goroutines.Add(1)
go func() {
select {
case <-parent.Done():
child.cancel(false, parent.Err(), Cause(parent))
case <-child.Done():
}
}()
}
type stringer interface {
String() string
}
func contextName(c Context) string {
if s, ok := c.(stringer); ok {
return s.String()
}
return reflectlite.TypeOf(c).String()
}
func (c *cancelCtx) String() string {
return contextName(c.Context) + ".WithCancel"
}
// cancel closes c.done, cancels each of c's children, and, if
// removeFromParent is true, removes c from its parent's children.
// cancel sets c.cause to cause if this is the first time c is canceled.
func (c *cancelCtx) cancel(removeFromParent bool, err, cause error) {
if err == nil {
panic("context: internal error: missing cancel error")
}
if cause == nil {
cause = err
}
c.mu.Lock()
if c.err.Load() != nil {
c.mu.Unlock()
return // already canceled
}
c.err.Store(err)
c.cause = cause
d, _ := c.done.Load().(chan struct{})
if d == nil {
c.done.Store(closedchan)
} else {
close(d)
}
for child := range c.children {
// NOTE: acquiring the child's lock while holding parent's lock.
child.cancel(false, err, cause)
}
c.children = nil
c.mu.Unlock()
if removeFromParent {
removeChild(c.Context, c)
}
}
// WithoutCancel returns a derived context that points to the parent context
// and is not canceled when parent is canceled.
// The returned context returns no Deadline or Err, and its Done channel is nil.
// Calling [Cause] on the returned context returns nil.
func WithoutCancel(parent Context) Context {
if parent == nil {
panic("cannot create context from nil parent")
}
return withoutCancelCtx{parent}
}
type withoutCancelCtx struct {
c Context
}
func (withoutCancelCtx) Deadline() (deadline time.Time, ok bool) {
return
}
func (withoutCancelCtx) Done() <-chan struct{} {
return nil
}
func (withoutCancelCtx) Err() error {
return nil
}
func (c withoutCancelCtx) Value(key any) any {
return value(c, key)
}
func (c withoutCancelCtx) String() string {
return contextName(c.c) + ".WithoutCancel"
}
// WithDeadline returns a derived context that points to the parent context
// but has the deadline adjusted to be no later than d. If the parent's
// deadline is already earlier than d, WithDeadline(parent, d) is semantically
// equivalent to parent. The returned [Context.Done] channel is closed when
// the deadline expires, when the returned cancel function is called,
// or when the parent context's Done channel is closed, whichever happens first.
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this [Context] complete.
func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) {
return WithDeadlineCause(parent, d, nil)
}
// WithDeadlineCause behaves like [WithDeadline] but also sets the cause of the
// returned Context when the deadline is exceeded. The returned [CancelFunc] does
// not set the cause.
func WithDeadlineCause(parent Context, d time.Time, cause error) (Context, CancelFunc) {
if parent == nil {
panic("cannot create context from nil parent")
}
if cur, ok := parent.Deadline(); ok && cur.Before(d) {
// The current deadline is already sooner than the new one.
return WithCancel(parent)
}
c := &timerCtx{
deadline: d,
}
c.cancelCtx.propagateCancel(parent, c)
dur := time.Until(d)
if dur <= 0 {
c.cancel(true, DeadlineExceeded, cause) // deadline has already passed
return c, func() { c.cancel(false, Canceled, nil) }
}
c.mu.Lock()
defer c.mu.Unlock()
if c.err.Load() == nil {
c.timer = time.AfterFunc(dur, func() {
c.cancel(true, DeadlineExceeded, cause)
})
}
return c, func() { c.cancel(true, Canceled, nil) }
}
// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
// implement Done and Err. It implements cancel by stopping its timer then
// delegating to cancelCtx.cancel.
type timerCtx struct {
cancelCtx
timer *time.Timer // Under cancelCtx.mu.
deadline time.Time
}
func (c *timerCtx) Deadline() (deadline time.Time, ok bool) {
return c.deadline, true
}
func (c *timerCtx) String() string {
return contextName(c.cancelCtx.Context) + ".WithDeadline(" +
c.deadline.String() + " [" +
time.Until(c.deadline).String() + "])"
}
func (c *timerCtx) cancel(removeFromParent bool, err, cause error) {
c.cancelCtx.cancel(false, err, cause)
if removeFromParent {
// Remove this timerCtx from its parent cancelCtx's children.
removeChild(c.cancelCtx.Context, c)
}
c.mu.Lock()
if c.timer != nil {
c.timer.Stop()
c.timer = nil
}
c.mu.Unlock()
}
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this [Context] complete:
//
// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
// defer cancel() // releases resources if slowOperation completes before timeout elapses
// return slowOperation(ctx)
// }
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
return WithDeadline(parent, time.Now().Add(timeout))
}
// WithTimeoutCause behaves like [WithTimeout] but also sets the cause of the
// returned Context when the timeout expires. The returned [CancelFunc] does
// not set the cause.
func WithTimeoutCause(parent Context, timeout time.Duration, cause error) (Context, CancelFunc) {
return WithDeadlineCause(parent, time.Now().Add(timeout), cause)
}
// WithValue returns a derived context that points to the parent Context.
// In the derived context, the value associated with key is val.
//
// Use context Values only for request-scoped data that transits processes and
// APIs, not for passing optional parameters to functions.
//
// The provided key must be comparable and should not be of type
// string or any other built-in type to avoid collisions between
// packages using context. Users of WithValue should define their own
// types for keys. To avoid allocating when assigning to an
// interface{}, context keys often have concrete type
// struct{}. Alternatively, exported context key variables' static
// type should be a pointer or interface.
func WithValue(parent Context, key, val any) Context {
if parent == nil {
panic("cannot create context from nil parent")
}
if key == nil {
panic("nil key")
}
if !reflectlite.TypeOf(key).Comparable() {
panic("key is not comparable")
}
return &valueCtx{parent, key, val}
}
// A valueCtx carries a key-value pair. It implements Value for that key and
// delegates all other calls to the embedded Context.
type valueCtx struct {
Context
key, val any
}
// stringify tries a bit to stringify v, without using fmt, since we don't
// want context depending on the unicode tables. This is only used by
// *valueCtx.String().
func stringify(v any) string {
switch s := v.(type) {
case stringer:
return s.String()
case string:
return s
case nil:
return "<nil>"
}
return reflectlite.TypeOf(v).String()
}
func (c *valueCtx) String() string {
return contextName(c.Context) + ".WithValue(" +
stringify(c.key) + ", " +
stringify(c.val) + ")"
}
func (c *valueCtx) Value(key any) any {
if c.key == key {
return c.val
}
return value(c.Context, key)
}
func value(c Context, key any) any {
for {
switch ctx := c.(type) {
case *valueCtx:
if key == ctx.key {
return ctx.val
}
c = ctx.Context
case *cancelCtx:
if key == &cancelCtxKey {
return c
}
c = ctx.Context
case withoutCancelCtx:
if key == &cancelCtxKey {
// This implements Cause(ctx) == nil
// when ctx is created using WithoutCancel.
return nil
}
c = ctx.c
case *timerCtx:
if key == &cancelCtxKey {
return &ctx.cancelCtx
}
c = ctx.Context
case backgroundCtx, todoCtx:
return nil
default:
return c.Value(key)
}
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package aes implements AES encryption (formerly Rijndael), as defined in
// U.S. Federal Information Processing Standards Publication 197.
//
// The AES operations in this package are not implemented using constant-time algorithms.
// An exception is when running on systems with enabled hardware support for AES
// that makes these operations constant-time. Examples include amd64 systems using AES-NI
// extensions and s390x systems using Message-Security-Assist extensions.
// On such systems, when the result of NewCipher is passed to cipher.NewGCM,
// the GHASH operation used by GCM is also constant-time.
package aes
import (
"crypto/cipher"
"crypto/internal/boring"
"crypto/internal/fips140/aes"
"strconv"
)
// The AES block size in bytes.
const BlockSize = 16
type KeySizeError int
func (k KeySizeError) Error() string {
return "crypto/aes: invalid key size " + strconv.Itoa(int(k))
}
// NewCipher creates and returns a new [cipher.Block].
// The key argument must be the AES key,
// either 16, 24, or 32 bytes to select
// AES-128, AES-192, or AES-256.
func NewCipher(key []byte) (cipher.Block, error) {
k := len(key)
switch k {
default:
return nil, KeySizeError(k)
case 16, 24, 32:
break
}
if boring.Enabled {
return boring.NewAESCipher(key)
}
return aes.New(key)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package des
import (
"internal/byteorder"
"sync"
)
func cryptBlock(subkeys []uint64, dst, src []byte, decrypt bool) {
b := byteorder.BEUint64(src)
b = permuteInitialBlock(b)
left, right := uint32(b>>32), uint32(b)
left = (left << 1) | (left >> 31)
right = (right << 1) | (right >> 31)
if decrypt {
for i := 0; i < 8; i++ {
left, right = feistel(left, right, subkeys[15-2*i], subkeys[15-(2*i+1)])
}
} else {
for i := 0; i < 8; i++ {
left, right = feistel(left, right, subkeys[2*i], subkeys[2*i+1])
}
}
left = (left << 31) | (left >> 1)
right = (right << 31) | (right >> 1)
// switch left & right and perform final permutation
preOutput := (uint64(right) << 32) | uint64(left)
byteorder.BEPutUint64(dst, permuteFinalBlock(preOutput))
}
// DES Feistel function. feistelBox must be initialized via
// feistelBoxOnce.Do(initFeistelBox) first.
func feistel(l, r uint32, k0, k1 uint64) (lout, rout uint32) {
var t uint32
t = r ^ uint32(k0>>32)
l ^= feistelBox[7][t&0x3f] ^
feistelBox[5][(t>>8)&0x3f] ^
feistelBox[3][(t>>16)&0x3f] ^
feistelBox[1][(t>>24)&0x3f]
t = ((r << 28) | (r >> 4)) ^ uint32(k0)
l ^= feistelBox[6][(t)&0x3f] ^
feistelBox[4][(t>>8)&0x3f] ^
feistelBox[2][(t>>16)&0x3f] ^
feistelBox[0][(t>>24)&0x3f]
t = l ^ uint32(k1>>32)
r ^= feistelBox[7][t&0x3f] ^
feistelBox[5][(t>>8)&0x3f] ^
feistelBox[3][(t>>16)&0x3f] ^
feistelBox[1][(t>>24)&0x3f]
t = ((l << 28) | (l >> 4)) ^ uint32(k1)
r ^= feistelBox[6][(t)&0x3f] ^
feistelBox[4][(t>>8)&0x3f] ^
feistelBox[2][(t>>16)&0x3f] ^
feistelBox[0][(t>>24)&0x3f]
return l, r
}
// feistelBox[s][16*i+j] contains the output of permutationFunction
// for sBoxes[s][i][j] << 4*(7-s)
var feistelBox [8][64]uint32
var feistelBoxOnce sync.Once
// general purpose function to perform DES block permutations.
func permuteBlock(src uint64, permutation []uint8) (block uint64) {
for position, n := range permutation {
bit := (src >> n) & 1
block |= bit << uint((len(permutation)-1)-position)
}
return
}
func initFeistelBox() {
for s := range sBoxes {
for i := 0; i < 4; i++ {
for j := 0; j < 16; j++ {
f := uint64(sBoxes[s][i][j]) << (4 * (7 - uint(s)))
f = permuteBlock(f, permutationFunction[:])
// Row is determined by the 1st and 6th bit.
// Column is the middle four bits.
row := uint8(((i & 2) << 4) | i&1)
col := uint8(j << 1)
t := row | col
// The rotation was performed in the feistel rounds, being factored out and now mixed into the feistelBox.
f = (f << 1) | (f >> 31)
feistelBox[s][t] = uint32(f)
}
}
}
}
// permuteInitialBlock is equivalent to the permutation defined
// by initialPermutation.
func permuteInitialBlock(block uint64) uint64 {
// block = b7 b6 b5 b4 b3 b2 b1 b0 (8 bytes)
b1 := block >> 48
b2 := block << 48
block ^= b1 ^ b2 ^ b1<<48 ^ b2>>48
// block = b1 b0 b5 b4 b3 b2 b7 b6
b1 = block >> 32 & 0xff00ff
b2 = (block & 0xff00ff00)
block ^= b1<<32 ^ b2 ^ b1<<8 ^ b2<<24 // exchange b0 b4 with b3 b7
// block is now b1 b3 b5 b7 b0 b2 b4 b6, the permutation:
// ... 8
// ... 24
// ... 40
// ... 56
// 7 6 5 4 3 2 1 0
// 23 22 21 20 19 18 17 16
// ... 32
// ... 48
// exchange 4,5,6,7 with 32,33,34,35 etc.
b1 = block & 0x0f0f00000f0f0000
b2 = block & 0x0000f0f00000f0f0
block ^= b1 ^ b2 ^ b1>>12 ^ b2<<12
// block is the permutation:
//
// [+8] [+40]
//
// 7 6 5 4
// 23 22 21 20
// 3 2 1 0
// 19 18 17 16 [+32]
// exchange 0,1,4,5 with 18,19,22,23
b1 = block & 0x3300330033003300
b2 = block & 0x00cc00cc00cc00cc
block ^= b1 ^ b2 ^ b1>>6 ^ b2<<6
// block is the permutation:
// 15 14
// 13 12
// 11 10
// 9 8
// 7 6
// 5 4
// 3 2
// 1 0 [+16] [+32] [+64]
// exchange 0,2,4,6 with 9,11,13,15:
b1 = block & 0xaaaaaaaa55555555
block ^= b1 ^ b1>>33 ^ b1<<33
// block is the permutation:
// 6 14 22 30 38 46 54 62
// 4 12 20 28 36 44 52 60
// 2 10 18 26 34 42 50 58
// 0 8 16 24 32 40 48 56
// 7 15 23 31 39 47 55 63
// 5 13 21 29 37 45 53 61
// 3 11 19 27 35 43 51 59
// 1 9 17 25 33 41 49 57
return block
}
// permuteFinalBlock is equivalent to the permutation defined
// by finalPermutation.
func permuteFinalBlock(block uint64) uint64 {
// Perform the same bit exchanges as permuteInitialBlock
// but in reverse order.
b1 := block & 0xaaaaaaaa55555555
block ^= b1 ^ b1>>33 ^ b1<<33
b1 = block & 0x3300330033003300
b2 := block & 0x00cc00cc00cc00cc
block ^= b1 ^ b2 ^ b1>>6 ^ b2<<6
b1 = block & 0x0f0f00000f0f0000
b2 = block & 0x0000f0f00000f0f0
block ^= b1 ^ b2 ^ b1>>12 ^ b2<<12
b1 = block >> 32 & 0xff00ff
b2 = (block & 0xff00ff00)
block ^= b1<<32 ^ b2 ^ b1<<8 ^ b2<<24
b1 = block >> 48
b2 = block << 48
block ^= b1 ^ b2 ^ b1<<48 ^ b2>>48
return block
}
// creates 16 28-bit blocks rotated according
// to the rotation schedule.
func ksRotate(in uint32) (out []uint32) {
out = make([]uint32, 16)
last := in
for i := 0; i < 16; i++ {
// 28-bit circular left shift
left := (last << (4 + ksRotations[i])) >> 4
right := (last << 4) >> (32 - ksRotations[i])
out[i] = left | right
last = out[i]
}
return
}
// creates 16 56-bit subkeys from the original key.
func (c *desCipher) generateSubkeys(keyBytes []byte) {
feistelBoxOnce.Do(initFeistelBox)
// apply PC1 permutation to key
key := byteorder.BEUint64(keyBytes)
permutedKey := permuteBlock(key, permutedChoice1[:])
// rotate halves of permuted key according to the rotation schedule
leftRotations := ksRotate(uint32(permutedKey >> 28))
rightRotations := ksRotate(uint32(permutedKey<<4) >> 4)
// generate subkeys
for i := 0; i < 16; i++ {
// combine halves to form 56-bit input to PC2
pc2Input := uint64(leftRotations[i])<<28 | uint64(rightRotations[i])
// apply PC2 permutation to 7 byte input
c.subkeys[i] = unpack(permuteBlock(pc2Input, permutedChoice2[:]))
}
}
// Expand 48-bit input to 64-bit, with each 6-bit block padded by extra two bits at the top.
// By doing so, we can have the input blocks (four bits each), and the key blocks (six bits each) well-aligned without
// extra shifts/rotations for alignments.
func unpack(x uint64) uint64 {
return ((x>>(6*1))&0xff)<<(8*0) |
((x>>(6*3))&0xff)<<(8*1) |
((x>>(6*5))&0xff)<<(8*2) |
((x>>(6*7))&0xff)<<(8*3) |
((x>>(6*0))&0xff)<<(8*4) |
((x>>(6*2))&0xff)<<(8*5) |
((x>>(6*4))&0xff)<<(8*6) |
((x>>(6*6))&0xff)<<(8*7)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package des
import (
"crypto/cipher"
"crypto/internal/fips140/alias"
"crypto/internal/fips140only"
"errors"
"internal/byteorder"
"strconv"
)
// The DES block size in bytes.
const BlockSize = 8
type KeySizeError int
func (k KeySizeError) Error() string {
return "crypto/des: invalid key size " + strconv.Itoa(int(k))
}
// desCipher is an instance of DES encryption.
type desCipher struct {
subkeys [16]uint64
}
// NewCipher creates and returns a new [cipher.Block].
func NewCipher(key []byte) (cipher.Block, error) {
if fips140only.Enabled {
return nil, errors.New("crypto/des: use of DES is not allowed in FIPS 140-only mode")
}
if len(key) != 8 {
return nil, KeySizeError(len(key))
}
c := new(desCipher)
c.generateSubkeys(key)
return c, nil
}
func (c *desCipher) BlockSize() int { return BlockSize }
func (c *desCipher) Encrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("crypto/des: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/des: output not full block")
}
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("crypto/des: invalid buffer overlap")
}
cryptBlock(c.subkeys[:], dst, src, false)
}
func (c *desCipher) Decrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("crypto/des: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/des: output not full block")
}
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("crypto/des: invalid buffer overlap")
}
cryptBlock(c.subkeys[:], dst, src, true)
}
// A tripleDESCipher is an instance of TripleDES encryption.
type tripleDESCipher struct {
cipher1, cipher2, cipher3 desCipher
}
// NewTripleDESCipher creates and returns a new [cipher.Block].
func NewTripleDESCipher(key []byte) (cipher.Block, error) {
if fips140only.Enabled {
return nil, errors.New("crypto/des: use of TripleDES is not allowed in FIPS 140-only mode")
}
if len(key) != 24 {
return nil, KeySizeError(len(key))
}
c := new(tripleDESCipher)
c.cipher1.generateSubkeys(key[:8])
c.cipher2.generateSubkeys(key[8:16])
c.cipher3.generateSubkeys(key[16:])
return c, nil
}
func (c *tripleDESCipher) BlockSize() int { return BlockSize }
func (c *tripleDESCipher) Encrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("crypto/des: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/des: output not full block")
}
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("crypto/des: invalid buffer overlap")
}
b := byteorder.BEUint64(src)
b = permuteInitialBlock(b)
left, right := uint32(b>>32), uint32(b)
left = (left << 1) | (left >> 31)
right = (right << 1) | (right >> 31)
for i := 0; i < 8; i++ {
left, right = feistel(left, right, c.cipher1.subkeys[2*i], c.cipher1.subkeys[2*i+1])
}
for i := 0; i < 8; i++ {
right, left = feistel(right, left, c.cipher2.subkeys[15-2*i], c.cipher2.subkeys[15-(2*i+1)])
}
for i := 0; i < 8; i++ {
left, right = feistel(left, right, c.cipher3.subkeys[2*i], c.cipher3.subkeys[2*i+1])
}
left = (left << 31) | (left >> 1)
right = (right << 31) | (right >> 1)
preOutput := (uint64(right) << 32) | uint64(left)
byteorder.BEPutUint64(dst, permuteFinalBlock(preOutput))
}
func (c *tripleDESCipher) Decrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("crypto/des: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/des: output not full block")
}
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("crypto/des: invalid buffer overlap")
}
b := byteorder.BEUint64(src)
b = permuteInitialBlock(b)
left, right := uint32(b>>32), uint32(b)
left = (left << 1) | (left >> 31)
right = (right << 1) | (right >> 31)
for i := 0; i < 8; i++ {
left, right = feistel(left, right, c.cipher3.subkeys[15-2*i], c.cipher3.subkeys[15-(2*i+1)])
}
for i := 0; i < 8; i++ {
right, left = feistel(right, left, c.cipher2.subkeys[2*i], c.cipher2.subkeys[2*i+1])
}
for i := 0; i < 8; i++ {
left, right = feistel(left, right, c.cipher1.subkeys[15-2*i], c.cipher1.subkeys[15-(2*i+1)])
}
left = (left << 31) | (left >> 1)
right = (right << 31) | (right >> 1)
preOutput := (uint64(right) << 32) | uint64(left)
byteorder.BEPutUint64(dst, permuteFinalBlock(preOutput))
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ecdsa implements the Elliptic Curve Digital Signature Algorithm, as
// defined in [FIPS 186-5].
//
// Signatures generated by this package are not deterministic, but entropy is
// mixed with the private key and the message, achieving the same level of
// security in case of randomness source failure.
//
// Operations involving private keys are implemented using constant-time
// algorithms, as long as an [elliptic.Curve] returned by [elliptic.P224],
// [elliptic.P256], [elliptic.P384], or [elliptic.P521] is used.
//
// [FIPS 186-5]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-5.pdf
package ecdsa
import (
"crypto"
"crypto/ecdh"
"crypto/elliptic"
"crypto/internal/boring"
"crypto/internal/boring/bbig"
"crypto/internal/fips140/ecdsa"
"crypto/internal/fips140/nistec"
"crypto/internal/fips140cache"
"crypto/internal/fips140hash"
"crypto/internal/fips140only"
"crypto/internal/randutil"
"crypto/sha512"
"crypto/subtle"
"errors"
"io"
"math/big"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
// PublicKey represents an ECDSA public key.
type PublicKey struct {
elliptic.Curve
// X, Y are the coordinates of the public key point.
//
// Modifying the raw coordinates can produce invalid keys, and may
// invalidate internal optimizations; moreover, [big.Int] methods are not
// suitable for operating on cryptographic values. To encode and decode
// PublicKey values, use [PublicKey.Bytes] and [ParseUncompressedPublicKey]
// or [crypto/x509.MarshalPKIXPublicKey] and [crypto/x509.ParsePKIXPublicKey].
// For ECDH, use [crypto/ecdh]. For lower-level elliptic curve operations,
// use a third-party module like filippo.io/nistec.
//
// These fields will be deprecated in Go 1.26.
X, Y *big.Int
}
// Any methods implemented on PublicKey might need to also be implemented on
// PrivateKey, as the latter embeds the former and will expose its methods.
// ECDH returns k as a [ecdh.PublicKey]. It returns an error if the key is
// invalid according to the definition of [ecdh.Curve.NewPublicKey], or if the
// Curve is not supported by crypto/ecdh.
func (k *PublicKey) ECDH() (*ecdh.PublicKey, error) {
c := curveToECDH(k.Curve)
if c == nil {
return nil, errors.New("ecdsa: unsupported curve by crypto/ecdh")
}
if !k.Curve.IsOnCurve(k.X, k.Y) {
return nil, errors.New("ecdsa: invalid public key")
}
return c.NewPublicKey(elliptic.Marshal(k.Curve, k.X, k.Y))
}
// Equal reports whether pub and x have the same value.
//
// Two keys are only considered to have the same value if they have the same Curve value.
// Note that for example [elliptic.P256] and elliptic.P256().Params() are different
// values, as the latter is a generic not constant time implementation.
func (pub *PublicKey) Equal(x crypto.PublicKey) bool {
xx, ok := x.(*PublicKey)
if !ok {
return false
}
return bigIntEqual(pub.X, xx.X) && bigIntEqual(pub.Y, xx.Y) &&
// Standard library Curve implementations are singletons, so this check
// will work for those. Other Curves might be equivalent even if not
// singletons, but there is no definitive way to check for that, and
// better to err on the side of safety.
pub.Curve == xx.Curve
}
// ParseUncompressedPublicKey parses a public key encoded as an uncompressed
// point according to SEC 1, Version 2.0, Section 2.3.3 (also known as the X9.62
// uncompressed format). It returns an error if the point is not in uncompressed
// form, is not on the curve, or is the point at infinity.
//
// curve must be one of [elliptic.P224], [elliptic.P256], [elliptic.P384], or
// [elliptic.P521], or ParseUncompressedPublicKey returns an error.
//
// ParseUncompressedPublicKey accepts the same format as
// [ecdh.Curve.NewPublicKey] does for NIST curves, but returns a [PublicKey]
// instead of an [ecdh.PublicKey].
//
// Note that public keys are more commonly encoded in DER (or PEM) format, which
// can be parsed with [crypto/x509.ParsePKIXPublicKey] (and [encoding/pem]).
func ParseUncompressedPublicKey(curve elliptic.Curve, data []byte) (*PublicKey, error) {
if len(data) < 1 || data[0] != 4 {
return nil, errors.New("ecdsa: invalid uncompressed public key")
}
switch curve {
case elliptic.P224():
return parseUncompressedPublicKey(ecdsa.P224(), curve, data)
case elliptic.P256():
return parseUncompressedPublicKey(ecdsa.P256(), curve, data)
case elliptic.P384():
return parseUncompressedPublicKey(ecdsa.P384(), curve, data)
case elliptic.P521():
return parseUncompressedPublicKey(ecdsa.P521(), curve, data)
default:
return nil, errors.New("ecdsa: curve not supported by ParseUncompressedPublicKey")
}
}
func parseUncompressedPublicKey[P ecdsa.Point[P]](c *ecdsa.Curve[P], curve elliptic.Curve, data []byte) (*PublicKey, error) {
k, err := ecdsa.NewPublicKey(c, data)
if err != nil {
return nil, err
}
return publicKeyFromFIPS(curve, k)
}
// Bytes encodes the public key as an uncompressed point according to SEC 1,
// Version 2.0, Section 2.3.3 (also known as the X9.62 uncompressed format).
// It returns an error if the public key is invalid.
//
// PublicKey.Curve must be one of [elliptic.P224], [elliptic.P256],
// [elliptic.P384], or [elliptic.P521], or Bytes returns an error.
//
// Bytes returns the same format as [ecdh.PublicKey.Bytes] does for NIST curves.
//
// Note that public keys are more commonly encoded in DER (or PEM) format, which
// can be generated with [crypto/x509.MarshalPKIXPublicKey] (and [encoding/pem]).
func (pub *PublicKey) Bytes() ([]byte, error) {
switch pub.Curve {
case elliptic.P224():
return publicKeyBytes(ecdsa.P224(), pub)
case elliptic.P256():
return publicKeyBytes(ecdsa.P256(), pub)
case elliptic.P384():
return publicKeyBytes(ecdsa.P384(), pub)
case elliptic.P521():
return publicKeyBytes(ecdsa.P521(), pub)
default:
return nil, errors.New("ecdsa: curve not supported by PublicKey.Bytes")
}
}
func publicKeyBytes[P ecdsa.Point[P]](c *ecdsa.Curve[P], pub *PublicKey) ([]byte, error) {
k, err := publicKeyToFIPS(c, pub)
if err != nil {
return nil, err
}
return k.Bytes(), nil
}
// PrivateKey represents an ECDSA private key.
type PrivateKey struct {
PublicKey
// D is the private scalar value.
//
// Modifying the raw value can produce invalid keys, and may
// invalidate internal optimizations; moreover, [big.Int] methods are not
// suitable for operating on cryptographic values. To encode and decode
// PrivateKey values, use [PrivateKey.Bytes] and [ParseRawPrivateKey] or
// [crypto/x509.MarshalPKCS8PrivateKey] and [crypto/x509.ParsePKCS8PrivateKey].
// For ECDH, use [crypto/ecdh].
//
// This field will be deprecated in Go 1.26.
D *big.Int
}
// ECDH returns k as a [ecdh.PrivateKey]. It returns an error if the key is
// invalid according to the definition of [ecdh.Curve.NewPrivateKey], or if the
// Curve is not supported by [crypto/ecdh].
func (k *PrivateKey) ECDH() (*ecdh.PrivateKey, error) {
c := curveToECDH(k.Curve)
if c == nil {
return nil, errors.New("ecdsa: unsupported curve by crypto/ecdh")
}
size := (k.Curve.Params().N.BitLen() + 7) / 8
if k.D.BitLen() > size*8 {
return nil, errors.New("ecdsa: invalid private key")
}
return c.NewPrivateKey(k.D.FillBytes(make([]byte, size)))
}
func curveToECDH(c elliptic.Curve) ecdh.Curve {
switch c {
case elliptic.P256():
return ecdh.P256()
case elliptic.P384():
return ecdh.P384()
case elliptic.P521():
return ecdh.P521()
default:
return nil
}
}
// Public returns the public key corresponding to priv.
func (priv *PrivateKey) Public() crypto.PublicKey {
return &priv.PublicKey
}
// Equal reports whether priv and x have the same value.
//
// See [PublicKey.Equal] for details on how Curve is compared.
func (priv *PrivateKey) Equal(x crypto.PrivateKey) bool {
xx, ok := x.(*PrivateKey)
if !ok {
return false
}
return priv.PublicKey.Equal(&xx.PublicKey) && bigIntEqual(priv.D, xx.D)
}
// bigIntEqual reports whether a and b are equal leaking only their bit length
// through timing side-channels.
func bigIntEqual(a, b *big.Int) bool {
return subtle.ConstantTimeCompare(a.Bytes(), b.Bytes()) == 1
}
// ParseRawPrivateKey parses a private key encoded as a fixed-length big-endian
// integer, according to SEC 1, Version 2.0, Section 2.3.6 (sometimes referred
// to as the raw format). It returns an error if the value is not reduced modulo
// the curve's order, or if it's zero.
//
// curve must be one of [elliptic.P224], [elliptic.P256], [elliptic.P384], or
// [elliptic.P521], or ParseRawPrivateKey returns an error.
//
// ParseRawPrivateKey accepts the same format as [ecdh.Curve.NewPrivateKey] does
// for NIST curves, but returns a [PrivateKey] instead of an [ecdh.PrivateKey].
//
// Note that private keys are more commonly encoded in ASN.1 or PKCS#8 format,
// which can be parsed with [crypto/x509.ParseECPrivateKey] or
// [crypto/x509.ParsePKCS8PrivateKey] (and [encoding/pem]).
func ParseRawPrivateKey(curve elliptic.Curve, data []byte) (*PrivateKey, error) {
switch curve {
case elliptic.P224():
return parseRawPrivateKey(ecdsa.P224(), nistec.NewP224Point, curve, data)
case elliptic.P256():
return parseRawPrivateKey(ecdsa.P256(), nistec.NewP256Point, curve, data)
case elliptic.P384():
return parseRawPrivateKey(ecdsa.P384(), nistec.NewP384Point, curve, data)
case elliptic.P521():
return parseRawPrivateKey(ecdsa.P521(), nistec.NewP521Point, curve, data)
default:
return nil, errors.New("ecdsa: curve not supported by ParseRawPrivateKey")
}
}
func parseRawPrivateKey[P ecdsa.Point[P]](c *ecdsa.Curve[P], newPoint func() P, curve elliptic.Curve, data []byte) (*PrivateKey, error) {
q, err := newPoint().ScalarBaseMult(data)
if err != nil {
return nil, err
}
k, err := ecdsa.NewPrivateKey(c, data, q.Bytes())
if err != nil {
return nil, err
}
return privateKeyFromFIPS(curve, k)
}
// Bytes encodes the private key as a fixed-length big-endian integer according
// to SEC 1, Version 2.0, Section 2.3.6 (sometimes referred to as the raw
// format). It returns an error if the private key is invalid.
//
// PrivateKey.Curve must be one of [elliptic.P224], [elliptic.P256],
// [elliptic.P384], or [elliptic.P521], or Bytes returns an error.
//
// Bytes returns the same format as [ecdh.PrivateKey.Bytes] does for NIST curves.
//
// Note that private keys are more commonly encoded in ASN.1 or PKCS#8 format,
// which can be generated with [crypto/x509.MarshalECPrivateKey] or
// [crypto/x509.MarshalPKCS8PrivateKey] (and [encoding/pem]).
func (priv *PrivateKey) Bytes() ([]byte, error) {
switch priv.Curve {
case elliptic.P224():
return privateKeyBytes(ecdsa.P224(), priv)
case elliptic.P256():
return privateKeyBytes(ecdsa.P256(), priv)
case elliptic.P384():
return privateKeyBytes(ecdsa.P384(), priv)
case elliptic.P521():
return privateKeyBytes(ecdsa.P521(), priv)
default:
return nil, errors.New("ecdsa: curve not supported by PrivateKey.Bytes")
}
}
func privateKeyBytes[P ecdsa.Point[P]](c *ecdsa.Curve[P], priv *PrivateKey) ([]byte, error) {
k, err := privateKeyToFIPS(c, priv)
if err != nil {
return nil, err
}
return k.Bytes(), nil
}
// Sign signs a hash (which should be the result of hashing a larger message
// with opts.HashFunc()) using the private key, priv. If the hash is longer than
// the bit-length of the private key's curve order, the hash will be truncated
// to that length. It returns the ASN.1 encoded signature, like [SignASN1].
//
// If rand is not nil, the signature is randomized. Most applications should use
// [crypto/rand.Reader] as rand. Note that the returned signature does not
// depend deterministically on the bytes read from rand, and may change between
// calls and/or between versions.
//
// If rand is nil, Sign will produce a deterministic signature according to RFC
// 6979. When producing a deterministic signature, opts.HashFunc() must be the
// function used to produce digest and priv.Curve must be one of
// [elliptic.P224], [elliptic.P256], [elliptic.P384], or [elliptic.P521].
func (priv *PrivateKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
if rand == nil {
return signRFC6979(priv, digest, opts)
}
return SignASN1(rand, priv, digest)
}
// GenerateKey generates a new ECDSA private key for the specified curve.
//
// Most applications should use [crypto/rand.Reader] as rand. Note that the
// returned key does not depend deterministically on the bytes read from rand,
// and may change between calls and/or between versions.
func GenerateKey(c elliptic.Curve, rand io.Reader) (*PrivateKey, error) {
randutil.MaybeReadByte(rand)
if boring.Enabled && rand == boring.RandReader {
x, y, d, err := boring.GenerateKeyECDSA(c.Params().Name)
if err != nil {
return nil, err
}
return &PrivateKey{PublicKey: PublicKey{Curve: c, X: bbig.Dec(x), Y: bbig.Dec(y)}, D: bbig.Dec(d)}, nil
}
boring.UnreachableExceptTests()
switch c.Params() {
case elliptic.P224().Params():
return generateFIPS(c, ecdsa.P224(), rand)
case elliptic.P256().Params():
return generateFIPS(c, ecdsa.P256(), rand)
case elliptic.P384().Params():
return generateFIPS(c, ecdsa.P384(), rand)
case elliptic.P521().Params():
return generateFIPS(c, ecdsa.P521(), rand)
default:
return generateLegacy(c, rand)
}
}
func generateFIPS[P ecdsa.Point[P]](curve elliptic.Curve, c *ecdsa.Curve[P], rand io.Reader) (*PrivateKey, error) {
if fips140only.Enabled && !fips140only.ApprovedRandomReader(rand) {
return nil, errors.New("crypto/ecdsa: only crypto/rand.Reader is allowed in FIPS 140-only mode")
}
privateKey, err := ecdsa.GenerateKey(c, rand)
if err != nil {
return nil, err
}
return privateKeyFromFIPS(curve, privateKey)
}
// errNoAsm is returned by signAsm and verifyAsm when the assembly
// implementation is not available.
var errNoAsm = errors.New("no assembly implementation available")
// SignASN1 signs a hash (which should be the result of hashing a larger message)
// using the private key, priv. If the hash is longer than the bit-length of the
// private key's curve order, the hash will be truncated to that length. It
// returns the ASN.1 encoded signature.
//
// The signature is randomized. Most applications should use [crypto/rand.Reader]
// as rand. Note that the returned signature does not depend deterministically on
// the bytes read from rand, and may change between calls and/or between versions.
func SignASN1(rand io.Reader, priv *PrivateKey, hash []byte) ([]byte, error) {
randutil.MaybeReadByte(rand)
if boring.Enabled && rand == boring.RandReader {
b, err := boringPrivateKey(priv)
if err != nil {
return nil, err
}
return boring.SignMarshalECDSA(b, hash)
}
boring.UnreachableExceptTests()
switch priv.Curve.Params() {
case elliptic.P224().Params():
return signFIPS(ecdsa.P224(), priv, rand, hash)
case elliptic.P256().Params():
return signFIPS(ecdsa.P256(), priv, rand, hash)
case elliptic.P384().Params():
return signFIPS(ecdsa.P384(), priv, rand, hash)
case elliptic.P521().Params():
return signFIPS(ecdsa.P521(), priv, rand, hash)
default:
return signLegacy(priv, rand, hash)
}
}
func signFIPS[P ecdsa.Point[P]](c *ecdsa.Curve[P], priv *PrivateKey, rand io.Reader, hash []byte) ([]byte, error) {
if fips140only.Enabled && !fips140only.ApprovedRandomReader(rand) {
return nil, errors.New("crypto/ecdsa: only crypto/rand.Reader is allowed in FIPS 140-only mode")
}
k, err := privateKeyToFIPS(c, priv)
if err != nil {
return nil, err
}
// Always using SHA-512 instead of the hash that computed hash is
// technically a violation of draft-irtf-cfrg-det-sigs-with-noise-04 but in
// our API we don't get to know what it was, and this has no security impact.
sig, err := ecdsa.Sign(c, sha512.New, k, rand, hash)
if err != nil {
return nil, err
}
return encodeSignature(sig.R, sig.S)
}
func signRFC6979(priv *PrivateKey, hash []byte, opts crypto.SignerOpts) ([]byte, error) {
if opts == nil {
return nil, errors.New("ecdsa: Sign called with nil opts")
}
h := opts.HashFunc()
if h.Size() != len(hash) {
return nil, errors.New("ecdsa: hash length does not match hash function")
}
switch priv.Curve.Params() {
case elliptic.P224().Params():
return signFIPSDeterministic(ecdsa.P224(), h, priv, hash)
case elliptic.P256().Params():
return signFIPSDeterministic(ecdsa.P256(), h, priv, hash)
case elliptic.P384().Params():
return signFIPSDeterministic(ecdsa.P384(), h, priv, hash)
case elliptic.P521().Params():
return signFIPSDeterministic(ecdsa.P521(), h, priv, hash)
default:
return nil, errors.New("ecdsa: curve not supported by deterministic signatures")
}
}
func signFIPSDeterministic[P ecdsa.Point[P]](c *ecdsa.Curve[P], hashFunc crypto.Hash, priv *PrivateKey, hash []byte) ([]byte, error) {
k, err := privateKeyToFIPS(c, priv)
if err != nil {
return nil, err
}
h := fips140hash.UnwrapNew(hashFunc.New)
if fips140only.Enabled && !fips140only.ApprovedHash(h()) {
return nil, errors.New("crypto/ecdsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
}
sig, err := ecdsa.SignDeterministic(c, h, k, hash)
if err != nil {
return nil, err
}
return encodeSignature(sig.R, sig.S)
}
func encodeSignature(r, s []byte) ([]byte, error) {
var b cryptobyte.Builder
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
addASN1IntBytes(b, r)
addASN1IntBytes(b, s)
})
return b.Bytes()
}
// addASN1IntBytes encodes in ASN.1 a positive integer represented as
// a big-endian byte slice with zero or more leading zeroes.
func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) {
for len(bytes) > 0 && bytes[0] == 0 {
bytes = bytes[1:]
}
if len(bytes) == 0 {
b.SetError(errors.New("invalid integer"))
return
}
b.AddASN1(asn1.INTEGER, func(c *cryptobyte.Builder) {
if bytes[0]&0x80 != 0 {
c.AddUint8(0)
}
c.AddBytes(bytes)
})
}
// VerifyASN1 verifies the ASN.1 encoded signature, sig, of hash using the
// public key, pub. Its return value records whether the signature is valid.
//
// The inputs are not considered confidential, and may leak through timing side
// channels, or if an attacker has control of part of the inputs.
func VerifyASN1(pub *PublicKey, hash, sig []byte) bool {
if boring.Enabled {
key, err := boringPublicKey(pub)
if err != nil {
return false
}
return boring.VerifyECDSA(key, hash, sig)
}
boring.UnreachableExceptTests()
switch pub.Curve.Params() {
case elliptic.P224().Params():
return verifyFIPS(ecdsa.P224(), pub, hash, sig)
case elliptic.P256().Params():
return verifyFIPS(ecdsa.P256(), pub, hash, sig)
case elliptic.P384().Params():
return verifyFIPS(ecdsa.P384(), pub, hash, sig)
case elliptic.P521().Params():
return verifyFIPS(ecdsa.P521(), pub, hash, sig)
default:
return verifyLegacy(pub, hash, sig)
}
}
func verifyFIPS[P ecdsa.Point[P]](c *ecdsa.Curve[P], pub *PublicKey, hash, sig []byte) bool {
r, s, err := parseSignature(sig)
if err != nil {
return false
}
k, err := publicKeyToFIPS(c, pub)
if err != nil {
return false
}
if err := ecdsa.Verify(c, k, hash, &ecdsa.Signature{R: r, S: s}); err != nil {
return false
}
return true
}
func parseSignature(sig []byte) (r, s []byte, err error) {
var inner cryptobyte.String
input := cryptobyte.String(sig)
if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
!input.Empty() ||
!inner.ReadASN1Integer(&r) ||
!inner.ReadASN1Integer(&s) ||
!inner.Empty() {
return nil, nil, errors.New("invalid ASN.1")
}
return r, s, nil
}
func publicKeyFromFIPS(curve elliptic.Curve, pub *ecdsa.PublicKey) (*PublicKey, error) {
x, y, err := pointToAffine(curve, pub.Bytes())
if err != nil {
return nil, err
}
return &PublicKey{Curve: curve, X: x, Y: y}, nil
}
func privateKeyFromFIPS(curve elliptic.Curve, priv *ecdsa.PrivateKey) (*PrivateKey, error) {
pub, err := publicKeyFromFIPS(curve, priv.PublicKey())
if err != nil {
return nil, err
}
return &PrivateKey{PublicKey: *pub, D: new(big.Int).SetBytes(priv.Bytes())}, nil
}
func publicKeyToFIPS[P ecdsa.Point[P]](c *ecdsa.Curve[P], pub *PublicKey) (*ecdsa.PublicKey, error) {
Q, err := pointFromAffine(pub.Curve, pub.X, pub.Y)
if err != nil {
return nil, err
}
return ecdsa.NewPublicKey(c, Q)
}
var privateKeyCache fips140cache.Cache[PrivateKey, ecdsa.PrivateKey]
func privateKeyToFIPS[P ecdsa.Point[P]](c *ecdsa.Curve[P], priv *PrivateKey) (*ecdsa.PrivateKey, error) {
Q, err := pointFromAffine(priv.Curve, priv.X, priv.Y)
if err != nil {
return nil, err
}
return privateKeyCache.Get(priv, func() (*ecdsa.PrivateKey, error) {
return ecdsa.NewPrivateKey(c, priv.D.Bytes(), Q)
}, func(k *ecdsa.PrivateKey) bool {
return subtle.ConstantTimeCompare(k.PublicKey().Bytes(), Q) == 1 &&
leftPadBytesEqual(k.Bytes(), priv.D.Bytes())
})
}
func leftPadBytesEqual(a, b []byte) bool {
if len(a) < len(b) {
a, b = b, a
}
if len(a) > len(b) {
x := make([]byte, 0, 66 /* enough for a P-521 private key */)
x = append(x, make([]byte, len(a)-len(b))...)
x = append(x, b...)
b = x
}
return subtle.ConstantTimeCompare(a, b) == 1
}
// pointFromAffine is used to convert the PublicKey to a nistec SetBytes input.
func pointFromAffine(curve elliptic.Curve, x, y *big.Int) ([]byte, error) {
bitSize := curve.Params().BitSize
// Reject values that would not get correctly encoded.
if x.Sign() < 0 || y.Sign() < 0 {
return nil, errors.New("negative coordinate")
}
if x.BitLen() > bitSize || y.BitLen() > bitSize {
return nil, errors.New("overflowing coordinate")
}
// Encode the coordinates and let SetBytes reject invalid points.
byteLen := (bitSize + 7) / 8
buf := make([]byte, 1+2*byteLen)
buf[0] = 4 // uncompressed point
x.FillBytes(buf[1 : 1+byteLen])
y.FillBytes(buf[1+byteLen : 1+2*byteLen])
return buf, nil
}
// pointToAffine is used to convert a nistec Bytes encoding to a PublicKey.
func pointToAffine(curve elliptic.Curve, p []byte) (x, y *big.Int, err error) {
if len(p) == 1 && p[0] == 0 {
// This is the encoding of the point at infinity.
return nil, nil, errors.New("ecdsa: public key point is the infinity")
}
byteLen := (curve.Params().BitSize + 7) / 8
x = new(big.Int).SetBytes(p[1 : 1+byteLen])
y = new(big.Int).SetBytes(p[1+byteLen:])
return x, y, nil
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ecdsa
import (
"crypto/elliptic"
"crypto/internal/fips140only"
"errors"
"io"
"math/big"
"math/rand/v2"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
// This file contains a math/big implementation of ECDSA that is only used for
// deprecated custom curves.
func generateLegacy(c elliptic.Curve, rand io.Reader) (*PrivateKey, error) {
if fips140only.Enabled {
return nil, errors.New("crypto/ecdsa: use of custom curves is not allowed in FIPS 140-only mode")
}
k, err := randFieldElement(c, rand)
if err != nil {
return nil, err
}
priv := new(PrivateKey)
priv.PublicKey.Curve = c
priv.D = k
priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
return priv, nil
}
// hashToInt converts a hash value to an integer. Per FIPS 186-4, Section 6.4,
// we use the left-most bits of the hash to match the bit-length of the order of
// the curve. This also performs Step 5 of SEC 1, Version 2.0, Section 4.1.3.
func hashToInt(hash []byte, c elliptic.Curve) *big.Int {
orderBits := c.Params().N.BitLen()
orderBytes := (orderBits + 7) / 8
if len(hash) > orderBytes {
hash = hash[:orderBytes]
}
ret := new(big.Int).SetBytes(hash)
excess := len(hash)*8 - orderBits
if excess > 0 {
ret.Rsh(ret, uint(excess))
}
return ret
}
var errZeroParam = errors.New("zero parameter")
// Sign signs a hash (which should be the result of hashing a larger message)
// using the private key, priv. If the hash is longer than the bit-length of the
// private key's curve order, the hash will be truncated to that length. It
// returns the signature as a pair of integers. Most applications should use
// [SignASN1] instead of dealing directly with r, s.
func Sign(rand io.Reader, priv *PrivateKey, hash []byte) (r, s *big.Int, err error) {
sig, err := SignASN1(rand, priv, hash)
if err != nil {
return nil, nil, err
}
r, s = new(big.Int), new(big.Int)
var inner cryptobyte.String
input := cryptobyte.String(sig)
if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
!input.Empty() ||
!inner.ReadASN1Integer(r) ||
!inner.ReadASN1Integer(s) ||
!inner.Empty() {
return nil, nil, errors.New("invalid ASN.1 from SignASN1")
}
return r, s, nil
}
func signLegacy(priv *PrivateKey, csprng io.Reader, hash []byte) (sig []byte, err error) {
if fips140only.Enabled {
return nil, errors.New("crypto/ecdsa: use of custom curves is not allowed in FIPS 140-only mode")
}
c := priv.Curve
// A cheap version of hedged signatures, for the deprecated path.
var seed [32]byte
if _, err := io.ReadFull(csprng, seed[:]); err != nil {
return nil, err
}
for i, b := range priv.D.Bytes() {
seed[i%32] ^= b
}
for i, b := range hash {
seed[i%32] ^= b
}
csprng = rand.NewChaCha8(seed)
// SEC 1, Version 2.0, Section 4.1.3
N := c.Params().N
if N.Sign() == 0 {
return nil, errZeroParam
}
var k, kInv, r, s *big.Int
for {
for {
k, err = randFieldElement(c, csprng)
if err != nil {
return nil, err
}
kInv = new(big.Int).ModInverse(k, N)
r, _ = c.ScalarBaseMult(k.Bytes())
r.Mod(r, N)
if r.Sign() != 0 {
break
}
}
e := hashToInt(hash, c)
s = new(big.Int).Mul(priv.D, r)
s.Add(s, e)
s.Mul(s, kInv)
s.Mod(s, N) // N != 0
if s.Sign() != 0 {
break
}
}
return encodeSignature(r.Bytes(), s.Bytes())
}
// Verify verifies the signature in r, s of hash using the public key, pub. Its
// return value records whether the signature is valid. Most applications should
// use VerifyASN1 instead of dealing directly with r, s.
//
// The inputs are not considered confidential, and may leak through timing side
// channels, or if an attacker has control of part of the inputs.
func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
if r.Sign() <= 0 || s.Sign() <= 0 {
return false
}
sig, err := encodeSignature(r.Bytes(), s.Bytes())
if err != nil {
return false
}
return VerifyASN1(pub, hash, sig)
}
func verifyLegacy(pub *PublicKey, hash []byte, sig []byte) bool {
if fips140only.Enabled {
panic("crypto/ecdsa: use of custom curves is not allowed in FIPS 140-only mode")
}
rBytes, sBytes, err := parseSignature(sig)
if err != nil {
return false
}
r, s := new(big.Int).SetBytes(rBytes), new(big.Int).SetBytes(sBytes)
c := pub.Curve
N := c.Params().N
if r.Sign() <= 0 || s.Sign() <= 0 {
return false
}
if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
return false
}
// SEC 1, Version 2.0, Section 4.1.4
e := hashToInt(hash, c)
w := new(big.Int).ModInverse(s, N)
u1 := e.Mul(e, w)
u1.Mod(u1, N)
u2 := w.Mul(r, w)
u2.Mod(u2, N)
x1, y1 := c.ScalarBaseMult(u1.Bytes())
x2, y2 := c.ScalarMult(pub.X, pub.Y, u2.Bytes())
x, y := c.Add(x1, y1, x2, y2)
if x.Sign() == 0 && y.Sign() == 0 {
return false
}
x.Mod(x, N)
return x.Cmp(r) == 0
}
var one = new(big.Int).SetInt64(1)
// randFieldElement returns a random element of the order of the given
// curve using the procedure given in FIPS 186-4, Appendix B.5.2.
func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
for {
N := c.Params().N
b := make([]byte, (N.BitLen()+7)/8)
if _, err = io.ReadFull(rand, b); err != nil {
return
}
if excess := len(b)*8 - N.BitLen(); excess > 0 {
b[0] >>= excess
}
k = new(big.Int).SetBytes(b)
if k.Sign() != 0 && k.Cmp(N) < 0 {
return
}
}
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !boringcrypto
package ecdsa
import "crypto/internal/boring"
func boringPublicKey(*PublicKey) (*boring.PublicKeyECDSA, error) {
panic("boringcrypto: not available")
}
func boringPrivateKey(*PrivateKey) (*boring.PrivateKeyECDSA, error) {
panic("boringcrypto: not available")
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ed25519 implements the Ed25519 signature algorithm. See
// https://ed25519.cr.yp.to/.
//
// These functions are also compatible with the “Ed25519” function defined in
// RFC 8032. However, unlike RFC 8032's formulation, this package's private key
// representation includes a public key suffix to make multiple signing
// operations with the same key more efficient. This package refers to the RFC
// 8032 private key as the “seed”.
//
// Operations involving private keys are implemented using constant-time
// algorithms.
package ed25519
import (
"crypto"
"crypto/internal/fips140/ed25519"
"crypto/internal/fips140cache"
"crypto/internal/fips140only"
cryptorand "crypto/rand"
"crypto/subtle"
"errors"
"io"
"strconv"
)
const (
// PublicKeySize is the size, in bytes, of public keys as used in this package.
PublicKeySize = 32
// PrivateKeySize is the size, in bytes, of private keys as used in this package.
PrivateKeySize = 64
// SignatureSize is the size, in bytes, of signatures generated and verified by this package.
SignatureSize = 64
// SeedSize is the size, in bytes, of private key seeds. These are the private key representations used by RFC 8032.
SeedSize = 32
)
// PublicKey is the type of Ed25519 public keys.
type PublicKey []byte
// Any methods implemented on PublicKey might need to also be implemented on
// PrivateKey, as the latter embeds the former and will expose its methods.
// Equal reports whether pub and x have the same value.
func (pub PublicKey) Equal(x crypto.PublicKey) bool {
xx, ok := x.(PublicKey)
if !ok {
return false
}
return subtle.ConstantTimeCompare(pub, xx) == 1
}
// PrivateKey is the type of Ed25519 private keys. It implements [crypto.Signer].
type PrivateKey []byte
// Public returns the [PublicKey] corresponding to priv.
func (priv PrivateKey) Public() crypto.PublicKey {
publicKey := make([]byte, PublicKeySize)
copy(publicKey, priv[32:])
return PublicKey(publicKey)
}
// Equal reports whether priv and x have the same value.
func (priv PrivateKey) Equal(x crypto.PrivateKey) bool {
xx, ok := x.(PrivateKey)
if !ok {
return false
}
return subtle.ConstantTimeCompare(priv, xx) == 1
}
// Seed returns the private key seed corresponding to priv. It is provided for
// interoperability with RFC 8032. RFC 8032's private keys correspond to seeds
// in this package.
func (priv PrivateKey) Seed() []byte {
return append(make([]byte, 0, SeedSize), priv[:SeedSize]...)
}
// privateKeyCache uses a pointer to the first byte of underlying storage as a
// key, because [PrivateKey] is a slice header passed around by value.
var privateKeyCache fips140cache.Cache[byte, ed25519.PrivateKey]
// Sign signs the given message with priv. rand is ignored and can be nil.
//
// If opts.HashFunc() is [crypto.SHA512], the pre-hashed variant Ed25519ph is used
// and message is expected to be a SHA-512 hash, otherwise opts.HashFunc() must
// be [crypto.Hash](0) and the message must not be hashed, as Ed25519 performs two
// passes over messages to be signed.
//
// A value of type [Options] can be used as opts, or crypto.Hash(0) or
// crypto.SHA512 directly to select plain Ed25519 or Ed25519ph, respectively.
func (priv PrivateKey) Sign(rand io.Reader, message []byte, opts crypto.SignerOpts) (signature []byte, err error) {
k, err := privateKeyCache.Get(&priv[0], func() (*ed25519.PrivateKey, error) {
return ed25519.NewPrivateKey(priv)
}, func(k *ed25519.PrivateKey) bool {
return subtle.ConstantTimeCompare(priv, k.Bytes()) == 1
})
if err != nil {
return nil, err
}
hash := opts.HashFunc()
context := ""
if opts, ok := opts.(*Options); ok {
context = opts.Context
}
switch {
case hash == crypto.SHA512: // Ed25519ph
return ed25519.SignPH(k, message, context)
case hash == crypto.Hash(0) && context != "": // Ed25519ctx
if fips140only.Enabled {
return nil, errors.New("crypto/ed25519: use of Ed25519ctx is not allowed in FIPS 140-only mode")
}
return ed25519.SignCtx(k, message, context)
case hash == crypto.Hash(0): // Ed25519
return ed25519.Sign(k, message), nil
default:
return nil, errors.New("ed25519: expected opts.HashFunc() zero (unhashed message, for standard Ed25519) or SHA-512 (for Ed25519ph)")
}
}
// Options can be used with [PrivateKey.Sign] or [VerifyWithOptions]
// to select Ed25519 variants.
type Options struct {
// Hash can be zero for regular Ed25519, or crypto.SHA512 for Ed25519ph.
Hash crypto.Hash
// Context, if not empty, selects Ed25519ctx or provides the context string
// for Ed25519ph. It can be at most 255 bytes in length.
Context string
}
// HashFunc returns o.Hash.
func (o *Options) HashFunc() crypto.Hash { return o.Hash }
// GenerateKey generates a public/private key pair using entropy from rand.
// If rand is nil, [crypto/rand.Reader] will be used.
//
// The output of this function is deterministic, and equivalent to reading
// [SeedSize] bytes from rand, and passing them to [NewKeyFromSeed].
func GenerateKey(rand io.Reader) (PublicKey, PrivateKey, error) {
if rand == nil {
rand = cryptorand.Reader
}
seed := make([]byte, SeedSize)
if _, err := io.ReadFull(rand, seed); err != nil {
return nil, nil, err
}
privateKey := NewKeyFromSeed(seed)
publicKey := privateKey.Public().(PublicKey)
return publicKey, privateKey, nil
}
// NewKeyFromSeed calculates a private key from a seed. It will panic if
// len(seed) is not [SeedSize]. This function is provided for interoperability
// with RFC 8032. RFC 8032's private keys correspond to seeds in this
// package.
func NewKeyFromSeed(seed []byte) PrivateKey {
// Outline the function body so that the returned key can be stack-allocated.
privateKey := make([]byte, PrivateKeySize)
newKeyFromSeed(privateKey, seed)
return privateKey
}
func newKeyFromSeed(privateKey, seed []byte) {
k, err := ed25519.NewPrivateKeyFromSeed(seed)
if err != nil {
// NewPrivateKeyFromSeed only returns an error if the seed length is incorrect.
panic("ed25519: bad seed length: " + strconv.Itoa(len(seed)))
}
copy(privateKey, k.Bytes())
}
// Sign signs the message with privateKey and returns a signature. It will
// panic if len(privateKey) is not [PrivateKeySize].
func Sign(privateKey PrivateKey, message []byte) []byte {
// Outline the function body so that the returned signature can be
// stack-allocated.
signature := make([]byte, SignatureSize)
sign(signature, privateKey, message)
return signature
}
func sign(signature []byte, privateKey PrivateKey, message []byte) {
k, err := privateKeyCache.Get(&privateKey[0], func() (*ed25519.PrivateKey, error) {
return ed25519.NewPrivateKey(privateKey)
}, func(k *ed25519.PrivateKey) bool {
return subtle.ConstantTimeCompare(privateKey, k.Bytes()) == 1
})
if err != nil {
panic("ed25519: bad private key: " + err.Error())
}
sig := ed25519.Sign(k, message)
copy(signature, sig)
}
// Verify reports whether sig is a valid signature of message by publicKey. It
// will panic if len(publicKey) is not [PublicKeySize].
//
// The inputs are not considered confidential, and may leak through timing side
// channels, or if an attacker has control of part of the inputs.
func Verify(publicKey PublicKey, message, sig []byte) bool {
return VerifyWithOptions(publicKey, message, sig, &Options{Hash: crypto.Hash(0)}) == nil
}
// VerifyWithOptions reports whether sig is a valid signature of message by
// publicKey. A valid signature is indicated by returning a nil error. It will
// panic if len(publicKey) is not [PublicKeySize].
//
// If opts.Hash is [crypto.SHA512], the pre-hashed variant Ed25519ph is used and
// message is expected to be a SHA-512 hash, otherwise opts.Hash must be
// [crypto.Hash](0) and the message must not be hashed, as Ed25519 performs two
// passes over messages to be signed.
//
// The inputs are not considered confidential, and may leak through timing side
// channels, or if an attacker has control of part of the inputs.
func VerifyWithOptions(publicKey PublicKey, message, sig []byte, opts *Options) error {
if l := len(publicKey); l != PublicKeySize {
panic("ed25519: bad public key length: " + strconv.Itoa(l))
}
k, err := ed25519.NewPublicKey(publicKey)
if err != nil {
return err
}
switch {
case opts.Hash == crypto.SHA512: // Ed25519ph
return ed25519.VerifyPH(k, message, sig, opts.Context)
case opts.Hash == crypto.Hash(0) && opts.Context != "": // Ed25519ctx
if fips140only.Enabled {
return errors.New("crypto/ed25519: use of Ed25519ctx is not allowed in FIPS 140-only mode")
}
return ed25519.VerifyCtx(k, message, sig, opts.Context)
case opts.Hash == crypto.Hash(0): // Ed25519
return ed25519.Verify(k, message, sig)
default:
return errors.New("ed25519: expected opts.Hash zero (unhashed message, for standard Ed25519) or SHA-512 (for Ed25519ph)")
}
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package elliptic implements the standard NIST P-224, P-256, P-384, and P-521
// elliptic curves over prime fields.
//
// Direct use of this package is deprecated, beyond the [P224], [P256], [P384],
// and [P521] values necessary to use [crypto/ecdsa]. Most other uses
// should migrate to the more efficient and safer [crypto/ecdh], or to
// third-party modules for lower-level functionality.
package elliptic
import (
"io"
"math/big"
"sync"
)
// A Curve represents a short-form Weierstrass curve with a=-3.
//
// The behavior of Add, Double, and ScalarMult when the input is not a point on
// the curve is undefined.
//
// Note that the conventional point at infinity (0, 0) is not considered on the
// curve, although it can be returned by Add, Double, ScalarMult, or
// ScalarBaseMult (but not the [Unmarshal] or [UnmarshalCompressed] functions).
//
// Using Curve implementations besides those returned by [P224], [P256], [P384],
// and [P521] is deprecated.
type Curve interface {
// Params returns the parameters for the curve.
Params() *CurveParams
// IsOnCurve reports whether the given (x,y) lies on the curve.
//
// Deprecated: this is a low-level unsafe API. For ECDH, use the crypto/ecdh
// package. The NewPublicKey methods of NIST curves in crypto/ecdh accept
// the same encoding as the Unmarshal function, and perform on-curve checks.
IsOnCurve(x, y *big.Int) bool
// Add returns the sum of (x1,y1) and (x2,y2).
//
// Deprecated: this is a low-level unsafe API.
Add(x1, y1, x2, y2 *big.Int) (x, y *big.Int)
// Double returns 2*(x,y).
//
// Deprecated: this is a low-level unsafe API.
Double(x1, y1 *big.Int) (x, y *big.Int)
// ScalarMult returns k*(x,y) where k is an integer in big-endian form.
//
// Deprecated: this is a low-level unsafe API. For ECDH, use the crypto/ecdh
// package. Most uses of ScalarMult can be replaced by a call to the ECDH
// methods of NIST curves in crypto/ecdh.
ScalarMult(x1, y1 *big.Int, k []byte) (x, y *big.Int)
// ScalarBaseMult returns k*G, where G is the base point of the group
// and k is an integer in big-endian form.
//
// Deprecated: this is a low-level unsafe API. For ECDH, use the crypto/ecdh
// package. Most uses of ScalarBaseMult can be replaced by a call to the
// PrivateKey.PublicKey method in crypto/ecdh.
ScalarBaseMult(k []byte) (x, y *big.Int)
}
var mask = []byte{0xff, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f}
// GenerateKey returns a public/private key pair. The private key is
// generated using the given reader, which must return random data.
//
// Deprecated: for ECDH, use the GenerateKey methods of the [crypto/ecdh] package;
// for ECDSA, use the GenerateKey function of the crypto/ecdsa package.
func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err error) {
N := curve.Params().N
bitSize := N.BitLen()
byteLen := (bitSize + 7) / 8
priv = make([]byte, byteLen)
for x == nil {
_, err = io.ReadFull(rand, priv)
if err != nil {
return
}
// We have to mask off any excess bits in the case that the size of the
// underlying field is not a whole number of bytes.
priv[0] &= mask[bitSize%8]
// This is because, in tests, rand will return all zeros and we don't
// want to get the point at infinity and loop forever.
priv[1] ^= 0x42
// If the scalar is out of range, sample another random number.
if new(big.Int).SetBytes(priv).Cmp(N) >= 0 {
continue
}
x, y = curve.ScalarBaseMult(priv)
}
return
}
// Marshal converts a point on the curve into the uncompressed form specified in
// SEC 1, Version 2.0, Section 2.3.3. If the point is not on the curve (or is
// the conventional point at infinity), the behavior is undefined.
//
// Deprecated: for ECDH, use the crypto/ecdh package. This function returns an
// encoding equivalent to that of PublicKey.Bytes in crypto/ecdh.
func Marshal(curve Curve, x, y *big.Int) []byte {
panicIfNotOnCurve(curve, x, y)
byteLen := (curve.Params().BitSize + 7) / 8
ret := make([]byte, 1+2*byteLen)
ret[0] = 4 // uncompressed point
x.FillBytes(ret[1 : 1+byteLen])
y.FillBytes(ret[1+byteLen : 1+2*byteLen])
return ret
}
// MarshalCompressed converts a point on the curve into the compressed form
// specified in SEC 1, Version 2.0, Section 2.3.3. If the point is not on the
// curve (or is the conventional point at infinity), the behavior is undefined.
func MarshalCompressed(curve Curve, x, y *big.Int) []byte {
panicIfNotOnCurve(curve, x, y)
byteLen := (curve.Params().BitSize + 7) / 8
compressed := make([]byte, 1+byteLen)
compressed[0] = byte(y.Bit(0)) | 2
x.FillBytes(compressed[1:])
return compressed
}
// unmarshaler is implemented by curves with their own constant-time Unmarshal.
//
// There isn't an equivalent interface for Marshal/MarshalCompressed because
// that doesn't involve any mathematical operations, only FillBytes and Bit.
type unmarshaler interface {
Unmarshal([]byte) (x, y *big.Int)
UnmarshalCompressed([]byte) (x, y *big.Int)
}
// Assert that the known curves implement unmarshaler.
var _ = []unmarshaler{p224, p256, p384, p521}
// Unmarshal converts a point, serialized by [Marshal], into an x, y pair. It is
// an error if the point is not in uncompressed form, is not on the curve, or is
// the point at infinity. On error, x = nil.
//
// Deprecated: for ECDH, use the crypto/ecdh package. This function accepts an
// encoding equivalent to that of the NewPublicKey methods in crypto/ecdh.
func Unmarshal(curve Curve, data []byte) (x, y *big.Int) {
if c, ok := curve.(unmarshaler); ok {
return c.Unmarshal(data)
}
byteLen := (curve.Params().BitSize + 7) / 8
if len(data) != 1+2*byteLen {
return nil, nil
}
if data[0] != 4 { // uncompressed form
return nil, nil
}
p := curve.Params().P
x = new(big.Int).SetBytes(data[1 : 1+byteLen])
y = new(big.Int).SetBytes(data[1+byteLen:])
if x.Cmp(p) >= 0 || y.Cmp(p) >= 0 {
return nil, nil
}
if !curve.IsOnCurve(x, y) {
return nil, nil
}
return
}
// UnmarshalCompressed converts a point, serialized by [MarshalCompressed], into
// an x, y pair. It is an error if the point is not in compressed form, is not
// on the curve, or is the point at infinity. On error, x = nil.
func UnmarshalCompressed(curve Curve, data []byte) (x, y *big.Int) {
if c, ok := curve.(unmarshaler); ok {
return c.UnmarshalCompressed(data)
}
byteLen := (curve.Params().BitSize + 7) / 8
if len(data) != 1+byteLen {
return nil, nil
}
if data[0] != 2 && data[0] != 3 { // compressed form
return nil, nil
}
p := curve.Params().P
x = new(big.Int).SetBytes(data[1:])
if x.Cmp(p) >= 0 {
return nil, nil
}
// y² = x³ - 3x + b
y = curve.Params().polynomial(x)
y = y.ModSqrt(y, p)
if y == nil {
return nil, nil
}
if byte(y.Bit(0)) != data[0]&1 {
y.Neg(y).Mod(y, p)
}
if !curve.IsOnCurve(x, y) {
return nil, nil
}
return
}
func panicIfNotOnCurve(curve Curve, x, y *big.Int) {
// (0, 0) is the point at infinity by convention. It's ok to operate on it,
// although IsOnCurve is documented to return false for it. See Issue 37294.
if x.Sign() == 0 && y.Sign() == 0 {
return
}
if !curve.IsOnCurve(x, y) {
panic("crypto/elliptic: attempted operation on invalid point")
}
}
var initonce sync.Once
func initAll() {
initP224()
initP256()
initP384()
initP521()
}
// P224 returns a [Curve] which implements NIST P-224 (FIPS 186-3, section D.2.2),
// also known as secp224r1. The CurveParams.Name of this [Curve] is "P-224".
//
// Multiple invocations of this function will return the same value, so it can
// be used for equality checks and switch statements.
//
// The cryptographic operations are implemented using constant-time algorithms.
func P224() Curve {
initonce.Do(initAll)
return p224
}
// P256 returns a [Curve] which implements NIST P-256 (FIPS 186-3, section D.2.3),
// also known as secp256r1 or prime256v1. The CurveParams.Name of this [Curve] is
// "P-256".
//
// Multiple invocations of this function will return the same value, so it can
// be used for equality checks and switch statements.
//
// The cryptographic operations are implemented using constant-time algorithms.
func P256() Curve {
initonce.Do(initAll)
return p256
}
// P384 returns a [Curve] which implements NIST P-384 (FIPS 186-3, section D.2.4),
// also known as secp384r1. The CurveParams.Name of this [Curve] is "P-384".
//
// Multiple invocations of this function will return the same value, so it can
// be used for equality checks and switch statements.
//
// The cryptographic operations are implemented using constant-time algorithms.
func P384() Curve {
initonce.Do(initAll)
return p384
}
// P521 returns a [Curve] which implements NIST P-521 (FIPS 186-3, section D.2.5),
// also known as secp521r1. The CurveParams.Name of this [Curve] is "P-521".
//
// Multiple invocations of this function will return the same value, so it can
// be used for equality checks and switch statements.
//
// The cryptographic operations are implemented using constant-time algorithms.
func P521() Curve {
initonce.Do(initAll)
return p521
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package elliptic
import (
"crypto/internal/fips140/nistec"
"errors"
"math/big"
)
var p224 = &nistCurve[*nistec.P224Point]{
newPoint: nistec.NewP224Point,
}
func initP224() {
p224.params = &CurveParams{
Name: "P-224",
BitSize: 224,
// SP 800-186, Section 3.2.1.2
P: bigFromDecimal("26959946667150639794667015087019630673557916260026308143510066298881"),
N: bigFromDecimal("26959946667150639794667015087019625940457807714424391721682722368061"),
B: bigFromHex("b4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4"),
Gx: bigFromHex("b70e0cbd6bb4bf7f321390b94a03c1d356c21122343280d6115c1d21"),
Gy: bigFromHex("bd376388b5f723fb4c22dfe6cd4375a05a07476444d5819985007e34"),
}
}
var p256 = &nistCurve[*nistec.P256Point]{
newPoint: nistec.NewP256Point,
}
func initP256() {
p256.params = &CurveParams{
Name: "P-256",
BitSize: 256,
// SP 800-186, Section 3.2.1.3
P: bigFromDecimal("115792089210356248762697446949407573530086143415290314195533631308867097853951"),
N: bigFromDecimal("115792089210356248762697446949407573529996955224135760342422259061068512044369"),
B: bigFromHex("5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b"),
Gx: bigFromHex("6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296"),
Gy: bigFromHex("4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5"),
}
}
var p384 = &nistCurve[*nistec.P384Point]{
newPoint: nistec.NewP384Point,
}
func initP384() {
p384.params = &CurveParams{
Name: "P-384",
BitSize: 384,
// SP 800-186, Section 3.2.1.4
P: bigFromDecimal("394020061963944792122790401001436138050797392704654" +
"46667948293404245721771496870329047266088258938001861606973112319"),
N: bigFromDecimal("394020061963944792122790401001436138050797392704654" +
"46667946905279627659399113263569398956308152294913554433653942643"),
B: bigFromHex("b3312fa7e23ee7e4988e056be3f82d19181d9c6efe8141120314088" +
"f5013875ac656398d8a2ed19d2a85c8edd3ec2aef"),
Gx: bigFromHex("aa87ca22be8b05378eb1c71ef320ad746e1d3b628ba79b9859f741" +
"e082542a385502f25dbf55296c3a545e3872760ab7"),
Gy: bigFromHex("3617de4a96262c6f5d9e98bf9292dc29f8f41dbd289a147ce9da31" +
"13b5f0b8c00a60b1ce1d7e819d7a431d7c90ea0e5f"),
}
}
var p521 = &nistCurve[*nistec.P521Point]{
newPoint: nistec.NewP521Point,
}
func initP521() {
p521.params = &CurveParams{
Name: "P-521",
BitSize: 521,
// SP 800-186, Section 3.2.1.5
P: bigFromDecimal("68647976601306097149819007990813932172694353001433" +
"0540939446345918554318339765605212255964066145455497729631139148" +
"0858037121987999716643812574028291115057151"),
N: bigFromDecimal("68647976601306097149819007990813932172694353001433" +
"0540939446345918554318339765539424505774633321719753296399637136" +
"3321113864768612440380340372808892707005449"),
B: bigFromHex("0051953eb9618e1c9a1f929a21a0b68540eea2da725b99b315f3b8" +
"b489918ef109e156193951ec7e937b1652c0bd3bb1bf073573df883d2c34f1ef" +
"451fd46b503f00"),
Gx: bigFromHex("00c6858e06b70404e9cd9e3ecb662395b4429c648139053fb521f8" +
"28af606b4d3dbaa14b5e77efe75928fe1dc127a2ffa8de3348b3c1856a429bf9" +
"7e7e31c2e5bd66"),
Gy: bigFromHex("011839296a789a3bc0045c8a5fb42c7d1bd998f54449579b446817" +
"afbd17273e662c97ee72995ef42640c550b9013fad0761353c7086a272c24088" +
"be94769fd16650"),
}
}
// nistCurve is a Curve implementation based on a nistec Point.
//
// It's a wrapper that exposes the big.Int-based Curve interface and encodes the
// legacy idiosyncrasies it requires, such as invalid and infinity point
// handling.
//
// To interact with the nistec package, points are encoded into and decoded from
// properly formatted byte slices. All big.Int use is limited to this package.
// Encoding and decoding is 1/1000th of the runtime of a scalar multiplication,
// so the overhead is acceptable.
type nistCurve[Point nistPoint[Point]] struct {
newPoint func() Point
params *CurveParams
}
// nistPoint is a generic constraint for the nistec Point types.
type nistPoint[T any] interface {
Bytes() []byte
SetBytes([]byte) (T, error)
Add(T, T) T
Double(T) T
ScalarMult(T, []byte) (T, error)
ScalarBaseMult([]byte) (T, error)
}
func (curve *nistCurve[Point]) Params() *CurveParams {
return curve.params
}
func (curve *nistCurve[Point]) IsOnCurve(x, y *big.Int) bool {
// IsOnCurve is documented to reject (0, 0), the conventional point at
// infinity, which however is accepted by pointFromAffine.
if x.Sign() == 0 && y.Sign() == 0 {
return false
}
_, err := curve.pointFromAffine(x, y)
return err == nil
}
func (curve *nistCurve[Point]) pointFromAffine(x, y *big.Int) (p Point, err error) {
// (0, 0) is by convention the point at infinity, which can't be represented
// in affine coordinates. See Issue 37294.
if x.Sign() == 0 && y.Sign() == 0 {
return curve.newPoint(), nil
}
// Reject values that would not get correctly encoded.
if x.Sign() < 0 || y.Sign() < 0 {
return p, errors.New("negative coordinate")
}
if x.BitLen() > curve.params.BitSize || y.BitLen() > curve.params.BitSize {
return p, errors.New("overflowing coordinate")
}
// Encode the coordinates and let SetBytes reject invalid points.
byteLen := (curve.params.BitSize + 7) / 8
buf := make([]byte, 1+2*byteLen)
buf[0] = 4 // uncompressed point
x.FillBytes(buf[1 : 1+byteLen])
y.FillBytes(buf[1+byteLen : 1+2*byteLen])
return curve.newPoint().SetBytes(buf)
}
func (curve *nistCurve[Point]) pointToAffine(p Point) (x, y *big.Int) {
out := p.Bytes()
if len(out) == 1 && out[0] == 0 {
// This is the encoding of the point at infinity, which the affine
// coordinates API represents as (0, 0) by convention.
return new(big.Int), new(big.Int)
}
byteLen := (curve.params.BitSize + 7) / 8
x = new(big.Int).SetBytes(out[1 : 1+byteLen])
y = new(big.Int).SetBytes(out[1+byteLen:])
return x, y
}
func (curve *nistCurve[Point]) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) {
p1, err := curve.pointFromAffine(x1, y1)
if err != nil {
panic("crypto/elliptic: Add was called on an invalid point")
}
p2, err := curve.pointFromAffine(x2, y2)
if err != nil {
panic("crypto/elliptic: Add was called on an invalid point")
}
return curve.pointToAffine(p1.Add(p1, p2))
}
func (curve *nistCurve[Point]) Double(x1, y1 *big.Int) (*big.Int, *big.Int) {
p, err := curve.pointFromAffine(x1, y1)
if err != nil {
panic("crypto/elliptic: Double was called on an invalid point")
}
return curve.pointToAffine(p.Double(p))
}
// normalizeScalar brings the scalar within the byte size of the order of the
// curve, as expected by the nistec scalar multiplication functions.
func (curve *nistCurve[Point]) normalizeScalar(scalar []byte) []byte {
byteSize := (curve.params.N.BitLen() + 7) / 8
if len(scalar) == byteSize {
return scalar
}
s := new(big.Int).SetBytes(scalar)
if len(scalar) > byteSize {
s.Mod(s, curve.params.N)
}
out := make([]byte, byteSize)
return s.FillBytes(out)
}
func (curve *nistCurve[Point]) ScalarMult(Bx, By *big.Int, scalar []byte) (*big.Int, *big.Int) {
p, err := curve.pointFromAffine(Bx, By)
if err != nil {
panic("crypto/elliptic: ScalarMult was called on an invalid point")
}
scalar = curve.normalizeScalar(scalar)
p, err = p.ScalarMult(p, scalar)
if err != nil {
panic("crypto/elliptic: nistec rejected normalized scalar")
}
return curve.pointToAffine(p)
}
func (curve *nistCurve[Point]) ScalarBaseMult(scalar []byte) (*big.Int, *big.Int) {
scalar = curve.normalizeScalar(scalar)
p, err := curve.newPoint().ScalarBaseMult(scalar)
if err != nil {
panic("crypto/elliptic: nistec rejected normalized scalar")
}
return curve.pointToAffine(p)
}
func (curve *nistCurve[Point]) Unmarshal(data []byte) (x, y *big.Int) {
if len(data) == 0 || data[0] != 4 {
return nil, nil
}
// Use SetBytes to check that data encodes a valid point.
_, err := curve.newPoint().SetBytes(data)
if err != nil {
return nil, nil
}
// We don't use pointToAffine because it involves an expensive field
// inversion to convert from Jacobian to affine coordinates, which we
// already have.
byteLen := (curve.params.BitSize + 7) / 8
x = new(big.Int).SetBytes(data[1 : 1+byteLen])
y = new(big.Int).SetBytes(data[1+byteLen:])
return x, y
}
func (curve *nistCurve[Point]) UnmarshalCompressed(data []byte) (x, y *big.Int) {
if len(data) == 0 || (data[0] != 2 && data[0] != 3) {
return nil, nil
}
p, err := curve.newPoint().SetBytes(data)
if err != nil {
return nil, nil
}
return curve.pointToAffine(p)
}
func bigFromDecimal(s string) *big.Int {
b, ok := new(big.Int).SetString(s, 10)
if !ok {
panic("crypto/elliptic: internal error: invalid encoding")
}
return b
}
func bigFromHex(s string) *big.Int {
b, ok := new(big.Int).SetString(s, 16)
if !ok {
panic("crypto/elliptic: internal error: invalid encoding")
}
return b
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package elliptic
import "math/big"
// CurveParams contains the parameters of an elliptic curve and also provides
// a generic, non-constant time implementation of [Curve].
//
// The generic Curve implementation is deprecated, and using custom curves
// (those not returned by [P224], [P256], [P384], and [P521]) is not guaranteed
// to provide any security property.
type CurveParams struct {
P *big.Int // the order of the underlying field
N *big.Int // the order of the base point
B *big.Int // the constant of the curve equation
Gx, Gy *big.Int // (x,y) of the base point
BitSize int // the size of the underlying field
Name string // the canonical name of the curve
}
func (curve *CurveParams) Params() *CurveParams {
return curve
}
// CurveParams operates, internally, on Jacobian coordinates. For a given
// (x, y) position on the curve, the Jacobian coordinates are (x1, y1, z1)
// where x = x1/z1² and y = y1/z1³. The greatest speedups come when the whole
// calculation can be performed within the transform (as in ScalarMult and
// ScalarBaseMult). But even for Add and Double, it's faster to apply and
// reverse the transform than to operate in affine coordinates.
// polynomial returns x³ - 3x + b.
func (curve *CurveParams) polynomial(x *big.Int) *big.Int {
x3 := new(big.Int).Mul(x, x)
x3.Mul(x3, x)
threeX := new(big.Int).Lsh(x, 1)
threeX.Add(threeX, x)
x3.Sub(x3, threeX)
x3.Add(x3, curve.B)
x3.Mod(x3, curve.P)
return x3
}
// IsOnCurve implements [Curve.IsOnCurve].
//
// Deprecated: the [CurveParams] methods are deprecated and are not guaranteed to
// provide any security property. For ECDH, use the [crypto/ecdh] package.
// For ECDSA, use the [crypto/ecdsa] package with a [Curve] value returned directly
// from [P224], [P256], [P384], or [P521].
func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool {
// If there is a dedicated constant-time implementation for this curve operation,
// use that instead of the generic one.
if specific, ok := matchesSpecificCurve(curve); ok {
return specific.IsOnCurve(x, y)
}
if x.Sign() < 0 || x.Cmp(curve.P) >= 0 ||
y.Sign() < 0 || y.Cmp(curve.P) >= 0 {
return false
}
// y² = x³ - 3x + b
y2 := new(big.Int).Mul(y, y)
y2.Mod(y2, curve.P)
return curve.polynomial(x).Cmp(y2) == 0
}
// zForAffine returns a Jacobian Z value for the affine point (x, y). If x and
// y are zero, it assumes that they represent the point at infinity because (0,
// 0) is not on the any of the curves handled here.
func zForAffine(x, y *big.Int) *big.Int {
z := new(big.Int)
if x.Sign() != 0 || y.Sign() != 0 {
z.SetInt64(1)
}
return z
}
// affineFromJacobian reverses the Jacobian transform. See the comment at the
// top of the file. If the point is ∞ it returns 0, 0.
func (curve *CurveParams) affineFromJacobian(x, y, z *big.Int) (xOut, yOut *big.Int) {
if z.Sign() == 0 {
return new(big.Int), new(big.Int)
}
zinv := new(big.Int).ModInverse(z, curve.P)
zinvsq := new(big.Int).Mul(zinv, zinv)
xOut = new(big.Int).Mul(x, zinvsq)
xOut.Mod(xOut, curve.P)
zinvsq.Mul(zinvsq, zinv)
yOut = new(big.Int).Mul(y, zinvsq)
yOut.Mod(yOut, curve.P)
return
}
// Add implements [Curve.Add].
//
// Deprecated: the [CurveParams] methods are deprecated and are not guaranteed to
// provide any security property. For ECDH, use the [crypto/ecdh] package.
// For ECDSA, use the [crypto/ecdsa] package with a [Curve] value returned directly
// from [P224], [P256], [P384], or [P521].
func (curve *CurveParams) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) {
// If there is a dedicated constant-time implementation for this curve operation,
// use that instead of the generic one.
if specific, ok := matchesSpecificCurve(curve); ok {
return specific.Add(x1, y1, x2, y2)
}
panicIfNotOnCurve(curve, x1, y1)
panicIfNotOnCurve(curve, x2, y2)
z1 := zForAffine(x1, y1)
z2 := zForAffine(x2, y2)
return curve.affineFromJacobian(curve.addJacobian(x1, y1, z1, x2, y2, z2))
}
// addJacobian takes two points in Jacobian coordinates, (x1, y1, z1) and
// (x2, y2, z2) and returns their sum, also in Jacobian form.
func (curve *CurveParams) addJacobian(x1, y1, z1, x2, y2, z2 *big.Int) (*big.Int, *big.Int, *big.Int) {
// See https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#addition-add-2007-bl
x3, y3, z3 := new(big.Int), new(big.Int), new(big.Int)
if z1.Sign() == 0 {
x3.Set(x2)
y3.Set(y2)
z3.Set(z2)
return x3, y3, z3
}
if z2.Sign() == 0 {
x3.Set(x1)
y3.Set(y1)
z3.Set(z1)
return x3, y3, z3
}
z1z1 := new(big.Int).Mul(z1, z1)
z1z1.Mod(z1z1, curve.P)
z2z2 := new(big.Int).Mul(z2, z2)
z2z2.Mod(z2z2, curve.P)
u1 := new(big.Int).Mul(x1, z2z2)
u1.Mod(u1, curve.P)
u2 := new(big.Int).Mul(x2, z1z1)
u2.Mod(u2, curve.P)
h := new(big.Int).Sub(u2, u1)
xEqual := h.Sign() == 0
if h.Sign() == -1 {
h.Add(h, curve.P)
}
i := new(big.Int).Lsh(h, 1)
i.Mul(i, i)
j := new(big.Int).Mul(h, i)
s1 := new(big.Int).Mul(y1, z2)
s1.Mul(s1, z2z2)
s1.Mod(s1, curve.P)
s2 := new(big.Int).Mul(y2, z1)
s2.Mul(s2, z1z1)
s2.Mod(s2, curve.P)
r := new(big.Int).Sub(s2, s1)
if r.Sign() == -1 {
r.Add(r, curve.P)
}
yEqual := r.Sign() == 0
if xEqual && yEqual {
return curve.doubleJacobian(x1, y1, z1)
}
r.Lsh(r, 1)
v := new(big.Int).Mul(u1, i)
x3.Set(r)
x3.Mul(x3, x3)
x3.Sub(x3, j)
x3.Sub(x3, v)
x3.Sub(x3, v)
x3.Mod(x3, curve.P)
y3.Set(r)
v.Sub(v, x3)
y3.Mul(y3, v)
s1.Mul(s1, j)
s1.Lsh(s1, 1)
y3.Sub(y3, s1)
y3.Mod(y3, curve.P)
z3.Add(z1, z2)
z3.Mul(z3, z3)
z3.Sub(z3, z1z1)
z3.Sub(z3, z2z2)
z3.Mul(z3, h)
z3.Mod(z3, curve.P)
return x3, y3, z3
}
// Double implements [Curve.Double].
//
// Deprecated: the [CurveParams] methods are deprecated and are not guaranteed to
// provide any security property. For ECDH, use the [crypto/ecdh] package.
// For ECDSA, use the [crypto/ecdsa] package with a [Curve] value returned directly
// from [P224], [P256], [P384], or [P521].
func (curve *CurveParams) Double(x1, y1 *big.Int) (*big.Int, *big.Int) {
// If there is a dedicated constant-time implementation for this curve operation,
// use that instead of the generic one.
if specific, ok := matchesSpecificCurve(curve); ok {
return specific.Double(x1, y1)
}
panicIfNotOnCurve(curve, x1, y1)
z1 := zForAffine(x1, y1)
return curve.affineFromJacobian(curve.doubleJacobian(x1, y1, z1))
}
// doubleJacobian takes a point in Jacobian coordinates, (x, y, z), and
// returns its double, also in Jacobian form.
func (curve *CurveParams) doubleJacobian(x, y, z *big.Int) (*big.Int, *big.Int, *big.Int) {
// See https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#doubling-dbl-2001-b
delta := new(big.Int).Mul(z, z)
delta.Mod(delta, curve.P)
gamma := new(big.Int).Mul(y, y)
gamma.Mod(gamma, curve.P)
alpha := new(big.Int).Sub(x, delta)
if alpha.Sign() == -1 {
alpha.Add(alpha, curve.P)
}
alpha2 := new(big.Int).Add(x, delta)
alpha.Mul(alpha, alpha2)
alpha2.Set(alpha)
alpha.Lsh(alpha, 1)
alpha.Add(alpha, alpha2)
beta := alpha2.Mul(x, gamma)
x3 := new(big.Int).Mul(alpha, alpha)
beta8 := new(big.Int).Lsh(beta, 3)
beta8.Mod(beta8, curve.P)
x3.Sub(x3, beta8)
if x3.Sign() == -1 {
x3.Add(x3, curve.P)
}
x3.Mod(x3, curve.P)
z3 := new(big.Int).Add(y, z)
z3.Mul(z3, z3)
z3.Sub(z3, gamma)
if z3.Sign() == -1 {
z3.Add(z3, curve.P)
}
z3.Sub(z3, delta)
if z3.Sign() == -1 {
z3.Add(z3, curve.P)
}
z3.Mod(z3, curve.P)
beta.Lsh(beta, 2)
beta.Sub(beta, x3)
if beta.Sign() == -1 {
beta.Add(beta, curve.P)
}
y3 := alpha.Mul(alpha, beta)
gamma.Mul(gamma, gamma)
gamma.Lsh(gamma, 3)
gamma.Mod(gamma, curve.P)
y3.Sub(y3, gamma)
if y3.Sign() == -1 {
y3.Add(y3, curve.P)
}
y3.Mod(y3, curve.P)
return x3, y3, z3
}
// ScalarMult implements [Curve.ScalarMult].
//
// Deprecated: the [CurveParams] methods are deprecated and are not guaranteed to
// provide any security property. For ECDH, use the [crypto/ecdh] package.
// For ECDSA, use the [crypto/ecdsa] package with a [Curve] value returned directly
// from [P224], [P256], [P384], or [P521].
func (curve *CurveParams) ScalarMult(Bx, By *big.Int, k []byte) (*big.Int, *big.Int) {
// If there is a dedicated constant-time implementation for this curve operation,
// use that instead of the generic one.
if specific, ok := matchesSpecificCurve(curve); ok {
return specific.ScalarMult(Bx, By, k)
}
panicIfNotOnCurve(curve, Bx, By)
Bz := new(big.Int).SetInt64(1)
x, y, z := new(big.Int), new(big.Int), new(big.Int)
for _, b := range k {
for range 8 {
x, y, z = curve.doubleJacobian(x, y, z)
if b&0x80 == 0x80 {
x, y, z = curve.addJacobian(Bx, By, Bz, x, y, z)
}
b <<= 1
}
}
return curve.affineFromJacobian(x, y, z)
}
// ScalarBaseMult implements [Curve.ScalarBaseMult].
//
// Deprecated: the [CurveParams] methods are deprecated and are not guaranteed to
// provide any security property. For ECDH, use the [crypto/ecdh] package.
// For ECDSA, use the [crypto/ecdsa] package with a [Curve] value returned directly
// from [P224], [P256], [P384], or [P521].
func (curve *CurveParams) ScalarBaseMult(k []byte) (*big.Int, *big.Int) {
// If there is a dedicated constant-time implementation for this curve operation,
// use that instead of the generic one.
if specific, ok := matchesSpecificCurve(curve); ok {
return specific.ScalarBaseMult(k)
}
return curve.ScalarMult(curve.Gx, curve.Gy, k)
}
func matchesSpecificCurve(params *CurveParams) (Curve, bool) {
for _, c := range []Curve{p224, p256, p384, p521} {
if params == c.Params() {
return c, true
}
}
return nil, false
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package hmac implements the Keyed-Hash Message Authentication Code (HMAC) as
defined in U.S. Federal Information Processing Standards Publication 198.
An HMAC is a cryptographic hash that uses a key to sign a message.
The receiver verifies the hash by recomputing it using the same key.
Receivers should be careful to use Equal to compare MACs in order to avoid
timing side-channels:
// ValidMAC reports whether messageMAC is a valid HMAC tag for message.
func ValidMAC(message, messageMAC, key []byte) bool {
mac := hmac.New(sha256.New, key)
mac.Write(message)
expectedMAC := mac.Sum(nil)
return hmac.Equal(messageMAC, expectedMAC)
}
*/
package hmac
import (
"crypto/internal/boring"
"crypto/internal/fips140/hmac"
"crypto/internal/fips140hash"
"crypto/internal/fips140only"
"crypto/subtle"
"hash"
)
// New returns a new HMAC hash using the given [hash.Hash] type and key.
// New functions like [crypto/sha256.New] can be used as h.
// h must return a new Hash every time it is called.
// Note that unlike other hash implementations in the standard library,
// the returned Hash does not implement [encoding.BinaryMarshaler]
// or [encoding.BinaryUnmarshaler].
func New(h func() hash.Hash, key []byte) hash.Hash {
if boring.Enabled {
hm := boring.NewHMAC(h, key)
if hm != nil {
return hm
}
// BoringCrypto did not recognize h, so fall through to standard Go code.
}
h = fips140hash.UnwrapNew(h)
if fips140only.Enabled {
if len(key) < 112/8 {
panic("crypto/hmac: use of keys shorter than 112 bits is not allowed in FIPS 140-only mode")
}
if !fips140only.ApprovedHash(h()) {
panic("crypto/hmac: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
}
}
return hmac.New(h, key)
}
// Equal compares two MACs for equality without leaking timing information.
func Equal(mac1, mac2 []byte) bool {
// We don't have to be constant time if the lengths of the MACs are
// different as that suggests that a completely different hash function
// was used.
return subtle.ConstantTimeCompare(mac1, mac2) == 1
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate go run gen.go -output md5block.go
// Package md5 implements the MD5 hash algorithm as defined in RFC 1321.
//
// MD5 is cryptographically broken and should not be used for secure
// applications.
package md5
import (
"crypto"
"crypto/internal/fips140only"
"errors"
"hash"
"internal/byteorder"
)
func init() {
crypto.RegisterHash(crypto.MD5, New)
}
// The size of an MD5 checksum in bytes.
const Size = 16
// The blocksize of MD5 in bytes.
const BlockSize = 64
// The maximum number of bytes that can be passed to block(). The limit exists
// because implementations that rely on assembly routines are not preemptible.
const maxAsmIters = 1024
const maxAsmSize = BlockSize * maxAsmIters // 64KiB
const (
init0 = 0x67452301
init1 = 0xEFCDAB89
init2 = 0x98BADCFE
init3 = 0x10325476
)
// digest represents the partial evaluation of a checksum.
type digest struct {
s [4]uint32
x [BlockSize]byte
nx int
len uint64
}
func (d *digest) Reset() {
d.s[0] = init0
d.s[1] = init1
d.s[2] = init2
d.s[3] = init3
d.nx = 0
d.len = 0
}
const (
magic = "md5\x01"
marshaledSize = len(magic) + 4*4 + BlockSize + 8
)
func (d *digest) MarshalBinary() ([]byte, error) {
return d.AppendBinary(make([]byte, 0, marshaledSize))
}
func (d *digest) AppendBinary(b []byte) ([]byte, error) {
b = append(b, magic...)
b = byteorder.BEAppendUint32(b, d.s[0])
b = byteorder.BEAppendUint32(b, d.s[1])
b = byteorder.BEAppendUint32(b, d.s[2])
b = byteorder.BEAppendUint32(b, d.s[3])
b = append(b, d.x[:d.nx]...)
b = append(b, make([]byte, len(d.x)-d.nx)...)
b = byteorder.BEAppendUint64(b, d.len)
return b, nil
}
func (d *digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic) || string(b[:len(magic)]) != magic {
return errors.New("crypto/md5: invalid hash state identifier")
}
if len(b) != marshaledSize {
return errors.New("crypto/md5: invalid hash state size")
}
b = b[len(magic):]
b, d.s[0] = consumeUint32(b)
b, d.s[1] = consumeUint32(b)
b, d.s[2] = consumeUint32(b)
b, d.s[3] = consumeUint32(b)
b = b[copy(d.x[:], b):]
b, d.len = consumeUint64(b)
d.nx = int(d.len % BlockSize)
return nil
}
func consumeUint64(b []byte) ([]byte, uint64) {
return b[8:], byteorder.BEUint64(b[0:8])
}
func consumeUint32(b []byte) ([]byte, uint32) {
return b[4:], byteorder.BEUint32(b[0:4])
}
func (d *digest) Clone() (hash.Cloner, error) {
r := *d
return &r, nil
}
// New returns a new [hash.Hash] computing the MD5 checksum. The Hash
// also implements [encoding.BinaryMarshaler], [encoding.BinaryAppender] and
// [encoding.BinaryUnmarshaler] to marshal and unmarshal the internal
// state of the hash.
func New() hash.Hash {
d := new(digest)
d.Reset()
return d
}
func (d *digest) Size() int { return Size }
func (d *digest) BlockSize() int { return BlockSize }
func (d *digest) Write(p []byte) (nn int, err error) {
if fips140only.Enabled {
return 0, errors.New("crypto/md5: use of MD5 is not allowed in FIPS 140-only mode")
}
// Note that we currently call block or blockGeneric
// directly (guarded using haveAsm) because this allows
// escape analysis to see that p and d don't escape.
nn = len(p)
d.len += uint64(nn)
if d.nx > 0 {
n := copy(d.x[d.nx:], p)
d.nx += n
if d.nx == BlockSize {
if haveAsm {
block(d, d.x[:])
} else {
blockGeneric(d, d.x[:])
}
d.nx = 0
}
p = p[n:]
}
if len(p) >= BlockSize {
n := len(p) &^ (BlockSize - 1)
if haveAsm {
for n > maxAsmSize {
block(d, p[:maxAsmSize])
p = p[maxAsmSize:]
n -= maxAsmSize
}
block(d, p[:n])
} else {
blockGeneric(d, p[:n])
}
p = p[n:]
}
if len(p) > 0 {
d.nx = copy(d.x[:], p)
}
return
}
func (d *digest) Sum(in []byte) []byte {
// Make a copy of d so that caller can keep writing and summing.
d0 := *d
hash := d0.checkSum()
return append(in, hash[:]...)
}
func (d *digest) checkSum() [Size]byte {
if fips140only.Enabled {
panic("crypto/md5: use of MD5 is not allowed in FIPS 140-only mode")
}
// Append 0x80 to the end of the message and then append zeros
// until the length is a multiple of 56 bytes. Finally append
// 8 bytes representing the message length in bits.
//
// 1 byte end marker :: 0-63 padding bytes :: 8 byte length
tmp := [1 + 63 + 8]byte{0x80}
pad := (55 - d.len) % 64 // calculate number of padding bytes
byteorder.LEPutUint64(tmp[1+pad:], d.len<<3) // append length in bits
d.Write(tmp[:1+pad+8])
// The previous write ensures that a whole number of
// blocks (i.e. a multiple of 64 bytes) have been hashed.
if d.nx != 0 {
panic("d.nx != 0")
}
var digest [Size]byte
byteorder.LEPutUint32(digest[0:], d.s[0])
byteorder.LEPutUint32(digest[4:], d.s[1])
byteorder.LEPutUint32(digest[8:], d.s[2])
byteorder.LEPutUint32(digest[12:], d.s[3])
return digest
}
// Sum returns the MD5 checksum of the data.
func Sum(data []byte) [Size]byte {
var d digest
d.Reset()
d.Write(data)
return d.checkSum()
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by go run gen.go -output md5block.go; DO NOT EDIT.
package md5
import (
"internal/byteorder"
"math/bits"
)
func blockGeneric(dig *digest, p []byte) {
// load state
a, b, c, d := dig.s[0], dig.s[1], dig.s[2], dig.s[3]
for i := 0; i <= len(p)-BlockSize; i += BlockSize {
// eliminate bounds checks on p
q := p[i:]
q = q[:BlockSize:BlockSize]
// save current state
aa, bb, cc, dd := a, b, c, d
// load input block
x0 := byteorder.LEUint32(q[4*0x0:])
x1 := byteorder.LEUint32(q[4*0x1:])
x2 := byteorder.LEUint32(q[4*0x2:])
x3 := byteorder.LEUint32(q[4*0x3:])
x4 := byteorder.LEUint32(q[4*0x4:])
x5 := byteorder.LEUint32(q[4*0x5:])
x6 := byteorder.LEUint32(q[4*0x6:])
x7 := byteorder.LEUint32(q[4*0x7:])
x8 := byteorder.LEUint32(q[4*0x8:])
x9 := byteorder.LEUint32(q[4*0x9:])
xa := byteorder.LEUint32(q[4*0xa:])
xb := byteorder.LEUint32(q[4*0xb:])
xc := byteorder.LEUint32(q[4*0xc:])
xd := byteorder.LEUint32(q[4*0xd:])
xe := byteorder.LEUint32(q[4*0xe:])
xf := byteorder.LEUint32(q[4*0xf:])
// round 1
a = b + bits.RotateLeft32((((c^d)&b)^d)+a+x0+0xd76aa478, 7)
d = a + bits.RotateLeft32((((b^c)&a)^c)+d+x1+0xe8c7b756, 12)
c = d + bits.RotateLeft32((((a^b)&d)^b)+c+x2+0x242070db, 17)
b = c + bits.RotateLeft32((((d^a)&c)^a)+b+x3+0xc1bdceee, 22)
a = b + bits.RotateLeft32((((c^d)&b)^d)+a+x4+0xf57c0faf, 7)
d = a + bits.RotateLeft32((((b^c)&a)^c)+d+x5+0x4787c62a, 12)
c = d + bits.RotateLeft32((((a^b)&d)^b)+c+x6+0xa8304613, 17)
b = c + bits.RotateLeft32((((d^a)&c)^a)+b+x7+0xfd469501, 22)
a = b + bits.RotateLeft32((((c^d)&b)^d)+a+x8+0x698098d8, 7)
d = a + bits.RotateLeft32((((b^c)&a)^c)+d+x9+0x8b44f7af, 12)
c = d + bits.RotateLeft32((((a^b)&d)^b)+c+xa+0xffff5bb1, 17)
b = c + bits.RotateLeft32((((d^a)&c)^a)+b+xb+0x895cd7be, 22)
a = b + bits.RotateLeft32((((c^d)&b)^d)+a+xc+0x6b901122, 7)
d = a + bits.RotateLeft32((((b^c)&a)^c)+d+xd+0xfd987193, 12)
c = d + bits.RotateLeft32((((a^b)&d)^b)+c+xe+0xa679438e, 17)
b = c + bits.RotateLeft32((((d^a)&c)^a)+b+xf+0x49b40821, 22)
// round 2
a = b + bits.RotateLeft32((((b^c)&d)^c)+a+x1+0xf61e2562, 5)
d = a + bits.RotateLeft32((((a^b)&c)^b)+d+x6+0xc040b340, 9)
c = d + bits.RotateLeft32((((d^a)&b)^a)+c+xb+0x265e5a51, 14)
b = c + bits.RotateLeft32((((c^d)&a)^d)+b+x0+0xe9b6c7aa, 20)
a = b + bits.RotateLeft32((((b^c)&d)^c)+a+x5+0xd62f105d, 5)
d = a + bits.RotateLeft32((((a^b)&c)^b)+d+xa+0x02441453, 9)
c = d + bits.RotateLeft32((((d^a)&b)^a)+c+xf+0xd8a1e681, 14)
b = c + bits.RotateLeft32((((c^d)&a)^d)+b+x4+0xe7d3fbc8, 20)
a = b + bits.RotateLeft32((((b^c)&d)^c)+a+x9+0x21e1cde6, 5)
d = a + bits.RotateLeft32((((a^b)&c)^b)+d+xe+0xc33707d6, 9)
c = d + bits.RotateLeft32((((d^a)&b)^a)+c+x3+0xf4d50d87, 14)
b = c + bits.RotateLeft32((((c^d)&a)^d)+b+x8+0x455a14ed, 20)
a = b + bits.RotateLeft32((((b^c)&d)^c)+a+xd+0xa9e3e905, 5)
d = a + bits.RotateLeft32((((a^b)&c)^b)+d+x2+0xfcefa3f8, 9)
c = d + bits.RotateLeft32((((d^a)&b)^a)+c+x7+0x676f02d9, 14)
b = c + bits.RotateLeft32((((c^d)&a)^d)+b+xc+0x8d2a4c8a, 20)
// round 3
a = b + bits.RotateLeft32((b^c^d)+a+x5+0xfffa3942, 4)
d = a + bits.RotateLeft32((a^b^c)+d+x8+0x8771f681, 11)
c = d + bits.RotateLeft32((d^a^b)+c+xb+0x6d9d6122, 16)
b = c + bits.RotateLeft32((c^d^a)+b+xe+0xfde5380c, 23)
a = b + bits.RotateLeft32((b^c^d)+a+x1+0xa4beea44, 4)
d = a + bits.RotateLeft32((a^b^c)+d+x4+0x4bdecfa9, 11)
c = d + bits.RotateLeft32((d^a^b)+c+x7+0xf6bb4b60, 16)
b = c + bits.RotateLeft32((c^d^a)+b+xa+0xbebfbc70, 23)
a = b + bits.RotateLeft32((b^c^d)+a+xd+0x289b7ec6, 4)
d = a + bits.RotateLeft32((a^b^c)+d+x0+0xeaa127fa, 11)
c = d + bits.RotateLeft32((d^a^b)+c+x3+0xd4ef3085, 16)
b = c + bits.RotateLeft32((c^d^a)+b+x6+0x04881d05, 23)
a = b + bits.RotateLeft32((b^c^d)+a+x9+0xd9d4d039, 4)
d = a + bits.RotateLeft32((a^b^c)+d+xc+0xe6db99e5, 11)
c = d + bits.RotateLeft32((d^a^b)+c+xf+0x1fa27cf8, 16)
b = c + bits.RotateLeft32((c^d^a)+b+x2+0xc4ac5665, 23)
// round 4
a = b + bits.RotateLeft32((c^(b|^d))+a+x0+0xf4292244, 6)
d = a + bits.RotateLeft32((b^(a|^c))+d+x7+0x432aff97, 10)
c = d + bits.RotateLeft32((a^(d|^b))+c+xe+0xab9423a7, 15)
b = c + bits.RotateLeft32((d^(c|^a))+b+x5+0xfc93a039, 21)
a = b + bits.RotateLeft32((c^(b|^d))+a+xc+0x655b59c3, 6)
d = a + bits.RotateLeft32((b^(a|^c))+d+x3+0x8f0ccc92, 10)
c = d + bits.RotateLeft32((a^(d|^b))+c+xa+0xffeff47d, 15)
b = c + bits.RotateLeft32((d^(c|^a))+b+x1+0x85845dd1, 21)
a = b + bits.RotateLeft32((c^(b|^d))+a+x8+0x6fa87e4f, 6)
d = a + bits.RotateLeft32((b^(a|^c))+d+xf+0xfe2ce6e0, 10)
c = d + bits.RotateLeft32((a^(d|^b))+c+x6+0xa3014314, 15)
b = c + bits.RotateLeft32((d^(c|^a))+b+xd+0x4e0811a1, 21)
a = b + bits.RotateLeft32((c^(b|^d))+a+x4+0xf7537e82, 6)
d = a + bits.RotateLeft32((b^(a|^c))+d+xb+0xbd3af235, 10)
c = d + bits.RotateLeft32((a^(d|^b))+c+x2+0x2ad7d2bb, 15)
b = c + bits.RotateLeft32((d^(c|^a))+b+x9+0xeb86d391, 21)
// add saved state
a += aa
b += bb
c += cc
d += dd
}
// save state
dig.s[0], dig.s[1], dig.s[2], dig.s[3] = a, b, c, d
}
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package mlkem implements the quantum-resistant key encapsulation method
// ML-KEM (formerly known as Kyber), as specified in [NIST FIPS 203].
//
// Most applications should use the ML-KEM-768 parameter set, as implemented by
// [DecapsulationKey768] and [EncapsulationKey768].
//
// [NIST FIPS 203]: https://doi.org/10.6028/NIST.FIPS.203
package mlkem
import "crypto/internal/fips140/mlkem"
const (
// SharedKeySize is the size of a shared key produced by ML-KEM.
SharedKeySize = 32
// SeedSize is the size of a seed used to generate a decapsulation key.
SeedSize = 64
// CiphertextSize768 is the size of a ciphertext produced by ML-KEM-768.
CiphertextSize768 = 1088
// EncapsulationKeySize768 is the size of an ML-KEM-768 encapsulation key.
EncapsulationKeySize768 = 1184
// CiphertextSize1024 is the size of a ciphertext produced by ML-KEM-1024.
CiphertextSize1024 = 1568
// EncapsulationKeySize1024 is the size of an ML-KEM-1024 encapsulation key.
EncapsulationKeySize1024 = 1568
)
// DecapsulationKey768 is the secret key used to decapsulate a shared key
// from a ciphertext. It includes various precomputed values.
type DecapsulationKey768 struct {
key *mlkem.DecapsulationKey768
}
// GenerateKey768 generates a new decapsulation key, drawing random bytes from
// the default crypto/rand source. The decapsulation key must be kept secret.
func GenerateKey768() (*DecapsulationKey768, error) {
key, err := mlkem.GenerateKey768()
if err != nil {
return nil, err
}
return &DecapsulationKey768{key}, nil
}
// NewDecapsulationKey768 expands a decapsulation key from a 64-byte seed in the
// "d || z" form. The seed must be uniformly random.
func NewDecapsulationKey768(seed []byte) (*DecapsulationKey768, error) {
key, err := mlkem.NewDecapsulationKey768(seed)
if err != nil {
return nil, err
}
return &DecapsulationKey768{key}, nil
}
// Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form.
//
// The decapsulation key must be kept secret.
func (dk *DecapsulationKey768) Bytes() []byte {
return dk.key.Bytes()
}
// Decapsulate generates a shared key from a ciphertext and a decapsulation
// key. If the ciphertext is not valid, Decapsulate returns an error.
//
// The shared key must be kept secret.
func (dk *DecapsulationKey768) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
return dk.key.Decapsulate(ciphertext)
}
// EncapsulationKey returns the public encapsulation key necessary to produce
// ciphertexts.
func (dk *DecapsulationKey768) EncapsulationKey() *EncapsulationKey768 {
return &EncapsulationKey768{dk.key.EncapsulationKey()}
}
// An EncapsulationKey768 is the public key used to produce ciphertexts to be
// decapsulated by the corresponding DecapsulationKey768.
type EncapsulationKey768 struct {
key *mlkem.EncapsulationKey768
}
// NewEncapsulationKey768 parses an encapsulation key from its encoded form. If
// the encapsulation key is not valid, NewEncapsulationKey768 returns an error.
func NewEncapsulationKey768(encapsulationKey []byte) (*EncapsulationKey768, error) {
key, err := mlkem.NewEncapsulationKey768(encapsulationKey)
if err != nil {
return nil, err
}
return &EncapsulationKey768{key}, nil
}
// Bytes returns the encapsulation key as a byte slice.
func (ek *EncapsulationKey768) Bytes() []byte {
return ek.key.Bytes()
}
// Encapsulate generates a shared key and an associated ciphertext from an
// encapsulation key, drawing random bytes from the default crypto/rand source.
//
// The shared key must be kept secret.
func (ek *EncapsulationKey768) Encapsulate() (sharedKey, ciphertext []byte) {
return ek.key.Encapsulate()
}
// DecapsulationKey1024 is the secret key used to decapsulate a shared key
// from a ciphertext. It includes various precomputed values.
type DecapsulationKey1024 struct {
key *mlkem.DecapsulationKey1024
}
// GenerateKey1024 generates a new decapsulation key, drawing random bytes from
// the default crypto/rand source. The decapsulation key must be kept secret.
func GenerateKey1024() (*DecapsulationKey1024, error) {
key, err := mlkem.GenerateKey1024()
if err != nil {
return nil, err
}
return &DecapsulationKey1024{key}, nil
}
// NewDecapsulationKey1024 expands a decapsulation key from a 64-byte seed in the
// "d || z" form. The seed must be uniformly random.
func NewDecapsulationKey1024(seed []byte) (*DecapsulationKey1024, error) {
key, err := mlkem.NewDecapsulationKey1024(seed)
if err != nil {
return nil, err
}
return &DecapsulationKey1024{key}, nil
}
// Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form.
//
// The decapsulation key must be kept secret.
func (dk *DecapsulationKey1024) Bytes() []byte {
return dk.key.Bytes()
}
// Decapsulate generates a shared key from a ciphertext and a decapsulation
// key. If the ciphertext is not valid, Decapsulate returns an error.
//
// The shared key must be kept secret.
func (dk *DecapsulationKey1024) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
return dk.key.Decapsulate(ciphertext)
}
// EncapsulationKey returns the public encapsulation key necessary to produce
// ciphertexts.
func (dk *DecapsulationKey1024) EncapsulationKey() *EncapsulationKey1024 {
return &EncapsulationKey1024{dk.key.EncapsulationKey()}
}
// An EncapsulationKey1024 is the public key used to produce ciphertexts to be
// decapsulated by the corresponding DecapsulationKey1024.
type EncapsulationKey1024 struct {
key *mlkem.EncapsulationKey1024
}
// NewEncapsulationKey1024 parses an encapsulation key from its encoded form. If
// the encapsulation key is not valid, NewEncapsulationKey1024 returns an error.
func NewEncapsulationKey1024(encapsulationKey []byte) (*EncapsulationKey1024, error) {
key, err := mlkem.NewEncapsulationKey1024(encapsulationKey)
if err != nil {
return nil, err
}
return &EncapsulationKey1024{key}, nil
}
// Bytes returns the encapsulation key as a byte slice.
func (ek *EncapsulationKey1024) Bytes() []byte {
return ek.key.Bytes()
}
// Encapsulate generates a shared key and an associated ciphertext from an
// encapsulation key, drawing random bytes from the default crypto/rand source.
//
// The shared key must be kept secret.
func (ek *EncapsulationKey1024) Encapsulate() (sharedKey, ciphertext []byte) {
return ek.key.Encapsulate()
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package rand implements a cryptographically secure
// random number generator.
package rand
import (
"crypto/internal/boring"
"crypto/internal/fips140"
"crypto/internal/fips140/drbg"
"crypto/internal/sysrand"
"io"
_ "unsafe"
)
// Reader is a global, shared instance of a cryptographically
// secure random number generator. It is safe for concurrent use.
//
// - On Linux, FreeBSD, Dragonfly, and Solaris, Reader uses getrandom(2).
// - On legacy Linux (< 3.17), Reader opens /dev/urandom on first use.
// - On macOS, iOS, and OpenBSD Reader, uses arc4random_buf(3).
// - On NetBSD, Reader uses the kern.arandom sysctl.
// - On Windows, Reader uses the ProcessPrng API.
// - On js/wasm, Reader uses the Web Crypto API.
// - On wasip1/wasm, Reader uses random_get.
//
// In FIPS 140-3 mode, the output passes through an SP 800-90A Rev. 1
// Deterministric Random Bit Generator (DRBG).
var Reader io.Reader
func init() {
if boring.Enabled {
Reader = boring.RandReader
return
}
Reader = &reader{}
}
type reader struct {
drbg.DefaultReader
}
func (r *reader) Read(b []byte) (n int, err error) {
boring.Unreachable()
if fips140.Enabled {
drbg.Read(b)
} else {
sysrand.Read(b)
}
return len(b), nil
}
// fatal is [runtime.fatal], pushed via linkname.
//
//go:linkname fatal
func fatal(string)
// Read fills b with cryptographically secure random bytes. It never returns an
// error, and always fills b entirely.
//
// Read calls [io.ReadFull] on [Reader] and crashes the program irrecoverably if
// an error is returned. The default Reader uses operating system APIs that are
// documented to never return an error on all but legacy Linux systems.
func Read(b []byte) (n int, err error) {
// We don't want b to escape to the heap, but escape analysis can't see
// through a potentially overridden Reader, so we special-case the default
// case which we can keep non-escaping, and in the general case we read into
// a heap buffer and copy from it.
if r, ok := Reader.(*reader); ok {
_, err = r.Read(b)
} else {
bb := make([]byte, len(b))
_, err = io.ReadFull(Reader, bb)
copy(b, bb)
}
if err != nil {
fatal("crypto/rand: failed to read random data (see https://go.dev/issue/66821): " + err.Error())
panic("unreachable") // To be sure.
}
return len(b), nil
}
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rand
const base32alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
// Text returns a cryptographically random string using the standard RFC 4648 base32 alphabet
// for use when a secret string, token, password, or other text is needed.
// The result contains at least 128 bits of randomness, enough to prevent brute force
// guessing attacks and to make the likelihood of collisions vanishingly small.
// A future version may return longer texts as needed to maintain those properties.
func Text() string {
// ⌈log₃₂ 2¹²⁸⌉ = 26 chars
src := make([]byte, 26)
Read(src)
for i := range src {
src[i] = base32alphabet[src[i]%32]
}
return string(src)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rand
import (
"crypto/internal/fips140only"
"crypto/internal/randutil"
"errors"
"io"
"math/big"
)
// Prime returns a number of the given bit length that is prime with high probability.
// Prime will return error for any error returned by rand.Read or if bits < 2.
func Prime(rand io.Reader, bits int) (*big.Int, error) {
if fips140only.Enabled {
return nil, errors.New("crypto/rand: use of Prime is not allowed in FIPS 140-only mode")
}
if bits < 2 {
return nil, errors.New("crypto/rand: prime size must be at least 2-bit")
}
randutil.MaybeReadByte(rand)
b := uint(bits % 8)
if b == 0 {
b = 8
}
bytes := make([]byte, (bits+7)/8)
p := new(big.Int)
for {
if _, err := io.ReadFull(rand, bytes); err != nil {
return nil, err
}
// Clear bits in the first byte to make sure the candidate has a size <= bits.
bytes[0] &= uint8(int(1<<b) - 1)
// Don't let the value be too small, i.e, set the most significant two bits.
// Setting the top two bits, rather than just the top bit,
// means that when two of these values are multiplied together,
// the result isn't ever one bit short.
if b >= 2 {
bytes[0] |= 3 << (b - 2)
} else {
// Here b==1, because b cannot be zero.
bytes[0] |= 1
if len(bytes) > 1 {
bytes[1] |= 0x80
}
}
// Make the value odd since an even number this large certainly isn't prime.
bytes[len(bytes)-1] |= 1
p.SetBytes(bytes)
if p.ProbablyPrime(20) {
return p, nil
}
}
}
// Int returns a uniform random value in [0, max). It panics if max <= 0, and
// returns an error if rand.Read returns one.
func Int(rand io.Reader, max *big.Int) (n *big.Int, err error) {
if max.Sign() <= 0 {
panic("crypto/rand: argument to Int is <= 0")
}
n = new(big.Int)
n.Sub(max, n.SetUint64(1))
// bitLen is the maximum bit length needed to encode a value < max.
bitLen := n.BitLen()
if bitLen == 0 {
// the only valid result is 0
return
}
// k is the maximum byte length needed to encode a value < max.
k := (bitLen + 7) / 8
// b is the number of bits in the most significant byte of max-1.
b := uint(bitLen % 8)
if b == 0 {
b = 8
}
bytes := make([]byte, k)
for {
_, err = io.ReadFull(rand, bytes)
if err != nil {
return nil, err
}
// Clear bits in the first byte to increase the probability
// that the candidate is < max.
bytes[0] &= uint8(int(1<<b) - 1)
n.SetBytes(bytes)
if n.Cmp(max) < 0 {
return
}
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package rc4 implements RC4 encryption, as defined in Bruce Schneier's
// Applied Cryptography.
//
// RC4 is cryptographically broken and should not be used for secure
// applications.
package rc4
import (
"crypto/internal/fips140/alias"
"crypto/internal/fips140only"
"errors"
"strconv"
)
// A Cipher is an instance of RC4 using a particular key.
type Cipher struct {
s [256]uint32
i, j uint8
}
type KeySizeError int
func (k KeySizeError) Error() string {
return "crypto/rc4: invalid key size " + strconv.Itoa(int(k))
}
// NewCipher creates and returns a new [Cipher]. The key argument should be the
// RC4 key, at least 1 byte and at most 256 bytes.
func NewCipher(key []byte) (*Cipher, error) {
if fips140only.Enabled {
return nil, errors.New("crypto/rc4: use of RC4 is not allowed in FIPS 140-only mode")
}
k := len(key)
if k < 1 || k > 256 {
return nil, KeySizeError(k)
}
var c Cipher
for i := 0; i < 256; i++ {
c.s[i] = uint32(i)
}
var j uint8 = 0
for i := 0; i < 256; i++ {
j += uint8(c.s[i]) + key[i%k]
c.s[i], c.s[j] = c.s[j], c.s[i]
}
return &c, nil
}
// Reset zeros the key data and makes the [Cipher] unusable.
//
// Deprecated: Reset can't guarantee that the key will be entirely removed from
// the process's memory.
func (c *Cipher) Reset() {
for i := range c.s {
c.s[i] = 0
}
c.i, c.j = 0, 0
}
// XORKeyStream sets dst to the result of XORing src with the key stream.
// Dst and src must overlap entirely or not at all.
func (c *Cipher) XORKeyStream(dst, src []byte) {
if len(src) == 0 {
return
}
if alias.InexactOverlap(dst[:len(src)], src) {
panic("crypto/rc4: invalid buffer overlap")
}
i, j := c.i, c.j
_ = dst[len(src)-1]
dst = dst[:len(src)] // eliminate bounds check from loop
for k, v := range src {
i += 1
x := c.s[i]
j += uint8(x)
y := c.s[j]
c.s[i], c.s[j] = y, x
dst[k] = v ^ uint8(c.s[uint8(x+y)])
}
c.i, c.j = i, j
}
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rsa
import (
"crypto"
"crypto/internal/boring"
"crypto/internal/fips140/rsa"
"crypto/internal/fips140hash"
"crypto/internal/fips140only"
"errors"
"hash"
"io"
)
const (
// PSSSaltLengthAuto causes the salt in a PSS signature to be as large
// as possible when signing, and to be auto-detected when verifying.
//
// When signing in FIPS 140-3 mode, the salt length is capped at the length
// of the hash function used in the signature.
PSSSaltLengthAuto = 0
// PSSSaltLengthEqualsHash causes the salt length to equal the length
// of the hash used in the signature.
PSSSaltLengthEqualsHash = -1
)
// PSSOptions contains options for creating and verifying PSS signatures.
type PSSOptions struct {
// SaltLength controls the length of the salt used in the PSS signature. It
// can either be a positive number of bytes, or one of the special
// PSSSaltLength constants.
SaltLength int
// Hash is the hash function used to generate the message digest. If not
// zero, it overrides the hash function passed to SignPSS. It's required
// when using PrivateKey.Sign.
Hash crypto.Hash
}
// HashFunc returns opts.Hash so that [PSSOptions] implements [crypto.SignerOpts].
func (opts *PSSOptions) HashFunc() crypto.Hash {
return opts.Hash
}
func (opts *PSSOptions) saltLength() int {
if opts == nil {
return PSSSaltLengthAuto
}
return opts.SaltLength
}
// SignPSS calculates the signature of digest using PSS.
//
// digest must be the result of hashing the input message using the given hash
// function. The opts argument may be nil, in which case sensible defaults are
// used. If opts.Hash is set, it overrides hash.
//
// The signature is randomized depending on the message, key, and salt size,
// using bytes from rand. Most applications should use [crypto/rand.Reader] as
// rand.
func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) {
if err := checkPublicKeySize(&priv.PublicKey); err != nil {
return nil, err
}
if opts != nil && opts.Hash != 0 {
hash = opts.Hash
}
if boring.Enabled && rand == boring.RandReader {
bkey, err := boringPrivateKey(priv)
if err != nil {
return nil, err
}
return boring.SignRSAPSS(bkey, hash, digest, opts.saltLength())
}
boring.UnreachableExceptTests()
h := fips140hash.Unwrap(hash.New())
if err := checkFIPS140OnlyPrivateKey(priv); err != nil {
return nil, err
}
if fips140only.Enabled && !fips140only.ApprovedHash(h) {
return nil, errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
}
if fips140only.Enabled && !fips140only.ApprovedRandomReader(rand) {
return nil, errors.New("crypto/rsa: only crypto/rand.Reader is allowed in FIPS 140-only mode")
}
k, err := fipsPrivateKey(priv)
if err != nil {
return nil, err
}
saltLength := opts.saltLength()
if fips140only.Enabled && saltLength > h.Size() {
return nil, errors.New("crypto/rsa: use of PSS salt longer than the hash is not allowed in FIPS 140-only mode")
}
switch saltLength {
case PSSSaltLengthAuto:
saltLength, err = rsa.PSSMaxSaltLength(k.PublicKey(), h)
if err != nil {
return nil, fipsError(err)
}
case PSSSaltLengthEqualsHash:
saltLength = h.Size()
default:
// If we get here saltLength is either > 0 or < -1, in the
// latter case we fail out.
if saltLength <= 0 {
return nil, errors.New("crypto/rsa: invalid PSS salt length")
}
}
return fipsError2(rsa.SignPSS(rand, k, h, digest, saltLength))
}
// VerifyPSS verifies a PSS signature.
//
// A valid signature is indicated by returning a nil error. digest must be the
// result of hashing the input message using the given hash function. The opts
// argument may be nil, in which case sensible defaults are used. opts.Hash is
// ignored.
//
// The inputs are not considered confidential, and may leak through timing side
// channels, or if an attacker has control of part of the inputs.
func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error {
if err := checkPublicKeySize(pub); err != nil {
return err
}
if boring.Enabled {
bkey, err := boringPublicKey(pub)
if err != nil {
return err
}
if err := boring.VerifyRSAPSS(bkey, hash, digest, sig, opts.saltLength()); err != nil {
return ErrVerification
}
return nil
}
h := fips140hash.Unwrap(hash.New())
if err := checkFIPS140OnlyPublicKey(pub); err != nil {
return err
}
if fips140only.Enabled && !fips140only.ApprovedHash(h) {
return errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
}
k, err := fipsPublicKey(pub)
if err != nil {
return err
}
saltLength := opts.saltLength()
if fips140only.Enabled && saltLength > h.Size() {
return errors.New("crypto/rsa: use of PSS salt longer than the hash is not allowed in FIPS 140-only mode")
}
switch saltLength {
case PSSSaltLengthAuto:
return fipsError(rsa.VerifyPSS(k, h, digest, sig))
case PSSSaltLengthEqualsHash:
return fipsError(rsa.VerifyPSSWithSaltLength(k, h, digest, sig, h.Size()))
default:
return fipsError(rsa.VerifyPSSWithSaltLength(k, h, digest, sig, saltLength))
}
}
// EncryptOAEP encrypts the given message with RSA-OAEP.
//
// OAEP is parameterised by a hash function that is used as a random oracle.
// Encryption and decryption of a given message must use the same hash function
// and sha256.New() is a reasonable choice.
//
// The random parameter is used as a source of entropy to ensure that
// encrypting the same message twice doesn't result in the same ciphertext.
// Most applications should use [crypto/rand.Reader] as random.
//
// The label parameter may contain arbitrary data that will not be encrypted,
// but which gives important context to the message. For example, if a given
// public key is used to encrypt two types of messages then distinct label
// values could be used to ensure that a ciphertext for one purpose cannot be
// used for another by an attacker. If not required it can be empty.
//
// The message must be no longer than the length of the public modulus minus
// twice the hash length, minus a further 2.
func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, label []byte) ([]byte, error) {
if err := checkPublicKeySize(pub); err != nil {
return nil, err
}
defer hash.Reset()
if boring.Enabled && random == boring.RandReader {
hash.Reset()
k := pub.Size()
if len(msg) > k-2*hash.Size()-2 {
return nil, ErrMessageTooLong
}
bkey, err := boringPublicKey(pub)
if err != nil {
return nil, err
}
return boring.EncryptRSAOAEP(hash, hash, bkey, msg, label)
}
boring.UnreachableExceptTests()
hash = fips140hash.Unwrap(hash)
if err := checkFIPS140OnlyPublicKey(pub); err != nil {
return nil, err
}
if fips140only.Enabled && !fips140only.ApprovedHash(hash) {
return nil, errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
}
if fips140only.Enabled && !fips140only.ApprovedRandomReader(random) {
return nil, errors.New("crypto/rsa: only crypto/rand.Reader is allowed in FIPS 140-only mode")
}
k, err := fipsPublicKey(pub)
if err != nil {
return nil, err
}
return fipsError2(rsa.EncryptOAEP(hash, hash, random, k, msg, label))
}
// DecryptOAEP decrypts ciphertext using RSA-OAEP.
//
// OAEP is parameterised by a hash function that is used as a random oracle.
// Encryption and decryption of a given message must use the same hash function
// and sha256.New() is a reasonable choice.
//
// The random parameter is legacy and ignored, and it can be nil.
//
// The label parameter must match the value given when encrypting. See
// [EncryptOAEP] for details.
func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext []byte, label []byte) ([]byte, error) {
defer hash.Reset()
return decryptOAEP(hash, hash, priv, ciphertext, label)
}
func decryptOAEP(hash, mgfHash hash.Hash, priv *PrivateKey, ciphertext []byte, label []byte) ([]byte, error) {
if err := checkPublicKeySize(&priv.PublicKey); err != nil {
return nil, err
}
if boring.Enabled {
k := priv.Size()
if len(ciphertext) > k ||
k < hash.Size()*2+2 {
return nil, ErrDecryption
}
bkey, err := boringPrivateKey(priv)
if err != nil {
return nil, err
}
out, err := boring.DecryptRSAOAEP(hash, mgfHash, bkey, ciphertext, label)
if err != nil {
return nil, ErrDecryption
}
return out, nil
}
hash = fips140hash.Unwrap(hash)
mgfHash = fips140hash.Unwrap(mgfHash)
if err := checkFIPS140OnlyPrivateKey(priv); err != nil {
return nil, err
}
if fips140only.Enabled {
if !fips140only.ApprovedHash(hash) || !fips140only.ApprovedHash(mgfHash) {
return nil, errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
}
}
k, err := fipsPrivateKey(priv)
if err != nil {
return nil, err
}
return fipsError2(rsa.DecryptOAEP(hash, mgfHash, k, ciphertext, label))
}
// SignPKCS1v15 calculates the signature of hashed using
// RSASSA-PKCS1-V1_5-SIGN from RSA PKCS #1 v1.5. Note that hashed must
// be the result of hashing the input message using the given hash
// function. If hash is zero, hashed is signed directly. This isn't
// advisable except for interoperability.
//
// The random parameter is legacy and ignored, and it can be nil.
//
// This function is deterministic. Thus, if the set of possible
// messages is small, an attacker may be able to build a map from
// messages to signatures and identify the signed messages. As ever,
// signatures provide authenticity, not confidentiality.
func SignPKCS1v15(random io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []byte) ([]byte, error) {
var hashName string
if hash != crypto.Hash(0) {
if len(hashed) != hash.Size() {
return nil, errors.New("crypto/rsa: input must be hashed message")
}
hashName = hash.String()
}
if err := checkPublicKeySize(&priv.PublicKey); err != nil {
return nil, err
}
if boring.Enabled {
bkey, err := boringPrivateKey(priv)
if err != nil {
return nil, err
}
return boring.SignRSAPKCS1v15(bkey, hash, hashed)
}
if err := checkFIPS140OnlyPrivateKey(priv); err != nil {
return nil, err
}
if fips140only.Enabled && !fips140only.ApprovedHash(fips140hash.Unwrap(hash.New())) {
return nil, errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
}
k, err := fipsPrivateKey(priv)
if err != nil {
return nil, err
}
return fipsError2(rsa.SignPKCS1v15(k, hashName, hashed))
}
// VerifyPKCS1v15 verifies an RSA PKCS #1 v1.5 signature.
// hashed is the result of hashing the input message using the given hash
// function and sig is the signature. A valid signature is indicated by
// returning a nil error. If hash is zero then hashed is used directly. This
// isn't advisable except for interoperability.
//
// The inputs are not considered confidential, and may leak through timing side
// channels, or if an attacker has control of part of the inputs.
func VerifyPKCS1v15(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte) error {
var hashName string
if hash != crypto.Hash(0) {
if len(hashed) != hash.Size() {
return errors.New("crypto/rsa: input must be hashed message")
}
hashName = hash.String()
}
if err := checkPublicKeySize(pub); err != nil {
return err
}
if boring.Enabled {
bkey, err := boringPublicKey(pub)
if err != nil {
return err
}
if err := boring.VerifyRSAPKCS1v15(bkey, hash, hashed, sig); err != nil {
return ErrVerification
}
return nil
}
if err := checkFIPS140OnlyPublicKey(pub); err != nil {
return err
}
if fips140only.Enabled && !fips140only.ApprovedHash(fips140hash.Unwrap(hash.New())) {
return errors.New("crypto/rsa: use of hash functions other than SHA-2 or SHA-3 is not allowed in FIPS 140-only mode")
}
k, err := fipsPublicKey(pub)
if err != nil {
return err
}
return fipsError(rsa.VerifyPKCS1v15(k, hashName, hashed, sig))
}
func fipsError(err error) error {
switch err {
case rsa.ErrDecryption:
return ErrDecryption
case rsa.ErrVerification:
return ErrVerification
case rsa.ErrMessageTooLong:
return ErrMessageTooLong
}
return err
}
func fipsError2[T any](x T, err error) (T, error) {
return x, fipsError(err)
}
func checkFIPS140OnlyPublicKey(pub *PublicKey) error {
if !fips140only.Enabled {
return nil
}
if pub.N == nil {
return errors.New("crypto/rsa: public key missing N")
}
if pub.N.BitLen() < 2048 {
return errors.New("crypto/rsa: use of keys smaller than 2048 bits is not allowed in FIPS 140-only mode")
}
if pub.N.BitLen()%2 == 1 {
return errors.New("crypto/rsa: use of keys with odd size is not allowed in FIPS 140-only mode")
}
if pub.E <= 1<<16 {
return errors.New("crypto/rsa: use of public exponent <= 2¹⁶ is not allowed in FIPS 140-only mode")
}
if pub.E&1 == 0 {
return errors.New("crypto/rsa: use of even public exponent is not allowed in FIPS 140-only mode")
}
return nil
}
func checkFIPS140OnlyPrivateKey(priv *PrivateKey) error {
if !fips140only.Enabled {
return nil
}
if err := checkFIPS140OnlyPublicKey(&priv.PublicKey); err != nil {
return err
}
if len(priv.Primes) != 2 {
return errors.New("crypto/rsa: use of multi-prime keys is not allowed in FIPS 140-only mode")
}
if priv.Primes[0] == nil || priv.Primes[1] == nil || priv.Primes[0].BitLen() != priv.Primes[1].BitLen() {
return errors.New("crypto/rsa: use of primes of different sizes is not allowed in FIPS 140-only mode")
}
return nil
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !boringcrypto
package rsa
import "crypto/internal/boring"
func boringPublicKey(*PublicKey) (*boring.PublicKeyRSA, error) {
panic("boringcrypto: not available")
}
func boringPrivateKey(*PrivateKey) (*boring.PrivateKeyRSA, error) {
panic("boringcrypto: not available")
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rsa
import (
"crypto/internal/boring"
"crypto/internal/fips140/rsa"
"crypto/internal/fips140only"
"crypto/internal/randutil"
"crypto/subtle"
"errors"
"io"
)
// This file implements encryption and decryption using PKCS #1 v1.5 padding.
// PKCS1v15DecryptOptions is for passing options to PKCS #1 v1.5 decryption using
// the [crypto.Decrypter] interface.
type PKCS1v15DecryptOptions struct {
// SessionKeyLen is the length of the session key that is being
// decrypted. If not zero, then a padding error during decryption will
// cause a random plaintext of this length to be returned rather than
// an error. These alternatives happen in constant time.
SessionKeyLen int
}
// EncryptPKCS1v15 encrypts the given message with RSA and the padding
// scheme from PKCS #1 v1.5. The message must be no longer than the
// length of the public modulus minus 11 bytes.
//
// The random parameter is used as a source of entropy to ensure that
// encrypting the same message twice doesn't result in the same
// ciphertext. Most applications should use [crypto/rand.Reader]
// as random. Note that the returned ciphertext does not depend
// deterministically on the bytes read from random, and may change
// between calls and/or between versions.
//
// WARNING: use of this function to encrypt plaintexts other than
// session keys is dangerous. Use RSA OAEP in new protocols.
func EncryptPKCS1v15(random io.Reader, pub *PublicKey, msg []byte) ([]byte, error) {
if fips140only.Enabled {
return nil, errors.New("crypto/rsa: use of PKCS#1 v1.5 encryption is not allowed in FIPS 140-only mode")
}
if err := checkPublicKeySize(pub); err != nil {
return nil, err
}
randutil.MaybeReadByte(random)
k := pub.Size()
if len(msg) > k-11 {
return nil, ErrMessageTooLong
}
if boring.Enabled && random == boring.RandReader {
bkey, err := boringPublicKey(pub)
if err != nil {
return nil, err
}
return boring.EncryptRSAPKCS1(bkey, msg)
}
boring.UnreachableExceptTests()
// EM = 0x00 || 0x02 || PS || 0x00 || M
em := make([]byte, k)
em[1] = 2
ps, mm := em[2:len(em)-len(msg)-1], em[len(em)-len(msg):]
err := nonZeroRandomBytes(ps, random)
if err != nil {
return nil, err
}
em[len(em)-len(msg)-1] = 0
copy(mm, msg)
if boring.Enabled {
var bkey *boring.PublicKeyRSA
bkey, err = boringPublicKey(pub)
if err != nil {
return nil, err
}
return boring.EncryptRSANoPadding(bkey, em)
}
fk, err := fipsPublicKey(pub)
if err != nil {
return nil, err
}
return rsa.Encrypt(fk, em)
}
// DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS #1 v1.5.
// The random parameter is legacy and ignored, and it can be nil.
//
// Note that whether this function returns an error or not discloses secret
// information. If an attacker can cause this function to run repeatedly and
// learn whether each instance returned an error then they can decrypt and
// forge signatures as if they had the private key. See
// DecryptPKCS1v15SessionKey for a way of solving this problem.
func DecryptPKCS1v15(random io.Reader, priv *PrivateKey, ciphertext []byte) ([]byte, error) {
if err := checkPublicKeySize(&priv.PublicKey); err != nil {
return nil, err
}
if boring.Enabled {
bkey, err := boringPrivateKey(priv)
if err != nil {
return nil, err
}
out, err := boring.DecryptRSAPKCS1(bkey, ciphertext)
if err != nil {
return nil, ErrDecryption
}
return out, nil
}
valid, out, index, err := decryptPKCS1v15(priv, ciphertext)
if err != nil {
return nil, err
}
if valid == 0 {
return nil, ErrDecryption
}
return out[index:], nil
}
// DecryptPKCS1v15SessionKey decrypts a session key using RSA and the padding
// scheme from PKCS #1 v1.5. The random parameter is legacy and ignored, and it
// can be nil.
//
// DecryptPKCS1v15SessionKey returns an error if the ciphertext is the wrong
// length or if the ciphertext is greater than the public modulus. Otherwise, no
// error is returned. If the padding is valid, the resulting plaintext message
// is copied into key. Otherwise, key is unchanged. These alternatives occur in
// constant time. It is intended that the user of this function generate a
// random session key beforehand and continue the protocol with the resulting
// value.
//
// Note that if the session key is too small then it may be possible for an
// attacker to brute-force it. If they can do that then they can learn whether a
// random value was used (because it'll be different for the same ciphertext)
// and thus whether the padding was correct. This also defeats the point of this
// function. Using at least a 16-byte key will protect against this attack.
//
// This method implements protections against Bleichenbacher chosen ciphertext
// attacks [0] described in RFC 3218 Section 2.3.2 [1]. While these protections
// make a Bleichenbacher attack significantly more difficult, the protections
// are only effective if the rest of the protocol which uses
// DecryptPKCS1v15SessionKey is designed with these considerations in mind. In
// particular, if any subsequent operations which use the decrypted session key
// leak any information about the key (e.g. whether it is a static or random
// key) then the mitigations are defeated. This method must be used extremely
// carefully, and typically should only be used when absolutely necessary for
// compatibility with an existing protocol (such as TLS) that is designed with
// these properties in mind.
//
// - [0] “Chosen Ciphertext Attacks Against Protocols Based on the RSA Encryption
// Standard PKCS #1”, Daniel Bleichenbacher, Advances in Cryptology (Crypto '98)
// - [1] RFC 3218, Preventing the Million Message Attack on CMS,
// https://www.rfc-editor.org/rfc/rfc3218.html
func DecryptPKCS1v15SessionKey(random io.Reader, priv *PrivateKey, ciphertext []byte, key []byte) error {
if err := checkPublicKeySize(&priv.PublicKey); err != nil {
return err
}
k := priv.Size()
if k-(len(key)+3+8) < 0 {
return ErrDecryption
}
valid, em, index, err := decryptPKCS1v15(priv, ciphertext)
if err != nil {
return err
}
if len(em) != k {
// This should be impossible because decryptPKCS1v15 always
// returns the full slice.
return ErrDecryption
}
valid &= subtle.ConstantTimeEq(int32(len(em)-index), int32(len(key)))
subtle.ConstantTimeCopy(valid, key, em[len(em)-len(key):])
return nil
}
// decryptPKCS1v15 decrypts ciphertext using priv. It returns one or zero in
// valid that indicates whether the plaintext was correctly structured.
// In either case, the plaintext is returned in em so that it may be read
// independently of whether it was valid in order to maintain constant memory
// access patterns. If the plaintext was valid then index contains the index of
// the original message in em, to allow constant time padding removal.
func decryptPKCS1v15(priv *PrivateKey, ciphertext []byte) (valid int, em []byte, index int, err error) {
if fips140only.Enabled {
return 0, nil, 0, errors.New("crypto/rsa: use of PKCS#1 v1.5 encryption is not allowed in FIPS 140-only mode")
}
k := priv.Size()
if k < 11 {
err = ErrDecryption
return 0, nil, 0, err
}
if boring.Enabled {
var bkey *boring.PrivateKeyRSA
bkey, err = boringPrivateKey(priv)
if err != nil {
return 0, nil, 0, err
}
em, err = boring.DecryptRSANoPadding(bkey, ciphertext)
if err != nil {
return 0, nil, 0, ErrDecryption
}
} else {
fk, err := fipsPrivateKey(priv)
if err != nil {
return 0, nil, 0, err
}
em, err = rsa.DecryptWithoutCheck(fk, ciphertext)
if err != nil {
return 0, nil, 0, ErrDecryption
}
}
firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
secondByteIsTwo := subtle.ConstantTimeByteEq(em[1], 2)
// The remainder of the plaintext must be a string of non-zero random
// octets, followed by a 0, followed by the message.
// lookingForIndex: 1 iff we are still looking for the zero.
// index: the offset of the first zero byte.
lookingForIndex := 1
for i := 2; i < len(em); i++ {
equals0 := subtle.ConstantTimeByteEq(em[i], 0)
index = subtle.ConstantTimeSelect(lookingForIndex&equals0, i, index)
lookingForIndex = subtle.ConstantTimeSelect(equals0, 0, lookingForIndex)
}
// The PS padding must be at least 8 bytes long, and it starts two
// bytes into em.
validPS := subtle.ConstantTimeLessOrEq(2+8, index)
valid = firstByteIsZero & secondByteIsTwo & (^lookingForIndex & 1) & validPS
index = subtle.ConstantTimeSelect(valid, index+1, 0)
return valid, em, index, nil
}
// nonZeroRandomBytes fills the given slice with non-zero random octets.
func nonZeroRandomBytes(s []byte, random io.Reader) (err error) {
_, err = io.ReadFull(random, s)
if err != nil {
return
}
for i := 0; i < len(s); i++ {
for s[i] == 0 {
_, err = io.ReadFull(random, s[i:i+1])
if err != nil {
return
}
// In tests, the PRNG may return all zeros so we do
// this to break the loop.
s[i] ^= 0x42
}
}
return
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package rsa implements RSA encryption as specified in PKCS #1 and RFC 8017.
//
// RSA is a single, fundamental operation that is used in this package to
// implement either public-key encryption or public-key signatures.
//
// The original specification for encryption and signatures with RSA is PKCS #1
// and the terms "RSA encryption" and "RSA signatures" by default refer to
// PKCS #1 version 1.5. However, that specification has flaws and new designs
// should use version 2, usually called by just OAEP and PSS, where
// possible.
//
// Two sets of interfaces are included in this package. When a more abstract
// interface isn't necessary, there are functions for encrypting/decrypting
// with v1.5/OAEP and signing/verifying with v1.5/PSS. If one needs to abstract
// over the public key primitive, the PrivateKey type implements the
// Decrypter and Signer interfaces from the crypto package.
//
// Operations involving private keys are implemented using constant-time
// algorithms, except for [GenerateKey] and for some operations involving
// deprecated multi-prime keys.
//
// # Minimum key size
//
// [GenerateKey] returns an error if a key of less than 1024 bits is requested,
// and all Sign, Verify, Encrypt, and Decrypt methods return an error if used
// with a key smaller than 1024 bits. Such keys are insecure and should not be
// used.
//
// The rsa1024min=0 GODEBUG setting suppresses this error, but we recommend
// doing so only in tests, if necessary. Tests can set this option using
// [testing.T.Setenv] or by including "//go:debug rsa1024min=0" in a *_test.go
// source file.
//
// Alternatively, see the [GenerateKey (TestKey)] example for a pregenerated
// test-only 2048-bit key.
//
// [GenerateKey (TestKey)]: https://pkg.go.dev/crypto/rsa#example-GenerateKey-TestKey
package rsa
import (
"crypto"
"crypto/internal/boring"
"crypto/internal/boring/bbig"
"crypto/internal/fips140/bigmod"
"crypto/internal/fips140/rsa"
"crypto/internal/fips140only"
"crypto/internal/randutil"
"crypto/rand"
"crypto/subtle"
"errors"
"fmt"
"internal/godebug"
"io"
"math"
"math/big"
)
var bigOne = big.NewInt(1)
// A PublicKey represents the public part of an RSA key.
//
// The values of N and E are not considered confidential, and may leak through
// side channels, or could be mathematically derived from other public values.
type PublicKey struct {
N *big.Int // modulus
E int // public exponent
}
// Any methods implemented on PublicKey might need to also be implemented on
// PrivateKey, as the latter embeds the former and will expose its methods.
// Size returns the modulus size in bytes. Raw signatures and ciphertexts
// for or by this public key will have the same size.
func (pub *PublicKey) Size() int {
return (pub.N.BitLen() + 7) / 8
}
// Equal reports whether pub and x have the same value.
func (pub *PublicKey) Equal(x crypto.PublicKey) bool {
xx, ok := x.(*PublicKey)
if !ok {
return false
}
return bigIntEqual(pub.N, xx.N) && pub.E == xx.E
}
// OAEPOptions is an interface for passing options to OAEP decryption using the
// crypto.Decrypter interface.
type OAEPOptions struct {
// Hash is the hash function that will be used when generating the mask.
Hash crypto.Hash
// MGFHash is the hash function used for MGF1.
// If zero, Hash is used instead.
MGFHash crypto.Hash
// Label is an arbitrary byte string that must be equal to the value
// used when encrypting.
Label []byte
}
// A PrivateKey represents an RSA key
type PrivateKey struct {
PublicKey // public part.
D *big.Int // private exponent
Primes []*big.Int // prime factors of N, has >= 2 elements.
// Precomputed contains precomputed values that speed up RSA operations,
// if available. It must be generated by calling PrivateKey.Precompute and
// must not be modified.
Precomputed PrecomputedValues
}
// Public returns the public key corresponding to priv.
func (priv *PrivateKey) Public() crypto.PublicKey {
return &priv.PublicKey
}
// Equal reports whether priv and x have equivalent values. It ignores
// Precomputed values.
func (priv *PrivateKey) Equal(x crypto.PrivateKey) bool {
xx, ok := x.(*PrivateKey)
if !ok {
return false
}
if !priv.PublicKey.Equal(&xx.PublicKey) || !bigIntEqual(priv.D, xx.D) {
return false
}
if len(priv.Primes) != len(xx.Primes) {
return false
}
for i := range priv.Primes {
if !bigIntEqual(priv.Primes[i], xx.Primes[i]) {
return false
}
}
return true
}
// bigIntEqual reports whether a and b are equal leaking only their bit length
// through timing side-channels.
func bigIntEqual(a, b *big.Int) bool {
return subtle.ConstantTimeCompare(a.Bytes(), b.Bytes()) == 1
}
// Sign signs digest with priv, reading randomness from rand. If opts is a
// *[PSSOptions] then the PSS algorithm will be used, otherwise PKCS #1 v1.5 will
// be used. digest must be the result of hashing the input message using
// opts.HashFunc().
//
// This method implements [crypto.Signer], which is an interface to support keys
// where the private part is kept in, for example, a hardware module. Common
// uses should use the Sign* functions in this package directly.
func (priv *PrivateKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
if pssOpts, ok := opts.(*PSSOptions); ok {
return SignPSS(rand, priv, pssOpts.Hash, digest, pssOpts)
}
return SignPKCS1v15(rand, priv, opts.HashFunc(), digest)
}
// Decrypt decrypts ciphertext with priv. If opts is nil or of type
// *[PKCS1v15DecryptOptions] then PKCS #1 v1.5 decryption is performed. Otherwise
// opts must have type *[OAEPOptions] and OAEP decryption is done.
func (priv *PrivateKey) Decrypt(rand io.Reader, ciphertext []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
if opts == nil {
return DecryptPKCS1v15(rand, priv, ciphertext)
}
switch opts := opts.(type) {
case *OAEPOptions:
if opts.MGFHash == 0 {
return decryptOAEP(opts.Hash.New(), opts.Hash.New(), priv, ciphertext, opts.Label)
} else {
return decryptOAEP(opts.Hash.New(), opts.MGFHash.New(), priv, ciphertext, opts.Label)
}
case *PKCS1v15DecryptOptions:
if l := opts.SessionKeyLen; l > 0 {
plaintext = make([]byte, l)
if _, err := io.ReadFull(rand, plaintext); err != nil {
return nil, err
}
if err := DecryptPKCS1v15SessionKey(rand, priv, ciphertext, plaintext); err != nil {
return nil, err
}
return plaintext, nil
} else {
return DecryptPKCS1v15(rand, priv, ciphertext)
}
default:
return nil, errors.New("crypto/rsa: invalid options for Decrypt")
}
}
type PrecomputedValues struct {
Dp, Dq *big.Int // D mod (P-1) (or mod Q-1)
Qinv *big.Int // Q^-1 mod P
// CRTValues is used for the 3rd and subsequent primes. Due to a
// historical accident, the CRT for the first two primes is handled
// differently in PKCS #1 and interoperability is sufficiently
// important that we mirror this.
//
// Deprecated: These values are still filled in by Precompute for
// backwards compatibility but are not used. Multi-prime RSA is very rare,
// and is implemented by this package without CRT optimizations to limit
// complexity.
CRTValues []CRTValue
fips *rsa.PrivateKey
}
// CRTValue contains the precomputed Chinese remainder theorem values.
type CRTValue struct {
Exp *big.Int // D mod (prime-1).
Coeff *big.Int // R·Coeff ≡ 1 mod Prime.
R *big.Int // product of primes prior to this (inc p and q).
}
// Validate performs basic sanity checks on the key.
// It returns nil if the key is valid, or else an error describing a problem.
//
// It runs faster on valid keys if run after [PrivateKey.Precompute].
func (priv *PrivateKey) Validate() error {
// We can operate on keys based on d alone, but it isn't possible to encode
// with [crypto/x509.MarshalPKCS1PrivateKey], which unfortunately doesn't
// return an error.
if len(priv.Primes) < 2 {
return errors.New("crypto/rsa: missing primes")
}
// If Precomputed.fips is set, then the key has been validated by
// [rsa.NewPrivateKey] or [rsa.NewPrivateKeyWithoutCRT].
if priv.Precomputed.fips != nil {
return nil
}
_, err := priv.precompute()
return err
}
// rsa1024min is a GODEBUG that re-enables weak RSA keys if set to "0".
// See https://go.dev/issue/68762.
var rsa1024min = godebug.New("rsa1024min")
func checkKeySize(size int) error {
if size >= 1024 {
return nil
}
if rsa1024min.Value() == "0" {
rsa1024min.IncNonDefault()
return nil
}
return fmt.Errorf("crypto/rsa: %d-bit keys are insecure (see https://go.dev/pkg/crypto/rsa#hdr-Minimum_key_size)", size)
}
func checkPublicKeySize(k *PublicKey) error {
if k.N == nil {
return errors.New("crypto/rsa: missing public modulus")
}
return checkKeySize(k.N.BitLen())
}
// GenerateKey generates a random RSA private key of the given bit size.
//
// If bits is less than 1024, [GenerateKey] returns an error. See the "[Minimum
// key size]" section for further details.
//
// Most applications should use [crypto/rand.Reader] as rand. Note that the
// returned key does not depend deterministically on the bytes read from rand,
// and may change between calls and/or between versions.
//
// [Minimum key size]: https://pkg.go.dev/crypto/rsa#hdr-Minimum_key_size
func GenerateKey(random io.Reader, bits int) (*PrivateKey, error) {
if err := checkKeySize(bits); err != nil {
return nil, err
}
if boring.Enabled && random == boring.RandReader &&
(bits == 2048 || bits == 3072 || bits == 4096) {
bN, bE, bD, bP, bQ, bDp, bDq, bQinv, err := boring.GenerateKeyRSA(bits)
if err != nil {
return nil, err
}
N := bbig.Dec(bN)
E := bbig.Dec(bE)
D := bbig.Dec(bD)
P := bbig.Dec(bP)
Q := bbig.Dec(bQ)
Dp := bbig.Dec(bDp)
Dq := bbig.Dec(bDq)
Qinv := bbig.Dec(bQinv)
e64 := E.Int64()
if !E.IsInt64() || int64(int(e64)) != e64 {
return nil, errors.New("crypto/rsa: generated key exponent too large")
}
key := &PrivateKey{
PublicKey: PublicKey{
N: N,
E: int(e64),
},
D: D,
Primes: []*big.Int{P, Q},
Precomputed: PrecomputedValues{
Dp: Dp,
Dq: Dq,
Qinv: Qinv,
CRTValues: make([]CRTValue, 0), // non-nil, to match Precompute
},
}
return key, nil
}
if fips140only.Enabled && bits < 2048 {
return nil, errors.New("crypto/rsa: use of keys smaller than 2048 bits is not allowed in FIPS 140-only mode")
}
if fips140only.Enabled && bits%2 == 1 {
return nil, errors.New("crypto/rsa: use of keys with odd size is not allowed in FIPS 140-only mode")
}
if fips140only.Enabled && !fips140only.ApprovedRandomReader(random) {
return nil, errors.New("crypto/rsa: only crypto/rand.Reader is allowed in FIPS 140-only mode")
}
k, err := rsa.GenerateKey(random, bits)
if bits < 256 && err != nil {
// Toy-sized keys have a non-negligible chance of hitting two hard
// failure cases: p == q and d <= 2^(nlen / 2).
//
// Since these are impossible to hit for real keys, we don't want to
// make the production code path more complex and harder to think about
// to handle them.
//
// Instead, just rerun the whole process a total of 8 times, which
// brings the chance of failure for 32-bit keys down to the same as for
// 256-bit keys.
for i := 1; i < 8 && err != nil; i++ {
k, err = rsa.GenerateKey(random, bits)
}
}
if err != nil {
return nil, err
}
N, e, d, p, q, dP, dQ, qInv := k.Export()
key := &PrivateKey{
PublicKey: PublicKey{
N: new(big.Int).SetBytes(N),
E: e,
},
D: new(big.Int).SetBytes(d),
Primes: []*big.Int{
new(big.Int).SetBytes(p),
new(big.Int).SetBytes(q),
},
Precomputed: PrecomputedValues{
fips: k,
Dp: new(big.Int).SetBytes(dP),
Dq: new(big.Int).SetBytes(dQ),
Qinv: new(big.Int).SetBytes(qInv),
CRTValues: make([]CRTValue, 0), // non-nil, to match Precompute
},
}
return key, nil
}
// GenerateMultiPrimeKey generates a multi-prime RSA keypair of the given bit
// size and the given random source.
//
// Table 1 in "[On the Security of Multi-prime RSA]" suggests maximum numbers of
// primes for a given bit size.
//
// Although the public keys are compatible (actually, indistinguishable) from
// the 2-prime case, the private keys are not. Thus it may not be possible to
// export multi-prime private keys in certain formats or to subsequently import
// them into other code.
//
// This package does not implement CRT optimizations for multi-prime RSA, so the
// keys with more than two primes will have worse performance.
//
// Deprecated: The use of this function with a number of primes different from
// two is not recommended for the above security, compatibility, and performance
// reasons. Use [GenerateKey] instead.
//
// [On the Security of Multi-prime RSA]: http://www.cacr.math.uwaterloo.ca/techreports/2006/cacr2006-16.pdf
func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (*PrivateKey, error) {
if nprimes == 2 {
return GenerateKey(random, bits)
}
if fips140only.Enabled {
return nil, errors.New("crypto/rsa: multi-prime RSA is not allowed in FIPS 140-only mode")
}
randutil.MaybeReadByte(random)
priv := new(PrivateKey)
priv.E = 65537
if nprimes < 2 {
return nil, errors.New("crypto/rsa: GenerateMultiPrimeKey: nprimes must be >= 2")
}
if bits < 64 {
primeLimit := float64(uint64(1) << uint(bits/nprimes))
// pi approximates the number of primes less than primeLimit
pi := primeLimit / (math.Log(primeLimit) - 1)
// Generated primes start with 11 (in binary) so we can only
// use a quarter of them.
pi /= 4
// Use a factor of two to ensure that key generation terminates
// in a reasonable amount of time.
pi /= 2
if pi <= float64(nprimes) {
return nil, errors.New("crypto/rsa: too few primes of given length to generate an RSA key")
}
}
primes := make([]*big.Int, nprimes)
NextSetOfPrimes:
for {
todo := bits
// crypto/rand should set the top two bits in each prime.
// Thus each prime has the form
// p_i = 2^bitlen(p_i) × 0.11... (in base 2).
// And the product is:
// P = 2^todo × α
// where α is the product of nprimes numbers of the form 0.11...
//
// If α < 1/2 (which can happen for nprimes > 2), we need to
// shift todo to compensate for lost bits: the mean value of 0.11...
// is 7/8, so todo + shift - nprimes * log2(7/8) ~= bits - 1/2
// will give good results.
if nprimes >= 7 {
todo += (nprimes - 2) / 5
}
for i := 0; i < nprimes; i++ {
var err error
primes[i], err = rand.Prime(random, todo/(nprimes-i))
if err != nil {
return nil, err
}
todo -= primes[i].BitLen()
}
// Make sure that primes is pairwise unequal.
for i, prime := range primes {
for j := 0; j < i; j++ {
if prime.Cmp(primes[j]) == 0 {
continue NextSetOfPrimes
}
}
}
n := new(big.Int).Set(bigOne)
totient := new(big.Int).Set(bigOne)
pminus1 := new(big.Int)
for _, prime := range primes {
n.Mul(n, prime)
pminus1.Sub(prime, bigOne)
totient.Mul(totient, pminus1)
}
if n.BitLen() != bits {
// This should never happen for nprimes == 2 because
// crypto/rand should set the top two bits in each prime.
// For nprimes > 2 we hope it does not happen often.
continue NextSetOfPrimes
}
priv.D = new(big.Int)
e := big.NewInt(int64(priv.E))
ok := priv.D.ModInverse(e, totient)
if ok != nil {
priv.Primes = primes
priv.N = n
break
}
}
priv.Precompute()
if err := priv.Validate(); err != nil {
return nil, err
}
return priv, nil
}
// ErrMessageTooLong is returned when attempting to encrypt or sign a message
// which is too large for the size of the key. When using [SignPSS], this can also
// be returned if the size of the salt is too large.
var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA key size")
// ErrDecryption represents a failure to decrypt a message.
// It is deliberately vague to avoid adaptive attacks.
var ErrDecryption = errors.New("crypto/rsa: decryption error")
// ErrVerification represents a failure to verify a signature.
// It is deliberately vague to avoid adaptive attacks.
var ErrVerification = errors.New("crypto/rsa: verification error")
// Precompute performs some calculations that speed up private key operations
// in the future. It is safe to run on non-validated private keys.
func (priv *PrivateKey) Precompute() {
if priv.Precomputed.fips != nil {
return
}
precomputed, err := priv.precompute()
if err != nil {
// We don't have a way to report errors, so just leave the key
// unmodified. Validate will re-run precompute.
return
}
priv.Precomputed = precomputed
}
func (priv *PrivateKey) precompute() (PrecomputedValues, error) {
var precomputed PrecomputedValues
if priv.N == nil {
return precomputed, errors.New("crypto/rsa: missing public modulus")
}
if priv.D == nil {
return precomputed, errors.New("crypto/rsa: missing private exponent")
}
if len(priv.Primes) != 2 {
return priv.precomputeLegacy()
}
if priv.Primes[0] == nil {
return precomputed, errors.New("crypto/rsa: prime P is nil")
}
if priv.Primes[1] == nil {
return precomputed, errors.New("crypto/rsa: prime Q is nil")
}
// If the CRT values are already set, use them.
if priv.Precomputed.Dp != nil && priv.Precomputed.Dq != nil && priv.Precomputed.Qinv != nil {
k, err := rsa.NewPrivateKeyWithPrecomputation(priv.N.Bytes(), priv.E, priv.D.Bytes(),
priv.Primes[0].Bytes(), priv.Primes[1].Bytes(),
priv.Precomputed.Dp.Bytes(), priv.Precomputed.Dq.Bytes(), priv.Precomputed.Qinv.Bytes())
if err != nil {
return precomputed, err
}
precomputed = priv.Precomputed
precomputed.fips = k
precomputed.CRTValues = make([]CRTValue, 0)
return precomputed, nil
}
k, err := rsa.NewPrivateKey(priv.N.Bytes(), priv.E, priv.D.Bytes(),
priv.Primes[0].Bytes(), priv.Primes[1].Bytes())
if err != nil {
return precomputed, err
}
precomputed.fips = k
_, _, _, _, _, dP, dQ, qInv := k.Export()
precomputed.Dp = new(big.Int).SetBytes(dP)
precomputed.Dq = new(big.Int).SetBytes(dQ)
precomputed.Qinv = new(big.Int).SetBytes(qInv)
precomputed.CRTValues = make([]CRTValue, 0)
return precomputed, nil
}
func (priv *PrivateKey) precomputeLegacy() (PrecomputedValues, error) {
var precomputed PrecomputedValues
k, err := rsa.NewPrivateKeyWithoutCRT(priv.N.Bytes(), priv.E, priv.D.Bytes())
if err != nil {
return precomputed, err
}
precomputed.fips = k
if len(priv.Primes) < 2 {
return precomputed, nil
}
// Ensure the Mod and ModInverse calls below don't panic.
for _, prime := range priv.Primes {
if prime == nil {
return precomputed, errors.New("crypto/rsa: prime factor is nil")
}
if prime.Cmp(bigOne) <= 0 {
return precomputed, errors.New("crypto/rsa: prime factor is <= 1")
}
}
precomputed.Dp = new(big.Int).Sub(priv.Primes[0], bigOne)
precomputed.Dp.Mod(priv.D, precomputed.Dp)
precomputed.Dq = new(big.Int).Sub(priv.Primes[1], bigOne)
precomputed.Dq.Mod(priv.D, precomputed.Dq)
precomputed.Qinv = new(big.Int).ModInverse(priv.Primes[1], priv.Primes[0])
if precomputed.Qinv == nil {
return precomputed, errors.New("crypto/rsa: prime factors are not relatively prime")
}
r := new(big.Int).Mul(priv.Primes[0], priv.Primes[1])
precomputed.CRTValues = make([]CRTValue, len(priv.Primes)-2)
for i := 2; i < len(priv.Primes); i++ {
prime := priv.Primes[i]
values := &precomputed.CRTValues[i-2]
values.Exp = new(big.Int).Sub(prime, bigOne)
values.Exp.Mod(priv.D, values.Exp)
values.R = new(big.Int).Set(r)
values.Coeff = new(big.Int).ModInverse(r, prime)
if values.Coeff == nil {
return precomputed, errors.New("crypto/rsa: prime factors are not relatively prime")
}
r.Mul(r, prime)
}
return precomputed, nil
}
func fipsPublicKey(pub *PublicKey) (*rsa.PublicKey, error) {
N, err := bigmod.NewModulus(pub.N.Bytes())
if err != nil {
return nil, err
}
return &rsa.PublicKey{N: N, E: pub.E}, nil
}
func fipsPrivateKey(priv *PrivateKey) (*rsa.PrivateKey, error) {
if priv.Precomputed.fips != nil {
return priv.Precomputed.fips, nil
}
precomputed, err := priv.precompute()
if err != nil {
return nil, err
}
return precomputed.fips, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package sha1 implements the SHA-1 hash algorithm as defined in RFC 3174.
//
// SHA-1 is cryptographically broken and should not be used for secure
// applications.
package sha1
import (
"crypto"
"crypto/internal/boring"
"crypto/internal/fips140only"
"errors"
"hash"
"internal/byteorder"
)
func init() {
crypto.RegisterHash(crypto.SHA1, New)
}
// The size of a SHA-1 checksum in bytes.
const Size = 20
// The blocksize of SHA-1 in bytes.
const BlockSize = 64
const (
chunk = 64
init0 = 0x67452301
init1 = 0xEFCDAB89
init2 = 0x98BADCFE
init3 = 0x10325476
init4 = 0xC3D2E1F0
)
// digest represents the partial evaluation of a checksum.
type digest struct {
h [5]uint32
x [chunk]byte
nx int
len uint64
}
const (
magic = "sha\x01"
marshaledSize = len(magic) + 5*4 + chunk + 8
)
func (d *digest) MarshalBinary() ([]byte, error) {
return d.AppendBinary(make([]byte, 0, marshaledSize))
}
func (d *digest) AppendBinary(b []byte) ([]byte, error) {
b = append(b, magic...)
b = byteorder.BEAppendUint32(b, d.h[0])
b = byteorder.BEAppendUint32(b, d.h[1])
b = byteorder.BEAppendUint32(b, d.h[2])
b = byteorder.BEAppendUint32(b, d.h[3])
b = byteorder.BEAppendUint32(b, d.h[4])
b = append(b, d.x[:d.nx]...)
b = append(b, make([]byte, len(d.x)-d.nx)...)
b = byteorder.BEAppendUint64(b, d.len)
return b, nil
}
func (d *digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic) || string(b[:len(magic)]) != magic {
return errors.New("crypto/sha1: invalid hash state identifier")
}
if len(b) != marshaledSize {
return errors.New("crypto/sha1: invalid hash state size")
}
b = b[len(magic):]
b, d.h[0] = consumeUint32(b)
b, d.h[1] = consumeUint32(b)
b, d.h[2] = consumeUint32(b)
b, d.h[3] = consumeUint32(b)
b, d.h[4] = consumeUint32(b)
b = b[copy(d.x[:], b):]
b, d.len = consumeUint64(b)
d.nx = int(d.len % chunk)
return nil
}
func consumeUint64(b []byte) ([]byte, uint64) {
return b[8:], byteorder.BEUint64(b)
}
func consumeUint32(b []byte) ([]byte, uint32) {
return b[4:], byteorder.BEUint32(b)
}
func (d *digest) Clone() (hash.Cloner, error) {
r := *d
return &r, nil
}
func (d *digest) Reset() {
d.h[0] = init0
d.h[1] = init1
d.h[2] = init2
d.h[3] = init3
d.h[4] = init4
d.nx = 0
d.len = 0
}
// New returns a new [hash.Hash] computing the SHA1 checksum. The Hash
// also implements [encoding.BinaryMarshaler], [encoding.BinaryAppender] and
// [encoding.BinaryUnmarshaler] to marshal and unmarshal the internal
// state of the hash.
func New() hash.Hash {
if boring.Enabled {
return boring.NewSHA1()
}
d := new(digest)
d.Reset()
return d
}
func (d *digest) Size() int { return Size }
func (d *digest) BlockSize() int { return BlockSize }
func (d *digest) Write(p []byte) (nn int, err error) {
if fips140only.Enabled {
return 0, errors.New("crypto/sha1: use of SHA-1 is not allowed in FIPS 140-only mode")
}
boring.Unreachable()
nn = len(p)
d.len += uint64(nn)
if d.nx > 0 {
n := copy(d.x[d.nx:], p)
d.nx += n
if d.nx == chunk {
block(d, d.x[:])
d.nx = 0
}
p = p[n:]
}
if len(p) >= chunk {
n := len(p) &^ (chunk - 1)
block(d, p[:n])
p = p[n:]
}
if len(p) > 0 {
d.nx = copy(d.x[:], p)
}
return
}
func (d *digest) Sum(in []byte) []byte {
boring.Unreachable()
// Make a copy of d so that caller can keep writing and summing.
d0 := *d
hash := d0.checkSum()
return append(in, hash[:]...)
}
func (d *digest) checkSum() [Size]byte {
if fips140only.Enabled {
panic("crypto/sha1: use of SHA-1 is not allowed in FIPS 140-only mode")
}
len := d.len
// Padding. Add a 1 bit and 0 bits until 56 bytes mod 64.
var tmp [64 + 8]byte // padding + length buffer
tmp[0] = 0x80
var t uint64
if len%64 < 56 {
t = 56 - len%64
} else {
t = 64 + 56 - len%64
}
// Length in bits.
len <<= 3
padlen := tmp[:t+8]
byteorder.BEPutUint64(padlen[t:], len)
d.Write(padlen)
if d.nx != 0 {
panic("d.nx != 0")
}
var digest [Size]byte
byteorder.BEPutUint32(digest[0:], d.h[0])
byteorder.BEPutUint32(digest[4:], d.h[1])
byteorder.BEPutUint32(digest[8:], d.h[2])
byteorder.BEPutUint32(digest[12:], d.h[3])
byteorder.BEPutUint32(digest[16:], d.h[4])
return digest
}
// ConstantTimeSum computes the same result of [Sum] but in constant time
func (d *digest) ConstantTimeSum(in []byte) []byte {
d0 := *d
hash := d0.constSum()
return append(in, hash[:]...)
}
func (d *digest) constSum() [Size]byte {
if fips140only.Enabled {
panic("crypto/sha1: use of SHA-1 is not allowed in FIPS 140-only mode")
}
var length [8]byte
l := d.len << 3
for i := uint(0); i < 8; i++ {
length[i] = byte(l >> (56 - 8*i))
}
nx := byte(d.nx)
t := nx - 56 // if nx < 56 then the MSB of t is one
mask1b := byte(int8(t) >> 7) // mask1b is 0xFF iff one block is enough
separator := byte(0x80) // gets reset to 0x00 once used
for i := byte(0); i < chunk; i++ {
mask := byte(int8(i-nx) >> 7) // 0x00 after the end of data
// if we reached the end of the data, replace with 0x80 or 0x00
d.x[i] = (^mask & separator) | (mask & d.x[i])
// zero the separator once used
separator &= mask
if i >= 56 {
// we might have to write the length here if all fit in one block
d.x[i] |= mask1b & length[i-56]
}
}
// compress, and only keep the digest if all fit in one block
block(d, d.x[:])
var digest [Size]byte
for i, s := range d.h {
digest[i*4] = mask1b & byte(s>>24)
digest[i*4+1] = mask1b & byte(s>>16)
digest[i*4+2] = mask1b & byte(s>>8)
digest[i*4+3] = mask1b & byte(s)
}
for i := byte(0); i < chunk; i++ {
// second block, it's always past the end of data, might start with 0x80
if i < 56 {
d.x[i] = separator
separator = 0
} else {
d.x[i] = length[i-56]
}
}
// compress, and only keep the digest if we actually needed the second block
block(d, d.x[:])
for i, s := range d.h {
digest[i*4] |= ^mask1b & byte(s>>24)
digest[i*4+1] |= ^mask1b & byte(s>>16)
digest[i*4+2] |= ^mask1b & byte(s>>8)
digest[i*4+3] |= ^mask1b & byte(s)
}
return digest
}
// Sum returns the SHA-1 checksum of the data.
func Sum(data []byte) [Size]byte {
if boring.Enabled {
return boring.SHA1(data)
}
if fips140only.Enabled {
panic("crypto/sha1: use of SHA-1 is not allowed in FIPS 140-only mode")
}
var d digest
d.Reset()
d.Write(data)
return d.checkSum()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sha1
import (
"math/bits"
)
const (
_K0 = 0x5A827999
_K1 = 0x6ED9EBA1
_K2 = 0x8F1BBCDC
_K3 = 0xCA62C1D6
)
// blockGeneric is a portable, pure Go version of the SHA-1 block step.
// It's used by sha1block_generic.go and tests.
func blockGeneric(dig *digest, p []byte) {
var w [16]uint32
h0, h1, h2, h3, h4 := dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4]
for len(p) >= chunk {
// Can interlace the computation of w with the
// rounds below if needed for speed.
for i := 0; i < 16; i++ {
j := i * 4
w[i] = uint32(p[j])<<24 | uint32(p[j+1])<<16 | uint32(p[j+2])<<8 | uint32(p[j+3])
}
a, b, c, d, e := h0, h1, h2, h3, h4
// Each of the four 20-iteration rounds
// differs only in the computation of f and
// the choice of K (_K0, _K1, etc).
i := 0
for ; i < 16; i++ {
f := b&c | (^b)&d
t := bits.RotateLeft32(a, 5) + f + e + w[i&0xf] + _K0
a, b, c, d, e = t, a, bits.RotateLeft32(b, 30), c, d
}
for ; i < 20; i++ {
tmp := w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]
w[i&0xf] = bits.RotateLeft32(tmp, 1)
f := b&c | (^b)&d
t := bits.RotateLeft32(a, 5) + f + e + w[i&0xf] + _K0
a, b, c, d, e = t, a, bits.RotateLeft32(b, 30), c, d
}
for ; i < 40; i++ {
tmp := w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]
w[i&0xf] = bits.RotateLeft32(tmp, 1)
f := b ^ c ^ d
t := bits.RotateLeft32(a, 5) + f + e + w[i&0xf] + _K1
a, b, c, d, e = t, a, bits.RotateLeft32(b, 30), c, d
}
for ; i < 60; i++ {
tmp := w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]
w[i&0xf] = bits.RotateLeft32(tmp, 1)
f := ((b | c) & d) | (b & c)
t := bits.RotateLeft32(a, 5) + f + e + w[i&0xf] + _K2
a, b, c, d, e = t, a, bits.RotateLeft32(b, 30), c, d
}
for ; i < 80; i++ {
tmp := w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]
w[i&0xf] = bits.RotateLeft32(tmp, 1)
f := b ^ c ^ d
t := bits.RotateLeft32(a, 5) + f + e + w[i&0xf] + _K3
a, b, c, d, e = t, a, bits.RotateLeft32(b, 30), c, d
}
h0 += a
h1 += b
h2 += c
h3 += d
h4 += e
p = p[chunk:]
}
dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4] = h0, h1, h2, h3, h4
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !purego
package sha1
import (
"crypto/internal/impl"
"internal/cpu"
)
//go:noescape
func blockAVX2(dig *digest, p []byte)
//go:noescape
func blockSHANI(dig *digest, p []byte)
var useAVX2 = cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI1 && cpu.X86.HasBMI2
var useSHANI = cpu.X86.HasAVX && cpu.X86.HasSHA && cpu.X86.HasSSE41 && cpu.X86.HasSSSE3
func init() {
impl.Register("sha1", "AVX2", &useAVX2)
impl.Register("sha1", "SHA-NI", &useSHANI)
}
func block(dig *digest, p []byte) {
if useSHANI {
blockSHANI(dig, p)
} else if useAVX2 && len(p) >= 256 {
// blockAVX2 calculates sha1 for 2 block per iteration and also
// interleaves precalculation for next block. So it may read up-to 192
// bytes past end of p. We could add checks inside blockAVX2, but this
// would just turn it into a copy of the old pre-AVX2 amd64 SHA1
// assembly implementation, so just call blockGeneric instead.
safeLen := len(p) - 128
if safeLen%128 != 0 {
safeLen -= 64
}
blockAVX2(dig, p[:safeLen])
blockGeneric(dig, p[safeLen:])
} else {
blockGeneric(dig, p)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package sha256 implements the SHA224 and SHA256 hash algorithms as defined
// in FIPS 180-4.
package sha256
import (
"crypto"
"crypto/internal/boring"
"crypto/internal/fips140/sha256"
"hash"
)
func init() {
crypto.RegisterHash(crypto.SHA224, New224)
crypto.RegisterHash(crypto.SHA256, New)
}
// The size of a SHA256 checksum in bytes.
const Size = 32
// The size of a SHA224 checksum in bytes.
const Size224 = 28
// The blocksize of SHA256 and SHA224 in bytes.
const BlockSize = 64
// New returns a new [hash.Hash] computing the SHA256 checksum. The Hash
// also implements [encoding.BinaryMarshaler], [encoding.BinaryAppender] and
// [encoding.BinaryUnmarshaler] to marshal and unmarshal the internal
// state of the hash.
func New() hash.Hash {
if boring.Enabled {
return boring.NewSHA256()
}
return sha256.New()
}
// New224 returns a new [hash.Hash] computing the SHA224 checksum. The Hash
// also implements [encoding.BinaryMarshaler], [encoding.BinaryAppender] and
// [encoding.BinaryUnmarshaler] to marshal and unmarshal the internal
// state of the hash.
func New224() hash.Hash {
if boring.Enabled {
return boring.NewSHA224()
}
return sha256.New224()
}
// Sum256 returns the SHA256 checksum of the data.
func Sum256(data []byte) [Size]byte {
if boring.Enabled {
return boring.SHA256(data)
}
h := New()
h.Write(data)
var sum [Size]byte
h.Sum(sum[:0])
return sum
}
// Sum224 returns the SHA224 checksum of the data.
func Sum224(data []byte) [Size224]byte {
if boring.Enabled {
return boring.SHA224(data)
}
h := New224()
h.Write(data)
var sum [Size224]byte
h.Sum(sum[:0])
return sum
}
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package sha3 implements the SHA-3 hash algorithms and the SHAKE extendable
// output functions defined in FIPS 202.
package sha3
import (
"crypto"
"crypto/internal/fips140/sha3"
"hash"
_ "unsafe"
)
func init() {
crypto.RegisterHash(crypto.SHA3_224, func() hash.Hash { return New224() })
crypto.RegisterHash(crypto.SHA3_256, func() hash.Hash { return New256() })
crypto.RegisterHash(crypto.SHA3_384, func() hash.Hash { return New384() })
crypto.RegisterHash(crypto.SHA3_512, func() hash.Hash { return New512() })
}
// Sum224 returns the SHA3-224 hash of data.
func Sum224(data []byte) [28]byte {
var out [28]byte
h := sha3.New224()
h.Write(data)
h.Sum(out[:0])
return out
}
// Sum256 returns the SHA3-256 hash of data.
func Sum256(data []byte) [32]byte {
var out [32]byte
h := sha3.New256()
h.Write(data)
h.Sum(out[:0])
return out
}
// Sum384 returns the SHA3-384 hash of data.
func Sum384(data []byte) [48]byte {
var out [48]byte
h := sha3.New384()
h.Write(data)
h.Sum(out[:0])
return out
}
// Sum512 returns the SHA3-512 hash of data.
func Sum512(data []byte) [64]byte {
var out [64]byte
h := sha3.New512()
h.Write(data)
h.Sum(out[:0])
return out
}
// SumSHAKE128 applies the SHAKE128 extendable output function to data and
// returns an output of the given length in bytes.
func SumSHAKE128(data []byte, length int) []byte {
// Outline the allocation for up to 256 bits of output to the caller's stack.
out := make([]byte, 32)
return sumSHAKE128(out, data, length)
}
func sumSHAKE128(out, data []byte, length int) []byte {
if len(out) < length {
out = make([]byte, length)
} else {
out = out[:length]
}
h := sha3.NewShake128()
h.Write(data)
h.Read(out)
return out
}
// SumSHAKE256 applies the SHAKE256 extendable output function to data and
// returns an output of the given length in bytes.
func SumSHAKE256(data []byte, length int) []byte {
// Outline the allocation for up to 512 bits of output to the caller's stack.
out := make([]byte, 64)
return sumSHAKE256(out, data, length)
}
func sumSHAKE256(out, data []byte, length int) []byte {
if len(out) < length {
out = make([]byte, length)
} else {
out = out[:length]
}
h := sha3.NewShake256()
h.Write(data)
h.Read(out)
return out
}
// SHA3 is an instance of a SHA-3 hash. It implements [hash.Hash].
type SHA3 struct {
s sha3.Digest
}
//go:linkname fips140hash_sha3Unwrap crypto/internal/fips140hash.sha3Unwrap
func fips140hash_sha3Unwrap(sha3 *SHA3) *sha3.Digest {
return &sha3.s
}
// New224 creates a new SHA3-224 hash.
func New224() *SHA3 {
return &SHA3{*sha3.New224()}
}
// New256 creates a new SHA3-256 hash.
func New256() *SHA3 {
return &SHA3{*sha3.New256()}
}
// New384 creates a new SHA3-384 hash.
func New384() *SHA3 {
return &SHA3{*sha3.New384()}
}
// New512 creates a new SHA3-512 hash.
func New512() *SHA3 {
return &SHA3{*sha3.New512()}
}
// Write absorbs more data into the hash's state.
func (s *SHA3) Write(p []byte) (n int, err error) {
return s.s.Write(p)
}
// Sum appends the current hash to b and returns the resulting slice.
func (s *SHA3) Sum(b []byte) []byte {
return s.s.Sum(b)
}
// Reset resets the hash to its initial state.
func (s *SHA3) Reset() {
s.s.Reset()
}
// Size returns the number of bytes Sum will produce.
func (s *SHA3) Size() int {
return s.s.Size()
}
// BlockSize returns the hash's rate.
func (s *SHA3) BlockSize() int {
return s.s.BlockSize()
}
// MarshalBinary implements [encoding.BinaryMarshaler].
func (s *SHA3) MarshalBinary() ([]byte, error) {
return s.s.MarshalBinary()
}
// AppendBinary implements [encoding.BinaryAppender].
func (s *SHA3) AppendBinary(p []byte) ([]byte, error) {
return s.s.AppendBinary(p)
}
// UnmarshalBinary implements [encoding.BinaryUnmarshaler].
func (s *SHA3) UnmarshalBinary(data []byte) error {
return s.s.UnmarshalBinary(data)
}
// Clone implements [hash.Cloner].
func (d *SHA3) Clone() (hash.Cloner, error) {
r := *d
return &r, nil
}
// SHAKE is an instance of a SHAKE extendable output function.
type SHAKE struct {
s sha3.SHAKE
}
// NewSHAKE128 creates a new SHAKE128 XOF.
func NewSHAKE128() *SHAKE {
return &SHAKE{*sha3.NewShake128()}
}
// NewSHAKE256 creates a new SHAKE256 XOF.
func NewSHAKE256() *SHAKE {
return &SHAKE{*sha3.NewShake256()}
}
// NewCSHAKE128 creates a new cSHAKE128 XOF.
//
// N is used to define functions based on cSHAKE, it can be empty when plain
// cSHAKE is desired. S is a customization byte string used for domain
// separation. When N and S are both empty, this is equivalent to NewSHAKE128.
func NewCSHAKE128(N, S []byte) *SHAKE {
return &SHAKE{*sha3.NewCShake128(N, S)}
}
// NewCSHAKE256 creates a new cSHAKE256 XOF.
//
// N is used to define functions based on cSHAKE, it can be empty when plain
// cSHAKE is desired. S is a customization byte string used for domain
// separation. When N and S are both empty, this is equivalent to NewSHAKE256.
func NewCSHAKE256(N, S []byte) *SHAKE {
return &SHAKE{*sha3.NewCShake256(N, S)}
}
// Write absorbs more data into the XOF's state.
//
// It panics if any output has already been read.
func (s *SHAKE) Write(p []byte) (n int, err error) {
return s.s.Write(p)
}
// Read squeezes more output from the XOF.
//
// Any call to Write after a call to Read will panic.
func (s *SHAKE) Read(p []byte) (n int, err error) {
return s.s.Read(p)
}
// Reset resets the XOF to its initial state.
func (s *SHAKE) Reset() {
s.s.Reset()
}
// BlockSize returns the rate of the XOF.
func (s *SHAKE) BlockSize() int {
return s.s.BlockSize()
}
// MarshalBinary implements [encoding.BinaryMarshaler].
func (s *SHAKE) MarshalBinary() ([]byte, error) {
return s.s.MarshalBinary()
}
// AppendBinary implements [encoding.BinaryAppender].
func (s *SHAKE) AppendBinary(p []byte) ([]byte, error) {
return s.s.AppendBinary(p)
}
// UnmarshalBinary implements [encoding.BinaryUnmarshaler].
func (s *SHAKE) UnmarshalBinary(data []byte) error {
return s.s.UnmarshalBinary(data)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package sha512 implements the SHA-384, SHA-512, SHA-512/224, and SHA-512/256
// hash algorithms as defined in FIPS 180-4.
//
// All the hash.Hash implementations returned by this package also
// implement encoding.BinaryMarshaler and encoding.BinaryUnmarshaler to
// marshal and unmarshal the internal state of the hash.
package sha512
import (
"crypto"
"crypto/internal/boring"
"crypto/internal/fips140/sha512"
"hash"
)
func init() {
crypto.RegisterHash(crypto.SHA384, New384)
crypto.RegisterHash(crypto.SHA512, New)
crypto.RegisterHash(crypto.SHA512_224, New512_224)
crypto.RegisterHash(crypto.SHA512_256, New512_256)
}
const (
// Size is the size, in bytes, of a SHA-512 checksum.
Size = 64
// Size224 is the size, in bytes, of a SHA-512/224 checksum.
Size224 = 28
// Size256 is the size, in bytes, of a SHA-512/256 checksum.
Size256 = 32
// Size384 is the size, in bytes, of a SHA-384 checksum.
Size384 = 48
// BlockSize is the block size, in bytes, of the SHA-512/224,
// SHA-512/256, SHA-384 and SHA-512 hash functions.
BlockSize = 128
)
// New returns a new [hash.Hash] computing the SHA-512 checksum. The Hash
// also implements [encoding.BinaryMarshaler], [encoding.BinaryAppender] and
// [encoding.BinaryUnmarshaler] to marshal and unmarshal the internal
// state of the hash.
func New() hash.Hash {
if boring.Enabled {
return boring.NewSHA512()
}
return sha512.New()
}
// New512_224 returns a new [hash.Hash] computing the SHA-512/224 checksum. The Hash
// also implements [encoding.BinaryMarshaler], [encoding.BinaryAppender] and
// [encoding.BinaryUnmarshaler] to marshal and unmarshal the internal
// state of the hash.
func New512_224() hash.Hash {
return sha512.New512_224()
}
// New512_256 returns a new [hash.Hash] computing the SHA-512/256 checksum. The Hash
// also implements [encoding.BinaryMarshaler], [encoding.BinaryAppender] and
// [encoding.BinaryUnmarshaler] to marshal and unmarshal the internal
// state of the hash.
func New512_256() hash.Hash {
return sha512.New512_256()
}
// New384 returns a new [hash.Hash] computing the SHA-384 checksum. The Hash
// also implements [encoding.BinaryMarshaler], [encoding.BinaryAppender] and
// [encoding.BinaryUnmarshaler] to marshal and unmarshal the internal
// state of the hash.
func New384() hash.Hash {
if boring.Enabled {
return boring.NewSHA384()
}
return sha512.New384()
}
// Sum512 returns the SHA512 checksum of the data.
func Sum512(data []byte) [Size]byte {
if boring.Enabled {
return boring.SHA512(data)
}
h := New()
h.Write(data)
var sum [Size]byte
h.Sum(sum[:0])
return sum
}
// Sum384 returns the SHA384 checksum of the data.
func Sum384(data []byte) [Size384]byte {
if boring.Enabled {
return boring.SHA384(data)
}
h := New384()
h.Write(data)
var sum [Size384]byte
h.Sum(sum[:0])
return sum
}
// Sum512_224 returns the Sum512/224 checksum of the data.
func Sum512_224(data []byte) [Size224]byte {
h := New512_224()
h.Write(data)
var sum [Size224]byte
h.Sum(sum[:0])
return sum
}
// Sum512_256 returns the Sum512/256 checksum of the data.
func Sum512_256(data []byte) [Size256]byte {
h := New512_256()
h.Write(data)
var sum [Size256]byte
h.Sum(sum[:0])
return sum
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package subtle implements functions that are often useful in cryptographic
// code but require careful thought to use correctly.
package subtle
import "crypto/internal/fips140/subtle"
// ConstantTimeCompare returns 1 if the two slices, x and y, have equal contents
// and 0 otherwise. The time taken is a function of the length of the slices and
// is independent of the contents. If the lengths of x and y do not match it
// returns 0 immediately.
func ConstantTimeCompare(x, y []byte) int {
return subtle.ConstantTimeCompare(x, y)
}
// ConstantTimeSelect returns x if v == 1 and y if v == 0.
// Its behavior is undefined if v takes any other value.
func ConstantTimeSelect(v, x, y int) int {
return subtle.ConstantTimeSelect(v, x, y)
}
// ConstantTimeByteEq returns 1 if x == y and 0 otherwise.
func ConstantTimeByteEq(x, y uint8) int {
return subtle.ConstantTimeByteEq(x, y)
}
// ConstantTimeEq returns 1 if x == y and 0 otherwise.
func ConstantTimeEq(x, y int32) int {
return subtle.ConstantTimeEq(x, y)
}
// ConstantTimeCopy copies the contents of y into x (a slice of equal length)
// if v == 1. If v == 0, x is left unchanged. Its behavior is undefined if v
// takes any other value.
func ConstantTimeCopy(v int, x, y []byte) {
subtle.ConstantTimeCopy(v, x, y)
}
// ConstantTimeLessOrEq returns 1 if x <= y and 0 otherwise.
// Its behavior is undefined if x or y are negative or > 2**31 - 1.
func ConstantTimeLessOrEq(x, y int) int {
return subtle.ConstantTimeLessOrEq(x, y)
}
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package subtle
import (
"internal/runtime/sys"
"runtime"
)
// WithDataIndependentTiming enables architecture specific features which ensure
// that the timing of specific instructions is independent of their inputs
// before executing f. On f returning it disables these features.
//
// WithDataIndependentTiming should only be used when f is written to make use
// of constant-time operations. WithDataIndependentTiming does not make
// variable-time code constant-time.
//
// WithDataIndependentTiming may lock the current goroutine to the OS thread for
// the duration of f. Calls to WithDataIndependentTiming may be nested.
//
// On Arm64 processors with FEAT_DIT, WithDataIndependentTiming enables
// PSTATE.DIT. See https://developer.arm.com/documentation/ka005181/1-0/?lang=en.
//
// Currently, on all other architectures WithDataIndependentTiming executes f immediately
// with no other side-effects.
//
//go:noinline
func WithDataIndependentTiming(f func()) {
if !sys.DITSupported {
f()
return
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
alreadyEnabled := sys.EnableDIT()
// disableDIT is called in a deferred function so that if f panics we will
// still disable DIT, in case the panic is recovered further up the stack.
defer func() {
if !alreadyEnabled {
sys.DisableDIT()
}
}()
f()
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package subtle
import "crypto/internal/fips140/subtle"
// XORBytes sets dst[i] = x[i] ^ y[i] for all i < n = min(len(x), len(y)),
// returning n, the number of bytes written to dst.
//
// If dst does not have length at least n,
// XORBytes panics without writing anything to dst.
//
// dst and x or y may overlap exactly or not at all,
// otherwise XORBytes may panic.
func XORBytes(dst, x, y []byte) int {
return subtle.XORBytes(dst, x, y)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import "strconv"
// An AlertError is a TLS alert.
//
// When using a QUIC transport, QUICConn methods will return an error
// which wraps AlertError rather than sending a TLS alert.
type AlertError uint8
func (e AlertError) Error() string {
return alert(e).String()
}
type alert uint8
const (
// alert level
alertLevelWarning = 1
alertLevelError = 2
)
const (
alertCloseNotify alert = 0
alertUnexpectedMessage alert = 10
alertBadRecordMAC alert = 20
alertDecryptionFailed alert = 21
alertRecordOverflow alert = 22
alertDecompressionFailure alert = 30
alertHandshakeFailure alert = 40
alertBadCertificate alert = 42
alertUnsupportedCertificate alert = 43
alertCertificateRevoked alert = 44
alertCertificateExpired alert = 45
alertCertificateUnknown alert = 46
alertIllegalParameter alert = 47
alertUnknownCA alert = 48
alertAccessDenied alert = 49
alertDecodeError alert = 50
alertDecryptError alert = 51
alertExportRestriction alert = 60
alertProtocolVersion alert = 70
alertInsufficientSecurity alert = 71
alertInternalError alert = 80
alertInappropriateFallback alert = 86
alertUserCanceled alert = 90
alertNoRenegotiation alert = 100
alertMissingExtension alert = 109
alertUnsupportedExtension alert = 110
alertCertificateUnobtainable alert = 111
alertUnrecognizedName alert = 112
alertBadCertificateStatusResponse alert = 113
alertBadCertificateHashValue alert = 114
alertUnknownPSKIdentity alert = 115
alertCertificateRequired alert = 116
alertNoApplicationProtocol alert = 120
alertECHRequired alert = 121
)
var alertText = map[alert]string{
alertCloseNotify: "close notify",
alertUnexpectedMessage: "unexpected message",
alertBadRecordMAC: "bad record MAC",
alertDecryptionFailed: "decryption failed",
alertRecordOverflow: "record overflow",
alertDecompressionFailure: "decompression failure",
alertHandshakeFailure: "handshake failure",
alertBadCertificate: "bad certificate",
alertUnsupportedCertificate: "unsupported certificate",
alertCertificateRevoked: "revoked certificate",
alertCertificateExpired: "expired certificate",
alertCertificateUnknown: "unknown certificate",
alertIllegalParameter: "illegal parameter",
alertUnknownCA: "unknown certificate authority",
alertAccessDenied: "access denied",
alertDecodeError: "error decoding message",
alertDecryptError: "error decrypting message",
alertExportRestriction: "export restriction",
alertProtocolVersion: "protocol version not supported",
alertInsufficientSecurity: "insufficient security level",
alertInternalError: "internal error",
alertInappropriateFallback: "inappropriate fallback",
alertUserCanceled: "user canceled",
alertNoRenegotiation: "no renegotiation",
alertMissingExtension: "missing extension",
alertUnsupportedExtension: "unsupported extension",
alertCertificateUnobtainable: "certificate unobtainable",
alertUnrecognizedName: "unrecognized name",
alertBadCertificateStatusResponse: "bad certificate status response",
alertBadCertificateHashValue: "bad certificate hash value",
alertUnknownPSKIdentity: "unknown PSK identity",
alertCertificateRequired: "certificate required",
alertNoApplicationProtocol: "no application protocol",
alertECHRequired: "encrypted client hello required",
}
func (e alert) String() string {
s, ok := alertText[e]
if ok {
return "tls: " + s
}
return "tls: alert(" + strconv.Itoa(int(e)) + ")"
}
func (e alert) Error() string {
return e.String()
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"errors"
"fmt"
"hash"
"io"
"slices"
)
// verifyHandshakeSignature verifies a signature against pre-hashed
// (if required) handshake contents.
func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error {
switch sigType {
case signatureECDSA:
pubKey, ok := pubkey.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("expected an ECDSA public key, got %T", pubkey)
}
if !ecdsa.VerifyASN1(pubKey, signed, sig) {
return errors.New("ECDSA verification failure")
}
case signatureEd25519:
pubKey, ok := pubkey.(ed25519.PublicKey)
if !ok {
return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey)
}
if !ed25519.Verify(pubKey, signed, sig) {
return errors.New("Ed25519 verification failure")
}
case signaturePKCS1v15:
pubKey, ok := pubkey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("expected an RSA public key, got %T", pubkey)
}
if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil {
return err
}
case signatureRSAPSS:
pubKey, ok := pubkey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("expected an RSA public key, got %T", pubkey)
}
signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}
if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil {
return err
}
default:
return errors.New("internal error: unknown signature type")
}
return nil
}
const (
serverSignatureContext = "TLS 1.3, server CertificateVerify\x00"
clientSignatureContext = "TLS 1.3, client CertificateVerify\x00"
)
var signaturePadding = []byte{
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
}
// signedMessage returns the pre-hashed (if necessary) message to be signed by
// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3.
func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte {
if sigHash == directSigning {
b := &bytes.Buffer{}
b.Write(signaturePadding)
io.WriteString(b, context)
b.Write(transcript.Sum(nil))
return b.Bytes()
}
h := sigHash.New()
h.Write(signaturePadding)
io.WriteString(h, context)
h.Write(transcript.Sum(nil))
return h.Sum(nil)
}
// typeAndHashFromSignatureScheme returns the corresponding signature type and
// crypto.Hash for a given TLS SignatureScheme.
func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) {
switch signatureAlgorithm {
case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512:
sigType = signaturePKCS1v15
case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512:
sigType = signatureRSAPSS
case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512:
sigType = signatureECDSA
case Ed25519:
sigType = signatureEd25519
default:
return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm)
}
switch signatureAlgorithm {
case PKCS1WithSHA1, ECDSAWithSHA1:
hash = crypto.SHA1
case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256:
hash = crypto.SHA256
case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384:
hash = crypto.SHA384
case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512:
hash = crypto.SHA512
case Ed25519:
hash = directSigning
default:
return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm)
}
return sigType, hash, nil
}
// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for
// a given public key used with TLS 1.0 and 1.1, before the introduction of
// signature algorithm negotiation.
func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) {
switch pub.(type) {
case *rsa.PublicKey:
return signaturePKCS1v15, crypto.MD5SHA1, nil
case *ecdsa.PublicKey:
return signatureECDSA, crypto.SHA1, nil
case ed25519.PublicKey:
// RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1,
// but it requires holding on to a handshake transcript to do a
// full signature, and not even OpenSSL bothers with the
// complexity, so we can't even test it properly.
return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2")
default:
return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub)
}
}
var rsaSignatureSchemes = []struct {
scheme SignatureScheme
minModulusBytes int
}{
// RSA-PSS is used with PSSSaltLengthEqualsHash, and requires
// emLen >= hLen + sLen + 2
{PSSWithSHA256, crypto.SHA256.Size()*2 + 2},
{PSSWithSHA384, crypto.SHA384.Size()*2 + 2},
{PSSWithSHA512, crypto.SHA512.Size()*2 + 2},
// PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires
// emLen >= len(prefix) + hLen + 11
{PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11},
{PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11},
{PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11},
{PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11},
}
func signatureSchemesForPublicKey(version uint16, pub crypto.PublicKey) []SignatureScheme {
switch pub := pub.(type) {
case *ecdsa.PublicKey:
if version < VersionTLS13 {
// In TLS 1.2 and earlier, ECDSA algorithms are not
// constrained to a single curve.
return []SignatureScheme{
ECDSAWithP256AndSHA256,
ECDSAWithP384AndSHA384,
ECDSAWithP521AndSHA512,
ECDSAWithSHA1,
}
}
switch pub.Curve {
case elliptic.P256():
return []SignatureScheme{ECDSAWithP256AndSHA256}
case elliptic.P384():
return []SignatureScheme{ECDSAWithP384AndSHA384}
case elliptic.P521():
return []SignatureScheme{ECDSAWithP521AndSHA512}
default:
return nil
}
case *rsa.PublicKey:
size := pub.Size()
sigAlgs := make([]SignatureScheme, 0, len(rsaSignatureSchemes))
for _, candidate := range rsaSignatureSchemes {
if size >= candidate.minModulusBytes {
sigAlgs = append(sigAlgs, candidate.scheme)
}
}
return sigAlgs
case ed25519.PublicKey:
return []SignatureScheme{Ed25519}
default:
return nil
}
}
// selectSignatureScheme picks a SignatureScheme from the peer's preference list
// that works with the selected certificate. It's only called for protocol
// versions that support signature algorithms, so TLS 1.2 and 1.3.
func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) {
priv, ok := c.PrivateKey.(crypto.Signer)
if !ok {
return 0, unsupportedCertificateError(c)
}
supportedAlgs := signatureSchemesForPublicKey(vers, priv.Public())
if c.SupportedSignatureAlgorithms != nil {
supportedAlgs = slices.DeleteFunc(supportedAlgs, func(sigAlg SignatureScheme) bool {
return !isSupportedSignatureAlgorithm(sigAlg, c.SupportedSignatureAlgorithms)
})
}
// Filter out any unsupported signature algorithms, for example due to
// FIPS 140-3 policy, tlssha1=0, or protocol version.
supportedAlgs = slices.DeleteFunc(supportedAlgs, func(sigAlg SignatureScheme) bool {
return isDisabledSignatureAlgorithm(vers, sigAlg, false)
})
if len(supportedAlgs) == 0 {
return 0, unsupportedCertificateError(c)
}
if len(peerAlgs) == 0 && vers == VersionTLS12 {
// For TLS 1.2, if the client didn't send signature_algorithms then we
// can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1.
// RFC 9155 made signature_algorithms mandatory in TLS 1.2, and we gated
// it behind the tlssha1 GODEBUG setting.
if tlssha1.Value() != "1" {
return 0, errors.New("tls: missing signature_algorithms from TLS 1.2 peer")
}
peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1}
}
// Pick signature scheme in the peer's preference order, as our
// preference order is not configurable.
for _, preferredAlg := range peerAlgs {
if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) {
return preferredAlg, nil
}
}
return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms")
}
// unsupportedCertificateError returns a helpful error for certificates with
// an unsupported private key.
func unsupportedCertificateError(cert *Certificate) error {
switch cert.PrivateKey.(type) {
case rsa.PrivateKey, ecdsa.PrivateKey:
return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T",
cert.PrivateKey, cert.PrivateKey)
case *ed25519.PrivateKey:
return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey")
}
signer, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer",
cert.PrivateKey)
}
switch pub := signer.Public().(type) {
case *ecdsa.PublicKey:
switch pub.Curve {
case elliptic.P256():
case elliptic.P384():
case elliptic.P521():
default:
return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name)
}
case *rsa.PublicKey:
return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms")
case ed25519.PublicKey:
default:
return fmt.Errorf("tls: unsupported certificate key (%T)", pub)
}
if cert.SupportedSignatureAlgorithms != nil {
return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms")
}
return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey)
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto/x509"
"runtime"
"sync"
"weak"
)
// weakCertCache provides a cache of *x509.Certificates, allowing multiple
// connections to reuse parsed certificates, instead of re-parsing the
// certificate for every connection, which is an expensive operation.
type weakCertCache struct{ sync.Map }
func (wcc *weakCertCache) newCert(der []byte) (*x509.Certificate, error) {
if entry, ok := wcc.Load(string(der)); ok {
if v := entry.(weak.Pointer[x509.Certificate]).Value(); v != nil {
return v, nil
}
}
cert, err := x509.ParseCertificate(der)
if err != nil {
return nil, err
}
wp := weak.Make(cert)
if entry, loaded := wcc.LoadOrStore(string(der), wp); !loaded {
runtime.AddCleanup(cert, func(_ any) { wcc.CompareAndDelete(string(der), entry) }, any(string(der)))
} else if v := entry.(weak.Pointer[x509.Certificate]).Value(); v != nil {
return v, nil
} else {
if wcc.CompareAndSwap(string(der), entry, wp) {
runtime.AddCleanup(cert, func(_ any) { wcc.CompareAndDelete(string(der), wp) }, any(string(der)))
}
}
return cert, nil
}
var globalCertCache = new(weakCertCache)
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/hmac"
"crypto/internal/boring"
fipsaes "crypto/internal/fips140/aes"
"crypto/internal/fips140/aes/gcm"
"crypto/rc4"
"crypto/sha1"
"crypto/sha256"
"fmt"
"hash"
"internal/cpu"
"runtime"
_ "unsafe" // for linkname
"golang.org/x/crypto/chacha20poly1305"
)
// CipherSuite is a TLS cipher suite. Note that most functions in this package
// accept and expose cipher suite IDs instead of this type.
type CipherSuite struct {
ID uint16
Name string
// Supported versions is the list of TLS protocol versions that can
// negotiate this cipher suite.
SupportedVersions []uint16
// Insecure is true if the cipher suite has known security issues
// due to its primitives, design, or implementation.
Insecure bool
}
var (
supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12}
supportedOnlyTLS12 = []uint16{VersionTLS12}
supportedOnlyTLS13 = []uint16{VersionTLS13}
)
// CipherSuites returns a list of cipher suites currently implemented by this
// package, excluding those with security issues, which are returned by
// [InsecureCipherSuites].
//
// The list is sorted by ID. Note that the default cipher suites selected by
// this package might depend on logic that can't be captured by a static list,
// and might not match those returned by this function.
func CipherSuites() []*CipherSuite {
return []*CipherSuite{
{TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false},
{TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false},
{TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
}
}
// InsecureCipherSuites returns a list of cipher suites currently implemented by
// this package and which have security issues.
//
// Most applications should not use the cipher suites in this list, and should
// only use those returned by [CipherSuites].
func InsecureCipherSuites() []*CipherSuite {
// This list includes legacy RSA kex, RC4, CBC_SHA256, and 3DES cipher
// suites. See cipherSuitesPreferenceOrder for details.
return []*CipherSuite{
{TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
{TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true},
{TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, true},
{TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, true},
{TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
{TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, true},
{TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, true},
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
}
}
// CipherSuiteName returns the standard name for the passed cipher suite ID
// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation
// of the ID value if the cipher suite is not implemented by this package.
func CipherSuiteName(id uint16) string {
for _, c := range CipherSuites() {
if c.ID == id {
return c.Name
}
}
for _, c := range InsecureCipherSuites() {
if c.ID == id {
return c.Name
}
}
return fmt.Sprintf("0x%04X", id)
}
const (
// suiteECDHE indicates that the cipher suite involves elliptic curve
// Diffie-Hellman. This means that it should only be selected when the
// client indicates that it supports ECC with a curve and point format
// that we're happy with.
suiteECDHE = 1 << iota
// suiteECSign indicates that the cipher suite involves an ECDSA or
// EdDSA signature and therefore may only be selected when the server's
// certificate is ECDSA or EdDSA. If this is not set then the cipher suite
// is RSA based.
suiteECSign
// suiteTLS12 indicates that the cipher suite should only be advertised
// and accepted when using TLS 1.2.
suiteTLS12
// suiteSHA384 indicates that the cipher suite uses SHA384 as the
// handshake hash.
suiteSHA384
)
// A cipherSuite is a TLS 1.0–1.2 cipher suite, and defines the key exchange
// mechanism, as well as the cipher+MAC pair or the AEAD.
type cipherSuite struct {
id uint16
// the lengths, in bytes, of the key material needed for each component.
keyLen int
macLen int
ivLen int
ka func(version uint16) keyAgreement
// flags is a bitmask of the suite* values, above.
flags int
cipher func(key, iv []byte, isRead bool) any
mac func(key []byte) hash.Hash
aead func(key, fixedNonce []byte) aead
}
var cipherSuites = []*cipherSuite{ // TODO: replace with a map, since the order doesn't matter.
{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305},
{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305},
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM},
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM},
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12, cipherAES, macSHA256, nil},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, cipherAES, macSHA256, nil},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil},
{TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM},
{TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
{TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12, cipherAES, macSHA256, nil},
{TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
{TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil},
{TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil},
{TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherRC4, macSHA1, nil},
}
// selectCipherSuite returns the first TLS 1.0–1.2 cipher suite from ids which
// is also in supportedIDs and passes the ok filter.
func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite {
for _, id := range ids {
candidate := cipherSuiteByID(id)
if candidate == nil || !ok(candidate) {
continue
}
for _, suppID := range supportedIDs {
if id == suppID {
return candidate
}
}
}
return nil
}
// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash
// algorithm to be used with HKDF. See RFC 8446, Appendix B.4.
type cipherSuiteTLS13 struct {
id uint16
keyLen int
aead func(key, fixedNonce []byte) aead
hash crypto.Hash
}
// cipherSuitesTLS13 should be an internal detail,
// but widely used packages access it using linkname.
// Notable members of the hall of shame include:
// - github.com/quic-go/quic-go
// - github.com/sagernet/quic-go
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname cipherSuitesTLS13
var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map.
{TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256},
{TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256},
{TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384},
}
// cipherSuitesPreferenceOrder is the order in which we'll select (on the
// server) or advertise (on the client) TLS 1.0–1.2 cipher suites.
//
// Cipher suites are filtered but not reordered based on the application and
// peer's preferences, meaning we'll never select a suite lower in this list if
// any higher one is available. This makes it more defensible to keep weaker
// cipher suites enabled, especially on the server side where we get the last
// word, since there are no known downgrade attacks on cipher suites selection.
//
// The list is sorted by applying the following priority rules, stopping at the
// first (most important) applicable one:
//
// - Anything else comes before RC4
//
// RC4 has practically exploitable biases. See https://www.rc4nomore.com.
//
// - Anything else comes before CBC_SHA256
//
// SHA-256 variants of the CBC ciphersuites don't implement any Lucky13
// countermeasures. See https://www.isg.rhul.ac.uk/tls/Lucky13.html and
// https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
//
// - Anything else comes before 3DES
//
// 3DES has 64-bit blocks, which makes it fundamentally susceptible to
// birthday attacks. See https://sweet32.info.
//
// - ECDHE comes before anything else
//
// Once we got the broken stuff out of the way, the most important
// property a cipher suite can have is forward secrecy. We don't
// implement FFDHE, so that means ECDHE.
//
// - AEADs come before CBC ciphers
//
// Even with Lucky13 countermeasures, MAC-then-Encrypt CBC cipher suites
// are fundamentally fragile, and suffered from an endless sequence of
// padding oracle attacks. See https://eprint.iacr.org/2015/1129,
// https://www.imperialviolet.org/2014/12/08/poodleagain.html, and
// https://blog.cloudflare.com/yet-another-padding-oracle-in-openssl-cbc-ciphersuites/.
//
// - AES comes before ChaCha20
//
// When AES hardware is available, AES-128-GCM and AES-256-GCM are faster
// than ChaCha20Poly1305.
//
// When AES hardware is not available, AES-128-GCM is one or more of: much
// slower, way more complex, and less safe (because not constant time)
// than ChaCha20Poly1305.
//
// We use this list if we think both peers have AES hardware, and
// cipherSuitesPreferenceOrderNoAES otherwise.
//
// - AES-128 comes before AES-256
//
// The only potential advantages of AES-256 are better multi-target
// margins, and hypothetical post-quantum properties. Neither apply to
// TLS, and AES-256 is slower due to its four extra rounds (which don't
// contribute to the advantages above).
//
// - ECDSA comes before RSA
//
// The relative order of ECDSA and RSA cipher suites doesn't matter,
// as they depend on the certificate. Pick one to get a stable order.
var cipherSuitesPreferenceOrder = []uint16{
// AEADs w/ ECDHE
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
// CBC w/ ECDHE
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
// AEADs w/o ECDHE
TLS_RSA_WITH_AES_128_GCM_SHA256,
TLS_RSA_WITH_AES_256_GCM_SHA384,
// CBC w/o ECDHE
TLS_RSA_WITH_AES_128_CBC_SHA,
TLS_RSA_WITH_AES_256_CBC_SHA,
// 3DES
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
TLS_RSA_WITH_3DES_EDE_CBC_SHA,
// CBC_SHA256
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
TLS_RSA_WITH_AES_128_CBC_SHA256,
// RC4
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
TLS_RSA_WITH_RC4_128_SHA,
}
var cipherSuitesPreferenceOrderNoAES = []uint16{
// ChaCha20Poly1305
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
// AES-GCM w/ ECDHE
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
// The rest of cipherSuitesPreferenceOrder.
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
TLS_RSA_WITH_AES_128_GCM_SHA256,
TLS_RSA_WITH_AES_256_GCM_SHA384,
TLS_RSA_WITH_AES_128_CBC_SHA,
TLS_RSA_WITH_AES_256_CBC_SHA,
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
TLS_RSA_WITH_3DES_EDE_CBC_SHA,
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
TLS_RSA_WITH_AES_128_CBC_SHA256,
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
TLS_RSA_WITH_RC4_128_SHA,
}
// disabledCipherSuites are not used unless explicitly listed in Config.CipherSuites.
var disabledCipherSuites = map[uint16]bool{
// CBC_SHA256
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: true,
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: true,
TLS_RSA_WITH_AES_128_CBC_SHA256: true,
// RC4
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: true,
TLS_ECDHE_RSA_WITH_RC4_128_SHA: true,
TLS_RSA_WITH_RC4_128_SHA: true,
}
// rsaKexCiphers contains the ciphers which use RSA based key exchange,
// which we also disable by default unless a GODEBUG is set.
var rsaKexCiphers = map[uint16]bool{
TLS_RSA_WITH_RC4_128_SHA: true,
TLS_RSA_WITH_3DES_EDE_CBC_SHA: true,
TLS_RSA_WITH_AES_128_CBC_SHA: true,
TLS_RSA_WITH_AES_256_CBC_SHA: true,
TLS_RSA_WITH_AES_128_CBC_SHA256: true,
TLS_RSA_WITH_AES_128_GCM_SHA256: true,
TLS_RSA_WITH_AES_256_GCM_SHA384: true,
}
// tdesCiphers contains 3DES ciphers,
// which we also disable by default unless a GODEBUG is set.
var tdesCiphers = map[uint16]bool{
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: true,
TLS_RSA_WITH_3DES_EDE_CBC_SHA: true,
}
var (
// Keep in sync with crypto/internal/fips140/aes/gcm.supportsAESGCM.
hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ && cpu.X86.HasSSE41 && cpu.X86.HasSSSE3
hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCTR && cpu.S390X.HasGHASH
hasGCMAsmPPC64 = runtime.GOARCH == "ppc64" || runtime.GOARCH == "ppc64le"
hasAESGCMHardwareSupport = hasGCMAsmAMD64 || hasGCMAsmARM64 || hasGCMAsmS390X || hasGCMAsmPPC64
)
var aesgcmCiphers = map[uint16]bool{
// TLS 1.2
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true,
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true,
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true,
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true,
// TLS 1.3
TLS_AES_128_GCM_SHA256: true,
TLS_AES_256_GCM_SHA384: true,
}
// isAESGCMPreferred returns whether we have hardware support for AES-GCM, and the
// first known cipher in the peer's preference list is an AES-GCM cipher,
// implying the peer also has hardware support for it.
func isAESGCMPreferred(ciphers []uint16) bool {
if !hasAESGCMHardwareSupport {
return false
}
for _, cID := range ciphers {
if c := cipherSuiteByID(cID); c != nil {
return aesgcmCiphers[cID]
}
if c := cipherSuiteTLS13ByID(cID); c != nil {
return aesgcmCiphers[cID]
}
}
return false
}
func cipherRC4(key, iv []byte, isRead bool) any {
cipher, _ := rc4.NewCipher(key)
return cipher
}
func cipher3DES(key, iv []byte, isRead bool) any {
block, _ := des.NewTripleDESCipher(key)
if isRead {
return cipher.NewCBCDecrypter(block, iv)
}
return cipher.NewCBCEncrypter(block, iv)
}
func cipherAES(key, iv []byte, isRead bool) any {
block, _ := aes.NewCipher(key)
if isRead {
return cipher.NewCBCDecrypter(block, iv)
}
return cipher.NewCBCEncrypter(block, iv)
}
// macSHA1 returns a SHA-1 based constant time MAC.
func macSHA1(key []byte) hash.Hash {
h := sha1.New
// The BoringCrypto SHA1 does not have a constant-time
// checksum function, so don't try to use it.
if !boring.Enabled {
h = newConstantTimeHash(h)
}
return hmac.New(h, key)
}
// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and
// is currently only used in disabled-by-default cipher suites.
func macSHA256(key []byte) hash.Hash {
return hmac.New(sha256.New, key)
}
type aead interface {
cipher.AEAD
// explicitNonceLen returns the number of bytes of explicit nonce
// included in each record. This is eight for older AEADs and
// zero for modern ones.
explicitNonceLen() int
}
const (
aeadNonceLength = 12
noncePrefixLength = 4
)
// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to
// each call.
type prefixNonceAEAD struct {
// nonce contains the fixed part of the nonce in the first four bytes.
nonce [aeadNonceLength]byte
aead cipher.AEAD
}
func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength }
func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() }
func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() }
func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
copy(f.nonce[4:], nonce)
return f.aead.Seal(out, f.nonce[:], plaintext, additionalData)
}
func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
copy(f.nonce[4:], nonce)
return f.aead.Open(out, f.nonce[:], ciphertext, additionalData)
}
// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce
// before each call.
type xorNonceAEAD struct {
nonceMask [aeadNonceLength]byte
aead cipher.AEAD
}
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result
}
func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result, err
}
func aeadAESGCM(key, noncePrefix []byte) aead {
if len(noncePrefix) != noncePrefixLength {
panic("tls: internal error: wrong nonce length")
}
aes, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
var aead cipher.AEAD
if boring.Enabled {
aead, err = boring.NewGCMTLS(aes)
} else {
boring.Unreachable()
aead, err = gcm.NewGCMForTLS12(aes.(*fipsaes.Block))
}
if err != nil {
panic(err)
}
ret := &prefixNonceAEAD{aead: aead}
copy(ret.nonce[:], noncePrefix)
return ret
}
// aeadAESGCMTLS13 should be an internal detail,
// but widely used packages access it using linkname.
// Notable members of the hall of shame include:
// - github.com/xtls/xray-core
// - github.com/v2fly/v2ray-core
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname aeadAESGCMTLS13
func aeadAESGCMTLS13(key, nonceMask []byte) aead {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aes, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
var aead cipher.AEAD
if boring.Enabled {
aead, err = boring.NewGCMTLS13(aes)
} else {
boring.Unreachable()
aead, err = gcm.NewGCMForTLS13(aes.(*fipsaes.Block))
}
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
func aeadChaCha20Poly1305(key, nonceMask []byte) aead {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aead, err := chacha20poly1305.New(key)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
type constantTimeHash interface {
hash.Hash
ConstantTimeSum(b []byte) []byte
}
// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces
// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC.
type cthWrapper struct {
h constantTimeHash
}
func (c *cthWrapper) Size() int { return c.h.Size() }
func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() }
func (c *cthWrapper) Reset() { c.h.Reset() }
func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) }
func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) }
func newConstantTimeHash(h func() hash.Hash) func() hash.Hash {
boring.Unreachable()
return func() hash.Hash {
return &cthWrapper{h().(constantTimeHash)}
}
}
// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3.
func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte {
h.Reset()
h.Write(seq)
h.Write(header)
h.Write(data)
res := h.Sum(out)
if extra != nil {
h.Write(extra)
}
return res
}
func rsaKA(version uint16) keyAgreement {
return rsaKeyAgreement{}
}
func ecdheECDSAKA(version uint16) keyAgreement {
return &ecdheKeyAgreement{
isRSA: false,
version: version,
}
}
func ecdheRSAKA(version uint16) keyAgreement {
return &ecdheKeyAgreement{
isRSA: true,
version: version,
}
}
// mutualCipherSuite returns a cipherSuite given a list of supported
// ciphersuites and the id requested by the peer.
func mutualCipherSuite(have []uint16, want uint16) *cipherSuite {
for _, id := range have {
if id == want {
return cipherSuiteByID(id)
}
}
return nil
}
func cipherSuiteByID(id uint16) *cipherSuite {
for _, cipherSuite := range cipherSuites {
if cipherSuite.id == id {
return cipherSuite
}
}
return nil
}
func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 {
for _, id := range have {
if id == want {
return cipherSuiteTLS13ByID(id)
}
}
return nil
}
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 {
for _, cipherSuite := range cipherSuitesTLS13 {
if cipherSuite.id == id {
return cipherSuite
}
}
return nil
}
// A list of cipher suite IDs that are, or have been, implemented by this
// package.
//
// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml
const (
// TLS 1.0 - 1.2 cipher suites.
TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005
TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a
TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f
TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035
TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c
TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c
TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a
TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9
// TLS 1.3 cipher suites.
TLS_AES_128_GCM_SHA256 uint16 = 0x1301
TLS_AES_256_GCM_SHA384 uint16 = 0x1302
TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303
// TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator
// that the client is doing version fallback. See RFC 7507.
TLS_FALLBACK_SCSV uint16 = 0x5600
// Legacy names for the corresponding cipher suites with the correct _SHA256
// suffix, retained for backward compatibility.
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
)
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"container/list"
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha512"
"crypto/tls/internal/fips140tls"
"crypto/x509"
"errors"
"fmt"
"internal/godebug"
"io"
"net"
"slices"
"strings"
"sync"
"time"
_ "unsafe" // for linkname
)
const (
VersionTLS10 = 0x0301
VersionTLS11 = 0x0302
VersionTLS12 = 0x0303
VersionTLS13 = 0x0304
// Deprecated: SSLv3 is cryptographically broken, and is no longer
// supported by this package. See golang.org/issue/32716.
VersionSSL30 = 0x0300
)
// VersionName returns the name for the provided TLS version number
// (e.g. "TLS 1.3"), or a fallback representation of the value if the
// version is not implemented by this package.
func VersionName(version uint16) string {
switch version {
case VersionSSL30:
return "SSLv3"
case VersionTLS10:
return "TLS 1.0"
case VersionTLS11:
return "TLS 1.1"
case VersionTLS12:
return "TLS 1.2"
case VersionTLS13:
return "TLS 1.3"
default:
return fmt.Sprintf("0x%04X", version)
}
}
const (
maxPlaintext = 16384 // maximum plaintext payload length
maxCiphertext = 16384 + 2048 // maximum ciphertext payload length
maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3
recordHeaderLen = 5 // record header length
maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
maxHandshakeCertificateMsg = 262144 // maximum certificate message size (256 KiB)
maxUselessRecords = 16 // maximum number of consecutive non-advancing records
)
// TLS record types.
type recordType uint8
const (
recordTypeChangeCipherSpec recordType = 20
recordTypeAlert recordType = 21
recordTypeHandshake recordType = 22
recordTypeApplicationData recordType = 23
)
// TLS handshake message types.
const (
typeHelloRequest uint8 = 0
typeClientHello uint8 = 1
typeServerHello uint8 = 2
typeNewSessionTicket uint8 = 4
typeEndOfEarlyData uint8 = 5
typeEncryptedExtensions uint8 = 8
typeCertificate uint8 = 11
typeServerKeyExchange uint8 = 12
typeCertificateRequest uint8 = 13
typeServerHelloDone uint8 = 14
typeCertificateVerify uint8 = 15
typeClientKeyExchange uint8 = 16
typeFinished uint8 = 20
typeCertificateStatus uint8 = 22
typeKeyUpdate uint8 = 24
typeMessageHash uint8 = 254 // synthetic message
)
// TLS compression types.
const (
compressionNone uint8 = 0
)
// TLS extension numbers
const (
extensionServerName uint16 = 0
extensionStatusRequest uint16 = 5
extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7
extensionSupportedPoints uint16 = 11
extensionSignatureAlgorithms uint16 = 13
extensionALPN uint16 = 16
extensionSCT uint16 = 18
extensionExtendedMasterSecret uint16 = 23
extensionSessionTicket uint16 = 35
extensionPreSharedKey uint16 = 41
extensionEarlyData uint16 = 42
extensionSupportedVersions uint16 = 43
extensionCookie uint16 = 44
extensionPSKModes uint16 = 45
extensionCertificateAuthorities uint16 = 47
extensionSignatureAlgorithmsCert uint16 = 50
extensionKeyShare uint16 = 51
extensionQUICTransportParameters uint16 = 57
extensionRenegotiationInfo uint16 = 0xff01
extensionECHOuterExtensions uint16 = 0xfd00
extensionEncryptedClientHello uint16 = 0xfe0d
)
// TLS signaling cipher suite values
const (
scsvRenegotiation uint16 = 0x00ff
)
// CurveID is the type of a TLS identifier for a key exchange mechanism. See
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8.
//
// In TLS 1.2, this registry used to support only elliptic curves. In TLS 1.3,
// it was extended to other groups and renamed NamedGroup. See RFC 8446, Section
// 4.2.7. It was then also extended to other mechanisms, such as hybrid
// post-quantum KEMs.
type CurveID uint16
const (
CurveP256 CurveID = 23
CurveP384 CurveID = 24
CurveP521 CurveID = 25
X25519 CurveID = 29
X25519MLKEM768 CurveID = 4588
)
func isTLS13OnlyKeyExchange(curve CurveID) bool {
return curve == X25519MLKEM768
}
func isPQKeyExchange(curve CurveID) bool {
return curve == X25519MLKEM768
}
// TLS 1.3 Key Share. See RFC 8446, Section 4.2.8.
type keyShare struct {
group CurveID
data []byte
}
// TLS 1.3 PSK Key Exchange Modes. See RFC 8446, Section 4.2.9.
const (
pskModePlain uint8 = 0
pskModeDHE uint8 = 1
)
// TLS 1.3 PSK Identity. Can be a Session Ticket, or a reference to a saved
// session. See RFC 8446, Section 4.2.11.
type pskIdentity struct {
label []byte
obfuscatedTicketAge uint32
}
// TLS Elliptic Curve Point Formats
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9
const (
pointFormatUncompressed uint8 = 0
)
// TLS CertificateStatusType (RFC 3546)
const (
statusTypeOCSP uint8 = 1
)
// Certificate types (for certificateRequestMsg)
const (
certTypeRSASign = 1
certTypeECDSASign = 64 // ECDSA or EdDSA keys, see RFC 8422, Section 3.
)
// Signature algorithms (for internal signaling use). Starting at 225 to avoid overlap with
// TLS 1.2 codepoints (RFC 5246, Appendix A.4.1), with which these have nothing to do.
const (
signaturePKCS1v15 uint8 = iota + 225
signatureRSAPSS
signatureECDSA
signatureEd25519
)
// directSigning is a standard Hash value that signals that no pre-hashing
// should be performed, and that the input should be signed directly. It is the
// hash function associated with the Ed25519 signature scheme.
var directSigning crypto.Hash = 0
// helloRetryRequestRandom is set as the Random value of a ServerHello
// to signal that the message is actually a HelloRetryRequest.
var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3.
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11,
0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E,
0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
}
const (
// downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server
// random as a downgrade protection if the server would be capable of
// negotiating a higher version. See RFC 8446, Section 4.1.3.
downgradeCanaryTLS12 = "DOWNGRD\x01"
downgradeCanaryTLS11 = "DOWNGRD\x00"
)
// testingOnlyForceDowngradeCanary is set in tests to force the server side to
// include downgrade canaries even if it's using its highers supported version.
var testingOnlyForceDowngradeCanary bool
// ConnectionState records basic TLS details about the connection.
type ConnectionState struct {
// Version is the TLS version used by the connection (e.g. VersionTLS12).
Version uint16
// HandshakeComplete is true if the handshake has concluded.
HandshakeComplete bool
// DidResume is true if this connection was successfully resumed from a
// previous session with a session ticket or similar mechanism.
DidResume bool
// CipherSuite is the cipher suite negotiated for the connection (e.g.
// TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_AES_128_GCM_SHA256).
CipherSuite uint16
// CurveID is the key exchange mechanism used for the connection. The name
// refers to elliptic curves for legacy reasons, see [CurveID]. If a legacy
// RSA key exchange is used, this value is zero.
CurveID CurveID
// NegotiatedProtocol is the application protocol negotiated with ALPN.
NegotiatedProtocol string
// NegotiatedProtocolIsMutual used to indicate a mutual NPN negotiation.
//
// Deprecated: this value is always true.
NegotiatedProtocolIsMutual bool
// ServerName is the value of the Server Name Indication extension sent by
// the client. It's available both on the server and on the client side.
ServerName string
// PeerCertificates are the parsed certificates sent by the peer, in the
// order in which they were sent. The first element is the leaf certificate
// that the connection is verified against.
//
// On the client side, it can't be empty. On the server side, it can be
// empty if Config.ClientAuth is not RequireAnyClientCert or
// RequireAndVerifyClientCert.
//
// PeerCertificates and its contents should not be modified.
PeerCertificates []*x509.Certificate
// VerifiedChains is a list of one or more chains where the first element is
// PeerCertificates[0] and the last element is from Config.RootCAs (on the
// client side) or Config.ClientCAs (on the server side).
//
// On the client side, it's set if Config.InsecureSkipVerify is false. On
// the server side, it's set if Config.ClientAuth is VerifyClientCertIfGiven
// (and the peer provided a certificate) or RequireAndVerifyClientCert.
//
// VerifiedChains and its contents should not be modified.
VerifiedChains [][]*x509.Certificate
// SignedCertificateTimestamps is a list of SCTs provided by the peer
// through the TLS handshake for the leaf certificate, if any.
SignedCertificateTimestamps [][]byte
// OCSPResponse is a stapled Online Certificate Status Protocol (OCSP)
// response provided by the peer for the leaf certificate, if any.
OCSPResponse []byte
// TLSUnique contains the "tls-unique" channel binding value (see RFC 5929,
// Section 3). This value will be nil for TLS 1.3 connections and for
// resumed connections that don't support Extended Master Secret (RFC 7627).
TLSUnique []byte
// ECHAccepted indicates if Encrypted Client Hello was offered by the client
// and accepted by the server. Currently, ECH is supported only on the
// client side.
ECHAccepted bool
// ekm is a closure exposed via ExportKeyingMaterial.
ekm func(label string, context []byte, length int) ([]byte, error)
// testingOnlyDidHRR is true if a HelloRetryRequest was sent/received.
testingOnlyDidHRR bool
// testingOnlyPeerSignatureAlgorithm is the signature algorithm used by the
// peer to sign the handshake. It is not set for resumed connections.
testingOnlyPeerSignatureAlgorithm SignatureScheme
}
// ExportKeyingMaterial returns length bytes of exported key material in a new
// slice as defined in RFC 5705. If context is nil, it is not used as part of
// the seed. If the connection was set to allow renegotiation via
// Config.Renegotiation, or if the connections supports neither TLS 1.3 nor
// Extended Master Secret, this function will return an error.
//
// Exporting key material without Extended Master Secret or TLS 1.3 was disabled
// in Go 1.22 due to security issues (see the Security Considerations sections
// of RFC 5705 and RFC 7627), but can be re-enabled with the GODEBUG setting
// tlsunsafeekm=1.
func (cs *ConnectionState) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
return cs.ekm(label, context, length)
}
// ClientAuthType declares the policy the server will follow for
// TLS Client Authentication.
type ClientAuthType int
const (
// NoClientCert indicates that no client certificate should be requested
// during the handshake, and if any certificates are sent they will not
// be verified.
NoClientCert ClientAuthType = iota
// RequestClientCert indicates that a client certificate should be requested
// during the handshake, but does not require that the client send any
// certificates.
RequestClientCert
// RequireAnyClientCert indicates that a client certificate should be requested
// during the handshake, and that at least one certificate is required to be
// sent by the client, but that certificate is not required to be valid.
RequireAnyClientCert
// VerifyClientCertIfGiven indicates that a client certificate should be requested
// during the handshake, but does not require that the client sends a
// certificate. If the client does send a certificate it is required to be
// valid.
VerifyClientCertIfGiven
// RequireAndVerifyClientCert indicates that a client certificate should be requested
// during the handshake, and that at least one valid certificate is required
// to be sent by the client.
RequireAndVerifyClientCert
)
// requiresClientCert reports whether the ClientAuthType requires a client
// certificate to be provided.
func requiresClientCert(c ClientAuthType) bool {
switch c {
case RequireAnyClientCert, RequireAndVerifyClientCert:
return true
default:
return false
}
}
// ClientSessionCache is a cache of ClientSessionState objects that can be used
// by a client to resume a TLS session with a given server. ClientSessionCache
// implementations should expect to be called concurrently from different
// goroutines. Up to TLS 1.2, only ticket-based resumption is supported, not
// SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which
// are supported via this interface.
type ClientSessionCache interface {
// Get searches for a ClientSessionState associated with the given key.
// On return, ok is true if one was found.
Get(sessionKey string) (session *ClientSessionState, ok bool)
// Put adds the ClientSessionState to the cache with the given key. It might
// get called multiple times in a connection if a TLS 1.3 server provides
// more than one session ticket. If called with a nil *ClientSessionState,
// it should remove the cache entry.
Put(sessionKey string, cs *ClientSessionState)
}
//go:generate stringer -linecomment -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go
// SignatureScheme identifies a signature algorithm supported by TLS. See
// RFC 8446, Section 4.2.3.
type SignatureScheme uint16
const (
// RSASSA-PKCS1-v1_5 algorithms.
PKCS1WithSHA256 SignatureScheme = 0x0401
PKCS1WithSHA384 SignatureScheme = 0x0501
PKCS1WithSHA512 SignatureScheme = 0x0601
// RSASSA-PSS algorithms with public key OID rsaEncryption.
PSSWithSHA256 SignatureScheme = 0x0804
PSSWithSHA384 SignatureScheme = 0x0805
PSSWithSHA512 SignatureScheme = 0x0806
// ECDSA algorithms. Only constrained to a specific curve in TLS 1.3.
ECDSAWithP256AndSHA256 SignatureScheme = 0x0403
ECDSAWithP384AndSHA384 SignatureScheme = 0x0503
ECDSAWithP521AndSHA512 SignatureScheme = 0x0603
// EdDSA algorithms.
Ed25519 SignatureScheme = 0x0807
// Legacy signature and hash algorithms for TLS 1.2.
PKCS1WithSHA1 SignatureScheme = 0x0201
ECDSAWithSHA1 SignatureScheme = 0x0203
)
// ClientHelloInfo contains information from a ClientHello message in order to
// guide application logic in the GetCertificate and GetConfigForClient callbacks.
type ClientHelloInfo struct {
// CipherSuites lists the CipherSuites supported by the client (e.g.
// TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256).
CipherSuites []uint16
// ServerName indicates the name of the server requested by the client
// in order to support virtual hosting. ServerName is only set if the
// client is using SNI (see RFC 4366, Section 3.1).
ServerName string
// SupportedCurves lists the key exchange mechanisms supported by the
// client. It was renamed to "supported groups" in TLS 1.3, see RFC 8446,
// Section 4.2.7 and [CurveID].
//
// SupportedCurves may be nil in TLS 1.2 and lower if the Supported Elliptic
// Curves Extension is not being used (see RFC 4492, Section 5.1.1).
SupportedCurves []CurveID
// SupportedPoints lists the point formats supported by the client.
// SupportedPoints is set only if the Supported Point Formats Extension
// is being used (see RFC 4492, Section 5.1.2).
SupportedPoints []uint8
// SignatureSchemes lists the signature and hash schemes that the client
// is willing to verify. SignatureSchemes is set only if the Signature
// Algorithms Extension is being used (see RFC 5246, Section 7.4.1.4.1).
SignatureSchemes []SignatureScheme
// SupportedProtos lists the application protocols supported by the client.
// SupportedProtos is set only if the Application-Layer Protocol
// Negotiation Extension is being used (see RFC 7301, Section 3.1).
//
// Servers can select a protocol by setting Config.NextProtos in a
// GetConfigForClient return value.
SupportedProtos []string
// SupportedVersions lists the TLS versions supported by the client.
// For TLS versions less than 1.3, this is extrapolated from the max
// version advertised by the client, so values other than the greatest
// might be rejected if used.
SupportedVersions []uint16
// Extensions lists the IDs of the extensions presented by the client
// in the ClientHello.
Extensions []uint16
// Conn is the underlying net.Conn for the connection. Do not read
// from, or write to, this connection; that will cause the TLS
// connection to fail.
Conn net.Conn
// config is embedded by the GetCertificate or GetConfigForClient caller,
// for use with SupportsCertificate.
config *Config
// ctx is the context of the handshake that is in progress.
ctx context.Context
}
// Context returns the context of the handshake that is in progress.
// This context is a child of the context passed to HandshakeContext,
// if any, and is canceled when the handshake concludes.
func (c *ClientHelloInfo) Context() context.Context {
return c.ctx
}
// CertificateRequestInfo contains information from a server's
// CertificateRequest message, which is used to demand a certificate and proof
// of control from a client.
type CertificateRequestInfo struct {
// AcceptableCAs contains zero or more, DER-encoded, X.501
// Distinguished Names. These are the names of root or intermediate CAs
// that the server wishes the returned certificate to be signed by. An
// empty slice indicates that the server has no preference.
AcceptableCAs [][]byte
// SignatureSchemes lists the signature schemes that the server is
// willing to verify.
SignatureSchemes []SignatureScheme
// Version is the TLS version that was negotiated for this connection.
Version uint16
// ctx is the context of the handshake that is in progress.
ctx context.Context
}
// Context returns the context of the handshake that is in progress.
// This context is a child of the context passed to HandshakeContext,
// if any, and is canceled when the handshake concludes.
func (c *CertificateRequestInfo) Context() context.Context {
return c.ctx
}
// RenegotiationSupport enumerates the different levels of support for TLS
// renegotiation. TLS renegotiation is the act of performing subsequent
// handshakes on a connection after the first. This significantly complicates
// the state machine and has been the source of numerous, subtle security
// issues. Initiating a renegotiation is not supported, but support for
// accepting renegotiation requests may be enabled.
//
// Even when enabled, the server may not change its identity between handshakes
// (i.e. the leaf certificate must be the same). Additionally, concurrent
// handshake and application data flow is not permitted so renegotiation can
// only be used with protocols that synchronise with the renegotiation, such as
// HTTPS.
//
// Renegotiation is not defined in TLS 1.3.
type RenegotiationSupport int
const (
// RenegotiateNever disables renegotiation.
RenegotiateNever RenegotiationSupport = iota
// RenegotiateOnceAsClient allows a remote server to request
// renegotiation once per connection.
RenegotiateOnceAsClient
// RenegotiateFreelyAsClient allows a remote server to repeatedly
// request renegotiation.
RenegotiateFreelyAsClient
)
// A Config structure is used to configure a TLS client or server.
// After one has been passed to a TLS function it must not be
// modified. A Config may be reused; the tls package will also not
// modify it.
type Config struct {
// Rand provides the source of entropy for nonces and RSA blinding.
// If Rand is nil, TLS uses the cryptographic random reader in package
// crypto/rand.
// The Reader must be safe for use by multiple goroutines.
Rand io.Reader
// Time returns the current time as the number of seconds since the epoch.
// If Time is nil, TLS uses time.Now.
Time func() time.Time
// Certificates contains one or more certificate chains to present to the
// other side of the connection. The first certificate compatible with the
// peer's requirements is selected automatically.
//
// Server configurations must set one of Certificates, GetCertificate or
// GetConfigForClient. Clients doing client-authentication may set either
// Certificates or GetClientCertificate.
//
// Note: if there are multiple Certificates, and they don't have the
// optional field Leaf set, certificate selection will incur a significant
// per-handshake performance cost.
Certificates []Certificate
// NameToCertificate maps from a certificate name to an element of
// Certificates. Note that a certificate name can be of the form
// '*.example.com' and so doesn't have to be a domain name as such.
//
// Deprecated: NameToCertificate only allows associating a single
// certificate with a given name. Leave this field nil to let the library
// select the first compatible chain from Certificates.
NameToCertificate map[string]*Certificate
// GetCertificate returns a Certificate based on the given
// ClientHelloInfo. It will only be called if the client supplies SNI
// information or if Certificates is empty.
//
// If GetCertificate is nil or returns nil, then the certificate is
// retrieved from NameToCertificate. If NameToCertificate is nil, the
// best element of Certificates will be used.
//
// Once a Certificate is returned it should not be modified.
GetCertificate func(*ClientHelloInfo) (*Certificate, error)
// GetClientCertificate, if not nil, is called when a server requests a
// certificate from a client. If set, the contents of Certificates will
// be ignored.
//
// If GetClientCertificate returns an error, the handshake will be
// aborted and that error will be returned. Otherwise
// GetClientCertificate must return a non-nil Certificate. If
// Certificate.Certificate is empty then no certificate will be sent to
// the server. If this is unacceptable to the server then it may abort
// the handshake.
//
// GetClientCertificate may be called multiple times for the same
// connection if renegotiation occurs or if TLS 1.3 is in use.
//
// Once a Certificate is returned it should not be modified.
GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error)
// GetConfigForClient, if not nil, is called after a ClientHello is
// received from a client. It may return a non-nil Config in order to
// change the Config that will be used to handle this connection. If
// the returned Config is nil, the original Config will be used. The
// Config returned by this callback may not be subsequently modified.
//
// If GetConfigForClient is nil, the Config passed to Server() will be
// used for all connections.
//
// If SessionTicketKey was explicitly set on the returned Config, or if
// SetSessionTicketKeys was called on the returned Config, those keys will
// be used. Otherwise, the original Config keys will be used (and possibly
// rotated if they are automatically managed).
GetConfigForClient func(*ClientHelloInfo) (*Config, error)
// VerifyPeerCertificate, if not nil, is called after normal
// certificate verification by either a TLS client or server. It
// receives the raw ASN.1 certificates provided by the peer and also
// any verified chains that normal processing found. If it returns a
// non-nil error, the handshake is aborted and that error results.
//
// If normal verification fails then the handshake will abort before
// considering this callback. If normal verification is disabled (on the
// client when InsecureSkipVerify is set, or on a server when ClientAuth is
// RequestClientCert or RequireAnyClientCert), then this callback will be
// considered but the verifiedChains argument will always be nil. When
// ClientAuth is NoClientCert, this callback is not called on the server.
// rawCerts may be empty on the server if ClientAuth is RequestClientCert or
// VerifyClientCertIfGiven.
//
// This callback is not invoked on resumed connections, as certificates are
// not re-verified on resumption.
//
// verifiedChains and its contents should not be modified.
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
// VerifyConnection, if not nil, is called after normal certificate
// verification and after VerifyPeerCertificate by either a TLS client
// or server. If it returns a non-nil error, the handshake is aborted
// and that error results.
//
// If normal verification fails then the handshake will abort before
// considering this callback. This callback will run for all connections,
// including resumptions, regardless of InsecureSkipVerify or ClientAuth
// settings.
VerifyConnection func(ConnectionState) error
// RootCAs defines the set of root certificate authorities
// that clients use when verifying server certificates.
// If RootCAs is nil, TLS uses the host's root CA set.
RootCAs *x509.CertPool
// NextProtos is a list of supported application level protocols, in
// order of preference. If both peers support ALPN, the selected
// protocol will be one from this list, and the connection will fail
// if there is no mutually supported protocol. If NextProtos is empty
// or the peer doesn't support ALPN, the connection will succeed and
// ConnectionState.NegotiatedProtocol will be empty.
NextProtos []string
// ServerName is used to verify the hostname on the returned
// certificates unless InsecureSkipVerify is given. It is also included
// in the client's handshake to support virtual hosting unless it is
// an IP address.
ServerName string
// ClientAuth determines the server's policy for
// TLS Client Authentication. The default is NoClientCert.
ClientAuth ClientAuthType
// ClientCAs defines the set of root certificate authorities
// that servers use if required to verify a client certificate
// by the policy in ClientAuth.
ClientCAs *x509.CertPool
// InsecureSkipVerify controls whether a client verifies the server's
// certificate chain and host name. If InsecureSkipVerify is true, crypto/tls
// accepts any certificate presented by the server and any host name in that
// certificate. In this mode, TLS is susceptible to machine-in-the-middle
// attacks unless custom verification is used. This should be used only for
// testing or in combination with VerifyConnection or VerifyPeerCertificate.
InsecureSkipVerify bool
// CipherSuites is a list of enabled TLS 1.0–1.2 cipher suites. The order of
// the list is ignored. Note that TLS 1.3 ciphersuites are not configurable.
//
// If CipherSuites is nil, a safe default list is used. The default cipher
// suites might change over time. In Go 1.22 RSA key exchange based cipher
// suites were removed from the default list, but can be re-added with the
// GODEBUG setting tlsrsakex=1. In Go 1.23 3DES cipher suites were removed
// from the default list, but can be re-added with the GODEBUG setting
// tls3des=1.
CipherSuites []uint16
// PreferServerCipherSuites is a legacy field and has no effect.
//
// It used to control whether the server would follow the client's or the
// server's preference. Servers now select the best mutually supported
// cipher suite based on logic that takes into account inferred client
// hardware, server hardware, and security.
//
// Deprecated: PreferServerCipherSuites is ignored.
PreferServerCipherSuites bool
// SessionTicketsDisabled may be set to true to disable session ticket and
// PSK (resumption) support. Note that on clients, session ticket support is
// also disabled if ClientSessionCache is nil.
SessionTicketsDisabled bool
// SessionTicketKey is used by TLS servers to provide session resumption.
// See RFC 5077 and the PSK mode of RFC 8446. If zero, it will be filled
// with random data before the first server handshake.
//
// Deprecated: if this field is left at zero, session ticket keys will be
// automatically rotated every day and dropped after seven days. For
// customizing the rotation schedule or synchronizing servers that are
// terminating connections for the same host, use SetSessionTicketKeys.
SessionTicketKey [32]byte
// ClientSessionCache is a cache of ClientSessionState entries for TLS
// session resumption. It is only used by clients.
ClientSessionCache ClientSessionCache
// UnwrapSession is called on the server to turn a ticket/identity
// previously produced by [WrapSession] into a usable session.
//
// UnwrapSession will usually either decrypt a session state in the ticket
// (for example with [Config.EncryptTicket]), or use the ticket as a handle
// to recover a previously stored state. It must use [ParseSessionState] to
// deserialize the session state.
//
// If UnwrapSession returns an error, the connection is terminated. If it
// returns (nil, nil), the session is ignored. crypto/tls may still choose
// not to resume the returned session.
UnwrapSession func(identity []byte, cs ConnectionState) (*SessionState, error)
// WrapSession is called on the server to produce a session ticket/identity.
//
// WrapSession must serialize the session state with [SessionState.Bytes].
// It may then encrypt the serialized state (for example with
// [Config.DecryptTicket]) and use it as the ticket, or store the state and
// return a handle for it.
//
// If WrapSession returns an error, the connection is terminated.
//
// Warning: the return value will be exposed on the wire and to clients in
// plaintext. The application is in charge of encrypting and authenticating
// it (and rotating keys) or returning high-entropy identifiers. Failing to
// do so correctly can compromise current, previous, and future connections
// depending on the protocol version.
WrapSession func(ConnectionState, *SessionState) ([]byte, error)
// MinVersion contains the minimum TLS version that is acceptable.
//
// By default, TLS 1.2 is currently used as the minimum. TLS 1.0 is the
// minimum supported by this package.
//
// The server-side default can be reverted to TLS 1.0 by including the value
// "tls10server=1" in the GODEBUG environment variable.
MinVersion uint16
// MaxVersion contains the maximum TLS version that is acceptable.
//
// By default, the maximum version supported by this package is used,
// which is currently TLS 1.3.
MaxVersion uint16
// CurvePreferences contains a set of supported key exchange mechanisms.
// The name refers to elliptic curves for legacy reasons, see [CurveID].
// The order of the list is ignored, and key exchange mechanisms are chosen
// from this list using an internal preference order. If empty, the default
// will be used.
//
// From Go 1.24, the default includes the [X25519MLKEM768] hybrid
// post-quantum key exchange. To disable it, set CurvePreferences explicitly
// or use the GODEBUG=tlsmlkem=0 environment variable.
CurvePreferences []CurveID
// DynamicRecordSizingDisabled disables adaptive sizing of TLS records.
// When true, the largest possible TLS record size is always used. When
// false, the size of TLS records may be adjusted in an attempt to
// improve latency.
DynamicRecordSizingDisabled bool
// Renegotiation controls what types of renegotiation are supported.
// The default, none, is correct for the vast majority of applications.
Renegotiation RenegotiationSupport
// KeyLogWriter optionally specifies a destination for TLS master secrets
// in NSS key log format that can be used to allow external programs
// such as Wireshark to decrypt TLS connections.
// See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format.
// Use of KeyLogWriter compromises security and should only be
// used for debugging.
KeyLogWriter io.Writer
// EncryptedClientHelloConfigList is a serialized ECHConfigList. If
// provided, clients will attempt to connect to servers using Encrypted
// Client Hello (ECH) using one of the provided ECHConfigs.
//
// Servers do not use this field. In order to configure ECH for servers, see
// the EncryptedClientHelloKeys field.
//
// If the list contains no valid ECH configs, the handshake will fail
// and return an error.
//
// If EncryptedClientHelloConfigList is set, MinVersion, if set, must
// be VersionTLS13.
//
// When EncryptedClientHelloConfigList is set, the handshake will only
// succeed if ECH is successfully negotiated. If the server rejects ECH,
// an ECHRejectionError error will be returned, which may contain a new
// ECHConfigList that the server suggests using.
//
// How this field is parsed may change in future Go versions, if the
// encoding described in the final Encrypted Client Hello RFC changes.
EncryptedClientHelloConfigList []byte
// EncryptedClientHelloRejectionVerify, if not nil, is called when ECH is
// rejected by the remote server, in order to verify the ECH provider
// certificate in the outer ClientHello. If it returns a non-nil error, the
// handshake is aborted and that error results.
//
// On the server side this field is not used.
//
// Unlike VerifyPeerCertificate and VerifyConnection, normal certificate
// verification will not be performed before calling
// EncryptedClientHelloRejectionVerify.
//
// If EncryptedClientHelloRejectionVerify is nil and ECH is rejected, the
// roots in RootCAs will be used to verify the ECH providers public
// certificate. VerifyPeerCertificate and VerifyConnection are not called
// when ECH is rejected, even if set, and InsecureSkipVerify is ignored.
EncryptedClientHelloRejectionVerify func(ConnectionState) error
// GetEncryptedClientHelloKeys, if not nil, is called when by a server when
// a client attempts ECH.
//
// If GetEncryptedClientHelloKeys is not nil, [EncryptedClientHelloKeys] is
// ignored.
//
// If GetEncryptedClientHelloKeys returns an error, the handshake will be
// aborted and the error will be returned. Otherwise,
// GetEncryptedClientHelloKeys must return a non-nil slice of
// [EncryptedClientHelloKey] that represents the acceptable ECH keys.
//
// For further details, see [EncryptedClientHelloKeys].
GetEncryptedClientHelloKeys func(*ClientHelloInfo) ([]EncryptedClientHelloKey, error)
// EncryptedClientHelloKeys are the ECH keys to use when a client
// attempts ECH.
//
// If EncryptedClientHelloKeys is set, MinVersion, if set, must be
// VersionTLS13.
//
// If a client attempts ECH, but it is rejected by the server, the server
// will send a list of configs to retry based on the set of
// EncryptedClientHelloKeys which have the SendAsRetry field set.
//
// If GetEncryptedClientHelloKeys is non-nil, EncryptedClientHelloKeys is
// ignored.
//
// On the client side, this field is ignored. In order to configure ECH for
// clients, see the EncryptedClientHelloConfigList field.
EncryptedClientHelloKeys []EncryptedClientHelloKey
// mutex protects sessionTicketKeys and autoSessionTicketKeys.
mutex sync.RWMutex
// sessionTicketKeys contains zero or more ticket keys. If set, it means
// the keys were set with SessionTicketKey or SetSessionTicketKeys. The
// first key is used for new tickets and any subsequent keys can be used to
// decrypt old tickets. The slice contents are not protected by the mutex
// and are immutable.
sessionTicketKeys []ticketKey
// autoSessionTicketKeys is like sessionTicketKeys but is owned by the
// auto-rotation logic. See Config.ticketKeys.
autoSessionTicketKeys []ticketKey
}
// EncryptedClientHelloKey holds a private key that is associated
// with a specific ECH config known to a client.
type EncryptedClientHelloKey struct {
// Config should be a marshalled ECHConfig associated with PrivateKey. This
// must match the config provided to clients byte-for-byte. The config
// should only specify the DHKEM(X25519, HKDF-SHA256) KEM ID (0x0020), the
// HKDF-SHA256 KDF ID (0x0001), and a subset of the following AEAD IDs:
// AES-128-GCM (0x0001), AES-256-GCM (0x0002), ChaCha20Poly1305 (0x0003).
Config []byte
// PrivateKey should be a marshalled private key. Currently, we expect
// this to be the output of [ecdh.PrivateKey.Bytes].
PrivateKey []byte
// SendAsRetry indicates if Config should be sent as part of the list of
// retry configs when ECH is requested by the client but rejected by the
// server.
SendAsRetry bool
}
const (
// ticketKeyLifetime is how long a ticket key remains valid and can be used to
// resume a client connection.
ticketKeyLifetime = 7 * 24 * time.Hour // 7 days
// ticketKeyRotation is how often the server should rotate the session ticket key
// that is used for new tickets.
ticketKeyRotation = 24 * time.Hour
)
// ticketKey is the internal representation of a session ticket key.
type ticketKey struct {
aesKey [16]byte
hmacKey [16]byte
// created is the time at which this ticket key was created. See Config.ticketKeys.
created time.Time
}
// ticketKeyFromBytes converts from the external representation of a session
// ticket key to a ticketKey. Externally, session ticket keys are 32 random
// bytes and this function expands that into sufficient name and key material.
func (c *Config) ticketKeyFromBytes(b [32]byte) (key ticketKey) {
hashed := sha512.Sum512(b[:])
// The first 16 bytes of the hash used to be exposed on the wire as a ticket
// prefix. They MUST NOT be used as a secret. In the future, it would make
// sense to use a proper KDF here, like HKDF with a fixed salt.
const legacyTicketKeyNameLen = 16
copy(key.aesKey[:], hashed[legacyTicketKeyNameLen:])
copy(key.hmacKey[:], hashed[legacyTicketKeyNameLen+len(key.aesKey):])
key.created = c.time()
return key
}
// maxSessionTicketLifetime is the maximum allowed lifetime of a TLS 1.3 session
// ticket, and the lifetime we set for all tickets we send.
const maxSessionTicketLifetime = 7 * 24 * time.Hour
// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a [Config] that is
// being used concurrently by a TLS client or server.
func (c *Config) Clone() *Config {
if c == nil {
return nil
}
c.mutex.RLock()
defer c.mutex.RUnlock()
return &Config{
Rand: c.Rand,
Time: c.Time,
Certificates: c.Certificates,
NameToCertificate: c.NameToCertificate,
GetCertificate: c.GetCertificate,
GetClientCertificate: c.GetClientCertificate,
GetConfigForClient: c.GetConfigForClient,
GetEncryptedClientHelloKeys: c.GetEncryptedClientHelloKeys,
VerifyPeerCertificate: c.VerifyPeerCertificate,
VerifyConnection: c.VerifyConnection,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
ServerName: c.ServerName,
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
CipherSuites: c.CipherSuites,
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey,
ClientSessionCache: c.ClientSessionCache,
UnwrapSession: c.UnwrapSession,
WrapSession: c.WrapSession,
MinVersion: c.MinVersion,
MaxVersion: c.MaxVersion,
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation,
KeyLogWriter: c.KeyLogWriter,
EncryptedClientHelloConfigList: c.EncryptedClientHelloConfigList,
EncryptedClientHelloRejectionVerify: c.EncryptedClientHelloRejectionVerify,
EncryptedClientHelloKeys: c.EncryptedClientHelloKeys,
sessionTicketKeys: c.sessionTicketKeys,
autoSessionTicketKeys: c.autoSessionTicketKeys,
}
}
// deprecatedSessionTicketKey is set as the prefix of SessionTicketKey if it was
// randomized for backwards compatibility but is not in use.
var deprecatedSessionTicketKey = []byte("DEPRECATED")
// initLegacySessionTicketKeyRLocked ensures the legacy SessionTicketKey field is
// randomized if empty, and that sessionTicketKeys is populated from it otherwise.
func (c *Config) initLegacySessionTicketKeyRLocked() {
// Don't write if SessionTicketKey is already defined as our deprecated string,
// or if it is defined by the user but sessionTicketKeys is already set.
if c.SessionTicketKey != [32]byte{} &&
(bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) || len(c.sessionTicketKeys) > 0) {
return
}
// We need to write some data, so get an exclusive lock and re-check any conditions.
c.mutex.RUnlock()
defer c.mutex.RLock()
c.mutex.Lock()
defer c.mutex.Unlock()
if c.SessionTicketKey == [32]byte{} {
if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil {
panic(fmt.Sprintf("tls: unable to generate random session ticket key: %v", err))
}
// Write the deprecated prefix at the beginning so we know we created
// it. This key with the DEPRECATED prefix isn't used as an actual
// session ticket key, and is only randomized in case the application
// reuses it for some reason.
copy(c.SessionTicketKey[:], deprecatedSessionTicketKey)
} else if !bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) && len(c.sessionTicketKeys) == 0 {
c.sessionTicketKeys = []ticketKey{c.ticketKeyFromBytes(c.SessionTicketKey)}
}
}
// ticketKeys returns the ticketKeys for this connection.
// If configForClient has explicitly set keys, those will
// be returned. Otherwise, the keys on c will be used and
// may be rotated if auto-managed.
// During rotation, any expired session ticket keys are deleted from
// c.sessionTicketKeys. If the session ticket key that is currently
// encrypting tickets (ie. the first ticketKey in c.sessionTicketKeys)
// is not fresh, then a new session ticket key will be
// created and prepended to c.sessionTicketKeys.
func (c *Config) ticketKeys(configForClient *Config) []ticketKey {
// If the ConfigForClient callback returned a Config with explicitly set
// keys, use those, otherwise just use the original Config.
if configForClient != nil {
configForClient.mutex.RLock()
if configForClient.SessionTicketsDisabled {
configForClient.mutex.RUnlock()
return nil
}
configForClient.initLegacySessionTicketKeyRLocked()
if len(configForClient.sessionTicketKeys) != 0 {
ret := configForClient.sessionTicketKeys
configForClient.mutex.RUnlock()
return ret
}
configForClient.mutex.RUnlock()
}
c.mutex.RLock()
defer c.mutex.RUnlock()
if c.SessionTicketsDisabled {
return nil
}
c.initLegacySessionTicketKeyRLocked()
if len(c.sessionTicketKeys) != 0 {
return c.sessionTicketKeys
}
// Fast path for the common case where the key is fresh enough.
if len(c.autoSessionTicketKeys) > 0 && c.time().Sub(c.autoSessionTicketKeys[0].created) < ticketKeyRotation {
return c.autoSessionTicketKeys
}
// autoSessionTicketKeys are managed by auto-rotation.
c.mutex.RUnlock()
defer c.mutex.RLock()
c.mutex.Lock()
defer c.mutex.Unlock()
// Re-check the condition in case it changed since obtaining the new lock.
if len(c.autoSessionTicketKeys) == 0 || c.time().Sub(c.autoSessionTicketKeys[0].created) >= ticketKeyRotation {
var newKey [32]byte
if _, err := io.ReadFull(c.rand(), newKey[:]); err != nil {
panic(fmt.Sprintf("unable to generate random session ticket key: %v", err))
}
valid := make([]ticketKey, 0, len(c.autoSessionTicketKeys)+1)
valid = append(valid, c.ticketKeyFromBytes(newKey))
for _, k := range c.autoSessionTicketKeys {
// While rotating the current key, also remove any expired ones.
if c.time().Sub(k.created) < ticketKeyLifetime {
valid = append(valid, k)
}
}
c.autoSessionTicketKeys = valid
}
return c.autoSessionTicketKeys
}
// SetSessionTicketKeys updates the session ticket keys for a server.
//
// The first key will be used when creating new tickets, while all keys can be
// used for decrypting tickets. It is safe to call this function while the
// server is running in order to rotate the session ticket keys. The function
// will panic if keys is empty.
//
// Calling this function will turn off automatic session ticket key rotation.
//
// If multiple servers are terminating connections for the same host they should
// all have the same session ticket keys. If the session ticket keys leaks,
// previously recorded and future TLS connections using those keys might be
// compromised.
func (c *Config) SetSessionTicketKeys(keys [][32]byte) {
if len(keys) == 0 {
panic("tls: keys must have at least one key")
}
newKeys := make([]ticketKey, len(keys))
for i, bytes := range keys {
newKeys[i] = c.ticketKeyFromBytes(bytes)
}
c.mutex.Lock()
c.sessionTicketKeys = newKeys
c.mutex.Unlock()
}
func (c *Config) rand() io.Reader {
r := c.Rand
if r == nil {
return rand.Reader
}
return r
}
func (c *Config) time() time.Time {
t := c.Time
if t == nil {
t = time.Now
}
return t()
}
func (c *Config) cipherSuites(aesGCMPreferred bool) []uint16 {
var cipherSuites []uint16
if c.CipherSuites == nil {
cipherSuites = defaultCipherSuites(aesGCMPreferred)
} else {
cipherSuites = supportedCipherSuites(aesGCMPreferred)
cipherSuites = slices.DeleteFunc(cipherSuites, func(id uint16) bool {
return !slices.Contains(c.CipherSuites, id)
})
}
if fips140tls.Required() {
cipherSuites = slices.DeleteFunc(cipherSuites, func(id uint16) bool {
return !slices.Contains(allowedCipherSuitesFIPS, id)
})
}
return cipherSuites
}
// supportedCipherSuites returns the supported TLS 1.0–1.2 cipher suites in an
// undefined order. For preference ordering, use [Config.cipherSuites].
func (c *Config) supportedCipherSuites() []uint16 {
return c.cipherSuites(false)
}
var supportedVersions = []uint16{
VersionTLS13,
VersionTLS12,
VersionTLS11,
VersionTLS10,
}
// roleClient and roleServer are meant to call supportedVersions and parents
// with more readability at the callsite.
const roleClient = true
const roleServer = false
var tls10server = godebug.New("tls10server")
// supportedVersions returns the list of supported TLS versions, sorted from
// highest to lowest (and hence also in preference order).
func (c *Config) supportedVersions(isClient bool) []uint16 {
versions := make([]uint16, 0, len(supportedVersions))
for _, v := range supportedVersions {
if fips140tls.Required() && !slices.Contains(allowedSupportedVersionsFIPS, v) {
continue
}
if (c == nil || c.MinVersion == 0) && v < VersionTLS12 {
if isClient || tls10server.Value() != "1" {
continue
}
}
if isClient && c.EncryptedClientHelloConfigList != nil && v < VersionTLS13 {
continue
}
if c != nil && c.MinVersion != 0 && v < c.MinVersion {
continue
}
if c != nil && c.MaxVersion != 0 && v > c.MaxVersion {
continue
}
versions = append(versions, v)
}
return versions
}
func (c *Config) maxSupportedVersion(isClient bool) uint16 {
supportedVersions := c.supportedVersions(isClient)
if len(supportedVersions) == 0 {
return 0
}
return supportedVersions[0]
}
// supportedVersionsFromMax returns a list of supported versions derived from a
// legacy maximum version value. Note that only versions supported by this
// library are returned. Any newer peer will use supportedVersions anyway.
func supportedVersionsFromMax(maxVersion uint16) []uint16 {
versions := make([]uint16, 0, len(supportedVersions))
for _, v := range supportedVersions {
if v > maxVersion {
continue
}
versions = append(versions, v)
}
return versions
}
func (c *Config) curvePreferences(version uint16) []CurveID {
curvePreferences := defaultCurvePreferences()
if fips140tls.Required() {
curvePreferences = slices.DeleteFunc(curvePreferences, func(x CurveID) bool {
return !slices.Contains(allowedCurvePreferencesFIPS, x)
})
}
if c != nil && len(c.CurvePreferences) != 0 {
curvePreferences = slices.DeleteFunc(curvePreferences, func(x CurveID) bool {
return !slices.Contains(c.CurvePreferences, x)
})
}
if version < VersionTLS13 {
curvePreferences = slices.DeleteFunc(curvePreferences, isTLS13OnlyKeyExchange)
}
return curvePreferences
}
func (c *Config) supportsCurve(version uint16, curve CurveID) bool {
return slices.Contains(c.curvePreferences(version), curve)
}
// mutualVersion returns the protocol version to use given the advertised
// versions of the peer. The highest supported version is preferred.
func (c *Config) mutualVersion(isClient bool, peerVersions []uint16) (uint16, bool) {
supportedVersions := c.supportedVersions(isClient)
for _, v := range supportedVersions {
if slices.Contains(peerVersions, v) {
return v, true
}
}
return 0, false
}
// errNoCertificates should be an internal detail,
// but widely used packages access it using linkname.
// Notable members of the hall of shame include:
// - github.com/xtls/xray-core
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname errNoCertificates
var errNoCertificates = errors.New("tls: no certificates configured")
// getCertificate returns the best certificate for the given ClientHelloInfo,
// defaulting to the first element of c.Certificates.
func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) {
if c.GetCertificate != nil &&
(len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) {
cert, err := c.GetCertificate(clientHello)
if cert != nil || err != nil {
return cert, err
}
}
if len(c.Certificates) == 0 {
return nil, errNoCertificates
}
if len(c.Certificates) == 1 {
// There's only one choice, so no point doing any work.
return &c.Certificates[0], nil
}
if c.NameToCertificate != nil {
name := strings.ToLower(clientHello.ServerName)
if cert, ok := c.NameToCertificate[name]; ok {
return cert, nil
}
if len(name) > 0 {
labels := strings.Split(name, ".")
labels[0] = "*"
wildcardName := strings.Join(labels, ".")
if cert, ok := c.NameToCertificate[wildcardName]; ok {
return cert, nil
}
}
}
for _, cert := range c.Certificates {
if err := clientHello.SupportsCertificate(&cert); err == nil {
return &cert, nil
}
}
// If nothing matches, return the first certificate.
return &c.Certificates[0], nil
}
// SupportsCertificate returns nil if the provided certificate is supported by
// the client that sent the ClientHello. Otherwise, it returns an error
// describing the reason for the incompatibility.
//
// If this [ClientHelloInfo] was passed to a GetConfigForClient or GetCertificate
// callback, this method will take into account the associated [Config]. Note that
// if GetConfigForClient returns a different [Config], the change can't be
// accounted for by this method.
//
// This function will call x509.ParseCertificate unless c.Leaf is set, which can
// incur a significant performance cost.
func (chi *ClientHelloInfo) SupportsCertificate(c *Certificate) error {
// Note we don't currently support certificate_authorities nor
// signature_algorithms_cert, and don't check the algorithms of the
// signatures on the chain (which anyway are a SHOULD, see RFC 8446,
// Section 4.4.2.2).
config := chi.config
if config == nil {
config = &Config{}
}
vers, ok := config.mutualVersion(roleServer, chi.SupportedVersions)
if !ok {
return errors.New("no mutually supported protocol versions")
}
// If the client specified the name they are trying to connect to, the
// certificate needs to be valid for it.
if chi.ServerName != "" {
x509Cert, err := c.leaf()
if err != nil {
return fmt.Errorf("failed to parse certificate: %w", err)
}
if err := x509Cert.VerifyHostname(chi.ServerName); err != nil {
return fmt.Errorf("certificate is not valid for requested server name: %w", err)
}
}
// supportsRSAFallback returns nil if the certificate and connection support
// the static RSA key exchange, and unsupported otherwise. The logic for
// supporting static RSA is completely disjoint from the logic for
// supporting signed key exchanges, so we just check it as a fallback.
supportsRSAFallback := func(unsupported error) error {
// TLS 1.3 dropped support for the static RSA key exchange.
if vers == VersionTLS13 {
return unsupported
}
// The static RSA key exchange works by decrypting a challenge with the
// RSA private key, not by signing, so check the PrivateKey implements
// crypto.Decrypter, like *rsa.PrivateKey does.
if priv, ok := c.PrivateKey.(crypto.Decrypter); ok {
if _, ok := priv.Public().(*rsa.PublicKey); !ok {
return unsupported
}
} else {
return unsupported
}
// Finally, there needs to be a mutual cipher suite that uses the static
// RSA key exchange instead of ECDHE.
rsaCipherSuite := selectCipherSuite(chi.CipherSuites, config.supportedCipherSuites(), func(c *cipherSuite) bool {
if c.flags&suiteECDHE != 0 {
return false
}
if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 {
return false
}
return true
})
if rsaCipherSuite == nil {
return unsupported
}
return nil
}
// If the client sent the signature_algorithms extension, ensure it supports
// schemes we can use with this certificate and TLS version.
if len(chi.SignatureSchemes) > 0 {
if _, err := selectSignatureScheme(vers, c, chi.SignatureSchemes); err != nil {
return supportsRSAFallback(err)
}
}
// In TLS 1.3 we are done because supported_groups is only relevant to the
// ECDHE computation, point format negotiation is removed, cipher suites are
// only relevant to the AEAD choice, and static RSA does not exist.
if vers == VersionTLS13 {
return nil
}
// The only signed key exchange we support is ECDHE.
ecdheSupported, err := supportsECDHE(config, vers, chi.SupportedCurves, chi.SupportedPoints)
if err != nil {
return err
}
if !ecdheSupported {
return supportsRSAFallback(errors.New("client doesn't support ECDHE, can only use legacy RSA key exchange"))
}
var ecdsaCipherSuite bool
if priv, ok := c.PrivateKey.(crypto.Signer); ok {
switch pub := priv.Public().(type) {
case *ecdsa.PublicKey:
var curve CurveID
switch pub.Curve {
case elliptic.P256():
curve = CurveP256
case elliptic.P384():
curve = CurveP384
case elliptic.P521():
curve = CurveP521
default:
return supportsRSAFallback(unsupportedCertificateError(c))
}
var curveOk bool
for _, c := range chi.SupportedCurves {
if c == curve && config.supportsCurve(vers, c) {
curveOk = true
break
}
}
if !curveOk {
return errors.New("client doesn't support certificate curve")
}
ecdsaCipherSuite = true
case ed25519.PublicKey:
if vers < VersionTLS12 || len(chi.SignatureSchemes) == 0 {
return errors.New("connection doesn't support Ed25519")
}
ecdsaCipherSuite = true
case *rsa.PublicKey:
default:
return supportsRSAFallback(unsupportedCertificateError(c))
}
} else {
return supportsRSAFallback(unsupportedCertificateError(c))
}
// Make sure that there is a mutually supported cipher suite that works with
// this certificate. Cipher suite selection will then apply the logic in
// reverse to pick it. See also serverHandshakeState.cipherSuiteOk.
cipherSuite := selectCipherSuite(chi.CipherSuites, config.supportedCipherSuites(), func(c *cipherSuite) bool {
if c.flags&suiteECDHE == 0 {
return false
}
if c.flags&suiteECSign != 0 {
if !ecdsaCipherSuite {
return false
}
} else {
if ecdsaCipherSuite {
return false
}
}
if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 {
return false
}
return true
})
if cipherSuite == nil {
return supportsRSAFallback(errors.New("client doesn't support any cipher suites compatible with the certificate"))
}
return nil
}
// SupportsCertificate returns nil if the provided certificate is supported by
// the server that sent the CertificateRequest. Otherwise, it returns an error
// describing the reason for the incompatibility.
func (cri *CertificateRequestInfo) SupportsCertificate(c *Certificate) error {
if _, err := selectSignatureScheme(cri.Version, c, cri.SignatureSchemes); err != nil {
return err
}
if len(cri.AcceptableCAs) == 0 {
return nil
}
for j, cert := range c.Certificate {
x509Cert := c.Leaf
// Parse the certificate if this isn't the leaf node, or if
// chain.Leaf was nil.
if j != 0 || x509Cert == nil {
var err error
if x509Cert, err = x509.ParseCertificate(cert); err != nil {
return fmt.Errorf("failed to parse certificate #%d in the chain: %w", j, err)
}
}
for _, ca := range cri.AcceptableCAs {
if bytes.Equal(x509Cert.RawIssuer, ca) {
return nil
}
}
}
return errors.New("chain is not signed by an acceptable CA")
}
// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate
// from the CommonName and SubjectAlternateName fields of each of the leaf
// certificates.
//
// Deprecated: NameToCertificate only allows associating a single certificate
// with a given name. Leave that field nil to let the library select the first
// compatible chain from Certificates.
func (c *Config) BuildNameToCertificate() {
c.NameToCertificate = make(map[string]*Certificate)
for i := range c.Certificates {
cert := &c.Certificates[i]
x509Cert, err := cert.leaf()
if err != nil {
continue
}
// If SANs are *not* present, some clients will consider the certificate
// valid for the name in the Common Name.
if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 {
c.NameToCertificate[x509Cert.Subject.CommonName] = cert
}
for _, san := range x509Cert.DNSNames {
c.NameToCertificate[san] = cert
}
}
}
const (
keyLogLabelTLS12 = "CLIENT_RANDOM"
keyLogLabelClientHandshake = "CLIENT_HANDSHAKE_TRAFFIC_SECRET"
keyLogLabelServerHandshake = "SERVER_HANDSHAKE_TRAFFIC_SECRET"
keyLogLabelClientTraffic = "CLIENT_TRAFFIC_SECRET_0"
keyLogLabelServerTraffic = "SERVER_TRAFFIC_SECRET_0"
)
func (c *Config) writeKeyLog(label string, clientRandom, secret []byte) error {
if c.KeyLogWriter == nil {
return nil
}
logLine := fmt.Appendf(nil, "%s %x %x\n", label, clientRandom, secret)
writerMutex.Lock()
_, err := c.KeyLogWriter.Write(logLine)
writerMutex.Unlock()
return err
}
// writerMutex protects all KeyLogWriters globally. It is rarely enabled,
// and is only for debugging, so a global mutex saves space.
var writerMutex sync.Mutex
// A Certificate is a chain of one or more certificates, leaf first.
type Certificate struct {
Certificate [][]byte
// PrivateKey contains the private key corresponding to the public key in
// Leaf. This must implement crypto.Signer with an RSA, ECDSA or Ed25519 PublicKey.
// For a server up to TLS 1.2, it can also implement crypto.Decrypter with
// an RSA PublicKey.
PrivateKey crypto.PrivateKey
// SupportedSignatureAlgorithms is an optional list restricting what
// signature algorithms the PrivateKey can be used for.
SupportedSignatureAlgorithms []SignatureScheme
// OCSPStaple contains an optional OCSP response which will be served
// to clients that request it.
OCSPStaple []byte
// SignedCertificateTimestamps contains an optional list of Signed
// Certificate Timestamps which will be served to clients that request it.
SignedCertificateTimestamps [][]byte
// Leaf is the parsed form of the leaf certificate, which may be initialized
// using x509.ParseCertificate to reduce per-handshake processing. If nil,
// the leaf certificate will be parsed as needed.
Leaf *x509.Certificate
}
// leaf returns the parsed leaf certificate, either from c.Leaf or by parsing
// the corresponding c.Certificate[0].
func (c *Certificate) leaf() (*x509.Certificate, error) {
if c.Leaf != nil {
return c.Leaf, nil
}
return x509.ParseCertificate(c.Certificate[0])
}
type handshakeMessage interface {
marshal() ([]byte, error)
unmarshal([]byte) bool
}
type handshakeMessageWithOriginalBytes interface {
handshakeMessage
// originalBytes should return the original bytes that were passed to
// unmarshal to create the message. If the message was not produced by
// unmarshal, it should return nil.
originalBytes() []byte
}
// lruSessionCache is a ClientSessionCache implementation that uses an LRU
// caching strategy.
type lruSessionCache struct {
sync.Mutex
m map[string]*list.Element
q *list.List
capacity int
}
type lruSessionCacheEntry struct {
sessionKey string
state *ClientSessionState
}
// NewLRUClientSessionCache returns a [ClientSessionCache] with the given
// capacity that uses an LRU strategy. If capacity is < 1, a default capacity
// is used instead.
func NewLRUClientSessionCache(capacity int) ClientSessionCache {
const defaultSessionCacheCapacity = 64
if capacity < 1 {
capacity = defaultSessionCacheCapacity
}
return &lruSessionCache{
m: make(map[string]*list.Element),
q: list.New(),
capacity: capacity,
}
}
// Put adds the provided (sessionKey, cs) pair to the cache. If cs is nil, the entry
// corresponding to sessionKey is removed from the cache instead.
func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) {
c.Lock()
defer c.Unlock()
if elem, ok := c.m[sessionKey]; ok {
if cs == nil {
c.q.Remove(elem)
delete(c.m, sessionKey)
} else {
entry := elem.Value.(*lruSessionCacheEntry)
entry.state = cs
c.q.MoveToFront(elem)
}
return
}
if c.q.Len() < c.capacity {
entry := &lruSessionCacheEntry{sessionKey, cs}
c.m[sessionKey] = c.q.PushFront(entry)
return
}
elem := c.q.Back()
entry := elem.Value.(*lruSessionCacheEntry)
delete(c.m, entry.sessionKey)
entry.sessionKey = sessionKey
entry.state = cs
c.q.MoveToFront(elem)
c.m[sessionKey] = elem
}
// Get returns the [ClientSessionState] value associated with a given key. It
// returns (nil, false) if no value is found.
func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) {
c.Lock()
defer c.Unlock()
if elem, ok := c.m[sessionKey]; ok {
c.q.MoveToFront(elem)
return elem.Value.(*lruSessionCacheEntry).state, true
}
return nil, false
}
var emptyConfig Config
func defaultConfig() *Config {
return &emptyConfig
}
func unexpectedMessageError(wanted, got any) error {
return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted)
}
var testingOnlySupportedSignatureAlgorithms []SignatureScheme
// supportedSignatureAlgorithms returns the supported signature algorithms for
// the given minimum TLS version, to advertise in ClientHello and
// CertificateRequest messages.
func supportedSignatureAlgorithms(minVers uint16) []SignatureScheme {
sigAlgs := defaultSupportedSignatureAlgorithms()
if testingOnlySupportedSignatureAlgorithms != nil {
sigAlgs = slices.Clone(testingOnlySupportedSignatureAlgorithms)
}
return slices.DeleteFunc(sigAlgs, func(s SignatureScheme) bool {
return isDisabledSignatureAlgorithm(minVers, s, false)
})
}
var tlssha1 = godebug.New("tlssha1")
func isDisabledSignatureAlgorithm(version uint16, s SignatureScheme, isCert bool) bool {
if fips140tls.Required() && !slices.Contains(allowedSignatureAlgorithmsFIPS, s) {
return true
}
// For the _cert extension we include all algorithms, including SHA-1 and
// PKCS#1 v1.5, because it's more likely that something on our side will be
// willing to accept a *-with-SHA1 certificate (e.g. with a custom
// VerifyConnection or by a direct match with the CertPool), than that the
// peer would have a better certificate but is just choosing not to send it.
// crypto/x509 will refuse to verify important SHA-1 signatures anyway.
if isCert {
return false
}
// TLS 1.3 removed support for PKCS#1 v1.5 and SHA-1 signatures,
// and Go 1.25 removed support for SHA-1 signatures in TLS 1.2.
if version > VersionTLS12 {
sigType, sigHash, _ := typeAndHashFromSignatureScheme(s)
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
return true
}
} else if tlssha1.Value() != "1" {
_, sigHash, _ := typeAndHashFromSignatureScheme(s)
if sigHash == crypto.SHA1 {
return true
}
}
return false
}
// supportedSignatureAlgorithmsCert returns the supported algorithms for
// signatures in certificates.
func supportedSignatureAlgorithmsCert() []SignatureScheme {
sigAlgs := defaultSupportedSignatureAlgorithms()
return slices.DeleteFunc(sigAlgs, func(s SignatureScheme) bool {
return isDisabledSignatureAlgorithm(0, s, true)
})
}
func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool {
return slices.Contains(supportedSignatureAlgorithms, sigAlg)
}
// CertificateVerificationError is returned when certificate verification fails during the handshake.
type CertificateVerificationError struct {
// UnverifiedCertificates and its contents should not be modified.
UnverifiedCertificates []*x509.Certificate
Err error
}
func (e *CertificateVerificationError) Error() string {
return fmt.Sprintf("tls: failed to verify certificate: %s", e.Err)
}
func (e *CertificateVerificationError) Unwrap() error {
return e.Err
}
// fipsAllowedChains returns chains that are allowed to be used in a TLS connection
// based on the current fips140tls enforcement setting.
//
// If fips140tls is not required, the chains are returned as-is with no processing.
// Otherwise, the returned chains are filtered to only those allowed by FIPS 140-3.
// If this results in no chains it returns an error.
func fipsAllowedChains(chains [][]*x509.Certificate) ([][]*x509.Certificate, error) {
if !fips140tls.Required() {
return chains, nil
}
permittedChains := make([][]*x509.Certificate, 0, len(chains))
for _, chain := range chains {
if fipsAllowChain(chain) {
permittedChains = append(permittedChains, chain)
}
}
if len(permittedChains) == 0 {
return nil, errors.New("tls: no FIPS compatible certificate chains found")
}
return permittedChains, nil
}
func fipsAllowChain(chain []*x509.Certificate) bool {
if len(chain) == 0 {
return false
}
for _, cert := range chain {
if !isCertificateAllowedFIPS(cert) {
return false
}
}
return true
}
// Code generated by "stringer -linecomment -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go"; DO NOT EDIT.
package tls
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[PKCS1WithSHA256-1025]
_ = x[PKCS1WithSHA384-1281]
_ = x[PKCS1WithSHA512-1537]
_ = x[PSSWithSHA256-2052]
_ = x[PSSWithSHA384-2053]
_ = x[PSSWithSHA512-2054]
_ = x[ECDSAWithP256AndSHA256-1027]
_ = x[ECDSAWithP384AndSHA384-1283]
_ = x[ECDSAWithP521AndSHA512-1539]
_ = x[Ed25519-2055]
_ = x[PKCS1WithSHA1-513]
_ = x[ECDSAWithSHA1-515]
}
const (
_SignatureScheme_name_0 = "PKCS1WithSHA1"
_SignatureScheme_name_1 = "ECDSAWithSHA1"
_SignatureScheme_name_2 = "PKCS1WithSHA256"
_SignatureScheme_name_3 = "ECDSAWithP256AndSHA256"
_SignatureScheme_name_4 = "PKCS1WithSHA384"
_SignatureScheme_name_5 = "ECDSAWithP384AndSHA384"
_SignatureScheme_name_6 = "PKCS1WithSHA512"
_SignatureScheme_name_7 = "ECDSAWithP521AndSHA512"
_SignatureScheme_name_8 = "PSSWithSHA256PSSWithSHA384PSSWithSHA512Ed25519"
)
var (
_SignatureScheme_index_8 = [...]uint8{0, 13, 26, 39, 46}
)
func (i SignatureScheme) String() string {
switch {
case i == 513:
return _SignatureScheme_name_0
case i == 515:
return _SignatureScheme_name_1
case i == 1025:
return _SignatureScheme_name_2
case i == 1027:
return _SignatureScheme_name_3
case i == 1281:
return _SignatureScheme_name_4
case i == 1283:
return _SignatureScheme_name_5
case i == 1537:
return _SignatureScheme_name_6
case i == 1539:
return _SignatureScheme_name_7
case 2052 <= i && i <= 2055:
i -= 2052
return _SignatureScheme_name_8[_SignatureScheme_index_8[i]:_SignatureScheme_index_8[i+1]]
default:
return "SignatureScheme(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[CurveP256-23]
_ = x[CurveP384-24]
_ = x[CurveP521-25]
_ = x[X25519-29]
_ = x[X25519MLKEM768-4588]
}
const (
_CurveID_name_0 = "CurveP256CurveP384CurveP521"
_CurveID_name_1 = "X25519"
_CurveID_name_2 = "X25519MLKEM768"
)
var (
_CurveID_index_0 = [...]uint8{0, 9, 18, 27}
)
func (i CurveID) String() string {
switch {
case 23 <= i && i <= 25:
i -= 23
return _CurveID_name_0[_CurveID_index_0[i]:_CurveID_index_0[i+1]]
case i == 29:
return _CurveID_name_1
case i == 4588:
return _CurveID_name_2
default:
return "CurveID(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[NoClientCert-0]
_ = x[RequestClientCert-1]
_ = x[RequireAnyClientCert-2]
_ = x[VerifyClientCertIfGiven-3]
_ = x[RequireAndVerifyClientCert-4]
}
const _ClientAuthType_name = "NoClientCertRequestClientCertRequireAnyClientCertVerifyClientCertIfGivenRequireAndVerifyClientCert"
var _ClientAuthType_index = [...]uint8{0, 12, 29, 49, 72, 98}
func (i ClientAuthType) String() string {
if i < 0 || i >= ClientAuthType(len(_ClientAuthType_index)-1) {
return "ClientAuthType(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _ClientAuthType_name[_ClientAuthType_index[i]:_ClientAuthType_index[i+1]]
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// TLS low level connection and record layer
package tls
import (
"bytes"
"context"
"crypto/cipher"
"crypto/subtle"
"crypto/x509"
"errors"
"fmt"
"hash"
"internal/godebug"
"io"
"net"
"sync"
"sync/atomic"
"time"
)
// A Conn represents a secured connection.
// It implements the net.Conn interface.
type Conn struct {
// constant
conn net.Conn
isClient bool
handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
quic *quicState // nil for non-QUIC connections
// isHandshakeComplete is true if the connection is currently transferring
// application data (i.e. is not currently processing a handshake).
// isHandshakeComplete is true implies handshakeErr == nil.
isHandshakeComplete atomic.Bool
// constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex
handshakeErr error // error resulting from handshake
vers uint16 // TLS version
haveVers bool // version has been negotiated
config *Config // configuration passed to constructor
// handshakes counts the number of handshakes performed on the
// connection so far. If renegotiation is disabled then this is either
// zero or one.
handshakes int
extMasterSecret bool
didResume bool // whether this connection was a session resumption
didHRR bool // whether a HelloRetryRequest was sent/received
cipherSuite uint16
curveID CurveID
peerSigAlg SignatureScheme
ocspResponse []byte // stapled OCSP response
scts [][]byte // signed certificate timestamps from server
peerCertificates []*x509.Certificate
// verifiedChains contains the certificate chains that we built, as
// opposed to the ones presented by the server.
verifiedChains [][]*x509.Certificate
// serverName contains the server name indicated by the client, if any.
serverName string
// secureRenegotiation is true if the server echoed the secure
// renegotiation extension. (This is meaningless as a server because
// renegotiation is not supported in that case.)
secureRenegotiation bool
// ekm is a closure for exporting keying material.
ekm func(label string, context []byte, length int) ([]byte, error)
// resumptionSecret is the resumption_master_secret for handling
// or sending NewSessionTicket messages.
resumptionSecret []byte
echAccepted bool
// ticketKeys is the set of active session ticket keys for this
// connection. The first one is used to encrypt new tickets and
// all are tried to decrypt tickets.
ticketKeys []ticketKey
// clientFinishedIsFirst is true if the client sent the first Finished
// message during the most recent handshake. This is recorded because
// the first transmitted Finished message is the tls-unique
// channel-binding value.
clientFinishedIsFirst bool
// closeNotifyErr is any error from sending the alertCloseNotify record.
closeNotifyErr error
// closeNotifySent is true if the Conn attempted to send an
// alertCloseNotify record.
closeNotifySent bool
// clientFinished and serverFinished contain the Finished message sent
// by the client or server in the most recent handshake. This is
// retained to support the renegotiation extension and tls-unique
// channel-binding.
clientFinished [12]byte
serverFinished [12]byte
// clientProtocol is the negotiated ALPN protocol.
clientProtocol string
// input/output
in, out halfConn
rawInput bytes.Buffer // raw input, starting with a record header
input bytes.Reader // application data waiting to be read, from rawInput.Next
hand bytes.Buffer // handshake data waiting to be read
buffering bool // whether records are buffered in sendBuf
sendBuf []byte // a buffer of records waiting to be sent
// bytesSent counts the bytes of application data sent.
// packetsSent counts packets.
bytesSent int64
packetsSent int64
// retryCount counts the number of consecutive non-advancing records
// received by Conn.readRecord. That is, records that neither advance the
// handshake, nor deliver application data. Protected by in.Mutex.
retryCount int
// activeCall indicates whether Close has been call in the low bit.
// the rest of the bits are the number of goroutines in Conn.Write.
activeCall atomic.Int32
tmp [16]byte
}
// Access to net.Conn methods.
// Cannot just embed net.Conn because that would
// export the struct field too.
// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline sets the read and write deadlines associated with the connection.
// A zero value for t means [Conn.Read] and [Conn.Write] will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
// SetReadDeadline sets the read deadline on the underlying connection.
// A zero value for t means [Conn.Read] will not time out.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetWriteDeadline sets the write deadline on the underlying connection.
// A zero value for t means [Conn.Write] will not time out.
// After a [Conn.Write] has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
// NetConn returns the underlying connection that is wrapped by c.
// Note that writing to or reading from this connection directly will corrupt the
// TLS session.
func (c *Conn) NetConn() net.Conn {
return c.conn
}
// A halfConn represents one direction of the record layer
// connection, either sending or receiving.
type halfConn struct {
sync.Mutex
err error // first permanent error
version uint16 // protocol version
cipher any // cipher algorithm
mac hash.Hash
seq [8]byte // 64-bit sequence number
scratchBuf [13]byte // to avoid allocs; interface method args escape
nextCipher any // next encryption state
nextMac hash.Hash // next MAC algorithm
level QUICEncryptionLevel // current QUIC encryption level
trafficSecret []byte // current TLS 1.3 traffic secret
}
type permanentError struct {
err net.Error
}
func (e *permanentError) Error() string { return e.err.Error() }
func (e *permanentError) Unwrap() error { return e.err }
func (e *permanentError) Timeout() bool { return e.err.Timeout() }
func (e *permanentError) Temporary() bool { return false }
func (hc *halfConn) setErrorLocked(err error) error {
if e, ok := err.(net.Error); ok {
hc.err = &permanentError{err: e}
} else {
hc.err = err
}
return hc.err
}
// prepareCipherSpec sets the encryption and MAC states
// that a subsequent changeCipherSpec will use.
func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) {
hc.version = version
hc.nextCipher = cipher
hc.nextMac = mac
}
// changeCipherSpec changes the encryption and MAC states
// to the ones previously passed to prepareCipherSpec.
func (hc *halfConn) changeCipherSpec() error {
if hc.nextCipher == nil || hc.version == VersionTLS13 {
return alertInternalError
}
hc.cipher = hc.nextCipher
hc.mac = hc.nextMac
hc.nextCipher = nil
hc.nextMac = nil
for i := range hc.seq {
hc.seq[i] = 0
}
return nil
}
func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
hc.trafficSecret = secret
hc.level = level
key, iv := suite.trafficKey(secret)
hc.cipher = suite.aead(key, iv)
for i := range hc.seq {
hc.seq[i] = 0
}
}
// incSeq increments the sequence number.
func (hc *halfConn) incSeq() {
for i := 7; i >= 0; i-- {
hc.seq[i]++
if hc.seq[i] != 0 {
return
}
}
// Not allowed to let sequence number wrap.
// Instead, must renegotiate before it does.
// Not likely enough to bother.
panic("TLS: sequence number wraparound")
}
// explicitNonceLen returns the number of bytes of explicit nonce or IV included
// in each record. Explicit nonces are present only in CBC modes after TLS 1.0
// and in certain AEAD modes in TLS 1.2.
func (hc *halfConn) explicitNonceLen() int {
if hc.cipher == nil {
return 0
}
switch c := hc.cipher.(type) {
case cipher.Stream:
return 0
case aead:
return c.explicitNonceLen()
case cbcMode:
// TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack.
if hc.version >= VersionTLS11 {
return c.BlockSize()
}
return 0
default:
panic("unknown cipher type")
}
}
// extractPadding returns, in constant time, the length of the padding to remove
// from the end of payload. It also returns a byte which is equal to 255 if the
// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2.
func extractPadding(payload []byte) (toRemove int, good byte) {
if len(payload) < 1 {
return 0, 0
}
paddingLen := payload[len(payload)-1]
t := uint(len(payload)-1) - uint(paddingLen)
// if len(payload) >= (paddingLen - 1) then the MSB of t is zero
good = byte(int32(^t) >> 31)
// The maximum possible padding length plus the actual length field
toCheck := 256
// The length of the padded data is public, so we can use an if here
if toCheck > len(payload) {
toCheck = len(payload)
}
for i := 0; i < toCheck; i++ {
t := uint(paddingLen) - uint(i)
// if i <= paddingLen then the MSB of t is zero
mask := byte(int32(^t) >> 31)
b := payload[len(payload)-1-i]
good &^= mask&paddingLen ^ mask&b
}
// We AND together the bits of good and replicate the result across
// all the bits.
good &= good << 4
good &= good << 2
good &= good << 1
good = uint8(int8(good) >> 7)
// Zero the padding length on error. This ensures any unchecked bytes
// are included in the MAC. Otherwise, an attacker that could
// distinguish MAC failures from padding failures could mount an attack
// similar to POODLE in SSL 3.0: given a good ciphertext that uses a
// full block's worth of padding, replace the final block with another
// block. If the MAC check passed but the padding check failed, the
// last byte of that block decrypted to the block size.
//
// See also macAndPaddingGood logic below.
paddingLen &= good
toRemove = int(paddingLen) + 1
return
}
func roundUp(a, b int) int {
return a + (b-a%b)%b
}
// cbcMode is an interface for block ciphers using cipher block chaining.
type cbcMode interface {
cipher.BlockMode
SetIV([]byte)
}
// decrypt authenticates and decrypts the record if protection is active at
// this stage. The returned plaintext might overlap with the input.
func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
var plaintext []byte
typ := recordType(record[0])
payload := record[recordHeaderLen:]
// In TLS 1.3, change_cipher_spec messages are to be ignored without being
// decrypted. See RFC 8446, Appendix D.4.
if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
return payload, typ, nil
}
paddingGood := byte(255)
paddingLen := 0
explicitNonceLen := hc.explicitNonceLen()
if hc.cipher != nil {
switch c := hc.cipher.(type) {
case cipher.Stream:
c.XORKeyStream(payload, payload)
case aead:
if len(payload) < explicitNonceLen {
return nil, 0, alertBadRecordMAC
}
nonce := payload[:explicitNonceLen]
if len(nonce) == 0 {
nonce = hc.seq[:]
}
payload = payload[explicitNonceLen:]
var additionalData []byte
if hc.version == VersionTLS13 {
additionalData = record[:recordHeaderLen]
} else {
additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
additionalData = append(additionalData, record[:3]...)
n := len(payload) - c.Overhead()
additionalData = append(additionalData, byte(n>>8), byte(n))
}
var err error
plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
if err != nil {
return nil, 0, alertBadRecordMAC
}
case cbcMode:
blockSize := c.BlockSize()
minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
if len(payload)%blockSize != 0 || len(payload) < minPayload {
return nil, 0, alertBadRecordMAC
}
if explicitNonceLen > 0 {
c.SetIV(payload[:explicitNonceLen])
payload = payload[explicitNonceLen:]
}
c.CryptBlocks(payload, payload)
// In a limited attempt to protect against CBC padding oracles like
// Lucky13, the data past paddingLen (which is secret) is passed to
// the MAC function as extra data, to be fed into the HMAC after
// computing the digest. This makes the MAC roughly constant time as
// long as the digest computation is constant time and does not
// affect the subsequent write, modulo cache effects.
paddingLen, paddingGood = extractPadding(payload)
default:
panic("unknown cipher type")
}
if hc.version == VersionTLS13 {
if typ != recordTypeApplicationData {
return nil, 0, alertUnexpectedMessage
}
if len(plaintext) > maxPlaintext+1 {
return nil, 0, alertRecordOverflow
}
// Remove padding and find the ContentType scanning from the end.
for i := len(plaintext) - 1; i >= 0; i-- {
if plaintext[i] != 0 {
typ = recordType(plaintext[i])
plaintext = plaintext[:i]
break
}
if i == 0 {
return nil, 0, alertUnexpectedMessage
}
}
}
} else {
plaintext = payload
}
if hc.mac != nil {
macSize := hc.mac.Size()
if len(payload) < macSize {
return nil, 0, alertBadRecordMAC
}
n := len(payload) - macSize - paddingLen
n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 }
record[3] = byte(n >> 8)
record[4] = byte(n)
remoteMAC := payload[n : n+macSize]
localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
// This is equivalent to checking the MACs and paddingGood
// separately, but in constant-time to prevent distinguishing
// padding failures from MAC failures. Depending on what value
// of paddingLen was returned on bad padding, distinguishing
// bad MAC from bad padding can lead to an attack.
//
// See also the logic at the end of extractPadding.
macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
if macAndPaddingGood != 1 {
return nil, 0, alertBadRecordMAC
}
plaintext = payload[:n]
}
hc.incSeq()
return plaintext, typ, nil
}
// sliceForAppend extends the input slice by n bytes. head is the full extended
// slice, while tail is the appended part. If the original slice has sufficient
// capacity no allocation is performed.
func sliceForAppend(in []byte, n int) (head, tail []byte) {
if total := len(in) + n; cap(in) >= total {
head = in[:total]
} else {
head = make([]byte, total)
copy(head, in)
}
tail = head[len(in):]
return
}
// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and
// appends it to record, which must already contain the record header.
func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
if hc.cipher == nil {
return append(record, payload...), nil
}
var explicitNonce []byte
if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
record, explicitNonce = sliceForAppend(record, explicitNonceLen)
if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
// The AES-GCM construction in TLS has an explicit nonce so that the
// nonce can be random. However, the nonce is only 8 bytes which is
// too small for a secure, random nonce. Therefore we use the
// sequence number as the nonce. The 3DES-CBC construction also has
// an 8 bytes nonce but its nonces must be unpredictable (see RFC
// 5246, Appendix F.3), forcing us to use randomness. That's not
// 3DES' biggest problem anyway because the birthday bound on block
// collision is reached first due to its similarly small block size
// (see the Sweet32 attack).
copy(explicitNonce, hc.seq[:])
} else {
if _, err := io.ReadFull(rand, explicitNonce); err != nil {
return nil, err
}
}
}
var dst []byte
switch c := hc.cipher.(type) {
case cipher.Stream:
mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
record, dst = sliceForAppend(record, len(payload)+len(mac))
c.XORKeyStream(dst[:len(payload)], payload)
c.XORKeyStream(dst[len(payload):], mac)
case aead:
nonce := explicitNonce
if len(nonce) == 0 {
nonce = hc.seq[:]
}
if hc.version == VersionTLS13 {
record = append(record, payload...)
// Encrypt the actual ContentType and replace the plaintext one.
record = append(record, record[0])
record[0] = byte(recordTypeApplicationData)
n := len(payload) + 1 + c.Overhead()
record[3] = byte(n >> 8)
record[4] = byte(n)
record = c.Seal(record[:recordHeaderLen],
nonce, record[recordHeaderLen:], record[:recordHeaderLen])
} else {
additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
additionalData = append(additionalData, record[:recordHeaderLen]...)
record = c.Seal(record, nonce, payload, additionalData)
}
case cbcMode:
mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
blockSize := c.BlockSize()
plaintextLen := len(payload) + len(mac)
paddingLen := blockSize - plaintextLen%blockSize
record, dst = sliceForAppend(record, plaintextLen+paddingLen)
copy(dst, payload)
copy(dst[len(payload):], mac)
for i := plaintextLen; i < len(dst); i++ {
dst[i] = byte(paddingLen - 1)
}
if len(explicitNonce) > 0 {
c.SetIV(explicitNonce)
}
c.CryptBlocks(dst, dst)
default:
panic("unknown cipher type")
}
// Update length to include nonce, MAC and any block padding needed.
n := len(record) - recordHeaderLen
record[3] = byte(n >> 8)
record[4] = byte(n)
hc.incSeq()
return record, nil
}
// RecordHeaderError is returned when a TLS record header is invalid.
type RecordHeaderError struct {
// Msg contains a human readable string that describes the error.
Msg string
// RecordHeader contains the five bytes of TLS record header that
// triggered the error.
RecordHeader [5]byte
// Conn provides the underlying net.Conn in the case that a client
// sent an initial handshake that didn't look like TLS.
// It is nil if there's already been a handshake or a TLS alert has
// been written to the connection.
Conn net.Conn
}
func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
err.Msg = msg
err.Conn = conn
copy(err.RecordHeader[:], c.rawInput.Bytes())
return err
}
func (c *Conn) readRecord() error {
return c.readRecordOrCCS(false)
}
func (c *Conn) readChangeCipherSpec() error {
return c.readRecordOrCCS(true)
}
// readRecordOrCCS reads one or more TLS records from the connection and
// updates the record layer state. Some invariants:
// - c.in must be locked
// - c.input must be empty
//
// During the handshake one and only one of the following will happen:
// - c.hand grows
// - c.in.changeCipherSpec is called
// - an error is returned
//
// After the handshake one and only one of the following will happen:
// - c.hand grows
// - c.input is set
// - an error is returned
func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
if c.in.err != nil {
return c.in.err
}
handshakeComplete := c.isHandshakeComplete.Load()
// This function modifies c.rawInput, which owns the c.input memory.
if c.input.Len() != 0 {
return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
}
c.input.Reset(nil)
if c.quic != nil {
return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
}
// Read header, payload.
if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
// RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
// is an error, but popular web sites seem to do this, so we accept it
// if and only if at the record boundary.
if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
err = io.EOF
}
if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.in.setErrorLocked(err)
}
return err
}
hdr := c.rawInput.Bytes()[:recordHeaderLen]
typ := recordType(hdr[0])
// No valid TLS record has a type of 0x80, however SSLv2 handshakes
// start with a uint16 length where the MSB is set and the first record
// is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
// an SSLv2 client.
if !handshakeComplete && typ == 0x80 {
c.sendAlert(alertProtocolVersion)
return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
}
vers := uint16(hdr[1])<<8 | uint16(hdr[2])
expectedVers := c.vers
if expectedVers == VersionTLS13 {
// All TLS 1.3 records are expected to have 0x0303 (1.2) after
// the initial hello (RFC 8446 Section 5.1).
expectedVers = VersionTLS12
}
n := int(hdr[3])<<8 | int(hdr[4])
if c.haveVers && vers != expectedVers {
c.sendAlert(alertProtocolVersion)
msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, expectedVers)
return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
}
if !c.haveVers {
// First message, be extra suspicious: this might not be a TLS
// client. Bail out before reading a full 'body', if possible.
// The current max version is 3.3 so if the version is >= 16.0,
// it's probably not real.
if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
}
}
if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
c.sendAlert(alertRecordOverflow)
msg := fmt.Sprintf("oversized record received with length %d", n)
return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
}
if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.in.setErrorLocked(err)
}
return err
}
// Process message.
record := c.rawInput.Next(recordHeaderLen + n)
data, typ, err := c.in.decrypt(record)
if err != nil {
return c.in.setErrorLocked(c.sendAlert(err.(alert)))
}
if len(data) > maxPlaintext {
return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
}
// Application Data messages are always protected.
if c.in.cipher == nil && typ == recordTypeApplicationData {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
// This is a state-advancing message: reset the retry count.
c.retryCount = 0
}
// Handshake messages MUST NOT be interleaved with other record types in TLS 1.3.
if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
switch typ {
default:
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
case recordTypeAlert:
if c.quic != nil {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
if len(data) != 2 {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
if alert(data[1]) == alertCloseNotify {
return c.in.setErrorLocked(io.EOF)
}
if c.vers == VersionTLS13 {
// TLS 1.3 removed warning-level alerts except for alertUserCanceled
// (RFC 8446, § 6.1). Since at least one major implementation
// (https://bugs.openjdk.org/browse/JDK-8323517) misuses this alert,
// many TLS stacks now ignore it outright when seen in a TLS 1.3
// handshake (e.g. BoringSSL, NSS, Rustls).
if alert(data[1]) == alertUserCanceled {
// Like TLS 1.2 alertLevelWarning alerts, we drop the record and retry.
return c.retryReadRecord(expectChangeCipherSpec)
}
return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
}
switch data[0] {
case alertLevelWarning:
// Drop the record on the floor and retry.
return c.retryReadRecord(expectChangeCipherSpec)
case alertLevelError:
return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
default:
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
case recordTypeChangeCipherSpec:
if len(data) != 1 || data[0] != 1 {
return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
}
// Handshake messages are not allowed to fragment across the CCS.
if c.hand.Len() > 0 {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
// In TLS 1.3, change_cipher_spec records are ignored until the
// Finished. See RFC 8446, Appendix D.4. Note that according to Section
// 5, a server can send a ChangeCipherSpec before its ServerHello, when
// c.vers is still unset. That's not useful though and suspicious if the
// server then selects a lower protocol version, so don't allow that.
if c.vers == VersionTLS13 {
return c.retryReadRecord(expectChangeCipherSpec)
}
if !expectChangeCipherSpec {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
if err := c.in.changeCipherSpec(); err != nil {
return c.in.setErrorLocked(c.sendAlert(err.(alert)))
}
case recordTypeApplicationData:
if !handshakeComplete || expectChangeCipherSpec {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
// Some OpenSSL servers send empty records in order to randomize the
// CBC IV. Ignore a limited number of empty records.
if len(data) == 0 {
return c.retryReadRecord(expectChangeCipherSpec)
}
// Note that data is owned by c.rawInput, following the Next call above,
// to avoid copying the plaintext. This is safe because c.rawInput is
// not read from or written to until c.input is drained.
c.input.Reset(data)
case recordTypeHandshake:
if len(data) == 0 || expectChangeCipherSpec {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
c.hand.Write(data)
}
return nil
}
// retryReadRecord recurs into readRecordOrCCS to drop a non-advancing record, like
// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3.
func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
c.retryCount++
if c.retryCount > maxUselessRecords {
c.sendAlert(alertUnexpectedMessage)
return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
}
return c.readRecordOrCCS(expectChangeCipherSpec)
}
// atLeastReader reads from R, stopping with EOF once at least N bytes have been
// read. It is different from an io.LimitedReader in that it doesn't cut short
// the last Read call, and in that it considers an early EOF an error.
type atLeastReader struct {
R io.Reader
N int64
}
func (r *atLeastReader) Read(p []byte) (int, error) {
if r.N <= 0 {
return 0, io.EOF
}
n, err := r.R.Read(p)
r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809
if r.N > 0 && err == io.EOF {
return n, io.ErrUnexpectedEOF
}
if r.N <= 0 && err == nil {
return n, io.EOF
}
return n, err
}
// readFromUntil reads from r into c.rawInput until c.rawInput contains
// at least n bytes or else returns an error.
func (c *Conn) readFromUntil(r io.Reader, n int) error {
if c.rawInput.Len() >= n {
return nil
}
needs := n - c.rawInput.Len()
// There might be extra input waiting on the wire. Make a best effort
// attempt to fetch it so that it can be used in (*Conn).Read to
// "predict" closeNotify alerts.
c.rawInput.Grow(needs + bytes.MinRead)
_, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
return err
}
// sendAlertLocked sends a TLS alert message.
func (c *Conn) sendAlertLocked(err alert) error {
if c.quic != nil {
return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
}
switch err {
case alertNoRenegotiation, alertCloseNotify:
c.tmp[0] = alertLevelWarning
default:
c.tmp[0] = alertLevelError
}
c.tmp[1] = byte(err)
_, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
if err == alertCloseNotify {
// closeNotify is a special case in that it isn't an error.
return writeErr
}
return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
}
// sendAlert sends a TLS alert message.
func (c *Conn) sendAlert(err alert) error {
c.out.Lock()
defer c.out.Unlock()
return c.sendAlertLocked(err)
}
const (
// tcpMSSEstimate is a conservative estimate of the TCP maximum segment
// size (MSS). A constant is used, rather than querying the kernel for
// the actual MSS, to avoid complexity. The value here is the IPv6
// minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40
// bytes) and a TCP header with timestamps (32 bytes).
tcpMSSEstimate = 1208
// recordSizeBoostThreshold is the number of bytes of application data
// sent after which the TLS record size will be increased to the
// maximum.
recordSizeBoostThreshold = 128 * 1024
)
// maxPayloadSizeForWrite returns the maximum TLS payload size to use for the
// next application data record. There is the following trade-off:
//
// - For latency-sensitive applications, such as web browsing, each TLS
// record should fit in one TCP segment.
// - For throughput-sensitive applications, such as large file transfers,
// larger TLS records better amortize framing and encryption overheads.
//
// A simple heuristic that works well in practice is to use small records for
// the first 1MB of data, then use larger records for subsequent data, and
// reset back to smaller records after the connection becomes idle. See "High
// Performance Web Networking", Chapter 4, or:
// https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/
//
// In the interests of simplicity and determinism, this code does not attempt
// to reset the record size once the connection is idle, however.
func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
return maxPlaintext
}
if c.bytesSent >= recordSizeBoostThreshold {
return maxPlaintext
}
// Subtract TLS overheads to get the maximum payload size.
payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
if c.out.cipher != nil {
switch ciph := c.out.cipher.(type) {
case cipher.Stream:
payloadBytes -= c.out.mac.Size()
case cipher.AEAD:
payloadBytes -= ciph.Overhead()
case cbcMode:
blockSize := ciph.BlockSize()
// The payload must fit in a multiple of blockSize, with
// room for at least one padding byte.
payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
// The MAC is appended before padding so affects the
// payload size directly.
payloadBytes -= c.out.mac.Size()
default:
panic("unknown cipher type")
}
}
if c.vers == VersionTLS13 {
payloadBytes-- // encrypted ContentType
}
// Allow packet growth in arithmetic progression up to max.
pkt := c.packetsSent
c.packetsSent++
if pkt > 1000 {
return maxPlaintext // avoid overflow in multiply below
}
n := payloadBytes * int(pkt+1)
if n > maxPlaintext {
n = maxPlaintext
}
return n
}
func (c *Conn) write(data []byte) (int, error) {
if c.buffering {
c.sendBuf = append(c.sendBuf, data...)
return len(data), nil
}
n, err := c.conn.Write(data)
c.bytesSent += int64(n)
return n, err
}
func (c *Conn) flush() (int, error) {
if len(c.sendBuf) == 0 {
return 0, nil
}
n, err := c.conn.Write(c.sendBuf)
c.bytesSent += int64(n)
c.sendBuf = nil
c.buffering = false
return n, err
}
// outBufPool pools the record-sized scratch buffers used by writeRecordLocked.
var outBufPool = sync.Pool{
New: func() any {
return new([]byte)
},
}
// writeRecordLocked writes a TLS record with the given type and payload to the
// connection and updates the record layer state.
func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
if c.quic != nil {
if typ != recordTypeHandshake {
return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
}
c.quicWriteCryptoData(c.out.level, data)
if !c.buffering {
if _, err := c.flush(); err != nil {
return 0, err
}
}
return len(data), nil
}
outBufPtr := outBufPool.Get().(*[]byte)
outBuf := *outBufPtr
defer func() {
// You might be tempted to simplify this by just passing &outBuf to Put,
// but that would make the local copy of the outBuf slice header escape
// to the heap, causing an allocation. Instead, we keep around the
// pointer to the slice header returned by Get, which is already on the
// heap, and overwrite and return that.
*outBufPtr = outBuf
outBufPool.Put(outBufPtr)
}()
var n int
for len(data) > 0 {
m := len(data)
if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
m = maxPayload
}
_, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
outBuf[0] = byte(typ)
vers := c.vers
if vers == 0 {
// Some TLS servers fail if the record version is
// greater than TLS 1.0 for the initial ClientHello.
vers = VersionTLS10
} else if vers == VersionTLS13 {
// TLS 1.3 froze the record layer version to 1.2.
// See RFC 8446, Section 5.1.
vers = VersionTLS12
}
outBuf[1] = byte(vers >> 8)
outBuf[2] = byte(vers)
outBuf[3] = byte(m >> 8)
outBuf[4] = byte(m)
var err error
outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
if err != nil {
return n, err
}
if _, err := c.write(outBuf); err != nil {
return n, err
}
n += m
data = data[m:]
}
if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
if err := c.out.changeCipherSpec(); err != nil {
return n, c.sendAlertLocked(err.(alert))
}
}
return n, nil
}
// writeHandshakeRecord writes a handshake message to the connection and updates
// the record layer state. If transcript is non-nil the marshaled message is
// written to it.
func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
c.out.Lock()
defer c.out.Unlock()
data, err := msg.marshal()
if err != nil {
return 0, err
}
if transcript != nil {
transcript.Write(data)
}
return c.writeRecordLocked(recordTypeHandshake, data)
}
// writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and
// updates the record layer state.
func (c *Conn) writeChangeCipherRecord() error {
c.out.Lock()
defer c.out.Unlock()
_, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
return err
}
// readHandshakeBytes reads handshake data until c.hand contains at least n bytes.
func (c *Conn) readHandshakeBytes(n int) error {
if c.quic != nil {
return c.quicReadHandshakeBytes(n)
}
for c.hand.Len() < n {
if err := c.readRecord(); err != nil {
return err
}
}
return nil
}
// readHandshake reads the next handshake message from
// the record layer. If transcript is non-nil, the message
// is written to the passed transcriptHash.
func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
if err := c.readHandshakeBytes(4); err != nil {
return nil, err
}
data := c.hand.Bytes()
maxHandshakeSize := maxHandshake
// hasVers indicates we're past the first message, forcing someone trying to
// make us just allocate a large buffer to at least do the initial part of
// the handshake first.
if c.haveVers && data[0] == typeCertificate {
// Since certificate messages are likely to be the only messages that
// can be larger than maxHandshake, we use a special limit for just
// those messages.
maxHandshakeSize = maxHandshakeCertificateMsg
}
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshakeSize {
c.sendAlertLocked(alertInternalError)
return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshakeSize))
}
if err := c.readHandshakeBytes(4 + n); err != nil {
return nil, err
}
data = c.hand.Next(4 + n)
return c.unmarshalHandshakeMessage(data, transcript)
}
func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
var m handshakeMessage
switch data[0] {
case typeHelloRequest:
m = new(helloRequestMsg)
case typeClientHello:
m = new(clientHelloMsg)
case typeServerHello:
m = new(serverHelloMsg)
case typeNewSessionTicket:
if c.vers == VersionTLS13 {
m = new(newSessionTicketMsgTLS13)
} else {
m = new(newSessionTicketMsg)
}
case typeCertificate:
if c.vers == VersionTLS13 {
m = new(certificateMsgTLS13)
} else {
m = new(certificateMsg)
}
case typeCertificateRequest:
if c.vers == VersionTLS13 {
m = new(certificateRequestMsgTLS13)
} else {
m = &certificateRequestMsg{
hasSignatureAlgorithm: c.vers >= VersionTLS12,
}
}
case typeCertificateStatus:
m = new(certificateStatusMsg)
case typeServerKeyExchange:
m = new(serverKeyExchangeMsg)
case typeServerHelloDone:
m = new(serverHelloDoneMsg)
case typeClientKeyExchange:
m = new(clientKeyExchangeMsg)
case typeCertificateVerify:
m = &certificateVerifyMsg{
hasSignatureAlgorithm: c.vers >= VersionTLS12,
}
case typeFinished:
m = new(finishedMsg)
case typeEncryptedExtensions:
m = new(encryptedExtensionsMsg)
case typeEndOfEarlyData:
m = new(endOfEarlyDataMsg)
case typeKeyUpdate:
m = new(keyUpdateMsg)
default:
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
// The handshake message unmarshalers
// expect to be able to keep references to data,
// so pass in a fresh copy that won't be overwritten.
data = append([]byte(nil), data...)
if !m.unmarshal(data) {
return nil, c.in.setErrorLocked(c.sendAlert(alertDecodeError))
}
if transcript != nil {
transcript.Write(data)
}
return m, nil
}
var (
errShutdown = errors.New("tls: protocol is shutdown")
)
// Write writes data to the connection.
//
// As Write calls [Conn.Handshake], in order to prevent indefinite blocking a deadline
// must be set for both [Conn.Read] and Write before Write is called when the handshake
// has not yet completed. See [Conn.SetDeadline], [Conn.SetReadDeadline], and
// [Conn.SetWriteDeadline].
func (c *Conn) Write(b []byte) (int, error) {
// interlock with Close below
for {
x := c.activeCall.Load()
if x&1 != 0 {
return 0, net.ErrClosed
}
if c.activeCall.CompareAndSwap(x, x+2) {
break
}
}
defer c.activeCall.Add(-2)
if err := c.Handshake(); err != nil {
return 0, err
}
c.out.Lock()
defer c.out.Unlock()
if err := c.out.err; err != nil {
return 0, err
}
if !c.isHandshakeComplete.Load() {
return 0, alertInternalError
}
if c.closeNotifySent {
return 0, errShutdown
}
// TLS 1.0 is susceptible to a chosen-plaintext
// attack when using block mode ciphers due to predictable IVs.
// This can be prevented by splitting each Application Data
// record into two records, effectively randomizing the IV.
//
// https://www.openssl.org/~bodo/tls-cbc.txt
// https://bugzilla.mozilla.org/show_bug.cgi?id=665814
// https://www.imperialviolet.org/2012/01/15/beastfollowup.html
var m int
if len(b) > 1 && c.vers == VersionTLS10 {
if _, ok := c.out.cipher.(cipher.BlockMode); ok {
n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
if err != nil {
return n, c.out.setErrorLocked(err)
}
m, b = 1, b[1:]
}
}
n, err := c.writeRecordLocked(recordTypeApplicationData, b)
return n + m, c.out.setErrorLocked(err)
}
// handleRenegotiation processes a HelloRequest handshake message.
func (c *Conn) handleRenegotiation() error {
if c.vers == VersionTLS13 {
return errors.New("tls: internal error: unexpected renegotiation")
}
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
helloReq, ok := msg.(*helloRequestMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(helloReq, msg)
}
if !c.isClient {
return c.sendAlert(alertNoRenegotiation)
}
switch c.config.Renegotiation {
case RenegotiateNever:
return c.sendAlert(alertNoRenegotiation)
case RenegotiateOnceAsClient:
if c.handshakes > 1 {
return c.sendAlert(alertNoRenegotiation)
}
case RenegotiateFreelyAsClient:
// Ok.
default:
c.sendAlert(alertInternalError)
return errors.New("tls: unknown Renegotiation value")
}
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
c.isHandshakeComplete.Store(false)
if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
c.handshakes++
}
return c.handshakeErr
}
// handlePostHandshakeMessage processes a handshake message arrived after the
// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation.
func (c *Conn) handlePostHandshakeMessage() error {
if c.vers != VersionTLS13 {
return c.handleRenegotiation()
}
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
c.retryCount++
if c.retryCount > maxUselessRecords {
c.sendAlert(alertUnexpectedMessage)
return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
}
switch msg := msg.(type) {
case *newSessionTicketMsgTLS13:
return c.handleNewSessionTicket(msg)
case *keyUpdateMsg:
return c.handleKeyUpdate(msg)
}
// The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest
// as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an
// unexpected_message alert here doesn't provide it with enough information to distinguish
// this condition from other unexpected messages. This is probably fine.
c.sendAlert(alertUnexpectedMessage)
return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
}
func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
if c.quic != nil {
c.sendAlert(alertUnexpectedMessage)
return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
}
cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
if cipherSuite == nil {
return c.in.setErrorLocked(c.sendAlert(alertInternalError))
}
newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
if keyUpdate.updateRequested {
c.out.Lock()
defer c.out.Unlock()
msg := &keyUpdateMsg{}
msgBytes, err := msg.marshal()
if err != nil {
return err
}
_, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
if err != nil {
// Surface the error at the next write.
c.out.setErrorLocked(err)
return nil
}
newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
}
return nil
}
// Read reads data from the connection.
//
// As Read calls [Conn.Handshake], in order to prevent indefinite blocking a deadline
// must be set for both Read and [Conn.Write] before Read is called when the handshake
// has not yet completed. See [Conn.SetDeadline], [Conn.SetReadDeadline], and
// [Conn.SetWriteDeadline].
func (c *Conn) Read(b []byte) (int, error) {
if err := c.Handshake(); err != nil {
return 0, err
}
if len(b) == 0 {
// Put this after Handshake, in case people were calling
// Read(nil) for the side effect of the Handshake.
return 0, nil
}
c.in.Lock()
defer c.in.Unlock()
for c.input.Len() == 0 {
if err := c.readRecord(); err != nil {
return 0, err
}
for c.hand.Len() > 0 {
if err := c.handlePostHandshakeMessage(); err != nil {
return 0, err
}
}
}
n, _ := c.input.Read(b)
// If a close-notify alert is waiting, read it so that we can return (n,
// EOF) instead of (n, nil), to signal to the HTTP response reading
// goroutine that the connection is now closed. This eliminates a race
// where the HTTP response reading goroutine would otherwise not observe
// the EOF until its next read, by which time a client goroutine might
// have already tried to reuse the HTTP connection for a new request.
// See https://golang.org/cl/76400046 and https://golang.org/issue/3514
if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
if err := c.readRecord(); err != nil {
return n, err // will be io.EOF on closeNotify
}
}
return n, nil
}
// Close closes the connection.
func (c *Conn) Close() error {
// Interlock with Conn.Write above.
var x int32
for {
x = c.activeCall.Load()
if x&1 != 0 {
return net.ErrClosed
}
if c.activeCall.CompareAndSwap(x, x|1) {
break
}
}
if x != 0 {
// io.Writer and io.Closer should not be used concurrently.
// If Close is called while a Write is currently in-flight,
// interpret that as a sign that this Close is really just
// being used to break the Write and/or clean up resources and
// avoid sending the alertCloseNotify, which may block
// waiting on handshakeMutex or the c.out mutex.
return c.conn.Close()
}
var alertErr error
if c.isHandshakeComplete.Load() {
if err := c.closeNotify(); err != nil {
alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
}
}
if err := c.conn.Close(); err != nil {
return err
}
return alertErr
}
var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
// CloseWrite shuts down the writing side of the connection. It should only be
// called once the handshake has completed and does not call CloseWrite on the
// underlying connection. Most callers should just use [Conn.Close].
func (c *Conn) CloseWrite() error {
if !c.isHandshakeComplete.Load() {
return errEarlyCloseWrite
}
return c.closeNotify()
}
func (c *Conn) closeNotify() error {
c.out.Lock()
defer c.out.Unlock()
if !c.closeNotifySent {
// Set a Write Deadline to prevent possibly blocking forever.
c.SetWriteDeadline(time.Now().Add(time.Second * 5))
c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
c.closeNotifySent = true
// Any subsequent writes will fail.
c.SetWriteDeadline(time.Now())
}
return c.closeNotifyErr
}
// Handshake runs the client or server handshake
// protocol if it has not yet been run.
//
// Most uses of this package need not call Handshake explicitly: the
// first [Conn.Read] or [Conn.Write] will call it automatically.
//
// For control over canceling or setting a timeout on a handshake, use
// [Conn.HandshakeContext] or the [Dialer]'s DialContext method instead.
//
// In order to avoid denial of service attacks, the maximum RSA key size allowed
// in certificates sent by either the TLS server or client is limited to 8192
// bits. This limit can be overridden by setting tlsmaxrsasize in the GODEBUG
// environment variable (e.g. GODEBUG=tlsmaxrsasize=4096).
func (c *Conn) Handshake() error {
return c.HandshakeContext(context.Background())
}
// HandshakeContext runs the client or server handshake
// protocol if it has not yet been run.
//
// The provided Context must be non-nil. If the context is canceled before
// the handshake is complete, the handshake is interrupted and an error is returned.
// Once the handshake has completed, cancellation of the context will not affect the
// connection.
//
// Most uses of this package need not call HandshakeContext explicitly: the
// first [Conn.Read] or [Conn.Write] will call it automatically.
func (c *Conn) HandshakeContext(ctx context.Context) error {
// Delegate to unexported method for named return
// without confusing documented signature.
return c.handshakeContext(ctx)
}
func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
// Fast sync/atomic-based exit if there is no handshake in flight and the
// last one succeeded without an error. Avoids the expensive context setup
// and mutex for most Read and Write calls.
if c.isHandshakeComplete.Load() {
return nil
}
handshakeCtx, cancel := context.WithCancel(ctx)
// Note: defer this before starting the "interrupter" goroutine
// so that we can tell the difference between the input being canceled and
// this cancellation. In the former case, we need to close the connection.
defer cancel()
if c.quic != nil {
c.quic.cancelc = handshakeCtx.Done()
c.quic.cancel = cancel
} else if ctx.Done() != nil {
// Start the "interrupter" goroutine, if this context might be canceled.
// (The background context cannot).
//
// The interrupter goroutine waits for the input context to be done and
// closes the connection if this happens before the function returns.
done := make(chan struct{})
interruptRes := make(chan error, 1)
defer func() {
close(done)
if ctxErr := <-interruptRes; ctxErr != nil {
// Return context error to user.
ret = ctxErr
}
}()
go func() {
select {
case <-handshakeCtx.Done():
// Close the connection, discarding the error
_ = c.conn.Close()
interruptRes <- handshakeCtx.Err()
case <-done:
interruptRes <- nil
}
}()
}
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if err := c.handshakeErr; err != nil {
return err
}
if c.isHandshakeComplete.Load() {
return nil
}
c.in.Lock()
defer c.in.Unlock()
c.handshakeErr = c.handshakeFn(handshakeCtx)
if c.handshakeErr == nil {
c.handshakes++
} else {
// If an error occurred during the handshake try to flush the
// alert that might be left in the buffer.
c.flush()
}
if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
}
if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
panic("tls: internal error: handshake returned an error but is marked successful")
}
if c.quic != nil {
if c.handshakeErr == nil {
c.quicHandshakeComplete()
// Provide the 1-RTT read secret now that the handshake is complete.
// The QUIC layer MUST NOT decrypt 1-RTT packets prior to completing
// the handshake (RFC 9001, Section 5.7).
c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
} else {
var a alert
c.out.Lock()
if !errors.As(c.out.err, &a) {
a = alertInternalError
}
c.out.Unlock()
// Return an error which wraps both the handshake error and
// any alert error we may have sent, or alertInternalError
// if we didn't send an alert.
// Truncate the text of the alert to 0 characters.
c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
}
close(c.quic.blockedc)
close(c.quic.signalc)
}
return c.handshakeErr
}
// ConnectionState returns basic TLS details about the connection.
func (c *Conn) ConnectionState() ConnectionState {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
return c.connectionStateLocked()
}
var tlsunsafeekm = godebug.New("tlsunsafeekm")
func (c *Conn) connectionStateLocked() ConnectionState {
var state ConnectionState
state.HandshakeComplete = c.isHandshakeComplete.Load()
state.Version = c.vers
state.NegotiatedProtocol = c.clientProtocol
state.DidResume = c.didResume
state.testingOnlyDidHRR = c.didHRR
state.testingOnlyPeerSignatureAlgorithm = c.peerSigAlg
state.CurveID = c.curveID
state.NegotiatedProtocolIsMutual = true
state.ServerName = c.serverName
state.CipherSuite = c.cipherSuite
state.PeerCertificates = c.peerCertificates
state.VerifiedChains = c.verifiedChains
state.SignedCertificateTimestamps = c.scts
state.OCSPResponse = c.ocspResponse
if (!c.didResume || c.extMasterSecret) && c.vers != VersionTLS13 {
if c.clientFinishedIsFirst {
state.TLSUnique = c.clientFinished[:]
} else {
state.TLSUnique = c.serverFinished[:]
}
}
if c.config.Renegotiation != RenegotiateNever {
state.ekm = noEKMBecauseRenegotiation
} else if c.vers != VersionTLS13 && !c.extMasterSecret {
state.ekm = func(label string, context []byte, length int) ([]byte, error) {
if tlsunsafeekm.Value() == "1" {
tlsunsafeekm.IncNonDefault()
return c.ekm(label, context, length)
}
return noEKMBecauseNoEMS(label, context, length)
}
} else {
state.ekm = c.ekm
}
state.ECHAccepted = c.echAccepted
return state
}
// OCSPResponse returns the stapled OCSP response from the TLS server, if
// any. (Only valid for client connections.)
func (c *Conn) OCSPResponse() []byte {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
return c.ocspResponse
}
// VerifyHostname checks that the peer certificate chain is valid for
// connecting to host. If so, it returns nil; if not, it returns an error
// describing the problem.
func (c *Conn) VerifyHostname(host string) error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if !c.isClient {
return errors.New("tls: VerifyHostname called on TLS server connection")
}
if !c.isHandshakeComplete.Load() {
return errors.New("tls: handshake has not yet been performed")
}
if len(c.verifiedChains) == 0 {
return errors.New("tls: handshake did not verify certificate chain")
}
return c.peerCertificates[0].VerifyHostname(host)
}
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"internal/godebug"
"slices"
_ "unsafe" // for linkname
)
// Defaults are collected in this file to allow distributions to more easily patch
// them to apply local policies.
var tlsmlkem = godebug.New("tlsmlkem")
// defaultCurvePreferences is the default set of supported key exchanges, as
// well as the preference order.
func defaultCurvePreferences() []CurveID {
if tlsmlkem.Value() == "0" {
return []CurveID{X25519, CurveP256, CurveP384, CurveP521}
}
return []CurveID{X25519MLKEM768, X25519, CurveP256, CurveP384, CurveP521}
}
// defaultSupportedSignatureAlgorithms returns the signature and hash algorithms that
// the code advertises and supports in a TLS 1.2+ ClientHello and in a TLS 1.2+
// CertificateRequest. The two fields are merged to match with TLS 1.3.
// Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc.
func defaultSupportedSignatureAlgorithms() []SignatureScheme {
return []SignatureScheme{
PSSWithSHA256,
ECDSAWithP256AndSHA256,
Ed25519,
PSSWithSHA384,
PSSWithSHA512,
PKCS1WithSHA256,
PKCS1WithSHA384,
PKCS1WithSHA512,
ECDSAWithP384AndSHA384,
ECDSAWithP521AndSHA512,
PKCS1WithSHA1,
ECDSAWithSHA1,
}
}
var tlsrsakex = godebug.New("tlsrsakex")
var tls3des = godebug.New("tls3des")
func supportedCipherSuites(aesGCMPreferred bool) []uint16 {
if aesGCMPreferred {
return slices.Clone(cipherSuitesPreferenceOrder)
} else {
return slices.Clone(cipherSuitesPreferenceOrderNoAES)
}
}
func defaultCipherSuites(aesGCMPreferred bool) []uint16 {
cipherSuites := supportedCipherSuites(aesGCMPreferred)
return slices.DeleteFunc(cipherSuites, func(c uint16) bool {
return disabledCipherSuites[c] ||
tlsrsakex.Value() != "1" && rsaKexCiphers[c] ||
tls3des.Value() != "1" && tdesCiphers[c]
})
}
// defaultCipherSuitesTLS13 is also the preference order, since there are no
// disabled by default TLS 1.3 cipher suites. The same AES vs ChaCha20 logic as
// cipherSuitesPreferenceOrder applies.
//
// defaultCipherSuitesTLS13 should be an internal detail,
// but widely used packages access it using linkname.
// Notable members of the hall of shame include:
// - github.com/quic-go/quic-go
// - github.com/sagernet/quic-go
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname defaultCipherSuitesTLS13
var defaultCipherSuitesTLS13 = []uint16{
TLS_AES_128_GCM_SHA256,
TLS_AES_256_GCM_SHA384,
TLS_CHACHA20_POLY1305_SHA256,
}
// defaultCipherSuitesTLS13NoAES should be an internal detail,
// but widely used packages access it using linkname.
// Notable members of the hall of shame include:
// - github.com/quic-go/quic-go
// - github.com/sagernet/quic-go
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname defaultCipherSuitesTLS13NoAES
var defaultCipherSuitesTLS13NoAES = []uint16{
TLS_CHACHA20_POLY1305_SHA256,
TLS_AES_128_GCM_SHA256,
TLS_AES_256_GCM_SHA384,
}
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !boringcrypto
package tls
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"crypto/x509"
)
// These FIPS 140-3 policies allow anything approved by SP 800-140C
// and SP 800-140D, and tested as part of the Go Cryptographic Module.
//
// Notably, not SHA-1, 3DES, RC4, ChaCha20Poly1305, RSA PKCS #1 v1.5 key
// transport, or TLS 1.0—1.1 (because we don't test its KDF).
//
// These are not default lists, but filters to apply to the default or
// configured lists. Missing items are treated as if they were not implemented.
//
// They are applied when the fips140 GODEBUG is "on" or "only".
var (
allowedSupportedVersionsFIPS = []uint16{
VersionTLS12,
VersionTLS13,
}
allowedCurvePreferencesFIPS = []CurveID{
X25519MLKEM768,
CurveP256,
CurveP384,
CurveP521,
}
allowedSignatureAlgorithmsFIPS = []SignatureScheme{
PSSWithSHA256,
ECDSAWithP256AndSHA256,
Ed25519,
PSSWithSHA384,
PSSWithSHA512,
PKCS1WithSHA256,
PKCS1WithSHA384,
PKCS1WithSHA512,
ECDSAWithP384AndSHA384,
ECDSAWithP521AndSHA512,
}
allowedCipherSuitesFIPS = []uint16{
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
}
allowedCipherSuitesTLS13FIPS = []uint16{
TLS_AES_128_GCM_SHA256,
TLS_AES_256_GCM_SHA384,
}
)
func isCertificateAllowedFIPS(c *x509.Certificate) bool {
switch k := c.PublicKey.(type) {
case *rsa.PublicKey:
return k.N.BitLen() >= 2048
case *ecdsa.PublicKey:
return k.Curve == elliptic.P256() || k.Curve == elliptic.P384() || k.Curve == elliptic.P521()
case ed25519.PublicKey:
return true
default:
return false
}
}
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"crypto/internal/hpke"
"errors"
"fmt"
"slices"
"strings"
"golang.org/x/crypto/cryptobyte"
)
// sortedSupportedAEADs is just a sorted version of hpke.SupportedAEADS.
// We need this so that when we insert them into ECHConfigs the ordering
// is stable.
var sortedSupportedAEADs []uint16
func init() {
for aeadID := range hpke.SupportedAEADs {
sortedSupportedAEADs = append(sortedSupportedAEADs, aeadID)
}
slices.Sort(sortedSupportedAEADs)
}
type echCipher struct {
KDFID uint16
AEADID uint16
}
type echExtension struct {
Type uint16
Data []byte
}
type echConfig struct {
raw []byte
Version uint16
Length uint16
ConfigID uint8
KemID uint16
PublicKey []byte
SymmetricCipherSuite []echCipher
MaxNameLength uint8
PublicName []byte
Extensions []echExtension
}
var errMalformedECHConfigList = errors.New("tls: malformed ECHConfigList")
type echConfigErr struct {
field string
}
func (e *echConfigErr) Error() string {
if e.field == "" {
return "tls: malformed ECHConfig"
}
return fmt.Sprintf("tls: malformed ECHConfig, invalid %s field", e.field)
}
func parseECHConfig(enc []byte) (skip bool, ec echConfig, err error) {
s := cryptobyte.String(enc)
ec.raw = []byte(enc)
if !s.ReadUint16(&ec.Version) {
return false, echConfig{}, &echConfigErr{"version"}
}
if !s.ReadUint16(&ec.Length) {
return false, echConfig{}, &echConfigErr{"length"}
}
if len(ec.raw) < int(ec.Length)+4 {
return false, echConfig{}, &echConfigErr{"length"}
}
ec.raw = ec.raw[:ec.Length+4]
if ec.Version != extensionEncryptedClientHello {
s.Skip(int(ec.Length))
return true, echConfig{}, nil
}
if !s.ReadUint8(&ec.ConfigID) {
return false, echConfig{}, &echConfigErr{"config_id"}
}
if !s.ReadUint16(&ec.KemID) {
return false, echConfig{}, &echConfigErr{"kem_id"}
}
if !readUint16LengthPrefixed(&s, &ec.PublicKey) {
return false, echConfig{}, &echConfigErr{"public_key"}
}
var cipherSuites cryptobyte.String
if !s.ReadUint16LengthPrefixed(&cipherSuites) {
return false, echConfig{}, &echConfigErr{"cipher_suites"}
}
for !cipherSuites.Empty() {
var c echCipher
if !cipherSuites.ReadUint16(&c.KDFID) {
return false, echConfig{}, &echConfigErr{"cipher_suites kdf_id"}
}
if !cipherSuites.ReadUint16(&c.AEADID) {
return false, echConfig{}, &echConfigErr{"cipher_suites aead_id"}
}
ec.SymmetricCipherSuite = append(ec.SymmetricCipherSuite, c)
}
if !s.ReadUint8(&ec.MaxNameLength) {
return false, echConfig{}, &echConfigErr{"maximum_name_length"}
}
var publicName cryptobyte.String
if !s.ReadUint8LengthPrefixed(&publicName) {
return false, echConfig{}, &echConfigErr{"public_name"}
}
ec.PublicName = publicName
var extensions cryptobyte.String
if !s.ReadUint16LengthPrefixed(&extensions) {
return false, echConfig{}, &echConfigErr{"extensions"}
}
for !extensions.Empty() {
var e echExtension
if !extensions.ReadUint16(&e.Type) {
return false, echConfig{}, &echConfigErr{"extensions type"}
}
if !extensions.ReadUint16LengthPrefixed((*cryptobyte.String)(&e.Data)) {
return false, echConfig{}, &echConfigErr{"extensions data"}
}
ec.Extensions = append(ec.Extensions, e)
}
return false, ec, nil
}
// parseECHConfigList parses a draft-ietf-tls-esni-18 ECHConfigList, returning a
// slice of parsed ECHConfigs, in the same order they were parsed, or an error
// if the list is malformed.
func parseECHConfigList(data []byte) ([]echConfig, error) {
s := cryptobyte.String(data)
var length uint16
if !s.ReadUint16(&length) {
return nil, errMalformedECHConfigList
}
if length != uint16(len(data)-2) {
return nil, errMalformedECHConfigList
}
var configs []echConfig
for len(s) > 0 {
if len(s) < 4 {
return nil, errors.New("tls: malformed ECHConfig")
}
configLen := uint16(s[2])<<8 | uint16(s[3])
skip, ec, err := parseECHConfig(s)
if err != nil {
return nil, err
}
s = s[configLen+4:]
if !skip {
configs = append(configs, ec)
}
}
return configs, nil
}
func pickECHConfig(list []echConfig) *echConfig {
for _, ec := range list {
if _, ok := hpke.SupportedKEMs[ec.KemID]; !ok {
continue
}
var validSCS bool
for _, cs := range ec.SymmetricCipherSuite {
if _, ok := hpke.SupportedAEADs[cs.AEADID]; !ok {
continue
}
if _, ok := hpke.SupportedKDFs[cs.KDFID]; !ok {
continue
}
validSCS = true
break
}
if !validSCS {
continue
}
if !validDNSName(string(ec.PublicName)) {
continue
}
var unsupportedExt bool
for _, ext := range ec.Extensions {
// If high order bit is set to 1 the extension is mandatory.
// Since we don't support any extensions, if we see a mandatory
// bit, we skip the config.
if ext.Type&uint16(1<<15) != 0 {
unsupportedExt = true
}
}
if unsupportedExt {
continue
}
return &ec
}
return nil
}
func pickECHCipherSuite(suites []echCipher) (echCipher, error) {
for _, s := range suites {
// NOTE: all of the supported AEADs and KDFs are fine, rather than
// imposing some sort of preference here, we just pick the first valid
// suite.
if _, ok := hpke.SupportedAEADs[s.AEADID]; !ok {
continue
}
if _, ok := hpke.SupportedKDFs[s.KDFID]; !ok {
continue
}
return s, nil
}
return echCipher{}, errors.New("tls: no supported symmetric ciphersuites for ECH")
}
func encodeInnerClientHello(inner *clientHelloMsg, maxNameLength int) ([]byte, error) {
h, err := inner.marshalMsg(true)
if err != nil {
return nil, err
}
h = h[4:] // strip four byte prefix
var paddingLen int
if inner.serverName != "" {
paddingLen = max(0, maxNameLength-len(inner.serverName))
} else {
paddingLen = maxNameLength + 9
}
paddingLen = 31 - ((len(h) + paddingLen - 1) % 32)
return append(h, make([]byte, paddingLen)...), nil
}
func skipUint8LengthPrefixed(s *cryptobyte.String) bool {
var skip uint8
if !s.ReadUint8(&skip) {
return false
}
return s.Skip(int(skip))
}
func skipUint16LengthPrefixed(s *cryptobyte.String) bool {
var skip uint16
if !s.ReadUint16(&skip) {
return false
}
return s.Skip(int(skip))
}
type rawExtension struct {
extType uint16
data []byte
}
func extractRawExtensions(hello *clientHelloMsg) ([]rawExtension, error) {
s := cryptobyte.String(hello.original)
if !s.Skip(4+2+32) || // header, version, random
!skipUint8LengthPrefixed(&s) || // session ID
!skipUint16LengthPrefixed(&s) || // cipher suites
!skipUint8LengthPrefixed(&s) { // compression methods
return nil, errors.New("tls: malformed outer client hello")
}
var rawExtensions []rawExtension
var extensions cryptobyte.String
if !s.ReadUint16LengthPrefixed(&extensions) {
return nil, errors.New("tls: malformed outer client hello")
}
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return nil, errors.New("tls: invalid inner client hello")
}
rawExtensions = append(rawExtensions, rawExtension{extension, extData})
}
return rawExtensions, nil
}
func decodeInnerClientHello(outer *clientHelloMsg, encoded []byte) (*clientHelloMsg, error) {
// Reconstructing the inner client hello from its encoded form is somewhat
// complicated. It is missing its header (message type and length), session
// ID, and the extensions may be compressed. Since we need to put the
// extensions back in the same order as they were in the raw outer hello,
// and since we don't store the raw extensions, or the order we parsed them
// in, we need to reparse the raw extensions from the outer hello in order
// to properly insert them into the inner hello. This _should_ result in raw
// bytes which match the hello as it was generated by the client.
innerReader := cryptobyte.String(encoded)
var versionAndRandom, sessionID, cipherSuites, compressionMethods []byte
var extensions cryptobyte.String
if !innerReader.ReadBytes(&versionAndRandom, 2+32) ||
!readUint8LengthPrefixed(&innerReader, &sessionID) ||
len(sessionID) != 0 ||
!readUint16LengthPrefixed(&innerReader, &cipherSuites) ||
!readUint8LengthPrefixed(&innerReader, &compressionMethods) ||
!innerReader.ReadUint16LengthPrefixed(&extensions) {
return nil, errors.New("tls: invalid inner client hello")
}
// The specification says we must verify that the trailing padding is all
// zeros. This is kind of weird for TLS messages, where we generally just
// throw away any trailing garbage.
for _, p := range innerReader {
if p != 0 {
return nil, errors.New("tls: invalid inner client hello")
}
}
rawOuterExts, err := extractRawExtensions(outer)
if err != nil {
return nil, err
}
recon := cryptobyte.NewBuilder(nil)
recon.AddUint8(typeClientHello)
recon.AddUint24LengthPrefixed(func(recon *cryptobyte.Builder) {
recon.AddBytes(versionAndRandom)
recon.AddUint8LengthPrefixed(func(recon *cryptobyte.Builder) {
recon.AddBytes(outer.sessionId)
})
recon.AddUint16LengthPrefixed(func(recon *cryptobyte.Builder) {
recon.AddBytes(cipherSuites)
})
recon.AddUint8LengthPrefixed(func(recon *cryptobyte.Builder) {
recon.AddBytes(compressionMethods)
})
recon.AddUint16LengthPrefixed(func(recon *cryptobyte.Builder) {
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
recon.SetError(errors.New("tls: invalid inner client hello"))
return
}
if extension == extensionECHOuterExtensions {
if !extData.ReadUint8LengthPrefixed(&extData) {
recon.SetError(errors.New("tls: invalid inner client hello"))
return
}
var i int
for !extData.Empty() {
var extType uint16
if !extData.ReadUint16(&extType) {
recon.SetError(errors.New("tls: invalid inner client hello"))
return
}
if extType == extensionEncryptedClientHello {
recon.SetError(errors.New("tls: invalid outer extensions"))
return
}
for ; i <= len(rawOuterExts); i++ {
if i == len(rawOuterExts) {
recon.SetError(errors.New("tls: invalid outer extensions"))
return
}
if rawOuterExts[i].extType == extType {
break
}
}
recon.AddUint16(rawOuterExts[i].extType)
recon.AddUint16LengthPrefixed(func(recon *cryptobyte.Builder) {
recon.AddBytes(rawOuterExts[i].data)
})
}
} else {
recon.AddUint16(extension)
recon.AddUint16LengthPrefixed(func(recon *cryptobyte.Builder) {
recon.AddBytes(extData)
})
}
}
})
})
reconBytes, err := recon.Bytes()
if err != nil {
return nil, err
}
inner := &clientHelloMsg{}
if !inner.unmarshal(reconBytes) {
return nil, errors.New("tls: invalid reconstructed inner client hello")
}
if !bytes.Equal(inner.encryptedClientHello, []byte{uint8(innerECHExt)}) {
return nil, errInvalidECHExt
}
hasTLS13 := false
for _, v := range inner.supportedVersions {
// Skip GREASE values (values of the form 0x?A0A).
// GREASE (Generate Random Extensions And Sustain Extensibility) is a mechanism used by
// browsers like Chrome to ensure TLS implementations correctly ignore unknown values.
// GREASE values follow a specific pattern: 0x?A0A, where ? can be any hex digit.
// These values should be ignored when processing supported TLS versions.
if v&0x0F0F == 0x0A0A && v&0xff == v>>8 {
continue
}
// Ensure at least TLS 1.3 is offered.
if v == VersionTLS13 {
hasTLS13 = true
} else if v < VersionTLS13 {
// Reject if any non-GREASE value is below TLS 1.3, as ECH requires TLS 1.3+.
return nil, errors.New("tls: client sent encrypted_client_hello extension with unsupported versions")
}
}
if !hasTLS13 {
return nil, errors.New("tls: client sent encrypted_client_hello extension but did not offer TLS 1.3")
}
return inner, nil
}
func decryptECHPayload(context *hpke.Recipient, hello, payload []byte) ([]byte, error) {
outerAAD := bytes.Replace(hello[4:], payload, make([]byte, len(payload)), 1)
return context.Open(outerAAD, payload)
}
func generateOuterECHExt(id uint8, kdfID, aeadID uint16, encodedKey []byte, payload []byte) ([]byte, error) {
var b cryptobyte.Builder
b.AddUint8(0) // outer
b.AddUint16(kdfID)
b.AddUint16(aeadID)
b.AddUint8(id)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { b.AddBytes(encodedKey) })
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { b.AddBytes(payload) })
return b.Bytes()
}
func computeAndUpdateOuterECHExtension(outer, inner *clientHelloMsg, ech *echClientContext, useKey bool) error {
var encapKey []byte
if useKey {
encapKey = ech.encapsulatedKey
}
encodedInner, err := encodeInnerClientHello(inner, int(ech.config.MaxNameLength))
if err != nil {
return err
}
// NOTE: the tag lengths for all of the supported AEADs are the same (16
// bytes), so we have hardcoded it here. If we add support for another AEAD
// with a different tag length, we will need to change this.
encryptedLen := len(encodedInner) + 16 // AEAD tag length
outer.encryptedClientHello, err = generateOuterECHExt(ech.config.ConfigID, ech.kdfID, ech.aeadID, encapKey, make([]byte, encryptedLen))
if err != nil {
return err
}
serializedOuter, err := outer.marshal()
if err != nil {
return err
}
serializedOuter = serializedOuter[4:] // strip the four byte prefix
encryptedInner, err := ech.hpkeContext.Seal(serializedOuter, encodedInner)
if err != nil {
return err
}
outer.encryptedClientHello, err = generateOuterECHExt(ech.config.ConfigID, ech.kdfID, ech.aeadID, encapKey, encryptedInner)
if err != nil {
return err
}
return nil
}
// validDNSName is a rather rudimentary check for the validity of a DNS name.
// This is used to check if the public_name in a ECHConfig is valid when we are
// picking a config. This can be somewhat lax because even if we pick a
// valid-looking name, the DNS layer will later reject it anyway.
func validDNSName(name string) bool {
if len(name) > 253 {
return false
}
labels := strings.Split(name, ".")
if len(labels) <= 1 {
return false
}
for _, l := range labels {
labelLen := len(l)
if labelLen == 0 {
return false
}
for i, r := range l {
if r == '-' && (i == 0 || i == labelLen-1) {
return false
}
if (r < '0' || r > '9') && (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && r != '-' {
return false
}
}
}
return true
}
// ECHRejectionError is the error type returned when ECH is rejected by a remote
// server. If the server offered a ECHConfigList to use for retries, the
// RetryConfigList field will contain this list.
//
// The client may treat an ECHRejectionError with an empty set of RetryConfigs
// as a secure signal from the server.
type ECHRejectionError struct {
RetryConfigList []byte
}
func (e *ECHRejectionError) Error() string {
return "tls: server rejected ECH"
}
var errMalformedECHExt = errors.New("tls: malformed encrypted_client_hello extension")
var errInvalidECHExt = errors.New("tls: client sent invalid encrypted_client_hello extension")
type echExtType uint8
const (
innerECHExt echExtType = 1
outerECHExt echExtType = 0
)
func parseECHExt(ext []byte) (echType echExtType, cs echCipher, configID uint8, encap []byte, payload []byte, err error) {
data := make([]byte, len(ext))
copy(data, ext)
s := cryptobyte.String(data)
var echInt uint8
if !s.ReadUint8(&echInt) {
err = errMalformedECHExt
return
}
echType = echExtType(echInt)
if echType == innerECHExt {
if !s.Empty() {
err = errMalformedECHExt
return
}
return echType, cs, 0, nil, nil, nil
}
if echType != outerECHExt {
err = errInvalidECHExt
return
}
if !s.ReadUint16(&cs.KDFID) {
err = errMalformedECHExt
return
}
if !s.ReadUint16(&cs.AEADID) {
err = errMalformedECHExt
return
}
if !s.ReadUint8(&configID) {
err = errMalformedECHExt
return
}
if !readUint16LengthPrefixed(&s, &encap) {
err = errMalformedECHExt
return
}
if !readUint16LengthPrefixed(&s, &payload) {
err = errMalformedECHExt
return
}
// NOTE: clone encap and payload so that mutating them does not mutate the
// raw extension bytes.
return echType, cs, configID, bytes.Clone(encap), bytes.Clone(payload), nil
}
func marshalEncryptedClientHelloConfigList(configs []EncryptedClientHelloKey) ([]byte, error) {
builder := cryptobyte.NewBuilder(nil)
builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) {
for _, c := range configs {
builder.AddBytes(c.Config)
}
})
return builder.Bytes()
}
func (c *Conn) processECHClientHello(outer *clientHelloMsg, echKeys []EncryptedClientHelloKey) (*clientHelloMsg, *echServerContext, error) {
echType, echCiphersuite, configID, encap, payload, err := parseECHExt(outer.encryptedClientHello)
if err != nil {
if errors.Is(err, errInvalidECHExt) {
c.sendAlert(alertIllegalParameter)
} else {
c.sendAlert(alertDecodeError)
}
return nil, nil, errInvalidECHExt
}
if echType == innerECHExt {
return outer, &echServerContext{inner: true}, nil
}
if len(echKeys) == 0 {
return outer, nil, nil
}
for _, echKey := range echKeys {
skip, config, err := parseECHConfig(echKey.Config)
if err != nil || skip {
c.sendAlert(alertInternalError)
return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKeys Config: %s", err)
}
if skip {
continue
}
echPriv, err := hpke.ParseHPKEPrivateKey(config.KemID, echKey.PrivateKey)
if err != nil {
c.sendAlert(alertInternalError)
return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKeys PrivateKey: %s", err)
}
info := append([]byte("tls ech\x00"), echKey.Config...)
hpkeContext, err := hpke.SetupRecipient(hpke.DHKEM_X25519_HKDF_SHA256, echCiphersuite.KDFID, echCiphersuite.AEADID, echPriv, info, encap)
if err != nil {
// attempt next trial decryption
continue
}
encodedInner, err := decryptECHPayload(hpkeContext, outer.original, payload)
if err != nil {
// attempt next trial decryption
continue
}
// NOTE: we do not enforce that the sent server_name matches the ECH
// configs PublicName, since this is not particularly important, and
// the client already had to know what it was in order to properly
// encrypt the payload. This is only a MAY in the spec, so we're not
// doing anything revolutionary.
echInner, err := decodeInnerClientHello(outer, encodedInner)
if err != nil {
c.sendAlert(alertIllegalParameter)
return nil, nil, errInvalidECHExt
}
c.echAccepted = true
return echInner, &echServerContext{
hpkeContext: hpkeContext,
configID: configID,
ciphersuite: echCiphersuite,
}, nil
}
return outer, nil, nil
}
func buildRetryConfigList(keys []EncryptedClientHelloKey) ([]byte, error) {
var atLeastOneRetryConfig bool
var retryBuilder cryptobyte.Builder
retryBuilder.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, c := range keys {
if !c.SendAsRetry {
continue
}
atLeastOneRetryConfig = true
b.AddBytes(c.Config)
}
})
if !atLeastOneRetryConfig {
return nil, nil
}
return retryBuilder.Bytes()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/internal/fips140/mlkem"
"crypto/internal/fips140/tls13"
"crypto/internal/hpke"
"crypto/rsa"
"crypto/subtle"
"crypto/tls/internal/fips140tls"
"crypto/x509"
"errors"
"fmt"
"hash"
"internal/godebug"
"io"
"net"
"slices"
"strconv"
"strings"
"time"
)
type clientHandshakeState struct {
c *Conn
ctx context.Context
serverHello *serverHelloMsg
hello *clientHelloMsg
suite *cipherSuite
finishedHash finishedHash
masterSecret []byte
session *SessionState // the session being resumed
ticket []byte // a fresh ticket received during this handshake
}
func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echClientContext, error) {
config := c.config
if len(config.ServerName) == 0 && !config.InsecureSkipVerify {
return nil, nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
}
nextProtosLength := 0
for _, proto := range config.NextProtos {
if l := len(proto); l == 0 || l > 255 {
return nil, nil, nil, errors.New("tls: invalid NextProtos value")
} else {
nextProtosLength += 1 + l
}
}
if nextProtosLength > 0xffff {
return nil, nil, nil, errors.New("tls: NextProtos values too large")
}
supportedVersions := config.supportedVersions(roleClient)
if len(supportedVersions) == 0 {
return nil, nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion")
}
// Since supportedVersions is sorted in descending order, the first element
// is the maximum version and the last element is the minimum version.
maxVersion := supportedVersions[0]
minVersion := supportedVersions[len(supportedVersions)-1]
hello := &clientHelloMsg{
vers: maxVersion,
compressionMethods: []uint8{compressionNone},
random: make([]byte, 32),
extendedMasterSecret: true,
ocspStapling: true,
scts: true,
serverName: hostnameInSNI(config.ServerName),
supportedCurves: config.curvePreferences(maxVersion),
supportedPoints: []uint8{pointFormatUncompressed},
secureRenegotiationSupported: true,
alpnProtocols: config.NextProtos,
supportedVersions: supportedVersions,
}
// The version at the beginning of the ClientHello was capped at TLS 1.2
// for compatibility reasons. The supported_versions extension is used
// to negotiate versions now. See RFC 8446, Section 4.2.1.
if hello.vers > VersionTLS12 {
hello.vers = VersionTLS12
}
if c.handshakes > 0 {
hello.secureRenegotiation = c.clientFinished[:]
}
hello.cipherSuites = config.cipherSuites(hasAESGCMHardwareSupport)
// Don't advertise TLS 1.2-only cipher suites unless we're attempting TLS 1.2.
if maxVersion < VersionTLS12 {
hello.cipherSuites = slices.DeleteFunc(hello.cipherSuites, func(id uint16) bool {
return cipherSuiteByID(id).flags&suiteTLS12 != 0
})
}
_, err := io.ReadFull(config.rand(), hello.random)
if err != nil {
return nil, nil, nil, errors.New("tls: short read from Rand: " + err.Error())
}
// A random session ID is used to detect when the server accepted a ticket
// and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as
// a compatibility measure (see RFC 8446, Section 4.1.2).
//
// The session ID is not set for QUIC connections (see RFC 9001, Section 8.4).
if c.quic == nil {
hello.sessionId = make([]byte, 32)
if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil {
return nil, nil, nil, errors.New("tls: short read from Rand: " + err.Error())
}
}
if maxVersion >= VersionTLS12 {
hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms(minVersion)
hello.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithmsCert()
}
var keyShareKeys *keySharePrivateKeys
if maxVersion >= VersionTLS13 {
// Reset the list of ciphers when the client only supports TLS 1.3.
if minVersion >= VersionTLS13 {
hello.cipherSuites = nil
}
if fips140tls.Required() {
hello.cipherSuites = append(hello.cipherSuites, allowedCipherSuitesTLS13FIPS...)
} else if hasAESGCMHardwareSupport {
hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...)
} else {
hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...)
}
if len(hello.supportedCurves) == 0 {
return nil, nil, nil, errors.New("tls: no supported elliptic curves for ECDHE")
}
curveID := hello.supportedCurves[0]
keyShareKeys = &keySharePrivateKeys{curveID: curveID}
// Note that if X25519MLKEM768 is supported, it will be first because
// the preference order is fixed.
if curveID == X25519MLKEM768 {
keyShareKeys.ecdhe, err = generateECDHEKey(config.rand(), X25519)
if err != nil {
return nil, nil, nil, err
}
seed := make([]byte, mlkem.SeedSize)
if _, err := io.ReadFull(config.rand(), seed); err != nil {
return nil, nil, nil, err
}
keyShareKeys.mlkem, err = mlkem.NewDecapsulationKey768(seed)
if err != nil {
return nil, nil, nil, err
}
mlkemEncapsulationKey := keyShareKeys.mlkem.EncapsulationKey().Bytes()
x25519EphemeralKey := keyShareKeys.ecdhe.PublicKey().Bytes()
hello.keyShares = []keyShare{
{group: X25519MLKEM768, data: append(mlkemEncapsulationKey, x25519EphemeralKey...)},
}
// If both X25519MLKEM768 and X25519 are supported, we send both key
// shares (as a fallback) and we reuse the same X25519 ephemeral
// key, as allowed by draft-ietf-tls-hybrid-design-09, Section 3.2.
if slices.Contains(hello.supportedCurves, X25519) {
hello.keyShares = append(hello.keyShares, keyShare{group: X25519, data: x25519EphemeralKey})
}
} else {
if _, ok := curveForCurveID(curveID); !ok {
return nil, nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
}
keyShareKeys.ecdhe, err = generateECDHEKey(config.rand(), curveID)
if err != nil {
return nil, nil, nil, err
}
hello.keyShares = []keyShare{{group: curveID, data: keyShareKeys.ecdhe.PublicKey().Bytes()}}
}
}
if c.quic != nil {
p, err := c.quicGetTransportParameters()
if err != nil {
return nil, nil, nil, err
}
if p == nil {
p = []byte{}
}
hello.quicTransportParameters = p
}
var ech *echClientContext
if c.config.EncryptedClientHelloConfigList != nil {
if c.config.MinVersion != 0 && c.config.MinVersion < VersionTLS13 {
return nil, nil, nil, errors.New("tls: MinVersion must be >= VersionTLS13 if EncryptedClientHelloConfigList is populated")
}
if c.config.MaxVersion != 0 && c.config.MaxVersion <= VersionTLS12 {
return nil, nil, nil, errors.New("tls: MaxVersion must be >= VersionTLS13 if EncryptedClientHelloConfigList is populated")
}
echConfigs, err := parseECHConfigList(c.config.EncryptedClientHelloConfigList)
if err != nil {
return nil, nil, nil, err
}
echConfig := pickECHConfig(echConfigs)
if echConfig == nil {
return nil, nil, nil, errors.New("tls: EncryptedClientHelloConfigList contains no valid configs")
}
ech = &echClientContext{config: echConfig}
hello.encryptedClientHello = []byte{1} // indicate inner hello
// We need to explicitly set these 1.2 fields to nil, as we do not
// marshal them when encoding the inner hello, otherwise transcripts
// will later mismatch.
hello.supportedPoints = nil
hello.ticketSupported = false
hello.secureRenegotiationSupported = false
hello.extendedMasterSecret = false
echPK, err := hpke.ParseHPKEPublicKey(ech.config.KemID, ech.config.PublicKey)
if err != nil {
return nil, nil, nil, err
}
suite, err := pickECHCipherSuite(ech.config.SymmetricCipherSuite)
if err != nil {
return nil, nil, nil, err
}
ech.kdfID, ech.aeadID = suite.KDFID, suite.AEADID
info := append([]byte("tls ech\x00"), ech.config.raw...)
ech.encapsulatedKey, ech.hpkeContext, err = hpke.SetupSender(ech.config.KemID, suite.KDFID, suite.AEADID, echPK, info)
if err != nil {
return nil, nil, nil, err
}
}
return hello, keyShareKeys, ech, nil
}
type echClientContext struct {
config *echConfig
hpkeContext *hpke.Sender
encapsulatedKey []byte
innerHello *clientHelloMsg
innerTranscript hash.Hash
kdfID uint16
aeadID uint16
echRejected bool
retryConfigs []byte
}
func (c *Conn) clientHandshake(ctx context.Context) (err error) {
if c.config == nil {
c.config = defaultConfig()
}
// This may be a renegotiation handshake, in which case some fields
// need to be reset.
c.didResume = false
c.curveID = 0
hello, keyShareKeys, ech, err := c.makeClientHello()
if err != nil {
return err
}
session, earlySecret, binderKey, err := c.loadSession(hello)
if err != nil {
return err
}
if session != nil {
defer func() {
// If we got a handshake failure when resuming a session, throw away
// the session ticket. See RFC 5077, Section 3.2.
//
// RFC 8446 makes no mention of dropping tickets on failure, but it
// does require servers to abort on invalid binders, so we need to
// delete tickets to recover from a corrupted PSK.
if err != nil {
if cacheKey := c.clientSessionCacheKey(); cacheKey != "" {
c.config.ClientSessionCache.Put(cacheKey, nil)
}
}
}()
}
if ech != nil {
// Split hello into inner and outer
ech.innerHello = hello.clone()
// Overwrite the server name in the outer hello with the public facing
// name.
hello.serverName = string(ech.config.PublicName)
// Generate a new random for the outer hello.
hello.random = make([]byte, 32)
_, err = io.ReadFull(c.config.rand(), hello.random)
if err != nil {
return errors.New("tls: short read from Rand: " + err.Error())
}
// NOTE: we don't do PSK GREASE, in line with boringssl, it's meant to
// work around _possibly_ broken middleboxes, but there is little-to-no
// evidence that this is actually a problem.
if err := computeAndUpdateOuterECHExtension(hello, ech.innerHello, ech, true); err != nil {
return err
}
}
c.serverName = hello.serverName
if _, err := c.writeHandshakeRecord(hello, nil); err != nil {
return err
}
if hello.earlyData {
suite := cipherSuiteTLS13ByID(session.cipherSuite)
transcript := suite.hash.New()
if err := transcriptMsg(hello, transcript); err != nil {
return err
}
earlyTrafficSecret := earlySecret.ClientEarlyTrafficSecret(transcript)
c.quicSetWriteSecret(QUICEncryptionLevelEarly, suite.id, earlyTrafficSecret)
}
// serverHelloMsg is not included in the transcript
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
serverHello, ok := msg.(*serverHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(serverHello, msg)
}
if err := c.pickTLSVersion(serverHello); err != nil {
return err
}
// If we are negotiating a protocol version that's lower than what we
// support, check for the server downgrade canaries.
// See RFC 8446, Section 4.1.3.
maxVers := c.config.maxSupportedVersion(roleClient)
tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12
tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11
if maxVers == VersionTLS13 && c.vers <= VersionTLS12 && (tls12Downgrade || tls11Downgrade) ||
maxVers == VersionTLS12 && c.vers <= VersionTLS11 && tls11Downgrade {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox")
}
if c.vers == VersionTLS13 {
hs := &clientHandshakeStateTLS13{
c: c,
ctx: ctx,
serverHello: serverHello,
hello: hello,
keyShareKeys: keyShareKeys,
session: session,
earlySecret: earlySecret,
binderKey: binderKey,
echContext: ech,
}
return hs.handshake()
}
hs := &clientHandshakeState{
c: c,
ctx: ctx,
serverHello: serverHello,
hello: hello,
session: session,
}
return hs.handshake()
}
func (c *Conn) loadSession(hello *clientHelloMsg) (
session *SessionState, earlySecret *tls13.EarlySecret, binderKey []byte, err error) {
if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
return nil, nil, nil, nil
}
echInner := bytes.Equal(hello.encryptedClientHello, []byte{1})
// ticketSupported is a TLS 1.2 extension (as TLS 1.3 replaced tickets with PSK
// identities) and ECH requires and forces TLS 1.3.
hello.ticketSupported = true && !echInner
if hello.supportedVersions[0] == VersionTLS13 {
// Require DHE on resumption as it guarantees forward secrecy against
// compromise of the session ticket key. See RFC 8446, Section 4.2.9.
hello.pskModes = []uint8{pskModeDHE}
}
// Session resumption is not allowed if renegotiating because
// renegotiation is primarily used to allow a client to send a client
// certificate, which would be skipped if session resumption occurred.
if c.handshakes != 0 {
return nil, nil, nil, nil
}
// Try to resume a previously negotiated TLS session, if available.
cacheKey := c.clientSessionCacheKey()
if cacheKey == "" {
return nil, nil, nil, nil
}
cs, ok := c.config.ClientSessionCache.Get(cacheKey)
if !ok || cs == nil {
return nil, nil, nil, nil
}
session = cs.session
// Check that version used for the previous session is still valid.
versOk := false
for _, v := range hello.supportedVersions {
if v == session.version {
versOk = true
break
}
}
if !versOk {
return nil, nil, nil, nil
}
// Check that the cached server certificate is not expired, and that it's
// valid for the ServerName. This should be ensured by the cache key, but
// protect the application from a faulty ClientSessionCache implementation.
if c.config.time().After(session.peerCertificates[0].NotAfter) {
// Expired certificate, delete the entry.
c.config.ClientSessionCache.Put(cacheKey, nil)
return nil, nil, nil, nil
}
if !c.config.InsecureSkipVerify {
if len(session.verifiedChains) == 0 {
// The original connection had InsecureSkipVerify, while this doesn't.
return nil, nil, nil, nil
}
if err := session.peerCertificates[0].VerifyHostname(c.config.ServerName); err != nil {
return nil, nil, nil, nil
}
}
if session.version != VersionTLS13 {
// In TLS 1.2 the cipher suite must match the resumed session. Ensure we
// are still offering it.
if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil {
return nil, nil, nil, nil
}
// FIPS 140-3 requires the use of Extended Master Secret.
if !session.extMasterSecret && fips140tls.Required() {
return nil, nil, nil, nil
}
hello.sessionTicket = session.ticket
return
}
// Check that the session ticket is not expired.
if c.config.time().After(time.Unix(int64(session.useBy), 0)) {
c.config.ClientSessionCache.Put(cacheKey, nil)
return nil, nil, nil, nil
}
// In TLS 1.3 the KDF hash must match the resumed session. Ensure we
// offer at least one cipher suite with that hash.
cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite)
if cipherSuite == nil {
return nil, nil, nil, nil
}
cipherSuiteOk := false
for _, offeredID := range hello.cipherSuites {
offeredSuite := cipherSuiteTLS13ByID(offeredID)
if offeredSuite != nil && offeredSuite.hash == cipherSuite.hash {
cipherSuiteOk = true
break
}
}
if !cipherSuiteOk {
return nil, nil, nil, nil
}
if c.quic != nil {
if c.quic.enableSessionEvents {
c.quicResumeSession(session)
}
// For 0-RTT, the cipher suite has to match exactly, and we need to be
// offering the same ALPN.
if session.EarlyData && mutualCipherSuiteTLS13(hello.cipherSuites, session.cipherSuite) != nil {
for _, alpn := range hello.alpnProtocols {
if alpn == session.alpnProtocol {
hello.earlyData = true
break
}
}
}
}
// Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1.
ticketAge := c.config.time().Sub(time.Unix(int64(session.createdAt), 0))
identity := pskIdentity{
label: session.ticket,
obfuscatedTicketAge: uint32(ticketAge/time.Millisecond) + session.ageAdd,
}
hello.pskIdentities = []pskIdentity{identity}
hello.pskBinders = [][]byte{make([]byte, cipherSuite.hash.Size())}
// Compute the PSK binders. See RFC 8446, Section 4.2.11.2.
earlySecret = tls13.NewEarlySecret(cipherSuite.hash.New, session.secret)
binderKey = earlySecret.ResumptionBinderKey()
transcript := cipherSuite.hash.New()
if err := computeAndUpdatePSK(hello, binderKey, transcript, cipherSuite.finishedHash); err != nil {
return nil, nil, nil, err
}
return
}
func (c *Conn) pickTLSVersion(serverHello *serverHelloMsg) error {
peerVersion := serverHello.vers
if serverHello.supportedVersion != 0 {
peerVersion = serverHello.supportedVersion
}
vers, ok := c.config.mutualVersion(roleClient, []uint16{peerVersion})
if !ok {
c.sendAlert(alertProtocolVersion)
return fmt.Errorf("tls: server selected unsupported protocol version %x", peerVersion)
}
c.vers = vers
c.haveVers = true
c.in.version = vers
c.out.version = vers
return nil
}
// Does the handshake, either a full one or resumes old session. Requires hs.c,
// hs.hello, hs.serverHello, and, optionally, hs.session to be set.
func (hs *clientHandshakeState) handshake() error {
c := hs.c
// If we did not load a session (hs.session == nil), but we did set a
// session ID in the transmitted client hello (hs.hello.sessionId != nil),
// it means we tried to negotiate TLS 1.3 and sent a random session ID as a
// compatibility measure (see RFC 8446, Section 4.1.2).
//
// Since we're now handshaking for TLS 1.2, if the server echoed the
// transmitted ID back to us, we know mischief is afoot: the session ID
// was random and can't possibly be recognized by the server.
if hs.session == nil && hs.hello.sessionId != nil && bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server echoed TLS 1.3 compatibility session ID in TLS 1.2")
}
isResume, err := hs.processServerHello()
if err != nil {
return err
}
hs.finishedHash = newFinishedHash(c.vers, hs.suite)
// No signatures of the handshake are needed in a resumption.
// Otherwise, in a full handshake, if we don't have any certificates
// configured then we will never send a CertificateVerify message and
// thus no signatures are needed in that case either.
if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) {
hs.finishedHash.discardHandshakeBuffer()
}
if err := transcriptMsg(hs.hello, &hs.finishedHash); err != nil {
return err
}
if err := transcriptMsg(hs.serverHello, &hs.finishedHash); err != nil {
return err
}
c.buffering = true
c.didResume = isResume
if isResume {
if err := hs.establishKeys(); err != nil {
return err
}
if err := hs.readSessionTicket(); err != nil {
return err
}
if err := hs.readFinished(c.serverFinished[:]); err != nil {
return err
}
c.clientFinishedIsFirst = false
// Make sure the connection is still being verified whether or not this
// is a resumption. Resumptions currently don't reverify certificates so
// they don't call verifyServerCertificate. See Issue 31641.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
if err := hs.sendFinished(c.clientFinished[:]); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
} else {
if err := hs.doFullHandshake(); err != nil {
return err
}
if err := hs.establishKeys(); err != nil {
return err
}
if err := hs.sendFinished(c.clientFinished[:]); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
c.clientFinishedIsFirst = true
if err := hs.readSessionTicket(); err != nil {
return err
}
if err := hs.readFinished(c.serverFinished[:]); err != nil {
return err
}
}
if err := hs.saveSessionTicket(); err != nil {
return err
}
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random)
c.isHandshakeComplete.Store(true)
return nil
}
func (hs *clientHandshakeState) pickCipherSuite() error {
if hs.suite = mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil {
hs.c.sendAlert(alertHandshakeFailure)
return errors.New("tls: server chose an unconfigured cipher suite")
}
if hs.c.config.CipherSuites == nil && !fips140tls.Required() && rsaKexCiphers[hs.suite.id] {
tlsrsakex.Value() // ensure godebug is initialized
tlsrsakex.IncNonDefault()
}
if hs.c.config.CipherSuites == nil && !fips140tls.Required() && tdesCiphers[hs.suite.id] {
tls3des.Value() // ensure godebug is initialized
tls3des.IncNonDefault()
}
hs.c.cipherSuite = hs.suite.id
return nil
}
func (hs *clientHandshakeState) doFullHandshake() error {
c := hs.c
msg, err := c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
certMsg, ok := msg.(*certificateMsg)
if !ok || len(certMsg.certificates) == 0 {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
cs, ok := msg.(*certificateStatusMsg)
if ok {
// RFC4366 on Certificate Status Request:
// The server MAY return a "certificate_status" message.
if !hs.serverHello.ocspStapling {
// If a server returns a "CertificateStatus" message, then the
// server MUST have included an extension of type "status_request"
// with empty "extension_data" in the extended server hello.
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: received unexpected CertificateStatus message")
}
c.ocspResponse = cs.response
msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
}
if c.handshakes == 0 {
// If this is the first handshake on a connection, process and
// (optionally) verify the server's certificates.
if err := c.verifyServerCertificate(certMsg.certificates); err != nil {
return err
}
} else {
// This is a renegotiation handshake. We require that the
// server's identity (i.e. leaf certificate) is unchanged and
// thus any previous trust decision is still valid.
//
// See https://mitls.org/pages/attacks/3SHAKE for the
// motivation behind this requirement.
if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) {
c.sendAlert(alertBadCertificate)
return errors.New("tls: server's identity changed during renegotiation")
}
}
keyAgreement := hs.suite.ka(c.vers)
skx, ok := msg.(*serverKeyExchangeMsg)
if ok {
err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx)
if err != nil {
c.sendAlert(alertIllegalParameter)
return err
}
if keyAgreement, ok := keyAgreement.(*ecdheKeyAgreement); ok {
c.curveID = keyAgreement.curveID
c.peerSigAlg = keyAgreement.signatureAlgorithm
}
msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
}
var chainToSend *Certificate
var certRequested bool
certReq, ok := msg.(*certificateRequestMsg)
if ok {
certRequested = true
cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq)
if chainToSend, err = c.getClientCertificate(cri); err != nil {
c.sendAlert(alertInternalError)
return err
}
msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
}
shd, ok := msg.(*serverHelloDoneMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(shd, msg)
}
// If the server requested a certificate then we have to send a
// Certificate message, even if it's empty because we don't have a
// certificate to send.
if certRequested {
certMsg = new(certificateMsg)
certMsg.certificates = chainToSend.Certificate
if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil {
return err
}
}
preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0])
if err != nil {
c.sendAlert(alertInternalError)
return err
}
if ckx != nil {
if _, err := hs.c.writeHandshakeRecord(ckx, &hs.finishedHash); err != nil {
return err
}
}
if hs.serverHello.extendedMasterSecret {
c.extMasterSecret = true
hs.masterSecret = extMasterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret,
hs.finishedHash.Sum())
} else {
if fips140tls.Required() {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: FIPS 140-3 requires the use of Extended Master Secret")
}
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret,
hs.hello.random, hs.serverHello.random)
}
if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.hello.random, hs.masterSecret); err != nil {
c.sendAlert(alertInternalError)
return errors.New("tls: failed to write to key log: " + err.Error())
}
if chainToSend != nil && len(chainToSend.Certificate) > 0 {
certVerify := &certificateVerifyMsg{}
key, ok := chainToSend.PrivateKey.(crypto.Signer)
if !ok {
c.sendAlert(alertInternalError)
return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey)
}
var sigType uint8
var sigHash crypto.Hash
if c.vers >= VersionTLS12 {
signatureAlgorithm, err := selectSignatureScheme(c.vers, chainToSend, certReq.supportedSignatureAlgorithms)
if err != nil {
c.sendAlert(alertHandshakeFailure)
return err
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
certVerify.hasSignatureAlgorithm = true
certVerify.signatureAlgorithm = signatureAlgorithm
if sigHash == crypto.SHA1 {
tlssha1.Value() // ensure godebug is initialized
tlssha1.IncNonDefault()
}
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(key.Public())
if err != nil {
c.sendAlert(alertIllegalParameter)
return err
}
}
signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
certVerify.signature, err = key.Sign(c.config.rand(), signed, signOpts)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
if _, err := hs.c.writeHandshakeRecord(certVerify, &hs.finishedHash); err != nil {
return err
}
}
hs.finishedHash.discardHandshakeBuffer()
return nil
}
func (hs *clientHandshakeState) establishKeys() error {
c := hs.c
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
var clientCipher, serverCipher any
var clientHash, serverHash hash.Hash
if hs.suite.cipher != nil {
clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */)
clientHash = hs.suite.mac(clientMAC)
serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */)
serverHash = hs.suite.mac(serverMAC)
} else {
clientCipher = hs.suite.aead(clientKey, clientIV)
serverCipher = hs.suite.aead(serverKey, serverIV)
}
c.in.prepareCipherSpec(c.vers, serverCipher, serverHash)
c.out.prepareCipherSpec(c.vers, clientCipher, clientHash)
return nil
}
func (hs *clientHandshakeState) serverResumedSession() bool {
// If the server responded with the same sessionId then it means the
// sessionTicket is being used to resume a TLS session.
return hs.session != nil && hs.hello.sessionId != nil &&
bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId)
}
func (hs *clientHandshakeState) processServerHello() (bool, error) {
c := hs.c
if err := hs.pickCipherSuite(); err != nil {
return false, err
}
if hs.serverHello.compressionMethod != compressionNone {
c.sendAlert(alertIllegalParameter)
return false, errors.New("tls: server selected unsupported compression format")
}
supportsPointFormat := false
offeredNonCompressedFormat := false
for _, format := range hs.serverHello.supportedPoints {
if format == pointFormatUncompressed {
supportsPointFormat = true
} else {
offeredNonCompressedFormat = true
}
}
if !supportsPointFormat && offeredNonCompressedFormat {
return false, errors.New("tls: server offered only incompatible point formats")
}
if c.handshakes == 0 && hs.serverHello.secureRenegotiationSupported {
c.secureRenegotiation = true
if len(hs.serverHello.secureRenegotiation) != 0 {
c.sendAlert(alertHandshakeFailure)
return false, errors.New("tls: initial handshake had non-empty renegotiation extension")
}
}
if c.handshakes > 0 && c.secureRenegotiation {
var expectedSecureRenegotiation [24]byte
copy(expectedSecureRenegotiation[:], c.clientFinished[:])
copy(expectedSecureRenegotiation[12:], c.serverFinished[:])
if !bytes.Equal(hs.serverHello.secureRenegotiation, expectedSecureRenegotiation[:]) {
c.sendAlert(alertHandshakeFailure)
return false, errors.New("tls: incorrect renegotiation extension contents")
}
}
if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol, false); err != nil {
c.sendAlert(alertUnsupportedExtension)
return false, err
}
c.clientProtocol = hs.serverHello.alpnProtocol
c.scts = hs.serverHello.scts
if !hs.serverResumedSession() {
return false, nil
}
if hs.session.version != c.vers {
c.sendAlert(alertHandshakeFailure)
return false, errors.New("tls: server resumed a session with a different version")
}
if hs.session.cipherSuite != hs.suite.id {
c.sendAlert(alertHandshakeFailure)
return false, errors.New("tls: server resumed a session with a different cipher suite")
}
// RFC 7627, Section 5.3
if hs.session.extMasterSecret != hs.serverHello.extendedMasterSecret {
c.sendAlert(alertHandshakeFailure)
return false, errors.New("tls: server resumed a session with a different EMS extension")
}
// Restore master secret and certificates from previous state
hs.masterSecret = hs.session.secret
c.extMasterSecret = hs.session.extMasterSecret
c.peerCertificates = hs.session.peerCertificates
c.verifiedChains = hs.session.verifiedChains
c.ocspResponse = hs.session.ocspResponse
// Let the ServerHello SCTs override the session SCTs from the original
// connection, if any are provided.
if len(c.scts) == 0 && len(hs.session.scts) != 0 {
c.scts = hs.session.scts
}
c.curveID = hs.session.curveID
return true, nil
}
// checkALPN ensure that the server's choice of ALPN protocol is compatible with
// the protocols that we advertised in the ClientHello.
func checkALPN(clientProtos []string, serverProto string, quic bool) error {
if serverProto == "" {
if quic && len(clientProtos) > 0 {
// RFC 9001, Section 8.1
return errors.New("tls: server did not select an ALPN protocol")
}
return nil
}
if len(clientProtos) == 0 {
return errors.New("tls: server advertised unrequested ALPN extension")
}
for _, proto := range clientProtos {
if proto == serverProto {
return nil
}
}
return errors.New("tls: server selected unadvertised ALPN protocol")
}
func (hs *clientHandshakeState) readFinished(out []byte) error {
c := hs.c
if err := c.readChangeCipherSpec(); err != nil {
return err
}
// finishedMsg is included in the transcript, but not until after we
// check the client version, since the state before this message was
// sent is used during verification.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
serverFinished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(serverFinished, msg)
}
verify := hs.finishedHash.serverSum(hs.masterSecret)
if len(verify) != len(serverFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: server's Finished message was incorrect")
}
if err := transcriptMsg(serverFinished, &hs.finishedHash); err != nil {
return err
}
copy(out, verify)
return nil
}
func (hs *clientHandshakeState) readSessionTicket() error {
if !hs.serverHello.ticketSupported {
return nil
}
c := hs.c
if !hs.hello.ticketSupported {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent unrequested session ticket")
}
msg, err := c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
sessionTicketMsg, ok := msg.(*newSessionTicketMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(sessionTicketMsg, msg)
}
hs.ticket = sessionTicketMsg.ticket
return nil
}
func (hs *clientHandshakeState) saveSessionTicket() error {
if hs.ticket == nil {
return nil
}
c := hs.c
cacheKey := c.clientSessionCacheKey()
if cacheKey == "" {
return nil
}
session := c.sessionState()
session.secret = hs.masterSecret
session.ticket = hs.ticket
cs := &ClientSessionState{session: session}
c.config.ClientSessionCache.Put(cacheKey, cs)
return nil
}
func (hs *clientHandshakeState) sendFinished(out []byte) error {
c := hs.c
if err := c.writeChangeCipherRecord(); err != nil {
return err
}
finished := new(finishedMsg)
finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil {
return err
}
copy(out, finished.verifyData)
return nil
}
// defaultMaxRSAKeySize is the maximum RSA key size in bits that we are willing
// to verify the signatures of during a TLS handshake.
const defaultMaxRSAKeySize = 8192
var tlsmaxrsasize = godebug.New("tlsmaxrsasize")
func checkKeySize(n int) (max int, ok bool) {
if v := tlsmaxrsasize.Value(); v != "" {
if max, err := strconv.Atoi(v); err == nil {
if (n <= max) != (n <= defaultMaxRSAKeySize) {
tlsmaxrsasize.IncNonDefault()
}
return max, n <= max
}
}
return defaultMaxRSAKeySize, n <= defaultMaxRSAKeySize
}
// verifyServerCertificate parses and verifies the provided chain, setting
// c.verifiedChains and c.peerCertificates or sending the appropriate alert.
func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
certs := make([]*x509.Certificate, len(certificates))
for i, asn1Data := range certificates {
cert, err := globalCertCache.newCert(asn1Data)
if err != nil {
c.sendAlert(alertDecodeError)
return errors.New("tls: failed to parse certificate from server: " + err.Error())
}
if cert.PublicKeyAlgorithm == x509.RSA {
n := cert.PublicKey.(*rsa.PublicKey).N.BitLen()
if max, ok := checkKeySize(n); !ok {
c.sendAlert(alertBadCertificate)
return fmt.Errorf("tls: server sent certificate containing RSA key larger than %d bits", max)
}
}
certs[i] = cert
}
echRejected := c.config.EncryptedClientHelloConfigList != nil && !c.echAccepted
if echRejected {
if c.config.EncryptedClientHelloRejectionVerify != nil {
if err := c.config.EncryptedClientHelloRejectionVerify(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
} else {
opts := x509.VerifyOptions{
Roots: c.config.RootCAs,
CurrentTime: c.config.time(),
DNSName: c.serverName,
Intermediates: x509.NewCertPool(),
}
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
chains, err := certs[0].Verify(opts)
if err != nil {
c.sendAlert(alertBadCertificate)
return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err}
}
c.verifiedChains, err = fipsAllowedChains(chains)
if err != nil {
c.sendAlert(alertBadCertificate)
return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err}
}
}
} else if !c.config.InsecureSkipVerify {
opts := x509.VerifyOptions{
Roots: c.config.RootCAs,
CurrentTime: c.config.time(),
DNSName: c.config.ServerName,
Intermediates: x509.NewCertPool(),
}
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
chains, err := certs[0].Verify(opts)
if err != nil {
c.sendAlert(alertBadCertificate)
return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err}
}
c.verifiedChains, err = fipsAllowedChains(chains)
if err != nil {
c.sendAlert(alertBadCertificate)
return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err}
}
}
switch certs[0].PublicKey.(type) {
case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey:
break
default:
c.sendAlert(alertUnsupportedCertificate)
return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey)
}
c.peerCertificates = certs
if c.config.VerifyPeerCertificate != nil && !echRejected {
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
if c.config.VerifyConnection != nil && !echRejected {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}
// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS
// <= 1.2 CertificateRequest, making an effort to fill in missing information.
func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo {
cri := &CertificateRequestInfo{
AcceptableCAs: certReq.certificateAuthorities,
Version: vers,
ctx: ctx,
}
var rsaAvail, ecAvail bool
for _, certType := range certReq.certificateTypes {
switch certType {
case certTypeRSASign:
rsaAvail = true
case certTypeECDSASign:
ecAvail = true
}
}
if !certReq.hasSignatureAlgorithm {
// Prior to TLS 1.2, signature schemes did not exist. In this case we
// make up a list based on the acceptable certificate types, to help
// GetClientCertificate and SupportsCertificate select the right certificate.
// The hash part of the SignatureScheme is a lie here, because
// TLS 1.0 and 1.1 always use MD5+SHA1 for RSA and SHA1 for ECDSA.
switch {
case rsaAvail && ecAvail:
cri.SignatureSchemes = []SignatureScheme{
ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512,
PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1,
}
case rsaAvail:
cri.SignatureSchemes = []SignatureScheme{
PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1,
}
case ecAvail:
cri.SignatureSchemes = []SignatureScheme{
ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512,
}
}
return cri
}
// Filter the signature schemes based on the certificate types.
// See RFC 5246, Section 7.4.4 (where it calls this "somewhat complicated").
cri.SignatureSchemes = make([]SignatureScheme, 0, len(certReq.supportedSignatureAlgorithms))
for _, sigScheme := range certReq.supportedSignatureAlgorithms {
sigType, _, err := typeAndHashFromSignatureScheme(sigScheme)
if err != nil {
continue
}
switch sigType {
case signatureECDSA, signatureEd25519:
if ecAvail {
cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme)
}
case signatureRSAPSS, signaturePKCS1v15:
if rsaAvail {
cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme)
}
}
}
return cri
}
func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate, error) {
if c.config.GetClientCertificate != nil {
return c.config.GetClientCertificate(cri)
}
for _, chain := range c.config.Certificates {
if err := cri.SupportsCertificate(&chain); err != nil {
continue
}
return &chain, nil
}
// No acceptable certificate found. Don't send a certificate.
return new(Certificate), nil
}
// clientSessionCacheKey returns a key used to cache sessionTickets that could
// be used to resume previously negotiated TLS sessions with a server.
func (c *Conn) clientSessionCacheKey() string {
if len(c.config.ServerName) > 0 {
return c.config.ServerName
}
if c.conn != nil {
return c.conn.RemoteAddr().String()
}
return ""
}
// hostnameInSNI converts name into an appropriate hostname for SNI.
// Literal IP addresses and absolute FQDNs are not permitted as SNI values.
// See RFC 6066, Section 3.
func hostnameInSNI(name string) string {
host := name
if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' {
host = host[1 : len(host)-1]
}
if i := strings.LastIndex(host, "%"); i > 0 {
host = host[:i]
}
if net.ParseIP(host) != nil {
return ""
}
for len(name) > 0 && name[len(name)-1] == '.' {
name = name[:len(name)-1]
}
return name
}
func computeAndUpdatePSK(m *clientHelloMsg, binderKey []byte, transcript hash.Hash, finishedHash func([]byte, hash.Hash) []byte) error {
helloBytes, err := m.marshalWithoutBinders()
if err != nil {
return err
}
transcript.Write(helloBytes)
pskBinders := [][]byte{finishedHash(binderKey, transcript)}
return m.updateBinders(pskBinders)
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"context"
"crypto"
"crypto/hkdf"
"crypto/hmac"
"crypto/internal/fips140/mlkem"
"crypto/internal/fips140/tls13"
"crypto/rsa"
"crypto/subtle"
"errors"
"hash"
"slices"
"time"
)
type clientHandshakeStateTLS13 struct {
c *Conn
ctx context.Context
serverHello *serverHelloMsg
hello *clientHelloMsg
keyShareKeys *keySharePrivateKeys
session *SessionState
earlySecret *tls13.EarlySecret
binderKey []byte
certReq *certificateRequestMsgTLS13
usingPSK bool
sentDummyCCS bool
suite *cipherSuiteTLS13
transcript hash.Hash
masterSecret *tls13.MasterSecret
trafficSecret []byte // client_application_traffic_secret_0
echContext *echClientContext
}
// handshake requires hs.c, hs.hello, hs.serverHello, hs.keyShareKeys, and,
// optionally, hs.session, hs.earlySecret and hs.binderKey to be set.
func (hs *clientHandshakeStateTLS13) handshake() error {
c := hs.c
// The server must not select TLS 1.3 in a renegotiation. See RFC 8446,
// sections 4.1.2 and 4.1.3.
if c.handshakes > 0 {
c.sendAlert(alertProtocolVersion)
return errors.New("tls: server selected TLS 1.3 in a renegotiation")
}
// Consistency check on the presence of a keyShare and its parameters.
if hs.keyShareKeys == nil || hs.keyShareKeys.ecdhe == nil || len(hs.hello.keyShares) == 0 {
return c.sendAlert(alertInternalError)
}
if err := hs.checkServerHelloOrHRR(); err != nil {
return err
}
hs.transcript = hs.suite.hash.New()
if err := transcriptMsg(hs.hello, hs.transcript); err != nil {
return err
}
if hs.echContext != nil {
hs.echContext.innerTranscript = hs.suite.hash.New()
if err := transcriptMsg(hs.echContext.innerHello, hs.echContext.innerTranscript); err != nil {
return err
}
}
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
if err := hs.processHelloRetryRequest(); err != nil {
return err
}
}
if hs.echContext != nil {
confTranscript := cloneHash(hs.echContext.innerTranscript, hs.suite.hash)
confTranscript.Write(hs.serverHello.original[:30])
confTranscript.Write(make([]byte, 8))
confTranscript.Write(hs.serverHello.original[38:])
h := hs.suite.hash.New
prk, err := hkdf.Extract(h, hs.echContext.innerHello.random, nil)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
acceptConfirmation := tls13.ExpandLabel(h, prk, "ech accept confirmation", confTranscript.Sum(nil), 8)
if subtle.ConstantTimeCompare(acceptConfirmation, hs.serverHello.random[len(hs.serverHello.random)-8:]) == 1 {
hs.hello = hs.echContext.innerHello
c.serverName = c.config.ServerName
hs.transcript = hs.echContext.innerTranscript
c.echAccepted = true
if hs.serverHello.encryptedClientHello != nil {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: unexpected encrypted client hello extension in server hello despite ECH being accepted")
}
if hs.hello.serverName == "" && hs.serverHello.serverNameAck {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: unexpected server_name extension in server hello")
}
} else {
hs.echContext.echRejected = true
}
}
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
return err
}
c.buffering = true
if err := hs.processServerHello(); err != nil {
return err
}
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
if err := hs.establishHandshakeKeys(); err != nil {
return err
}
if err := hs.readServerParameters(); err != nil {
return err
}
if err := hs.readServerCertificate(); err != nil {
return err
}
if err := hs.readServerFinished(); err != nil {
return err
}
if err := hs.sendClientCertificate(); err != nil {
return err
}
if err := hs.sendClientFinished(); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
if hs.echContext != nil && hs.echContext.echRejected {
c.sendAlert(alertECHRequired)
return &ECHRejectionError{hs.echContext.retryConfigs}
}
c.isHandshakeComplete.Store(true)
return nil
}
// checkServerHelloOrHRR does validity checks that apply to both ServerHello and
// HelloRetryRequest messages. It sets hs.suite.
func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error {
c := hs.c
if hs.serverHello.supportedVersion == 0 {
c.sendAlert(alertMissingExtension)
return errors.New("tls: server selected TLS 1.3 using the legacy version field")
}
if hs.serverHello.supportedVersion != VersionTLS13 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected an invalid version after a HelloRetryRequest")
}
if hs.serverHello.vers != VersionTLS12 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an incorrect legacy version")
}
if hs.serverHello.ocspStapling ||
hs.serverHello.ticketSupported ||
hs.serverHello.extendedMasterSecret ||
hs.serverHello.secureRenegotiationSupported ||
len(hs.serverHello.secureRenegotiation) != 0 ||
len(hs.serverHello.alpnProtocol) != 0 ||
len(hs.serverHello.scts) != 0 {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3")
}
if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server did not echo the legacy session ID")
}
if hs.serverHello.compressionMethod != compressionNone {
c.sendAlert(alertDecodeError)
return errors.New("tls: server sent non-zero legacy TLS compression method")
}
selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite)
if hs.suite != nil && selectedSuite != hs.suite {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server changed cipher suite after a HelloRetryRequest")
}
if selectedSuite == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server chose an unconfigured cipher suite")
}
hs.suite = selectedSuite
c.cipherSuite = hs.suite.id
return nil
}
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
if hs.c.quic != nil {
return nil
}
if hs.sentDummyCCS {
return nil
}
hs.sentDummyCCS = true
return hs.c.writeChangeCipherRecord()
}
// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and
// resends hs.hello, and reads the new ServerHello into hs.serverHello.
func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
c := hs.c
// The first ClientHello gets double-hashed into the transcript upon a
// HelloRetryRequest. (The idea is that the server might offload transcript
// storage to the client in the cookie.) See RFC 8446, Section 4.4.1.
chHash := hs.transcript.Sum(nil)
hs.transcript.Reset()
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
hs.transcript.Write(chHash)
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
return err
}
var isInnerHello bool
hello := hs.hello
if hs.echContext != nil {
chHash = hs.echContext.innerTranscript.Sum(nil)
hs.echContext.innerTranscript.Reset()
hs.echContext.innerTranscript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
hs.echContext.innerTranscript.Write(chHash)
if hs.serverHello.encryptedClientHello != nil {
if len(hs.serverHello.encryptedClientHello) != 8 {
hs.c.sendAlert(alertDecodeError)
return errors.New("tls: malformed encrypted client hello extension")
}
confTranscript := cloneHash(hs.echContext.innerTranscript, hs.suite.hash)
hrrHello := make([]byte, len(hs.serverHello.original))
copy(hrrHello, hs.serverHello.original)
hrrHello = bytes.Replace(hrrHello, hs.serverHello.encryptedClientHello, make([]byte, 8), 1)
confTranscript.Write(hrrHello)
h := hs.suite.hash.New
prk, err := hkdf.Extract(h, hs.echContext.innerHello.random, nil)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
acceptConfirmation := tls13.ExpandLabel(h, prk, "hrr ech accept confirmation", confTranscript.Sum(nil), 8)
if subtle.ConstantTimeCompare(acceptConfirmation, hs.serverHello.encryptedClientHello) == 1 {
hello = hs.echContext.innerHello
c.serverName = c.config.ServerName
isInnerHello = true
c.echAccepted = true
}
}
if err := transcriptMsg(hs.serverHello, hs.echContext.innerTranscript); err != nil {
return err
}
} else if hs.serverHello.encryptedClientHello != nil {
// Unsolicited ECH extension should be rejected
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: unexpected encrypted client hello extension in serverHello")
}
// The only HelloRetryRequest extensions we support are key_share and
// cookie, and clients must abort the handshake if the HRR would not result
// in any change in the ClientHello.
if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an unnecessary HelloRetryRequest message")
}
if hs.serverHello.cookie != nil {
hello.cookie = hs.serverHello.cookie
}
if hs.serverHello.serverShare.group != 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: received malformed key_share extension")
}
// If the server sent a key_share extension selecting a group, ensure it's
// a group we advertised but did not send a key share for, and send a key
// share for it this time.
if curveID := hs.serverHello.selectedGroup; curveID != 0 {
if !slices.Contains(hello.supportedCurves, curveID) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
if slices.ContainsFunc(hs.hello.keyShares, func(ks keyShare) bool {
return ks.group == curveID
}) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share")
}
// Note: we don't support selecting X25519MLKEM768 in a HRR, because it
// is currently first in preference order, so if it's enabled we'll
// always send a key share for it.
//
// This will have to change once we support multiple hybrid KEMs.
if _, ok := curveForCurveID(curveID); !ok {
c.sendAlert(alertInternalError)
return errors.New("tls: CurvePreferences includes unsupported curve")
}
key, err := generateECDHEKey(c.config.rand(), curveID)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.keyShareKeys = &keySharePrivateKeys{curveID: curveID, ecdhe: key}
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
}
if len(hello.pskIdentities) > 0 {
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
if pskSuite == nil {
return c.sendAlert(alertInternalError)
}
if pskSuite.hash == hs.suite.hash {
// Update binders and obfuscated_ticket_age.
ticketAge := c.config.time().Sub(time.Unix(int64(hs.session.createdAt), 0))
hello.pskIdentities[0].obfuscatedTicketAge = uint32(ticketAge/time.Millisecond) + hs.session.ageAdd
transcript := hs.suite.hash.New()
transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
transcript.Write(chHash)
if err := transcriptMsg(hs.serverHello, transcript); err != nil {
return err
}
if err := computeAndUpdatePSK(hello, hs.binderKey, transcript, hs.suite.finishedHash); err != nil {
return err
}
} else {
// Server selected a cipher suite incompatible with the PSK.
hello.pskIdentities = nil
hello.pskBinders = nil
}
}
if hello.earlyData {
hello.earlyData = false
c.quicRejectedEarlyData()
}
if isInnerHello {
// Any extensions which have changed in hello, but are mirrored in the
// outer hello and compressed, need to be copied to the outer hello, so
// they can be properly decompressed by the server. For now, the only
// extension which may have changed is keyShares.
hs.hello.keyShares = hello.keyShares
hs.echContext.innerHello = hello
if err := transcriptMsg(hs.echContext.innerHello, hs.echContext.innerTranscript); err != nil {
return err
}
if err := computeAndUpdateOuterECHExtension(hs.hello, hs.echContext.innerHello, hs.echContext, false); err != nil {
return err
}
} else {
hs.hello = hello
}
if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
return err
}
// serverHelloMsg is not included in the transcript
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
serverHello, ok := msg.(*serverHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(serverHello, msg)
}
hs.serverHello = serverHello
if err := hs.checkServerHelloOrHRR(); err != nil {
return err
}
c.didHRR = true
return nil
}
func (hs *clientHandshakeStateTLS13) processServerHello() error {
c := hs.c
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: server sent two HelloRetryRequest messages")
}
if len(hs.serverHello.cookie) != 0 {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent a cookie in a normal ServerHello")
}
if hs.serverHello.selectedGroup != 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: malformed key_share extension")
}
if hs.serverHello.serverShare.group == 0 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server did not send a key share")
}
if !slices.ContainsFunc(hs.hello.keyShares, func(ks keyShare) bool {
return ks.group == hs.serverHello.serverShare.group
}) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
if !hs.serverHello.selectedIdentityPresent {
return nil
}
if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected an invalid PSK")
}
if len(hs.hello.pskIdentities) != 1 || hs.session == nil {
return c.sendAlert(alertInternalError)
}
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
if pskSuite == nil {
return c.sendAlert(alertInternalError)
}
if pskSuite.hash != hs.suite.hash {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected an invalid PSK and cipher suite pair")
}
hs.usingPSK = true
c.didResume = true
c.peerCertificates = hs.session.peerCertificates
c.verifiedChains = hs.session.verifiedChains
c.ocspResponse = hs.session.ocspResponse
c.scts = hs.session.scts
return nil
}
func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
c := hs.c
ecdhePeerData := hs.serverHello.serverShare.data
if hs.serverHello.serverShare.group == X25519MLKEM768 {
if len(ecdhePeerData) != mlkem.CiphertextSize768+x25519PublicKeySize {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid server X25519MLKEM768 key share")
}
ecdhePeerData = hs.serverHello.serverShare.data[mlkem.CiphertextSize768:]
}
peerKey, err := hs.keyShareKeys.ecdhe.Curve().NewPublicKey(ecdhePeerData)
if err != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid server key share")
}
sharedKey, err := hs.keyShareKeys.ecdhe.ECDH(peerKey)
if err != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid server key share")
}
if hs.serverHello.serverShare.group == X25519MLKEM768 {
if hs.keyShareKeys.mlkem == nil {
return c.sendAlert(alertInternalError)
}
ciphertext := hs.serverHello.serverShare.data[:mlkem.CiphertextSize768]
mlkemShared, err := hs.keyShareKeys.mlkem.Decapsulate(ciphertext)
if err != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid X25519MLKEM768 server key share")
}
sharedKey = append(mlkemShared, sharedKey...)
}
c.curveID = hs.serverHello.serverShare.group
earlySecret := hs.earlySecret
if !hs.usingPSK {
earlySecret = tls13.NewEarlySecret(hs.suite.hash.New, nil)
}
handshakeSecret := earlySecret.HandshakeSecret(sharedKey)
clientSecret := handshakeSecret.ClientHandshakeTrafficSecret(hs.transcript)
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
serverSecret := handshakeSecret.ServerHandshakeTrafficSecret(hs.transcript)
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
if c.quic != nil {
if c.hand.Len() != 0 {
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
}
err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.masterSecret = handshakeSecret.MasterSecret()
return nil
}
func (hs *clientHandshakeStateTLS13) readServerParameters() error {
c := hs.c
msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
encryptedExtensions, ok := msg.(*encryptedExtensionsMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(encryptedExtensions, msg)
}
if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol, c.quic != nil); err != nil {
// RFC 8446 specifies that no_application_protocol is sent by servers, but
// does not specify how clients handle the selection of an incompatible protocol.
// RFC 9001 Section 8.1 specifies that QUIC clients send no_application_protocol
// in this case. Always sending no_application_protocol seems reasonable.
c.sendAlert(alertNoApplicationProtocol)
return err
}
c.clientProtocol = encryptedExtensions.alpnProtocol
if c.quic != nil {
if encryptedExtensions.quicTransportParameters == nil {
// RFC 9001 Section 8.2.
c.sendAlert(alertMissingExtension)
return errors.New("tls: server did not send a quic_transport_parameters extension")
}
c.quicSetTransportParameters(encryptedExtensions.quicTransportParameters)
} else {
if encryptedExtensions.quicTransportParameters != nil {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent an unexpected quic_transport_parameters extension")
}
}
if !hs.hello.earlyData && encryptedExtensions.earlyData {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent an unexpected early_data extension")
}
if hs.hello.earlyData && !encryptedExtensions.earlyData {
c.quicRejectedEarlyData()
}
if encryptedExtensions.earlyData {
if hs.session.cipherSuite != c.cipherSuite {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: server accepted 0-RTT with the wrong cipher suite")
}
if hs.session.alpnProtocol != c.clientProtocol {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: server accepted 0-RTT with the wrong ALPN")
}
}
if hs.echContext != nil {
if hs.echContext.echRejected {
hs.echContext.retryConfigs = encryptedExtensions.echRetryConfigs
} else if encryptedExtensions.echRetryConfigs != nil {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent encrypted client hello retry configs after accepting encrypted client hello")
}
}
return nil
}
func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
c := hs.c
// Either a PSK or a certificate is always used, but not both.
// See RFC 8446, Section 4.1.1.
if hs.usingPSK {
// Make sure the connection is still being verified whether or not this
// is a resumption. Resumptions currently don't reverify certificates so
// they don't call verifyServerCertificate. See Issue 31641.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}
msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
certReq, ok := msg.(*certificateRequestMsgTLS13)
if ok {
hs.certReq = certReq
msg, err = c.readHandshake(hs.transcript)
if err != nil {
return err
}
}
certMsg, ok := msg.(*certificateMsgTLS13)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
if len(certMsg.certificate.Certificate) == 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: received empty certificates message")
}
c.scts = certMsg.certificate.SignedCertificateTimestamps
c.ocspResponse = certMsg.certificate.OCSPStaple
if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil {
return err
}
// certificateVerifyMsg is included in the transcript, but not until
// after we verify the handshake signature, since the state before
// this message was sent is used.
msg, err = c.readHandshake(nil)
if err != nil {
return err
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerify, msg)
}
// See RFC 8446, Section 4.4.3.
// We don't use hs.hello.supportedSignatureAlgorithms because it might
// include PKCS#1 v1.5 and SHA-1 if the ClientHello also supported TLS 1.2.
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms(c.vers)) ||
!isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, signatureSchemesForPublicKey(c.vers, c.peerCertificates[0].PublicKey)) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: certificate used with invalid signature algorithm")
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
return c.sendAlert(alertInternalError)
}
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
}
c.peerSigAlg = certVerify.signatureAlgorithm
if err := transcriptMsg(certVerify, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *clientHandshakeStateTLS13) readServerFinished() error {
c := hs.c
// finishedMsg is included in the transcript, but not until after we
// check the client version, since the state before this message was
// sent is used during verification.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
finished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(finished, msg)
}
expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript)
if !hmac.Equal(expectedMAC, finished.verifyData) {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid server finished hash")
}
if err := transcriptMsg(finished, hs.transcript); err != nil {
return err
}
// Derive secrets that take context through the server Finished.
hs.trafficSecret = hs.masterSecret.ClientApplicationTrafficSecret(hs.transcript)
serverSecret := hs.masterSecret.ServerApplicationTrafficSecret(hs.transcript)
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript)
return nil
}
func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
c := hs.c
if hs.certReq == nil {
return nil
}
if hs.echContext != nil && hs.echContext.echRejected {
if _, err := hs.c.writeHandshakeRecord(&certificateMsgTLS13{}, hs.transcript); err != nil {
return err
}
return nil
}
cert, err := c.getClientCertificate(&CertificateRequestInfo{
AcceptableCAs: hs.certReq.certificateAuthorities,
SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
Version: c.vers,
ctx: hs.ctx,
})
if err != nil {
return err
}
certMsg := new(certificateMsgTLS13)
certMsg.certificate = *cert
certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0
certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0
if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
return err
}
// If we sent an empty certificate message, skip the CertificateVerify.
if len(cert.Certificate) == 0 {
return nil
}
certVerifyMsg := new(certificateVerifyMsg)
certVerifyMsg.hasSignatureAlgorithm = true
certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms)
if err != nil {
// getClientCertificate returned a certificate incompatible with the
// CertificateRequestInfo supported signature algorithms.
c.sendAlert(alertHandshakeFailure)
return err
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts)
if err != nil {
c.sendAlert(alertInternalError)
return errors.New("tls: failed to sign handshake: " + err.Error())
}
certVerifyMsg.signature = sig
if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
c := hs.c
finished := &finishedMsg{
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
}
if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
return err
}
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil {
c.resumptionSecret = hs.masterSecret.ResumptionMasterSecret(hs.transcript)
}
if c.quic != nil {
if c.hand.Len() != 0 {
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, hs.trafficSecret)
}
return nil
}
func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
if !c.isClient {
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: received new session ticket from a client")
}
if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
return nil
}
// See RFC 8446, Section 4.6.1.
if msg.lifetime == 0 {
return nil
}
lifetime := time.Duration(msg.lifetime) * time.Second
if lifetime > maxSessionTicketLifetime {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: received a session ticket with invalid lifetime")
}
if len(msg.label) == 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: received a session ticket with empty opaque ticket label")
}
// RFC 9001, Section 4.6.1
if c.quic != nil && msg.maxEarlyData != 0 && msg.maxEarlyData != 0xffffffff {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid early data for QUIC connection")
}
cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
if cipherSuite == nil || c.resumptionSecret == nil {
return c.sendAlert(alertInternalError)
}
psk := tls13.ExpandLabel(cipherSuite.hash.New, c.resumptionSecret, "resumption",
msg.nonce, cipherSuite.hash.Size())
session := c.sessionState()
session.secret = psk
session.useBy = uint64(c.config.time().Add(lifetime).Unix())
session.ageAdd = msg.ageAdd
session.EarlyData = c.quic != nil && msg.maxEarlyData == 0xffffffff // RFC 9001, Section 4.6.1
session.ticket = msg.label
if c.quic != nil && c.quic.enableSessionEvents {
c.quicStoreSession(session)
return nil
}
cs := &ClientSessionState{session: session}
if cacheKey := c.clientSessionCacheKey(); cacheKey != "" {
c.config.ClientSessionCache.Put(cacheKey, cs)
}
return nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"errors"
"fmt"
"slices"
"strings"
"golang.org/x/crypto/cryptobyte"
)
// The marshalingFunction type is an adapter to allow the use of ordinary
// functions as cryptobyte.MarshalingValue.
type marshalingFunction func(b *cryptobyte.Builder) error
func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error {
return f(b)
}
// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If
// the length of the sequence is not the value specified, it produces an error.
func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error {
if len(v) != n {
return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v))
}
b.AddBytes(v)
return nil
}))
}
// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder.
func addUint64(b *cryptobyte.Builder, v uint64) {
b.AddUint32(uint32(v >> 32))
b.AddUint32(uint32(v))
}
// readUint64 decodes a big-endian, 64-bit value into out and advances over it.
// It reports whether the read was successful.
func readUint64(s *cryptobyte.String, out *uint64) bool {
var hi, lo uint32
if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) {
return false
}
*out = uint64(hi)<<32 | uint64(lo)
return true
}
// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
// []byte instead of a cryptobyte.String.
func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
}
// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a
// []byte instead of a cryptobyte.String.
func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out))
}
// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a
// []byte instead of a cryptobyte.String.
func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out))
}
type clientHelloMsg struct {
original []byte
vers uint16
random []byte
sessionId []byte
cipherSuites []uint16
compressionMethods []uint8
serverName string
ocspStapling bool
supportedCurves []CurveID
supportedPoints []uint8
ticketSupported bool
sessionTicket []uint8
supportedSignatureAlgorithms []SignatureScheme
supportedSignatureAlgorithmsCert []SignatureScheme
secureRenegotiationSupported bool
secureRenegotiation []byte
extendedMasterSecret bool
alpnProtocols []string
scts bool
supportedVersions []uint16
cookie []byte
keyShares []keyShare
earlyData bool
pskModes []uint8
pskIdentities []pskIdentity
pskBinders [][]byte
quicTransportParameters []byte
encryptedClientHello []byte
// extensions are only populated on the server-side of a handshake
extensions []uint16
}
func (m *clientHelloMsg) marshalMsg(echInner bool) ([]byte, error) {
var exts cryptobyte.Builder
if len(m.serverName) > 0 {
// RFC 6066, Section 3
exts.AddUint16(extensionServerName)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8(0) // name_type = host_name
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes([]byte(m.serverName))
})
})
})
}
if len(m.supportedPoints) > 0 && !echInner {
// RFC 4492, Section 5.1.2
exts.AddUint16(extensionSupportedPoints)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.supportedPoints)
})
})
}
if m.ticketSupported && !echInner {
// RFC 5077, Section 3.2
exts.AddUint16(extensionSessionTicket)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.sessionTicket)
})
}
if m.secureRenegotiationSupported && !echInner {
// RFC 5746, Section 3.2
exts.AddUint16(extensionRenegotiationInfo)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.secureRenegotiation)
})
})
}
if m.extendedMasterSecret && !echInner {
// RFC 7627
exts.AddUint16(extensionExtendedMasterSecret)
exts.AddUint16(0) // empty extension_data
}
if m.scts {
// RFC 6962, Section 3.3.1
exts.AddUint16(extensionSCT)
exts.AddUint16(0) // empty extension_data
}
if m.earlyData {
// RFC 8446, Section 4.2.10
exts.AddUint16(extensionEarlyData)
exts.AddUint16(0) // empty extension_data
}
if m.quicTransportParameters != nil { // marshal zero-length parameters when present
// RFC 9001, Section 8.2
exts.AddUint16(extensionQUICTransportParameters)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.quicTransportParameters)
})
}
if len(m.encryptedClientHello) > 0 {
exts.AddUint16(extensionEncryptedClientHello)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.encryptedClientHello)
})
}
// Note that any extension that can be compressed during ECH must be
// contiguous. If any additional extensions are to be compressed they must
// be added to the following block, so that they can be properly
// decompressed on the other side.
var echOuterExts []uint16
if m.ocspStapling {
// RFC 4366, Section 3.6
if echInner {
echOuterExts = append(echOuterExts, extensionStatusRequest)
} else {
exts.AddUint16(extensionStatusRequest)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8(1) // status_type = ocsp
exts.AddUint16(0) // empty responder_id_list
exts.AddUint16(0) // empty request_extensions
})
}
}
if len(m.supportedCurves) > 0 {
// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
if echInner {
echOuterExts = append(echOuterExts, extensionSupportedCurves)
} else {
exts.AddUint16(extensionSupportedCurves)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, curve := range m.supportedCurves {
exts.AddUint16(uint16(curve))
}
})
})
}
}
if len(m.supportedSignatureAlgorithms) > 0 {
// RFC 5246, Section 7.4.1.4.1
if echInner {
echOuterExts = append(echOuterExts, extensionSignatureAlgorithms)
} else {
exts.AddUint16(extensionSignatureAlgorithms)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, sigAlgo := range m.supportedSignatureAlgorithms {
exts.AddUint16(uint16(sigAlgo))
}
})
})
}
}
if len(m.supportedSignatureAlgorithmsCert) > 0 {
// RFC 8446, Section 4.2.3
if echInner {
echOuterExts = append(echOuterExts, extensionSignatureAlgorithmsCert)
} else {
exts.AddUint16(extensionSignatureAlgorithmsCert)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
exts.AddUint16(uint16(sigAlgo))
}
})
})
}
}
if len(m.alpnProtocols) > 0 {
// RFC 7301, Section 3.1
if echInner {
echOuterExts = append(echOuterExts, extensionALPN)
} else {
exts.AddUint16(extensionALPN)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, proto := range m.alpnProtocols {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes([]byte(proto))
})
}
})
})
}
}
if len(m.supportedVersions) > 0 {
// RFC 8446, Section 4.2.1
if echInner {
echOuterExts = append(echOuterExts, extensionSupportedVersions)
} else {
exts.AddUint16(extensionSupportedVersions)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, vers := range m.supportedVersions {
exts.AddUint16(vers)
}
})
})
}
}
if len(m.cookie) > 0 {
// RFC 8446, Section 4.2.2
if echInner {
echOuterExts = append(echOuterExts, extensionCookie)
} else {
exts.AddUint16(extensionCookie)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.cookie)
})
})
}
}
if len(m.keyShares) > 0 {
// RFC 8446, Section 4.2.8
if echInner {
echOuterExts = append(echOuterExts, extensionKeyShare)
} else {
exts.AddUint16(extensionKeyShare)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, ks := range m.keyShares {
exts.AddUint16(uint16(ks.group))
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(ks.data)
})
}
})
})
}
}
if len(m.pskModes) > 0 {
// RFC 8446, Section 4.2.9
if echInner {
echOuterExts = append(echOuterExts, extensionPSKModes)
} else {
exts.AddUint16(extensionPSKModes)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.pskModes)
})
})
}
}
if len(echOuterExts) > 0 && echInner {
exts.AddUint16(extensionECHOuterExtensions)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, e := range echOuterExts {
exts.AddUint16(e)
}
})
})
}
if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
// RFC 8446, Section 4.2.11
exts.AddUint16(extensionPreSharedKey)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, psk := range m.pskIdentities {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(psk.label)
})
exts.AddUint32(psk.obfuscatedTicketAge)
}
})
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, binder := range m.pskBinders {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(binder)
})
}
})
})
}
extBytes, err := exts.Bytes()
if err != nil {
return nil, err
}
var b cryptobyte.Builder
b.AddUint8(typeClientHello)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16(m.vers)
addBytesWithLength(b, m.random, 32)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
if !echInner {
b.AddBytes(m.sessionId)
}
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, suite := range m.cipherSuites {
b.AddUint16(suite)
}
})
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.compressionMethods)
})
if len(extBytes) > 0 {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(extBytes)
})
}
})
return b.Bytes()
}
func (m *clientHelloMsg) marshal() ([]byte, error) {
return m.marshalMsg(false)
}
// marshalWithoutBinders returns the ClientHello through the
// PreSharedKeyExtension.identities field, according to RFC 8446, Section
// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length.
func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) {
bindersLen := 2 // uint16 length prefix
for _, binder := range m.pskBinders {
bindersLen += 1 // uint8 length prefix
bindersLen += len(binder)
}
var fullMessage []byte
if m.original != nil {
fullMessage = m.original
} else {
var err error
fullMessage, err = m.marshal()
if err != nil {
return nil, err
}
}
return fullMessage[:len(fullMessage)-bindersLen], nil
}
// updateBinders updates the m.pskBinders field. The supplied binders must have
// the same length as the current m.pskBinders.
func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error {
if len(pskBinders) != len(m.pskBinders) {
return errors.New("tls: internal error: pskBinders length mismatch")
}
for i := range m.pskBinders {
if len(pskBinders[i]) != len(m.pskBinders[i]) {
return errors.New("tls: internal error: pskBinders length mismatch")
}
}
m.pskBinders = pskBinders
return nil
}
func (m *clientHelloMsg) unmarshal(data []byte) bool {
*m = clientHelloMsg{original: data}
s := cryptobyte.String(data)
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
!readUint8LengthPrefixed(&s, &m.sessionId) {
return false
}
var cipherSuites cryptobyte.String
if !s.ReadUint16LengthPrefixed(&cipherSuites) {
return false
}
m.cipherSuites = []uint16{}
m.secureRenegotiationSupported = false
for !cipherSuites.Empty() {
var suite uint16
if !cipherSuites.ReadUint16(&suite) {
return false
}
if suite == scsvRenegotiation {
m.secureRenegotiationSupported = true
}
m.cipherSuites = append(m.cipherSuites, suite)
}
if !readUint8LengthPrefixed(&s, &m.compressionMethods) {
return false
}
if s.Empty() {
// ClientHello is optionally followed by extension data
return true
}
var extensions cryptobyte.String
if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
return false
}
seenExts := make(map[uint16]bool)
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
if seenExts[extension] {
return false
}
seenExts[extension] = true
m.extensions = append(m.extensions, extension)
switch extension {
case extensionServerName:
// RFC 6066, Section 3
var nameList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
return false
}
for !nameList.Empty() {
var nameType uint8
var serverName cryptobyte.String
if !nameList.ReadUint8(&nameType) ||
!nameList.ReadUint16LengthPrefixed(&serverName) ||
serverName.Empty() {
return false
}
if nameType != 0 {
continue
}
if len(m.serverName) != 0 {
// Multiple names of the same name_type are prohibited.
return false
}
m.serverName = string(serverName)
// An SNI value may not include a trailing dot.
if strings.HasSuffix(m.serverName, ".") {
return false
}
}
case extensionStatusRequest:
// RFC 4366, Section 3.6
var statusType uint8
var ignored cryptobyte.String
if !extData.ReadUint8(&statusType) ||
!extData.ReadUint16LengthPrefixed(&ignored) ||
!extData.ReadUint16LengthPrefixed(&ignored) {
return false
}
m.ocspStapling = statusType == statusTypeOCSP
case extensionSupportedCurves:
// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
var curves cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() {
return false
}
for !curves.Empty() {
var curve uint16
if !curves.ReadUint16(&curve) {
return false
}
m.supportedCurves = append(m.supportedCurves, CurveID(curve))
}
case extensionSupportedPoints:
// RFC 4492, Section 5.1.2
if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
len(m.supportedPoints) == 0 {
return false
}
case extensionSessionTicket:
// RFC 5077, Section 3.2
m.ticketSupported = true
extData.ReadBytes(&m.sessionTicket, len(extData))
case extensionSignatureAlgorithms:
// RFC 5246, Section 7.4.1.4.1
var sigAndAlgs cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
return false
}
for !sigAndAlgs.Empty() {
var sigAndAlg uint16
if !sigAndAlgs.ReadUint16(&sigAndAlg) {
return false
}
m.supportedSignatureAlgorithms = append(
m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
}
case extensionSignatureAlgorithmsCert:
// RFC 8446, Section 4.2.3
var sigAndAlgs cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
return false
}
for !sigAndAlgs.Empty() {
var sigAndAlg uint16
if !sigAndAlgs.ReadUint16(&sigAndAlg) {
return false
}
m.supportedSignatureAlgorithmsCert = append(
m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
}
case extensionRenegotiationInfo:
// RFC 5746, Section 3.2
if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
return false
}
m.secureRenegotiationSupported = true
case extensionExtendedMasterSecret:
// RFC 7627
m.extendedMasterSecret = true
case extensionALPN:
// RFC 7301, Section 3.1
var protoList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
return false
}
for !protoList.Empty() {
var proto cryptobyte.String
if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
return false
}
m.alpnProtocols = append(m.alpnProtocols, string(proto))
}
case extensionSCT:
// RFC 6962, Section 3.3.1
m.scts = true
case extensionSupportedVersions:
// RFC 8446, Section 4.2.1
var versList cryptobyte.String
if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() {
return false
}
for !versList.Empty() {
var vers uint16
if !versList.ReadUint16(&vers) {
return false
}
m.supportedVersions = append(m.supportedVersions, vers)
}
case extensionCookie:
// RFC 8446, Section 4.2.2
if !readUint16LengthPrefixed(&extData, &m.cookie) ||
len(m.cookie) == 0 {
return false
}
case extensionKeyShare:
// RFC 8446, Section 4.2.8
var clientShares cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&clientShares) {
return false
}
for !clientShares.Empty() {
var ks keyShare
if !clientShares.ReadUint16((*uint16)(&ks.group)) ||
!readUint16LengthPrefixed(&clientShares, &ks.data) ||
len(ks.data) == 0 {
return false
}
m.keyShares = append(m.keyShares, ks)
}
case extensionEarlyData:
// RFC 8446, Section 4.2.10
m.earlyData = true
case extensionPSKModes:
// RFC 8446, Section 4.2.9
if !readUint8LengthPrefixed(&extData, &m.pskModes) {
return false
}
case extensionQUICTransportParameters:
m.quicTransportParameters = make([]byte, len(extData))
if !extData.CopyBytes(m.quicTransportParameters) {
return false
}
case extensionPreSharedKey:
// RFC 8446, Section 4.2.11
if !extensions.Empty() {
return false // pre_shared_key must be the last extension
}
var identities cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() {
return false
}
for !identities.Empty() {
var psk pskIdentity
if !readUint16LengthPrefixed(&identities, &psk.label) ||
!identities.ReadUint32(&psk.obfuscatedTicketAge) ||
len(psk.label) == 0 {
return false
}
m.pskIdentities = append(m.pskIdentities, psk)
}
var binders cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() {
return false
}
for !binders.Empty() {
var binder []byte
if !readUint8LengthPrefixed(&binders, &binder) ||
len(binder) == 0 {
return false
}
m.pskBinders = append(m.pskBinders, binder)
}
case extensionEncryptedClientHello:
if !extData.ReadBytes(&m.encryptedClientHello, len(extData)) {
return false
}
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
return true
}
func (m *clientHelloMsg) originalBytes() []byte {
return m.original
}
func (m *clientHelloMsg) clone() *clientHelloMsg {
return &clientHelloMsg{
original: slices.Clone(m.original),
vers: m.vers,
random: slices.Clone(m.random),
sessionId: slices.Clone(m.sessionId),
cipherSuites: slices.Clone(m.cipherSuites),
compressionMethods: slices.Clone(m.compressionMethods),
serverName: m.serverName,
ocspStapling: m.ocspStapling,
supportedCurves: slices.Clone(m.supportedCurves),
supportedPoints: slices.Clone(m.supportedPoints),
ticketSupported: m.ticketSupported,
sessionTicket: slices.Clone(m.sessionTicket),
supportedSignatureAlgorithms: slices.Clone(m.supportedSignatureAlgorithms),
supportedSignatureAlgorithmsCert: slices.Clone(m.supportedSignatureAlgorithmsCert),
secureRenegotiationSupported: m.secureRenegotiationSupported,
secureRenegotiation: slices.Clone(m.secureRenegotiation),
extendedMasterSecret: m.extendedMasterSecret,
alpnProtocols: slices.Clone(m.alpnProtocols),
scts: m.scts,
supportedVersions: slices.Clone(m.supportedVersions),
cookie: slices.Clone(m.cookie),
keyShares: slices.Clone(m.keyShares),
earlyData: m.earlyData,
pskModes: slices.Clone(m.pskModes),
pskIdentities: slices.Clone(m.pskIdentities),
pskBinders: slices.Clone(m.pskBinders),
quicTransportParameters: slices.Clone(m.quicTransportParameters),
encryptedClientHello: slices.Clone(m.encryptedClientHello),
}
}
type serverHelloMsg struct {
original []byte
vers uint16
random []byte
sessionId []byte
cipherSuite uint16
compressionMethod uint8
ocspStapling bool
ticketSupported bool
secureRenegotiationSupported bool
secureRenegotiation []byte
extendedMasterSecret bool
alpnProtocol string
scts [][]byte
supportedVersion uint16
serverShare keyShare
selectedIdentityPresent bool
selectedIdentity uint16
supportedPoints []uint8
encryptedClientHello []byte
serverNameAck bool
// HelloRetryRequest extensions
cookie []byte
selectedGroup CurveID
}
func (m *serverHelloMsg) marshal() ([]byte, error) {
var exts cryptobyte.Builder
if m.ocspStapling {
exts.AddUint16(extensionStatusRequest)
exts.AddUint16(0) // empty extension_data
}
if m.ticketSupported {
exts.AddUint16(extensionSessionTicket)
exts.AddUint16(0) // empty extension_data
}
if m.secureRenegotiationSupported {
exts.AddUint16(extensionRenegotiationInfo)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.secureRenegotiation)
})
})
}
if m.extendedMasterSecret {
exts.AddUint16(extensionExtendedMasterSecret)
exts.AddUint16(0) // empty extension_data
}
if len(m.alpnProtocol) > 0 {
exts.AddUint16(extensionALPN)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes([]byte(m.alpnProtocol))
})
})
})
}
if len(m.scts) > 0 {
exts.AddUint16(extensionSCT)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, sct := range m.scts {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(sct)
})
}
})
})
}
if m.supportedVersion != 0 {
exts.AddUint16(extensionSupportedVersions)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16(m.supportedVersion)
})
}
if m.serverShare.group != 0 {
exts.AddUint16(extensionKeyShare)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16(uint16(m.serverShare.group))
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.serverShare.data)
})
})
}
if m.selectedIdentityPresent {
exts.AddUint16(extensionPreSharedKey)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16(m.selectedIdentity)
})
}
if len(m.cookie) > 0 {
exts.AddUint16(extensionCookie)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.cookie)
})
})
}
if m.selectedGroup != 0 {
exts.AddUint16(extensionKeyShare)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16(uint16(m.selectedGroup))
})
}
if len(m.supportedPoints) > 0 {
exts.AddUint16(extensionSupportedPoints)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.supportedPoints)
})
})
}
if len(m.encryptedClientHello) > 0 {
exts.AddUint16(extensionEncryptedClientHello)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.encryptedClientHello)
})
}
if m.serverNameAck {
exts.AddUint16(extensionServerName)
exts.AddUint16(0)
}
extBytes, err := exts.Bytes()
if err != nil {
return nil, err
}
var b cryptobyte.Builder
b.AddUint8(typeServerHello)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16(m.vers)
addBytesWithLength(b, m.random, 32)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.sessionId)
})
b.AddUint16(m.cipherSuite)
b.AddUint8(m.compressionMethod)
if len(extBytes) > 0 {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(extBytes)
})
}
})
return b.Bytes()
}
func (m *serverHelloMsg) unmarshal(data []byte) bool {
*m = serverHelloMsg{original: data}
s := cryptobyte.String(data)
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
!readUint8LengthPrefixed(&s, &m.sessionId) ||
!s.ReadUint16(&m.cipherSuite) ||
!s.ReadUint8(&m.compressionMethod) {
return false
}
if s.Empty() {
// ServerHello is optionally followed by extension data
return true
}
var extensions cryptobyte.String
if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
return false
}
seenExts := make(map[uint16]bool)
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
if seenExts[extension] {
return false
}
seenExts[extension] = true
switch extension {
case extensionStatusRequest:
m.ocspStapling = true
case extensionSessionTicket:
m.ticketSupported = true
case extensionRenegotiationInfo:
if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
return false
}
m.secureRenegotiationSupported = true
case extensionExtendedMasterSecret:
m.extendedMasterSecret = true
case extensionALPN:
var protoList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
return false
}
var proto cryptobyte.String
if !protoList.ReadUint8LengthPrefixed(&proto) ||
proto.Empty() || !protoList.Empty() {
return false
}
m.alpnProtocol = string(proto)
case extensionSCT:
var sctList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
return false
}
for !sctList.Empty() {
var sct []byte
if !readUint16LengthPrefixed(&sctList, &sct) ||
len(sct) == 0 {
return false
}
m.scts = append(m.scts, sct)
}
case extensionSupportedVersions:
if !extData.ReadUint16(&m.supportedVersion) {
return false
}
case extensionCookie:
if !readUint16LengthPrefixed(&extData, &m.cookie) ||
len(m.cookie) == 0 {
return false
}
case extensionKeyShare:
// This extension has different formats in SH and HRR, accept either
// and let the handshake logic decide. See RFC 8446, Section 4.2.8.
if len(extData) == 2 {
if !extData.ReadUint16((*uint16)(&m.selectedGroup)) {
return false
}
} else {
if !extData.ReadUint16((*uint16)(&m.serverShare.group)) ||
!readUint16LengthPrefixed(&extData, &m.serverShare.data) {
return false
}
}
case extensionPreSharedKey:
m.selectedIdentityPresent = true
if !extData.ReadUint16(&m.selectedIdentity) {
return false
}
case extensionSupportedPoints:
// RFC 4492, Section 5.1.2
if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
len(m.supportedPoints) == 0 {
return false
}
case extensionEncryptedClientHello: // encrypted_client_hello
m.encryptedClientHello = make([]byte, len(extData))
if !extData.CopyBytes(m.encryptedClientHello) {
return false
}
case extensionServerName:
if len(extData) != 0 {
return false
}
m.serverNameAck = true
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
return true
}
func (m *serverHelloMsg) originalBytes() []byte {
return m.original
}
type encryptedExtensionsMsg struct {
alpnProtocol string
quicTransportParameters []byte
earlyData bool
echRetryConfigs []byte
serverNameAck bool
}
func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint8(typeEncryptedExtensions)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if len(m.alpnProtocol) > 0 {
b.AddUint16(extensionALPN)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte(m.alpnProtocol))
})
})
})
}
if m.quicTransportParameters != nil { // marshal zero-length parameters when present
// draft-ietf-quic-tls-32, Section 8.2
b.AddUint16(extensionQUICTransportParameters)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.quicTransportParameters)
})
}
if m.earlyData {
// RFC 8446, Section 4.2.10
b.AddUint16(extensionEarlyData)
b.AddUint16(0) // empty extension_data
}
if len(m.echRetryConfigs) > 0 {
b.AddUint16(extensionEncryptedClientHello)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.echRetryConfigs)
})
}
if m.serverNameAck {
b.AddUint16(extensionServerName)
b.AddUint16(0) // empty extension_data
}
})
})
return b.Bytes()
}
func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
*m = encryptedExtensionsMsg{}
s := cryptobyte.String(data)
var extensions cryptobyte.String
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
return false
}
seenExts := make(map[uint16]bool)
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
if seenExts[extension] {
return false
}
seenExts[extension] = true
switch extension {
case extensionALPN:
var protoList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
return false
}
var proto cryptobyte.String
if !protoList.ReadUint8LengthPrefixed(&proto) ||
proto.Empty() || !protoList.Empty() {
return false
}
m.alpnProtocol = string(proto)
case extensionQUICTransportParameters:
m.quicTransportParameters = make([]byte, len(extData))
if !extData.CopyBytes(m.quicTransportParameters) {
return false
}
case extensionEarlyData:
// RFC 8446, Section 4.2.10
m.earlyData = true
case extensionEncryptedClientHello:
m.echRetryConfigs = make([]byte, len(extData))
if !extData.CopyBytes(m.echRetryConfigs) {
return false
}
case extensionServerName:
if len(extData) != 0 {
return false
}
m.serverNameAck = true
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
return true
}
type endOfEarlyDataMsg struct{}
func (m *endOfEarlyDataMsg) marshal() ([]byte, error) {
x := make([]byte, 4)
x[0] = typeEndOfEarlyData
return x, nil
}
func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
return len(data) == 4
}
type keyUpdateMsg struct {
updateRequested bool
}
func (m *keyUpdateMsg) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint8(typeKeyUpdate)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
if m.updateRequested {
b.AddUint8(1)
} else {
b.AddUint8(0)
}
})
return b.Bytes()
}
func (m *keyUpdateMsg) unmarshal(data []byte) bool {
s := cryptobyte.String(data)
var updateRequested uint8
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint8(&updateRequested) || !s.Empty() {
return false
}
switch updateRequested {
case 0:
m.updateRequested = false
case 1:
m.updateRequested = true
default:
return false
}
return true
}
type newSessionTicketMsgTLS13 struct {
lifetime uint32
ageAdd uint32
nonce []byte
label []byte
maxEarlyData uint32
}
func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint8(typeNewSessionTicket)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint32(m.lifetime)
b.AddUint32(m.ageAdd)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.nonce)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.label)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if m.maxEarlyData > 0 {
b.AddUint16(extensionEarlyData)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint32(m.maxEarlyData)
})
}
})
})
return b.Bytes()
}
func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
*m = newSessionTicketMsgTLS13{}
s := cryptobyte.String(data)
var extensions cryptobyte.String
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint32(&m.lifetime) ||
!s.ReadUint32(&m.ageAdd) ||
!readUint8LengthPrefixed(&s, &m.nonce) ||
!readUint16LengthPrefixed(&s, &m.label) ||
!s.ReadUint16LengthPrefixed(&extensions) ||
!s.Empty() {
return false
}
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
switch extension {
case extensionEarlyData:
if !extData.ReadUint32(&m.maxEarlyData) {
return false
}
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
return true
}
type certificateRequestMsgTLS13 struct {
ocspStapling bool
scts bool
supportedSignatureAlgorithms []SignatureScheme
supportedSignatureAlgorithmsCert []SignatureScheme
certificateAuthorities [][]byte
}
func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint8(typeCertificateRequest)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
// certificate_request_context (SHALL be zero length unless used for
// post-handshake authentication)
b.AddUint8(0)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if m.ocspStapling {
b.AddUint16(extensionStatusRequest)
b.AddUint16(0) // empty extension_data
}
if m.scts {
// RFC 8446, Section 4.4.2.1 makes no mention of
// signed_certificate_timestamp in CertificateRequest, but
// "Extensions in the Certificate message from the client MUST
// correspond to extensions in the CertificateRequest message
// from the server." and it appears in the table in Section 4.2.
b.AddUint16(extensionSCT)
b.AddUint16(0) // empty extension_data
}
if len(m.supportedSignatureAlgorithms) > 0 {
b.AddUint16(extensionSignatureAlgorithms)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, sigAlgo := range m.supportedSignatureAlgorithms {
b.AddUint16(uint16(sigAlgo))
}
})
})
}
if len(m.supportedSignatureAlgorithmsCert) > 0 {
b.AddUint16(extensionSignatureAlgorithmsCert)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
b.AddUint16(uint16(sigAlgo))
}
})
})
}
if len(m.certificateAuthorities) > 0 {
b.AddUint16(extensionCertificateAuthorities)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, ca := range m.certificateAuthorities {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(ca)
})
}
})
})
}
})
})
return b.Bytes()
}
func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
*m = certificateRequestMsgTLS13{}
s := cryptobyte.String(data)
var context, extensions cryptobyte.String
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
!s.ReadUint16LengthPrefixed(&extensions) ||
!s.Empty() {
return false
}
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
switch extension {
case extensionStatusRequest:
m.ocspStapling = true
case extensionSCT:
m.scts = true
case extensionSignatureAlgorithms:
var sigAndAlgs cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
return false
}
for !sigAndAlgs.Empty() {
var sigAndAlg uint16
if !sigAndAlgs.ReadUint16(&sigAndAlg) {
return false
}
m.supportedSignatureAlgorithms = append(
m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
}
case extensionSignatureAlgorithmsCert:
var sigAndAlgs cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
return false
}
for !sigAndAlgs.Empty() {
var sigAndAlg uint16
if !sigAndAlgs.ReadUint16(&sigAndAlg) {
return false
}
m.supportedSignatureAlgorithmsCert = append(
m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
}
case extensionCertificateAuthorities:
var auths cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() {
return false
}
for !auths.Empty() {
var ca []byte
if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 {
return false
}
m.certificateAuthorities = append(m.certificateAuthorities, ca)
}
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
return true
}
type certificateMsg struct {
certificates [][]byte
}
func (m *certificateMsg) marshal() ([]byte, error) {
var i int
for _, slice := range m.certificates {
i += len(slice)
}
length := 3 + 3*len(m.certificates) + i
x := make([]byte, 4+length)
x[0] = typeCertificate
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
certificateOctets := length - 3
x[4] = uint8(certificateOctets >> 16)
x[5] = uint8(certificateOctets >> 8)
x[6] = uint8(certificateOctets)
y := x[7:]
for _, slice := range m.certificates {
y[0] = uint8(len(slice) >> 16)
y[1] = uint8(len(slice) >> 8)
y[2] = uint8(len(slice))
copy(y[3:], slice)
y = y[3+len(slice):]
}
return x, nil
}
func (m *certificateMsg) unmarshal(data []byte) bool {
if len(data) < 7 {
return false
}
certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
if uint32(len(data)) != certsLen+7 {
return false
}
numCerts := 0
d := data[7:]
for certsLen > 0 {
if len(d) < 4 {
return false
}
certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
if uint32(len(d)) < 3+certLen {
return false
}
d = d[3+certLen:]
certsLen -= 3 + certLen
numCerts++
}
m.certificates = make([][]byte, numCerts)
d = data[7:]
for i := 0; i < numCerts; i++ {
certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
m.certificates[i] = d[3 : 3+certLen]
d = d[3+certLen:]
}
return true
}
type certificateMsgTLS13 struct {
certificate Certificate
ocspStapling bool
scts bool
}
func (m *certificateMsgTLS13) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint8(typeCertificate)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8(0) // certificate_request_context
certificate := m.certificate
if !m.ocspStapling {
certificate.OCSPStaple = nil
}
if !m.scts {
certificate.SignedCertificateTimestamps = nil
}
marshalCertificate(b, certificate)
})
return b.Bytes()
}
func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
for i, cert := range certificate.Certificate {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(cert)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if i > 0 {
// This library only supports OCSP and SCT for leaf certificates.
return
}
if certificate.OCSPStaple != nil {
b.AddUint16(extensionStatusRequest)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8(statusTypeOCSP)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(certificate.OCSPStaple)
})
})
}
if certificate.SignedCertificateTimestamps != nil {
b.AddUint16(extensionSCT)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, sct := range certificate.SignedCertificateTimestamps {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(sct)
})
}
})
})
}
})
}
})
}
func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
*m = certificateMsgTLS13{}
s := cryptobyte.String(data)
var context cryptobyte.String
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
!unmarshalCertificate(&s, &m.certificate) ||
!s.Empty() {
return false
}
m.scts = m.certificate.SignedCertificateTimestamps != nil
m.ocspStapling = m.certificate.OCSPStaple != nil
return true
}
func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
var certList cryptobyte.String
if !s.ReadUint24LengthPrefixed(&certList) {
return false
}
for !certList.Empty() {
var cert []byte
var extensions cryptobyte.String
if !readUint24LengthPrefixed(&certList, &cert) ||
!certList.ReadUint16LengthPrefixed(&extensions) {
return false
}
certificate.Certificate = append(certificate.Certificate, cert)
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
if len(certificate.Certificate) > 1 {
// This library only supports OCSP and SCT for leaf certificates.
continue
}
switch extension {
case extensionStatusRequest:
var statusType uint8
if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
!readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) ||
len(certificate.OCSPStaple) == 0 {
return false
}
case extensionSCT:
var sctList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
return false
}
for !sctList.Empty() {
var sct []byte
if !readUint16LengthPrefixed(&sctList, &sct) ||
len(sct) == 0 {
return false
}
certificate.SignedCertificateTimestamps = append(
certificate.SignedCertificateTimestamps, sct)
}
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
}
return true
}
type serverKeyExchangeMsg struct {
key []byte
}
func (m *serverKeyExchangeMsg) marshal() ([]byte, error) {
length := len(m.key)
x := make([]byte, length+4)
x[0] = typeServerKeyExchange
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
copy(x[4:], m.key)
return x, nil
}
func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
if len(data) < 4 {
return false
}
m.key = data[4:]
return true
}
type certificateStatusMsg struct {
response []byte
}
func (m *certificateStatusMsg) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint8(typeCertificateStatus)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8(statusTypeOCSP)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.response)
})
})
return b.Bytes()
}
func (m *certificateStatusMsg) unmarshal(data []byte) bool {
s := cryptobyte.String(data)
var statusType uint8
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
!readUint24LengthPrefixed(&s, &m.response) ||
len(m.response) == 0 || !s.Empty() {
return false
}
return true
}
type serverHelloDoneMsg struct{}
func (m *serverHelloDoneMsg) marshal() ([]byte, error) {
x := make([]byte, 4)
x[0] = typeServerHelloDone
return x, nil
}
func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
return len(data) == 4
}
type clientKeyExchangeMsg struct {
ciphertext []byte
}
func (m *clientKeyExchangeMsg) marshal() ([]byte, error) {
length := len(m.ciphertext)
x := make([]byte, length+4)
x[0] = typeClientKeyExchange
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
copy(x[4:], m.ciphertext)
return x, nil
}
func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
if len(data) < 4 {
return false
}
l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if l != len(data)-4 {
return false
}
m.ciphertext = data[4:]
return true
}
type finishedMsg struct {
verifyData []byte
}
func (m *finishedMsg) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint8(typeFinished)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.verifyData)
})
return b.Bytes()
}
func (m *finishedMsg) unmarshal(data []byte) bool {
s := cryptobyte.String(data)
return s.Skip(1) &&
readUint24LengthPrefixed(&s, &m.verifyData) &&
s.Empty()
}
type certificateRequestMsg struct {
// hasSignatureAlgorithm indicates whether this message includes a list of
// supported signature algorithms. This change was introduced with TLS 1.2.
hasSignatureAlgorithm bool
certificateTypes []byte
supportedSignatureAlgorithms []SignatureScheme
certificateAuthorities [][]byte
}
func (m *certificateRequestMsg) marshal() ([]byte, error) {
// See RFC 4346, Section 7.4.4.
length := 1 + len(m.certificateTypes) + 2
casLength := 0
for _, ca := range m.certificateAuthorities {
casLength += 2 + len(ca)
}
length += casLength
if m.hasSignatureAlgorithm {
length += 2 + 2*len(m.supportedSignatureAlgorithms)
}
x := make([]byte, 4+length)
x[0] = typeCertificateRequest
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
x[4] = uint8(len(m.certificateTypes))
copy(x[5:], m.certificateTypes)
y := x[5+len(m.certificateTypes):]
if m.hasSignatureAlgorithm {
n := len(m.supportedSignatureAlgorithms) * 2
y[0] = uint8(n >> 8)
y[1] = uint8(n)
y = y[2:]
for _, sigAlgo := range m.supportedSignatureAlgorithms {
y[0] = uint8(sigAlgo >> 8)
y[1] = uint8(sigAlgo)
y = y[2:]
}
}
y[0] = uint8(casLength >> 8)
y[1] = uint8(casLength)
y = y[2:]
for _, ca := range m.certificateAuthorities {
y[0] = uint8(len(ca) >> 8)
y[1] = uint8(len(ca))
y = y[2:]
copy(y, ca)
y = y[len(ca):]
}
return x, nil
}
func (m *certificateRequestMsg) unmarshal(data []byte) bool {
if len(data) < 5 {
return false
}
length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
if uint32(len(data))-4 != length {
return false
}
numCertTypes := int(data[4])
data = data[5:]
if numCertTypes == 0 || len(data) <= numCertTypes {
return false
}
m.certificateTypes = make([]byte, numCertTypes)
if copy(m.certificateTypes, data) != numCertTypes {
return false
}
data = data[numCertTypes:]
if m.hasSignatureAlgorithm {
if len(data) < 2 {
return false
}
sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
data = data[2:]
if sigAndHashLen&1 != 0 || sigAndHashLen == 0 {
return false
}
if len(data) < int(sigAndHashLen) {
return false
}
numSigAlgos := sigAndHashLen / 2
m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
for i := range m.supportedSignatureAlgorithms {
m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
data = data[2:]
}
}
if len(data) < 2 {
return false
}
casLength := uint16(data[0])<<8 | uint16(data[1])
data = data[2:]
if len(data) < int(casLength) {
return false
}
cas := make([]byte, casLength)
copy(cas, data)
data = data[casLength:]
m.certificateAuthorities = nil
for len(cas) > 0 {
if len(cas) < 2 {
return false
}
caLen := uint16(cas[0])<<8 | uint16(cas[1])
cas = cas[2:]
if len(cas) < int(caLen) {
return false
}
m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
cas = cas[caLen:]
}
return len(data) == 0
}
type certificateVerifyMsg struct {
hasSignatureAlgorithm bool // format change introduced in TLS 1.2
signatureAlgorithm SignatureScheme
signature []byte
}
func (m *certificateVerifyMsg) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint8(typeCertificateVerify)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
if m.hasSignatureAlgorithm {
b.AddUint16(uint16(m.signatureAlgorithm))
}
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.signature)
})
})
return b.Bytes()
}
func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
s := cryptobyte.String(data)
if !s.Skip(4) { // message type and uint24 length field
return false
}
if m.hasSignatureAlgorithm {
if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) {
return false
}
}
return readUint16LengthPrefixed(&s, &m.signature) && s.Empty()
}
type newSessionTicketMsg struct {
ticket []byte
}
func (m *newSessionTicketMsg) marshal() ([]byte, error) {
// See RFC 5077, Section 3.3.
ticketLen := len(m.ticket)
length := 2 + 4 + ticketLen
x := make([]byte, 4+length)
x[0] = typeNewSessionTicket
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
x[8] = uint8(ticketLen >> 8)
x[9] = uint8(ticketLen)
copy(x[10:], m.ticket)
return x, nil
}
func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
if len(data) < 10 {
return false
}
length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
if uint32(len(data))-4 != length {
return false
}
ticketLen := int(data[8])<<8 + int(data[9])
if len(data)-10 != ticketLen {
return false
}
m.ticket = data[10:]
return true
}
type helloRequestMsg struct {
}
func (*helloRequestMsg) marshal() ([]byte, error) {
return []byte{typeHelloRequest, 0, 0, 0}, nil
}
func (*helloRequestMsg) unmarshal(data []byte) bool {
return len(data) == 4
}
type transcriptHash interface {
Write([]byte) (int, error)
}
// transcriptMsg is a helper used to hash messages which are not hashed when
// they are read from, or written to, the wire. This is typically the case for
// messages which are either not sent, or need to be hashed out of order from
// when they are read/written.
//
// For most messages, the message is marshalled using their marshal method,
// since their wire representation is idempotent. For clientHelloMsg and
// serverHelloMsg, we store the original wire representation of the message and
// use that for hashing, since unmarshal/marshal are not idempotent due to
// extension ordering and other malleable fields, which may cause differences
// between what was received and what we marshal.
func transcriptMsg(msg handshakeMessage, h transcriptHash) error {
if msgWithOrig, ok := msg.(handshakeMessageWithOriginalBytes); ok {
if orig := msgWithOrig.originalBytes(); orig != nil {
h.Write(msgWithOrig.originalBytes())
return nil
}
}
data, err := msg.marshal()
if err != nil {
return err
}
h.Write(data)
return nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/subtle"
"crypto/tls/internal/fips140tls"
"crypto/x509"
"errors"
"fmt"
"hash"
"io"
"time"
)
// serverHandshakeState contains details of a server handshake in progress.
// It's discarded once the handshake has completed.
type serverHandshakeState struct {
c *Conn
ctx context.Context
clientHello *clientHelloMsg
hello *serverHelloMsg
suite *cipherSuite
ecdheOk bool
ecSignOk bool
rsaDecryptOk bool
rsaSignOk bool
sessionState *SessionState
finishedHash finishedHash
masterSecret []byte
cert *Certificate
}
// serverHandshake performs a TLS handshake as a server.
func (c *Conn) serverHandshake(ctx context.Context) error {
clientHello, ech, err := c.readClientHello(ctx)
if err != nil {
return err
}
if c.vers == VersionTLS13 {
hs := serverHandshakeStateTLS13{
c: c,
ctx: ctx,
clientHello: clientHello,
echContext: ech,
}
return hs.handshake()
}
hs := serverHandshakeState{
c: c,
ctx: ctx,
clientHello: clientHello,
}
return hs.handshake()
}
func (hs *serverHandshakeState) handshake() error {
c := hs.c
if err := hs.processClientHello(); err != nil {
return err
}
// For an overview of TLS handshaking, see RFC 5246, Section 7.3.
c.buffering = true
if err := hs.checkForResumption(); err != nil {
return err
}
if hs.sessionState != nil {
// The client has included a session ticket and so we do an abbreviated handshake.
if err := hs.doResumeHandshake(); err != nil {
return err
}
if err := hs.establishKeys(); err != nil {
return err
}
if err := hs.sendSessionTicket(); err != nil {
return err
}
if err := hs.sendFinished(c.serverFinished[:]); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
c.clientFinishedIsFirst = false
if err := hs.readFinished(nil); err != nil {
return err
}
} else {
// The client didn't include a session ticket, or it wasn't
// valid so we do a full handshake.
if err := hs.pickCipherSuite(); err != nil {
return err
}
if err := hs.doFullHandshake(); err != nil {
return err
}
if err := hs.establishKeys(); err != nil {
return err
}
if err := hs.readFinished(c.clientFinished[:]); err != nil {
return err
}
c.clientFinishedIsFirst = true
c.buffering = true
if err := hs.sendSessionTicket(); err != nil {
return err
}
if err := hs.sendFinished(nil); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
}
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
c.isHandshakeComplete.Store(true)
return nil
}
// readClientHello reads a ClientHello message and selects the protocol version.
func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, *echServerContext, error) {
// clientHelloMsg is included in the transcript, but we haven't initialized
// it yet. The respective handshake functions will record it themselves.
msg, err := c.readHandshake(nil)
if err != nil {
return nil, nil, err
}
clientHello, ok := msg.(*clientHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return nil, nil, unexpectedMessageError(clientHello, msg)
}
// ECH processing has to be done before we do any other negotiation based on
// the contents of the client hello, since we may swap it out completely.
var ech *echServerContext
if len(clientHello.encryptedClientHello) != 0 {
echKeys := c.config.EncryptedClientHelloKeys
if c.config.GetEncryptedClientHelloKeys != nil {
echKeys, err = c.config.GetEncryptedClientHelloKeys(clientHelloInfo(ctx, c, clientHello))
if err != nil {
c.sendAlert(alertInternalError)
return nil, nil, err
}
}
clientHello, ech, err = c.processECHClientHello(clientHello, echKeys)
if err != nil {
return nil, nil, err
}
}
var configForClient *Config
originalConfig := c.config
if c.config.GetConfigForClient != nil {
chi := clientHelloInfo(ctx, c, clientHello)
if configForClient, err = c.config.GetConfigForClient(chi); err != nil {
c.sendAlert(alertInternalError)
return nil, nil, err
} else if configForClient != nil {
c.config = configForClient
}
}
c.ticketKeys = originalConfig.ticketKeys(configForClient)
clientVersions := clientHello.supportedVersions
if clientHello.vers >= VersionTLS13 && len(clientVersions) == 0 {
// RFC 8446 4.2.1 indicates when the supported_versions extension is not sent,
// compatible servers MUST negotiate TLS 1.2 or earlier if supported, even
// if the client legacy version is TLS 1.3 or later.
//
// Since we reject empty extensionSupportedVersions in the client hello unmarshal
// finding the supportedVersions empty indicates the extension was not present.
clientVersions = supportedVersionsFromMax(VersionTLS12)
} else if len(clientVersions) == 0 {
clientVersions = supportedVersionsFromMax(clientHello.vers)
}
c.vers, ok = c.config.mutualVersion(roleServer, clientVersions)
if !ok {
c.sendAlert(alertProtocolVersion)
return nil, nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions)
}
c.haveVers = true
c.in.version = c.vers
c.out.version = c.vers
// This check reflects some odd specification implied behavior. Client-facing servers
// are supposed to reject hellos with outer ECH and inner ECH that offers 1.2, but
// backend servers are allowed to accept hellos with inner ECH that offer 1.2, since
// they cannot expect client-facing servers to behave properly. Since we act as both
// a client-facing and backend server, we only enforce 1.3 being negotiated if we
// saw a hello with outer ECH first. The spec probably should've made this an error,
// but it didn't, and this matches the boringssl behavior.
if c.vers != VersionTLS13 && (ech != nil && !ech.inner) {
c.sendAlert(alertIllegalParameter)
return nil, nil, errors.New("tls: Encrypted Client Hello cannot be used pre-TLS 1.3")
}
if c.config.MinVersion == 0 && c.vers < VersionTLS12 {
tls10server.Value() // ensure godebug is initialized
tls10server.IncNonDefault()
}
return clientHello, ech, nil
}
func (hs *serverHandshakeState) processClientHello() error {
c := hs.c
hs.hello = new(serverHelloMsg)
hs.hello.vers = c.vers
foundCompression := false
// We only support null compression, so check that the client offered it.
for _, compression := range hs.clientHello.compressionMethods {
if compression == compressionNone {
foundCompression = true
break
}
}
if !foundCompression {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client does not support uncompressed connections")
}
hs.hello.random = make([]byte, 32)
serverRandom := hs.hello.random
// Downgrade protection canaries. See RFC 8446, Section 4.1.3.
maxVers := c.config.maxSupportedVersion(roleServer)
if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary {
if c.vers == VersionTLS12 {
copy(serverRandom[24:], downgradeCanaryTLS12)
} else {
copy(serverRandom[24:], downgradeCanaryTLS11)
}
serverRandom = serverRandom[:24]
}
_, err := io.ReadFull(c.config.rand(), serverRandom)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
if len(hs.clientHello.secureRenegotiation) != 0 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: initial handshake had non-empty renegotiation extension")
}
hs.hello.extendedMasterSecret = hs.clientHello.extendedMasterSecret
hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported
hs.hello.compressionMethod = compressionNone
if len(hs.clientHello.serverName) > 0 {
c.serverName = hs.clientHello.serverName
}
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, false)
if err != nil {
c.sendAlert(alertNoApplicationProtocol)
return err
}
hs.hello.alpnProtocol = selectedProto
c.clientProtocol = selectedProto
hs.cert, err = c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
if err != nil {
if err == errNoCertificates {
c.sendAlert(alertUnrecognizedName)
} else {
c.sendAlert(alertInternalError)
}
return err
}
if hs.clientHello.scts {
hs.hello.scts = hs.cert.SignedCertificateTimestamps
}
hs.ecdheOk, err = supportsECDHE(c.config, c.vers, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints)
if err != nil {
c.sendAlert(alertMissingExtension)
return err
}
if hs.ecdheOk && len(hs.clientHello.supportedPoints) > 0 {
// Although omitting the ec_point_formats extension is permitted, some
// old OpenSSL version will refuse to handshake if not present.
//
// Per RFC 4492, section 5.1.2, implementations MUST support the
// uncompressed point format. See golang.org/issue/31943.
hs.hello.supportedPoints = []uint8{pointFormatUncompressed}
}
if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok {
switch priv.Public().(type) {
case *ecdsa.PublicKey:
hs.ecSignOk = true
case ed25519.PublicKey:
hs.ecSignOk = true
case *rsa.PublicKey:
hs.rsaSignOk = true
default:
c.sendAlert(alertInternalError)
return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public())
}
}
if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok {
switch priv.Public().(type) {
case *rsa.PublicKey:
hs.rsaDecryptOk = true
default:
c.sendAlert(alertInternalError)
return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public())
}
}
return nil
}
// negotiateALPN picks a shared ALPN protocol that both sides support in server
// preference order. If ALPN is not configured or the peer doesn't support it,
// it returns "" and no error.
func negotiateALPN(serverProtos, clientProtos []string, quic bool) (string, error) {
if len(serverProtos) == 0 || len(clientProtos) == 0 {
if quic && len(serverProtos) != 0 {
// RFC 9001, Section 8.1
return "", fmt.Errorf("tls: client did not request an application protocol")
}
return "", nil
}
var http11fallback bool
for _, s := range serverProtos {
for _, c := range clientProtos {
if s == c {
return s, nil
}
if s == "h2" && c == "http/1.1" {
http11fallback = true
}
}
}
// As a special case, let http/1.1 clients connect to h2 servers as if they
// didn't support ALPN. We used not to enforce protocol overlap, so over
// time a number of HTTP servers were configured with only "h2", but
// expected to accept connections from "http/1.1" clients. See Issue 46310.
if http11fallback {
return "", nil
}
return "", fmt.Errorf("tls: client requested unsupported application protocols (%s)", clientProtos)
}
// supportsECDHE returns whether ECDHE key exchanges can be used with this
// pre-TLS 1.3 client.
func supportsECDHE(c *Config, version uint16, supportedCurves []CurveID, supportedPoints []uint8) (bool, error) {
supportsCurve := false
for _, curve := range supportedCurves {
if c.supportsCurve(version, curve) {
supportsCurve = true
break
}
}
supportsPointFormat := false
offeredNonCompressedFormat := false
for _, pointFormat := range supportedPoints {
if pointFormat == pointFormatUncompressed {
supportsPointFormat = true
} else {
offeredNonCompressedFormat = true
}
}
// Per RFC 8422, Section 5.1.2, if the Supported Point Formats extension is
// missing, uncompressed points are supported. If supportedPoints is empty,
// the extension must be missing, as an empty extension body is rejected by
// the parser. See https://go.dev/issue/49126.
if len(supportedPoints) == 0 {
supportsPointFormat = true
} else if offeredNonCompressedFormat && !supportsPointFormat {
return false, errors.New("tls: client offered only incompatible point formats")
}
return supportsCurve && supportsPointFormat, nil
}
func (hs *serverHandshakeState) pickCipherSuite() error {
c := hs.c
preferenceList := c.config.cipherSuites(isAESGCMPreferred(hs.clientHello.cipherSuites))
hs.suite = selectCipherSuite(preferenceList, hs.clientHello.cipherSuites, hs.cipherSuiteOk)
if hs.suite == nil {
c.sendAlert(alertHandshakeFailure)
return fmt.Errorf("tls: no cipher suite supported by both client and server; client offered: %x",
hs.clientHello.cipherSuites)
}
c.cipherSuite = hs.suite.id
if c.config.CipherSuites == nil && !fips140tls.Required() && rsaKexCiphers[hs.suite.id] {
tlsrsakex.Value() // ensure godebug is initialized
tlsrsakex.IncNonDefault()
}
if c.config.CipherSuites == nil && !fips140tls.Required() && tdesCiphers[hs.suite.id] {
tls3des.Value() // ensure godebug is initialized
tls3des.IncNonDefault()
}
for _, id := range hs.clientHello.cipherSuites {
if id == TLS_FALLBACK_SCSV {
// The client is doing a fallback connection. See RFC 7507.
if hs.clientHello.vers < c.config.maxSupportedVersion(roleServer) {
c.sendAlert(alertInappropriateFallback)
return errors.New("tls: client using inappropriate protocol fallback")
}
break
}
}
return nil
}
func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool {
if c.flags&suiteECDHE != 0 {
if !hs.ecdheOk {
return false
}
if c.flags&suiteECSign != 0 {
if !hs.ecSignOk {
return false
}
} else if !hs.rsaSignOk {
return false
}
} else if !hs.rsaDecryptOk {
return false
}
if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 {
return false
}
return true
}
// checkForResumption reports whether we should perform resumption on this connection.
func (hs *serverHandshakeState) checkForResumption() error {
c := hs.c
if c.config.SessionTicketsDisabled {
return nil
}
var sessionState *SessionState
if c.config.UnwrapSession != nil {
ss, err := c.config.UnwrapSession(hs.clientHello.sessionTicket, c.connectionStateLocked())
if err != nil {
return err
}
if ss == nil {
return nil
}
sessionState = ss
} else {
plaintext := c.config.decryptTicket(hs.clientHello.sessionTicket, c.ticketKeys)
if plaintext == nil {
return nil
}
ss, err := ParseSessionState(plaintext)
if err != nil {
return nil
}
sessionState = ss
}
// TLS 1.2 tickets don't natively have a lifetime, but we want to avoid
// re-wrapping the same master secret in different tickets over and over for
// too long, weakening forward secrecy.
createdAt := time.Unix(int64(sessionState.createdAt), 0)
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
return nil
}
// Never resume a session for a different TLS version.
if c.vers != sessionState.version {
return nil
}
cipherSuiteOk := false
// Check that the client is still offering the ciphersuite in the session.
for _, id := range hs.clientHello.cipherSuites {
if id == sessionState.cipherSuite {
cipherSuiteOk = true
break
}
}
if !cipherSuiteOk {
return nil
}
// Check that we also support the ciphersuite from the session.
suite := selectCipherSuite([]uint16{sessionState.cipherSuite},
c.config.supportedCipherSuites(), hs.cipherSuiteOk)
if suite == nil {
return nil
}
sessionHasClientCerts := len(sessionState.peerCertificates) != 0
needClientCerts := requiresClientCert(c.config.ClientAuth)
if needClientCerts && !sessionHasClientCerts {
return nil
}
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
return nil
}
if sessionHasClientCerts && c.config.time().After(sessionState.peerCertificates[0].NotAfter) {
return nil
}
if sessionHasClientCerts && c.config.ClientAuth >= VerifyClientCertIfGiven &&
len(sessionState.verifiedChains) == 0 {
return nil
}
// RFC 7627, Section 5.3
if !sessionState.extMasterSecret && hs.clientHello.extendedMasterSecret {
return nil
}
if sessionState.extMasterSecret && !hs.clientHello.extendedMasterSecret {
// Aborting is somewhat harsh, but it's a MUST and it would indicate a
// weird downgrade in client capabilities.
return errors.New("tls: session supported extended_master_secret but client does not")
}
if !sessionState.extMasterSecret && fips140tls.Required() {
// FIPS 140-3 requires the use of Extended Master Secret.
return nil
}
c.peerCertificates = sessionState.peerCertificates
c.ocspResponse = sessionState.ocspResponse
c.scts = sessionState.scts
c.verifiedChains = sessionState.verifiedChains
c.extMasterSecret = sessionState.extMasterSecret
hs.sessionState = sessionState
hs.suite = suite
c.curveID = sessionState.curveID
c.didResume = true
return nil
}
func (hs *serverHandshakeState) doResumeHandshake() error {
c := hs.c
hs.hello.cipherSuite = hs.suite.id
c.cipherSuite = hs.suite.id
// We echo the client's session ID in the ServerHello to let it know
// that we're doing a resumption.
hs.hello.sessionId = hs.clientHello.sessionId
// We always send a new session ticket, even if it wraps the same master
// secret and it's potentially encrypted with the same key, to help the
// client avoid cross-connection tracking from a network observer.
hs.hello.ticketSupported = true
hs.finishedHash = newFinishedHash(c.vers, hs.suite)
hs.finishedHash.discardHandshakeBuffer()
if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
return err
}
if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
return err
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
hs.masterSecret = hs.sessionState.secret
return nil
}
func (hs *serverHandshakeState) doFullHandshake() error {
c := hs.c
if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
hs.hello.ocspStapling = true
}
if hs.clientHello.serverName != "" {
hs.hello.serverNameAck = true
}
hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled
hs.hello.cipherSuite = hs.suite.id
hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite)
if c.config.ClientAuth == NoClientCert {
// No need to keep a full record of the handshake if client
// certificates won't be used.
hs.finishedHash.discardHandshakeBuffer()
}
if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
return err
}
if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
return err
}
certMsg := new(certificateMsg)
certMsg.certificates = hs.cert.Certificate
if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil {
return err
}
if hs.hello.ocspStapling {
certStatus := new(certificateStatusMsg)
certStatus.response = hs.cert.OCSPStaple
if _, err := hs.c.writeHandshakeRecord(certStatus, &hs.finishedHash); err != nil {
return err
}
}
keyAgreement := hs.suite.ka(c.vers)
skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello)
if err != nil {
c.sendAlert(alertHandshakeFailure)
return err
}
if skx != nil {
if keyAgreement, ok := keyAgreement.(*ecdheKeyAgreement); ok {
c.curveID = keyAgreement.curveID
c.peerSigAlg = keyAgreement.signatureAlgorithm
}
if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil {
return err
}
}
var certReq *certificateRequestMsg
if c.config.ClientAuth >= RequestClientCert {
// Request a client certificate
certReq = new(certificateRequestMsg)
certReq.certificateTypes = []byte{
byte(certTypeRSASign),
byte(certTypeECDSASign),
}
if c.vers >= VersionTLS12 {
certReq.hasSignatureAlgorithm = true
certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms(c.vers)
}
// An empty list of certificateAuthorities signals to
// the client that it may send any certificate in response
// to our request. When we know the CAs we trust, then
// we can send them down, so that the client can choose
// an appropriate certificate to give to us.
if c.config.ClientCAs != nil {
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
}
if _, err := hs.c.writeHandshakeRecord(certReq, &hs.finishedHash); err != nil {
return err
}
}
helloDone := new(serverHelloDoneMsg)
if _, err := hs.c.writeHandshakeRecord(helloDone, &hs.finishedHash); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
var pub crypto.PublicKey // public key for client auth, if any
msg, err := c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
// If we requested a client certificate, then the client must send a
// certificate message, even if it's empty.
if c.config.ClientAuth >= RequestClientCert {
certMsg, ok := msg.(*certificateMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
if err := c.processCertsFromClient(Certificate{
Certificate: certMsg.certificates,
}); err != nil {
return err
}
if len(certMsg.certificates) != 0 {
pub = c.peerCertificates[0].PublicKey
}
msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
// Get client key exchange
ckx, ok := msg.(*clientKeyExchangeMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(ckx, msg)
}
preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
if err != nil {
c.sendAlert(alertIllegalParameter)
return err
}
if hs.hello.extendedMasterSecret {
c.extMasterSecret = true
hs.masterSecret = extMasterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret,
hs.finishedHash.Sum())
} else {
if fips140tls.Required() {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: FIPS 140-3 requires the use of Extended Master Secret")
}
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret,
hs.clientHello.random, hs.hello.random)
}
if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil {
c.sendAlert(alertInternalError)
return err
}
// If we received a client cert in response to our certificate request message,
// the client will send us a certificateVerifyMsg immediately after the
// clientKeyExchangeMsg. This message is a digest of all preceding
// handshake-layer messages that is signed using the private key corresponding
// to the client's certificate. This allows us to verify that the client is in
// possession of the private key of the certificate.
if len(c.peerCertificates) > 0 {
// certificateVerifyMsg is included in the transcript, but not until
// after we verify the handshake signature, since the state before
// this message was sent is used.
msg, err = c.readHandshake(nil)
if err != nil {
return err
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerify, msg)
}
var sigType uint8
var sigHash crypto.Hash
if c.vers >= VersionTLS12 {
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, certReq.supportedSignatureAlgorithms) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client certificate used with invalid signature algorithm")
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
if sigHash == crypto.SHA1 {
tlssha1.Value() // ensure godebug is initialized
tlssha1.IncNonDefault()
}
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub)
if err != nil {
c.sendAlert(alertIllegalParameter)
return err
}
}
signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash)
if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
c.peerSigAlg = certVerify.signatureAlgorithm
if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil {
return err
}
}
hs.finishedHash.discardHandshakeBuffer()
return nil
}
func (hs *serverHandshakeState) establishKeys() error {
c := hs.c
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
var clientCipher, serverCipher any
var clientHash, serverHash hash.Hash
if hs.suite.aead == nil {
clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */)
clientHash = hs.suite.mac(clientMAC)
serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */)
serverHash = hs.suite.mac(serverMAC)
} else {
clientCipher = hs.suite.aead(clientKey, clientIV)
serverCipher = hs.suite.aead(serverKey, serverIV)
}
c.in.prepareCipherSpec(c.vers, clientCipher, clientHash)
c.out.prepareCipherSpec(c.vers, serverCipher, serverHash)
return nil
}
func (hs *serverHandshakeState) readFinished(out []byte) error {
c := hs.c
if err := c.readChangeCipherSpec(); err != nil {
return err
}
// finishedMsg is included in the transcript, but not until after we
// check the client version, since the state before this message was
// sent is used during verification.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
clientFinished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(clientFinished, msg)
}
verify := hs.finishedHash.clientSum(hs.masterSecret)
if len(verify) != len(clientFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: client's Finished message is incorrect")
}
if err := transcriptMsg(clientFinished, &hs.finishedHash); err != nil {
return err
}
copy(out, verify)
return nil
}
func (hs *serverHandshakeState) sendSessionTicket() error {
if !hs.hello.ticketSupported {
return nil
}
c := hs.c
m := new(newSessionTicketMsg)
state := c.sessionState()
state.secret = hs.masterSecret
if hs.sessionState != nil {
// If this is re-wrapping an old key, then keep
// the original time it was created.
state.createdAt = hs.sessionState.createdAt
}
if c.config.WrapSession != nil {
var err error
m.ticket, err = c.config.WrapSession(c.connectionStateLocked(), state)
if err != nil {
return err
}
} else {
stateBytes, err := state.Bytes()
if err != nil {
return err
}
m.ticket, err = c.config.encryptTicket(stateBytes, c.ticketKeys)
if err != nil {
return err
}
}
if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeState) sendFinished(out []byte) error {
c := hs.c
if err := c.writeChangeCipherRecord(); err != nil {
return err
}
finished := new(finishedMsg)
finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil {
return err
}
copy(out, finished.verifyData)
return nil
}
// processCertsFromClient takes a chain of client certificates either from a
// certificateMsg message or a certificateMsgTLS13 message and verifies them.
func (c *Conn) processCertsFromClient(certificate Certificate) error {
certificates := certificate.Certificate
certs := make([]*x509.Certificate, len(certificates))
var err error
for i, asn1Data := range certificates {
if certs[i], err = x509.ParseCertificate(asn1Data); err != nil {
c.sendAlert(alertDecodeError)
return errors.New("tls: failed to parse client certificate: " + err.Error())
}
if certs[i].PublicKeyAlgorithm == x509.RSA {
n := certs[i].PublicKey.(*rsa.PublicKey).N.BitLen()
if max, ok := checkKeySize(n); !ok {
c.sendAlert(alertBadCertificate)
return fmt.Errorf("tls: client sent certificate containing RSA key larger than %d bits", max)
}
}
}
if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) {
if c.vers == VersionTLS13 {
c.sendAlert(alertCertificateRequired)
} else {
c.sendAlert(alertHandshakeFailure)
}
return errors.New("tls: client didn't provide a certificate")
}
if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 {
opts := x509.VerifyOptions{
Roots: c.config.ClientCAs,
CurrentTime: c.config.time(),
Intermediates: x509.NewCertPool(),
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
chains, err := certs[0].Verify(opts)
if err != nil {
var errCertificateInvalid x509.CertificateInvalidError
if errors.As(err, &x509.UnknownAuthorityError{}) {
c.sendAlert(alertUnknownCA)
} else if errors.As(err, &errCertificateInvalid) && errCertificateInvalid.Reason == x509.Expired {
c.sendAlert(alertCertificateExpired)
} else {
c.sendAlert(alertBadCertificate)
}
return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err}
}
c.verifiedChains, err = fipsAllowedChains(chains)
if err != nil {
c.sendAlert(alertBadCertificate)
return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err}
}
}
c.peerCertificates = certs
c.ocspResponse = certificate.OCSPStaple
c.scts = certificate.SignedCertificateTimestamps
if len(certs) > 0 {
switch certs[0].PublicKey.(type) {
case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey:
default:
c.sendAlert(alertUnsupportedCertificate)
return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey)
}
}
if c.config.VerifyPeerCertificate != nil {
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}
func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
supportedVersions := clientHello.supportedVersions
if len(clientHello.supportedVersions) == 0 {
supportedVersions = supportedVersionsFromMax(clientHello.vers)
}
return &ClientHelloInfo{
CipherSuites: clientHello.cipherSuites,
ServerName: clientHello.serverName,
SupportedCurves: clientHello.supportedCurves,
SupportedPoints: clientHello.supportedPoints,
SignatureSchemes: clientHello.supportedSignatureAlgorithms,
SupportedProtos: clientHello.alpnProtocols,
SupportedVersions: supportedVersions,
Extensions: clientHello.extensions,
Conn: c.conn,
config: c.config,
ctx: ctx,
}
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"context"
"crypto"
"crypto/hkdf"
"crypto/hmac"
"crypto/internal/fips140/mlkem"
"crypto/internal/fips140/tls13"
"crypto/internal/hpke"
"crypto/rsa"
"crypto/tls/internal/fips140tls"
"errors"
"fmt"
"hash"
"internal/byteorder"
"io"
"slices"
"sort"
"time"
)
// maxClientPSKIdentities is the number of client PSK identities the server will
// attempt to validate. It will ignore the rest not to let cheap ClientHello
// messages cause too much work in session ticket decryption attempts.
const maxClientPSKIdentities = 5
type echServerContext struct {
hpkeContext *hpke.Recipient
configID uint8
ciphersuite echCipher
transcript hash.Hash
// inner indicates that the initial client_hello we recieved contained an
// encrypted_client_hello extension that indicated it was an "inner" hello.
// We don't do any additional processing of the hello in this case, so all
// fields above are unset.
inner bool
}
type serverHandshakeStateTLS13 struct {
c *Conn
ctx context.Context
clientHello *clientHelloMsg
hello *serverHelloMsg
sentDummyCCS bool
usingPSK bool
earlyData bool
suite *cipherSuiteTLS13
cert *Certificate
sigAlg SignatureScheme
earlySecret *tls13.EarlySecret
sharedKey []byte
handshakeSecret *tls13.HandshakeSecret
masterSecret *tls13.MasterSecret
trafficSecret []byte // client_application_traffic_secret_0
transcript hash.Hash
clientFinished []byte
echContext *echServerContext
}
func (hs *serverHandshakeStateTLS13) handshake() error {
c := hs.c
// For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2.
if err := hs.processClientHello(); err != nil {
return err
}
if err := hs.checkForResumption(); err != nil {
return err
}
if err := hs.pickCertificate(); err != nil {
return err
}
c.buffering = true
if err := hs.sendServerParameters(); err != nil {
return err
}
if err := hs.sendServerCertificate(); err != nil {
return err
}
if err := hs.sendServerFinished(); err != nil {
return err
}
// Note that at this point we could start sending application data without
// waiting for the client's second flight, but the application might not
// expect the lack of replay protection of the ClientHello parameters.
if _, err := c.flush(); err != nil {
return err
}
if err := hs.readClientCertificate(); err != nil {
return err
}
if err := hs.readClientFinished(); err != nil {
return err
}
c.isHandshakeComplete.Store(true)
return nil
}
func (hs *serverHandshakeStateTLS13) processClientHello() error {
c := hs.c
hs.hello = new(serverHelloMsg)
// TLS 1.3 froze the ServerHello.legacy_version field, and uses
// supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1.
hs.hello.vers = VersionTLS12
hs.hello.supportedVersion = c.vers
if len(hs.clientHello.supportedVersions) == 0 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client used the legacy version field to negotiate TLS 1.3")
}
// Abort if the client is doing a fallback and landing lower than what we
// support. See RFC 7507, which however does not specify the interaction
// with supported_versions. The only difference is that with
// supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4]
// handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case,
// it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to
// TLS 1.2, because a TLS 1.3 server would abort here. The situation before
// supported_versions was not better because there was just no way to do a
// TLS 1.4 handshake without risking the server selecting TLS 1.3.
for _, id := range hs.clientHello.cipherSuites {
if id == TLS_FALLBACK_SCSV {
// Use c.vers instead of max(supported_versions) because an attacker
// could defeat this by adding an arbitrary high version otherwise.
if c.vers < c.config.maxSupportedVersion(roleServer) {
c.sendAlert(alertInappropriateFallback)
return errors.New("tls: client using inappropriate protocol fallback")
}
break
}
}
if len(hs.clientHello.compressionMethods) != 1 ||
hs.clientHello.compressionMethods[0] != compressionNone {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: TLS 1.3 client supports illegal compression methods")
}
hs.hello.random = make([]byte, 32)
if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil {
c.sendAlert(alertInternalError)
return err
}
if len(hs.clientHello.secureRenegotiation) != 0 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: initial handshake had non-empty renegotiation extension")
}
if hs.clientHello.earlyData && c.quic != nil {
if len(hs.clientHello.pskIdentities) == 0 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: early_data without pre_shared_key")
}
} else if hs.clientHello.earlyData {
// See RFC 8446, Section 4.2.10 for the complicated behavior required
// here. The scenario is that a different server at our address offered
// to accept early data in the past, which we can't handle. For now, all
// 0-RTT enabled session tickets need to expire before a Go server can
// replace a server or join a pool. That's the same requirement that
// applies to mixing or replacing with any TLS 1.2 server.
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: client sent unexpected early data")
}
hs.hello.sessionId = hs.clientHello.sessionId
hs.hello.compressionMethod = compressionNone
preferenceList := defaultCipherSuitesTLS13
if !hasAESGCMHardwareSupport || !isAESGCMPreferred(hs.clientHello.cipherSuites) {
preferenceList = defaultCipherSuitesTLS13NoAES
}
if fips140tls.Required() {
preferenceList = allowedCipherSuitesTLS13FIPS
}
for _, suiteID := range preferenceList {
hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID)
if hs.suite != nil {
break
}
}
if hs.suite == nil {
c.sendAlert(alertHandshakeFailure)
return fmt.Errorf("tls: no cipher suite supported by both client and server; client offered: %x",
hs.clientHello.cipherSuites)
}
c.cipherSuite = hs.suite.id
hs.hello.cipherSuite = hs.suite.id
hs.transcript = hs.suite.hash.New()
// First, if a post-quantum key exchange is available, use one. See
// draft-ietf-tls-key-share-prediction-01, Section 4 for why this must be
// first.
//
// Second, if the client sent a key share for a group we support, use that,
// to avoid a HelloRetryRequest round-trip.
//
// Finally, pick in our fixed preference order.
preferredGroups := c.config.curvePreferences(c.vers)
preferredGroups = slices.DeleteFunc(preferredGroups, func(group CurveID) bool {
return !slices.Contains(hs.clientHello.supportedCurves, group)
})
if len(preferredGroups) == 0 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: no key exchanges supported by both client and server")
}
hasKeyShare := func(group CurveID) bool {
for _, ks := range hs.clientHello.keyShares {
if ks.group == group {
return true
}
}
return false
}
sort.SliceStable(preferredGroups, func(i, j int) bool {
return hasKeyShare(preferredGroups[i]) && !hasKeyShare(preferredGroups[j])
})
sort.SliceStable(preferredGroups, func(i, j int) bool {
return isPQKeyExchange(preferredGroups[i]) && !isPQKeyExchange(preferredGroups[j])
})
selectedGroup := preferredGroups[0]
var clientKeyShare *keyShare
for _, ks := range hs.clientHello.keyShares {
if ks.group == selectedGroup {
clientKeyShare = &ks
break
}
}
if clientKeyShare == nil {
ks, err := hs.doHelloRetryRequest(selectedGroup)
if err != nil {
return err
}
clientKeyShare = ks
}
c.curveID = selectedGroup
ecdhGroup := selectedGroup
ecdhData := clientKeyShare.data
if selectedGroup == X25519MLKEM768 {
ecdhGroup = X25519
if len(ecdhData) != mlkem.EncapsulationKeySize768+x25519PublicKeySize {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid X25519MLKEM768 client key share")
}
ecdhData = ecdhData[mlkem.EncapsulationKeySize768:]
}
if _, ok := curveForCurveID(ecdhGroup); !ok {
c.sendAlert(alertInternalError)
return errors.New("tls: CurvePreferences includes unsupported curve")
}
key, err := generateECDHEKey(c.config.rand(), ecdhGroup)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.hello.serverShare = keyShare{group: selectedGroup, data: key.PublicKey().Bytes()}
peerKey, err := key.Curve().NewPublicKey(ecdhData)
if err != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid client key share")
}
hs.sharedKey, err = key.ECDH(peerKey)
if err != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid client key share")
}
if selectedGroup == X25519MLKEM768 {
k, err := mlkem.NewEncapsulationKey768(clientKeyShare.data[:mlkem.EncapsulationKeySize768])
if err != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid X25519MLKEM768 client key share")
}
mlkemSharedSecret, ciphertext := k.Encapsulate()
// draft-kwiatkowski-tls-ecdhe-mlkem-02, Section 3.1.3: "For
// X25519MLKEM768, the shared secret is the concatenation of the ML-KEM
// shared secret and the X25519 shared secret. The shared secret is 64
// bytes (32 bytes for each part)."
hs.sharedKey = append(mlkemSharedSecret, hs.sharedKey...)
// draft-kwiatkowski-tls-ecdhe-mlkem-02, Section 3.1.2: "When the
// X25519MLKEM768 group is negotiated, the server's key exchange value
// is the concatenation of an ML-KEM ciphertext returned from
// encapsulation to the client's encapsulation key, and the server's
// ephemeral X25519 share."
hs.hello.serverShare.data = append(ciphertext, hs.hello.serverShare.data...)
}
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
if err != nil {
c.sendAlert(alertNoApplicationProtocol)
return err
}
c.clientProtocol = selectedProto
if c.quic != nil {
// RFC 9001 Section 4.2: Clients MUST NOT offer TLS versions older than 1.3.
for _, v := range hs.clientHello.supportedVersions {
if v < VersionTLS13 {
c.sendAlert(alertProtocolVersion)
return errors.New("tls: client offered TLS version older than TLS 1.3")
}
}
// RFC 9001 Section 8.2.
if hs.clientHello.quicTransportParameters == nil {
c.sendAlert(alertMissingExtension)
return errors.New("tls: client did not send a quic_transport_parameters extension")
}
c.quicSetTransportParameters(hs.clientHello.quicTransportParameters)
} else {
if hs.clientHello.quicTransportParameters != nil {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: client sent an unexpected quic_transport_parameters extension")
}
}
c.serverName = hs.clientHello.serverName
return nil
}
func (hs *serverHandshakeStateTLS13) checkForResumption() error {
c := hs.c
if c.config.SessionTicketsDisabled {
return nil
}
modeOK := false
for _, mode := range hs.clientHello.pskModes {
if mode == pskModeDHE {
modeOK = true
break
}
}
if !modeOK {
return nil
}
if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid or missing PSK binders")
}
if len(hs.clientHello.pskIdentities) == 0 {
return nil
}
for i, identity := range hs.clientHello.pskIdentities {
if i >= maxClientPSKIdentities {
break
}
var sessionState *SessionState
if c.config.UnwrapSession != nil {
var err error
sessionState, err = c.config.UnwrapSession(identity.label, c.connectionStateLocked())
if err != nil {
return err
}
if sessionState == nil {
continue
}
} else {
plaintext := c.config.decryptTicket(identity.label, c.ticketKeys)
if plaintext == nil {
continue
}
var err error
sessionState, err = ParseSessionState(plaintext)
if err != nil {
continue
}
}
if sessionState.version != VersionTLS13 {
continue
}
createdAt := time.Unix(int64(sessionState.createdAt), 0)
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
continue
}
pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite)
if pskSuite == nil || pskSuite.hash != hs.suite.hash {
continue
}
// PSK connections don't re-establish client certificates, but carry
// them over in the session ticket. Ensure the presence of client certs
// in the ticket is consistent with the configured requirements.
sessionHasClientCerts := len(sessionState.peerCertificates) != 0
needClientCerts := requiresClientCert(c.config.ClientAuth)
if needClientCerts && !sessionHasClientCerts {
continue
}
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
continue
}
if sessionHasClientCerts && c.config.time().After(sessionState.peerCertificates[0].NotAfter) {
continue
}
if sessionHasClientCerts && c.config.ClientAuth >= VerifyClientCertIfGiven &&
len(sessionState.verifiedChains) == 0 {
continue
}
if c.quic != nil && c.quic.enableSessionEvents {
if err := c.quicResumeSession(sessionState); err != nil {
return err
}
}
hs.earlySecret = tls13.NewEarlySecret(hs.suite.hash.New, sessionState.secret)
binderKey := hs.earlySecret.ResumptionBinderKey()
// Clone the transcript in case a HelloRetryRequest was recorded.
transcript := cloneHash(hs.transcript, hs.suite.hash)
if transcript == nil {
c.sendAlert(alertInternalError)
return errors.New("tls: internal error: failed to clone hash")
}
clientHelloBytes, err := hs.clientHello.marshalWithoutBinders()
if err != nil {
c.sendAlert(alertInternalError)
return err
}
transcript.Write(clientHelloBytes)
pskBinder := hs.suite.finishedHash(binderKey, transcript)
if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid PSK binder")
}
if c.quic != nil && hs.clientHello.earlyData && i == 0 &&
sessionState.EarlyData && sessionState.cipherSuite == hs.suite.id &&
sessionState.alpnProtocol == c.clientProtocol {
hs.earlyData = true
transcript := hs.suite.hash.New()
if err := transcriptMsg(hs.clientHello, transcript); err != nil {
return err
}
earlyTrafficSecret := hs.earlySecret.ClientEarlyTrafficSecret(transcript)
c.quicSetReadSecret(QUICEncryptionLevelEarly, hs.suite.id, earlyTrafficSecret)
}
c.didResume = true
c.peerCertificates = sessionState.peerCertificates
c.ocspResponse = sessionState.ocspResponse
c.scts = sessionState.scts
c.verifiedChains = sessionState.verifiedChains
hs.hello.selectedIdentityPresent = true
hs.hello.selectedIdentity = uint16(i)
hs.usingPSK = true
return nil
}
return nil
}
// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
// interfaces implemented by standard library hashes to clone the state of in
// to a new instance of h. It returns nil if the operation fails.
func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash {
// Recreate the interface to avoid importing encoding.
type binaryMarshaler interface {
MarshalBinary() (data []byte, err error)
UnmarshalBinary(data []byte) error
}
marshaler, ok := in.(binaryMarshaler)
if !ok {
return nil
}
state, err := marshaler.MarshalBinary()
if err != nil {
return nil
}
out := h.New()
unmarshaler, ok := out.(binaryMarshaler)
if !ok {
return nil
}
if err := unmarshaler.UnmarshalBinary(state); err != nil {
return nil
}
return out
}
func (hs *serverHandshakeStateTLS13) pickCertificate() error {
c := hs.c
// Only one of PSK and certificates are used at a time.
if hs.usingPSK {
return nil
}
// signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3.
if len(hs.clientHello.supportedSignatureAlgorithms) == 0 {
return c.sendAlert(alertMissingExtension)
}
certificate, err := c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
if err != nil {
if err == errNoCertificates {
c.sendAlert(alertUnrecognizedName)
} else {
c.sendAlert(alertInternalError)
}
return err
}
hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms)
if err != nil {
// getCertificate returned a certificate that is unsupported or
// incompatible with the client's signature algorithms.
c.sendAlert(alertHandshakeFailure)
return err
}
hs.cert = certificate
return nil
}
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
if hs.c.quic != nil {
return nil
}
if hs.sentDummyCCS {
return nil
}
hs.sentDummyCCS = true
return hs.c.writeChangeCipherRecord()
}
func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) (*keyShare, error) {
c := hs.c
// The first ClientHello gets double-hashed into the transcript upon a
// HelloRetryRequest. See RFC 8446, Section 4.4.1.
if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
return nil, err
}
chHash := hs.transcript.Sum(nil)
hs.transcript.Reset()
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
hs.transcript.Write(chHash)
helloRetryRequest := &serverHelloMsg{
vers: hs.hello.vers,
random: helloRetryRequestRandom,
sessionId: hs.hello.sessionId,
cipherSuite: hs.hello.cipherSuite,
compressionMethod: hs.hello.compressionMethod,
supportedVersion: hs.hello.supportedVersion,
selectedGroup: selectedGroup,
}
if hs.echContext != nil {
// Compute the acceptance message.
helloRetryRequest.encryptedClientHello = make([]byte, 8)
confTranscript := cloneHash(hs.transcript, hs.suite.hash)
if err := transcriptMsg(helloRetryRequest, confTranscript); err != nil {
return nil, err
}
h := hs.suite.hash.New
prf, err := hkdf.Extract(h, hs.clientHello.random, nil)
if err != nil {
c.sendAlert(alertInternalError)
return nil, err
}
acceptConfirmation := tls13.ExpandLabel(h, prf, "hrr ech accept confirmation", confTranscript.Sum(nil), 8)
helloRetryRequest.encryptedClientHello = acceptConfirmation
}
if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil {
return nil, err
}
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return nil, err
}
// clientHelloMsg is not included in the transcript.
msg, err := c.readHandshake(nil)
if err != nil {
return nil, err
}
clientHello, ok := msg.(*clientHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return nil, unexpectedMessageError(clientHello, msg)
}
if hs.echContext != nil {
if len(clientHello.encryptedClientHello) == 0 {
c.sendAlert(alertMissingExtension)
return nil, errors.New("tls: second client hello missing encrypted client hello extension")
}
echType, echCiphersuite, configID, encap, payload, err := parseECHExt(clientHello.encryptedClientHello)
if err != nil {
c.sendAlert(alertDecodeError)
return nil, errors.New("tls: client sent invalid encrypted client hello extension")
}
if echType == outerECHExt && hs.echContext.inner || echType == innerECHExt && !hs.echContext.inner {
c.sendAlert(alertDecodeError)
return nil, errors.New("tls: unexpected switch in encrypted client hello extension type")
}
if echType == outerECHExt {
if echCiphersuite != hs.echContext.ciphersuite || configID != hs.echContext.configID || len(encap) != 0 {
c.sendAlert(alertIllegalParameter)
return nil, errors.New("tls: second client hello encrypted client hello extension does not match")
}
encodedInner, err := decryptECHPayload(hs.echContext.hpkeContext, clientHello.original, payload)
if err != nil {
c.sendAlert(alertDecryptError)
return nil, errors.New("tls: failed to decrypt second client hello encrypted client hello extension payload")
}
echInner, err := decodeInnerClientHello(clientHello, encodedInner)
if err != nil {
c.sendAlert(alertIllegalParameter)
return nil, errors.New("tls: client sent invalid encrypted client hello extension")
}
clientHello = echInner
}
}
if len(clientHello.keyShares) != 1 {
c.sendAlert(alertIllegalParameter)
return nil, errors.New("tls: client didn't send one key share in second ClientHello")
}
ks := &clientHello.keyShares[0]
if ks.group != selectedGroup {
c.sendAlert(alertIllegalParameter)
return nil, errors.New("tls: client sent unexpected key share in second ClientHello")
}
if clientHello.earlyData {
c.sendAlert(alertIllegalParameter)
return nil, errors.New("tls: client indicated early data in second ClientHello")
}
if illegalClientHelloChange(clientHello, hs.clientHello) {
c.sendAlert(alertIllegalParameter)
return nil, errors.New("tls: client illegally modified second ClientHello")
}
c.didHRR = true
hs.clientHello = clientHello
return ks, nil
}
// illegalClientHelloChange reports whether the two ClientHello messages are
// different, with the exception of the changes allowed before and after a
// HelloRetryRequest. See RFC 8446, Section 4.1.2.
func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool {
if len(ch.supportedVersions) != len(ch1.supportedVersions) ||
len(ch.cipherSuites) != len(ch1.cipherSuites) ||
len(ch.supportedCurves) != len(ch1.supportedCurves) ||
len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) ||
len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) ||
len(ch.alpnProtocols) != len(ch1.alpnProtocols) {
return true
}
for i := range ch.supportedVersions {
if ch.supportedVersions[i] != ch1.supportedVersions[i] {
return true
}
}
for i := range ch.cipherSuites {
if ch.cipherSuites[i] != ch1.cipherSuites[i] {
return true
}
}
for i := range ch.supportedCurves {
if ch.supportedCurves[i] != ch1.supportedCurves[i] {
return true
}
}
for i := range ch.supportedSignatureAlgorithms {
if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] {
return true
}
}
for i := range ch.supportedSignatureAlgorithmsCert {
if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] {
return true
}
}
for i := range ch.alpnProtocols {
if ch.alpnProtocols[i] != ch1.alpnProtocols[i] {
return true
}
}
return ch.vers != ch1.vers ||
!bytes.Equal(ch.random, ch1.random) ||
!bytes.Equal(ch.sessionId, ch1.sessionId) ||
!bytes.Equal(ch.compressionMethods, ch1.compressionMethods) ||
ch.serverName != ch1.serverName ||
ch.ocspStapling != ch1.ocspStapling ||
!bytes.Equal(ch.supportedPoints, ch1.supportedPoints) ||
ch.ticketSupported != ch1.ticketSupported ||
!bytes.Equal(ch.sessionTicket, ch1.sessionTicket) ||
ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported ||
!bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) ||
ch.scts != ch1.scts ||
!bytes.Equal(ch.cookie, ch1.cookie) ||
!bytes.Equal(ch.pskModes, ch1.pskModes)
}
func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
c := hs.c
if hs.echContext != nil {
copy(hs.hello.random[32-8:], make([]byte, 8))
echTranscript := cloneHash(hs.transcript, hs.suite.hash)
echTranscript.Write(hs.clientHello.original)
if err := transcriptMsg(hs.hello, echTranscript); err != nil {
return err
}
// compute the acceptance message
h := hs.suite.hash.New
prk, err := hkdf.Extract(h, hs.clientHello.random, nil)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
acceptConfirmation := tls13.ExpandLabel(h, prk, "ech accept confirmation", echTranscript.Sum(nil), 8)
copy(hs.hello.random[32-8:], acceptConfirmation)
}
if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
return err
}
if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
return err
}
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
earlySecret := hs.earlySecret
if earlySecret == nil {
earlySecret = tls13.NewEarlySecret(hs.suite.hash.New, nil)
}
hs.handshakeSecret = earlySecret.HandshakeSecret(hs.sharedKey)
clientSecret := hs.handshakeSecret.ClientHandshakeTrafficSecret(hs.transcript)
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
serverSecret := hs.handshakeSecret.ServerHandshakeTrafficSecret(hs.transcript)
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
if c.quic != nil {
if c.hand.Len() != 0 {
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
}
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
encryptedExtensions := new(encryptedExtensionsMsg)
encryptedExtensions.alpnProtocol = c.clientProtocol
if c.quic != nil {
p, err := c.quicGetTransportParameters()
if err != nil {
return err
}
encryptedExtensions.quicTransportParameters = p
encryptedExtensions.earlyData = hs.earlyData
}
if !hs.c.didResume && hs.clientHello.serverName != "" {
encryptedExtensions.serverNameAck = true
}
// If client sent ECH extension, but we didn't accept it,
// send retry configs, if available.
echKeys := hs.c.config.EncryptedClientHelloKeys
if hs.c.config.GetEncryptedClientHelloKeys != nil {
echKeys, err = hs.c.config.GetEncryptedClientHelloKeys(clientHelloInfo(hs.ctx, c, hs.clientHello))
if err != nil {
c.sendAlert(alertInternalError)
return err
}
}
if len(echKeys) > 0 && len(hs.clientHello.encryptedClientHello) > 0 && hs.echContext == nil {
encryptedExtensions.echRetryConfigs, err = buildRetryConfigList(echKeys)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
}
if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) requestClientCert() bool {
return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK
}
func (hs *serverHandshakeStateTLS13) sendServerCertificate() error {
c := hs.c
// Only one of PSK and certificates are used at a time.
if hs.usingPSK {
return nil
}
if hs.requestClientCert() {
// Request a client certificate
certReq := new(certificateRequestMsgTLS13)
certReq.ocspStapling = true
certReq.scts = true
certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms(c.vers)
certReq.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithmsCert()
if c.config.ClientCAs != nil {
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
}
if _, err := hs.c.writeHandshakeRecord(certReq, hs.transcript); err != nil {
return err
}
}
certMsg := new(certificateMsgTLS13)
certMsg.certificate = *hs.cert
certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0
certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0
if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
return err
}
certVerifyMsg := new(certificateVerifyMsg)
certVerifyMsg.hasSignatureAlgorithm = true
certVerifyMsg.signatureAlgorithm = hs.sigAlg
sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg)
if err != nil {
return c.sendAlert(alertInternalError)
}
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts)
if err != nil {
public := hs.cert.PrivateKey.(crypto.Signer).Public()
if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS &&
rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS
c.sendAlert(alertHandshakeFailure)
} else {
c.sendAlert(alertInternalError)
}
return errors.New("tls: failed to sign handshake: " + err.Error())
}
certVerifyMsg.signature = sig
if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) sendServerFinished() error {
c := hs.c
finished := &finishedMsg{
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
}
if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
return err
}
// Derive secrets that take context through the server Finished.
hs.masterSecret = hs.handshakeSecret.MasterSecret()
hs.trafficSecret = hs.masterSecret.ClientApplicationTrafficSecret(hs.transcript)
serverSecret := hs.masterSecret.ServerApplicationTrafficSecret(hs.transcript)
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
if c.quic != nil {
if c.hand.Len() != 0 {
// TODO: Handle this in setTrafficSecret?
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, serverSecret)
}
err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript)
// If we did not request client certificates, at this point we can
// precompute the client finished and roll the transcript forward to send
// session tickets in our first flight.
if !hs.requestClientCert() {
if err := hs.sendSessionTickets(); err != nil {
return err
}
}
return nil
}
func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool {
if hs.c.config.SessionTicketsDisabled {
return false
}
// QUIC tickets are sent by QUICConn.SendSessionTicket, not automatically.
if hs.c.quic != nil {
return false
}
// Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9.
return slices.Contains(hs.clientHello.pskModes, pskModeDHE)
}
func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
c := hs.c
hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript)
finishedMsg := &finishedMsg{
verifyData: hs.clientFinished,
}
if err := transcriptMsg(finishedMsg, hs.transcript); err != nil {
return err
}
c.resumptionSecret = hs.masterSecret.ResumptionMasterSecret(hs.transcript)
if !hs.shouldSendSessionTickets() {
return nil
}
return c.sendSessionTicket(false, nil)
}
func (c *Conn) sendSessionTicket(earlyData bool, extra [][]byte) error {
suite := cipherSuiteTLS13ByID(c.cipherSuite)
if suite == nil {
return errors.New("tls: internal error: unknown cipher suite")
}
// ticket_nonce, which must be unique per connection, is always left at
// zero because we only ever send one ticket per connection.
psk := tls13.ExpandLabel(suite.hash.New, c.resumptionSecret, "resumption",
nil, suite.hash.Size())
m := new(newSessionTicketMsgTLS13)
state := c.sessionState()
state.secret = psk
state.EarlyData = earlyData
state.Extra = extra
if c.config.WrapSession != nil {
var err error
m.label, err = c.config.WrapSession(c.connectionStateLocked(), state)
if err != nil {
return err
}
} else {
stateBytes, err := state.Bytes()
if err != nil {
c.sendAlert(alertInternalError)
return err
}
m.label, err = c.config.encryptTicket(stateBytes, c.ticketKeys)
if err != nil {
return err
}
}
m.lifetime = uint32(maxSessionTicketLifetime / time.Second)
// ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1
// The value is not stored anywhere; we never need to check the ticket age
// because 0-RTT is not supported.
ageAdd := make([]byte, 4)
if _, err := c.config.rand().Read(ageAdd); err != nil {
return err
}
m.ageAdd = byteorder.LEUint32(ageAdd)
if earlyData {
// RFC 9001, Section 4.6.1
m.maxEarlyData = 0xffffffff
}
if _, err := c.writeHandshakeRecord(m, nil); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
c := hs.c
if !hs.requestClientCert() {
// Make sure the connection is still being verified whether or not
// the server requested a client certificate.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}
// If we requested a client certificate, then the client must send a
// certificate message. If it's empty, no CertificateVerify is sent.
msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
certMsg, ok := msg.(*certificateMsgTLS13)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
if err := c.processCertsFromClient(certMsg.certificate); err != nil {
return err
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
if len(certMsg.certificate.Certificate) != 0 {
// certificateVerifyMsg is included in the transcript, but not until
// after we verify the handshake signature, since the state before
// this message was sent is used.
msg, err = c.readHandshake(nil)
if err != nil {
return err
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerify, msg)
}
// See RFC 8446, Section 4.4.3.
// We don't use certReq.supportedSignatureAlgorithms because it would
// require keeping the certificateRequestMsgTLS13 around in the hs.
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms(c.vers)) ||
!isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, signatureSchemesForPublicKey(c.vers, c.peerCertificates[0].PublicKey)) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client certificate used with invalid signature algorithm")
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
return c.sendAlert(alertInternalError)
}
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript)
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
c.peerSigAlg = certVerify.signatureAlgorithm
if err := transcriptMsg(certVerify, hs.transcript); err != nil {
return err
}
}
// If we waited until the client certificates to send session tickets, we
// are ready to do it now.
if err := hs.sendSessionTickets(); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) readClientFinished() error {
c := hs.c
// finishedMsg is not included in the transcript.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
finished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(finished, msg)
}
if !hmac.Equal(hs.clientFinished, finished.verifyData) {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid client finished hash")
}
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
return nil
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto"
"crypto/ecdh"
"crypto/md5"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"errors"
"fmt"
"io"
"slices"
)
// A keyAgreement implements the client and server side of a TLS 1.0–1.2 key
// agreement protocol by generating and processing key exchange messages.
type keyAgreement interface {
// On the server side, the first two methods are called in order.
// In the case that the key agreement protocol doesn't use a
// ServerKeyExchange message, generateServerKeyExchange can return nil,
// nil.
generateServerKeyExchange(*Config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error)
processClientKeyExchange(*Config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error)
// On the client side, the next two methods are called in order.
// This method may not be called if the server doesn't send a
// ServerKeyExchange message.
processServerKeyExchange(*Config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error
generateClientKeyExchange(*Config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error)
}
var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
// rsaKeyAgreement implements the standard TLS key agreement where the client
// encrypts the pre-master secret to the server's public key.
type rsaKeyAgreement struct{}
func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
return nil, nil
}
func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
if len(ckx.ciphertext) < 2 {
return nil, errClientKeyExchange
}
ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
if ciphertextLen != len(ckx.ciphertext)-2 {
return nil, errClientKeyExchange
}
ciphertext := ckx.ciphertext[2:]
priv, ok := cert.PrivateKey.(crypto.Decrypter)
if !ok {
return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter")
}
// Perform constant time RSA PKCS #1 v1.5 decryption
preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48})
if err != nil {
return nil, err
}
// We don't check the version number in the premaster secret. For one,
// by checking it, we would leak information about the validity of the
// encrypted pre-master secret. Secondly, it provides only a small
// benefit against a downgrade attack and some implementations send the
// wrong version anyway. See the discussion at the end of section
// 7.4.7.1 of RFC 4346.
return preMasterSecret, nil
}
func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
return errors.New("tls: unexpected ServerKeyExchange")
}
func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
preMasterSecret := make([]byte, 48)
preMasterSecret[0] = byte(clientHello.vers >> 8)
preMasterSecret[1] = byte(clientHello.vers)
_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
if err != nil {
return nil, nil, err
}
rsaKey, ok := cert.PublicKey.(*rsa.PublicKey)
if !ok {
return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite")
}
encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret)
if err != nil {
return nil, nil, err
}
ckx := new(clientKeyExchangeMsg)
ckx.ciphertext = make([]byte, len(encrypted)+2)
ckx.ciphertext[0] = byte(len(encrypted) >> 8)
ckx.ciphertext[1] = byte(len(encrypted))
copy(ckx.ciphertext[2:], encrypted)
return preMasterSecret, ckx, nil
}
// sha1Hash calculates a SHA1 hash over the given byte slices.
func sha1Hash(slices [][]byte) []byte {
hsha1 := sha1.New()
for _, slice := range slices {
hsha1.Write(slice)
}
return hsha1.Sum(nil)
}
// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the
// concatenation of an MD5 and SHA1 hash.
func md5SHA1Hash(slices [][]byte) []byte {
md5sha1 := make([]byte, md5.Size+sha1.Size)
hmd5 := md5.New()
for _, slice := range slices {
hmd5.Write(slice)
}
copy(md5sha1, hmd5.Sum(nil))
copy(md5sha1[md5.Size:], sha1Hash(slices))
return md5sha1
}
// hashForServerKeyExchange hashes the given slices and returns their digest
// using the given hash function (for TLS 1.2) or using a default based on
// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't
// do pre-hashing, it returns the concatenation of the slices.
func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte {
if sigType == signatureEd25519 {
var signed []byte
for _, slice := range slices {
signed = append(signed, slice...)
}
return signed
}
if version >= VersionTLS12 {
h := hashFunc.New()
for _, slice := range slices {
h.Write(slice)
}
digest := h.Sum(nil)
return digest
}
if sigType == signatureECDSA {
return sha1Hash(slices)
}
return md5SHA1Hash(slices)
}
// ecdheKeyAgreement implements a TLS key agreement where the server
// generates an ephemeral EC public/private key pair and signs it. The
// pre-master secret is then calculated using ECDH. The signature may
// be ECDSA, Ed25519 or RSA.
type ecdheKeyAgreement struct {
version uint16
isRSA bool
key *ecdh.PrivateKey
// ckx and preMasterSecret are generated in processServerKeyExchange
// and returned in generateClientKeyExchange.
ckx *clientKeyExchangeMsg
preMasterSecret []byte
// curveID and signatureAlgorithm are set by processServerKeyExchange and
// generateServerKeyExchange.
curveID CurveID
signatureAlgorithm SignatureScheme
}
func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
for _, c := range clientHello.supportedCurves {
if config.supportsCurve(ka.version, c) {
ka.curveID = c
break
}
}
if ka.curveID == 0 {
return nil, errors.New("tls: no supported elliptic curves offered")
}
if _, ok := curveForCurveID(ka.curveID); !ok {
return nil, errors.New("tls: CurvePreferences includes unsupported curve")
}
key, err := generateECDHEKey(config.rand(), ka.curveID)
if err != nil {
return nil, err
}
ka.key = key
// See RFC 4492, Section 5.4.
ecdhePublic := key.PublicKey().Bytes()
serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic))
serverECDHEParams[0] = 3 // named curve
serverECDHEParams[1] = byte(ka.curveID >> 8)
serverECDHEParams[2] = byte(ka.curveID)
serverECDHEParams[3] = byte(len(ecdhePublic))
copy(serverECDHEParams[4:], ecdhePublic)
priv, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey)
}
var sigType uint8
var sigHash crypto.Hash
if ka.version >= VersionTLS12 {
ka.signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms)
if err != nil {
return nil, err
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(ka.signatureAlgorithm)
if err != nil {
return nil, err
}
if sigHash == crypto.SHA1 {
tlssha1.Value() // ensure godebug is initialized
tlssha1.IncNonDefault()
}
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public())
if err != nil {
return nil, err
}
}
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
return nil, errors.New("tls: certificate cannot be used with the selected cipher suite")
}
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := priv.Sign(config.rand(), signed, signOpts)
if err != nil {
return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
}
skx := new(serverKeyExchangeMsg)
sigAndHashLen := 0
if ka.version >= VersionTLS12 {
sigAndHashLen = 2
}
skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig))
copy(skx.key, serverECDHEParams)
k := skx.key[len(serverECDHEParams):]
if ka.version >= VersionTLS12 {
k[0] = byte(ka.signatureAlgorithm >> 8)
k[1] = byte(ka.signatureAlgorithm)
k = k[2:]
}
k[0] = byte(len(sig) >> 8)
k[1] = byte(len(sig))
copy(k[2:], sig)
return skx, nil
}
func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
return nil, errClientKeyExchange
}
peerKey, err := ka.key.Curve().NewPublicKey(ckx.ciphertext[1:])
if err != nil {
return nil, errClientKeyExchange
}
preMasterSecret, err := ka.key.ECDH(peerKey)
if err != nil {
return nil, errClientKeyExchange
}
return preMasterSecret, nil
}
func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
if len(skx.key) < 4 {
return errServerKeyExchange
}
if skx.key[0] != 3 { // named curve
return errors.New("tls: server selected unsupported curve")
}
ka.curveID = CurveID(skx.key[1])<<8 | CurveID(skx.key[2])
publicLen := int(skx.key[3])
if publicLen+4 > len(skx.key) {
return errServerKeyExchange
}
serverECDHEParams := skx.key[:4+publicLen]
publicKey := serverECDHEParams[4:]
sig := skx.key[4+publicLen:]
if len(sig) < 2 {
return errServerKeyExchange
}
if !slices.Contains(clientHello.supportedCurves, ka.curveID) {
return errors.New("tls: server selected unoffered curve")
}
if _, ok := curveForCurveID(ka.curveID); !ok {
return errors.New("tls: server selected unsupported curve")
}
key, err := generateECDHEKey(config.rand(), ka.curveID)
if err != nil {
return err
}
ka.key = key
peerKey, err := key.Curve().NewPublicKey(publicKey)
if err != nil {
return errServerKeyExchange
}
ka.preMasterSecret, err = key.ECDH(peerKey)
if err != nil {
return errServerKeyExchange
}
ourPublicKey := key.PublicKey().Bytes()
ka.ckx = new(clientKeyExchangeMsg)
ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey))
ka.ckx.ciphertext[0] = byte(len(ourPublicKey))
copy(ka.ckx.ciphertext[1:], ourPublicKey)
var sigType uint8
var sigHash crypto.Hash
if ka.version >= VersionTLS12 {
ka.signatureAlgorithm = SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1])
sig = sig[2:]
if len(sig) < 2 {
return errServerKeyExchange
}
if !isSupportedSignatureAlgorithm(ka.signatureAlgorithm, clientHello.supportedSignatureAlgorithms) {
return errors.New("tls: certificate used with invalid signature algorithm")
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(ka.signatureAlgorithm)
if err != nil {
return err
}
if sigHash == crypto.SHA1 {
tlssha1.Value() // ensure godebug is initialized
tlssha1.IncNonDefault()
}
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey)
if err != nil {
return err
}
}
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
return errServerKeyExchange
}
sigLen := int(sig[0])<<8 | int(sig[1])
if sigLen+2 != len(sig) {
return errServerKeyExchange
}
sig = sig[2:]
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams)
if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil {
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
}
return nil
}
func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
if ka.ckx == nil {
return nil, nil, errors.New("tls: missing ServerKeyExchange message")
}
return ka.preMasterSecret, ka.ckx, nil
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto/ecdh"
"crypto/hmac"
"crypto/internal/fips140/mlkem"
"crypto/internal/fips140/tls13"
"errors"
"hash"
"io"
)
// This file contains the functions necessary to compute the TLS 1.3 key
// schedule. See RFC 8446, Section 7.
// nextTrafficSecret generates the next traffic secret, given the current one,
// according to RFC 8446, Section 7.2.
func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte {
return tls13.ExpandLabel(c.hash.New, trafficSecret, "traffic upd", nil, c.hash.Size())
}
// trafficKey generates traffic keys according to RFC 8446, Section 7.3.
func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) {
key = tls13.ExpandLabel(c.hash.New, trafficSecret, "key", nil, c.keyLen)
iv = tls13.ExpandLabel(c.hash.New, trafficSecret, "iv", nil, aeadNonceLength)
return
}
// finishedHash generates the Finished verify_data or PskBinderEntry according
// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey
// selection.
func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte {
finishedKey := tls13.ExpandLabel(c.hash.New, baseKey, "finished", nil, c.hash.Size())
verifyData := hmac.New(c.hash.New, finishedKey)
verifyData.Write(transcript.Sum(nil))
return verifyData.Sum(nil)
}
// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to
// RFC 8446, Section 7.5.
func (c *cipherSuiteTLS13) exportKeyingMaterial(s *tls13.MasterSecret, transcript hash.Hash) func(string, []byte, int) ([]byte, error) {
expMasterSecret := s.ExporterMasterSecret(transcript)
return func(label string, context []byte, length int) ([]byte, error) {
return expMasterSecret.Exporter(label, context, length), nil
}
}
type keySharePrivateKeys struct {
curveID CurveID
ecdhe *ecdh.PrivateKey
mlkem *mlkem.DecapsulationKey768
}
const x25519PublicKeySize = 32
// generateECDHEKey returns a PrivateKey that implements Diffie-Hellman
// according to RFC 8446, Section 4.2.8.2.
func generateECDHEKey(rand io.Reader, curveID CurveID) (*ecdh.PrivateKey, error) {
curve, ok := curveForCurveID(curveID)
if !ok {
return nil, errors.New("tls: internal error: unsupported curve")
}
return curve.GenerateKey(rand)
}
func curveForCurveID(id CurveID) (ecdh.Curve, bool) {
switch id {
case X25519:
return ecdh.X25519(), true
case CurveP256:
return ecdh.P256(), true
case CurveP384:
return ecdh.P384(), true
case CurveP521:
return ecdh.P521(), true
default:
return nil, false
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto"
"crypto/hmac"
"crypto/internal/fips140/tls12"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"errors"
"fmt"
"hash"
)
type prfFunc func(secret []byte, label string, seed []byte, keyLen int) []byte
// Split a premaster secret in two as specified in RFC 4346, Section 5.
func splitPreMasterSecret(secret []byte) (s1, s2 []byte) {
s1 = secret[0 : (len(secret)+1)/2]
s2 = secret[len(secret)/2:]
return
}
// pHash implements the P_hash function, as defined in RFC 4346, Section 5.
func pHash(result, secret, seed []byte, hash func() hash.Hash) {
h := hmac.New(hash, secret)
h.Write(seed)
a := h.Sum(nil)
j := 0
for j < len(result) {
h.Reset()
h.Write(a)
h.Write(seed)
b := h.Sum(nil)
copy(result[j:], b)
j += len(b)
h.Reset()
h.Write(a)
a = h.Sum(nil)
}
}
// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5.
func prf10(secret []byte, label string, seed []byte, keyLen int) []byte {
result := make([]byte, keyLen)
hashSHA1 := sha1.New
hashMD5 := md5.New
labelAndSeed := make([]byte, len(label)+len(seed))
copy(labelAndSeed, label)
copy(labelAndSeed[len(label):], seed)
s1, s2 := splitPreMasterSecret(secret)
pHash(result, s1, labelAndSeed, hashMD5)
result2 := make([]byte, len(result))
pHash(result2, s2, labelAndSeed, hashSHA1)
for i, b := range result2 {
result[i] ^= b
}
return result
}
// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5.
func prf12(hashFunc func() hash.Hash) prfFunc {
return func(secret []byte, label string, seed []byte, keyLen int) []byte {
return tls12.PRF(hashFunc, secret, label, seed, keyLen)
}
}
const (
masterSecretLength = 48 // Length of a master secret in TLS 1.1.
finishedVerifyLength = 12 // Length of verify_data in a Finished message.
)
const masterSecretLabel = "master secret"
const extendedMasterSecretLabel = "extended master secret"
const keyExpansionLabel = "key expansion"
const clientFinishedLabel = "client finished"
const serverFinishedLabel = "server finished"
func prfAndHashForVersion(version uint16, suite *cipherSuite) (prfFunc, crypto.Hash) {
switch version {
case VersionTLS10, VersionTLS11:
return prf10, crypto.Hash(0)
case VersionTLS12:
if suite.flags&suiteSHA384 != 0 {
return prf12(sha512.New384), crypto.SHA384
}
return prf12(sha256.New), crypto.SHA256
default:
panic("unknown version")
}
}
func prfForVersion(version uint16, suite *cipherSuite) prfFunc {
prf, _ := prfAndHashForVersion(version, suite)
return prf
}
// masterFromPreMasterSecret generates the master secret from the pre-master
// secret. See RFC 5246, Section 8.1.
func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte {
seed := make([]byte, 0, len(clientRandom)+len(serverRandom))
seed = append(seed, clientRandom...)
seed = append(seed, serverRandom...)
return prfForVersion(version, suite)(preMasterSecret, masterSecretLabel, seed, masterSecretLength)
}
// extMasterFromPreMasterSecret generates the extended master secret from the
// pre-master secret. See RFC 7627.
func extMasterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, transcript []byte) []byte {
prf, hash := prfAndHashForVersion(version, suite)
if version == VersionTLS12 {
// Use the FIPS 140-3 module only for TLS 1.2 with EMS, which is the
// only TLS 1.0-1.2 approved mode per IG D.Q.
return tls12.MasterSecret(hash.New, preMasterSecret, transcript)
}
return prf(preMasterSecret, extendedMasterSecretLabel, transcript, masterSecretLength)
}
// keysFromMasterSecret generates the connection keys from the master
// secret, given the lengths of the MAC key, cipher key and IV, as defined in
// RFC 2246, Section 6.3.
func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) {
seed := make([]byte, 0, len(serverRandom)+len(clientRandom))
seed = append(seed, serverRandom...)
seed = append(seed, clientRandom...)
n := 2*macLen + 2*keyLen + 2*ivLen
keyMaterial := prfForVersion(version, suite)(masterSecret, keyExpansionLabel, seed, n)
clientMAC = keyMaterial[:macLen]
keyMaterial = keyMaterial[macLen:]
serverMAC = keyMaterial[:macLen]
keyMaterial = keyMaterial[macLen:]
clientKey = keyMaterial[:keyLen]
keyMaterial = keyMaterial[keyLen:]
serverKey = keyMaterial[:keyLen]
keyMaterial = keyMaterial[keyLen:]
clientIV = keyMaterial[:ivLen]
keyMaterial = keyMaterial[ivLen:]
serverIV = keyMaterial[:ivLen]
return
}
func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash {
var buffer []byte
if version >= VersionTLS12 {
buffer = []byte{}
}
prf, hash := prfAndHashForVersion(version, cipherSuite)
if hash != 0 {
return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf}
}
return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf}
}
// A finishedHash calculates the hash of a set of handshake messages suitable
// for including in a Finished message.
type finishedHash struct {
client hash.Hash
server hash.Hash
// Prior to TLS 1.2, an additional MD5 hash is required.
clientMD5 hash.Hash
serverMD5 hash.Hash
// In TLS 1.2, a full buffer is sadly required.
buffer []byte
version uint16
prf prfFunc
}
func (h *finishedHash) Write(msg []byte) (n int, err error) {
h.client.Write(msg)
h.server.Write(msg)
if h.version < VersionTLS12 {
h.clientMD5.Write(msg)
h.serverMD5.Write(msg)
}
if h.buffer != nil {
h.buffer = append(h.buffer, msg...)
}
return len(msg), nil
}
func (h finishedHash) Sum() []byte {
if h.version >= VersionTLS12 {
return h.client.Sum(nil)
}
out := make([]byte, 0, md5.Size+sha1.Size)
out = h.clientMD5.Sum(out)
return h.client.Sum(out)
}
// clientSum returns the contents of the verify_data member of a client's
// Finished message.
func (h finishedHash) clientSum(masterSecret []byte) []byte {
return h.prf(masterSecret, clientFinishedLabel, h.Sum(), finishedVerifyLength)
}
// serverSum returns the contents of the verify_data member of a server's
// Finished message.
func (h finishedHash) serverSum(masterSecret []byte) []byte {
return h.prf(masterSecret, serverFinishedLabel, h.Sum(), finishedVerifyLength)
}
// hashForClientCertificate returns the handshake messages so far, pre-hashed if
// necessary, suitable for signing by a TLS client certificate.
func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash) []byte {
if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil {
panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer")
}
if sigType == signatureEd25519 {
return h.buffer
}
if h.version >= VersionTLS12 {
hash := hashAlg.New()
hash.Write(h.buffer)
return hash.Sum(nil)
}
if sigType == signatureECDSA {
return h.server.Sum(nil)
}
return h.Sum()
}
// discardHandshakeBuffer is called when there is no more need to
// buffer the entirety of the handshake messages.
func (h *finishedHash) discardHandshakeBuffer() {
h.buffer = nil
}
// noEKMBecauseRenegotiation is used as a value of
// ConnectionState.ekm when renegotiation is enabled and thus
// we wish to fail all key-material export requests.
func noEKMBecauseRenegotiation(label string, context []byte, length int) ([]byte, error) {
return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled")
}
// noEKMBecauseNoEMS is used as a value of ConnectionState.ekm when Extended
// Master Secret is not negotiated and thus we wish to fail all key-material
// export requests.
func noEKMBecauseNoEMS(label string, context []byte, length int) ([]byte, error) {
return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when neither TLS 1.3 nor Extended Master Secret are negotiated; override with GODEBUG=tlsunsafeekm=1")
}
// ekmFromMasterSecret generates exported keying material as defined in RFC 5705.
func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) {
return func(label string, context []byte, length int) ([]byte, error) {
switch label {
case "client finished", "server finished", "master secret", "key expansion":
// These values are reserved and may not be used.
return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label)
}
seedLen := len(serverRandom) + len(clientRandom)
if context != nil {
seedLen += 2 + len(context)
}
seed := make([]byte, 0, seedLen)
seed = append(seed, clientRandom...)
seed = append(seed, serverRandom...)
if context != nil {
if len(context) >= 1<<16 {
return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long")
}
seed = append(seed, byte(len(context)>>8), byte(len(context)))
seed = append(seed, context...)
}
return prfForVersion(version, suite)(masterSecret, label, seed, length), nil
}
}
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"context"
"errors"
"fmt"
)
// QUICEncryptionLevel represents a QUIC encryption level used to transmit
// handshake messages.
type QUICEncryptionLevel int
const (
QUICEncryptionLevelInitial = QUICEncryptionLevel(iota)
QUICEncryptionLevelEarly
QUICEncryptionLevelHandshake
QUICEncryptionLevelApplication
)
func (l QUICEncryptionLevel) String() string {
switch l {
case QUICEncryptionLevelInitial:
return "Initial"
case QUICEncryptionLevelEarly:
return "Early"
case QUICEncryptionLevelHandshake:
return "Handshake"
case QUICEncryptionLevelApplication:
return "Application"
default:
return fmt.Sprintf("QUICEncryptionLevel(%v)", int(l))
}
}
// A QUICConn represents a connection which uses a QUIC implementation as the underlying
// transport as described in RFC 9001.
//
// Methods of QUICConn are not safe for concurrent use.
type QUICConn struct {
conn *Conn
sessionTicketSent bool
}
// A QUICConfig configures a [QUICConn].
type QUICConfig struct {
TLSConfig *Config
// EnableSessionEvents may be set to true to enable the
// [QUICStoreSession] and [QUICResumeSession] events for client connections.
// When this event is enabled, sessions are not automatically
// stored in the client session cache.
// The application should use [QUICConn.StoreSession] to store sessions.
EnableSessionEvents bool
}
// A QUICEventKind is a type of operation on a QUIC connection.
type QUICEventKind int
const (
// QUICNoEvent indicates that there are no events available.
QUICNoEvent QUICEventKind = iota
// QUICSetReadSecret and QUICSetWriteSecret provide the read and write
// secrets for a given encryption level.
// QUICEvent.Level, QUICEvent.Data, and QUICEvent.Suite are set.
//
// Secrets for the Initial encryption level are derived from the initial
// destination connection ID, and are not provided by the QUICConn.
QUICSetReadSecret
QUICSetWriteSecret
// QUICWriteData provides data to send to the peer in CRYPTO frames.
// QUICEvent.Data is set.
QUICWriteData
// QUICTransportParameters provides the peer's QUIC transport parameters.
// QUICEvent.Data is set.
QUICTransportParameters
// QUICTransportParametersRequired indicates that the caller must provide
// QUIC transport parameters to send to the peer. The caller should set
// the transport parameters with QUICConn.SetTransportParameters and call
// QUICConn.NextEvent again.
//
// If transport parameters are set before calling QUICConn.Start, the
// connection will never generate a QUICTransportParametersRequired event.
QUICTransportParametersRequired
// QUICRejectedEarlyData indicates that the server rejected 0-RTT data even
// if we offered it. It's returned before QUICEncryptionLevelApplication
// keys are returned.
// This event only occurs on client connections.
QUICRejectedEarlyData
// QUICHandshakeDone indicates that the TLS handshake has completed.
QUICHandshakeDone
// QUICResumeSession indicates that a client is attempting to resume a previous session.
// [QUICEvent.SessionState] is set.
//
// For client connections, this event occurs when the session ticket is selected.
// For server connections, this event occurs when receiving the client's session ticket.
//
// The application may set [QUICEvent.SessionState.EarlyData] to false before the
// next call to [QUICConn.NextEvent] to decline 0-RTT even if the session supports it.
QUICResumeSession
// QUICStoreSession indicates that the server has provided state permitting
// the client to resume the session.
// [QUICEvent.SessionState] is set.
// The application should use [QUICConn.StoreSession] session to store the [SessionState].
// The application may modify the [SessionState] before storing it.
// This event only occurs on client connections.
QUICStoreSession
)
// A QUICEvent is an event occurring on a QUIC connection.
//
// The type of event is specified by the Kind field.
// The contents of the other fields are kind-specific.
type QUICEvent struct {
Kind QUICEventKind
// Set for QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
Level QUICEncryptionLevel
// Set for QUICTransportParameters, QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
// The contents are owned by crypto/tls, and are valid until the next NextEvent call.
Data []byte
// Set for QUICSetReadSecret and QUICSetWriteSecret.
Suite uint16
// Set for QUICResumeSession and QUICStoreSession.
SessionState *SessionState
}
type quicState struct {
events []QUICEvent
nextEvent int
// eventArr is a statically allocated event array, large enough to handle
// the usual maximum number of events resulting from a single call: transport
// parameters, Initial data, Early read secret, Handshake write and read
// secrets, Handshake data, Application write secret, Application data.
eventArr [8]QUICEvent
started bool
signalc chan struct{} // handshake data is available to be read
blockedc chan struct{} // handshake is waiting for data, closed when done
cancelc <-chan struct{} // handshake has been canceled
cancel context.CancelFunc
waitingForDrain bool
// readbuf is shared between HandleData and the handshake goroutine.
// HandshakeCryptoData passes ownership to the handshake goroutine by
// reading from signalc, and reclaims ownership by reading from blockedc.
readbuf []byte
transportParams []byte // to send to the peer
enableSessionEvents bool
}
// QUICClient returns a new TLS client side connection using QUICTransport as the
// underlying transport. The config cannot be nil.
//
// The config's MinVersion must be at least TLS 1.3.
func QUICClient(config *QUICConfig) *QUICConn {
return newQUICConn(Client(nil, config.TLSConfig), config)
}
// QUICServer returns a new TLS server side connection using QUICTransport as the
// underlying transport. The config cannot be nil.
//
// The config's MinVersion must be at least TLS 1.3.
func QUICServer(config *QUICConfig) *QUICConn {
return newQUICConn(Server(nil, config.TLSConfig), config)
}
func newQUICConn(conn *Conn, config *QUICConfig) *QUICConn {
conn.quic = &quicState{
signalc: make(chan struct{}),
blockedc: make(chan struct{}),
enableSessionEvents: config.EnableSessionEvents,
}
conn.quic.events = conn.quic.eventArr[:0]
return &QUICConn{
conn: conn,
}
}
// Start starts the client or server handshake protocol.
// It may produce connection events, which may be read with [QUICConn.NextEvent].
//
// Start must be called at most once.
func (q *QUICConn) Start(ctx context.Context) error {
if q.conn.quic.started {
return quicError(errors.New("tls: Start called more than once"))
}
q.conn.quic.started = true
if q.conn.config.MinVersion < VersionTLS13 {
return quicError(errors.New("tls: Config MinVersion must be at least TLS 1.3"))
}
go q.conn.HandshakeContext(ctx)
if _, ok := <-q.conn.quic.blockedc; !ok {
return q.conn.handshakeErr
}
return nil
}
// NextEvent returns the next event occurring on the connection.
// It returns an event with a Kind of [QUICNoEvent] when no events are available.
func (q *QUICConn) NextEvent() QUICEvent {
qs := q.conn.quic
if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 {
// Write over some of the previous event's data,
// to catch callers erroniously retaining it.
qs.events[last].Data[0] = 0
}
if qs.nextEvent >= len(qs.events) && qs.waitingForDrain {
qs.waitingForDrain = false
<-qs.signalc
<-qs.blockedc
}
if qs.nextEvent >= len(qs.events) {
qs.events = qs.events[:0]
qs.nextEvent = 0
return QUICEvent{Kind: QUICNoEvent}
}
e := qs.events[qs.nextEvent]
qs.events[qs.nextEvent] = QUICEvent{} // zero out references to data
qs.nextEvent++
return e
}
// Close closes the connection and stops any in-progress handshake.
func (q *QUICConn) Close() error {
if q.conn.quic.cancel == nil {
return nil // never started
}
q.conn.quic.cancel()
for range q.conn.quic.blockedc {
// Wait for the handshake goroutine to return.
}
return q.conn.handshakeErr
}
// HandleData handles handshake bytes received from the peer.
// It may produce connection events, which may be read with [QUICConn.NextEvent].
func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error {
c := q.conn
if c.in.level != level {
return quicError(c.in.setErrorLocked(errors.New("tls: handshake data received at wrong level")))
}
c.quic.readbuf = data
<-c.quic.signalc
_, ok := <-c.quic.blockedc
if ok {
// The handshake goroutine is waiting for more data.
return nil
}
// The handshake goroutine has exited.
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
c.hand.Write(c.quic.readbuf)
c.quic.readbuf = nil
for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil {
b := q.conn.hand.Bytes()
n := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
if n > maxHandshake {
q.conn.handshakeErr = fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)
break
}
if len(b) < 4+n {
return nil
}
if err := q.conn.handlePostHandshakeMessage(); err != nil {
q.conn.handshakeErr = err
}
}
if q.conn.handshakeErr != nil {
return quicError(q.conn.handshakeErr)
}
return nil
}
type QUICSessionTicketOptions struct {
// EarlyData specifies whether the ticket may be used for 0-RTT.
EarlyData bool
Extra [][]byte
}
// SendSessionTicket sends a session ticket to the client.
// It produces connection events, which may be read with [QUICConn.NextEvent].
// Currently, it can only be called once.
func (q *QUICConn) SendSessionTicket(opts QUICSessionTicketOptions) error {
c := q.conn
if c.config.SessionTicketsDisabled {
return nil
}
if !c.isHandshakeComplete.Load() {
return quicError(errors.New("tls: SendSessionTicket called before handshake completed"))
}
if c.isClient {
return quicError(errors.New("tls: SendSessionTicket called on the client"))
}
if q.sessionTicketSent {
return quicError(errors.New("tls: SendSessionTicket called multiple times"))
}
q.sessionTicketSent = true
return quicError(c.sendSessionTicket(opts.EarlyData, opts.Extra))
}
// StoreSession stores a session previously received in a QUICStoreSession event
// in the ClientSessionCache.
// The application may process additional events or modify the SessionState
// before storing the session.
func (q *QUICConn) StoreSession(session *SessionState) error {
c := q.conn
if !c.isClient {
return quicError(errors.New("tls: StoreSessionTicket called on the server"))
}
cacheKey := c.clientSessionCacheKey()
if cacheKey == "" {
return nil
}
cs := &ClientSessionState{session: session}
c.config.ClientSessionCache.Put(cacheKey, cs)
return nil
}
// ConnectionState returns basic TLS details about the connection.
func (q *QUICConn) ConnectionState() ConnectionState {
return q.conn.ConnectionState()
}
// SetTransportParameters sets the transport parameters to send to the peer.
//
// Server connections may delay setting the transport parameters until after
// receiving the client's transport parameters. See [QUICTransportParametersRequired].
func (q *QUICConn) SetTransportParameters(params []byte) {
if params == nil {
params = []byte{}
}
q.conn.quic.transportParams = params
if q.conn.quic.started {
<-q.conn.quic.signalc
<-q.conn.quic.blockedc
}
}
// quicError ensures err is an AlertError.
// If err is not already, quicError wraps it with alertInternalError.
func quicError(err error) error {
if err == nil {
return nil
}
var ae AlertError
if errors.As(err, &ae) {
return err
}
var a alert
if !errors.As(err, &a) {
a = alertInternalError
}
// Return an error wrapping the original error and an AlertError.
// Truncate the text of the alert to 0 characters.
return fmt.Errorf("%w%.0w", err, AlertError(a))
}
func (c *Conn) quicReadHandshakeBytes(n int) error {
for c.hand.Len() < n {
if err := c.quicWaitForSignal(); err != nil {
return err
}
}
return nil
}
func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICSetReadSecret,
Level: level,
Suite: suite,
Data: secret,
})
}
func (c *Conn) quicSetWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICSetWriteSecret,
Level: level,
Suite: suite,
Data: secret,
})
}
func (c *Conn) quicWriteCryptoData(level QUICEncryptionLevel, data []byte) {
var last *QUICEvent
if len(c.quic.events) > 0 {
last = &c.quic.events[len(c.quic.events)-1]
}
if last == nil || last.Kind != QUICWriteData || last.Level != level {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICWriteData,
Level: level,
})
last = &c.quic.events[len(c.quic.events)-1]
}
last.Data = append(last.Data, data...)
}
func (c *Conn) quicResumeSession(session *SessionState) error {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICResumeSession,
SessionState: session,
})
c.quic.waitingForDrain = true
for c.quic.waitingForDrain {
if err := c.quicWaitForSignal(); err != nil {
return err
}
}
return nil
}
func (c *Conn) quicStoreSession(session *SessionState) {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICStoreSession,
SessionState: session,
})
}
func (c *Conn) quicSetTransportParameters(params []byte) {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICTransportParameters,
Data: params,
})
}
func (c *Conn) quicGetTransportParameters() ([]byte, error) {
if c.quic.transportParams == nil {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICTransportParametersRequired,
})
}
for c.quic.transportParams == nil {
if err := c.quicWaitForSignal(); err != nil {
return nil, err
}
}
return c.quic.transportParams, nil
}
func (c *Conn) quicHandshakeComplete() {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICHandshakeDone,
})
}
func (c *Conn) quicRejectedEarlyData() {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICRejectedEarlyData,
})
}
// quicWaitForSignal notifies the QUICConn that handshake progress is blocked,
// and waits for a signal that the handshake should proceed.
//
// The handshake may become blocked waiting for handshake bytes
// or for the user to provide transport parameters.
func (c *Conn) quicWaitForSignal() error {
// Drop the handshake mutex while blocked to allow the user
// to call ConnectionState before the handshake completes.
c.handshakeMutex.Unlock()
defer c.handshakeMutex.Lock()
// Send on blockedc to notify the QUICConn that the handshake is blocked.
// Exported methods of QUICConn wait for the handshake to become blocked
// before returning to the user.
select {
case c.quic.blockedc <- struct{}{}:
case <-c.quic.cancelc:
return c.sendAlertLocked(alertCloseNotify)
}
// The QUICConn reads from signalc to notify us that the handshake may
// be able to proceed. (The QUICConn reads, because we close signalc to
// indicate that the handshake has completed.)
select {
case c.quic.signalc <- struct{}{}:
c.hand.Write(c.quic.readbuf)
c.quic.readbuf = nil
case <-c.quic.cancelc:
return c.sendAlertLocked(alertCloseNotify)
}
return nil
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"crypto/x509"
"errors"
"io"
"golang.org/x/crypto/cryptobyte"
)
// A SessionState is a resumable session.
type SessionState struct {
// Encoded as a SessionState (in the language of RFC 8446, Section 3).
//
// enum { server(1), client(2) } SessionStateType;
//
// opaque Certificate<1..2^24-1>;
//
// Certificate CertificateChain<0..2^24-1>;
//
// opaque Extra<0..2^24-1>;
//
// struct {
// uint16 version;
// SessionStateType type;
// uint16 cipher_suite;
// uint64 created_at;
// opaque secret<1..2^8-1>;
// Extra extra<0..2^24-1>;
// uint8 ext_master_secret = { 0, 1 };
// uint8 early_data = { 0, 1 };
// CertificateEntry certificate_list<0..2^24-1>;
// CertificateChain verified_chains<0..2^24-1>; /* excluding leaf */
// select (SessionState.early_data) {
// case 0: Empty;
// case 1: opaque alpn<1..2^8-1>;
// };
// select (SessionState.version) {
// case VersionTLS10..VersionTLS12: uint16 curve_id;
// case VersionTLS13: select (SessionState.type) {
// case server: Empty;
// case client: struct {
// uint64 use_by;
// uint32 age_add;
// };
// };
// };
// } SessionState;
//
// The format can be extended backwards-compatibly by adding new fields at
// the end. Otherwise, a new SessionStateType must be used, as different Go
// versions may share the same session ticket encryption key.
// Extra is ignored by crypto/tls, but is encoded by [SessionState.Bytes]
// and parsed by [ParseSessionState].
//
// This allows [Config.UnwrapSession]/[Config.WrapSession] and
// [ClientSessionCache] implementations to store and retrieve additional
// data alongside this session.
//
// To allow different layers in a protocol stack to share this field,
// applications must only append to it, not replace it, and must use entries
// that can be recognized even if out of order (for example, by starting
// with an id and version prefix).
Extra [][]byte
// EarlyData indicates whether the ticket can be used for 0-RTT in a QUIC
// connection. The application may set this to false if it is true to
// decline to offer 0-RTT even if supported.
EarlyData bool
version uint16
isClient bool
cipherSuite uint16
// createdAt is the generation time of the secret on the sever (which for
// TLS 1.0–1.2 might be earlier than the current session) and the time at
// which the ticket was received on the client.
createdAt uint64 // seconds since UNIX epoch
secret []byte // master secret for TLS 1.2, or the PSK for TLS 1.3
extMasterSecret bool
peerCertificates []*x509.Certificate
ocspResponse []byte
scts [][]byte
verifiedChains [][]*x509.Certificate
alpnProtocol string // only set if EarlyData is true
// Client-side TLS 1.3-only fields.
useBy uint64 // seconds since UNIX epoch
ageAdd uint32
ticket []byte
// TLS 1.0–1.2 only fields.
curveID CurveID
}
// Bytes encodes the session, including any private fields, so that it can be
// parsed by [ParseSessionState]. The encoding contains secret values critical
// to the security of future and possibly past sessions.
//
// The specific encoding should be considered opaque and may change incompatibly
// between Go versions.
func (s *SessionState) Bytes() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(s.version)
if s.isClient {
b.AddUint8(2) // client
} else {
b.AddUint8(1) // server
}
b.AddUint16(s.cipherSuite)
addUint64(&b, s.createdAt)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(s.secret)
})
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
for _, extra := range s.Extra {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(extra)
})
}
})
if s.extMasterSecret {
b.AddUint8(1)
} else {
b.AddUint8(0)
}
if s.EarlyData {
b.AddUint8(1)
} else {
b.AddUint8(0)
}
marshalCertificate(&b, Certificate{
Certificate: certificatesToBytesSlice(s.peerCertificates),
OCSPStaple: s.ocspResponse,
SignedCertificateTimestamps: s.scts,
})
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
for _, chain := range s.verifiedChains {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
// We elide the first certificate because it's always the leaf.
if len(chain) == 0 {
b.SetError(errors.New("tls: internal error: empty verified chain"))
return
}
for _, cert := range chain[1:] {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(cert.Raw)
})
}
})
}
})
if s.EarlyData {
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte(s.alpnProtocol))
})
}
if s.version >= VersionTLS13 {
if s.isClient {
addUint64(&b, s.useBy)
b.AddUint32(s.ageAdd)
}
} else {
b.AddUint16(uint16(s.curveID))
}
return b.Bytes()
}
func certificatesToBytesSlice(certs []*x509.Certificate) [][]byte {
s := make([][]byte, 0, len(certs))
for _, c := range certs {
s = append(s, c.Raw)
}
return s
}
// ParseSessionState parses a [SessionState] encoded by [SessionState.Bytes].
func ParseSessionState(data []byte) (*SessionState, error) {
ss := &SessionState{}
s := cryptobyte.String(data)
var typ, extMasterSecret, earlyData uint8
var cert Certificate
var extra cryptobyte.String
if !s.ReadUint16(&ss.version) ||
!s.ReadUint8(&typ) ||
!s.ReadUint16(&ss.cipherSuite) ||
!readUint64(&s, &ss.createdAt) ||
!readUint8LengthPrefixed(&s, &ss.secret) ||
!s.ReadUint24LengthPrefixed(&extra) ||
!s.ReadUint8(&extMasterSecret) ||
!s.ReadUint8(&earlyData) ||
len(ss.secret) == 0 ||
!unmarshalCertificate(&s, &cert) {
return nil, errors.New("tls: invalid session encoding")
}
for !extra.Empty() {
var e []byte
if !readUint24LengthPrefixed(&extra, &e) {
return nil, errors.New("tls: invalid session encoding")
}
ss.Extra = append(ss.Extra, e)
}
switch typ {
case 1:
ss.isClient = false
case 2:
ss.isClient = true
default:
return nil, errors.New("tls: unknown session encoding")
}
switch extMasterSecret {
case 0:
ss.extMasterSecret = false
case 1:
ss.extMasterSecret = true
default:
return nil, errors.New("tls: invalid session encoding")
}
switch earlyData {
case 0:
ss.EarlyData = false
case 1:
ss.EarlyData = true
default:
return nil, errors.New("tls: invalid session encoding")
}
for _, cert := range cert.Certificate {
c, err := globalCertCache.newCert(cert)
if err != nil {
return nil, err
}
ss.peerCertificates = append(ss.peerCertificates, c)
}
if ss.isClient && len(ss.peerCertificates) == 0 {
return nil, errors.New("tls: no server certificates in client session")
}
ss.ocspResponse = cert.OCSPStaple
ss.scts = cert.SignedCertificateTimestamps
var chainList cryptobyte.String
if !s.ReadUint24LengthPrefixed(&chainList) {
return nil, errors.New("tls: invalid session encoding")
}
for !chainList.Empty() {
var certList cryptobyte.String
if !chainList.ReadUint24LengthPrefixed(&certList) {
return nil, errors.New("tls: invalid session encoding")
}
var chain []*x509.Certificate
if len(ss.peerCertificates) == 0 {
return nil, errors.New("tls: invalid session encoding")
}
chain = append(chain, ss.peerCertificates[0])
for !certList.Empty() {
var cert []byte
if !readUint24LengthPrefixed(&certList, &cert) {
return nil, errors.New("tls: invalid session encoding")
}
c, err := globalCertCache.newCert(cert)
if err != nil {
return nil, err
}
chain = append(chain, c)
}
ss.verifiedChains = append(ss.verifiedChains, chain)
}
if ss.EarlyData {
var alpn []byte
if !readUint8LengthPrefixed(&s, &alpn) {
return nil, errors.New("tls: invalid session encoding")
}
ss.alpnProtocol = string(alpn)
}
if ss.version >= VersionTLS13 {
if ss.isClient {
if !s.ReadUint64(&ss.useBy) || !s.ReadUint32(&ss.ageAdd) {
return nil, errors.New("tls: invalid session encoding")
}
}
} else {
if !s.ReadUint16((*uint16)(&ss.curveID)) {
return nil, errors.New("tls: invalid session encoding")
}
}
return ss, nil
}
// sessionState returns a partially filled-out [SessionState] with information
// from the current connection.
func (c *Conn) sessionState() *SessionState {
return &SessionState{
version: c.vers,
cipherSuite: c.cipherSuite,
createdAt: uint64(c.config.time().Unix()),
alpnProtocol: c.clientProtocol,
peerCertificates: c.peerCertificates,
ocspResponse: c.ocspResponse,
scts: c.scts,
isClient: c.isClient,
extMasterSecret: c.extMasterSecret,
verifiedChains: c.verifiedChains,
curveID: c.curveID,
}
}
// EncryptTicket encrypts a ticket with the [Config]'s configured (or default)
// session ticket keys. It can be used as a [Config.WrapSession] implementation.
func (c *Config) EncryptTicket(cs ConnectionState, ss *SessionState) ([]byte, error) {
ticketKeys := c.ticketKeys(nil)
stateBytes, err := ss.Bytes()
if err != nil {
return nil, err
}
return c.encryptTicket(stateBytes, ticketKeys)
}
func (c *Config) encryptTicket(state []byte, ticketKeys []ticketKey) ([]byte, error) {
if len(ticketKeys) == 0 {
return nil, errors.New("tls: internal error: session ticket keys unavailable")
}
encrypted := make([]byte, aes.BlockSize+len(state)+sha256.Size)
iv := encrypted[:aes.BlockSize]
ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size]
authenticated := encrypted[:len(encrypted)-sha256.Size]
macBytes := encrypted[len(encrypted)-sha256.Size:]
if _, err := io.ReadFull(c.rand(), iv); err != nil {
return nil, err
}
key := ticketKeys[0]
block, err := aes.NewCipher(key.aesKey[:])
if err != nil {
return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
}
cipher.NewCTR(block, iv).XORKeyStream(ciphertext, state)
mac := hmac.New(sha256.New, key.hmacKey[:])
mac.Write(authenticated)
mac.Sum(macBytes[:0])
return encrypted, nil
}
// DecryptTicket decrypts a ticket encrypted by [Config.EncryptTicket]. It can
// be used as a [Config.UnwrapSession] implementation.
//
// If the ticket can't be decrypted or parsed, DecryptTicket returns (nil, nil).
func (c *Config) DecryptTicket(identity []byte, cs ConnectionState) (*SessionState, error) {
ticketKeys := c.ticketKeys(nil)
stateBytes := c.decryptTicket(identity, ticketKeys)
if stateBytes == nil {
return nil, nil
}
s, err := ParseSessionState(stateBytes)
if err != nil {
return nil, nil // drop unparsable tickets on the floor
}
return s, nil
}
func (c *Config) decryptTicket(encrypted []byte, ticketKeys []ticketKey) []byte {
if len(encrypted) < aes.BlockSize+sha256.Size {
return nil
}
iv := encrypted[:aes.BlockSize]
ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size]
authenticated := encrypted[:len(encrypted)-sha256.Size]
macBytes := encrypted[len(encrypted)-sha256.Size:]
for _, key := range ticketKeys {
mac := hmac.New(sha256.New, key.hmacKey[:])
mac.Write(authenticated)
expected := mac.Sum(nil)
if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
continue
}
block, err := aes.NewCipher(key.aesKey[:])
if err != nil {
return nil
}
plaintext := make([]byte, len(ciphertext))
cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
return plaintext
}
return nil
}
// ClientSessionState contains the state needed by a client to
// resume a previous TLS session.
type ClientSessionState struct {
session *SessionState
}
// ResumptionState returns the session ticket sent by the server (also known as
// the session's identity) and the state necessary to resume this session.
//
// It can be called by [ClientSessionCache.Put] to serialize (with
// [SessionState.Bytes]) and store the session.
func (cs *ClientSessionState) ResumptionState() (ticket []byte, state *SessionState, err error) {
if cs == nil || cs.session == nil {
return nil, nil, nil
}
return cs.session.ticket, cs.session, nil
}
// NewResumptionState returns a state value that can be returned by
// [ClientSessionCache.Get] to resume a previous session.
//
// state needs to be returned by [ParseSessionState], and the ticket and session
// state must have been returned by [ClientSessionState.ResumptionState].
func NewResumptionState(ticket []byte, state *SessionState) (*ClientSessionState, error) {
state.ticket = ticket
return &ClientSessionState{
session: state,
}, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tls partially implements TLS 1.2, as specified in RFC 5246,
// and TLS 1.3, as specified in RFC 8446.
//
// # FIPS 140-3 mode
//
// When the program is in [FIPS 140-3 mode], this package behaves as if only
// SP 800-140C and SP 800-140D approved protocol versions, cipher suites,
// signature algorithms, certificate public key types and sizes, and key
// exchange and derivation algorithms were implemented. Others are silently
// ignored and not negotiated, or rejected. This set may depend on the
// algorithms supported by the FIPS 140-3 Go Cryptographic Module selected with
// GOFIPS140, and may change across Go versions.
//
// [FIPS 140-3 mode]: https://go.dev/doc/security/fips140
package tls
// BUG(agl): The crypto/tls package only implements some countermeasures
// against Lucky13 attacks on CBC-mode encryption, and only on SHA1
// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and
// https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"internal/godebug"
"net"
"os"
"strings"
)
// Server returns a new TLS server side connection
// using conn as the underlying transport.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Server(conn net.Conn, config *Config) *Conn {
c := &Conn{
conn: conn,
config: config,
}
c.handshakeFn = c.serverHandshake
return c
}
// Client returns a new TLS client side connection
// using conn as the underlying transport.
// The config cannot be nil: users must set either ServerName or
// InsecureSkipVerify in the config.
func Client(conn net.Conn, config *Config) *Conn {
c := &Conn{
conn: conn,
config: config,
isClient: true,
}
c.handshakeFn = c.clientHandshake
return c
}
// A listener implements a network listener (net.Listener) for TLS connections.
type listener struct {
net.Listener
config *Config
}
// Accept waits for and returns the next incoming TLS connection.
// The returned connection is of type *Conn.
func (l *listener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return Server(c, l.config), nil
}
// NewListener creates a Listener which accepts connections from an inner
// Listener and wraps each connection with [Server].
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func NewListener(inner net.Listener, config *Config) net.Listener {
l := new(listener)
l.Listener = inner
l.config = config
return l
}
// Listen creates a TLS listener accepting connections on the
// given network address using net.Listen.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Listen(network, laddr string, config *Config) (net.Listener, error) {
// If this condition changes, consider updating http.Server.ServeTLS too.
if config == nil || len(config.Certificates) == 0 &&
config.GetCertificate == nil && config.GetConfigForClient == nil {
return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
}
l, err := net.Listen(network, laddr)
if err != nil {
return nil, err
}
return NewListener(l, config), nil
}
type timeoutError struct{}
func (timeoutError) Error() string { return "tls: DialWithDialer timed out" }
func (timeoutError) Timeout() bool { return true }
func (timeoutError) Temporary() bool { return true }
// DialWithDialer connects to the given network address using dialer.Dial and
// then initiates a TLS handshake, returning the resulting TLS connection. Any
// timeout or deadline given in the dialer apply to connection and TLS
// handshake as a whole.
//
// DialWithDialer interprets a nil configuration as equivalent to the zero
// configuration; see the documentation of [Config] for the defaults.
//
// DialWithDialer uses context.Background internally; to specify the context,
// use [Dialer.DialContext] with NetDialer set to the desired dialer.
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
return dial(context.Background(), dialer, network, addr, config)
}
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
if netDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
defer cancel()
}
if !netDialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
defer cancel()
}
rawConn, err := netDialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
if config == nil {
config = defaultConfig()
}
// If no ServerName is set, infer the ServerName
// from the hostname we're connecting to.
if config.ServerName == "" {
// Make a copy to avoid polluting argument or default.
c := config.Clone()
c.ServerName = hostname
config = c
}
conn := Client(rawConn, config)
if err := conn.HandshakeContext(ctx); err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}
// Dial connects to the given network address using net.Dial
// and then initiates a TLS handshake, returning the resulting
// TLS connection.
// Dial interprets a nil configuration as equivalent to
// the zero configuration; see the documentation of Config
// for the defaults.
func Dial(network, addr string, config *Config) (*Conn, error) {
return DialWithDialer(new(net.Dialer), network, addr, config)
}
// Dialer dials TLS connections given a configuration and a Dialer for the
// underlying connection.
type Dialer struct {
// NetDialer is the optional dialer to use for the TLS connections'
// underlying TCP connections.
// A nil NetDialer is equivalent to the net.Dialer zero value.
NetDialer *net.Dialer
// Config is the TLS configuration to use for new connections.
// A nil configuration is equivalent to the zero
// configuration; see the documentation of Config for the
// defaults.
Config *Config
}
// Dial connects to the given network address and initiates a TLS
// handshake, returning the resulting TLS connection.
//
// The returned [Conn], if any, will always be of type *[Conn].
//
// Dial uses context.Background internally; to specify the context,
// use [Dialer.DialContext].
func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
return d.DialContext(context.Background(), network, addr)
}
func (d *Dialer) netDialer() *net.Dialer {
if d.NetDialer != nil {
return d.NetDialer
}
return new(net.Dialer)
}
// DialContext connects to the given network address and initiates a TLS
// handshake, returning the resulting TLS connection.
//
// The provided Context must be non-nil. If the context expires before
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
//
// The returned [Conn], if any, will always be of type *[Conn].
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
if err != nil {
// Don't return c (a typed nil) in an interface.
return nil, err
}
return c, nil
}
// LoadX509KeyPair reads and parses a public/private key pair from a pair of
// files. The files must contain PEM encoded data. The certificate file may
// contain intermediate certificates following the leaf certificate to form a
// certificate chain. On successful return, Certificate.Leaf will be populated.
//
// Before Go 1.23 Certificate.Leaf was left nil, and the parsed certificate was
// discarded. This behavior can be re-enabled by setting "x509keypairleaf=0"
// in the GODEBUG environment variable.
func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
certPEMBlock, err := os.ReadFile(certFile)
if err != nil {
return Certificate{}, err
}
keyPEMBlock, err := os.ReadFile(keyFile)
if err != nil {
return Certificate{}, err
}
return X509KeyPair(certPEMBlock, keyPEMBlock)
}
var x509keypairleaf = godebug.New("x509keypairleaf")
// X509KeyPair parses a public/private key pair from a pair of
// PEM encoded data. On successful return, Certificate.Leaf will be populated.
//
// Before Go 1.23 Certificate.Leaf was left nil, and the parsed certificate was
// discarded. This behavior can be re-enabled by setting "x509keypairleaf=0"
// in the GODEBUG environment variable.
func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
fail := func(err error) (Certificate, error) { return Certificate{}, err }
var cert Certificate
var skippedBlockTypes []string
for {
var certDERBlock *pem.Block
certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
if certDERBlock == nil {
break
}
if certDERBlock.Type == "CERTIFICATE" {
cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
} else {
skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
}
}
if len(cert.Certificate) == 0 {
if len(skippedBlockTypes) == 0 {
return fail(errors.New("tls: failed to find any PEM data in certificate input"))
}
if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
}
return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
}
skippedBlockTypes = skippedBlockTypes[:0]
var keyDERBlock *pem.Block
for {
keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
if keyDERBlock == nil {
if len(skippedBlockTypes) == 0 {
return fail(errors.New("tls: failed to find any PEM data in key input"))
}
if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
}
return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
}
if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
break
}
skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
}
// We don't need to parse the public key for TLS, but we so do anyway
// to check that it looks sane and matches the private key.
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return fail(err)
}
if x509keypairleaf.Value() != "0" {
cert.Leaf = x509Cert
} else {
x509keypairleaf.IncNonDefault()
}
cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
if err != nil {
return fail(err)
}
switch pub := x509Cert.PublicKey.(type) {
case *rsa.PublicKey:
priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
if !ok {
return fail(errors.New("tls: private key type does not match public key type"))
}
if pub.N.Cmp(priv.N) != 0 {
return fail(errors.New("tls: private key does not match public key"))
}
case *ecdsa.PublicKey:
priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
if !ok {
return fail(errors.New("tls: private key type does not match public key type"))
}
if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
return fail(errors.New("tls: private key does not match public key"))
}
case ed25519.PublicKey:
priv, ok := cert.PrivateKey.(ed25519.PrivateKey)
if !ok {
return fail(errors.New("tls: private key type does not match public key type"))
}
if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) {
return fail(errors.New("tls: private key does not match public key"))
}
default:
return fail(errors.New("tls: unknown public key algorithm"))
}
return cert, nil
}
// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys.
// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
return key, nil
}
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
switch key := key.(type) {
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
return key, nil
default:
return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping")
}
}
if key, err := x509.ParseECPrivateKey(der); err == nil {
return key, nil
}
return nil, errors.New("tls: failed to parse private key")
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"bytes"
"crypto/sha256"
"encoding/pem"
"sync"
)
type sum224 [sha256.Size224]byte
// CertPool is a set of certificates.
type CertPool struct {
byName map[string][]int // cert.RawSubject => index into lazyCerts
// lazyCerts contains funcs that return a certificate,
// lazily parsing/decompressing it as needed.
lazyCerts []lazyCert
// haveSum maps from sum224(cert.Raw) to true. It's used only
// for AddCert duplicate detection, to avoid CertPool.contains
// calls in the AddCert path (because the contains method can
// call getCert and otherwise negate savings from lazy getCert
// funcs).
haveSum map[sum224]bool
// systemPool indicates whether this is a special pool derived from the
// system roots. If it includes additional roots, it requires doing two
// verifications, one using the roots provided by the caller, and one using
// the system platform verifier.
systemPool bool
}
// lazyCert is minimal metadata about a Cert and a func to retrieve it
// in its normal expanded *Certificate form.
type lazyCert struct {
// rawSubject is the Certificate.RawSubject value.
// It's the same as the CertPool.byName key, but in []byte
// form to make CertPool.Subjects (as used by crypto/tls) do
// fewer allocations.
rawSubject []byte
// constraint is a function to run against a chain when it is a candidate to
// be added to the chain. This allows adding arbitrary constraints that are
// not specified in the certificate itself.
constraint func([]*Certificate) error
// getCert returns the certificate.
//
// It is not meant to do network operations or anything else
// where a failure is likely; the func is meant to lazily
// parse/decompress data that is already known to be good. The
// error in the signature primarily is meant for use in the
// case where a cert file existed on local disk when the program
// started up is deleted later before it's read.
getCert func() (*Certificate, error)
}
// NewCertPool returns a new, empty CertPool.
func NewCertPool() *CertPool {
return &CertPool{
byName: make(map[string][]int),
haveSum: make(map[sum224]bool),
}
}
// len returns the number of certs in the set.
// A nil set is a valid empty set.
func (s *CertPool) len() int {
if s == nil {
return 0
}
return len(s.lazyCerts)
}
// cert returns cert index n in s.
func (s *CertPool) cert(n int) (*Certificate, func([]*Certificate) error, error) {
cert, err := s.lazyCerts[n].getCert()
return cert, s.lazyCerts[n].constraint, err
}
// Clone returns a copy of s.
func (s *CertPool) Clone() *CertPool {
p := &CertPool{
byName: make(map[string][]int, len(s.byName)),
lazyCerts: make([]lazyCert, len(s.lazyCerts)),
haveSum: make(map[sum224]bool, len(s.haveSum)),
systemPool: s.systemPool,
}
for k, v := range s.byName {
indexes := make([]int, len(v))
copy(indexes, v)
p.byName[k] = indexes
}
for k := range s.haveSum {
p.haveSum[k] = true
}
copy(p.lazyCerts, s.lazyCerts)
return p
}
// SystemCertPool returns a copy of the system cert pool.
//
// On Unix systems other than macOS the environment variables SSL_CERT_FILE and
// SSL_CERT_DIR can be used to override the system default locations for the SSL
// certificate file and SSL certificate files directory, respectively. The
// latter can be a colon-separated list.
//
// Any mutations to the returned pool are not written to disk and do not affect
// any other pool returned by SystemCertPool.
//
// New changes in the system cert pool might not be reflected in subsequent calls.
func SystemCertPool() (*CertPool, error) {
if sysRoots := systemRootsPool(); sysRoots != nil {
return sysRoots.Clone(), nil
}
return loadSystemRoots()
}
type potentialParent struct {
cert *Certificate
constraint func([]*Certificate) error
}
// findPotentialParents returns the certificates in s which might have signed
// cert.
func (s *CertPool) findPotentialParents(cert *Certificate) []potentialParent {
if s == nil {
return nil
}
// consider all candidates where cert.Issuer matches cert.Subject.
// when picking possible candidates the list is built in the order
// of match plausibility as to save cycles in buildChains:
// AKID and SKID match
// AKID present, SKID missing / AKID missing, SKID present
// AKID and SKID don't match
var matchingKeyID, oneKeyID, mismatchKeyID []potentialParent
for _, c := range s.byName[string(cert.RawIssuer)] {
candidate, constraint, err := s.cert(c)
if err != nil {
continue
}
kidMatch := bytes.Equal(candidate.SubjectKeyId, cert.AuthorityKeyId)
switch {
case kidMatch:
matchingKeyID = append(matchingKeyID, potentialParent{candidate, constraint})
case (len(candidate.SubjectKeyId) == 0 && len(cert.AuthorityKeyId) > 0) ||
(len(candidate.SubjectKeyId) > 0 && len(cert.AuthorityKeyId) == 0):
oneKeyID = append(oneKeyID, potentialParent{candidate, constraint})
default:
mismatchKeyID = append(mismatchKeyID, potentialParent{candidate, constraint})
}
}
found := len(matchingKeyID) + len(oneKeyID) + len(mismatchKeyID)
if found == 0 {
return nil
}
candidates := make([]potentialParent, 0, found)
candidates = append(candidates, matchingKeyID...)
candidates = append(candidates, oneKeyID...)
candidates = append(candidates, mismatchKeyID...)
return candidates
}
func (s *CertPool) contains(cert *Certificate) bool {
if s == nil {
return false
}
return s.haveSum[sha256.Sum224(cert.Raw)]
}
// AddCert adds a certificate to a pool.
func (s *CertPool) AddCert(cert *Certificate) {
if cert == nil {
panic("adding nil Certificate to CertPool")
}
s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) {
return cert, nil
}, nil)
}
// addCertFunc adds metadata about a certificate to a pool, along with
// a func to fetch that certificate later when needed.
//
// The rawSubject is Certificate.RawSubject and must be non-empty.
// The getCert func may be called 0 or more times.
func (s *CertPool) addCertFunc(rawSum224 sum224, rawSubject string, getCert func() (*Certificate, error), constraint func([]*Certificate) error) {
if getCert == nil {
panic("getCert can't be nil")
}
// Check that the certificate isn't being added twice.
if s.haveSum[rawSum224] {
return
}
s.haveSum[rawSum224] = true
s.lazyCerts = append(s.lazyCerts, lazyCert{
rawSubject: []byte(rawSubject),
getCert: getCert,
constraint: constraint,
})
s.byName[rawSubject] = append(s.byName[rawSubject], len(s.lazyCerts)-1)
}
// AppendCertsFromPEM attempts to parse a series of PEM encoded certificates.
// It appends any certificates found to s and reports whether any certificates
// were successfully parsed.
//
// On many Linux systems, /etc/ssl/cert.pem will contain the system wide set
// of root CAs in a format suitable for this function.
func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
for len(pemCerts) > 0 {
var block *pem.Block
block, pemCerts = pem.Decode(pemCerts)
if block == nil {
break
}
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
continue
}
certBytes := block.Bytes
cert, err := ParseCertificate(certBytes)
if err != nil {
continue
}
var lazyCert struct {
sync.Once
v *Certificate
}
s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) {
lazyCert.Do(func() {
// This can't fail, as the same bytes already parsed above.
lazyCert.v, _ = ParseCertificate(certBytes)
certBytes = nil
})
return lazyCert.v, nil
}, nil)
ok = true
}
return ok
}
// Subjects returns a list of the DER-encoded subjects of
// all of the certificates in the pool.
//
// Deprecated: if s was returned by [SystemCertPool], Subjects
// will not include the system roots.
func (s *CertPool) Subjects() [][]byte {
res := make([][]byte, s.len())
for i, lc := range s.lazyCerts {
res[i] = lc.rawSubject
}
return res
}
// Equal reports whether s and other are equal.
func (s *CertPool) Equal(other *CertPool) bool {
if s == nil || other == nil {
return s == other
}
if s.systemPool != other.systemPool || len(s.haveSum) != len(other.haveSum) {
return false
}
for h := range s.haveSum {
if !other.haveSum[h] {
return false
}
}
return true
}
// AddCertWithConstraint adds a certificate to the pool with the additional
// constraint. When Certificate.Verify builds a chain which is rooted by cert,
// it will additionally pass the whole chain to constraint to determine its
// validity. If constraint returns a non-nil error, the chain will be discarded.
// constraint may be called concurrently from multiple goroutines.
func (s *CertPool) AddCertWithConstraint(cert *Certificate, constraint func([]*Certificate) error) {
if cert == nil {
panic("adding nil Certificate to CertPool")
}
s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) {
return cert, nil
}, constraint)
}
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"bytes"
"encoding/asn1"
"errors"
"math"
"math/big"
"math/bits"
"strconv"
"strings"
)
var (
errInvalidOID = errors.New("invalid oid")
)
// An OID represents an ASN.1 OBJECT IDENTIFIER.
type OID struct {
der []byte
}
// ParseOID parses a Object Identifier string, represented by ASCII numbers separated by dots.
func ParseOID(oid string) (OID, error) {
var o OID
return o, o.unmarshalOIDText(oid)
}
func newOIDFromDER(der []byte) (OID, bool) {
if len(der) == 0 || der[len(der)-1]&0x80 != 0 {
return OID{}, false
}
start := 0
for i, v := range der {
// ITU-T X.690, section 8.19.2:
// The subidentifier shall be encoded in the fewest possible octets,
// that is, the leading octet of the subidentifier shall not have the value 0x80.
if i == start && v == 0x80 {
return OID{}, false
}
if v&0x80 == 0 {
start = i + 1
}
}
return OID{der}, true
}
// OIDFromInts creates a new OID using ints, each integer is a separate component.
func OIDFromInts(oid []uint64) (OID, error) {
if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
return OID{}, errInvalidOID
}
length := base128IntLength(oid[0]*40 + oid[1])
for _, v := range oid[2:] {
length += base128IntLength(v)
}
der := make([]byte, 0, length)
der = appendBase128Int(der, oid[0]*40+oid[1])
for _, v := range oid[2:] {
der = appendBase128Int(der, v)
}
return OID{der}, nil
}
func base128IntLength(n uint64) int {
if n == 0 {
return 1
}
return (bits.Len64(n) + 6) / 7
}
func appendBase128Int(dst []byte, n uint64) []byte {
for i := base128IntLength(n) - 1; i >= 0; i-- {
o := byte(n >> uint(i*7))
o &= 0x7f
if i != 0 {
o |= 0x80
}
dst = append(dst, o)
}
return dst
}
func base128BigIntLength(n *big.Int) int {
if n.Cmp(big.NewInt(0)) == 0 {
return 1
}
return (n.BitLen() + 6) / 7
}
func appendBase128BigInt(dst []byte, n *big.Int) []byte {
if n.Cmp(big.NewInt(0)) == 0 {
return append(dst, 0)
}
for i := base128BigIntLength(n) - 1; i >= 0; i-- {
o := byte(big.NewInt(0).Rsh(n, uint(i)*7).Bits()[0])
o &= 0x7f
if i != 0 {
o |= 0x80
}
dst = append(dst, o)
}
return dst
}
// AppendText implements [encoding.TextAppender]
func (o OID) AppendText(b []byte) ([]byte, error) {
return append(b, o.String()...), nil
}
// MarshalText implements [encoding.TextMarshaler]
func (o OID) MarshalText() ([]byte, error) {
return o.AppendText(nil)
}
// UnmarshalText implements [encoding.TextUnmarshaler]
func (o *OID) UnmarshalText(text []byte) error {
return o.unmarshalOIDText(string(text))
}
func (o *OID) unmarshalOIDText(oid string) error {
// (*big.Int).SetString allows +/- signs, but we don't want
// to allow them in the string representation of Object Identifier, so
// reject such encodings.
for _, c := range oid {
isDigit := c >= '0' && c <= '9'
if !isDigit && c != '.' {
return errInvalidOID
}
}
var (
firstNum string
secondNum string
)
var nextComponentExists bool
firstNum, oid, nextComponentExists = strings.Cut(oid, ".")
if !nextComponentExists {
return errInvalidOID
}
secondNum, oid, nextComponentExists = strings.Cut(oid, ".")
var (
first = big.NewInt(0)
second = big.NewInt(0)
)
if _, ok := first.SetString(firstNum, 10); !ok {
return errInvalidOID
}
if _, ok := second.SetString(secondNum, 10); !ok {
return errInvalidOID
}
if first.Cmp(big.NewInt(2)) > 0 || (first.Cmp(big.NewInt(2)) < 0 && second.Cmp(big.NewInt(40)) >= 0) {
return errInvalidOID
}
firstComponent := first.Mul(first, big.NewInt(40))
firstComponent.Add(firstComponent, second)
der := appendBase128BigInt(make([]byte, 0, 32), firstComponent)
for nextComponentExists {
var strNum string
strNum, oid, nextComponentExists = strings.Cut(oid, ".")
b, ok := big.NewInt(0).SetString(strNum, 10)
if !ok {
return errInvalidOID
}
der = appendBase128BigInt(der, b)
}
o.der = der
return nil
}
// AppendBinary implements [encoding.BinaryAppender]
func (o OID) AppendBinary(b []byte) ([]byte, error) {
return append(b, o.der...), nil
}
// MarshalBinary implements [encoding.BinaryMarshaler]
func (o OID) MarshalBinary() ([]byte, error) {
return o.AppendBinary(nil)
}
// UnmarshalBinary implements [encoding.BinaryUnmarshaler]
func (o *OID) UnmarshalBinary(b []byte) error {
oid, ok := newOIDFromDER(bytes.Clone(b))
if !ok {
return errInvalidOID
}
*o = oid
return nil
}
// Equal returns true when oid and other represents the same Object Identifier.
func (oid OID) Equal(other OID) bool {
// There is only one possible DER encoding of
// each unique Object Identifier.
return bytes.Equal(oid.der, other.der)
}
func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, failed bool) {
offset = initOffset
var ret64 int64
for shifted := 0; offset < len(bytes); shifted++ {
// 5 * 7 bits per byte == 35 bits of data
// Thus the representation is either non-minimal or too large for an int32
if shifted == 5 {
failed = true
return
}
ret64 <<= 7
b := bytes[offset]
// integers should be minimally encoded, so the leading octet should
// never be 0x80
if shifted == 0 && b == 0x80 {
failed = true
return
}
ret64 |= int64(b & 0x7f)
offset++
if b&0x80 == 0 {
ret = int(ret64)
// Ensure that the returned value fits in an int on all platforms
if ret64 > math.MaxInt32 {
failed = true
}
return
}
}
failed = true
return
}
// EqualASN1OID returns whether an OID equals an asn1.ObjectIdentifier. If
// asn1.ObjectIdentifier cannot represent the OID specified by oid, because
// a component of OID requires more than 31 bits, it returns false.
func (oid OID) EqualASN1OID(other asn1.ObjectIdentifier) bool {
if len(other) < 2 {
return false
}
v, offset, failed := parseBase128Int(oid.der, 0)
if failed {
// This should never happen, since we've already parsed the OID,
// but just in case.
return false
}
if v < 80 {
a, b := v/40, v%40
if other[0] != a || other[1] != b {
return false
}
} else {
a, b := 2, v-80
if other[0] != a || other[1] != b {
return false
}
}
i := 2
for ; offset < len(oid.der); i++ {
v, offset, failed = parseBase128Int(oid.der, offset)
if failed {
// Again, shouldn't happen, since we've already parsed
// the OID, but better safe than sorry.
return false
}
if i >= len(other) || v != other[i] {
return false
}
}
return i == len(other)
}
// Strings returns the string representation of the Object Identifier.
func (oid OID) String() string {
var b strings.Builder
b.Grow(32)
const (
valSize = 64 // size in bits of val.
bitsPerByte = 7
maxValSafeShift = (1 << (valSize - bitsPerByte)) - 1
)
var (
start = 0
val = uint64(0)
numBuf = make([]byte, 0, 21)
bigVal *big.Int
overflow bool
)
for i, v := range oid.der {
curVal := v & 0x7F
valEnd := v&0x80 == 0
if valEnd {
if start != 0 {
b.WriteByte('.')
}
}
if !overflow && val > maxValSafeShift {
if bigVal == nil {
bigVal = new(big.Int)
}
bigVal = bigVal.SetUint64(val)
overflow = true
}
if overflow {
bigVal = bigVal.Lsh(bigVal, bitsPerByte).Or(bigVal, big.NewInt(int64(curVal)))
if valEnd {
if start == 0 {
b.WriteString("2.")
bigVal = bigVal.Sub(bigVal, big.NewInt(80))
}
numBuf = bigVal.Append(numBuf, 10)
b.Write(numBuf)
numBuf = numBuf[:0]
val = 0
start = i + 1
overflow = false
}
continue
}
val <<= bitsPerByte
val |= uint64(curVal)
if valEnd {
if start == 0 {
if val < 80 {
b.Write(strconv.AppendUint(numBuf, val/40, 10))
b.WriteByte('.')
b.Write(strconv.AppendUint(numBuf, val%40, 10))
} else {
b.WriteString("2.")
b.Write(strconv.AppendUint(numBuf, val-80, 10))
}
} else {
b.Write(strconv.AppendUint(numBuf, val, 10))
}
val = 0
start = i + 1
}
}
return b.String()
}
func (oid OID) toASN1OID() (asn1.ObjectIdentifier, bool) {
out := make([]int, 0, len(oid.der)+1)
const (
valSize = 31 // amount of usable bits of val for OIDs.
bitsPerByte = 7
maxValSafeShift = (1 << (valSize - bitsPerByte)) - 1
)
val := 0
for _, v := range oid.der {
if val > maxValSafeShift {
return nil, false
}
val <<= bitsPerByte
val |= int(v & 0x7F)
if v&0x80 == 0 {
if len(out) == 0 {
if val < 80 {
out = append(out, val/40)
out = append(out, val%40)
} else {
out = append(out, 2)
out = append(out, val-80)
}
val = 0
continue
}
out = append(out, val)
val = 0
}
}
return out, true
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"bytes"
"crypto/dsa"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"fmt"
"internal/godebug"
"math"
"math/big"
"net"
"net/url"
"strconv"
"strings"
"time"
"unicode/utf16"
"unicode/utf8"
"golang.org/x/crypto/cryptobyte"
cryptobyte_asn1 "golang.org/x/crypto/cryptobyte/asn1"
)
// isPrintable reports whether the given b is in the ASN.1 PrintableString set.
// This is a simplified version of encoding/asn1.isPrintable.
func isPrintable(b byte) bool {
return 'a' <= b && b <= 'z' ||
'A' <= b && b <= 'Z' ||
'0' <= b && b <= '9' ||
'\'' <= b && b <= ')' ||
'+' <= b && b <= '/' ||
b == ' ' ||
b == ':' ||
b == '=' ||
b == '?' ||
// This is technically not allowed in a PrintableString.
// However, x509 certificates with wildcard strings don't
// always use the correct string type so we permit it.
b == '*' ||
// This is not technically allowed either. However, not
// only is it relatively common, but there are also a
// handful of CA certificates that contain it. At least
// one of which will not expire until 2027.
b == '&'
}
// parseASN1String parses the ASN.1 string types T61String, PrintableString,
// UTF8String, BMPString, IA5String, and NumericString. This is mostly copied
// from the respective encoding/asn1.parse... methods, rather than just
// increasing the API surface of that package.
func parseASN1String(tag cryptobyte_asn1.Tag, value []byte) (string, error) {
switch tag {
case cryptobyte_asn1.T61String:
// T.61 is a defunct ITU 8-bit character encoding which preceded Unicode.
// T.61 uses a code page layout that _almost_ exactly maps to the code
// page layout of the ISO 8859-1 (Latin-1) character encoding, with the
// exception that a number of characters in Latin-1 are not present
// in T.61.
//
// Instead of mapping which characters are present in Latin-1 but not T.61,
// we just treat these strings as being encoded using Latin-1. This matches
// what most of the world does, including BoringSSL.
buf := make([]byte, 0, len(value))
for _, v := range value {
// All the 1-byte UTF-8 runes map 1-1 with Latin-1.
buf = utf8.AppendRune(buf, rune(v))
}
return string(buf), nil
case cryptobyte_asn1.PrintableString:
for _, b := range value {
if !isPrintable(b) {
return "", errors.New("invalid PrintableString")
}
}
return string(value), nil
case cryptobyte_asn1.UTF8String:
if !utf8.Valid(value) {
return "", errors.New("invalid UTF-8 string")
}
return string(value), nil
case cryptobyte_asn1.Tag(asn1.TagBMPString):
// BMPString uses the defunct UCS-2 16-bit character encoding, which
// covers the Basic Multilingual Plane (BMP). UTF-16 was an extension of
// UCS-2, containing all of the same code points, but also including
// multi-code point characters (by using surrogate code points). We can
// treat a UCS-2 encoded string as a UTF-16 encoded string, as long as
// we reject out the UTF-16 specific code points. This matches the
// BoringSSL behavior.
if len(value)%2 != 0 {
return "", errors.New("invalid BMPString")
}
// Strip terminator if present.
if l := len(value); l >= 2 && value[l-1] == 0 && value[l-2] == 0 {
value = value[:l-2]
}
s := make([]uint16, 0, len(value)/2)
for len(value) > 0 {
point := uint16(value[0])<<8 + uint16(value[1])
// Reject UTF-16 code points that are permanently reserved
// noncharacters (0xfffe, 0xffff, and 0xfdd0-0xfdef) and surrogates
// (0xd800-0xdfff).
if point == 0xfffe || point == 0xffff ||
(point >= 0xfdd0 && point <= 0xfdef) ||
(point >= 0xd800 && point <= 0xdfff) {
return "", errors.New("invalid BMPString")
}
s = append(s, point)
value = value[2:]
}
return string(utf16.Decode(s)), nil
case cryptobyte_asn1.IA5String:
s := string(value)
if isIA5String(s) != nil {
return "", errors.New("invalid IA5String")
}
return s, nil
case cryptobyte_asn1.Tag(asn1.TagNumericString):
for _, b := range value {
if !('0' <= b && b <= '9' || b == ' ') {
return "", errors.New("invalid NumericString")
}
}
return string(value), nil
}
return "", fmt.Errorf("unsupported string type: %v", tag)
}
// parseName parses a DER encoded Name as defined in RFC 5280. We may
// want to export this function in the future for use in crypto/tls.
func parseName(raw cryptobyte.String) (*pkix.RDNSequence, error) {
if !raw.ReadASN1(&raw, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: invalid RDNSequence")
}
var rdnSeq pkix.RDNSequence
for !raw.Empty() {
var rdnSet pkix.RelativeDistinguishedNameSET
var set cryptobyte.String
if !raw.ReadASN1(&set, cryptobyte_asn1.SET) {
return nil, errors.New("x509: invalid RDNSequence")
}
for !set.Empty() {
var atav cryptobyte.String
if !set.ReadASN1(&atav, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: invalid RDNSequence: invalid attribute")
}
var attr pkix.AttributeTypeAndValue
if !atav.ReadASN1ObjectIdentifier(&attr.Type) {
return nil, errors.New("x509: invalid RDNSequence: invalid attribute type")
}
var rawValue cryptobyte.String
var valueTag cryptobyte_asn1.Tag
if !atav.ReadAnyASN1(&rawValue, &valueTag) {
return nil, errors.New("x509: invalid RDNSequence: invalid attribute value")
}
var err error
attr.Value, err = parseASN1String(valueTag, rawValue)
if err != nil {
return nil, fmt.Errorf("x509: invalid RDNSequence: invalid attribute value: %s", err)
}
rdnSet = append(rdnSet, attr)
}
rdnSeq = append(rdnSeq, rdnSet)
}
return &rdnSeq, nil
}
func parseAI(der cryptobyte.String) (pkix.AlgorithmIdentifier, error) {
ai := pkix.AlgorithmIdentifier{}
if !der.ReadASN1ObjectIdentifier(&ai.Algorithm) {
return ai, errors.New("x509: malformed OID")
}
if der.Empty() {
return ai, nil
}
var params cryptobyte.String
var tag cryptobyte_asn1.Tag
if !der.ReadAnyASN1Element(¶ms, &tag) {
return ai, errors.New("x509: malformed parameters")
}
ai.Parameters.Tag = int(tag)
ai.Parameters.FullBytes = params
return ai, nil
}
func parseTime(der *cryptobyte.String) (time.Time, error) {
var t time.Time
switch {
case der.PeekASN1Tag(cryptobyte_asn1.UTCTime):
if !der.ReadASN1UTCTime(&t) {
return t, errors.New("x509: malformed UTCTime")
}
case der.PeekASN1Tag(cryptobyte_asn1.GeneralizedTime):
if !der.ReadASN1GeneralizedTime(&t) {
return t, errors.New("x509: malformed GeneralizedTime")
}
default:
return t, errors.New("x509: unsupported time format")
}
return t, nil
}
func parseValidity(der cryptobyte.String) (time.Time, time.Time, error) {
notBefore, err := parseTime(&der)
if err != nil {
return time.Time{}, time.Time{}, err
}
notAfter, err := parseTime(&der)
if err != nil {
return time.Time{}, time.Time{}, err
}
return notBefore, notAfter, nil
}
func parseExtension(der cryptobyte.String) (pkix.Extension, error) {
var ext pkix.Extension
if !der.ReadASN1ObjectIdentifier(&ext.Id) {
return ext, errors.New("x509: malformed extension OID field")
}
if der.PeekASN1Tag(cryptobyte_asn1.BOOLEAN) {
if !der.ReadASN1Boolean(&ext.Critical) {
return ext, errors.New("x509: malformed extension critical field")
}
}
var val cryptobyte.String
if !der.ReadASN1(&val, cryptobyte_asn1.OCTET_STRING) {
return ext, errors.New("x509: malformed extension value field")
}
ext.Value = val
return ext, nil
}
func parsePublicKey(keyData *publicKeyInfo) (any, error) {
oid := keyData.Algorithm.Algorithm
params := keyData.Algorithm.Parameters
der := cryptobyte.String(keyData.PublicKey.RightAlign())
switch {
case oid.Equal(oidPublicKeyRSA):
// RSA public keys must have a NULL in the parameters.
// See RFC 3279, Section 2.3.1.
if !bytes.Equal(params.FullBytes, asn1.NullBytes) {
return nil, errors.New("x509: RSA key missing NULL parameters")
}
p := &pkcs1PublicKey{N: new(big.Int)}
if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: invalid RSA public key")
}
if !der.ReadASN1Integer(p.N) {
return nil, errors.New("x509: invalid RSA modulus")
}
if !der.ReadASN1Integer(&p.E) {
return nil, errors.New("x509: invalid RSA public exponent")
}
if p.N.Sign() <= 0 {
return nil, errors.New("x509: RSA modulus is not a positive number")
}
if p.E <= 0 {
return nil, errors.New("x509: RSA public exponent is not a positive number")
}
pub := &rsa.PublicKey{
E: p.E,
N: p.N,
}
return pub, nil
case oid.Equal(oidPublicKeyECDSA):
paramsDer := cryptobyte.String(params.FullBytes)
namedCurveOID := new(asn1.ObjectIdentifier)
if !paramsDer.ReadASN1ObjectIdentifier(namedCurveOID) {
return nil, errors.New("x509: invalid ECDSA parameters")
}
namedCurve := namedCurveFromOID(*namedCurveOID)
if namedCurve == nil {
return nil, errors.New("x509: unsupported elliptic curve")
}
x, y := elliptic.Unmarshal(namedCurve, der)
if x == nil {
return nil, errors.New("x509: failed to unmarshal elliptic curve point")
}
pub := &ecdsa.PublicKey{
Curve: namedCurve,
X: x,
Y: y,
}
return pub, nil
case oid.Equal(oidPublicKeyEd25519):
// RFC 8410, Section 3
// > For all of the OIDs, the parameters MUST be absent.
if len(params.FullBytes) != 0 {
return nil, errors.New("x509: Ed25519 key encoded with illegal parameters")
}
if len(der) != ed25519.PublicKeySize {
return nil, errors.New("x509: wrong Ed25519 public key size")
}
return ed25519.PublicKey(der), nil
case oid.Equal(oidPublicKeyX25519):
// RFC 8410, Section 3
// > For all of the OIDs, the parameters MUST be absent.
if len(params.FullBytes) != 0 {
return nil, errors.New("x509: X25519 key encoded with illegal parameters")
}
return ecdh.X25519().NewPublicKey(der)
case oid.Equal(oidPublicKeyDSA):
y := new(big.Int)
if !der.ReadASN1Integer(y) {
return nil, errors.New("x509: invalid DSA public key")
}
pub := &dsa.PublicKey{
Y: y,
Parameters: dsa.Parameters{
P: new(big.Int),
Q: new(big.Int),
G: new(big.Int),
},
}
paramsDer := cryptobyte.String(params.FullBytes)
if !paramsDer.ReadASN1(¶msDer, cryptobyte_asn1.SEQUENCE) ||
!paramsDer.ReadASN1Integer(pub.Parameters.P) ||
!paramsDer.ReadASN1Integer(pub.Parameters.Q) ||
!paramsDer.ReadASN1Integer(pub.Parameters.G) {
return nil, errors.New("x509: invalid DSA parameters")
}
if pub.Y.Sign() <= 0 || pub.Parameters.P.Sign() <= 0 ||
pub.Parameters.Q.Sign() <= 0 || pub.Parameters.G.Sign() <= 0 {
return nil, errors.New("x509: zero or negative DSA parameter")
}
return pub, nil
default:
return nil, errors.New("x509: unknown public key algorithm")
}
}
func parseKeyUsageExtension(der cryptobyte.String) (KeyUsage, error) {
var usageBits asn1.BitString
if !der.ReadASN1BitString(&usageBits) {
return 0, errors.New("x509: invalid key usage")
}
var usage int
for i := 0; i < 9; i++ {
if usageBits.At(i) != 0 {
usage |= 1 << uint(i)
}
}
return KeyUsage(usage), nil
}
func parseBasicConstraintsExtension(der cryptobyte.String) (bool, int, error) {
var isCA bool
if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) {
return false, 0, errors.New("x509: invalid basic constraints")
}
if der.PeekASN1Tag(cryptobyte_asn1.BOOLEAN) {
if !der.ReadASN1Boolean(&isCA) {
return false, 0, errors.New("x509: invalid basic constraints")
}
}
maxPathLen := -1
if der.PeekASN1Tag(cryptobyte_asn1.INTEGER) {
var mpl uint
if !der.ReadASN1Integer(&mpl) || mpl > math.MaxInt {
return false, 0, errors.New("x509: invalid basic constraints")
}
maxPathLen = int(mpl)
}
return isCA, maxPathLen, nil
}
func forEachSAN(der cryptobyte.String, callback func(tag int, data []byte) error) error {
if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) {
return errors.New("x509: invalid subject alternative names")
}
for !der.Empty() {
var san cryptobyte.String
var tag cryptobyte_asn1.Tag
if !der.ReadAnyASN1(&san, &tag) {
return errors.New("x509: invalid subject alternative name")
}
if err := callback(int(tag^0x80), san); err != nil {
return err
}
}
return nil
}
func parseSANExtension(der cryptobyte.String) (dnsNames, emailAddresses []string, ipAddresses []net.IP, uris []*url.URL, err error) {
err = forEachSAN(der, func(tag int, data []byte) error {
switch tag {
case nameTypeEmail:
email := string(data)
if err := isIA5String(email); err != nil {
return errors.New("x509: SAN rfc822Name is malformed")
}
emailAddresses = append(emailAddresses, email)
case nameTypeDNS:
name := string(data)
if err := isIA5String(name); err != nil {
return errors.New("x509: SAN dNSName is malformed")
}
dnsNames = append(dnsNames, string(name))
case nameTypeURI:
uriStr := string(data)
if err := isIA5String(uriStr); err != nil {
return errors.New("x509: SAN uniformResourceIdentifier is malformed")
}
uri, err := url.Parse(uriStr)
if err != nil {
return fmt.Errorf("x509: cannot parse URI %q: %s", uriStr, err)
}
if len(uri.Host) > 0 {
if _, ok := domainToReverseLabels(uri.Host); !ok {
return fmt.Errorf("x509: cannot parse URI %q: invalid domain", uriStr)
}
}
uris = append(uris, uri)
case nameTypeIP:
switch len(data) {
case net.IPv4len, net.IPv6len:
ipAddresses = append(ipAddresses, data)
default:
return errors.New("x509: cannot parse IP address of length " + strconv.Itoa(len(data)))
}
}
return nil
})
return
}
func parseAuthorityKeyIdentifier(e pkix.Extension) ([]byte, error) {
// RFC 5280, Section 4.2.1.1
if e.Critical {
// Conforming CAs MUST mark this extension as non-critical
return nil, errors.New("x509: authority key identifier incorrectly marked critical")
}
val := cryptobyte.String(e.Value)
var akid cryptobyte.String
if !val.ReadASN1(&akid, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: invalid authority key identifier")
}
if akid.PeekASN1Tag(cryptobyte_asn1.Tag(0).ContextSpecific()) {
if !akid.ReadASN1(&akid, cryptobyte_asn1.Tag(0).ContextSpecific()) {
return nil, errors.New("x509: invalid authority key identifier")
}
return akid, nil
}
return nil, nil
}
func parseExtKeyUsageExtension(der cryptobyte.String) ([]ExtKeyUsage, []asn1.ObjectIdentifier, error) {
var extKeyUsages []ExtKeyUsage
var unknownUsages []asn1.ObjectIdentifier
if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) {
return nil, nil, errors.New("x509: invalid extended key usages")
}
for !der.Empty() {
var eku asn1.ObjectIdentifier
if !der.ReadASN1ObjectIdentifier(&eku) {
return nil, nil, errors.New("x509: invalid extended key usages")
}
if extKeyUsage, ok := extKeyUsageFromOID(eku); ok {
extKeyUsages = append(extKeyUsages, extKeyUsage)
} else {
unknownUsages = append(unknownUsages, eku)
}
}
return extKeyUsages, unknownUsages, nil
}
func parseCertificatePoliciesExtension(der cryptobyte.String) ([]OID, error) {
var oids []OID
seenOIDs := map[string]bool{}
if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: invalid certificate policies")
}
for !der.Empty() {
var cp cryptobyte.String
var OIDBytes cryptobyte.String
if !der.ReadASN1(&cp, cryptobyte_asn1.SEQUENCE) || !cp.ReadASN1(&OIDBytes, cryptobyte_asn1.OBJECT_IDENTIFIER) {
return nil, errors.New("x509: invalid certificate policies")
}
if seenOIDs[string(OIDBytes)] {
return nil, errors.New("x509: invalid certificate policies")
}
seenOIDs[string(OIDBytes)] = true
oid, ok := newOIDFromDER(OIDBytes)
if !ok {
return nil, errors.New("x509: invalid certificate policies")
}
oids = append(oids, oid)
}
return oids, nil
}
// isValidIPMask reports whether mask consists of zero or more 1 bits, followed by zero bits.
func isValidIPMask(mask []byte) bool {
seenZero := false
for _, b := range mask {
if seenZero {
if b != 0 {
return false
}
continue
}
switch b {
case 0x00, 0x80, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc, 0xfe:
seenZero = true
case 0xff:
default:
return false
}
}
return true
}
func parseNameConstraintsExtension(out *Certificate, e pkix.Extension) (unhandled bool, err error) {
// RFC 5280, 4.2.1.10
// NameConstraints ::= SEQUENCE {
// permittedSubtrees [0] GeneralSubtrees OPTIONAL,
// excludedSubtrees [1] GeneralSubtrees OPTIONAL }
//
// GeneralSubtrees ::= SEQUENCE SIZE (1..MAX) OF GeneralSubtree
//
// GeneralSubtree ::= SEQUENCE {
// base GeneralName,
// minimum [0] BaseDistance DEFAULT 0,
// maximum [1] BaseDistance OPTIONAL }
//
// BaseDistance ::= INTEGER (0..MAX)
outer := cryptobyte.String(e.Value)
var toplevel, permitted, excluded cryptobyte.String
var havePermitted, haveExcluded bool
if !outer.ReadASN1(&toplevel, cryptobyte_asn1.SEQUENCE) ||
!outer.Empty() ||
!toplevel.ReadOptionalASN1(&permitted, &havePermitted, cryptobyte_asn1.Tag(0).ContextSpecific().Constructed()) ||
!toplevel.ReadOptionalASN1(&excluded, &haveExcluded, cryptobyte_asn1.Tag(1).ContextSpecific().Constructed()) ||
!toplevel.Empty() {
return false, errors.New("x509: invalid NameConstraints extension")
}
if !havePermitted && !haveExcluded || len(permitted) == 0 && len(excluded) == 0 {
// From RFC 5280, Section 4.2.1.10:
// “either the permittedSubtrees field
// or the excludedSubtrees MUST be
// present”
return false, errors.New("x509: empty name constraints extension")
}
getValues := func(subtrees cryptobyte.String) (dnsNames []string, ips []*net.IPNet, emails, uriDomains []string, err error) {
for !subtrees.Empty() {
var seq, value cryptobyte.String
var tag cryptobyte_asn1.Tag
if !subtrees.ReadASN1(&seq, cryptobyte_asn1.SEQUENCE) ||
!seq.ReadAnyASN1(&value, &tag) {
return nil, nil, nil, nil, fmt.Errorf("x509: invalid NameConstraints extension")
}
var (
dnsTag = cryptobyte_asn1.Tag(2).ContextSpecific()
emailTag = cryptobyte_asn1.Tag(1).ContextSpecific()
ipTag = cryptobyte_asn1.Tag(7).ContextSpecific()
uriTag = cryptobyte_asn1.Tag(6).ContextSpecific()
)
switch tag {
case dnsTag:
domain := string(value)
if err := isIA5String(domain); err != nil {
return nil, nil, nil, nil, errors.New("x509: invalid constraint value: " + err.Error())
}
trimmedDomain := domain
if len(trimmedDomain) > 0 && trimmedDomain[0] == '.' {
// constraints can have a leading
// period to exclude the domain
// itself, but that's not valid in a
// normal domain name.
trimmedDomain = trimmedDomain[1:]
}
if _, ok := domainToReverseLabels(trimmedDomain); !ok {
return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse dnsName constraint %q", domain)
}
dnsNames = append(dnsNames, domain)
case ipTag:
l := len(value)
var ip, mask []byte
switch l {
case 8:
ip = value[:4]
mask = value[4:]
case 32:
ip = value[:16]
mask = value[16:]
default:
return nil, nil, nil, nil, fmt.Errorf("x509: IP constraint contained value of length %d", l)
}
if !isValidIPMask(mask) {
return nil, nil, nil, nil, fmt.Errorf("x509: IP constraint contained invalid mask %x", mask)
}
ips = append(ips, &net.IPNet{IP: net.IP(ip), Mask: net.IPMask(mask)})
case emailTag:
constraint := string(value)
if err := isIA5String(constraint); err != nil {
return nil, nil, nil, nil, errors.New("x509: invalid constraint value: " + err.Error())
}
// If the constraint contains an @ then
// it specifies an exact mailbox name.
if strings.Contains(constraint, "@") {
if _, ok := parseRFC2821Mailbox(constraint); !ok {
return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse rfc822Name constraint %q", constraint)
}
} else {
// Otherwise it's a domain name.
domain := constraint
if len(domain) > 0 && domain[0] == '.' {
domain = domain[1:]
}
if _, ok := domainToReverseLabels(domain); !ok {
return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse rfc822Name constraint %q", constraint)
}
}
emails = append(emails, constraint)
case uriTag:
domain := string(value)
if err := isIA5String(domain); err != nil {
return nil, nil, nil, nil, errors.New("x509: invalid constraint value: " + err.Error())
}
if net.ParseIP(domain) != nil {
return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse URI constraint %q: cannot be IP address", domain)
}
trimmedDomain := domain
if len(trimmedDomain) > 0 && trimmedDomain[0] == '.' {
// constraints can have a leading
// period to exclude the domain itself,
// but that's not valid in a normal
// domain name.
trimmedDomain = trimmedDomain[1:]
}
if _, ok := domainToReverseLabels(trimmedDomain); !ok {
return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse URI constraint %q", domain)
}
uriDomains = append(uriDomains, domain)
default:
unhandled = true
}
}
return dnsNames, ips, emails, uriDomains, nil
}
if out.PermittedDNSDomains, out.PermittedIPRanges, out.PermittedEmailAddresses, out.PermittedURIDomains, err = getValues(permitted); err != nil {
return false, err
}
if out.ExcludedDNSDomains, out.ExcludedIPRanges, out.ExcludedEmailAddresses, out.ExcludedURIDomains, err = getValues(excluded); err != nil {
return false, err
}
out.PermittedDNSDomainsCritical = e.Critical
return unhandled, nil
}
func processExtensions(out *Certificate) error {
var err error
for _, e := range out.Extensions {
unhandled := false
if len(e.Id) == 4 && e.Id[0] == 2 && e.Id[1] == 5 && e.Id[2] == 29 {
switch e.Id[3] {
case 15:
out.KeyUsage, err = parseKeyUsageExtension(e.Value)
if err != nil {
return err
}
case 19:
out.IsCA, out.MaxPathLen, err = parseBasicConstraintsExtension(e.Value)
if err != nil {
return err
}
out.BasicConstraintsValid = true
out.MaxPathLenZero = out.MaxPathLen == 0
case 17:
out.DNSNames, out.EmailAddresses, out.IPAddresses, out.URIs, err = parseSANExtension(e.Value)
if err != nil {
return err
}
if len(out.DNSNames) == 0 && len(out.EmailAddresses) == 0 && len(out.IPAddresses) == 0 && len(out.URIs) == 0 {
// If we didn't parse anything then we do the critical check, below.
unhandled = true
}
case 30:
unhandled, err = parseNameConstraintsExtension(out, e)
if err != nil {
return err
}
case 31:
// RFC 5280, 4.2.1.13
// CRLDistributionPoints ::= SEQUENCE SIZE (1..MAX) OF DistributionPoint
//
// DistributionPoint ::= SEQUENCE {
// distributionPoint [0] DistributionPointName OPTIONAL,
// reasons [1] ReasonFlags OPTIONAL,
// cRLIssuer [2] GeneralNames OPTIONAL }
//
// DistributionPointName ::= CHOICE {
// fullName [0] GeneralNames,
// nameRelativeToCRLIssuer [1] RelativeDistinguishedName }
val := cryptobyte.String(e.Value)
if !val.ReadASN1(&val, cryptobyte_asn1.SEQUENCE) {
return errors.New("x509: invalid CRL distribution points")
}
for !val.Empty() {
var dpDER cryptobyte.String
if !val.ReadASN1(&dpDER, cryptobyte_asn1.SEQUENCE) {
return errors.New("x509: invalid CRL distribution point")
}
var dpNameDER cryptobyte.String
var dpNamePresent bool
if !dpDER.ReadOptionalASN1(&dpNameDER, &dpNamePresent, cryptobyte_asn1.Tag(0).Constructed().ContextSpecific()) {
return errors.New("x509: invalid CRL distribution point")
}
if !dpNamePresent {
continue
}
if !dpNameDER.ReadASN1(&dpNameDER, cryptobyte_asn1.Tag(0).Constructed().ContextSpecific()) {
return errors.New("x509: invalid CRL distribution point")
}
for !dpNameDER.Empty() {
if !dpNameDER.PeekASN1Tag(cryptobyte_asn1.Tag(6).ContextSpecific()) {
break
}
var uri cryptobyte.String
if !dpNameDER.ReadASN1(&uri, cryptobyte_asn1.Tag(6).ContextSpecific()) {
return errors.New("x509: invalid CRL distribution point")
}
out.CRLDistributionPoints = append(out.CRLDistributionPoints, string(uri))
}
}
case 35:
out.AuthorityKeyId, err = parseAuthorityKeyIdentifier(e)
if err != nil {
return err
}
case 36:
val := cryptobyte.String(e.Value)
if !val.ReadASN1(&val, cryptobyte_asn1.SEQUENCE) {
return errors.New("x509: invalid policy constraints extension")
}
if val.PeekASN1Tag(cryptobyte_asn1.Tag(0).ContextSpecific()) {
var v int64
if !val.ReadASN1Int64WithTag(&v, cryptobyte_asn1.Tag(0).ContextSpecific()) {
return errors.New("x509: invalid policy constraints extension")
}
out.RequireExplicitPolicy = int(v)
// Check for overflow.
if int64(out.RequireExplicitPolicy) != v {
return errors.New("x509: policy constraints requireExplicitPolicy field overflows int")
}
out.RequireExplicitPolicyZero = out.RequireExplicitPolicy == 0
}
if val.PeekASN1Tag(cryptobyte_asn1.Tag(1).ContextSpecific()) {
var v int64
if !val.ReadASN1Int64WithTag(&v, cryptobyte_asn1.Tag(1).ContextSpecific()) {
return errors.New("x509: invalid policy constraints extension")
}
out.InhibitPolicyMapping = int(v)
// Check for overflow.
if int64(out.InhibitPolicyMapping) != v {
return errors.New("x509: policy constraints inhibitPolicyMapping field overflows int")
}
out.InhibitPolicyMappingZero = out.InhibitPolicyMapping == 0
}
case 37:
out.ExtKeyUsage, out.UnknownExtKeyUsage, err = parseExtKeyUsageExtension(e.Value)
if err != nil {
return err
}
case 14: // RFC 5280, 4.2.1.2
if e.Critical {
// Conforming CAs MUST mark this extension as non-critical
return errors.New("x509: subject key identifier incorrectly marked critical")
}
val := cryptobyte.String(e.Value)
var skid cryptobyte.String
if !val.ReadASN1(&skid, cryptobyte_asn1.OCTET_STRING) {
return errors.New("x509: invalid subject key identifier")
}
out.SubjectKeyId = skid
case 32:
out.Policies, err = parseCertificatePoliciesExtension(e.Value)
if err != nil {
return err
}
out.PolicyIdentifiers = make([]asn1.ObjectIdentifier, 0, len(out.Policies))
for _, oid := range out.Policies {
if oid, ok := oid.toASN1OID(); ok {
out.PolicyIdentifiers = append(out.PolicyIdentifiers, oid)
}
}
case 33:
val := cryptobyte.String(e.Value)
if !val.ReadASN1(&val, cryptobyte_asn1.SEQUENCE) {
return errors.New("x509: invalid policy mappings extension")
}
for !val.Empty() {
var s cryptobyte.String
var issuer, subject cryptobyte.String
if !val.ReadASN1(&s, cryptobyte_asn1.SEQUENCE) ||
!s.ReadASN1(&issuer, cryptobyte_asn1.OBJECT_IDENTIFIER) ||
!s.ReadASN1(&subject, cryptobyte_asn1.OBJECT_IDENTIFIER) {
return errors.New("x509: invalid policy mappings extension")
}
out.PolicyMappings = append(out.PolicyMappings, PolicyMapping{OID{issuer}, OID{subject}})
}
case 54:
val := cryptobyte.String(e.Value)
if !val.ReadASN1Integer(&out.InhibitAnyPolicy) {
return errors.New("x509: invalid inhibit any policy extension")
}
out.InhibitAnyPolicyZero = out.InhibitAnyPolicy == 0
default:
// Unknown extensions are recorded if critical.
unhandled = true
}
} else if e.Id.Equal(oidExtensionAuthorityInfoAccess) {
// RFC 5280 4.2.2.1: Authority Information Access
if e.Critical {
// Conforming CAs MUST mark this extension as non-critical
return errors.New("x509: authority info access incorrectly marked critical")
}
val := cryptobyte.String(e.Value)
if !val.ReadASN1(&val, cryptobyte_asn1.SEQUENCE) {
return errors.New("x509: invalid authority info access")
}
for !val.Empty() {
var aiaDER cryptobyte.String
if !val.ReadASN1(&aiaDER, cryptobyte_asn1.SEQUENCE) {
return errors.New("x509: invalid authority info access")
}
var method asn1.ObjectIdentifier
if !aiaDER.ReadASN1ObjectIdentifier(&method) {
return errors.New("x509: invalid authority info access")
}
if !aiaDER.PeekASN1Tag(cryptobyte_asn1.Tag(6).ContextSpecific()) {
continue
}
if !aiaDER.ReadASN1(&aiaDER, cryptobyte_asn1.Tag(6).ContextSpecific()) {
return errors.New("x509: invalid authority info access")
}
switch {
case method.Equal(oidAuthorityInfoAccessOcsp):
out.OCSPServer = append(out.OCSPServer, string(aiaDER))
case method.Equal(oidAuthorityInfoAccessIssuers):
out.IssuingCertificateURL = append(out.IssuingCertificateURL, string(aiaDER))
}
}
} else {
// Unknown extensions are recorded if critical.
unhandled = true
}
if e.Critical && unhandled {
out.UnhandledCriticalExtensions = append(out.UnhandledCriticalExtensions, e.Id)
}
}
return nil
}
var x509negativeserial = godebug.New("x509negativeserial")
func parseCertificate(der []byte) (*Certificate, error) {
cert := &Certificate{}
input := cryptobyte.String(der)
// we read the SEQUENCE including length and tag bytes so that
// we can populate Certificate.Raw, before unwrapping the
// SEQUENCE so it can be operated on
if !input.ReadASN1Element(&input, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed certificate")
}
cert.Raw = input
if !input.ReadASN1(&input, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed certificate")
}
var tbs cryptobyte.String
// do the same trick again as above to extract the raw
// bytes for Certificate.RawTBSCertificate
if !input.ReadASN1Element(&tbs, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed tbs certificate")
}
cert.RawTBSCertificate = tbs
if !tbs.ReadASN1(&tbs, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed tbs certificate")
}
if !tbs.ReadOptionalASN1Integer(&cert.Version, cryptobyte_asn1.Tag(0).Constructed().ContextSpecific(), 0) {
return nil, errors.New("x509: malformed version")
}
if cert.Version < 0 {
return nil, errors.New("x509: malformed version")
}
// for backwards compat reasons Version is one-indexed,
// rather than zero-indexed as defined in 5280
cert.Version++
if cert.Version > 3 {
return nil, errors.New("x509: invalid version")
}
serial := new(big.Int)
if !tbs.ReadASN1Integer(serial) {
return nil, errors.New("x509: malformed serial number")
}
if serial.Sign() == -1 {
if x509negativeserial.Value() != "1" {
return nil, errors.New("x509: negative serial number")
} else {
x509negativeserial.IncNonDefault()
}
}
cert.SerialNumber = serial
var sigAISeq cryptobyte.String
if !tbs.ReadASN1(&sigAISeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed signature algorithm identifier")
}
// Before parsing the inner algorithm identifier, extract
// the outer algorithm identifier and make sure that they
// match.
var outerSigAISeq cryptobyte.String
if !input.ReadASN1(&outerSigAISeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed algorithm identifier")
}
if !bytes.Equal(outerSigAISeq, sigAISeq) {
return nil, errors.New("x509: inner and outer signature algorithm identifiers don't match")
}
sigAI, err := parseAI(sigAISeq)
if err != nil {
return nil, err
}
cert.SignatureAlgorithm = getSignatureAlgorithmFromAI(sigAI)
var issuerSeq cryptobyte.String
if !tbs.ReadASN1Element(&issuerSeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed issuer")
}
cert.RawIssuer = issuerSeq
issuerRDNs, err := parseName(issuerSeq)
if err != nil {
return nil, err
}
cert.Issuer.FillFromRDNSequence(issuerRDNs)
var validity cryptobyte.String
if !tbs.ReadASN1(&validity, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed validity")
}
cert.NotBefore, cert.NotAfter, err = parseValidity(validity)
if err != nil {
return nil, err
}
var subjectSeq cryptobyte.String
if !tbs.ReadASN1Element(&subjectSeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed issuer")
}
cert.RawSubject = subjectSeq
subjectRDNs, err := parseName(subjectSeq)
if err != nil {
return nil, err
}
cert.Subject.FillFromRDNSequence(subjectRDNs)
var spki cryptobyte.String
if !tbs.ReadASN1Element(&spki, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed spki")
}
cert.RawSubjectPublicKeyInfo = spki
if !spki.ReadASN1(&spki, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed spki")
}
var pkAISeq cryptobyte.String
if !spki.ReadASN1(&pkAISeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed public key algorithm identifier")
}
pkAI, err := parseAI(pkAISeq)
if err != nil {
return nil, err
}
cert.PublicKeyAlgorithm = getPublicKeyAlgorithmFromOID(pkAI.Algorithm)
var spk asn1.BitString
if !spki.ReadASN1BitString(&spk) {
return nil, errors.New("x509: malformed subjectPublicKey")
}
if cert.PublicKeyAlgorithm != UnknownPublicKeyAlgorithm {
cert.PublicKey, err = parsePublicKey(&publicKeyInfo{
Algorithm: pkAI,
PublicKey: spk,
})
if err != nil {
return nil, err
}
}
if cert.Version > 1 {
if !tbs.SkipOptionalASN1(cryptobyte_asn1.Tag(1).ContextSpecific()) {
return nil, errors.New("x509: malformed issuerUniqueID")
}
if !tbs.SkipOptionalASN1(cryptobyte_asn1.Tag(2).ContextSpecific()) {
return nil, errors.New("x509: malformed subjectUniqueID")
}
if cert.Version == 3 {
var extensions cryptobyte.String
var present bool
if !tbs.ReadOptionalASN1(&extensions, &present, cryptobyte_asn1.Tag(3).Constructed().ContextSpecific()) {
return nil, errors.New("x509: malformed extensions")
}
if present {
seenExts := make(map[string]bool)
if !extensions.ReadASN1(&extensions, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed extensions")
}
for !extensions.Empty() {
var extension cryptobyte.String
if !extensions.ReadASN1(&extension, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed extension")
}
ext, err := parseExtension(extension)
if err != nil {
return nil, err
}
oidStr := ext.Id.String()
if seenExts[oidStr] {
return nil, fmt.Errorf("x509: certificate contains duplicate extension with OID %q", oidStr)
}
seenExts[oidStr] = true
cert.Extensions = append(cert.Extensions, ext)
}
err = processExtensions(cert)
if err != nil {
return nil, err
}
}
}
}
var signature asn1.BitString
if !input.ReadASN1BitString(&signature) {
return nil, errors.New("x509: malformed signature")
}
cert.Signature = signature.RightAlign()
return cert, nil
}
// ParseCertificate parses a single certificate from the given ASN.1 DER data.
//
// Before Go 1.23, ParseCertificate accepted certificates with negative serial
// numbers. This behavior can be restored by including "x509negativeserial=1" in
// the GODEBUG environment variable.
func ParseCertificate(der []byte) (*Certificate, error) {
cert, err := parseCertificate(der)
if err != nil {
return nil, err
}
if len(der) != len(cert.Raw) {
return nil, errors.New("x509: trailing data")
}
return cert, nil
}
// ParseCertificates parses one or more certificates from the given ASN.1 DER
// data. The certificates must be concatenated with no intermediate padding.
func ParseCertificates(der []byte) ([]*Certificate, error) {
var certs []*Certificate
for len(der) > 0 {
cert, err := parseCertificate(der)
if err != nil {
return nil, err
}
certs = append(certs, cert)
der = der[len(cert.Raw):]
}
return certs, nil
}
// The X.509 standards confusingly 1-indexed the version names, but 0-indexed
// the actual encoded version, so the version for X.509v2 is 1.
const x509v2Version = 1
// ParseRevocationList parses a X509 v2 [Certificate] Revocation List from the given
// ASN.1 DER data.
func ParseRevocationList(der []byte) (*RevocationList, error) {
rl := &RevocationList{}
input := cryptobyte.String(der)
// we read the SEQUENCE including length and tag bytes so that
// we can populate RevocationList.Raw, before unwrapping the
// SEQUENCE so it can be operated on
if !input.ReadASN1Element(&input, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed crl")
}
rl.Raw = input
if !input.ReadASN1(&input, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed crl")
}
var tbs cryptobyte.String
// do the same trick again as above to extract the raw
// bytes for Certificate.RawTBSCertificate
if !input.ReadASN1Element(&tbs, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed tbs crl")
}
rl.RawTBSRevocationList = tbs
if !tbs.ReadASN1(&tbs, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed tbs crl")
}
var version int
if !tbs.PeekASN1Tag(cryptobyte_asn1.INTEGER) {
return nil, errors.New("x509: unsupported crl version")
}
if !tbs.ReadASN1Integer(&version) {
return nil, errors.New("x509: malformed crl")
}
if version != x509v2Version {
return nil, fmt.Errorf("x509: unsupported crl version: %d", version)
}
var sigAISeq cryptobyte.String
if !tbs.ReadASN1(&sigAISeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed signature algorithm identifier")
}
// Before parsing the inner algorithm identifier, extract
// the outer algorithm identifier and make sure that they
// match.
var outerSigAISeq cryptobyte.String
if !input.ReadASN1(&outerSigAISeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed algorithm identifier")
}
if !bytes.Equal(outerSigAISeq, sigAISeq) {
return nil, errors.New("x509: inner and outer signature algorithm identifiers don't match")
}
sigAI, err := parseAI(sigAISeq)
if err != nil {
return nil, err
}
rl.SignatureAlgorithm = getSignatureAlgorithmFromAI(sigAI)
var signature asn1.BitString
if !input.ReadASN1BitString(&signature) {
return nil, errors.New("x509: malformed signature")
}
rl.Signature = signature.RightAlign()
var issuerSeq cryptobyte.String
if !tbs.ReadASN1Element(&issuerSeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed issuer")
}
rl.RawIssuer = issuerSeq
issuerRDNs, err := parseName(issuerSeq)
if err != nil {
return nil, err
}
rl.Issuer.FillFromRDNSequence(issuerRDNs)
rl.ThisUpdate, err = parseTime(&tbs)
if err != nil {
return nil, err
}
if tbs.PeekASN1Tag(cryptobyte_asn1.GeneralizedTime) || tbs.PeekASN1Tag(cryptobyte_asn1.UTCTime) {
rl.NextUpdate, err = parseTime(&tbs)
if err != nil {
return nil, err
}
}
if tbs.PeekASN1Tag(cryptobyte_asn1.SEQUENCE) {
var revokedSeq cryptobyte.String
if !tbs.ReadASN1(&revokedSeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed crl")
}
for !revokedSeq.Empty() {
rce := RevocationListEntry{}
var certSeq cryptobyte.String
if !revokedSeq.ReadASN1Element(&certSeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed crl")
}
rce.Raw = certSeq
if !certSeq.ReadASN1(&certSeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed crl")
}
rce.SerialNumber = new(big.Int)
if !certSeq.ReadASN1Integer(rce.SerialNumber) {
return nil, errors.New("x509: malformed serial number")
}
rce.RevocationTime, err = parseTime(&certSeq)
if err != nil {
return nil, err
}
var extensions cryptobyte.String
var present bool
if !certSeq.ReadOptionalASN1(&extensions, &present, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed extensions")
}
if present {
for !extensions.Empty() {
var extension cryptobyte.String
if !extensions.ReadASN1(&extension, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed extension")
}
ext, err := parseExtension(extension)
if err != nil {
return nil, err
}
if ext.Id.Equal(oidExtensionReasonCode) {
val := cryptobyte.String(ext.Value)
if !val.ReadASN1Enum(&rce.ReasonCode) {
return nil, fmt.Errorf("x509: malformed reasonCode extension")
}
}
rce.Extensions = append(rce.Extensions, ext)
}
}
rl.RevokedCertificateEntries = append(rl.RevokedCertificateEntries, rce)
rcDeprecated := pkix.RevokedCertificate{
SerialNumber: rce.SerialNumber,
RevocationTime: rce.RevocationTime,
Extensions: rce.Extensions,
}
rl.RevokedCertificates = append(rl.RevokedCertificates, rcDeprecated)
}
}
var extensions cryptobyte.String
var present bool
if !tbs.ReadOptionalASN1(&extensions, &present, cryptobyte_asn1.Tag(0).Constructed().ContextSpecific()) {
return nil, errors.New("x509: malformed extensions")
}
if present {
if !extensions.ReadASN1(&extensions, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed extensions")
}
for !extensions.Empty() {
var extension cryptobyte.String
if !extensions.ReadASN1(&extension, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed extension")
}
ext, err := parseExtension(extension)
if err != nil {
return nil, err
}
if ext.Id.Equal(oidExtensionAuthorityKeyId) {
rl.AuthorityKeyId, err = parseAuthorityKeyIdentifier(ext)
if err != nil {
return nil, err
}
} else if ext.Id.Equal(oidExtensionCRLNumber) {
value := cryptobyte.String(ext.Value)
rl.Number = new(big.Int)
if !value.ReadASN1Integer(rl.Number) {
return nil, errors.New("x509: malformed crl number")
}
}
rl.Extensions = append(rl.Extensions, ext)
}
}
return rl, nil
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
// RFC 1423 describes the encryption of PEM blocks. The algorithm used to
// generate a key from the password was derived by looking at the OpenSSL
// implementation.
import (
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/md5"
"encoding/hex"
"encoding/pem"
"errors"
"io"
"strings"
)
type PEMCipher int
// Possible values for the EncryptPEMBlock encryption algorithm.
const (
_ PEMCipher = iota
PEMCipherDES
PEMCipher3DES
PEMCipherAES128
PEMCipherAES192
PEMCipherAES256
)
// rfc1423Algo holds a method for enciphering a PEM block.
type rfc1423Algo struct {
cipher PEMCipher
name string
cipherFunc func(key []byte) (cipher.Block, error)
keySize int
blockSize int
}
// rfc1423Algos holds a slice of the possible ways to encrypt a PEM
// block. The ivSize numbers were taken from the OpenSSL source.
var rfc1423Algos = []rfc1423Algo{{
cipher: PEMCipherDES,
name: "DES-CBC",
cipherFunc: des.NewCipher,
keySize: 8,
blockSize: des.BlockSize,
}, {
cipher: PEMCipher3DES,
name: "DES-EDE3-CBC",
cipherFunc: des.NewTripleDESCipher,
keySize: 24,
blockSize: des.BlockSize,
}, {
cipher: PEMCipherAES128,
name: "AES-128-CBC",
cipherFunc: aes.NewCipher,
keySize: 16,
blockSize: aes.BlockSize,
}, {
cipher: PEMCipherAES192,
name: "AES-192-CBC",
cipherFunc: aes.NewCipher,
keySize: 24,
blockSize: aes.BlockSize,
}, {
cipher: PEMCipherAES256,
name: "AES-256-CBC",
cipherFunc: aes.NewCipher,
keySize: 32,
blockSize: aes.BlockSize,
},
}
// deriveKey uses a key derivation function to stretch the password into a key
// with the number of bits our cipher requires. This algorithm was derived from
// the OpenSSL source.
func (c rfc1423Algo) deriveKey(password, salt []byte) []byte {
hash := md5.New()
out := make([]byte, c.keySize)
var digest []byte
for i := 0; i < len(out); i += len(digest) {
hash.Reset()
hash.Write(digest)
hash.Write(password)
hash.Write(salt)
digest = hash.Sum(digest[:0])
copy(out[i:], digest)
}
return out
}
// IsEncryptedPEMBlock returns whether the PEM block is password encrypted
// according to RFC 1423.
//
// Deprecated: Legacy PEM encryption as specified in RFC 1423 is insecure by
// design. Since it does not authenticate the ciphertext, it is vulnerable to
// padding oracle attacks that can let an attacker recover the plaintext.
func IsEncryptedPEMBlock(b *pem.Block) bool {
_, ok := b.Headers["DEK-Info"]
return ok
}
// IncorrectPasswordError is returned when an incorrect password is detected.
var IncorrectPasswordError = errors.New("x509: decryption password incorrect")
// DecryptPEMBlock takes a PEM block encrypted according to RFC 1423 and the
// password used to encrypt it and returns a slice of decrypted DER encoded
// bytes. It inspects the DEK-Info header to determine the algorithm used for
// decryption. If no DEK-Info header is present, an error is returned. If an
// incorrect password is detected an [IncorrectPasswordError] is returned. Because
// of deficiencies in the format, it's not always possible to detect an
// incorrect password. In these cases no error will be returned but the
// decrypted DER bytes will be random noise.
//
// Deprecated: Legacy PEM encryption as specified in RFC 1423 is insecure by
// design. Since it does not authenticate the ciphertext, it is vulnerable to
// padding oracle attacks that can let an attacker recover the plaintext.
func DecryptPEMBlock(b *pem.Block, password []byte) ([]byte, error) {
dek, ok := b.Headers["DEK-Info"]
if !ok {
return nil, errors.New("x509: no DEK-Info header in block")
}
mode, hexIV, ok := strings.Cut(dek, ",")
if !ok {
return nil, errors.New("x509: malformed DEK-Info header")
}
ciph := cipherByName(mode)
if ciph == nil {
return nil, errors.New("x509: unknown encryption mode")
}
iv, err := hex.DecodeString(hexIV)
if err != nil {
return nil, err
}
if len(iv) != ciph.blockSize {
return nil, errors.New("x509: incorrect IV size")
}
// Based on the OpenSSL implementation. The salt is the first 8 bytes
// of the initialization vector.
key := ciph.deriveKey(password, iv[:8])
block, err := ciph.cipherFunc(key)
if err != nil {
return nil, err
}
if len(b.Bytes)%block.BlockSize() != 0 {
return nil, errors.New("x509: encrypted PEM data is not a multiple of the block size")
}
data := make([]byte, len(b.Bytes))
dec := cipher.NewCBCDecrypter(block, iv)
dec.CryptBlocks(data, b.Bytes)
// Blocks are padded using a scheme where the last n bytes of padding are all
// equal to n. It can pad from 1 to blocksize bytes inclusive. See RFC 1423.
// For example:
// [x y z 2 2]
// [x y 7 7 7 7 7 7 7]
// If we detect a bad padding, we assume it is an invalid password.
dlen := len(data)
if dlen == 0 || dlen%ciph.blockSize != 0 {
return nil, errors.New("x509: invalid padding")
}
last := int(data[dlen-1])
if dlen < last {
return nil, IncorrectPasswordError
}
if last == 0 || last > ciph.blockSize {
return nil, IncorrectPasswordError
}
for _, val := range data[dlen-last:] {
if int(val) != last {
return nil, IncorrectPasswordError
}
}
return data[:dlen-last], nil
}
// EncryptPEMBlock returns a PEM block of the specified type holding the
// given DER encoded data encrypted with the specified algorithm and
// password according to RFC 1423.
//
// Deprecated: Legacy PEM encryption as specified in RFC 1423 is insecure by
// design. Since it does not authenticate the ciphertext, it is vulnerable to
// padding oracle attacks that can let an attacker recover the plaintext.
func EncryptPEMBlock(rand io.Reader, blockType string, data, password []byte, alg PEMCipher) (*pem.Block, error) {
ciph := cipherByKey(alg)
if ciph == nil {
return nil, errors.New("x509: unknown encryption mode")
}
iv := make([]byte, ciph.blockSize)
if _, err := io.ReadFull(rand, iv); err != nil {
return nil, errors.New("x509: cannot generate IV: " + err.Error())
}
// The salt is the first 8 bytes of the initialization vector,
// matching the key derivation in DecryptPEMBlock.
key := ciph.deriveKey(password, iv[:8])
block, err := ciph.cipherFunc(key)
if err != nil {
return nil, err
}
enc := cipher.NewCBCEncrypter(block, iv)
pad := ciph.blockSize - len(data)%ciph.blockSize
encrypted := make([]byte, len(data), len(data)+pad)
// We could save this copy by encrypting all the whole blocks in
// the data separately, but it doesn't seem worth the additional
// code.
copy(encrypted, data)
// See RFC 1423, Section 1.1.
for i := 0; i < pad; i++ {
encrypted = append(encrypted, byte(pad))
}
enc.CryptBlocks(encrypted, encrypted)
return &pem.Block{
Type: blockType,
Headers: map[string]string{
"Proc-Type": "4,ENCRYPTED",
"DEK-Info": ciph.name + "," + hex.EncodeToString(iv),
},
Bytes: encrypted,
}, nil
}
func cipherByName(name string) *rfc1423Algo {
for i := range rfc1423Algos {
alg := &rfc1423Algos[i]
if alg.name == name {
return alg
}
}
return nil
}
func cipherByKey(key PEMCipher) *rfc1423Algo {
for i := range rfc1423Algos {
alg := &rfc1423Algos[i]
if alg.cipher == key {
return alg
}
}
return nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"crypto/rsa"
"encoding/asn1"
"errors"
"internal/godebug"
"math/big"
)
// pkcs1PrivateKey is a structure which mirrors the PKCS #1 ASN.1 for an RSA private key.
type pkcs1PrivateKey struct {
Version int
N *big.Int
E int
D *big.Int
P *big.Int
Q *big.Int
Dp *big.Int `asn1:"optional"`
Dq *big.Int `asn1:"optional"`
Qinv *big.Int `asn1:"optional"`
AdditionalPrimes []pkcs1AdditionalRSAPrime `asn1:"optional,omitempty"`
}
type pkcs1AdditionalRSAPrime struct {
Prime *big.Int
// We ignore these values because rsa will calculate them.
Exp *big.Int
Coeff *big.Int
}
// pkcs1PublicKey reflects the ASN.1 structure of a PKCS #1 public key.
type pkcs1PublicKey struct {
N *big.Int
E int
}
// x509rsacrt, if zero, makes ParsePKCS1PrivateKey ignore and recompute invalid
// CRT values in the RSA private key.
var x509rsacrt = godebug.New("x509rsacrt")
// ParsePKCS1PrivateKey parses an [RSA] private key in PKCS #1, ASN.1 DER form.
//
// This kind of key is commonly encoded in PEM blocks of type "RSA PRIVATE KEY".
//
// Before Go 1.24, the CRT parameters were ignored and recomputed. To restore
// the old behavior, use the GODEBUG=x509rsacrt=0 environment variable.
func ParsePKCS1PrivateKey(der []byte) (*rsa.PrivateKey, error) {
var priv pkcs1PrivateKey
rest, err := asn1.Unmarshal(der, &priv)
if len(rest) > 0 {
return nil, asn1.SyntaxError{Msg: "trailing data"}
}
if err != nil {
if _, err := asn1.Unmarshal(der, &ecPrivateKey{}); err == nil {
return nil, errors.New("x509: failed to parse private key (use ParseECPrivateKey instead for this key format)")
}
if _, err := asn1.Unmarshal(der, &pkcs8{}); err == nil {
return nil, errors.New("x509: failed to parse private key (use ParsePKCS8PrivateKey instead for this key format)")
}
return nil, err
}
if priv.Version > 1 {
return nil, errors.New("x509: unsupported private key version")
}
if priv.N.Sign() <= 0 || priv.D.Sign() <= 0 || priv.P.Sign() <= 0 || priv.Q.Sign() <= 0 ||
priv.Dp != nil && priv.Dp.Sign() <= 0 ||
priv.Dq != nil && priv.Dq.Sign() <= 0 ||
priv.Qinv != nil && priv.Qinv.Sign() <= 0 {
return nil, errors.New("x509: private key contains zero or negative value")
}
key := new(rsa.PrivateKey)
key.PublicKey = rsa.PublicKey{
E: priv.E,
N: priv.N,
}
key.D = priv.D
key.Primes = make([]*big.Int, 2+len(priv.AdditionalPrimes))
key.Primes[0] = priv.P
key.Primes[1] = priv.Q
key.Precomputed.Dp = priv.Dp
key.Precomputed.Dq = priv.Dq
key.Precomputed.Qinv = priv.Qinv
for i, a := range priv.AdditionalPrimes {
if a.Prime.Sign() <= 0 {
return nil, errors.New("x509: private key contains zero or negative prime")
}
key.Primes[i+2] = a.Prime
// We ignore the other two values because rsa will calculate
// them as needed.
}
key.Precompute()
if err := key.Validate(); err != nil {
// If x509rsacrt=0 is set, try dropping the CRT values and
// rerunning precomputation and key validation.
if x509rsacrt.Value() == "0" {
key.Precomputed.Dp = nil
key.Precomputed.Dq = nil
key.Precomputed.Qinv = nil
key.Precompute()
if err := key.Validate(); err == nil {
x509rsacrt.IncNonDefault()
return key, nil
}
}
return nil, err
}
return key, nil
}
// MarshalPKCS1PrivateKey converts an [RSA] private key to PKCS #1, ASN.1 DER form.
//
// This kind of key is commonly encoded in PEM blocks of type "RSA PRIVATE KEY".
// For a more flexible key format which is not [RSA] specific, use
// [MarshalPKCS8PrivateKey].
//
// The key must have passed validation by calling [rsa.PrivateKey.Validate]
// first. MarshalPKCS1PrivateKey calls [rsa.PrivateKey.Precompute], which may
// modify the key if not already precomputed.
func MarshalPKCS1PrivateKey(key *rsa.PrivateKey) []byte {
key.Precompute()
version := 0
if len(key.Primes) > 2 {
version = 1
}
priv := pkcs1PrivateKey{
Version: version,
N: key.N,
E: key.PublicKey.E,
D: key.D,
P: key.Primes[0],
Q: key.Primes[1],
Dp: key.Precomputed.Dp,
Dq: key.Precomputed.Dq,
Qinv: key.Precomputed.Qinv,
}
priv.AdditionalPrimes = make([]pkcs1AdditionalRSAPrime, len(key.Precomputed.CRTValues))
for i, values := range key.Precomputed.CRTValues {
priv.AdditionalPrimes[i].Prime = key.Primes[2+i]
priv.AdditionalPrimes[i].Exp = values.Exp
priv.AdditionalPrimes[i].Coeff = values.Coeff
}
b, _ := asn1.Marshal(priv)
return b
}
// ParsePKCS1PublicKey parses an [RSA] public key in PKCS #1, ASN.1 DER form.
//
// This kind of key is commonly encoded in PEM blocks of type "RSA PUBLIC KEY".
func ParsePKCS1PublicKey(der []byte) (*rsa.PublicKey, error) {
var pub pkcs1PublicKey
rest, err := asn1.Unmarshal(der, &pub)
if err != nil {
if _, err := asn1.Unmarshal(der, &publicKeyInfo{}); err == nil {
return nil, errors.New("x509: failed to parse public key (use ParsePKIXPublicKey instead for this key format)")
}
return nil, err
}
if len(rest) > 0 {
return nil, asn1.SyntaxError{Msg: "trailing data"}
}
if pub.N.Sign() <= 0 || pub.E <= 0 {
return nil, errors.New("x509: public key contains zero or negative value")
}
if pub.E > 1<<31-1 {
return nil, errors.New("x509: public key contains large public exponent")
}
return &rsa.PublicKey{
E: pub.E,
N: pub.N,
}, nil
}
// MarshalPKCS1PublicKey converts an [RSA] public key to PKCS #1, ASN.1 DER form.
//
// This kind of key is commonly encoded in PEM blocks of type "RSA PUBLIC KEY".
func MarshalPKCS1PublicKey(key *rsa.PublicKey) []byte {
derBytes, _ := asn1.Marshal(pkcs1PublicKey{
N: key.N,
E: key.E,
})
return derBytes
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"fmt"
)
// pkcs8 reflects an ASN.1, PKCS #8 PrivateKey. See
// ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-8/pkcs-8v1_2.asn
// and RFC 5208.
type pkcs8 struct {
Version int
Algo pkix.AlgorithmIdentifier
PrivateKey []byte
// optional attributes omitted.
}
// ParsePKCS8PrivateKey parses an unencrypted private key in PKCS #8, ASN.1 DER form.
//
// It returns a *[rsa.PrivateKey], an *[ecdsa.PrivateKey], an [ed25519.PrivateKey] (not
// a pointer), or an *[ecdh.PrivateKey] (for X25519). More types might be supported
// in the future.
//
// This kind of key is commonly encoded in PEM blocks of type "PRIVATE KEY".
//
// Before Go 1.24, the CRT parameters of RSA keys were ignored and recomputed.
// To restore the old behavior, use the GODEBUG=x509rsacrt=0 environment variable.
func ParsePKCS8PrivateKey(der []byte) (key any, err error) {
var privKey pkcs8
if _, err := asn1.Unmarshal(der, &privKey); err != nil {
if _, err := asn1.Unmarshal(der, &ecPrivateKey{}); err == nil {
return nil, errors.New("x509: failed to parse private key (use ParseECPrivateKey instead for this key format)")
}
if _, err := asn1.Unmarshal(der, &pkcs1PrivateKey{}); err == nil {
return nil, errors.New("x509: failed to parse private key (use ParsePKCS1PrivateKey instead for this key format)")
}
return nil, err
}
switch {
case privKey.Algo.Algorithm.Equal(oidPublicKeyRSA):
key, err = ParsePKCS1PrivateKey(privKey.PrivateKey)
if err != nil {
return nil, errors.New("x509: failed to parse RSA private key embedded in PKCS#8: " + err.Error())
}
return key, nil
case privKey.Algo.Algorithm.Equal(oidPublicKeyECDSA):
bytes := privKey.Algo.Parameters.FullBytes
namedCurveOID := new(asn1.ObjectIdentifier)
if _, err := asn1.Unmarshal(bytes, namedCurveOID); err != nil {
namedCurveOID = nil
}
key, err = parseECPrivateKey(namedCurveOID, privKey.PrivateKey)
if err != nil {
return nil, errors.New("x509: failed to parse EC private key embedded in PKCS#8: " + err.Error())
}
return key, nil
case privKey.Algo.Algorithm.Equal(oidPublicKeyEd25519):
if l := len(privKey.Algo.Parameters.FullBytes); l != 0 {
return nil, errors.New("x509: invalid Ed25519 private key parameters")
}
var curvePrivateKey []byte
if _, err := asn1.Unmarshal(privKey.PrivateKey, &curvePrivateKey); err != nil {
return nil, fmt.Errorf("x509: invalid Ed25519 private key: %v", err)
}
if l := len(curvePrivateKey); l != ed25519.SeedSize {
return nil, fmt.Errorf("x509: invalid Ed25519 private key length: %d", l)
}
return ed25519.NewKeyFromSeed(curvePrivateKey), nil
case privKey.Algo.Algorithm.Equal(oidPublicKeyX25519):
if l := len(privKey.Algo.Parameters.FullBytes); l != 0 {
return nil, errors.New("x509: invalid X25519 private key parameters")
}
var curvePrivateKey []byte
if _, err := asn1.Unmarshal(privKey.PrivateKey, &curvePrivateKey); err != nil {
return nil, fmt.Errorf("x509: invalid X25519 private key: %v", err)
}
return ecdh.X25519().NewPrivateKey(curvePrivateKey)
default:
return nil, fmt.Errorf("x509: PKCS#8 wrapping contained private key with unknown algorithm: %v", privKey.Algo.Algorithm)
}
}
// MarshalPKCS8PrivateKey converts a private key to PKCS #8, ASN.1 DER form.
//
// The following key types are currently supported: *[rsa.PrivateKey],
// *[ecdsa.PrivateKey], [ed25519.PrivateKey] (not a pointer), and *[ecdh.PrivateKey].
// Unsupported key types result in an error.
//
// This kind of key is commonly encoded in PEM blocks of type "PRIVATE KEY".
//
// MarshalPKCS8PrivateKey runs [rsa.PrivateKey.Precompute] on RSA keys.
func MarshalPKCS8PrivateKey(key any) ([]byte, error) {
var privKey pkcs8
switch k := key.(type) {
case *rsa.PrivateKey:
privKey.Algo = pkix.AlgorithmIdentifier{
Algorithm: oidPublicKeyRSA,
Parameters: asn1.NullRawValue,
}
k.Precompute()
if err := k.Validate(); err != nil {
return nil, err
}
privKey.PrivateKey = MarshalPKCS1PrivateKey(k)
case *ecdsa.PrivateKey:
oid, ok := oidFromNamedCurve(k.Curve)
if !ok {
return nil, errors.New("x509: unknown curve while marshaling to PKCS#8")
}
oidBytes, err := asn1.Marshal(oid)
if err != nil {
return nil, errors.New("x509: failed to marshal curve OID: " + err.Error())
}
privKey.Algo = pkix.AlgorithmIdentifier{
Algorithm: oidPublicKeyECDSA,
Parameters: asn1.RawValue{
FullBytes: oidBytes,
},
}
if privKey.PrivateKey, err = marshalECPrivateKeyWithOID(k, nil); err != nil {
return nil, errors.New("x509: failed to marshal EC private key while building PKCS#8: " + err.Error())
}
case ed25519.PrivateKey:
privKey.Algo = pkix.AlgorithmIdentifier{
Algorithm: oidPublicKeyEd25519,
}
curvePrivateKey, err := asn1.Marshal(k.Seed())
if err != nil {
return nil, fmt.Errorf("x509: failed to marshal private key: %v", err)
}
privKey.PrivateKey = curvePrivateKey
case *ecdh.PrivateKey:
if k.Curve() == ecdh.X25519() {
privKey.Algo = pkix.AlgorithmIdentifier{
Algorithm: oidPublicKeyX25519,
}
var err error
if privKey.PrivateKey, err = asn1.Marshal(k.Bytes()); err != nil {
return nil, fmt.Errorf("x509: failed to marshal private key: %v", err)
}
} else {
oid, ok := oidFromECDHCurve(k.Curve())
if !ok {
return nil, errors.New("x509: unknown curve while marshaling to PKCS#8")
}
oidBytes, err := asn1.Marshal(oid)
if err != nil {
return nil, errors.New("x509: failed to marshal curve OID: " + err.Error())
}
privKey.Algo = pkix.AlgorithmIdentifier{
Algorithm: oidPublicKeyECDSA,
Parameters: asn1.RawValue{
FullBytes: oidBytes,
},
}
if privKey.PrivateKey, err = marshalECDHPrivateKey(k); err != nil {
return nil, errors.New("x509: failed to marshal EC private key while building PKCS#8: " + err.Error())
}
}
default:
return nil, fmt.Errorf("x509: unknown key type while marshaling PKCS#8: %T", key)
}
return asn1.Marshal(privKey)
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"internal/godebug"
"sync"
_ "unsafe" // for linkname
)
// systemRoots should be an internal detail,
// but widely used packages access it using linkname.
// Notable members of the hall of shame include:
// - github.com/breml/rootcerts
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname systemRoots
var (
once sync.Once
systemRootsMu sync.RWMutex
systemRoots *CertPool
systemRootsErr error
fallbacksSet bool
)
func systemRootsPool() *CertPool {
once.Do(initSystemRoots)
systemRootsMu.RLock()
defer systemRootsMu.RUnlock()
return systemRoots
}
func initSystemRoots() {
systemRootsMu.Lock()
defer systemRootsMu.Unlock()
systemRoots, systemRootsErr = loadSystemRoots()
if systemRootsErr != nil {
systemRoots = nil
}
}
var x509usefallbackroots = godebug.New("x509usefallbackroots")
// SetFallbackRoots sets the roots to use during certificate verification, if no
// custom roots are specified and a platform verifier or a system certificate
// pool is not available (for instance in a container which does not have a root
// certificate bundle). SetFallbackRoots will panic if roots is nil.
//
// SetFallbackRoots may only be called once, if called multiple times it will
// panic.
//
// The fallback behavior can be forced on all platforms, even when there is a
// system certificate pool, by setting GODEBUG=x509usefallbackroots=1 (note that
// on Windows and macOS this will disable usage of the platform verification
// APIs and cause the pure Go verifier to be used). Setting
// x509usefallbackroots=1 without calling SetFallbackRoots has no effect.
func SetFallbackRoots(roots *CertPool) {
if roots == nil {
panic("roots must be non-nil")
}
// trigger initSystemRoots if it hasn't already been called before we
// take the lock
_ = systemRootsPool()
systemRootsMu.Lock()
defer systemRootsMu.Unlock()
if fallbacksSet {
panic("SetFallbackRoots has already been called")
}
fallbacksSet = true
if systemRoots != nil && (systemRoots.len() > 0 || systemRoots.systemPool) {
if x509usefallbackroots.Value() != "1" {
return
}
x509usefallbackroots.IncNonDefault()
}
systemRoots, systemRootsErr = roots, nil
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import "internal/goos"
// Possible certificate files; stop after finding one.
var certFiles = []string{
"/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu/Gentoo etc.
"/etc/pki/tls/certs/ca-bundle.crt", // Fedora/RHEL 6
"/etc/ssl/ca-bundle.pem", // OpenSUSE
"/etc/pki/tls/cacert.pem", // OpenELEC
"/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem", // CentOS/RHEL 7
"/etc/ssl/cert.pem", // Alpine Linux
}
// Possible directories with certificate files; all will be read.
var certDirectories = []string{
"/etc/ssl/certs", // SLES10/SLES11, https://golang.org/issue/12139
"/etc/pki/tls/certs", // Fedora/RHEL
}
func init() {
if goos.IsAndroid == 1 {
certDirectories = append(certDirectories,
"/system/etc/security/cacerts", // Android system roots
"/data/misc/keychain/certs-added", // User trusted CA folder
)
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || dragonfly || freebsd || (js && wasm) || linux || netbsd || openbsd || solaris || wasip1
package x509
import (
"io/fs"
"os"
"path/filepath"
"strings"
)
const (
// certFileEnv is the environment variable which identifies where to locate
// the SSL certificate file. If set this overrides the system default.
certFileEnv = "SSL_CERT_FILE"
// certDirEnv is the environment variable which identifies which directory
// to check for SSL certificate files. If set this overrides the system default.
// It is a colon separated list of directories.
// See https://www.openssl.org/docs/man1.0.2/man1/c_rehash.html.
certDirEnv = "SSL_CERT_DIR"
)
func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) {
return nil, nil
}
func loadSystemRoots() (*CertPool, error) {
roots := NewCertPool()
files := certFiles
if f := os.Getenv(certFileEnv); f != "" {
files = []string{f}
}
var firstErr error
for _, file := range files {
data, err := os.ReadFile(file)
if err == nil {
roots.AppendCertsFromPEM(data)
break
}
if firstErr == nil && !os.IsNotExist(err) {
firstErr = err
}
}
dirs := certDirectories
if d := os.Getenv(certDirEnv); d != "" {
// OpenSSL and BoringSSL both use ":" as the SSL_CERT_DIR separator.
// See:
// * https://golang.org/issue/35325
// * https://www.openssl.org/docs/man1.0.2/man1/c_rehash.html
dirs = strings.Split(d, ":")
}
for _, directory := range dirs {
fis, err := readUniqueDirectoryEntries(directory)
if err != nil {
if firstErr == nil && !os.IsNotExist(err) {
firstErr = err
}
continue
}
for _, fi := range fis {
data, err := os.ReadFile(directory + "/" + fi.Name())
if err == nil {
roots.AppendCertsFromPEM(data)
}
}
}
if roots.len() > 0 || firstErr == nil {
return roots, nil
}
return nil, firstErr
}
// readUniqueDirectoryEntries is like os.ReadDir but omits
// symlinks that point within the directory.
func readUniqueDirectoryEntries(dir string) ([]fs.DirEntry, error) {
files, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
uniq := files[:0]
for _, f := range files {
if !isSameDirSymlink(f, dir) {
uniq = append(uniq, f)
}
}
return uniq, nil
}
// isSameDirSymlink reports whether fi in dir is a symlink with a
// target not containing a slash.
func isSameDirSymlink(f fs.DirEntry, dir string) bool {
if f.Type()&fs.ModeSymlink == 0 {
return false
}
target, err := os.Readlink(filepath.Join(dir, f.Name()))
return err == nil && !strings.Contains(target, "/")
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"crypto/ecdh"
"crypto/ecdsa"
"crypto/elliptic"
"encoding/asn1"
"errors"
"fmt"
"math/big"
)
const ecPrivKeyVersion = 1
// ecPrivateKey reflects an ASN.1 Elliptic Curve Private Key Structure.
// References:
//
// RFC 5915
// SEC1 - http://www.secg.org/sec1-v2.pdf
//
// Per RFC 5915 the NamedCurveOID is marked as ASN.1 OPTIONAL, however in
// most cases it is not.
type ecPrivateKey struct {
Version int
PrivateKey []byte
NamedCurveOID asn1.ObjectIdentifier `asn1:"optional,explicit,tag:0"`
PublicKey asn1.BitString `asn1:"optional,explicit,tag:1"`
}
// ParseECPrivateKey parses an EC private key in SEC 1, ASN.1 DER form.
//
// This kind of key is commonly encoded in PEM blocks of type "EC PRIVATE KEY".
func ParseECPrivateKey(der []byte) (*ecdsa.PrivateKey, error) {
return parseECPrivateKey(nil, der)
}
// MarshalECPrivateKey converts an EC private key to SEC 1, ASN.1 DER form.
//
// This kind of key is commonly encoded in PEM blocks of type "EC PRIVATE KEY".
// For a more flexible key format which is not EC specific, use
// [MarshalPKCS8PrivateKey].
func MarshalECPrivateKey(key *ecdsa.PrivateKey) ([]byte, error) {
oid, ok := oidFromNamedCurve(key.Curve)
if !ok {
return nil, errors.New("x509: unknown elliptic curve")
}
return marshalECPrivateKeyWithOID(key, oid)
}
// marshalECPrivateKeyWithOID marshals an EC private key into ASN.1, DER format and
// sets the curve ID to the given OID, or omits it if OID is nil.
func marshalECPrivateKeyWithOID(key *ecdsa.PrivateKey, oid asn1.ObjectIdentifier) ([]byte, error) {
if !key.Curve.IsOnCurve(key.X, key.Y) {
return nil, errors.New("invalid elliptic key public key")
}
privateKey := make([]byte, (key.Curve.Params().N.BitLen()+7)/8)
return asn1.Marshal(ecPrivateKey{
Version: 1,
PrivateKey: key.D.FillBytes(privateKey),
NamedCurveOID: oid,
PublicKey: asn1.BitString{Bytes: elliptic.Marshal(key.Curve, key.X, key.Y)},
})
}
// marshalECDHPrivateKey marshals an EC private key into ASN.1, DER format
// suitable for NIST curves.
func marshalECDHPrivateKey(key *ecdh.PrivateKey) ([]byte, error) {
return asn1.Marshal(ecPrivateKey{
Version: 1,
PrivateKey: key.Bytes(),
PublicKey: asn1.BitString{Bytes: key.PublicKey().Bytes()},
})
}
// parseECPrivateKey parses an ASN.1 Elliptic Curve Private Key Structure.
// The OID for the named curve may be provided from another source (such as
// the PKCS8 container) - if it is provided then use this instead of the OID
// that may exist in the EC private key structure.
func parseECPrivateKey(namedCurveOID *asn1.ObjectIdentifier, der []byte) (key *ecdsa.PrivateKey, err error) {
var privKey ecPrivateKey
if _, err := asn1.Unmarshal(der, &privKey); err != nil {
if _, err := asn1.Unmarshal(der, &pkcs8{}); err == nil {
return nil, errors.New("x509: failed to parse private key (use ParsePKCS8PrivateKey instead for this key format)")
}
if _, err := asn1.Unmarshal(der, &pkcs1PrivateKey{}); err == nil {
return nil, errors.New("x509: failed to parse private key (use ParsePKCS1PrivateKey instead for this key format)")
}
return nil, errors.New("x509: failed to parse EC private key: " + err.Error())
}
if privKey.Version != ecPrivKeyVersion {
return nil, fmt.Errorf("x509: unknown EC private key version %d", privKey.Version)
}
var curve elliptic.Curve
if namedCurveOID != nil {
curve = namedCurveFromOID(*namedCurveOID)
} else {
curve = namedCurveFromOID(privKey.NamedCurveOID)
}
if curve == nil {
return nil, errors.New("x509: unknown elliptic curve")
}
k := new(big.Int).SetBytes(privKey.PrivateKey)
curveOrder := curve.Params().N
if k.Cmp(curveOrder) >= 0 {
return nil, errors.New("x509: invalid elliptic curve private key value")
}
priv := new(ecdsa.PrivateKey)
priv.Curve = curve
priv.D = k
privateKey := make([]byte, (curveOrder.BitLen()+7)/8)
// Some private keys have leading zero padding. This is invalid
// according to [SEC1], but this code will ignore it.
for len(privKey.PrivateKey) > len(privateKey) {
if privKey.PrivateKey[0] != 0 {
return nil, errors.New("x509: invalid private key length")
}
privKey.PrivateKey = privKey.PrivateKey[1:]
}
// Some private keys remove all leading zeros, this is also invalid
// according to [SEC1] but since OpenSSL used to do this, we ignore
// this too.
copy(privateKey[len(privateKey)-len(privKey.PrivateKey):], privKey.PrivateKey)
priv.X, priv.Y = curve.ScalarBaseMult(privateKey)
return priv, nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"bytes"
"crypto"
"crypto/x509/pkix"
"errors"
"fmt"
"iter"
"maps"
"net"
"net/netip"
"net/url"
"reflect"
"runtime"
"strings"
"time"
"unicode/utf8"
)
type InvalidReason int
const (
// NotAuthorizedToSign results when a certificate is signed by another
// which isn't marked as a CA certificate.
NotAuthorizedToSign InvalidReason = iota
// Expired results when a certificate has expired, based on the time
// given in the VerifyOptions.
Expired
// CANotAuthorizedForThisName results when an intermediate or root
// certificate has a name constraint which doesn't permit a DNS or
// other name (including IP address) in the leaf certificate.
CANotAuthorizedForThisName
// TooManyIntermediates results when a path length constraint is
// violated.
TooManyIntermediates
// IncompatibleUsage results when the certificate's key usage indicates
// that it may only be used for a different purpose.
IncompatibleUsage
// NameMismatch results when the subject name of a parent certificate
// does not match the issuer name in the child.
NameMismatch
// NameConstraintsWithoutSANs is a legacy error and is no longer returned.
NameConstraintsWithoutSANs
// UnconstrainedName results when a CA certificate contains permitted
// name constraints, but leaf certificate contains a name of an
// unsupported or unconstrained type.
UnconstrainedName
// TooManyConstraints results when the number of comparison operations
// needed to check a certificate exceeds the limit set by
// VerifyOptions.MaxConstraintComparisions. This limit exists to
// prevent pathological certificates can consuming excessive amounts of
// CPU time to verify.
TooManyConstraints
// CANotAuthorizedForExtKeyUsage results when an intermediate or root
// certificate does not permit a requested extended key usage.
CANotAuthorizedForExtKeyUsage
// NoValidChains results when there are no valid chains to return.
NoValidChains
)
// CertificateInvalidError results when an odd error occurs. Users of this
// library probably want to handle all these errors uniformly.
type CertificateInvalidError struct {
Cert *Certificate
Reason InvalidReason
Detail string
}
func (e CertificateInvalidError) Error() string {
switch e.Reason {
case NotAuthorizedToSign:
return "x509: certificate is not authorized to sign other certificates"
case Expired:
return "x509: certificate has expired or is not yet valid: " + e.Detail
case CANotAuthorizedForThisName:
return "x509: a root or intermediate certificate is not authorized to sign for this name: " + e.Detail
case CANotAuthorizedForExtKeyUsage:
return "x509: a root or intermediate certificate is not authorized for an extended key usage: " + e.Detail
case TooManyIntermediates:
return "x509: too many intermediates for path length constraint"
case IncompatibleUsage:
return "x509: certificate specifies an incompatible key usage"
case NameMismatch:
return "x509: issuer name does not match subject from issuing certificate"
case NameConstraintsWithoutSANs:
return "x509: issuer has name constraints but leaf doesn't have a SAN extension"
case UnconstrainedName:
return "x509: issuer has name constraints but leaf contains unknown or unconstrained name: " + e.Detail
case NoValidChains:
s := "x509: no valid chains built"
if e.Detail != "" {
s = fmt.Sprintf("%s: %s", s, e.Detail)
}
return s
}
return "x509: unknown error"
}
// HostnameError results when the set of authorized names doesn't match the
// requested name.
type HostnameError struct {
Certificate *Certificate
Host string
}
func (h HostnameError) Error() string {
c := h.Certificate
if !c.hasSANExtension() && matchHostnames(c.Subject.CommonName, h.Host) {
return "x509: certificate relies on legacy Common Name field, use SANs instead"
}
var valid string
if ip := net.ParseIP(h.Host); ip != nil {
// Trying to validate an IP
if len(c.IPAddresses) == 0 {
return "x509: cannot validate certificate for " + h.Host + " because it doesn't contain any IP SANs"
}
for _, san := range c.IPAddresses {
if len(valid) > 0 {
valid += ", "
}
valid += san.String()
}
} else {
valid = strings.Join(c.DNSNames, ", ")
}
if len(valid) == 0 {
return "x509: certificate is not valid for any names, but wanted to match " + h.Host
}
return "x509: certificate is valid for " + valid + ", not " + h.Host
}
// UnknownAuthorityError results when the certificate issuer is unknown
type UnknownAuthorityError struct {
Cert *Certificate
// hintErr contains an error that may be helpful in determining why an
// authority wasn't found.
hintErr error
// hintCert contains a possible authority certificate that was rejected
// because of the error in hintErr.
hintCert *Certificate
}
func (e UnknownAuthorityError) Error() string {
s := "x509: certificate signed by unknown authority"
if e.hintErr != nil {
certName := e.hintCert.Subject.CommonName
if len(certName) == 0 {
if len(e.hintCert.Subject.Organization) > 0 {
certName = e.hintCert.Subject.Organization[0]
} else {
certName = "serial:" + e.hintCert.SerialNumber.String()
}
}
s += fmt.Sprintf(" (possibly because of %q while trying to verify candidate authority certificate %q)", e.hintErr, certName)
}
return s
}
// SystemRootsError results when we fail to load the system root certificates.
type SystemRootsError struct {
Err error
}
func (se SystemRootsError) Error() string {
msg := "x509: failed to load system roots and no roots provided"
if se.Err != nil {
return msg + "; " + se.Err.Error()
}
return msg
}
func (se SystemRootsError) Unwrap() error { return se.Err }
// errNotParsed is returned when a certificate without ASN.1 contents is
// verified. Platform-specific verification needs the ASN.1 contents.
var errNotParsed = errors.New("x509: missing ASN.1 contents; use ParseCertificate")
// VerifyOptions contains parameters for Certificate.Verify.
type VerifyOptions struct {
// DNSName, if set, is checked against the leaf certificate with
// Certificate.VerifyHostname or the platform verifier.
DNSName string
// Intermediates is an optional pool of certificates that are not trust
// anchors, but can be used to form a chain from the leaf certificate to a
// root certificate.
Intermediates *CertPool
// Roots is the set of trusted root certificates the leaf certificate needs
// to chain up to. If nil, the system roots or the platform verifier are used.
Roots *CertPool
// CurrentTime is used to check the validity of all certificates in the
// chain. If zero, the current time is used.
CurrentTime time.Time
// KeyUsages specifies which Extended Key Usage values are acceptable. A
// chain is accepted if it allows any of the listed values. An empty list
// means ExtKeyUsageServerAuth. To accept any key usage, include ExtKeyUsageAny.
KeyUsages []ExtKeyUsage
// MaxConstraintComparisions is the maximum number of comparisons to
// perform when checking a given certificate's name constraints. If
// zero, a sensible default is used. This limit prevents pathological
// certificates from consuming excessive amounts of CPU time when
// validating. It does not apply to the platform verifier.
MaxConstraintComparisions int
// CertificatePolicies specifies which certificate policy OIDs are
// acceptable during policy validation. An empty CertificatePolices
// field implies any valid policy is acceptable.
CertificatePolicies []OID
// The following policy fields are unexported, because we do not expect
// users to actually need to use them, but are useful for testing the
// policy validation code.
// inhibitPolicyMapping indicates if policy mapping should be allowed
// during path validation.
inhibitPolicyMapping bool
// requireExplicitPolicy indidicates if explicit policies must be present
// for each certificate being validated.
requireExplicitPolicy bool
// inhibitAnyPolicy indicates if the anyPolicy policy should be
// processed if present in a certificate being validated.
inhibitAnyPolicy bool
}
const (
leafCertificate = iota
intermediateCertificate
rootCertificate
)
// rfc2821Mailbox represents a “mailbox” (which is an email address to most
// people) by breaking it into the “local” (i.e. before the '@') and “domain”
// parts.
type rfc2821Mailbox struct {
local, domain string
}
// parseRFC2821Mailbox parses an email address into local and domain parts,
// based on the ABNF for a “Mailbox” from RFC 2821. According to RFC 5280,
// Section 4.2.1.6 that's correct for an rfc822Name from a certificate: “The
// format of an rfc822Name is a "Mailbox" as defined in RFC 2821, Section 4.1.2”.
func parseRFC2821Mailbox(in string) (mailbox rfc2821Mailbox, ok bool) {
if len(in) == 0 {
return mailbox, false
}
localPartBytes := make([]byte, 0, len(in)/2)
if in[0] == '"' {
// Quoted-string = DQUOTE *qcontent DQUOTE
// non-whitespace-control = %d1-8 / %d11 / %d12 / %d14-31 / %d127
// qcontent = qtext / quoted-pair
// qtext = non-whitespace-control /
// %d33 / %d35-91 / %d93-126
// quoted-pair = ("\" text) / obs-qp
// text = %d1-9 / %d11 / %d12 / %d14-127 / obs-text
//
// (Names beginning with “obs-” are the obsolete syntax from RFC 2822,
// Section 4. Since it has been 16 years, we no longer accept that.)
in = in[1:]
QuotedString:
for {
if len(in) == 0 {
return mailbox, false
}
c := in[0]
in = in[1:]
switch {
case c == '"':
break QuotedString
case c == '\\':
// quoted-pair
if len(in) == 0 {
return mailbox, false
}
if in[0] == 11 ||
in[0] == 12 ||
(1 <= in[0] && in[0] <= 9) ||
(14 <= in[0] && in[0] <= 127) {
localPartBytes = append(localPartBytes, in[0])
in = in[1:]
} else {
return mailbox, false
}
case c == 11 ||
c == 12 ||
// Space (char 32) is not allowed based on the
// BNF, but RFC 3696 gives an example that
// assumes that it is. Several “verified”
// errata continue to argue about this point.
// We choose to accept it.
c == 32 ||
c == 33 ||
c == 127 ||
(1 <= c && c <= 8) ||
(14 <= c && c <= 31) ||
(35 <= c && c <= 91) ||
(93 <= c && c <= 126):
// qtext
localPartBytes = append(localPartBytes, c)
default:
return mailbox, false
}
}
} else {
// Atom ("." Atom)*
NextChar:
for len(in) > 0 {
// atext from RFC 2822, Section 3.2.4
c := in[0]
switch {
case c == '\\':
// Examples given in RFC 3696 suggest that
// escaped characters can appear outside of a
// quoted string. Several “verified” errata
// continue to argue the point. We choose to
// accept it.
in = in[1:]
if len(in) == 0 {
return mailbox, false
}
fallthrough
case ('0' <= c && c <= '9') ||
('a' <= c && c <= 'z') ||
('A' <= c && c <= 'Z') ||
c == '!' || c == '#' || c == '$' || c == '%' ||
c == '&' || c == '\'' || c == '*' || c == '+' ||
c == '-' || c == '/' || c == '=' || c == '?' ||
c == '^' || c == '_' || c == '`' || c == '{' ||
c == '|' || c == '}' || c == '~' || c == '.':
localPartBytes = append(localPartBytes, in[0])
in = in[1:]
default:
break NextChar
}
}
if len(localPartBytes) == 0 {
return mailbox, false
}
// From RFC 3696, Section 3:
// “period (".") may also appear, but may not be used to start
// or end the local part, nor may two or more consecutive
// periods appear.”
twoDots := []byte{'.', '.'}
if localPartBytes[0] == '.' ||
localPartBytes[len(localPartBytes)-1] == '.' ||
bytes.Contains(localPartBytes, twoDots) {
return mailbox, false
}
}
if len(in) == 0 || in[0] != '@' {
return mailbox, false
}
in = in[1:]
// The RFC species a format for domains, but that's known to be
// violated in practice so we accept that anything after an '@' is the
// domain part.
if _, ok := domainToReverseLabels(in); !ok {
return mailbox, false
}
mailbox.local = string(localPartBytes)
mailbox.domain = in
return mailbox, true
}
// domainToReverseLabels converts a textual domain name like foo.example.com to
// the list of labels in reverse order, e.g. ["com", "example", "foo"].
func domainToReverseLabels(domain string) (reverseLabels []string, ok bool) {
for len(domain) > 0 {
if i := strings.LastIndexByte(domain, '.'); i == -1 {
reverseLabels = append(reverseLabels, domain)
domain = ""
} else {
reverseLabels = append(reverseLabels, domain[i+1:])
domain = domain[:i]
if i == 0 { // domain == ""
// domain is prefixed with an empty label, append an empty
// string to reverseLabels to indicate this.
reverseLabels = append(reverseLabels, "")
}
}
}
if len(reverseLabels) > 0 && len(reverseLabels[0]) == 0 {
// An empty label at the end indicates an absolute value.
return nil, false
}
for _, label := range reverseLabels {
if len(label) == 0 {
// Empty labels are otherwise invalid.
return nil, false
}
for _, c := range label {
if c < 33 || c > 126 {
// Invalid character.
return nil, false
}
}
}
return reverseLabels, true
}
func matchEmailConstraint(mailbox rfc2821Mailbox, constraint string) (bool, error) {
// If the constraint contains an @, then it specifies an exact mailbox
// name.
if strings.Contains(constraint, "@") {
constraintMailbox, ok := parseRFC2821Mailbox(constraint)
if !ok {
return false, fmt.Errorf("x509: internal error: cannot parse constraint %q", constraint)
}
return mailbox.local == constraintMailbox.local && strings.EqualFold(mailbox.domain, constraintMailbox.domain), nil
}
// Otherwise the constraint is like a DNS constraint of the domain part
// of the mailbox.
return matchDomainConstraint(mailbox.domain, constraint)
}
func matchURIConstraint(uri *url.URL, constraint string) (bool, error) {
// From RFC 5280, Section 4.2.1.10:
// “a uniformResourceIdentifier that does not include an authority
// component with a host name specified as a fully qualified domain
// name (e.g., if the URI either does not include an authority
// component or includes an authority component in which the host name
// is specified as an IP address), then the application MUST reject the
// certificate.”
host := uri.Host
if len(host) == 0 {
return false, fmt.Errorf("URI with empty host (%q) cannot be matched against constraints", uri.String())
}
if strings.Contains(host, ":") && !strings.HasSuffix(host, "]") {
var err error
host, _, err = net.SplitHostPort(uri.Host)
if err != nil {
return false, err
}
}
// netip.ParseAddr will reject the URI IPv6 literal form "[...]", so we
// check if _either_ the string parses as an IP, or if it is enclosed in
// square brackets.
if _, err := netip.ParseAddr(host); err == nil || (strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]")) {
return false, fmt.Errorf("URI with IP (%q) cannot be matched against constraints", uri.String())
}
return matchDomainConstraint(host, constraint)
}
func matchIPConstraint(ip net.IP, constraint *net.IPNet) (bool, error) {
if len(ip) != len(constraint.IP) {
return false, nil
}
for i := range ip {
if mask := constraint.Mask[i]; ip[i]&mask != constraint.IP[i]&mask {
return false, nil
}
}
return true, nil
}
func matchDomainConstraint(domain, constraint string) (bool, error) {
// The meaning of zero length constraints is not specified, but this
// code follows NSS and accepts them as matching everything.
if len(constraint) == 0 {
return true, nil
}
domainLabels, ok := domainToReverseLabels(domain)
if !ok {
return false, fmt.Errorf("x509: internal error: cannot parse domain %q", domain)
}
// RFC 5280 says that a leading period in a domain name means that at
// least one label must be prepended, but only for URI and email
// constraints, not DNS constraints. The code also supports that
// behaviour for DNS constraints.
mustHaveSubdomains := false
if constraint[0] == '.' {
mustHaveSubdomains = true
constraint = constraint[1:]
}
constraintLabels, ok := domainToReverseLabels(constraint)
if !ok {
return false, fmt.Errorf("x509: internal error: cannot parse domain %q", constraint)
}
if len(domainLabels) < len(constraintLabels) ||
(mustHaveSubdomains && len(domainLabels) == len(constraintLabels)) {
return false, nil
}
for i, constraintLabel := range constraintLabels {
if !strings.EqualFold(constraintLabel, domainLabels[i]) {
return false, nil
}
}
return true, nil
}
// checkNameConstraints checks that c permits a child certificate to claim the
// given name, of type nameType. The argument parsedName contains the parsed
// form of name, suitable for passing to the match function. The total number
// of comparisons is tracked in the given count and should not exceed the given
// limit.
func (c *Certificate) checkNameConstraints(count *int,
maxConstraintComparisons int,
nameType string,
name string,
parsedName any,
match func(parsedName, constraint any) (match bool, err error),
permitted, excluded any) error {
excludedValue := reflect.ValueOf(excluded)
*count += excludedValue.Len()
if *count > maxConstraintComparisons {
return CertificateInvalidError{c, TooManyConstraints, ""}
}
for i := 0; i < excludedValue.Len(); i++ {
constraint := excludedValue.Index(i).Interface()
match, err := match(parsedName, constraint)
if err != nil {
return CertificateInvalidError{c, CANotAuthorizedForThisName, err.Error()}
}
if match {
return CertificateInvalidError{c, CANotAuthorizedForThisName, fmt.Sprintf("%s %q is excluded by constraint %q", nameType, name, constraint)}
}
}
permittedValue := reflect.ValueOf(permitted)
*count += permittedValue.Len()
if *count > maxConstraintComparisons {
return CertificateInvalidError{c, TooManyConstraints, ""}
}
ok := true
for i := 0; i < permittedValue.Len(); i++ {
constraint := permittedValue.Index(i).Interface()
var err error
if ok, err = match(parsedName, constraint); err != nil {
return CertificateInvalidError{c, CANotAuthorizedForThisName, err.Error()}
}
if ok {
break
}
}
if !ok {
return CertificateInvalidError{c, CANotAuthorizedForThisName, fmt.Sprintf("%s %q is not permitted by any constraint", nameType, name)}
}
return nil
}
// isValid performs validity checks on c given that it is a candidate to append
// to the chain in currentChain.
func (c *Certificate) isValid(certType int, currentChain []*Certificate, opts *VerifyOptions) error {
if len(c.UnhandledCriticalExtensions) > 0 {
return UnhandledCriticalExtension{}
}
if len(currentChain) > 0 {
child := currentChain[len(currentChain)-1]
if !bytes.Equal(child.RawIssuer, c.RawSubject) {
return CertificateInvalidError{c, NameMismatch, ""}
}
}
now := opts.CurrentTime
if now.IsZero() {
now = time.Now()
}
if now.Before(c.NotBefore) {
return CertificateInvalidError{
Cert: c,
Reason: Expired,
Detail: fmt.Sprintf("current time %s is before %s", now.Format(time.RFC3339), c.NotBefore.Format(time.RFC3339)),
}
} else if now.After(c.NotAfter) {
return CertificateInvalidError{
Cert: c,
Reason: Expired,
Detail: fmt.Sprintf("current time %s is after %s", now.Format(time.RFC3339), c.NotAfter.Format(time.RFC3339)),
}
}
maxConstraintComparisons := opts.MaxConstraintComparisions
if maxConstraintComparisons == 0 {
maxConstraintComparisons = 250000
}
comparisonCount := 0
if certType == intermediateCertificate || certType == rootCertificate {
if len(currentChain) == 0 {
return errors.New("x509: internal error: empty chain when appending CA cert")
}
}
if (certType == intermediateCertificate || certType == rootCertificate) &&
c.hasNameConstraints() {
toCheck := []*Certificate{}
for _, c := range currentChain {
if c.hasSANExtension() {
toCheck = append(toCheck, c)
}
}
for _, sanCert := range toCheck {
err := forEachSAN(sanCert.getSANExtension(), func(tag int, data []byte) error {
switch tag {
case nameTypeEmail:
name := string(data)
mailbox, ok := parseRFC2821Mailbox(name)
if !ok {
return fmt.Errorf("x509: cannot parse rfc822Name %q", mailbox)
}
if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "email address", name, mailbox,
func(parsedName, constraint any) (bool, error) {
return matchEmailConstraint(parsedName.(rfc2821Mailbox), constraint.(string))
}, c.PermittedEmailAddresses, c.ExcludedEmailAddresses); err != nil {
return err
}
case nameTypeDNS:
name := string(data)
if _, ok := domainToReverseLabels(name); !ok {
return fmt.Errorf("x509: cannot parse dnsName %q", name)
}
if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "DNS name", name, name,
func(parsedName, constraint any) (bool, error) {
return matchDomainConstraint(parsedName.(string), constraint.(string))
}, c.PermittedDNSDomains, c.ExcludedDNSDomains); err != nil {
return err
}
case nameTypeURI:
name := string(data)
uri, err := url.Parse(name)
if err != nil {
return fmt.Errorf("x509: internal error: URI SAN %q failed to parse", name)
}
if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "URI", name, uri,
func(parsedName, constraint any) (bool, error) {
return matchURIConstraint(parsedName.(*url.URL), constraint.(string))
}, c.PermittedURIDomains, c.ExcludedURIDomains); err != nil {
return err
}
case nameTypeIP:
ip := net.IP(data)
if l := len(ip); l != net.IPv4len && l != net.IPv6len {
return fmt.Errorf("x509: internal error: IP SAN %x failed to parse", data)
}
if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "IP address", ip.String(), ip,
func(parsedName, constraint any) (bool, error) {
return matchIPConstraint(parsedName.(net.IP), constraint.(*net.IPNet))
}, c.PermittedIPRanges, c.ExcludedIPRanges); err != nil {
return err
}
default:
// Unknown SAN types are ignored.
}
return nil
})
if err != nil {
return err
}
}
}
// KeyUsage status flags are ignored. From Engineering Security, Peter
// Gutmann: A European government CA marked its signing certificates as
// being valid for encryption only, but no-one noticed. Another
// European CA marked its signature keys as not being valid for
// signatures. A different CA marked its own trusted root certificate
// as being invalid for certificate signing. Another national CA
// distributed a certificate to be used to encrypt data for the
// country’s tax authority that was marked as only being usable for
// digital signatures but not for encryption. Yet another CA reversed
// the order of the bit flags in the keyUsage due to confusion over
// encoding endianness, essentially setting a random keyUsage in
// certificates that it issued. Another CA created a self-invalidating
// certificate by adding a certificate policy statement stipulating
// that the certificate had to be used strictly as specified in the
// keyUsage, and a keyUsage containing a flag indicating that the RSA
// encryption key could only be used for Diffie-Hellman key agreement.
if certType == intermediateCertificate && (!c.BasicConstraintsValid || !c.IsCA) {
return CertificateInvalidError{c, NotAuthorizedToSign, ""}
}
if c.BasicConstraintsValid && c.MaxPathLen >= 0 {
numIntermediates := len(currentChain) - 1
if numIntermediates > c.MaxPathLen {
return CertificateInvalidError{c, TooManyIntermediates, ""}
}
}
return nil
}
// Verify attempts to verify c by building one or more chains from c to a
// certificate in opts.Roots, using certificates in opts.Intermediates if
// needed. If successful, it returns one or more chains where the first
// element of the chain is c and the last element is from opts.Roots.
//
// If opts.Roots is nil, the platform verifier might be used, and
// verification details might differ from what is described below. If system
// roots are unavailable the returned error will be of type SystemRootsError.
//
// Name constraints in the intermediates will be applied to all names claimed
// in the chain, not just opts.DNSName. Thus it is invalid for a leaf to claim
// example.com if an intermediate doesn't permit it, even if example.com is not
// the name being validated. Note that DirectoryName constraints are not
// supported.
//
// Name constraint validation follows the rules from RFC 5280, with the
// addition that DNS name constraints may use the leading period format
// defined for emails and URIs. When a constraint has a leading period
// it indicates that at least one additional label must be prepended to
// the constrained name to be considered valid.
//
// Extended Key Usage values are enforced nested down a chain, so an intermediate
// or root that enumerates EKUs prevents a leaf from asserting an EKU not in that
// list. (While this is not specified, it is common practice in order to limit
// the types of certificates a CA can issue.)
//
// Certificates that use SHA1WithRSA and ECDSAWithSHA1 signatures are not supported,
// and will not be used to build chains.
//
// Certificates other than c in the returned chains should not be modified.
//
// WARNING: this function doesn't do any revocation checking.
func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err error) {
// Platform-specific verification needs the ASN.1 contents so
// this makes the behavior consistent across platforms.
if len(c.Raw) == 0 {
return nil, errNotParsed
}
for i := 0; i < opts.Intermediates.len(); i++ {
c, _, err := opts.Intermediates.cert(i)
if err != nil {
return nil, fmt.Errorf("crypto/x509: error fetching intermediate: %w", err)
}
if len(c.Raw) == 0 {
return nil, errNotParsed
}
}
// Use platform verifiers, where available, if Roots is from SystemCertPool.
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
// Don't use the system verifier if the system pool was replaced with a non-system pool,
// i.e. if SetFallbackRoots was called with x509usefallbackroots=1.
systemPool := systemRootsPool()
if opts.Roots == nil && (systemPool == nil || systemPool.systemPool) {
return c.systemVerify(&opts)
}
if opts.Roots != nil && opts.Roots.systemPool {
platformChains, err := c.systemVerify(&opts)
// If the platform verifier succeeded, or there are no additional
// roots, return the platform verifier result. Otherwise, continue
// with the Go verifier.
if err == nil || opts.Roots.len() == 0 {
return platformChains, err
}
}
}
if opts.Roots == nil {
opts.Roots = systemRootsPool()
if opts.Roots == nil {
return nil, SystemRootsError{systemRootsErr}
}
}
err = c.isValid(leafCertificate, nil, &opts)
if err != nil {
return
}
if len(opts.DNSName) > 0 {
err = c.VerifyHostname(opts.DNSName)
if err != nil {
return
}
}
var candidateChains [][]*Certificate
if opts.Roots.contains(c) {
candidateChains = [][]*Certificate{{c}}
} else {
candidateChains, err = c.buildChains([]*Certificate{c}, nil, &opts)
if err != nil {
return nil, err
}
}
chains = make([][]*Certificate, 0, len(candidateChains))
var invalidPoliciesChains int
for _, candidate := range candidateChains {
if !policiesValid(candidate, opts) {
invalidPoliciesChains++
continue
}
chains = append(chains, candidate)
}
if len(chains) == 0 {
return nil, CertificateInvalidError{c, NoValidChains, "all candidate chains have invalid policies"}
}
for _, eku := range opts.KeyUsages {
if eku == ExtKeyUsageAny {
// If any key usage is acceptable, no need to check the chain for
// key usages.
return chains, nil
}
}
if len(opts.KeyUsages) == 0 {
opts.KeyUsages = []ExtKeyUsage{ExtKeyUsageServerAuth}
}
candidateChains = chains
chains = chains[:0]
var incompatibleKeyUsageChains int
for _, candidate := range candidateChains {
if !checkChainForKeyUsage(candidate, opts.KeyUsages) {
incompatibleKeyUsageChains++
continue
}
chains = append(chains, candidate)
}
if len(chains) == 0 {
var details []string
if incompatibleKeyUsageChains > 0 {
if invalidPoliciesChains == 0 {
return nil, CertificateInvalidError{c, IncompatibleUsage, ""}
}
details = append(details, fmt.Sprintf("%d chains with incompatible key usage", incompatibleKeyUsageChains))
}
if invalidPoliciesChains > 0 {
details = append(details, fmt.Sprintf("%d chains with invalid policies", invalidPoliciesChains))
}
err = CertificateInvalidError{c, NoValidChains, strings.Join(details, ", ")}
return nil, err
}
return chains, nil
}
func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate {
n := make([]*Certificate, len(chain)+1)
copy(n, chain)
n[len(chain)] = cert
return n
}
// alreadyInChain checks whether a candidate certificate is present in a chain.
// Rather than doing a direct byte for byte equivalency check, we check if the
// subject, public key, and SAN, if present, are equal. This prevents loops that
// are created by mutual cross-signatures, or other cross-signature bridge
// oddities.
func alreadyInChain(candidate *Certificate, chain []*Certificate) bool {
type pubKeyEqual interface {
Equal(crypto.PublicKey) bool
}
var candidateSAN *pkix.Extension
for _, ext := range candidate.Extensions {
if ext.Id.Equal(oidExtensionSubjectAltName) {
candidateSAN = &ext
break
}
}
for _, cert := range chain {
if !bytes.Equal(candidate.RawSubject, cert.RawSubject) {
continue
}
if !candidate.PublicKey.(pubKeyEqual).Equal(cert.PublicKey) {
continue
}
var certSAN *pkix.Extension
for _, ext := range cert.Extensions {
if ext.Id.Equal(oidExtensionSubjectAltName) {
certSAN = &ext
break
}
}
if candidateSAN == nil && certSAN == nil {
return true
} else if candidateSAN == nil || certSAN == nil {
return false
}
if bytes.Equal(candidateSAN.Value, certSAN.Value) {
return true
}
}
return false
}
// maxChainSignatureChecks is the maximum number of CheckSignatureFrom calls
// that an invocation of buildChains will (transitively) make. Most chains are
// less than 15 certificates long, so this leaves space for multiple chains and
// for failed checks due to different intermediates having the same Subject.
const maxChainSignatureChecks = 100
func (c *Certificate) buildChains(currentChain []*Certificate, sigChecks *int, opts *VerifyOptions) (chains [][]*Certificate, err error) {
var (
hintErr error
hintCert *Certificate
)
considerCandidate := func(certType int, candidate potentialParent) {
if candidate.cert.PublicKey == nil || alreadyInChain(candidate.cert, currentChain) {
return
}
if sigChecks == nil {
sigChecks = new(int)
}
*sigChecks++
if *sigChecks > maxChainSignatureChecks {
err = errors.New("x509: signature check attempts limit reached while verifying certificate chain")
return
}
if err := c.CheckSignatureFrom(candidate.cert); err != nil {
if hintErr == nil {
hintErr = err
hintCert = candidate.cert
}
return
}
err = candidate.cert.isValid(certType, currentChain, opts)
if err != nil {
if hintErr == nil {
hintErr = err
hintCert = candidate.cert
}
return
}
if candidate.constraint != nil {
if err := candidate.constraint(currentChain); err != nil {
if hintErr == nil {
hintErr = err
hintCert = candidate.cert
}
return
}
}
switch certType {
case rootCertificate:
chains = append(chains, appendToFreshChain(currentChain, candidate.cert))
case intermediateCertificate:
var childChains [][]*Certificate
childChains, err = candidate.cert.buildChains(appendToFreshChain(currentChain, candidate.cert), sigChecks, opts)
chains = append(chains, childChains...)
}
}
for _, root := range opts.Roots.findPotentialParents(c) {
considerCandidate(rootCertificate, root)
}
for _, intermediate := range opts.Intermediates.findPotentialParents(c) {
considerCandidate(intermediateCertificate, intermediate)
}
if len(chains) > 0 {
err = nil
}
if len(chains) == 0 && err == nil {
err = UnknownAuthorityError{c, hintErr, hintCert}
}
return
}
func validHostnamePattern(host string) bool { return validHostname(host, true) }
func validHostnameInput(host string) bool { return validHostname(host, false) }
// validHostname reports whether host is a valid hostname that can be matched or
// matched against according to RFC 6125 2.2, with some leniency to accommodate
// legacy values.
func validHostname(host string, isPattern bool) bool {
if !isPattern {
host = strings.TrimSuffix(host, ".")
}
if len(host) == 0 {
return false
}
if host == "*" {
// Bare wildcards are not allowed, they are not valid DNS names,
// nor are they allowed per RFC 6125.
return false
}
for i, part := range strings.Split(host, ".") {
if part == "" {
// Empty label.
return false
}
if isPattern && i == 0 && part == "*" {
// Only allow full left-most wildcards, as those are the only ones
// we match, and matching literal '*' characters is probably never
// the expected behavior.
continue
}
for j, c := range part {
if 'a' <= c && c <= 'z' {
continue
}
if '0' <= c && c <= '9' {
continue
}
if 'A' <= c && c <= 'Z' {
continue
}
if c == '-' && j != 0 {
continue
}
if c == '_' {
// Not a valid character in hostnames, but commonly
// found in deployments outside the WebPKI.
continue
}
return false
}
}
return true
}
func matchExactly(hostA, hostB string) bool {
if hostA == "" || hostA == "." || hostB == "" || hostB == "." {
return false
}
return toLowerCaseASCII(hostA) == toLowerCaseASCII(hostB)
}
func matchHostnames(pattern, host string) bool {
pattern = toLowerCaseASCII(pattern)
host = toLowerCaseASCII(strings.TrimSuffix(host, "."))
if len(pattern) == 0 || len(host) == 0 {
return false
}
patternParts := strings.Split(pattern, ".")
hostParts := strings.Split(host, ".")
if len(patternParts) != len(hostParts) {
return false
}
for i, patternPart := range patternParts {
if i == 0 && patternPart == "*" {
continue
}
if patternPart != hostParts[i] {
return false
}
}
return true
}
// toLowerCaseASCII returns a lower-case version of in. See RFC 6125 6.4.1. We use
// an explicitly ASCII function to avoid any sharp corners resulting from
// performing Unicode operations on DNS labels.
func toLowerCaseASCII(in string) string {
// If the string is already lower-case then there's nothing to do.
isAlreadyLowerCase := true
for _, c := range in {
if c == utf8.RuneError {
// If we get a UTF-8 error then there might be
// upper-case ASCII bytes in the invalid sequence.
isAlreadyLowerCase = false
break
}
if 'A' <= c && c <= 'Z' {
isAlreadyLowerCase = false
break
}
}
if isAlreadyLowerCase {
return in
}
out := []byte(in)
for i, c := range out {
if 'A' <= c && c <= 'Z' {
out[i] += 'a' - 'A'
}
}
return string(out)
}
// VerifyHostname returns nil if c is a valid certificate for the named host.
// Otherwise it returns an error describing the mismatch.
//
// IP addresses can be optionally enclosed in square brackets and are checked
// against the IPAddresses field. Other names are checked case insensitively
// against the DNSNames field. If the names are valid hostnames, the certificate
// fields can have a wildcard as the complete left-most label (e.g. *.example.com).
//
// Note that the legacy Common Name field is ignored.
func (c *Certificate) VerifyHostname(h string) error {
// IP addresses may be written in [ ].
candidateIP := h
if len(h) >= 3 && h[0] == '[' && h[len(h)-1] == ']' {
candidateIP = h[1 : len(h)-1]
}
if ip := net.ParseIP(candidateIP); ip != nil {
// We only match IP addresses against IP SANs.
// See RFC 6125, Appendix B.2.
for _, candidate := range c.IPAddresses {
if ip.Equal(candidate) {
return nil
}
}
return HostnameError{c, candidateIP}
}
candidateName := toLowerCaseASCII(h) // Save allocations inside the loop.
validCandidateName := validHostnameInput(candidateName)
for _, match := range c.DNSNames {
// Ideally, we'd only match valid hostnames according to RFC 6125 like
// browsers (more or less) do, but in practice Go is used in a wider
// array of contexts and can't even assume DNS resolution. Instead,
// always allow perfect matches, and only apply wildcard and trailing
// dot processing to valid hostnames.
if validCandidateName && validHostnamePattern(match) {
if matchHostnames(match, candidateName) {
return nil
}
} else {
if matchExactly(match, candidateName) {
return nil
}
}
}
return HostnameError{c, h}
}
func checkChainForKeyUsage(chain []*Certificate, keyUsages []ExtKeyUsage) bool {
usages := make([]ExtKeyUsage, len(keyUsages))
copy(usages, keyUsages)
if len(chain) == 0 {
return false
}
usagesRemaining := len(usages)
// We walk down the list and cross out any usages that aren't supported
// by each certificate. If we cross out all the usages, then the chain
// is unacceptable.
NextCert:
for i := len(chain) - 1; i >= 0; i-- {
cert := chain[i]
if len(cert.ExtKeyUsage) == 0 && len(cert.UnknownExtKeyUsage) == 0 {
// The certificate doesn't have any extended key usage specified.
continue
}
for _, usage := range cert.ExtKeyUsage {
if usage == ExtKeyUsageAny {
// The certificate is explicitly good for any usage.
continue NextCert
}
}
const invalidUsage ExtKeyUsage = -1
NextRequestedUsage:
for i, requestedUsage := range usages {
if requestedUsage == invalidUsage {
continue
}
for _, usage := range cert.ExtKeyUsage {
if requestedUsage == usage {
continue NextRequestedUsage
}
}
usages[i] = invalidUsage
usagesRemaining--
if usagesRemaining == 0 {
return false
}
}
}
return true
}
func mustNewOIDFromInts(ints []uint64) OID {
oid, err := OIDFromInts(ints)
if err != nil {
panic(fmt.Sprintf("OIDFromInts(%v) unexpected error: %v", ints, err))
}
return oid
}
type policyGraphNode struct {
validPolicy OID
expectedPolicySet []OID
// we do not implement qualifiers, so we don't track qualifier_set
parents map[*policyGraphNode]bool
children map[*policyGraphNode]bool
}
func newPolicyGraphNode(valid OID, parents []*policyGraphNode) *policyGraphNode {
n := &policyGraphNode{
validPolicy: valid,
expectedPolicySet: []OID{valid},
children: map[*policyGraphNode]bool{},
parents: map[*policyGraphNode]bool{},
}
for _, p := range parents {
p.children[n] = true
n.parents[p] = true
}
return n
}
type policyGraph struct {
strata []map[string]*policyGraphNode
// map of OID -> nodes at strata[depth-1] with OID in their expectedPolicySet
parentIndex map[string][]*policyGraphNode
depth int
}
var anyPolicyOID = mustNewOIDFromInts([]uint64{2, 5, 29, 32, 0})
func newPolicyGraph() *policyGraph {
root := policyGraphNode{
validPolicy: anyPolicyOID,
expectedPolicySet: []OID{anyPolicyOID},
children: map[*policyGraphNode]bool{},
parents: map[*policyGraphNode]bool{},
}
return &policyGraph{
depth: 0,
strata: []map[string]*policyGraphNode{{string(anyPolicyOID.der): &root}},
}
}
func (pg *policyGraph) insert(n *policyGraphNode) {
pg.strata[pg.depth][string(n.validPolicy.der)] = n
}
func (pg *policyGraph) parentsWithExpected(expected OID) []*policyGraphNode {
if pg.depth == 0 {
return nil
}
return pg.parentIndex[string(expected.der)]
}
func (pg *policyGraph) parentWithAnyPolicy() *policyGraphNode {
if pg.depth == 0 {
return nil
}
return pg.strata[pg.depth-1][string(anyPolicyOID.der)]
}
func (pg *policyGraph) parents() iter.Seq[*policyGraphNode] {
if pg.depth == 0 {
return nil
}
return maps.Values(pg.strata[pg.depth-1])
}
func (pg *policyGraph) leaves() map[string]*policyGraphNode {
return pg.strata[pg.depth]
}
func (pg *policyGraph) leafWithPolicy(policy OID) *policyGraphNode {
return pg.strata[pg.depth][string(policy.der)]
}
func (pg *policyGraph) deleteLeaf(policy OID) {
n := pg.strata[pg.depth][string(policy.der)]
if n == nil {
return
}
for p := range n.parents {
delete(p.children, n)
}
for c := range n.children {
delete(c.parents, n)
}
delete(pg.strata[pg.depth], string(policy.der))
}
func (pg *policyGraph) validPolicyNodes() []*policyGraphNode {
var validNodes []*policyGraphNode
for i := pg.depth; i >= 0; i-- {
for _, n := range pg.strata[i] {
if n.validPolicy.Equal(anyPolicyOID) {
continue
}
if len(n.parents) == 1 {
for p := range n.parents {
if p.validPolicy.Equal(anyPolicyOID) {
validNodes = append(validNodes, n)
}
}
}
}
}
return validNodes
}
func (pg *policyGraph) prune() {
for i := pg.depth - 1; i > 0; i-- {
for _, n := range pg.strata[i] {
if len(n.children) == 0 {
for p := range n.parents {
delete(p.children, n)
}
delete(pg.strata[i], string(n.validPolicy.der))
}
}
}
}
func (pg *policyGraph) incrDepth() {
pg.parentIndex = map[string][]*policyGraphNode{}
for _, n := range pg.strata[pg.depth] {
for _, e := range n.expectedPolicySet {
pg.parentIndex[string(e.der)] = append(pg.parentIndex[string(e.der)], n)
}
}
pg.depth++
pg.strata = append(pg.strata, map[string]*policyGraphNode{})
}
func policiesValid(chain []*Certificate, opts VerifyOptions) bool {
// The following code implements the policy verification algorithm as
// specified in RFC 5280 and updated by RFC 9618. In particular the
// following sections are replaced by RFC 9618:
// * 6.1.2 (a)
// * 6.1.3 (d)
// * 6.1.3 (e)
// * 6.1.3 (f)
// * 6.1.4 (b)
// * 6.1.5 (g)
if len(chain) == 1 {
return true
}
// n is the length of the chain minus the trust anchor
n := len(chain) - 1
pg := newPolicyGraph()
var inhibitAnyPolicy, explicitPolicy, policyMapping int
if !opts.inhibitAnyPolicy {
inhibitAnyPolicy = n + 1
}
if !opts.requireExplicitPolicy {
explicitPolicy = n + 1
}
if !opts.inhibitPolicyMapping {
policyMapping = n + 1
}
initialUserPolicySet := map[string]bool{}
for _, p := range opts.CertificatePolicies {
initialUserPolicySet[string(p.der)] = true
}
// If the user does not pass any policies, we consider
// that equivalent to passing anyPolicyOID.
if len(initialUserPolicySet) == 0 {
initialUserPolicySet[string(anyPolicyOID.der)] = true
}
for i := n - 1; i >= 0; i-- {
cert := chain[i]
isSelfSigned := bytes.Equal(cert.RawIssuer, cert.RawSubject)
// 6.1.3 (e) -- as updated by RFC 9618
if len(cert.Policies) == 0 {
pg = nil
}
// 6.1.3 (f) -- as updated by RFC 9618
if explicitPolicy == 0 && pg == nil {
return false
}
if pg != nil {
pg.incrDepth()
policies := map[string]bool{}
// 6.1.3 (d) (1) -- as updated by RFC 9618
for _, policy := range cert.Policies {
policies[string(policy.der)] = true
if policy.Equal(anyPolicyOID) {
continue
}
// 6.1.3 (d) (1) (i) -- as updated by RFC 9618
parents := pg.parentsWithExpected(policy)
if len(parents) == 0 {
// 6.1.3 (d) (1) (ii) -- as updated by RFC 9618
if anyParent := pg.parentWithAnyPolicy(); anyParent != nil {
parents = []*policyGraphNode{anyParent}
}
}
if len(parents) > 0 {
pg.insert(newPolicyGraphNode(policy, parents))
}
}
// 6.1.3 (d) (2) -- as updated by RFC 9618
// NOTE: in the check "n-i < n" our i is different from the i in the specification.
// In the specification chains go from the trust anchor to the leaf, whereas our
// chains go from the leaf to the trust anchor, so our i's our inverted. Our
// check here matches the check "i < n" in the specification.
if policies[string(anyPolicyOID.der)] && (inhibitAnyPolicy > 0 || (n-i < n && isSelfSigned)) {
missing := map[string][]*policyGraphNode{}
leaves := pg.leaves()
for p := range pg.parents() {
for _, expected := range p.expectedPolicySet {
if leaves[string(expected.der)] == nil {
missing[string(expected.der)] = append(missing[string(expected.der)], p)
}
}
}
for oidStr, parents := range missing {
pg.insert(newPolicyGraphNode(OID{der: []byte(oidStr)}, parents))
}
}
// 6.1.3 (d) (3) -- as updated by RFC 9618
pg.prune()
if i != 0 {
// 6.1.4 (b) -- as updated by RFC 9618
if len(cert.PolicyMappings) > 0 {
// collect map of issuer -> []subject
mappings := map[string][]OID{}
for _, mapping := range cert.PolicyMappings {
if policyMapping > 0 {
if mapping.IssuerDomainPolicy.Equal(anyPolicyOID) || mapping.SubjectDomainPolicy.Equal(anyPolicyOID) {
// Invalid mapping
return false
}
mappings[string(mapping.IssuerDomainPolicy.der)] = append(mappings[string(mapping.IssuerDomainPolicy.der)], mapping.SubjectDomainPolicy)
} else {
// 6.1.4 (b) (3) (i) -- as updated by RFC 9618
pg.deleteLeaf(mapping.IssuerDomainPolicy)
// 6.1.4 (b) (3) (ii) -- as updated by RFC 9618
pg.prune()
}
}
for issuerStr, subjectPolicies := range mappings {
// 6.1.4 (b) (1) -- as updated by RFC 9618
if matching := pg.leafWithPolicy(OID{der: []byte(issuerStr)}); matching != nil {
matching.expectedPolicySet = subjectPolicies
} else if matching := pg.leafWithPolicy(anyPolicyOID); matching != nil {
// 6.1.4 (b) (2) -- as updated by RFC 9618
n := newPolicyGraphNode(OID{der: []byte(issuerStr)}, []*policyGraphNode{matching})
n.expectedPolicySet = subjectPolicies
pg.insert(n)
}
}
}
}
}
if i != 0 {
// 6.1.4 (h)
if !isSelfSigned {
if explicitPolicy > 0 {
explicitPolicy--
}
if policyMapping > 0 {
policyMapping--
}
if inhibitAnyPolicy > 0 {
inhibitAnyPolicy--
}
}
// 6.1.4 (i)
if (cert.RequireExplicitPolicy > 0 || cert.RequireExplicitPolicyZero) && cert.RequireExplicitPolicy < explicitPolicy {
explicitPolicy = cert.RequireExplicitPolicy
}
if (cert.InhibitPolicyMapping > 0 || cert.InhibitPolicyMappingZero) && cert.InhibitPolicyMapping < policyMapping {
policyMapping = cert.InhibitPolicyMapping
}
// 6.1.4 (j)
if (cert.InhibitAnyPolicy > 0 || cert.InhibitAnyPolicyZero) && cert.InhibitAnyPolicy < inhibitAnyPolicy {
inhibitAnyPolicy = cert.InhibitAnyPolicy
}
}
}
// 6.1.5 (a)
if explicitPolicy > 0 {
explicitPolicy--
}
// 6.1.5 (b)
if chain[0].RequireExplicitPolicyZero {
explicitPolicy = 0
}
// 6.1.5 (g) (1) -- as updated by RFC 9618
var validPolicyNodeSet []*policyGraphNode
// 6.1.5 (g) (2) -- as updated by RFC 9618
if pg != nil {
validPolicyNodeSet = pg.validPolicyNodes()
// 6.1.5 (g) (3) -- as updated by RFC 9618
if currentAny := pg.leafWithPolicy(anyPolicyOID); currentAny != nil {
validPolicyNodeSet = append(validPolicyNodeSet, currentAny)
}
}
// 6.1.5 (g) (4) -- as updated by RFC 9618
authorityConstrainedPolicySet := map[string]bool{}
for _, n := range validPolicyNodeSet {
authorityConstrainedPolicySet[string(n.validPolicy.der)] = true
}
// 6.1.5 (g) (5) -- as updated by RFC 9618
userConstrainedPolicySet := maps.Clone(authorityConstrainedPolicySet)
// 6.1.5 (g) (6) -- as updated by RFC 9618
if len(initialUserPolicySet) != 1 || !initialUserPolicySet[string(anyPolicyOID.der)] {
// 6.1.5 (g) (6) (i) -- as updated by RFC 9618
for p := range userConstrainedPolicySet {
if !initialUserPolicySet[p] {
delete(userConstrainedPolicySet, p)
}
}
// 6.1.5 (g) (6) (ii) -- as updated by RFC 9618
if authorityConstrainedPolicySet[string(anyPolicyOID.der)] {
for policy := range initialUserPolicySet {
userConstrainedPolicySet[policy] = true
}
}
}
if explicitPolicy == 0 && len(userConstrainedPolicySet) == 0 {
return false
}
return true
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package x509 implements a subset of the X.509 standard.
//
// It allows parsing and generating certificates, certificate signing
// requests, certificate revocation lists, and encoded public and private keys.
// It provides a certificate verifier, complete with a chain builder.
//
// The package targets the X.509 technical profile defined by the IETF (RFC
// 2459/3280/5280), and as further restricted by the CA/Browser Forum Baseline
// Requirements. There is minimal support for features outside of these
// profiles, as the primary goal of the package is to provide compatibility
// with the publicly trusted TLS certificate ecosystem and its policies and
// constraints.
//
// On macOS and Windows, certificate verification is handled by system APIs, but
// the package aims to apply consistent validation rules across operating
// systems.
package x509
import (
"bytes"
"crypto"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"errors"
"fmt"
"internal/godebug"
"io"
"math/big"
"net"
"net/url"
"strconv"
"time"
"unicode"
// Explicitly import these for their crypto.RegisterHash init side-effects.
// Keep these as blank imports, even if they're imported above.
_ "crypto/sha1"
_ "crypto/sha256"
_ "crypto/sha512"
"golang.org/x/crypto/cryptobyte"
cryptobyte_asn1 "golang.org/x/crypto/cryptobyte/asn1"
)
// pkixPublicKey reflects a PKIX public key structure. See SubjectPublicKeyInfo
// in RFC 3280.
type pkixPublicKey struct {
Algo pkix.AlgorithmIdentifier
BitString asn1.BitString
}
// ParsePKIXPublicKey parses a public key in PKIX, ASN.1 DER form. The encoded
// public key is a SubjectPublicKeyInfo structure (see RFC 5280, Section 4.1).
//
// It returns a *[rsa.PublicKey], *[dsa.PublicKey], *[ecdsa.PublicKey],
// [ed25519.PublicKey] (not a pointer), or *[ecdh.PublicKey] (for X25519).
// More types might be supported in the future.
//
// This kind of key is commonly encoded in PEM blocks of type "PUBLIC KEY".
func ParsePKIXPublicKey(derBytes []byte) (pub any, err error) {
var pki publicKeyInfo
if rest, err := asn1.Unmarshal(derBytes, &pki); err != nil {
if _, err := asn1.Unmarshal(derBytes, &pkcs1PublicKey{}); err == nil {
return nil, errors.New("x509: failed to parse public key (use ParsePKCS1PublicKey instead for this key format)")
}
return nil, err
} else if len(rest) != 0 {
return nil, errors.New("x509: trailing data after ASN.1 of public-key")
}
return parsePublicKey(&pki)
}
func marshalPublicKey(pub any) (publicKeyBytes []byte, publicKeyAlgorithm pkix.AlgorithmIdentifier, err error) {
switch pub := pub.(type) {
case *rsa.PublicKey:
publicKeyBytes, err = asn1.Marshal(pkcs1PublicKey{
N: pub.N,
E: pub.E,
})
if err != nil {
return nil, pkix.AlgorithmIdentifier{}, err
}
publicKeyAlgorithm.Algorithm = oidPublicKeyRSA
// This is a NULL parameters value which is required by
// RFC 3279, Section 2.3.1.
publicKeyAlgorithm.Parameters = asn1.NullRawValue
case *ecdsa.PublicKey:
oid, ok := oidFromNamedCurve(pub.Curve)
if !ok {
return nil, pkix.AlgorithmIdentifier{}, errors.New("x509: unsupported elliptic curve")
}
if !pub.Curve.IsOnCurve(pub.X, pub.Y) {
return nil, pkix.AlgorithmIdentifier{}, errors.New("x509: invalid elliptic curve public key")
}
publicKeyBytes = elliptic.Marshal(pub.Curve, pub.X, pub.Y)
publicKeyAlgorithm.Algorithm = oidPublicKeyECDSA
var paramBytes []byte
paramBytes, err = asn1.Marshal(oid)
if err != nil {
return
}
publicKeyAlgorithm.Parameters.FullBytes = paramBytes
case ed25519.PublicKey:
publicKeyBytes = pub
publicKeyAlgorithm.Algorithm = oidPublicKeyEd25519
case *ecdh.PublicKey:
publicKeyBytes = pub.Bytes()
if pub.Curve() == ecdh.X25519() {
publicKeyAlgorithm.Algorithm = oidPublicKeyX25519
} else {
oid, ok := oidFromECDHCurve(pub.Curve())
if !ok {
return nil, pkix.AlgorithmIdentifier{}, errors.New("x509: unsupported elliptic curve")
}
publicKeyAlgorithm.Algorithm = oidPublicKeyECDSA
var paramBytes []byte
paramBytes, err = asn1.Marshal(oid)
if err != nil {
return
}
publicKeyAlgorithm.Parameters.FullBytes = paramBytes
}
default:
return nil, pkix.AlgorithmIdentifier{}, fmt.Errorf("x509: unsupported public key type: %T", pub)
}
return publicKeyBytes, publicKeyAlgorithm, nil
}
// MarshalPKIXPublicKey converts a public key to PKIX, ASN.1 DER form.
// The encoded public key is a SubjectPublicKeyInfo structure
// (see RFC 5280, Section 4.1).
//
// The following key types are currently supported: *[rsa.PublicKey],
// *[ecdsa.PublicKey], [ed25519.PublicKey] (not a pointer), and *[ecdh.PublicKey].
// Unsupported key types result in an error.
//
// This kind of key is commonly encoded in PEM blocks of type "PUBLIC KEY".
func MarshalPKIXPublicKey(pub any) ([]byte, error) {
var publicKeyBytes []byte
var publicKeyAlgorithm pkix.AlgorithmIdentifier
var err error
if publicKeyBytes, publicKeyAlgorithm, err = marshalPublicKey(pub); err != nil {
return nil, err
}
pkix := pkixPublicKey{
Algo: publicKeyAlgorithm,
BitString: asn1.BitString{
Bytes: publicKeyBytes,
BitLength: 8 * len(publicKeyBytes),
},
}
ret, _ := asn1.Marshal(pkix)
return ret, nil
}
// These structures reflect the ASN.1 structure of X.509 certificates.:
type certificate struct {
TBSCertificate tbsCertificate
SignatureAlgorithm pkix.AlgorithmIdentifier
SignatureValue asn1.BitString
}
type tbsCertificate struct {
Raw asn1.RawContent
Version int `asn1:"optional,explicit,default:0,tag:0"`
SerialNumber *big.Int
SignatureAlgorithm pkix.AlgorithmIdentifier
Issuer asn1.RawValue
Validity validity
Subject asn1.RawValue
PublicKey publicKeyInfo
UniqueId asn1.BitString `asn1:"optional,tag:1"`
SubjectUniqueId asn1.BitString `asn1:"optional,tag:2"`
Extensions []pkix.Extension `asn1:"omitempty,optional,explicit,tag:3"`
}
type dsaAlgorithmParameters struct {
P, Q, G *big.Int
}
type validity struct {
NotBefore, NotAfter time.Time
}
type publicKeyInfo struct {
Raw asn1.RawContent
Algorithm pkix.AlgorithmIdentifier
PublicKey asn1.BitString
}
// RFC 5280, 4.2.1.1
type authKeyId struct {
Id []byte `asn1:"optional,tag:0"`
}
type SignatureAlgorithm int
const (
UnknownSignatureAlgorithm SignatureAlgorithm = iota
MD2WithRSA // Unsupported.
MD5WithRSA // Only supported for signing, not verification.
SHA1WithRSA // Only supported for signing, and verification of CRLs, CSRs, and OCSP responses.
SHA256WithRSA
SHA384WithRSA
SHA512WithRSA
DSAWithSHA1 // Unsupported.
DSAWithSHA256 // Unsupported.
ECDSAWithSHA1 // Only supported for signing, and verification of CRLs, CSRs, and OCSP responses.
ECDSAWithSHA256
ECDSAWithSHA384
ECDSAWithSHA512
SHA256WithRSAPSS
SHA384WithRSAPSS
SHA512WithRSAPSS
PureEd25519
)
func (algo SignatureAlgorithm) isRSAPSS() bool {
for _, details := range signatureAlgorithmDetails {
if details.algo == algo {
return details.isRSAPSS
}
}
return false
}
func (algo SignatureAlgorithm) hashFunc() crypto.Hash {
for _, details := range signatureAlgorithmDetails {
if details.algo == algo {
return details.hash
}
}
return crypto.Hash(0)
}
func (algo SignatureAlgorithm) String() string {
for _, details := range signatureAlgorithmDetails {
if details.algo == algo {
return details.name
}
}
return strconv.Itoa(int(algo))
}
type PublicKeyAlgorithm int
const (
UnknownPublicKeyAlgorithm PublicKeyAlgorithm = iota
RSA
DSA // Only supported for parsing.
ECDSA
Ed25519
)
var publicKeyAlgoName = [...]string{
RSA: "RSA",
DSA: "DSA",
ECDSA: "ECDSA",
Ed25519: "Ed25519",
}
func (algo PublicKeyAlgorithm) String() string {
if 0 < algo && int(algo) < len(publicKeyAlgoName) {
return publicKeyAlgoName[algo]
}
return strconv.Itoa(int(algo))
}
// OIDs for signature algorithms
//
// pkcs-1 OBJECT IDENTIFIER ::= {
// iso(1) member-body(2) us(840) rsadsi(113549) pkcs(1) 1 }
//
// RFC 3279 2.2.1 RSA Signature Algorithms
//
// md5WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 4 }
//
// sha-1WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 5 }
//
// dsaWithSha1 OBJECT IDENTIFIER ::= {
// iso(1) member-body(2) us(840) x9-57(10040) x9cm(4) 3 }
//
// RFC 3279 2.2.3 ECDSA Signature Algorithm
//
// ecdsa-with-SHA1 OBJECT IDENTIFIER ::= {
// iso(1) member-body(2) us(840) ansi-x962(10045)
// signatures(4) ecdsa-with-SHA1(1)}
//
// RFC 4055 5 PKCS #1 Version 1.5
//
// sha256WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 11 }
//
// sha384WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 12 }
//
// sha512WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 13 }
//
// RFC 5758 3.1 DSA Signature Algorithms
//
// dsaWithSha256 OBJECT IDENTIFIER ::= {
// joint-iso-ccitt(2) country(16) us(840) organization(1) gov(101)
// csor(3) algorithms(4) id-dsa-with-sha2(3) 2}
//
// RFC 5758 3.2 ECDSA Signature Algorithm
//
// ecdsa-with-SHA256 OBJECT IDENTIFIER ::= { iso(1) member-body(2)
// us(840) ansi-X9-62(10045) signatures(4) ecdsa-with-SHA2(3) 2 }
//
// ecdsa-with-SHA384 OBJECT IDENTIFIER ::= { iso(1) member-body(2)
// us(840) ansi-X9-62(10045) signatures(4) ecdsa-with-SHA2(3) 3 }
//
// ecdsa-with-SHA512 OBJECT IDENTIFIER ::= { iso(1) member-body(2)
// us(840) ansi-X9-62(10045) signatures(4) ecdsa-with-SHA2(3) 4 }
//
// RFC 8410 3 Curve25519 and Curve448 Algorithm Identifiers
//
// id-Ed25519 OBJECT IDENTIFIER ::= { 1 3 101 112 }
var (
oidSignatureMD5WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 4}
oidSignatureSHA1WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 5}
oidSignatureSHA256WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 11}
oidSignatureSHA384WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 12}
oidSignatureSHA512WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 13}
oidSignatureRSAPSS = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 10}
oidSignatureDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 3}
oidSignatureDSAWithSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 3, 2}
oidSignatureECDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 1}
oidSignatureECDSAWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 2}
oidSignatureECDSAWithSHA384 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 3}
oidSignatureECDSAWithSHA512 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 4}
oidSignatureEd25519 = asn1.ObjectIdentifier{1, 3, 101, 112}
oidSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 1}
oidSHA384 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 2}
oidSHA512 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 3}
oidMGF1 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 8}
// oidISOSignatureSHA1WithRSA means the same as oidSignatureSHA1WithRSA
// but it's specified by ISO. Microsoft's makecert.exe has been known
// to produce certificates with this OID.
oidISOSignatureSHA1WithRSA = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 29}
)
var signatureAlgorithmDetails = []struct {
algo SignatureAlgorithm
name string
oid asn1.ObjectIdentifier
params asn1.RawValue
pubKeyAlgo PublicKeyAlgorithm
hash crypto.Hash
isRSAPSS bool
}{
{MD5WithRSA, "MD5-RSA", oidSignatureMD5WithRSA, asn1.NullRawValue, RSA, crypto.MD5, false},
{SHA1WithRSA, "SHA1-RSA", oidSignatureSHA1WithRSA, asn1.NullRawValue, RSA, crypto.SHA1, false},
{SHA1WithRSA, "SHA1-RSA", oidISOSignatureSHA1WithRSA, asn1.NullRawValue, RSA, crypto.SHA1, false},
{SHA256WithRSA, "SHA256-RSA", oidSignatureSHA256WithRSA, asn1.NullRawValue, RSA, crypto.SHA256, false},
{SHA384WithRSA, "SHA384-RSA", oidSignatureSHA384WithRSA, asn1.NullRawValue, RSA, crypto.SHA384, false},
{SHA512WithRSA, "SHA512-RSA", oidSignatureSHA512WithRSA, asn1.NullRawValue, RSA, crypto.SHA512, false},
{SHA256WithRSAPSS, "SHA256-RSAPSS", oidSignatureRSAPSS, pssParametersSHA256, RSA, crypto.SHA256, true},
{SHA384WithRSAPSS, "SHA384-RSAPSS", oidSignatureRSAPSS, pssParametersSHA384, RSA, crypto.SHA384, true},
{SHA512WithRSAPSS, "SHA512-RSAPSS", oidSignatureRSAPSS, pssParametersSHA512, RSA, crypto.SHA512, true},
{DSAWithSHA1, "DSA-SHA1", oidSignatureDSAWithSHA1, emptyRawValue, DSA, crypto.SHA1, false},
{DSAWithSHA256, "DSA-SHA256", oidSignatureDSAWithSHA256, emptyRawValue, DSA, crypto.SHA256, false},
{ECDSAWithSHA1, "ECDSA-SHA1", oidSignatureECDSAWithSHA1, emptyRawValue, ECDSA, crypto.SHA1, false},
{ECDSAWithSHA256, "ECDSA-SHA256", oidSignatureECDSAWithSHA256, emptyRawValue, ECDSA, crypto.SHA256, false},
{ECDSAWithSHA384, "ECDSA-SHA384", oidSignatureECDSAWithSHA384, emptyRawValue, ECDSA, crypto.SHA384, false},
{ECDSAWithSHA512, "ECDSA-SHA512", oidSignatureECDSAWithSHA512, emptyRawValue, ECDSA, crypto.SHA512, false},
{PureEd25519, "Ed25519", oidSignatureEd25519, emptyRawValue, Ed25519, crypto.Hash(0) /* no pre-hashing */, false},
}
var emptyRawValue = asn1.RawValue{}
// DER encoded RSA PSS parameters for the
// SHA256, SHA384, and SHA512 hashes as defined in RFC 3447, Appendix A.2.3.
// The parameters contain the following values:
// - hashAlgorithm contains the associated hash identifier with NULL parameters
// - maskGenAlgorithm always contains the default mgf1SHA1 identifier
// - saltLength contains the length of the associated hash
// - trailerField always contains the default trailerFieldBC value
var (
pssParametersSHA256 = asn1.RawValue{FullBytes: []byte{48, 52, 160, 15, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 1, 5, 0, 161, 28, 48, 26, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 8, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 1, 5, 0, 162, 3, 2, 1, 32}}
pssParametersSHA384 = asn1.RawValue{FullBytes: []byte{48, 52, 160, 15, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 2, 5, 0, 161, 28, 48, 26, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 8, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 2, 5, 0, 162, 3, 2, 1, 48}}
pssParametersSHA512 = asn1.RawValue{FullBytes: []byte{48, 52, 160, 15, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 3, 5, 0, 161, 28, 48, 26, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 8, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 3, 5, 0, 162, 3, 2, 1, 64}}
)
// pssParameters reflects the parameters in an AlgorithmIdentifier that
// specifies RSA PSS. See RFC 3447, Appendix A.2.3.
type pssParameters struct {
// The following three fields are not marked as
// optional because the default values specify SHA-1,
// which is no longer suitable for use in signatures.
Hash pkix.AlgorithmIdentifier `asn1:"explicit,tag:0"`
MGF pkix.AlgorithmIdentifier `asn1:"explicit,tag:1"`
SaltLength int `asn1:"explicit,tag:2"`
TrailerField int `asn1:"optional,explicit,tag:3,default:1"`
}
func getSignatureAlgorithmFromAI(ai pkix.AlgorithmIdentifier) SignatureAlgorithm {
if ai.Algorithm.Equal(oidSignatureEd25519) {
// RFC 8410, Section 3
// > For all of the OIDs, the parameters MUST be absent.
if len(ai.Parameters.FullBytes) != 0 {
return UnknownSignatureAlgorithm
}
}
if !ai.Algorithm.Equal(oidSignatureRSAPSS) {
for _, details := range signatureAlgorithmDetails {
if ai.Algorithm.Equal(details.oid) {
return details.algo
}
}
return UnknownSignatureAlgorithm
}
// RSA PSS is special because it encodes important parameters
// in the Parameters.
var params pssParameters
if _, err := asn1.Unmarshal(ai.Parameters.FullBytes, ¶ms); err != nil {
return UnknownSignatureAlgorithm
}
var mgf1HashFunc pkix.AlgorithmIdentifier
if _, err := asn1.Unmarshal(params.MGF.Parameters.FullBytes, &mgf1HashFunc); err != nil {
return UnknownSignatureAlgorithm
}
// PSS is greatly overburdened with options. This code forces them into
// three buckets by requiring that the MGF1 hash function always match the
// message hash function (as recommended in RFC 3447, Section 8.1), that the
// salt length matches the hash length, and that the trailer field has the
// default value.
if (len(params.Hash.Parameters.FullBytes) != 0 && !bytes.Equal(params.Hash.Parameters.FullBytes, asn1.NullBytes)) ||
!params.MGF.Algorithm.Equal(oidMGF1) ||
!mgf1HashFunc.Algorithm.Equal(params.Hash.Algorithm) ||
(len(mgf1HashFunc.Parameters.FullBytes) != 0 && !bytes.Equal(mgf1HashFunc.Parameters.FullBytes, asn1.NullBytes)) ||
params.TrailerField != 1 {
return UnknownSignatureAlgorithm
}
switch {
case params.Hash.Algorithm.Equal(oidSHA256) && params.SaltLength == 32:
return SHA256WithRSAPSS
case params.Hash.Algorithm.Equal(oidSHA384) && params.SaltLength == 48:
return SHA384WithRSAPSS
case params.Hash.Algorithm.Equal(oidSHA512) && params.SaltLength == 64:
return SHA512WithRSAPSS
}
return UnknownSignatureAlgorithm
}
var (
// RFC 3279, 2.3 Public Key Algorithms
//
// pkcs-1 OBJECT IDENTIFIER ::== { iso(1) member-body(2) us(840)
// rsadsi(113549) pkcs(1) 1 }
//
// rsaEncryption OBJECT IDENTIFIER ::== { pkcs1-1 1 }
//
// id-dsa OBJECT IDENTIFIER ::== { iso(1) member-body(2) us(840)
// x9-57(10040) x9cm(4) 1 }
oidPublicKeyRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1}
oidPublicKeyDSA = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 1}
// RFC 5480, 2.1.1 Unrestricted Algorithm Identifier and Parameters
//
// id-ecPublicKey OBJECT IDENTIFIER ::= {
// iso(1) member-body(2) us(840) ansi-X9-62(10045) keyType(2) 1 }
oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1}
// RFC 8410, Section 3
//
// id-X25519 OBJECT IDENTIFIER ::= { 1 3 101 110 }
// id-Ed25519 OBJECT IDENTIFIER ::= { 1 3 101 112 }
oidPublicKeyX25519 = asn1.ObjectIdentifier{1, 3, 101, 110}
oidPublicKeyEd25519 = asn1.ObjectIdentifier{1, 3, 101, 112}
)
// getPublicKeyAlgorithmFromOID returns the exposed PublicKeyAlgorithm
// identifier for public key types supported in certificates and CSRs. Marshal
// and Parse functions may support a different set of public key types.
func getPublicKeyAlgorithmFromOID(oid asn1.ObjectIdentifier) PublicKeyAlgorithm {
switch {
case oid.Equal(oidPublicKeyRSA):
return RSA
case oid.Equal(oidPublicKeyDSA):
return DSA
case oid.Equal(oidPublicKeyECDSA):
return ECDSA
case oid.Equal(oidPublicKeyEd25519):
return Ed25519
}
return UnknownPublicKeyAlgorithm
}
// RFC 5480, 2.1.1.1. Named Curve
//
// secp224r1 OBJECT IDENTIFIER ::= {
// iso(1) identified-organization(3) certicom(132) curve(0) 33 }
//
// secp256r1 OBJECT IDENTIFIER ::= {
// iso(1) member-body(2) us(840) ansi-X9-62(10045) curves(3)
// prime(1) 7 }
//
// secp384r1 OBJECT IDENTIFIER ::= {
// iso(1) identified-organization(3) certicom(132) curve(0) 34 }
//
// secp521r1 OBJECT IDENTIFIER ::= {
// iso(1) identified-organization(3) certicom(132) curve(0) 35 }
//
// NB: secp256r1 is equivalent to prime256v1
var (
oidNamedCurveP224 = asn1.ObjectIdentifier{1, 3, 132, 0, 33}
oidNamedCurveP256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 3, 1, 7}
oidNamedCurveP384 = asn1.ObjectIdentifier{1, 3, 132, 0, 34}
oidNamedCurveP521 = asn1.ObjectIdentifier{1, 3, 132, 0, 35}
)
func namedCurveFromOID(oid asn1.ObjectIdentifier) elliptic.Curve {
switch {
case oid.Equal(oidNamedCurveP224):
return elliptic.P224()
case oid.Equal(oidNamedCurveP256):
return elliptic.P256()
case oid.Equal(oidNamedCurveP384):
return elliptic.P384()
case oid.Equal(oidNamedCurveP521):
return elliptic.P521()
}
return nil
}
func oidFromNamedCurve(curve elliptic.Curve) (asn1.ObjectIdentifier, bool) {
switch curve {
case elliptic.P224():
return oidNamedCurveP224, true
case elliptic.P256():
return oidNamedCurveP256, true
case elliptic.P384():
return oidNamedCurveP384, true
case elliptic.P521():
return oidNamedCurveP521, true
}
return nil, false
}
func oidFromECDHCurve(curve ecdh.Curve) (asn1.ObjectIdentifier, bool) {
switch curve {
case ecdh.X25519():
return oidPublicKeyX25519, true
case ecdh.P256():
return oidNamedCurveP256, true
case ecdh.P384():
return oidNamedCurveP384, true
case ecdh.P521():
return oidNamedCurveP521, true
}
return nil, false
}
// KeyUsage represents the set of actions that are valid for a given key. It's
// a bitmap of the KeyUsage* constants.
type KeyUsage int
const (
KeyUsageDigitalSignature KeyUsage = 1 << iota
KeyUsageContentCommitment
KeyUsageKeyEncipherment
KeyUsageDataEncipherment
KeyUsageKeyAgreement
KeyUsageCertSign
KeyUsageCRLSign
KeyUsageEncipherOnly
KeyUsageDecipherOnly
)
// RFC 5280, 4.2.1.12 Extended Key Usage
//
// anyExtendedKeyUsage OBJECT IDENTIFIER ::= { id-ce-extKeyUsage 0 }
//
// id-kp OBJECT IDENTIFIER ::= { id-pkix 3 }
//
// id-kp-serverAuth OBJECT IDENTIFIER ::= { id-kp 1 }
// id-kp-clientAuth OBJECT IDENTIFIER ::= { id-kp 2 }
// id-kp-codeSigning OBJECT IDENTIFIER ::= { id-kp 3 }
// id-kp-emailProtection OBJECT IDENTIFIER ::= { id-kp 4 }
// id-kp-timeStamping OBJECT IDENTIFIER ::= { id-kp 8 }
// id-kp-OCSPSigning OBJECT IDENTIFIER ::= { id-kp 9 }
var (
oidExtKeyUsageAny = asn1.ObjectIdentifier{2, 5, 29, 37, 0}
oidExtKeyUsageServerAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 1}
oidExtKeyUsageClientAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 2}
oidExtKeyUsageCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 3}
oidExtKeyUsageEmailProtection = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 4}
oidExtKeyUsageIPSECEndSystem = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 5}
oidExtKeyUsageIPSECTunnel = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 6}
oidExtKeyUsageIPSECUser = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 7}
oidExtKeyUsageTimeStamping = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 8}
oidExtKeyUsageOCSPSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 9}
oidExtKeyUsageMicrosoftServerGatedCrypto = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 10, 3, 3}
oidExtKeyUsageNetscapeServerGatedCrypto = asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 4, 1}
oidExtKeyUsageMicrosoftCommercialCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 2, 1, 22}
oidExtKeyUsageMicrosoftKernelCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 61, 1, 1}
)
// ExtKeyUsage represents an extended set of actions that are valid for a given key.
// Each of the ExtKeyUsage* constants define a unique action.
type ExtKeyUsage int
const (
ExtKeyUsageAny ExtKeyUsage = iota
ExtKeyUsageServerAuth
ExtKeyUsageClientAuth
ExtKeyUsageCodeSigning
ExtKeyUsageEmailProtection
ExtKeyUsageIPSECEndSystem
ExtKeyUsageIPSECTunnel
ExtKeyUsageIPSECUser
ExtKeyUsageTimeStamping
ExtKeyUsageOCSPSigning
ExtKeyUsageMicrosoftServerGatedCrypto
ExtKeyUsageNetscapeServerGatedCrypto
ExtKeyUsageMicrosoftCommercialCodeSigning
ExtKeyUsageMicrosoftKernelCodeSigning
)
// extKeyUsageOIDs contains the mapping between an ExtKeyUsage and its OID.
var extKeyUsageOIDs = []struct {
extKeyUsage ExtKeyUsage
oid asn1.ObjectIdentifier
}{
{ExtKeyUsageAny, oidExtKeyUsageAny},
{ExtKeyUsageServerAuth, oidExtKeyUsageServerAuth},
{ExtKeyUsageClientAuth, oidExtKeyUsageClientAuth},
{ExtKeyUsageCodeSigning, oidExtKeyUsageCodeSigning},
{ExtKeyUsageEmailProtection, oidExtKeyUsageEmailProtection},
{ExtKeyUsageIPSECEndSystem, oidExtKeyUsageIPSECEndSystem},
{ExtKeyUsageIPSECTunnel, oidExtKeyUsageIPSECTunnel},
{ExtKeyUsageIPSECUser, oidExtKeyUsageIPSECUser},
{ExtKeyUsageTimeStamping, oidExtKeyUsageTimeStamping},
{ExtKeyUsageOCSPSigning, oidExtKeyUsageOCSPSigning},
{ExtKeyUsageMicrosoftServerGatedCrypto, oidExtKeyUsageMicrosoftServerGatedCrypto},
{ExtKeyUsageNetscapeServerGatedCrypto, oidExtKeyUsageNetscapeServerGatedCrypto},
{ExtKeyUsageMicrosoftCommercialCodeSigning, oidExtKeyUsageMicrosoftCommercialCodeSigning},
{ExtKeyUsageMicrosoftKernelCodeSigning, oidExtKeyUsageMicrosoftKernelCodeSigning},
}
func extKeyUsageFromOID(oid asn1.ObjectIdentifier) (eku ExtKeyUsage, ok bool) {
for _, pair := range extKeyUsageOIDs {
if oid.Equal(pair.oid) {
return pair.extKeyUsage, true
}
}
return
}
func oidFromExtKeyUsage(eku ExtKeyUsage) (oid asn1.ObjectIdentifier, ok bool) {
for _, pair := range extKeyUsageOIDs {
if eku == pair.extKeyUsage {
return pair.oid, true
}
}
return
}
// A Certificate represents an X.509 certificate.
type Certificate struct {
Raw []byte // Complete ASN.1 DER content (certificate, signature algorithm and signature).
RawTBSCertificate []byte // Certificate part of raw ASN.1 DER content.
RawSubjectPublicKeyInfo []byte // DER encoded SubjectPublicKeyInfo.
RawSubject []byte // DER encoded Subject
RawIssuer []byte // DER encoded Issuer
Signature []byte
SignatureAlgorithm SignatureAlgorithm
PublicKeyAlgorithm PublicKeyAlgorithm
PublicKey any
Version int
SerialNumber *big.Int
Issuer pkix.Name
Subject pkix.Name
NotBefore, NotAfter time.Time // Validity bounds.
KeyUsage KeyUsage
// Extensions contains raw X.509 extensions. When parsing certificates,
// this can be used to extract non-critical extensions that are not
// parsed by this package. When marshaling certificates, the Extensions
// field is ignored, see ExtraExtensions.
Extensions []pkix.Extension
// ExtraExtensions contains extensions to be copied, raw, into any
// marshaled certificates. Values override any extensions that would
// otherwise be produced based on the other fields. The ExtraExtensions
// field is not populated when parsing certificates, see Extensions.
ExtraExtensions []pkix.Extension
// UnhandledCriticalExtensions contains a list of extension IDs that
// were not (fully) processed when parsing. Verify will fail if this
// slice is non-empty, unless verification is delegated to an OS
// library which understands all the critical extensions.
//
// Users can access these extensions using Extensions and can remove
// elements from this slice if they believe that they have been
// handled.
UnhandledCriticalExtensions []asn1.ObjectIdentifier
ExtKeyUsage []ExtKeyUsage // Sequence of extended key usages.
UnknownExtKeyUsage []asn1.ObjectIdentifier // Encountered extended key usages unknown to this package.
// BasicConstraintsValid indicates whether IsCA, MaxPathLen,
// and MaxPathLenZero are valid.
BasicConstraintsValid bool
IsCA bool
// MaxPathLen and MaxPathLenZero indicate the presence and
// value of the BasicConstraints' "pathLenConstraint".
//
// When parsing a certificate, a positive non-zero MaxPathLen
// means that the field was specified, -1 means it was unset,
// and MaxPathLenZero being true mean that the field was
// explicitly set to zero. The case of MaxPathLen==0 with MaxPathLenZero==false
// should be treated equivalent to -1 (unset).
//
// When generating a certificate, an unset pathLenConstraint
// can be requested with either MaxPathLen == -1 or using the
// zero value for both MaxPathLen and MaxPathLenZero.
MaxPathLen int
// MaxPathLenZero indicates that BasicConstraintsValid==true
// and MaxPathLen==0 should be interpreted as an actual
// maximum path length of zero. Otherwise, that combination is
// interpreted as MaxPathLen not being set.
MaxPathLenZero bool
SubjectKeyId []byte
AuthorityKeyId []byte
// RFC 5280, 4.2.2.1 (Authority Information Access)
OCSPServer []string
IssuingCertificateURL []string
// Subject Alternate Name values. (Note that these values may not be valid
// if invalid values were contained within a parsed certificate. For
// example, an element of DNSNames may not be a valid DNS domain name.)
DNSNames []string
EmailAddresses []string
IPAddresses []net.IP
URIs []*url.URL
// Name constraints
PermittedDNSDomainsCritical bool // if true then the name constraints are marked critical.
PermittedDNSDomains []string
ExcludedDNSDomains []string
PermittedIPRanges []*net.IPNet
ExcludedIPRanges []*net.IPNet
PermittedEmailAddresses []string
ExcludedEmailAddresses []string
PermittedURIDomains []string
ExcludedURIDomains []string
// CRL Distribution Points
CRLDistributionPoints []string
// PolicyIdentifiers contains asn1.ObjectIdentifiers, the components
// of which are limited to int32. If a certificate contains a policy which
// cannot be represented by asn1.ObjectIdentifier, it will not be included in
// PolicyIdentifiers, but will be present in Policies, which contains all parsed
// policy OIDs.
// See CreateCertificate for context about how this field and the Policies field
// interact.
PolicyIdentifiers []asn1.ObjectIdentifier
// Policies contains all policy identifiers included in the certificate.
// See CreateCertificate for context about how this field and the PolicyIdentifiers field
// interact.
// In Go 1.22, encoding/gob cannot handle and ignores this field.
Policies []OID
// InhibitAnyPolicy and InhibitAnyPolicyZero indicate the presence and value
// of the inhibitAnyPolicy extension.
//
// The value of InhibitAnyPolicy indicates the number of additional
// certificates in the path after this certificate that may use the
// anyPolicy policy OID to indicate a match with any other policy.
//
// When parsing a certificate, a positive non-zero InhibitAnyPolicy means
// that the field was specified, -1 means it was unset, and
// InhibitAnyPolicyZero being true mean that the field was explicitly set to
// zero. The case of InhibitAnyPolicy==0 with InhibitAnyPolicyZero==false
// should be treated equivalent to -1 (unset).
InhibitAnyPolicy int
// InhibitAnyPolicyZero indicates that InhibitAnyPolicy==0 should be
// interpreted as an actual maximum path length of zero. Otherwise, that
// combination is interpreted as InhibitAnyPolicy not being set.
InhibitAnyPolicyZero bool
// InhibitPolicyMapping and InhibitPolicyMappingZero indicate the presence
// and value of the inhibitPolicyMapping field of the policyConstraints
// extension.
//
// The value of InhibitPolicyMapping indicates the number of additional
// certificates in the path after this certificate that may use policy
// mapping.
//
// When parsing a certificate, a positive non-zero InhibitPolicyMapping
// means that the field was specified, -1 means it was unset, and
// InhibitPolicyMappingZero being true mean that the field was explicitly
// set to zero. The case of InhibitPolicyMapping==0 with
// InhibitPolicyMappingZero==false should be treated equivalent to -1
// (unset).
InhibitPolicyMapping int
// InhibitPolicyMappingZero indicates that InhibitPolicyMapping==0 should be
// interpreted as an actual maximum path length of zero. Otherwise, that
// combination is interpreted as InhibitAnyPolicy not being set.
InhibitPolicyMappingZero bool
// RequireExplicitPolicy and RequireExplicitPolicyZero indicate the presence
// and value of the requireExplicitPolicy field of the policyConstraints
// extension.
//
// The value of RequireExplicitPolicy indicates the number of additional
// certificates in the path after this certificate before an explicit policy
// is required for the rest of the path. When an explicit policy is required,
// each subsequent certificate in the path must contain a required policy OID,
// or a policy OID which has been declared as equivalent through the policy
// mapping extension.
//
// When parsing a certificate, a positive non-zero RequireExplicitPolicy
// means that the field was specified, -1 means it was unset, and
// RequireExplicitPolicyZero being true mean that the field was explicitly
// set to zero. The case of RequireExplicitPolicy==0 with
// RequireExplicitPolicyZero==false should be treated equivalent to -1
// (unset).
RequireExplicitPolicy int
// RequireExplicitPolicyZero indicates that RequireExplicitPolicy==0 should be
// interpreted as an actual maximum path length of zero. Otherwise, that
// combination is interpreted as InhibitAnyPolicy not being set.
RequireExplicitPolicyZero bool
// PolicyMappings contains a list of policy mappings included in the certificate.
PolicyMappings []PolicyMapping
}
// PolicyMapping represents a policy mapping entry in the policyMappings extension.
type PolicyMapping struct {
// IssuerDomainPolicy contains a policy OID the issuing certificate considers
// equivalent to SubjectDomainPolicy in the subject certificate.
IssuerDomainPolicy OID
// SubjectDomainPolicy contains a OID the issuing certificate considers
// equivalent to IssuerDomainPolicy in the subject certificate.
SubjectDomainPolicy OID
}
// ErrUnsupportedAlgorithm results from attempting to perform an operation that
// involves algorithms that are not currently implemented.
var ErrUnsupportedAlgorithm = errors.New("x509: cannot verify signature: algorithm unimplemented")
// An InsecureAlgorithmError indicates that the [SignatureAlgorithm] used to
// generate the signature is not secure, and the signature has been rejected.
type InsecureAlgorithmError SignatureAlgorithm
func (e InsecureAlgorithmError) Error() string {
return fmt.Sprintf("x509: cannot verify signature: insecure algorithm %v", SignatureAlgorithm(e))
}
// ConstraintViolationError results when a requested usage is not permitted by
// a certificate. For example: checking a signature when the public key isn't a
// certificate signing key.
type ConstraintViolationError struct{}
func (ConstraintViolationError) Error() string {
return "x509: invalid signature: parent certificate cannot sign this kind of certificate"
}
func (c *Certificate) Equal(other *Certificate) bool {
if c == nil || other == nil {
return c == other
}
return bytes.Equal(c.Raw, other.Raw)
}
func (c *Certificate) hasSANExtension() bool {
return oidInExtensions(oidExtensionSubjectAltName, c.Extensions)
}
// CheckSignatureFrom verifies that the signature on c is a valid signature from parent.
//
// This is a low-level API that performs very limited checks, and not a full
// path verifier. Most users should use [Certificate.Verify] instead.
func (c *Certificate) CheckSignatureFrom(parent *Certificate) error {
// RFC 5280, 4.2.1.9:
// "If the basic constraints extension is not present in a version 3
// certificate, or the extension is present but the cA boolean is not
// asserted, then the certified public key MUST NOT be used to verify
// certificate signatures."
if parent.Version == 3 && !parent.BasicConstraintsValid ||
parent.BasicConstraintsValid && !parent.IsCA {
return ConstraintViolationError{}
}
if parent.KeyUsage != 0 && parent.KeyUsage&KeyUsageCertSign == 0 {
return ConstraintViolationError{}
}
if parent.PublicKeyAlgorithm == UnknownPublicKeyAlgorithm {
return ErrUnsupportedAlgorithm
}
return checkSignature(c.SignatureAlgorithm, c.RawTBSCertificate, c.Signature, parent.PublicKey, false)
}
// CheckSignature verifies that signature is a valid signature over signed from
// c's public key.
//
// This is a low-level API that performs no validity checks on the certificate.
//
// [MD5WithRSA] signatures are rejected, while [SHA1WithRSA] and [ECDSAWithSHA1]
// signatures are currently accepted.
func (c *Certificate) CheckSignature(algo SignatureAlgorithm, signed, signature []byte) error {
return checkSignature(algo, signed, signature, c.PublicKey, true)
}
func (c *Certificate) hasNameConstraints() bool {
return oidInExtensions(oidExtensionNameConstraints, c.Extensions)
}
func (c *Certificate) getSANExtension() []byte {
for _, e := range c.Extensions {
if e.Id.Equal(oidExtensionSubjectAltName) {
return e.Value
}
}
return nil
}
func signaturePublicKeyAlgoMismatchError(expectedPubKeyAlgo PublicKeyAlgorithm, pubKey any) error {
return fmt.Errorf("x509: signature algorithm specifies an %s public key, but have public key of type %T", expectedPubKeyAlgo.String(), pubKey)
}
// checkSignature verifies that signature is a valid signature over signed from
// a crypto.PublicKey.
func checkSignature(algo SignatureAlgorithm, signed, signature []byte, publicKey crypto.PublicKey, allowSHA1 bool) (err error) {
var hashType crypto.Hash
var pubKeyAlgo PublicKeyAlgorithm
for _, details := range signatureAlgorithmDetails {
if details.algo == algo {
hashType = details.hash
pubKeyAlgo = details.pubKeyAlgo
break
}
}
switch hashType {
case crypto.Hash(0):
if pubKeyAlgo != Ed25519 {
return ErrUnsupportedAlgorithm
}
case crypto.MD5:
return InsecureAlgorithmError(algo)
case crypto.SHA1:
// SHA-1 signatures are only allowed for CRLs and CSRs.
if !allowSHA1 {
return InsecureAlgorithmError(algo)
}
fallthrough
default:
if !hashType.Available() {
return ErrUnsupportedAlgorithm
}
h := hashType.New()
h.Write(signed)
signed = h.Sum(nil)
}
switch pub := publicKey.(type) {
case *rsa.PublicKey:
if pubKeyAlgo != RSA {
return signaturePublicKeyAlgoMismatchError(pubKeyAlgo, pub)
}
if algo.isRSAPSS() {
return rsa.VerifyPSS(pub, hashType, signed, signature, &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash})
} else {
return rsa.VerifyPKCS1v15(pub, hashType, signed, signature)
}
case *ecdsa.PublicKey:
if pubKeyAlgo != ECDSA {
return signaturePublicKeyAlgoMismatchError(pubKeyAlgo, pub)
}
if !ecdsa.VerifyASN1(pub, signed, signature) {
return errors.New("x509: ECDSA verification failure")
}
return
case ed25519.PublicKey:
if pubKeyAlgo != Ed25519 {
return signaturePublicKeyAlgoMismatchError(pubKeyAlgo, pub)
}
if !ed25519.Verify(pub, signed, signature) {
return errors.New("x509: Ed25519 verification failure")
}
return
}
return ErrUnsupportedAlgorithm
}
// CheckCRLSignature checks that the signature in crl is from c.
//
// Deprecated: Use [RevocationList.CheckSignatureFrom] instead.
func (c *Certificate) CheckCRLSignature(crl *pkix.CertificateList) error {
algo := getSignatureAlgorithmFromAI(crl.SignatureAlgorithm)
return c.CheckSignature(algo, crl.TBSCertList.Raw, crl.SignatureValue.RightAlign())
}
type UnhandledCriticalExtension struct{}
func (h UnhandledCriticalExtension) Error() string {
return "x509: unhandled critical extension"
}
type basicConstraints struct {
IsCA bool `asn1:"optional"`
MaxPathLen int `asn1:"optional,default:-1"`
}
// RFC 5280 4.2.1.4
type policyInformation struct {
Policy asn1.ObjectIdentifier
// policyQualifiers omitted
}
const (
nameTypeEmail = 1
nameTypeDNS = 2
nameTypeURI = 6
nameTypeIP = 7
)
// RFC 5280, 4.2.2.1
type authorityInfoAccess struct {
Method asn1.ObjectIdentifier
Location asn1.RawValue
}
// RFC 5280, 4.2.1.14
type distributionPoint struct {
DistributionPoint distributionPointName `asn1:"optional,tag:0"`
Reason asn1.BitString `asn1:"optional,tag:1"`
CRLIssuer asn1.RawValue `asn1:"optional,tag:2"`
}
type distributionPointName struct {
FullName []asn1.RawValue `asn1:"optional,tag:0"`
RelativeName pkix.RDNSequence `asn1:"optional,tag:1"`
}
func reverseBitsInAByte(in byte) byte {
b1 := in>>4 | in<<4
b2 := b1>>2&0x33 | b1<<2&0xcc
b3 := b2>>1&0x55 | b2<<1&0xaa
return b3
}
// asn1BitLength returns the bit-length of bitString by considering the
// most-significant bit in a byte to be the "first" bit. This convention
// matches ASN.1, but differs from almost everything else.
func asn1BitLength(bitString []byte) int {
bitLen := len(bitString) * 8
for i := range bitString {
b := bitString[len(bitString)-i-1]
for bit := uint(0); bit < 8; bit++ {
if (b>>bit)&1 == 1 {
return bitLen
}
bitLen--
}
}
return 0
}
var (
oidExtensionSubjectKeyId = []int{2, 5, 29, 14}
oidExtensionKeyUsage = []int{2, 5, 29, 15}
oidExtensionExtendedKeyUsage = []int{2, 5, 29, 37}
oidExtensionAuthorityKeyId = []int{2, 5, 29, 35}
oidExtensionBasicConstraints = []int{2, 5, 29, 19}
oidExtensionSubjectAltName = []int{2, 5, 29, 17}
oidExtensionCertificatePolicies = []int{2, 5, 29, 32}
oidExtensionNameConstraints = []int{2, 5, 29, 30}
oidExtensionCRLDistributionPoints = []int{2, 5, 29, 31}
oidExtensionAuthorityInfoAccess = []int{1, 3, 6, 1, 5, 5, 7, 1, 1}
oidExtensionCRLNumber = []int{2, 5, 29, 20}
oidExtensionReasonCode = []int{2, 5, 29, 21}
)
var (
oidAuthorityInfoAccessOcsp = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 48, 1}
oidAuthorityInfoAccessIssuers = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 48, 2}
)
// oidInExtensions reports whether an extension with the given oid exists in
// extensions.
func oidInExtensions(oid asn1.ObjectIdentifier, extensions []pkix.Extension) bool {
for _, e := range extensions {
if e.Id.Equal(oid) {
return true
}
}
return false
}
// marshalSANs marshals a list of addresses into a the contents of an X.509
// SubjectAlternativeName extension.
func marshalSANs(dnsNames, emailAddresses []string, ipAddresses []net.IP, uris []*url.URL) (derBytes []byte, err error) {
var rawValues []asn1.RawValue
for _, name := range dnsNames {
if err := isIA5String(name); err != nil {
return nil, err
}
rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeDNS, Class: 2, Bytes: []byte(name)})
}
for _, email := range emailAddresses {
if err := isIA5String(email); err != nil {
return nil, err
}
rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeEmail, Class: 2, Bytes: []byte(email)})
}
for _, rawIP := range ipAddresses {
// If possible, we always want to encode IPv4 addresses in 4 bytes.
ip := rawIP.To4()
if ip == nil {
ip = rawIP
}
rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeIP, Class: 2, Bytes: ip})
}
for _, uri := range uris {
uriStr := uri.String()
if err := isIA5String(uriStr); err != nil {
return nil, err
}
rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeURI, Class: 2, Bytes: []byte(uriStr)})
}
return asn1.Marshal(rawValues)
}
func isIA5String(s string) error {
for _, r := range s {
// Per RFC5280 "IA5String is limited to the set of ASCII characters"
if r > unicode.MaxASCII {
return fmt.Errorf("x509: %q cannot be encoded as an IA5String", s)
}
}
return nil
}
var x509usepolicies = godebug.New("x509usepolicies")
func buildCertExtensions(template *Certificate, subjectIsEmpty bool, authorityKeyId []byte, subjectKeyId []byte) (ret []pkix.Extension, err error) {
ret = make([]pkix.Extension, 10 /* maximum number of elements. */)
n := 0
if template.KeyUsage != 0 &&
!oidInExtensions(oidExtensionKeyUsage, template.ExtraExtensions) {
ret[n], err = marshalKeyUsage(template.KeyUsage)
if err != nil {
return nil, err
}
n++
}
if (len(template.ExtKeyUsage) > 0 || len(template.UnknownExtKeyUsage) > 0) &&
!oidInExtensions(oidExtensionExtendedKeyUsage, template.ExtraExtensions) {
ret[n], err = marshalExtKeyUsage(template.ExtKeyUsage, template.UnknownExtKeyUsage)
if err != nil {
return nil, err
}
n++
}
if template.BasicConstraintsValid && !oidInExtensions(oidExtensionBasicConstraints, template.ExtraExtensions) {
ret[n], err = marshalBasicConstraints(template.IsCA, template.MaxPathLen, template.MaxPathLenZero)
if err != nil {
return nil, err
}
n++
}
if len(subjectKeyId) > 0 && !oidInExtensions(oidExtensionSubjectKeyId, template.ExtraExtensions) {
ret[n].Id = oidExtensionSubjectKeyId
ret[n].Value, err = asn1.Marshal(subjectKeyId)
if err != nil {
return
}
n++
}
if len(authorityKeyId) > 0 && !oidInExtensions(oidExtensionAuthorityKeyId, template.ExtraExtensions) {
ret[n].Id = oidExtensionAuthorityKeyId
ret[n].Value, err = asn1.Marshal(authKeyId{authorityKeyId})
if err != nil {
return
}
n++
}
if (len(template.OCSPServer) > 0 || len(template.IssuingCertificateURL) > 0) &&
!oidInExtensions(oidExtensionAuthorityInfoAccess, template.ExtraExtensions) {
ret[n].Id = oidExtensionAuthorityInfoAccess
var aiaValues []authorityInfoAccess
for _, name := range template.OCSPServer {
aiaValues = append(aiaValues, authorityInfoAccess{
Method: oidAuthorityInfoAccessOcsp,
Location: asn1.RawValue{Tag: 6, Class: 2, Bytes: []byte(name)},
})
}
for _, name := range template.IssuingCertificateURL {
aiaValues = append(aiaValues, authorityInfoAccess{
Method: oidAuthorityInfoAccessIssuers,
Location: asn1.RawValue{Tag: 6, Class: 2, Bytes: []byte(name)},
})
}
ret[n].Value, err = asn1.Marshal(aiaValues)
if err != nil {
return
}
n++
}
if (len(template.DNSNames) > 0 || len(template.EmailAddresses) > 0 || len(template.IPAddresses) > 0 || len(template.URIs) > 0) &&
!oidInExtensions(oidExtensionSubjectAltName, template.ExtraExtensions) {
ret[n].Id = oidExtensionSubjectAltName
// From RFC 5280, Section 4.2.1.6:
// “If the subject field contains an empty sequence ... then
// subjectAltName extension ... is marked as critical”
ret[n].Critical = subjectIsEmpty
ret[n].Value, err = marshalSANs(template.DNSNames, template.EmailAddresses, template.IPAddresses, template.URIs)
if err != nil {
return
}
n++
}
usePolicies := x509usepolicies.Value() != "0"
if ((!usePolicies && len(template.PolicyIdentifiers) > 0) || (usePolicies && len(template.Policies) > 0)) &&
!oidInExtensions(oidExtensionCertificatePolicies, template.ExtraExtensions) {
ret[n], err = marshalCertificatePolicies(template.Policies, template.PolicyIdentifiers)
if err != nil {
return nil, err
}
n++
}
if (len(template.PermittedDNSDomains) > 0 || len(template.ExcludedDNSDomains) > 0 ||
len(template.PermittedIPRanges) > 0 || len(template.ExcludedIPRanges) > 0 ||
len(template.PermittedEmailAddresses) > 0 || len(template.ExcludedEmailAddresses) > 0 ||
len(template.PermittedURIDomains) > 0 || len(template.ExcludedURIDomains) > 0) &&
!oidInExtensions(oidExtensionNameConstraints, template.ExtraExtensions) {
ret[n].Id = oidExtensionNameConstraints
ret[n].Critical = template.PermittedDNSDomainsCritical
ipAndMask := func(ipNet *net.IPNet) []byte {
maskedIP := ipNet.IP.Mask(ipNet.Mask)
ipAndMask := make([]byte, 0, len(maskedIP)+len(ipNet.Mask))
ipAndMask = append(ipAndMask, maskedIP...)
ipAndMask = append(ipAndMask, ipNet.Mask...)
return ipAndMask
}
serialiseConstraints := func(dns []string, ips []*net.IPNet, emails []string, uriDomains []string) (der []byte, err error) {
var b cryptobyte.Builder
for _, name := range dns {
if err = isIA5String(name); err != nil {
return nil, err
}
b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) {
b.AddASN1(cryptobyte_asn1.Tag(2).ContextSpecific(), func(b *cryptobyte.Builder) {
b.AddBytes([]byte(name))
})
})
}
for _, ipNet := range ips {
b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) {
b.AddASN1(cryptobyte_asn1.Tag(7).ContextSpecific(), func(b *cryptobyte.Builder) {
b.AddBytes(ipAndMask(ipNet))
})
})
}
for _, email := range emails {
if err = isIA5String(email); err != nil {
return nil, err
}
b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) {
b.AddASN1(cryptobyte_asn1.Tag(1).ContextSpecific(), func(b *cryptobyte.Builder) {
b.AddBytes([]byte(email))
})
})
}
for _, uriDomain := range uriDomains {
if err = isIA5String(uriDomain); err != nil {
return nil, err
}
b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) {
b.AddASN1(cryptobyte_asn1.Tag(6).ContextSpecific(), func(b *cryptobyte.Builder) {
b.AddBytes([]byte(uriDomain))
})
})
}
return b.Bytes()
}
permitted, err := serialiseConstraints(template.PermittedDNSDomains, template.PermittedIPRanges, template.PermittedEmailAddresses, template.PermittedURIDomains)
if err != nil {
return nil, err
}
excluded, err := serialiseConstraints(template.ExcludedDNSDomains, template.ExcludedIPRanges, template.ExcludedEmailAddresses, template.ExcludedURIDomains)
if err != nil {
return nil, err
}
var b cryptobyte.Builder
b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) {
if len(permitted) > 0 {
b.AddASN1(cryptobyte_asn1.Tag(0).ContextSpecific().Constructed(), func(b *cryptobyte.Builder) {
b.AddBytes(permitted)
})
}
if len(excluded) > 0 {
b.AddASN1(cryptobyte_asn1.Tag(1).ContextSpecific().Constructed(), func(b *cryptobyte.Builder) {
b.AddBytes(excluded)
})
}
})
ret[n].Value, err = b.Bytes()
if err != nil {
return nil, err
}
n++
}
if len(template.CRLDistributionPoints) > 0 &&
!oidInExtensions(oidExtensionCRLDistributionPoints, template.ExtraExtensions) {
ret[n].Id = oidExtensionCRLDistributionPoints
var crlDp []distributionPoint
for _, name := range template.CRLDistributionPoints {
dp := distributionPoint{
DistributionPoint: distributionPointName{
FullName: []asn1.RawValue{
{Tag: 6, Class: 2, Bytes: []byte(name)},
},
},
}
crlDp = append(crlDp, dp)
}
ret[n].Value, err = asn1.Marshal(crlDp)
if err != nil {
return
}
n++
}
// Adding another extension here? Remember to update the maximum number
// of elements in the make() at the top of the function and the list of
// template fields used in CreateCertificate documentation.
return append(ret[:n], template.ExtraExtensions...), nil
}
func marshalKeyUsage(ku KeyUsage) (pkix.Extension, error) {
ext := pkix.Extension{Id: oidExtensionKeyUsage, Critical: true}
var a [2]byte
a[0] = reverseBitsInAByte(byte(ku))
a[1] = reverseBitsInAByte(byte(ku >> 8))
l := 1
if a[1] != 0 {
l = 2
}
bitString := a[:l]
var err error
ext.Value, err = asn1.Marshal(asn1.BitString{Bytes: bitString, BitLength: asn1BitLength(bitString)})
return ext, err
}
func marshalExtKeyUsage(extUsages []ExtKeyUsage, unknownUsages []asn1.ObjectIdentifier) (pkix.Extension, error) {
ext := pkix.Extension{Id: oidExtensionExtendedKeyUsage}
oids := make([]asn1.ObjectIdentifier, len(extUsages)+len(unknownUsages))
for i, u := range extUsages {
if oid, ok := oidFromExtKeyUsage(u); ok {
oids[i] = oid
} else {
return ext, errors.New("x509: unknown extended key usage")
}
}
copy(oids[len(extUsages):], unknownUsages)
var err error
ext.Value, err = asn1.Marshal(oids)
return ext, err
}
func marshalBasicConstraints(isCA bool, maxPathLen int, maxPathLenZero bool) (pkix.Extension, error) {
ext := pkix.Extension{Id: oidExtensionBasicConstraints, Critical: true}
// Leaving MaxPathLen as zero indicates that no maximum path
// length is desired, unless MaxPathLenZero is set. A value of
// -1 causes encoding/asn1 to omit the value as desired.
if maxPathLen == 0 && !maxPathLenZero {
maxPathLen = -1
}
var err error
ext.Value, err = asn1.Marshal(basicConstraints{isCA, maxPathLen})
return ext, err
}
func marshalCertificatePolicies(policies []OID, policyIdentifiers []asn1.ObjectIdentifier) (pkix.Extension, error) {
ext := pkix.Extension{Id: oidExtensionCertificatePolicies}
b := cryptobyte.NewBuilder(make([]byte, 0, 128))
b.AddASN1(cryptobyte_asn1.SEQUENCE, func(child *cryptobyte.Builder) {
if x509usepolicies.Value() != "0" {
x509usepolicies.IncNonDefault()
for _, v := range policies {
child.AddASN1(cryptobyte_asn1.SEQUENCE, func(child *cryptobyte.Builder) {
child.AddASN1(cryptobyte_asn1.OBJECT_IDENTIFIER, func(child *cryptobyte.Builder) {
if len(v.der) == 0 {
child.SetError(errors.New("invalid policy object identifier"))
return
}
child.AddBytes(v.der)
})
})
}
} else {
for _, v := range policyIdentifiers {
child.AddASN1(cryptobyte_asn1.SEQUENCE, func(child *cryptobyte.Builder) {
child.AddASN1ObjectIdentifier(v)
})
}
}
})
var err error
ext.Value, err = b.Bytes()
return ext, err
}
func buildCSRExtensions(template *CertificateRequest) ([]pkix.Extension, error) {
var ret []pkix.Extension
if (len(template.DNSNames) > 0 || len(template.EmailAddresses) > 0 || len(template.IPAddresses) > 0 || len(template.URIs) > 0) &&
!oidInExtensions(oidExtensionSubjectAltName, template.ExtraExtensions) {
sanBytes, err := marshalSANs(template.DNSNames, template.EmailAddresses, template.IPAddresses, template.URIs)
if err != nil {
return nil, err
}
ret = append(ret, pkix.Extension{
Id: oidExtensionSubjectAltName,
Value: sanBytes,
})
}
return append(ret, template.ExtraExtensions...), nil
}
func subjectBytes(cert *Certificate) ([]byte, error) {
if len(cert.RawSubject) > 0 {
return cert.RawSubject, nil
}
return asn1.Marshal(cert.Subject.ToRDNSequence())
}
// signingParamsForKey returns the signature algorithm and its Algorithm
// Identifier to use for signing, based on the key type. If sigAlgo is not zero
// then it overrides the default.
func signingParamsForKey(key crypto.Signer, sigAlgo SignatureAlgorithm) (SignatureAlgorithm, pkix.AlgorithmIdentifier, error) {
var ai pkix.AlgorithmIdentifier
var pubType PublicKeyAlgorithm
var defaultAlgo SignatureAlgorithm
switch pub := key.Public().(type) {
case *rsa.PublicKey:
pubType = RSA
defaultAlgo = SHA256WithRSA
case *ecdsa.PublicKey:
pubType = ECDSA
switch pub.Curve {
case elliptic.P224(), elliptic.P256():
defaultAlgo = ECDSAWithSHA256
case elliptic.P384():
defaultAlgo = ECDSAWithSHA384
case elliptic.P521():
defaultAlgo = ECDSAWithSHA512
default:
return 0, ai, errors.New("x509: unsupported elliptic curve")
}
case ed25519.PublicKey:
pubType = Ed25519
defaultAlgo = PureEd25519
default:
return 0, ai, errors.New("x509: only RSA, ECDSA and Ed25519 keys supported")
}
if sigAlgo == 0 {
sigAlgo = defaultAlgo
}
for _, details := range signatureAlgorithmDetails {
if details.algo == sigAlgo {
if details.pubKeyAlgo != pubType {
return 0, ai, errors.New("x509: requested SignatureAlgorithm does not match private key type")
}
if details.hash == crypto.MD5 {
return 0, ai, errors.New("x509: signing with MD5 is not supported")
}
return sigAlgo, pkix.AlgorithmIdentifier{
Algorithm: details.oid,
Parameters: details.params,
}, nil
}
}
return 0, ai, errors.New("x509: unknown SignatureAlgorithm")
}
func signTBS(tbs []byte, key crypto.Signer, sigAlg SignatureAlgorithm, rand io.Reader) ([]byte, error) {
hashFunc := sigAlg.hashFunc()
var signerOpts crypto.SignerOpts = hashFunc
if sigAlg.isRSAPSS() {
signerOpts = &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthEqualsHash,
Hash: hashFunc,
}
}
signature, err := crypto.SignMessage(key, rand, tbs, signerOpts)
if err != nil {
return nil, err
}
// Check the signature to ensure the crypto.Signer behaved correctly.
if err := checkSignature(sigAlg, tbs, signature, key.Public(), true); err != nil {
return nil, fmt.Errorf("x509: signature returned by signer is invalid: %w", err)
}
return signature, nil
}
// emptyASN1Subject is the ASN.1 DER encoding of an empty Subject, which is
// just an empty SEQUENCE.
var emptyASN1Subject = []byte{0x30, 0}
// CreateCertificate creates a new X.509 v3 certificate based on a template.
// The following members of template are currently used:
//
// - AuthorityKeyId
// - BasicConstraintsValid
// - CRLDistributionPoints
// - DNSNames
// - EmailAddresses
// - ExcludedDNSDomains
// - ExcludedEmailAddresses
// - ExcludedIPRanges
// - ExcludedURIDomains
// - ExtKeyUsage
// - ExtraExtensions
// - IPAddresses
// - IsCA
// - IssuingCertificateURL
// - KeyUsage
// - MaxPathLen
// - MaxPathLenZero
// - NotAfter
// - NotBefore
// - OCSPServer
// - PermittedDNSDomains
// - PermittedDNSDomainsCritical
// - PermittedEmailAddresses
// - PermittedIPRanges
// - PermittedURIDomains
// - PolicyIdentifiers (see note below)
// - Policies (see note below)
// - SerialNumber
// - SignatureAlgorithm
// - Subject
// - SubjectKeyId
// - URIs
// - UnknownExtKeyUsage
//
// The certificate is signed by parent. If parent is equal to template then the
// certificate is self-signed. The parameter pub is the public key of the
// certificate to be generated and priv is the private key of the signer.
//
// The returned slice is the certificate in DER encoding.
//
// The currently supported key types are *rsa.PublicKey, *ecdsa.PublicKey and
// ed25519.PublicKey. pub must be a supported key type, and priv must be a
// crypto.Signer or crypto.MessageSigner with a supported public key.
//
// The AuthorityKeyId will be taken from the SubjectKeyId of parent, if any,
// unless the resulting certificate is self-signed. Otherwise the value from
// template will be used.
//
// If SubjectKeyId from template is empty and the template is a CA, SubjectKeyId
// will be generated from the hash of the public key.
//
// If template.SerialNumber is nil, a serial number will be generated which
// conforms to RFC 5280, Section 4.1.2.2 using entropy from rand.
//
// The PolicyIdentifier and Policies fields can both be used to marshal certificate
// policy OIDs. By default, only the Policies is marshaled, but if the
// GODEBUG setting "x509usepolicies" has the value "0", the PolicyIdentifiers field will
// be marshaled instead of the Policies field. This changed in Go 1.24. The Policies field can
// be used to marshal policy OIDs which have components that are larger than 31
// bits.
func CreateCertificate(rand io.Reader, template, parent *Certificate, pub, priv any) ([]byte, error) {
key, ok := priv.(crypto.Signer)
if !ok {
return nil, errors.New("x509: certificate private key does not implement crypto.Signer")
}
serialNumber := template.SerialNumber
if serialNumber == nil {
// Generate a serial number following RFC 5280, Section 4.1.2.2 if one
// is not provided. The serial number must be positive and at most 20
// octets *when encoded*.
serialBytes := make([]byte, 20)
if _, err := io.ReadFull(rand, serialBytes); err != nil {
return nil, err
}
// If the top bit is set, the serial will be padded with a leading zero
// byte during encoding, so that it's not interpreted as a negative
// integer. This padding would make the serial 21 octets so we clear the
// top bit to ensure the correct length in all cases.
serialBytes[0] &= 0b0111_1111
serialNumber = new(big.Int).SetBytes(serialBytes)
}
// RFC 5280 Section 4.1.2.2: serial number must be positive
//
// We _should_ also restrict serials to <= 20 octets, but it turns out a lot of people
// get this wrong, in part because the encoding can itself alter the length of the
// serial. For now we accept these non-conformant serials.
if serialNumber.Sign() == -1 {
return nil, errors.New("x509: serial number must be positive")
}
if template.BasicConstraintsValid && template.MaxPathLen < -1 {
return nil, errors.New("x509: invalid MaxPathLen, must be greater or equal to -1")
}
if template.BasicConstraintsValid && !template.IsCA && template.MaxPathLen != -1 && (template.MaxPathLen != 0 || template.MaxPathLenZero) {
return nil, errors.New("x509: only CAs are allowed to specify MaxPathLen")
}
signatureAlgorithm, algorithmIdentifier, err := signingParamsForKey(key, template.SignatureAlgorithm)
if err != nil {
return nil, err
}
publicKeyBytes, publicKeyAlgorithm, err := marshalPublicKey(pub)
if err != nil {
return nil, err
}
if getPublicKeyAlgorithmFromOID(publicKeyAlgorithm.Algorithm) == UnknownPublicKeyAlgorithm {
return nil, fmt.Errorf("x509: unsupported public key type: %T", pub)
}
asn1Issuer, err := subjectBytes(parent)
if err != nil {
return nil, err
}
asn1Subject, err := subjectBytes(template)
if err != nil {
return nil, err
}
authorityKeyId := template.AuthorityKeyId
if !bytes.Equal(asn1Issuer, asn1Subject) && len(parent.SubjectKeyId) > 0 {
authorityKeyId = parent.SubjectKeyId
}
subjectKeyId := template.SubjectKeyId
if len(subjectKeyId) == 0 && template.IsCA {
if x509sha256skid.Value() == "0" {
x509sha256skid.IncNonDefault()
// SubjectKeyId generated using method 1 in RFC 5280, Section 4.2.1.2:
// (1) The keyIdentifier is composed of the 160-bit SHA-1 hash of the
// value of the BIT STRING subjectPublicKey (excluding the tag,
// length, and number of unused bits).
h := sha1.Sum(publicKeyBytes)
subjectKeyId = h[:]
} else {
// SubjectKeyId generated using method 1 in RFC 7093, Section 2:
// 1) The keyIdentifier is composed of the leftmost 160-bits of the
// SHA-256 hash of the value of the BIT STRING subjectPublicKey
// (excluding the tag, length, and number of unused bits).
h := sha256.Sum256(publicKeyBytes)
subjectKeyId = h[:20]
}
}
// Check that the signer's public key matches the private key, if available.
type privateKey interface {
Equal(crypto.PublicKey) bool
}
if privPub, ok := key.Public().(privateKey); !ok {
return nil, errors.New("x509: internal error: supported public key does not implement Equal")
} else if parent.PublicKey != nil && !privPub.Equal(parent.PublicKey) {
return nil, errors.New("x509: provided PrivateKey doesn't match parent's PublicKey")
}
extensions, err := buildCertExtensions(template, bytes.Equal(asn1Subject, emptyASN1Subject), authorityKeyId, subjectKeyId)
if err != nil {
return nil, err
}
encodedPublicKey := asn1.BitString{BitLength: len(publicKeyBytes) * 8, Bytes: publicKeyBytes}
c := tbsCertificate{
Version: 2,
SerialNumber: serialNumber,
SignatureAlgorithm: algorithmIdentifier,
Issuer: asn1.RawValue{FullBytes: asn1Issuer},
Validity: validity{template.NotBefore.UTC(), template.NotAfter.UTC()},
Subject: asn1.RawValue{FullBytes: asn1Subject},
PublicKey: publicKeyInfo{nil, publicKeyAlgorithm, encodedPublicKey},
Extensions: extensions,
}
tbsCertContents, err := asn1.Marshal(c)
if err != nil {
return nil, err
}
c.Raw = tbsCertContents
signature, err := signTBS(tbsCertContents, key, signatureAlgorithm, rand)
if err != nil {
return nil, err
}
return asn1.Marshal(certificate{
TBSCertificate: c,
SignatureAlgorithm: algorithmIdentifier,
SignatureValue: asn1.BitString{Bytes: signature, BitLength: len(signature) * 8},
})
}
var x509sha256skid = godebug.New("x509sha256skid")
// pemCRLPrefix is the magic string that indicates that we have a PEM encoded
// CRL.
var pemCRLPrefix = []byte("-----BEGIN X509 CRL")
// pemType is the type of a PEM encoded CRL.
var pemType = "X509 CRL"
// ParseCRL parses a CRL from the given bytes. It's often the case that PEM
// encoded CRLs will appear where they should be DER encoded, so this function
// will transparently handle PEM encoding as long as there isn't any leading
// garbage.
//
// Deprecated: Use [ParseRevocationList] instead.
func ParseCRL(crlBytes []byte) (*pkix.CertificateList, error) {
if bytes.HasPrefix(crlBytes, pemCRLPrefix) {
block, _ := pem.Decode(crlBytes)
if block != nil && block.Type == pemType {
crlBytes = block.Bytes
}
}
return ParseDERCRL(crlBytes)
}
// ParseDERCRL parses a DER encoded CRL from the given bytes.
//
// Deprecated: Use [ParseRevocationList] instead.
func ParseDERCRL(derBytes []byte) (*pkix.CertificateList, error) {
certList := new(pkix.CertificateList)
if rest, err := asn1.Unmarshal(derBytes, certList); err != nil {
return nil, err
} else if len(rest) != 0 {
return nil, errors.New("x509: trailing data after CRL")
}
return certList, nil
}
// CreateCRL returns a DER encoded CRL, signed by this Certificate, that
// contains the given list of revoked certificates.
//
// Deprecated: this method does not generate an RFC 5280 conformant X.509 v2 CRL.
// To generate a standards compliant CRL, use [CreateRevocationList] instead.
func (c *Certificate) CreateCRL(rand io.Reader, priv any, revokedCerts []pkix.RevokedCertificate, now, expiry time.Time) (crlBytes []byte, err error) {
key, ok := priv.(crypto.Signer)
if !ok {
return nil, errors.New("x509: certificate private key does not implement crypto.Signer")
}
signatureAlgorithm, algorithmIdentifier, err := signingParamsForKey(key, 0)
if err != nil {
return nil, err
}
// Force revocation times to UTC per RFC 5280.
revokedCertsUTC := make([]pkix.RevokedCertificate, len(revokedCerts))
for i, rc := range revokedCerts {
rc.RevocationTime = rc.RevocationTime.UTC()
revokedCertsUTC[i] = rc
}
tbsCertList := pkix.TBSCertificateList{
Version: 1,
Signature: algorithmIdentifier,
Issuer: c.Subject.ToRDNSequence(),
ThisUpdate: now.UTC(),
NextUpdate: expiry.UTC(),
RevokedCertificates: revokedCertsUTC,
}
// Authority Key Id
if len(c.SubjectKeyId) > 0 {
var aki pkix.Extension
aki.Id = oidExtensionAuthorityKeyId
aki.Value, err = asn1.Marshal(authKeyId{Id: c.SubjectKeyId})
if err != nil {
return nil, err
}
tbsCertList.Extensions = append(tbsCertList.Extensions, aki)
}
tbsCertListContents, err := asn1.Marshal(tbsCertList)
if err != nil {
return nil, err
}
tbsCertList.Raw = tbsCertListContents
signature, err := signTBS(tbsCertListContents, key, signatureAlgorithm, rand)
if err != nil {
return nil, err
}
return asn1.Marshal(pkix.CertificateList{
TBSCertList: tbsCertList,
SignatureAlgorithm: algorithmIdentifier,
SignatureValue: asn1.BitString{Bytes: signature, BitLength: len(signature) * 8},
})
}
// CertificateRequest represents a PKCS #10, certificate signature request.
type CertificateRequest struct {
Raw []byte // Complete ASN.1 DER content (CSR, signature algorithm and signature).
RawTBSCertificateRequest []byte // Certificate request info part of raw ASN.1 DER content.
RawSubjectPublicKeyInfo []byte // DER encoded SubjectPublicKeyInfo.
RawSubject []byte // DER encoded Subject.
Version int
Signature []byte
SignatureAlgorithm SignatureAlgorithm
PublicKeyAlgorithm PublicKeyAlgorithm
PublicKey any
Subject pkix.Name
// Attributes contains the CSR attributes that can parse as
// pkix.AttributeTypeAndValueSET.
//
// Deprecated: Use Extensions and ExtraExtensions instead for parsing and
// generating the requestedExtensions attribute.
Attributes []pkix.AttributeTypeAndValueSET
// Extensions contains all requested extensions, in raw form. When parsing
// CSRs, this can be used to extract extensions that are not parsed by this
// package.
Extensions []pkix.Extension
// ExtraExtensions contains extensions to be copied, raw, into any CSR
// marshaled by CreateCertificateRequest. Values override any extensions
// that would otherwise be produced based on the other fields but are
// overridden by any extensions specified in Attributes.
//
// The ExtraExtensions field is not populated by ParseCertificateRequest,
// see Extensions instead.
ExtraExtensions []pkix.Extension
// Subject Alternate Name values.
DNSNames []string
EmailAddresses []string
IPAddresses []net.IP
URIs []*url.URL
}
// These structures reflect the ASN.1 structure of X.509 certificate
// signature requests (see RFC 2986):
type tbsCertificateRequest struct {
Raw asn1.RawContent
Version int
Subject asn1.RawValue
PublicKey publicKeyInfo
RawAttributes []asn1.RawValue `asn1:"tag:0"`
}
type certificateRequest struct {
Raw asn1.RawContent
TBSCSR tbsCertificateRequest
SignatureAlgorithm pkix.AlgorithmIdentifier
SignatureValue asn1.BitString
}
// oidExtensionRequest is a PKCS #9 OBJECT IDENTIFIER that indicates requested
// extensions in a CSR.
var oidExtensionRequest = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 14}
// newRawAttributes converts AttributeTypeAndValueSETs from a template
// CertificateRequest's Attributes into tbsCertificateRequest RawAttributes.
func newRawAttributes(attributes []pkix.AttributeTypeAndValueSET) ([]asn1.RawValue, error) {
var rawAttributes []asn1.RawValue
b, err := asn1.Marshal(attributes)
if err != nil {
return nil, err
}
rest, err := asn1.Unmarshal(b, &rawAttributes)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, errors.New("x509: failed to unmarshal raw CSR Attributes")
}
return rawAttributes, nil
}
// parseRawAttributes Unmarshals RawAttributes into AttributeTypeAndValueSETs.
func parseRawAttributes(rawAttributes []asn1.RawValue) []pkix.AttributeTypeAndValueSET {
var attributes []pkix.AttributeTypeAndValueSET
for _, rawAttr := range rawAttributes {
var attr pkix.AttributeTypeAndValueSET
rest, err := asn1.Unmarshal(rawAttr.FullBytes, &attr)
// Ignore attributes that don't parse into pkix.AttributeTypeAndValueSET
// (i.e.: challengePassword or unstructuredName).
if err == nil && len(rest) == 0 {
attributes = append(attributes, attr)
}
}
return attributes
}
// parseCSRExtensions parses the attributes from a CSR and extracts any
// requested extensions.
func parseCSRExtensions(rawAttributes []asn1.RawValue) ([]pkix.Extension, error) {
// pkcs10Attribute reflects the Attribute structure from RFC 2986, Section 4.1.
type pkcs10Attribute struct {
Id asn1.ObjectIdentifier
Values []asn1.RawValue `asn1:"set"`
}
var ret []pkix.Extension
requestedExts := make(map[string]bool)
for _, rawAttr := range rawAttributes {
var attr pkcs10Attribute
if rest, err := asn1.Unmarshal(rawAttr.FullBytes, &attr); err != nil || len(rest) != 0 || len(attr.Values) == 0 {
// Ignore attributes that don't parse.
continue
}
if !attr.Id.Equal(oidExtensionRequest) {
continue
}
var extensions []pkix.Extension
if _, err := asn1.Unmarshal(attr.Values[0].FullBytes, &extensions); err != nil {
return nil, err
}
for _, ext := range extensions {
oidStr := ext.Id.String()
if requestedExts[oidStr] {
return nil, errors.New("x509: certificate request contains duplicate requested extensions")
}
requestedExts[oidStr] = true
}
ret = append(ret, extensions...)
}
return ret, nil
}
// CreateCertificateRequest creates a new certificate request based on a
// template. The following members of template are used:
//
// - SignatureAlgorithm
// - Subject
// - DNSNames
// - EmailAddresses
// - IPAddresses
// - URIs
// - ExtraExtensions
// - Attributes (deprecated)
//
// priv is the private key to sign the CSR with, and the corresponding public
// key will be included in the CSR. It must implement crypto.Signer or
// crypto.MessageSigner and its Public() method must return a *rsa.PublicKey or
// a *ecdsa.PublicKey or a ed25519.PublicKey. (A *rsa.PrivateKey,
// *ecdsa.PrivateKey or ed25519.PrivateKey satisfies this.)
//
// The returned slice is the certificate request in DER encoding.
func CreateCertificateRequest(rand io.Reader, template *CertificateRequest, priv any) (csr []byte, err error) {
key, ok := priv.(crypto.Signer)
if !ok {
return nil, errors.New("x509: certificate private key does not implement crypto.Signer")
}
signatureAlgorithm, algorithmIdentifier, err := signingParamsForKey(key, template.SignatureAlgorithm)
if err != nil {
return nil, err
}
var publicKeyBytes []byte
var publicKeyAlgorithm pkix.AlgorithmIdentifier
publicKeyBytes, publicKeyAlgorithm, err = marshalPublicKey(key.Public())
if err != nil {
return nil, err
}
extensions, err := buildCSRExtensions(template)
if err != nil {
return nil, err
}
// Make a copy of template.Attributes because we may alter it below.
attributes := make([]pkix.AttributeTypeAndValueSET, 0, len(template.Attributes))
for _, attr := range template.Attributes {
values := make([][]pkix.AttributeTypeAndValue, len(attr.Value))
copy(values, attr.Value)
attributes = append(attributes, pkix.AttributeTypeAndValueSET{
Type: attr.Type,
Value: values,
})
}
extensionsAppended := false
if len(extensions) > 0 {
// Append the extensions to an existing attribute if possible.
for _, atvSet := range attributes {
if !atvSet.Type.Equal(oidExtensionRequest) || len(atvSet.Value) == 0 {
continue
}
// specifiedExtensions contains all the extensions that we
// found specified via template.Attributes.
specifiedExtensions := make(map[string]bool)
for _, atvs := range atvSet.Value {
for _, atv := range atvs {
specifiedExtensions[atv.Type.String()] = true
}
}
newValue := make([]pkix.AttributeTypeAndValue, 0, len(atvSet.Value[0])+len(extensions))
newValue = append(newValue, atvSet.Value[0]...)
for _, e := range extensions {
if specifiedExtensions[e.Id.String()] {
// Attributes already contained a value for
// this extension and it takes priority.
continue
}
newValue = append(newValue, pkix.AttributeTypeAndValue{
// There is no place for the critical
// flag in an AttributeTypeAndValue.
Type: e.Id,
Value: e.Value,
})
}
atvSet.Value[0] = newValue
extensionsAppended = true
break
}
}
rawAttributes, err := newRawAttributes(attributes)
if err != nil {
return nil, err
}
// If not included in attributes, add a new attribute for the
// extensions.
if len(extensions) > 0 && !extensionsAppended {
attr := struct {
Type asn1.ObjectIdentifier
Value [][]pkix.Extension `asn1:"set"`
}{
Type: oidExtensionRequest,
Value: [][]pkix.Extension{extensions},
}
b, err := asn1.Marshal(attr)
if err != nil {
return nil, errors.New("x509: failed to serialise extensions attribute: " + err.Error())
}
var rawValue asn1.RawValue
if _, err := asn1.Unmarshal(b, &rawValue); err != nil {
return nil, err
}
rawAttributes = append(rawAttributes, rawValue)
}
asn1Subject := template.RawSubject
if len(asn1Subject) == 0 {
asn1Subject, err = asn1.Marshal(template.Subject.ToRDNSequence())
if err != nil {
return nil, err
}
}
tbsCSR := tbsCertificateRequest{
Version: 0, // PKCS #10, RFC 2986
Subject: asn1.RawValue{FullBytes: asn1Subject},
PublicKey: publicKeyInfo{
Algorithm: publicKeyAlgorithm,
PublicKey: asn1.BitString{
Bytes: publicKeyBytes,
BitLength: len(publicKeyBytes) * 8,
},
},
RawAttributes: rawAttributes,
}
tbsCSRContents, err := asn1.Marshal(tbsCSR)
if err != nil {
return nil, err
}
tbsCSR.Raw = tbsCSRContents
signature, err := signTBS(tbsCSRContents, key, signatureAlgorithm, rand)
if err != nil {
return nil, err
}
return asn1.Marshal(certificateRequest{
TBSCSR: tbsCSR,
SignatureAlgorithm: algorithmIdentifier,
SignatureValue: asn1.BitString{Bytes: signature, BitLength: len(signature) * 8},
})
}
// ParseCertificateRequest parses a single certificate request from the
// given ASN.1 DER data.
func ParseCertificateRequest(asn1Data []byte) (*CertificateRequest, error) {
var csr certificateRequest
rest, err := asn1.Unmarshal(asn1Data, &csr)
if err != nil {
return nil, err
} else if len(rest) != 0 {
return nil, asn1.SyntaxError{Msg: "trailing data"}
}
return parseCertificateRequest(&csr)
}
func parseCertificateRequest(in *certificateRequest) (*CertificateRequest, error) {
out := &CertificateRequest{
Raw: in.Raw,
RawTBSCertificateRequest: in.TBSCSR.Raw,
RawSubjectPublicKeyInfo: in.TBSCSR.PublicKey.Raw,
RawSubject: in.TBSCSR.Subject.FullBytes,
Signature: in.SignatureValue.RightAlign(),
SignatureAlgorithm: getSignatureAlgorithmFromAI(in.SignatureAlgorithm),
PublicKeyAlgorithm: getPublicKeyAlgorithmFromOID(in.TBSCSR.PublicKey.Algorithm.Algorithm),
Version: in.TBSCSR.Version,
Attributes: parseRawAttributes(in.TBSCSR.RawAttributes),
}
var err error
if out.PublicKeyAlgorithm != UnknownPublicKeyAlgorithm {
out.PublicKey, err = parsePublicKey(&in.TBSCSR.PublicKey)
if err != nil {
return nil, err
}
}
var subject pkix.RDNSequence
if rest, err := asn1.Unmarshal(in.TBSCSR.Subject.FullBytes, &subject); err != nil {
return nil, err
} else if len(rest) != 0 {
return nil, errors.New("x509: trailing data after X.509 Subject")
}
out.Subject.FillFromRDNSequence(&subject)
if out.Extensions, err = parseCSRExtensions(in.TBSCSR.RawAttributes); err != nil {
return nil, err
}
for _, extension := range out.Extensions {
switch {
case extension.Id.Equal(oidExtensionSubjectAltName):
out.DNSNames, out.EmailAddresses, out.IPAddresses, out.URIs, err = parseSANExtension(extension.Value)
if err != nil {
return nil, err
}
}
}
return out, nil
}
// CheckSignature reports whether the signature on c is valid.
func (c *CertificateRequest) CheckSignature() error {
return checkSignature(c.SignatureAlgorithm, c.RawTBSCertificateRequest, c.Signature, c.PublicKey, true)
}
// RevocationListEntry represents an entry in the revokedCertificates
// sequence of a CRL.
type RevocationListEntry struct {
// Raw contains the raw bytes of the revokedCertificates entry. It is set when
// parsing a CRL; it is ignored when generating a CRL.
Raw []byte
// SerialNumber represents the serial number of a revoked certificate. It is
// both used when creating a CRL and populated when parsing a CRL. It must not
// be nil.
SerialNumber *big.Int
// RevocationTime represents the time at which the certificate was revoked. It
// is both used when creating a CRL and populated when parsing a CRL. It must
// not be the zero time.
RevocationTime time.Time
// ReasonCode represents the reason for revocation, using the integer enum
// values specified in RFC 5280 Section 5.3.1. When creating a CRL, the zero
// value will result in the reasonCode extension being omitted. When parsing a
// CRL, the zero value may represent either the reasonCode extension being
// absent (which implies the default revocation reason of 0/Unspecified), or
// it may represent the reasonCode extension being present and explicitly
// containing a value of 0/Unspecified (which should not happen according to
// the DER encoding rules, but can and does happen anyway).
ReasonCode int
// Extensions contains raw X.509 extensions. When parsing CRL entries,
// this can be used to extract non-critical extensions that are not
// parsed by this package. When marshaling CRL entries, the Extensions
// field is ignored, see ExtraExtensions.
Extensions []pkix.Extension
// ExtraExtensions contains extensions to be copied, raw, into any
// marshaled CRL entries. Values override any extensions that would
// otherwise be produced based on the other fields. The ExtraExtensions
// field is not populated when parsing CRL entries, see Extensions.
ExtraExtensions []pkix.Extension
}
// RevocationList represents a [Certificate] Revocation List (CRL) as specified
// by RFC 5280.
type RevocationList struct {
// Raw contains the complete ASN.1 DER content of the CRL (tbsCertList,
// signatureAlgorithm, and signatureValue.)
Raw []byte
// RawTBSRevocationList contains just the tbsCertList portion of the ASN.1
// DER.
RawTBSRevocationList []byte
// RawIssuer contains the DER encoded Issuer.
RawIssuer []byte
// Issuer contains the DN of the issuing certificate.
Issuer pkix.Name
// AuthorityKeyId is used to identify the public key associated with the
// issuing certificate. It is populated from the authorityKeyIdentifier
// extension when parsing a CRL. It is ignored when creating a CRL; the
// extension is populated from the issuing certificate itself.
AuthorityKeyId []byte
Signature []byte
// SignatureAlgorithm is used to determine the signature algorithm to be
// used when signing the CRL. If 0 the default algorithm for the signing
// key will be used.
SignatureAlgorithm SignatureAlgorithm
// RevokedCertificateEntries represents the revokedCertificates sequence in
// the CRL. It is used when creating a CRL and also populated when parsing a
// CRL. When creating a CRL, it may be empty or nil, in which case the
// revokedCertificates ASN.1 sequence will be omitted from the CRL entirely.
RevokedCertificateEntries []RevocationListEntry
// RevokedCertificates is used to populate the revokedCertificates
// sequence in the CRL if RevokedCertificateEntries is empty. It may be empty
// or nil, in which case an empty CRL will be created.
//
// Deprecated: Use RevokedCertificateEntries instead.
RevokedCertificates []pkix.RevokedCertificate
// Number is used to populate the X.509 v2 cRLNumber extension in the CRL,
// which should be a monotonically increasing sequence number for a given
// CRL scope and CRL issuer. It is also populated from the cRLNumber
// extension when parsing a CRL.
Number *big.Int
// ThisUpdate is used to populate the thisUpdate field in the CRL, which
// indicates the issuance date of the CRL.
ThisUpdate time.Time
// NextUpdate is used to populate the nextUpdate field in the CRL, which
// indicates the date by which the next CRL will be issued. NextUpdate
// must be greater than ThisUpdate.
NextUpdate time.Time
// Extensions contains raw X.509 extensions. When creating a CRL,
// the Extensions field is ignored, see ExtraExtensions.
Extensions []pkix.Extension
// ExtraExtensions contains any additional extensions to add directly to
// the CRL.
ExtraExtensions []pkix.Extension
}
// These structures reflect the ASN.1 structure of X.509 CRLs better than
// the existing crypto/x509/pkix variants do. These mirror the existing
// certificate structs in this file.
//
// Notably, we include issuer as an asn1.RawValue, mirroring the behavior of
// tbsCertificate and allowing raw (unparsed) subjects to be passed cleanly.
type certificateList struct {
TBSCertList tbsCertificateList
SignatureAlgorithm pkix.AlgorithmIdentifier
SignatureValue asn1.BitString
}
type tbsCertificateList struct {
Raw asn1.RawContent
Version int `asn1:"optional,default:0"`
Signature pkix.AlgorithmIdentifier
Issuer asn1.RawValue
ThisUpdate time.Time
NextUpdate time.Time `asn1:"optional"`
RevokedCertificates []pkix.RevokedCertificate `asn1:"optional"`
Extensions []pkix.Extension `asn1:"tag:0,optional,explicit"`
}
// CreateRevocationList creates a new X.509 v2 [Certificate] Revocation List,
// according to RFC 5280, based on template.
//
// The CRL is signed by priv which should be a crypto.Signer or
// crypto.MessageSigner associated with the public key in the issuer
// certificate.
//
// The issuer may not be nil, and the crlSign bit must be set in [KeyUsage] in
// order to use it as a CRL issuer.
//
// The issuer distinguished name CRL field and authority key identifier
// extension are populated using the issuer certificate. issuer must have
// SubjectKeyId set.
func CreateRevocationList(rand io.Reader, template *RevocationList, issuer *Certificate, priv crypto.Signer) ([]byte, error) {
if template == nil {
return nil, errors.New("x509: template can not be nil")
}
if issuer == nil {
return nil, errors.New("x509: issuer can not be nil")
}
if (issuer.KeyUsage & KeyUsageCRLSign) == 0 {
return nil, errors.New("x509: issuer must have the crlSign key usage bit set")
}
if len(issuer.SubjectKeyId) == 0 {
return nil, errors.New("x509: issuer certificate doesn't contain a subject key identifier")
}
if template.NextUpdate.Before(template.ThisUpdate) {
return nil, errors.New("x509: template.ThisUpdate is after template.NextUpdate")
}
if template.Number == nil {
return nil, errors.New("x509: template contains nil Number field")
}
signatureAlgorithm, algorithmIdentifier, err := signingParamsForKey(priv, template.SignatureAlgorithm)
if err != nil {
return nil, err
}
var revokedCerts []pkix.RevokedCertificate
// Only process the deprecated RevokedCertificates field if it is populated
// and the new RevokedCertificateEntries field is not populated.
if len(template.RevokedCertificates) > 0 && len(template.RevokedCertificateEntries) == 0 {
// Force revocation times to UTC per RFC 5280.
revokedCerts = make([]pkix.RevokedCertificate, len(template.RevokedCertificates))
for i, rc := range template.RevokedCertificates {
rc.RevocationTime = rc.RevocationTime.UTC()
revokedCerts[i] = rc
}
} else {
// Convert the ReasonCode field to a proper extension, and force revocation
// times to UTC per RFC 5280.
revokedCerts = make([]pkix.RevokedCertificate, len(template.RevokedCertificateEntries))
for i, rce := range template.RevokedCertificateEntries {
if rce.SerialNumber == nil {
return nil, errors.New("x509: template contains entry with nil SerialNumber field")
}
if rce.RevocationTime.IsZero() {
return nil, errors.New("x509: template contains entry with zero RevocationTime field")
}
rc := pkix.RevokedCertificate{
SerialNumber: rce.SerialNumber,
RevocationTime: rce.RevocationTime.UTC(),
}
// Copy over any extra extensions, except for a Reason Code extension,
// because we'll synthesize that ourselves to ensure it is correct.
exts := make([]pkix.Extension, 0, len(rce.ExtraExtensions))
for _, ext := range rce.ExtraExtensions {
if ext.Id.Equal(oidExtensionReasonCode) {
return nil, errors.New("x509: template contains entry with ReasonCode ExtraExtension; use ReasonCode field instead")
}
exts = append(exts, ext)
}
// Only add a reasonCode extension if the reason is non-zero, as per
// RFC 5280 Section 5.3.1.
if rce.ReasonCode != 0 {
reasonBytes, err := asn1.Marshal(asn1.Enumerated(rce.ReasonCode))
if err != nil {
return nil, err
}
exts = append(exts, pkix.Extension{
Id: oidExtensionReasonCode,
Value: reasonBytes,
})
}
if len(exts) > 0 {
rc.Extensions = exts
}
revokedCerts[i] = rc
}
}
aki, err := asn1.Marshal(authKeyId{Id: issuer.SubjectKeyId})
if err != nil {
return nil, err
}
if numBytes := template.Number.Bytes(); len(numBytes) > 20 || (len(numBytes) == 20 && numBytes[0]&0x80 != 0) {
return nil, errors.New("x509: CRL number exceeds 20 octets")
}
crlNum, err := asn1.Marshal(template.Number)
if err != nil {
return nil, err
}
// Correctly use the issuer's subject sequence if one is specified.
issuerSubject, err := subjectBytes(issuer)
if err != nil {
return nil, err
}
tbsCertList := tbsCertificateList{
Version: 1, // v2
Signature: algorithmIdentifier,
Issuer: asn1.RawValue{FullBytes: issuerSubject},
ThisUpdate: template.ThisUpdate.UTC(),
NextUpdate: template.NextUpdate.UTC(),
Extensions: []pkix.Extension{
{
Id: oidExtensionAuthorityKeyId,
Value: aki,
},
{
Id: oidExtensionCRLNumber,
Value: crlNum,
},
},
}
if len(revokedCerts) > 0 {
tbsCertList.RevokedCertificates = revokedCerts
}
if len(template.ExtraExtensions) > 0 {
tbsCertList.Extensions = append(tbsCertList.Extensions, template.ExtraExtensions...)
}
tbsCertListContents, err := asn1.Marshal(tbsCertList)
if err != nil {
return nil, err
}
// Optimization to only marshal this struct once, when signing and
// then embedding in certificateList below.
tbsCertList.Raw = tbsCertListContents
signature, err := signTBS(tbsCertListContents, priv, signatureAlgorithm, rand)
if err != nil {
return nil, err
}
return asn1.Marshal(certificateList{
TBSCertList: tbsCertList,
SignatureAlgorithm: algorithmIdentifier,
SignatureValue: asn1.BitString{Bytes: signature, BitLength: len(signature) * 8},
})
}
// CheckSignatureFrom verifies that the signature on rl is a valid signature
// from issuer.
func (rl *RevocationList) CheckSignatureFrom(parent *Certificate) error {
if parent.Version == 3 && !parent.BasicConstraintsValid ||
parent.BasicConstraintsValid && !parent.IsCA {
return ConstraintViolationError{}
}
if parent.KeyUsage != 0 && parent.KeyUsage&KeyUsageCRLSign == 0 {
return ConstraintViolationError{}
}
if parent.PublicKeyAlgorithm == UnknownPublicKeyAlgorithm {
return ErrUnsupportedAlgorithm
}
return parent.CheckSignature(rl.SignatureAlgorithm, rl.RawTBSRevocationList, rl.Signature)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Type conversions for Scan.
package sql
import (
"bytes"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"strconv"
"time"
"unicode"
"unicode/utf8"
_ "unsafe" // for linkname
)
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
func describeNamedValue(nv *driver.NamedValue) string {
if len(nv.Name) == 0 {
return fmt.Sprintf("$%d", nv.Ordinal)
}
return fmt.Sprintf("with name %q", nv.Name)
}
func validateNamedValueName(name string) error {
if len(name) == 0 {
return nil
}
r, _ := utf8.DecodeRuneInString(name)
if unicode.IsLetter(r) {
return nil
}
return fmt.Errorf("name %q does not begin with a letter", name)
}
// ccChecker wraps the driver.ColumnConverter and allows it to be used
// as if it were a NamedValueChecker. If the driver ColumnConverter
// is not present then the NamedValueChecker will return driver.ErrSkip.
type ccChecker struct {
cci driver.ColumnConverter
want int
}
func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
if c.cci == nil {
return driver.ErrSkip
}
// The column converter shouldn't be called on any index
// it isn't expecting. The final error will be thrown
// in the argument converter loop.
index := nv.Ordinal - 1
if c.want <= index {
return nil
}
// First, see if the value itself knows how to convert
// itself to a driver type. For example, a NullString
// struct changing into a string or nil.
if vr, ok := nv.Value.(driver.Valuer); ok {
sv, err := callValuerValue(vr)
if err != nil {
return err
}
if !driver.IsValue(sv) {
return fmt.Errorf("non-subset type %T returned from Value", sv)
}
nv.Value = sv
}
// Second, ask the column to sanity check itself. For
// example, drivers might use this to make sure that
// an int64 values being inserted into a 16-bit
// integer field is in range (before getting
// truncated), or that a nil can't go into a NOT NULL
// column before going across the network to get the
// same error.
var err error
arg := nv.Value
nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
if err != nil {
return err
}
if !driver.IsValue(nv.Value) {
return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value)
}
return nil
}
// defaultCheckNamedValue wraps the default ColumnConverter to have the same
// function signature as the CheckNamedValue in the driver.NamedValueChecker
// interface.
func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
return err
}
// driverArgsConnLocked converts arguments from callers of Stmt.Exec and
// Stmt.Query into driver Values.
//
// The statement ds may be nil, if no statement is available.
//
// ci must be locked.
func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []any) ([]driver.NamedValue, error) {
nvargs := make([]driver.NamedValue, len(args))
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
want := -1
var si driver.Stmt
var cc ccChecker
if ds != nil {
si = ds.si
want = ds.si.NumInput()
cc.want = want
}
// Check all types of interfaces from the start.
// Drivers may opt to use the NamedValueChecker for special
// argument types, then return driver.ErrSkip to pass it along
// to the column converter.
nvc, ok := si.(driver.NamedValueChecker)
if !ok {
nvc, _ = ci.(driver.NamedValueChecker)
}
cci, ok := si.(driver.ColumnConverter)
if ok {
cc.cci = cci
}
// Loop through all the arguments, checking each one.
// If no error is returned simply increment the index
// and continue. However, if driver.ErrRemoveArgument
// is returned the argument is not included in the query
// argument list.
var err error
var n int
for _, arg := range args {
nv := &nvargs[n]
if np, ok := arg.(NamedArg); ok {
if err = validateNamedValueName(np.Name); err != nil {
return nil, err
}
arg = np.Value
nv.Name = np.Name
}
nv.Ordinal = n + 1
nv.Value = arg
// Checking sequence has four routes:
// A: 1. Default
// B: 1. NamedValueChecker 2. Column Converter 3. Default
// C: 1. NamedValueChecker 3. Default
// D: 1. Column Converter 2. Default
//
// The only time a Column Converter is called is first
// or after NamedValueConverter. If first it is handled before
// the nextCheck label. Thus for repeats tries only when the
// NamedValueConverter is selected should the Column Converter
// be used in the retry.
checker := defaultCheckNamedValue
nextCC := false
switch {
case nvc != nil:
nextCC = cci != nil
checker = nvc.CheckNamedValue
case cci != nil:
checker = cc.CheckNamedValue
}
nextCheck:
err = checker(nv)
switch err {
case nil:
n++
continue
case driver.ErrRemoveArgument:
nvargs = nvargs[:len(nvargs)-1]
continue
case driver.ErrSkip:
if nextCC {
nextCC = false
checker = cc.CheckNamedValue
} else {
checker = defaultCheckNamedValue
}
goto nextCheck
default:
return nil, fmt.Errorf("sql: converting argument %s type: %w", describeNamedValue(nv), err)
}
}
// Check the length of arguments after conversion to allow for omitted
// arguments.
if want != -1 && len(nvargs) != want {
return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
}
return nvargs, nil
}
// convertAssign is the same as convertAssignRows, but without the optional
// rows argument.
//
// convertAssign should be an internal detail,
// but widely used packages access it using linkname.
// Notable members of the hall of shame include:
// - ariga.io/entcache
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname convertAssign
func convertAssign(dest, src any) error {
return convertAssignRows(dest, src, nil)
}
// convertAssignRows copies to dest the value in src, converting it if possible.
// An error is returned if the copy would result in loss of information.
// dest should be a pointer type. If rows is passed in, the rows will
// be used as the parent for any cursor values converted from a
// driver.Rows to a *Rows.
func convertAssignRows(dest, src any, rows *Rows) error {
// Common cases, without reflect.
switch s := src.(type) {
case string:
switch d := dest.(type) {
case *string:
if d == nil {
return errNilPtr
}
*d = s
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = []byte(s)
return nil
case *RawBytes:
if d == nil {
return errNilPtr
}
*d = rows.setrawbuf(append(rows.rawbuf(), s...))
return nil
}
case []byte:
switch d := dest.(type) {
case *string:
if d == nil {
return errNilPtr
}
*d = string(s)
return nil
case *any:
if d == nil {
return errNilPtr
}
*d = bytes.Clone(s)
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = bytes.Clone(s)
return nil
case *RawBytes:
if d == nil {
return errNilPtr
}
*d = s
return nil
}
case time.Time:
switch d := dest.(type) {
case *time.Time:
*d = s
return nil
case *string:
*d = s.Format(time.RFC3339Nano)
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = s.AppendFormat(make([]byte, 0, len(time.RFC3339Nano)), time.RFC3339Nano)
return nil
case *RawBytes:
if d == nil {
return errNilPtr
}
*d = rows.setrawbuf(s.AppendFormat(rows.rawbuf(), time.RFC3339Nano))
return nil
}
case decimalDecompose:
switch d := dest.(type) {
case decimalCompose:
return d.Compose(s.Decompose(nil))
}
case nil:
switch d := dest.(type) {
case *any:
if d == nil {
return errNilPtr
}
*d = nil
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = nil
return nil
case *RawBytes:
if d == nil {
return errNilPtr
}
*d = nil
return nil
}
// The driver is returning a cursor the client may iterate over.
case driver.Rows:
switch d := dest.(type) {
case *Rows:
if d == nil {
return errNilPtr
}
if rows == nil {
return errors.New("invalid context to convert cursor rows, missing parent *Rows")
}
*d = Rows{
dc: rows.dc,
releaseConn: func(error) {},
rowsi: s,
}
// Chain the cancel function.
parentCancel := rows.cancel
rows.cancel = func() {
// When Rows.cancel is called, the closemu will be locked as well.
// So we can access rs.lasterr.
d.close(rows.lasterr)
if parentCancel != nil {
parentCancel()
}
}
return nil
}
}
var sv reflect.Value
switch d := dest.(type) {
case *string:
sv = reflect.ValueOf(src)
switch sv.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64:
*d = asString(src)
return nil
}
case *[]byte:
sv = reflect.ValueOf(src)
if b, ok := asBytes(nil, sv); ok {
*d = b
return nil
}
case *RawBytes:
sv = reflect.ValueOf(src)
if b, ok := asBytes(rows.rawbuf(), sv); ok {
*d = rows.setrawbuf(b)
return nil
}
case *bool:
bv, err := driver.Bool.ConvertValue(src)
if err == nil {
*d = bv.(bool)
}
return err
case *any:
*d = src
return nil
}
if scanner, ok := dest.(Scanner); ok {
return scanner.Scan(src)
}
dpv := reflect.ValueOf(dest)
if dpv.Kind() != reflect.Pointer {
return errors.New("destination not a pointer")
}
if dpv.IsNil() {
return errNilPtr
}
if !sv.IsValid() {
sv = reflect.ValueOf(src)
}
dv := reflect.Indirect(dpv)
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
switch b := src.(type) {
case []byte:
dv.Set(reflect.ValueOf(bytes.Clone(b)))
default:
dv.Set(sv)
}
return nil
}
if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
dv.Set(sv.Convert(dv.Type()))
return nil
}
// The following conversions use a string value as an intermediate representation
// to convert between various numeric types.
//
// This also allows scanning into user defined types such as "type Int int64".
// For symmetry, also check for string destination types.
switch dv.Kind() {
case reflect.Pointer:
if src == nil {
dv.SetZero()
return nil
}
dv.Set(reflect.New(dv.Type().Elem()))
return convertAssignRows(dv.Interface(), src, rows)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
s := asString(src)
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetInt(i64)
return nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
s := asString(src)
u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetUint(u64)
return nil
case reflect.Float32, reflect.Float64:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
s := asString(src)
f64, err := strconv.ParseFloat(s, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetFloat(f64)
return nil
case reflect.String:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
switch v := src.(type) {
case string:
dv.SetString(v)
return nil
case []byte:
dv.SetString(string(v))
return nil
}
}
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
}
func strconvErr(err error) error {
if ne, ok := err.(*strconv.NumError); ok {
return ne.Err
}
return err
}
func asString(src any) string {
switch v := src.(type) {
case string:
return v
case []byte:
return string(v)
}
rv := reflect.ValueOf(src)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(rv.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.FormatUint(rv.Uint(), 10)
case reflect.Float64:
return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
case reflect.Float32:
return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
case reflect.Bool:
return strconv.FormatBool(rv.Bool())
}
return fmt.Sprintf("%v", src)
}
func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.AppendInt(buf, rv.Int(), 10), true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.AppendUint(buf, rv.Uint(), 10), true
case reflect.Float32:
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
case reflect.Float64:
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
case reflect.Bool:
return strconv.AppendBool(buf, rv.Bool()), true
case reflect.String:
s := rv.String()
return append(buf, s...), true
}
return
}
var valuerReflectType = reflect.TypeFor[driver.Valuer]()
// callValuerValue returns vr.Value(), with one exception:
// If vr.Value is an auto-generated method on a pointer type and the
// pointer is nil, it would panic at runtime in the panicwrap
// method. Treat it like nil instead.
// Issue 8415.
//
// This is so people can implement driver.Value on value types and
// still use nil pointers to those types to mean nil/NULL, just like
// string/*string.
//
// This function is mirrored in the database/sql/driver package.
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer &&
rv.IsNil() &&
rv.Type().Elem().Implements(valuerReflectType) {
return nil, nil
}
return vr.Value()
}
// decimal composes or decomposes a decimal value to and from individual parts.
// There are four parts: a boolean negative flag, a form byte with three possible states
// (finite=0, infinite=1, NaN=2), a base-2 big-endian integer
// coefficient (also known as a significand) as a []byte, and an int32 exponent.
// These are composed into a final value as "decimal = (neg) (form=finite) coefficient * 10 ^ exponent".
// A zero length coefficient is a zero value.
// The big-endian integer coefficient stores the most significant byte first (at coefficient[0]).
// If the form is not finite the coefficient and exponent should be ignored.
// The negative parameter may be set to true for any form, although implementations are not required
// to respect the negative parameter in the non-finite form.
//
// Implementations may choose to set the negative parameter to true on a zero or NaN value,
// but implementations that do not differentiate between negative and positive
// zero or NaN values should ignore the negative parameter without error.
// If an implementation does not support Infinity it may be converted into a NaN without error.
// If a value is set that is larger than what is supported by an implementation,
// an error must be returned.
// Implementations must return an error if a NaN or Infinity is attempted to be set while neither
// are supported.
//
// NOTE(kardianos): This is an experimental interface. See https://golang.org/issue/30870
type decimal interface {
decimalDecompose
decimalCompose
}
type decimalDecompose interface {
// Decompose returns the internal decimal state in parts.
// If the provided buf has sufficient capacity, buf may be returned as the coefficient with
// the value set and length set as appropriate.
Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
}
type decimalCompose interface {
// Compose sets the internal decimal value from parts. If the value cannot be
// represented then an error should be returned.
Compose(form byte, negative bool, coefficient []byte, exponent int32) error
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sql
import (
"context"
"database/sql/driver"
"errors"
)
func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) {
if ciCtx, is := ci.(driver.ConnPrepareContext); is {
return ciCtx.PrepareContext(ctx, query)
}
si, err := ci.Prepare(query)
if err == nil {
select {
default:
case <-ctx.Done():
si.Close()
return nil, ctx.Err()
}
}
return si, err
}
func ctxDriverExec(ctx context.Context, execerCtx driver.ExecerContext, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
if execerCtx != nil {
return execerCtx.ExecContext(ctx, query, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return execer.Exec(query, dargs)
}
func ctxDriverQuery(ctx context.Context, queryerCtx driver.QueryerContext, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
if queryerCtx != nil {
return queryerCtx.QueryContext(ctx, query, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return queryer.Query(query, dargs)
}
func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
if siCtx, is := si.(driver.StmtExecContext); is {
return siCtx.ExecContext(ctx, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return si.Exec(dargs)
}
func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
if siCtx, is := si.(driver.StmtQueryContext); is {
return siCtx.QueryContext(ctx, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return si.Query(dargs)
}
func ctxDriverBegin(ctx context.Context, opts *TxOptions, ci driver.Conn) (driver.Tx, error) {
if ciCtx, is := ci.(driver.ConnBeginTx); is {
dopts := driver.TxOptions{}
if opts != nil {
dopts.Isolation = driver.IsolationLevel(opts.Isolation)
dopts.ReadOnly = opts.ReadOnly
}
return ciCtx.BeginTx(ctx, dopts)
}
if opts != nil {
// Check the transaction level. If the transaction level is non-default
// then return an error here as the BeginTx driver value is not supported.
if opts.Isolation != LevelDefault {
return nil, errors.New("sql: driver does not support non-default isolation level")
}
// If a read-only transaction is requested return an error as the
// BeginTx driver value is not supported.
if opts.ReadOnly {
return nil, errors.New("sql: driver does not support read-only transactions")
}
}
if ctx.Done() == nil {
return ci.Begin()
}
txi, err := ci.Begin()
if err == nil {
select {
default:
case <-ctx.Done():
txi.Rollback()
return nil, ctx.Err()
}
}
return txi, err
}
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
dargs := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
return nil, errors.New("sql: driver does not support the use of Named Parameters")
}
dargs[n] = param.Value
}
return dargs, nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package driver defines interfaces to be implemented by database
// drivers as used by package sql.
//
// Most code should use the [database/sql] package.
//
// The driver interface has evolved over time. Drivers should implement
// [Connector] and [DriverContext] interfaces.
// The Connector.Connect and Driver.Open methods should never return [ErrBadConn].
// [ErrBadConn] should only be returned from [Validator], [SessionResetter], or
// a query method if the connection is already in an invalid (e.g. closed) state.
//
// All [Conn] implementations should implement the following interfaces:
// [Pinger], [SessionResetter], and [Validator].
//
// If named parameters or context are supported, the driver's [Conn] should implement:
// [ExecerContext], [QueryerContext], [ConnPrepareContext], and [ConnBeginTx].
//
// To support custom data types, implement [NamedValueChecker]. [NamedValueChecker]
// also allows queries to accept per-query options as a parameter by returning
// [ErrRemoveArgument] from CheckNamedValue.
//
// If multiple result sets are supported, [Rows] should implement [RowsNextResultSet].
// If the driver knows how to describe the types present in the returned result
// it should implement the following interfaces: [RowsColumnTypeScanType],
// [RowsColumnTypeDatabaseTypeName], [RowsColumnTypeLength], [RowsColumnTypeNullable],
// and [RowsColumnTypePrecisionScale]. A given row value may also return a [Rows]
// type, which may represent a database cursor value.
//
// If a [Conn] implements [Validator], then the IsValid method is called
// before returning the connection to the connection pool. If an entry in the
// connection pool implements [SessionResetter], then ResetSession
// is called before reusing the connection for another query. If a connection is
// never returned to the connection pool but is immediately reused, then
// ResetSession is called prior to reuse but IsValid is not called.
package driver
import (
"context"
"errors"
"reflect"
)
// Value is a value that drivers must be able to handle.
// It is either nil, a type handled by a database driver's [NamedValueChecker]
// interface, or an instance of one of these types:
//
// int64
// float64
// bool
// []byte
// string
// time.Time
//
// If the driver supports cursors, a returned Value may also implement the [Rows] interface
// in this package. This is used, for example, when a user selects a cursor
// such as "select cursor(select * from my_table) from dual". If the [Rows]
// from the select is closed, the cursor [Rows] will also be closed.
type Value any
// NamedValue holds both the value name and value.
type NamedValue struct {
// If the Name is not empty it should be used for the parameter identifier and
// not the ordinal position.
//
// Name will not have a symbol prefix.
Name string
// Ordinal position of the parameter starting from one and is always set.
Ordinal int
// Value is the parameter value.
Value Value
}
// Driver is the interface that must be implemented by a database
// driver.
//
// Database drivers may implement [DriverContext] for access
// to contexts and to parse the name only once for a pool of connections,
// instead of once per connection.
type Driver interface {
// Open returns a new connection to the database.
// The name is a string in a driver-specific format.
//
// Open may return a cached connection (one previously
// closed), but doing so is unnecessary; the sql package
// maintains a pool of idle connections for efficient re-use.
//
// The returned connection is only used by one goroutine at a
// time.
Open(name string) (Conn, error)
}
// If a [Driver] implements DriverContext, then [database/sql.DB] will call
// OpenConnector to obtain a [Connector] and then invoke
// that [Connector]'s Connect method to obtain each needed connection,
// instead of invoking the [Driver]'s Open method for each connection.
// The two-step sequence allows drivers to parse the name just once
// and also provides access to per-[Conn] contexts.
type DriverContext interface {
// OpenConnector must parse the name in the same format that Driver.Open
// parses the name parameter.
OpenConnector(name string) (Connector, error)
}
// A Connector represents a driver in a fixed configuration
// and can create any number of equivalent Conns for use
// by multiple goroutines.
//
// A Connector can be passed to [database/sql.OpenDB], to allow drivers
// to implement their own [database/sql.DB] constructors, or returned by
// [DriverContext]'s OpenConnector method, to allow drivers
// access to context and to avoid repeated parsing of driver
// configuration.
//
// If a Connector implements [io.Closer], the [database/sql.DB.Close]
// method will call the Close method and return error (if any).
type Connector interface {
// Connect returns a connection to the database.
// Connect may return a cached connection (one previously
// closed), but doing so is unnecessary; the sql package
// maintains a pool of idle connections for efficient re-use.
//
// The provided context.Context is for dialing purposes only
// (see net.DialContext) and should not be stored or used for
// other purposes. A default timeout should still be used
// when dialing as a connection pool may call Connect
// asynchronously to any query.
//
// The returned connection is only used by one goroutine at a
// time.
Connect(context.Context) (Conn, error)
// Driver returns the underlying Driver of the Connector,
// mainly to maintain compatibility with the Driver method
// on sql.DB.
Driver() Driver
}
// ErrSkip may be returned by some optional interfaces' methods to
// indicate at runtime that the fast path is unavailable and the sql
// package should continue as if the optional interface was not
// implemented. ErrSkip is only supported where explicitly
// documented.
var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented")
// ErrBadConn should be returned by a driver to signal to the [database/sql]
// package that a driver.[Conn] is in a bad state (such as the server
// having earlier closed the connection) and the [database/sql] package should
// retry on a new connection.
//
// To prevent duplicate operations, ErrBadConn should NOT be returned
// if there's a possibility that the database server might have
// performed the operation. Even if the server sends back an error,
// you shouldn't return ErrBadConn.
//
// Errors will be checked using [errors.Is]. An error may
// wrap ErrBadConn or implement the Is(error) bool method.
var ErrBadConn = errors.New("driver: bad connection")
// Pinger is an optional interface that may be implemented by a [Conn].
//
// If a [Conn] does not implement Pinger, the [database/sql.DB.Ping] and
// [database/sql.DB.PingContext] will check if there is at least one [Conn] available.
//
// If Conn.Ping returns [ErrBadConn], [database/sql.DB.Ping] and [database/sql.DB.PingContext] will remove
// the [Conn] from pool.
type Pinger interface {
Ping(ctx context.Context) error
}
// Execer is an optional interface that may be implemented by a [Conn].
//
// If a [Conn] implements neither [ExecerContext] nor [Execer],
// the [database/sql.DB.Exec] will first prepare a query, execute the statement,
// and then close the statement.
//
// Exec may return [ErrSkip].
//
// Deprecated: Drivers should implement [ExecerContext] instead.
type Execer interface {
Exec(query string, args []Value) (Result, error)
}
// ExecerContext is an optional interface that may be implemented by a [Conn].
//
// If a [Conn] does not implement [ExecerContext], the [database/sql.DB.Exec]
// will fall back to [Execer]; if the Conn does not implement Execer either,
// [database/sql.DB.Exec] will first prepare a query, execute the statement, and then
// close the statement.
//
// ExecContext may return [ErrSkip].
//
// ExecContext must honor the context timeout and return when the context is canceled.
type ExecerContext interface {
ExecContext(ctx context.Context, query string, args []NamedValue) (Result, error)
}
// Queryer is an optional interface that may be implemented by a [Conn].
//
// If a [Conn] implements neither [QueryerContext] nor [Queryer],
// the [database/sql.DB.Query] will first prepare a query, execute the statement,
// and then close the statement.
//
// Query may return [ErrSkip].
//
// Deprecated: Drivers should implement [QueryerContext] instead.
type Queryer interface {
Query(query string, args []Value) (Rows, error)
}
// QueryerContext is an optional interface that may be implemented by a [Conn].
//
// If a [Conn] does not implement QueryerContext, the [database/sql.DB.Query]
// will fall back to [Queryer]; if the [Conn] does not implement [Queryer] either,
// [database/sql.DB.Query] will first prepare a query, execute the statement, and then
// close the statement.
//
// QueryContext may return [ErrSkip].
//
// QueryContext must honor the context timeout and return when the context is canceled.
type QueryerContext interface {
QueryContext(ctx context.Context, query string, args []NamedValue) (Rows, error)
}
// Conn is a connection to a database. It is not used concurrently
// by multiple goroutines.
//
// Conn is assumed to be stateful.
type Conn interface {
// Prepare returns a prepared statement, bound to this connection.
Prepare(query string) (Stmt, error)
// Close invalidates and potentially stops any current
// prepared statements and transactions, marking this
// connection as no longer in use.
//
// Because the sql package maintains a free pool of
// connections and only calls Close when there's a surplus of
// idle connections, it shouldn't be necessary for drivers to
// do their own connection caching.
//
// Drivers must ensure all network calls made by Close
// do not block indefinitely (e.g. apply a timeout).
Close() error
// Begin starts and returns a new transaction.
//
// Deprecated: Drivers should implement ConnBeginTx instead (or additionally).
Begin() (Tx, error)
}
// ConnPrepareContext enhances the [Conn] interface with context.
type ConnPrepareContext interface {
// PrepareContext returns a prepared statement, bound to this connection.
// context is for the preparation of the statement,
// it must not store the context within the statement itself.
PrepareContext(ctx context.Context, query string) (Stmt, error)
}
// IsolationLevel is the transaction isolation level stored in [TxOptions].
//
// This type should be considered identical to [database/sql.IsolationLevel] along
// with any values defined on it.
type IsolationLevel int
// TxOptions holds the transaction options.
//
// This type should be considered identical to [database/sql.TxOptions].
type TxOptions struct {
Isolation IsolationLevel
ReadOnly bool
}
// ConnBeginTx enhances the [Conn] interface with context and [TxOptions].
type ConnBeginTx interface {
// BeginTx starts and returns a new transaction.
// If the context is canceled by the user the sql package will
// call Tx.Rollback before discarding and closing the connection.
//
// This must check opts.Isolation to determine if there is a set
// isolation level. If the driver does not support a non-default
// level and one is set or if there is a non-default isolation level
// that is not supported, an error must be returned.
//
// This must also check opts.ReadOnly to determine if the read-only
// value is true to either set the read-only transaction property if supported
// or return an error if it is not supported.
BeginTx(ctx context.Context, opts TxOptions) (Tx, error)
}
// SessionResetter may be implemented by [Conn] to allow drivers to reset the
// session state associated with the connection and to signal a bad connection.
type SessionResetter interface {
// ResetSession is called prior to executing a query on the connection
// if the connection has been used before. If the driver returns ErrBadConn
// the connection is discarded.
ResetSession(ctx context.Context) error
}
// Validator may be implemented by [Conn] to allow drivers to
// signal if a connection is valid or if it should be discarded.
//
// If implemented, drivers may return the underlying error from queries,
// even if the connection should be discarded by the connection pool.
type Validator interface {
// IsValid is called prior to placing the connection into the
// connection pool. The connection will be discarded if false is returned.
IsValid() bool
}
// Result is the result of a query execution.
type Result interface {
// LastInsertId returns the database's auto-generated ID
// after, for example, an INSERT into a table with primary
// key.
LastInsertId() (int64, error)
// RowsAffected returns the number of rows affected by the
// query.
RowsAffected() (int64, error)
}
// Stmt is a prepared statement. It is bound to a [Conn] and not
// used by multiple goroutines concurrently.
type Stmt interface {
// Close closes the statement.
//
// As of Go 1.1, a Stmt will not be closed if it's in use
// by any queries.
//
// Drivers must ensure all network calls made by Close
// do not block indefinitely (e.g. apply a timeout).
Close() error
// NumInput returns the number of placeholder parameters.
//
// If NumInput returns >= 0, the sql package will sanity check
// argument counts from callers and return errors to the caller
// before the statement's Exec or Query methods are called.
//
// NumInput may also return -1, if the driver doesn't know
// its number of placeholders. In that case, the sql package
// will not sanity check Exec or Query argument counts.
NumInput() int
// Exec executes a query that doesn't return rows, such
// as an INSERT or UPDATE.
//
// Deprecated: Drivers should implement StmtExecContext instead (or additionally).
Exec(args []Value) (Result, error)
// Query executes a query that may return rows, such as a
// SELECT.
//
// Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
Query(args []Value) (Rows, error)
}
// StmtExecContext enhances the [Stmt] interface by providing Exec with context.
type StmtExecContext interface {
// ExecContext executes a query that doesn't return rows, such
// as an INSERT or UPDATE.
//
// ExecContext must honor the context timeout and return when it is canceled.
ExecContext(ctx context.Context, args []NamedValue) (Result, error)
}
// StmtQueryContext enhances the [Stmt] interface by providing Query with context.
type StmtQueryContext interface {
// QueryContext executes a query that may return rows, such as a
// SELECT.
//
// QueryContext must honor the context timeout and return when it is canceled.
QueryContext(ctx context.Context, args []NamedValue) (Rows, error)
}
// ErrRemoveArgument may be returned from [NamedValueChecker] to instruct the
// [database/sql] package to not pass the argument to the driver query interface.
// Return when accepting query specific options or structures that aren't
// SQL query arguments.
var ErrRemoveArgument = errors.New("driver: remove argument from query")
// NamedValueChecker may be optionally implemented by [Conn] or [Stmt]. It provides
// the driver more control to handle Go and database types beyond the default
// [Value] types allowed.
//
// The [database/sql] package checks for value checkers in the following order,
// stopping at the first found match: Stmt.NamedValueChecker, Conn.NamedValueChecker,
// Stmt.ColumnConverter, [DefaultParameterConverter].
//
// If CheckNamedValue returns [ErrRemoveArgument], the [NamedValue] will not be included in
// the final query arguments. This may be used to pass special options to
// the query itself.
//
// If [ErrSkip] is returned the column converter error checking
// path is used for the argument. Drivers may wish to return [ErrSkip] after
// they have exhausted their own special cases.
type NamedValueChecker interface {
// CheckNamedValue is called before passing arguments to the driver
// and is called in place of any ColumnConverter. CheckNamedValue must do type
// validation and conversion as appropriate for the driver.
CheckNamedValue(*NamedValue) error
}
// ColumnConverter may be optionally implemented by [Stmt] if the
// statement is aware of its own columns' types and can convert from
// any type to a driver [Value].
//
// Deprecated: Drivers should implement [NamedValueChecker].
type ColumnConverter interface {
// ColumnConverter returns a ValueConverter for the provided
// column index. If the type of a specific column isn't known
// or shouldn't be handled specially, [DefaultParameterConverter]
// can be returned.
ColumnConverter(idx int) ValueConverter
}
// Rows is an iterator over an executed query's results.
type Rows interface {
// Columns returns the names of the columns. The number of
// columns of the result is inferred from the length of the
// slice. If a particular column name isn't known, an empty
// string should be returned for that entry.
Columns() []string
// Close closes the rows iterator.
Close() error
// Next is called to populate the next row of data into
// the provided slice. The provided slice will be the same
// size as the Columns() are wide.
//
// Next should return io.EOF when there are no more rows.
//
// The dest should not be written to outside of Next. Care
// should be taken when closing Rows not to modify
// a buffer held in dest.
Next(dest []Value) error
}
// RowsNextResultSet extends the [Rows] interface by providing a way to signal
// the driver to advance to the next result set.
type RowsNextResultSet interface {
Rows
// HasNextResultSet is called at the end of the current result set and
// reports whether there is another result set after the current one.
HasNextResultSet() bool
// NextResultSet advances the driver to the next result set even
// if there are remaining rows in the current result set.
//
// NextResultSet should return io.EOF when there are no more result sets.
NextResultSet() error
}
// RowsColumnTypeScanType may be implemented by [Rows]. It should return
// the value type that can be used to scan types into. For example, the database
// column type "bigint" this should return "[reflect.TypeOf](int64(0))".
type RowsColumnTypeScanType interface {
Rows
ColumnTypeScanType(index int) reflect.Type
}
// RowsColumnTypeDatabaseTypeName may be implemented by [Rows]. It should return the
// database system type name without the length. Type names should be uppercase.
// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
// "TIMESTAMP".
type RowsColumnTypeDatabaseTypeName interface {
Rows
ColumnTypeDatabaseTypeName(index int) string
}
// RowsColumnTypeLength may be implemented by [Rows]. It should return the length
// of the column type if the column is a variable length type. If the column is
// not a variable length type ok should return false.
// If length is not limited other than system limits, it should return [math.MaxInt64].
// The following are examples of returned values for various types:
//
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
type RowsColumnTypeLength interface {
Rows
ColumnTypeLength(index int) (length int64, ok bool)
}
// RowsColumnTypeNullable may be implemented by [Rows]. The nullable value should
// be true if it is known the column may be null, or false if the column is known
// to be not nullable.
// If the column nullability is unknown, ok should be false.
type RowsColumnTypeNullable interface {
Rows
ColumnTypeNullable(index int) (nullable, ok bool)
}
// RowsColumnTypePrecisionScale may be implemented by [Rows]. It should return
// the precision and scale for decimal types. If not applicable, ok should be false.
// The following are examples of returned values for various types:
//
// decimal(38, 4) (38, 4, true)
// int (0, 0, false)
// decimal (math.MaxInt64, math.MaxInt64, true)
type RowsColumnTypePrecisionScale interface {
Rows
ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool)
}
// Tx is a transaction.
type Tx interface {
Commit() error
Rollback() error
}
// RowsAffected implements [Result] for an INSERT or UPDATE operation
// which mutates a number of rows.
type RowsAffected int64
var _ Result = RowsAffected(0)
func (RowsAffected) LastInsertId() (int64, error) {
return 0, errors.New("LastInsertId is not supported by this driver")
}
func (v RowsAffected) RowsAffected() (int64, error) {
return int64(v), nil
}
// ResultNoRows is a pre-defined [Result] for drivers to return when a DDL
// command (such as a CREATE TABLE) succeeds. It returns an error for both
// LastInsertId and [RowsAffected].
var ResultNoRows noRows
type noRows struct{}
var _ Result = noRows{}
func (noRows) LastInsertId() (int64, error) {
return 0, errors.New("no LastInsertId available after DDL statement")
}
func (noRows) RowsAffected() (int64, error) {
return 0, errors.New("no RowsAffected available after DDL statement")
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package driver
import (
"fmt"
"reflect"
"strconv"
"time"
)
// ValueConverter is the interface providing the ConvertValue method.
//
// Various implementations of ValueConverter are provided by the
// driver package to provide consistent implementations of conversions
// between drivers. The ValueConverters have several uses:
//
// - converting from the [Value] types as provided by the sql package
// into a database table's specific column type and making sure it
// fits, such as making sure a particular int64 fits in a
// table's uint16 column.
//
// - converting a value as given from the database into one of the
// driver [Value] types.
//
// - by the [database/sql] package, for converting from a driver's [Value] type
// to a user's type in a scan.
type ValueConverter interface {
// ConvertValue converts a value to a driver Value.
ConvertValue(v any) (Value, error)
}
// Valuer is the interface providing the Value method.
//
// Errors returned by the [Value] method are wrapped by the database/sql package.
// This allows callers to use [errors.Is] for precise error handling after operations
// like [database/sql.Query], [database/sql.Exec], or [database/sql.QueryRow].
//
// Types implementing Valuer interface are able to convert
// themselves to a driver [Value].
type Valuer interface {
// Value returns a driver Value.
// Value must not panic.
Value() (Value, error)
}
// Bool is a [ValueConverter] that converts input values to bool.
//
// The conversion rules are:
// - booleans are returned unchanged
// - for integer types,
// 1 is true
// 0 is false,
// other integers are an error
// - for strings and []byte, same rules as [strconv.ParseBool]
// - all other types are an error
var Bool boolType
type boolType struct{}
var _ ValueConverter = boolType{}
func (boolType) String() string { return "Bool" }
func (boolType) ConvertValue(src any) (Value, error) {
switch s := src.(type) {
case bool:
return s, nil
case string:
b, err := strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s)
}
return b, nil
case []byte:
b, err := strconv.ParseBool(string(s))
if err != nil {
return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s)
}
return b, nil
}
sv := reflect.ValueOf(src)
switch sv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
iv := sv.Int()
if iv == 1 || iv == 0 {
return iv == 1, nil
}
return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", iv)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
uv := sv.Uint()
if uv == 1 || uv == 0 {
return uv == 1, nil
}
return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", uv)
}
return nil, fmt.Errorf("sql/driver: couldn't convert %v (%T) into type bool", src, src)
}
// Int32 is a [ValueConverter] that converts input values to int64,
// respecting the limits of an int32 value.
var Int32 int32Type
type int32Type struct{}
var _ ValueConverter = int32Type{}
func (int32Type) ConvertValue(v any) (Value, error) {
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
i64 := rv.Int()
if i64 > (1<<31)-1 || i64 < -(1<<31) {
return nil, fmt.Errorf("sql/driver: value %d overflows int32", v)
}
return i64, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
u64 := rv.Uint()
if u64 > (1<<31)-1 {
return nil, fmt.Errorf("sql/driver: value %d overflows int32", v)
}
return int64(u64), nil
case reflect.String:
i, err := strconv.Atoi(rv.String())
if err != nil {
return nil, fmt.Errorf("sql/driver: value %q can't be converted to int32", v)
}
return int64(i), nil
}
return nil, fmt.Errorf("sql/driver: unsupported value %v (type %T) converting to int32", v, v)
}
// String is a [ValueConverter] that converts its input to a string.
// If the value is already a string or []byte, it's unchanged.
// If the value is of another type, conversion to string is done
// with fmt.Sprintf("%v", v).
var String stringType
type stringType struct{}
func (stringType) ConvertValue(v any) (Value, error) {
switch v.(type) {
case string, []byte:
return v, nil
}
return fmt.Sprintf("%v", v), nil
}
// Null is a type that implements [ValueConverter] by allowing nil
// values but otherwise delegating to another [ValueConverter].
type Null struct {
Converter ValueConverter
}
func (n Null) ConvertValue(v any) (Value, error) {
if v == nil {
return nil, nil
}
return n.Converter.ConvertValue(v)
}
// NotNull is a type that implements [ValueConverter] by disallowing nil
// values but otherwise delegating to another [ValueConverter].
type NotNull struct {
Converter ValueConverter
}
func (n NotNull) ConvertValue(v any) (Value, error) {
if v == nil {
return nil, fmt.Errorf("nil value not allowed")
}
return n.Converter.ConvertValue(v)
}
// IsValue reports whether v is a valid [Value] parameter type.
func IsValue(v any) bool {
if v == nil {
return true
}
switch v.(type) {
case []byte, bool, float64, int64, string, time.Time:
return true
case decimalDecompose:
return true
}
return false
}
// IsScanValue is equivalent to [IsValue].
// It exists for compatibility.
func IsScanValue(v any) bool {
return IsValue(v)
}
// DefaultParameterConverter is the default implementation of
// [ValueConverter] that's used when a [Stmt] doesn't implement
// [ColumnConverter].
//
// DefaultParameterConverter returns its argument directly if
// IsValue(arg). Otherwise, if the argument implements [Valuer], its
// Value method is used to return a [Value]. As a fallback, the provided
// argument's underlying type is used to convert it to a [Value]:
// underlying integer types are converted to int64, floats to float64,
// bool, string, and []byte to themselves. If the argument is a nil
// pointer, defaultConverter.ConvertValue returns a nil [Value].
// If the argument is a non-nil pointer, it is dereferenced and
// defaultConverter.ConvertValue is called recursively. Other types
// are an error.
var DefaultParameterConverter defaultConverter
type defaultConverter struct{}
var _ ValueConverter = defaultConverter{}
var valuerReflectType = reflect.TypeFor[Valuer]()
// callValuerValue returns vr.Value(), with one exception:
// If vr.Value is an auto-generated method on a pointer type and the
// pointer is nil, it would panic at runtime in the panicwrap
// method. Treat it like nil instead.
// Issue 8415.
//
// This is so people can implement driver.Value on value types and
// still use nil pointers to those types to mean nil/NULL, just like
// string/*string.
//
// This function is mirrored in the database/sql package.
func callValuerValue(vr Valuer) (v Value, err error) {
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer &&
rv.IsNil() &&
rv.Type().Elem().Implements(valuerReflectType) {
return nil, nil
}
return vr.Value()
}
func (defaultConverter) ConvertValue(v any) (Value, error) {
if IsValue(v) {
return v, nil
}
switch vr := v.(type) {
case Valuer:
sv, err := callValuerValue(vr)
if err != nil {
return nil, err
}
if !IsValue(sv) {
return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
}
return sv, nil
// For now, continue to prefer the Valuer interface over the decimal decompose interface.
case decimalDecompose:
return vr, nil
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Pointer:
// indirect pointers
if rv.IsNil() {
return nil, nil
} else {
return defaultConverter{}.ConvertValue(rv.Elem().Interface())
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64(rv.Uint()), nil
case reflect.Uint64:
u64 := rv.Uint()
if u64 >= 1<<63 {
return nil, fmt.Errorf("uint64 values with high bit set are not supported")
}
return int64(u64), nil
case reflect.Float32, reflect.Float64:
return rv.Float(), nil
case reflect.Bool:
return rv.Bool(), nil
case reflect.Slice:
ek := rv.Type().Elem().Kind()
if ek == reflect.Uint8 {
return rv.Bytes(), nil
}
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
case reflect.String:
return rv.String(), nil
}
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}
type decimalDecompose interface {
// Decompose returns the internal decimal state into parts.
// If the provided buf has sufficient capacity, buf may be returned as the coefficient with
// the value set and length set as appropriate.
Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package sql provides a generic interface around SQL (or SQL-like)
// databases.
//
// The sql package must be used in conjunction with a database driver.
// See https://golang.org/s/sqldrivers for a list of drivers.
//
// Drivers that do not support context cancellation will not return until
// after the query is completed.
//
// For usage examples, see the wiki page at
// https://golang.org/s/sqlwiki.
package sql
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"io"
"maps"
"math/rand/v2"
"reflect"
"runtime"
"slices"
"strconv"
"sync"
"sync/atomic"
"time"
_ "unsafe"
)
var driversMu sync.RWMutex
// drivers should be an internal detail,
// but widely used packages access it using linkname.
// (It is extra wrong that they linkname drivers but not driversMu.)
// Notable members of the hall of shame include:
// - github.com/instana/go-sensor
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname drivers
var drivers = make(map[string]driver.Driver)
// nowFunc returns the current time; it's overridden in tests.
var nowFunc = time.Now
// Register makes a database driver available by the provided name.
// If Register is called twice with the same name or if driver is nil,
// it panics.
func Register(name string, driver driver.Driver) {
driversMu.Lock()
defer driversMu.Unlock()
if driver == nil {
panic("sql: Register driver is nil")
}
if _, dup := drivers[name]; dup {
panic("sql: Register called twice for driver " + name)
}
drivers[name] = driver
}
func unregisterAllDrivers() {
driversMu.Lock()
defer driversMu.Unlock()
// For tests.
drivers = make(map[string]driver.Driver)
}
// Drivers returns a sorted list of the names of the registered drivers.
func Drivers() []string {
driversMu.RLock()
defer driversMu.RUnlock()
return slices.Sorted(maps.Keys(drivers))
}
// A NamedArg is a named argument. NamedArg values may be used as
// arguments to [DB.Query] or [DB.Exec] and bind to the corresponding named
// parameter in the SQL statement.
//
// For a more concise way to create NamedArg values, see
// the [Named] function.
type NamedArg struct {
_NamedFieldsRequired struct{}
// Name is the name of the parameter placeholder.
//
// If empty, the ordinal position in the argument list will be
// used.
//
// Name must omit any symbol prefix.
Name string
// Value is the value of the parameter.
// It may be assigned the same value types as the query
// arguments.
Value any
}
// Named provides a more concise way to create [NamedArg] values.
//
// Example usage:
//
// db.ExecContext(ctx, `
// delete from Invoice
// where
// TimeCreated < @end
// and TimeCreated >= @start;`,
// sql.Named("start", startTime),
// sql.Named("end", endTime),
// )
func Named(name string, value any) NamedArg {
// This method exists because the go1compat promise
// doesn't guarantee that structs don't grow more fields,
// so unkeyed struct literals are a vet error. Thus, we don't
// want to allow sql.NamedArg{name, value}.
return NamedArg{Name: name, Value: value}
}
// IsolationLevel is the transaction isolation level used in [TxOptions].
type IsolationLevel int
// Various isolation levels that drivers may support in [DB.BeginTx].
// If a driver does not support a given isolation level an error may be returned.
//
// See https://en.wikipedia.org/wiki/Isolation_(database_systems)#Isolation_levels.
const (
LevelDefault IsolationLevel = iota
LevelReadUncommitted
LevelReadCommitted
LevelWriteCommitted
LevelRepeatableRead
LevelSnapshot
LevelSerializable
LevelLinearizable
)
// String returns the name of the transaction isolation level.
func (i IsolationLevel) String() string {
switch i {
case LevelDefault:
return "Default"
case LevelReadUncommitted:
return "Read Uncommitted"
case LevelReadCommitted:
return "Read Committed"
case LevelWriteCommitted:
return "Write Committed"
case LevelRepeatableRead:
return "Repeatable Read"
case LevelSnapshot:
return "Snapshot"
case LevelSerializable:
return "Serializable"
case LevelLinearizable:
return "Linearizable"
default:
return "IsolationLevel(" + strconv.Itoa(int(i)) + ")"
}
}
var _ fmt.Stringer = LevelDefault
// TxOptions holds the transaction options to be used in [DB.BeginTx].
type TxOptions struct {
// Isolation is the transaction isolation level.
// If zero, the driver or database's default level is used.
Isolation IsolationLevel
ReadOnly bool
}
// RawBytes is a byte slice that holds a reference to memory owned by
// the database itself. After a [Rows.Scan] into a RawBytes, the slice is only
// valid until the next call to [Rows.Next], [Rows.Scan], or [Rows.Close].
type RawBytes []byte
// NullString represents a string that may be null.
// NullString implements the [Scanner] interface so
// it can be used as a scan destination:
//
// var s NullString
// err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&s)
// ...
// if s.Valid {
// // use s.String
// } else {
// // NULL value
// }
type NullString struct {
String string
Valid bool // Valid is true if String is not NULL
}
// Scan implements the [Scanner] interface.
func (ns *NullString) Scan(value any) error {
if value == nil {
ns.String, ns.Valid = "", false
return nil
}
ns.Valid = true
return convertAssign(&ns.String, value)
}
// Value implements the [driver.Valuer] interface.
func (ns NullString) Value() (driver.Value, error) {
if !ns.Valid {
return nil, nil
}
return ns.String, nil
}
// NullInt64 represents an int64 that may be null.
// NullInt64 implements the [Scanner] interface so
// it can be used as a scan destination, similar to [NullString].
type NullInt64 struct {
Int64 int64
Valid bool // Valid is true if Int64 is not NULL
}
// Scan implements the [Scanner] interface.
func (n *NullInt64) Scan(value any) error {
if value == nil {
n.Int64, n.Valid = 0, false
return nil
}
n.Valid = true
return convertAssign(&n.Int64, value)
}
// Value implements the [driver.Valuer] interface.
func (n NullInt64) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int64, nil
}
// NullInt32 represents an int32 that may be null.
// NullInt32 implements the [Scanner] interface so
// it can be used as a scan destination, similar to [NullString].
type NullInt32 struct {
Int32 int32
Valid bool // Valid is true if Int32 is not NULL
}
// Scan implements the [Scanner] interface.
func (n *NullInt32) Scan(value any) error {
if value == nil {
n.Int32, n.Valid = 0, false
return nil
}
n.Valid = true
return convertAssign(&n.Int32, value)
}
// Value implements the [driver.Valuer] interface.
func (n NullInt32) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return int64(n.Int32), nil
}
// NullInt16 represents an int16 that may be null.
// NullInt16 implements the [Scanner] interface so
// it can be used as a scan destination, similar to [NullString].
type NullInt16 struct {
Int16 int16
Valid bool // Valid is true if Int16 is not NULL
}
// Scan implements the [Scanner] interface.
func (n *NullInt16) Scan(value any) error {
if value == nil {
n.Int16, n.Valid = 0, false
return nil
}
err := convertAssign(&n.Int16, value)
n.Valid = err == nil
return err
}
// Value implements the [driver.Valuer] interface.
func (n NullInt16) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return int64(n.Int16), nil
}
// NullByte represents a byte that may be null.
// NullByte implements the [Scanner] interface so
// it can be used as a scan destination, similar to [NullString].
type NullByte struct {
Byte byte
Valid bool // Valid is true if Byte is not NULL
}
// Scan implements the [Scanner] interface.
func (n *NullByte) Scan(value any) error {
if value == nil {
n.Byte, n.Valid = 0, false
return nil
}
err := convertAssign(&n.Byte, value)
n.Valid = err == nil
return err
}
// Value implements the [driver.Valuer] interface.
func (n NullByte) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return int64(n.Byte), nil
}
// NullFloat64 represents a float64 that may be null.
// NullFloat64 implements the [Scanner] interface so
// it can be used as a scan destination, similar to [NullString].
type NullFloat64 struct {
Float64 float64
Valid bool // Valid is true if Float64 is not NULL
}
// Scan implements the [Scanner] interface.
func (n *NullFloat64) Scan(value any) error {
if value == nil {
n.Float64, n.Valid = 0, false
return nil
}
n.Valid = true
return convertAssign(&n.Float64, value)
}
// Value implements the [driver.Valuer] interface.
func (n NullFloat64) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Float64, nil
}
// NullBool represents a bool that may be null.
// NullBool implements the [Scanner] interface so
// it can be used as a scan destination, similar to [NullString].
type NullBool struct {
Bool bool
Valid bool // Valid is true if Bool is not NULL
}
// Scan implements the [Scanner] interface.
func (n *NullBool) Scan(value any) error {
if value == nil {
n.Bool, n.Valid = false, false
return nil
}
n.Valid = true
return convertAssign(&n.Bool, value)
}
// Value implements the [driver.Valuer] interface.
func (n NullBool) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Bool, nil
}
// NullTime represents a [time.Time] that may be null.
// NullTime implements the [Scanner] interface so
// it can be used as a scan destination, similar to [NullString].
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}
// Scan implements the [Scanner] interface.
func (n *NullTime) Scan(value any) error {
if value == nil {
n.Time, n.Valid = time.Time{}, false
return nil
}
n.Valid = true
return convertAssign(&n.Time, value)
}
// Value implements the [driver.Valuer] interface.
func (n NullTime) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Time, nil
}
// Null represents a value that may be null.
// Null implements the [Scanner] interface so
// it can be used as a scan destination:
//
// var s Null[string]
// err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&s)
// ...
// if s.Valid {
// // use s.V
// } else {
// // NULL value
// }
//
// T should be one of the types accepted by [driver.Value].
type Null[T any] struct {
V T
Valid bool
}
func (n *Null[T]) Scan(value any) error {
if value == nil {
n.V, n.Valid = *new(T), false
return nil
}
n.Valid = true
return convertAssign(&n.V, value)
}
func (n Null[T]) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
v := any(n.V)
// See issue 69728.
if valuer, ok := v.(driver.Valuer); ok {
val, err := callValuerValue(valuer)
if err != nil {
return val, err
}
v = val
}
// See issue 69837.
return driver.DefaultParameterConverter.ConvertValue(v)
}
// Scanner is an interface used by [Rows.Scan].
type Scanner interface {
// Scan assigns a value from a database driver.
//
// The src value will be of one of the following types:
//
// int64
// float64
// bool
// []byte
// string
// time.Time
// nil - for NULL values
//
// An error should be returned if the value cannot be stored
// without loss of information.
//
// Reference types such as []byte are only valid until the next call to Scan
// and should not be retained. Their underlying memory is owned by the driver.
// If retention is necessary, copy their values before the next call to Scan.
Scan(src any) error
}
// Out may be used to retrieve OUTPUT value parameters from stored procedures.
//
// Not all drivers and databases support OUTPUT value parameters.
//
// Example usage:
//
// var outArg string
// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", sql.Out{Dest: &outArg}))
type Out struct {
_NamedFieldsRequired struct{}
// Dest is a pointer to the value that will be set to the result of the
// stored procedure's OUTPUT parameter.
Dest any
// In is whether the parameter is an INOUT parameter. If so, the input value to the stored
// procedure is the dereferenced value of Dest's pointer, which is then replaced with
// the output value.
In bool
}
// ErrNoRows is returned by [Row.Scan] when [DB.QueryRow] doesn't return a
// row. In such a case, QueryRow returns a placeholder [*Row] value that
// defers this error until a Scan.
var ErrNoRows = errors.New("sql: no rows in result set")
// DB is a database handle representing a pool of zero or more
// underlying connections. It's safe for concurrent use by multiple
// goroutines.
//
// The sql package creates and frees connections automatically; it
// also maintains a free pool of idle connections. If the database has
// a concept of per-connection state, such state can be reliably observed
// within a transaction ([Tx]) or connection ([Conn]). Once [DB.Begin] is called, the
// returned [Tx] is bound to a single connection. Once [Tx.Commit] or
// [Tx.Rollback] is called on the transaction, that transaction's
// connection is returned to [DB]'s idle connection pool. The pool size
// can be controlled with [DB.SetMaxIdleConns].
type DB struct {
// Total time waited for new connections.
waitDuration atomic.Int64
connector driver.Connector
// numClosed is an atomic counter which represents a total number of
// closed connections. Stmt.openStmt checks it before cleaning closed
// connections in Stmt.css.
numClosed atomic.Uint64
mu sync.Mutex // protects following fields
freeConn []*driverConn // free connections ordered by returnedAt oldest to newest
connRequests connRequestSet
numOpen int // number of opened and pending open connections
// Used to signal the need for new connections
// a goroutine running connectionOpener() reads on this chan and
// maybeOpenNewConnections sends on the chan (one send per needed connection)
// It is closed during db.Close(). The close tells the connectionOpener
// goroutine to exit.
openerCh chan struct{}
closed bool
dep map[finalCloser]depSet
lastPut map[*driverConn]string // stacktrace of last conn's put; debug only
maxIdleCount int // zero means defaultMaxIdleConns; negative means 0
maxOpen int // <= 0 means unlimited
maxLifetime time.Duration // maximum amount of time a connection may be reused
maxIdleTime time.Duration // maximum amount of time a connection may be idle before being closed
cleanerCh chan struct{}
waitCount int64 // Total number of connections waited for.
maxIdleClosed int64 // Total number of connections closed due to idle count.
maxIdleTimeClosed int64 // Total number of connections closed due to idle time.
maxLifetimeClosed int64 // Total number of connections closed due to max connection lifetime limit.
stop func() // stop cancels the connection opener.
}
// connReuseStrategy determines how (*DB).conn returns database connections.
type connReuseStrategy uint8
const (
// alwaysNewConn forces a new connection to the database.
alwaysNewConn connReuseStrategy = iota
// cachedOrNewConn returns a cached connection, if available, else waits
// for one to become available (if MaxOpenConns has been reached) or
// creates a new database connection.
cachedOrNewConn
)
// driverConn wraps a driver.Conn with a mutex, to
// be held during all calls into the Conn. (including any calls onto
// interfaces returned via that Conn, such as calls on Tx, Stmt,
// Result, Rows)
type driverConn struct {
db *DB
createdAt time.Time
sync.Mutex // guards following
ci driver.Conn
needReset bool // The connection session should be reset before use if true.
closed bool
finalClosed bool // ci.Close has been called
openStmt map[*driverStmt]bool
// guarded by db.mu
inUse bool
dbmuClosed bool // same as closed, but guarded by db.mu, for removeClosedStmtLocked
returnedAt time.Time // Time the connection was created or returned.
onPut []func() // code (with db.mu held) run when conn is next returned
}
func (dc *driverConn) releaseConn(err error) {
dc.db.putConn(dc, err, true)
}
func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
dc.Lock()
defer dc.Unlock()
delete(dc.openStmt, ds)
}
func (dc *driverConn) expired(timeout time.Duration) bool {
if timeout <= 0 {
return false
}
return dc.createdAt.Add(timeout).Before(nowFunc())
}
// resetSession checks if the driver connection needs the
// session to be reset and if required, resets it.
func (dc *driverConn) resetSession(ctx context.Context) error {
dc.Lock()
defer dc.Unlock()
if !dc.needReset {
return nil
}
if cr, ok := dc.ci.(driver.SessionResetter); ok {
return cr.ResetSession(ctx)
}
return nil
}
// validateConnection checks if the connection is valid and can
// still be used. It also marks the session for reset if required.
func (dc *driverConn) validateConnection(needsReset bool) bool {
dc.Lock()
defer dc.Unlock()
if needsReset {
dc.needReset = true
}
if cv, ok := dc.ci.(driver.Validator); ok {
return cv.IsValid()
}
return true
}
// prepareLocked prepares the query on dc. When cg == nil the dc must keep track of
// the prepared statements in a pool.
func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) {
si, err := ctxDriverPrepare(ctx, dc.ci, query)
if err != nil {
return nil, err
}
ds := &driverStmt{Locker: dc, si: si}
// No need to manage open statements if there is a single connection grabber.
if cg != nil {
return ds, nil
}
// Track each driverConn's open statements, so we can close them
// before closing the conn.
//
// Wrap all driver.Stmt is *driverStmt to ensure they are only closed once.
if dc.openStmt == nil {
dc.openStmt = make(map[*driverStmt]bool)
}
dc.openStmt[ds] = true
return ds, nil
}
// the dc.db's Mutex is held.
func (dc *driverConn) closeDBLocked() func() error {
dc.Lock()
defer dc.Unlock()
if dc.closed {
return func() error { return errors.New("sql: duplicate driverConn close") }
}
dc.closed = true
return dc.db.removeDepLocked(dc, dc)
}
func (dc *driverConn) Close() error {
dc.Lock()
if dc.closed {
dc.Unlock()
return errors.New("sql: duplicate driverConn close")
}
dc.closed = true
dc.Unlock() // not defer; removeDep finalClose calls may need to lock
// And now updates that require holding dc.mu.Lock.
dc.db.mu.Lock()
dc.dbmuClosed = true
fn := dc.db.removeDepLocked(dc, dc)
dc.db.mu.Unlock()
return fn()
}
func (dc *driverConn) finalClose() error {
var err error
// Each *driverStmt has a lock to the dc. Copy the list out of the dc
// before calling close on each stmt.
var openStmt []*driverStmt
withLock(dc, func() {
openStmt = make([]*driverStmt, 0, len(dc.openStmt))
for ds := range dc.openStmt {
openStmt = append(openStmt, ds)
}
dc.openStmt = nil
})
for _, ds := range openStmt {
ds.Close()
}
withLock(dc, func() {
dc.finalClosed = true
err = dc.ci.Close()
dc.ci = nil
})
dc.db.mu.Lock()
dc.db.numOpen--
dc.db.maybeOpenNewConnections()
dc.db.mu.Unlock()
dc.db.numClosed.Add(1)
return err
}
// driverStmt associates a driver.Stmt with the
// *driverConn from which it came, so the driverConn's lock can be
// held during calls.
type driverStmt struct {
sync.Locker // the *driverConn
si driver.Stmt
closed bool
closeErr error // return value of previous Close call
}
// Close ensures driver.Stmt is only closed once and always returns the same
// result.
func (ds *driverStmt) Close() error {
ds.Lock()
defer ds.Unlock()
if ds.closed {
return ds.closeErr
}
ds.closed = true
ds.closeErr = ds.si.Close()
return ds.closeErr
}
// depSet is a finalCloser's outstanding dependencies
type depSet map[any]bool // set of true bools
// The finalCloser interface is used by (*DB).addDep and related
// dependency reference counting.
type finalCloser interface {
// finalClose is called when the reference count of an object
// goes to zero. (*DB).mu is not held while calling it.
finalClose() error
}
// addDep notes that x now depends on dep, and x's finalClose won't be
// called until all of x's dependencies are removed with removeDep.
func (db *DB) addDep(x finalCloser, dep any) {
db.mu.Lock()
defer db.mu.Unlock()
db.addDepLocked(x, dep)
}
func (db *DB) addDepLocked(x finalCloser, dep any) {
if db.dep == nil {
db.dep = make(map[finalCloser]depSet)
}
xdep := db.dep[x]
if xdep == nil {
xdep = make(depSet)
db.dep[x] = xdep
}
xdep[dep] = true
}
// removeDep notes that x no longer depends on dep.
// If x still has dependencies, nil is returned.
// If x no longer has any dependencies, its finalClose method will be
// called and its error value will be returned.
func (db *DB) removeDep(x finalCloser, dep any) error {
db.mu.Lock()
fn := db.removeDepLocked(x, dep)
db.mu.Unlock()
return fn()
}
func (db *DB) removeDepLocked(x finalCloser, dep any) func() error {
xdep, ok := db.dep[x]
if !ok {
panic(fmt.Sprintf("unpaired removeDep: no deps for %T", x))
}
l0 := len(xdep)
delete(xdep, dep)
switch len(xdep) {
case l0:
// Nothing removed. Shouldn't happen.
panic(fmt.Sprintf("unpaired removeDep: no %T dep on %T", dep, x))
case 0:
// No more dependencies.
delete(db.dep, x)
return x.finalClose
default:
// Dependencies remain.
return func() error { return nil }
}
}
// This is the size of the connectionOpener request chan (DB.openerCh).
// This value should be larger than the maximum typical value
// used for DB.maxOpen. If maxOpen is significantly larger than
// connectionRequestQueueSize then it is possible for ALL calls into the *DB
// to block until the connectionOpener can satisfy the backlog of requests.
var connectionRequestQueueSize = 1000000
type dsnConnector struct {
dsn string
driver driver.Driver
}
func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
return t.driver.Open(t.dsn)
}
func (t dsnConnector) Driver() driver.Driver {
return t.driver
}
// OpenDB opens a database using a [driver.Connector], allowing drivers to
// bypass a string based data source name.
//
// Most users will open a database via a driver-specific connection
// helper function that returns a [*DB]. No database drivers are included
// in the Go standard library. See https://golang.org/s/sqldrivers for
// a list of third-party drivers.
//
// OpenDB may just validate its arguments without creating a connection
// to the database. To verify that the data source name is valid, call
// [DB.Ping].
//
// The returned [DB] is safe for concurrent use by multiple goroutines
// and maintains its own pool of idle connections. Thus, the OpenDB
// function should be called just once. It is rarely necessary to
// close a [DB].
func OpenDB(c driver.Connector) *DB {
ctx, cancel := context.WithCancel(context.Background())
db := &DB{
connector: c,
openerCh: make(chan struct{}, connectionRequestQueueSize),
lastPut: make(map[*driverConn]string),
stop: cancel,
}
go db.connectionOpener(ctx)
return db
}
// Open opens a database specified by its database driver name and a
// driver-specific data source name, usually consisting of at least a
// database name and connection information.
//
// Most users will open a database via a driver-specific connection
// helper function that returns a [*DB]. No database drivers are included
// in the Go standard library. See https://golang.org/s/sqldrivers for
// a list of third-party drivers.
//
// Open may just validate its arguments without creating a connection
// to the database. To verify that the data source name is valid, call
// [DB.Ping].
//
// The returned [DB] is safe for concurrent use by multiple goroutines
// and maintains its own pool of idle connections. Thus, the Open
// function should be called just once. It is rarely necessary to
// close a [DB].
func Open(driverName, dataSourceName string) (*DB, error) {
driversMu.RLock()
driveri, ok := drivers[driverName]
driversMu.RUnlock()
if !ok {
return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
}
if driverCtx, ok := driveri.(driver.DriverContext); ok {
connector, err := driverCtx.OpenConnector(dataSourceName)
if err != nil {
return nil, err
}
return OpenDB(connector), nil
}
return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
}
func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error {
var err error
if pinger, ok := dc.ci.(driver.Pinger); ok {
withLock(dc, func() {
err = pinger.Ping(ctx)
})
}
release(err)
return err
}
// PingContext verifies a connection to the database is still alive,
// establishing a connection if necessary.
func (db *DB) PingContext(ctx context.Context) error {
var dc *driverConn
var err error
err = db.retry(func(strategy connReuseStrategy) error {
dc, err = db.conn(ctx, strategy)
return err
})
if err != nil {
return err
}
return db.pingDC(ctx, dc, dc.releaseConn)
}
// Ping verifies a connection to the database is still alive,
// establishing a connection if necessary.
//
// Ping uses [context.Background] internally; to specify the context, use
// [DB.PingContext].
func (db *DB) Ping() error {
return db.PingContext(context.Background())
}
// Close closes the database and prevents new queries from starting.
// Close then waits for all queries that have started processing on the server
// to finish.
//
// It is rare to Close a [DB], as the [DB] handle is meant to be
// long-lived and shared between many goroutines.
func (db *DB) Close() error {
db.mu.Lock()
if db.closed { // Make DB.Close idempotent
db.mu.Unlock()
return nil
}
if db.cleanerCh != nil {
close(db.cleanerCh)
}
var err error
fns := make([]func() error, 0, len(db.freeConn))
for _, dc := range db.freeConn {
fns = append(fns, dc.closeDBLocked())
}
db.freeConn = nil
db.closed = true
db.connRequests.CloseAndRemoveAll()
db.mu.Unlock()
for _, fn := range fns {
err1 := fn()
if err1 != nil {
err = err1
}
}
db.stop()
if c, ok := db.connector.(io.Closer); ok {
err1 := c.Close()
if err1 != nil {
err = err1
}
}
return err
}
const defaultMaxIdleConns = 2
func (db *DB) maxIdleConnsLocked() int {
n := db.maxIdleCount
switch {
case n == 0:
// TODO(bradfitz): ask driver, if supported, for its default preference
return defaultMaxIdleConns
case n < 0:
return 0
default:
return n
}
}
func (db *DB) shortestIdleTimeLocked() time.Duration {
if db.maxIdleTime <= 0 {
return db.maxLifetime
}
if db.maxLifetime <= 0 {
return db.maxIdleTime
}
return min(db.maxIdleTime, db.maxLifetime)
}
// SetMaxIdleConns sets the maximum number of connections in the idle
// connection pool.
//
// If MaxOpenConns is greater than 0 but less than the new MaxIdleConns,
// then the new MaxIdleConns will be reduced to match the MaxOpenConns limit.
//
// If n <= 0, no idle connections are retained.
//
// The default max idle connections is currently 2. This may change in
// a future release.
func (db *DB) SetMaxIdleConns(n int) {
db.mu.Lock()
if n > 0 {
db.maxIdleCount = n
} else {
// No idle connections.
db.maxIdleCount = -1
}
// Make sure maxIdle doesn't exceed maxOpen
if db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen {
db.maxIdleCount = db.maxOpen
}
var closing []*driverConn
idleCount := len(db.freeConn)
maxIdle := db.maxIdleConnsLocked()
if idleCount > maxIdle {
closing = db.freeConn[maxIdle:]
db.freeConn = db.freeConn[:maxIdle]
}
db.maxIdleClosed += int64(len(closing))
db.mu.Unlock()
for _, c := range closing {
c.Close()
}
}
// SetMaxOpenConns sets the maximum number of open connections to the database.
//
// If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than
// MaxIdleConns, then MaxIdleConns will be reduced to match the new
// MaxOpenConns limit.
//
// If n <= 0, then there is no limit on the number of open connections.
// The default is 0 (unlimited).
func (db *DB) SetMaxOpenConns(n int) {
db.mu.Lock()
db.maxOpen = n
if n < 0 {
db.maxOpen = 0
}
syncMaxIdle := db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen
db.mu.Unlock()
if syncMaxIdle {
db.SetMaxIdleConns(n)
}
}
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused.
//
// Expired connections may be closed lazily before reuse.
//
// If d <= 0, connections are not closed due to a connection's age.
func (db *DB) SetConnMaxLifetime(d time.Duration) {
if d < 0 {
d = 0
}
db.mu.Lock()
// Wake cleaner up when lifetime is shortened.
if d > 0 && d < db.shortestIdleTimeLocked() && db.cleanerCh != nil {
select {
case db.cleanerCh <- struct{}{}:
default:
}
}
db.maxLifetime = d
db.startCleanerLocked()
db.mu.Unlock()
}
// SetConnMaxIdleTime sets the maximum amount of time a connection may be idle.
//
// Expired connections may be closed lazily before reuse.
//
// If d <= 0, connections are not closed due to a connection's idle time.
func (db *DB) SetConnMaxIdleTime(d time.Duration) {
if d < 0 {
d = 0
}
db.mu.Lock()
defer db.mu.Unlock()
// Wake cleaner up when idle time is shortened.
if d > 0 && d < db.shortestIdleTimeLocked() && db.cleanerCh != nil {
select {
case db.cleanerCh <- struct{}{}:
default:
}
}
db.maxIdleTime = d
db.startCleanerLocked()
}
// startCleanerLocked starts connectionCleaner if needed.
func (db *DB) startCleanerLocked() {
if (db.maxLifetime > 0 || db.maxIdleTime > 0) && db.numOpen > 0 && db.cleanerCh == nil {
db.cleanerCh = make(chan struct{}, 1)
go db.connectionCleaner(db.shortestIdleTimeLocked())
}
}
func (db *DB) connectionCleaner(d time.Duration) {
const minInterval = time.Second
if d < minInterval {
d = minInterval
}
t := time.NewTimer(d)
for {
select {
case <-t.C:
case <-db.cleanerCh: // maxLifetime was changed or db was closed.
}
db.mu.Lock()
d = db.shortestIdleTimeLocked()
if db.closed || db.numOpen == 0 || d <= 0 {
db.cleanerCh = nil
db.mu.Unlock()
return
}
d, closing := db.connectionCleanerRunLocked(d)
db.mu.Unlock()
for _, c := range closing {
c.Close()
}
if d < minInterval {
d = minInterval
}
if !t.Stop() {
select {
case <-t.C:
default:
}
}
t.Reset(d)
}
}
// connectionCleanerRunLocked removes connections that should be closed from
// freeConn and returns them along side an updated duration to the next check
// if a quicker check is required to ensure connections are checked appropriately.
func (db *DB) connectionCleanerRunLocked(d time.Duration) (time.Duration, []*driverConn) {
var idleClosing int64
var closing []*driverConn
if db.maxIdleTime > 0 {
// As freeConn is ordered by returnedAt process
// in reverse order to minimise the work needed.
idleSince := nowFunc().Add(-db.maxIdleTime)
last := len(db.freeConn) - 1
for i := last; i >= 0; i-- {
c := db.freeConn[i]
if c.returnedAt.Before(idleSince) {
i++
closing = db.freeConn[:i:i]
db.freeConn = db.freeConn[i:]
idleClosing = int64(len(closing))
db.maxIdleTimeClosed += idleClosing
break
}
}
if len(db.freeConn) > 0 {
c := db.freeConn[0]
if d2 := c.returnedAt.Sub(idleSince); d2 < d {
// Ensure idle connections are cleaned up as soon as
// possible.
d = d2
}
}
}
if db.maxLifetime > 0 {
expiredSince := nowFunc().Add(-db.maxLifetime)
for i := 0; i < len(db.freeConn); i++ {
c := db.freeConn[i]
if c.createdAt.Before(expiredSince) {
closing = append(closing, c)
last := len(db.freeConn) - 1
// Use slow delete as order is required to ensure
// connections are reused least idle time first.
copy(db.freeConn[i:], db.freeConn[i+1:])
db.freeConn[last] = nil
db.freeConn = db.freeConn[:last]
i--
} else if d2 := c.createdAt.Sub(expiredSince); d2 < d {
// Prevent connections sitting the freeConn when they
// have expired by updating our next deadline d.
d = d2
}
}
db.maxLifetimeClosed += int64(len(closing)) - idleClosing
}
return d, closing
}
// DBStats contains database statistics.
type DBStats struct {
MaxOpenConnections int // Maximum number of open connections to the database.
// Pool Status
OpenConnections int // The number of established connections both in use and idle.
InUse int // The number of connections currently in use.
Idle int // The number of idle connections.
// Counters
WaitCount int64 // The total number of connections waited for.
WaitDuration time.Duration // The total time blocked waiting for a new connection.
MaxIdleClosed int64 // The total number of connections closed due to SetMaxIdleConns.
MaxIdleTimeClosed int64 // The total number of connections closed due to SetConnMaxIdleTime.
MaxLifetimeClosed int64 // The total number of connections closed due to SetConnMaxLifetime.
}
// Stats returns database statistics.
func (db *DB) Stats() DBStats {
wait := db.waitDuration.Load()
db.mu.Lock()
defer db.mu.Unlock()
stats := DBStats{
MaxOpenConnections: db.maxOpen,
Idle: len(db.freeConn),
OpenConnections: db.numOpen,
InUse: db.numOpen - len(db.freeConn),
WaitCount: db.waitCount,
WaitDuration: time.Duration(wait),
MaxIdleClosed: db.maxIdleClosed,
MaxIdleTimeClosed: db.maxIdleTimeClosed,
MaxLifetimeClosed: db.maxLifetimeClosed,
}
return stats
}
// Assumes db.mu is locked.
// If there are connRequests and the connection limit hasn't been reached,
// then tell the connectionOpener to open new connections.
func (db *DB) maybeOpenNewConnections() {
numRequests := db.connRequests.Len()
if db.maxOpen > 0 {
numCanOpen := db.maxOpen - db.numOpen
if numRequests > numCanOpen {
numRequests = numCanOpen
}
}
for numRequests > 0 {
db.numOpen++ // optimistically
numRequests--
if db.closed {
return
}
db.openerCh <- struct{}{}
}
}
// Runs in a separate goroutine, opens new connections when requested.
func (db *DB) connectionOpener(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-db.openerCh:
db.openNewConnection(ctx)
}
}
}
// Open one new connection
func (db *DB) openNewConnection(ctx context.Context) {
// maybeOpenNewConnections has already executed db.numOpen++ before it sent
// on db.openerCh. This function must execute db.numOpen-- if the
// connection fails or is closed before returning.
ci, err := db.connector.Connect(ctx)
db.mu.Lock()
defer db.mu.Unlock()
if db.closed {
if err == nil {
ci.Close()
}
db.numOpen--
return
}
if err != nil {
db.numOpen--
db.putConnDBLocked(nil, err)
db.maybeOpenNewConnections()
return
}
dc := &driverConn{
db: db,
createdAt: nowFunc(),
returnedAt: nowFunc(),
ci: ci,
}
if db.putConnDBLocked(dc, err) {
db.addDepLocked(dc, dc)
} else {
db.numOpen--
ci.Close()
}
}
// connRequest represents one request for a new connection
// When there are no idle connections available, DB.conn will create
// a new connRequest and put it on the db.connRequests list.
type connRequest struct {
conn *driverConn
err error
}
var errDBClosed = errors.New("sql: database is closed")
// conn returns a newly-opened or cached *driverConn.
func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
db.mu.Lock()
if db.closed {
db.mu.Unlock()
return nil, errDBClosed
}
// Check if the context is expired.
select {
default:
case <-ctx.Done():
db.mu.Unlock()
return nil, ctx.Err()
}
lifetime := db.maxLifetime
// Prefer a free connection, if possible.
last := len(db.freeConn) - 1
if strategy == cachedOrNewConn && last >= 0 {
// Reuse the lowest idle time connection so we can close
// connections which remain idle as soon as possible.
conn := db.freeConn[last]
db.freeConn = db.freeConn[:last]
conn.inUse = true
if conn.expired(lifetime) {
db.maxLifetimeClosed++
db.mu.Unlock()
conn.Close()
return nil, driver.ErrBadConn
}
db.mu.Unlock()
// Reset the session if required.
if err := conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
conn.Close()
return nil, err
}
return conn, nil
}
// Out of free connections or we were asked not to use one. If we're not
// allowed to open any more connections, make a request and wait.
if db.maxOpen > 0 && db.numOpen >= db.maxOpen {
// Make the connRequest channel. It's buffered so that the
// connectionOpener doesn't block while waiting for the req to be read.
req := make(chan connRequest, 1)
delHandle := db.connRequests.Add(req)
db.waitCount++
db.mu.Unlock()
waitStart := nowFunc()
// Timeout the connection request with the context.
select {
case <-ctx.Done():
// Remove the connection request and ensure no value has been sent
// on it after removing.
db.mu.Lock()
deleted := db.connRequests.Delete(delHandle)
db.mu.Unlock()
db.waitDuration.Add(int64(time.Since(waitStart)))
// If we failed to delete it, that means either the DB was closed or
// something else grabbed it and is about to send on it.
if !deleted {
// TODO(bradfitz): rather than this best effort select, we
// should probably start a goroutine to read from req. This best
// effort select existed before the change to check 'deleted'.
// But if we know for sure it wasn't deleted and a sender is
// outstanding, we should probably block on req (in a new
// goroutine) to get the connection back.
select {
default:
case ret, ok := <-req:
if ok && ret.conn != nil {
db.putConn(ret.conn, ret.err, false)
}
}
}
return nil, ctx.Err()
case ret, ok := <-req:
db.waitDuration.Add(int64(time.Since(waitStart)))
if !ok {
return nil, errDBClosed
}
// Only check if the connection is expired if the strategy is cachedOrNewConns.
// If we require a new connection, just re-use the connection without looking
// at the expiry time. If it is expired, it will be checked when it is placed
// back into the connection pool.
// This prioritizes giving a valid connection to a client over the exact connection
// lifetime, which could expire exactly after this point anyway.
if strategy == cachedOrNewConn && ret.err == nil && ret.conn.expired(lifetime) {
db.mu.Lock()
db.maxLifetimeClosed++
db.mu.Unlock()
ret.conn.Close()
return nil, driver.ErrBadConn
}
if ret.conn == nil {
return nil, ret.err
}
// Reset the session if required.
if err := ret.conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
ret.conn.Close()
return nil, err
}
return ret.conn, ret.err
}
}
db.numOpen++ // optimistically
db.mu.Unlock()
ci, err := db.connector.Connect(ctx)
if err != nil {
db.mu.Lock()
db.numOpen-- // correct for earlier optimism
db.maybeOpenNewConnections()
db.mu.Unlock()
return nil, err
}
db.mu.Lock()
dc := &driverConn{
db: db,
createdAt: nowFunc(),
returnedAt: nowFunc(),
ci: ci,
inUse: true,
}
db.addDepLocked(dc, dc)
db.mu.Unlock()
return dc, nil
}
// putConnHook is a hook for testing.
var putConnHook func(*DB, *driverConn)
// noteUnusedDriverStatement notes that ds is no longer used and should
// be closed whenever possible (when c is next not in use), unless c is
// already closed.
func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
db.mu.Lock()
defer db.mu.Unlock()
if c.inUse {
c.onPut = append(c.onPut, func() {
ds.Close()
})
} else {
c.Lock()
fc := c.finalClosed
c.Unlock()
if !fc {
ds.Close()
}
}
}
// debugGetPut determines whether getConn & putConn calls' stack traces
// are returned for more verbose crashes.
const debugGetPut = false
// putConn adds a connection to the db's free pool.
// err is optionally the last error that occurred on this connection.
func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
if !errors.Is(err, driver.ErrBadConn) {
if !dc.validateConnection(resetSession) {
err = driver.ErrBadConn
}
}
db.mu.Lock()
if !dc.inUse {
db.mu.Unlock()
if debugGetPut {
fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc])
}
panic("sql: connection returned that was never out")
}
if !errors.Is(err, driver.ErrBadConn) && dc.expired(db.maxLifetime) {
db.maxLifetimeClosed++
err = driver.ErrBadConn
}
if debugGetPut {
db.lastPut[dc] = stack()
}
dc.inUse = false
dc.returnedAt = nowFunc()
for _, fn := range dc.onPut {
fn()
}
dc.onPut = nil
if errors.Is(err, driver.ErrBadConn) {
// Don't reuse bad connections.
// Since the conn is considered bad and is being discarded, treat it
// as closed. Don't decrement the open count here, finalClose will
// take care of that.
db.maybeOpenNewConnections()
db.mu.Unlock()
dc.Close()
return
}
if putConnHook != nil {
putConnHook(db, dc)
}
added := db.putConnDBLocked(dc, nil)
db.mu.Unlock()
if !added {
dc.Close()
return
}
}
// Satisfy a connRequest or put the driverConn in the idle pool and return true
// or return false.
// putConnDBLocked will satisfy a connRequest if there is one, or it will
// return the *driverConn to the freeConn list if err == nil and the idle
// connection limit will not be exceeded.
// If err != nil, the value of dc is ignored.
// If err == nil, then dc must not equal nil.
// If a connRequest was fulfilled or the *driverConn was placed in the
// freeConn list, then true is returned, otherwise false is returned.
func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
if db.closed {
return false
}
if db.maxOpen > 0 && db.numOpen > db.maxOpen {
return false
}
if req, ok := db.connRequests.TakeRandom(); ok {
if err == nil {
dc.inUse = true
}
req <- connRequest{
conn: dc,
err: err,
}
return true
} else if err == nil && !db.closed {
if db.maxIdleConnsLocked() > len(db.freeConn) {
db.freeConn = append(db.freeConn, dc)
db.startCleanerLocked()
return true
}
db.maxIdleClosed++
}
return false
}
// maxBadConnRetries is the number of maximum retries if the driver returns
// driver.ErrBadConn to signal a broken connection before forcing a new
// connection to be opened.
const maxBadConnRetries = 2
func (db *DB) retry(fn func(strategy connReuseStrategy) error) error {
for i := int64(0); i < maxBadConnRetries; i++ {
err := fn(cachedOrNewConn)
// retry if err is driver.ErrBadConn
if err == nil || !errors.Is(err, driver.ErrBadConn) {
return err
}
}
return fn(alwaysNewConn)
}
// PrepareContext creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
// The caller must call the statement's [*Stmt.Close] method
// when the statement is no longer needed.
//
// The provided context is used for the preparation of the statement, not for the
// execution of the statement.
func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
var stmt *Stmt
var err error
err = db.retry(func(strategy connReuseStrategy) error {
stmt, err = db.prepare(ctx, query, strategy)
return err
})
return stmt, err
}
// Prepare creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
// The caller must call the statement's [*Stmt.Close] method
// when the statement is no longer needed.
//
// Prepare uses [context.Background] internally; to specify the context, use
// [DB.PrepareContext].
func (db *DB) Prepare(query string) (*Stmt, error) {
return db.PrepareContext(context.Background(), query)
}
func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
// TODO: check if db.driver supports an optional
// driver.Preparer interface and call that instead, if so,
// otherwise we make a prepared statement that's bound
// to a connection, and to execute this prepared statement
// we either need to use this connection (if it's free), else
// get a new connection + re-prepare + execute on that one.
dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
return db.prepareDC(ctx, dc, dc.releaseConn, nil, query)
}
// prepareDC prepares a query on the driverConn and calls release before
// returning. When cg == nil it implies that a connection pool is used, and
// when cg != nil only a single driver connection is used.
func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) {
var ds *driverStmt
var err error
defer func() {
release(err)
}()
withLock(dc, func() {
ds, err = dc.prepareLocked(ctx, cg, query)
})
if err != nil {
return nil, err
}
stmt := &Stmt{
db: db,
query: query,
cg: cg,
cgds: ds,
}
// When cg == nil this statement will need to keep track of various
// connections they are prepared on and record the stmt dependency on
// the DB.
if cg == nil {
stmt.css = []connStmt{{dc, ds}}
stmt.lastNumClosed = db.numClosed.Load()
db.addDep(stmt, stmt)
}
return stmt, nil
}
// ExecContext executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
var res Result
var err error
err = db.retry(func(strategy connReuseStrategy) error {
res, err = db.exec(ctx, query, args, strategy)
return err
})
return res, err
}
// Exec executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
//
// Exec uses [context.Background] internally; to specify the context, use
// [DB.ExecContext].
func (db *DB) Exec(query string, args ...any) (Result, error) {
return db.ExecContext(context.Background(), query, args...)
}
func (db *DB) exec(ctx context.Context, query string, args []any, strategy connReuseStrategy) (Result, error) {
dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
return db.execDC(ctx, dc, dc.releaseConn, query, args)
}
func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), query string, args []any) (res Result, err error) {
defer func() {
release(err)
}()
execerCtx, ok := dc.ci.(driver.ExecerContext)
var execer driver.Execer
if !ok {
execer, ok = dc.ci.(driver.Execer)
}
if ok {
var nvdargs []driver.NamedValue
var resi driver.Result
withLock(dc, func() {
nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
if err != nil {
return
}
resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
})
if err != driver.ErrSkip {
if err != nil {
return nil, err
}
return driverResult{dc, resi}, nil
}
}
var si driver.Stmt
withLock(dc, func() {
si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
if err != nil {
return nil, err
}
ds := &driverStmt{Locker: dc, si: si}
defer ds.Close()
return resultFromStatement(ctx, dc.ci, ds, args...)
}
// QueryContext executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
var rows *Rows
var err error
err = db.retry(func(strategy connReuseStrategy) error {
rows, err = db.query(ctx, query, args, strategy)
return err
})
return rows, err
}
// Query executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
//
// Query uses [context.Background] internally; to specify the context, use
// [DB.QueryContext].
func (db *DB) Query(query string, args ...any) (*Rows, error) {
return db.QueryContext(context.Background(), query, args...)
}
func (db *DB) query(ctx context.Context, query string, args []any, strategy connReuseStrategy) (*Rows, error) {
dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args)
}
// queryDC executes a query on the given connection.
// The connection gets released by the releaseConn function.
// The ctx context is from a query method and the txctx context is from an
// optional transaction context.
func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []any) (*Rows, error) {
queryerCtx, ok := dc.ci.(driver.QueryerContext)
var queryer driver.Queryer
if !ok {
queryer, ok = dc.ci.(driver.Queryer)
}
if ok {
var nvdargs []driver.NamedValue
var rowsi driver.Rows
var err error
withLock(dc, func() {
nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
if err != nil {
return
}
rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
})
if err != driver.ErrSkip {
if err != nil {
releaseConn(err)
return nil, err
}
// Note: ownership of dc passes to the *Rows, to be freed
// with releaseConn.
rows := &Rows{
dc: dc,
releaseConn: releaseConn,
rowsi: rowsi,
}
rows.initContextClose(ctx, txctx)
return rows, nil
}
}
var si driver.Stmt
var err error
withLock(dc, func() {
si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
if err != nil {
releaseConn(err)
return nil, err
}
ds := &driverStmt{Locker: dc, si: si}
rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...)
if err != nil {
ds.Close()
releaseConn(err)
return nil, err
}
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
rows := &Rows{
dc: dc,
releaseConn: releaseConn,
rowsi: rowsi,
closeStmt: ds,
}
rows.initContextClose(ctx, txctx)
return rows, nil
}
// QueryRowContext executes a query that is expected to return at most one row.
// QueryRowContext always returns a non-nil value. Errors are deferred until
// [Row]'s Scan method is called.
// If the query selects no rows, the [*Row.Scan] will return [ErrNoRows].
// Otherwise, [*Row.Scan] scans the first selected row and discards
// the rest.
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
rows, err := db.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err}
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// [Row]'s Scan method is called.
// If the query selects no rows, the [*Row.Scan] will return [ErrNoRows].
// Otherwise, [*Row.Scan] scans the first selected row and discards
// the rest.
//
// QueryRow uses [context.Background] internally; to specify the context, use
// [DB.QueryRowContext].
func (db *DB) QueryRow(query string, args ...any) *Row {
return db.QueryRowContext(context.Background(), query, args...)
}
// BeginTx starts a transaction.
//
// The provided context is used until the transaction is committed or rolled back.
// If the context is canceled, the sql package will roll back
// the transaction. [Tx.Commit] will return an error if the context provided to
// BeginTx is canceled.
//
// The provided [TxOptions] is optional and may be nil if defaults should be used.
// If a non-default isolation level is used that the driver doesn't support,
// an error will be returned.
func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
var tx *Tx
var err error
err = db.retry(func(strategy connReuseStrategy) error {
tx, err = db.begin(ctx, opts, strategy)
return err
})
return tx, err
}
// Begin starts a transaction. The default isolation level is dependent on
// the driver.
//
// Begin uses [context.Background] internally; to specify the context, use
// [DB.BeginTx].
func (db *DB) Begin() (*Tx, error) {
return db.BeginTx(context.Background(), nil)
}
func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) {
dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
return db.beginDC(ctx, dc, dc.releaseConn, opts)
}
// beginDC starts a transaction. The provided dc must be valid and ready to use.
func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) {
var txi driver.Tx
keepConnOnRollback := false
withLock(dc, func() {
_, hasSessionResetter := dc.ci.(driver.SessionResetter)
_, hasConnectionValidator := dc.ci.(driver.Validator)
keepConnOnRollback = hasSessionResetter && hasConnectionValidator
txi, err = ctxDriverBegin(ctx, opts, dc.ci)
})
if err != nil {
release(err)
return nil, err
}
// Schedule the transaction to rollback when the context is canceled.
// The cancel function in Tx will be called after done is set to true.
ctx, cancel := context.WithCancel(ctx)
tx = &Tx{
db: db,
dc: dc,
releaseConn: release,
txi: txi,
cancel: cancel,
keepConnOnRollback: keepConnOnRollback,
ctx: ctx,
}
go tx.awaitDone()
return tx, nil
}
// Driver returns the database's underlying driver.
func (db *DB) Driver() driver.Driver {
return db.connector.Driver()
}
// ErrConnDone is returned by any operation that is performed on a connection
// that has already been returned to the connection pool.
var ErrConnDone = errors.New("sql: connection is already closed")
// Conn returns a single connection by either opening a new connection
// or returning an existing connection from the connection pool. Conn will
// block until either a connection is returned or ctx is canceled.
// Queries run on the same Conn will be run in the same database session.
//
// Every Conn must be returned to the database pool after use by
// calling [Conn.Close].
func (db *DB) Conn(ctx context.Context) (*Conn, error) {
var dc *driverConn
var err error
err = db.retry(func(strategy connReuseStrategy) error {
dc, err = db.conn(ctx, strategy)
return err
})
if err != nil {
return nil, err
}
conn := &Conn{
db: db,
dc: dc,
}
return conn, nil
}
type releaseConn func(error)
// Conn represents a single database connection rather than a pool of database
// connections. Prefer running queries from [DB] unless there is a specific
// need for a continuous single database connection.
//
// A Conn must call [Conn.Close] to return the connection to the database pool
// and may do so concurrently with a running query.
//
// After a call to [Conn.Close], all operations on the
// connection fail with [ErrConnDone].
type Conn struct {
db *DB
// closemu prevents the connection from closing while there
// is an active query. It is held for read during queries
// and exclusively during close.
closemu sync.RWMutex
// dc is owned until close, at which point
// it's returned to the connection pool.
dc *driverConn
// done transitions from false to true exactly once, on close.
// Once done, all operations fail with ErrConnDone.
done atomic.Bool
releaseConnOnce sync.Once
// releaseConnCache is a cache of c.closemuRUnlockCondReleaseConn
// to save allocations in a call to grabConn.
releaseConnCache releaseConn
}
// grabConn takes a context to implement stmtConnGrabber
// but the context is not used.
func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) {
if c.done.Load() {
return nil, nil, ErrConnDone
}
c.releaseConnOnce.Do(func() {
c.releaseConnCache = c.closemuRUnlockCondReleaseConn
})
c.closemu.RLock()
return c.dc, c.releaseConnCache, nil
}
// PingContext verifies the connection to the database is still alive.
func (c *Conn) PingContext(ctx context.Context) error {
dc, release, err := c.grabConn(ctx)
if err != nil {
return err
}
return c.db.pingDC(ctx, dc, release)
}
// ExecContext executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (c *Conn) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
return c.db.execDC(ctx, dc, release, query, args)
}
// QueryContext executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
func (c *Conn) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
return c.db.queryDC(ctx, nil, dc, release, query, args)
}
// QueryRowContext executes a query that is expected to return at most one row.
// QueryRowContext always returns a non-nil value. Errors are deferred until
// the [*Row.Scan] method is called.
// If the query selects no rows, the [*Row.Scan] will return [ErrNoRows].
// Otherwise, the [*Row.Scan] scans the first selected row and discards
// the rest.
func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
rows, err := c.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err}
}
// PrepareContext creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
// The caller must call the statement's [*Stmt.Close] method
// when the statement is no longer needed.
//
// The provided context is used for the preparation of the statement, not for the
// execution of the statement.
func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
return c.db.prepareDC(ctx, dc, release, c, query)
}
// Raw executes f exposing the underlying driver connection for the
// duration of f. The driverConn must not be used outside of f.
//
// Once f returns and err is not [driver.ErrBadConn], the [Conn] will continue to be usable
// until [Conn.Close] is called.
func (c *Conn) Raw(f func(driverConn any) error) (err error) {
var dc *driverConn
var release releaseConn
// grabConn takes a context to implement stmtConnGrabber, but the context is not used.
dc, release, err = c.grabConn(nil)
if err != nil {
return
}
fPanic := true
dc.Mutex.Lock()
defer func() {
dc.Mutex.Unlock()
// If f panics fPanic will remain true.
// Ensure an error is passed to release so the connection
// may be discarded.
if fPanic {
err = driver.ErrBadConn
}
release(err)
}()
err = f(dc.ci)
fPanic = false
return
}
// BeginTx starts a transaction.
//
// The provided context is used until the transaction is committed or rolled back.
// If the context is canceled, the sql package will roll back
// the transaction. [Tx.Commit] will return an error if the context provided to
// BeginTx is canceled.
//
// The provided [TxOptions] is optional and may be nil if defaults should be used.
// If a non-default isolation level is used that the driver doesn't support,
// an error will be returned.
func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
return c.db.beginDC(ctx, dc, release, opts)
}
// closemuRUnlockCondReleaseConn read unlocks closemu
// as the sql operation is done with the dc.
func (c *Conn) closemuRUnlockCondReleaseConn(err error) {
c.closemu.RUnlock()
if errors.Is(err, driver.ErrBadConn) {
c.close(err)
}
}
func (c *Conn) txCtx() context.Context {
return nil
}
func (c *Conn) close(err error) error {
if !c.done.CompareAndSwap(false, true) {
return ErrConnDone
}
// Lock around releasing the driver connection
// to ensure all queries have been stopped before doing so.
c.closemu.Lock()
defer c.closemu.Unlock()
c.dc.releaseConn(err)
c.dc = nil
c.db = nil
return err
}
// Close returns the connection to the connection pool.
// All operations after a Close will return with [ErrConnDone].
// Close is safe to call concurrently with other operations and will
// block until all other operations finish. It may be useful to first
// cancel any used context and then call close directly after.
func (c *Conn) Close() error {
return c.close(nil)
}
// Tx is an in-progress database transaction.
//
// A transaction must end with a call to [Tx.Commit] or [Tx.Rollback].
//
// After a call to [Tx.Commit] or [Tx.Rollback], all operations on the
// transaction fail with [ErrTxDone].
//
// The statements prepared for a transaction by calling
// the transaction's [Tx.Prepare] or [Tx.Stmt] methods are closed
// by the call to [Tx.Commit] or [Tx.Rollback].
type Tx struct {
db *DB
// closemu prevents the transaction from closing while there
// is an active query. It is held for read during queries
// and exclusively during close.
closemu sync.RWMutex
// dc is owned exclusively until Commit or Rollback, at which point
// it's returned with putConn.
dc *driverConn
txi driver.Tx
// releaseConn is called once the Tx is closed to release
// any held driverConn back to the pool.
releaseConn func(error)
// done transitions from false to true exactly once, on Commit
// or Rollback. once done, all operations fail with
// ErrTxDone.
done atomic.Bool
// keepConnOnRollback is true if the driver knows
// how to reset the connection's session and if need be discard
// the connection.
keepConnOnRollback bool
// All Stmts prepared for this transaction. These will be closed after the
// transaction has been committed or rolled back.
stmts struct {
sync.Mutex
v []*Stmt
}
// cancel is called after done transitions from 0 to 1.
cancel func()
// ctx lives for the life of the transaction.
ctx context.Context
}
// awaitDone blocks until the context in Tx is canceled and rolls back
// the transaction if it's not already done.
func (tx *Tx) awaitDone() {
// Wait for either the transaction to be committed or rolled
// back, or for the associated context to be closed.
<-tx.ctx.Done()
// Discard and close the connection used to ensure the
// transaction is closed and the resources are released. This
// rollback does nothing if the transaction has already been
// committed or rolled back.
// Do not discard the connection if the connection knows
// how to reset the session.
discardConnection := !tx.keepConnOnRollback
tx.rollback(discardConnection)
}
func (tx *Tx) isDone() bool {
return tx.done.Load()
}
// ErrTxDone is returned by any operation that is performed on a transaction
// that has already been committed or rolled back.
var ErrTxDone = errors.New("sql: transaction has already been committed or rolled back")
// close returns the connection to the pool and
// must only be called by Tx.rollback or Tx.Commit while
// tx is already canceled and won't be executed concurrently.
func (tx *Tx) close(err error) {
tx.releaseConn(err)
tx.dc = nil
tx.txi = nil
}
// hookTxGrabConn specifies an optional hook to be called on
// a successful call to (*Tx).grabConn. For tests.
var hookTxGrabConn func()
func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
select {
default:
case <-ctx.Done():
return nil, nil, ctx.Err()
}
// closemu.RLock must come before the check for isDone to prevent the Tx from
// closing while a query is executing.
tx.closemu.RLock()
if tx.isDone() {
tx.closemu.RUnlock()
return nil, nil, ErrTxDone
}
if hookTxGrabConn != nil { // test hook
hookTxGrabConn()
}
return tx.dc, tx.closemuRUnlockRelease, nil
}
func (tx *Tx) txCtx() context.Context {
return tx.ctx
}
// closemuRUnlockRelease is used as a func(error) method value in
// [DB.ExecContext] and [DB.QueryContext]. Unlocking in the releaseConn keeps
// the driver conn from being returned to the connection pool until
// the Rows has been closed.
func (tx *Tx) closemuRUnlockRelease(error) {
tx.closemu.RUnlock()
}
// Closes all Stmts prepared for this transaction.
func (tx *Tx) closePrepared() {
tx.stmts.Lock()
defer tx.stmts.Unlock()
for _, stmt := range tx.stmts.v {
stmt.Close()
}
}
// Commit commits the transaction.
func (tx *Tx) Commit() error {
// Check context first to avoid transaction leak.
// If put it behind tx.done CompareAndSwap statement, we can't ensure
// the consistency between tx.done and the real COMMIT operation.
select {
default:
case <-tx.ctx.Done():
if tx.done.Load() {
return ErrTxDone
}
return tx.ctx.Err()
}
if !tx.done.CompareAndSwap(false, true) {
return ErrTxDone
}
// Cancel the Tx to release any active R-closemu locks.
// This is safe to do because tx.done has already transitioned
// from 0 to 1. Hold the W-closemu lock prior to rollback
// to ensure no other connection has an active query.
tx.cancel()
tx.closemu.Lock()
tx.closemu.Unlock()
var err error
withLock(tx.dc, func() {
err = tx.txi.Commit()
})
if !errors.Is(err, driver.ErrBadConn) {
tx.closePrepared()
}
tx.close(err)
return err
}
var rollbackHook func()
// rollback aborts the transaction and optionally forces the pool to discard
// the connection.
func (tx *Tx) rollback(discardConn bool) error {
if !tx.done.CompareAndSwap(false, true) {
return ErrTxDone
}
if rollbackHook != nil {
rollbackHook()
}
// Cancel the Tx to release any active R-closemu locks.
// This is safe to do because tx.done has already transitioned
// from 0 to 1. Hold the W-closemu lock prior to rollback
// to ensure no other connection has an active query.
tx.cancel()
tx.closemu.Lock()
tx.closemu.Unlock()
var err error
withLock(tx.dc, func() {
err = tx.txi.Rollback()
})
if !errors.Is(err, driver.ErrBadConn) {
tx.closePrepared()
}
if discardConn {
err = driver.ErrBadConn
}
tx.close(err)
return err
}
// Rollback aborts the transaction.
func (tx *Tx) Rollback() error {
return tx.rollback(false)
}
// PrepareContext creates a prepared statement for use within a transaction.
//
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
//
// To use an existing prepared statement on this transaction, see [Tx.Stmt].
//
// The provided context will be used for the preparation of the context, not
// for the execution of the returned statement. The returned statement
// will run in the transaction context.
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
dc, release, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query)
if err != nil {
return nil, err
}
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, stmt)
tx.stmts.Unlock()
return stmt, nil
}
// Prepare creates a prepared statement for use within a transaction.
//
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
//
// To use an existing prepared statement on this transaction, see [Tx.Stmt].
//
// Prepare uses [context.Background] internally; to specify the context, use
// [Tx.PrepareContext].
func (tx *Tx) Prepare(query string) (*Stmt, error) {
return tx.PrepareContext(context.Background(), query)
}
// StmtContext returns a transaction-specific prepared statement from
// an existing statement.
//
// Example:
//
// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
// ...
// tx, err := db.Begin()
// ...
// res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203)
//
// The provided context is used for the preparation of the statement, not for the
// execution of the statement.
//
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
dc, release, err := tx.grabConn(ctx)
if err != nil {
return &Stmt{stickyErr: err}
}
defer release(nil)
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
var si driver.Stmt
var parentStmt *Stmt
stmt.mu.Lock()
if stmt.closed || stmt.cg != nil {
// If the statement has been closed or already belongs to a
// transaction, we can't reuse it in this connection.
// Since tx.StmtContext should never need to be called with a
// Stmt already belonging to tx, we ignore this edge case and
// re-prepare the statement in this case. No need to add
// code-complexity for this.
stmt.mu.Unlock()
withLock(dc, func() {
si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
})
if err != nil {
return &Stmt{stickyErr: err}
}
} else {
stmt.removeClosedStmtLocked()
// See if the statement has already been prepared on this connection,
// and reuse it if possible.
for _, v := range stmt.css {
if v.dc == dc {
si = v.ds.si
break
}
}
stmt.mu.Unlock()
if si == nil {
var ds *driverStmt
withLock(dc, func() {
ds, err = stmt.prepareOnConnLocked(ctx, dc)
})
if err != nil {
return &Stmt{stickyErr: err}
}
si = ds.si
}
parentStmt = stmt
}
txs := &Stmt{
db: tx.db,
cg: tx,
cgds: &driverStmt{
Locker: dc,
si: si,
},
parentStmt: parentStmt,
query: stmt.query,
}
if parentStmt != nil {
tx.db.addDep(parentStmt, txs)
}
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, txs)
tx.stmts.Unlock()
return txs
}
// Stmt returns a transaction-specific prepared statement from
// an existing statement.
//
// Example:
//
// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
// ...
// tx, err := db.Begin()
// ...
// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
//
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
//
// Stmt uses [context.Background] internally; to specify the context, use
// [Tx.StmtContext].
func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
return tx.StmtContext(context.Background(), stmt)
}
// ExecContext executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
dc, release, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
return tx.db.execDC(ctx, dc, release, query, args)
}
// Exec executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
//
// Exec uses [context.Background] internally; to specify the context, use
// [Tx.ExecContext].
func (tx *Tx) Exec(query string, args ...any) (Result, error) {
return tx.ExecContext(context.Background(), query, args...)
}
// QueryContext executes a query that returns rows, typically a SELECT.
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
dc, release, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args)
}
// Query executes a query that returns rows, typically a SELECT.
//
// Query uses [context.Background] internally; to specify the context, use
// [Tx.QueryContext].
func (tx *Tx) Query(query string, args ...any) (*Rows, error) {
return tx.QueryContext(context.Background(), query, args...)
}
// QueryRowContext executes a query that is expected to return at most one row.
// QueryRowContext always returns a non-nil value. Errors are deferred until
// [Row]'s Scan method is called.
// If the query selects no rows, the [*Row.Scan] will return [ErrNoRows].
// Otherwise, the [*Row.Scan] scans the first selected row and discards
// the rest.
func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
rows, err := tx.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err}
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// [Row]'s Scan method is called.
// If the query selects no rows, the [*Row.Scan] will return [ErrNoRows].
// Otherwise, the [*Row.Scan] scans the first selected row and discards
// the rest.
//
// QueryRow uses [context.Background] internally; to specify the context, use
// [Tx.QueryRowContext].
func (tx *Tx) QueryRow(query string, args ...any) *Row {
return tx.QueryRowContext(context.Background(), query, args...)
}
// connStmt is a prepared statement on a particular connection.
type connStmt struct {
dc *driverConn
ds *driverStmt
}
// stmtConnGrabber represents a Tx or Conn that will return the underlying
// driverConn and release function.
type stmtConnGrabber interface {
// grabConn returns the driverConn and the associated release function
// that must be called when the operation completes.
grabConn(context.Context) (*driverConn, releaseConn, error)
// txCtx returns the transaction context if available.
// The returned context should be selected on along with
// any query context when awaiting a cancel.
txCtx() context.Context
}
var (
_ stmtConnGrabber = &Tx{}
_ stmtConnGrabber = &Conn{}
)
// Stmt is a prepared statement.
// A Stmt is safe for concurrent use by multiple goroutines.
//
// If a Stmt is prepared on a [Tx] or [Conn], it will be bound to a single
// underlying connection forever. If the [Tx] or [Conn] closes, the Stmt will
// become unusable and all operations will return an error.
// If a Stmt is prepared on a [DB], it will remain usable for the lifetime of the
// [DB]. When the Stmt needs to execute on a new underlying connection, it will
// prepare itself on the new connection automatically.
type Stmt struct {
// Immutable:
db *DB // where we came from
query string // that created the Stmt
stickyErr error // if non-nil, this error is returned for all operations
closemu sync.RWMutex // held exclusively during close, for read otherwise.
// If Stmt is prepared on a Tx or Conn then cg is present and will
// only ever grab a connection from cg.
// If cg is nil then the Stmt must grab an arbitrary connection
// from db and determine if it must prepare the stmt again by
// inspecting css.
cg stmtConnGrabber
cgds *driverStmt
// parentStmt is set when a transaction-specific statement
// is requested from an identical statement prepared on the same
// conn. parentStmt is used to track the dependency of this statement
// on its originating ("parent") statement so that parentStmt may
// be closed by the user without them having to know whether or not
// any transactions are still using it.
parentStmt *Stmt
mu sync.Mutex // protects the rest of the fields
closed bool
// css is a list of underlying driver statement interfaces
// that are valid on particular connections. This is only
// used if cg == nil and one is found that has idle
// connections. If cg != nil, cgds is always used.
css []connStmt
// lastNumClosed is copied from db.numClosed when Stmt is created
// without tx and closed connections in css are removed.
lastNumClosed uint64
}
// ExecContext executes a prepared statement with the given arguments and
// returns a [Result] summarizing the effect of the statement.
func (s *Stmt) ExecContext(ctx context.Context, args ...any) (Result, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
var res Result
err := s.db.retry(func(strategy connReuseStrategy) error {
dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
if err != nil {
return err
}
res, err = resultFromStatement(ctx, dc.ci, ds, args...)
releaseConn(err)
return err
})
return res, err
}
// Exec executes a prepared statement with the given arguments and
// returns a [Result] summarizing the effect of the statement.
//
// Exec uses [context.Background] internally; to specify the context, use
// [Stmt.ExecContext].
func (s *Stmt) Exec(args ...any) (Result, error) {
return s.ExecContext(context.Background(), args...)
}
func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (Result, error) {
ds.Lock()
defer ds.Unlock()
dargs, err := driverArgsConnLocked(ci, ds, args)
if err != nil {
return nil, err
}
resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
if err != nil {
return nil, err
}
return driverResult{ds.Locker, resi}, nil
}
// removeClosedStmtLocked removes closed conns in s.css.
//
// To avoid lock contention on DB.mu, we do it only when
// s.db.numClosed - s.lastNum is large enough.
func (s *Stmt) removeClosedStmtLocked() {
t := len(s.css)/2 + 1
if t > 10 {
t = 10
}
dbClosed := s.db.numClosed.Load()
if dbClosed-s.lastNumClosed < uint64(t) {
return
}
s.db.mu.Lock()
for i := 0; i < len(s.css); i++ {
if s.css[i].dc.dbmuClosed {
s.css[i] = s.css[len(s.css)-1]
// Zero out the last element (for GC) before shrinking the slice.
s.css[len(s.css)-1] = connStmt{}
s.css = s.css[:len(s.css)-1]
i--
}
}
s.db.mu.Unlock()
s.lastNumClosed = dbClosed
}
// connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a
// statement bound to that connection.
func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
if err = s.stickyErr; err != nil {
return
}
s.mu.Lock()
if s.closed {
s.mu.Unlock()
err = errors.New("sql: statement is closed")
return
}
// In a transaction or connection, we always use the connection that the
// stmt was created on.
if s.cg != nil {
s.mu.Unlock()
dc, releaseConn, err = s.cg.grabConn(ctx) // blocks, waiting for the connection.
if err != nil {
return
}
return dc, releaseConn, s.cgds, nil
}
s.removeClosedStmtLocked()
s.mu.Unlock()
dc, err = s.db.conn(ctx, strategy)
if err != nil {
return nil, nil, nil, err
}
s.mu.Lock()
for _, v := range s.css {
if v.dc == dc {
s.mu.Unlock()
return dc, dc.releaseConn, v.ds, nil
}
}
s.mu.Unlock()
// No luck; we need to prepare the statement on this connection
withLock(dc, func() {
ds, err = s.prepareOnConnLocked(ctx, dc)
})
if err != nil {
dc.releaseConn(err)
return nil, nil, nil, err
}
return dc, dc.releaseConn, ds, nil
}
// prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of
// open connStmt on the statement. It assumes the caller is holding the lock on dc.
func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
si, err := dc.prepareLocked(ctx, s.cg, s.query)
if err != nil {
return nil, err
}
cs := connStmt{dc, si}
s.mu.Lock()
s.css = append(s.css, cs)
s.mu.Unlock()
return cs.ds, nil
}
// QueryContext executes a prepared query statement with the given arguments
// and returns the query results as a [*Rows].
func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*Rows, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
var rowsi driver.Rows
var rows *Rows
err := s.db.retry(func(strategy connReuseStrategy) error {
dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
if err != nil {
return err
}
rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...)
if err == nil {
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
rows = &Rows{
dc: dc,
rowsi: rowsi,
// releaseConn set below
}
// addDep must be added before initContextClose or it could attempt
// to removeDep before it has been added.
s.db.addDep(s, rows)
// releaseConn must be set before initContextClose or it could
// release the connection before it is set.
rows.releaseConn = func(err error) {
releaseConn(err)
s.db.removeDep(s, rows)
}
var txctx context.Context
if s.cg != nil {
txctx = s.cg.txCtx()
}
rows.initContextClose(ctx, txctx)
return nil
}
releaseConn(err)
return err
})
return rows, err
}
// Query executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
//
// Query uses [context.Background] internally; to specify the context, use
// [Stmt.QueryContext].
func (s *Stmt) Query(args ...any) (*Rows, error) {
return s.QueryContext(context.Background(), args...)
}
func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (driver.Rows, error) {
ds.Lock()
defer ds.Unlock()
dargs, err := driverArgsConnLocked(ci, ds, args)
if err != nil {
return nil, err
}
return ctxDriverStmtQuery(ctx, ds.si, dargs)
}
// QueryRowContext executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned [*Row], which is always non-nil.
// If the query selects no rows, the [*Row.Scan] will return [ErrNoRows].
// Otherwise, the [*Row.Scan] scans the first selected row and discards
// the rest.
func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *Row {
rows, err := s.QueryContext(ctx, args...)
if err != nil {
return &Row{err: err}
}
return &Row{rows: rows}
}
// QueryRow executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned [*Row], which is always non-nil.
// If the query selects no rows, the [*Row.Scan] will return [ErrNoRows].
// Otherwise, the [*Row.Scan] scans the first selected row and discards
// the rest.
//
// Example usage:
//
// var name string
// err := nameByUseridStmt.QueryRow(id).Scan(&name)
//
// QueryRow uses [context.Background] internally; to specify the context, use
// [Stmt.QueryRowContext].
func (s *Stmt) QueryRow(args ...any) *Row {
return s.QueryRowContext(context.Background(), args...)
}
// Close closes the statement.
func (s *Stmt) Close() error {
s.closemu.Lock()
defer s.closemu.Unlock()
if s.stickyErr != nil {
return s.stickyErr
}
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return nil
}
s.closed = true
txds := s.cgds
s.cgds = nil
s.mu.Unlock()
if s.cg == nil {
return s.db.removeDep(s, s)
}
if s.parentStmt != nil {
// If parentStmt is set, we must not close s.txds since it's stored
// in the css array of the parentStmt.
return s.db.removeDep(s.parentStmt, s)
}
return txds.Close()
}
func (s *Stmt) finalClose() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.css != nil {
for _, v := range s.css {
s.db.noteUnusedDriverStatement(v.dc, v.ds)
v.dc.removeOpenStmt(v.ds)
}
s.css = nil
}
return nil
}
// Rows is the result of a query. Its cursor starts before the first row
// of the result set. Use [Rows.Next] to advance from row to row.
type Rows struct {
dc *driverConn // owned; must call releaseConn when closed to release
releaseConn func(error)
rowsi driver.Rows
cancel func() // called when Rows is closed, may be nil.
closeStmt *driverStmt // if non-nil, statement to Close on close
contextDone atomic.Pointer[error] // error that awaitDone saw; set before close attempt
// closemu prevents Rows from closing while there
// is an active streaming result. It is held for read during non-close operations
// and exclusively during close.
//
// closemu guards lasterr and closed.
closemu sync.RWMutex
lasterr error // non-nil only if closed is true
closed bool
// closemuScanHold is whether the previous call to Scan kept closemu RLock'ed
// without unlocking it. It does that when the user passes a *RawBytes scan
// target. In that case, we need to prevent awaitDone from closing the Rows
// while the user's still using the memory. See go.dev/issue/60304.
//
// It is only used by Scan, Next, and NextResultSet which are expected
// not to be called concurrently.
closemuScanHold bool
// hitEOF is whether Next hit the end of the rows without
// encountering an error. It's set in Next before
// returning. It's only used by Next and Err which are
// expected not to be called concurrently.
hitEOF bool
// lastcols is only used in Scan, Next, and NextResultSet which are expected
// not to be called concurrently.
lastcols []driver.Value
// raw is a buffer for RawBytes that persists between Scan calls.
// This is used when the driver returns a mismatched type that requires
// a cloning allocation. For example, if the driver returns a *string and
// the user is scanning into a *RawBytes, we need to copy the string.
// The raw buffer here lets us reuse the memory for that copy across Scan calls.
raw []byte
}
// lasterrOrErrLocked returns either lasterr or the provided err.
// rs.closemu must be read-locked.
func (rs *Rows) lasterrOrErrLocked(err error) error {
if rs.lasterr != nil && rs.lasterr != io.EOF {
return rs.lasterr
}
return err
}
// bypassRowsAwaitDone is only used for testing.
// If true, it will not close the Rows automatically from the context.
var bypassRowsAwaitDone = false
func (rs *Rows) initContextClose(ctx, txctx context.Context) {
if ctx.Done() == nil && (txctx == nil || txctx.Done() == nil) {
return
}
if bypassRowsAwaitDone {
return
}
closectx, cancel := context.WithCancel(ctx)
rs.cancel = cancel
go rs.awaitDone(ctx, txctx, closectx)
}
// awaitDone blocks until ctx, txctx, or closectx is canceled.
// The ctx is provided from the query context.
// If the query was issued in a transaction, the transaction's context
// is also provided in txctx, to ensure Rows is closed if the Tx is closed.
// The closectx is closed by an explicit call to rs.Close.
func (rs *Rows) awaitDone(ctx, txctx, closectx context.Context) {
var txctxDone <-chan struct{}
if txctx != nil {
txctxDone = txctx.Done()
}
select {
case <-ctx.Done():
err := ctx.Err()
rs.contextDone.Store(&err)
case <-txctxDone:
err := txctx.Err()
rs.contextDone.Store(&err)
case <-closectx.Done():
// rs.cancel was called via Close(); don't store this into contextDone
// to ensure Err() is unaffected.
}
rs.close(ctx.Err())
}
// Next prepares the next result row for reading with the [Rows.Scan] method. It
// returns true on success, or false if there is no next result row or an error
// happened while preparing it. [Rows.Err] should be consulted to distinguish between
// the two cases.
//
// Every call to [Rows.Scan], even the first one, must be preceded by a call to [Rows.Next].
func (rs *Rows) Next() bool {
// If the user's calling Next, they're done with their previous row's Scan
// results (any RawBytes memory), so we can release the read lock that would
// be preventing awaitDone from calling close.
rs.closemuRUnlockIfHeldByScan()
if rs.contextDone.Load() != nil {
return false
}
var doClose, ok bool
withLock(rs.closemu.RLocker(), func() {
doClose, ok = rs.nextLocked()
})
if doClose {
rs.Close()
}
if doClose && !ok {
rs.hitEOF = true
}
return ok
}
func (rs *Rows) nextLocked() (doClose, ok bool) {
if rs.closed {
return false, false
}
// Lock the driver connection before calling the driver interface
// rowsi to prevent a Tx from rolling back the connection at the same time.
rs.dc.Lock()
defer rs.dc.Unlock()
if rs.lastcols == nil {
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
}
rs.lasterr = rs.rowsi.Next(rs.lastcols)
if rs.lasterr != nil {
// Close the connection if there is a driver error.
if rs.lasterr != io.EOF {
return true, false
}
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
if !ok {
return true, false
}
// The driver is at the end of the current result set.
// Test to see if there is another result set after the current one.
// Only close Rows if there is no further result sets to read.
if !nextResultSet.HasNextResultSet() {
doClose = true
}
return doClose, false
}
return false, true
}
// NextResultSet prepares the next result set for reading. It reports whether
// there is further result sets, or false if there is no further result set
// or if there is an error advancing to it. The [Rows.Err] method should be consulted
// to distinguish between the two cases.
//
// After calling NextResultSet, the [Rows.Next] method should always be called before
// scanning. If there are further result sets they may not have rows in the result
// set.
func (rs *Rows) NextResultSet() bool {
// If the user's calling NextResultSet, they're done with their previous
// row's Scan results (any RawBytes memory), so we can release the read lock
// that would be preventing awaitDone from calling close.
rs.closemuRUnlockIfHeldByScan()
var doClose bool
defer func() {
if doClose {
rs.Close()
}
}()
rs.closemu.RLock()
defer rs.closemu.RUnlock()
if rs.closed {
return false
}
rs.lastcols = nil
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
if !ok {
doClose = true
return false
}
// Lock the driver connection before calling the driver interface
// rowsi to prevent a Tx from rolling back the connection at the same time.
rs.dc.Lock()
defer rs.dc.Unlock()
rs.lasterr = nextResultSet.NextResultSet()
if rs.lasterr != nil {
doClose = true
return false
}
return true
}
// Err returns the error, if any, that was encountered during iteration.
// Err may be called after an explicit or implicit [Rows.Close].
func (rs *Rows) Err() error {
// Return any context error that might've happened during row iteration,
// but only if we haven't reported the final Next() = false after rows
// are done, in which case the user might've canceled their own context
// before calling Rows.Err.
if !rs.hitEOF {
if errp := rs.contextDone.Load(); errp != nil {
return *errp
}
}
rs.closemu.RLock()
defer rs.closemu.RUnlock()
return rs.lasterrOrErrLocked(nil)
}
// rawbuf returns the buffer to append RawBytes values to.
// This buffer is reused across calls to Rows.Scan.
//
// Usage:
//
// rawBytes = rows.setrawbuf(append(rows.rawbuf(), value...))
func (rs *Rows) rawbuf() []byte {
if rs == nil {
// convertAssignRows can take a nil *Rows; for simplicity handle it here
return nil
}
return rs.raw
}
// setrawbuf updates the RawBytes buffer with the result of appending a new value to it.
// It returns the new value.
func (rs *Rows) setrawbuf(b []byte) RawBytes {
if rs == nil {
// convertAssignRows can take a nil *Rows; for simplicity handle it here
return RawBytes(b)
}
off := len(rs.raw)
rs.raw = b
return RawBytes(rs.raw[off:])
}
var errRowsClosed = errors.New("sql: Rows are closed")
var errNoRows = errors.New("sql: no Rows available")
// Columns returns the column names.
// Columns returns an error if the rows are closed.
func (rs *Rows) Columns() ([]string, error) {
rs.closemu.RLock()
defer rs.closemu.RUnlock()
if rs.closed {
return nil, rs.lasterrOrErrLocked(errRowsClosed)
}
if rs.rowsi == nil {
return nil, rs.lasterrOrErrLocked(errNoRows)
}
rs.dc.Lock()
defer rs.dc.Unlock()
return rs.rowsi.Columns(), nil
}
// ColumnTypes returns column information such as column type, length,
// and nullable. Some information may not be available from some drivers.
func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
rs.closemu.RLock()
defer rs.closemu.RUnlock()
if rs.closed {
return nil, rs.lasterrOrErrLocked(errRowsClosed)
}
if rs.rowsi == nil {
return nil, rs.lasterrOrErrLocked(errNoRows)
}
rs.dc.Lock()
defer rs.dc.Unlock()
return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
}
// ColumnType contains the name and type of a column.
type ColumnType struct {
name string
hasNullable bool
hasLength bool
hasPrecisionScale bool
nullable bool
length int64
databaseType string
precision int64
scale int64
scanType reflect.Type
}
// Name returns the name or alias of the column.
func (ci *ColumnType) Name() string {
return ci.name
}
// Length returns the column type length for variable length column types such
// as text and binary field types. If the type length is unbounded the value will
// be [math.MaxInt64] (any database limits will still apply).
// If the column type is not variable length, such as an int, or if not supported
// by the driver ok is false.
func (ci *ColumnType) Length() (length int64, ok bool) {
return ci.length, ci.hasLength
}
// DecimalSize returns the scale and precision of a decimal type.
// If not applicable or if not supported ok is false.
func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
return ci.precision, ci.scale, ci.hasPrecisionScale
}
// ScanType returns a Go type suitable for scanning into using [Rows.Scan].
// If a driver does not support this property ScanType will return
// the type of an empty interface.
func (ci *ColumnType) ScanType() reflect.Type {
return ci.scanType
}
// Nullable reports whether the column may be null.
// If a driver does not support this property ok will be false.
func (ci *ColumnType) Nullable() (nullable, ok bool) {
return ci.nullable, ci.hasNullable
}
// DatabaseTypeName returns the database system name of the column type. If an empty
// string is returned, then the driver type name is not supported.
// Consult your driver documentation for a list of driver data types. [ColumnType.Length] specifiers
// are not included.
// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
// "INT", and "BIGINT".
func (ci *ColumnType) DatabaseTypeName() string {
return ci.databaseType
}
func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
names := rowsi.Columns()
list := make([]*ColumnType, len(names))
for i := range list {
ci := &ColumnType{
name: names[i],
}
list[i] = ci
if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok {
ci.scanType = prop.ColumnTypeScanType(i)
} else {
ci.scanType = reflect.TypeFor[any]()
}
if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok {
ci.databaseType = prop.ColumnTypeDatabaseTypeName(i)
}
if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok {
ci.length, ci.hasLength = prop.ColumnTypeLength(i)
}
if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok {
ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i)
}
if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok {
ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i)
}
}
return list
}
// Scan copies the columns in the current row into the values pointed
// at by dest. The number of values in dest must be the same as the
// number of columns in [Rows].
//
// Scan converts columns read from the database into the following
// common Go types and special types provided by the sql package:
//
// *string
// *[]byte
// *int, *int8, *int16, *int32, *int64
// *uint, *uint8, *uint16, *uint32, *uint64
// *bool
// *float32, *float64
// *interface{}
// *RawBytes
// *Rows (cursor value)
// any type implementing Scanner (see Scanner docs)
//
// In the most simple case, if the type of the value from the source
// column is an integer, bool or string type T and dest is of type *T,
// Scan simply assigns the value through the pointer.
//
// Scan also converts between string and numeric types, as long as no
// information would be lost. While Scan stringifies all numbers
// scanned from numeric database columns into *string, scans into
// numeric types are checked for overflow. For example, a float64 with
// value 300 or a string with value "300" can scan into a uint16, but
// not into a uint8, though float64(255) or "255" can scan into a
// uint8. One exception is that scans of some float64 numbers to
// strings may lose information when stringifying. In general, scan
// floating point columns into *float64.
//
// If a dest argument has type *[]byte, Scan saves in that argument a
// copy of the corresponding data. The copy is owned by the caller and
// can be modified and held indefinitely. The copy can be avoided by
// using an argument of type [*RawBytes] instead; see the documentation
// for [RawBytes] for restrictions on its use.
//
// If an argument has type *interface{}, Scan copies the value
// provided by the underlying driver without conversion. When scanning
// from a source value of type []byte to *interface{}, a copy of the
// slice is made and the caller owns the result.
//
// Source values of type [time.Time] may be scanned into values of type
// *time.Time, *interface{}, *string, or *[]byte. When converting to
// the latter two, [time.RFC3339Nano] is used.
//
// Source values of type bool may be scanned into types *bool,
// *interface{}, *string, *[]byte, or [*RawBytes].
//
// For scanning into *bool, the source may be true, false, 1, 0, or
// string inputs parseable by [strconv.ParseBool].
//
// Scan can also convert a cursor returned from a query, such as
// "select cursor(select * from my_table) from dual", into a
// [*Rows] value that can itself be scanned from. The parent
// select query will close any cursor [*Rows] if the parent [*Rows] is closed.
//
// If any of the first arguments implementing [Scanner] returns an error,
// that error will be wrapped in the returned error.
func (rs *Rows) Scan(dest ...any) error {
if rs.closemuScanHold {
// This should only be possible if the user calls Scan twice in a row
// without calling Next.
return fmt.Errorf("sql: Scan called without calling Next (closemuScanHold)")
}
rs.closemu.RLock()
rs.raw = rs.raw[:0]
err := rs.scanLocked(dest...)
if err == nil && scanArgsContainRawBytes(dest) {
rs.closemuScanHold = true
} else {
rs.closemu.RUnlock()
}
return err
}
func (rs *Rows) scanLocked(dest ...any) error {
if rs.lasterr != nil && rs.lasterr != io.EOF {
return rs.lasterr
}
if rs.closed {
return rs.lasterrOrErrLocked(errRowsClosed)
}
if rs.lastcols == nil {
return errors.New("sql: Scan called without calling Next")
}
if len(dest) != len(rs.lastcols) {
return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
}
for i, sv := range rs.lastcols {
err := convertAssignRows(dest[i], sv, rs)
if err != nil {
return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
}
}
return nil
}
// closemuRUnlockIfHeldByScan releases any closemu.RLock held open by a previous
// call to Scan with *RawBytes.
func (rs *Rows) closemuRUnlockIfHeldByScan() {
if rs.closemuScanHold {
rs.closemuScanHold = false
rs.closemu.RUnlock()
}
}
func scanArgsContainRawBytes(args []any) bool {
for _, a := range args {
if _, ok := a.(*RawBytes); ok {
return true
}
}
return false
}
// rowsCloseHook returns a function so tests may install the
// hook through a test only mutex.
var rowsCloseHook = func() func(*Rows, *error) { return nil }
// Close closes the [Rows], preventing further enumeration. If [Rows.Next] is called
// and returns false and there are no further result sets,
// the [Rows] are closed automatically and it will suffice to check the
// result of [Rows.Err]. Close is idempotent and does not affect the result of [Rows.Err].
func (rs *Rows) Close() error {
// If the user's calling Close, they're done with their previous row's Scan
// results (any RawBytes memory), so we can release the read lock that would
// be preventing awaitDone from calling the unexported close before we do so.
rs.closemuRUnlockIfHeldByScan()
return rs.close(nil)
}
func (rs *Rows) close(err error) error {
rs.closemu.Lock()
defer rs.closemu.Unlock()
if rs.closed {
return nil
}
rs.closed = true
if rs.lasterr == nil {
rs.lasterr = err
}
withLock(rs.dc, func() {
err = rs.rowsi.Close()
})
if fn := rowsCloseHook(); fn != nil {
fn(rs, &err)
}
if rs.cancel != nil {
rs.cancel()
}
if rs.closeStmt != nil {
rs.closeStmt.Close()
}
rs.releaseConn(err)
rs.lasterr = rs.lasterrOrErrLocked(err)
return err
}
// Row is the result of calling [DB.QueryRow] to select a single row.
type Row struct {
// One of these two will be non-nil:
err error // deferred error for easy chaining
rows *Rows
}
// Scan copies the columns from the matched row into the values
// pointed at by dest. See the documentation on [Rows.Scan] for details.
// If more than one row matches the query,
// Scan uses the first row and discards the rest. If no row matches
// the query, Scan returns [ErrNoRows].
func (r *Row) Scan(dest ...any) error {
if r.err != nil {
return r.err
}
// TODO(bradfitz): for now we need to defensively clone all
// []byte that the driver returned (not permitting
// *RawBytes in Rows.Scan), since we're about to close
// the Rows in our defer, when we return from this function.
// the contract with the driver.Next(...) interface is that it
// can return slices into read-only temporary memory that's
// only valid until the next Scan/Close. But the TODO is that
// for a lot of drivers, this copy will be unnecessary. We
// should provide an optional interface for drivers to
// implement to say, "don't worry, the []bytes that I return
// from Next will not be modified again." (for instance, if
// they were obtained from the network anyway) But for now we
// don't care.
defer r.rows.Close()
if scanArgsContainRawBytes(dest) {
return errors.New("sql: RawBytes isn't allowed on Row.Scan")
}
if !r.rows.Next() {
if err := r.rows.Err(); err != nil {
return err
}
return ErrNoRows
}
err := r.rows.Scan(dest...)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
return r.rows.Close()
}
// Err provides a way for wrapping packages to check for
// query errors without calling [Row.Scan].
// Err returns the error, if any, that was encountered while running the query.
// If this error is not nil, this error will also be returned from [Row.Scan].
func (r *Row) Err() error {
return r.err
}
// A Result summarizes an executed SQL command.
type Result interface {
// LastInsertId returns the integer generated by the database
// in response to a command. Typically this will be from an
// "auto increment" column when inserting a new row. Not all
// databases support this feature, and the syntax of such
// statements varies.
LastInsertId() (int64, error)
// RowsAffected returns the number of rows affected by an
// update, insert, or delete. Not every database or database
// driver may support this.
RowsAffected() (int64, error)
}
type driverResult struct {
sync.Locker // the *driverConn
resi driver.Result
}
func (dr driverResult) LastInsertId() (int64, error) {
dr.Lock()
defer dr.Unlock()
return dr.resi.LastInsertId()
}
func (dr driverResult) RowsAffected() (int64, error) {
dr.Lock()
defer dr.Unlock()
return dr.resi.RowsAffected()
}
func stack() string {
var buf [2 << 10]byte
return string(buf[:runtime.Stack(buf[:], false)])
}
// withLock runs while holding lk.
func withLock(lk sync.Locker, fn func()) {
lk.Lock()
defer lk.Unlock() // in case fn panics
fn()
}
// connRequestSet is a set of chan connRequest that's
// optimized for:
//
// - adding an element
// - removing an element (only by the caller who added it)
// - taking (get + delete) a random element
//
// We previously used a map for this but the take of a random element
// was expensive, making mapiters. This type avoids a map entirely
// and just uses a slice.
type connRequestSet struct {
// s are the elements in the set.
s []connRequestAndIndex
}
type connRequestAndIndex struct {
// req is the element in the set.
req chan connRequest
// curIdx points to the current location of this element in
// connRequestSet.s. It gets set to -1 upon removal.
curIdx *int
}
// CloseAndRemoveAll closes all channels in the set
// and clears the set.
func (s *connRequestSet) CloseAndRemoveAll() {
for _, v := range s.s {
*v.curIdx = -1
close(v.req)
}
s.s = nil
}
// Len returns the length of the set.
func (s *connRequestSet) Len() int { return len(s.s) }
// connRequestDelHandle is an opaque handle to delete an
// item from calling Add.
type connRequestDelHandle struct {
idx *int // pointer to index; or -1 if not in slice
}
// Add adds v to the set of waiting requests.
// The returned connRequestDelHandle can be used to remove the item from
// the set.
func (s *connRequestSet) Add(v chan connRequest) connRequestDelHandle {
idx := len(s.s)
// TODO(bradfitz): for simplicity, this always allocates a new int-sized
// allocation to store the index. But generally the set will be small and
// under a scannable-threshold. As an optimization, we could permit the *int
// to be nil when the set is small and should be scanned. This works even if
// the set grows over the threshold with delete handles outstanding because
// an element can only move to a lower index. So if it starts with a nil
// position, it'll always be in a low index and thus scannable. But that
// can be done in a follow-up change.
idxPtr := &idx
s.s = append(s.s, connRequestAndIndex{v, idxPtr})
return connRequestDelHandle{idxPtr}
}
// Delete removes an element from the set.
//
// It reports whether the element was deleted. (It can return false if a caller
// of TakeRandom took it meanwhile, or upon the second call to Delete)
func (s *connRequestSet) Delete(h connRequestDelHandle) bool {
idx := *h.idx
if idx < 0 {
return false
}
s.deleteIndex(idx)
return true
}
func (s *connRequestSet) deleteIndex(idx int) {
// Mark item as deleted.
*(s.s[idx].curIdx) = -1
// Copy last element, updating its position
// to its new home.
if idx < len(s.s)-1 {
last := s.s[len(s.s)-1]
*last.curIdx = idx
s.s[idx] = last
}
// Zero out last element (for GC) before shrinking the slice.
s.s[len(s.s)-1] = connRequestAndIndex{}
s.s = s.s[:len(s.s)-1]
}
// TakeRandom returns and removes a random element from s
// and reports whether there was one to take. (It returns ok=false
// if the set is empty.)
func (s *connRequestSet) TakeRandom() (v chan connRequest, ok bool) {
if len(s.s) == 0 {
return nil, false
}
pick := rand.IntN(len(s.s))
e := s.s[pick]
s.deleteIndex(pick)
return e.req, true
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package buildinfo provides access to information embedded in a Go binary
// about how it was built. This includes the Go toolchain version, and the
// set of modules used (for binaries built in module mode).
//
// Build information is available for the currently running binary in
// runtime/debug.ReadBuildInfo.
package buildinfo
import (
"bytes"
"debug/elf"
"debug/macho"
"debug/pe"
"debug/plan9obj"
"encoding/binary"
"errors"
"fmt"
"internal/saferio"
"internal/xcoff"
"io"
"io/fs"
"os"
"runtime/debug"
_ "unsafe" // for linkname
)
// Type alias for build info. We cannot move the types here, since
// runtime/debug would need to import this package, which would make it
// a much larger dependency.
type BuildInfo = debug.BuildInfo
// errUnrecognizedFormat is returned when a given executable file doesn't
// appear to be in a known format, or it breaks the rules of that format,
// or when there are I/O errors reading the file.
var errUnrecognizedFormat = errors.New("unrecognized file format")
// errNotGoExe is returned when a given executable file is valid but does
// not contain Go build information.
//
// errNotGoExe should be an internal detail,
// but widely used packages access it using linkname.
// Notable members of the hall of shame include:
// - github.com/quay/claircore
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname errNotGoExe
var errNotGoExe = errors.New("not a Go executable")
// The build info blob left by the linker is identified by a 32-byte header,
// consisting of buildInfoMagic (14 bytes), followed by version-dependent
// fields.
var buildInfoMagic = []byte("\xff Go buildinf:")
const (
buildInfoAlign = 16
buildInfoHeaderSize = 32
)
// ReadFile returns build information embedded in a Go binary
// file at the given path. Most information is only available for binaries built
// with module support.
func ReadFile(name string) (info *BuildInfo, err error) {
defer func() {
if pathErr := (*fs.PathError)(nil); errors.As(err, &pathErr) {
err = fmt.Errorf("could not read Go build info: %w", err)
} else if err != nil {
err = fmt.Errorf("could not read Go build info from %s: %w", name, err)
}
}()
f, err := os.Open(name)
if err != nil {
return nil, err
}
defer f.Close()
return Read(f)
}
// Read returns build information embedded in a Go binary file
// accessed through the given ReaderAt. Most information is only available for
// binaries built with module support.
func Read(r io.ReaderAt) (*BuildInfo, error) {
vers, mod, err := readRawBuildInfo(r)
if err != nil {
return nil, err
}
bi, err := debug.ParseBuildInfo(mod)
if err != nil {
return nil, err
}
bi.GoVersion = vers
return bi, nil
}
type exe interface {
// DataStart returns the virtual address and size of the segment or section that
// should contain build information. This is either a specially named section
// or the first writable non-zero data segment.
DataStart() (uint64, uint64)
// DataReader returns an io.ReaderAt that reads from addr until the end
// of segment or section that contains addr.
DataReader(addr uint64) (io.ReaderAt, error)
}
// readRawBuildInfo extracts the Go toolchain version and module information
// strings from a Go binary. On success, vers should be non-empty. mod
// is empty if the binary was not built with modules enabled.
func readRawBuildInfo(r io.ReaderAt) (vers, mod string, err error) {
// Read the first bytes of the file to identify the format, then delegate to
// a format-specific function to load segment and section headers.
ident := make([]byte, 16)
if n, err := r.ReadAt(ident, 0); n < len(ident) || err != nil {
return "", "", errUnrecognizedFormat
}
var x exe
switch {
case bytes.HasPrefix(ident, []byte("\x7FELF")):
f, err := elf.NewFile(r)
if err != nil {
return "", "", errUnrecognizedFormat
}
x = &elfExe{f}
case bytes.HasPrefix(ident, []byte("MZ")):
f, err := pe.NewFile(r)
if err != nil {
return "", "", errUnrecognizedFormat
}
x = &peExe{f}
case bytes.HasPrefix(ident, []byte("\xFE\xED\xFA")) || bytes.HasPrefix(ident[1:], []byte("\xFA\xED\xFE")):
f, err := macho.NewFile(r)
if err != nil {
return "", "", errUnrecognizedFormat
}
x = &machoExe{f}
case bytes.HasPrefix(ident, []byte("\xCA\xFE\xBA\xBE")) || bytes.HasPrefix(ident, []byte("\xCA\xFE\xBA\xBF")):
f, err := macho.NewFatFile(r)
if err != nil || len(f.Arches) == 0 {
return "", "", errUnrecognizedFormat
}
x = &machoExe{f.Arches[0].File}
case bytes.HasPrefix(ident, []byte{0x01, 0xDF}) || bytes.HasPrefix(ident, []byte{0x01, 0xF7}):
f, err := xcoff.NewFile(r)
if err != nil {
return "", "", errUnrecognizedFormat
}
x = &xcoffExe{f}
case hasPlan9Magic(ident):
f, err := plan9obj.NewFile(r)
if err != nil {
return "", "", errUnrecognizedFormat
}
x = &plan9objExe{f}
default:
return "", "", errUnrecognizedFormat
}
// Read segment or section to find the build info blob.
// On some platforms, the blob will be in its own section, and DataStart
// returns the address of that section. On others, it's somewhere in the
// data segment; the linker puts it near the beginning.
// See cmd/link/internal/ld.Link.buildinfo.
dataAddr, dataSize := x.DataStart()
if dataSize == 0 {
return "", "", errNotGoExe
}
addr, err := searchMagic(x, dataAddr, dataSize)
if err != nil {
return "", "", err
}
// Read in the full header first.
header, err := readData(x, addr, buildInfoHeaderSize)
if err == io.EOF {
return "", "", errNotGoExe
} else if err != nil {
return "", "", err
}
if len(header) < buildInfoHeaderSize {
return "", "", errNotGoExe
}
const (
ptrSizeOffset = 14
flagsOffset = 15
versPtrOffset = 16
flagsEndianMask = 0x1
flagsEndianLittle = 0x0
flagsEndianBig = 0x1
flagsVersionMask = 0x2
flagsVersionPtr = 0x0
flagsVersionInl = 0x2
)
// Decode the blob. The blob is a 32-byte header, optionally followed
// by 2 varint-prefixed string contents.
//
// type buildInfoHeader struct {
// magic [14]byte
// ptrSize uint8 // used if flagsVersionPtr
// flags uint8
// versPtr targetUintptr // used if flagsVersionPtr
// modPtr targetUintptr // used if flagsVersionPtr
// }
//
// The version bit of the flags field determines the details of the format.
//
// Prior to 1.18, the flags version bit is flagsVersionPtr. In this
// case, the header includes pointers to the version and modinfo Go
// strings in the header. The ptrSize field indicates the size of the
// pointers and the endian bit of the flag indicates the pointer
// endianness.
//
// Since 1.18, the flags version bit is flagsVersionInl. In this case,
// the header is followed by the string contents inline as
// length-prefixed (as varint) string contents. First is the version
// string, followed immediately by the modinfo string.
flags := header[flagsOffset]
if flags&flagsVersionMask == flagsVersionInl {
vers, addr, err = decodeString(x, addr+buildInfoHeaderSize)
if err != nil {
return "", "", err
}
mod, _, err = decodeString(x, addr)
if err != nil {
return "", "", err
}
} else {
// flagsVersionPtr (<1.18)
ptrSize := int(header[ptrSizeOffset])
bigEndian := flags&flagsEndianMask == flagsEndianBig
var bo binary.ByteOrder
if bigEndian {
bo = binary.BigEndian
} else {
bo = binary.LittleEndian
}
var readPtr func([]byte) uint64
if ptrSize == 4 {
readPtr = func(b []byte) uint64 { return uint64(bo.Uint32(b)) }
} else if ptrSize == 8 {
readPtr = bo.Uint64
} else {
return "", "", errNotGoExe
}
vers = readString(x, ptrSize, readPtr, readPtr(header[versPtrOffset:]))
mod = readString(x, ptrSize, readPtr, readPtr(header[versPtrOffset+ptrSize:]))
}
if vers == "" {
return "", "", errNotGoExe
}
if len(mod) >= 33 && mod[len(mod)-17] == '\n' {
// Strip module framing: sentinel strings delimiting the module info.
// These are cmd/go/internal/modload.infoStart and infoEnd.
mod = mod[16 : len(mod)-16]
} else {
mod = ""
}
return vers, mod, nil
}
func hasPlan9Magic(magic []byte) bool {
if len(magic) >= 4 {
m := binary.BigEndian.Uint32(magic)
switch m {
case plan9obj.Magic386, plan9obj.MagicAMD64, plan9obj.MagicARM:
return true
}
}
return false
}
func decodeString(x exe, addr uint64) (string, uint64, error) {
// varint length followed by length bytes of data.
// N.B. ReadData reads _up to_ size bytes from the section containing
// addr. So we don't need to check that size doesn't overflow the
// section.
b, err := readData(x, addr, binary.MaxVarintLen64)
if err == io.EOF {
return "", 0, errNotGoExe
} else if err != nil {
return "", 0, err
}
length, n := binary.Uvarint(b)
if n <= 0 {
return "", 0, errNotGoExe
}
addr += uint64(n)
b, err = readData(x, addr, length)
if err == io.EOF {
return "", 0, errNotGoExe
} else if err == io.ErrUnexpectedEOF {
// Length too large to allocate. Clearly bogus value.
return "", 0, errNotGoExe
} else if err != nil {
return "", 0, err
}
if uint64(len(b)) < length {
// Section ended before we could read the full string.
return "", 0, errNotGoExe
}
return string(b), addr + length, nil
}
// readString returns the string at address addr in the executable x.
func readString(x exe, ptrSize int, readPtr func([]byte) uint64, addr uint64) string {
hdr, err := readData(x, addr, uint64(2*ptrSize))
if err != nil || len(hdr) < 2*ptrSize {
return ""
}
dataAddr := readPtr(hdr)
dataLen := readPtr(hdr[ptrSize:])
data, err := readData(x, dataAddr, dataLen)
if err != nil || uint64(len(data)) < dataLen {
return ""
}
return string(data)
}
const searchChunkSize = 1 << 20 // 1 MB
// searchMagic returns the aligned first instance of buildInfoMagic in the data
// range [addr, addr+size). Returns false if not found.
func searchMagic(x exe, start, size uint64) (uint64, error) {
end := start + size
if end < start {
// Overflow.
return 0, errUnrecognizedFormat
}
// Round up start; magic can't occur in the initial unaligned portion.
start = (start + buildInfoAlign - 1) &^ (buildInfoAlign - 1)
if start >= end {
return 0, errNotGoExe
}
var buf []byte
for start < end {
// Read in chunks to avoid consuming too much memory if data is large.
//
// Normally it would be somewhat painful to handle the magic crossing a
// chunk boundary, but since it must be 16-byte aligned we know it will
// fall within a single chunk.
remaining := end - start
chunkSize := uint64(searchChunkSize)
if chunkSize > remaining {
chunkSize = remaining
}
if buf == nil {
buf = make([]byte, chunkSize)
} else {
// N.B. chunkSize can only decrease, and only on the
// last chunk.
buf = buf[:chunkSize]
clear(buf)
}
n, err := readDataInto(x, start, buf)
if err == io.EOF {
// EOF before finding the magic; must not be a Go executable.
return 0, errNotGoExe
} else if err != nil {
return 0, err
}
data := buf[:n]
for len(data) > 0 {
i := bytes.Index(data, buildInfoMagic)
if i < 0 {
break
}
if remaining-uint64(i) < buildInfoHeaderSize {
// Found magic, but not enough space left for the full header.
return 0, errNotGoExe
}
if i%buildInfoAlign != 0 {
// Found magic, but misaligned. Keep searching.
next := (i + buildInfoAlign - 1) &^ (buildInfoAlign - 1)
if next > len(data) {
// Corrupt object file: the remaining
// count says there is more data,
// but we didn't read it.
return 0, errNotGoExe
}
data = data[next:]
continue
}
// Good match!
return start + uint64(i), nil
}
start += chunkSize
}
return 0, errNotGoExe
}
func readData(x exe, addr, size uint64) ([]byte, error) {
r, err := x.DataReader(addr)
if err != nil {
return nil, err
}
b, err := saferio.ReadDataAt(r, size, 0)
if len(b) > 0 && err == io.EOF {
err = nil
}
return b, err
}
func readDataInto(x exe, addr uint64, b []byte) (int, error) {
r, err := x.DataReader(addr)
if err != nil {
return 0, err
}
n, err := r.ReadAt(b, 0)
if n > 0 && err == io.EOF {
err = nil
}
return n, err
}
// elfExe is the ELF implementation of the exe interface.
type elfExe struct {
f *elf.File
}
func (x *elfExe) DataReader(addr uint64) (io.ReaderAt, error) {
for _, prog := range x.f.Progs {
if prog.Vaddr <= addr && addr <= prog.Vaddr+prog.Filesz-1 {
remaining := prog.Vaddr + prog.Filesz - addr
return io.NewSectionReader(prog, int64(addr-prog.Vaddr), int64(remaining)), nil
}
}
return nil, errUnrecognizedFormat
}
func (x *elfExe) DataStart() (uint64, uint64) {
for _, s := range x.f.Sections {
if s.Name == ".go.buildinfo" {
return s.Addr, s.Size
}
}
for _, p := range x.f.Progs {
if p.Type == elf.PT_LOAD && p.Flags&(elf.PF_X|elf.PF_W) == elf.PF_W {
return p.Vaddr, p.Memsz
}
}
return 0, 0
}
// peExe is the PE (Windows Portable Executable) implementation of the exe interface.
type peExe struct {
f *pe.File
}
func (x *peExe) imageBase() uint64 {
switch oh := x.f.OptionalHeader.(type) {
case *pe.OptionalHeader32:
return uint64(oh.ImageBase)
case *pe.OptionalHeader64:
return oh.ImageBase
}
return 0
}
func (x *peExe) DataReader(addr uint64) (io.ReaderAt, error) {
addr -= x.imageBase()
for _, sect := range x.f.Sections {
if uint64(sect.VirtualAddress) <= addr && addr <= uint64(sect.VirtualAddress+sect.Size-1) {
remaining := uint64(sect.VirtualAddress+sect.Size) - addr
return io.NewSectionReader(sect, int64(addr-uint64(sect.VirtualAddress)), int64(remaining)), nil
}
}
return nil, errUnrecognizedFormat
}
func (x *peExe) DataStart() (uint64, uint64) {
// Assume data is first writable section.
const (
IMAGE_SCN_CNT_CODE = 0x00000020
IMAGE_SCN_CNT_INITIALIZED_DATA = 0x00000040
IMAGE_SCN_CNT_UNINITIALIZED_DATA = 0x00000080
IMAGE_SCN_MEM_EXECUTE = 0x20000000
IMAGE_SCN_MEM_READ = 0x40000000
IMAGE_SCN_MEM_WRITE = 0x80000000
IMAGE_SCN_MEM_DISCARDABLE = 0x2000000
IMAGE_SCN_LNK_NRELOC_OVFL = 0x1000000
IMAGE_SCN_ALIGN_32BYTES = 0x600000
)
for _, sect := range x.f.Sections {
if sect.VirtualAddress != 0 && sect.Size != 0 &&
sect.Characteristics&^IMAGE_SCN_ALIGN_32BYTES == IMAGE_SCN_CNT_INITIALIZED_DATA|IMAGE_SCN_MEM_READ|IMAGE_SCN_MEM_WRITE {
return uint64(sect.VirtualAddress) + x.imageBase(), uint64(sect.VirtualSize)
}
}
return 0, 0
}
// machoExe is the Mach-O (Apple macOS/iOS) implementation of the exe interface.
type machoExe struct {
f *macho.File
}
func (x *machoExe) DataReader(addr uint64) (io.ReaderAt, error) {
for _, load := range x.f.Loads {
seg, ok := load.(*macho.Segment)
if !ok {
continue
}
if seg.Addr <= addr && addr <= seg.Addr+seg.Filesz-1 {
if seg.Name == "__PAGEZERO" {
continue
}
remaining := seg.Addr + seg.Filesz - addr
return io.NewSectionReader(seg, int64(addr-seg.Addr), int64(remaining)), nil
}
}
return nil, errUnrecognizedFormat
}
func (x *machoExe) DataStart() (uint64, uint64) {
// Look for section named "__go_buildinfo".
for _, sec := range x.f.Sections {
if sec.Name == "__go_buildinfo" {
return sec.Addr, sec.Size
}
}
// Try the first non-empty writable segment.
const RW = 3
for _, load := range x.f.Loads {
seg, ok := load.(*macho.Segment)
if ok && seg.Addr != 0 && seg.Filesz != 0 && seg.Prot == RW && seg.Maxprot == RW {
return seg.Addr, seg.Memsz
}
}
return 0, 0
}
// xcoffExe is the XCOFF (AIX eXtended COFF) implementation of the exe interface.
type xcoffExe struct {
f *xcoff.File
}
func (x *xcoffExe) DataReader(addr uint64) (io.ReaderAt, error) {
for _, sect := range x.f.Sections {
if sect.VirtualAddress <= addr && addr <= sect.VirtualAddress+sect.Size-1 {
remaining := sect.VirtualAddress + sect.Size - addr
return io.NewSectionReader(sect, int64(addr-sect.VirtualAddress), int64(remaining)), nil
}
}
return nil, errors.New("address not mapped")
}
func (x *xcoffExe) DataStart() (uint64, uint64) {
if s := x.f.SectionByType(xcoff.STYP_DATA); s != nil {
return s.VirtualAddress, s.Size
}
return 0, 0
}
// plan9objExe is the Plan 9 a.out implementation of the exe interface.
type plan9objExe struct {
f *plan9obj.File
}
func (x *plan9objExe) DataStart() (uint64, uint64) {
if s := x.f.Section("data"); s != nil {
return uint64(s.Offset), uint64(s.Size)
}
return 0, 0
}
func (x *plan9objExe) DataReader(addr uint64) (io.ReaderAt, error) {
for _, sect := range x.f.Sections {
if uint64(sect.Offset) <= addr && addr <= uint64(sect.Offset+sect.Size-1) {
remaining := uint64(sect.Offset+sect.Size) - addr
return io.NewSectionReader(sect, int64(addr-uint64(sect.Offset)), int64(remaining)), nil
}
}
return nil, errors.New("address not mapped")
}
// Code generated by "stringer -type Attr -trimprefix=Attr"; DO NOT EDIT.
package dwarf
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[AttrSibling-1]
_ = x[AttrLocation-2]
_ = x[AttrName-3]
_ = x[AttrOrdering-9]
_ = x[AttrByteSize-11]
_ = x[AttrBitOffset-12]
_ = x[AttrBitSize-13]
_ = x[AttrStmtList-16]
_ = x[AttrLowpc-17]
_ = x[AttrHighpc-18]
_ = x[AttrLanguage-19]
_ = x[AttrDiscr-21]
_ = x[AttrDiscrValue-22]
_ = x[AttrVisibility-23]
_ = x[AttrImport-24]
_ = x[AttrStringLength-25]
_ = x[AttrCommonRef-26]
_ = x[AttrCompDir-27]
_ = x[AttrConstValue-28]
_ = x[AttrContainingType-29]
_ = x[AttrDefaultValue-30]
_ = x[AttrInline-32]
_ = x[AttrIsOptional-33]
_ = x[AttrLowerBound-34]
_ = x[AttrProducer-37]
_ = x[AttrPrototyped-39]
_ = x[AttrReturnAddr-42]
_ = x[AttrStartScope-44]
_ = x[AttrStrideSize-46]
_ = x[AttrUpperBound-47]
_ = x[AttrAbstractOrigin-49]
_ = x[AttrAccessibility-50]
_ = x[AttrAddrClass-51]
_ = x[AttrArtificial-52]
_ = x[AttrBaseTypes-53]
_ = x[AttrCalling-54]
_ = x[AttrCount-55]
_ = x[AttrDataMemberLoc-56]
_ = x[AttrDeclColumn-57]
_ = x[AttrDeclFile-58]
_ = x[AttrDeclLine-59]
_ = x[AttrDeclaration-60]
_ = x[AttrDiscrList-61]
_ = x[AttrEncoding-62]
_ = x[AttrExternal-63]
_ = x[AttrFrameBase-64]
_ = x[AttrFriend-65]
_ = x[AttrIdentifierCase-66]
_ = x[AttrMacroInfo-67]
_ = x[AttrNamelistItem-68]
_ = x[AttrPriority-69]
_ = x[AttrSegment-70]
_ = x[AttrSpecification-71]
_ = x[AttrStaticLink-72]
_ = x[AttrType-73]
_ = x[AttrUseLocation-74]
_ = x[AttrVarParam-75]
_ = x[AttrVirtuality-76]
_ = x[AttrVtableElemLoc-77]
_ = x[AttrAllocated-78]
_ = x[AttrAssociated-79]
_ = x[AttrDataLocation-80]
_ = x[AttrStride-81]
_ = x[AttrEntrypc-82]
_ = x[AttrUseUTF8-83]
_ = x[AttrExtension-84]
_ = x[AttrRanges-85]
_ = x[AttrTrampoline-86]
_ = x[AttrCallColumn-87]
_ = x[AttrCallFile-88]
_ = x[AttrCallLine-89]
_ = x[AttrDescription-90]
_ = x[AttrBinaryScale-91]
_ = x[AttrDecimalScale-92]
_ = x[AttrSmall-93]
_ = x[AttrDecimalSign-94]
_ = x[AttrDigitCount-95]
_ = x[AttrPictureString-96]
_ = x[AttrMutable-97]
_ = x[AttrThreadsScaled-98]
_ = x[AttrExplicit-99]
_ = x[AttrObjectPointer-100]
_ = x[AttrEndianity-101]
_ = x[AttrElemental-102]
_ = x[AttrPure-103]
_ = x[AttrRecursive-104]
_ = x[AttrSignature-105]
_ = x[AttrMainSubprogram-106]
_ = x[AttrDataBitOffset-107]
_ = x[AttrConstExpr-108]
_ = x[AttrEnumClass-109]
_ = x[AttrLinkageName-110]
_ = x[AttrStringLengthBitSize-111]
_ = x[AttrStringLengthByteSize-112]
_ = x[AttrRank-113]
_ = x[AttrStrOffsetsBase-114]
_ = x[AttrAddrBase-115]
_ = x[AttrRnglistsBase-116]
_ = x[AttrDwoName-118]
_ = x[AttrReference-119]
_ = x[AttrRvalueReference-120]
_ = x[AttrMacros-121]
_ = x[AttrCallAllCalls-122]
_ = x[AttrCallAllSourceCalls-123]
_ = x[AttrCallAllTailCalls-124]
_ = x[AttrCallReturnPC-125]
_ = x[AttrCallValue-126]
_ = x[AttrCallOrigin-127]
_ = x[AttrCallParameter-128]
_ = x[AttrCallPC-129]
_ = x[AttrCallTailCall-130]
_ = x[AttrCallTarget-131]
_ = x[AttrCallTargetClobbered-132]
_ = x[AttrCallDataLocation-133]
_ = x[AttrCallDataValue-134]
_ = x[AttrNoreturn-135]
_ = x[AttrAlignment-136]
_ = x[AttrExportSymbols-137]
_ = x[AttrDeleted-138]
_ = x[AttrDefaulted-139]
_ = x[AttrLoclistsBase-140]
}
const _Attr_name = "SiblingLocationNameOrderingByteSizeBitOffsetBitSizeStmtListLowpcHighpcLanguageDiscrDiscrValueVisibilityImportStringLengthCommonRefCompDirConstValueContainingTypeDefaultValueInlineIsOptionalLowerBoundProducerPrototypedReturnAddrStartScopeStrideSizeUpperBoundAbstractOriginAccessibilityAddrClassArtificialBaseTypesCallingCountDataMemberLocDeclColumnDeclFileDeclLineDeclarationDiscrListEncodingExternalFrameBaseFriendIdentifierCaseMacroInfoNamelistItemPrioritySegmentSpecificationStaticLinkTypeUseLocationVarParamVirtualityVtableElemLocAllocatedAssociatedDataLocationStrideEntrypcUseUTF8ExtensionRangesTrampolineCallColumnCallFileCallLineDescriptionBinaryScaleDecimalScaleSmallDecimalSignDigitCountPictureStringMutableThreadsScaledExplicitObjectPointerEndianityElementalPureRecursiveSignatureMainSubprogramDataBitOffsetConstExprEnumClassLinkageNameStringLengthBitSizeStringLengthByteSizeRankStrOffsetsBaseAddrBaseRnglistsBaseDwoNameReferenceRvalueReferenceMacrosCallAllCallsCallAllSourceCallsCallAllTailCallsCallReturnPCCallValueCallOriginCallParameterCallPCCallTailCallCallTargetCallTargetClobberedCallDataLocationCallDataValueNoreturnAlignmentExportSymbolsDeletedDefaultedLoclistsBase"
var _Attr_map = map[Attr]string{
1: _Attr_name[0:7],
2: _Attr_name[7:15],
3: _Attr_name[15:19],
9: _Attr_name[19:27],
11: _Attr_name[27:35],
12: _Attr_name[35:44],
13: _Attr_name[44:51],
16: _Attr_name[51:59],
17: _Attr_name[59:64],
18: _Attr_name[64:70],
19: _Attr_name[70:78],
21: _Attr_name[78:83],
22: _Attr_name[83:93],
23: _Attr_name[93:103],
24: _Attr_name[103:109],
25: _Attr_name[109:121],
26: _Attr_name[121:130],
27: _Attr_name[130:137],
28: _Attr_name[137:147],
29: _Attr_name[147:161],
30: _Attr_name[161:173],
32: _Attr_name[173:179],
33: _Attr_name[179:189],
34: _Attr_name[189:199],
37: _Attr_name[199:207],
39: _Attr_name[207:217],
42: _Attr_name[217:227],
44: _Attr_name[227:237],
46: _Attr_name[237:247],
47: _Attr_name[247:257],
49: _Attr_name[257:271],
50: _Attr_name[271:284],
51: _Attr_name[284:293],
52: _Attr_name[293:303],
53: _Attr_name[303:312],
54: _Attr_name[312:319],
55: _Attr_name[319:324],
56: _Attr_name[324:337],
57: _Attr_name[337:347],
58: _Attr_name[347:355],
59: _Attr_name[355:363],
60: _Attr_name[363:374],
61: _Attr_name[374:383],
62: _Attr_name[383:391],
63: _Attr_name[391:399],
64: _Attr_name[399:408],
65: _Attr_name[408:414],
66: _Attr_name[414:428],
67: _Attr_name[428:437],
68: _Attr_name[437:449],
69: _Attr_name[449:457],
70: _Attr_name[457:464],
71: _Attr_name[464:477],
72: _Attr_name[477:487],
73: _Attr_name[487:491],
74: _Attr_name[491:502],
75: _Attr_name[502:510],
76: _Attr_name[510:520],
77: _Attr_name[520:533],
78: _Attr_name[533:542],
79: _Attr_name[542:552],
80: _Attr_name[552:564],
81: _Attr_name[564:570],
82: _Attr_name[570:577],
83: _Attr_name[577:584],
84: _Attr_name[584:593],
85: _Attr_name[593:599],
86: _Attr_name[599:609],
87: _Attr_name[609:619],
88: _Attr_name[619:627],
89: _Attr_name[627:635],
90: _Attr_name[635:646],
91: _Attr_name[646:657],
92: _Attr_name[657:669],
93: _Attr_name[669:674],
94: _Attr_name[674:685],
95: _Attr_name[685:695],
96: _Attr_name[695:708],
97: _Attr_name[708:715],
98: _Attr_name[715:728],
99: _Attr_name[728:736],
100: _Attr_name[736:749],
101: _Attr_name[749:758],
102: _Attr_name[758:767],
103: _Attr_name[767:771],
104: _Attr_name[771:780],
105: _Attr_name[780:789],
106: _Attr_name[789:803],
107: _Attr_name[803:816],
108: _Attr_name[816:825],
109: _Attr_name[825:834],
110: _Attr_name[834:845],
111: _Attr_name[845:864],
112: _Attr_name[864:884],
113: _Attr_name[884:888],
114: _Attr_name[888:902],
115: _Attr_name[902:910],
116: _Attr_name[910:922],
118: _Attr_name[922:929],
119: _Attr_name[929:938],
120: _Attr_name[938:953],
121: _Attr_name[953:959],
122: _Attr_name[959:971],
123: _Attr_name[971:989],
124: _Attr_name[989:1005],
125: _Attr_name[1005:1017],
126: _Attr_name[1017:1026],
127: _Attr_name[1026:1036],
128: _Attr_name[1036:1049],
129: _Attr_name[1049:1055],
130: _Attr_name[1055:1067],
131: _Attr_name[1067:1077],
132: _Attr_name[1077:1096],
133: _Attr_name[1096:1112],
134: _Attr_name[1112:1125],
135: _Attr_name[1125:1133],
136: _Attr_name[1133:1142],
137: _Attr_name[1142:1155],
138: _Attr_name[1155:1162],
139: _Attr_name[1162:1171],
140: _Attr_name[1171:1183],
}
func (i Attr) String() string {
if str, ok := _Attr_map[i]; ok {
return str
}
return "Attr(" + strconv.FormatInt(int64(i), 10) + ")"
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Buffered reading and decoding of DWARF data streams.
package dwarf
import (
"bytes"
"encoding/binary"
"strconv"
)
// Data buffer being decoded.
type buf struct {
dwarf *Data
order binary.ByteOrder
format dataFormat
name string
off Offset
data []byte
err error
}
// Data format, other than byte order. This affects the handling of
// certain field formats.
type dataFormat interface {
// DWARF version number. Zero means unknown.
version() int
// 64-bit DWARF format?
dwarf64() (dwarf64 bool, isKnown bool)
// Size of an address, in bytes. Zero means unknown.
addrsize() int
}
// Some parts of DWARF have no data format, e.g., abbrevs.
type unknownFormat struct{}
func (u unknownFormat) version() int {
return 0
}
func (u unknownFormat) dwarf64() (bool, bool) {
return false, false
}
func (u unknownFormat) addrsize() int {
return 0
}
func makeBuf(d *Data, format dataFormat, name string, off Offset, data []byte) buf {
return buf{d, d.order, format, name, off, data, nil}
}
func (b *buf) uint8() uint8 {
if len(b.data) < 1 {
b.error("underflow")
return 0
}
val := b.data[0]
b.data = b.data[1:]
b.off++
return val
}
func (b *buf) bytes(n int) []byte {
if n < 0 || len(b.data) < n {
b.error("underflow")
return nil
}
data := b.data[0:n]
b.data = b.data[n:]
b.off += Offset(n)
return data
}
func (b *buf) skip(n int) { b.bytes(n) }
func (b *buf) string() string {
i := bytes.IndexByte(b.data, 0)
if i < 0 {
b.error("underflow")
return ""
}
s := string(b.data[0:i])
b.data = b.data[i+1:]
b.off += Offset(i + 1)
return s
}
func (b *buf) uint16() uint16 {
a := b.bytes(2)
if a == nil {
return 0
}
return b.order.Uint16(a)
}
func (b *buf) uint24() uint32 {
a := b.bytes(3)
if a == nil {
return 0
}
if b.dwarf.bigEndian {
return uint32(a[2]) | uint32(a[1])<<8 | uint32(a[0])<<16
} else {
return uint32(a[0]) | uint32(a[1])<<8 | uint32(a[2])<<16
}
}
func (b *buf) uint32() uint32 {
a := b.bytes(4)
if a == nil {
return 0
}
return b.order.Uint32(a)
}
func (b *buf) uint64() uint64 {
a := b.bytes(8)
if a == nil {
return 0
}
return b.order.Uint64(a)
}
// Read a varint, which is 7 bits per byte, little endian.
// the 0x80 bit means read another byte.
func (b *buf) varint() (c uint64, bits uint) {
for i := 0; i < len(b.data); i++ {
byte := b.data[i]
c |= uint64(byte&0x7F) << bits
bits += 7
if byte&0x80 == 0 {
b.off += Offset(i + 1)
b.data = b.data[i+1:]
return c, bits
}
}
return 0, 0
}
// Unsigned int is just a varint.
func (b *buf) uint() uint64 {
x, _ := b.varint()
return x
}
// Signed int is a sign-extended varint.
func (b *buf) int() int64 {
ux, bits := b.varint()
x := int64(ux)
if x&(1<<(bits-1)) != 0 {
x |= -1 << bits
}
return x
}
// Address-sized uint.
func (b *buf) addr() uint64 {
switch b.format.addrsize() {
case 1:
return uint64(b.uint8())
case 2:
return uint64(b.uint16())
case 4:
return uint64(b.uint32())
case 8:
return b.uint64()
}
b.error("unknown address size")
return 0
}
func (b *buf) unitLength() (length Offset, dwarf64 bool) {
length = Offset(b.uint32())
if length == 0xffffffff {
dwarf64 = true
length = Offset(b.uint64())
} else if length >= 0xfffffff0 {
b.error("unit length has reserved value")
}
return
}
func (b *buf) error(s string) {
if b.err == nil {
b.data = nil
b.err = DecodeError{b.name, b.off, s}
}
}
type DecodeError struct {
Name string
Offset Offset
Err string
}
func (e DecodeError) Error() string {
return "decoding dwarf section " + e.Name + " at offset 0x" + strconv.FormatInt(int64(e.Offset), 16) + ": " + e.Err
}
// Code generated by "stringer -type=Class"; DO NOT EDIT.
package dwarf
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[ClassUnknown-0]
_ = x[ClassAddress-1]
_ = x[ClassBlock-2]
_ = x[ClassConstant-3]
_ = x[ClassExprLoc-4]
_ = x[ClassFlag-5]
_ = x[ClassLinePtr-6]
_ = x[ClassLocListPtr-7]
_ = x[ClassMacPtr-8]
_ = x[ClassRangeListPtr-9]
_ = x[ClassReference-10]
_ = x[ClassReferenceSig-11]
_ = x[ClassString-12]
_ = x[ClassReferenceAlt-13]
_ = x[ClassStringAlt-14]
_ = x[ClassAddrPtr-15]
_ = x[ClassLocList-16]
_ = x[ClassRngList-17]
_ = x[ClassRngListsPtr-18]
_ = x[ClassStrOffsetsPtr-19]
}
const _Class_name = "ClassUnknownClassAddressClassBlockClassConstantClassExprLocClassFlagClassLinePtrClassLocListPtrClassMacPtrClassRangeListPtrClassReferenceClassReferenceSigClassStringClassReferenceAltClassStringAltClassAddrPtrClassLocListClassRngListClassRngListsPtrClassStrOffsetsPtr"
var _Class_index = [...]uint16{0, 12, 24, 34, 47, 59, 68, 80, 95, 106, 123, 137, 154, 165, 182, 196, 208, 220, 232, 248, 266}
func (i Class) String() string {
if i < 0 || i >= Class(len(_Class_index)-1) {
return "Class(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _Class_name[_Class_index[i]:_Class_index[i+1]]
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Constants
package dwarf
//go:generate stringer -type Attr -trimprefix=Attr
// An Attr identifies the attribute type in a DWARF [Entry.Field].
type Attr uint32
const (
AttrSibling Attr = 0x01
AttrLocation Attr = 0x02
AttrName Attr = 0x03
AttrOrdering Attr = 0x09
AttrByteSize Attr = 0x0B
AttrBitOffset Attr = 0x0C
AttrBitSize Attr = 0x0D
AttrStmtList Attr = 0x10
AttrLowpc Attr = 0x11
AttrHighpc Attr = 0x12
AttrLanguage Attr = 0x13
AttrDiscr Attr = 0x15
AttrDiscrValue Attr = 0x16
AttrVisibility Attr = 0x17
AttrImport Attr = 0x18
AttrStringLength Attr = 0x19
AttrCommonRef Attr = 0x1A
AttrCompDir Attr = 0x1B
AttrConstValue Attr = 0x1C
AttrContainingType Attr = 0x1D
AttrDefaultValue Attr = 0x1E
AttrInline Attr = 0x20
AttrIsOptional Attr = 0x21
AttrLowerBound Attr = 0x22
AttrProducer Attr = 0x25
AttrPrototyped Attr = 0x27
AttrReturnAddr Attr = 0x2A
AttrStartScope Attr = 0x2C
AttrStrideSize Attr = 0x2E
AttrUpperBound Attr = 0x2F
AttrAbstractOrigin Attr = 0x31
AttrAccessibility Attr = 0x32
AttrAddrClass Attr = 0x33
AttrArtificial Attr = 0x34
AttrBaseTypes Attr = 0x35
AttrCalling Attr = 0x36
AttrCount Attr = 0x37
AttrDataMemberLoc Attr = 0x38
AttrDeclColumn Attr = 0x39
AttrDeclFile Attr = 0x3A
AttrDeclLine Attr = 0x3B
AttrDeclaration Attr = 0x3C
AttrDiscrList Attr = 0x3D
AttrEncoding Attr = 0x3E
AttrExternal Attr = 0x3F
AttrFrameBase Attr = 0x40
AttrFriend Attr = 0x41
AttrIdentifierCase Attr = 0x42
AttrMacroInfo Attr = 0x43
AttrNamelistItem Attr = 0x44
AttrPriority Attr = 0x45
AttrSegment Attr = 0x46
AttrSpecification Attr = 0x47
AttrStaticLink Attr = 0x48
AttrType Attr = 0x49
AttrUseLocation Attr = 0x4A
AttrVarParam Attr = 0x4B
AttrVirtuality Attr = 0x4C
AttrVtableElemLoc Attr = 0x4D
// The following are new in DWARF 3.
AttrAllocated Attr = 0x4E
AttrAssociated Attr = 0x4F
AttrDataLocation Attr = 0x50
AttrStride Attr = 0x51
AttrEntrypc Attr = 0x52
AttrUseUTF8 Attr = 0x53
AttrExtension Attr = 0x54
AttrRanges Attr = 0x55
AttrTrampoline Attr = 0x56
AttrCallColumn Attr = 0x57
AttrCallFile Attr = 0x58
AttrCallLine Attr = 0x59
AttrDescription Attr = 0x5A
AttrBinaryScale Attr = 0x5B
AttrDecimalScale Attr = 0x5C
AttrSmall Attr = 0x5D
AttrDecimalSign Attr = 0x5E
AttrDigitCount Attr = 0x5F
AttrPictureString Attr = 0x60
AttrMutable Attr = 0x61
AttrThreadsScaled Attr = 0x62
AttrExplicit Attr = 0x63
AttrObjectPointer Attr = 0x64
AttrEndianity Attr = 0x65
AttrElemental Attr = 0x66
AttrPure Attr = 0x67
AttrRecursive Attr = 0x68
// The following are new in DWARF 4.
AttrSignature Attr = 0x69
AttrMainSubprogram Attr = 0x6A
AttrDataBitOffset Attr = 0x6B
AttrConstExpr Attr = 0x6C
AttrEnumClass Attr = 0x6D
AttrLinkageName Attr = 0x6E
// The following are new in DWARF 5.
AttrStringLengthBitSize Attr = 0x6F
AttrStringLengthByteSize Attr = 0x70
AttrRank Attr = 0x71
AttrStrOffsetsBase Attr = 0x72
AttrAddrBase Attr = 0x73
AttrRnglistsBase Attr = 0x74
AttrDwoName Attr = 0x76
AttrReference Attr = 0x77
AttrRvalueReference Attr = 0x78
AttrMacros Attr = 0x79
AttrCallAllCalls Attr = 0x7A
AttrCallAllSourceCalls Attr = 0x7B
AttrCallAllTailCalls Attr = 0x7C
AttrCallReturnPC Attr = 0x7D
AttrCallValue Attr = 0x7E
AttrCallOrigin Attr = 0x7F
AttrCallParameter Attr = 0x80
AttrCallPC Attr = 0x81
AttrCallTailCall Attr = 0x82
AttrCallTarget Attr = 0x83
AttrCallTargetClobbered Attr = 0x84
AttrCallDataLocation Attr = 0x85
AttrCallDataValue Attr = 0x86
AttrNoreturn Attr = 0x87
AttrAlignment Attr = 0x88
AttrExportSymbols Attr = 0x89
AttrDeleted Attr = 0x8A
AttrDefaulted Attr = 0x8B
AttrLoclistsBase Attr = 0x8C
)
func (a Attr) GoString() string {
if str, ok := _Attr_map[a]; ok {
return "dwarf.Attr" + str
}
return "dwarf." + a.String()
}
// A format is a DWARF data encoding format.
type format uint32
const (
// value formats
formAddr format = 0x01
formDwarfBlock2 format = 0x03
formDwarfBlock4 format = 0x04
formData2 format = 0x05
formData4 format = 0x06
formData8 format = 0x07
formString format = 0x08
formDwarfBlock format = 0x09
formDwarfBlock1 format = 0x0A
formData1 format = 0x0B
formFlag format = 0x0C
formSdata format = 0x0D
formStrp format = 0x0E
formUdata format = 0x0F
formRefAddr format = 0x10
formRef1 format = 0x11
formRef2 format = 0x12
formRef4 format = 0x13
formRef8 format = 0x14
formRefUdata format = 0x15
formIndirect format = 0x16
// The following are new in DWARF 4.
formSecOffset format = 0x17
formExprloc format = 0x18
formFlagPresent format = 0x19
formRefSig8 format = 0x20
// The following are new in DWARF 5.
formStrx format = 0x1A
formAddrx format = 0x1B
formRefSup4 format = 0x1C
formStrpSup format = 0x1D
formData16 format = 0x1E
formLineStrp format = 0x1F
formImplicitConst format = 0x21
formLoclistx format = 0x22
formRnglistx format = 0x23
formRefSup8 format = 0x24
formStrx1 format = 0x25
formStrx2 format = 0x26
formStrx3 format = 0x27
formStrx4 format = 0x28
formAddrx1 format = 0x29
formAddrx2 format = 0x2A
formAddrx3 format = 0x2B
formAddrx4 format = 0x2C
// Extensions for multi-file compression (.dwz)
// http://www.dwarfstd.org/ShowIssue.php?issue=120604.1
formGnuRefAlt format = 0x1f20
formGnuStrpAlt format = 0x1f21
)
//go:generate stringer -type Tag -trimprefix=Tag
// A Tag is the classification (the type) of an [Entry].
type Tag uint32
const (
TagArrayType Tag = 0x01
TagClassType Tag = 0x02
TagEntryPoint Tag = 0x03
TagEnumerationType Tag = 0x04
TagFormalParameter Tag = 0x05
TagImportedDeclaration Tag = 0x08
TagLabel Tag = 0x0A
TagLexDwarfBlock Tag = 0x0B
TagMember Tag = 0x0D
TagPointerType Tag = 0x0F
TagReferenceType Tag = 0x10
TagCompileUnit Tag = 0x11
TagStringType Tag = 0x12
TagStructType Tag = 0x13
TagSubroutineType Tag = 0x15
TagTypedef Tag = 0x16
TagUnionType Tag = 0x17
TagUnspecifiedParameters Tag = 0x18
TagVariant Tag = 0x19
TagCommonDwarfBlock Tag = 0x1A
TagCommonInclusion Tag = 0x1B
TagInheritance Tag = 0x1C
TagInlinedSubroutine Tag = 0x1D
TagModule Tag = 0x1E
TagPtrToMemberType Tag = 0x1F
TagSetType Tag = 0x20
TagSubrangeType Tag = 0x21
TagWithStmt Tag = 0x22
TagAccessDeclaration Tag = 0x23
TagBaseType Tag = 0x24
TagCatchDwarfBlock Tag = 0x25
TagConstType Tag = 0x26
TagConstant Tag = 0x27
TagEnumerator Tag = 0x28
TagFileType Tag = 0x29
TagFriend Tag = 0x2A
TagNamelist Tag = 0x2B
TagNamelistItem Tag = 0x2C
TagPackedType Tag = 0x2D
TagSubprogram Tag = 0x2E
TagTemplateTypeParameter Tag = 0x2F
TagTemplateValueParameter Tag = 0x30
TagThrownType Tag = 0x31
TagTryDwarfBlock Tag = 0x32
TagVariantPart Tag = 0x33
TagVariable Tag = 0x34
TagVolatileType Tag = 0x35
// The following are new in DWARF 3.
TagDwarfProcedure Tag = 0x36
TagRestrictType Tag = 0x37
TagInterfaceType Tag = 0x38
TagNamespace Tag = 0x39
TagImportedModule Tag = 0x3A
TagUnspecifiedType Tag = 0x3B
TagPartialUnit Tag = 0x3C
TagImportedUnit Tag = 0x3D
TagMutableType Tag = 0x3E // Later removed from DWARF.
TagCondition Tag = 0x3F
TagSharedType Tag = 0x40
// The following are new in DWARF 4.
TagTypeUnit Tag = 0x41
TagRvalueReferenceType Tag = 0x42
TagTemplateAlias Tag = 0x43
// The following are new in DWARF 5.
TagCoarrayType Tag = 0x44
TagGenericSubrange Tag = 0x45
TagDynamicType Tag = 0x46
TagAtomicType Tag = 0x47
TagCallSite Tag = 0x48
TagCallSiteParameter Tag = 0x49
TagSkeletonUnit Tag = 0x4A
TagImmutableType Tag = 0x4B
)
func (t Tag) GoString() string {
if t <= TagTemplateAlias {
return "dwarf.Tag" + t.String()
}
return "dwarf." + t.String()
}
// Location expression operators.
// The debug info encodes value locations like 8(R3)
// as a sequence of these op codes.
// This package does not implement full expressions;
// the opPlusUconst operator is expected by the type parser.
const (
opAddr = 0x03 /* 1 op, const addr */
opDeref = 0x06
opConst1u = 0x08 /* 1 op, 1 byte const */
opConst1s = 0x09 /* " signed */
opConst2u = 0x0A /* 1 op, 2 byte const */
opConst2s = 0x0B /* " signed */
opConst4u = 0x0C /* 1 op, 4 byte const */
opConst4s = 0x0D /* " signed */
opConst8u = 0x0E /* 1 op, 8 byte const */
opConst8s = 0x0F /* " signed */
opConstu = 0x10 /* 1 op, LEB128 const */
opConsts = 0x11 /* " signed */
opDup = 0x12
opDrop = 0x13
opOver = 0x14
opPick = 0x15 /* 1 op, 1 byte stack index */
opSwap = 0x16
opRot = 0x17
opXderef = 0x18
opAbs = 0x19
opAnd = 0x1A
opDiv = 0x1B
opMinus = 0x1C
opMod = 0x1D
opMul = 0x1E
opNeg = 0x1F
opNot = 0x20
opOr = 0x21
opPlus = 0x22
opPlusUconst = 0x23 /* 1 op, ULEB128 addend */
opShl = 0x24
opShr = 0x25
opShra = 0x26
opXor = 0x27
opSkip = 0x2F /* 1 op, signed 2-byte constant */
opBra = 0x28 /* 1 op, signed 2-byte constant */
opEq = 0x29
opGe = 0x2A
opGt = 0x2B
opLe = 0x2C
opLt = 0x2D
opNe = 0x2E
opLit0 = 0x30
/* OpLitN = OpLit0 + N for N = 0..31 */
opReg0 = 0x50
/* OpRegN = OpReg0 + N for N = 0..31 */
opBreg0 = 0x70 /* 1 op, signed LEB128 constant */
/* OpBregN = OpBreg0 + N for N = 0..31 */
opRegx = 0x90 /* 1 op, ULEB128 register */
opFbreg = 0x91 /* 1 op, SLEB128 offset */
opBregx = 0x92 /* 2 op, ULEB128 reg; SLEB128 off */
opPiece = 0x93 /* 1 op, ULEB128 size of piece */
opDerefSize = 0x94 /* 1-byte size of data retrieved */
opXderefSize = 0x95 /* 1-byte size of data retrieved */
opNop = 0x96
// The following are new in DWARF 3.
opPushObjAddr = 0x97
opCall2 = 0x98 /* 2-byte offset of DIE */
opCall4 = 0x99 /* 4-byte offset of DIE */
opCallRef = 0x9A /* 4- or 8- byte offset of DIE */
opFormTLSAddress = 0x9B
opCallFrameCFA = 0x9C
opBitPiece = 0x9D
// The following are new in DWARF 4.
opImplicitValue = 0x9E
opStackValue = 0x9F
// The following a new in DWARF 5.
opImplicitPointer = 0xA0
opAddrx = 0xA1
opConstx = 0xA2
opEntryValue = 0xA3
opConstType = 0xA4
opRegvalType = 0xA5
opDerefType = 0xA6
opXderefType = 0xA7
opConvert = 0xA8
opReinterpret = 0xA9
/* 0xE0-0xFF reserved for user-specific */
)
// Basic type encodings -- the value for AttrEncoding in a TagBaseType Entry.
const (
encAddress = 0x01
encBoolean = 0x02
encComplexFloat = 0x03
encFloat = 0x04
encSigned = 0x05
encSignedChar = 0x06
encUnsigned = 0x07
encUnsignedChar = 0x08
// The following are new in DWARF 3.
encImaginaryFloat = 0x09
encPackedDecimal = 0x0A
encNumericString = 0x0B
encEdited = 0x0C
encSignedFixed = 0x0D
encUnsignedFixed = 0x0E
encDecimalFloat = 0x0F
// The following are new in DWARF 4.
encUTF = 0x10
// The following are new in DWARF 5.
encUCS = 0x11
encASCII = 0x12
)
// Statement program standard opcode encodings.
const (
lnsCopy = 1
lnsAdvancePC = 2
lnsAdvanceLine = 3
lnsSetFile = 4
lnsSetColumn = 5
lnsNegateStmt = 6
lnsSetBasicBlock = 7
lnsConstAddPC = 8
lnsFixedAdvancePC = 9
// DWARF 3
lnsSetPrologueEnd = 10
lnsSetEpilogueBegin = 11
lnsSetISA = 12
)
// Statement program extended opcode encodings.
const (
lneEndSequence = 1
lneSetAddress = 2
lneDefineFile = 3
// DWARF 4
lneSetDiscriminator = 4
)
// Line table directory and file name entry formats.
// These are new in DWARF 5.
const (
lnctPath = 0x01
lnctDirectoryIndex = 0x02
lnctTimestamp = 0x03
lnctSize = 0x04
lnctMD5 = 0x05
)
// Location list entry codes.
// These are new in DWARF 5.
const (
lleEndOfList = 0x00
lleBaseAddressx = 0x01
lleStartxEndx = 0x02
lleStartxLength = 0x03
lleOffsetPair = 0x04
lleDefaultLocation = 0x05
lleBaseAddress = 0x06
lleStartEnd = 0x07
lleStartLength = 0x08
)
// Unit header unit type encodings.
// These are new in DWARF 5.
const (
utCompile = 0x01
utType = 0x02
utPartial = 0x03
utSkeleton = 0x04
utSplitCompile = 0x05
utSplitType = 0x06
)
// Opcodes for DWARFv5 debug_rnglists section.
const (
rleEndOfList = 0x0
rleBaseAddressx = 0x1
rleStartxEndx = 0x2
rleStartxLength = 0x3
rleOffsetPair = 0x4
rleBaseAddress = 0x5
rleStartEnd = 0x6
rleStartLength = 0x7
)
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// DWARF debug information entry parser.
// An entry is a sequence of data items of a given format.
// The first word in the entry is an index into what DWARF
// calls the ``abbreviation table.'' An abbreviation is really
// just a type descriptor: it's an array of attribute tag/value format pairs.
package dwarf
import (
"encoding/binary"
"errors"
"fmt"
"strconv"
)
// a single entry's description: a sequence of attributes
type abbrev struct {
tag Tag
children bool
field []afield
}
type afield struct {
attr Attr
fmt format
class Class
val int64 // for formImplicitConst
}
// a map from entry format ids to their descriptions
type abbrevTable map[uint32]abbrev
// parseAbbrev returns the abbreviation table that starts at byte off
// in the .debug_abbrev section.
func (d *Data) parseAbbrev(off uint64, vers int) (abbrevTable, error) {
if m, ok := d.abbrevCache[off]; ok {
return m, nil
}
data := d.abbrev
if off > uint64(len(data)) {
data = nil
} else {
data = data[off:]
}
b := makeBuf(d, unknownFormat{}, "abbrev", 0, data)
// Error handling is simplified by the buf getters
// returning an endless stream of 0s after an error.
m := make(abbrevTable)
for {
// Table ends with id == 0.
id := uint32(b.uint())
if id == 0 {
break
}
// Walk over attributes, counting.
n := 0
b1 := b // Read from copy of b.
b1.uint()
b1.uint8()
for {
tag := b1.uint()
fmt := b1.uint()
if tag == 0 && fmt == 0 {
break
}
if format(fmt) == formImplicitConst {
b1.int()
}
n++
}
if b1.err != nil {
return nil, b1.err
}
// Walk over attributes again, this time writing them down.
var a abbrev
a.tag = Tag(b.uint())
a.children = b.uint8() != 0
a.field = make([]afield, n)
for i := range a.field {
a.field[i].attr = Attr(b.uint())
a.field[i].fmt = format(b.uint())
a.field[i].class = formToClass(a.field[i].fmt, a.field[i].attr, vers, &b)
if a.field[i].fmt == formImplicitConst {
a.field[i].val = b.int()
}
}
b.uint()
b.uint()
m[id] = a
}
if b.err != nil {
return nil, b.err
}
d.abbrevCache[off] = m
return m, nil
}
// attrIsExprloc indicates attributes that allow exprloc values that
// are encoded as block values in DWARF 2 and 3. See DWARF 4, Figure
// 20.
var attrIsExprloc = map[Attr]bool{
AttrLocation: true,
AttrByteSize: true,
AttrBitOffset: true,
AttrBitSize: true,
AttrStringLength: true,
AttrLowerBound: true,
AttrReturnAddr: true,
AttrStrideSize: true,
AttrUpperBound: true,
AttrCount: true,
AttrDataMemberLoc: true,
AttrFrameBase: true,
AttrSegment: true,
AttrStaticLink: true,
AttrUseLocation: true,
AttrVtableElemLoc: true,
AttrAllocated: true,
AttrAssociated: true,
AttrDataLocation: true,
AttrStride: true,
}
// attrPtrClass indicates the *ptr class of attributes that have
// encoding formSecOffset in DWARF 4 or formData* in DWARF 2 and 3.
var attrPtrClass = map[Attr]Class{
AttrLocation: ClassLocListPtr,
AttrStmtList: ClassLinePtr,
AttrStringLength: ClassLocListPtr,
AttrReturnAddr: ClassLocListPtr,
AttrStartScope: ClassRangeListPtr,
AttrDataMemberLoc: ClassLocListPtr,
AttrFrameBase: ClassLocListPtr,
AttrMacroInfo: ClassMacPtr,
AttrSegment: ClassLocListPtr,
AttrStaticLink: ClassLocListPtr,
AttrUseLocation: ClassLocListPtr,
AttrVtableElemLoc: ClassLocListPtr,
AttrRanges: ClassRangeListPtr,
// The following are new in DWARF 5.
AttrStrOffsetsBase: ClassStrOffsetsPtr,
AttrAddrBase: ClassAddrPtr,
AttrRnglistsBase: ClassRngListsPtr,
AttrLoclistsBase: ClassLocListPtr,
}
// formToClass returns the DWARF 4 Class for the given form. If the
// DWARF version is less then 4, it will disambiguate some forms
// depending on the attribute.
func formToClass(form format, attr Attr, vers int, b *buf) Class {
switch form {
default:
b.error("cannot determine class of unknown attribute form")
return 0
case formIndirect:
return ClassUnknown
case formAddr, formAddrx, formAddrx1, formAddrx2, formAddrx3, formAddrx4:
return ClassAddress
case formDwarfBlock1, formDwarfBlock2, formDwarfBlock4, formDwarfBlock:
// In DWARF 2 and 3, ClassExprLoc was encoded as a
// block. DWARF 4 distinguishes ClassBlock and
// ClassExprLoc, but there are no attributes that can
// be both, so we also promote ClassBlock values in
// DWARF 4 that should be ClassExprLoc in case
// producers get this wrong.
if attrIsExprloc[attr] {
return ClassExprLoc
}
return ClassBlock
case formData1, formData2, formData4, formData8, formSdata, formUdata, formData16, formImplicitConst:
// In DWARF 2 and 3, ClassPtr was encoded as a
// constant. Unlike ClassExprLoc/ClassBlock, some
// DWARF 4 attributes need to distinguish Class*Ptr
// from ClassConstant, so we only do this promotion
// for versions 2 and 3.
if class, ok := attrPtrClass[attr]; vers < 4 && ok {
return class
}
return ClassConstant
case formFlag, formFlagPresent:
return ClassFlag
case formRefAddr, formRef1, formRef2, formRef4, formRef8, formRefUdata, formRefSup4, formRefSup8:
return ClassReference
case formRefSig8:
return ClassReferenceSig
case formString, formStrp, formStrx, formStrpSup, formLineStrp, formStrx1, formStrx2, formStrx3, formStrx4:
return ClassString
case formSecOffset:
// DWARF 4 defines four *ptr classes, but doesn't
// distinguish them in the encoding. Disambiguate
// these classes using the attribute.
if class, ok := attrPtrClass[attr]; ok {
return class
}
return ClassUnknown
case formExprloc:
return ClassExprLoc
case formGnuRefAlt:
return ClassReferenceAlt
case formGnuStrpAlt:
return ClassStringAlt
case formLoclistx:
return ClassLocList
case formRnglistx:
return ClassRngList
}
}
// An entry is a sequence of attribute/value pairs.
type Entry struct {
Offset Offset // offset of Entry in DWARF info
Tag Tag // tag (kind of Entry)
Children bool // whether Entry is followed by children
Field []Field
}
// A Field is a single attribute/value pair in an [Entry].
//
// A value can be one of several "attribute classes" defined by DWARF.
// The Go types corresponding to each class are:
//
// DWARF class Go type Class
// ----------- ------- -----
// address uint64 ClassAddress
// block []byte ClassBlock
// constant int64 ClassConstant
// flag bool ClassFlag
// reference
// to info dwarf.Offset ClassReference
// to type unit uint64 ClassReferenceSig
// string string ClassString
// exprloc []byte ClassExprLoc
// lineptr int64 ClassLinePtr
// loclistptr int64 ClassLocListPtr
// macptr int64 ClassMacPtr
// rangelistptr int64 ClassRangeListPtr
//
// For unrecognized or vendor-defined attributes, [Class] may be
// [ClassUnknown].
type Field struct {
Attr Attr
Val any
Class Class
}
// A Class is the DWARF 4 class of an attribute value.
//
// In general, a given attribute's value may take on one of several
// possible classes defined by DWARF, each of which leads to a
// slightly different interpretation of the attribute.
//
// DWARF version 4 distinguishes attribute value classes more finely
// than previous versions of DWARF. The reader will disambiguate
// coarser classes from earlier versions of DWARF into the appropriate
// DWARF 4 class. For example, DWARF 2 uses "constant" for constants
// as well as all types of section offsets, but the reader will
// canonicalize attributes in DWARF 2 files that refer to section
// offsets to one of the Class*Ptr classes, even though these classes
// were only defined in DWARF 3.
type Class int
const (
// ClassUnknown represents values of unknown DWARF class.
ClassUnknown Class = iota
// ClassAddress represents values of type uint64 that are
// addresses on the target machine.
ClassAddress
// ClassBlock represents values of type []byte whose
// interpretation depends on the attribute.
ClassBlock
// ClassConstant represents values of type int64 that are
// constants. The interpretation of this constant depends on
// the attribute.
ClassConstant
// ClassExprLoc represents values of type []byte that contain
// an encoded DWARF expression or location description.
ClassExprLoc
// ClassFlag represents values of type bool.
ClassFlag
// ClassLinePtr represents values that are an int64 offset
// into the "line" section.
ClassLinePtr
// ClassLocListPtr represents values that are an int64 offset
// into the "loclist" section.
ClassLocListPtr
// ClassMacPtr represents values that are an int64 offset into
// the "mac" section.
ClassMacPtr
// ClassRangeListPtr represents values that are an int64 offset into
// the "rangelist" section.
ClassRangeListPtr
// ClassReference represents values that are an Offset offset
// of an Entry in the info section (for use with Reader.Seek).
// The DWARF specification combines ClassReference and
// ClassReferenceSig into class "reference".
ClassReference
// ClassReferenceSig represents values that are a uint64 type
// signature referencing a type Entry.
ClassReferenceSig
// ClassString represents values that are strings. If the
// compilation unit specifies the AttrUseUTF8 flag (strongly
// recommended), the string value will be encoded in UTF-8.
// Otherwise, the encoding is unspecified.
ClassString
// ClassReferenceAlt represents values of type int64 that are
// an offset into the DWARF "info" section of an alternate
// object file.
ClassReferenceAlt
// ClassStringAlt represents values of type int64 that are an
// offset into the DWARF string section of an alternate object
// file.
ClassStringAlt
// ClassAddrPtr represents values that are an int64 offset
// into the "addr" section.
ClassAddrPtr
// ClassLocList represents values that are an int64 offset
// into the "loclists" section.
ClassLocList
// ClassRngList represents values that are a uint64 offset
// from the base of the "rnglists" section.
ClassRngList
// ClassRngListsPtr represents values that are an int64 offset
// into the "rnglists" section. These are used as the base for
// ClassRngList values.
ClassRngListsPtr
// ClassStrOffsetsPtr represents values that are an int64
// offset into the "str_offsets" section.
ClassStrOffsetsPtr
)
//go:generate stringer -type=Class
func (i Class) GoString() string {
return "dwarf." + i.String()
}
// Val returns the value associated with attribute [Attr] in [Entry],
// or nil if there is no such attribute.
//
// A common idiom is to merge the check for nil return with
// the check that the value has the expected dynamic type, as in:
//
// v, ok := e.Val(AttrSibling).(int64)
func (e *Entry) Val(a Attr) any {
if f := e.AttrField(a); f != nil {
return f.Val
}
return nil
}
// AttrField returns the [Field] associated with attribute [Attr] in
// [Entry], or nil if there is no such attribute.
func (e *Entry) AttrField(a Attr) *Field {
for i, f := range e.Field {
if f.Attr == a {
return &e.Field[i]
}
}
return nil
}
// An Offset represents the location of an [Entry] within the DWARF info.
// (See [Reader.Seek].)
type Offset uint32
// Entry reads a single entry from buf, decoding
// according to the given abbreviation table.
func (b *buf) entry(cu *Entry, u *unit) *Entry {
atab, ubase, vers := u.atable, u.base, u.vers
off := b.off
id := uint32(b.uint())
if id == 0 {
return &Entry{}
}
a, ok := atab[id]
if !ok {
b.error("unknown abbreviation table index")
return nil
}
e := &Entry{
Offset: off,
Tag: a.tag,
Children: a.children,
Field: make([]Field, len(a.field)),
}
resolveStrx := func(strBase, off uint64) string {
off += strBase
if uint64(int(off)) != off {
b.error("DW_FORM_strx offset out of range")
}
b1 := makeBuf(b.dwarf, b.format, "str_offsets", 0, b.dwarf.strOffsets)
b1.skip(int(off))
is64, _ := b.format.dwarf64()
if is64 {
off = b1.uint64()
} else {
off = uint64(b1.uint32())
}
if b1.err != nil {
b.err = b1.err
return ""
}
if uint64(int(off)) != off {
b.error("DW_FORM_strx indirect offset out of range")
}
b1 = makeBuf(b.dwarf, b.format, "str", 0, b.dwarf.str)
b1.skip(int(off))
val := b1.string()
if b1.err != nil {
b.err = b1.err
}
return val
}
resolveRnglistx := func(rnglistsBase, off uint64) uint64 {
is64, _ := b.format.dwarf64()
if is64 {
off *= 8
} else {
off *= 4
}
off += rnglistsBase
if uint64(int(off)) != off {
b.error("DW_FORM_rnglistx offset out of range")
}
b1 := makeBuf(b.dwarf, b.format, "rnglists", 0, b.dwarf.rngLists)
b1.skip(int(off))
if is64 {
off = b1.uint64()
} else {
off = uint64(b1.uint32())
}
if b1.err != nil {
b.err = b1.err
return 0
}
if uint64(int(off)) != off {
b.error("DW_FORM_rnglistx indirect offset out of range")
}
return rnglistsBase + off
}
for i := range e.Field {
e.Field[i].Attr = a.field[i].attr
e.Field[i].Class = a.field[i].class
fmt := a.field[i].fmt
if fmt == formIndirect {
fmt = format(b.uint())
e.Field[i].Class = formToClass(fmt, a.field[i].attr, vers, b)
}
var val any
switch fmt {
default:
b.error("unknown entry attr format 0x" + strconv.FormatInt(int64(fmt), 16))
// address
case formAddr:
val = b.addr()
case formAddrx, formAddrx1, formAddrx2, formAddrx3, formAddrx4:
var off uint64
switch fmt {
case formAddrx:
off = b.uint()
case formAddrx1:
off = uint64(b.uint8())
case formAddrx2:
off = uint64(b.uint16())
case formAddrx3:
off = uint64(b.uint24())
case formAddrx4:
off = uint64(b.uint32())
}
if b.dwarf.addr == nil {
b.error("DW_FORM_addrx with no .debug_addr section")
}
if b.err != nil {
return nil
}
addrBase := int64(u.addrBase())
var err error
val, err = b.dwarf.debugAddr(b.format, uint64(addrBase), off)
if err != nil {
if b.err == nil {
b.err = err
}
return nil
}
// block
case formDwarfBlock1:
val = b.bytes(int(b.uint8()))
case formDwarfBlock2:
val = b.bytes(int(b.uint16()))
case formDwarfBlock4:
val = b.bytes(int(b.uint32()))
case formDwarfBlock:
val = b.bytes(int(b.uint()))
// constant
case formData1:
val = int64(b.uint8())
case formData2:
val = int64(b.uint16())
case formData4:
val = int64(b.uint32())
case formData8:
val = int64(b.uint64())
case formData16:
val = b.bytes(16)
case formSdata:
val = b.int()
case formUdata:
val = int64(b.uint())
case formImplicitConst:
val = a.field[i].val
// flag
case formFlag:
val = b.uint8() == 1
// New in DWARF 4.
case formFlagPresent:
// The attribute is implicitly indicated as present, and no value is
// encoded in the debugging information entry itself.
val = true
// reference to other entry
case formRefAddr:
vers := b.format.version()
if vers == 0 {
b.error("unknown version for DW_FORM_ref_addr")
} else if vers == 2 {
val = Offset(b.addr())
} else {
is64, known := b.format.dwarf64()
if !known {
b.error("unknown size for DW_FORM_ref_addr")
} else if is64 {
val = Offset(b.uint64())
} else {
val = Offset(b.uint32())
}
}
case formRef1:
val = Offset(b.uint8()) + ubase
case formRef2:
val = Offset(b.uint16()) + ubase
case formRef4:
val = Offset(b.uint32()) + ubase
case formRef8:
val = Offset(b.uint64()) + ubase
case formRefUdata:
val = Offset(b.uint()) + ubase
// string
case formString:
val = b.string()
case formStrp, formLineStrp:
var off uint64 // offset into .debug_str
is64, known := b.format.dwarf64()
if !known {
b.error("unknown size for DW_FORM_strp/line_strp")
} else if is64 {
off = b.uint64()
} else {
off = uint64(b.uint32())
}
if uint64(int(off)) != off {
b.error("DW_FORM_strp/line_strp offset out of range")
}
if b.err != nil {
return nil
}
var b1 buf
if fmt == formStrp {
b1 = makeBuf(b.dwarf, b.format, "str", 0, b.dwarf.str)
} else {
if len(b.dwarf.lineStr) == 0 {
b.error("DW_FORM_line_strp with no .debug_line_str section")
return nil
}
b1 = makeBuf(b.dwarf, b.format, "line_str", 0, b.dwarf.lineStr)
}
b1.skip(int(off))
val = b1.string()
if b1.err != nil {
b.err = b1.err
return nil
}
case formStrx, formStrx1, formStrx2, formStrx3, formStrx4:
var off uint64
switch fmt {
case formStrx:
off = b.uint()
case formStrx1:
off = uint64(b.uint8())
case formStrx2:
off = uint64(b.uint16())
case formStrx3:
off = uint64(b.uint24())
case formStrx4:
off = uint64(b.uint32())
}
if len(b.dwarf.strOffsets) == 0 {
b.error("DW_FORM_strx with no .debug_str_offsets section")
}
is64, known := b.format.dwarf64()
if !known {
b.error("unknown offset size for DW_FORM_strx")
}
if b.err != nil {
return nil
}
if is64 {
off *= 8
} else {
off *= 4
}
strBase := int64(u.strOffsetsBase())
val = resolveStrx(uint64(strBase), off)
case formStrpSup:
is64, known := b.format.dwarf64()
if !known {
b.error("unknown size for DW_FORM_strp_sup")
} else if is64 {
val = b.uint64()
} else {
val = b.uint32()
}
// lineptr, loclistptr, macptr, rangelistptr
// New in DWARF 4, but clang can generate them with -gdwarf-2.
// Section reference, replacing use of formData4 and formData8.
case formSecOffset, formGnuRefAlt, formGnuStrpAlt:
is64, known := b.format.dwarf64()
if !known {
b.error("unknown size for form 0x" + strconv.FormatInt(int64(fmt), 16))
} else if is64 {
val = int64(b.uint64())
} else {
val = int64(b.uint32())
}
// exprloc
// New in DWARF 4.
case formExprloc:
val = b.bytes(int(b.uint()))
// reference
// New in DWARF 4.
case formRefSig8:
// 64-bit type signature.
val = b.uint64()
case formRefSup4:
val = b.uint32()
case formRefSup8:
val = b.uint64()
// loclist
case formLoclistx:
val = b.uint()
// rnglist
case formRnglistx:
off := b.uint()
rnglistsBase := int64(u.rngListsBase())
val = resolveRnglistx(uint64(rnglistsBase), off)
}
e.Field[i].Val = val
}
if b.err != nil {
return nil
}
return e
}
// A Reader allows reading [Entry] structures from a DWARF “info” section.
// The [Entry] structures are arranged in a tree. The [Reader.Next] function
// return successive entries from a pre-order traversal of the tree.
// If an entry has children, its Children field will be true, and the children
// follow, terminated by an [Entry] with [Tag] 0.
type Reader struct {
b buf
d *Data
err error
unit int
lastUnit bool // set if last entry returned by Next is TagCompileUnit/TagPartialUnit
lastChildren bool // .Children of last entry returned by Next
lastSibling Offset // .Val(AttrSibling) of last entry returned by Next
cu *Entry // current compilation unit
}
// Reader returns a new Reader for [Data].
// The reader is positioned at byte offset 0 in the DWARF “info” section.
func (d *Data) Reader() *Reader {
r := &Reader{d: d}
r.Seek(0)
return r
}
// AddressSize returns the size in bytes of addresses in the current compilation
// unit.
func (r *Reader) AddressSize() int {
return r.d.unit[r.unit].asize
}
// ByteOrder returns the byte order in the current compilation unit.
func (r *Reader) ByteOrder() binary.ByteOrder {
return r.b.order
}
// Seek positions the [Reader] at offset off in the encoded entry stream.
// Offset 0 can be used to denote the first entry.
func (r *Reader) Seek(off Offset) {
d := r.d
r.err = nil
r.lastChildren = false
if off == 0 {
if len(d.unit) == 0 {
return
}
u := &d.unit[0]
r.unit = 0
r.b = makeBuf(r.d, u, "info", u.off, u.data)
r.collectDwarf5BaseOffsets(u)
r.cu = nil
return
}
i := d.offsetToUnit(off)
if i == -1 {
r.err = errors.New("offset out of range")
return
}
if i != r.unit {
r.cu = nil
}
u := &d.unit[i]
r.unit = i
r.b = makeBuf(r.d, u, "info", off, u.data[off-u.off:])
r.collectDwarf5BaseOffsets(u)
}
// maybeNextUnit advances to the next unit if this one is finished.
func (r *Reader) maybeNextUnit() {
for len(r.b.data) == 0 && r.unit+1 < len(r.d.unit) {
r.nextUnit()
}
}
// nextUnit advances to the next unit.
func (r *Reader) nextUnit() {
r.unit++
u := &r.d.unit[r.unit]
r.b = makeBuf(r.d, u, "info", u.off, u.data)
r.cu = nil
r.collectDwarf5BaseOffsets(u)
}
func (r *Reader) collectDwarf5BaseOffsets(u *unit) {
if u.vers < 5 || u.unit5 != nil {
return
}
u.unit5 = new(unit5)
if err := r.d.collectDwarf5BaseOffsets(u); err != nil {
r.err = err
}
}
// Next reads the next entry from the encoded entry stream.
// It returns nil, nil when it reaches the end of the section.
// It returns an error if the current offset is invalid or the data at the
// offset cannot be decoded as a valid [Entry].
func (r *Reader) Next() (*Entry, error) {
if r.err != nil {
return nil, r.err
}
r.maybeNextUnit()
if len(r.b.data) == 0 {
return nil, nil
}
u := &r.d.unit[r.unit]
e := r.b.entry(r.cu, u)
if r.b.err != nil {
r.err = r.b.err
return nil, r.err
}
r.lastUnit = false
if e != nil {
r.lastChildren = e.Children
if r.lastChildren {
r.lastSibling, _ = e.Val(AttrSibling).(Offset)
}
if e.Tag == TagCompileUnit || e.Tag == TagPartialUnit {
r.lastUnit = true
r.cu = e
}
} else {
r.lastChildren = false
}
return e, nil
}
// SkipChildren skips over the child entries associated with
// the last [Entry] returned by [Reader.Next]. If that [Entry] did not have
// children or [Reader.Next] has not been called, SkipChildren is a no-op.
func (r *Reader) SkipChildren() {
if r.err != nil || !r.lastChildren {
return
}
// If the last entry had a sibling attribute,
// that attribute gives the offset of the next
// sibling, so we can avoid decoding the
// child subtrees.
if r.lastSibling >= r.b.off {
r.Seek(r.lastSibling)
return
}
if r.lastUnit && r.unit+1 < len(r.d.unit) {
r.nextUnit()
return
}
for {
e, err := r.Next()
if err != nil || e == nil || e.Tag == 0 {
break
}
if e.Children {
r.SkipChildren()
}
}
}
// clone returns a copy of the reader. This is used by the typeReader
// interface.
func (r *Reader) clone() typeReader {
return r.d.Reader()
}
// offset returns the current buffer offset. This is used by the
// typeReader interface.
func (r *Reader) offset() Offset {
return r.b.off
}
// SeekPC returns the [Entry] for the compilation unit that includes pc,
// and positions the reader to read the children of that unit. If pc
// is not covered by any unit, SeekPC returns [ErrUnknownPC] and the
// position of the reader is undefined.
//
// Because compilation units can describe multiple regions of the
// executable, in the worst case SeekPC must search through all the
// ranges in all the compilation units. Each call to SeekPC starts the
// search at the compilation unit of the last call, so in general
// looking up a series of PCs will be faster if they are sorted. If
// the caller wishes to do repeated fast PC lookups, it should build
// an appropriate index using the Ranges method.
func (r *Reader) SeekPC(pc uint64) (*Entry, error) {
unit := r.unit
for i := 0; i < len(r.d.unit); i++ {
if unit >= len(r.d.unit) {
unit = 0
}
r.err = nil
r.lastChildren = false
r.unit = unit
r.cu = nil
u := &r.d.unit[unit]
r.b = makeBuf(r.d, u, "info", u.off, u.data)
r.collectDwarf5BaseOffsets(u)
e, err := r.Next()
if err != nil {
return nil, err
}
if e == nil || e.Tag == 0 {
return nil, ErrUnknownPC
}
ranges, err := r.d.Ranges(e)
if err != nil {
return nil, err
}
for _, pcs := range ranges {
if pcs[0] <= pc && pc < pcs[1] {
return e, nil
}
}
unit++
}
return nil, ErrUnknownPC
}
// Ranges returns the PC ranges covered by e, a slice of [low,high) pairs.
// Only some entry types, such as [TagCompileUnit] or [TagSubprogram], have PC
// ranges; for others, this will return nil with no error.
func (d *Data) Ranges(e *Entry) ([][2]uint64, error) {
var ret [][2]uint64
low, lowOK := e.Val(AttrLowpc).(uint64)
var high uint64
var highOK bool
highField := e.AttrField(AttrHighpc)
if highField != nil {
switch highField.Class {
case ClassAddress:
high, highOK = highField.Val.(uint64)
case ClassConstant:
off, ok := highField.Val.(int64)
if ok {
high = low + uint64(off)
highOK = true
}
}
}
if lowOK && highOK {
ret = append(ret, [2]uint64{low, high})
}
var u *unit
if uidx := d.offsetToUnit(e.Offset); uidx >= 0 && uidx < len(d.unit) {
u = &d.unit[uidx]
}
if u != nil && u.vers >= 5 && d.rngLists != nil {
// DWARF version 5 and later
field := e.AttrField(AttrRanges)
if field == nil {
return ret, nil
}
switch field.Class {
case ClassRangeListPtr:
ranges, rangesOK := field.Val.(int64)
if !rangesOK {
return ret, nil
}
cu, base, err := d.baseAddressForEntry(e)
if err != nil {
return nil, err
}
return d.dwarf5Ranges(u, cu, base, ranges, ret)
case ClassRngList:
rnglist, ok := field.Val.(uint64)
if !ok {
return ret, nil
}
cu, base, err := d.baseAddressForEntry(e)
if err != nil {
return nil, err
}
return d.dwarf5Ranges(u, cu, base, int64(rnglist), ret)
default:
return ret, nil
}
}
// DWARF version 2 through 4
ranges, rangesOK := e.Val(AttrRanges).(int64)
if rangesOK && d.ranges != nil {
_, base, err := d.baseAddressForEntry(e)
if err != nil {
return nil, err
}
return d.dwarf2Ranges(u, base, ranges, ret)
}
return ret, nil
}
// baseAddressForEntry returns the initial base address to be used when
// looking up the range list of entry e.
// DWARF specifies that this should be the lowpc attribute of the enclosing
// compilation unit, however comments in gdb/dwarf2read.c say that some
// versions of GCC use the entrypc attribute, so we check that too.
func (d *Data) baseAddressForEntry(e *Entry) (*Entry, uint64, error) {
var cu *Entry
if e.Tag == TagCompileUnit {
cu = e
} else {
i := d.offsetToUnit(e.Offset)
if i == -1 {
return nil, 0, errors.New("no unit for entry")
}
u := &d.unit[i]
b := makeBuf(d, u, "info", u.off, u.data)
cu = b.entry(nil, u)
if b.err != nil {
return nil, 0, b.err
}
}
if cuEntry, cuEntryOK := cu.Val(AttrEntrypc).(uint64); cuEntryOK {
return cu, cuEntry, nil
} else if cuLow, cuLowOK := cu.Val(AttrLowpc).(uint64); cuLowOK {
return cu, cuLow, nil
}
return cu, 0, nil
}
func (d *Data) dwarf2Ranges(u *unit, base uint64, ranges int64, ret [][2]uint64) ([][2]uint64, error) {
if ranges < 0 || ranges > int64(len(d.ranges)) {
return nil, fmt.Errorf("invalid range offset %d (max %d)", ranges, len(d.ranges))
}
buf := makeBuf(d, u, "ranges", Offset(ranges), d.ranges[ranges:])
for len(buf.data) > 0 {
low := buf.addr()
high := buf.addr()
if low == 0 && high == 0 {
break
}
if low == ^uint64(0)>>uint((8-u.addrsize())*8) {
base = high
} else {
ret = append(ret, [2]uint64{base + low, base + high})
}
}
return ret, nil
}
// dwarf5Ranges interprets a debug_rnglists sequence, see DWARFv5 section
// 2.17.3 (page 53).
func (d *Data) dwarf5Ranges(u *unit, cu *Entry, base uint64, ranges int64, ret [][2]uint64) ([][2]uint64, error) {
if ranges < 0 || ranges > int64(len(d.rngLists)) {
return nil, fmt.Errorf("invalid rnglist offset %d (max %d)", ranges, len(d.ranges))
}
var addrBase int64
if cu != nil {
addrBase, _ = cu.Val(AttrAddrBase).(int64)
}
buf := makeBuf(d, u, "rnglists", 0, d.rngLists)
buf.skip(int(ranges))
for {
opcode := buf.uint8()
switch opcode {
case rleEndOfList:
if buf.err != nil {
return nil, buf.err
}
return ret, nil
case rleBaseAddressx:
baseIdx := buf.uint()
var err error
base, err = d.debugAddr(u, uint64(addrBase), baseIdx)
if err != nil {
return nil, err
}
case rleStartxEndx:
startIdx := buf.uint()
endIdx := buf.uint()
start, err := d.debugAddr(u, uint64(addrBase), startIdx)
if err != nil {
return nil, err
}
end, err := d.debugAddr(u, uint64(addrBase), endIdx)
if err != nil {
return nil, err
}
ret = append(ret, [2]uint64{start, end})
case rleStartxLength:
startIdx := buf.uint()
len := buf.uint()
start, err := d.debugAddr(u, uint64(addrBase), startIdx)
if err != nil {
return nil, err
}
ret = append(ret, [2]uint64{start, start + len})
case rleOffsetPair:
off1 := buf.uint()
off2 := buf.uint()
ret = append(ret, [2]uint64{base + off1, base + off2})
case rleBaseAddress:
base = buf.addr()
case rleStartEnd:
start := buf.addr()
end := buf.addr()
ret = append(ret, [2]uint64{start, end})
case rleStartLength:
start := buf.addr()
len := buf.uint()
ret = append(ret, [2]uint64{start, start + len})
}
}
}
// debugAddr returns the address at idx in debug_addr
func (d *Data) debugAddr(format dataFormat, addrBase, idx uint64) (uint64, error) {
off := idx*uint64(format.addrsize()) + addrBase
if uint64(int(off)) != off {
return 0, errors.New("offset out of range")
}
b := makeBuf(d, format, "addr", 0, d.addr)
b.skip(int(off))
val := b.addr()
if b.err != nil {
return 0, b.err
}
return val, nil
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dwarf
import (
"errors"
"fmt"
"io"
"path"
"strings"
)
// A LineReader reads a sequence of [LineEntry] structures from a DWARF
// "line" section for a single compilation unit. LineEntries occur in
// order of increasing PC and each [LineEntry] gives metadata for the
// instructions from that [LineEntry]'s PC to just before the next
// [LineEntry]'s PC. The last entry will have the [LineEntry.EndSequence] field set.
type LineReader struct {
buf buf
// Original .debug_line section data. Used by Seek.
section []byte
str []byte // .debug_str
lineStr []byte // .debug_line_str
// Header information
version uint16
addrsize int
segmentSelectorSize int
minInstructionLength int
maxOpsPerInstruction int
defaultIsStmt bool
lineBase int
lineRange int
opcodeBase int
opcodeLengths []int
directories []string
fileEntries []*LineFile
programOffset Offset // section offset of line number program
endOffset Offset // section offset of byte following program
initialFileEntries int // initial length of fileEntries
// Current line number program state machine registers
state LineEntry // public state
fileIndex int // private state
}
// A LineEntry is a row in a DWARF line table.
type LineEntry struct {
// Address is the program-counter value of a machine
// instruction generated by the compiler. This LineEntry
// applies to each instruction from Address to just before the
// Address of the next LineEntry.
Address uint64
// OpIndex is the index of an operation within a VLIW
// instruction. The index of the first operation is 0. For
// non-VLIW architectures, it will always be 0. Address and
// OpIndex together form an operation pointer that can
// reference any individual operation within the instruction
// stream.
OpIndex int
// File is the source file corresponding to these
// instructions.
File *LineFile
// Line is the source code line number corresponding to these
// instructions. Lines are numbered beginning at 1. It may be
// 0 if these instructions cannot be attributed to any source
// line.
Line int
// Column is the column number within the source line of these
// instructions. Columns are numbered beginning at 1. It may
// be 0 to indicate the "left edge" of the line.
Column int
// IsStmt indicates that Address is a recommended breakpoint
// location, such as the beginning of a line, statement, or a
// distinct subpart of a statement.
IsStmt bool
// BasicBlock indicates that Address is the beginning of a
// basic block.
BasicBlock bool
// PrologueEnd indicates that Address is one (of possibly
// many) PCs where execution should be suspended for a
// breakpoint on entry to the containing function.
//
// Added in DWARF 3.
PrologueEnd bool
// EpilogueBegin indicates that Address is one (of possibly
// many) PCs where execution should be suspended for a
// breakpoint on exit from this function.
//
// Added in DWARF 3.
EpilogueBegin bool
// ISA is the instruction set architecture for these
// instructions. Possible ISA values should be defined by the
// applicable ABI specification.
//
// Added in DWARF 3.
ISA int
// Discriminator is an arbitrary integer indicating the block
// to which these instructions belong. It serves to
// distinguish among multiple blocks that may all have with
// the same source file, line, and column. Where only one
// block exists for a given source position, it should be 0.
//
// Added in DWARF 3.
Discriminator int
// EndSequence indicates that Address is the first byte after
// the end of a sequence of target machine instructions. If it
// is set, only this and the Address field are meaningful. A
// line number table may contain information for multiple
// potentially disjoint instruction sequences. The last entry
// in a line table should always have EndSequence set.
EndSequence bool
}
// A LineFile is a source file referenced by a DWARF line table entry.
type LineFile struct {
Name string
Mtime uint64 // Implementation defined modification time, or 0 if unknown
Length int // File length, or 0 if unknown
}
// LineReader returns a new reader for the line table of compilation
// unit cu, which must be an [Entry] with tag [TagCompileUnit].
//
// If this compilation unit has no line table, it returns nil, nil.
func (d *Data) LineReader(cu *Entry) (*LineReader, error) {
if d.line == nil {
// No line tables available.
return nil, nil
}
// Get line table information from cu.
off, ok := cu.Val(AttrStmtList).(int64)
if !ok {
// cu has no line table.
return nil, nil
}
if off < 0 || off > int64(len(d.line)) {
return nil, errors.New("AttrStmtList value out of range")
}
// AttrCompDir is optional if all file names are absolute. Use
// the empty string if it's not present.
compDir, _ := cu.Val(AttrCompDir).(string)
// Create the LineReader.
u := &d.unit[d.offsetToUnit(cu.Offset)]
buf := makeBuf(d, u, "line", Offset(off), d.line[off:])
// The compilation directory is implicitly directories[0].
r := LineReader{
buf: buf,
section: d.line,
str: d.str,
lineStr: d.lineStr,
}
// Read the header.
if err := r.readHeader(compDir); err != nil {
return nil, err
}
// Initialize line reader state.
r.Reset()
return &r, nil
}
// readHeader reads the line number program header from r.buf and sets
// all of the header fields in r.
func (r *LineReader) readHeader(compDir string) error {
buf := &r.buf
// Read basic header fields [DWARF2 6.2.4].
hdrOffset := buf.off
unitLength, dwarf64 := buf.unitLength()
r.endOffset = buf.off + unitLength
if r.endOffset > buf.off+Offset(len(buf.data)) {
return DecodeError{"line", hdrOffset, fmt.Sprintf("line table end %d exceeds section size %d", r.endOffset, buf.off+Offset(len(buf.data)))}
}
r.version = buf.uint16()
if buf.err == nil && (r.version < 2 || r.version > 5) {
// DWARF goes to all this effort to make new opcodes
// backward-compatible, and then adds fields right in
// the middle of the header in new versions, so we're
// picky about only supporting known line table
// versions.
return DecodeError{"line", hdrOffset, fmt.Sprintf("unknown line table version %d", r.version)}
}
if r.version >= 5 {
r.addrsize = int(buf.uint8())
r.segmentSelectorSize = int(buf.uint8())
} else {
r.addrsize = buf.format.addrsize()
r.segmentSelectorSize = 0
}
var headerLength Offset
if dwarf64 {
headerLength = Offset(buf.uint64())
} else {
headerLength = Offset(buf.uint32())
}
programOffset := buf.off + headerLength
if programOffset > r.endOffset {
return DecodeError{"line", hdrOffset, fmt.Sprintf("malformed line table: program offset %d exceeds end offset %d", programOffset, r.endOffset)}
}
r.programOffset = programOffset
r.minInstructionLength = int(buf.uint8())
if r.version >= 4 {
// [DWARF4 6.2.4]
r.maxOpsPerInstruction = int(buf.uint8())
} else {
r.maxOpsPerInstruction = 1
}
r.defaultIsStmt = buf.uint8() != 0
r.lineBase = int(int8(buf.uint8()))
r.lineRange = int(buf.uint8())
// Validate header.
if buf.err != nil {
return buf.err
}
if r.maxOpsPerInstruction == 0 {
return DecodeError{"line", hdrOffset, "invalid maximum operations per instruction: 0"}
}
if r.lineRange == 0 {
return DecodeError{"line", hdrOffset, "invalid line range: 0"}
}
// Read standard opcode length table. This table starts with opcode 1.
r.opcodeBase = int(buf.uint8())
r.opcodeLengths = make([]int, r.opcodeBase)
for i := 1; i < r.opcodeBase; i++ {
r.opcodeLengths[i] = int(buf.uint8())
}
// Validate opcode lengths.
if buf.err != nil {
return buf.err
}
for i, length := range r.opcodeLengths {
if known, ok := knownOpcodeLengths[i]; ok && known != length {
return DecodeError{"line", hdrOffset, fmt.Sprintf("opcode %d expected to have length %d, but has length %d", i, known, length)}
}
}
if r.version < 5 {
// Read include directories table.
r.directories = []string{compDir}
for {
directory := buf.string()
if buf.err != nil {
return buf.err
}
if len(directory) == 0 {
break
}
if !pathIsAbs(directory) {
// Relative paths are implicitly relative to
// the compilation directory.
directory = pathJoin(compDir, directory)
}
r.directories = append(r.directories, directory)
}
// Read file name list. File numbering starts with 1,
// so leave the first entry nil.
r.fileEntries = make([]*LineFile, 1)
for {
if done, err := r.readFileEntry(); err != nil {
return err
} else if done {
break
}
}
} else {
dirFormat := r.readLNCTFormat()
c := buf.uint()
r.directories = make([]string, c)
for i := range r.directories {
dir, _, _, err := r.readLNCT(dirFormat, dwarf64)
if err != nil {
return err
}
r.directories[i] = dir
}
fileFormat := r.readLNCTFormat()
c = buf.uint()
r.fileEntries = make([]*LineFile, c)
for i := range r.fileEntries {
name, mtime, size, err := r.readLNCT(fileFormat, dwarf64)
if err != nil {
return err
}
r.fileEntries[i] = &LineFile{name, mtime, int(size)}
}
}
r.initialFileEntries = len(r.fileEntries)
return buf.err
}
// lnctForm is a pair of an LNCT code and a form. This represents an
// entry in the directory name or file name description in the DWARF 5
// line number program header.
type lnctForm struct {
lnct int
form format
}
// readLNCTFormat reads an LNCT format description.
func (r *LineReader) readLNCTFormat() []lnctForm {
c := r.buf.uint8()
ret := make([]lnctForm, c)
for i := range ret {
ret[i].lnct = int(r.buf.uint())
ret[i].form = format(r.buf.uint())
}
return ret
}
// readLNCT reads a sequence of LNCT entries and returns path information.
func (r *LineReader) readLNCT(s []lnctForm, dwarf64 bool) (path string, mtime uint64, size uint64, err error) {
var dir string
for _, lf := range s {
var str string
var val uint64
switch lf.form {
case formString:
str = r.buf.string()
case formStrp, formLineStrp:
var off uint64
if dwarf64 {
off = r.buf.uint64()
} else {
off = uint64(r.buf.uint32())
}
if uint64(int(off)) != off {
return "", 0, 0, DecodeError{"line", r.buf.off, "strp/line_strp offset out of range"}
}
var b1 buf
if lf.form == formStrp {
b1 = makeBuf(r.buf.dwarf, r.buf.format, "str", 0, r.str)
} else {
b1 = makeBuf(r.buf.dwarf, r.buf.format, "line_str", 0, r.lineStr)
}
b1.skip(int(off))
str = b1.string()
if b1.err != nil {
return "", 0, 0, DecodeError{"line", r.buf.off, b1.err.Error()}
}
case formStrpSup:
// Supplemental sections not yet supported.
if dwarf64 {
r.buf.uint64()
} else {
r.buf.uint32()
}
case formStrx:
// .debug_line.dwo sections not yet supported.
r.buf.uint()
case formStrx1:
r.buf.uint8()
case formStrx2:
r.buf.uint16()
case formStrx3:
r.buf.uint24()
case formStrx4:
r.buf.uint32()
case formData1:
val = uint64(r.buf.uint8())
case formData2:
val = uint64(r.buf.uint16())
case formData4:
val = uint64(r.buf.uint32())
case formData8:
val = r.buf.uint64()
case formData16:
r.buf.bytes(16)
case formDwarfBlock:
r.buf.bytes(int(r.buf.uint()))
case formUdata:
val = r.buf.uint()
}
switch lf.lnct {
case lnctPath:
path = str
case lnctDirectoryIndex:
if val >= uint64(len(r.directories)) {
return "", 0, 0, DecodeError{"line", r.buf.off, "directory index out of range"}
}
dir = r.directories[val]
case lnctTimestamp:
mtime = val
case lnctSize:
size = val
case lnctMD5:
// Ignored.
}
}
if dir != "" && path != "" {
path = pathJoin(dir, path)
}
return path, mtime, size, nil
}
// readFileEntry reads a file entry from either the header or a
// DW_LNE_define_file extended opcode and adds it to r.fileEntries. A
// true return value indicates that there are no more entries to read.
func (r *LineReader) readFileEntry() (bool, error) {
name := r.buf.string()
if r.buf.err != nil {
return false, r.buf.err
}
if len(name) == 0 {
return true, nil
}
off := r.buf.off
dirIndex := int(r.buf.uint())
if !pathIsAbs(name) {
if dirIndex >= len(r.directories) {
return false, DecodeError{"line", off, "directory index too large"}
}
name = pathJoin(r.directories[dirIndex], name)
}
mtime := r.buf.uint()
length := int(r.buf.uint())
// If this is a dynamically added path and the cursor was
// backed up, we may have already added this entry. Avoid
// updating existing line table entries in this case. This
// avoids an allocation and potential racy access to the slice
// backing store if the user called Files.
if len(r.fileEntries) < cap(r.fileEntries) {
fe := r.fileEntries[:len(r.fileEntries)+1]
if fe[len(fe)-1] != nil {
// We already processed this addition.
r.fileEntries = fe
return false, nil
}
}
r.fileEntries = append(r.fileEntries, &LineFile{name, mtime, length})
return false, nil
}
// updateFile updates r.state.File after r.fileIndex has
// changed or r.fileEntries has changed.
func (r *LineReader) updateFile() {
if r.fileIndex < len(r.fileEntries) {
r.state.File = r.fileEntries[r.fileIndex]
} else {
r.state.File = nil
}
}
// Next sets *entry to the next row in this line table and moves to
// the next row. If there are no more entries and the line table is
// properly terminated, it returns [io.EOF].
//
// Rows are always in order of increasing entry.Address, but
// entry.Line may go forward or backward.
func (r *LineReader) Next(entry *LineEntry) error {
if r.buf.err != nil {
return r.buf.err
}
// Execute opcodes until we reach an opcode that emits a line
// table entry.
for {
if len(r.buf.data) == 0 {
return io.EOF
}
emit := r.step(entry)
if r.buf.err != nil {
return r.buf.err
}
if emit {
return nil
}
}
}
// knownOpcodeLengths gives the opcode lengths (in varint arguments)
// of known standard opcodes.
var knownOpcodeLengths = map[int]int{
lnsCopy: 0,
lnsAdvancePC: 1,
lnsAdvanceLine: 1,
lnsSetFile: 1,
lnsNegateStmt: 0,
lnsSetBasicBlock: 0,
lnsConstAddPC: 0,
lnsSetPrologueEnd: 0,
lnsSetEpilogueBegin: 0,
lnsSetISA: 1,
// lnsFixedAdvancePC takes a uint8 rather than a varint; it's
// unclear what length the header is supposed to claim, so
// ignore it.
}
// step processes the next opcode and updates r.state. If the opcode
// emits a row in the line table, this updates *entry and returns
// true.
func (r *LineReader) step(entry *LineEntry) bool {
opcode := int(r.buf.uint8())
if opcode >= r.opcodeBase {
// Special opcode [DWARF2 6.2.5.1, DWARF4 6.2.5.1]
adjustedOpcode := opcode - r.opcodeBase
r.advancePC(adjustedOpcode / r.lineRange)
lineDelta := r.lineBase + adjustedOpcode%r.lineRange
r.state.Line += lineDelta
goto emit
}
switch opcode {
case 0:
// Extended opcode [DWARF2 6.2.5.3]
length := Offset(r.buf.uint())
startOff := r.buf.off
opcode := r.buf.uint8()
switch opcode {
case lneEndSequence:
r.state.EndSequence = true
*entry = r.state
r.resetState()
case lneSetAddress:
switch r.addrsize {
case 1:
r.state.Address = uint64(r.buf.uint8())
case 2:
r.state.Address = uint64(r.buf.uint16())
case 4:
r.state.Address = uint64(r.buf.uint32())
case 8:
r.state.Address = r.buf.uint64()
default:
r.buf.error("unknown address size")
}
case lneDefineFile:
if done, err := r.readFileEntry(); err != nil {
r.buf.err = err
return false
} else if done {
r.buf.err = DecodeError{"line", startOff, "malformed DW_LNE_define_file operation"}
return false
}
r.updateFile()
case lneSetDiscriminator:
// [DWARF4 6.2.5.3]
r.state.Discriminator = int(r.buf.uint())
}
r.buf.skip(int(startOff + length - r.buf.off))
if opcode == lneEndSequence {
return true
}
// Standard opcodes [DWARF2 6.2.5.2]
case lnsCopy:
goto emit
case lnsAdvancePC:
r.advancePC(int(r.buf.uint()))
case lnsAdvanceLine:
r.state.Line += int(r.buf.int())
case lnsSetFile:
r.fileIndex = int(r.buf.uint())
r.updateFile()
case lnsSetColumn:
r.state.Column = int(r.buf.uint())
case lnsNegateStmt:
r.state.IsStmt = !r.state.IsStmt
case lnsSetBasicBlock:
r.state.BasicBlock = true
case lnsConstAddPC:
r.advancePC((255 - r.opcodeBase) / r.lineRange)
case lnsFixedAdvancePC:
r.state.Address += uint64(r.buf.uint16())
// DWARF3 standard opcodes [DWARF3 6.2.5.2]
case lnsSetPrologueEnd:
r.state.PrologueEnd = true
case lnsSetEpilogueBegin:
r.state.EpilogueBegin = true
case lnsSetISA:
r.state.ISA = int(r.buf.uint())
default:
// Unhandled standard opcode. Skip the number of
// arguments that the prologue says this opcode has.
for i := 0; i < r.opcodeLengths[opcode]; i++ {
r.buf.uint()
}
}
return false
emit:
*entry = r.state
r.state.BasicBlock = false
r.state.PrologueEnd = false
r.state.EpilogueBegin = false
r.state.Discriminator = 0
return true
}
// advancePC advances "operation pointer" (the combination of Address
// and OpIndex) in r.state by opAdvance steps.
func (r *LineReader) advancePC(opAdvance int) {
opIndex := r.state.OpIndex + opAdvance
r.state.Address += uint64(r.minInstructionLength * (opIndex / r.maxOpsPerInstruction))
r.state.OpIndex = opIndex % r.maxOpsPerInstruction
}
// A LineReaderPos represents a position in a line table.
type LineReaderPos struct {
// off is the current offset in the DWARF line section.
off Offset
// numFileEntries is the length of fileEntries.
numFileEntries int
// state and fileIndex are the statement machine state at
// offset off.
state LineEntry
fileIndex int
}
// Tell returns the current position in the line table.
func (r *LineReader) Tell() LineReaderPos {
return LineReaderPos{r.buf.off, len(r.fileEntries), r.state, r.fileIndex}
}
// Seek restores the line table reader to a position returned by [LineReader.Tell].
//
// The argument pos must have been returned by a call to [LineReader.Tell] on this
// line table.
func (r *LineReader) Seek(pos LineReaderPos) {
r.buf.off = pos.off
r.buf.data = r.section[r.buf.off:r.endOffset]
r.fileEntries = r.fileEntries[:pos.numFileEntries]
r.state = pos.state
r.fileIndex = pos.fileIndex
}
// Reset repositions the line table reader at the beginning of the
// line table.
func (r *LineReader) Reset() {
// Reset buffer to the line number program offset.
r.buf.off = r.programOffset
r.buf.data = r.section[r.buf.off:r.endOffset]
// Reset file entries list.
r.fileEntries = r.fileEntries[:r.initialFileEntries]
// Reset line number program state.
r.resetState()
}
// resetState resets r.state to its default values
func (r *LineReader) resetState() {
// Reset the state machine registers to the defaults given in
// [DWARF4 6.2.2].
r.state = LineEntry{
Address: 0,
OpIndex: 0,
File: nil,
Line: 1,
Column: 0,
IsStmt: r.defaultIsStmt,
BasicBlock: false,
PrologueEnd: false,
EpilogueBegin: false,
ISA: 0,
Discriminator: 0,
}
r.fileIndex = 1
r.updateFile()
}
// Files returns the file name table of this compilation unit as of
// the current position in the line table. The file name table may be
// referenced from attributes in this compilation unit such as
// [AttrDeclFile].
//
// Entry 0 is always nil, since file index 0 represents "no file".
//
// The file name table of a compilation unit is not fixed. Files
// returns the file table as of the current position in the line
// table. This may contain more entries than the file table at an
// earlier position in the line table, though existing entries never
// change.
func (r *LineReader) Files() []*LineFile {
return r.fileEntries
}
// ErrUnknownPC is the error returned by LineReader.ScanPC when the
// seek PC is not covered by any entry in the line table.
var ErrUnknownPC = errors.New("ErrUnknownPC")
// SeekPC sets *entry to the [LineEntry] that includes pc and positions
// the reader on the next entry in the line table. If necessary, this
// will seek backwards to find pc.
//
// If pc is not covered by any entry in this line table, SeekPC
// returns [ErrUnknownPC]. In this case, *entry and the final seek
// position are unspecified.
//
// Note that DWARF line tables only permit sequential, forward scans.
// Hence, in the worst case, this takes time linear in the size of the
// line table. If the caller wishes to do repeated fast PC lookups, it
// should build an appropriate index of the line table.
func (r *LineReader) SeekPC(pc uint64, entry *LineEntry) error {
if err := r.Next(entry); err != nil {
return err
}
if entry.Address > pc {
// We're too far. Start at the beginning of the table.
r.Reset()
if err := r.Next(entry); err != nil {
return err
}
if entry.Address > pc {
// The whole table starts after pc.
r.Reset()
return ErrUnknownPC
}
}
// Scan until we pass pc, then back up one.
for {
var next LineEntry
pos := r.Tell()
if err := r.Next(&next); err != nil {
if err == io.EOF {
return ErrUnknownPC
}
return err
}
if next.Address > pc {
if entry.EndSequence {
// pc is in a hole in the table.
return ErrUnknownPC
}
// entry is the desired entry. Back up the
// cursor to "next" and return success.
r.Seek(pos)
return nil
}
*entry = next
}
}
// pathIsAbs reports whether path is an absolute path (or "full path
// name" in DWARF parlance). This is in "whatever form makes sense for
// the host system", so this accepts both UNIX-style and DOS-style
// absolute paths. We avoid the filepath package because we want this
// to behave the same regardless of our host system and because we
// don't know what system the paths came from.
func pathIsAbs(path string) bool {
_, path = splitDrive(path)
return len(path) > 0 && (path[0] == '/' || path[0] == '\\')
}
// pathJoin joins dirname and filename. filename must be relative.
// DWARF paths can be UNIX-style or DOS-style, so this handles both.
func pathJoin(dirname, filename string) string {
if len(dirname) == 0 {
return filename
}
// dirname should be absolute, which means we can determine
// whether it's a DOS path reasonably reliably by looking for
// a drive letter or UNC path.
drive, dirname := splitDrive(dirname)
if drive == "" {
// UNIX-style path.
return path.Join(dirname, filename)
}
// DOS-style path.
drive2, filename := splitDrive(filename)
if drive2 != "" {
if !strings.EqualFold(drive, drive2) {
// Different drives. There's not much we can
// do here, so just ignore the directory.
return drive2 + filename
}
// Drives are the same. Ignore drive on filename.
}
if !(strings.HasSuffix(dirname, "/") || strings.HasSuffix(dirname, `\`)) && dirname != "" {
sep := `\`
if strings.HasPrefix(dirname, "/") {
sep = `/`
}
dirname += sep
}
return drive + dirname + filename
}
// splitDrive splits the DOS drive letter or UNC share point from
// path, if any. path == drive + rest
func splitDrive(path string) (drive, rest string) {
if len(path) >= 2 && path[1] == ':' {
if c := path[0]; 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' {
return path[:2], path[2:]
}
}
if len(path) > 3 && (path[0] == '\\' || path[0] == '/') && (path[1] == '\\' || path[1] == '/') {
// Normalize the path so we can search for just \ below.
npath := strings.ReplaceAll(path, "/", `\`)
// Get the host part, which must be non-empty.
slash1 := strings.IndexByte(npath[2:], '\\') + 2
if slash1 > 2 {
// Get the mount-point part, which must be non-empty.
slash2 := strings.IndexByte(npath[slash1+1:], '\\') + slash1 + 1
if slash2 > slash1 {
return path[:slash2], path[slash2:]
}
}
}
return "", path
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package dwarf provides access to DWARF debugging information loaded from
executable files, as defined in the DWARF 2.0 Standard at
http://dwarfstd.org/doc/dwarf-2.0.0.pdf.
# Security
This package is not designed to be hardened against adversarial inputs, and is
outside the scope of https://go.dev/security/policy. In particular, only basic
validation is done when parsing object files. As such, care should be taken when
parsing untrusted inputs, as parsing malformed files may consume significant
resources, or cause panics.
*/
package dwarf
import (
"encoding/binary"
"errors"
)
// Data represents the DWARF debugging information
// loaded from an executable file (for example, an ELF or Mach-O executable).
type Data struct {
// raw data
abbrev []byte
aranges []byte
frame []byte
info []byte
line []byte
pubnames []byte
ranges []byte
str []byte
// New sections added in DWARF 5.
addr []byte
lineStr []byte
strOffsets []byte
rngLists []byte
// parsed data
abbrevCache map[uint64]abbrevTable
bigEndian bool
order binary.ByteOrder
typeCache map[Offset]Type
typeSigs map[uint64]*typeUnit
unit []unit
}
var errSegmentSelector = errors.New("non-zero segment_selector size not supported")
// New returns a new [Data] object initialized from the given parameters.
// Rather than calling this function directly, clients should typically use
// the DWARF method of the File type of the appropriate package [debug/elf],
// [debug/macho], or [debug/pe].
//
// The []byte arguments are the data from the corresponding debug section
// in the object file; for example, for an ELF object, abbrev is the contents of
// the ".debug_abbrev" section.
func New(abbrev, aranges, frame, info, line, pubnames, ranges, str []byte) (*Data, error) {
d := &Data{
abbrev: abbrev,
aranges: aranges,
frame: frame,
info: info,
line: line,
pubnames: pubnames,
ranges: ranges,
str: str,
abbrevCache: make(map[uint64]abbrevTable),
typeCache: make(map[Offset]Type),
typeSigs: make(map[uint64]*typeUnit),
}
// Sniff .debug_info to figure out byte order.
// 32-bit DWARF: 4 byte length, 2 byte version.
// 64-bit DWARf: 4 bytes of 0xff, 8 byte length, 2 byte version.
if len(d.info) < 6 {
return nil, DecodeError{"info", Offset(len(d.info)), "too short"}
}
offset := 4
if d.info[0] == 0xff && d.info[1] == 0xff && d.info[2] == 0xff && d.info[3] == 0xff {
if len(d.info) < 14 {
return nil, DecodeError{"info", Offset(len(d.info)), "too short"}
}
offset = 12
}
// Fetch the version, a tiny 16-bit number (1, 2, 3, 4, 5).
x, y := d.info[offset], d.info[offset+1]
switch {
case x == 0 && y == 0:
return nil, DecodeError{"info", 4, "unsupported version 0"}
case x == 0:
d.bigEndian = true
d.order = binary.BigEndian
case y == 0:
d.bigEndian = false
d.order = binary.LittleEndian
default:
return nil, DecodeError{"info", 4, "cannot determine byte order"}
}
u, err := d.parseUnits()
if err != nil {
return nil, err
}
d.unit = u
return d, nil
}
// AddTypes will add one .debug_types section to the DWARF data. A
// typical object with DWARF version 4 debug info will have multiple
// .debug_types sections. The name is used for error reporting only,
// and serves to distinguish one .debug_types section from another.
func (d *Data) AddTypes(name string, types []byte) error {
return d.parseTypes(name, types)
}
// AddSection adds another DWARF section by name. The name should be a
// DWARF section name such as ".debug_addr", ".debug_str_offsets", and
// so forth. This approach is used for new DWARF sections added in
// DWARF 5 and later.
func (d *Data) AddSection(name string, contents []byte) error {
var err error
switch name {
case ".debug_addr":
d.addr = contents
case ".debug_line_str":
d.lineStr = contents
case ".debug_str_offsets":
d.strOffsets = contents
case ".debug_rnglists":
d.rngLists = contents
}
// Just ignore names that we don't yet support.
return err
}
// Code generated by "stringer -type Tag -trimprefix=Tag"; DO NOT EDIT.
package dwarf
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[TagArrayType-1]
_ = x[TagClassType-2]
_ = x[TagEntryPoint-3]
_ = x[TagEnumerationType-4]
_ = x[TagFormalParameter-5]
_ = x[TagImportedDeclaration-8]
_ = x[TagLabel-10]
_ = x[TagLexDwarfBlock-11]
_ = x[TagMember-13]
_ = x[TagPointerType-15]
_ = x[TagReferenceType-16]
_ = x[TagCompileUnit-17]
_ = x[TagStringType-18]
_ = x[TagStructType-19]
_ = x[TagSubroutineType-21]
_ = x[TagTypedef-22]
_ = x[TagUnionType-23]
_ = x[TagUnspecifiedParameters-24]
_ = x[TagVariant-25]
_ = x[TagCommonDwarfBlock-26]
_ = x[TagCommonInclusion-27]
_ = x[TagInheritance-28]
_ = x[TagInlinedSubroutine-29]
_ = x[TagModule-30]
_ = x[TagPtrToMemberType-31]
_ = x[TagSetType-32]
_ = x[TagSubrangeType-33]
_ = x[TagWithStmt-34]
_ = x[TagAccessDeclaration-35]
_ = x[TagBaseType-36]
_ = x[TagCatchDwarfBlock-37]
_ = x[TagConstType-38]
_ = x[TagConstant-39]
_ = x[TagEnumerator-40]
_ = x[TagFileType-41]
_ = x[TagFriend-42]
_ = x[TagNamelist-43]
_ = x[TagNamelistItem-44]
_ = x[TagPackedType-45]
_ = x[TagSubprogram-46]
_ = x[TagTemplateTypeParameter-47]
_ = x[TagTemplateValueParameter-48]
_ = x[TagThrownType-49]
_ = x[TagTryDwarfBlock-50]
_ = x[TagVariantPart-51]
_ = x[TagVariable-52]
_ = x[TagVolatileType-53]
_ = x[TagDwarfProcedure-54]
_ = x[TagRestrictType-55]
_ = x[TagInterfaceType-56]
_ = x[TagNamespace-57]
_ = x[TagImportedModule-58]
_ = x[TagUnspecifiedType-59]
_ = x[TagPartialUnit-60]
_ = x[TagImportedUnit-61]
_ = x[TagMutableType-62]
_ = x[TagCondition-63]
_ = x[TagSharedType-64]
_ = x[TagTypeUnit-65]
_ = x[TagRvalueReferenceType-66]
_ = x[TagTemplateAlias-67]
_ = x[TagCoarrayType-68]
_ = x[TagGenericSubrange-69]
_ = x[TagDynamicType-70]
_ = x[TagAtomicType-71]
_ = x[TagCallSite-72]
_ = x[TagCallSiteParameter-73]
_ = x[TagSkeletonUnit-74]
_ = x[TagImmutableType-75]
}
const (
_Tag_name_0 = "ArrayTypeClassTypeEntryPointEnumerationTypeFormalParameter"
_Tag_name_1 = "ImportedDeclaration"
_Tag_name_2 = "LabelLexDwarfBlock"
_Tag_name_3 = "Member"
_Tag_name_4 = "PointerTypeReferenceTypeCompileUnitStringTypeStructType"
_Tag_name_5 = "SubroutineTypeTypedefUnionTypeUnspecifiedParametersVariantCommonDwarfBlockCommonInclusionInheritanceInlinedSubroutineModulePtrToMemberTypeSetTypeSubrangeTypeWithStmtAccessDeclarationBaseTypeCatchDwarfBlockConstTypeConstantEnumeratorFileTypeFriendNamelistNamelistItemPackedTypeSubprogramTemplateTypeParameterTemplateValueParameterThrownTypeTryDwarfBlockVariantPartVariableVolatileTypeDwarfProcedureRestrictTypeInterfaceTypeNamespaceImportedModuleUnspecifiedTypePartialUnitImportedUnitMutableTypeConditionSharedTypeTypeUnitRvalueReferenceTypeTemplateAliasCoarrayTypeGenericSubrangeDynamicTypeAtomicTypeCallSiteCallSiteParameterSkeletonUnitImmutableType"
)
var (
_Tag_index_0 = [...]uint8{0, 9, 18, 28, 43, 58}
_Tag_index_2 = [...]uint8{0, 5, 18}
_Tag_index_4 = [...]uint8{0, 11, 24, 35, 45, 55}
_Tag_index_5 = [...]uint16{0, 14, 21, 30, 51, 58, 74, 89, 100, 117, 123, 138, 145, 157, 165, 182, 190, 205, 214, 222, 232, 240, 246, 254, 266, 276, 286, 307, 329, 339, 352, 363, 371, 383, 397, 409, 422, 431, 445, 460, 471, 483, 494, 503, 513, 521, 540, 553, 564, 579, 590, 600, 608, 625, 637, 650}
)
func (i Tag) String() string {
switch {
case 1 <= i && i <= 5:
i -= 1
return _Tag_name_0[_Tag_index_0[i]:_Tag_index_0[i+1]]
case i == 8:
return _Tag_name_1
case 10 <= i && i <= 11:
i -= 10
return _Tag_name_2[_Tag_index_2[i]:_Tag_index_2[i+1]]
case i == 13:
return _Tag_name_3
case 15 <= i && i <= 19:
i -= 15
return _Tag_name_4[_Tag_index_4[i]:_Tag_index_4[i+1]]
case 21 <= i && i <= 75:
i -= 21
return _Tag_name_5[_Tag_index_5[i]:_Tag_index_5[i+1]]
default:
return "Tag(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// DWARF type information structures.
// The format is heavily biased toward C, but for simplicity
// the String methods use a pseudo-Go syntax.
package dwarf
import "strconv"
// A Type conventionally represents a pointer to any of the
// specific Type structures ([CharType], [StructType], etc.).
type Type interface {
Common() *CommonType
String() string
Size() int64
}
// A CommonType holds fields common to multiple types.
// If a field is not known or not applicable for a given type,
// the zero value is used.
type CommonType struct {
ByteSize int64 // size of value of this type, in bytes
Name string // name that can be used to refer to type
}
func (c *CommonType) Common() *CommonType { return c }
func (c *CommonType) Size() int64 { return c.ByteSize }
// Basic types
// A BasicType holds fields common to all basic types.
//
// See the documentation for [StructField] for more info on the interpretation of
// the BitSize/BitOffset/DataBitOffset fields.
type BasicType struct {
CommonType
BitSize int64
BitOffset int64
DataBitOffset int64
}
func (b *BasicType) Basic() *BasicType { return b }
func (t *BasicType) String() string {
if t.Name != "" {
return t.Name
}
return "?"
}
// A CharType represents a signed character type.
type CharType struct {
BasicType
}
// A UcharType represents an unsigned character type.
type UcharType struct {
BasicType
}
// An IntType represents a signed integer type.
type IntType struct {
BasicType
}
// A UintType represents an unsigned integer type.
type UintType struct {
BasicType
}
// A FloatType represents a floating point type.
type FloatType struct {
BasicType
}
// A ComplexType represents a complex floating point type.
type ComplexType struct {
BasicType
}
// A BoolType represents a boolean type.
type BoolType struct {
BasicType
}
// An AddrType represents a machine address type.
type AddrType struct {
BasicType
}
// An UnspecifiedType represents an implicit, unknown, ambiguous or nonexistent type.
type UnspecifiedType struct {
BasicType
}
// qualifiers
// A QualType represents a type that has the C/C++ "const", "restrict", or "volatile" qualifier.
type QualType struct {
CommonType
Qual string
Type Type
}
func (t *QualType) String() string { return t.Qual + " " + t.Type.String() }
func (t *QualType) Size() int64 { return t.Type.Size() }
// An ArrayType represents a fixed size array type.
type ArrayType struct {
CommonType
Type Type
StrideBitSize int64 // if > 0, number of bits to hold each element
Count int64 // if == -1, an incomplete array, like char x[].
}
func (t *ArrayType) String() string {
return "[" + strconv.FormatInt(t.Count, 10) + "]" + t.Type.String()
}
func (t *ArrayType) Size() int64 {
if t.Count == -1 {
return 0
}
return t.Count * t.Type.Size()
}
// A VoidType represents the C void type.
type VoidType struct {
CommonType
}
func (t *VoidType) String() string { return "void" }
// A PtrType represents a pointer type.
type PtrType struct {
CommonType
Type Type
}
func (t *PtrType) String() string { return "*" + t.Type.String() }
// A StructType represents a struct, union, or C++ class type.
type StructType struct {
CommonType
StructName string
Kind string // "struct", "union", or "class".
Field []*StructField
Incomplete bool // if true, struct, union, class is declared but not defined
}
// A StructField represents a field in a struct, union, or C++ class type.
//
// # Bit Fields
//
// The BitSize, BitOffset, and DataBitOffset fields describe the bit
// size and offset of data members declared as bit fields in C/C++
// struct/union/class types.
//
// BitSize is the number of bits in the bit field.
//
// DataBitOffset, if non-zero, is the number of bits from the start of
// the enclosing entity (e.g. containing struct/class/union) to the
// start of the bit field. This corresponds to the DW_AT_data_bit_offset
// DWARF attribute that was introduced in DWARF 4.
//
// BitOffset, if non-zero, is the number of bits between the most
// significant bit of the storage unit holding the bit field to the
// most significant bit of the bit field. Here "storage unit" is the
// type name before the bit field (for a field "unsigned x:17", the
// storage unit is "unsigned"). BitOffset values can vary depending on
// the endianness of the system. BitOffset corresponds to the
// DW_AT_bit_offset DWARF attribute that was deprecated in DWARF 4 and
// removed in DWARF 5.
//
// At most one of DataBitOffset and BitOffset will be non-zero;
// DataBitOffset/BitOffset will only be non-zero if BitSize is
// non-zero. Whether a C compiler uses one or the other
// will depend on compiler vintage and command line options.
//
// Here is an example of C/C++ bit field use, along with what to
// expect in terms of DWARF bit offset info. Consider this code:
//
// struct S {
// int q;
// int j:5;
// int k:6;
// int m:5;
// int n:8;
// } s;
//
// For the code above, one would expect to see the following for
// DW_AT_bit_offset values (using GCC 8):
//
// Little | Big
// Endian | Endian
// |
// "j": 27 | 0
// "k": 21 | 5
// "m": 16 | 11
// "n": 8 | 16
//
// Note that in the above the offsets are purely with respect to the
// containing storage unit for j/k/m/n -- these values won't vary based
// on the size of prior data members in the containing struct.
//
// If the compiler emits DW_AT_data_bit_offset, the expected values
// would be:
//
// "j": 32
// "k": 37
// "m": 43
// "n": 48
//
// Here the value 32 for "j" reflects the fact that the bit field is
// preceded by other data members (recall that DW_AT_data_bit_offset
// values are relative to the start of the containing struct). Hence
// DW_AT_data_bit_offset values can be quite large for structs with
// many fields.
//
// DWARF also allow for the possibility of base types that have
// non-zero bit size and bit offset, so this information is also
// captured for base types, but it is worth noting that it is not
// possible to trigger this behavior using mainstream languages.
type StructField struct {
Name string
Type Type
ByteOffset int64
ByteSize int64 // usually zero; use Type.Size() for normal fields
BitOffset int64
DataBitOffset int64
BitSize int64 // zero if not a bit field
}
func (t *StructType) String() string {
if t.StructName != "" {
return t.Kind + " " + t.StructName
}
return t.Defn()
}
func (f *StructField) bitOffset() int64 {
if f.BitOffset != 0 {
return f.BitOffset
}
return f.DataBitOffset
}
func (t *StructType) Defn() string {
s := t.Kind
if t.StructName != "" {
s += " " + t.StructName
}
if t.Incomplete {
s += " /*incomplete*/"
return s
}
s += " {"
for i, f := range t.Field {
if i > 0 {
s += "; "
}
s += f.Name + " " + f.Type.String()
s += "@" + strconv.FormatInt(f.ByteOffset, 10)
if f.BitSize > 0 {
s += " : " + strconv.FormatInt(f.BitSize, 10)
s += "@" + strconv.FormatInt(f.bitOffset(), 10)
}
}
s += "}"
return s
}
// An EnumType represents an enumerated type.
// The only indication of its native integer type is its ByteSize
// (inside [CommonType]).
type EnumType struct {
CommonType
EnumName string
Val []*EnumValue
}
// An EnumValue represents a single enumeration value.
type EnumValue struct {
Name string
Val int64
}
func (t *EnumType) String() string {
s := "enum"
if t.EnumName != "" {
s += " " + t.EnumName
}
s += " {"
for i, v := range t.Val {
if i > 0 {
s += "; "
}
s += v.Name + "=" + strconv.FormatInt(v.Val, 10)
}
s += "}"
return s
}
// A FuncType represents a function type.
type FuncType struct {
CommonType
ReturnType Type
ParamType []Type
}
func (t *FuncType) String() string {
s := "func("
for i, t := range t.ParamType {
if i > 0 {
s += ", "
}
s += t.String()
}
s += ")"
if t.ReturnType != nil {
s += " " + t.ReturnType.String()
}
return s
}
// A DotDotDotType represents the variadic ... function parameter.
type DotDotDotType struct {
CommonType
}
func (t *DotDotDotType) String() string { return "..." }
// A TypedefType represents a named type.
type TypedefType struct {
CommonType
Type Type
}
func (t *TypedefType) String() string { return t.Name }
func (t *TypedefType) Size() int64 { return t.Type.Size() }
// An UnsupportedType is a placeholder returned in situations where we
// encounter a type that isn't supported.
type UnsupportedType struct {
CommonType
Tag Tag
}
func (t *UnsupportedType) String() string {
if t.Name != "" {
return t.Name
}
return t.Name + "(unsupported type " + t.Tag.String() + ")"
}
// typeReader is used to read from either the info section or the
// types section.
type typeReader interface {
Seek(Offset)
Next() (*Entry, error)
clone() typeReader
offset() Offset
// AddressSize returns the size in bytes of addresses in the current
// compilation unit.
AddressSize() int
}
// Type reads the type at off in the DWARF “info” section.
func (d *Data) Type(off Offset) (Type, error) {
return d.readType("info", d.Reader(), off, d.typeCache, nil)
}
type typeFixer struct {
typedefs []*TypedefType
arraytypes []*Type
}
func (tf *typeFixer) recordArrayType(t *Type) {
if t == nil {
return
}
_, ok := (*t).(*ArrayType)
if ok {
tf.arraytypes = append(tf.arraytypes, t)
}
}
func (tf *typeFixer) apply() {
for _, t := range tf.typedefs {
t.Common().ByteSize = t.Type.Size()
}
for _, t := range tf.arraytypes {
zeroArray(t)
}
}
// readType reads a type from r at off of name. It adds types to the
// type cache, appends new typedef types to typedefs, and computes the
// sizes of types. Callers should pass nil for typedefs; this is used
// for internal recursion.
func (d *Data) readType(name string, r typeReader, off Offset, typeCache map[Offset]Type, fixups *typeFixer) (Type, error) {
if t, ok := typeCache[off]; ok {
return t, nil
}
r.Seek(off)
e, err := r.Next()
if err != nil {
return nil, err
}
addressSize := r.AddressSize()
if e == nil || e.Offset != off {
return nil, DecodeError{name, off, "no type at offset"}
}
// If this is the root of the recursion, prepare to resolve
// typedef sizes and perform other fixups once the recursion is
// done. This must be done after the type graph is constructed
// because it may need to resolve cycles in a different order than
// readType encounters them.
if fixups == nil {
var fixer typeFixer
defer func() {
fixer.apply()
}()
fixups = &fixer
}
// Parse type from Entry.
// Must always set typeCache[off] before calling
// d.readType recursively, to handle circular types correctly.
var typ Type
nextDepth := 0
// Get next child; set err if error happens.
next := func() *Entry {
if !e.Children {
return nil
}
// Only return direct children.
// Skip over composite entries that happen to be nested
// inside this one. Most DWARF generators wouldn't generate
// such a thing, but clang does.
// See golang.org/issue/6472.
for {
kid, err1 := r.Next()
if err1 != nil {
err = err1
return nil
}
if kid == nil {
err = DecodeError{name, r.offset(), "unexpected end of DWARF entries"}
return nil
}
if kid.Tag == 0 {
if nextDepth > 0 {
nextDepth--
continue
}
return nil
}
if kid.Children {
nextDepth++
}
if nextDepth > 0 {
continue
}
return kid
}
}
// Get Type referred to by Entry's AttrType field.
// Set err if error happens. Not having a type is an error.
typeOf := func(e *Entry) Type {
tval := e.Val(AttrType)
var t Type
switch toff := tval.(type) {
case Offset:
if t, err = d.readType(name, r.clone(), toff, typeCache, fixups); err != nil {
return nil
}
case uint64:
if t, err = d.sigToType(toff); err != nil {
return nil
}
default:
// It appears that no Type means "void".
return new(VoidType)
}
return t
}
switch e.Tag {
case TagArrayType:
// Multi-dimensional array. (DWARF v2 §5.4)
// Attributes:
// AttrType:subtype [required]
// AttrStrideSize: size in bits of each element of the array
// AttrByteSize: size of entire array
// Children:
// TagSubrangeType or TagEnumerationType giving one dimension.
// dimensions are in left to right order.
t := new(ArrayType)
typ = t
typeCache[off] = t
if t.Type = typeOf(e); err != nil {
goto Error
}
t.StrideBitSize, _ = e.Val(AttrStrideSize).(int64)
// Accumulate dimensions,
var dims []int64
for kid := next(); kid != nil; kid = next() {
// TODO(rsc): Can also be TagEnumerationType
// but haven't seen that in the wild yet.
switch kid.Tag {
case TagSubrangeType:
count, ok := kid.Val(AttrCount).(int64)
if !ok {
// Old binaries may have an upper bound instead.
count, ok = kid.Val(AttrUpperBound).(int64)
if ok {
count++ // Length is one more than upper bound.
} else if len(dims) == 0 {
count = -1 // As in x[].
}
}
dims = append(dims, count)
case TagEnumerationType:
err = DecodeError{name, kid.Offset, "cannot handle enumeration type as array bound"}
goto Error
}
}
if len(dims) == 0 {
// LLVM generates this for x[].
dims = []int64{-1}
}
t.Count = dims[0]
for i := len(dims) - 1; i >= 1; i-- {
t.Type = &ArrayType{Type: t.Type, Count: dims[i]}
}
case TagBaseType:
// Basic type. (DWARF v2 §5.1)
// Attributes:
// AttrName: name of base type in programming language of the compilation unit [required]
// AttrEncoding: encoding value for type (encFloat etc) [required]
// AttrByteSize: size of type in bytes [required]
// AttrBitOffset: bit offset of value within containing storage unit
// AttrDataBitOffset: bit offset of value within containing storage unit
// AttrBitSize: size in bits
//
// For most languages BitOffset/DataBitOffset/BitSize will not be present
// for base types.
name, _ := e.Val(AttrName).(string)
enc, ok := e.Val(AttrEncoding).(int64)
if !ok {
err = DecodeError{name, e.Offset, "missing encoding attribute for " + name}
goto Error
}
switch enc {
default:
err = DecodeError{name, e.Offset, "unrecognized encoding attribute value"}
goto Error
case encAddress:
typ = new(AddrType)
case encBoolean:
typ = new(BoolType)
case encComplexFloat:
typ = new(ComplexType)
if name == "complex" {
// clang writes out 'complex' instead of 'complex float' or 'complex double'.
// clang also writes out a byte size that we can use to distinguish.
// See issue 8694.
switch byteSize, _ := e.Val(AttrByteSize).(int64); byteSize {
case 8:
name = "complex float"
case 16:
name = "complex double"
}
}
case encFloat:
typ = new(FloatType)
case encSigned:
typ = new(IntType)
case encUnsigned:
typ = new(UintType)
case encSignedChar:
typ = new(CharType)
case encUnsignedChar:
typ = new(UcharType)
}
typeCache[off] = typ
t := typ.(interface {
Basic() *BasicType
}).Basic()
t.Name = name
t.BitSize, _ = e.Val(AttrBitSize).(int64)
haveBitOffset := false
haveDataBitOffset := false
t.BitOffset, haveBitOffset = e.Val(AttrBitOffset).(int64)
t.DataBitOffset, haveDataBitOffset = e.Val(AttrDataBitOffset).(int64)
if haveBitOffset && haveDataBitOffset {
err = DecodeError{name, e.Offset, "duplicate bit offset attributes"}
goto Error
}
case TagClassType, TagStructType, TagUnionType:
// Structure, union, or class type. (DWARF v2 §5.5)
// Attributes:
// AttrName: name of struct, union, or class
// AttrByteSize: byte size [required]
// AttrDeclaration: if true, struct/union/class is incomplete
// Children:
// TagMember to describe one member.
// AttrName: name of member [required]
// AttrType: type of member [required]
// AttrByteSize: size in bytes
// AttrBitOffset: bit offset within bytes for bit fields
// AttrDataBitOffset: field bit offset relative to struct start
// AttrBitSize: bit size for bit fields
// AttrDataMemberLoc: location within struct [required for struct, class]
// There is much more to handle C++, all ignored for now.
t := new(StructType)
typ = t
typeCache[off] = t
switch e.Tag {
case TagClassType:
t.Kind = "class"
case TagStructType:
t.Kind = "struct"
case TagUnionType:
t.Kind = "union"
}
t.StructName, _ = e.Val(AttrName).(string)
t.Incomplete = e.Val(AttrDeclaration) != nil
t.Field = make([]*StructField, 0, 8)
var lastFieldType *Type
var lastFieldBitSize int64
var lastFieldByteOffset int64
for kid := next(); kid != nil; kid = next() {
if kid.Tag != TagMember {
continue
}
f := new(StructField)
if f.Type = typeOf(kid); err != nil {
goto Error
}
switch loc := kid.Val(AttrDataMemberLoc).(type) {
case []byte:
// TODO: Should have original compilation
// unit here, not unknownFormat.
b := makeBuf(d, unknownFormat{}, "location", 0, loc)
if b.uint8() != opPlusUconst {
err = DecodeError{name, kid.Offset, "unexpected opcode"}
goto Error
}
f.ByteOffset = int64(b.uint())
if b.err != nil {
err = b.err
goto Error
}
case int64:
f.ByteOffset = loc
}
f.Name, _ = kid.Val(AttrName).(string)
f.ByteSize, _ = kid.Val(AttrByteSize).(int64)
haveBitOffset := false
haveDataBitOffset := false
f.BitOffset, haveBitOffset = kid.Val(AttrBitOffset).(int64)
f.DataBitOffset, haveDataBitOffset = kid.Val(AttrDataBitOffset).(int64)
if haveBitOffset && haveDataBitOffset {
err = DecodeError{name, e.Offset, "duplicate bit offset attributes"}
goto Error
}
f.BitSize, _ = kid.Val(AttrBitSize).(int64)
t.Field = append(t.Field, f)
if lastFieldBitSize == 0 && lastFieldByteOffset == f.ByteOffset && t.Kind != "union" {
// Last field was zero width. Fix array length.
// (DWARF writes out 0-length arrays as if they were 1-length arrays.)
fixups.recordArrayType(lastFieldType)
}
lastFieldType = &f.Type
lastFieldByteOffset = f.ByteOffset
lastFieldBitSize = f.BitSize
}
if t.Kind != "union" {
b, ok := e.Val(AttrByteSize).(int64)
if ok && b == lastFieldByteOffset {
// Final field must be zero width. Fix array length.
fixups.recordArrayType(lastFieldType)
}
}
case TagConstType, TagVolatileType, TagRestrictType:
// Type modifier (DWARF v2 §5.2)
// Attributes:
// AttrType: subtype
t := new(QualType)
typ = t
typeCache[off] = t
if t.Type = typeOf(e); err != nil {
goto Error
}
switch e.Tag {
case TagConstType:
t.Qual = "const"
case TagRestrictType:
t.Qual = "restrict"
case TagVolatileType:
t.Qual = "volatile"
}
case TagEnumerationType:
// Enumeration type (DWARF v2 §5.6)
// Attributes:
// AttrName: enum name if any
// AttrByteSize: bytes required to represent largest value
// Children:
// TagEnumerator:
// AttrName: name of constant
// AttrConstValue: value of constant
t := new(EnumType)
typ = t
typeCache[off] = t
t.EnumName, _ = e.Val(AttrName).(string)
t.Val = make([]*EnumValue, 0, 8)
for kid := next(); kid != nil; kid = next() {
if kid.Tag == TagEnumerator {
f := new(EnumValue)
f.Name, _ = kid.Val(AttrName).(string)
f.Val, _ = kid.Val(AttrConstValue).(int64)
n := len(t.Val)
if n >= cap(t.Val) {
val := make([]*EnumValue, n, n*2)
copy(val, t.Val)
t.Val = val
}
t.Val = t.Val[0 : n+1]
t.Val[n] = f
}
}
case TagPointerType:
// Type modifier (DWARF v2 §5.2)
// Attributes:
// AttrType: subtype [not required! void* has no AttrType]
// AttrAddrClass: address class [ignored]
t := new(PtrType)
typ = t
typeCache[off] = t
if e.Val(AttrType) == nil {
t.Type = &VoidType{}
break
}
t.Type = typeOf(e)
case TagSubroutineType:
// Subroutine type. (DWARF v2 §5.7)
// Attributes:
// AttrType: type of return value if any
// AttrName: possible name of type [ignored]
// AttrPrototyped: whether used ANSI C prototype [ignored]
// Children:
// TagFormalParameter: typed parameter
// AttrType: type of parameter
// TagUnspecifiedParameter: final ...
t := new(FuncType)
typ = t
typeCache[off] = t
if t.ReturnType = typeOf(e); err != nil {
goto Error
}
t.ParamType = make([]Type, 0, 8)
for kid := next(); kid != nil; kid = next() {
var tkid Type
switch kid.Tag {
default:
continue
case TagFormalParameter:
if tkid = typeOf(kid); err != nil {
goto Error
}
case TagUnspecifiedParameters:
tkid = &DotDotDotType{}
}
t.ParamType = append(t.ParamType, tkid)
}
case TagTypedef:
// Typedef (DWARF v2 §5.3)
// Attributes:
// AttrName: name [required]
// AttrType: type definition [required]
t := new(TypedefType)
typ = t
typeCache[off] = t
t.Name, _ = e.Val(AttrName).(string)
t.Type = typeOf(e)
case TagUnspecifiedType:
// Unspecified type (DWARF v3 §5.2)
// Attributes:
// AttrName: name
t := new(UnspecifiedType)
typ = t
typeCache[off] = t
t.Name, _ = e.Val(AttrName).(string)
default:
// This is some other type DIE that we're currently not
// equipped to handle. Return an abstract "unsupported type"
// object in such cases.
t := new(UnsupportedType)
typ = t
typeCache[off] = t
t.Tag = e.Tag
t.Name, _ = e.Val(AttrName).(string)
}
if err != nil {
goto Error
}
{
b, ok := e.Val(AttrByteSize).(int64)
if !ok {
b = -1
switch t := typ.(type) {
case *TypedefType:
// Record that we need to resolve this
// type's size once the type graph is
// constructed.
fixups.typedefs = append(fixups.typedefs, t)
case *PtrType:
b = int64(addressSize)
}
}
typ.Common().ByteSize = b
}
return typ, nil
Error:
// If the parse fails, take the type out of the cache
// so that the next call with this offset doesn't hit
// the cache and return success.
delete(typeCache, off)
return nil, err
}
func zeroArray(t *Type) {
at := (*t).(*ArrayType)
if at.Type.Size() == 0 {
return
}
// Make a copy to avoid invalidating typeCache.
tt := *at
tt.Count = 0
*t = &tt
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dwarf
import (
"fmt"
"strconv"
)
// Parse the type units stored in a DWARF4 .debug_types section. Each
// type unit defines a single primary type and an 8-byte signature.
// Other sections may then use formRefSig8 to refer to the type.
// The typeUnit format is a single type with a signature. It holds
// the same data as a compilation unit.
type typeUnit struct {
unit
toff Offset // Offset to signature type within data.
name string // Name of .debug_type section.
cache Type // Cache the type, nil to start.
}
// Parse a .debug_types section.
func (d *Data) parseTypes(name string, types []byte) error {
b := makeBuf(d, unknownFormat{}, name, 0, types)
for len(b.data) > 0 {
base := b.off
n, dwarf64 := b.unitLength()
if n != Offset(uint32(n)) {
b.error("type unit length overflow")
return b.err
}
hdroff := b.off
vers := int(b.uint16())
if vers != 4 {
b.error("unsupported DWARF version " + strconv.Itoa(vers))
return b.err
}
var ao uint64
if !dwarf64 {
ao = uint64(b.uint32())
} else {
ao = b.uint64()
}
atable, err := d.parseAbbrev(ao, vers)
if err != nil {
return err
}
asize := b.uint8()
sig := b.uint64()
var toff uint32
if !dwarf64 {
toff = b.uint32()
} else {
to64 := b.uint64()
if to64 != uint64(uint32(to64)) {
b.error("type unit type offset overflow")
return b.err
}
toff = uint32(to64)
}
boff := b.off
d.typeSigs[sig] = &typeUnit{
unit: unit{
base: base,
off: boff,
data: b.bytes(int(n - (b.off - hdroff))),
atable: atable,
asize: int(asize),
vers: vers,
is64: dwarf64,
},
toff: Offset(toff),
name: name,
}
if b.err != nil {
return b.err
}
}
return nil
}
// Return the type for a type signature.
func (d *Data) sigToType(sig uint64) (Type, error) {
tu := d.typeSigs[sig]
if tu == nil {
return nil, fmt.Errorf("no type unit with signature %v", sig)
}
if tu.cache != nil {
return tu.cache, nil
}
b := makeBuf(d, tu, tu.name, tu.off, tu.data)
r := &typeUnitReader{d: d, tu: tu, b: b}
t, err := d.readType(tu.name, r, tu.toff, make(map[Offset]Type), nil)
if err != nil {
return nil, err
}
tu.cache = t
return t, nil
}
// typeUnitReader is a typeReader for a tagTypeUnit.
type typeUnitReader struct {
d *Data
tu *typeUnit
b buf
err error
}
// Seek to a new position in the type unit.
func (tur *typeUnitReader) Seek(off Offset) {
tur.err = nil
doff := off - tur.tu.off
if doff < 0 || doff >= Offset(len(tur.tu.data)) {
tur.err = fmt.Errorf("%s: offset %d out of range; max %d", tur.tu.name, doff, len(tur.tu.data))
return
}
tur.b = makeBuf(tur.d, tur.tu, tur.tu.name, off, tur.tu.data[doff:])
}
// AddressSize returns the size in bytes of addresses in the current type unit.
func (tur *typeUnitReader) AddressSize() int {
return tur.tu.unit.asize
}
// Next reads the next [Entry] from the type unit.
func (tur *typeUnitReader) Next() (*Entry, error) {
if tur.err != nil {
return nil, tur.err
}
if len(tur.tu.data) == 0 {
return nil, nil
}
e := tur.b.entry(nil, &tur.tu.unit)
if tur.b.err != nil {
tur.err = tur.b.err
return nil, tur.err
}
return e, nil
}
// clone returns a new reader for the type unit.
func (tur *typeUnitReader) clone() typeReader {
return &typeUnitReader{
d: tur.d,
tu: tur.tu,
b: makeBuf(tur.d, tur.tu, tur.tu.name, tur.tu.off, tur.tu.data),
}
}
// offset returns the current offset.
func (tur *typeUnitReader) offset() Offset {
return tur.b.off
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dwarf
import (
"sort"
"strconv"
)
// DWARF debug info is split into a sequence of compilation units.
// Each unit has its own abbreviation table and address size.
type unit struct {
base Offset // byte offset of header within the aggregate info
off Offset // byte offset of data within the aggregate info
data []byte
atable abbrevTable
*unit5 // info specific to DWARF 5 units
asize int
vers int
is64 bool // True for 64-bit DWARF format
utype uint8 // DWARF 5 unit type
}
type unit5 struct {
addrBase uint64
strOffsetsBase uint64
rngListsBase uint64
locListsBase uint64
}
// Implement the dataFormat interface.
func (u *unit) version() int {
return u.vers
}
func (u *unit) dwarf64() (bool, bool) {
return u.is64, true
}
func (u *unit) addrsize() int {
return u.asize
}
func (u *unit) addrBase() uint64 {
if u.unit5 != nil {
return u.unit5.addrBase
}
return 0
}
func (u *unit) strOffsetsBase() uint64 {
if u.unit5 != nil {
return u.unit5.strOffsetsBase
}
return 0
}
func (u *unit) rngListsBase() uint64 {
if u.unit5 != nil {
return u.unit5.rngListsBase
}
return 0
}
func (u *unit) locListsBase() uint64 {
if u.unit5 != nil {
return u.unit5.locListsBase
}
return 0
}
func (d *Data) parseUnits() ([]unit, error) {
// Count units.
nunit := 0
b := makeBuf(d, unknownFormat{}, "info", 0, d.info)
for len(b.data) > 0 {
len, _ := b.unitLength()
if len != Offset(uint32(len)) {
b.error("unit length overflow")
break
}
b.skip(int(len))
if len > 0 {
nunit++
}
}
if b.err != nil {
return nil, b.err
}
// Again, this time writing them down.
b = makeBuf(d, unknownFormat{}, "info", 0, d.info)
units := make([]unit, nunit)
for i := range units {
u := &units[i]
u.base = b.off
var n Offset
if b.err != nil {
return nil, b.err
}
for n == 0 {
n, u.is64 = b.unitLength()
}
dataOff := b.off
vers := b.uint16()
if vers < 2 || vers > 5 {
b.error("unsupported DWARF version " + strconv.Itoa(int(vers)))
break
}
u.vers = int(vers)
if vers >= 5 {
u.utype = b.uint8()
u.asize = int(b.uint8())
}
var abbrevOff uint64
if u.is64 {
abbrevOff = b.uint64()
} else {
abbrevOff = uint64(b.uint32())
}
atable, err := d.parseAbbrev(abbrevOff, u.vers)
if err != nil {
if b.err == nil {
b.err = err
}
break
}
u.atable = atable
if vers < 5 {
u.asize = int(b.uint8())
}
switch u.utype {
case utSkeleton, utSplitCompile:
b.uint64() // unit ID
case utType, utSplitType:
b.uint64() // type signature
if u.is64 { // type offset
b.uint64()
} else {
b.uint32()
}
}
u.off = b.off
u.data = b.bytes(int(n - (b.off - dataOff)))
}
if b.err != nil {
return nil, b.err
}
return units, nil
}
// offsetToUnit returns the index of the unit containing offset off.
// It returns -1 if no unit contains this offset.
func (d *Data) offsetToUnit(off Offset) int {
// Find the unit after off
next := sort.Search(len(d.unit), func(i int) bool {
return d.unit[i].off > off
})
if next == 0 {
return -1
}
u := &d.unit[next-1]
if u.off <= off && off < u.off+Offset(len(u.data)) {
return next - 1
}
return -1
}
func (d *Data) collectDwarf5BaseOffsets(u *unit) error {
if u.unit5 == nil {
panic("expected unit5 to be set up already")
}
b := makeBuf(d, u, "info", u.off, u.data)
cu := b.entry(nil, u)
if cu == nil {
// Unknown abbreviation table entry or some other fatal
// problem; bail early on the assumption that this will be
// detected at some later point.
return b.err
}
if iAddrBase, ok := cu.Val(AttrAddrBase).(int64); ok {
u.unit5.addrBase = uint64(iAddrBase)
}
if iStrOffsetsBase, ok := cu.Val(AttrStrOffsetsBase).(int64); ok {
u.unit5.strOffsetsBase = uint64(iStrOffsetsBase)
}
if iRngListsBase, ok := cu.Val(AttrRnglistsBase).(int64); ok {
u.unit5.rngListsBase = uint64(iRngListsBase)
}
if iLocListsBase, ok := cu.Val(AttrLoclistsBase).(int64); ok {
u.unit5.locListsBase = uint64(iLocListsBase)
}
return nil
}
/*
* ELF constants and data structures
*
* Derived from:
* $FreeBSD: src/sys/sys/elf32.h,v 1.8.14.1 2005/12/30 22:13:58 marcel Exp $
* $FreeBSD: src/sys/sys/elf64.h,v 1.10.14.1 2005/12/30 22:13:58 marcel Exp $
* $FreeBSD: src/sys/sys/elf_common.h,v 1.15.8.1 2005/12/30 22:13:58 marcel Exp $
* $FreeBSD: src/sys/alpha/include/elf.h,v 1.14 2003/09/25 01:10:22 peter Exp $
* $FreeBSD: src/sys/amd64/include/elf.h,v 1.18 2004/08/03 08:21:48 dfr Exp $
* $FreeBSD: src/sys/arm/include/elf.h,v 1.5.2.1 2006/06/30 21:42:52 cognet Exp $
* $FreeBSD: src/sys/i386/include/elf.h,v 1.16 2004/08/02 19:12:17 dfr Exp $
* $FreeBSD: src/sys/powerpc/include/elf.h,v 1.7 2004/11/02 09:47:01 ssouhlal Exp $
* $FreeBSD: src/sys/sparc64/include/elf.h,v 1.12 2003/09/25 01:10:26 peter Exp $
* "System V ABI" (http://www.sco.com/developers/gabi/latest/ch4.eheader.html)
* "ELF for the ARM® 64-bit Architecture (AArch64)" (ARM IHI 0056B)
* "RISC-V ELF psABI specification" (https://github.com/riscv-non-isa/riscv-elf-psabi-doc/blob/master/riscv-elf.adoc)
* llvm/BinaryFormat/ELF.h - ELF constants and structures
*
* Copyright (c) 1996-1998 John D. Polstra. All rights reserved.
* Copyright (c) 2001 David E. O'Brien
* Portions Copyright 2009 The Go Authors. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
* OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
* OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
* SUCH DAMAGE.
*/
package elf
import "strconv"
/*
* Constants
*/
// Indexes into the Header.Ident array.
const (
EI_CLASS = 4 /* Class of machine. */
EI_DATA = 5 /* Data format. */
EI_VERSION = 6 /* ELF format version. */
EI_OSABI = 7 /* Operating system / ABI identification */
EI_ABIVERSION = 8 /* ABI version */
EI_PAD = 9 /* Start of padding (per SVR4 ABI). */
EI_NIDENT = 16 /* Size of e_ident array. */
)
// Initial magic number for ELF files.
const ELFMAG = "\177ELF"
// Version is found in Header.Ident[EI_VERSION] and Header.Version.
type Version byte
const (
EV_NONE Version = 0
EV_CURRENT Version = 1
)
var versionStrings = []intName{
{0, "EV_NONE"},
{1, "EV_CURRENT"},
}
func (i Version) String() string { return stringName(uint32(i), versionStrings, false) }
func (i Version) GoString() string { return stringName(uint32(i), versionStrings, true) }
// Class is found in Header.Ident[EI_CLASS] and Header.Class.
type Class byte
const (
ELFCLASSNONE Class = 0 /* Unknown class. */
ELFCLASS32 Class = 1 /* 32-bit architecture. */
ELFCLASS64 Class = 2 /* 64-bit architecture. */
)
var classStrings = []intName{
{0, "ELFCLASSNONE"},
{1, "ELFCLASS32"},
{2, "ELFCLASS64"},
}
func (i Class) String() string { return stringName(uint32(i), classStrings, false) }
func (i Class) GoString() string { return stringName(uint32(i), classStrings, true) }
// Data is found in Header.Ident[EI_DATA] and Header.Data.
type Data byte
const (
ELFDATANONE Data = 0 /* Unknown data format. */
ELFDATA2LSB Data = 1 /* 2's complement little-endian. */
ELFDATA2MSB Data = 2 /* 2's complement big-endian. */
)
var dataStrings = []intName{
{0, "ELFDATANONE"},
{1, "ELFDATA2LSB"},
{2, "ELFDATA2MSB"},
}
func (i Data) String() string { return stringName(uint32(i), dataStrings, false) }
func (i Data) GoString() string { return stringName(uint32(i), dataStrings, true) }
// OSABI is found in Header.Ident[EI_OSABI] and Header.OSABI.
type OSABI byte
const (
ELFOSABI_NONE OSABI = 0 /* UNIX System V ABI */
ELFOSABI_HPUX OSABI = 1 /* HP-UX operating system */
ELFOSABI_NETBSD OSABI = 2 /* NetBSD */
ELFOSABI_LINUX OSABI = 3 /* Linux */
ELFOSABI_HURD OSABI = 4 /* Hurd */
ELFOSABI_86OPEN OSABI = 5 /* 86Open common IA32 ABI */
ELFOSABI_SOLARIS OSABI = 6 /* Solaris */
ELFOSABI_AIX OSABI = 7 /* AIX */
ELFOSABI_IRIX OSABI = 8 /* IRIX */
ELFOSABI_FREEBSD OSABI = 9 /* FreeBSD */
ELFOSABI_TRU64 OSABI = 10 /* TRU64 UNIX */
ELFOSABI_MODESTO OSABI = 11 /* Novell Modesto */
ELFOSABI_OPENBSD OSABI = 12 /* OpenBSD */
ELFOSABI_OPENVMS OSABI = 13 /* Open VMS */
ELFOSABI_NSK OSABI = 14 /* HP Non-Stop Kernel */
ELFOSABI_AROS OSABI = 15 /* Amiga Research OS */
ELFOSABI_FENIXOS OSABI = 16 /* The FenixOS highly scalable multi-core OS */
ELFOSABI_CLOUDABI OSABI = 17 /* Nuxi CloudABI */
ELFOSABI_ARM OSABI = 97 /* ARM */
ELFOSABI_STANDALONE OSABI = 255 /* Standalone (embedded) application */
)
var osabiStrings = []intName{
{0, "ELFOSABI_NONE"},
{1, "ELFOSABI_HPUX"},
{2, "ELFOSABI_NETBSD"},
{3, "ELFOSABI_LINUX"},
{4, "ELFOSABI_HURD"},
{5, "ELFOSABI_86OPEN"},
{6, "ELFOSABI_SOLARIS"},
{7, "ELFOSABI_AIX"},
{8, "ELFOSABI_IRIX"},
{9, "ELFOSABI_FREEBSD"},
{10, "ELFOSABI_TRU64"},
{11, "ELFOSABI_MODESTO"},
{12, "ELFOSABI_OPENBSD"},
{13, "ELFOSABI_OPENVMS"},
{14, "ELFOSABI_NSK"},
{15, "ELFOSABI_AROS"},
{16, "ELFOSABI_FENIXOS"},
{17, "ELFOSABI_CLOUDABI"},
{97, "ELFOSABI_ARM"},
{255, "ELFOSABI_STANDALONE"},
}
func (i OSABI) String() string { return stringName(uint32(i), osabiStrings, false) }
func (i OSABI) GoString() string { return stringName(uint32(i), osabiStrings, true) }
// Type is found in Header.Type.
type Type uint16
const (
ET_NONE Type = 0 /* Unknown type. */
ET_REL Type = 1 /* Relocatable. */
ET_EXEC Type = 2 /* Executable. */
ET_DYN Type = 3 /* Shared object. */
ET_CORE Type = 4 /* Core file. */
ET_LOOS Type = 0xfe00 /* First operating system specific. */
ET_HIOS Type = 0xfeff /* Last operating system-specific. */
ET_LOPROC Type = 0xff00 /* First processor-specific. */
ET_HIPROC Type = 0xffff /* Last processor-specific. */
)
var typeStrings = []intName{
{0, "ET_NONE"},
{1, "ET_REL"},
{2, "ET_EXEC"},
{3, "ET_DYN"},
{4, "ET_CORE"},
{0xfe00, "ET_LOOS"},
{0xfeff, "ET_HIOS"},
{0xff00, "ET_LOPROC"},
{0xffff, "ET_HIPROC"},
}
func (i Type) String() string { return stringName(uint32(i), typeStrings, false) }
func (i Type) GoString() string { return stringName(uint32(i), typeStrings, true) }
// Machine is found in Header.Machine.
type Machine uint16
const (
EM_NONE Machine = 0 /* Unknown machine. */
EM_M32 Machine = 1 /* AT&T WE32100. */
EM_SPARC Machine = 2 /* Sun SPARC. */
EM_386 Machine = 3 /* Intel i386. */
EM_68K Machine = 4 /* Motorola 68000. */
EM_88K Machine = 5 /* Motorola 88000. */
EM_860 Machine = 7 /* Intel i860. */
EM_MIPS Machine = 8 /* MIPS R3000 Big-Endian only. */
EM_S370 Machine = 9 /* IBM System/370. */
EM_MIPS_RS3_LE Machine = 10 /* MIPS R3000 Little-Endian. */
EM_PARISC Machine = 15 /* HP PA-RISC. */
EM_VPP500 Machine = 17 /* Fujitsu VPP500. */
EM_SPARC32PLUS Machine = 18 /* SPARC v8plus. */
EM_960 Machine = 19 /* Intel 80960. */
EM_PPC Machine = 20 /* PowerPC 32-bit. */
EM_PPC64 Machine = 21 /* PowerPC 64-bit. */
EM_S390 Machine = 22 /* IBM System/390. */
EM_V800 Machine = 36 /* NEC V800. */
EM_FR20 Machine = 37 /* Fujitsu FR20. */
EM_RH32 Machine = 38 /* TRW RH-32. */
EM_RCE Machine = 39 /* Motorola RCE. */
EM_ARM Machine = 40 /* ARM. */
EM_SH Machine = 42 /* Hitachi SH. */
EM_SPARCV9 Machine = 43 /* SPARC v9 64-bit. */
EM_TRICORE Machine = 44 /* Siemens TriCore embedded processor. */
EM_ARC Machine = 45 /* Argonaut RISC Core. */
EM_H8_300 Machine = 46 /* Hitachi H8/300. */
EM_H8_300H Machine = 47 /* Hitachi H8/300H. */
EM_H8S Machine = 48 /* Hitachi H8S. */
EM_H8_500 Machine = 49 /* Hitachi H8/500. */
EM_IA_64 Machine = 50 /* Intel IA-64 Processor. */
EM_MIPS_X Machine = 51 /* Stanford MIPS-X. */
EM_COLDFIRE Machine = 52 /* Motorola ColdFire. */
EM_68HC12 Machine = 53 /* Motorola M68HC12. */
EM_MMA Machine = 54 /* Fujitsu MMA. */
EM_PCP Machine = 55 /* Siemens PCP. */
EM_NCPU Machine = 56 /* Sony nCPU. */
EM_NDR1 Machine = 57 /* Denso NDR1 microprocessor. */
EM_STARCORE Machine = 58 /* Motorola Star*Core processor. */
EM_ME16 Machine = 59 /* Toyota ME16 processor. */
EM_ST100 Machine = 60 /* STMicroelectronics ST100 processor. */
EM_TINYJ Machine = 61 /* Advanced Logic Corp. TinyJ processor. */
EM_X86_64 Machine = 62 /* Advanced Micro Devices x86-64 */
EM_PDSP Machine = 63 /* Sony DSP Processor */
EM_PDP10 Machine = 64 /* Digital Equipment Corp. PDP-10 */
EM_PDP11 Machine = 65 /* Digital Equipment Corp. PDP-11 */
EM_FX66 Machine = 66 /* Siemens FX66 microcontroller */
EM_ST9PLUS Machine = 67 /* STMicroelectronics ST9+ 8/16 bit microcontroller */
EM_ST7 Machine = 68 /* STMicroelectronics ST7 8-bit microcontroller */
EM_68HC16 Machine = 69 /* Motorola MC68HC16 Microcontroller */
EM_68HC11 Machine = 70 /* Motorola MC68HC11 Microcontroller */
EM_68HC08 Machine = 71 /* Motorola MC68HC08 Microcontroller */
EM_68HC05 Machine = 72 /* Motorola MC68HC05 Microcontroller */
EM_SVX Machine = 73 /* Silicon Graphics SVx */
EM_ST19 Machine = 74 /* STMicroelectronics ST19 8-bit microcontroller */
EM_VAX Machine = 75 /* Digital VAX */
EM_CRIS Machine = 76 /* Axis Communications 32-bit embedded processor */
EM_JAVELIN Machine = 77 /* Infineon Technologies 32-bit embedded processor */
EM_FIREPATH Machine = 78 /* Element 14 64-bit DSP Processor */
EM_ZSP Machine = 79 /* LSI Logic 16-bit DSP Processor */
EM_MMIX Machine = 80 /* Donald Knuth's educational 64-bit processor */
EM_HUANY Machine = 81 /* Harvard University machine-independent object files */
EM_PRISM Machine = 82 /* SiTera Prism */
EM_AVR Machine = 83 /* Atmel AVR 8-bit microcontroller */
EM_FR30 Machine = 84 /* Fujitsu FR30 */
EM_D10V Machine = 85 /* Mitsubishi D10V */
EM_D30V Machine = 86 /* Mitsubishi D30V */
EM_V850 Machine = 87 /* NEC v850 */
EM_M32R Machine = 88 /* Mitsubishi M32R */
EM_MN10300 Machine = 89 /* Matsushita MN10300 */
EM_MN10200 Machine = 90 /* Matsushita MN10200 */
EM_PJ Machine = 91 /* picoJava */
EM_OPENRISC Machine = 92 /* OpenRISC 32-bit embedded processor */
EM_ARC_COMPACT Machine = 93 /* ARC International ARCompact processor (old spelling/synonym: EM_ARC_A5) */
EM_XTENSA Machine = 94 /* Tensilica Xtensa Architecture */
EM_VIDEOCORE Machine = 95 /* Alphamosaic VideoCore processor */
EM_TMM_GPP Machine = 96 /* Thompson Multimedia General Purpose Processor */
EM_NS32K Machine = 97 /* National Semiconductor 32000 series */
EM_TPC Machine = 98 /* Tenor Network TPC processor */
EM_SNP1K Machine = 99 /* Trebia SNP 1000 processor */
EM_ST200 Machine = 100 /* STMicroelectronics (www.st.com) ST200 microcontroller */
EM_IP2K Machine = 101 /* Ubicom IP2xxx microcontroller family */
EM_MAX Machine = 102 /* MAX Processor */
EM_CR Machine = 103 /* National Semiconductor CompactRISC microprocessor */
EM_F2MC16 Machine = 104 /* Fujitsu F2MC16 */
EM_MSP430 Machine = 105 /* Texas Instruments embedded microcontroller msp430 */
EM_BLACKFIN Machine = 106 /* Analog Devices Blackfin (DSP) processor */
EM_SE_C33 Machine = 107 /* S1C33 Family of Seiko Epson processors */
EM_SEP Machine = 108 /* Sharp embedded microprocessor */
EM_ARCA Machine = 109 /* Arca RISC Microprocessor */
EM_UNICORE Machine = 110 /* Microprocessor series from PKU-Unity Ltd. and MPRC of Peking University */
EM_EXCESS Machine = 111 /* eXcess: 16/32/64-bit configurable embedded CPU */
EM_DXP Machine = 112 /* Icera Semiconductor Inc. Deep Execution Processor */
EM_ALTERA_NIOS2 Machine = 113 /* Altera Nios II soft-core processor */
EM_CRX Machine = 114 /* National Semiconductor CompactRISC CRX microprocessor */
EM_XGATE Machine = 115 /* Motorola XGATE embedded processor */
EM_C166 Machine = 116 /* Infineon C16x/XC16x processor */
EM_M16C Machine = 117 /* Renesas M16C series microprocessors */
EM_DSPIC30F Machine = 118 /* Microchip Technology dsPIC30F Digital Signal Controller */
EM_CE Machine = 119 /* Freescale Communication Engine RISC core */
EM_M32C Machine = 120 /* Renesas M32C series microprocessors */
EM_TSK3000 Machine = 131 /* Altium TSK3000 core */
EM_RS08 Machine = 132 /* Freescale RS08 embedded processor */
EM_SHARC Machine = 133 /* Analog Devices SHARC family of 32-bit DSP processors */
EM_ECOG2 Machine = 134 /* Cyan Technology eCOG2 microprocessor */
EM_SCORE7 Machine = 135 /* Sunplus S+core7 RISC processor */
EM_DSP24 Machine = 136 /* New Japan Radio (NJR) 24-bit DSP Processor */
EM_VIDEOCORE3 Machine = 137 /* Broadcom VideoCore III processor */
EM_LATTICEMICO32 Machine = 138 /* RISC processor for Lattice FPGA architecture */
EM_SE_C17 Machine = 139 /* Seiko Epson C17 family */
EM_TI_C6000 Machine = 140 /* The Texas Instruments TMS320C6000 DSP family */
EM_TI_C2000 Machine = 141 /* The Texas Instruments TMS320C2000 DSP family */
EM_TI_C5500 Machine = 142 /* The Texas Instruments TMS320C55x DSP family */
EM_TI_ARP32 Machine = 143 /* Texas Instruments Application Specific RISC Processor, 32bit fetch */
EM_TI_PRU Machine = 144 /* Texas Instruments Programmable Realtime Unit */
EM_MMDSP_PLUS Machine = 160 /* STMicroelectronics 64bit VLIW Data Signal Processor */
EM_CYPRESS_M8C Machine = 161 /* Cypress M8C microprocessor */
EM_R32C Machine = 162 /* Renesas R32C series microprocessors */
EM_TRIMEDIA Machine = 163 /* NXP Semiconductors TriMedia architecture family */
EM_QDSP6 Machine = 164 /* QUALCOMM DSP6 Processor */
EM_8051 Machine = 165 /* Intel 8051 and variants */
EM_STXP7X Machine = 166 /* STMicroelectronics STxP7x family of configurable and extensible RISC processors */
EM_NDS32 Machine = 167 /* Andes Technology compact code size embedded RISC processor family */
EM_ECOG1 Machine = 168 /* Cyan Technology eCOG1X family */
EM_ECOG1X Machine = 168 /* Cyan Technology eCOG1X family */
EM_MAXQ30 Machine = 169 /* Dallas Semiconductor MAXQ30 Core Micro-controllers */
EM_XIMO16 Machine = 170 /* New Japan Radio (NJR) 16-bit DSP Processor */
EM_MANIK Machine = 171 /* M2000 Reconfigurable RISC Microprocessor */
EM_CRAYNV2 Machine = 172 /* Cray Inc. NV2 vector architecture */
EM_RX Machine = 173 /* Renesas RX family */
EM_METAG Machine = 174 /* Imagination Technologies META processor architecture */
EM_MCST_ELBRUS Machine = 175 /* MCST Elbrus general purpose hardware architecture */
EM_ECOG16 Machine = 176 /* Cyan Technology eCOG16 family */
EM_CR16 Machine = 177 /* National Semiconductor CompactRISC CR16 16-bit microprocessor */
EM_ETPU Machine = 178 /* Freescale Extended Time Processing Unit */
EM_SLE9X Machine = 179 /* Infineon Technologies SLE9X core */
EM_L10M Machine = 180 /* Intel L10M */
EM_K10M Machine = 181 /* Intel K10M */
EM_AARCH64 Machine = 183 /* ARM 64-bit Architecture (AArch64) */
EM_AVR32 Machine = 185 /* Atmel Corporation 32-bit microprocessor family */
EM_STM8 Machine = 186 /* STMicroeletronics STM8 8-bit microcontroller */
EM_TILE64 Machine = 187 /* Tilera TILE64 multicore architecture family */
EM_TILEPRO Machine = 188 /* Tilera TILEPro multicore architecture family */
EM_MICROBLAZE Machine = 189 /* Xilinx MicroBlaze 32-bit RISC soft processor core */
EM_CUDA Machine = 190 /* NVIDIA CUDA architecture */
EM_TILEGX Machine = 191 /* Tilera TILE-Gx multicore architecture family */
EM_CLOUDSHIELD Machine = 192 /* CloudShield architecture family */
EM_COREA_1ST Machine = 193 /* KIPO-KAIST Core-A 1st generation processor family */
EM_COREA_2ND Machine = 194 /* KIPO-KAIST Core-A 2nd generation processor family */
EM_ARC_COMPACT2 Machine = 195 /* Synopsys ARCompact V2 */
EM_OPEN8 Machine = 196 /* Open8 8-bit RISC soft processor core */
EM_RL78 Machine = 197 /* Renesas RL78 family */
EM_VIDEOCORE5 Machine = 198 /* Broadcom VideoCore V processor */
EM_78KOR Machine = 199 /* Renesas 78KOR family */
EM_56800EX Machine = 200 /* Freescale 56800EX Digital Signal Controller (DSC) */
EM_BA1 Machine = 201 /* Beyond BA1 CPU architecture */
EM_BA2 Machine = 202 /* Beyond BA2 CPU architecture */
EM_XCORE Machine = 203 /* XMOS xCORE processor family */
EM_MCHP_PIC Machine = 204 /* Microchip 8-bit PIC(r) family */
EM_INTEL205 Machine = 205 /* Reserved by Intel */
EM_INTEL206 Machine = 206 /* Reserved by Intel */
EM_INTEL207 Machine = 207 /* Reserved by Intel */
EM_INTEL208 Machine = 208 /* Reserved by Intel */
EM_INTEL209 Machine = 209 /* Reserved by Intel */
EM_KM32 Machine = 210 /* KM211 KM32 32-bit processor */
EM_KMX32 Machine = 211 /* KM211 KMX32 32-bit processor */
EM_KMX16 Machine = 212 /* KM211 KMX16 16-bit processor */
EM_KMX8 Machine = 213 /* KM211 KMX8 8-bit processor */
EM_KVARC Machine = 214 /* KM211 KVARC processor */
EM_CDP Machine = 215 /* Paneve CDP architecture family */
EM_COGE Machine = 216 /* Cognitive Smart Memory Processor */
EM_COOL Machine = 217 /* Bluechip Systems CoolEngine */
EM_NORC Machine = 218 /* Nanoradio Optimized RISC */
EM_CSR_KALIMBA Machine = 219 /* CSR Kalimba architecture family */
EM_Z80 Machine = 220 /* Zilog Z80 */
EM_VISIUM Machine = 221 /* Controls and Data Services VISIUMcore processor */
EM_FT32 Machine = 222 /* FTDI Chip FT32 high performance 32-bit RISC architecture */
EM_MOXIE Machine = 223 /* Moxie processor family */
EM_AMDGPU Machine = 224 /* AMD GPU architecture */
EM_RISCV Machine = 243 /* RISC-V */
EM_LANAI Machine = 244 /* Lanai 32-bit processor */
EM_BPF Machine = 247 /* Linux BPF – in-kernel virtual machine */
EM_LOONGARCH Machine = 258 /* LoongArch */
/* Non-standard or deprecated. */
EM_486 Machine = 6 /* Intel i486. */
EM_MIPS_RS4_BE Machine = 10 /* MIPS R4000 Big-Endian */
EM_ALPHA_STD Machine = 41 /* Digital Alpha (standard value). */
EM_ALPHA Machine = 0x9026 /* Alpha (written in the absence of an ABI) */
)
var machineStrings = []intName{
{0, "EM_NONE"},
{1, "EM_M32"},
{2, "EM_SPARC"},
{3, "EM_386"},
{4, "EM_68K"},
{5, "EM_88K"},
{7, "EM_860"},
{8, "EM_MIPS"},
{9, "EM_S370"},
{10, "EM_MIPS_RS3_LE"},
{15, "EM_PARISC"},
{17, "EM_VPP500"},
{18, "EM_SPARC32PLUS"},
{19, "EM_960"},
{20, "EM_PPC"},
{21, "EM_PPC64"},
{22, "EM_S390"},
{36, "EM_V800"},
{37, "EM_FR20"},
{38, "EM_RH32"},
{39, "EM_RCE"},
{40, "EM_ARM"},
{42, "EM_SH"},
{43, "EM_SPARCV9"},
{44, "EM_TRICORE"},
{45, "EM_ARC"},
{46, "EM_H8_300"},
{47, "EM_H8_300H"},
{48, "EM_H8S"},
{49, "EM_H8_500"},
{50, "EM_IA_64"},
{51, "EM_MIPS_X"},
{52, "EM_COLDFIRE"},
{53, "EM_68HC12"},
{54, "EM_MMA"},
{55, "EM_PCP"},
{56, "EM_NCPU"},
{57, "EM_NDR1"},
{58, "EM_STARCORE"},
{59, "EM_ME16"},
{60, "EM_ST100"},
{61, "EM_TINYJ"},
{62, "EM_X86_64"},
{63, "EM_PDSP"},
{64, "EM_PDP10"},
{65, "EM_PDP11"},
{66, "EM_FX66"},
{67, "EM_ST9PLUS"},
{68, "EM_ST7"},
{69, "EM_68HC16"},
{70, "EM_68HC11"},
{71, "EM_68HC08"},
{72, "EM_68HC05"},
{73, "EM_SVX"},
{74, "EM_ST19"},
{75, "EM_VAX"},
{76, "EM_CRIS"},
{77, "EM_JAVELIN"},
{78, "EM_FIREPATH"},
{79, "EM_ZSP"},
{80, "EM_MMIX"},
{81, "EM_HUANY"},
{82, "EM_PRISM"},
{83, "EM_AVR"},
{84, "EM_FR30"},
{85, "EM_D10V"},
{86, "EM_D30V"},
{87, "EM_V850"},
{88, "EM_M32R"},
{89, "EM_MN10300"},
{90, "EM_MN10200"},
{91, "EM_PJ"},
{92, "EM_OPENRISC"},
{93, "EM_ARC_COMPACT"},
{94, "EM_XTENSA"},
{95, "EM_VIDEOCORE"},
{96, "EM_TMM_GPP"},
{97, "EM_NS32K"},
{98, "EM_TPC"},
{99, "EM_SNP1K"},
{100, "EM_ST200"},
{101, "EM_IP2K"},
{102, "EM_MAX"},
{103, "EM_CR"},
{104, "EM_F2MC16"},
{105, "EM_MSP430"},
{106, "EM_BLACKFIN"},
{107, "EM_SE_C33"},
{108, "EM_SEP"},
{109, "EM_ARCA"},
{110, "EM_UNICORE"},
{111, "EM_EXCESS"},
{112, "EM_DXP"},
{113, "EM_ALTERA_NIOS2"},
{114, "EM_CRX"},
{115, "EM_XGATE"},
{116, "EM_C166"},
{117, "EM_M16C"},
{118, "EM_DSPIC30F"},
{119, "EM_CE"},
{120, "EM_M32C"},
{131, "EM_TSK3000"},
{132, "EM_RS08"},
{133, "EM_SHARC"},
{134, "EM_ECOG2"},
{135, "EM_SCORE7"},
{136, "EM_DSP24"},
{137, "EM_VIDEOCORE3"},
{138, "EM_LATTICEMICO32"},
{139, "EM_SE_C17"},
{140, "EM_TI_C6000"},
{141, "EM_TI_C2000"},
{142, "EM_TI_C5500"},
{143, "EM_TI_ARP32"},
{144, "EM_TI_PRU"},
{160, "EM_MMDSP_PLUS"},
{161, "EM_CYPRESS_M8C"},
{162, "EM_R32C"},
{163, "EM_TRIMEDIA"},
{164, "EM_QDSP6"},
{165, "EM_8051"},
{166, "EM_STXP7X"},
{167, "EM_NDS32"},
{168, "EM_ECOG1"},
{168, "EM_ECOG1X"},
{169, "EM_MAXQ30"},
{170, "EM_XIMO16"},
{171, "EM_MANIK"},
{172, "EM_CRAYNV2"},
{173, "EM_RX"},
{174, "EM_METAG"},
{175, "EM_MCST_ELBRUS"},
{176, "EM_ECOG16"},
{177, "EM_CR16"},
{178, "EM_ETPU"},
{179, "EM_SLE9X"},
{180, "EM_L10M"},
{181, "EM_K10M"},
{183, "EM_AARCH64"},
{185, "EM_AVR32"},
{186, "EM_STM8"},
{187, "EM_TILE64"},
{188, "EM_TILEPRO"},
{189, "EM_MICROBLAZE"},
{190, "EM_CUDA"},
{191, "EM_TILEGX"},
{192, "EM_CLOUDSHIELD"},
{193, "EM_COREA_1ST"},
{194, "EM_COREA_2ND"},
{195, "EM_ARC_COMPACT2"},
{196, "EM_OPEN8"},
{197, "EM_RL78"},
{198, "EM_VIDEOCORE5"},
{199, "EM_78KOR"},
{200, "EM_56800EX"},
{201, "EM_BA1"},
{202, "EM_BA2"},
{203, "EM_XCORE"},
{204, "EM_MCHP_PIC"},
{205, "EM_INTEL205"},
{206, "EM_INTEL206"},
{207, "EM_INTEL207"},
{208, "EM_INTEL208"},
{209, "EM_INTEL209"},
{210, "EM_KM32"},
{211, "EM_KMX32"},
{212, "EM_KMX16"},
{213, "EM_KMX8"},
{214, "EM_KVARC"},
{215, "EM_CDP"},
{216, "EM_COGE"},
{217, "EM_COOL"},
{218, "EM_NORC"},
{219, "EM_CSR_KALIMBA "},
{220, "EM_Z80 "},
{221, "EM_VISIUM "},
{222, "EM_FT32 "},
{223, "EM_MOXIE"},
{224, "EM_AMDGPU"},
{243, "EM_RISCV"},
{244, "EM_LANAI"},
{247, "EM_BPF"},
{258, "EM_LOONGARCH"},
/* Non-standard or deprecated. */
{6, "EM_486"},
{10, "EM_MIPS_RS4_BE"},
{41, "EM_ALPHA_STD"},
{0x9026, "EM_ALPHA"},
}
func (i Machine) String() string { return stringName(uint32(i), machineStrings, false) }
func (i Machine) GoString() string { return stringName(uint32(i), machineStrings, true) }
// Special section indices.
type SectionIndex int
const (
SHN_UNDEF SectionIndex = 0 /* Undefined, missing, irrelevant. */
SHN_LORESERVE SectionIndex = 0xff00 /* First of reserved range. */
SHN_LOPROC SectionIndex = 0xff00 /* First processor-specific. */
SHN_HIPROC SectionIndex = 0xff1f /* Last processor-specific. */
SHN_LOOS SectionIndex = 0xff20 /* First operating system-specific. */
SHN_HIOS SectionIndex = 0xff3f /* Last operating system-specific. */
SHN_ABS SectionIndex = 0xfff1 /* Absolute values. */
SHN_COMMON SectionIndex = 0xfff2 /* Common data. */
SHN_XINDEX SectionIndex = 0xffff /* Escape; index stored elsewhere. */
SHN_HIRESERVE SectionIndex = 0xffff /* Last of reserved range. */
)
var shnStrings = []intName{
{0, "SHN_UNDEF"},
{0xff00, "SHN_LOPROC"},
{0xff20, "SHN_LOOS"},
{0xfff1, "SHN_ABS"},
{0xfff2, "SHN_COMMON"},
{0xffff, "SHN_XINDEX"},
}
func (i SectionIndex) String() string { return stringName(uint32(i), shnStrings, false) }
func (i SectionIndex) GoString() string { return stringName(uint32(i), shnStrings, true) }
// Section type.
type SectionType uint32
const (
SHT_NULL SectionType = 0 /* inactive */
SHT_PROGBITS SectionType = 1 /* program defined information */
SHT_SYMTAB SectionType = 2 /* symbol table section */
SHT_STRTAB SectionType = 3 /* string table section */
SHT_RELA SectionType = 4 /* relocation section with addends */
SHT_HASH SectionType = 5 /* symbol hash table section */
SHT_DYNAMIC SectionType = 6 /* dynamic section */
SHT_NOTE SectionType = 7 /* note section */
SHT_NOBITS SectionType = 8 /* no space section */
SHT_REL SectionType = 9 /* relocation section - no addends */
SHT_SHLIB SectionType = 10 /* reserved - purpose unknown */
SHT_DYNSYM SectionType = 11 /* dynamic symbol table section */
SHT_INIT_ARRAY SectionType = 14 /* Initialization function pointers. */
SHT_FINI_ARRAY SectionType = 15 /* Termination function pointers. */
SHT_PREINIT_ARRAY SectionType = 16 /* Pre-initialization function ptrs. */
SHT_GROUP SectionType = 17 /* Section group. */
SHT_SYMTAB_SHNDX SectionType = 18 /* Section indexes (see SHN_XINDEX). */
SHT_LOOS SectionType = 0x60000000 /* First of OS specific semantics */
SHT_GNU_ATTRIBUTES SectionType = 0x6ffffff5 /* GNU object attributes */
SHT_GNU_HASH SectionType = 0x6ffffff6 /* GNU hash table */
SHT_GNU_LIBLIST SectionType = 0x6ffffff7 /* GNU prelink library list */
SHT_GNU_VERDEF SectionType = 0x6ffffffd /* GNU version definition section */
SHT_GNU_VERNEED SectionType = 0x6ffffffe /* GNU version needs section */
SHT_GNU_VERSYM SectionType = 0x6fffffff /* GNU version symbol table */
SHT_HIOS SectionType = 0x6fffffff /* Last of OS specific semantics */
SHT_LOPROC SectionType = 0x70000000 /* reserved range for processor */
SHT_RISCV_ATTRIBUTES SectionType = 0x70000003 /* RISCV object attributes */
SHT_MIPS_ABIFLAGS SectionType = 0x7000002a /* .MIPS.abiflags */
SHT_HIPROC SectionType = 0x7fffffff /* specific section header types */
SHT_LOUSER SectionType = 0x80000000 /* reserved range for application */
SHT_HIUSER SectionType = 0xffffffff /* specific indexes */
)
var shtStrings = []intName{
{0, "SHT_NULL"},
{1, "SHT_PROGBITS"},
{2, "SHT_SYMTAB"},
{3, "SHT_STRTAB"},
{4, "SHT_RELA"},
{5, "SHT_HASH"},
{6, "SHT_DYNAMIC"},
{7, "SHT_NOTE"},
{8, "SHT_NOBITS"},
{9, "SHT_REL"},
{10, "SHT_SHLIB"},
{11, "SHT_DYNSYM"},
{14, "SHT_INIT_ARRAY"},
{15, "SHT_FINI_ARRAY"},
{16, "SHT_PREINIT_ARRAY"},
{17, "SHT_GROUP"},
{18, "SHT_SYMTAB_SHNDX"},
{0x60000000, "SHT_LOOS"},
{0x6ffffff5, "SHT_GNU_ATTRIBUTES"},
{0x6ffffff6, "SHT_GNU_HASH"},
{0x6ffffff7, "SHT_GNU_LIBLIST"},
{0x6ffffffd, "SHT_GNU_VERDEF"},
{0x6ffffffe, "SHT_GNU_VERNEED"},
{0x6fffffff, "SHT_GNU_VERSYM"},
{0x70000000, "SHT_LOPROC"},
// We don't list the processor-dependent SectionType,
// as the values overlap.
{0x7000002a, "SHT_MIPS_ABIFLAGS"},
{0x7fffffff, "SHT_HIPROC"},
{0x80000000, "SHT_LOUSER"},
{0xffffffff, "SHT_HIUSER"},
}
func (i SectionType) String() string { return stringName(uint32(i), shtStrings, false) }
func (i SectionType) GoString() string { return stringName(uint32(i), shtStrings, true) }
// Section flags.
type SectionFlag uint32
const (
SHF_WRITE SectionFlag = 0x1 /* Section contains writable data. */
SHF_ALLOC SectionFlag = 0x2 /* Section occupies memory. */
SHF_EXECINSTR SectionFlag = 0x4 /* Section contains instructions. */
SHF_MERGE SectionFlag = 0x10 /* Section may be merged. */
SHF_STRINGS SectionFlag = 0x20 /* Section contains strings. */
SHF_INFO_LINK SectionFlag = 0x40 /* sh_info holds section index. */
SHF_LINK_ORDER SectionFlag = 0x80 /* Special ordering requirements. */
SHF_OS_NONCONFORMING SectionFlag = 0x100 /* OS-specific processing required. */
SHF_GROUP SectionFlag = 0x200 /* Member of section group. */
SHF_TLS SectionFlag = 0x400 /* Section contains TLS data. */
SHF_COMPRESSED SectionFlag = 0x800 /* Section is compressed. */
SHF_MASKOS SectionFlag = 0x0ff00000 /* OS-specific semantics. */
SHF_MASKPROC SectionFlag = 0xf0000000 /* Processor-specific semantics. */
)
var shfStrings = []intName{
{0x1, "SHF_WRITE"},
{0x2, "SHF_ALLOC"},
{0x4, "SHF_EXECINSTR"},
{0x10, "SHF_MERGE"},
{0x20, "SHF_STRINGS"},
{0x40, "SHF_INFO_LINK"},
{0x80, "SHF_LINK_ORDER"},
{0x100, "SHF_OS_NONCONFORMING"},
{0x200, "SHF_GROUP"},
{0x400, "SHF_TLS"},
{0x800, "SHF_COMPRESSED"},
}
func (i SectionFlag) String() string { return flagName(uint32(i), shfStrings, false) }
func (i SectionFlag) GoString() string { return flagName(uint32(i), shfStrings, true) }
// Section compression type.
type CompressionType int
const (
COMPRESS_ZLIB CompressionType = 1 /* ZLIB compression. */
COMPRESS_ZSTD CompressionType = 2 /* ZSTD compression. */
COMPRESS_LOOS CompressionType = 0x60000000 /* First OS-specific. */
COMPRESS_HIOS CompressionType = 0x6fffffff /* Last OS-specific. */
COMPRESS_LOPROC CompressionType = 0x70000000 /* First processor-specific type. */
COMPRESS_HIPROC CompressionType = 0x7fffffff /* Last processor-specific type. */
)
var compressionStrings = []intName{
{1, "COMPRESS_ZLIB"},
{2, "COMPRESS_ZSTD"},
{0x60000000, "COMPRESS_LOOS"},
{0x6fffffff, "COMPRESS_HIOS"},
{0x70000000, "COMPRESS_LOPROC"},
{0x7fffffff, "COMPRESS_HIPROC"},
}
func (i CompressionType) String() string { return stringName(uint32(i), compressionStrings, false) }
func (i CompressionType) GoString() string { return stringName(uint32(i), compressionStrings, true) }
// Prog.Type
type ProgType int
const (
PT_NULL ProgType = 0 /* Unused entry. */
PT_LOAD ProgType = 1 /* Loadable segment. */
PT_DYNAMIC ProgType = 2 /* Dynamic linking information segment. */
PT_INTERP ProgType = 3 /* Pathname of interpreter. */
PT_NOTE ProgType = 4 /* Auxiliary information. */
PT_SHLIB ProgType = 5 /* Reserved (not used). */
PT_PHDR ProgType = 6 /* Location of program header itself. */
PT_TLS ProgType = 7 /* Thread local storage segment */
PT_LOOS ProgType = 0x60000000 /* First OS-specific. */
PT_GNU_EH_FRAME ProgType = 0x6474e550 /* Frame unwind information */
PT_GNU_STACK ProgType = 0x6474e551 /* Stack flags */
PT_GNU_RELRO ProgType = 0x6474e552 /* Read only after relocs */
PT_GNU_PROPERTY ProgType = 0x6474e553 /* GNU property */
PT_GNU_MBIND_LO ProgType = 0x6474e555 /* Mbind segments start */
PT_GNU_MBIND_HI ProgType = 0x6474f554 /* Mbind segments finish */
PT_PAX_FLAGS ProgType = 0x65041580 /* PAX flags */
PT_OPENBSD_RANDOMIZE ProgType = 0x65a3dbe6 /* Random data */
PT_OPENBSD_WXNEEDED ProgType = 0x65a3dbe7 /* W^X violations */
PT_OPENBSD_NOBTCFI ProgType = 0x65a3dbe8 /* No branch target CFI */
PT_OPENBSD_BOOTDATA ProgType = 0x65a41be6 /* Boot arguments */
PT_SUNW_EH_FRAME ProgType = 0x6474e550 /* Frame unwind information */
PT_SUNWSTACK ProgType = 0x6ffffffb /* Stack segment */
PT_HIOS ProgType = 0x6fffffff /* Last OS-specific. */
PT_LOPROC ProgType = 0x70000000 /* First processor-specific type. */
PT_ARM_ARCHEXT ProgType = 0x70000000 /* Architecture compatibility */
PT_ARM_EXIDX ProgType = 0x70000001 /* Exception unwind tables */
PT_AARCH64_ARCHEXT ProgType = 0x70000000 /* Architecture compatibility */
PT_AARCH64_UNWIND ProgType = 0x70000001 /* Exception unwind tables */
PT_MIPS_REGINFO ProgType = 0x70000000 /* Register usage */
PT_MIPS_RTPROC ProgType = 0x70000001 /* Runtime procedures */
PT_MIPS_OPTIONS ProgType = 0x70000002 /* Options */
PT_MIPS_ABIFLAGS ProgType = 0x70000003 /* ABI flags */
PT_RISCV_ATTRIBUTES ProgType = 0x70000003 /* RISC-V ELF attribute section. */
PT_S390_PGSTE ProgType = 0x70000000 /* 4k page table size */
PT_HIPROC ProgType = 0x7fffffff /* Last processor-specific type. */
)
var ptStrings = []intName{
{0, "PT_NULL"},
{1, "PT_LOAD"},
{2, "PT_DYNAMIC"},
{3, "PT_INTERP"},
{4, "PT_NOTE"},
{5, "PT_SHLIB"},
{6, "PT_PHDR"},
{7, "PT_TLS"},
{0x60000000, "PT_LOOS"},
{0x6474e550, "PT_GNU_EH_FRAME"},
{0x6474e551, "PT_GNU_STACK"},
{0x6474e552, "PT_GNU_RELRO"},
{0x6474e553, "PT_GNU_PROPERTY"},
{0x65041580, "PT_PAX_FLAGS"},
{0x65a3dbe6, "PT_OPENBSD_RANDOMIZE"},
{0x65a3dbe7, "PT_OPENBSD_WXNEEDED"},
{0x65a41be6, "PT_OPENBSD_BOOTDATA"},
{0x6ffffffb, "PT_SUNWSTACK"},
{0x6fffffff, "PT_HIOS"},
{0x70000000, "PT_LOPROC"},
// We don't list the processor-dependent ProgTypes,
// as the values overlap.
{0x7fffffff, "PT_HIPROC"},
}
func (i ProgType) String() string { return stringName(uint32(i), ptStrings, false) }
func (i ProgType) GoString() string { return stringName(uint32(i), ptStrings, true) }
// Prog.Flag
type ProgFlag uint32
const (
PF_X ProgFlag = 0x1 /* Executable. */
PF_W ProgFlag = 0x2 /* Writable. */
PF_R ProgFlag = 0x4 /* Readable. */
PF_MASKOS ProgFlag = 0x0ff00000 /* Operating system-specific. */
PF_MASKPROC ProgFlag = 0xf0000000 /* Processor-specific. */
)
var pfStrings = []intName{
{0x1, "PF_X"},
{0x2, "PF_W"},
{0x4, "PF_R"},
}
func (i ProgFlag) String() string { return flagName(uint32(i), pfStrings, false) }
func (i ProgFlag) GoString() string { return flagName(uint32(i), pfStrings, true) }
// Dyn.Tag
type DynTag int
const (
DT_NULL DynTag = 0 /* Terminating entry. */
DT_NEEDED DynTag = 1 /* String table offset of a needed shared library. */
DT_PLTRELSZ DynTag = 2 /* Total size in bytes of PLT relocations. */
DT_PLTGOT DynTag = 3 /* Processor-dependent address. */
DT_HASH DynTag = 4 /* Address of symbol hash table. */
DT_STRTAB DynTag = 5 /* Address of string table. */
DT_SYMTAB DynTag = 6 /* Address of symbol table. */
DT_RELA DynTag = 7 /* Address of ElfNN_Rela relocations. */
DT_RELASZ DynTag = 8 /* Total size of ElfNN_Rela relocations. */
DT_RELAENT DynTag = 9 /* Size of each ElfNN_Rela relocation entry. */
DT_STRSZ DynTag = 10 /* Size of string table. */
DT_SYMENT DynTag = 11 /* Size of each symbol table entry. */
DT_INIT DynTag = 12 /* Address of initialization function. */
DT_FINI DynTag = 13 /* Address of finalization function. */
DT_SONAME DynTag = 14 /* String table offset of shared object name. */
DT_RPATH DynTag = 15 /* String table offset of library path. [sup] */
DT_SYMBOLIC DynTag = 16 /* Indicates "symbolic" linking. [sup] */
DT_REL DynTag = 17 /* Address of ElfNN_Rel relocations. */
DT_RELSZ DynTag = 18 /* Total size of ElfNN_Rel relocations. */
DT_RELENT DynTag = 19 /* Size of each ElfNN_Rel relocation. */
DT_PLTREL DynTag = 20 /* Type of relocation used for PLT. */
DT_DEBUG DynTag = 21 /* Reserved (not used). */
DT_TEXTREL DynTag = 22 /* Indicates there may be relocations in non-writable segments. [sup] */
DT_JMPREL DynTag = 23 /* Address of PLT relocations. */
DT_BIND_NOW DynTag = 24 /* [sup] */
DT_INIT_ARRAY DynTag = 25 /* Address of the array of pointers to initialization functions */
DT_FINI_ARRAY DynTag = 26 /* Address of the array of pointers to termination functions */
DT_INIT_ARRAYSZ DynTag = 27 /* Size in bytes of the array of initialization functions. */
DT_FINI_ARRAYSZ DynTag = 28 /* Size in bytes of the array of termination functions. */
DT_RUNPATH DynTag = 29 /* String table offset of a null-terminated library search path string. */
DT_FLAGS DynTag = 30 /* Object specific flag values. */
DT_ENCODING DynTag = 32 /* Values greater than or equal to DT_ENCODING
and less than DT_LOOS follow the rules for
the interpretation of the d_un union
as follows: even == 'd_ptr', even == 'd_val'
or none */
DT_PREINIT_ARRAY DynTag = 32 /* Address of the array of pointers to pre-initialization functions. */
DT_PREINIT_ARRAYSZ DynTag = 33 /* Size in bytes of the array of pre-initialization functions. */
DT_SYMTAB_SHNDX DynTag = 34 /* Address of SHT_SYMTAB_SHNDX section. */
DT_LOOS DynTag = 0x6000000d /* First OS-specific */
DT_HIOS DynTag = 0x6ffff000 /* Last OS-specific */
DT_VALRNGLO DynTag = 0x6ffffd00
DT_GNU_PRELINKED DynTag = 0x6ffffdf5
DT_GNU_CONFLICTSZ DynTag = 0x6ffffdf6
DT_GNU_LIBLISTSZ DynTag = 0x6ffffdf7
DT_CHECKSUM DynTag = 0x6ffffdf8
DT_PLTPADSZ DynTag = 0x6ffffdf9
DT_MOVEENT DynTag = 0x6ffffdfa
DT_MOVESZ DynTag = 0x6ffffdfb
DT_FEATURE DynTag = 0x6ffffdfc
DT_POSFLAG_1 DynTag = 0x6ffffdfd
DT_SYMINSZ DynTag = 0x6ffffdfe
DT_SYMINENT DynTag = 0x6ffffdff
DT_VALRNGHI DynTag = 0x6ffffdff
DT_ADDRRNGLO DynTag = 0x6ffffe00
DT_GNU_HASH DynTag = 0x6ffffef5
DT_TLSDESC_PLT DynTag = 0x6ffffef6
DT_TLSDESC_GOT DynTag = 0x6ffffef7
DT_GNU_CONFLICT DynTag = 0x6ffffef8
DT_GNU_LIBLIST DynTag = 0x6ffffef9
DT_CONFIG DynTag = 0x6ffffefa
DT_DEPAUDIT DynTag = 0x6ffffefb
DT_AUDIT DynTag = 0x6ffffefc
DT_PLTPAD DynTag = 0x6ffffefd
DT_MOVETAB DynTag = 0x6ffffefe
DT_SYMINFO DynTag = 0x6ffffeff
DT_ADDRRNGHI DynTag = 0x6ffffeff
DT_VERSYM DynTag = 0x6ffffff0
DT_RELACOUNT DynTag = 0x6ffffff9
DT_RELCOUNT DynTag = 0x6ffffffa
DT_FLAGS_1 DynTag = 0x6ffffffb
DT_VERDEF DynTag = 0x6ffffffc
DT_VERDEFNUM DynTag = 0x6ffffffd
DT_VERNEED DynTag = 0x6ffffffe
DT_VERNEEDNUM DynTag = 0x6fffffff
DT_LOPROC DynTag = 0x70000000 /* First processor-specific type. */
DT_MIPS_RLD_VERSION DynTag = 0x70000001
DT_MIPS_TIME_STAMP DynTag = 0x70000002
DT_MIPS_ICHECKSUM DynTag = 0x70000003
DT_MIPS_IVERSION DynTag = 0x70000004
DT_MIPS_FLAGS DynTag = 0x70000005
DT_MIPS_BASE_ADDRESS DynTag = 0x70000006
DT_MIPS_MSYM DynTag = 0x70000007
DT_MIPS_CONFLICT DynTag = 0x70000008
DT_MIPS_LIBLIST DynTag = 0x70000009
DT_MIPS_LOCAL_GOTNO DynTag = 0x7000000a
DT_MIPS_CONFLICTNO DynTag = 0x7000000b
DT_MIPS_LIBLISTNO DynTag = 0x70000010
DT_MIPS_SYMTABNO DynTag = 0x70000011
DT_MIPS_UNREFEXTNO DynTag = 0x70000012
DT_MIPS_GOTSYM DynTag = 0x70000013
DT_MIPS_HIPAGENO DynTag = 0x70000014
DT_MIPS_RLD_MAP DynTag = 0x70000016
DT_MIPS_DELTA_CLASS DynTag = 0x70000017
DT_MIPS_DELTA_CLASS_NO DynTag = 0x70000018
DT_MIPS_DELTA_INSTANCE DynTag = 0x70000019
DT_MIPS_DELTA_INSTANCE_NO DynTag = 0x7000001a
DT_MIPS_DELTA_RELOC DynTag = 0x7000001b
DT_MIPS_DELTA_RELOC_NO DynTag = 0x7000001c
DT_MIPS_DELTA_SYM DynTag = 0x7000001d
DT_MIPS_DELTA_SYM_NO DynTag = 0x7000001e
DT_MIPS_DELTA_CLASSSYM DynTag = 0x70000020
DT_MIPS_DELTA_CLASSSYM_NO DynTag = 0x70000021
DT_MIPS_CXX_FLAGS DynTag = 0x70000022
DT_MIPS_PIXIE_INIT DynTag = 0x70000023
DT_MIPS_SYMBOL_LIB DynTag = 0x70000024
DT_MIPS_LOCALPAGE_GOTIDX DynTag = 0x70000025
DT_MIPS_LOCAL_GOTIDX DynTag = 0x70000026
DT_MIPS_HIDDEN_GOTIDX DynTag = 0x70000027
DT_MIPS_PROTECTED_GOTIDX DynTag = 0x70000028
DT_MIPS_OPTIONS DynTag = 0x70000029
DT_MIPS_INTERFACE DynTag = 0x7000002a
DT_MIPS_DYNSTR_ALIGN DynTag = 0x7000002b
DT_MIPS_INTERFACE_SIZE DynTag = 0x7000002c
DT_MIPS_RLD_TEXT_RESOLVE_ADDR DynTag = 0x7000002d
DT_MIPS_PERF_SUFFIX DynTag = 0x7000002e
DT_MIPS_COMPACT_SIZE DynTag = 0x7000002f
DT_MIPS_GP_VALUE DynTag = 0x70000030
DT_MIPS_AUX_DYNAMIC DynTag = 0x70000031
DT_MIPS_PLTGOT DynTag = 0x70000032
DT_MIPS_RWPLT DynTag = 0x70000034
DT_MIPS_RLD_MAP_REL DynTag = 0x70000035
DT_PPC_GOT DynTag = 0x70000000
DT_PPC_OPT DynTag = 0x70000001
DT_PPC64_GLINK DynTag = 0x70000000
DT_PPC64_OPD DynTag = 0x70000001
DT_PPC64_OPDSZ DynTag = 0x70000002
DT_PPC64_OPT DynTag = 0x70000003
DT_SPARC_REGISTER DynTag = 0x70000001
DT_AUXILIARY DynTag = 0x7ffffffd
DT_USED DynTag = 0x7ffffffe
DT_FILTER DynTag = 0x7fffffff
DT_HIPROC DynTag = 0x7fffffff /* Last processor-specific type. */
)
var dtStrings = []intName{
{0, "DT_NULL"},
{1, "DT_NEEDED"},
{2, "DT_PLTRELSZ"},
{3, "DT_PLTGOT"},
{4, "DT_HASH"},
{5, "DT_STRTAB"},
{6, "DT_SYMTAB"},
{7, "DT_RELA"},
{8, "DT_RELASZ"},
{9, "DT_RELAENT"},
{10, "DT_STRSZ"},
{11, "DT_SYMENT"},
{12, "DT_INIT"},
{13, "DT_FINI"},
{14, "DT_SONAME"},
{15, "DT_RPATH"},
{16, "DT_SYMBOLIC"},
{17, "DT_REL"},
{18, "DT_RELSZ"},
{19, "DT_RELENT"},
{20, "DT_PLTREL"},
{21, "DT_DEBUG"},
{22, "DT_TEXTREL"},
{23, "DT_JMPREL"},
{24, "DT_BIND_NOW"},
{25, "DT_INIT_ARRAY"},
{26, "DT_FINI_ARRAY"},
{27, "DT_INIT_ARRAYSZ"},
{28, "DT_FINI_ARRAYSZ"},
{29, "DT_RUNPATH"},
{30, "DT_FLAGS"},
{32, "DT_ENCODING"},
{32, "DT_PREINIT_ARRAY"},
{33, "DT_PREINIT_ARRAYSZ"},
{34, "DT_SYMTAB_SHNDX"},
{0x6000000d, "DT_LOOS"},
{0x6ffff000, "DT_HIOS"},
{0x6ffffd00, "DT_VALRNGLO"},
{0x6ffffdf5, "DT_GNU_PRELINKED"},
{0x6ffffdf6, "DT_GNU_CONFLICTSZ"},
{0x6ffffdf7, "DT_GNU_LIBLISTSZ"},
{0x6ffffdf8, "DT_CHECKSUM"},
{0x6ffffdf9, "DT_PLTPADSZ"},
{0x6ffffdfa, "DT_MOVEENT"},
{0x6ffffdfb, "DT_MOVESZ"},
{0x6ffffdfc, "DT_FEATURE"},
{0x6ffffdfd, "DT_POSFLAG_1"},
{0x6ffffdfe, "DT_SYMINSZ"},
{0x6ffffdff, "DT_SYMINENT"},
{0x6ffffdff, "DT_VALRNGHI"},
{0x6ffffe00, "DT_ADDRRNGLO"},
{0x6ffffef5, "DT_GNU_HASH"},
{0x6ffffef6, "DT_TLSDESC_PLT"},
{0x6ffffef7, "DT_TLSDESC_GOT"},
{0x6ffffef8, "DT_GNU_CONFLICT"},
{0x6ffffef9, "DT_GNU_LIBLIST"},
{0x6ffffefa, "DT_CONFIG"},
{0x6ffffefb, "DT_DEPAUDIT"},
{0x6ffffefc, "DT_AUDIT"},
{0x6ffffefd, "DT_PLTPAD"},
{0x6ffffefe, "DT_MOVETAB"},
{0x6ffffeff, "DT_SYMINFO"},
{0x6ffffeff, "DT_ADDRRNGHI"},
{0x6ffffff0, "DT_VERSYM"},
{0x6ffffff9, "DT_RELACOUNT"},
{0x6ffffffa, "DT_RELCOUNT"},
{0x6ffffffb, "DT_FLAGS_1"},
{0x6ffffffc, "DT_VERDEF"},
{0x6ffffffd, "DT_VERDEFNUM"},
{0x6ffffffe, "DT_VERNEED"},
{0x6fffffff, "DT_VERNEEDNUM"},
{0x70000000, "DT_LOPROC"},
// We don't list the processor-dependent DynTags,
// as the values overlap.
{0x7ffffffd, "DT_AUXILIARY"},
{0x7ffffffe, "DT_USED"},
{0x7fffffff, "DT_FILTER"},
}
func (i DynTag) String() string { return stringName(uint32(i), dtStrings, false) }
func (i DynTag) GoString() string { return stringName(uint32(i), dtStrings, true) }
// DT_FLAGS values.
type DynFlag int
const (
DF_ORIGIN DynFlag = 0x0001 /* Indicates that the object being loaded may
make reference to the
$ORIGIN substitution string */
DF_SYMBOLIC DynFlag = 0x0002 /* Indicates "symbolic" linking. */
DF_TEXTREL DynFlag = 0x0004 /* Indicates there may be relocations in non-writable segments. */
DF_BIND_NOW DynFlag = 0x0008 /* Indicates that the dynamic linker should
process all relocations for the object
containing this entry before transferring
control to the program. */
DF_STATIC_TLS DynFlag = 0x0010 /* Indicates that the shared object or
executable contains code using a static
thread-local storage scheme. */
)
var dflagStrings = []intName{
{0x0001, "DF_ORIGIN"},
{0x0002, "DF_SYMBOLIC"},
{0x0004, "DF_TEXTREL"},
{0x0008, "DF_BIND_NOW"},
{0x0010, "DF_STATIC_TLS"},
}
func (i DynFlag) String() string { return flagName(uint32(i), dflagStrings, false) }
func (i DynFlag) GoString() string { return flagName(uint32(i), dflagStrings, true) }
// DT_FLAGS_1 values.
type DynFlag1 uint32
const (
// Indicates that all relocations for this object must be processed before
// returning control to the program.
DF_1_NOW DynFlag1 = 0x00000001
// Unused.
DF_1_GLOBAL DynFlag1 = 0x00000002
// Indicates that the object is a member of a group.
DF_1_GROUP DynFlag1 = 0x00000004
// Indicates that the object cannot be deleted from a process.
DF_1_NODELETE DynFlag1 = 0x00000008
// Meaningful only for filters. Indicates that all associated filtees be
// processed immediately.
DF_1_LOADFLTR DynFlag1 = 0x00000010
// Indicates that this object's initialization section be run before any other
// objects loaded.
DF_1_INITFIRST DynFlag1 = 0x00000020
// Indicates that the object cannot be added to a running process with dlopen.
DF_1_NOOPEN DynFlag1 = 0x00000040
// Indicates the object requires $ORIGIN processing.
DF_1_ORIGIN DynFlag1 = 0x00000080
// Indicates that the object should use direct binding information.
DF_1_DIRECT DynFlag1 = 0x00000100
// Unused.
DF_1_TRANS DynFlag1 = 0x00000200
// Indicates that the objects symbol table is to interpose before all symbols
// except the primary load object, which is typically the executable.
DF_1_INTERPOSE DynFlag1 = 0x00000400
// Indicates that the search for dependencies of this object ignores any
// default library search paths.
DF_1_NODEFLIB DynFlag1 = 0x00000800
// Indicates that this object is not dumped by dldump. Candidates are objects
// with no relocations that might get included when generating alternative
// objects using.
DF_1_NODUMP DynFlag1 = 0x00001000
// Identifies this object as a configuration alternative object generated by
// crle. Triggers the runtime linker to search for a configuration file $ORIGIN/ld.config.app-name.
DF_1_CONFALT DynFlag1 = 0x00002000
// Meaningful only for filtees. Terminates a filters search for any
// further filtees.
DF_1_ENDFILTEE DynFlag1 = 0x00004000
// Indicates that this object has displacement relocations applied.
DF_1_DISPRELDNE DynFlag1 = 0x00008000
// Indicates that this object has displacement relocations pending.
DF_1_DISPRELPND DynFlag1 = 0x00010000
// Indicates that this object contains symbols that cannot be directly
// bound to.
DF_1_NODIRECT DynFlag1 = 0x00020000
// Reserved for internal use by the kernel runtime-linker.
DF_1_IGNMULDEF DynFlag1 = 0x00040000
// Reserved for internal use by the kernel runtime-linker.
DF_1_NOKSYMS DynFlag1 = 0x00080000
// Reserved for internal use by the kernel runtime-linker.
DF_1_NOHDR DynFlag1 = 0x00100000
// Indicates that this object has been edited or has been modified since the
// objects original construction by the link-editor.
DF_1_EDITED DynFlag1 = 0x00200000
// Reserved for internal use by the kernel runtime-linker.
DF_1_NORELOC DynFlag1 = 0x00400000
// Indicates that the object contains individual symbols that should interpose
// before all symbols except the primary load object, which is typically the
// executable.
DF_1_SYMINTPOSE DynFlag1 = 0x00800000
// Indicates that the executable requires global auditing.
DF_1_GLOBAUDIT DynFlag1 = 0x01000000
// Indicates that the object defines, or makes reference to singleton symbols.
DF_1_SINGLETON DynFlag1 = 0x02000000
// Indicates that the object is a stub.
DF_1_STUB DynFlag1 = 0x04000000
// Indicates that the object is a position-independent executable.
DF_1_PIE DynFlag1 = 0x08000000
// Indicates that the object is a kernel module.
DF_1_KMOD DynFlag1 = 0x10000000
// Indicates that the object is a weak standard filter.
DF_1_WEAKFILTER DynFlag1 = 0x20000000
// Unused.
DF_1_NOCOMMON DynFlag1 = 0x40000000
)
var dflag1Strings = []intName{
{0x00000001, "DF_1_NOW"},
{0x00000002, "DF_1_GLOBAL"},
{0x00000004, "DF_1_GROUP"},
{0x00000008, "DF_1_NODELETE"},
{0x00000010, "DF_1_LOADFLTR"},
{0x00000020, "DF_1_INITFIRST"},
{0x00000040, "DF_1_NOOPEN"},
{0x00000080, "DF_1_ORIGIN"},
{0x00000100, "DF_1_DIRECT"},
{0x00000200, "DF_1_TRANS"},
{0x00000400, "DF_1_INTERPOSE"},
{0x00000800, "DF_1_NODEFLIB"},
{0x00001000, "DF_1_NODUMP"},
{0x00002000, "DF_1_CONFALT"},
{0x00004000, "DF_1_ENDFILTEE"},
{0x00008000, "DF_1_DISPRELDNE"},
{0x00010000, "DF_1_DISPRELPND"},
{0x00020000, "DF_1_NODIRECT"},
{0x00040000, "DF_1_IGNMULDEF"},
{0x00080000, "DF_1_NOKSYMS"},
{0x00100000, "DF_1_NOHDR"},
{0x00200000, "DF_1_EDITED"},
{0x00400000, "DF_1_NORELOC"},
{0x00800000, "DF_1_SYMINTPOSE"},
{0x01000000, "DF_1_GLOBAUDIT"},
{0x02000000, "DF_1_SINGLETON"},
{0x04000000, "DF_1_STUB"},
{0x08000000, "DF_1_PIE"},
{0x10000000, "DF_1_KMOD"},
{0x20000000, "DF_1_WEAKFILTER"},
{0x40000000, "DF_1_NOCOMMON"},
}
func (i DynFlag1) String() string { return flagName(uint32(i), dflag1Strings, false) }
func (i DynFlag1) GoString() string { return flagName(uint32(i), dflag1Strings, true) }
// NType values; used in core files.
type NType int
const (
NT_PRSTATUS NType = 1 /* Process status. */
NT_FPREGSET NType = 2 /* Floating point registers. */
NT_PRPSINFO NType = 3 /* Process state info. */
)
var ntypeStrings = []intName{
{1, "NT_PRSTATUS"},
{2, "NT_FPREGSET"},
{3, "NT_PRPSINFO"},
}
func (i NType) String() string { return stringName(uint32(i), ntypeStrings, false) }
func (i NType) GoString() string { return stringName(uint32(i), ntypeStrings, true) }
/* Symbol Binding - ELFNN_ST_BIND - st_info */
type SymBind int
const (
STB_LOCAL SymBind = 0 /* Local symbol */
STB_GLOBAL SymBind = 1 /* Global symbol */
STB_WEAK SymBind = 2 /* like global - lower precedence */
STB_LOOS SymBind = 10 /* Reserved range for operating system */
STB_HIOS SymBind = 12 /* specific semantics. */
STB_LOPROC SymBind = 13 /* reserved range for processor */
STB_HIPROC SymBind = 15 /* specific semantics. */
)
var stbStrings = []intName{
{0, "STB_LOCAL"},
{1, "STB_GLOBAL"},
{2, "STB_WEAK"},
{10, "STB_LOOS"},
{12, "STB_HIOS"},
{13, "STB_LOPROC"},
{15, "STB_HIPROC"},
}
func (i SymBind) String() string { return stringName(uint32(i), stbStrings, false) }
func (i SymBind) GoString() string { return stringName(uint32(i), stbStrings, true) }
/* Symbol type - ELFNN_ST_TYPE - st_info */
type SymType int
const (
STT_NOTYPE SymType = 0 /* Unspecified type. */
STT_OBJECT SymType = 1 /* Data object. */
STT_FUNC SymType = 2 /* Function. */
STT_SECTION SymType = 3 /* Section. */
STT_FILE SymType = 4 /* Source file. */
STT_COMMON SymType = 5 /* Uninitialized common block. */
STT_TLS SymType = 6 /* TLS object. */
STT_LOOS SymType = 10 /* Reserved range for operating system */
STT_HIOS SymType = 12 /* specific semantics. */
STT_LOPROC SymType = 13 /* reserved range for processor */
STT_HIPROC SymType = 15 /* specific semantics. */
/* Non-standard symbol types. */
STT_RELC SymType = 8 /* Complex relocation expression. */
STT_SRELC SymType = 9 /* Signed complex relocation expression. */
STT_GNU_IFUNC SymType = 10 /* Indirect code object. */
)
var sttStrings = []intName{
{0, "STT_NOTYPE"},
{1, "STT_OBJECT"},
{2, "STT_FUNC"},
{3, "STT_SECTION"},
{4, "STT_FILE"},
{5, "STT_COMMON"},
{6, "STT_TLS"},
{8, "STT_RELC"},
{9, "STT_SRELC"},
{10, "STT_LOOS"},
{12, "STT_HIOS"},
{13, "STT_LOPROC"},
{15, "STT_HIPROC"},
}
func (i SymType) String() string { return stringName(uint32(i), sttStrings, false) }
func (i SymType) GoString() string { return stringName(uint32(i), sttStrings, true) }
/* Symbol visibility - ELFNN_ST_VISIBILITY - st_other */
type SymVis int
const (
STV_DEFAULT SymVis = 0x0 /* Default visibility (see binding). */
STV_INTERNAL SymVis = 0x1 /* Special meaning in relocatable objects. */
STV_HIDDEN SymVis = 0x2 /* Not visible. */
STV_PROTECTED SymVis = 0x3 /* Visible but not preemptible. */
)
var stvStrings = []intName{
{0x0, "STV_DEFAULT"},
{0x1, "STV_INTERNAL"},
{0x2, "STV_HIDDEN"},
{0x3, "STV_PROTECTED"},
}
func (i SymVis) String() string { return stringName(uint32(i), stvStrings, false) }
func (i SymVis) GoString() string { return stringName(uint32(i), stvStrings, true) }
/*
* Relocation types.
*/
// Relocation types for x86-64.
type R_X86_64 int
const (
R_X86_64_NONE R_X86_64 = 0 /* No relocation. */
R_X86_64_64 R_X86_64 = 1 /* Add 64 bit symbol value. */
R_X86_64_PC32 R_X86_64 = 2 /* PC-relative 32 bit signed sym value. */
R_X86_64_GOT32 R_X86_64 = 3 /* PC-relative 32 bit GOT offset. */
R_X86_64_PLT32 R_X86_64 = 4 /* PC-relative 32 bit PLT offset. */
R_X86_64_COPY R_X86_64 = 5 /* Copy data from shared object. */
R_X86_64_GLOB_DAT R_X86_64 = 6 /* Set GOT entry to data address. */
R_X86_64_JMP_SLOT R_X86_64 = 7 /* Set GOT entry to code address. */
R_X86_64_RELATIVE R_X86_64 = 8 /* Add load address of shared object. */
R_X86_64_GOTPCREL R_X86_64 = 9 /* Add 32 bit signed pcrel offset to GOT. */
R_X86_64_32 R_X86_64 = 10 /* Add 32 bit zero extended symbol value */
R_X86_64_32S R_X86_64 = 11 /* Add 32 bit sign extended symbol value */
R_X86_64_16 R_X86_64 = 12 /* Add 16 bit zero extended symbol value */
R_X86_64_PC16 R_X86_64 = 13 /* Add 16 bit signed extended pc relative symbol value */
R_X86_64_8 R_X86_64 = 14 /* Add 8 bit zero extended symbol value */
R_X86_64_PC8 R_X86_64 = 15 /* Add 8 bit signed extended pc relative symbol value */
R_X86_64_DTPMOD64 R_X86_64 = 16 /* ID of module containing symbol */
R_X86_64_DTPOFF64 R_X86_64 = 17 /* Offset in TLS block */
R_X86_64_TPOFF64 R_X86_64 = 18 /* Offset in static TLS block */
R_X86_64_TLSGD R_X86_64 = 19 /* PC relative offset to GD GOT entry */
R_X86_64_TLSLD R_X86_64 = 20 /* PC relative offset to LD GOT entry */
R_X86_64_DTPOFF32 R_X86_64 = 21 /* Offset in TLS block */
R_X86_64_GOTTPOFF R_X86_64 = 22 /* PC relative offset to IE GOT entry */
R_X86_64_TPOFF32 R_X86_64 = 23 /* Offset in static TLS block */
R_X86_64_PC64 R_X86_64 = 24 /* PC relative 64-bit sign extended symbol value. */
R_X86_64_GOTOFF64 R_X86_64 = 25
R_X86_64_GOTPC32 R_X86_64 = 26
R_X86_64_GOT64 R_X86_64 = 27
R_X86_64_GOTPCREL64 R_X86_64 = 28
R_X86_64_GOTPC64 R_X86_64 = 29
R_X86_64_GOTPLT64 R_X86_64 = 30
R_X86_64_PLTOFF64 R_X86_64 = 31
R_X86_64_SIZE32 R_X86_64 = 32
R_X86_64_SIZE64 R_X86_64 = 33
R_X86_64_GOTPC32_TLSDESC R_X86_64 = 34
R_X86_64_TLSDESC_CALL R_X86_64 = 35
R_X86_64_TLSDESC R_X86_64 = 36
R_X86_64_IRELATIVE R_X86_64 = 37
R_X86_64_RELATIVE64 R_X86_64 = 38
R_X86_64_PC32_BND R_X86_64 = 39
R_X86_64_PLT32_BND R_X86_64 = 40
R_X86_64_GOTPCRELX R_X86_64 = 41
R_X86_64_REX_GOTPCRELX R_X86_64 = 42
)
var rx86_64Strings = []intName{
{0, "R_X86_64_NONE"},
{1, "R_X86_64_64"},
{2, "R_X86_64_PC32"},
{3, "R_X86_64_GOT32"},
{4, "R_X86_64_PLT32"},
{5, "R_X86_64_COPY"},
{6, "R_X86_64_GLOB_DAT"},
{7, "R_X86_64_JMP_SLOT"},
{8, "R_X86_64_RELATIVE"},
{9, "R_X86_64_GOTPCREL"},
{10, "R_X86_64_32"},
{11, "R_X86_64_32S"},
{12, "R_X86_64_16"},
{13, "R_X86_64_PC16"},
{14, "R_X86_64_8"},
{15, "R_X86_64_PC8"},
{16, "R_X86_64_DTPMOD64"},
{17, "R_X86_64_DTPOFF64"},
{18, "R_X86_64_TPOFF64"},
{19, "R_X86_64_TLSGD"},
{20, "R_X86_64_TLSLD"},
{21, "R_X86_64_DTPOFF32"},
{22, "R_X86_64_GOTTPOFF"},
{23, "R_X86_64_TPOFF32"},
{24, "R_X86_64_PC64"},
{25, "R_X86_64_GOTOFF64"},
{26, "R_X86_64_GOTPC32"},
{27, "R_X86_64_GOT64"},
{28, "R_X86_64_GOTPCREL64"},
{29, "R_X86_64_GOTPC64"},
{30, "R_X86_64_GOTPLT64"},
{31, "R_X86_64_PLTOFF64"},
{32, "R_X86_64_SIZE32"},
{33, "R_X86_64_SIZE64"},
{34, "R_X86_64_GOTPC32_TLSDESC"},
{35, "R_X86_64_TLSDESC_CALL"},
{36, "R_X86_64_TLSDESC"},
{37, "R_X86_64_IRELATIVE"},
{38, "R_X86_64_RELATIVE64"},
{39, "R_X86_64_PC32_BND"},
{40, "R_X86_64_PLT32_BND"},
{41, "R_X86_64_GOTPCRELX"},
{42, "R_X86_64_REX_GOTPCRELX"},
}
func (i R_X86_64) String() string { return stringName(uint32(i), rx86_64Strings, false) }
func (i R_X86_64) GoString() string { return stringName(uint32(i), rx86_64Strings, true) }
// Relocation types for AArch64 (aka arm64)
type R_AARCH64 int
const (
R_AARCH64_NONE R_AARCH64 = 0
R_AARCH64_P32_ABS32 R_AARCH64 = 1
R_AARCH64_P32_ABS16 R_AARCH64 = 2
R_AARCH64_P32_PREL32 R_AARCH64 = 3
R_AARCH64_P32_PREL16 R_AARCH64 = 4
R_AARCH64_P32_MOVW_UABS_G0 R_AARCH64 = 5
R_AARCH64_P32_MOVW_UABS_G0_NC R_AARCH64 = 6
R_AARCH64_P32_MOVW_UABS_G1 R_AARCH64 = 7
R_AARCH64_P32_MOVW_SABS_G0 R_AARCH64 = 8
R_AARCH64_P32_LD_PREL_LO19 R_AARCH64 = 9
R_AARCH64_P32_ADR_PREL_LO21 R_AARCH64 = 10
R_AARCH64_P32_ADR_PREL_PG_HI21 R_AARCH64 = 11
R_AARCH64_P32_ADD_ABS_LO12_NC R_AARCH64 = 12
R_AARCH64_P32_LDST8_ABS_LO12_NC R_AARCH64 = 13
R_AARCH64_P32_LDST16_ABS_LO12_NC R_AARCH64 = 14
R_AARCH64_P32_LDST32_ABS_LO12_NC R_AARCH64 = 15
R_AARCH64_P32_LDST64_ABS_LO12_NC R_AARCH64 = 16
R_AARCH64_P32_LDST128_ABS_LO12_NC R_AARCH64 = 17
R_AARCH64_P32_TSTBR14 R_AARCH64 = 18
R_AARCH64_P32_CONDBR19 R_AARCH64 = 19
R_AARCH64_P32_JUMP26 R_AARCH64 = 20
R_AARCH64_P32_CALL26 R_AARCH64 = 21
R_AARCH64_P32_GOT_LD_PREL19 R_AARCH64 = 25
R_AARCH64_P32_ADR_GOT_PAGE R_AARCH64 = 26
R_AARCH64_P32_LD32_GOT_LO12_NC R_AARCH64 = 27
R_AARCH64_P32_TLSGD_ADR_PAGE21 R_AARCH64 = 81
R_AARCH64_P32_TLSGD_ADD_LO12_NC R_AARCH64 = 82
R_AARCH64_P32_TLSIE_ADR_GOTTPREL_PAGE21 R_AARCH64 = 103
R_AARCH64_P32_TLSIE_LD32_GOTTPREL_LO12_NC R_AARCH64 = 104
R_AARCH64_P32_TLSIE_LD_GOTTPREL_PREL19 R_AARCH64 = 105
R_AARCH64_P32_TLSLE_MOVW_TPREL_G1 R_AARCH64 = 106
R_AARCH64_P32_TLSLE_MOVW_TPREL_G0 R_AARCH64 = 107
R_AARCH64_P32_TLSLE_MOVW_TPREL_G0_NC R_AARCH64 = 108
R_AARCH64_P32_TLSLE_ADD_TPREL_HI12 R_AARCH64 = 109
R_AARCH64_P32_TLSLE_ADD_TPREL_LO12 R_AARCH64 = 110
R_AARCH64_P32_TLSLE_ADD_TPREL_LO12_NC R_AARCH64 = 111
R_AARCH64_P32_TLSDESC_LD_PREL19 R_AARCH64 = 122
R_AARCH64_P32_TLSDESC_ADR_PREL21 R_AARCH64 = 123
R_AARCH64_P32_TLSDESC_ADR_PAGE21 R_AARCH64 = 124
R_AARCH64_P32_TLSDESC_LD32_LO12_NC R_AARCH64 = 125
R_AARCH64_P32_TLSDESC_ADD_LO12_NC R_AARCH64 = 126
R_AARCH64_P32_TLSDESC_CALL R_AARCH64 = 127
R_AARCH64_P32_COPY R_AARCH64 = 180
R_AARCH64_P32_GLOB_DAT R_AARCH64 = 181
R_AARCH64_P32_JUMP_SLOT R_AARCH64 = 182
R_AARCH64_P32_RELATIVE R_AARCH64 = 183
R_AARCH64_P32_TLS_DTPMOD R_AARCH64 = 184
R_AARCH64_P32_TLS_DTPREL R_AARCH64 = 185
R_AARCH64_P32_TLS_TPREL R_AARCH64 = 186
R_AARCH64_P32_TLSDESC R_AARCH64 = 187
R_AARCH64_P32_IRELATIVE R_AARCH64 = 188
R_AARCH64_NULL R_AARCH64 = 256
R_AARCH64_ABS64 R_AARCH64 = 257
R_AARCH64_ABS32 R_AARCH64 = 258
R_AARCH64_ABS16 R_AARCH64 = 259
R_AARCH64_PREL64 R_AARCH64 = 260
R_AARCH64_PREL32 R_AARCH64 = 261
R_AARCH64_PREL16 R_AARCH64 = 262
R_AARCH64_MOVW_UABS_G0 R_AARCH64 = 263
R_AARCH64_MOVW_UABS_G0_NC R_AARCH64 = 264
R_AARCH64_MOVW_UABS_G1 R_AARCH64 = 265
R_AARCH64_MOVW_UABS_G1_NC R_AARCH64 = 266
R_AARCH64_MOVW_UABS_G2 R_AARCH64 = 267
R_AARCH64_MOVW_UABS_G2_NC R_AARCH64 = 268
R_AARCH64_MOVW_UABS_G3 R_AARCH64 = 269
R_AARCH64_MOVW_SABS_G0 R_AARCH64 = 270
R_AARCH64_MOVW_SABS_G1 R_AARCH64 = 271
R_AARCH64_MOVW_SABS_G2 R_AARCH64 = 272
R_AARCH64_LD_PREL_LO19 R_AARCH64 = 273
R_AARCH64_ADR_PREL_LO21 R_AARCH64 = 274
R_AARCH64_ADR_PREL_PG_HI21 R_AARCH64 = 275
R_AARCH64_ADR_PREL_PG_HI21_NC R_AARCH64 = 276
R_AARCH64_ADD_ABS_LO12_NC R_AARCH64 = 277
R_AARCH64_LDST8_ABS_LO12_NC R_AARCH64 = 278
R_AARCH64_TSTBR14 R_AARCH64 = 279
R_AARCH64_CONDBR19 R_AARCH64 = 280
R_AARCH64_JUMP26 R_AARCH64 = 282
R_AARCH64_CALL26 R_AARCH64 = 283
R_AARCH64_LDST16_ABS_LO12_NC R_AARCH64 = 284
R_AARCH64_LDST32_ABS_LO12_NC R_AARCH64 = 285
R_AARCH64_LDST64_ABS_LO12_NC R_AARCH64 = 286
R_AARCH64_LDST128_ABS_LO12_NC R_AARCH64 = 299
R_AARCH64_GOT_LD_PREL19 R_AARCH64 = 309
R_AARCH64_LD64_GOTOFF_LO15 R_AARCH64 = 310
R_AARCH64_ADR_GOT_PAGE R_AARCH64 = 311
R_AARCH64_LD64_GOT_LO12_NC R_AARCH64 = 312
R_AARCH64_LD64_GOTPAGE_LO15 R_AARCH64 = 313
R_AARCH64_TLSGD_ADR_PREL21 R_AARCH64 = 512
R_AARCH64_TLSGD_ADR_PAGE21 R_AARCH64 = 513
R_AARCH64_TLSGD_ADD_LO12_NC R_AARCH64 = 514
R_AARCH64_TLSGD_MOVW_G1 R_AARCH64 = 515
R_AARCH64_TLSGD_MOVW_G0_NC R_AARCH64 = 516
R_AARCH64_TLSLD_ADR_PREL21 R_AARCH64 = 517
R_AARCH64_TLSLD_ADR_PAGE21 R_AARCH64 = 518
R_AARCH64_TLSIE_MOVW_GOTTPREL_G1 R_AARCH64 = 539
R_AARCH64_TLSIE_MOVW_GOTTPREL_G0_NC R_AARCH64 = 540
R_AARCH64_TLSIE_ADR_GOTTPREL_PAGE21 R_AARCH64 = 541
R_AARCH64_TLSIE_LD64_GOTTPREL_LO12_NC R_AARCH64 = 542
R_AARCH64_TLSIE_LD_GOTTPREL_PREL19 R_AARCH64 = 543
R_AARCH64_TLSLE_MOVW_TPREL_G2 R_AARCH64 = 544
R_AARCH64_TLSLE_MOVW_TPREL_G1 R_AARCH64 = 545
R_AARCH64_TLSLE_MOVW_TPREL_G1_NC R_AARCH64 = 546
R_AARCH64_TLSLE_MOVW_TPREL_G0 R_AARCH64 = 547
R_AARCH64_TLSLE_MOVW_TPREL_G0_NC R_AARCH64 = 548
R_AARCH64_TLSLE_ADD_TPREL_HI12 R_AARCH64 = 549
R_AARCH64_TLSLE_ADD_TPREL_LO12 R_AARCH64 = 550
R_AARCH64_TLSLE_ADD_TPREL_LO12_NC R_AARCH64 = 551
R_AARCH64_TLSDESC_LD_PREL19 R_AARCH64 = 560
R_AARCH64_TLSDESC_ADR_PREL21 R_AARCH64 = 561
R_AARCH64_TLSDESC_ADR_PAGE21 R_AARCH64 = 562
R_AARCH64_TLSDESC_LD64_LO12_NC R_AARCH64 = 563
R_AARCH64_TLSDESC_ADD_LO12_NC R_AARCH64 = 564
R_AARCH64_TLSDESC_OFF_G1 R_AARCH64 = 565
R_AARCH64_TLSDESC_OFF_G0_NC R_AARCH64 = 566
R_AARCH64_TLSDESC_LDR R_AARCH64 = 567
R_AARCH64_TLSDESC_ADD R_AARCH64 = 568
R_AARCH64_TLSDESC_CALL R_AARCH64 = 569
R_AARCH64_TLSLE_LDST128_TPREL_LO12 R_AARCH64 = 570
R_AARCH64_TLSLE_LDST128_TPREL_LO12_NC R_AARCH64 = 571
R_AARCH64_TLSLD_LDST128_DTPREL_LO12 R_AARCH64 = 572
R_AARCH64_TLSLD_LDST128_DTPREL_LO12_NC R_AARCH64 = 573
R_AARCH64_COPY R_AARCH64 = 1024
R_AARCH64_GLOB_DAT R_AARCH64 = 1025
R_AARCH64_JUMP_SLOT R_AARCH64 = 1026
R_AARCH64_RELATIVE R_AARCH64 = 1027
R_AARCH64_TLS_DTPMOD64 R_AARCH64 = 1028
R_AARCH64_TLS_DTPREL64 R_AARCH64 = 1029
R_AARCH64_TLS_TPREL64 R_AARCH64 = 1030
R_AARCH64_TLSDESC R_AARCH64 = 1031
R_AARCH64_IRELATIVE R_AARCH64 = 1032
)
var raarch64Strings = []intName{
{0, "R_AARCH64_NONE"},
{1, "R_AARCH64_P32_ABS32"},
{2, "R_AARCH64_P32_ABS16"},
{3, "R_AARCH64_P32_PREL32"},
{4, "R_AARCH64_P32_PREL16"},
{5, "R_AARCH64_P32_MOVW_UABS_G0"},
{6, "R_AARCH64_P32_MOVW_UABS_G0_NC"},
{7, "R_AARCH64_P32_MOVW_UABS_G1"},
{8, "R_AARCH64_P32_MOVW_SABS_G0"},
{9, "R_AARCH64_P32_LD_PREL_LO19"},
{10, "R_AARCH64_P32_ADR_PREL_LO21"},
{11, "R_AARCH64_P32_ADR_PREL_PG_HI21"},
{12, "R_AARCH64_P32_ADD_ABS_LO12_NC"},
{13, "R_AARCH64_P32_LDST8_ABS_LO12_NC"},
{14, "R_AARCH64_P32_LDST16_ABS_LO12_NC"},
{15, "R_AARCH64_P32_LDST32_ABS_LO12_NC"},
{16, "R_AARCH64_P32_LDST64_ABS_LO12_NC"},
{17, "R_AARCH64_P32_LDST128_ABS_LO12_NC"},
{18, "R_AARCH64_P32_TSTBR14"},
{19, "R_AARCH64_P32_CONDBR19"},
{20, "R_AARCH64_P32_JUMP26"},
{21, "R_AARCH64_P32_CALL26"},
{25, "R_AARCH64_P32_GOT_LD_PREL19"},
{26, "R_AARCH64_P32_ADR_GOT_PAGE"},
{27, "R_AARCH64_P32_LD32_GOT_LO12_NC"},
{81, "R_AARCH64_P32_TLSGD_ADR_PAGE21"},
{82, "R_AARCH64_P32_TLSGD_ADD_LO12_NC"},
{103, "R_AARCH64_P32_TLSIE_ADR_GOTTPREL_PAGE21"},
{104, "R_AARCH64_P32_TLSIE_LD32_GOTTPREL_LO12_NC"},
{105, "R_AARCH64_P32_TLSIE_LD_GOTTPREL_PREL19"},
{106, "R_AARCH64_P32_TLSLE_MOVW_TPREL_G1"},
{107, "R_AARCH64_P32_TLSLE_MOVW_TPREL_G0"},
{108, "R_AARCH64_P32_TLSLE_MOVW_TPREL_G0_NC"},
{109, "R_AARCH64_P32_TLSLE_ADD_TPREL_HI12"},
{110, "R_AARCH64_P32_TLSLE_ADD_TPREL_LO12"},
{111, "R_AARCH64_P32_TLSLE_ADD_TPREL_LO12_NC"},
{122, "R_AARCH64_P32_TLSDESC_LD_PREL19"},
{123, "R_AARCH64_P32_TLSDESC_ADR_PREL21"},
{124, "R_AARCH64_P32_TLSDESC_ADR_PAGE21"},
{125, "R_AARCH64_P32_TLSDESC_LD32_LO12_NC"},
{126, "R_AARCH64_P32_TLSDESC_ADD_LO12_NC"},
{127, "R_AARCH64_P32_TLSDESC_CALL"},
{180, "R_AARCH64_P32_COPY"},
{181, "R_AARCH64_P32_GLOB_DAT"},
{182, "R_AARCH64_P32_JUMP_SLOT"},
{183, "R_AARCH64_P32_RELATIVE"},
{184, "R_AARCH64_P32_TLS_DTPMOD"},
{185, "R_AARCH64_P32_TLS_DTPREL"},
{186, "R_AARCH64_P32_TLS_TPREL"},
{187, "R_AARCH64_P32_TLSDESC"},
{188, "R_AARCH64_P32_IRELATIVE"},
{256, "R_AARCH64_NULL"},
{257, "R_AARCH64_ABS64"},
{258, "R_AARCH64_ABS32"},
{259, "R_AARCH64_ABS16"},
{260, "R_AARCH64_PREL64"},
{261, "R_AARCH64_PREL32"},
{262, "R_AARCH64_PREL16"},
{263, "R_AARCH64_MOVW_UABS_G0"},
{264, "R_AARCH64_MOVW_UABS_G0_NC"},
{265, "R_AARCH64_MOVW_UABS_G1"},
{266, "R_AARCH64_MOVW_UABS_G1_NC"},
{267, "R_AARCH64_MOVW_UABS_G2"},
{268, "R_AARCH64_MOVW_UABS_G2_NC"},
{269, "R_AARCH64_MOVW_UABS_G3"},
{270, "R_AARCH64_MOVW_SABS_G0"},
{271, "R_AARCH64_MOVW_SABS_G1"},
{272, "R_AARCH64_MOVW_SABS_G2"},
{273, "R_AARCH64_LD_PREL_LO19"},
{274, "R_AARCH64_ADR_PREL_LO21"},
{275, "R_AARCH64_ADR_PREL_PG_HI21"},
{276, "R_AARCH64_ADR_PREL_PG_HI21_NC"},
{277, "R_AARCH64_ADD_ABS_LO12_NC"},
{278, "R_AARCH64_LDST8_ABS_LO12_NC"},
{279, "R_AARCH64_TSTBR14"},
{280, "R_AARCH64_CONDBR19"},
{282, "R_AARCH64_JUMP26"},
{283, "R_AARCH64_CALL26"},
{284, "R_AARCH64_LDST16_ABS_LO12_NC"},
{285, "R_AARCH64_LDST32_ABS_LO12_NC"},
{286, "R_AARCH64_LDST64_ABS_LO12_NC"},
{299, "R_AARCH64_LDST128_ABS_LO12_NC"},
{309, "R_AARCH64_GOT_LD_PREL19"},
{310, "R_AARCH64_LD64_GOTOFF_LO15"},
{311, "R_AARCH64_ADR_GOT_PAGE"},
{312, "R_AARCH64_LD64_GOT_LO12_NC"},
{313, "R_AARCH64_LD64_GOTPAGE_LO15"},
{512, "R_AARCH64_TLSGD_ADR_PREL21"},
{513, "R_AARCH64_TLSGD_ADR_PAGE21"},
{514, "R_AARCH64_TLSGD_ADD_LO12_NC"},
{515, "R_AARCH64_TLSGD_MOVW_G1"},
{516, "R_AARCH64_TLSGD_MOVW_G0_NC"},
{517, "R_AARCH64_TLSLD_ADR_PREL21"},
{518, "R_AARCH64_TLSLD_ADR_PAGE21"},
{539, "R_AARCH64_TLSIE_MOVW_GOTTPREL_G1"},
{540, "R_AARCH64_TLSIE_MOVW_GOTTPREL_G0_NC"},
{541, "R_AARCH64_TLSIE_ADR_GOTTPREL_PAGE21"},
{542, "R_AARCH64_TLSIE_LD64_GOTTPREL_LO12_NC"},
{543, "R_AARCH64_TLSIE_LD_GOTTPREL_PREL19"},
{544, "R_AARCH64_TLSLE_MOVW_TPREL_G2"},
{545, "R_AARCH64_TLSLE_MOVW_TPREL_G1"},
{546, "R_AARCH64_TLSLE_MOVW_TPREL_G1_NC"},
{547, "R_AARCH64_TLSLE_MOVW_TPREL_G0"},
{548, "R_AARCH64_TLSLE_MOVW_TPREL_G0_NC"},
{549, "R_AARCH64_TLSLE_ADD_TPREL_HI12"},
{550, "R_AARCH64_TLSLE_ADD_TPREL_LO12"},
{551, "R_AARCH64_TLSLE_ADD_TPREL_LO12_NC"},
{560, "R_AARCH64_TLSDESC_LD_PREL19"},
{561, "R_AARCH64_TLSDESC_ADR_PREL21"},
{562, "R_AARCH64_TLSDESC_ADR_PAGE21"},
{563, "R_AARCH64_TLSDESC_LD64_LO12_NC"},
{564, "R_AARCH64_TLSDESC_ADD_LO12_NC"},
{565, "R_AARCH64_TLSDESC_OFF_G1"},
{566, "R_AARCH64_TLSDESC_OFF_G0_NC"},
{567, "R_AARCH64_TLSDESC_LDR"},
{568, "R_AARCH64_TLSDESC_ADD"},
{569, "R_AARCH64_TLSDESC_CALL"},
{570, "R_AARCH64_TLSLE_LDST128_TPREL_LO12"},
{571, "R_AARCH64_TLSLE_LDST128_TPREL_LO12_NC"},
{572, "R_AARCH64_TLSLD_LDST128_DTPREL_LO12"},
{573, "R_AARCH64_TLSLD_LDST128_DTPREL_LO12_NC"},
{1024, "R_AARCH64_COPY"},
{1025, "R_AARCH64_GLOB_DAT"},
{1026, "R_AARCH64_JUMP_SLOT"},
{1027, "R_AARCH64_RELATIVE"},
{1028, "R_AARCH64_TLS_DTPMOD64"},
{1029, "R_AARCH64_TLS_DTPREL64"},
{1030, "R_AARCH64_TLS_TPREL64"},
{1031, "R_AARCH64_TLSDESC"},
{1032, "R_AARCH64_IRELATIVE"},
}
func (i R_AARCH64) String() string { return stringName(uint32(i), raarch64Strings, false) }
func (i R_AARCH64) GoString() string { return stringName(uint32(i), raarch64Strings, true) }
// Relocation types for Alpha.
type R_ALPHA int
const (
R_ALPHA_NONE R_ALPHA = 0 /* No reloc */
R_ALPHA_REFLONG R_ALPHA = 1 /* Direct 32 bit */
R_ALPHA_REFQUAD R_ALPHA = 2 /* Direct 64 bit */
R_ALPHA_GPREL32 R_ALPHA = 3 /* GP relative 32 bit */
R_ALPHA_LITERAL R_ALPHA = 4 /* GP relative 16 bit w/optimization */
R_ALPHA_LITUSE R_ALPHA = 5 /* Optimization hint for LITERAL */
R_ALPHA_GPDISP R_ALPHA = 6 /* Add displacement to GP */
R_ALPHA_BRADDR R_ALPHA = 7 /* PC+4 relative 23 bit shifted */
R_ALPHA_HINT R_ALPHA = 8 /* PC+4 relative 16 bit shifted */
R_ALPHA_SREL16 R_ALPHA = 9 /* PC relative 16 bit */
R_ALPHA_SREL32 R_ALPHA = 10 /* PC relative 32 bit */
R_ALPHA_SREL64 R_ALPHA = 11 /* PC relative 64 bit */
R_ALPHA_OP_PUSH R_ALPHA = 12 /* OP stack push */
R_ALPHA_OP_STORE R_ALPHA = 13 /* OP stack pop and store */
R_ALPHA_OP_PSUB R_ALPHA = 14 /* OP stack subtract */
R_ALPHA_OP_PRSHIFT R_ALPHA = 15 /* OP stack right shift */
R_ALPHA_GPVALUE R_ALPHA = 16
R_ALPHA_GPRELHIGH R_ALPHA = 17
R_ALPHA_GPRELLOW R_ALPHA = 18
R_ALPHA_IMMED_GP_16 R_ALPHA = 19
R_ALPHA_IMMED_GP_HI32 R_ALPHA = 20
R_ALPHA_IMMED_SCN_HI32 R_ALPHA = 21
R_ALPHA_IMMED_BR_HI32 R_ALPHA = 22
R_ALPHA_IMMED_LO32 R_ALPHA = 23
R_ALPHA_COPY R_ALPHA = 24 /* Copy symbol at runtime */
R_ALPHA_GLOB_DAT R_ALPHA = 25 /* Create GOT entry */
R_ALPHA_JMP_SLOT R_ALPHA = 26 /* Create PLT entry */
R_ALPHA_RELATIVE R_ALPHA = 27 /* Adjust by program base */
)
var ralphaStrings = []intName{
{0, "R_ALPHA_NONE"},
{1, "R_ALPHA_REFLONG"},
{2, "R_ALPHA_REFQUAD"},
{3, "R_ALPHA_GPREL32"},
{4, "R_ALPHA_LITERAL"},
{5, "R_ALPHA_LITUSE"},
{6, "R_ALPHA_GPDISP"},
{7, "R_ALPHA_BRADDR"},
{8, "R_ALPHA_HINT"},
{9, "R_ALPHA_SREL16"},
{10, "R_ALPHA_SREL32"},
{11, "R_ALPHA_SREL64"},
{12, "R_ALPHA_OP_PUSH"},
{13, "R_ALPHA_OP_STORE"},
{14, "R_ALPHA_OP_PSUB"},
{15, "R_ALPHA_OP_PRSHIFT"},
{16, "R_ALPHA_GPVALUE"},
{17, "R_ALPHA_GPRELHIGH"},
{18, "R_ALPHA_GPRELLOW"},
{19, "R_ALPHA_IMMED_GP_16"},
{20, "R_ALPHA_IMMED_GP_HI32"},
{21, "R_ALPHA_IMMED_SCN_HI32"},
{22, "R_ALPHA_IMMED_BR_HI32"},
{23, "R_ALPHA_IMMED_LO32"},
{24, "R_ALPHA_COPY"},
{25, "R_ALPHA_GLOB_DAT"},
{26, "R_ALPHA_JMP_SLOT"},
{27, "R_ALPHA_RELATIVE"},
}
func (i R_ALPHA) String() string { return stringName(uint32(i), ralphaStrings, false) }
func (i R_ALPHA) GoString() string { return stringName(uint32(i), ralphaStrings, true) }
// Relocation types for ARM.
type R_ARM int
const (
R_ARM_NONE R_ARM = 0 /* No relocation. */
R_ARM_PC24 R_ARM = 1
R_ARM_ABS32 R_ARM = 2
R_ARM_REL32 R_ARM = 3
R_ARM_PC13 R_ARM = 4
R_ARM_ABS16 R_ARM = 5
R_ARM_ABS12 R_ARM = 6
R_ARM_THM_ABS5 R_ARM = 7
R_ARM_ABS8 R_ARM = 8
R_ARM_SBREL32 R_ARM = 9
R_ARM_THM_PC22 R_ARM = 10
R_ARM_THM_PC8 R_ARM = 11
R_ARM_AMP_VCALL9 R_ARM = 12
R_ARM_SWI24 R_ARM = 13
R_ARM_THM_SWI8 R_ARM = 14
R_ARM_XPC25 R_ARM = 15
R_ARM_THM_XPC22 R_ARM = 16
R_ARM_TLS_DTPMOD32 R_ARM = 17
R_ARM_TLS_DTPOFF32 R_ARM = 18
R_ARM_TLS_TPOFF32 R_ARM = 19
R_ARM_COPY R_ARM = 20 /* Copy data from shared object. */
R_ARM_GLOB_DAT R_ARM = 21 /* Set GOT entry to data address. */
R_ARM_JUMP_SLOT R_ARM = 22 /* Set GOT entry to code address. */
R_ARM_RELATIVE R_ARM = 23 /* Add load address of shared object. */
R_ARM_GOTOFF R_ARM = 24 /* Add GOT-relative symbol address. */
R_ARM_GOTPC R_ARM = 25 /* Add PC-relative GOT table address. */
R_ARM_GOT32 R_ARM = 26 /* Add PC-relative GOT offset. */
R_ARM_PLT32 R_ARM = 27 /* Add PC-relative PLT offset. */
R_ARM_CALL R_ARM = 28
R_ARM_JUMP24 R_ARM = 29
R_ARM_THM_JUMP24 R_ARM = 30
R_ARM_BASE_ABS R_ARM = 31
R_ARM_ALU_PCREL_7_0 R_ARM = 32
R_ARM_ALU_PCREL_15_8 R_ARM = 33
R_ARM_ALU_PCREL_23_15 R_ARM = 34
R_ARM_LDR_SBREL_11_10_NC R_ARM = 35
R_ARM_ALU_SBREL_19_12_NC R_ARM = 36
R_ARM_ALU_SBREL_27_20_CK R_ARM = 37
R_ARM_TARGET1 R_ARM = 38
R_ARM_SBREL31 R_ARM = 39
R_ARM_V4BX R_ARM = 40
R_ARM_TARGET2 R_ARM = 41
R_ARM_PREL31 R_ARM = 42
R_ARM_MOVW_ABS_NC R_ARM = 43
R_ARM_MOVT_ABS R_ARM = 44
R_ARM_MOVW_PREL_NC R_ARM = 45
R_ARM_MOVT_PREL R_ARM = 46
R_ARM_THM_MOVW_ABS_NC R_ARM = 47
R_ARM_THM_MOVT_ABS R_ARM = 48
R_ARM_THM_MOVW_PREL_NC R_ARM = 49
R_ARM_THM_MOVT_PREL R_ARM = 50
R_ARM_THM_JUMP19 R_ARM = 51
R_ARM_THM_JUMP6 R_ARM = 52
R_ARM_THM_ALU_PREL_11_0 R_ARM = 53
R_ARM_THM_PC12 R_ARM = 54
R_ARM_ABS32_NOI R_ARM = 55
R_ARM_REL32_NOI R_ARM = 56
R_ARM_ALU_PC_G0_NC R_ARM = 57
R_ARM_ALU_PC_G0 R_ARM = 58
R_ARM_ALU_PC_G1_NC R_ARM = 59
R_ARM_ALU_PC_G1 R_ARM = 60
R_ARM_ALU_PC_G2 R_ARM = 61
R_ARM_LDR_PC_G1 R_ARM = 62
R_ARM_LDR_PC_G2 R_ARM = 63
R_ARM_LDRS_PC_G0 R_ARM = 64
R_ARM_LDRS_PC_G1 R_ARM = 65
R_ARM_LDRS_PC_G2 R_ARM = 66
R_ARM_LDC_PC_G0 R_ARM = 67
R_ARM_LDC_PC_G1 R_ARM = 68
R_ARM_LDC_PC_G2 R_ARM = 69
R_ARM_ALU_SB_G0_NC R_ARM = 70
R_ARM_ALU_SB_G0 R_ARM = 71
R_ARM_ALU_SB_G1_NC R_ARM = 72
R_ARM_ALU_SB_G1 R_ARM = 73
R_ARM_ALU_SB_G2 R_ARM = 74
R_ARM_LDR_SB_G0 R_ARM = 75
R_ARM_LDR_SB_G1 R_ARM = 76
R_ARM_LDR_SB_G2 R_ARM = 77
R_ARM_LDRS_SB_G0 R_ARM = 78
R_ARM_LDRS_SB_G1 R_ARM = 79
R_ARM_LDRS_SB_G2 R_ARM = 80
R_ARM_LDC_SB_G0 R_ARM = 81
R_ARM_LDC_SB_G1 R_ARM = 82
R_ARM_LDC_SB_G2 R_ARM = 83
R_ARM_MOVW_BREL_NC R_ARM = 84
R_ARM_MOVT_BREL R_ARM = 85
R_ARM_MOVW_BREL R_ARM = 86
R_ARM_THM_MOVW_BREL_NC R_ARM = 87
R_ARM_THM_MOVT_BREL R_ARM = 88
R_ARM_THM_MOVW_BREL R_ARM = 89
R_ARM_TLS_GOTDESC R_ARM = 90
R_ARM_TLS_CALL R_ARM = 91
R_ARM_TLS_DESCSEQ R_ARM = 92
R_ARM_THM_TLS_CALL R_ARM = 93
R_ARM_PLT32_ABS R_ARM = 94
R_ARM_GOT_ABS R_ARM = 95
R_ARM_GOT_PREL R_ARM = 96
R_ARM_GOT_BREL12 R_ARM = 97
R_ARM_GOTOFF12 R_ARM = 98
R_ARM_GOTRELAX R_ARM = 99
R_ARM_GNU_VTENTRY R_ARM = 100
R_ARM_GNU_VTINHERIT R_ARM = 101
R_ARM_THM_JUMP11 R_ARM = 102
R_ARM_THM_JUMP8 R_ARM = 103
R_ARM_TLS_GD32 R_ARM = 104
R_ARM_TLS_LDM32 R_ARM = 105
R_ARM_TLS_LDO32 R_ARM = 106
R_ARM_TLS_IE32 R_ARM = 107
R_ARM_TLS_LE32 R_ARM = 108
R_ARM_TLS_LDO12 R_ARM = 109
R_ARM_TLS_LE12 R_ARM = 110
R_ARM_TLS_IE12GP R_ARM = 111
R_ARM_PRIVATE_0 R_ARM = 112
R_ARM_PRIVATE_1 R_ARM = 113
R_ARM_PRIVATE_2 R_ARM = 114
R_ARM_PRIVATE_3 R_ARM = 115
R_ARM_PRIVATE_4 R_ARM = 116
R_ARM_PRIVATE_5 R_ARM = 117
R_ARM_PRIVATE_6 R_ARM = 118
R_ARM_PRIVATE_7 R_ARM = 119
R_ARM_PRIVATE_8 R_ARM = 120
R_ARM_PRIVATE_9 R_ARM = 121
R_ARM_PRIVATE_10 R_ARM = 122
R_ARM_PRIVATE_11 R_ARM = 123
R_ARM_PRIVATE_12 R_ARM = 124
R_ARM_PRIVATE_13 R_ARM = 125
R_ARM_PRIVATE_14 R_ARM = 126
R_ARM_PRIVATE_15 R_ARM = 127
R_ARM_ME_TOO R_ARM = 128
R_ARM_THM_TLS_DESCSEQ16 R_ARM = 129
R_ARM_THM_TLS_DESCSEQ32 R_ARM = 130
R_ARM_THM_GOT_BREL12 R_ARM = 131
R_ARM_THM_ALU_ABS_G0_NC R_ARM = 132
R_ARM_THM_ALU_ABS_G1_NC R_ARM = 133
R_ARM_THM_ALU_ABS_G2_NC R_ARM = 134
R_ARM_THM_ALU_ABS_G3 R_ARM = 135
R_ARM_IRELATIVE R_ARM = 160
R_ARM_RXPC25 R_ARM = 249
R_ARM_RSBREL32 R_ARM = 250
R_ARM_THM_RPC22 R_ARM = 251
R_ARM_RREL32 R_ARM = 252
R_ARM_RABS32 R_ARM = 253
R_ARM_RPC24 R_ARM = 254
R_ARM_RBASE R_ARM = 255
)
var rarmStrings = []intName{
{0, "R_ARM_NONE"},
{1, "R_ARM_PC24"},
{2, "R_ARM_ABS32"},
{3, "R_ARM_REL32"},
{4, "R_ARM_PC13"},
{5, "R_ARM_ABS16"},
{6, "R_ARM_ABS12"},
{7, "R_ARM_THM_ABS5"},
{8, "R_ARM_ABS8"},
{9, "R_ARM_SBREL32"},
{10, "R_ARM_THM_PC22"},
{11, "R_ARM_THM_PC8"},
{12, "R_ARM_AMP_VCALL9"},
{13, "R_ARM_SWI24"},
{14, "R_ARM_THM_SWI8"},
{15, "R_ARM_XPC25"},
{16, "R_ARM_THM_XPC22"},
{17, "R_ARM_TLS_DTPMOD32"},
{18, "R_ARM_TLS_DTPOFF32"},
{19, "R_ARM_TLS_TPOFF32"},
{20, "R_ARM_COPY"},
{21, "R_ARM_GLOB_DAT"},
{22, "R_ARM_JUMP_SLOT"},
{23, "R_ARM_RELATIVE"},
{24, "R_ARM_GOTOFF"},
{25, "R_ARM_GOTPC"},
{26, "R_ARM_GOT32"},
{27, "R_ARM_PLT32"},
{28, "R_ARM_CALL"},
{29, "R_ARM_JUMP24"},
{30, "R_ARM_THM_JUMP24"},
{31, "R_ARM_BASE_ABS"},
{32, "R_ARM_ALU_PCREL_7_0"},
{33, "R_ARM_ALU_PCREL_15_8"},
{34, "R_ARM_ALU_PCREL_23_15"},
{35, "R_ARM_LDR_SBREL_11_10_NC"},
{36, "R_ARM_ALU_SBREL_19_12_NC"},
{37, "R_ARM_ALU_SBREL_27_20_CK"},
{38, "R_ARM_TARGET1"},
{39, "R_ARM_SBREL31"},
{40, "R_ARM_V4BX"},
{41, "R_ARM_TARGET2"},
{42, "R_ARM_PREL31"},
{43, "R_ARM_MOVW_ABS_NC"},
{44, "R_ARM_MOVT_ABS"},
{45, "R_ARM_MOVW_PREL_NC"},
{46, "R_ARM_MOVT_PREL"},
{47, "R_ARM_THM_MOVW_ABS_NC"},
{48, "R_ARM_THM_MOVT_ABS"},
{49, "R_ARM_THM_MOVW_PREL_NC"},
{50, "R_ARM_THM_MOVT_PREL"},
{51, "R_ARM_THM_JUMP19"},
{52, "R_ARM_THM_JUMP6"},
{53, "R_ARM_THM_ALU_PREL_11_0"},
{54, "R_ARM_THM_PC12"},
{55, "R_ARM_ABS32_NOI"},
{56, "R_ARM_REL32_NOI"},
{57, "R_ARM_ALU_PC_G0_NC"},
{58, "R_ARM_ALU_PC_G0"},
{59, "R_ARM_ALU_PC_G1_NC"},
{60, "R_ARM_ALU_PC_G1"},
{61, "R_ARM_ALU_PC_G2"},
{62, "R_ARM_LDR_PC_G1"},
{63, "R_ARM_LDR_PC_G2"},
{64, "R_ARM_LDRS_PC_G0"},
{65, "R_ARM_LDRS_PC_G1"},
{66, "R_ARM_LDRS_PC_G2"},
{67, "R_ARM_LDC_PC_G0"},
{68, "R_ARM_LDC_PC_G1"},
{69, "R_ARM_LDC_PC_G2"},
{70, "R_ARM_ALU_SB_G0_NC"},
{71, "R_ARM_ALU_SB_G0"},
{72, "R_ARM_ALU_SB_G1_NC"},
{73, "R_ARM_ALU_SB_G1"},
{74, "R_ARM_ALU_SB_G2"},
{75, "R_ARM_LDR_SB_G0"},
{76, "R_ARM_LDR_SB_G1"},
{77, "R_ARM_LDR_SB_G2"},
{78, "R_ARM_LDRS_SB_G0"},
{79, "R_ARM_LDRS_SB_G1"},
{80, "R_ARM_LDRS_SB_G2"},
{81, "R_ARM_LDC_SB_G0"},
{82, "R_ARM_LDC_SB_G1"},
{83, "R_ARM_LDC_SB_G2"},
{84, "R_ARM_MOVW_BREL_NC"},
{85, "R_ARM_MOVT_BREL"},
{86, "R_ARM_MOVW_BREL"},
{87, "R_ARM_THM_MOVW_BREL_NC"},
{88, "R_ARM_THM_MOVT_BREL"},
{89, "R_ARM_THM_MOVW_BREL"},
{90, "R_ARM_TLS_GOTDESC"},
{91, "R_ARM_TLS_CALL"},
{92, "R_ARM_TLS_DESCSEQ"},
{93, "R_ARM_THM_TLS_CALL"},
{94, "R_ARM_PLT32_ABS"},
{95, "R_ARM_GOT_ABS"},
{96, "R_ARM_GOT_PREL"},
{97, "R_ARM_GOT_BREL12"},
{98, "R_ARM_GOTOFF12"},
{99, "R_ARM_GOTRELAX"},
{100, "R_ARM_GNU_VTENTRY"},
{101, "R_ARM_GNU_VTINHERIT"},
{102, "R_ARM_THM_JUMP11"},
{103, "R_ARM_THM_JUMP8"},
{104, "R_ARM_TLS_GD32"},
{105, "R_ARM_TLS_LDM32"},
{106, "R_ARM_TLS_LDO32"},
{107, "R_ARM_TLS_IE32"},
{108, "R_ARM_TLS_LE32"},
{109, "R_ARM_TLS_LDO12"},
{110, "R_ARM_TLS_LE12"},
{111, "R_ARM_TLS_IE12GP"},
{112, "R_ARM_PRIVATE_0"},
{113, "R_ARM_PRIVATE_1"},
{114, "R_ARM_PRIVATE_2"},
{115, "R_ARM_PRIVATE_3"},
{116, "R_ARM_PRIVATE_4"},
{117, "R_ARM_PRIVATE_5"},
{118, "R_ARM_PRIVATE_6"},
{119, "R_ARM_PRIVATE_7"},
{120, "R_ARM_PRIVATE_8"},
{121, "R_ARM_PRIVATE_9"},
{122, "R_ARM_PRIVATE_10"},
{123, "R_ARM_PRIVATE_11"},
{124, "R_ARM_PRIVATE_12"},
{125, "R_ARM_PRIVATE_13"},
{126, "R_ARM_PRIVATE_14"},
{127, "R_ARM_PRIVATE_15"},
{128, "R_ARM_ME_TOO"},
{129, "R_ARM_THM_TLS_DESCSEQ16"},
{130, "R_ARM_THM_TLS_DESCSEQ32"},
{131, "R_ARM_THM_GOT_BREL12"},
{132, "R_ARM_THM_ALU_ABS_G0_NC"},
{133, "R_ARM_THM_ALU_ABS_G1_NC"},
{134, "R_ARM_THM_ALU_ABS_G2_NC"},
{135, "R_ARM_THM_ALU_ABS_G3"},
{160, "R_ARM_IRELATIVE"},
{249, "R_ARM_RXPC25"},
{250, "R_ARM_RSBREL32"},
{251, "R_ARM_THM_RPC22"},
{252, "R_ARM_RREL32"},
{253, "R_ARM_RABS32"},
{254, "R_ARM_RPC24"},
{255, "R_ARM_RBASE"},
}
func (i R_ARM) String() string { return stringName(uint32(i), rarmStrings, false) }
func (i R_ARM) GoString() string { return stringName(uint32(i), rarmStrings, true) }
// Relocation types for 386.
type R_386 int
const (
R_386_NONE R_386 = 0 /* No relocation. */
R_386_32 R_386 = 1 /* Add symbol value. */
R_386_PC32 R_386 = 2 /* Add PC-relative symbol value. */
R_386_GOT32 R_386 = 3 /* Add PC-relative GOT offset. */
R_386_PLT32 R_386 = 4 /* Add PC-relative PLT offset. */
R_386_COPY R_386 = 5 /* Copy data from shared object. */
R_386_GLOB_DAT R_386 = 6 /* Set GOT entry to data address. */
R_386_JMP_SLOT R_386 = 7 /* Set GOT entry to code address. */
R_386_RELATIVE R_386 = 8 /* Add load address of shared object. */
R_386_GOTOFF R_386 = 9 /* Add GOT-relative symbol address. */
R_386_GOTPC R_386 = 10 /* Add PC-relative GOT table address. */
R_386_32PLT R_386 = 11
R_386_TLS_TPOFF R_386 = 14 /* Negative offset in static TLS block */
R_386_TLS_IE R_386 = 15 /* Absolute address of GOT for -ve static TLS */
R_386_TLS_GOTIE R_386 = 16 /* GOT entry for negative static TLS block */
R_386_TLS_LE R_386 = 17 /* Negative offset relative to static TLS */
R_386_TLS_GD R_386 = 18 /* 32 bit offset to GOT (index,off) pair */
R_386_TLS_LDM R_386 = 19 /* 32 bit offset to GOT (index,zero) pair */
R_386_16 R_386 = 20
R_386_PC16 R_386 = 21
R_386_8 R_386 = 22
R_386_PC8 R_386 = 23
R_386_TLS_GD_32 R_386 = 24 /* 32 bit offset to GOT (index,off) pair */
R_386_TLS_GD_PUSH R_386 = 25 /* pushl instruction for Sun ABI GD sequence */
R_386_TLS_GD_CALL R_386 = 26 /* call instruction for Sun ABI GD sequence */
R_386_TLS_GD_POP R_386 = 27 /* popl instruction for Sun ABI GD sequence */
R_386_TLS_LDM_32 R_386 = 28 /* 32 bit offset to GOT (index,zero) pair */
R_386_TLS_LDM_PUSH R_386 = 29 /* pushl instruction for Sun ABI LD sequence */
R_386_TLS_LDM_CALL R_386 = 30 /* call instruction for Sun ABI LD sequence */
R_386_TLS_LDM_POP R_386 = 31 /* popl instruction for Sun ABI LD sequence */
R_386_TLS_LDO_32 R_386 = 32 /* 32 bit offset from start of TLS block */
R_386_TLS_IE_32 R_386 = 33 /* 32 bit offset to GOT static TLS offset entry */
R_386_TLS_LE_32 R_386 = 34 /* 32 bit offset within static TLS block */
R_386_TLS_DTPMOD32 R_386 = 35 /* GOT entry containing TLS index */
R_386_TLS_DTPOFF32 R_386 = 36 /* GOT entry containing TLS offset */
R_386_TLS_TPOFF32 R_386 = 37 /* GOT entry of -ve static TLS offset */
R_386_SIZE32 R_386 = 38
R_386_TLS_GOTDESC R_386 = 39
R_386_TLS_DESC_CALL R_386 = 40
R_386_TLS_DESC R_386 = 41
R_386_IRELATIVE R_386 = 42
R_386_GOT32X R_386 = 43
)
var r386Strings = []intName{
{0, "R_386_NONE"},
{1, "R_386_32"},
{2, "R_386_PC32"},
{3, "R_386_GOT32"},
{4, "R_386_PLT32"},
{5, "R_386_COPY"},
{6, "R_386_GLOB_DAT"},
{7, "R_386_JMP_SLOT"},
{8, "R_386_RELATIVE"},
{9, "R_386_GOTOFF"},
{10, "R_386_GOTPC"},
{11, "R_386_32PLT"},
{14, "R_386_TLS_TPOFF"},
{15, "R_386_TLS_IE"},
{16, "R_386_TLS_GOTIE"},
{17, "R_386_TLS_LE"},
{18, "R_386_TLS_GD"},
{19, "R_386_TLS_LDM"},
{20, "R_386_16"},
{21, "R_386_PC16"},
{22, "R_386_8"},
{23, "R_386_PC8"},
{24, "R_386_TLS_GD_32"},
{25, "R_386_TLS_GD_PUSH"},
{26, "R_386_TLS_GD_CALL"},
{27, "R_386_TLS_GD_POP"},
{28, "R_386_TLS_LDM_32"},
{29, "R_386_TLS_LDM_PUSH"},
{30, "R_386_TLS_LDM_CALL"},
{31, "R_386_TLS_LDM_POP"},
{32, "R_386_TLS_LDO_32"},
{33, "R_386_TLS_IE_32"},
{34, "R_386_TLS_LE_32"},
{35, "R_386_TLS_DTPMOD32"},
{36, "R_386_TLS_DTPOFF32"},
{37, "R_386_TLS_TPOFF32"},
{38, "R_386_SIZE32"},
{39, "R_386_TLS_GOTDESC"},
{40, "R_386_TLS_DESC_CALL"},
{41, "R_386_TLS_DESC"},
{42, "R_386_IRELATIVE"},
{43, "R_386_GOT32X"},
}
func (i R_386) String() string { return stringName(uint32(i), r386Strings, false) }
func (i R_386) GoString() string { return stringName(uint32(i), r386Strings, true) }
// Relocation types for MIPS.
type R_MIPS int
const (
R_MIPS_NONE R_MIPS = 0
R_MIPS_16 R_MIPS = 1
R_MIPS_32 R_MIPS = 2
R_MIPS_REL32 R_MIPS = 3
R_MIPS_26 R_MIPS = 4
R_MIPS_HI16 R_MIPS = 5 /* high 16 bits of symbol value */
R_MIPS_LO16 R_MIPS = 6 /* low 16 bits of symbol value */
R_MIPS_GPREL16 R_MIPS = 7 /* GP-relative reference */
R_MIPS_LITERAL R_MIPS = 8 /* Reference to literal section */
R_MIPS_GOT16 R_MIPS = 9 /* Reference to global offset table */
R_MIPS_PC16 R_MIPS = 10 /* 16 bit PC relative reference */
R_MIPS_CALL16 R_MIPS = 11 /* 16 bit call through glbl offset tbl */
R_MIPS_GPREL32 R_MIPS = 12
R_MIPS_SHIFT5 R_MIPS = 16
R_MIPS_SHIFT6 R_MIPS = 17
R_MIPS_64 R_MIPS = 18
R_MIPS_GOT_DISP R_MIPS = 19
R_MIPS_GOT_PAGE R_MIPS = 20
R_MIPS_GOT_OFST R_MIPS = 21
R_MIPS_GOT_HI16 R_MIPS = 22
R_MIPS_GOT_LO16 R_MIPS = 23
R_MIPS_SUB R_MIPS = 24
R_MIPS_INSERT_A R_MIPS = 25
R_MIPS_INSERT_B R_MIPS = 26
R_MIPS_DELETE R_MIPS = 27
R_MIPS_HIGHER R_MIPS = 28
R_MIPS_HIGHEST R_MIPS = 29
R_MIPS_CALL_HI16 R_MIPS = 30
R_MIPS_CALL_LO16 R_MIPS = 31
R_MIPS_SCN_DISP R_MIPS = 32
R_MIPS_REL16 R_MIPS = 33
R_MIPS_ADD_IMMEDIATE R_MIPS = 34
R_MIPS_PJUMP R_MIPS = 35
R_MIPS_RELGOT R_MIPS = 36
R_MIPS_JALR R_MIPS = 37
R_MIPS_TLS_DTPMOD32 R_MIPS = 38 /* Module number 32 bit */
R_MIPS_TLS_DTPREL32 R_MIPS = 39 /* Module-relative offset 32 bit */
R_MIPS_TLS_DTPMOD64 R_MIPS = 40 /* Module number 64 bit */
R_MIPS_TLS_DTPREL64 R_MIPS = 41 /* Module-relative offset 64 bit */
R_MIPS_TLS_GD R_MIPS = 42 /* 16 bit GOT offset for GD */
R_MIPS_TLS_LDM R_MIPS = 43 /* 16 bit GOT offset for LDM */
R_MIPS_TLS_DTPREL_HI16 R_MIPS = 44 /* Module-relative offset, high 16 bits */
R_MIPS_TLS_DTPREL_LO16 R_MIPS = 45 /* Module-relative offset, low 16 bits */
R_MIPS_TLS_GOTTPREL R_MIPS = 46 /* 16 bit GOT offset for IE */
R_MIPS_TLS_TPREL32 R_MIPS = 47 /* TP-relative offset, 32 bit */
R_MIPS_TLS_TPREL64 R_MIPS = 48 /* TP-relative offset, 64 bit */
R_MIPS_TLS_TPREL_HI16 R_MIPS = 49 /* TP-relative offset, high 16 bits */
R_MIPS_TLS_TPREL_LO16 R_MIPS = 50 /* TP-relative offset, low 16 bits */
R_MIPS_PC32 R_MIPS = 248 /* 32 bit PC relative reference */
)
var rmipsStrings = []intName{
{0, "R_MIPS_NONE"},
{1, "R_MIPS_16"},
{2, "R_MIPS_32"},
{3, "R_MIPS_REL32"},
{4, "R_MIPS_26"},
{5, "R_MIPS_HI16"},
{6, "R_MIPS_LO16"},
{7, "R_MIPS_GPREL16"},
{8, "R_MIPS_LITERAL"},
{9, "R_MIPS_GOT16"},
{10, "R_MIPS_PC16"},
{11, "R_MIPS_CALL16"},
{12, "R_MIPS_GPREL32"},
{16, "R_MIPS_SHIFT5"},
{17, "R_MIPS_SHIFT6"},
{18, "R_MIPS_64"},
{19, "R_MIPS_GOT_DISP"},
{20, "R_MIPS_GOT_PAGE"},
{21, "R_MIPS_GOT_OFST"},
{22, "R_MIPS_GOT_HI16"},
{23, "R_MIPS_GOT_LO16"},
{24, "R_MIPS_SUB"},
{25, "R_MIPS_INSERT_A"},
{26, "R_MIPS_INSERT_B"},
{27, "R_MIPS_DELETE"},
{28, "R_MIPS_HIGHER"},
{29, "R_MIPS_HIGHEST"},
{30, "R_MIPS_CALL_HI16"},
{31, "R_MIPS_CALL_LO16"},
{32, "R_MIPS_SCN_DISP"},
{33, "R_MIPS_REL16"},
{34, "R_MIPS_ADD_IMMEDIATE"},
{35, "R_MIPS_PJUMP"},
{36, "R_MIPS_RELGOT"},
{37, "R_MIPS_JALR"},
{38, "R_MIPS_TLS_DTPMOD32"},
{39, "R_MIPS_TLS_DTPREL32"},
{40, "R_MIPS_TLS_DTPMOD64"},
{41, "R_MIPS_TLS_DTPREL64"},
{42, "R_MIPS_TLS_GD"},
{43, "R_MIPS_TLS_LDM"},
{44, "R_MIPS_TLS_DTPREL_HI16"},
{45, "R_MIPS_TLS_DTPREL_LO16"},
{46, "R_MIPS_TLS_GOTTPREL"},
{47, "R_MIPS_TLS_TPREL32"},
{48, "R_MIPS_TLS_TPREL64"},
{49, "R_MIPS_TLS_TPREL_HI16"},
{50, "R_MIPS_TLS_TPREL_LO16"},
{248, "R_MIPS_PC32"},
}
func (i R_MIPS) String() string { return stringName(uint32(i), rmipsStrings, false) }
func (i R_MIPS) GoString() string { return stringName(uint32(i), rmipsStrings, true) }
// Relocation types for LoongArch.
type R_LARCH int
const (
R_LARCH_NONE R_LARCH = 0
R_LARCH_32 R_LARCH = 1
R_LARCH_64 R_LARCH = 2
R_LARCH_RELATIVE R_LARCH = 3
R_LARCH_COPY R_LARCH = 4
R_LARCH_JUMP_SLOT R_LARCH = 5
R_LARCH_TLS_DTPMOD32 R_LARCH = 6
R_LARCH_TLS_DTPMOD64 R_LARCH = 7
R_LARCH_TLS_DTPREL32 R_LARCH = 8
R_LARCH_TLS_DTPREL64 R_LARCH = 9
R_LARCH_TLS_TPREL32 R_LARCH = 10
R_LARCH_TLS_TPREL64 R_LARCH = 11
R_LARCH_IRELATIVE R_LARCH = 12
R_LARCH_MARK_LA R_LARCH = 20
R_LARCH_MARK_PCREL R_LARCH = 21
R_LARCH_SOP_PUSH_PCREL R_LARCH = 22
R_LARCH_SOP_PUSH_ABSOLUTE R_LARCH = 23
R_LARCH_SOP_PUSH_DUP R_LARCH = 24
R_LARCH_SOP_PUSH_GPREL R_LARCH = 25
R_LARCH_SOP_PUSH_TLS_TPREL R_LARCH = 26
R_LARCH_SOP_PUSH_TLS_GOT R_LARCH = 27
R_LARCH_SOP_PUSH_TLS_GD R_LARCH = 28
R_LARCH_SOP_PUSH_PLT_PCREL R_LARCH = 29
R_LARCH_SOP_ASSERT R_LARCH = 30
R_LARCH_SOP_NOT R_LARCH = 31
R_LARCH_SOP_SUB R_LARCH = 32
R_LARCH_SOP_SL R_LARCH = 33
R_LARCH_SOP_SR R_LARCH = 34
R_LARCH_SOP_ADD R_LARCH = 35
R_LARCH_SOP_AND R_LARCH = 36
R_LARCH_SOP_IF_ELSE R_LARCH = 37
R_LARCH_SOP_POP_32_S_10_5 R_LARCH = 38
R_LARCH_SOP_POP_32_U_10_12 R_LARCH = 39
R_LARCH_SOP_POP_32_S_10_12 R_LARCH = 40
R_LARCH_SOP_POP_32_S_10_16 R_LARCH = 41
R_LARCH_SOP_POP_32_S_10_16_S2 R_LARCH = 42
R_LARCH_SOP_POP_32_S_5_20 R_LARCH = 43
R_LARCH_SOP_POP_32_S_0_5_10_16_S2 R_LARCH = 44
R_LARCH_SOP_POP_32_S_0_10_10_16_S2 R_LARCH = 45
R_LARCH_SOP_POP_32_U R_LARCH = 46
R_LARCH_ADD8 R_LARCH = 47
R_LARCH_ADD16 R_LARCH = 48
R_LARCH_ADD24 R_LARCH = 49
R_LARCH_ADD32 R_LARCH = 50
R_LARCH_ADD64 R_LARCH = 51
R_LARCH_SUB8 R_LARCH = 52
R_LARCH_SUB16 R_LARCH = 53
R_LARCH_SUB24 R_LARCH = 54
R_LARCH_SUB32 R_LARCH = 55
R_LARCH_SUB64 R_LARCH = 56
R_LARCH_GNU_VTINHERIT R_LARCH = 57
R_LARCH_GNU_VTENTRY R_LARCH = 58
R_LARCH_B16 R_LARCH = 64
R_LARCH_B21 R_LARCH = 65
R_LARCH_B26 R_LARCH = 66
R_LARCH_ABS_HI20 R_LARCH = 67
R_LARCH_ABS_LO12 R_LARCH = 68
R_LARCH_ABS64_LO20 R_LARCH = 69
R_LARCH_ABS64_HI12 R_LARCH = 70
R_LARCH_PCALA_HI20 R_LARCH = 71
R_LARCH_PCALA_LO12 R_LARCH = 72
R_LARCH_PCALA64_LO20 R_LARCH = 73
R_LARCH_PCALA64_HI12 R_LARCH = 74
R_LARCH_GOT_PC_HI20 R_LARCH = 75
R_LARCH_GOT_PC_LO12 R_LARCH = 76
R_LARCH_GOT64_PC_LO20 R_LARCH = 77
R_LARCH_GOT64_PC_HI12 R_LARCH = 78
R_LARCH_GOT_HI20 R_LARCH = 79
R_LARCH_GOT_LO12 R_LARCH = 80
R_LARCH_GOT64_LO20 R_LARCH = 81
R_LARCH_GOT64_HI12 R_LARCH = 82
R_LARCH_TLS_LE_HI20 R_LARCH = 83
R_LARCH_TLS_LE_LO12 R_LARCH = 84
R_LARCH_TLS_LE64_LO20 R_LARCH = 85
R_LARCH_TLS_LE64_HI12 R_LARCH = 86
R_LARCH_TLS_IE_PC_HI20 R_LARCH = 87
R_LARCH_TLS_IE_PC_LO12 R_LARCH = 88
R_LARCH_TLS_IE64_PC_LO20 R_LARCH = 89
R_LARCH_TLS_IE64_PC_HI12 R_LARCH = 90
R_LARCH_TLS_IE_HI20 R_LARCH = 91
R_LARCH_TLS_IE_LO12 R_LARCH = 92
R_LARCH_TLS_IE64_LO20 R_LARCH = 93
R_LARCH_TLS_IE64_HI12 R_LARCH = 94
R_LARCH_TLS_LD_PC_HI20 R_LARCH = 95
R_LARCH_TLS_LD_HI20 R_LARCH = 96
R_LARCH_TLS_GD_PC_HI20 R_LARCH = 97
R_LARCH_TLS_GD_HI20 R_LARCH = 98
R_LARCH_32_PCREL R_LARCH = 99
R_LARCH_RELAX R_LARCH = 100
R_LARCH_DELETE R_LARCH = 101
R_LARCH_ALIGN R_LARCH = 102
R_LARCH_PCREL20_S2 R_LARCH = 103
R_LARCH_CFA R_LARCH = 104
R_LARCH_ADD6 R_LARCH = 105
R_LARCH_SUB6 R_LARCH = 106
R_LARCH_ADD_ULEB128 R_LARCH = 107
R_LARCH_SUB_ULEB128 R_LARCH = 108
R_LARCH_64_PCREL R_LARCH = 109
)
var rlarchStrings = []intName{
{0, "R_LARCH_NONE"},
{1, "R_LARCH_32"},
{2, "R_LARCH_64"},
{3, "R_LARCH_RELATIVE"},
{4, "R_LARCH_COPY"},
{5, "R_LARCH_JUMP_SLOT"},
{6, "R_LARCH_TLS_DTPMOD32"},
{7, "R_LARCH_TLS_DTPMOD64"},
{8, "R_LARCH_TLS_DTPREL32"},
{9, "R_LARCH_TLS_DTPREL64"},
{10, "R_LARCH_TLS_TPREL32"},
{11, "R_LARCH_TLS_TPREL64"},
{12, "R_LARCH_IRELATIVE"},
{20, "R_LARCH_MARK_LA"},
{21, "R_LARCH_MARK_PCREL"},
{22, "R_LARCH_SOP_PUSH_PCREL"},
{23, "R_LARCH_SOP_PUSH_ABSOLUTE"},
{24, "R_LARCH_SOP_PUSH_DUP"},
{25, "R_LARCH_SOP_PUSH_GPREL"},
{26, "R_LARCH_SOP_PUSH_TLS_TPREL"},
{27, "R_LARCH_SOP_PUSH_TLS_GOT"},
{28, "R_LARCH_SOP_PUSH_TLS_GD"},
{29, "R_LARCH_SOP_PUSH_PLT_PCREL"},
{30, "R_LARCH_SOP_ASSERT"},
{31, "R_LARCH_SOP_NOT"},
{32, "R_LARCH_SOP_SUB"},
{33, "R_LARCH_SOP_SL"},
{34, "R_LARCH_SOP_SR"},
{35, "R_LARCH_SOP_ADD"},
{36, "R_LARCH_SOP_AND"},
{37, "R_LARCH_SOP_IF_ELSE"},
{38, "R_LARCH_SOP_POP_32_S_10_5"},
{39, "R_LARCH_SOP_POP_32_U_10_12"},
{40, "R_LARCH_SOP_POP_32_S_10_12"},
{41, "R_LARCH_SOP_POP_32_S_10_16"},
{42, "R_LARCH_SOP_POP_32_S_10_16_S2"},
{43, "R_LARCH_SOP_POP_32_S_5_20"},
{44, "R_LARCH_SOP_POP_32_S_0_5_10_16_S2"},
{45, "R_LARCH_SOP_POP_32_S_0_10_10_16_S2"},
{46, "R_LARCH_SOP_POP_32_U"},
{47, "R_LARCH_ADD8"},
{48, "R_LARCH_ADD16"},
{49, "R_LARCH_ADD24"},
{50, "R_LARCH_ADD32"},
{51, "R_LARCH_ADD64"},
{52, "R_LARCH_SUB8"},
{53, "R_LARCH_SUB16"},
{54, "R_LARCH_SUB24"},
{55, "R_LARCH_SUB32"},
{56, "R_LARCH_SUB64"},
{57, "R_LARCH_GNU_VTINHERIT"},
{58, "R_LARCH_GNU_VTENTRY"},
{64, "R_LARCH_B16"},
{65, "R_LARCH_B21"},
{66, "R_LARCH_B26"},
{67, "R_LARCH_ABS_HI20"},
{68, "R_LARCH_ABS_LO12"},
{69, "R_LARCH_ABS64_LO20"},
{70, "R_LARCH_ABS64_HI12"},
{71, "R_LARCH_PCALA_HI20"},
{72, "R_LARCH_PCALA_LO12"},
{73, "R_LARCH_PCALA64_LO20"},
{74, "R_LARCH_PCALA64_HI12"},
{75, "R_LARCH_GOT_PC_HI20"},
{76, "R_LARCH_GOT_PC_LO12"},
{77, "R_LARCH_GOT64_PC_LO20"},
{78, "R_LARCH_GOT64_PC_HI12"},
{79, "R_LARCH_GOT_HI20"},
{80, "R_LARCH_GOT_LO12"},
{81, "R_LARCH_GOT64_LO20"},
{82, "R_LARCH_GOT64_HI12"},
{83, "R_LARCH_TLS_LE_HI20"},
{84, "R_LARCH_TLS_LE_LO12"},
{85, "R_LARCH_TLS_LE64_LO20"},
{86, "R_LARCH_TLS_LE64_HI12"},
{87, "R_LARCH_TLS_IE_PC_HI20"},
{88, "R_LARCH_TLS_IE_PC_LO12"},
{89, "R_LARCH_TLS_IE64_PC_LO20"},
{90, "R_LARCH_TLS_IE64_PC_HI12"},
{91, "R_LARCH_TLS_IE_HI20"},
{92, "R_LARCH_TLS_IE_LO12"},
{93, "R_LARCH_TLS_IE64_LO20"},
{94, "R_LARCH_TLS_IE64_HI12"},
{95, "R_LARCH_TLS_LD_PC_HI20"},
{96, "R_LARCH_TLS_LD_HI20"},
{97, "R_LARCH_TLS_GD_PC_HI20"},
{98, "R_LARCH_TLS_GD_HI20"},
{99, "R_LARCH_32_PCREL"},
{100, "R_LARCH_RELAX"},
{101, "R_LARCH_DELETE"},
{102, "R_LARCH_ALIGN"},
{103, "R_LARCH_PCREL20_S2"},
{104, "R_LARCH_CFA"},
{105, "R_LARCH_ADD6"},
{106, "R_LARCH_SUB6"},
{107, "R_LARCH_ADD_ULEB128"},
{108, "R_LARCH_SUB_ULEB128"},
{109, "R_LARCH_64_PCREL"},
}
func (i R_LARCH) String() string { return stringName(uint32(i), rlarchStrings, false) }
func (i R_LARCH) GoString() string { return stringName(uint32(i), rlarchStrings, true) }
// Relocation types for PowerPC.
//
// Values that are shared by both R_PPC and R_PPC64 are prefixed with
// R_POWERPC_ in the ELF standard. For the R_PPC type, the relevant
// shared relocations have been renamed with the prefix R_PPC_.
// The original name follows the value in a comment.
type R_PPC int
const (
R_PPC_NONE R_PPC = 0 // R_POWERPC_NONE
R_PPC_ADDR32 R_PPC = 1 // R_POWERPC_ADDR32
R_PPC_ADDR24 R_PPC = 2 // R_POWERPC_ADDR24
R_PPC_ADDR16 R_PPC = 3 // R_POWERPC_ADDR16
R_PPC_ADDR16_LO R_PPC = 4 // R_POWERPC_ADDR16_LO
R_PPC_ADDR16_HI R_PPC = 5 // R_POWERPC_ADDR16_HI
R_PPC_ADDR16_HA R_PPC = 6 // R_POWERPC_ADDR16_HA
R_PPC_ADDR14 R_PPC = 7 // R_POWERPC_ADDR14
R_PPC_ADDR14_BRTAKEN R_PPC = 8 // R_POWERPC_ADDR14_BRTAKEN
R_PPC_ADDR14_BRNTAKEN R_PPC = 9 // R_POWERPC_ADDR14_BRNTAKEN
R_PPC_REL24 R_PPC = 10 // R_POWERPC_REL24
R_PPC_REL14 R_PPC = 11 // R_POWERPC_REL14
R_PPC_REL14_BRTAKEN R_PPC = 12 // R_POWERPC_REL14_BRTAKEN
R_PPC_REL14_BRNTAKEN R_PPC = 13 // R_POWERPC_REL14_BRNTAKEN
R_PPC_GOT16 R_PPC = 14 // R_POWERPC_GOT16
R_PPC_GOT16_LO R_PPC = 15 // R_POWERPC_GOT16_LO
R_PPC_GOT16_HI R_PPC = 16 // R_POWERPC_GOT16_HI
R_PPC_GOT16_HA R_PPC = 17 // R_POWERPC_GOT16_HA
R_PPC_PLTREL24 R_PPC = 18
R_PPC_COPY R_PPC = 19 // R_POWERPC_COPY
R_PPC_GLOB_DAT R_PPC = 20 // R_POWERPC_GLOB_DAT
R_PPC_JMP_SLOT R_PPC = 21 // R_POWERPC_JMP_SLOT
R_PPC_RELATIVE R_PPC = 22 // R_POWERPC_RELATIVE
R_PPC_LOCAL24PC R_PPC = 23
R_PPC_UADDR32 R_PPC = 24 // R_POWERPC_UADDR32
R_PPC_UADDR16 R_PPC = 25 // R_POWERPC_UADDR16
R_PPC_REL32 R_PPC = 26 // R_POWERPC_REL32
R_PPC_PLT32 R_PPC = 27 // R_POWERPC_PLT32
R_PPC_PLTREL32 R_PPC = 28 // R_POWERPC_PLTREL32
R_PPC_PLT16_LO R_PPC = 29 // R_POWERPC_PLT16_LO
R_PPC_PLT16_HI R_PPC = 30 // R_POWERPC_PLT16_HI
R_PPC_PLT16_HA R_PPC = 31 // R_POWERPC_PLT16_HA
R_PPC_SDAREL16 R_PPC = 32
R_PPC_SECTOFF R_PPC = 33 // R_POWERPC_SECTOFF
R_PPC_SECTOFF_LO R_PPC = 34 // R_POWERPC_SECTOFF_LO
R_PPC_SECTOFF_HI R_PPC = 35 // R_POWERPC_SECTOFF_HI
R_PPC_SECTOFF_HA R_PPC = 36 // R_POWERPC_SECTOFF_HA
R_PPC_TLS R_PPC = 67 // R_POWERPC_TLS
R_PPC_DTPMOD32 R_PPC = 68 // R_POWERPC_DTPMOD32
R_PPC_TPREL16 R_PPC = 69 // R_POWERPC_TPREL16
R_PPC_TPREL16_LO R_PPC = 70 // R_POWERPC_TPREL16_LO
R_PPC_TPREL16_HI R_PPC = 71 // R_POWERPC_TPREL16_HI
R_PPC_TPREL16_HA R_PPC = 72 // R_POWERPC_TPREL16_HA
R_PPC_TPREL32 R_PPC = 73 // R_POWERPC_TPREL32
R_PPC_DTPREL16 R_PPC = 74 // R_POWERPC_DTPREL16
R_PPC_DTPREL16_LO R_PPC = 75 // R_POWERPC_DTPREL16_LO
R_PPC_DTPREL16_HI R_PPC = 76 // R_POWERPC_DTPREL16_HI
R_PPC_DTPREL16_HA R_PPC = 77 // R_POWERPC_DTPREL16_HA
R_PPC_DTPREL32 R_PPC = 78 // R_POWERPC_DTPREL32
R_PPC_GOT_TLSGD16 R_PPC = 79 // R_POWERPC_GOT_TLSGD16
R_PPC_GOT_TLSGD16_LO R_PPC = 80 // R_POWERPC_GOT_TLSGD16_LO
R_PPC_GOT_TLSGD16_HI R_PPC = 81 // R_POWERPC_GOT_TLSGD16_HI
R_PPC_GOT_TLSGD16_HA R_PPC = 82 // R_POWERPC_GOT_TLSGD16_HA
R_PPC_GOT_TLSLD16 R_PPC = 83 // R_POWERPC_GOT_TLSLD16
R_PPC_GOT_TLSLD16_LO R_PPC = 84 // R_POWERPC_GOT_TLSLD16_LO
R_PPC_GOT_TLSLD16_HI R_PPC = 85 // R_POWERPC_GOT_TLSLD16_HI
R_PPC_GOT_TLSLD16_HA R_PPC = 86 // R_POWERPC_GOT_TLSLD16_HA
R_PPC_GOT_TPREL16 R_PPC = 87 // R_POWERPC_GOT_TPREL16
R_PPC_GOT_TPREL16_LO R_PPC = 88 // R_POWERPC_GOT_TPREL16_LO
R_PPC_GOT_TPREL16_HI R_PPC = 89 // R_POWERPC_GOT_TPREL16_HI
R_PPC_GOT_TPREL16_HA R_PPC = 90 // R_POWERPC_GOT_TPREL16_HA
R_PPC_EMB_NADDR32 R_PPC = 101
R_PPC_EMB_NADDR16 R_PPC = 102
R_PPC_EMB_NADDR16_LO R_PPC = 103
R_PPC_EMB_NADDR16_HI R_PPC = 104
R_PPC_EMB_NADDR16_HA R_PPC = 105
R_PPC_EMB_SDAI16 R_PPC = 106
R_PPC_EMB_SDA2I16 R_PPC = 107
R_PPC_EMB_SDA2REL R_PPC = 108
R_PPC_EMB_SDA21 R_PPC = 109
R_PPC_EMB_MRKREF R_PPC = 110
R_PPC_EMB_RELSEC16 R_PPC = 111
R_PPC_EMB_RELST_LO R_PPC = 112
R_PPC_EMB_RELST_HI R_PPC = 113
R_PPC_EMB_RELST_HA R_PPC = 114
R_PPC_EMB_BIT_FLD R_PPC = 115
R_PPC_EMB_RELSDA R_PPC = 116
)
var rppcStrings = []intName{
{0, "R_PPC_NONE"},
{1, "R_PPC_ADDR32"},
{2, "R_PPC_ADDR24"},
{3, "R_PPC_ADDR16"},
{4, "R_PPC_ADDR16_LO"},
{5, "R_PPC_ADDR16_HI"},
{6, "R_PPC_ADDR16_HA"},
{7, "R_PPC_ADDR14"},
{8, "R_PPC_ADDR14_BRTAKEN"},
{9, "R_PPC_ADDR14_BRNTAKEN"},
{10, "R_PPC_REL24"},
{11, "R_PPC_REL14"},
{12, "R_PPC_REL14_BRTAKEN"},
{13, "R_PPC_REL14_BRNTAKEN"},
{14, "R_PPC_GOT16"},
{15, "R_PPC_GOT16_LO"},
{16, "R_PPC_GOT16_HI"},
{17, "R_PPC_GOT16_HA"},
{18, "R_PPC_PLTREL24"},
{19, "R_PPC_COPY"},
{20, "R_PPC_GLOB_DAT"},
{21, "R_PPC_JMP_SLOT"},
{22, "R_PPC_RELATIVE"},
{23, "R_PPC_LOCAL24PC"},
{24, "R_PPC_UADDR32"},
{25, "R_PPC_UADDR16"},
{26, "R_PPC_REL32"},
{27, "R_PPC_PLT32"},
{28, "R_PPC_PLTREL32"},
{29, "R_PPC_PLT16_LO"},
{30, "R_PPC_PLT16_HI"},
{31, "R_PPC_PLT16_HA"},
{32, "R_PPC_SDAREL16"},
{33, "R_PPC_SECTOFF"},
{34, "R_PPC_SECTOFF_LO"},
{35, "R_PPC_SECTOFF_HI"},
{36, "R_PPC_SECTOFF_HA"},
{67, "R_PPC_TLS"},
{68, "R_PPC_DTPMOD32"},
{69, "R_PPC_TPREL16"},
{70, "R_PPC_TPREL16_LO"},
{71, "R_PPC_TPREL16_HI"},
{72, "R_PPC_TPREL16_HA"},
{73, "R_PPC_TPREL32"},
{74, "R_PPC_DTPREL16"},
{75, "R_PPC_DTPREL16_LO"},
{76, "R_PPC_DTPREL16_HI"},
{77, "R_PPC_DTPREL16_HA"},
{78, "R_PPC_DTPREL32"},
{79, "R_PPC_GOT_TLSGD16"},
{80, "R_PPC_GOT_TLSGD16_LO"},
{81, "R_PPC_GOT_TLSGD16_HI"},
{82, "R_PPC_GOT_TLSGD16_HA"},
{83, "R_PPC_GOT_TLSLD16"},
{84, "R_PPC_GOT_TLSLD16_LO"},
{85, "R_PPC_GOT_TLSLD16_HI"},
{86, "R_PPC_GOT_TLSLD16_HA"},
{87, "R_PPC_GOT_TPREL16"},
{88, "R_PPC_GOT_TPREL16_LO"},
{89, "R_PPC_GOT_TPREL16_HI"},
{90, "R_PPC_GOT_TPREL16_HA"},
{101, "R_PPC_EMB_NADDR32"},
{102, "R_PPC_EMB_NADDR16"},
{103, "R_PPC_EMB_NADDR16_LO"},
{104, "R_PPC_EMB_NADDR16_HI"},
{105, "R_PPC_EMB_NADDR16_HA"},
{106, "R_PPC_EMB_SDAI16"},
{107, "R_PPC_EMB_SDA2I16"},
{108, "R_PPC_EMB_SDA2REL"},
{109, "R_PPC_EMB_SDA21"},
{110, "R_PPC_EMB_MRKREF"},
{111, "R_PPC_EMB_RELSEC16"},
{112, "R_PPC_EMB_RELST_LO"},
{113, "R_PPC_EMB_RELST_HI"},
{114, "R_PPC_EMB_RELST_HA"},
{115, "R_PPC_EMB_BIT_FLD"},
{116, "R_PPC_EMB_RELSDA"},
}
func (i R_PPC) String() string { return stringName(uint32(i), rppcStrings, false) }
func (i R_PPC) GoString() string { return stringName(uint32(i), rppcStrings, true) }
// Relocation types for 64-bit PowerPC or Power Architecture processors.
//
// Values that are shared by both R_PPC and R_PPC64 are prefixed with
// R_POWERPC_ in the ELF standard. For the R_PPC64 type, the relevant
// shared relocations have been renamed with the prefix R_PPC64_.
// The original name follows the value in a comment.
type R_PPC64 int
const (
R_PPC64_NONE R_PPC64 = 0 // R_POWERPC_NONE
R_PPC64_ADDR32 R_PPC64 = 1 // R_POWERPC_ADDR32
R_PPC64_ADDR24 R_PPC64 = 2 // R_POWERPC_ADDR24
R_PPC64_ADDR16 R_PPC64 = 3 // R_POWERPC_ADDR16
R_PPC64_ADDR16_LO R_PPC64 = 4 // R_POWERPC_ADDR16_LO
R_PPC64_ADDR16_HI R_PPC64 = 5 // R_POWERPC_ADDR16_HI
R_PPC64_ADDR16_HA R_PPC64 = 6 // R_POWERPC_ADDR16_HA
R_PPC64_ADDR14 R_PPC64 = 7 // R_POWERPC_ADDR14
R_PPC64_ADDR14_BRTAKEN R_PPC64 = 8 // R_POWERPC_ADDR14_BRTAKEN
R_PPC64_ADDR14_BRNTAKEN R_PPC64 = 9 // R_POWERPC_ADDR14_BRNTAKEN
R_PPC64_REL24 R_PPC64 = 10 // R_POWERPC_REL24
R_PPC64_REL14 R_PPC64 = 11 // R_POWERPC_REL14
R_PPC64_REL14_BRTAKEN R_PPC64 = 12 // R_POWERPC_REL14_BRTAKEN
R_PPC64_REL14_BRNTAKEN R_PPC64 = 13 // R_POWERPC_REL14_BRNTAKEN
R_PPC64_GOT16 R_PPC64 = 14 // R_POWERPC_GOT16
R_PPC64_GOT16_LO R_PPC64 = 15 // R_POWERPC_GOT16_LO
R_PPC64_GOT16_HI R_PPC64 = 16 // R_POWERPC_GOT16_HI
R_PPC64_GOT16_HA R_PPC64 = 17 // R_POWERPC_GOT16_HA
R_PPC64_COPY R_PPC64 = 19 // R_POWERPC_COPY
R_PPC64_GLOB_DAT R_PPC64 = 20 // R_POWERPC_GLOB_DAT
R_PPC64_JMP_SLOT R_PPC64 = 21 // R_POWERPC_JMP_SLOT
R_PPC64_RELATIVE R_PPC64 = 22 // R_POWERPC_RELATIVE
R_PPC64_UADDR32 R_PPC64 = 24 // R_POWERPC_UADDR32
R_PPC64_UADDR16 R_PPC64 = 25 // R_POWERPC_UADDR16
R_PPC64_REL32 R_PPC64 = 26 // R_POWERPC_REL32
R_PPC64_PLT32 R_PPC64 = 27 // R_POWERPC_PLT32
R_PPC64_PLTREL32 R_PPC64 = 28 // R_POWERPC_PLTREL32
R_PPC64_PLT16_LO R_PPC64 = 29 // R_POWERPC_PLT16_LO
R_PPC64_PLT16_HI R_PPC64 = 30 // R_POWERPC_PLT16_HI
R_PPC64_PLT16_HA R_PPC64 = 31 // R_POWERPC_PLT16_HA
R_PPC64_SECTOFF R_PPC64 = 33 // R_POWERPC_SECTOFF
R_PPC64_SECTOFF_LO R_PPC64 = 34 // R_POWERPC_SECTOFF_LO
R_PPC64_SECTOFF_HI R_PPC64 = 35 // R_POWERPC_SECTOFF_HI
R_PPC64_SECTOFF_HA R_PPC64 = 36 // R_POWERPC_SECTOFF_HA
R_PPC64_REL30 R_PPC64 = 37 // R_POWERPC_ADDR30
R_PPC64_ADDR64 R_PPC64 = 38
R_PPC64_ADDR16_HIGHER R_PPC64 = 39
R_PPC64_ADDR16_HIGHERA R_PPC64 = 40
R_PPC64_ADDR16_HIGHEST R_PPC64 = 41
R_PPC64_ADDR16_HIGHESTA R_PPC64 = 42
R_PPC64_UADDR64 R_PPC64 = 43
R_PPC64_REL64 R_PPC64 = 44
R_PPC64_PLT64 R_PPC64 = 45
R_PPC64_PLTREL64 R_PPC64 = 46
R_PPC64_TOC16 R_PPC64 = 47
R_PPC64_TOC16_LO R_PPC64 = 48
R_PPC64_TOC16_HI R_PPC64 = 49
R_PPC64_TOC16_HA R_PPC64 = 50
R_PPC64_TOC R_PPC64 = 51
R_PPC64_PLTGOT16 R_PPC64 = 52
R_PPC64_PLTGOT16_LO R_PPC64 = 53
R_PPC64_PLTGOT16_HI R_PPC64 = 54
R_PPC64_PLTGOT16_HA R_PPC64 = 55
R_PPC64_ADDR16_DS R_PPC64 = 56
R_PPC64_ADDR16_LO_DS R_PPC64 = 57
R_PPC64_GOT16_DS R_PPC64 = 58
R_PPC64_GOT16_LO_DS R_PPC64 = 59
R_PPC64_PLT16_LO_DS R_PPC64 = 60
R_PPC64_SECTOFF_DS R_PPC64 = 61
R_PPC64_SECTOFF_LO_DS R_PPC64 = 62
R_PPC64_TOC16_DS R_PPC64 = 63
R_PPC64_TOC16_LO_DS R_PPC64 = 64
R_PPC64_PLTGOT16_DS R_PPC64 = 65
R_PPC64_PLTGOT_LO_DS R_PPC64 = 66
R_PPC64_TLS R_PPC64 = 67 // R_POWERPC_TLS
R_PPC64_DTPMOD64 R_PPC64 = 68 // R_POWERPC_DTPMOD64
R_PPC64_TPREL16 R_PPC64 = 69 // R_POWERPC_TPREL16
R_PPC64_TPREL16_LO R_PPC64 = 70 // R_POWERPC_TPREL16_LO
R_PPC64_TPREL16_HI R_PPC64 = 71 // R_POWERPC_TPREL16_HI
R_PPC64_TPREL16_HA R_PPC64 = 72 // R_POWERPC_TPREL16_HA
R_PPC64_TPREL64 R_PPC64 = 73 // R_POWERPC_TPREL64
R_PPC64_DTPREL16 R_PPC64 = 74 // R_POWERPC_DTPREL16
R_PPC64_DTPREL16_LO R_PPC64 = 75 // R_POWERPC_DTPREL16_LO
R_PPC64_DTPREL16_HI R_PPC64 = 76 // R_POWERPC_DTPREL16_HI
R_PPC64_DTPREL16_HA R_PPC64 = 77 // R_POWERPC_DTPREL16_HA
R_PPC64_DTPREL64 R_PPC64 = 78 // R_POWERPC_DTPREL64
R_PPC64_GOT_TLSGD16 R_PPC64 = 79 // R_POWERPC_GOT_TLSGD16
R_PPC64_GOT_TLSGD16_LO R_PPC64 = 80 // R_POWERPC_GOT_TLSGD16_LO
R_PPC64_GOT_TLSGD16_HI R_PPC64 = 81 // R_POWERPC_GOT_TLSGD16_HI
R_PPC64_GOT_TLSGD16_HA R_PPC64 = 82 // R_POWERPC_GOT_TLSGD16_HA
R_PPC64_GOT_TLSLD16 R_PPC64 = 83 // R_POWERPC_GOT_TLSLD16
R_PPC64_GOT_TLSLD16_LO R_PPC64 = 84 // R_POWERPC_GOT_TLSLD16_LO
R_PPC64_GOT_TLSLD16_HI R_PPC64 = 85 // R_POWERPC_GOT_TLSLD16_HI
R_PPC64_GOT_TLSLD16_HA R_PPC64 = 86 // R_POWERPC_GOT_TLSLD16_HA
R_PPC64_GOT_TPREL16_DS R_PPC64 = 87 // R_POWERPC_GOT_TPREL16_DS
R_PPC64_GOT_TPREL16_LO_DS R_PPC64 = 88 // R_POWERPC_GOT_TPREL16_LO_DS
R_PPC64_GOT_TPREL16_HI R_PPC64 = 89 // R_POWERPC_GOT_TPREL16_HI
R_PPC64_GOT_TPREL16_HA R_PPC64 = 90 // R_POWERPC_GOT_TPREL16_HA
R_PPC64_GOT_DTPREL16_DS R_PPC64 = 91 // R_POWERPC_GOT_DTPREL16_DS
R_PPC64_GOT_DTPREL16_LO_DS R_PPC64 = 92 // R_POWERPC_GOT_DTPREL16_LO_DS
R_PPC64_GOT_DTPREL16_HI R_PPC64 = 93 // R_POWERPC_GOT_DTPREL16_HI
R_PPC64_GOT_DTPREL16_HA R_PPC64 = 94 // R_POWERPC_GOT_DTPREL16_HA
R_PPC64_TPREL16_DS R_PPC64 = 95
R_PPC64_TPREL16_LO_DS R_PPC64 = 96
R_PPC64_TPREL16_HIGHER R_PPC64 = 97
R_PPC64_TPREL16_HIGHERA R_PPC64 = 98
R_PPC64_TPREL16_HIGHEST R_PPC64 = 99
R_PPC64_TPREL16_HIGHESTA R_PPC64 = 100
R_PPC64_DTPREL16_DS R_PPC64 = 101
R_PPC64_DTPREL16_LO_DS R_PPC64 = 102
R_PPC64_DTPREL16_HIGHER R_PPC64 = 103
R_PPC64_DTPREL16_HIGHERA R_PPC64 = 104
R_PPC64_DTPREL16_HIGHEST R_PPC64 = 105
R_PPC64_DTPREL16_HIGHESTA R_PPC64 = 106
R_PPC64_TLSGD R_PPC64 = 107
R_PPC64_TLSLD R_PPC64 = 108
R_PPC64_TOCSAVE R_PPC64 = 109
R_PPC64_ADDR16_HIGH R_PPC64 = 110
R_PPC64_ADDR16_HIGHA R_PPC64 = 111
R_PPC64_TPREL16_HIGH R_PPC64 = 112
R_PPC64_TPREL16_HIGHA R_PPC64 = 113
R_PPC64_DTPREL16_HIGH R_PPC64 = 114
R_PPC64_DTPREL16_HIGHA R_PPC64 = 115
R_PPC64_REL24_NOTOC R_PPC64 = 116
R_PPC64_ADDR64_LOCAL R_PPC64 = 117
R_PPC64_ENTRY R_PPC64 = 118
R_PPC64_PLTSEQ R_PPC64 = 119
R_PPC64_PLTCALL R_PPC64 = 120
R_PPC64_PLTSEQ_NOTOC R_PPC64 = 121
R_PPC64_PLTCALL_NOTOC R_PPC64 = 122
R_PPC64_PCREL_OPT R_PPC64 = 123
R_PPC64_REL24_P9NOTOC R_PPC64 = 124
R_PPC64_D34 R_PPC64 = 128
R_PPC64_D34_LO R_PPC64 = 129
R_PPC64_D34_HI30 R_PPC64 = 130
R_PPC64_D34_HA30 R_PPC64 = 131
R_PPC64_PCREL34 R_PPC64 = 132
R_PPC64_GOT_PCREL34 R_PPC64 = 133
R_PPC64_PLT_PCREL34 R_PPC64 = 134
R_PPC64_PLT_PCREL34_NOTOC R_PPC64 = 135
R_PPC64_ADDR16_HIGHER34 R_PPC64 = 136
R_PPC64_ADDR16_HIGHERA34 R_PPC64 = 137
R_PPC64_ADDR16_HIGHEST34 R_PPC64 = 138
R_PPC64_ADDR16_HIGHESTA34 R_PPC64 = 139
R_PPC64_REL16_HIGHER34 R_PPC64 = 140
R_PPC64_REL16_HIGHERA34 R_PPC64 = 141
R_PPC64_REL16_HIGHEST34 R_PPC64 = 142
R_PPC64_REL16_HIGHESTA34 R_PPC64 = 143
R_PPC64_D28 R_PPC64 = 144
R_PPC64_PCREL28 R_PPC64 = 145
R_PPC64_TPREL34 R_PPC64 = 146
R_PPC64_DTPREL34 R_PPC64 = 147
R_PPC64_GOT_TLSGD_PCREL34 R_PPC64 = 148
R_PPC64_GOT_TLSLD_PCREL34 R_PPC64 = 149
R_PPC64_GOT_TPREL_PCREL34 R_PPC64 = 150
R_PPC64_GOT_DTPREL_PCREL34 R_PPC64 = 151
R_PPC64_REL16_HIGH R_PPC64 = 240
R_PPC64_REL16_HIGHA R_PPC64 = 241
R_PPC64_REL16_HIGHER R_PPC64 = 242
R_PPC64_REL16_HIGHERA R_PPC64 = 243
R_PPC64_REL16_HIGHEST R_PPC64 = 244
R_PPC64_REL16_HIGHESTA R_PPC64 = 245
R_PPC64_REL16DX_HA R_PPC64 = 246 // R_POWERPC_REL16DX_HA
R_PPC64_JMP_IREL R_PPC64 = 247
R_PPC64_IRELATIVE R_PPC64 = 248 // R_POWERPC_IRELATIVE
R_PPC64_REL16 R_PPC64 = 249 // R_POWERPC_REL16
R_PPC64_REL16_LO R_PPC64 = 250 // R_POWERPC_REL16_LO
R_PPC64_REL16_HI R_PPC64 = 251 // R_POWERPC_REL16_HI
R_PPC64_REL16_HA R_PPC64 = 252 // R_POWERPC_REL16_HA
R_PPC64_GNU_VTINHERIT R_PPC64 = 253
R_PPC64_GNU_VTENTRY R_PPC64 = 254
)
var rppc64Strings = []intName{
{0, "R_PPC64_NONE"},
{1, "R_PPC64_ADDR32"},
{2, "R_PPC64_ADDR24"},
{3, "R_PPC64_ADDR16"},
{4, "R_PPC64_ADDR16_LO"},
{5, "R_PPC64_ADDR16_HI"},
{6, "R_PPC64_ADDR16_HA"},
{7, "R_PPC64_ADDR14"},
{8, "R_PPC64_ADDR14_BRTAKEN"},
{9, "R_PPC64_ADDR14_BRNTAKEN"},
{10, "R_PPC64_REL24"},
{11, "R_PPC64_REL14"},
{12, "R_PPC64_REL14_BRTAKEN"},
{13, "R_PPC64_REL14_BRNTAKEN"},
{14, "R_PPC64_GOT16"},
{15, "R_PPC64_GOT16_LO"},
{16, "R_PPC64_GOT16_HI"},
{17, "R_PPC64_GOT16_HA"},
{19, "R_PPC64_COPY"},
{20, "R_PPC64_GLOB_DAT"},
{21, "R_PPC64_JMP_SLOT"},
{22, "R_PPC64_RELATIVE"},
{24, "R_PPC64_UADDR32"},
{25, "R_PPC64_UADDR16"},
{26, "R_PPC64_REL32"},
{27, "R_PPC64_PLT32"},
{28, "R_PPC64_PLTREL32"},
{29, "R_PPC64_PLT16_LO"},
{30, "R_PPC64_PLT16_HI"},
{31, "R_PPC64_PLT16_HA"},
{33, "R_PPC64_SECTOFF"},
{34, "R_PPC64_SECTOFF_LO"},
{35, "R_PPC64_SECTOFF_HI"},
{36, "R_PPC64_SECTOFF_HA"},
{37, "R_PPC64_REL30"},
{38, "R_PPC64_ADDR64"},
{39, "R_PPC64_ADDR16_HIGHER"},
{40, "R_PPC64_ADDR16_HIGHERA"},
{41, "R_PPC64_ADDR16_HIGHEST"},
{42, "R_PPC64_ADDR16_HIGHESTA"},
{43, "R_PPC64_UADDR64"},
{44, "R_PPC64_REL64"},
{45, "R_PPC64_PLT64"},
{46, "R_PPC64_PLTREL64"},
{47, "R_PPC64_TOC16"},
{48, "R_PPC64_TOC16_LO"},
{49, "R_PPC64_TOC16_HI"},
{50, "R_PPC64_TOC16_HA"},
{51, "R_PPC64_TOC"},
{52, "R_PPC64_PLTGOT16"},
{53, "R_PPC64_PLTGOT16_LO"},
{54, "R_PPC64_PLTGOT16_HI"},
{55, "R_PPC64_PLTGOT16_HA"},
{56, "R_PPC64_ADDR16_DS"},
{57, "R_PPC64_ADDR16_LO_DS"},
{58, "R_PPC64_GOT16_DS"},
{59, "R_PPC64_GOT16_LO_DS"},
{60, "R_PPC64_PLT16_LO_DS"},
{61, "R_PPC64_SECTOFF_DS"},
{62, "R_PPC64_SECTOFF_LO_DS"},
{63, "R_PPC64_TOC16_DS"},
{64, "R_PPC64_TOC16_LO_DS"},
{65, "R_PPC64_PLTGOT16_DS"},
{66, "R_PPC64_PLTGOT_LO_DS"},
{67, "R_PPC64_TLS"},
{68, "R_PPC64_DTPMOD64"},
{69, "R_PPC64_TPREL16"},
{70, "R_PPC64_TPREL16_LO"},
{71, "R_PPC64_TPREL16_HI"},
{72, "R_PPC64_TPREL16_HA"},
{73, "R_PPC64_TPREL64"},
{74, "R_PPC64_DTPREL16"},
{75, "R_PPC64_DTPREL16_LO"},
{76, "R_PPC64_DTPREL16_HI"},
{77, "R_PPC64_DTPREL16_HA"},
{78, "R_PPC64_DTPREL64"},
{79, "R_PPC64_GOT_TLSGD16"},
{80, "R_PPC64_GOT_TLSGD16_LO"},
{81, "R_PPC64_GOT_TLSGD16_HI"},
{82, "R_PPC64_GOT_TLSGD16_HA"},
{83, "R_PPC64_GOT_TLSLD16"},
{84, "R_PPC64_GOT_TLSLD16_LO"},
{85, "R_PPC64_GOT_TLSLD16_HI"},
{86, "R_PPC64_GOT_TLSLD16_HA"},
{87, "R_PPC64_GOT_TPREL16_DS"},
{88, "R_PPC64_GOT_TPREL16_LO_DS"},
{89, "R_PPC64_GOT_TPREL16_HI"},
{90, "R_PPC64_GOT_TPREL16_HA"},
{91, "R_PPC64_GOT_DTPREL16_DS"},
{92, "R_PPC64_GOT_DTPREL16_LO_DS"},
{93, "R_PPC64_GOT_DTPREL16_HI"},
{94, "R_PPC64_GOT_DTPREL16_HA"},
{95, "R_PPC64_TPREL16_DS"},
{96, "R_PPC64_TPREL16_LO_DS"},
{97, "R_PPC64_TPREL16_HIGHER"},
{98, "R_PPC64_TPREL16_HIGHERA"},
{99, "R_PPC64_TPREL16_HIGHEST"},
{100, "R_PPC64_TPREL16_HIGHESTA"},
{101, "R_PPC64_DTPREL16_DS"},
{102, "R_PPC64_DTPREL16_LO_DS"},
{103, "R_PPC64_DTPREL16_HIGHER"},
{104, "R_PPC64_DTPREL16_HIGHERA"},
{105, "R_PPC64_DTPREL16_HIGHEST"},
{106, "R_PPC64_DTPREL16_HIGHESTA"},
{107, "R_PPC64_TLSGD"},
{108, "R_PPC64_TLSLD"},
{109, "R_PPC64_TOCSAVE"},
{110, "R_PPC64_ADDR16_HIGH"},
{111, "R_PPC64_ADDR16_HIGHA"},
{112, "R_PPC64_TPREL16_HIGH"},
{113, "R_PPC64_TPREL16_HIGHA"},
{114, "R_PPC64_DTPREL16_HIGH"},
{115, "R_PPC64_DTPREL16_HIGHA"},
{116, "R_PPC64_REL24_NOTOC"},
{117, "R_PPC64_ADDR64_LOCAL"},
{118, "R_PPC64_ENTRY"},
{119, "R_PPC64_PLTSEQ"},
{120, "R_PPC64_PLTCALL"},
{121, "R_PPC64_PLTSEQ_NOTOC"},
{122, "R_PPC64_PLTCALL_NOTOC"},
{123, "R_PPC64_PCREL_OPT"},
{124, "R_PPC64_REL24_P9NOTOC"},
{128, "R_PPC64_D34"},
{129, "R_PPC64_D34_LO"},
{130, "R_PPC64_D34_HI30"},
{131, "R_PPC64_D34_HA30"},
{132, "R_PPC64_PCREL34"},
{133, "R_PPC64_GOT_PCREL34"},
{134, "R_PPC64_PLT_PCREL34"},
{135, "R_PPC64_PLT_PCREL34_NOTOC"},
{136, "R_PPC64_ADDR16_HIGHER34"},
{137, "R_PPC64_ADDR16_HIGHERA34"},
{138, "R_PPC64_ADDR16_HIGHEST34"},
{139, "R_PPC64_ADDR16_HIGHESTA34"},
{140, "R_PPC64_REL16_HIGHER34"},
{141, "R_PPC64_REL16_HIGHERA34"},
{142, "R_PPC64_REL16_HIGHEST34"},
{143, "R_PPC64_REL16_HIGHESTA34"},
{144, "R_PPC64_D28"},
{145, "R_PPC64_PCREL28"},
{146, "R_PPC64_TPREL34"},
{147, "R_PPC64_DTPREL34"},
{148, "R_PPC64_GOT_TLSGD_PCREL34"},
{149, "R_PPC64_GOT_TLSLD_PCREL34"},
{150, "R_PPC64_GOT_TPREL_PCREL34"},
{151, "R_PPC64_GOT_DTPREL_PCREL34"},
{240, "R_PPC64_REL16_HIGH"},
{241, "R_PPC64_REL16_HIGHA"},
{242, "R_PPC64_REL16_HIGHER"},
{243, "R_PPC64_REL16_HIGHERA"},
{244, "R_PPC64_REL16_HIGHEST"},
{245, "R_PPC64_REL16_HIGHESTA"},
{246, "R_PPC64_REL16DX_HA"},
{247, "R_PPC64_JMP_IREL"},
{248, "R_PPC64_IRELATIVE"},
{249, "R_PPC64_REL16"},
{250, "R_PPC64_REL16_LO"},
{251, "R_PPC64_REL16_HI"},
{252, "R_PPC64_REL16_HA"},
{253, "R_PPC64_GNU_VTINHERIT"},
{254, "R_PPC64_GNU_VTENTRY"},
}
func (i R_PPC64) String() string { return stringName(uint32(i), rppc64Strings, false) }
func (i R_PPC64) GoString() string { return stringName(uint32(i), rppc64Strings, true) }
// Relocation types for RISC-V processors.
type R_RISCV int
const (
R_RISCV_NONE R_RISCV = 0 /* No relocation. */
R_RISCV_32 R_RISCV = 1 /* Add 32 bit zero extended symbol value */
R_RISCV_64 R_RISCV = 2 /* Add 64 bit symbol value. */
R_RISCV_RELATIVE R_RISCV = 3 /* Add load address of shared object. */
R_RISCV_COPY R_RISCV = 4 /* Copy data from shared object. */
R_RISCV_JUMP_SLOT R_RISCV = 5 /* Set GOT entry to code address. */
R_RISCV_TLS_DTPMOD32 R_RISCV = 6 /* 32 bit ID of module containing symbol */
R_RISCV_TLS_DTPMOD64 R_RISCV = 7 /* ID of module containing symbol */
R_RISCV_TLS_DTPREL32 R_RISCV = 8 /* 32 bit relative offset in TLS block */
R_RISCV_TLS_DTPREL64 R_RISCV = 9 /* Relative offset in TLS block */
R_RISCV_TLS_TPREL32 R_RISCV = 10 /* 32 bit relative offset in static TLS block */
R_RISCV_TLS_TPREL64 R_RISCV = 11 /* Relative offset in static TLS block */
R_RISCV_BRANCH R_RISCV = 16 /* PC-relative branch */
R_RISCV_JAL R_RISCV = 17 /* PC-relative jump */
R_RISCV_CALL R_RISCV = 18 /* PC-relative call */
R_RISCV_CALL_PLT R_RISCV = 19 /* PC-relative call (PLT) */
R_RISCV_GOT_HI20 R_RISCV = 20 /* PC-relative GOT reference */
R_RISCV_TLS_GOT_HI20 R_RISCV = 21 /* PC-relative TLS IE GOT offset */
R_RISCV_TLS_GD_HI20 R_RISCV = 22 /* PC-relative TLS GD reference */
R_RISCV_PCREL_HI20 R_RISCV = 23 /* PC-relative reference */
R_RISCV_PCREL_LO12_I R_RISCV = 24 /* PC-relative reference */
R_RISCV_PCREL_LO12_S R_RISCV = 25 /* PC-relative reference */
R_RISCV_HI20 R_RISCV = 26 /* Absolute address */
R_RISCV_LO12_I R_RISCV = 27 /* Absolute address */
R_RISCV_LO12_S R_RISCV = 28 /* Absolute address */
R_RISCV_TPREL_HI20 R_RISCV = 29 /* TLS LE thread offset */
R_RISCV_TPREL_LO12_I R_RISCV = 30 /* TLS LE thread offset */
R_RISCV_TPREL_LO12_S R_RISCV = 31 /* TLS LE thread offset */
R_RISCV_TPREL_ADD R_RISCV = 32 /* TLS LE thread usage */
R_RISCV_ADD8 R_RISCV = 33 /* 8-bit label addition */
R_RISCV_ADD16 R_RISCV = 34 /* 16-bit label addition */
R_RISCV_ADD32 R_RISCV = 35 /* 32-bit label addition */
R_RISCV_ADD64 R_RISCV = 36 /* 64-bit label addition */
R_RISCV_SUB8 R_RISCV = 37 /* 8-bit label subtraction */
R_RISCV_SUB16 R_RISCV = 38 /* 16-bit label subtraction */
R_RISCV_SUB32 R_RISCV = 39 /* 32-bit label subtraction */
R_RISCV_SUB64 R_RISCV = 40 /* 64-bit label subtraction */
R_RISCV_GNU_VTINHERIT R_RISCV = 41 /* GNU C++ vtable hierarchy */
R_RISCV_GNU_VTENTRY R_RISCV = 42 /* GNU C++ vtable member usage */
R_RISCV_ALIGN R_RISCV = 43 /* Alignment statement */
R_RISCV_RVC_BRANCH R_RISCV = 44 /* PC-relative branch offset */
R_RISCV_RVC_JUMP R_RISCV = 45 /* PC-relative jump offset */
R_RISCV_RVC_LUI R_RISCV = 46 /* Absolute address */
R_RISCV_GPREL_I R_RISCV = 47 /* GP-relative reference */
R_RISCV_GPREL_S R_RISCV = 48 /* GP-relative reference */
R_RISCV_TPREL_I R_RISCV = 49 /* TP-relative TLS LE load */
R_RISCV_TPREL_S R_RISCV = 50 /* TP-relative TLS LE store */
R_RISCV_RELAX R_RISCV = 51 /* Instruction pair can be relaxed */
R_RISCV_SUB6 R_RISCV = 52 /* Local label subtraction */
R_RISCV_SET6 R_RISCV = 53 /* Local label subtraction */
R_RISCV_SET8 R_RISCV = 54 /* Local label subtraction */
R_RISCV_SET16 R_RISCV = 55 /* Local label subtraction */
R_RISCV_SET32 R_RISCV = 56 /* Local label subtraction */
R_RISCV_32_PCREL R_RISCV = 57 /* 32-bit PC relative */
)
var rriscvStrings = []intName{
{0, "R_RISCV_NONE"},
{1, "R_RISCV_32"},
{2, "R_RISCV_64"},
{3, "R_RISCV_RELATIVE"},
{4, "R_RISCV_COPY"},
{5, "R_RISCV_JUMP_SLOT"},
{6, "R_RISCV_TLS_DTPMOD32"},
{7, "R_RISCV_TLS_DTPMOD64"},
{8, "R_RISCV_TLS_DTPREL32"},
{9, "R_RISCV_TLS_DTPREL64"},
{10, "R_RISCV_TLS_TPREL32"},
{11, "R_RISCV_TLS_TPREL64"},
{16, "R_RISCV_BRANCH"},
{17, "R_RISCV_JAL"},
{18, "R_RISCV_CALL"},
{19, "R_RISCV_CALL_PLT"},
{20, "R_RISCV_GOT_HI20"},
{21, "R_RISCV_TLS_GOT_HI20"},
{22, "R_RISCV_TLS_GD_HI20"},
{23, "R_RISCV_PCREL_HI20"},
{24, "R_RISCV_PCREL_LO12_I"},
{25, "R_RISCV_PCREL_LO12_S"},
{26, "R_RISCV_HI20"},
{27, "R_RISCV_LO12_I"},
{28, "R_RISCV_LO12_S"},
{29, "R_RISCV_TPREL_HI20"},
{30, "R_RISCV_TPREL_LO12_I"},
{31, "R_RISCV_TPREL_LO12_S"},
{32, "R_RISCV_TPREL_ADD"},
{33, "R_RISCV_ADD8"},
{34, "R_RISCV_ADD16"},
{35, "R_RISCV_ADD32"},
{36, "R_RISCV_ADD64"},
{37, "R_RISCV_SUB8"},
{38, "R_RISCV_SUB16"},
{39, "R_RISCV_SUB32"},
{40, "R_RISCV_SUB64"},
{41, "R_RISCV_GNU_VTINHERIT"},
{42, "R_RISCV_GNU_VTENTRY"},
{43, "R_RISCV_ALIGN"},
{44, "R_RISCV_RVC_BRANCH"},
{45, "R_RISCV_RVC_JUMP"},
{46, "R_RISCV_RVC_LUI"},
{47, "R_RISCV_GPREL_I"},
{48, "R_RISCV_GPREL_S"},
{49, "R_RISCV_TPREL_I"},
{50, "R_RISCV_TPREL_S"},
{51, "R_RISCV_RELAX"},
{52, "R_RISCV_SUB6"},
{53, "R_RISCV_SET6"},
{54, "R_RISCV_SET8"},
{55, "R_RISCV_SET16"},
{56, "R_RISCV_SET32"},
{57, "R_RISCV_32_PCREL"},
}
func (i R_RISCV) String() string { return stringName(uint32(i), rriscvStrings, false) }
func (i R_RISCV) GoString() string { return stringName(uint32(i), rriscvStrings, true) }
// Relocation types for s390x processors.
type R_390 int
const (
R_390_NONE R_390 = 0
R_390_8 R_390 = 1
R_390_12 R_390 = 2
R_390_16 R_390 = 3
R_390_32 R_390 = 4
R_390_PC32 R_390 = 5
R_390_GOT12 R_390 = 6
R_390_GOT32 R_390 = 7
R_390_PLT32 R_390 = 8
R_390_COPY R_390 = 9
R_390_GLOB_DAT R_390 = 10
R_390_JMP_SLOT R_390 = 11
R_390_RELATIVE R_390 = 12
R_390_GOTOFF R_390 = 13
R_390_GOTPC R_390 = 14
R_390_GOT16 R_390 = 15
R_390_PC16 R_390 = 16
R_390_PC16DBL R_390 = 17
R_390_PLT16DBL R_390 = 18
R_390_PC32DBL R_390 = 19
R_390_PLT32DBL R_390 = 20
R_390_GOTPCDBL R_390 = 21
R_390_64 R_390 = 22
R_390_PC64 R_390 = 23
R_390_GOT64 R_390 = 24
R_390_PLT64 R_390 = 25
R_390_GOTENT R_390 = 26
R_390_GOTOFF16 R_390 = 27
R_390_GOTOFF64 R_390 = 28
R_390_GOTPLT12 R_390 = 29
R_390_GOTPLT16 R_390 = 30
R_390_GOTPLT32 R_390 = 31
R_390_GOTPLT64 R_390 = 32
R_390_GOTPLTENT R_390 = 33
R_390_GOTPLTOFF16 R_390 = 34
R_390_GOTPLTOFF32 R_390 = 35
R_390_GOTPLTOFF64 R_390 = 36
R_390_TLS_LOAD R_390 = 37
R_390_TLS_GDCALL R_390 = 38
R_390_TLS_LDCALL R_390 = 39
R_390_TLS_GD32 R_390 = 40
R_390_TLS_GD64 R_390 = 41
R_390_TLS_GOTIE12 R_390 = 42
R_390_TLS_GOTIE32 R_390 = 43
R_390_TLS_GOTIE64 R_390 = 44
R_390_TLS_LDM32 R_390 = 45
R_390_TLS_LDM64 R_390 = 46
R_390_TLS_IE32 R_390 = 47
R_390_TLS_IE64 R_390 = 48
R_390_TLS_IEENT R_390 = 49
R_390_TLS_LE32 R_390 = 50
R_390_TLS_LE64 R_390 = 51
R_390_TLS_LDO32 R_390 = 52
R_390_TLS_LDO64 R_390 = 53
R_390_TLS_DTPMOD R_390 = 54
R_390_TLS_DTPOFF R_390 = 55
R_390_TLS_TPOFF R_390 = 56
R_390_20 R_390 = 57
R_390_GOT20 R_390 = 58
R_390_GOTPLT20 R_390 = 59
R_390_TLS_GOTIE20 R_390 = 60
)
var r390Strings = []intName{
{0, "R_390_NONE"},
{1, "R_390_8"},
{2, "R_390_12"},
{3, "R_390_16"},
{4, "R_390_32"},
{5, "R_390_PC32"},
{6, "R_390_GOT12"},
{7, "R_390_GOT32"},
{8, "R_390_PLT32"},
{9, "R_390_COPY"},
{10, "R_390_GLOB_DAT"},
{11, "R_390_JMP_SLOT"},
{12, "R_390_RELATIVE"},
{13, "R_390_GOTOFF"},
{14, "R_390_GOTPC"},
{15, "R_390_GOT16"},
{16, "R_390_PC16"},
{17, "R_390_PC16DBL"},
{18, "R_390_PLT16DBL"},
{19, "R_390_PC32DBL"},
{20, "R_390_PLT32DBL"},
{21, "R_390_GOTPCDBL"},
{22, "R_390_64"},
{23, "R_390_PC64"},
{24, "R_390_GOT64"},
{25, "R_390_PLT64"},
{26, "R_390_GOTENT"},
{27, "R_390_GOTOFF16"},
{28, "R_390_GOTOFF64"},
{29, "R_390_GOTPLT12"},
{30, "R_390_GOTPLT16"},
{31, "R_390_GOTPLT32"},
{32, "R_390_GOTPLT64"},
{33, "R_390_GOTPLTENT"},
{34, "R_390_GOTPLTOFF16"},
{35, "R_390_GOTPLTOFF32"},
{36, "R_390_GOTPLTOFF64"},
{37, "R_390_TLS_LOAD"},
{38, "R_390_TLS_GDCALL"},
{39, "R_390_TLS_LDCALL"},
{40, "R_390_TLS_GD32"},
{41, "R_390_TLS_GD64"},
{42, "R_390_TLS_GOTIE12"},
{43, "R_390_TLS_GOTIE32"},
{44, "R_390_TLS_GOTIE64"},
{45, "R_390_TLS_LDM32"},
{46, "R_390_TLS_LDM64"},
{47, "R_390_TLS_IE32"},
{48, "R_390_TLS_IE64"},
{49, "R_390_TLS_IEENT"},
{50, "R_390_TLS_LE32"},
{51, "R_390_TLS_LE64"},
{52, "R_390_TLS_LDO32"},
{53, "R_390_TLS_LDO64"},
{54, "R_390_TLS_DTPMOD"},
{55, "R_390_TLS_DTPOFF"},
{56, "R_390_TLS_TPOFF"},
{57, "R_390_20"},
{58, "R_390_GOT20"},
{59, "R_390_GOTPLT20"},
{60, "R_390_TLS_GOTIE20"},
}
func (i R_390) String() string { return stringName(uint32(i), r390Strings, false) }
func (i R_390) GoString() string { return stringName(uint32(i), r390Strings, true) }
// Relocation types for SPARC.
type R_SPARC int
const (
R_SPARC_NONE R_SPARC = 0
R_SPARC_8 R_SPARC = 1
R_SPARC_16 R_SPARC = 2
R_SPARC_32 R_SPARC = 3
R_SPARC_DISP8 R_SPARC = 4
R_SPARC_DISP16 R_SPARC = 5
R_SPARC_DISP32 R_SPARC = 6
R_SPARC_WDISP30 R_SPARC = 7
R_SPARC_WDISP22 R_SPARC = 8
R_SPARC_HI22 R_SPARC = 9
R_SPARC_22 R_SPARC = 10
R_SPARC_13 R_SPARC = 11
R_SPARC_LO10 R_SPARC = 12
R_SPARC_GOT10 R_SPARC = 13
R_SPARC_GOT13 R_SPARC = 14
R_SPARC_GOT22 R_SPARC = 15
R_SPARC_PC10 R_SPARC = 16
R_SPARC_PC22 R_SPARC = 17
R_SPARC_WPLT30 R_SPARC = 18
R_SPARC_COPY R_SPARC = 19
R_SPARC_GLOB_DAT R_SPARC = 20
R_SPARC_JMP_SLOT R_SPARC = 21
R_SPARC_RELATIVE R_SPARC = 22
R_SPARC_UA32 R_SPARC = 23
R_SPARC_PLT32 R_SPARC = 24
R_SPARC_HIPLT22 R_SPARC = 25
R_SPARC_LOPLT10 R_SPARC = 26
R_SPARC_PCPLT32 R_SPARC = 27
R_SPARC_PCPLT22 R_SPARC = 28
R_SPARC_PCPLT10 R_SPARC = 29
R_SPARC_10 R_SPARC = 30
R_SPARC_11 R_SPARC = 31
R_SPARC_64 R_SPARC = 32
R_SPARC_OLO10 R_SPARC = 33
R_SPARC_HH22 R_SPARC = 34
R_SPARC_HM10 R_SPARC = 35
R_SPARC_LM22 R_SPARC = 36
R_SPARC_PC_HH22 R_SPARC = 37
R_SPARC_PC_HM10 R_SPARC = 38
R_SPARC_PC_LM22 R_SPARC = 39
R_SPARC_WDISP16 R_SPARC = 40
R_SPARC_WDISP19 R_SPARC = 41
R_SPARC_GLOB_JMP R_SPARC = 42
R_SPARC_7 R_SPARC = 43
R_SPARC_5 R_SPARC = 44
R_SPARC_6 R_SPARC = 45
R_SPARC_DISP64 R_SPARC = 46
R_SPARC_PLT64 R_SPARC = 47
R_SPARC_HIX22 R_SPARC = 48
R_SPARC_LOX10 R_SPARC = 49
R_SPARC_H44 R_SPARC = 50
R_SPARC_M44 R_SPARC = 51
R_SPARC_L44 R_SPARC = 52
R_SPARC_REGISTER R_SPARC = 53
R_SPARC_UA64 R_SPARC = 54
R_SPARC_UA16 R_SPARC = 55
)
var rsparcStrings = []intName{
{0, "R_SPARC_NONE"},
{1, "R_SPARC_8"},
{2, "R_SPARC_16"},
{3, "R_SPARC_32"},
{4, "R_SPARC_DISP8"},
{5, "R_SPARC_DISP16"},
{6, "R_SPARC_DISP32"},
{7, "R_SPARC_WDISP30"},
{8, "R_SPARC_WDISP22"},
{9, "R_SPARC_HI22"},
{10, "R_SPARC_22"},
{11, "R_SPARC_13"},
{12, "R_SPARC_LO10"},
{13, "R_SPARC_GOT10"},
{14, "R_SPARC_GOT13"},
{15, "R_SPARC_GOT22"},
{16, "R_SPARC_PC10"},
{17, "R_SPARC_PC22"},
{18, "R_SPARC_WPLT30"},
{19, "R_SPARC_COPY"},
{20, "R_SPARC_GLOB_DAT"},
{21, "R_SPARC_JMP_SLOT"},
{22, "R_SPARC_RELATIVE"},
{23, "R_SPARC_UA32"},
{24, "R_SPARC_PLT32"},
{25, "R_SPARC_HIPLT22"},
{26, "R_SPARC_LOPLT10"},
{27, "R_SPARC_PCPLT32"},
{28, "R_SPARC_PCPLT22"},
{29, "R_SPARC_PCPLT10"},
{30, "R_SPARC_10"},
{31, "R_SPARC_11"},
{32, "R_SPARC_64"},
{33, "R_SPARC_OLO10"},
{34, "R_SPARC_HH22"},
{35, "R_SPARC_HM10"},
{36, "R_SPARC_LM22"},
{37, "R_SPARC_PC_HH22"},
{38, "R_SPARC_PC_HM10"},
{39, "R_SPARC_PC_LM22"},
{40, "R_SPARC_WDISP16"},
{41, "R_SPARC_WDISP19"},
{42, "R_SPARC_GLOB_JMP"},
{43, "R_SPARC_7"},
{44, "R_SPARC_5"},
{45, "R_SPARC_6"},
{46, "R_SPARC_DISP64"},
{47, "R_SPARC_PLT64"},
{48, "R_SPARC_HIX22"},
{49, "R_SPARC_LOX10"},
{50, "R_SPARC_H44"},
{51, "R_SPARC_M44"},
{52, "R_SPARC_L44"},
{53, "R_SPARC_REGISTER"},
{54, "R_SPARC_UA64"},
{55, "R_SPARC_UA16"},
}
func (i R_SPARC) String() string { return stringName(uint32(i), rsparcStrings, false) }
func (i R_SPARC) GoString() string { return stringName(uint32(i), rsparcStrings, true) }
// Magic number for the elf trampoline, chosen wisely to be an immediate value.
const ARM_MAGIC_TRAMP_NUMBER = 0x5c000003
// ELF32 File header.
type Header32 struct {
Ident [EI_NIDENT]byte /* File identification. */
Type uint16 /* File type. */
Machine uint16 /* Machine architecture. */
Version uint32 /* ELF format version. */
Entry uint32 /* Entry point. */
Phoff uint32 /* Program header file offset. */
Shoff uint32 /* Section header file offset. */
Flags uint32 /* Architecture-specific flags. */
Ehsize uint16 /* Size of ELF header in bytes. */
Phentsize uint16 /* Size of program header entry. */
Phnum uint16 /* Number of program header entries. */
Shentsize uint16 /* Size of section header entry. */
Shnum uint16 /* Number of section header entries. */
Shstrndx uint16 /* Section name strings section. */
}
// ELF32 Section header.
type Section32 struct {
Name uint32 /* Section name (index into the section header string table). */
Type uint32 /* Section type. */
Flags uint32 /* Section flags. */
Addr uint32 /* Address in memory image. */
Off uint32 /* Offset in file. */
Size uint32 /* Size in bytes. */
Link uint32 /* Index of a related section. */
Info uint32 /* Depends on section type. */
Addralign uint32 /* Alignment in bytes. */
Entsize uint32 /* Size of each entry in section. */
}
// ELF32 Program header.
type Prog32 struct {
Type uint32 /* Entry type. */
Off uint32 /* File offset of contents. */
Vaddr uint32 /* Virtual address in memory image. */
Paddr uint32 /* Physical address (not used). */
Filesz uint32 /* Size of contents in file. */
Memsz uint32 /* Size of contents in memory. */
Flags uint32 /* Access permission flags. */
Align uint32 /* Alignment in memory and file. */
}
// ELF32 Dynamic structure. The ".dynamic" section contains an array of them.
type Dyn32 struct {
Tag int32 /* Entry type. */
Val uint32 /* Integer/Address value. */
}
// ELF32 Compression header.
type Chdr32 struct {
Type uint32
Size uint32
Addralign uint32
}
/*
* Relocation entries.
*/
// ELF32 Relocations that don't need an addend field.
type Rel32 struct {
Off uint32 /* Location to be relocated. */
Info uint32 /* Relocation type and symbol index. */
}
// ELF32 Relocations that need an addend field.
type Rela32 struct {
Off uint32 /* Location to be relocated. */
Info uint32 /* Relocation type and symbol index. */
Addend int32 /* Addend. */
}
func R_SYM32(info uint32) uint32 { return info >> 8 }
func R_TYPE32(info uint32) uint32 { return info & 0xff }
func R_INFO32(sym, typ uint32) uint32 { return sym<<8 | typ }
// ELF32 Symbol.
type Sym32 struct {
Name uint32
Value uint32
Size uint32
Info uint8
Other uint8
Shndx uint16
}
const Sym32Size = 16
func ST_BIND(info uint8) SymBind { return SymBind(info >> 4) }
func ST_TYPE(info uint8) SymType { return SymType(info & 0xF) }
func ST_INFO(bind SymBind, typ SymType) uint8 {
return uint8(bind)<<4 | uint8(typ)&0xf
}
func ST_VISIBILITY(other uint8) SymVis { return SymVis(other & 3) }
/*
* ELF64
*/
// ELF64 file header.
type Header64 struct {
Ident [EI_NIDENT]byte /* File identification. */
Type uint16 /* File type. */
Machine uint16 /* Machine architecture. */
Version uint32 /* ELF format version. */
Entry uint64 /* Entry point. */
Phoff uint64 /* Program header file offset. */
Shoff uint64 /* Section header file offset. */
Flags uint32 /* Architecture-specific flags. */
Ehsize uint16 /* Size of ELF header in bytes. */
Phentsize uint16 /* Size of program header entry. */
Phnum uint16 /* Number of program header entries. */
Shentsize uint16 /* Size of section header entry. */
Shnum uint16 /* Number of section header entries. */
Shstrndx uint16 /* Section name strings section. */
}
// ELF64 Section header.
type Section64 struct {
Name uint32 /* Section name (index into the section header string table). */
Type uint32 /* Section type. */
Flags uint64 /* Section flags. */
Addr uint64 /* Address in memory image. */
Off uint64 /* Offset in file. */
Size uint64 /* Size in bytes. */
Link uint32 /* Index of a related section. */
Info uint32 /* Depends on section type. */
Addralign uint64 /* Alignment in bytes. */
Entsize uint64 /* Size of each entry in section. */
}
// ELF64 Program header.
type Prog64 struct {
Type uint32 /* Entry type. */
Flags uint32 /* Access permission flags. */
Off uint64 /* File offset of contents. */
Vaddr uint64 /* Virtual address in memory image. */
Paddr uint64 /* Physical address (not used). */
Filesz uint64 /* Size of contents in file. */
Memsz uint64 /* Size of contents in memory. */
Align uint64 /* Alignment in memory and file. */
}
// ELF64 Dynamic structure. The ".dynamic" section contains an array of them.
type Dyn64 struct {
Tag int64 /* Entry type. */
Val uint64 /* Integer/address value */
}
// ELF64 Compression header.
type Chdr64 struct {
Type uint32
_ uint32 /* Reserved. */
Size uint64
Addralign uint64
}
/*
* Relocation entries.
*/
/* ELF64 relocations that don't need an addend field. */
type Rel64 struct {
Off uint64 /* Location to be relocated. */
Info uint64 /* Relocation type and symbol index. */
}
/* ELF64 relocations that need an addend field. */
type Rela64 struct {
Off uint64 /* Location to be relocated. */
Info uint64 /* Relocation type and symbol index. */
Addend int64 /* Addend. */
}
func R_SYM64(info uint64) uint32 { return uint32(info >> 32) }
func R_TYPE64(info uint64) uint32 { return uint32(info) }
func R_INFO(sym, typ uint32) uint64 { return uint64(sym)<<32 | uint64(typ) }
// ELF64 symbol table entries.
type Sym64 struct {
Name uint32 /* String table index of name. */
Info uint8 /* Type and binding information. */
Other uint8 /* Reserved (not used). */
Shndx uint16 /* Section index of symbol. */
Value uint64 /* Symbol value. */
Size uint64 /* Size of associated object. */
}
const Sym64Size = 24
type intName struct {
i uint32
s string
}
// Dynamic version flags.
type DynamicVersionFlag uint16
const (
VER_FLG_BASE DynamicVersionFlag = 0x1 /* Version definition of the file. */
VER_FLG_WEAK DynamicVersionFlag = 0x2 /* Weak version identifier. */
VER_FLG_INFO DynamicVersionFlag = 0x4 /* Reference exists for informational purposes. */
)
func stringName(i uint32, names []intName, goSyntax bool) string {
for _, n := range names {
if n.i == i {
if goSyntax {
return "elf." + n.s
}
return n.s
}
}
// second pass - look for smaller to add with.
// assume sorted already
for j := len(names) - 1; j >= 0; j-- {
n := names[j]
if n.i < i {
s := n.s
if goSyntax {
s = "elf." + s
}
return s + "+" + strconv.FormatUint(uint64(i-n.i), 10)
}
}
return strconv.FormatUint(uint64(i), 10)
}
func flagName(i uint32, names []intName, goSyntax bool) string {
s := ""
for _, n := range names {
if n.i&i == n.i {
if len(s) > 0 {
s += "+"
}
if goSyntax {
s += "elf."
}
s += n.s
i -= n.i
}
}
if len(s) == 0 {
return "0x" + strconv.FormatUint(uint64(i), 16)
}
if i != 0 {
s += "+0x" + strconv.FormatUint(uint64(i), 16)
}
return s
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package elf implements access to ELF object files.
# Security
This package is not designed to be hardened against adversarial inputs, and is
outside the scope of https://go.dev/security/policy. In particular, only basic
validation is done when parsing object files. As such, care should be taken when
parsing untrusted inputs, as parsing malformed files may consume significant
resources, or cause panics.
*/
package elf
import (
"bytes"
"compress/zlib"
"debug/dwarf"
"encoding/binary"
"errors"
"fmt"
"internal/saferio"
"internal/zstd"
"io"
"os"
"strings"
"unsafe"
)
// TODO: error reporting detail
/*
* Internal ELF representation
*/
// A FileHeader represents an ELF file header.
type FileHeader struct {
Class Class
Data Data
Version Version
OSABI OSABI
ABIVersion uint8
ByteOrder binary.ByteOrder
Type Type
Machine Machine
Entry uint64
}
// A File represents an open ELF file.
type File struct {
FileHeader
Sections []*Section
Progs []*Prog
closer io.Closer
dynVers []DynamicVersion
dynVerNeeds []DynamicVersionNeed
gnuVersym []byte
}
// A SectionHeader represents a single ELF section header.
type SectionHeader struct {
Name string
Type SectionType
Flags SectionFlag
Addr uint64
Offset uint64
Size uint64
Link uint32
Info uint32
Addralign uint64
Entsize uint64
// FileSize is the size of this section in the file in bytes.
// If a section is compressed, FileSize is the size of the
// compressed data, while Size (above) is the size of the
// uncompressed data.
FileSize uint64
}
// A Section represents a single section in an ELF file.
type Section struct {
SectionHeader
// Embed ReaderAt for ReadAt method.
// Do not embed SectionReader directly
// to avoid having Read and Seek.
// If a client wants Read and Seek it must use
// Open() to avoid fighting over the seek offset
// with other clients.
//
// ReaderAt may be nil if the section is not easily available
// in a random-access form. For example, a compressed section
// may have a nil ReaderAt.
io.ReaderAt
sr *io.SectionReader
compressionType CompressionType
compressionOffset int64
}
// Data reads and returns the contents of the ELF section.
// Even if the section is stored compressed in the ELF file,
// Data returns uncompressed data.
//
// For an [SHT_NOBITS] section, Data always returns a non-nil error.
func (s *Section) Data() ([]byte, error) {
return saferio.ReadData(s.Open(), s.Size)
}
// stringTable reads and returns the string table given by the
// specified link value.
func (f *File) stringTable(link uint32) ([]byte, error) {
if link <= 0 || link >= uint32(len(f.Sections)) {
return nil, errors.New("section has invalid string table link")
}
return f.Sections[link].Data()
}
// Open returns a new ReadSeeker reading the ELF section.
// Even if the section is stored compressed in the ELF file,
// the ReadSeeker reads uncompressed data.
//
// For an [SHT_NOBITS] section, all calls to the opened reader
// will return a non-nil error.
func (s *Section) Open() io.ReadSeeker {
if s.Type == SHT_NOBITS {
return io.NewSectionReader(&nobitsSectionReader{}, 0, int64(s.Size))
}
var zrd func(io.Reader) (io.ReadCloser, error)
if s.Flags&SHF_COMPRESSED == 0 {
if !strings.HasPrefix(s.Name, ".zdebug") {
return io.NewSectionReader(s.sr, 0, 1<<63-1)
}
b := make([]byte, 12)
n, _ := s.sr.ReadAt(b, 0)
if n != 12 || string(b[:4]) != "ZLIB" {
return io.NewSectionReader(s.sr, 0, 1<<63-1)
}
s.compressionOffset = 12
s.compressionType = COMPRESS_ZLIB
s.Size = binary.BigEndian.Uint64(b[4:12])
zrd = zlib.NewReader
} else if s.Flags&SHF_ALLOC != 0 {
return errorReader{&FormatError{int64(s.Offset),
"SHF_COMPRESSED applies only to non-allocable sections", s.compressionType}}
}
switch s.compressionType {
case COMPRESS_ZLIB:
zrd = zlib.NewReader
case COMPRESS_ZSTD:
zrd = func(r io.Reader) (io.ReadCloser, error) {
return io.NopCloser(zstd.NewReader(r)), nil
}
}
if zrd == nil {
return errorReader{&FormatError{int64(s.Offset), "unknown compression type", s.compressionType}}
}
return &readSeekerFromReader{
reset: func() (io.Reader, error) {
fr := io.NewSectionReader(s.sr, s.compressionOffset, int64(s.FileSize)-s.compressionOffset)
return zrd(fr)
},
size: int64(s.Size),
}
}
// A ProgHeader represents a single ELF program header.
type ProgHeader struct {
Type ProgType
Flags ProgFlag
Off uint64
Vaddr uint64
Paddr uint64
Filesz uint64
Memsz uint64
Align uint64
}
// A Prog represents a single ELF program header in an ELF binary.
type Prog struct {
ProgHeader
// Embed ReaderAt for ReadAt method.
// Do not embed SectionReader directly
// to avoid having Read and Seek.
// If a client wants Read and Seek it must use
// Open() to avoid fighting over the seek offset
// with other clients.
io.ReaderAt
sr *io.SectionReader
}
// Open returns a new ReadSeeker reading the ELF program body.
func (p *Prog) Open() io.ReadSeeker { return io.NewSectionReader(p.sr, 0, 1<<63-1) }
// A Symbol represents an entry in an ELF symbol table section.
type Symbol struct {
Name string
Info, Other byte
// HasVersion reports whether the symbol has any version information.
// This will only be true for the dynamic symbol table.
HasVersion bool
// VersionIndex is the symbol's version index.
// Use the methods of the [VersionIndex] type to access it.
// This field is only meaningful if HasVersion is true.
VersionIndex VersionIndex
Section SectionIndex
Value, Size uint64
// These fields are present only for the dynamic symbol table.
Version string
Library string
}
/*
* ELF reader
*/
type FormatError struct {
off int64
msg string
val any
}
func (e *FormatError) Error() string {
msg := e.msg
if e.val != nil {
msg += fmt.Sprintf(" '%v' ", e.val)
}
msg += fmt.Sprintf("in record at byte %#x", e.off)
return msg
}
// Open opens the named file using [os.Open] and prepares it for use as an ELF binary.
func Open(name string) (*File, error) {
f, err := os.Open(name)
if err != nil {
return nil, err
}
ff, err := NewFile(f)
if err != nil {
f.Close()
return nil, err
}
ff.closer = f
return ff, nil
}
// Close closes the [File].
// If the [File] was created using [NewFile] directly instead of [Open],
// Close has no effect.
func (f *File) Close() error {
var err error
if f.closer != nil {
err = f.closer.Close()
f.closer = nil
}
return err
}
// SectionByType returns the first section in f with the
// given type, or nil if there is no such section.
func (f *File) SectionByType(typ SectionType) *Section {
for _, s := range f.Sections {
if s.Type == typ {
return s
}
}
return nil
}
// NewFile creates a new [File] for accessing an ELF binary in an underlying reader.
// The ELF binary is expected to start at position 0 in the ReaderAt.
func NewFile(r io.ReaderAt) (*File, error) {
sr := io.NewSectionReader(r, 0, 1<<63-1)
// Read and decode ELF identifier
var ident [16]uint8
if _, err := r.ReadAt(ident[0:], 0); err != nil {
return nil, err
}
if ident[0] != '\x7f' || ident[1] != 'E' || ident[2] != 'L' || ident[3] != 'F' {
return nil, &FormatError{0, "bad magic number", ident[0:4]}
}
f := new(File)
f.Class = Class(ident[EI_CLASS])
switch f.Class {
case ELFCLASS32:
case ELFCLASS64:
// ok
default:
return nil, &FormatError{0, "unknown ELF class", f.Class}
}
f.Data = Data(ident[EI_DATA])
var bo binary.ByteOrder
switch f.Data {
case ELFDATA2LSB:
bo = binary.LittleEndian
case ELFDATA2MSB:
bo = binary.BigEndian
default:
return nil, &FormatError{0, "unknown ELF data encoding", f.Data}
}
f.ByteOrder = bo
f.Version = Version(ident[EI_VERSION])
if f.Version != EV_CURRENT {
return nil, &FormatError{0, "unknown ELF version", f.Version}
}
f.OSABI = OSABI(ident[EI_OSABI])
f.ABIVersion = ident[EI_ABIVERSION]
// Read ELF file header
var phoff int64
var phentsize, phnum int
var shoff int64
var shentsize, shnum, shstrndx int
switch f.Class {
case ELFCLASS32:
var hdr Header32
data := make([]byte, unsafe.Sizeof(hdr))
if _, err := sr.ReadAt(data, 0); err != nil {
return nil, err
}
f.Type = Type(bo.Uint16(data[unsafe.Offsetof(hdr.Type):]))
f.Machine = Machine(bo.Uint16(data[unsafe.Offsetof(hdr.Machine):]))
f.Entry = uint64(bo.Uint32(data[unsafe.Offsetof(hdr.Entry):]))
if v := Version(bo.Uint32(data[unsafe.Offsetof(hdr.Version):])); v != f.Version {
return nil, &FormatError{0, "mismatched ELF version", v}
}
phoff = int64(bo.Uint32(data[unsafe.Offsetof(hdr.Phoff):]))
phentsize = int(bo.Uint16(data[unsafe.Offsetof(hdr.Phentsize):]))
phnum = int(bo.Uint16(data[unsafe.Offsetof(hdr.Phnum):]))
shoff = int64(bo.Uint32(data[unsafe.Offsetof(hdr.Shoff):]))
shentsize = int(bo.Uint16(data[unsafe.Offsetof(hdr.Shentsize):]))
shnum = int(bo.Uint16(data[unsafe.Offsetof(hdr.Shnum):]))
shstrndx = int(bo.Uint16(data[unsafe.Offsetof(hdr.Shstrndx):]))
case ELFCLASS64:
var hdr Header64
data := make([]byte, unsafe.Sizeof(hdr))
if _, err := sr.ReadAt(data, 0); err != nil {
return nil, err
}
f.Type = Type(bo.Uint16(data[unsafe.Offsetof(hdr.Type):]))
f.Machine = Machine(bo.Uint16(data[unsafe.Offsetof(hdr.Machine):]))
f.Entry = bo.Uint64(data[unsafe.Offsetof(hdr.Entry):])
if v := Version(bo.Uint32(data[unsafe.Offsetof(hdr.Version):])); v != f.Version {
return nil, &FormatError{0, "mismatched ELF version", v}
}
phoff = int64(bo.Uint64(data[unsafe.Offsetof(hdr.Phoff):]))
phentsize = int(bo.Uint16(data[unsafe.Offsetof(hdr.Phentsize):]))
phnum = int(bo.Uint16(data[unsafe.Offsetof(hdr.Phnum):]))
shoff = int64(bo.Uint64(data[unsafe.Offsetof(hdr.Shoff):]))
shentsize = int(bo.Uint16(data[unsafe.Offsetof(hdr.Shentsize):]))
shnum = int(bo.Uint16(data[unsafe.Offsetof(hdr.Shnum):]))
shstrndx = int(bo.Uint16(data[unsafe.Offsetof(hdr.Shstrndx):]))
}
if shoff < 0 {
return nil, &FormatError{0, "invalid shoff", shoff}
}
if phoff < 0 {
return nil, &FormatError{0, "invalid phoff", phoff}
}
if shoff == 0 && shnum != 0 {
return nil, &FormatError{0, "invalid ELF shnum for shoff=0", shnum}
}
if shnum > 0 && shstrndx >= shnum {
return nil, &FormatError{0, "invalid ELF shstrndx", shstrndx}
}
var wantPhentsize, wantShentsize int
switch f.Class {
case ELFCLASS32:
wantPhentsize = 8 * 4
wantShentsize = 10 * 4
case ELFCLASS64:
wantPhentsize = 2*4 + 6*8
wantShentsize = 4*4 + 6*8
}
if phnum > 0 && phentsize < wantPhentsize {
return nil, &FormatError{0, "invalid ELF phentsize", phentsize}
}
// Read program headers
f.Progs = make([]*Prog, phnum)
phdata, err := saferio.ReadDataAt(sr, uint64(phnum)*uint64(phentsize), phoff)
if err != nil {
return nil, err
}
for i := 0; i < phnum; i++ {
off := uintptr(i) * uintptr(phentsize)
p := new(Prog)
switch f.Class {
case ELFCLASS32:
var ph Prog32
p.ProgHeader = ProgHeader{
Type: ProgType(bo.Uint32(phdata[off+unsafe.Offsetof(ph.Type):])),
Flags: ProgFlag(bo.Uint32(phdata[off+unsafe.Offsetof(ph.Flags):])),
Off: uint64(bo.Uint32(phdata[off+unsafe.Offsetof(ph.Off):])),
Vaddr: uint64(bo.Uint32(phdata[off+unsafe.Offsetof(ph.Vaddr):])),
Paddr: uint64(bo.Uint32(phdata[off+unsafe.Offsetof(ph.Paddr):])),
Filesz: uint64(bo.Uint32(phdata[off+unsafe.Offsetof(ph.Filesz):])),
Memsz: uint64(bo.Uint32(phdata[off+unsafe.Offsetof(ph.Memsz):])),
Align: uint64(bo.Uint32(phdata[off+unsafe.Offsetof(ph.Align):])),
}
case ELFCLASS64:
var ph Prog64
p.ProgHeader = ProgHeader{
Type: ProgType(bo.Uint32(phdata[off+unsafe.Offsetof(ph.Type):])),
Flags: ProgFlag(bo.Uint32(phdata[off+unsafe.Offsetof(ph.Flags):])),
Off: bo.Uint64(phdata[off+unsafe.Offsetof(ph.Off):]),
Vaddr: bo.Uint64(phdata[off+unsafe.Offsetof(ph.Vaddr):]),
Paddr: bo.Uint64(phdata[off+unsafe.Offsetof(ph.Paddr):]),
Filesz: bo.Uint64(phdata[off+unsafe.Offsetof(ph.Filesz):]),
Memsz: bo.Uint64(phdata[off+unsafe.Offsetof(ph.Memsz):]),
Align: bo.Uint64(phdata[off+unsafe.Offsetof(ph.Align):]),
}
}
if int64(p.Off) < 0 {
return nil, &FormatError{phoff + int64(off), "invalid program header offset", p.Off}
}
if int64(p.Filesz) < 0 {
return nil, &FormatError{phoff + int64(off), "invalid program header file size", p.Filesz}
}
p.sr = io.NewSectionReader(r, int64(p.Off), int64(p.Filesz))
p.ReaderAt = p.sr
f.Progs[i] = p
}
// If the number of sections is greater than or equal to SHN_LORESERVE
// (0xff00), shnum has the value zero and the actual number of section
// header table entries is contained in the sh_size field of the section
// header at index 0.
if shoff > 0 && shnum == 0 {
var typ, link uint32
sr.Seek(shoff, io.SeekStart)
switch f.Class {
case ELFCLASS32:
sh := new(Section32)
if err := binary.Read(sr, bo, sh); err != nil {
return nil, err
}
shnum = int(sh.Size)
typ = sh.Type
link = sh.Link
case ELFCLASS64:
sh := new(Section64)
if err := binary.Read(sr, bo, sh); err != nil {
return nil, err
}
shnum = int(sh.Size)
typ = sh.Type
link = sh.Link
}
if SectionType(typ) != SHT_NULL {
return nil, &FormatError{shoff, "invalid type of the initial section", SectionType(typ)}
}
if shnum < int(SHN_LORESERVE) {
return nil, &FormatError{shoff, "invalid ELF shnum contained in sh_size", shnum}
}
// If the section name string table section index is greater than or
// equal to SHN_LORESERVE (0xff00), this member has the value
// SHN_XINDEX (0xffff) and the actual index of the section name
// string table section is contained in the sh_link field of the
// section header at index 0.
if shstrndx == int(SHN_XINDEX) {
shstrndx = int(link)
if shstrndx < int(SHN_LORESERVE) {
return nil, &FormatError{shoff, "invalid ELF shstrndx contained in sh_link", shstrndx}
}
}
}
if shnum > 0 && shentsize < wantShentsize {
return nil, &FormatError{0, "invalid ELF shentsize", shentsize}
}
// Read section headers
c := saferio.SliceCap[Section](uint64(shnum))
if c < 0 {
return nil, &FormatError{0, "too many sections", shnum}
}
if shnum > 0 && ((1<<64)-1)/uint64(shnum) < uint64(shentsize) {
return nil, &FormatError{0, "section header overflow", shnum}
}
f.Sections = make([]*Section, 0, c)
names := make([]uint32, 0, c)
shdata, err := saferio.ReadDataAt(sr, uint64(shnum)*uint64(shentsize), shoff)
if err != nil {
return nil, err
}
for i := 0; i < shnum; i++ {
off := uintptr(i) * uintptr(shentsize)
s := new(Section)
switch f.Class {
case ELFCLASS32:
var sh Section32
names = append(names, bo.Uint32(shdata[off+unsafe.Offsetof(sh.Name):]))
s.SectionHeader = SectionHeader{
Type: SectionType(bo.Uint32(shdata[off+unsafe.Offsetof(sh.Type):])),
Flags: SectionFlag(bo.Uint32(shdata[off+unsafe.Offsetof(sh.Flags):])),
Addr: uint64(bo.Uint32(shdata[off+unsafe.Offsetof(sh.Addr):])),
Offset: uint64(bo.Uint32(shdata[off+unsafe.Offsetof(sh.Off):])),
FileSize: uint64(bo.Uint32(shdata[off+unsafe.Offsetof(sh.Size):])),
Link: bo.Uint32(shdata[off+unsafe.Offsetof(sh.Link):]),
Info: bo.Uint32(shdata[off+unsafe.Offsetof(sh.Info):]),
Addralign: uint64(bo.Uint32(shdata[off+unsafe.Offsetof(sh.Addralign):])),
Entsize: uint64(bo.Uint32(shdata[off+unsafe.Offsetof(sh.Entsize):])),
}
case ELFCLASS64:
var sh Section64
names = append(names, bo.Uint32(shdata[off+unsafe.Offsetof(sh.Name):]))
s.SectionHeader = SectionHeader{
Type: SectionType(bo.Uint32(shdata[off+unsafe.Offsetof(sh.Type):])),
Flags: SectionFlag(bo.Uint64(shdata[off+unsafe.Offsetof(sh.Flags):])),
Offset: bo.Uint64(shdata[off+unsafe.Offsetof(sh.Off):]),
FileSize: bo.Uint64(shdata[off+unsafe.Offsetof(sh.Size):]),
Addr: bo.Uint64(shdata[off+unsafe.Offsetof(sh.Addr):]),
Link: bo.Uint32(shdata[off+unsafe.Offsetof(sh.Link):]),
Info: bo.Uint32(shdata[off+unsafe.Offsetof(sh.Info):]),
Addralign: bo.Uint64(shdata[off+unsafe.Offsetof(sh.Addralign):]),
Entsize: bo.Uint64(shdata[off+unsafe.Offsetof(sh.Entsize):]),
}
}
if int64(s.Offset) < 0 {
return nil, &FormatError{shoff + int64(off), "invalid section offset", int64(s.Offset)}
}
if int64(s.FileSize) < 0 {
return nil, &FormatError{shoff + int64(off), "invalid section size", int64(s.FileSize)}
}
s.sr = io.NewSectionReader(r, int64(s.Offset), int64(s.FileSize))
if s.Flags&SHF_COMPRESSED == 0 {
s.ReaderAt = s.sr
s.Size = s.FileSize
} else {
// Read the compression header.
switch f.Class {
case ELFCLASS32:
var ch Chdr32
chdata := make([]byte, unsafe.Sizeof(ch))
if _, err := s.sr.ReadAt(chdata, 0); err != nil {
return nil, err
}
s.compressionType = CompressionType(bo.Uint32(chdata[unsafe.Offsetof(ch.Type):]))
s.Size = uint64(bo.Uint32(chdata[unsafe.Offsetof(ch.Size):]))
s.Addralign = uint64(bo.Uint32(chdata[unsafe.Offsetof(ch.Addralign):]))
s.compressionOffset = int64(unsafe.Sizeof(ch))
case ELFCLASS64:
var ch Chdr64
chdata := make([]byte, unsafe.Sizeof(ch))
if _, err := s.sr.ReadAt(chdata, 0); err != nil {
return nil, err
}
s.compressionType = CompressionType(bo.Uint32(chdata[unsafe.Offsetof(ch.Type):]))
s.Size = bo.Uint64(chdata[unsafe.Offsetof(ch.Size):])
s.Addralign = bo.Uint64(chdata[unsafe.Offsetof(ch.Addralign):])
s.compressionOffset = int64(unsafe.Sizeof(ch))
}
}
f.Sections = append(f.Sections, s)
}
if len(f.Sections) == 0 {
return f, nil
}
// Load section header string table.
if shstrndx == 0 {
// If the file has no section name string table,
// shstrndx holds the value SHN_UNDEF (0).
return f, nil
}
shstr := f.Sections[shstrndx]
if shstr.Type != SHT_STRTAB {
return nil, &FormatError{shoff + int64(shstrndx*shentsize), "invalid ELF section name string table type", shstr.Type}
}
shstrtab, err := shstr.Data()
if err != nil {
return nil, err
}
for i, s := range f.Sections {
var ok bool
s.Name, ok = getString(shstrtab, int(names[i]))
if !ok {
return nil, &FormatError{shoff + int64(i*shentsize), "bad section name index", names[i]}
}
}
return f, nil
}
// getSymbols returns a slice of Symbols from parsing the symbol table
// with the given type, along with the associated string table.
func (f *File) getSymbols(typ SectionType) ([]Symbol, []byte, error) {
switch f.Class {
case ELFCLASS64:
return f.getSymbols64(typ)
case ELFCLASS32:
return f.getSymbols32(typ)
}
return nil, nil, errors.New("not implemented")
}
// ErrNoSymbols is returned by [File.Symbols] and [File.DynamicSymbols]
// if there is no such section in the File.
var ErrNoSymbols = errors.New("no symbol section")
func (f *File) getSymbols32(typ SectionType) ([]Symbol, []byte, error) {
symtabSection := f.SectionByType(typ)
if symtabSection == nil {
return nil, nil, ErrNoSymbols
}
data, err := symtabSection.Data()
if err != nil {
return nil, nil, fmt.Errorf("cannot load symbol section: %w", err)
}
if len(data) == 0 {
return nil, nil, errors.New("symbol section is empty")
}
if len(data)%Sym32Size != 0 {
return nil, nil, errors.New("length of symbol section is not a multiple of SymSize")
}
strdata, err := f.stringTable(symtabSection.Link)
if err != nil {
return nil, nil, fmt.Errorf("cannot load string table section: %w", err)
}
// The first entry is all zeros.
data = data[Sym32Size:]
symbols := make([]Symbol, len(data)/Sym32Size)
i := 0
var sym Sym32
for len(data) > 0 {
sym.Name = f.ByteOrder.Uint32(data[0:4])
sym.Value = f.ByteOrder.Uint32(data[4:8])
sym.Size = f.ByteOrder.Uint32(data[8:12])
sym.Info = data[12]
sym.Other = data[13]
sym.Shndx = f.ByteOrder.Uint16(data[14:16])
str, _ := getString(strdata, int(sym.Name))
symbols[i].Name = str
symbols[i].Info = sym.Info
symbols[i].Other = sym.Other
symbols[i].Section = SectionIndex(sym.Shndx)
symbols[i].Value = uint64(sym.Value)
symbols[i].Size = uint64(sym.Size)
i++
data = data[Sym32Size:]
}
return symbols, strdata, nil
}
func (f *File) getSymbols64(typ SectionType) ([]Symbol, []byte, error) {
symtabSection := f.SectionByType(typ)
if symtabSection == nil {
return nil, nil, ErrNoSymbols
}
data, err := symtabSection.Data()
if err != nil {
return nil, nil, fmt.Errorf("cannot load symbol section: %w", err)
}
if len(data)%Sym64Size != 0 {
return nil, nil, errors.New("length of symbol section is not a multiple of Sym64Size")
}
strdata, err := f.stringTable(symtabSection.Link)
if err != nil {
return nil, nil, fmt.Errorf("cannot load string table section: %w", err)
}
// The first entry is all zeros.
data = data[Sym64Size:]
symbols := make([]Symbol, len(data)/Sym64Size)
i := 0
var sym Sym64
for len(data) > 0 {
sym.Name = f.ByteOrder.Uint32(data[0:4])
sym.Info = data[4]
sym.Other = data[5]
sym.Shndx = f.ByteOrder.Uint16(data[6:8])
sym.Value = f.ByteOrder.Uint64(data[8:16])
sym.Size = f.ByteOrder.Uint64(data[16:24])
str, _ := getString(strdata, int(sym.Name))
symbols[i].Name = str
symbols[i].Info = sym.Info
symbols[i].Other = sym.Other
symbols[i].Section = SectionIndex(sym.Shndx)
symbols[i].Value = sym.Value
symbols[i].Size = sym.Size
i++
data = data[Sym64Size:]
}
return symbols, strdata, nil
}
// getString extracts a string from an ELF string table.
func getString(section []byte, start int) (string, bool) {
if start < 0 || start >= len(section) {
return "", false
}
for end := start; end < len(section); end++ {
if section[end] == 0 {
return string(section[start:end]), true
}
}
return "", false
}
// Section returns a section with the given name, or nil if no such
// section exists.
func (f *File) Section(name string) *Section {
for _, s := range f.Sections {
if s.Name == name {
return s
}
}
return nil
}
// applyRelocations applies relocations to dst. rels is a relocations section
// in REL or RELA format.
func (f *File) applyRelocations(dst []byte, rels []byte) error {
switch {
case f.Class == ELFCLASS64 && f.Machine == EM_X86_64:
return f.applyRelocationsAMD64(dst, rels)
case f.Class == ELFCLASS32 && f.Machine == EM_386:
return f.applyRelocations386(dst, rels)
case f.Class == ELFCLASS32 && f.Machine == EM_ARM:
return f.applyRelocationsARM(dst, rels)
case f.Class == ELFCLASS64 && f.Machine == EM_AARCH64:
return f.applyRelocationsARM64(dst, rels)
case f.Class == ELFCLASS32 && f.Machine == EM_PPC:
return f.applyRelocationsPPC(dst, rels)
case f.Class == ELFCLASS64 && f.Machine == EM_PPC64:
return f.applyRelocationsPPC64(dst, rels)
case f.Class == ELFCLASS32 && f.Machine == EM_MIPS:
return f.applyRelocationsMIPS(dst, rels)
case f.Class == ELFCLASS64 && f.Machine == EM_MIPS:
return f.applyRelocationsMIPS64(dst, rels)
case f.Class == ELFCLASS64 && f.Machine == EM_LOONGARCH:
return f.applyRelocationsLOONG64(dst, rels)
case f.Class == ELFCLASS64 && f.Machine == EM_RISCV:
return f.applyRelocationsRISCV64(dst, rels)
case f.Class == ELFCLASS64 && f.Machine == EM_S390:
return f.applyRelocationss390x(dst, rels)
case f.Class == ELFCLASS64 && f.Machine == EM_SPARCV9:
return f.applyRelocationsSPARC64(dst, rels)
default:
return errors.New("applyRelocations: not implemented")
}
}
// canApplyRelocation reports whether we should try to apply a
// relocation to a DWARF data section, given a pointer to the symbol
// targeted by the relocation.
// Most relocations in DWARF data tend to be section-relative, but
// some target non-section symbols (for example, low_PC attrs on
// subprogram or compilation unit DIEs that target function symbols).
func canApplyRelocation(sym *Symbol) bool {
return sym.Section != SHN_UNDEF && sym.Section < SHN_LORESERVE
}
func (f *File) applyRelocationsAMD64(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
symNo := rela.Info >> 32
t := R_X86_64(rela.Info & 0xffff)
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
// There are relocations, so this must be a normal
// object file. The code below handles only basic relocations
// of the form S + A (symbol plus addend).
switch t {
case R_X86_64_64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_X86_64_32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocations386(dst []byte, rels []byte) error {
// 8 is the size of Rel32.
if len(rels)%8 != 0 {
return errors.New("length of relocation section is not a multiple of 8")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rel Rel32
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rel)
symNo := rel.Info >> 8
t := R_386(rel.Info & 0xff)
if symNo == 0 || symNo > uint32(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if t == R_386_32 {
if rel.Off+4 >= uint32(len(dst)) {
continue
}
val := f.ByteOrder.Uint32(dst[rel.Off : rel.Off+4])
val += uint32(sym.Value)
f.ByteOrder.PutUint32(dst[rel.Off:rel.Off+4], val)
}
}
return nil
}
func (f *File) applyRelocationsARM(dst []byte, rels []byte) error {
// 8 is the size of Rel32.
if len(rels)%8 != 0 {
return errors.New("length of relocation section is not a multiple of 8")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rel Rel32
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rel)
symNo := rel.Info >> 8
t := R_ARM(rel.Info & 0xff)
if symNo == 0 || symNo > uint32(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
switch t {
case R_ARM_ABS32:
if rel.Off+4 >= uint32(len(dst)) {
continue
}
val := f.ByteOrder.Uint32(dst[rel.Off : rel.Off+4])
val += uint32(sym.Value)
f.ByteOrder.PutUint32(dst[rel.Off:rel.Off+4], val)
}
}
return nil
}
func (f *File) applyRelocationsARM64(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
symNo := rela.Info >> 32
t := R_AARCH64(rela.Info & 0xffff)
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
// There are relocations, so this must be a normal
// object file. The code below handles only basic relocations
// of the form S + A (symbol plus addend).
switch t {
case R_AARCH64_ABS64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_AARCH64_ABS32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocationsPPC(dst []byte, rels []byte) error {
// 12 is the size of Rela32.
if len(rels)%12 != 0 {
return errors.New("length of relocation section is not a multiple of 12")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela32
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
symNo := rela.Info >> 8
t := R_PPC(rela.Info & 0xff)
if symNo == 0 || symNo > uint32(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
switch t {
case R_PPC_ADDR32:
if rela.Off+4 >= uint32(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocationsPPC64(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
symNo := rela.Info >> 32
t := R_PPC64(rela.Info & 0xffff)
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
switch t {
case R_PPC64_ADDR64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_PPC64_ADDR32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocationsMIPS(dst []byte, rels []byte) error {
// 8 is the size of Rel32.
if len(rels)%8 != 0 {
return errors.New("length of relocation section is not a multiple of 8")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rel Rel32
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rel)
symNo := rel.Info >> 8
t := R_MIPS(rel.Info & 0xff)
if symNo == 0 || symNo > uint32(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
switch t {
case R_MIPS_32:
if rel.Off+4 >= uint32(len(dst)) {
continue
}
val := f.ByteOrder.Uint32(dst[rel.Off : rel.Off+4])
val += uint32(sym.Value)
f.ByteOrder.PutUint32(dst[rel.Off:rel.Off+4], val)
}
}
return nil
}
func (f *File) applyRelocationsMIPS64(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
var symNo uint64
var t R_MIPS
if f.ByteOrder == binary.BigEndian {
symNo = rela.Info >> 32
t = R_MIPS(rela.Info & 0xff)
} else {
symNo = rela.Info & 0xffffffff
t = R_MIPS(rela.Info >> 56)
}
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
switch t {
case R_MIPS_64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_MIPS_32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocationsLOONG64(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
var symNo uint64
var t R_LARCH
symNo = rela.Info >> 32
t = R_LARCH(rela.Info & 0xffff)
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
switch t {
case R_LARCH_64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_LARCH_32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocationsRISCV64(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
symNo := rela.Info >> 32
t := R_RISCV(rela.Info & 0xffff)
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
switch t {
case R_RISCV_64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_RISCV_32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocationss390x(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
symNo := rela.Info >> 32
t := R_390(rela.Info & 0xffff)
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
switch t {
case R_390_64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_390_32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocationsSPARC64(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
symNo := rela.Info >> 32
t := R_SPARC(rela.Info & 0xff)
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
switch t {
case R_SPARC_64, R_SPARC_UA64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_SPARC_32, R_SPARC_UA32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) DWARF() (*dwarf.Data, error) {
dwarfSuffix := func(s *Section) string {
switch {
case strings.HasPrefix(s.Name, ".debug_"):
return s.Name[7:]
case strings.HasPrefix(s.Name, ".zdebug_"):
return s.Name[8:]
default:
return ""
}
}
// sectionData gets the data for s, checks its size, and
// applies any applicable relations.
sectionData := func(i int, s *Section) ([]byte, error) {
b, err := s.Data()
if err != nil && uint64(len(b)) < s.Size {
return nil, err
}
if f.Type == ET_EXEC {
// Do not apply relocations to DWARF sections for ET_EXEC binaries.
// Relocations should already be applied, and .rela sections may
// contain incorrect data.
return b, nil
}
for _, r := range f.Sections {
if r.Type != SHT_RELA && r.Type != SHT_REL {
continue
}
if int(r.Info) != i {
continue
}
rd, err := r.Data()
if err != nil {
return nil, err
}
err = f.applyRelocations(b, rd)
if err != nil {
return nil, err
}
}
return b, nil
}
// There are many DWARf sections, but these are the ones
// the debug/dwarf package started with.
var dat = map[string][]byte{"abbrev": nil, "info": nil, "str": nil, "line": nil, "ranges": nil}
for i, s := range f.Sections {
suffix := dwarfSuffix(s)
if suffix == "" {
continue
}
if _, ok := dat[suffix]; !ok {
continue
}
b, err := sectionData(i, s)
if err != nil {
return nil, err
}
dat[suffix] = b
}
d, err := dwarf.New(dat["abbrev"], nil, nil, dat["info"], dat["line"], nil, dat["ranges"], dat["str"])
if err != nil {
return nil, err
}
// Look for DWARF4 .debug_types sections and DWARF5 sections.
for i, s := range f.Sections {
suffix := dwarfSuffix(s)
if suffix == "" {
continue
}
if _, ok := dat[suffix]; ok {
// Already handled.
continue
}
b, err := sectionData(i, s)
if err != nil {
return nil, err
}
if suffix == "types" {
if err := d.AddTypes(fmt.Sprintf("types-%d", i), b); err != nil {
return nil, err
}
} else {
if err := d.AddSection(".debug_"+suffix, b); err != nil {
return nil, err
}
}
}
return d, nil
}
// Symbols returns the symbol table for f. The symbols will be listed in the order
// they appear in f.
//
// For compatibility with Go 1.0, Symbols omits the null symbol at index 0.
// After retrieving the symbols as symtab, an externally supplied index x
// corresponds to symtab[x-1], not symtab[x].
func (f *File) Symbols() ([]Symbol, error) {
sym, _, err := f.getSymbols(SHT_SYMTAB)
return sym, err
}
// DynamicSymbols returns the dynamic symbol table for f. The symbols
// will be listed in the order they appear in f.
//
// If f has a symbol version table, the returned [File.Symbols] will have
// initialized Version and Library fields.
//
// For compatibility with [File.Symbols], [File.DynamicSymbols] omits the null symbol at index 0.
// After retrieving the symbols as symtab, an externally supplied index x
// corresponds to symtab[x-1], not symtab[x].
func (f *File) DynamicSymbols() ([]Symbol, error) {
sym, str, err := f.getSymbols(SHT_DYNSYM)
if err != nil {
return nil, err
}
hasVersions, err := f.gnuVersionInit(str)
if err != nil {
return nil, err
}
if hasVersions {
for i := range sym {
sym[i].HasVersion, sym[i].VersionIndex, sym[i].Version, sym[i].Library = f.gnuVersion(i)
}
}
return sym, nil
}
type ImportedSymbol struct {
Name string
Version string
Library string
}
// ImportedSymbols returns the names of all symbols
// referred to by the binary f that are expected to be
// satisfied by other libraries at dynamic load time.
// It does not return weak symbols.
func (f *File) ImportedSymbols() ([]ImportedSymbol, error) {
sym, str, err := f.getSymbols(SHT_DYNSYM)
if err != nil {
return nil, err
}
if _, err := f.gnuVersionInit(str); err != nil {
return nil, err
}
var all []ImportedSymbol
for i, s := range sym {
if ST_BIND(s.Info) == STB_GLOBAL && s.Section == SHN_UNDEF {
all = append(all, ImportedSymbol{Name: s.Name})
sym := &all[len(all)-1]
_, _, sym.Version, sym.Library = f.gnuVersion(i)
}
}
return all, nil
}
// VersionIndex is the type of a [Symbol] version index.
type VersionIndex uint16
// IsHidden reports whether the symbol is hidden within the version.
// This means that the symbol can only be seen by specifying the exact version.
func (vi VersionIndex) IsHidden() bool {
return vi&0x8000 != 0
}
// Index returns the version index.
// If this is the value 0, it means that the symbol is local,
// and is not visible externally.
// If this is the value 1, it means that the symbol is in the base version,
// and has no specific version; it may or may not match a
// [DynamicVersion.Index] in the slice returned by [File.DynamicVersions].
// Other values will match either [DynamicVersion.Index]
// in the slice returned by [File.DynamicVersions],
// or [DynamicVersionDep.Index] in the Needs field
// of the elements of the slice returned by [File.DynamicVersionNeeds].
// In general, a defined symbol will have an index referring
// to DynamicVersions, and an undefined symbol will have an index
// referring to some version in DynamicVersionNeeds.
func (vi VersionIndex) Index() uint16 {
return uint16(vi & 0x7fff)
}
// DynamicVersion is a version defined by a dynamic object.
// This describes entries in the ELF SHT_GNU_verdef section.
// We assume that the vd_version field is 1.
// Note that the name of the version appears here;
// it is not in the first Deps entry as it is in the ELF file.
type DynamicVersion struct {
Name string // Name of version defined by this index.
Index uint16 // Version index.
Flags DynamicVersionFlag
Deps []string // Names of versions that this version depends upon.
}
// DynamicVersionNeed describes a shared library needed by a dynamic object,
// with a list of the versions needed from that shared library.
// This describes entries in the ELF SHT_GNU_verneed section.
// We assume that the vn_version field is 1.
type DynamicVersionNeed struct {
Name string // Shared library name.
Needs []DynamicVersionDep // Dependencies.
}
// DynamicVersionDep is a version needed from some shared library.
type DynamicVersionDep struct {
Flags DynamicVersionFlag
Index uint16 // Version index.
Dep string // Name of required version.
}
// dynamicVersions returns version information for a dynamic object.
func (f *File) dynamicVersions(str []byte) error {
if f.dynVers != nil {
// Already initialized.
return nil
}
// Accumulate verdef information.
vd := f.SectionByType(SHT_GNU_VERDEF)
if vd == nil {
return nil
}
d, _ := vd.Data()
var dynVers []DynamicVersion
i := 0
for {
if i+20 > len(d) {
break
}
version := f.ByteOrder.Uint16(d[i : i+2])
if version != 1 {
return &FormatError{int64(vd.Offset + uint64(i)), "unexpected dynamic version", version}
}
flags := DynamicVersionFlag(f.ByteOrder.Uint16(d[i+2 : i+4]))
ndx := f.ByteOrder.Uint16(d[i+4 : i+6])
cnt := f.ByteOrder.Uint16(d[i+6 : i+8])
aux := f.ByteOrder.Uint32(d[i+12 : i+16])
next := f.ByteOrder.Uint32(d[i+16 : i+20])
if cnt == 0 {
return &FormatError{int64(vd.Offset + uint64(i)), "dynamic version has no name", nil}
}
var name string
var depName string
var deps []string
j := i + int(aux)
for c := 0; c < int(cnt); c++ {
if j+8 > len(d) {
break
}
vname := f.ByteOrder.Uint32(d[j : j+4])
vnext := f.ByteOrder.Uint32(d[j+4 : j+8])
depName, _ = getString(str, int(vname))
if c == 0 {
name = depName
} else {
deps = append(deps, depName)
}
j += int(vnext)
}
dynVers = append(dynVers, DynamicVersion{
Name: name,
Index: ndx,
Flags: flags,
Deps: deps,
})
if next == 0 {
break
}
i += int(next)
}
f.dynVers = dynVers
return nil
}
// DynamicVersions returns version information for a dynamic object.
func (f *File) DynamicVersions() ([]DynamicVersion, error) {
if f.dynVers == nil {
_, str, err := f.getSymbols(SHT_DYNSYM)
if err != nil {
return nil, err
}
hasVersions, err := f.gnuVersionInit(str)
if err != nil {
return nil, err
}
if !hasVersions {
return nil, errors.New("DynamicVersions: missing version table")
}
}
return f.dynVers, nil
}
// dynamicVersionNeeds returns version dependencies for a dynamic object.
func (f *File) dynamicVersionNeeds(str []byte) error {
if f.dynVerNeeds != nil {
// Already initialized.
return nil
}
// Accumulate verneed information.
vn := f.SectionByType(SHT_GNU_VERNEED)
if vn == nil {
return nil
}
d, _ := vn.Data()
var dynVerNeeds []DynamicVersionNeed
i := 0
for {
if i+16 > len(d) {
break
}
vers := f.ByteOrder.Uint16(d[i : i+2])
if vers != 1 {
return &FormatError{int64(vn.Offset + uint64(i)), "unexpected dynamic need version", vers}
}
cnt := f.ByteOrder.Uint16(d[i+2 : i+4])
fileoff := f.ByteOrder.Uint32(d[i+4 : i+8])
aux := f.ByteOrder.Uint32(d[i+8 : i+12])
next := f.ByteOrder.Uint32(d[i+12 : i+16])
file, _ := getString(str, int(fileoff))
var deps []DynamicVersionDep
j := i + int(aux)
for c := 0; c < int(cnt); c++ {
if j+16 > len(d) {
break
}
flags := DynamicVersionFlag(f.ByteOrder.Uint16(d[j+4 : j+6]))
index := f.ByteOrder.Uint16(d[j+6 : j+8])
nameoff := f.ByteOrder.Uint32(d[j+8 : j+12])
next := f.ByteOrder.Uint32(d[j+12 : j+16])
depName, _ := getString(str, int(nameoff))
deps = append(deps, DynamicVersionDep{
Flags: flags,
Index: index,
Dep: depName,
})
if next == 0 {
break
}
j += int(next)
}
dynVerNeeds = append(dynVerNeeds, DynamicVersionNeed{
Name: file,
Needs: deps,
})
if next == 0 {
break
}
i += int(next)
}
f.dynVerNeeds = dynVerNeeds
return nil
}
// DynamicVersionNeeds returns version dependencies for a dynamic object.
func (f *File) DynamicVersionNeeds() ([]DynamicVersionNeed, error) {
if f.dynVerNeeds == nil {
_, str, err := f.getSymbols(SHT_DYNSYM)
if err != nil {
return nil, err
}
hasVersions, err := f.gnuVersionInit(str)
if err != nil {
return nil, err
}
if !hasVersions {
return nil, errors.New("DynamicVersionNeeds: missing version table")
}
}
return f.dynVerNeeds, nil
}
// gnuVersionInit parses the GNU version tables
// for use by calls to gnuVersion.
// It reports whether any version tables were found.
func (f *File) gnuVersionInit(str []byte) (bool, error) {
// Versym parallels symbol table, indexing into verneed.
vs := f.SectionByType(SHT_GNU_VERSYM)
if vs == nil {
return false, nil
}
d, _ := vs.Data()
f.gnuVersym = d
if err := f.dynamicVersions(str); err != nil {
return false, err
}
if err := f.dynamicVersionNeeds(str); err != nil {
return false, err
}
return true, nil
}
// gnuVersion adds Library and Version information to sym,
// which came from offset i of the symbol table.
func (f *File) gnuVersion(i int) (hasVersion bool, versionIndex VersionIndex, version string, library string) {
// Each entry is two bytes; skip undef entry at beginning.
i = (i + 1) * 2
if i >= len(f.gnuVersym) {
return false, 0, "", ""
}
s := f.gnuVersym[i:]
if len(s) < 2 {
return false, 0, "", ""
}
vi := VersionIndex(f.ByteOrder.Uint16(s))
ndx := vi.Index()
if ndx == 0 || ndx == 1 {
return true, vi, "", ""
}
for _, v := range f.dynVerNeeds {
for _, n := range v.Needs {
if ndx == n.Index {
return true, vi, n.Dep, v.Name
}
}
}
for _, v := range f.dynVers {
if ndx == v.Index {
return true, vi, v.Name, ""
}
}
return false, 0, "", ""
}
// ImportedLibraries returns the names of all libraries
// referred to by the binary f that are expected to be
// linked with the binary at dynamic link time.
func (f *File) ImportedLibraries() ([]string, error) {
return f.DynString(DT_NEEDED)
}
// DynString returns the strings listed for the given tag in the file's dynamic
// section.
//
// The tag must be one that takes string values: [DT_NEEDED], [DT_SONAME], [DT_RPATH], or
// [DT_RUNPATH].
func (f *File) DynString(tag DynTag) ([]string, error) {
switch tag {
case DT_NEEDED, DT_SONAME, DT_RPATH, DT_RUNPATH:
default:
return nil, fmt.Errorf("non-string-valued tag %v", tag)
}
ds := f.SectionByType(SHT_DYNAMIC)
if ds == nil {
// not dynamic, so no libraries
return nil, nil
}
d, err := ds.Data()
if err != nil {
return nil, err
}
dynSize := 8
if f.Class == ELFCLASS64 {
dynSize = 16
}
if len(d)%dynSize != 0 {
return nil, errors.New("length of dynamic section is not a multiple of dynamic entry size")
}
str, err := f.stringTable(ds.Link)
if err != nil {
return nil, err
}
var all []string
for len(d) > 0 {
var t DynTag
var v uint64
switch f.Class {
case ELFCLASS32:
t = DynTag(f.ByteOrder.Uint32(d[0:4]))
v = uint64(f.ByteOrder.Uint32(d[4:8]))
d = d[8:]
case ELFCLASS64:
t = DynTag(f.ByteOrder.Uint64(d[0:8]))
v = f.ByteOrder.Uint64(d[8:16])
d = d[16:]
}
if t == tag {
s, ok := getString(str, int(v))
if ok {
all = append(all, s)
}
}
}
return all, nil
}
// DynValue returns the values listed for the given tag in the file's dynamic
// section.
func (f *File) DynValue(tag DynTag) ([]uint64, error) {
ds := f.SectionByType(SHT_DYNAMIC)
if ds == nil {
return nil, nil
}
d, err := ds.Data()
if err != nil {
return nil, err
}
dynSize := 8
if f.Class == ELFCLASS64 {
dynSize = 16
}
if len(d)%dynSize != 0 {
return nil, errors.New("length of dynamic section is not a multiple of dynamic entry size")
}
// Parse the .dynamic section as a string of bytes.
var vals []uint64
for len(d) > 0 {
var t DynTag
var v uint64
switch f.Class {
case ELFCLASS32:
t = DynTag(f.ByteOrder.Uint32(d[0:4]))
v = uint64(f.ByteOrder.Uint32(d[4:8]))
d = d[8:]
case ELFCLASS64:
t = DynTag(f.ByteOrder.Uint64(d[0:8]))
v = f.ByteOrder.Uint64(d[8:16])
d = d[16:]
}
if t == tag {
vals = append(vals, v)
}
}
return vals, nil
}
type nobitsSectionReader struct{}
func (*nobitsSectionReader) ReadAt(p []byte, off int64) (n int, err error) {
return 0, errors.New("unexpected read from SHT_NOBITS section")
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package elf
import (
"io"
"os"
)
// errorReader returns error from all operations.
type errorReader struct {
error
}
func (r errorReader) Read(p []byte) (n int, err error) {
return 0, r.error
}
func (r errorReader) ReadAt(p []byte, off int64) (n int, err error) {
return 0, r.error
}
func (r errorReader) Seek(offset int64, whence int) (int64, error) {
return 0, r.error
}
func (r errorReader) Close() error {
return r.error
}
// readSeekerFromReader converts an io.Reader into an io.ReadSeeker.
// In general Seek may not be efficient, but it is optimized for
// common cases such as seeking to the end to find the length of the
// data.
type readSeekerFromReader struct {
reset func() (io.Reader, error)
r io.Reader
size int64
offset int64
}
func (r *readSeekerFromReader) start() {
x, err := r.reset()
if err != nil {
r.r = errorReader{err}
} else {
r.r = x
}
r.offset = 0
}
func (r *readSeekerFromReader) Read(p []byte) (n int, err error) {
if r.r == nil {
r.start()
}
n, err = r.r.Read(p)
r.offset += int64(n)
return n, err
}
func (r *readSeekerFromReader) Seek(offset int64, whence int) (int64, error) {
var newOffset int64
switch whence {
case io.SeekStart:
newOffset = offset
case io.SeekCurrent:
newOffset = r.offset + offset
case io.SeekEnd:
newOffset = r.size + offset
default:
return 0, os.ErrInvalid
}
switch {
case newOffset == r.offset:
return newOffset, nil
case newOffset < 0, newOffset > r.size:
return 0, os.ErrInvalid
case newOffset == 0:
r.r = nil
case newOffset == r.size:
r.r = errorReader{io.EOF}
default:
if newOffset < r.offset {
// Restart at the beginning.
r.start()
}
// Read until we reach offset.
var buf [512]byte
for r.offset < newOffset {
b := buf[:]
if newOffset-r.offset < int64(len(buf)) {
b = buf[:newOffset-r.offset]
}
if _, err := r.Read(b); err != nil {
return 0, err
}
}
}
r.offset = newOffset
return r.offset, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
* Line tables
*/
package gosym
import (
"bytes"
"encoding/binary"
"sort"
"sync"
)
// version of the pclntab
type version int
const (
verUnknown version = iota
ver11
ver12
ver116
ver118
ver120
)
// A LineTable is a data structure mapping program counters to line numbers.
//
// In Go 1.1 and earlier, each function (represented by a [Func]) had its own LineTable,
// and the line number corresponded to a numbering of all source lines in the
// program, across all files. That absolute line number would then have to be
// converted separately to a file name and line number within the file.
//
// In Go 1.2, the format of the data changed so that there is a single LineTable
// for the entire program, shared by all Funcs, and there are no absolute line
// numbers, just line numbers within specific files.
//
// For the most part, LineTable's methods should be treated as an internal
// detail of the package; callers should use the methods on [Table] instead.
type LineTable struct {
Data []byte
PC uint64
Line int
// This mutex is used to keep parsing of pclntab synchronous.
mu sync.Mutex
// Contains the version of the pclntab section.
version version
// Go 1.2/1.16/1.18 state
binary binary.ByteOrder
quantum uint32
ptrsize uint32
textStart uint64 // address of runtime.text symbol (1.18+)
funcnametab []byte
cutab []byte
funcdata []byte
functab []byte
nfunctab uint32
filetab []byte
pctab []byte // points to the pctables.
nfiletab uint32
funcNames map[uint32]string // cache the function names
strings map[uint32]string // interned substrings of Data, keyed by offset
// fileMap varies depending on the version of the object file.
// For ver12, it maps the name to the index in the file table.
// For ver116, it maps the name to the offset in filetab.
fileMap map[string]uint32
}
// NOTE(rsc): This is wrong for GOARCH=arm, which uses a quantum of 4,
// but we have no idea whether we're using arm or not. This only
// matters in the old (pre-Go 1.2) symbol table format, so it's not worth
// fixing.
const oldQuantum = 1
func (t *LineTable) parse(targetPC uint64, targetLine int) (b []byte, pc uint64, line int) {
// The PC/line table can be thought of as a sequence of
// <pc update>* <line update>
// batches. Each update batch results in a (pc, line) pair,
// where line applies to every PC from pc up to but not
// including the pc of the next pair.
//
// Here we process each update individually, which simplifies
// the code, but makes the corner cases more confusing.
b, pc, line = t.Data, t.PC, t.Line
for pc <= targetPC && line != targetLine && len(b) > 0 {
code := b[0]
b = b[1:]
switch {
case code == 0:
if len(b) < 4 {
b = b[0:0]
break
}
val := binary.BigEndian.Uint32(b)
b = b[4:]
line += int(val)
case code <= 64:
line += int(code)
case code <= 128:
line -= int(code - 64)
default:
pc += oldQuantum * uint64(code-128)
continue
}
pc += oldQuantum
}
return b, pc, line
}
func (t *LineTable) slice(pc uint64) *LineTable {
data, pc, line := t.parse(pc, -1)
return &LineTable{Data: data, PC: pc, Line: line}
}
// PCToLine returns the line number for the given program counter.
//
// Deprecated: Use Table's PCToLine method instead.
func (t *LineTable) PCToLine(pc uint64) int {
if t.isGo12() {
return t.go12PCToLine(pc)
}
_, _, line := t.parse(pc, -1)
return line
}
// LineToPC returns the program counter for the given line number,
// considering only program counters before maxpc.
//
// Deprecated: Use Table's LineToPC method instead.
func (t *LineTable) LineToPC(line int, maxpc uint64) uint64 {
if t.isGo12() {
return 0
}
_, pc, line1 := t.parse(maxpc, line)
if line1 != line {
return 0
}
// Subtract quantum from PC to account for post-line increment
return pc - oldQuantum
}
// NewLineTable returns a new PC/line table
// corresponding to the encoded data.
// Text must be the start address of the
// corresponding text segment, with the exact
// value stored in the 'runtime.text' symbol.
// This value may differ from the start
// address of the text segment if
// binary was built with cgo enabled.
func NewLineTable(data []byte, text uint64) *LineTable {
return &LineTable{Data: data, PC: text, Line: 0, funcNames: make(map[uint32]string), strings: make(map[uint32]string)}
}
// Go 1.2 symbol table format.
// See golang.org/s/go12symtab.
//
// A general note about the methods here: rather than try to avoid
// index out of bounds errors, we trust Go to detect them, and then
// we recover from the panics and treat them as indicative of a malformed
// or incomplete table.
//
// The methods called by symtab.go, which begin with "go12" prefixes,
// are expected to have that recovery logic.
// isGo12 reports whether this is a Go 1.2 (or later) symbol table.
func (t *LineTable) isGo12() bool {
t.parsePclnTab()
return t.version >= ver12
}
const (
go12magic = 0xfffffffb
go116magic = 0xfffffffa
go118magic = 0xfffffff0
go120magic = 0xfffffff1
)
// uintptr returns the pointer-sized value encoded at b.
// The pointer size is dictated by the table being read.
func (t *LineTable) uintptr(b []byte) uint64 {
if t.ptrsize == 4 {
return uint64(t.binary.Uint32(b))
}
return t.binary.Uint64(b)
}
// parsePclnTab parses the pclntab, setting the version.
func (t *LineTable) parsePclnTab() {
t.mu.Lock()
defer t.mu.Unlock()
if t.version != verUnknown {
return
}
// Note that during this function, setting the version is the last thing we do.
// If we set the version too early, and parsing failed (likely as a panic on
// slice lookups), we'd have a mistaken version.
//
// Error paths through this code will default the version to 1.1.
t.version = ver11
if !disableRecover {
defer func() {
// If we panic parsing, assume it's a Go 1.1 pclntab.
recover()
}()
}
// Check header: 4-byte magic, two zeros, pc quantum, pointer size.
if len(t.Data) < 16 || t.Data[4] != 0 || t.Data[5] != 0 ||
(t.Data[6] != 1 && t.Data[6] != 2 && t.Data[6] != 4) || // pc quantum
(t.Data[7] != 4 && t.Data[7] != 8) { // pointer size
return
}
var possibleVersion version
leMagic := binary.LittleEndian.Uint32(t.Data)
beMagic := binary.BigEndian.Uint32(t.Data)
switch {
case leMagic == go12magic:
t.binary, possibleVersion = binary.LittleEndian, ver12
case beMagic == go12magic:
t.binary, possibleVersion = binary.BigEndian, ver12
case leMagic == go116magic:
t.binary, possibleVersion = binary.LittleEndian, ver116
case beMagic == go116magic:
t.binary, possibleVersion = binary.BigEndian, ver116
case leMagic == go118magic:
t.binary, possibleVersion = binary.LittleEndian, ver118
case beMagic == go118magic:
t.binary, possibleVersion = binary.BigEndian, ver118
case leMagic == go120magic:
t.binary, possibleVersion = binary.LittleEndian, ver120
case beMagic == go120magic:
t.binary, possibleVersion = binary.BigEndian, ver120
default:
return
}
t.version = possibleVersion
// quantum and ptrSize are the same between 1.2, 1.16, and 1.18
t.quantum = uint32(t.Data[6])
t.ptrsize = uint32(t.Data[7])
offset := func(word uint32) uint64 {
return t.uintptr(t.Data[8+word*t.ptrsize:])
}
data := func(word uint32) []byte {
return t.Data[offset(word):]
}
switch possibleVersion {
case ver118, ver120:
t.nfunctab = uint32(offset(0))
t.nfiletab = uint32(offset(1))
t.textStart = t.PC // use the start PC instead of reading from the table, which may be unrelocated
t.funcnametab = data(3)
t.cutab = data(4)
t.filetab = data(5)
t.pctab = data(6)
t.funcdata = data(7)
t.functab = data(7)
functabsize := (int(t.nfunctab)*2 + 1) * t.functabFieldSize()
t.functab = t.functab[:functabsize]
case ver116:
t.nfunctab = uint32(offset(0))
t.nfiletab = uint32(offset(1))
t.funcnametab = data(2)
t.cutab = data(3)
t.filetab = data(4)
t.pctab = data(5)
t.funcdata = data(6)
t.functab = data(6)
functabsize := (int(t.nfunctab)*2 + 1) * t.functabFieldSize()
t.functab = t.functab[:functabsize]
case ver12:
t.nfunctab = uint32(t.uintptr(t.Data[8:]))
t.funcdata = t.Data
t.funcnametab = t.Data
t.functab = t.Data[8+t.ptrsize:]
t.pctab = t.Data
functabsize := (int(t.nfunctab)*2 + 1) * t.functabFieldSize()
fileoff := t.binary.Uint32(t.functab[functabsize:])
t.functab = t.functab[:functabsize]
t.filetab = t.Data[fileoff:]
t.nfiletab = t.binary.Uint32(t.filetab)
t.filetab = t.filetab[:t.nfiletab*4]
default:
panic("unreachable")
}
}
// go12Funcs returns a slice of Funcs derived from the Go 1.2+ pcln table.
func (t *LineTable) go12Funcs() []Func {
// Assume it is malformed and return nil on error.
if !disableRecover {
defer func() {
recover()
}()
}
ft := t.funcTab()
funcs := make([]Func, ft.Count())
syms := make([]Sym, len(funcs))
for i := range funcs {
f := &funcs[i]
f.Entry = ft.pc(i)
f.End = ft.pc(i + 1)
info := t.funcData(uint32(i))
f.LineTable = t
f.FrameSize = int(info.deferreturn())
syms[i] = Sym{
Value: f.Entry,
Type: 'T',
Name: t.funcName(info.nameOff()),
GoType: 0,
Func: f,
goVersion: t.version,
}
f.Sym = &syms[i]
}
return funcs
}
// findFunc returns the funcData corresponding to the given program counter.
func (t *LineTable) findFunc(pc uint64) funcData {
ft := t.funcTab()
if pc < ft.pc(0) || pc >= ft.pc(ft.Count()) {
return funcData{}
}
idx := sort.Search(int(t.nfunctab), func(i int) bool {
return ft.pc(i) > pc
})
idx--
return t.funcData(uint32(idx))
}
// readvarint reads, removes, and returns a varint from *pp.
func (t *LineTable) readvarint(pp *[]byte) uint32 {
var v, shift uint32
p := *pp
for shift = 0; ; shift += 7 {
b := p[0]
p = p[1:]
v |= (uint32(b) & 0x7F) << shift
if b&0x80 == 0 {
break
}
}
*pp = p
return v
}
// funcName returns the name of the function found at off.
func (t *LineTable) funcName(off uint32) string {
if s, ok := t.funcNames[off]; ok {
return s
}
i := bytes.IndexByte(t.funcnametab[off:], 0)
s := string(t.funcnametab[off : off+uint32(i)])
t.funcNames[off] = s
return s
}
// stringFrom returns a Go string found at off from a position.
func (t *LineTable) stringFrom(arr []byte, off uint32) string {
if s, ok := t.strings[off]; ok {
return s
}
i := bytes.IndexByte(arr[off:], 0)
s := string(arr[off : off+uint32(i)])
t.strings[off] = s
return s
}
// string returns a Go string found at off.
func (t *LineTable) string(off uint32) string {
return t.stringFrom(t.funcdata, off)
}
// functabFieldSize returns the size in bytes of a single functab field.
func (t *LineTable) functabFieldSize() int {
if t.version >= ver118 {
return 4
}
return int(t.ptrsize)
}
// funcTab returns t's funcTab.
func (t *LineTable) funcTab() funcTab {
return funcTab{LineTable: t, sz: t.functabFieldSize()}
}
// funcTab is memory corresponding to a slice of functab structs, followed by an invalid PC.
// A functab struct is a PC and a func offset.
type funcTab struct {
*LineTable
sz int // cached result of t.functabFieldSize
}
// Count returns the number of func entries in f.
func (f funcTab) Count() int {
return int(f.nfunctab)
}
// pc returns the PC of the i'th func in f.
func (f funcTab) pc(i int) uint64 {
u := f.uint(f.functab[2*i*f.sz:])
if f.version >= ver118 {
u += f.textStart
}
return u
}
// funcOff returns the funcdata offset of the i'th func in f.
func (f funcTab) funcOff(i int) uint64 {
return f.uint(f.functab[(2*i+1)*f.sz:])
}
// uint returns the uint stored at b.
func (f funcTab) uint(b []byte) uint64 {
if f.sz == 4 {
return uint64(f.binary.Uint32(b))
}
return f.binary.Uint64(b)
}
// funcData is memory corresponding to an _func struct.
type funcData struct {
t *LineTable // LineTable this data is a part of
data []byte // raw memory for the function
}
// funcData returns the ith funcData in t.functab.
func (t *LineTable) funcData(i uint32) funcData {
data := t.funcdata[t.funcTab().funcOff(int(i)):]
return funcData{t: t, data: data}
}
// IsZero reports whether f is the zero value.
func (f funcData) IsZero() bool {
return f.t == nil && f.data == nil
}
// entryPC returns the func's entry PC.
func (f *funcData) entryPC() uint64 {
// In Go 1.18, the first field of _func changed
// from a uintptr entry PC to a uint32 entry offset.
if f.t.version >= ver118 {
// TODO: support multiple text sections.
// See runtime/symtab.go:(*moduledata).textAddr.
return uint64(f.t.binary.Uint32(f.data)) + f.t.textStart
}
return f.t.uintptr(f.data)
}
func (f funcData) nameOff() uint32 { return f.field(1) }
func (f funcData) deferreturn() uint32 { return f.field(3) }
func (f funcData) pcfile() uint32 { return f.field(5) }
func (f funcData) pcln() uint32 { return f.field(6) }
func (f funcData) cuOffset() uint32 { return f.field(8) }
// field returns the nth field of the _func struct.
// It panics if n == 0 or n > 9; for n == 0, call f.entryPC.
// Most callers should use a named field accessor (just above).
func (f funcData) field(n uint32) uint32 {
if n == 0 || n > 9 {
panic("bad funcdata field")
}
// In Go 1.18, the first field of _func changed
// from a uintptr entry PC to a uint32 entry offset.
sz0 := f.t.ptrsize
if f.t.version >= ver118 {
sz0 = 4
}
off := sz0 + (n-1)*4 // subsequent fields are 4 bytes each
data := f.data[off:]
return f.t.binary.Uint32(data)
}
// step advances to the next pc, value pair in the encoded table.
func (t *LineTable) step(p *[]byte, pc *uint64, val *int32, first bool) bool {
uvdelta := t.readvarint(p)
if uvdelta == 0 && !first {
return false
}
if uvdelta&1 != 0 {
uvdelta = ^(uvdelta >> 1)
} else {
uvdelta >>= 1
}
vdelta := int32(uvdelta)
pcdelta := t.readvarint(p) * t.quantum
*pc += uint64(pcdelta)
*val += vdelta
return true
}
// pcvalue reports the value associated with the target pc.
// off is the offset to the beginning of the pc-value table,
// and entry is the start PC for the corresponding function.
func (t *LineTable) pcvalue(off uint32, entry, targetpc uint64) int32 {
p := t.pctab[off:]
val := int32(-1)
pc := entry
for t.step(&p, &pc, &val, pc == entry) {
if targetpc < pc {
return val
}
}
return -1
}
// findFileLine scans one function in the binary looking for a
// program counter in the given file on the given line.
// It does so by running the pc-value tables mapping program counter
// to file number. Since most functions come from a single file, these
// are usually short and quick to scan. If a file match is found, then the
// code goes to the expense of looking for a simultaneous line number match.
func (t *LineTable) findFileLine(entry uint64, filetab, linetab uint32, filenum, line int32, cutab []byte) uint64 {
if filetab == 0 || linetab == 0 {
return 0
}
fp := t.pctab[filetab:]
fl := t.pctab[linetab:]
fileVal := int32(-1)
filePC := entry
lineVal := int32(-1)
linePC := entry
fileStartPC := filePC
for t.step(&fp, &filePC, &fileVal, filePC == entry) {
fileIndex := fileVal
if t.version == ver116 || t.version == ver118 || t.version == ver120 {
fileIndex = int32(t.binary.Uint32(cutab[fileVal*4:]))
}
if fileIndex == filenum && fileStartPC < filePC {
// fileIndex is in effect starting at fileStartPC up to
// but not including filePC, and it's the file we want.
// Run the PC table looking for a matching line number
// or until we reach filePC.
lineStartPC := linePC
for linePC < filePC && t.step(&fl, &linePC, &lineVal, linePC == entry) {
// lineVal is in effect until linePC, and lineStartPC < filePC.
if lineVal == line {
if fileStartPC <= lineStartPC {
return lineStartPC
}
if fileStartPC < linePC {
return fileStartPC
}
}
lineStartPC = linePC
}
}
fileStartPC = filePC
}
return 0
}
// go12PCToLine maps program counter to line number for the Go 1.2+ pcln table.
func (t *LineTable) go12PCToLine(pc uint64) (line int) {
defer func() {
if !disableRecover && recover() != nil {
line = -1
}
}()
f := t.findFunc(pc)
if f.IsZero() {
return -1
}
entry := f.entryPC()
linetab := f.pcln()
return int(t.pcvalue(linetab, entry, pc))
}
// go12PCToFile maps program counter to file name for the Go 1.2+ pcln table.
func (t *LineTable) go12PCToFile(pc uint64) (file string) {
defer func() {
if !disableRecover && recover() != nil {
file = ""
}
}()
f := t.findFunc(pc)
if f.IsZero() {
return ""
}
entry := f.entryPC()
filetab := f.pcfile()
fno := t.pcvalue(filetab, entry, pc)
if t.version == ver12 {
if fno <= 0 {
return ""
}
return t.string(t.binary.Uint32(t.filetab[4*fno:]))
}
// Go ≥ 1.16
if fno < 0 { // 0 is valid for ≥ 1.16
return ""
}
cuoff := f.cuOffset()
if fnoff := t.binary.Uint32(t.cutab[(cuoff+uint32(fno))*4:]); fnoff != ^uint32(0) {
return t.stringFrom(t.filetab, fnoff)
}
return ""
}
// go12LineToPC maps a (file, line) pair to a program counter for the Go 1.2+ pcln table.
func (t *LineTable) go12LineToPC(file string, line int) (pc uint64) {
defer func() {
if !disableRecover && recover() != nil {
pc = 0
}
}()
t.initFileMap()
filenum, ok := t.fileMap[file]
if !ok {
return 0
}
// Scan all functions.
// If this turns out to be a bottleneck, we could build a map[int32][]int32
// mapping file number to a list of functions with code from that file.
var cutab []byte
for i := uint32(0); i < t.nfunctab; i++ {
f := t.funcData(i)
entry := f.entryPC()
filetab := f.pcfile()
linetab := f.pcln()
if t.version == ver116 || t.version == ver118 || t.version == ver120 {
if f.cuOffset() == ^uint32(0) {
// skip functions without compilation unit (not real function, or linker generated)
continue
}
cutab = t.cutab[f.cuOffset()*4:]
}
pc := t.findFileLine(entry, filetab, linetab, int32(filenum), int32(line), cutab)
if pc != 0 {
return pc
}
}
return 0
}
// initFileMap initializes the map from file name to file number.
func (t *LineTable) initFileMap() {
t.mu.Lock()
defer t.mu.Unlock()
if t.fileMap != nil {
return
}
m := make(map[string]uint32)
if t.version == ver12 {
for i := uint32(1); i < t.nfiletab; i++ {
s := t.string(t.binary.Uint32(t.filetab[4*i:]))
m[s] = i
}
} else {
var pos uint32
for i := uint32(0); i < t.nfiletab; i++ {
s := t.stringFrom(t.filetab, pos)
m[s] = pos
pos += uint32(len(s) + 1)
}
}
t.fileMap = m
}
// go12MapFiles adds to m a key for every file in the Go 1.2 LineTable.
// Every key maps to obj. That's not a very interesting map, but it provides
// a way for callers to obtain the list of files in the program.
func (t *LineTable) go12MapFiles(m map[string]*Obj, obj *Obj) {
if !disableRecover {
defer func() {
recover()
}()
}
t.initFileMap()
for file := range t.fileMap {
m[file] = obj
}
}
// disableRecover causes this package not to swallow panics.
// This is useful when making changes.
const disableRecover = false
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package gosym implements access to the Go symbol
// and line number tables embedded in Go binaries generated
// by the gc compilers.
package gosym
import (
"bytes"
"encoding/binary"
"fmt"
"strconv"
"strings"
)
/*
* Symbols
*/
// A Sym represents a single symbol table entry.
type Sym struct {
Value uint64
Type byte
Name string
GoType uint64
// If this symbol is a function symbol, the corresponding Func
Func *Func
goVersion version
}
// Static reports whether this symbol is static (not visible outside its file).
func (s *Sym) Static() bool { return s.Type >= 'a' }
// nameWithoutInst returns s.Name if s.Name has no brackets (does not reference an
// instantiated type, function, or method). If s.Name contains brackets, then it
// returns s.Name with all the contents between (and including) the outermost left
// and right bracket removed. This is useful to ignore any extra slashes or dots
// inside the brackets from the string searches below, where needed.
func (s *Sym) nameWithoutInst() string {
start := strings.Index(s.Name, "[")
if start < 0 {
return s.Name
}
end := strings.LastIndex(s.Name, "]")
if end < 0 {
// Malformed name, should contain closing bracket too.
return s.Name
}
return s.Name[0:start] + s.Name[end+1:]
}
// PackageName returns the package part of the symbol name,
// or the empty string if there is none.
func (s *Sym) PackageName() string {
name := s.nameWithoutInst()
// Since go1.20, a prefix of "type:" and "go:" is a compiler-generated symbol,
// they do not belong to any package.
//
// See cmd/compile/internal/base/link.go:ReservedImports variable.
if s.goVersion >= ver120 && (strings.HasPrefix(name, "go:") || strings.HasPrefix(name, "type:")) {
return ""
}
// For go1.18 and below, the prefix are "type." and "go." instead.
if s.goVersion <= ver118 && (strings.HasPrefix(name, "go.") || strings.HasPrefix(name, "type.")) {
return ""
}
pathend := strings.LastIndex(name, "/")
if pathend < 0 {
pathend = 0
}
if i := strings.Index(name[pathend:], "."); i != -1 {
return name[:pathend+i]
}
return ""
}
// ReceiverName returns the receiver type name of this symbol,
// or the empty string if there is none. A receiver name is only detected in
// the case that s.Name is fully-specified with a package name.
func (s *Sym) ReceiverName() string {
name := s.nameWithoutInst()
// If we find a slash in name, it should precede any bracketed expression
// that was removed, so pathend will apply correctly to name and s.Name.
pathend := strings.LastIndex(name, "/")
if pathend < 0 {
pathend = 0
}
// Find the first dot after pathend (or from the beginning, if there was
// no slash in name).
l := strings.Index(name[pathend:], ".")
// Find the last dot after pathend (or the beginning).
r := strings.LastIndex(name[pathend:], ".")
if l == -1 || r == -1 || l == r {
// There is no receiver if we didn't find two distinct dots after pathend.
return ""
}
// Given there is a trailing '.' that is in name, find it now in s.Name.
// pathend+l should apply to s.Name, because it should be the dot in the
// package name.
r = strings.LastIndex(s.Name[pathend:], ".")
return s.Name[pathend+l+1 : pathend+r]
}
// BaseName returns the symbol name without the package or receiver name.
func (s *Sym) BaseName() string {
name := s.nameWithoutInst()
if i := strings.LastIndex(name, "."); i != -1 {
if s.Name != name {
brack := strings.Index(s.Name, "[")
if i > brack {
// BaseName is a method name after the brackets, so
// recalculate for s.Name. Otherwise, i applies
// correctly to s.Name, since it is before the
// brackets.
i = strings.LastIndex(s.Name, ".")
}
}
return s.Name[i+1:]
}
return s.Name
}
// A Func collects information about a single function.
type Func struct {
Entry uint64
*Sym
End uint64
Params []*Sym // nil for Go 1.3 and later binaries
Locals []*Sym // nil for Go 1.3 and later binaries
FrameSize int
LineTable *LineTable
Obj *Obj
}
// An Obj represents a collection of functions in a symbol table.
//
// The exact method of division of a binary into separate Objs is an internal detail
// of the symbol table format.
//
// In early versions of Go each source file became a different Obj.
//
// In Go 1 and Go 1.1, each package produced one Obj for all Go sources
// and one Obj per C source file.
//
// In Go 1.2, there is a single Obj for the entire program.
type Obj struct {
// Funcs is a list of functions in the Obj.
Funcs []Func
// In Go 1.1 and earlier, Paths is a list of symbols corresponding
// to the source file names that produced the Obj.
// In Go 1.2, Paths is nil.
// Use the keys of Table.Files to obtain a list of source files.
Paths []Sym // meta
}
/*
* Symbol tables
*/
// Table represents a Go symbol table. It stores all of the
// symbols decoded from the program and provides methods to translate
// between symbols, names, and addresses.
type Table struct {
Syms []Sym // nil for Go 1.3 and later binaries
Funcs []Func
Files map[string]*Obj // for Go 1.2 and later all files map to one Obj
Objs []Obj // for Go 1.2 and later only one Obj in slice
go12line *LineTable // Go 1.2 line number table
}
type sym struct {
value uint64
gotype uint64
typ byte
name []byte
}
var (
littleEndianSymtab = []byte{0xFD, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00}
bigEndianSymtab = []byte{0xFF, 0xFF, 0xFF, 0xFD, 0x00, 0x00, 0x00}
oldLittleEndianSymtab = []byte{0xFE, 0xFF, 0xFF, 0xFF, 0x00, 0x00}
)
func walksymtab(data []byte, fn func(sym) error) error {
if len(data) == 0 { // missing symtab is okay
return nil
}
var order binary.ByteOrder = binary.BigEndian
newTable := false
switch {
case bytes.HasPrefix(data, oldLittleEndianSymtab):
// Same as Go 1.0, but little endian.
// Format was used during interim development between Go 1.0 and Go 1.1.
// Should not be widespread, but easy to support.
data = data[6:]
order = binary.LittleEndian
case bytes.HasPrefix(data, bigEndianSymtab):
newTable = true
case bytes.HasPrefix(data, littleEndianSymtab):
newTable = true
order = binary.LittleEndian
}
var ptrsz int
if newTable {
if len(data) < 8 {
return &DecodingError{len(data), "unexpected EOF", nil}
}
ptrsz = int(data[7])
if ptrsz != 4 && ptrsz != 8 {
return &DecodingError{7, "invalid pointer size", ptrsz}
}
data = data[8:]
}
var s sym
p := data
for len(p) >= 4 {
var typ byte
if newTable {
// Symbol type, value, Go type.
typ = p[0] & 0x3F
wideValue := p[0]&0x40 != 0
goType := p[0]&0x80 != 0
if typ < 26 {
typ += 'A'
} else {
typ += 'a' - 26
}
s.typ = typ
p = p[1:]
if wideValue {
if len(p) < ptrsz {
return &DecodingError{len(data), "unexpected EOF", nil}
}
// fixed-width value
if ptrsz == 8 {
s.value = order.Uint64(p[0:8])
p = p[8:]
} else {
s.value = uint64(order.Uint32(p[0:4]))
p = p[4:]
}
} else {
// varint value
s.value = 0
shift := uint(0)
for len(p) > 0 && p[0]&0x80 != 0 {
s.value |= uint64(p[0]&0x7F) << shift
shift += 7
p = p[1:]
}
if len(p) == 0 {
return &DecodingError{len(data), "unexpected EOF", nil}
}
s.value |= uint64(p[0]) << shift
p = p[1:]
}
if goType {
if len(p) < ptrsz {
return &DecodingError{len(data), "unexpected EOF", nil}
}
// fixed-width go type
if ptrsz == 8 {
s.gotype = order.Uint64(p[0:8])
p = p[8:]
} else {
s.gotype = uint64(order.Uint32(p[0:4]))
p = p[4:]
}
}
} else {
// Value, symbol type.
s.value = uint64(order.Uint32(p[0:4]))
if len(p) < 5 {
return &DecodingError{len(data), "unexpected EOF", nil}
}
typ = p[4]
if typ&0x80 == 0 {
return &DecodingError{len(data) - len(p) + 4, "bad symbol type", typ}
}
typ &^= 0x80
s.typ = typ
p = p[5:]
}
// Name.
var i int
var nnul int
for i = 0; i < len(p); i++ {
if p[i] == 0 {
nnul = 1
break
}
}
switch typ {
case 'z', 'Z':
p = p[i+nnul:]
for i = 0; i+2 <= len(p); i += 2 {
if p[i] == 0 && p[i+1] == 0 {
nnul = 2
break
}
}
}
if len(p) < i+nnul {
return &DecodingError{len(data), "unexpected EOF", nil}
}
s.name = p[0:i]
i += nnul
p = p[i:]
if !newTable {
if len(p) < 4 {
return &DecodingError{len(data), "unexpected EOF", nil}
}
// Go type.
s.gotype = uint64(order.Uint32(p[:4]))
p = p[4:]
}
fn(s)
}
return nil
}
// NewTable decodes the Go symbol table (the ".gosymtab" section in ELF),
// returning an in-memory representation.
// Starting with Go 1.3, the Go symbol table no longer includes symbol data.
func NewTable(symtab []byte, pcln *LineTable) (*Table, error) {
var n int
err := walksymtab(symtab, func(s sym) error {
n++
return nil
})
if err != nil {
return nil, err
}
var t Table
if pcln.isGo12() {
t.go12line = pcln
}
fname := make(map[uint16]string)
t.Syms = make([]Sym, 0, n)
nf := 0
nz := 0
lasttyp := uint8(0)
err = walksymtab(symtab, func(s sym) error {
n := len(t.Syms)
t.Syms = t.Syms[0 : n+1]
ts := &t.Syms[n]
ts.Type = s.typ
ts.Value = s.value
ts.GoType = s.gotype
ts.goVersion = pcln.version
switch s.typ {
default:
// rewrite name to use . instead of · (c2 b7)
w := 0
b := s.name
for i := 0; i < len(b); i++ {
if b[i] == 0xc2 && i+1 < len(b) && b[i+1] == 0xb7 {
i++
b[i] = '.'
}
b[w] = b[i]
w++
}
ts.Name = string(s.name[0:w])
case 'z', 'Z':
if lasttyp != 'z' && lasttyp != 'Z' {
nz++
}
for i := 0; i < len(s.name); i += 2 {
eltIdx := binary.BigEndian.Uint16(s.name[i : i+2])
elt, ok := fname[eltIdx]
if !ok {
return &DecodingError{-1, "bad filename code", eltIdx}
}
if n := len(ts.Name); n > 0 && ts.Name[n-1] != '/' {
ts.Name += "/"
}
ts.Name += elt
}
}
switch s.typ {
case 'T', 't', 'L', 'l':
nf++
case 'f':
fname[uint16(s.value)] = ts.Name
}
lasttyp = s.typ
return nil
})
if err != nil {
return nil, err
}
t.Funcs = make([]Func, 0, nf)
t.Files = make(map[string]*Obj)
var obj *Obj
if t.go12line != nil {
// Put all functions into one Obj.
t.Objs = make([]Obj, 1)
obj = &t.Objs[0]
t.go12line.go12MapFiles(t.Files, obj)
} else {
t.Objs = make([]Obj, 0, nz)
}
// Count text symbols and attach frame sizes, parameters, and
// locals to them. Also, find object file boundaries.
lastf := 0
for i := 0; i < len(t.Syms); i++ {
sym := &t.Syms[i]
switch sym.Type {
case 'Z', 'z': // path symbol
if t.go12line != nil {
// Go 1.2 binaries have the file information elsewhere. Ignore.
break
}
// Finish the current object
if obj != nil {
obj.Funcs = t.Funcs[lastf:]
}
lastf = len(t.Funcs)
// Start new object
n := len(t.Objs)
t.Objs = t.Objs[0 : n+1]
obj = &t.Objs[n]
// Count & copy path symbols
var end int
for end = i + 1; end < len(t.Syms); end++ {
if c := t.Syms[end].Type; c != 'Z' && c != 'z' {
break
}
}
obj.Paths = t.Syms[i:end]
i = end - 1 // loop will i++
// Record file names
depth := 0
for j := range obj.Paths {
s := &obj.Paths[j]
if s.Name == "" {
depth--
} else {
if depth == 0 {
t.Files[s.Name] = obj
}
depth++
}
}
case 'T', 't', 'L', 'l': // text symbol
if n := len(t.Funcs); n > 0 {
t.Funcs[n-1].End = sym.Value
}
if sym.Name == "runtime.etext" || sym.Name == "etext" {
continue
}
// Count parameter and local (auto) syms
var np, na int
var end int
countloop:
for end = i + 1; end < len(t.Syms); end++ {
switch t.Syms[end].Type {
case 'T', 't', 'L', 'l', 'Z', 'z':
break countloop
case 'p':
np++
case 'a':
na++
}
}
// Fill in the function symbol
n := len(t.Funcs)
t.Funcs = t.Funcs[0 : n+1]
fn := &t.Funcs[n]
sym.Func = fn
fn.Params = make([]*Sym, 0, np)
fn.Locals = make([]*Sym, 0, na)
fn.Sym = sym
fn.Entry = sym.Value
fn.Obj = obj
if t.go12line != nil {
// All functions share the same line table.
// It knows how to narrow down to a specific
// function quickly.
fn.LineTable = t.go12line
} else if pcln != nil {
fn.LineTable = pcln.slice(fn.Entry)
pcln = fn.LineTable
}
for j := i; j < end; j++ {
s := &t.Syms[j]
switch s.Type {
case 'm':
fn.FrameSize = int(s.Value)
case 'p':
n := len(fn.Params)
fn.Params = fn.Params[0 : n+1]
fn.Params[n] = s
case 'a':
n := len(fn.Locals)
fn.Locals = fn.Locals[0 : n+1]
fn.Locals[n] = s
}
}
i = end - 1 // loop will i++
}
}
if t.go12line != nil && nf == 0 {
t.Funcs = t.go12line.go12Funcs()
}
if obj != nil {
obj.Funcs = t.Funcs[lastf:]
}
return &t, nil
}
// PCToFunc returns the function containing the program counter pc,
// or nil if there is no such function.
func (t *Table) PCToFunc(pc uint64) *Func {
funcs := t.Funcs
for len(funcs) > 0 {
m := len(funcs) / 2
fn := &funcs[m]
switch {
case pc < fn.Entry:
funcs = funcs[0:m]
case fn.Entry <= pc && pc < fn.End:
return fn
default:
funcs = funcs[m+1:]
}
}
return nil
}
// PCToLine looks up line number information for a program counter.
// If there is no information, it returns fn == nil.
func (t *Table) PCToLine(pc uint64) (file string, line int, fn *Func) {
if fn = t.PCToFunc(pc); fn == nil {
return
}
if t.go12line != nil {
file = t.go12line.go12PCToFile(pc)
line = t.go12line.go12PCToLine(pc)
} else {
file, line = fn.Obj.lineFromAline(fn.LineTable.PCToLine(pc))
}
return
}
// LineToPC looks up the first program counter on the given line in
// the named file. It returns [UnknownFileError] or [UnknownLineError] if
// there is an error looking up this line.
func (t *Table) LineToPC(file string, line int) (pc uint64, fn *Func, err error) {
obj, ok := t.Files[file]
if !ok {
return 0, nil, UnknownFileError(file)
}
if t.go12line != nil {
pc := t.go12line.go12LineToPC(file, line)
if pc == 0 {
return 0, nil, &UnknownLineError{file, line}
}
return pc, t.PCToFunc(pc), nil
}
abs, err := obj.alineFromLine(file, line)
if err != nil {
return
}
for i := range obj.Funcs {
f := &obj.Funcs[i]
pc := f.LineTable.LineToPC(abs, f.End)
if pc != 0 {
return pc, f, nil
}
}
return 0, nil, &UnknownLineError{file, line}
}
// LookupSym returns the text, data, or bss symbol with the given name,
// or nil if no such symbol is found.
func (t *Table) LookupSym(name string) *Sym {
// TODO(austin) Maybe make a map
for i := range t.Syms {
s := &t.Syms[i]
switch s.Type {
case 'T', 't', 'L', 'l', 'D', 'd', 'B', 'b':
if s.Name == name {
return s
}
}
}
return nil
}
// LookupFunc returns the text, data, or bss symbol with the given name,
// or nil if no such symbol is found.
func (t *Table) LookupFunc(name string) *Func {
for i := range t.Funcs {
f := &t.Funcs[i]
if f.Sym.Name == name {
return f
}
}
return nil
}
// SymByAddr returns the text, data, or bss symbol starting at the given address.
func (t *Table) SymByAddr(addr uint64) *Sym {
for i := range t.Syms {
s := &t.Syms[i]
switch s.Type {
case 'T', 't', 'L', 'l', 'D', 'd', 'B', 'b':
if s.Value == addr {
return s
}
}
}
return nil
}
/*
* Object files
*/
// This is legacy code for Go 1.1 and earlier, which used the
// Plan 9 format for pc-line tables. This code was never quite
// correct. It's probably very close, and it's usually correct, but
// we never quite found all the corner cases.
//
// Go 1.2 and later use a simpler format, documented at golang.org/s/go12symtab.
func (o *Obj) lineFromAline(aline int) (string, int) {
type stackEnt struct {
path string
start int
offset int
prev *stackEnt
}
noPath := &stackEnt{"", 0, 0, nil}
tos := noPath
pathloop:
for _, s := range o.Paths {
val := int(s.Value)
switch {
case val > aline:
break pathloop
case val == 1:
// Start a new stack
tos = &stackEnt{s.Name, val, 0, noPath}
case s.Name == "":
// Pop
if tos == noPath {
return "<malformed symbol table>", 0
}
tos.prev.offset += val - tos.start
tos = tos.prev
default:
// Push
tos = &stackEnt{s.Name, val, 0, tos}
}
}
if tos == noPath {
return "", 0
}
return tos.path, aline - tos.start - tos.offset + 1
}
func (o *Obj) alineFromLine(path string, line int) (int, error) {
if line < 1 {
return 0, &UnknownLineError{path, line}
}
for i, s := range o.Paths {
// Find this path
if s.Name != path {
continue
}
// Find this line at this stack level
depth := 0
var incstart int
line += int(s.Value)
pathloop:
for _, s := range o.Paths[i:] {
val := int(s.Value)
switch {
case depth == 1 && val >= line:
return line - 1, nil
case s.Name == "":
depth--
if depth == 0 {
break pathloop
} else if depth == 1 {
line += val - incstart
}
default:
if depth == 1 {
incstart = val
}
depth++
}
}
return 0, &UnknownLineError{path, line}
}
return 0, UnknownFileError(path)
}
/*
* Errors
*/
// UnknownFileError represents a failure to find the specific file in
// the symbol table.
type UnknownFileError string
func (e UnknownFileError) Error() string { return "unknown file: " + string(e) }
// UnknownLineError represents a failure to map a line to a program
// counter, either because the line is beyond the bounds of the file
// or because there is no code on the given line.
type UnknownLineError struct {
File string
Line int
}
func (e *UnknownLineError) Error() string {
return "no code at " + e.File + ":" + strconv.Itoa(e.Line)
}
// DecodingError represents an error during the decoding of
// the symbol table.
type DecodingError struct {
off int
msg string
val any
}
func (e *DecodingError) Error() string {
msg := e.msg
if e.val != nil {
msg += fmt.Sprintf(" '%v'", e.val)
}
msg += fmt.Sprintf(" at byte %#x", e.off)
return msg
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package macho
import (
"encoding/binary"
"fmt"
"internal/saferio"
"io"
"os"
)
// A FatFile is a Mach-O universal binary that contains at least one architecture.
type FatFile struct {
Magic uint32
Arches []FatArch
closer io.Closer
}
// A FatArchHeader represents a fat header for a specific image architecture.
type FatArchHeader struct {
Cpu Cpu
SubCpu uint32
Offset uint32
Size uint32
Align uint32
}
const fatArchHeaderSize = 5 * 4
// A FatArch is a Mach-O File inside a FatFile.
type FatArch struct {
FatArchHeader
*File
}
// ErrNotFat is returned from [NewFatFile] or [OpenFat] when the file is not a
// universal binary but may be a thin binary, based on its magic number.
var ErrNotFat = &FormatError{0, "not a fat Mach-O file", nil}
// NewFatFile creates a new [FatFile] for accessing all the Mach-O images in a
// universal binary. The Mach-O binary is expected to start at position 0 in
// the ReaderAt.
func NewFatFile(r io.ReaderAt) (*FatFile, error) {
var ff FatFile
sr := io.NewSectionReader(r, 0, 1<<63-1)
// Read the fat_header struct, which is always in big endian.
// Start with the magic number.
err := binary.Read(sr, binary.BigEndian, &ff.Magic)
if err != nil {
return nil, &FormatError{0, "error reading magic number", nil}
} else if ff.Magic != MagicFat {
// See if this is a Mach-O file via its magic number. The magic
// must be converted to little endian first though.
var buf [4]byte
binary.BigEndian.PutUint32(buf[:], ff.Magic)
leMagic := binary.LittleEndian.Uint32(buf[:])
if leMagic == Magic32 || leMagic == Magic64 {
return nil, ErrNotFat
} else {
return nil, &FormatError{0, "invalid magic number", nil}
}
}
offset := int64(4)
// Read the number of FatArchHeaders that come after the fat_header.
var narch uint32
err = binary.Read(sr, binary.BigEndian, &narch)
if err != nil {
return nil, &FormatError{offset, "invalid fat_header", nil}
}
offset += 4
if narch < 1 {
return nil, &FormatError{offset, "file contains no images", nil}
}
// Combine the Cpu and SubCpu (both uint32) into a uint64 to make sure
// there are not duplicate architectures.
seenArches := make(map[uint64]bool)
// Make sure that all images are for the same MH_ type.
var machoType Type
// Following the fat_header comes narch fat_arch structs that index
// Mach-O images further in the file.
c := saferio.SliceCap[FatArch](uint64(narch))
if c < 0 {
return nil, &FormatError{offset, "too many images", nil}
}
ff.Arches = make([]FatArch, 0, c)
for i := uint32(0); i < narch; i++ {
var fa FatArch
err = binary.Read(sr, binary.BigEndian, &fa.FatArchHeader)
if err != nil {
return nil, &FormatError{offset, "invalid fat_arch header", nil}
}
offset += fatArchHeaderSize
fr := io.NewSectionReader(r, int64(fa.Offset), int64(fa.Size))
fa.File, err = NewFile(fr)
if err != nil {
return nil, err
}
// Make sure the architecture for this image is not duplicate.
seenArch := (uint64(fa.Cpu) << 32) | uint64(fa.SubCpu)
if o, k := seenArches[seenArch]; o || k {
return nil, &FormatError{offset, fmt.Sprintf("duplicate architecture cpu=%v, subcpu=%#x", fa.Cpu, fa.SubCpu), nil}
}
seenArches[seenArch] = true
// Make sure the Mach-O type matches that of the first image.
if i == 0 {
machoType = fa.Type
} else {
if fa.Type != machoType {
return nil, &FormatError{offset, fmt.Sprintf("Mach-O type for architecture #%d (type=%#x) does not match first (type=%#x)", i, fa.Type, machoType), nil}
}
}
ff.Arches = append(ff.Arches, fa)
}
return &ff, nil
}
// OpenFat opens the named file using [os.Open] and prepares it for use as a Mach-O
// universal binary.
func OpenFat(name string) (*FatFile, error) {
f, err := os.Open(name)
if err != nil {
return nil, err
}
ff, err := NewFatFile(f)
if err != nil {
f.Close()
return nil, err
}
ff.closer = f
return ff, nil
}
func (ff *FatFile) Close() error {
var err error
if ff.closer != nil {
err = ff.closer.Close()
ff.closer = nil
}
return err
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package macho implements access to Mach-O object files.
# Security
This package is not designed to be hardened against adversarial inputs, and is
outside the scope of https://go.dev/security/policy. In particular, only basic
validation is done when parsing object files. As such, care should be taken when
parsing untrusted inputs, as parsing malformed files may consume significant
resources, or cause panics.
*/
package macho
// High level access to low level data structures.
import (
"bytes"
"compress/zlib"
"debug/dwarf"
"encoding/binary"
"fmt"
"internal/saferio"
"io"
"os"
"strings"
)
// A File represents an open Mach-O file.
type File struct {
FileHeader
ByteOrder binary.ByteOrder
Loads []Load
Sections []*Section
Symtab *Symtab
Dysymtab *Dysymtab
closer io.Closer
}
// A Load represents any Mach-O load command.
type Load interface {
Raw() []byte
}
// A LoadBytes is the uninterpreted bytes of a Mach-O load command.
type LoadBytes []byte
func (b LoadBytes) Raw() []byte { return b }
// A SegmentHeader is the header for a Mach-O 32-bit or 64-bit load segment command.
type SegmentHeader struct {
Cmd LoadCmd
Len uint32
Name string
Addr uint64
Memsz uint64
Offset uint64
Filesz uint64
Maxprot uint32
Prot uint32
Nsect uint32
Flag uint32
}
// A Segment represents a Mach-O 32-bit or 64-bit load segment command.
type Segment struct {
LoadBytes
SegmentHeader
// Embed ReaderAt for ReadAt method.
// Do not embed SectionReader directly
// to avoid having Read and Seek.
// If a client wants Read and Seek it must use
// Open() to avoid fighting over the seek offset
// with other clients.
io.ReaderAt
sr *io.SectionReader
}
// Data reads and returns the contents of the segment.
func (s *Segment) Data() ([]byte, error) {
return saferio.ReadDataAt(s.sr, s.Filesz, 0)
}
// Open returns a new ReadSeeker reading the segment.
func (s *Segment) Open() io.ReadSeeker { return io.NewSectionReader(s.sr, 0, 1<<63-1) }
type SectionHeader struct {
Name string
Seg string
Addr uint64
Size uint64
Offset uint32
Align uint32
Reloff uint32
Nreloc uint32
Flags uint32
}
// A Reloc represents a Mach-O relocation.
type Reloc struct {
Addr uint32
Value uint32
// when Scattered == false && Extern == true, Value is the symbol number.
// when Scattered == false && Extern == false, Value is the section number.
// when Scattered == true, Value is the value that this reloc refers to.
Type uint8
Len uint8 // 0=byte, 1=word, 2=long, 3=quad
Pcrel bool
Extern bool // valid if Scattered == false
Scattered bool
}
type Section struct {
SectionHeader
Relocs []Reloc
// Embed ReaderAt for ReadAt method.
// Do not embed SectionReader directly
// to avoid having Read and Seek.
// If a client wants Read and Seek it must use
// Open() to avoid fighting over the seek offset
// with other clients.
io.ReaderAt
sr *io.SectionReader
}
// Data reads and returns the contents of the Mach-O section.
func (s *Section) Data() ([]byte, error) {
return saferio.ReadDataAt(s.sr, s.Size, 0)
}
// Open returns a new ReadSeeker reading the Mach-O section.
func (s *Section) Open() io.ReadSeeker { return io.NewSectionReader(s.sr, 0, 1<<63-1) }
// A Dylib represents a Mach-O load dynamic library command.
type Dylib struct {
LoadBytes
Name string
Time uint32
CurrentVersion uint32
CompatVersion uint32
}
// A Symtab represents a Mach-O symbol table command.
type Symtab struct {
LoadBytes
SymtabCmd
Syms []Symbol
}
// A Dysymtab represents a Mach-O dynamic symbol table command.
type Dysymtab struct {
LoadBytes
DysymtabCmd
IndirectSyms []uint32 // indices into Symtab.Syms
}
// A Rpath represents a Mach-O rpath command.
type Rpath struct {
LoadBytes
Path string
}
// A Symbol is a Mach-O 32-bit or 64-bit symbol table entry.
type Symbol struct {
Name string
Type uint8
Sect uint8
Desc uint16
Value uint64
}
/*
* Mach-O reader
*/
// FormatError is returned by some operations if the data does
// not have the correct format for an object file.
type FormatError struct {
off int64
msg string
val any
}
func (e *FormatError) Error() string {
msg := e.msg
if e.val != nil {
msg += fmt.Sprintf(" '%v'", e.val)
}
msg += fmt.Sprintf(" in record at byte %#x", e.off)
return msg
}
// Open opens the named file using [os.Open] and prepares it for use as a Mach-O binary.
func Open(name string) (*File, error) {
f, err := os.Open(name)
if err != nil {
return nil, err
}
ff, err := NewFile(f)
if err != nil {
f.Close()
return nil, err
}
ff.closer = f
return ff, nil
}
// Close closes the [File].
// If the [File] was created using [NewFile] directly instead of [Open],
// Close has no effect.
func (f *File) Close() error {
var err error
if f.closer != nil {
err = f.closer.Close()
f.closer = nil
}
return err
}
// NewFile creates a new [File] for accessing a Mach-O binary in an underlying reader.
// The Mach-O binary is expected to start at position 0 in the ReaderAt.
func NewFile(r io.ReaderAt) (*File, error) {
f := new(File)
sr := io.NewSectionReader(r, 0, 1<<63-1)
// Read and decode Mach magic to determine byte order, size.
// Magic32 and Magic64 differ only in the bottom bit.
var ident [4]byte
if _, err := r.ReadAt(ident[0:], 0); err != nil {
return nil, err
}
be := binary.BigEndian.Uint32(ident[0:])
le := binary.LittleEndian.Uint32(ident[0:])
switch Magic32 &^ 1 {
case be &^ 1:
f.ByteOrder = binary.BigEndian
f.Magic = be
case le &^ 1:
f.ByteOrder = binary.LittleEndian
f.Magic = le
default:
return nil, &FormatError{0, "invalid magic number", nil}
}
// Read entire file header.
if err := binary.Read(sr, f.ByteOrder, &f.FileHeader); err != nil {
return nil, err
}
// Then load commands.
offset := int64(fileHeaderSize32)
if f.Magic == Magic64 {
offset = fileHeaderSize64
}
dat, err := saferio.ReadDataAt(r, uint64(f.Cmdsz), offset)
if err != nil {
return nil, err
}
c := saferio.SliceCap[Load](uint64(f.Ncmd))
if c < 0 {
return nil, &FormatError{offset, "too many load commands", nil}
}
f.Loads = make([]Load, 0, c)
bo := f.ByteOrder
for i := uint32(0); i < f.Ncmd; i++ {
// Each load command begins with uint32 command and length.
if len(dat) < 8 {
return nil, &FormatError{offset, "command block too small", nil}
}
cmd, siz := LoadCmd(bo.Uint32(dat[0:4])), bo.Uint32(dat[4:8])
if siz < 8 || siz > uint32(len(dat)) {
return nil, &FormatError{offset, "invalid command block size", nil}
}
var cmddat []byte
cmddat, dat = dat[0:siz], dat[siz:]
offset += int64(siz)
var s *Segment
switch cmd {
default:
f.Loads = append(f.Loads, LoadBytes(cmddat))
case LoadCmdRpath:
var hdr RpathCmd
b := bytes.NewReader(cmddat)
if err := binary.Read(b, bo, &hdr); err != nil {
return nil, err
}
l := new(Rpath)
if hdr.Path >= uint32(len(cmddat)) {
return nil, &FormatError{offset, "invalid path in rpath command", hdr.Path}
}
l.Path = cstring(cmddat[hdr.Path:])
l.LoadBytes = LoadBytes(cmddat)
f.Loads = append(f.Loads, l)
case LoadCmdDylib:
var hdr DylibCmd
b := bytes.NewReader(cmddat)
if err := binary.Read(b, bo, &hdr); err != nil {
return nil, err
}
l := new(Dylib)
if hdr.Name >= uint32(len(cmddat)) {
return nil, &FormatError{offset, "invalid name in dynamic library command", hdr.Name}
}
l.Name = cstring(cmddat[hdr.Name:])
l.Time = hdr.Time
l.CurrentVersion = hdr.CurrentVersion
l.CompatVersion = hdr.CompatVersion
l.LoadBytes = LoadBytes(cmddat)
f.Loads = append(f.Loads, l)
case LoadCmdSymtab:
var hdr SymtabCmd
b := bytes.NewReader(cmddat)
if err := binary.Read(b, bo, &hdr); err != nil {
return nil, err
}
strtab, err := saferio.ReadDataAt(r, uint64(hdr.Strsize), int64(hdr.Stroff))
if err != nil {
return nil, err
}
var symsz int
if f.Magic == Magic64 {
symsz = 16
} else {
symsz = 12
}
symdat, err := saferio.ReadDataAt(r, uint64(hdr.Nsyms)*uint64(symsz), int64(hdr.Symoff))
if err != nil {
return nil, err
}
st, err := f.parseSymtab(symdat, strtab, cmddat, &hdr, offset)
if err != nil {
return nil, err
}
f.Loads = append(f.Loads, st)
f.Symtab = st
case LoadCmdDysymtab:
var hdr DysymtabCmd
b := bytes.NewReader(cmddat)
if err := binary.Read(b, bo, &hdr); err != nil {
return nil, err
}
if f.Symtab == nil {
return nil, &FormatError{offset, "dynamic symbol table seen before any ordinary symbol table", nil}
} else if hdr.Iundefsym > uint32(len(f.Symtab.Syms)) {
return nil, &FormatError{offset, fmt.Sprintf(
"undefined symbols index in dynamic symbol table command is greater than symbol table length (%d > %d)",
hdr.Iundefsym, len(f.Symtab.Syms)), nil}
} else if hdr.Iundefsym+hdr.Nundefsym > uint32(len(f.Symtab.Syms)) {
return nil, &FormatError{offset, fmt.Sprintf(
"number of undefined symbols after index in dynamic symbol table command is greater than symbol table length (%d > %d)",
hdr.Iundefsym+hdr.Nundefsym, len(f.Symtab.Syms)), nil}
}
dat, err := saferio.ReadDataAt(r, uint64(hdr.Nindirectsyms)*4, int64(hdr.Indirectsymoff))
if err != nil {
return nil, err
}
x := make([]uint32, hdr.Nindirectsyms)
if err := binary.Read(bytes.NewReader(dat), bo, x); err != nil {
return nil, err
}
st := new(Dysymtab)
st.LoadBytes = LoadBytes(cmddat)
st.DysymtabCmd = hdr
st.IndirectSyms = x
f.Loads = append(f.Loads, st)
f.Dysymtab = st
case LoadCmdSegment:
var seg32 Segment32
b := bytes.NewReader(cmddat)
if err := binary.Read(b, bo, &seg32); err != nil {
return nil, err
}
s = new(Segment)
s.LoadBytes = cmddat
s.Cmd = cmd
s.Len = siz
s.Name = cstring(seg32.Name[0:])
s.Addr = uint64(seg32.Addr)
s.Memsz = uint64(seg32.Memsz)
s.Offset = uint64(seg32.Offset)
s.Filesz = uint64(seg32.Filesz)
s.Maxprot = seg32.Maxprot
s.Prot = seg32.Prot
s.Nsect = seg32.Nsect
s.Flag = seg32.Flag
f.Loads = append(f.Loads, s)
for i := 0; i < int(s.Nsect); i++ {
var sh32 Section32
if err := binary.Read(b, bo, &sh32); err != nil {
return nil, err
}
sh := new(Section)
sh.Name = cstring(sh32.Name[0:])
sh.Seg = cstring(sh32.Seg[0:])
sh.Addr = uint64(sh32.Addr)
sh.Size = uint64(sh32.Size)
sh.Offset = sh32.Offset
sh.Align = sh32.Align
sh.Reloff = sh32.Reloff
sh.Nreloc = sh32.Nreloc
sh.Flags = sh32.Flags
if err := f.pushSection(sh, r); err != nil {
return nil, err
}
}
case LoadCmdSegment64:
var seg64 Segment64
b := bytes.NewReader(cmddat)
if err := binary.Read(b, bo, &seg64); err != nil {
return nil, err
}
s = new(Segment)
s.LoadBytes = cmddat
s.Cmd = cmd
s.Len = siz
s.Name = cstring(seg64.Name[0:])
s.Addr = seg64.Addr
s.Memsz = seg64.Memsz
s.Offset = seg64.Offset
s.Filesz = seg64.Filesz
s.Maxprot = seg64.Maxprot
s.Prot = seg64.Prot
s.Nsect = seg64.Nsect
s.Flag = seg64.Flag
f.Loads = append(f.Loads, s)
for i := 0; i < int(s.Nsect); i++ {
var sh64 Section64
if err := binary.Read(b, bo, &sh64); err != nil {
return nil, err
}
sh := new(Section)
sh.Name = cstring(sh64.Name[0:])
sh.Seg = cstring(sh64.Seg[0:])
sh.Addr = sh64.Addr
sh.Size = sh64.Size
sh.Offset = sh64.Offset
sh.Align = sh64.Align
sh.Reloff = sh64.Reloff
sh.Nreloc = sh64.Nreloc
sh.Flags = sh64.Flags
if err := f.pushSection(sh, r); err != nil {
return nil, err
}
}
}
if s != nil {
if int64(s.Offset) < 0 {
return nil, &FormatError{offset, "invalid section offset", s.Offset}
}
if int64(s.Filesz) < 0 {
return nil, &FormatError{offset, "invalid section file size", s.Filesz}
}
s.sr = io.NewSectionReader(r, int64(s.Offset), int64(s.Filesz))
s.ReaderAt = s.sr
}
}
return f, nil
}
func (f *File) parseSymtab(symdat, strtab, cmddat []byte, hdr *SymtabCmd, offset int64) (*Symtab, error) {
bo := f.ByteOrder
c := saferio.SliceCap[Symbol](uint64(hdr.Nsyms))
if c < 0 {
return nil, &FormatError{offset, "too many symbols", nil}
}
symtab := make([]Symbol, 0, c)
b := bytes.NewReader(symdat)
for i := 0; i < int(hdr.Nsyms); i++ {
var n Nlist64
if f.Magic == Magic64 {
if err := binary.Read(b, bo, &n); err != nil {
return nil, err
}
} else {
var n32 Nlist32
if err := binary.Read(b, bo, &n32); err != nil {
return nil, err
}
n.Name = n32.Name
n.Type = n32.Type
n.Sect = n32.Sect
n.Desc = n32.Desc
n.Value = uint64(n32.Value)
}
if n.Name >= uint32(len(strtab)) {
return nil, &FormatError{offset, "invalid name in symbol table", n.Name}
}
// We add "_" to Go symbols. Strip it here. See issue 33808.
name := cstring(strtab[n.Name:])
if strings.Contains(name, ".") && name[0] == '_' {
name = name[1:]
}
symtab = append(symtab, Symbol{
Name: name,
Type: n.Type,
Sect: n.Sect,
Desc: n.Desc,
Value: n.Value,
})
}
st := new(Symtab)
st.LoadBytes = LoadBytes(cmddat)
st.Syms = symtab
return st, nil
}
type relocInfo struct {
Addr uint32
Symnum uint32
}
func (f *File) pushSection(sh *Section, r io.ReaderAt) error {
f.Sections = append(f.Sections, sh)
sh.sr = io.NewSectionReader(r, int64(sh.Offset), int64(sh.Size))
sh.ReaderAt = sh.sr
if sh.Nreloc > 0 {
reldat, err := saferio.ReadDataAt(r, uint64(sh.Nreloc)*8, int64(sh.Reloff))
if err != nil {
return err
}
b := bytes.NewReader(reldat)
bo := f.ByteOrder
sh.Relocs = make([]Reloc, sh.Nreloc)
for i := range sh.Relocs {
rel := &sh.Relocs[i]
var ri relocInfo
if err := binary.Read(b, bo, &ri); err != nil {
return err
}
if ri.Addr&(1<<31) != 0 { // scattered
rel.Addr = ri.Addr & (1<<24 - 1)
rel.Type = uint8((ri.Addr >> 24) & (1<<4 - 1))
rel.Len = uint8((ri.Addr >> 28) & (1<<2 - 1))
rel.Pcrel = ri.Addr&(1<<30) != 0
rel.Value = ri.Symnum
rel.Scattered = true
} else {
switch bo {
case binary.LittleEndian:
rel.Addr = ri.Addr
rel.Value = ri.Symnum & (1<<24 - 1)
rel.Pcrel = ri.Symnum&(1<<24) != 0
rel.Len = uint8((ri.Symnum >> 25) & (1<<2 - 1))
rel.Extern = ri.Symnum&(1<<27) != 0
rel.Type = uint8((ri.Symnum >> 28) & (1<<4 - 1))
case binary.BigEndian:
rel.Addr = ri.Addr
rel.Value = ri.Symnum >> 8
rel.Pcrel = ri.Symnum&(1<<7) != 0
rel.Len = uint8((ri.Symnum >> 5) & (1<<2 - 1))
rel.Extern = ri.Symnum&(1<<4) != 0
rel.Type = uint8(ri.Symnum & (1<<4 - 1))
default:
panic("unreachable")
}
}
}
}
return nil
}
func cstring(b []byte) string {
i := bytes.IndexByte(b, 0)
if i == -1 {
i = len(b)
}
return string(b[0:i])
}
// Segment returns the first Segment with the given name, or nil if no such segment exists.
func (f *File) Segment(name string) *Segment {
for _, l := range f.Loads {
if s, ok := l.(*Segment); ok && s.Name == name {
return s
}
}
return nil
}
// Section returns the first section with the given name, or nil if no such
// section exists.
func (f *File) Section(name string) *Section {
for _, s := range f.Sections {
if s.Name == name {
return s
}
}
return nil
}
// DWARF returns the DWARF debug information for the Mach-O file.
func (f *File) DWARF() (*dwarf.Data, error) {
dwarfSuffix := func(s *Section) string {
sectname := s.Name
var pfx int
switch {
case strings.HasPrefix(sectname, "__debug_"):
pfx = 8
case strings.HasPrefix(sectname, "__zdebug_"):
pfx = 9
default:
return ""
}
// Mach-O executables truncate section names to 16 characters, mangling some DWARF sections.
// As of DWARFv5 these are the only problematic section names (see DWARFv5 Appendix G).
for _, longname := range []string{
"__debug_str_offsets",
"__zdebug_line_str",
"__zdebug_loclists",
"__zdebug_pubnames",
"__zdebug_pubtypes",
"__zdebug_rnglists",
"__zdebug_str_offsets",
} {
if sectname == longname[:16] {
sectname = longname
break
}
}
return sectname[pfx:]
}
sectionData := func(s *Section) ([]byte, error) {
b, err := s.Data()
if err != nil && uint64(len(b)) < s.Size {
return nil, err
}
if len(b) >= 12 && string(b[:4]) == "ZLIB" {
dlen := binary.BigEndian.Uint64(b[4:12])
dbuf := make([]byte, dlen)
r, err := zlib.NewReader(bytes.NewBuffer(b[12:]))
if err != nil {
return nil, err
}
if _, err := io.ReadFull(r, dbuf); err != nil {
return nil, err
}
if err := r.Close(); err != nil {
return nil, err
}
b = dbuf
}
return b, nil
}
// There are many other DWARF sections, but these
// are the ones the debug/dwarf package uses.
// Don't bother loading others.
var dat = map[string][]byte{"abbrev": nil, "info": nil, "str": nil, "line": nil, "ranges": nil}
for _, s := range f.Sections {
suffix := dwarfSuffix(s)
if suffix == "" {
continue
}
if _, ok := dat[suffix]; !ok {
continue
}
b, err := sectionData(s)
if err != nil {
return nil, err
}
dat[suffix] = b
}
d, err := dwarf.New(dat["abbrev"], nil, nil, dat["info"], dat["line"], nil, dat["ranges"], dat["str"])
if err != nil {
return nil, err
}
// Look for DWARF4 .debug_types sections and DWARF5 sections.
for i, s := range f.Sections {
suffix := dwarfSuffix(s)
if suffix == "" {
continue
}
if _, ok := dat[suffix]; ok {
// Already handled.
continue
}
b, err := sectionData(s)
if err != nil {
return nil, err
}
if suffix == "types" {
err = d.AddTypes(fmt.Sprintf("types-%d", i), b)
} else {
err = d.AddSection(".debug_"+suffix, b)
}
if err != nil {
return nil, err
}
}
return d, nil
}
// ImportedSymbols returns the names of all symbols
// referred to by the binary f that are expected to be
// satisfied by other libraries at dynamic load time.
func (f *File) ImportedSymbols() ([]string, error) {
if f.Symtab == nil {
return nil, &FormatError{0, "missing symbol table", nil}
}
st := f.Symtab
dt := f.Dysymtab
var all []string
if dt != nil {
for _, s := range st.Syms[dt.Iundefsym : dt.Iundefsym+dt.Nundefsym] {
all = append(all, s.Name)
}
} else {
// From Darwin's include/mach-o/nlist.h
const (
N_TYPE = 0x0e
N_UNDF = 0x0
)
for _, s := range st.Syms {
if s.Type&N_TYPE == N_UNDF && s.Sect == 0 {
all = append(all, s.Name)
}
}
}
return all, nil
}
// ImportedLibraries returns the paths of all libraries
// referred to by the binary f that are expected to be
// linked with the binary at dynamic link time.
func (f *File) ImportedLibraries() ([]string, error) {
var all []string
for _, l := range f.Loads {
if lib, ok := l.(*Dylib); ok {
all = append(all, lib.Name)
}
}
return all, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Mach-O header data structures
// Originally at:
// http://developer.apple.com/mac/library/documentation/DeveloperTools/Conceptual/MachORuntime/Reference/reference.html (since deleted by Apple)
// Archived copy at:
// https://web.archive.org/web/20090819232456/http://developer.apple.com/documentation/DeveloperTools/Conceptual/MachORuntime/index.html
// For cloned PDF see:
// https://github.com/aidansteele/osx-abi-macho-file-format-reference
package macho
import "strconv"
// A FileHeader represents a Mach-O file header.
type FileHeader struct {
Magic uint32
Cpu Cpu
SubCpu uint32
Type Type
Ncmd uint32
Cmdsz uint32
Flags uint32
}
const (
fileHeaderSize32 = 7 * 4
fileHeaderSize64 = 8 * 4
)
const (
Magic32 uint32 = 0xfeedface
Magic64 uint32 = 0xfeedfacf
MagicFat uint32 = 0xcafebabe
)
// A Type is the Mach-O file type, e.g. an object file, executable, or dynamic library.
type Type uint32
const (
TypeObj Type = 1
TypeExec Type = 2
TypeDylib Type = 6
TypeBundle Type = 8
)
var typeStrings = []intName{
{uint32(TypeObj), "Obj"},
{uint32(TypeExec), "Exec"},
{uint32(TypeDylib), "Dylib"},
{uint32(TypeBundle), "Bundle"},
}
func (t Type) String() string { return stringName(uint32(t), typeStrings, false) }
func (t Type) GoString() string { return stringName(uint32(t), typeStrings, true) }
// A Cpu is a Mach-O cpu type.
type Cpu uint32
const cpuArch64 = 0x01000000
const (
Cpu386 Cpu = 7
CpuAmd64 Cpu = Cpu386 | cpuArch64
CpuArm Cpu = 12
CpuArm64 Cpu = CpuArm | cpuArch64
CpuPpc Cpu = 18
CpuPpc64 Cpu = CpuPpc | cpuArch64
)
var cpuStrings = []intName{
{uint32(Cpu386), "Cpu386"},
{uint32(CpuAmd64), "CpuAmd64"},
{uint32(CpuArm), "CpuArm"},
{uint32(CpuArm64), "CpuArm64"},
{uint32(CpuPpc), "CpuPpc"},
{uint32(CpuPpc64), "CpuPpc64"},
}
func (i Cpu) String() string { return stringName(uint32(i), cpuStrings, false) }
func (i Cpu) GoString() string { return stringName(uint32(i), cpuStrings, true) }
// A LoadCmd is a Mach-O load command.
type LoadCmd uint32
const (
LoadCmdSegment LoadCmd = 0x1
LoadCmdSymtab LoadCmd = 0x2
LoadCmdThread LoadCmd = 0x4
LoadCmdUnixThread LoadCmd = 0x5 // thread+stack
LoadCmdDysymtab LoadCmd = 0xb
LoadCmdDylib LoadCmd = 0xc // load dylib command
LoadCmdDylinker LoadCmd = 0xf // id dylinker command (not load dylinker command)
LoadCmdSegment64 LoadCmd = 0x19
LoadCmdRpath LoadCmd = 0x8000001c
)
var cmdStrings = []intName{
{uint32(LoadCmdSegment), "LoadCmdSegment"},
{uint32(LoadCmdThread), "LoadCmdThread"},
{uint32(LoadCmdUnixThread), "LoadCmdUnixThread"},
{uint32(LoadCmdDylib), "LoadCmdDylib"},
{uint32(LoadCmdSegment64), "LoadCmdSegment64"},
{uint32(LoadCmdRpath), "LoadCmdRpath"},
}
func (i LoadCmd) String() string { return stringName(uint32(i), cmdStrings, false) }
func (i LoadCmd) GoString() string { return stringName(uint32(i), cmdStrings, true) }
type (
// A Segment32 is a 32-bit Mach-O segment load command.
Segment32 struct {
Cmd LoadCmd
Len uint32
Name [16]byte
Addr uint32
Memsz uint32
Offset uint32
Filesz uint32
Maxprot uint32
Prot uint32
Nsect uint32
Flag uint32
}
// A Segment64 is a 64-bit Mach-O segment load command.
Segment64 struct {
Cmd LoadCmd
Len uint32
Name [16]byte
Addr uint64
Memsz uint64
Offset uint64
Filesz uint64
Maxprot uint32
Prot uint32
Nsect uint32
Flag uint32
}
// A SymtabCmd is a Mach-O symbol table command.
SymtabCmd struct {
Cmd LoadCmd
Len uint32
Symoff uint32
Nsyms uint32
Stroff uint32
Strsize uint32
}
// A DysymtabCmd is a Mach-O dynamic symbol table command.
DysymtabCmd struct {
Cmd LoadCmd
Len uint32
Ilocalsym uint32
Nlocalsym uint32
Iextdefsym uint32
Nextdefsym uint32
Iundefsym uint32
Nundefsym uint32
Tocoffset uint32
Ntoc uint32
Modtaboff uint32
Nmodtab uint32
Extrefsymoff uint32
Nextrefsyms uint32
Indirectsymoff uint32
Nindirectsyms uint32
Extreloff uint32
Nextrel uint32
Locreloff uint32
Nlocrel uint32
}
// A DylibCmd is a Mach-O load dynamic library command.
DylibCmd struct {
Cmd LoadCmd
Len uint32
Name uint32
Time uint32
CurrentVersion uint32
CompatVersion uint32
}
// A RpathCmd is a Mach-O rpath command.
RpathCmd struct {
Cmd LoadCmd
Len uint32
Path uint32
}
// A Thread is a Mach-O thread state command.
Thread struct {
Cmd LoadCmd
Len uint32
Type uint32
Data []uint32
}
)
const (
FlagNoUndefs uint32 = 0x1
FlagIncrLink uint32 = 0x2
FlagDyldLink uint32 = 0x4
FlagBindAtLoad uint32 = 0x8
FlagPrebound uint32 = 0x10
FlagSplitSegs uint32 = 0x20
FlagLazyInit uint32 = 0x40
FlagTwoLevel uint32 = 0x80
FlagForceFlat uint32 = 0x100
FlagNoMultiDefs uint32 = 0x200
FlagNoFixPrebinding uint32 = 0x400
FlagPrebindable uint32 = 0x800
FlagAllModsBound uint32 = 0x1000
FlagSubsectionsViaSymbols uint32 = 0x2000
FlagCanonical uint32 = 0x4000
FlagWeakDefines uint32 = 0x8000
FlagBindsToWeak uint32 = 0x10000
FlagAllowStackExecution uint32 = 0x20000
FlagRootSafe uint32 = 0x40000
FlagSetuidSafe uint32 = 0x80000
FlagNoReexportedDylibs uint32 = 0x100000
FlagPIE uint32 = 0x200000
FlagDeadStrippableDylib uint32 = 0x400000
FlagHasTLVDescriptors uint32 = 0x800000
FlagNoHeapExecution uint32 = 0x1000000
FlagAppExtensionSafe uint32 = 0x2000000
)
// A Section32 is a 32-bit Mach-O section header.
type Section32 struct {
Name [16]byte
Seg [16]byte
Addr uint32
Size uint32
Offset uint32
Align uint32
Reloff uint32
Nreloc uint32
Flags uint32
Reserve1 uint32
Reserve2 uint32
}
// A Section64 is a 64-bit Mach-O section header.
type Section64 struct {
Name [16]byte
Seg [16]byte
Addr uint64
Size uint64
Offset uint32
Align uint32
Reloff uint32
Nreloc uint32
Flags uint32
Reserve1 uint32
Reserve2 uint32
Reserve3 uint32
}
// An Nlist32 is a Mach-O 32-bit symbol table entry.
type Nlist32 struct {
Name uint32
Type uint8
Sect uint8
Desc uint16
Value uint32
}
// An Nlist64 is a Mach-O 64-bit symbol table entry.
type Nlist64 struct {
Name uint32
Type uint8
Sect uint8
Desc uint16
Value uint64
}
// Regs386 is the Mach-O 386 register structure.
type Regs386 struct {
AX uint32
BX uint32
CX uint32
DX uint32
DI uint32
SI uint32
BP uint32
SP uint32
SS uint32
FLAGS uint32
IP uint32
CS uint32
DS uint32
ES uint32
FS uint32
GS uint32
}
// RegsAMD64 is the Mach-O AMD64 register structure.
type RegsAMD64 struct {
AX uint64
BX uint64
CX uint64
DX uint64
DI uint64
SI uint64
BP uint64
SP uint64
R8 uint64
R9 uint64
R10 uint64
R11 uint64
R12 uint64
R13 uint64
R14 uint64
R15 uint64
IP uint64
FLAGS uint64
CS uint64
FS uint64
GS uint64
}
type intName struct {
i uint32
s string
}
func stringName(i uint32, names []intName, goSyntax bool) string {
for _, n := range names {
if n.i == i {
if goSyntax {
return "macho." + n.s
}
return n.s
}
}
return strconv.FormatUint(uint64(i), 10)
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package macho
//go:generate stringer -type=RelocTypeGeneric,RelocTypeX86_64,RelocTypeARM,RelocTypeARM64 -output reloctype_string.go
type RelocTypeGeneric int
const (
GENERIC_RELOC_VANILLA RelocTypeGeneric = 0
GENERIC_RELOC_PAIR RelocTypeGeneric = 1
GENERIC_RELOC_SECTDIFF RelocTypeGeneric = 2
GENERIC_RELOC_PB_LA_PTR RelocTypeGeneric = 3
GENERIC_RELOC_LOCAL_SECTDIFF RelocTypeGeneric = 4
GENERIC_RELOC_TLV RelocTypeGeneric = 5
)
func (r RelocTypeGeneric) GoString() string { return "macho." + r.String() }
type RelocTypeX86_64 int
const (
X86_64_RELOC_UNSIGNED RelocTypeX86_64 = 0
X86_64_RELOC_SIGNED RelocTypeX86_64 = 1
X86_64_RELOC_BRANCH RelocTypeX86_64 = 2
X86_64_RELOC_GOT_LOAD RelocTypeX86_64 = 3
X86_64_RELOC_GOT RelocTypeX86_64 = 4
X86_64_RELOC_SUBTRACTOR RelocTypeX86_64 = 5
X86_64_RELOC_SIGNED_1 RelocTypeX86_64 = 6
X86_64_RELOC_SIGNED_2 RelocTypeX86_64 = 7
X86_64_RELOC_SIGNED_4 RelocTypeX86_64 = 8
X86_64_RELOC_TLV RelocTypeX86_64 = 9
)
func (r RelocTypeX86_64) GoString() string { return "macho." + r.String() }
type RelocTypeARM int
const (
ARM_RELOC_VANILLA RelocTypeARM = 0
ARM_RELOC_PAIR RelocTypeARM = 1
ARM_RELOC_SECTDIFF RelocTypeARM = 2
ARM_RELOC_LOCAL_SECTDIFF RelocTypeARM = 3
ARM_RELOC_PB_LA_PTR RelocTypeARM = 4
ARM_RELOC_BR24 RelocTypeARM = 5
ARM_THUMB_RELOC_BR22 RelocTypeARM = 6
ARM_THUMB_32BIT_BRANCH RelocTypeARM = 7
ARM_RELOC_HALF RelocTypeARM = 8
ARM_RELOC_HALF_SECTDIFF RelocTypeARM = 9
)
func (r RelocTypeARM) GoString() string { return "macho." + r.String() }
type RelocTypeARM64 int
const (
ARM64_RELOC_UNSIGNED RelocTypeARM64 = 0
ARM64_RELOC_SUBTRACTOR RelocTypeARM64 = 1
ARM64_RELOC_BRANCH26 RelocTypeARM64 = 2
ARM64_RELOC_PAGE21 RelocTypeARM64 = 3
ARM64_RELOC_PAGEOFF12 RelocTypeARM64 = 4
ARM64_RELOC_GOT_LOAD_PAGE21 RelocTypeARM64 = 5
ARM64_RELOC_GOT_LOAD_PAGEOFF12 RelocTypeARM64 = 6
ARM64_RELOC_POINTER_TO_GOT RelocTypeARM64 = 7
ARM64_RELOC_TLVP_LOAD_PAGE21 RelocTypeARM64 = 8
ARM64_RELOC_TLVP_LOAD_PAGEOFF12 RelocTypeARM64 = 9
ARM64_RELOC_ADDEND RelocTypeARM64 = 10
)
func (r RelocTypeARM64) GoString() string { return "macho." + r.String() }
// Code generated by "stringer -type=RelocTypeGeneric,RelocTypeX86_64,RelocTypeARM,RelocTypeARM64 -output reloctype_string.go"; DO NOT EDIT.
package macho
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[GENERIC_RELOC_VANILLA-0]
_ = x[GENERIC_RELOC_PAIR-1]
_ = x[GENERIC_RELOC_SECTDIFF-2]
_ = x[GENERIC_RELOC_PB_LA_PTR-3]
_ = x[GENERIC_RELOC_LOCAL_SECTDIFF-4]
_ = x[GENERIC_RELOC_TLV-5]
}
const _RelocTypeGeneric_name = "GENERIC_RELOC_VANILLAGENERIC_RELOC_PAIRGENERIC_RELOC_SECTDIFFGENERIC_RELOC_PB_LA_PTRGENERIC_RELOC_LOCAL_SECTDIFFGENERIC_RELOC_TLV"
var _RelocTypeGeneric_index = [...]uint8{0, 21, 39, 61, 84, 112, 129}
func (i RelocTypeGeneric) String() string {
if i < 0 || i >= RelocTypeGeneric(len(_RelocTypeGeneric_index)-1) {
return "RelocTypeGeneric(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _RelocTypeGeneric_name[_RelocTypeGeneric_index[i]:_RelocTypeGeneric_index[i+1]]
}
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[X86_64_RELOC_UNSIGNED-0]
_ = x[X86_64_RELOC_SIGNED-1]
_ = x[X86_64_RELOC_BRANCH-2]
_ = x[X86_64_RELOC_GOT_LOAD-3]
_ = x[X86_64_RELOC_GOT-4]
_ = x[X86_64_RELOC_SUBTRACTOR-5]
_ = x[X86_64_RELOC_SIGNED_1-6]
_ = x[X86_64_RELOC_SIGNED_2-7]
_ = x[X86_64_RELOC_SIGNED_4-8]
_ = x[X86_64_RELOC_TLV-9]
}
const _RelocTypeX86_64_name = "X86_64_RELOC_UNSIGNEDX86_64_RELOC_SIGNEDX86_64_RELOC_BRANCHX86_64_RELOC_GOT_LOADX86_64_RELOC_GOTX86_64_RELOC_SUBTRACTORX86_64_RELOC_SIGNED_1X86_64_RELOC_SIGNED_2X86_64_RELOC_SIGNED_4X86_64_RELOC_TLV"
var _RelocTypeX86_64_index = [...]uint8{0, 21, 40, 59, 80, 96, 119, 140, 161, 182, 198}
func (i RelocTypeX86_64) String() string {
if i < 0 || i >= RelocTypeX86_64(len(_RelocTypeX86_64_index)-1) {
return "RelocTypeX86_64(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _RelocTypeX86_64_name[_RelocTypeX86_64_index[i]:_RelocTypeX86_64_index[i+1]]
}
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[ARM_RELOC_VANILLA-0]
_ = x[ARM_RELOC_PAIR-1]
_ = x[ARM_RELOC_SECTDIFF-2]
_ = x[ARM_RELOC_LOCAL_SECTDIFF-3]
_ = x[ARM_RELOC_PB_LA_PTR-4]
_ = x[ARM_RELOC_BR24-5]
_ = x[ARM_THUMB_RELOC_BR22-6]
_ = x[ARM_THUMB_32BIT_BRANCH-7]
_ = x[ARM_RELOC_HALF-8]
_ = x[ARM_RELOC_HALF_SECTDIFF-9]
}
const _RelocTypeARM_name = "ARM_RELOC_VANILLAARM_RELOC_PAIRARM_RELOC_SECTDIFFARM_RELOC_LOCAL_SECTDIFFARM_RELOC_PB_LA_PTRARM_RELOC_BR24ARM_THUMB_RELOC_BR22ARM_THUMB_32BIT_BRANCHARM_RELOC_HALFARM_RELOC_HALF_SECTDIFF"
var _RelocTypeARM_index = [...]uint8{0, 17, 31, 49, 73, 92, 106, 126, 148, 162, 185}
func (i RelocTypeARM) String() string {
if i < 0 || i >= RelocTypeARM(len(_RelocTypeARM_index)-1) {
return "RelocTypeARM(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _RelocTypeARM_name[_RelocTypeARM_index[i]:_RelocTypeARM_index[i+1]]
}
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[ARM64_RELOC_UNSIGNED-0]
_ = x[ARM64_RELOC_SUBTRACTOR-1]
_ = x[ARM64_RELOC_BRANCH26-2]
_ = x[ARM64_RELOC_PAGE21-3]
_ = x[ARM64_RELOC_PAGEOFF12-4]
_ = x[ARM64_RELOC_GOT_LOAD_PAGE21-5]
_ = x[ARM64_RELOC_GOT_LOAD_PAGEOFF12-6]
_ = x[ARM64_RELOC_POINTER_TO_GOT-7]
_ = x[ARM64_RELOC_TLVP_LOAD_PAGE21-8]
_ = x[ARM64_RELOC_TLVP_LOAD_PAGEOFF12-9]
_ = x[ARM64_RELOC_ADDEND-10]
}
const _RelocTypeARM64_name = "ARM64_RELOC_UNSIGNEDARM64_RELOC_SUBTRACTORARM64_RELOC_BRANCH26ARM64_RELOC_PAGE21ARM64_RELOC_PAGEOFF12ARM64_RELOC_GOT_LOAD_PAGE21ARM64_RELOC_GOT_LOAD_PAGEOFF12ARM64_RELOC_POINTER_TO_GOTARM64_RELOC_TLVP_LOAD_PAGE21ARM64_RELOC_TLVP_LOAD_PAGEOFF12ARM64_RELOC_ADDEND"
var _RelocTypeARM64_index = [...]uint16{0, 20, 42, 62, 80, 101, 128, 158, 184, 212, 243, 261}
func (i RelocTypeARM64) String() string {
if i < 0 || i >= RelocTypeARM64(len(_RelocTypeARM64_index)-1) {
return "RelocTypeARM64(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _RelocTypeARM64_name[_RelocTypeARM64_index[i]:_RelocTypeARM64_index[i+1]]
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package pe implements access to PE (Microsoft Windows Portable Executable) files.
# Security
This package is not designed to be hardened against adversarial inputs, and is
outside the scope of https://go.dev/security/policy. In particular, only basic
validation is done when parsing object files. As such, care should be taken when
parsing untrusted inputs, as parsing malformed files may consume significant
resources, or cause panics.
*/
package pe
import (
"bytes"
"compress/zlib"
"debug/dwarf"
"encoding/binary"
"errors"
"fmt"
"io"
"os"
"strings"
)
// A File represents an open PE file.
type File struct {
FileHeader
OptionalHeader any // of type *OptionalHeader32 or *OptionalHeader64
Sections []*Section
Symbols []*Symbol // COFF symbols with auxiliary symbol records removed
COFFSymbols []COFFSymbol // all COFF symbols (including auxiliary symbol records)
StringTable StringTable
closer io.Closer
}
// Open opens the named file using [os.Open] and prepares it for use as a PE binary.
func Open(name string) (*File, error) {
f, err := os.Open(name)
if err != nil {
return nil, err
}
ff, err := NewFile(f)
if err != nil {
f.Close()
return nil, err
}
ff.closer = f
return ff, nil
}
// Close closes the [File].
// If the [File] was created using [NewFile] directly instead of [Open],
// Close has no effect.
func (f *File) Close() error {
var err error
if f.closer != nil {
err = f.closer.Close()
f.closer = nil
}
return err
}
// TODO(brainman): add Load function, as a replacement for NewFile, that does not call removeAuxSymbols (for performance)
// NewFile creates a new [File] for accessing a PE binary in an underlying reader.
func NewFile(r io.ReaderAt) (*File, error) {
f := new(File)
sr := io.NewSectionReader(r, 0, 1<<63-1)
var dosheader [96]byte
if _, err := r.ReadAt(dosheader[0:], 0); err != nil {
return nil, err
}
var base int64
if dosheader[0] == 'M' && dosheader[1] == 'Z' {
signoff := int64(binary.LittleEndian.Uint32(dosheader[0x3c:]))
var sign [4]byte
r.ReadAt(sign[:], signoff)
if !(sign[0] == 'P' && sign[1] == 'E' && sign[2] == 0 && sign[3] == 0) {
return nil, fmt.Errorf("invalid PE file signature: % x", sign)
}
base = signoff + 4
} else {
base = int64(0)
}
sr.Seek(base, io.SeekStart)
if err := binary.Read(sr, binary.LittleEndian, &f.FileHeader); err != nil {
return nil, err
}
switch f.FileHeader.Machine {
case IMAGE_FILE_MACHINE_AMD64,
IMAGE_FILE_MACHINE_ARM64,
IMAGE_FILE_MACHINE_ARMNT,
IMAGE_FILE_MACHINE_I386,
IMAGE_FILE_MACHINE_RISCV32,
IMAGE_FILE_MACHINE_RISCV64,
IMAGE_FILE_MACHINE_RISCV128,
IMAGE_FILE_MACHINE_UNKNOWN:
// ok
default:
return nil, fmt.Errorf("unrecognized PE machine: %#x", f.FileHeader.Machine)
}
var err error
// Read string table.
f.StringTable, err = readStringTable(&f.FileHeader, sr)
if err != nil {
return nil, err
}
// Read symbol table.
f.COFFSymbols, err = readCOFFSymbols(&f.FileHeader, sr)
if err != nil {
return nil, err
}
f.Symbols, err = removeAuxSymbols(f.COFFSymbols, f.StringTable)
if err != nil {
return nil, err
}
// Seek past file header.
_, err = sr.Seek(base+int64(binary.Size(f.FileHeader)), io.SeekStart)
if err != nil {
return nil, err
}
// Read optional header.
f.OptionalHeader, err = readOptionalHeader(sr, f.FileHeader.SizeOfOptionalHeader)
if err != nil {
return nil, err
}
// Process sections.
f.Sections = make([]*Section, f.FileHeader.NumberOfSections)
for i := 0; i < int(f.FileHeader.NumberOfSections); i++ {
sh := new(SectionHeader32)
if err := binary.Read(sr, binary.LittleEndian, sh); err != nil {
return nil, err
}
name, err := sh.fullName(f.StringTable)
if err != nil {
return nil, err
}
s := new(Section)
s.SectionHeader = SectionHeader{
Name: name,
VirtualSize: sh.VirtualSize,
VirtualAddress: sh.VirtualAddress,
Size: sh.SizeOfRawData,
Offset: sh.PointerToRawData,
PointerToRelocations: sh.PointerToRelocations,
PointerToLineNumbers: sh.PointerToLineNumbers,
NumberOfRelocations: sh.NumberOfRelocations,
NumberOfLineNumbers: sh.NumberOfLineNumbers,
Characteristics: sh.Characteristics,
}
r2 := r
if sh.PointerToRawData == 0 { // .bss must have all 0s
r2 = &nobitsSectionReader{}
}
s.sr = io.NewSectionReader(r2, int64(s.SectionHeader.Offset), int64(s.SectionHeader.Size))
s.ReaderAt = s.sr
f.Sections[i] = s
}
for i := range f.Sections {
var err error
f.Sections[i].Relocs, err = readRelocs(&f.Sections[i].SectionHeader, sr)
if err != nil {
return nil, err
}
}
return f, nil
}
type nobitsSectionReader struct{}
func (*nobitsSectionReader) ReadAt(p []byte, off int64) (n int, err error) {
return 0, errors.New("unexpected read from section with uninitialized data")
}
// getString extracts a string from symbol string table.
func getString(section []byte, start int) (string, bool) {
if start < 0 || start >= len(section) {
return "", false
}
for end := start; end < len(section); end++ {
if section[end] == 0 {
return string(section[start:end]), true
}
}
return "", false
}
// Section returns the first section with the given name, or nil if no such
// section exists.
func (f *File) Section(name string) *Section {
for _, s := range f.Sections {
if s.Name == name {
return s
}
}
return nil
}
func (f *File) DWARF() (*dwarf.Data, error) {
dwarfSuffix := func(s *Section) string {
switch {
case strings.HasPrefix(s.Name, ".debug_"):
return s.Name[7:]
case strings.HasPrefix(s.Name, ".zdebug_"):
return s.Name[8:]
default:
return ""
}
}
// sectionData gets the data for s and checks its size.
sectionData := func(s *Section) ([]byte, error) {
b, err := s.Data()
if err != nil && uint32(len(b)) < s.Size {
return nil, err
}
if 0 < s.VirtualSize && s.VirtualSize < s.Size {
b = b[:s.VirtualSize]
}
if len(b) >= 12 && string(b[:4]) == "ZLIB" {
dlen := binary.BigEndian.Uint64(b[4:12])
dbuf := make([]byte, dlen)
r, err := zlib.NewReader(bytes.NewBuffer(b[12:]))
if err != nil {
return nil, err
}
if _, err := io.ReadFull(r, dbuf); err != nil {
return nil, err
}
if err := r.Close(); err != nil {
return nil, err
}
b = dbuf
}
return b, nil
}
// There are many other DWARF sections, but these
// are the ones the debug/dwarf package uses.
// Don't bother loading others.
var dat = map[string][]byte{"abbrev": nil, "info": nil, "str": nil, "line": nil, "ranges": nil}
for _, s := range f.Sections {
suffix := dwarfSuffix(s)
if suffix == "" {
continue
}
if _, ok := dat[suffix]; !ok {
continue
}
b, err := sectionData(s)
if err != nil {
return nil, err
}
dat[suffix] = b
}
d, err := dwarf.New(dat["abbrev"], nil, nil, dat["info"], dat["line"], nil, dat["ranges"], dat["str"])
if err != nil {
return nil, err
}
// Look for DWARF4 .debug_types sections and DWARF5 sections.
for i, s := range f.Sections {
suffix := dwarfSuffix(s)
if suffix == "" {
continue
}
if _, ok := dat[suffix]; ok {
// Already handled.
continue
}
b, err := sectionData(s)
if err != nil {
return nil, err
}
if suffix == "types" {
err = d.AddTypes(fmt.Sprintf("types-%d", i), b)
} else {
err = d.AddSection(".debug_"+suffix, b)
}
if err != nil {
return nil, err
}
}
return d, nil
}
// TODO(brainman): document ImportDirectory once we decide what to do with it.
type ImportDirectory struct {
OriginalFirstThunk uint32
TimeDateStamp uint32
ForwarderChain uint32
Name uint32
FirstThunk uint32
dll string
}
// ImportedSymbols returns the names of all symbols
// referred to by the binary f that are expected to be
// satisfied by other libraries at dynamic load time.
// It does not return weak symbols.
func (f *File) ImportedSymbols() ([]string, error) {
if f.OptionalHeader == nil {
return nil, nil
}
_, pe64 := f.OptionalHeader.(*OptionalHeader64)
// grab the number of data directory entries
var dd_length uint32
if pe64 {
dd_length = f.OptionalHeader.(*OptionalHeader64).NumberOfRvaAndSizes
} else {
dd_length = f.OptionalHeader.(*OptionalHeader32).NumberOfRvaAndSizes
}
// check that the length of data directory entries is large
// enough to include the imports directory.
if dd_length < IMAGE_DIRECTORY_ENTRY_IMPORT+1 {
return nil, nil
}
// grab the import data directory entry
var idd DataDirectory
if pe64 {
idd = f.OptionalHeader.(*OptionalHeader64).DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT]
} else {
idd = f.OptionalHeader.(*OptionalHeader32).DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT]
}
// figure out which section contains the import directory table
var ds *Section
ds = nil
for _, s := range f.Sections {
if s.Offset == 0 {
continue
}
// We are using distance between s.VirtualAddress and idd.VirtualAddress
// to avoid potential overflow of uint32 caused by addition of s.VirtualSize
// to s.VirtualAddress.
if s.VirtualAddress <= idd.VirtualAddress && idd.VirtualAddress-s.VirtualAddress < s.VirtualSize {
ds = s
break
}
}
// didn't find a section, so no import libraries were found
if ds == nil {
return nil, nil
}
d, err := ds.Data()
if err != nil {
return nil, err
}
// seek to the virtual address specified in the import data directory
d = d[idd.VirtualAddress-ds.VirtualAddress:]
// start decoding the import directory
var ida []ImportDirectory
for len(d) >= 20 {
var dt ImportDirectory
dt.OriginalFirstThunk = binary.LittleEndian.Uint32(d[0:4])
dt.TimeDateStamp = binary.LittleEndian.Uint32(d[4:8])
dt.ForwarderChain = binary.LittleEndian.Uint32(d[8:12])
dt.Name = binary.LittleEndian.Uint32(d[12:16])
dt.FirstThunk = binary.LittleEndian.Uint32(d[16:20])
d = d[20:]
if dt.OriginalFirstThunk == 0 {
break
}
ida = append(ida, dt)
}
// TODO(brainman): this needs to be rewritten
// ds.Data() returns contents of section containing import table. Why store in variable called "names"?
// Why we are retrieving it second time? We already have it in "d", and it is not modified anywhere.
// getString does not extracts a string from symbol string table (as getString doco says).
// Why ds.Data() called again and again in the loop?
// Needs test before rewrite.
names, _ := ds.Data()
var all []string
for _, dt := range ida {
dt.dll, _ = getString(names, int(dt.Name-ds.VirtualAddress))
d, _ = ds.Data()
// seek to OriginalFirstThunk
d = d[dt.OriginalFirstThunk-ds.VirtualAddress:]
for len(d) > 0 {
if pe64 { // 64bit
va := binary.LittleEndian.Uint64(d[0:8])
d = d[8:]
if va == 0 {
break
}
if va&0x8000000000000000 > 0 { // is Ordinal
// TODO add dynimport ordinal support.
} else {
fn, _ := getString(names, int(uint32(va)-ds.VirtualAddress+2))
all = append(all, fn+":"+dt.dll)
}
} else { // 32bit
va := binary.LittleEndian.Uint32(d[0:4])
d = d[4:]
if va == 0 {
break
}
if va&0x80000000 > 0 { // is Ordinal
// TODO add dynimport ordinal support.
//ord := va&0x0000FFFF
} else {
fn, _ := getString(names, int(va-ds.VirtualAddress+2))
all = append(all, fn+":"+dt.dll)
}
}
}
}
return all, nil
}
// ImportedLibraries returns the names of all libraries
// referred to by the binary f that are expected to be
// linked with the binary at dynamic link time.
func (f *File) ImportedLibraries() ([]string, error) {
// TODO
// cgo -dynimport don't use this for windows PE, so just return.
return nil, nil
}
// FormatError is unused.
// The type is retained for compatibility.
type FormatError struct {
}
func (e *FormatError) Error() string {
return "unknown error"
}
// readOptionalHeader accepts an io.ReadSeeker pointing to optional header in the PE file
// and its size as seen in the file header.
// It parses the given size of bytes and returns optional header. It infers whether the
// bytes being parsed refer to 32 bit or 64 bit version of optional header.
func readOptionalHeader(r io.ReadSeeker, sz uint16) (any, error) {
// If optional header size is 0, return empty optional header.
if sz == 0 {
return nil, nil
}
var (
// First couple of bytes in option header state its type.
// We need to read them first to determine the type and
// validity of optional header.
ohMagic uint16
ohMagicSz = binary.Size(ohMagic)
)
// If optional header size is greater than 0 but less than its magic size, return error.
if sz < uint16(ohMagicSz) {
return nil, fmt.Errorf("optional header size is less than optional header magic size")
}
// read reads from io.ReadSeeke, r, into data.
var err error
read := func(data any) bool {
err = binary.Read(r, binary.LittleEndian, data)
return err == nil
}
if !read(&ohMagic) {
return nil, fmt.Errorf("failure to read optional header magic: %v", err)
}
switch ohMagic {
case 0x10b: // PE32
var (
oh32 OptionalHeader32
// There can be 0 or more data directories. So the minimum size of optional
// header is calculated by subtracting oh32.DataDirectory size from oh32 size.
oh32MinSz = binary.Size(oh32) - binary.Size(oh32.DataDirectory)
)
if sz < uint16(oh32MinSz) {
return nil, fmt.Errorf("optional header size(%d) is less minimum size (%d) of PE32 optional header", sz, oh32MinSz)
}
// Init oh32 fields
oh32.Magic = ohMagic
if !read(&oh32.MajorLinkerVersion) ||
!read(&oh32.MinorLinkerVersion) ||
!read(&oh32.SizeOfCode) ||
!read(&oh32.SizeOfInitializedData) ||
!read(&oh32.SizeOfUninitializedData) ||
!read(&oh32.AddressOfEntryPoint) ||
!read(&oh32.BaseOfCode) ||
!read(&oh32.BaseOfData) ||
!read(&oh32.ImageBase) ||
!read(&oh32.SectionAlignment) ||
!read(&oh32.FileAlignment) ||
!read(&oh32.MajorOperatingSystemVersion) ||
!read(&oh32.MinorOperatingSystemVersion) ||
!read(&oh32.MajorImageVersion) ||
!read(&oh32.MinorImageVersion) ||
!read(&oh32.MajorSubsystemVersion) ||
!read(&oh32.MinorSubsystemVersion) ||
!read(&oh32.Win32VersionValue) ||
!read(&oh32.SizeOfImage) ||
!read(&oh32.SizeOfHeaders) ||
!read(&oh32.CheckSum) ||
!read(&oh32.Subsystem) ||
!read(&oh32.DllCharacteristics) ||
!read(&oh32.SizeOfStackReserve) ||
!read(&oh32.SizeOfStackCommit) ||
!read(&oh32.SizeOfHeapReserve) ||
!read(&oh32.SizeOfHeapCommit) ||
!read(&oh32.LoaderFlags) ||
!read(&oh32.NumberOfRvaAndSizes) {
return nil, fmt.Errorf("failure to read PE32 optional header: %v", err)
}
dd, err := readDataDirectories(r, sz-uint16(oh32MinSz), oh32.NumberOfRvaAndSizes)
if err != nil {
return nil, err
}
copy(oh32.DataDirectory[:], dd)
return &oh32, nil
case 0x20b: // PE32+
var (
oh64 OptionalHeader64
// There can be 0 or more data directories. So the minimum size of optional
// header is calculated by subtracting oh64.DataDirectory size from oh64 size.
oh64MinSz = binary.Size(oh64) - binary.Size(oh64.DataDirectory)
)
if sz < uint16(oh64MinSz) {
return nil, fmt.Errorf("optional header size(%d) is less minimum size (%d) for PE32+ optional header", sz, oh64MinSz)
}
// Init oh64 fields
oh64.Magic = ohMagic
if !read(&oh64.MajorLinkerVersion) ||
!read(&oh64.MinorLinkerVersion) ||
!read(&oh64.SizeOfCode) ||
!read(&oh64.SizeOfInitializedData) ||
!read(&oh64.SizeOfUninitializedData) ||
!read(&oh64.AddressOfEntryPoint) ||
!read(&oh64.BaseOfCode) ||
!read(&oh64.ImageBase) ||
!read(&oh64.SectionAlignment) ||
!read(&oh64.FileAlignment) ||
!read(&oh64.MajorOperatingSystemVersion) ||
!read(&oh64.MinorOperatingSystemVersion) ||
!read(&oh64.MajorImageVersion) ||
!read(&oh64.MinorImageVersion) ||
!read(&oh64.MajorSubsystemVersion) ||
!read(&oh64.MinorSubsystemVersion) ||
!read(&oh64.Win32VersionValue) ||
!read(&oh64.SizeOfImage) ||
!read(&oh64.SizeOfHeaders) ||
!read(&oh64.CheckSum) ||
!read(&oh64.Subsystem) ||
!read(&oh64.DllCharacteristics) ||
!read(&oh64.SizeOfStackReserve) ||
!read(&oh64.SizeOfStackCommit) ||
!read(&oh64.SizeOfHeapReserve) ||
!read(&oh64.SizeOfHeapCommit) ||
!read(&oh64.LoaderFlags) ||
!read(&oh64.NumberOfRvaAndSizes) {
return nil, fmt.Errorf("failure to read PE32+ optional header: %v", err)
}
dd, err := readDataDirectories(r, sz-uint16(oh64MinSz), oh64.NumberOfRvaAndSizes)
if err != nil {
return nil, err
}
copy(oh64.DataDirectory[:], dd)
return &oh64, nil
default:
return nil, fmt.Errorf("optional header has unexpected Magic of 0x%x", ohMagic)
}
}
// readDataDirectories accepts an io.ReadSeeker pointing to data directories in the PE file,
// its size and number of data directories as seen in optional header.
// It parses the given size of bytes and returns given number of data directories.
func readDataDirectories(r io.ReadSeeker, sz uint16, n uint32) ([]DataDirectory, error) {
ddSz := uint64(binary.Size(DataDirectory{}))
if uint64(sz) != uint64(n)*ddSz {
return nil, fmt.Errorf("size of data directories(%d) is inconsistent with number of data directories(%d)", sz, n)
}
dd := make([]DataDirectory, n)
if err := binary.Read(r, binary.LittleEndian, dd); err != nil {
return nil, fmt.Errorf("failure to read data directories: %v", err)
}
return dd, nil
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pe
import (
"encoding/binary"
"fmt"
"internal/saferio"
"io"
"strconv"
)
// SectionHeader32 represents real PE COFF section header.
type SectionHeader32 struct {
Name [8]uint8
VirtualSize uint32
VirtualAddress uint32
SizeOfRawData uint32
PointerToRawData uint32
PointerToRelocations uint32
PointerToLineNumbers uint32
NumberOfRelocations uint16
NumberOfLineNumbers uint16
Characteristics uint32
}
// fullName finds real name of section sh. Normally name is stored
// in sh.Name, but if it is longer then 8 characters, it is stored
// in COFF string table st instead.
func (sh *SectionHeader32) fullName(st StringTable) (string, error) {
if sh.Name[0] != '/' {
return cstring(sh.Name[:]), nil
}
i, err := strconv.Atoi(cstring(sh.Name[1:]))
if err != nil {
return "", err
}
return st.String(uint32(i))
}
// TODO(brainman): copy all IMAGE_REL_* consts from ldpe.go here
// Reloc represents a PE COFF relocation.
// Each section contains its own relocation list.
type Reloc struct {
VirtualAddress uint32
SymbolTableIndex uint32
Type uint16
}
func readRelocs(sh *SectionHeader, r io.ReadSeeker) ([]Reloc, error) {
if sh.NumberOfRelocations <= 0 {
return nil, nil
}
_, err := r.Seek(int64(sh.PointerToRelocations), io.SeekStart)
if err != nil {
return nil, fmt.Errorf("fail to seek to %q section relocations: %v", sh.Name, err)
}
relocs := make([]Reloc, sh.NumberOfRelocations)
err = binary.Read(r, binary.LittleEndian, relocs)
if err != nil {
return nil, fmt.Errorf("fail to read section relocations: %v", err)
}
return relocs, nil
}
// SectionHeader is similar to [SectionHeader32] with Name
// field replaced by Go string.
type SectionHeader struct {
Name string
VirtualSize uint32
VirtualAddress uint32
Size uint32
Offset uint32
PointerToRelocations uint32
PointerToLineNumbers uint32
NumberOfRelocations uint16
NumberOfLineNumbers uint16
Characteristics uint32
}
// Section provides access to PE COFF section.
type Section struct {
SectionHeader
Relocs []Reloc
// Embed ReaderAt for ReadAt method.
// Do not embed SectionReader directly
// to avoid having Read and Seek.
// If a client wants Read and Seek it must use
// Open() to avoid fighting over the seek offset
// with other clients.
io.ReaderAt
sr *io.SectionReader
}
// Data reads and returns the contents of the PE section s.
//
// If s.Offset is 0, the section has no contents,
// and Data will always return a non-nil error.
func (s *Section) Data() ([]byte, error) {
return saferio.ReadDataAt(s.sr, uint64(s.Size), 0)
}
// Open returns a new ReadSeeker reading the PE section s.
//
// If s.Offset is 0, the section has no contents, and all calls
// to the returned reader will return a non-nil error.
func (s *Section) Open() io.ReadSeeker {
return io.NewSectionReader(s.sr, 0, 1<<63-1)
}
// Section characteristics flags.
const (
IMAGE_SCN_CNT_CODE = 0x00000020
IMAGE_SCN_CNT_INITIALIZED_DATA = 0x00000040
IMAGE_SCN_CNT_UNINITIALIZED_DATA = 0x00000080
IMAGE_SCN_LNK_COMDAT = 0x00001000
IMAGE_SCN_MEM_DISCARDABLE = 0x02000000
IMAGE_SCN_MEM_EXECUTE = 0x20000000
IMAGE_SCN_MEM_READ = 0x40000000
IMAGE_SCN_MEM_WRITE = 0x80000000
)
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pe
import (
"bytes"
"encoding/binary"
"fmt"
"internal/saferio"
"io"
)
// cstring converts ASCII byte sequence b to string.
// It stops once it finds 0 or reaches end of b.
func cstring(b []byte) string {
i := bytes.IndexByte(b, 0)
if i == -1 {
i = len(b)
}
return string(b[:i])
}
// StringTable is a COFF string table.
type StringTable []byte
func readStringTable(fh *FileHeader, r io.ReadSeeker) (StringTable, error) {
// COFF string table is located right after COFF symbol table.
if fh.PointerToSymbolTable <= 0 {
return nil, nil
}
offset := fh.PointerToSymbolTable + COFFSymbolSize*fh.NumberOfSymbols
_, err := r.Seek(int64(offset), io.SeekStart)
if err != nil {
return nil, fmt.Errorf("fail to seek to string table: %v", err)
}
var l uint32
err = binary.Read(r, binary.LittleEndian, &l)
if err != nil {
return nil, fmt.Errorf("fail to read string table length: %v", err)
}
// string table length includes itself
if l <= 4 {
return nil, nil
}
l -= 4
buf, err := saferio.ReadData(r, uint64(l))
if err != nil {
return nil, fmt.Errorf("fail to read string table: %v", err)
}
return StringTable(buf), nil
}
// TODO(brainman): decide if start parameter should be int instead of uint32
// String extracts string from COFF string table st at offset start.
func (st StringTable) String(start uint32) (string, error) {
// start includes 4 bytes of string table length
if start < 4 {
return "", fmt.Errorf("offset %d is before the start of string table", start)
}
start -= 4
if int(start) > len(st) {
return "", fmt.Errorf("offset %d is beyond the end of string table", start)
}
return cstring(st[start:]), nil
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pe
import (
"encoding/binary"
"errors"
"fmt"
"internal/saferio"
"io"
"unsafe"
)
const COFFSymbolSize = 18
// COFFSymbol represents single COFF symbol table record.
type COFFSymbol struct {
Name [8]uint8
Value uint32
SectionNumber int16
Type uint16
StorageClass uint8
NumberOfAuxSymbols uint8
}
// readCOFFSymbols reads in the symbol table for a PE file, returning
// a slice of COFFSymbol objects. The PE format includes both primary
// symbols (whose fields are described by COFFSymbol above) and
// auxiliary symbols; all symbols are 18 bytes in size. The auxiliary
// symbols for a given primary symbol are placed following it in the
// array, e.g.
//
// ...
// k+0: regular sym k
// k+1: 1st aux symbol for k
// k+2: 2nd aux symbol for k
// k+3: regular sym k+3
// k+4: 1st aux symbol for k+3
// k+5: regular sym k+5
// k+6: regular sym k+6
//
// The PE format allows for several possible aux symbol formats. For
// more info see:
//
// https://docs.microsoft.com/en-us/windows/win32/debug/pe-format#auxiliary-symbol-records
//
// At the moment this package only provides APIs for looking at
// aux symbols of format 5 (associated with section definition symbols).
func readCOFFSymbols(fh *FileHeader, r io.ReadSeeker) ([]COFFSymbol, error) {
if fh.PointerToSymbolTable == 0 {
return nil, nil
}
if fh.NumberOfSymbols <= 0 {
return nil, nil
}
_, err := r.Seek(int64(fh.PointerToSymbolTable), io.SeekStart)
if err != nil {
return nil, fmt.Errorf("fail to seek to symbol table: %v", err)
}
c := saferio.SliceCap[COFFSymbol](uint64(fh.NumberOfSymbols))
if c < 0 {
return nil, errors.New("too many symbols; file may be corrupt")
}
syms := make([]COFFSymbol, 0, c)
naux := 0
for k := uint32(0); k < fh.NumberOfSymbols; k++ {
var sym COFFSymbol
if naux == 0 {
// Read a primary symbol.
err = binary.Read(r, binary.LittleEndian, &sym)
if err != nil {
return nil, fmt.Errorf("fail to read symbol table: %v", err)
}
// Record how many auxiliary symbols it has.
naux = int(sym.NumberOfAuxSymbols)
} else {
// Read an aux symbol. At the moment we assume all
// aux symbols are format 5 (obviously this doesn't always
// hold; more cases will be needed below if more aux formats
// are supported in the future).
naux--
aux := (*COFFSymbolAuxFormat5)(unsafe.Pointer(&sym))
err = binary.Read(r, binary.LittleEndian, aux)
if err != nil {
return nil, fmt.Errorf("fail to read symbol table: %v", err)
}
}
syms = append(syms, sym)
}
if naux != 0 {
return nil, fmt.Errorf("fail to read symbol table: %d aux symbols unread", naux)
}
return syms, nil
}
// isSymNameOffset checks symbol name if it is encoded as offset into string table.
func isSymNameOffset(name [8]byte) (bool, uint32) {
if name[0] == 0 && name[1] == 0 && name[2] == 0 && name[3] == 0 {
return true, binary.LittleEndian.Uint32(name[4:])
}
return false, 0
}
// FullName finds real name of symbol sym. Normally name is stored
// in sym.Name, but if it is longer then 8 characters, it is stored
// in COFF string table st instead.
func (sym *COFFSymbol) FullName(st StringTable) (string, error) {
if ok, offset := isSymNameOffset(sym.Name); ok {
return st.String(offset)
}
return cstring(sym.Name[:]), nil
}
func removeAuxSymbols(allsyms []COFFSymbol, st StringTable) ([]*Symbol, error) {
if len(allsyms) == 0 {
return nil, nil
}
syms := make([]*Symbol, 0)
aux := uint8(0)
for _, sym := range allsyms {
if aux > 0 {
aux--
continue
}
name, err := sym.FullName(st)
if err != nil {
return nil, err
}
aux = sym.NumberOfAuxSymbols
s := &Symbol{
Name: name,
Value: sym.Value,
SectionNumber: sym.SectionNumber,
Type: sym.Type,
StorageClass: sym.StorageClass,
}
syms = append(syms, s)
}
return syms, nil
}
// Symbol is similar to [COFFSymbol] with Name field replaced
// by Go string. Symbol also does not have NumberOfAuxSymbols.
type Symbol struct {
Name string
Value uint32
SectionNumber int16
Type uint16
StorageClass uint8
}
// COFFSymbolAuxFormat5 describes the expected form of an aux symbol
// attached to a section definition symbol. The PE format defines a
// number of different aux symbol formats: format 1 for function
// definitions, format 2 for .be and .ef symbols, and so on. Format 5
// holds extra info associated with a section definition, including
// number of relocations + line numbers, as well as COMDAT info. See
// https://docs.microsoft.com/en-us/windows/win32/debug/pe-format#auxiliary-format-5-section-definitions
// for more on what's going on here.
type COFFSymbolAuxFormat5 struct {
Size uint32
NumRelocs uint16
NumLineNumbers uint16
Checksum uint32
SecNum uint16
Selection uint8
_ [3]uint8 // padding
}
// These constants make up the possible values for the 'Selection'
// field in an AuxFormat5.
const (
IMAGE_COMDAT_SELECT_NODUPLICATES = 1
IMAGE_COMDAT_SELECT_ANY = 2
IMAGE_COMDAT_SELECT_SAME_SIZE = 3
IMAGE_COMDAT_SELECT_EXACT_MATCH = 4
IMAGE_COMDAT_SELECT_ASSOCIATIVE = 5
IMAGE_COMDAT_SELECT_LARGEST = 6
)
// COFFSymbolReadSectionDefAux returns a blob of auxiliary information
// (including COMDAT info) for a section definition symbol. Here 'idx'
// is the index of a section symbol in the main [COFFSymbol] array for
// the File. Return value is a pointer to the appropriate aux symbol
// struct. For more info, see:
//
// auxiliary symbols: https://docs.microsoft.com/en-us/windows/win32/debug/pe-format#auxiliary-symbol-records
// COMDAT sections: https://docs.microsoft.com/en-us/windows/win32/debug/pe-format#comdat-sections-object-only
// auxiliary info for section definitions: https://docs.microsoft.com/en-us/windows/win32/debug/pe-format#auxiliary-format-5-section-definitions
func (f *File) COFFSymbolReadSectionDefAux(idx int) (*COFFSymbolAuxFormat5, error) {
var rv *COFFSymbolAuxFormat5
if idx < 0 || idx >= len(f.COFFSymbols) {
return rv, fmt.Errorf("invalid symbol index")
}
pesym := &f.COFFSymbols[idx]
const IMAGE_SYM_CLASS_STATIC = 3
if pesym.StorageClass != uint8(IMAGE_SYM_CLASS_STATIC) {
return rv, fmt.Errorf("incorrect symbol storage class")
}
if pesym.NumberOfAuxSymbols == 0 || idx+1 >= len(f.COFFSymbols) {
return rv, fmt.Errorf("aux symbol unavailable")
}
// Locate and return a pointer to the successor aux symbol.
pesymn := &f.COFFSymbols[idx+1]
rv = (*COFFSymbolAuxFormat5)(unsafe.Pointer(pesymn))
return rv, nil
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package plan9obj implements access to Plan 9 a.out object files.
# Security
This package is not designed to be hardened against adversarial inputs, and is
outside the scope of https://go.dev/security/policy. In particular, only basic
validation is done when parsing object files. As such, care should be taken when
parsing untrusted inputs, as parsing malformed files may consume significant
resources, or cause panics.
*/
package plan9obj
import (
"encoding/binary"
"errors"
"fmt"
"internal/saferio"
"io"
"os"
)
// A FileHeader represents a Plan 9 a.out file header.
type FileHeader struct {
Magic uint32
Bss uint32
Entry uint64
PtrSize int
LoadAddress uint64
HdrSize uint64
}
// A File represents an open Plan 9 a.out file.
type File struct {
FileHeader
Sections []*Section
closer io.Closer
}
// A SectionHeader represents a single Plan 9 a.out section header.
// This structure doesn't exist on-disk, but eases navigation
// through the object file.
type SectionHeader struct {
Name string
Size uint32
Offset uint32
}
// A Section represents a single section in a Plan 9 a.out file.
type Section struct {
SectionHeader
// Embed ReaderAt for ReadAt method.
// Do not embed SectionReader directly
// to avoid having Read and Seek.
// If a client wants Read and Seek it must use
// Open() to avoid fighting over the seek offset
// with other clients.
io.ReaderAt
sr *io.SectionReader
}
// Data reads and returns the contents of the Plan 9 a.out section.
func (s *Section) Data() ([]byte, error) {
return saferio.ReadDataAt(s.sr, uint64(s.Size), 0)
}
// Open returns a new ReadSeeker reading the Plan 9 a.out section.
func (s *Section) Open() io.ReadSeeker { return io.NewSectionReader(s.sr, 0, 1<<63-1) }
// A Symbol represents an entry in a Plan 9 a.out symbol table section.
type Sym struct {
Value uint64
Type rune
Name string
}
/*
* Plan 9 a.out reader
*/
// formatError is returned by some operations if the data does
// not have the correct format for an object file.
type formatError struct {
off int
msg string
val any
}
func (e *formatError) Error() string {
msg := e.msg
if e.val != nil {
msg += fmt.Sprintf(" '%v'", e.val)
}
msg += fmt.Sprintf(" in record at byte %#x", e.off)
return msg
}
// Open opens the named file using [os.Open] and prepares it for use as a Plan 9 a.out binary.
func Open(name string) (*File, error) {
f, err := os.Open(name)
if err != nil {
return nil, err
}
ff, err := NewFile(f)
if err != nil {
f.Close()
return nil, err
}
ff.closer = f
return ff, nil
}
// Close closes the [File].
// If the [File] was created using [NewFile] directly instead of [Open],
// Close has no effect.
func (f *File) Close() error {
var err error
if f.closer != nil {
err = f.closer.Close()
f.closer = nil
}
return err
}
func parseMagic(magic []byte) (uint32, error) {
m := binary.BigEndian.Uint32(magic)
switch m {
case Magic386, MagicAMD64, MagicARM:
return m, nil
}
return 0, &formatError{0, "bad magic number", magic}
}
// NewFile creates a new [File] for accessing a Plan 9 binary in an underlying reader.
// The Plan 9 binary is expected to start at position 0 in the ReaderAt.
func NewFile(r io.ReaderAt) (*File, error) {
sr := io.NewSectionReader(r, 0, 1<<63-1)
// Read and decode Plan 9 magic
var magic [4]byte
if _, err := r.ReadAt(magic[:], 0); err != nil {
return nil, err
}
_, err := parseMagic(magic[:])
if err != nil {
return nil, err
}
ph := new(prog)
if err := binary.Read(sr, binary.BigEndian, ph); err != nil {
return nil, err
}
f := &File{FileHeader: FileHeader{
Magic: ph.Magic,
Bss: ph.Bss,
Entry: uint64(ph.Entry),
PtrSize: 4,
LoadAddress: 0x1000,
HdrSize: 4 * 8,
}}
if ph.Magic&Magic64 != 0 {
if err := binary.Read(sr, binary.BigEndian, &f.Entry); err != nil {
return nil, err
}
f.PtrSize = 8
f.LoadAddress = 0x200000
f.HdrSize += 8
}
var sects = []struct {
name string
size uint32
}{
{"text", ph.Text},
{"data", ph.Data},
{"syms", ph.Syms},
{"spsz", ph.Spsz},
{"pcsz", ph.Pcsz},
}
f.Sections = make([]*Section, 5)
off := uint32(f.HdrSize)
for i, sect := range sects {
s := new(Section)
s.SectionHeader = SectionHeader{
Name: sect.name,
Size: sect.size,
Offset: off,
}
off += sect.size
s.sr = io.NewSectionReader(r, int64(s.Offset), int64(s.Size))
s.ReaderAt = s.sr
f.Sections[i] = s
}
return f, nil
}
func walksymtab(data []byte, ptrsz int, fn func(sym) error) error {
var order binary.ByteOrder = binary.BigEndian
var s sym
p := data
for len(p) >= 4 {
// Symbol type, value.
if len(p) < ptrsz {
return &formatError{len(data), "unexpected EOF", nil}
}
// fixed-width value
if ptrsz == 8 {
s.value = order.Uint64(p[0:8])
p = p[8:]
} else {
s.value = uint64(order.Uint32(p[0:4]))
p = p[4:]
}
if len(p) < 1 {
return &formatError{len(data), "unexpected EOF", nil}
}
typ := p[0] & 0x7F
s.typ = typ
p = p[1:]
// Name.
var i int
var nnul int
for i = 0; i < len(p); i++ {
if p[i] == 0 {
nnul = 1
break
}
}
switch typ {
case 'z', 'Z':
p = p[i+nnul:]
for i = 0; i+2 <= len(p); i += 2 {
if p[i] == 0 && p[i+1] == 0 {
nnul = 2
break
}
}
}
if len(p) < i+nnul {
return &formatError{len(data), "unexpected EOF", nil}
}
s.name = p[0:i]
i += nnul
p = p[i:]
fn(s)
}
return nil
}
// newTable decodes the Go symbol table in data,
// returning an in-memory representation.
func newTable(symtab []byte, ptrsz int) ([]Sym, error) {
var n int
err := walksymtab(symtab, ptrsz, func(s sym) error {
n++
return nil
})
if err != nil {
return nil, err
}
fname := make(map[uint16]string)
syms := make([]Sym, 0, n)
err = walksymtab(symtab, ptrsz, func(s sym) error {
n := len(syms)
syms = syms[0 : n+1]
ts := &syms[n]
ts.Type = rune(s.typ)
ts.Value = s.value
switch s.typ {
default:
ts.Name = string(s.name)
case 'z', 'Z':
for i := 0; i < len(s.name); i += 2 {
eltIdx := binary.BigEndian.Uint16(s.name[i : i+2])
elt, ok := fname[eltIdx]
if !ok {
return &formatError{-1, "bad filename code", eltIdx}
}
if n := len(ts.Name); n > 0 && ts.Name[n-1] != '/' {
ts.Name += "/"
}
ts.Name += elt
}
}
switch s.typ {
case 'f':
fname[uint16(s.value)] = ts.Name
}
return nil
})
if err != nil {
return nil, err
}
return syms, nil
}
// ErrNoSymbols is returned by [File.Symbols] if there is no such section
// in the File.
var ErrNoSymbols = errors.New("no symbol section")
// Symbols returns the symbol table for f.
func (f *File) Symbols() ([]Sym, error) {
symtabSection := f.Section("syms")
if symtabSection == nil {
return nil, ErrNoSymbols
}
symtab, err := symtabSection.Data()
if err != nil {
return nil, errors.New("cannot load symbol section")
}
return newTable(symtab, f.PtrSize)
}
// Section returns a section with the given name, or nil if no such
// section exists.
func (f *File) Section(name string) *Section {
for _, s := range f.Sections {
if s.Name == name {
return s
}
}
return nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ascii85 implements the ascii85 data encoding
// as used in the btoa tool and Adobe's PostScript and PDF document formats.
package ascii85
import (
"io"
"strconv"
)
/*
* Encoder
*/
// Encode encodes src into at most [MaxEncodedLen](len(src))
// bytes of dst, returning the actual number of bytes written.
//
// The encoding handles 4-byte chunks, using a special encoding
// for the last fragment, so Encode is not appropriate for use on
// individual blocks of a large data stream. Use [NewEncoder] instead.
//
// Often, ascii85-encoded data is wrapped in <~ and ~> symbols.
// Encode does not add these.
func Encode(dst, src []byte) int {
if len(src) == 0 {
return 0
}
n := 0
for len(src) > 0 {
dst[0] = 0
dst[1] = 0
dst[2] = 0
dst[3] = 0
dst[4] = 0
// Unpack 4 bytes into uint32 to repack into base 85 5-byte.
var v uint32
switch len(src) {
default:
v |= uint32(src[3])
fallthrough
case 3:
v |= uint32(src[2]) << 8
fallthrough
case 2:
v |= uint32(src[1]) << 16
fallthrough
case 1:
v |= uint32(src[0]) << 24
}
// Special case: zero (!!!!!) shortens to z.
if v == 0 && len(src) >= 4 {
dst[0] = 'z'
dst = dst[1:]
src = src[4:]
n++
continue
}
// Otherwise, 5 base 85 digits starting at !.
for i := 4; i >= 0; i-- {
dst[i] = '!' + byte(v%85)
v /= 85
}
// If src was short, discard the low destination bytes.
m := 5
if len(src) < 4 {
m -= 4 - len(src)
src = nil
} else {
src = src[4:]
}
dst = dst[m:]
n += m
}
return n
}
// MaxEncodedLen returns the maximum length of an encoding of n source bytes.
func MaxEncodedLen(n int) int { return (n + 3) / 4 * 5 }
// NewEncoder returns a new ascii85 stream encoder. Data written to
// the returned writer will be encoded and then written to w.
// Ascii85 encodings operate in 32-bit blocks; when finished
// writing, the caller must Close the returned encoder to flush any
// trailing partial block.
func NewEncoder(w io.Writer) io.WriteCloser { return &encoder{w: w} }
type encoder struct {
err error
w io.Writer
buf [4]byte // buffered data waiting to be encoded
nbuf int // number of bytes in buf
out [1024]byte // output buffer
}
func (e *encoder) Write(p []byte) (n int, err error) {
if e.err != nil {
return 0, e.err
}
// Leading fringe.
if e.nbuf > 0 {
var i int
for i = 0; i < len(p) && e.nbuf < 4; i++ {
e.buf[e.nbuf] = p[i]
e.nbuf++
}
n += i
p = p[i:]
if e.nbuf < 4 {
return
}
nout := Encode(e.out[0:], e.buf[0:])
if _, e.err = e.w.Write(e.out[0:nout]); e.err != nil {
return n, e.err
}
e.nbuf = 0
}
// Large interior chunks.
for len(p) >= 4 {
nn := len(e.out) / 5 * 4
if nn > len(p) {
nn = len(p)
}
nn -= nn % 4
if nn > 0 {
nout := Encode(e.out[0:], p[0:nn])
if _, e.err = e.w.Write(e.out[0:nout]); e.err != nil {
return n, e.err
}
}
n += nn
p = p[nn:]
}
// Trailing fringe.
copy(e.buf[:], p)
e.nbuf = len(p)
n += len(p)
return
}
// Close flushes any pending output from the encoder.
// It is an error to call Write after calling Close.
func (e *encoder) Close() error {
// If there's anything left in the buffer, flush it out
if e.err == nil && e.nbuf > 0 {
nout := Encode(e.out[0:], e.buf[0:e.nbuf])
e.nbuf = 0
_, e.err = e.w.Write(e.out[0:nout])
}
return e.err
}
/*
* Decoder
*/
type CorruptInputError int64
func (e CorruptInputError) Error() string {
return "illegal ascii85 data at input byte " + strconv.FormatInt(int64(e), 10)
}
// Decode decodes src into dst, returning both the number
// of bytes written to dst and the number consumed from src.
// If src contains invalid ascii85 data, Decode will return the
// number of bytes successfully written and a [CorruptInputError].
// Decode ignores space and control characters in src.
// Often, ascii85-encoded data is wrapped in <~ and ~> symbols.
// Decode expects these to have been stripped by the caller.
//
// If flush is true, Decode assumes that src represents the
// end of the input stream and processes it completely rather
// than wait for the completion of another 32-bit block.
//
// [NewDecoder] wraps an [io.Reader] interface around Decode.
func Decode(dst, src []byte, flush bool) (ndst, nsrc int, err error) {
var v uint32
var nb int
for i, b := range src {
if len(dst)-ndst < 4 {
return
}
switch {
case b <= ' ':
continue
case b == 'z' && nb == 0:
nb = 5
v = 0
case '!' <= b && b <= 'u':
v = v*85 + uint32(b-'!')
nb++
default:
return 0, 0, CorruptInputError(i)
}
if nb == 5 {
nsrc = i + 1
dst[ndst] = byte(v >> 24)
dst[ndst+1] = byte(v >> 16)
dst[ndst+2] = byte(v >> 8)
dst[ndst+3] = byte(v)
ndst += 4
nb = 0
v = 0
}
}
if flush {
nsrc = len(src)
if nb > 0 {
// The number of output bytes in the last fragment
// is the number of leftover input bytes - 1:
// the extra byte provides enough bits to cover
// the inefficiency of the encoding for the block.
if nb == 1 {
return 0, 0, CorruptInputError(len(src))
}
for i := nb; i < 5; i++ {
// The short encoding truncated the output value.
// We have to assume the worst case values (digit 84)
// in order to ensure that the top bits are correct.
v = v*85 + 84
}
for i := 0; i < nb-1; i++ {
dst[ndst] = byte(v >> 24)
v <<= 8
ndst++
}
}
}
return
}
// NewDecoder constructs a new ascii85 stream decoder.
func NewDecoder(r io.Reader) io.Reader { return &decoder{r: r} }
type decoder struct {
err error
readErr error
r io.Reader
buf [1024]byte // leftover input
nbuf int
out []byte // leftover decoded output
outbuf [1024]byte
}
func (d *decoder) Read(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
if d.err != nil {
return 0, d.err
}
for {
// Copy leftover output from last decode.
if len(d.out) > 0 {
n = copy(p, d.out)
d.out = d.out[n:]
return
}
// Decode leftover input from last read.
var nn, nsrc, ndst int
if d.nbuf > 0 {
ndst, nsrc, d.err = Decode(d.outbuf[0:], d.buf[0:d.nbuf], d.readErr != nil)
if ndst > 0 {
d.out = d.outbuf[0:ndst]
d.nbuf = copy(d.buf[0:], d.buf[nsrc:d.nbuf])
continue // copy out and return
}
if ndst == 0 && d.err == nil {
// Special case: input buffer is mostly filled with non-data bytes.
// Filter out such bytes to make room for more input.
off := 0
for i := 0; i < d.nbuf; i++ {
if d.buf[i] > ' ' {
d.buf[off] = d.buf[i]
off++
}
}
d.nbuf = off
}
}
// Out of input, out of decoded output. Check errors.
if d.err != nil {
return 0, d.err
}
if d.readErr != nil {
d.err = d.readErr
return 0, d.err
}
// Read more data.
nn, d.readErr = d.r.Read(d.buf[d.nbuf:])
d.nbuf += nn
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package asn1 implements parsing of DER-encoded ASN.1 data structures,
// as defined in ITU-T Rec X.690.
//
// See also “A Layman's Guide to a Subset of ASN.1, BER, and DER,”
// http://luca.ntop.org/Teaching/Appunti/asn1.html.
package asn1
// ASN.1 is a syntax for specifying abstract objects and BER, DER, PER, XER etc
// are different encoding formats for those objects. Here, we'll be dealing
// with DER, the Distinguished Encoding Rules. DER is used in X.509 because
// it's fast to parse and, unlike BER, has a unique encoding for every object.
// When calculating hashes over objects, it's important that the resulting
// bytes be the same at both ends and DER removes this margin of error.
//
// ASN.1 is very complex and this package doesn't attempt to implement
// everything by any means.
import (
"errors"
"fmt"
"math"
"math/big"
"reflect"
"slices"
"strconv"
"strings"
"time"
"unicode/utf16"
"unicode/utf8"
)
// A StructuralError suggests that the ASN.1 data is valid, but the Go type
// which is receiving it doesn't match.
type StructuralError struct {
Msg string
}
func (e StructuralError) Error() string { return "asn1: structure error: " + e.Msg }
// A SyntaxError suggests that the ASN.1 data is invalid.
type SyntaxError struct {
Msg string
}
func (e SyntaxError) Error() string { return "asn1: syntax error: " + e.Msg }
// We start by dealing with each of the primitive types in turn.
// BOOLEAN
func parseBool(bytes []byte) (ret bool, err error) {
if len(bytes) != 1 {
err = SyntaxError{"invalid boolean"}
return
}
// DER demands that "If the encoding represents the boolean value TRUE,
// its single contents octet shall have all eight bits set to one."
// Thus only 0 and 255 are valid encoded values.
switch bytes[0] {
case 0:
ret = false
case 0xff:
ret = true
default:
err = SyntaxError{"invalid boolean"}
}
return
}
// INTEGER
// checkInteger returns nil if the given bytes are a valid DER-encoded
// INTEGER and an error otherwise.
func checkInteger(bytes []byte) error {
if len(bytes) == 0 {
return StructuralError{"empty integer"}
}
if len(bytes) == 1 {
return nil
}
if (bytes[0] == 0 && bytes[1]&0x80 == 0) || (bytes[0] == 0xff && bytes[1]&0x80 == 0x80) {
return StructuralError{"integer not minimally-encoded"}
}
return nil
}
// parseInt64 treats the given bytes as a big-endian, signed integer and
// returns the result.
func parseInt64(bytes []byte) (ret int64, err error) {
err = checkInteger(bytes)
if err != nil {
return
}
if len(bytes) > 8 {
// We'll overflow an int64 in this case.
err = StructuralError{"integer too large"}
return
}
for bytesRead := 0; bytesRead < len(bytes); bytesRead++ {
ret <<= 8
ret |= int64(bytes[bytesRead])
}
// Shift up and down in order to sign extend the result.
ret <<= 64 - uint8(len(bytes))*8
ret >>= 64 - uint8(len(bytes))*8
return
}
// parseInt32 treats the given bytes as a big-endian, signed integer and returns
// the result.
func parseInt32(bytes []byte) (int32, error) {
if err := checkInteger(bytes); err != nil {
return 0, err
}
ret64, err := parseInt64(bytes)
if err != nil {
return 0, err
}
if ret64 != int64(int32(ret64)) {
return 0, StructuralError{"integer too large"}
}
return int32(ret64), nil
}
var bigOne = big.NewInt(1)
// parseBigInt treats the given bytes as a big-endian, signed integer and returns
// the result.
func parseBigInt(bytes []byte) (*big.Int, error) {
if err := checkInteger(bytes); err != nil {
return nil, err
}
ret := new(big.Int)
if len(bytes) > 0 && bytes[0]&0x80 == 0x80 {
// This is a negative number.
notBytes := make([]byte, len(bytes))
for i := range notBytes {
notBytes[i] = ^bytes[i]
}
ret.SetBytes(notBytes)
ret.Add(ret, bigOne)
ret.Neg(ret)
return ret, nil
}
ret.SetBytes(bytes)
return ret, nil
}
// BIT STRING
// BitString is the structure to use when you want an ASN.1 BIT STRING type. A
// bit string is padded up to the nearest byte in memory and the number of
// valid bits is recorded. Padding bits will be zero.
type BitString struct {
Bytes []byte // bits packed into bytes.
BitLength int // length in bits.
}
// At returns the bit at the given index. If the index is out of range it
// returns 0.
func (b BitString) At(i int) int {
if i < 0 || i >= b.BitLength {
return 0
}
x := i / 8
y := 7 - uint(i%8)
return int(b.Bytes[x]>>y) & 1
}
// RightAlign returns a slice where the padding bits are at the beginning. The
// slice may share memory with the BitString.
func (b BitString) RightAlign() []byte {
shift := uint(8 - (b.BitLength % 8))
if shift == 8 || len(b.Bytes) == 0 {
return b.Bytes
}
a := make([]byte, len(b.Bytes))
a[0] = b.Bytes[0] >> shift
for i := 1; i < len(b.Bytes); i++ {
a[i] = b.Bytes[i-1] << (8 - shift)
a[i] |= b.Bytes[i] >> shift
}
return a
}
// parseBitString parses an ASN.1 bit string from the given byte slice and returns it.
func parseBitString(bytes []byte) (ret BitString, err error) {
if len(bytes) == 0 {
err = SyntaxError{"zero length BIT STRING"}
return
}
paddingBits := int(bytes[0])
if paddingBits > 7 ||
len(bytes) == 1 && paddingBits > 0 ||
bytes[len(bytes)-1]&((1<<bytes[0])-1) != 0 {
err = SyntaxError{"invalid padding bits in BIT STRING"}
return
}
ret.BitLength = (len(bytes)-1)*8 - paddingBits
ret.Bytes = bytes[1:]
return
}
// NULL
// NullRawValue is a [RawValue] with its Tag set to the ASN.1 NULL type tag (5).
var NullRawValue = RawValue{Tag: TagNull}
// NullBytes contains bytes representing the DER-encoded ASN.1 NULL type.
var NullBytes = []byte{TagNull, 0}
// OBJECT IDENTIFIER
// An ObjectIdentifier represents an ASN.1 OBJECT IDENTIFIER.
type ObjectIdentifier []int
// Equal reports whether oi and other represent the same identifier.
func (oi ObjectIdentifier) Equal(other ObjectIdentifier) bool {
return slices.Equal(oi, other)
}
func (oi ObjectIdentifier) String() string {
var s strings.Builder
s.Grow(32)
buf := make([]byte, 0, 19)
for i, v := range oi {
if i > 0 {
s.WriteByte('.')
}
s.Write(strconv.AppendInt(buf, int64(v), 10))
}
return s.String()
}
// parseObjectIdentifier parses an OBJECT IDENTIFIER from the given bytes and
// returns it. An object identifier is a sequence of variable length integers
// that are assigned in a hierarchy.
func parseObjectIdentifier(bytes []byte) (s ObjectIdentifier, err error) {
if len(bytes) == 0 {
err = SyntaxError{"zero length OBJECT IDENTIFIER"}
return
}
// In the worst case, we get two elements from the first byte (which is
// encoded differently) and then every varint is a single byte long.
s = make([]int, len(bytes)+1)
// The first varint is 40*value1 + value2:
// According to this packing, value1 can take the values 0, 1 and 2 only.
// When value1 = 0 or value1 = 1, then value2 is <= 39. When value1 = 2,
// then there are no restrictions on value2.
v, offset, err := parseBase128Int(bytes, 0)
if err != nil {
return
}
if v < 80 {
s[0] = v / 40
s[1] = v % 40
} else {
s[0] = 2
s[1] = v - 80
}
i := 2
for ; offset < len(bytes); i++ {
v, offset, err = parseBase128Int(bytes, offset)
if err != nil {
return
}
s[i] = v
}
s = s[0:i]
return
}
// ENUMERATED
// An Enumerated is represented as a plain int.
type Enumerated int
// FLAG
// A Flag accepts any data and is set to true if present.
type Flag bool
// parseBase128Int parses a base-128 encoded int from the given offset in the
// given byte slice. It returns the value and the new offset.
func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err error) {
offset = initOffset
var ret64 int64
for shifted := 0; offset < len(bytes); shifted++ {
// 5 * 7 bits per byte == 35 bits of data
// Thus the representation is either non-minimal or too large for an int32
if shifted == 5 {
err = StructuralError{"base 128 integer too large"}
return
}
ret64 <<= 7
b := bytes[offset]
// integers should be minimally encoded, so the leading octet should
// never be 0x80
if shifted == 0 && b == 0x80 {
err = SyntaxError{"integer is not minimally encoded"}
return
}
ret64 |= int64(b & 0x7f)
offset++
if b&0x80 == 0 {
ret = int(ret64)
// Ensure that the returned value fits in an int on all platforms
if ret64 > math.MaxInt32 {
err = StructuralError{"base 128 integer too large"}
}
return
}
}
err = SyntaxError{"truncated base 128 integer"}
return
}
// UTCTime
func parseUTCTime(bytes []byte) (ret time.Time, err error) {
s := string(bytes)
formatStr := "0601021504Z0700"
ret, err = time.Parse(formatStr, s)
if err != nil {
formatStr = "060102150405Z0700"
ret, err = time.Parse(formatStr, s)
}
if err != nil {
return
}
if serialized := ret.Format(formatStr); serialized != s {
err = fmt.Errorf("asn1: time did not serialize back to the original value and may be invalid: given %q, but serialized as %q", s, serialized)
return
}
if ret.Year() >= 2050 {
// UTCTime only encodes times prior to 2050. See https://tools.ietf.org/html/rfc5280#section-4.1.2.5.1
ret = ret.AddDate(-100, 0, 0)
}
return
}
// parseGeneralizedTime parses the GeneralizedTime from the given byte slice
// and returns the resulting time.
func parseGeneralizedTime(bytes []byte) (ret time.Time, err error) {
const formatStr = "20060102150405.999999999Z0700"
s := string(bytes)
if ret, err = time.Parse(formatStr, s); err != nil {
return
}
if serialized := ret.Format(formatStr); serialized != s {
err = fmt.Errorf("asn1: time did not serialize back to the original value and may be invalid: given %q, but serialized as %q", s, serialized)
}
return
}
// NumericString
// parseNumericString parses an ASN.1 NumericString from the given byte array
// and returns it.
func parseNumericString(bytes []byte) (ret string, err error) {
for _, b := range bytes {
if !isNumeric(b) {
return "", SyntaxError{"NumericString contains invalid character"}
}
}
return string(bytes), nil
}
// isNumeric reports whether the given b is in the ASN.1 NumericString set.
func isNumeric(b byte) bool {
return '0' <= b && b <= '9' ||
b == ' '
}
// PrintableString
// parsePrintableString parses an ASN.1 PrintableString from the given byte
// array and returns it.
func parsePrintableString(bytes []byte) (ret string, err error) {
for _, b := range bytes {
if !isPrintable(b, allowAsterisk, allowAmpersand) {
err = SyntaxError{"PrintableString contains invalid character"}
return
}
}
ret = string(bytes)
return
}
type asteriskFlag bool
type ampersandFlag bool
const (
allowAsterisk asteriskFlag = true
rejectAsterisk asteriskFlag = false
allowAmpersand ampersandFlag = true
rejectAmpersand ampersandFlag = false
)
// isPrintable reports whether the given b is in the ASN.1 PrintableString set.
// If asterisk is allowAsterisk then '*' is also allowed, reflecting existing
// practice. If ampersand is allowAmpersand then '&' is allowed as well.
func isPrintable(b byte, asterisk asteriskFlag, ampersand ampersandFlag) bool {
return 'a' <= b && b <= 'z' ||
'A' <= b && b <= 'Z' ||
'0' <= b && b <= '9' ||
'\'' <= b && b <= ')' ||
'+' <= b && b <= '/' ||
b == ' ' ||
b == ':' ||
b == '=' ||
b == '?' ||
// This is technically not allowed in a PrintableString.
// However, x509 certificates with wildcard strings don't
// always use the correct string type so we permit it.
(bool(asterisk) && b == '*') ||
// This is not technically allowed either. However, not
// only is it relatively common, but there are also a
// handful of CA certificates that contain it. At least
// one of which will not expire until 2027.
(bool(ampersand) && b == '&')
}
// IA5String
// parseIA5String parses an ASN.1 IA5String (ASCII string) from the given
// byte slice and returns it.
func parseIA5String(bytes []byte) (ret string, err error) {
for _, b := range bytes {
if b >= utf8.RuneSelf {
err = SyntaxError{"IA5String contains invalid character"}
return
}
}
ret = string(bytes)
return
}
// T61String
// parseT61String parses an ASN.1 T61String (8-bit clean string) from the given
// byte slice and returns it.
func parseT61String(bytes []byte) (ret string, err error) {
// T.61 is a defunct ITU 8-bit character encoding which preceded Unicode.
// T.61 uses a code page layout that _almost_ exactly maps to the code
// page layout of the ISO 8859-1 (Latin-1) character encoding, with the
// exception that a number of characters in Latin-1 are not present
// in T.61.
//
// Instead of mapping which characters are present in Latin-1 but not T.61,
// we just treat these strings as being encoded using Latin-1. This matches
// what most of the world does, including BoringSSL.
buf := make([]byte, 0, len(bytes))
for _, v := range bytes {
// All the 1-byte UTF-8 runes map 1-1 with Latin-1.
buf = utf8.AppendRune(buf, rune(v))
}
return string(buf), nil
}
// UTF8String
// parseUTF8String parses an ASN.1 UTF8String (raw UTF-8) from the given byte
// array and returns it.
func parseUTF8String(bytes []byte) (ret string, err error) {
if !utf8.Valid(bytes) {
return "", errors.New("asn1: invalid UTF-8 string")
}
return string(bytes), nil
}
// BMPString
// parseBMPString parses an ASN.1 BMPString (Basic Multilingual Plane of
// ISO/IEC/ITU 10646-1) from the given byte slice and returns it.
func parseBMPString(bmpString []byte) (string, error) {
// BMPString uses the defunct UCS-2 16-bit character encoding, which
// covers the Basic Multilingual Plane (BMP). UTF-16 was an extension of
// UCS-2, containing all of the same code points, but also including
// multi-code point characters (by using surrogate code points). We can
// treat a UCS-2 encoded string as a UTF-16 encoded string, as long as
// we reject out the UTF-16 specific code points. This matches the
// BoringSSL behavior.
if len(bmpString)%2 != 0 {
return "", errors.New("invalid BMPString")
}
// Strip terminator if present.
if l := len(bmpString); l >= 2 && bmpString[l-1] == 0 && bmpString[l-2] == 0 {
bmpString = bmpString[:l-2]
}
s := make([]uint16, 0, len(bmpString)/2)
for len(bmpString) > 0 {
point := uint16(bmpString[0])<<8 + uint16(bmpString[1])
// Reject UTF-16 code points that are permanently reserved
// noncharacters (0xfffe, 0xffff, and 0xfdd0-0xfdef) and surrogates
// (0xd800-0xdfff).
if point == 0xfffe || point == 0xffff ||
(point >= 0xfdd0 && point <= 0xfdef) ||
(point >= 0xd800 && point <= 0xdfff) {
return "", errors.New("invalid BMPString")
}
s = append(s, point)
bmpString = bmpString[2:]
}
return string(utf16.Decode(s)), nil
}
// A RawValue represents an undecoded ASN.1 object.
type RawValue struct {
Class, Tag int
IsCompound bool
Bytes []byte
FullBytes []byte // includes the tag and length
}
// RawContent is used to signal that the undecoded, DER data needs to be
// preserved for a struct. To use it, the first field of the struct must have
// this type. It's an error for any of the other fields to have this type.
type RawContent []byte
// Tagging
// parseTagAndLength parses an ASN.1 tag and length pair from the given offset
// into a byte slice. It returns the parsed data and the new offset. SET and
// SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we
// don't distinguish between ordered and unordered objects in this code.
func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset int, err error) {
offset = initOffset
// parseTagAndLength should not be called without at least a single
// byte to read. Thus this check is for robustness:
if offset >= len(bytes) {
err = errors.New("asn1: internal error in parseTagAndLength")
return
}
b := bytes[offset]
offset++
ret.class = int(b >> 6)
ret.isCompound = b&0x20 == 0x20
ret.tag = int(b & 0x1f)
// If the bottom five bits are set, then the tag number is actually base 128
// encoded afterwards
if ret.tag == 0x1f {
ret.tag, offset, err = parseBase128Int(bytes, offset)
if err != nil {
return
}
// Tags should be encoded in minimal form.
if ret.tag < 0x1f {
err = SyntaxError{"non-minimal tag"}
return
}
}
if offset >= len(bytes) {
err = SyntaxError{"truncated tag or length"}
return
}
b = bytes[offset]
offset++
if b&0x80 == 0 {
// The length is encoded in the bottom 7 bits.
ret.length = int(b & 0x7f)
} else {
// Bottom 7 bits give the number of length bytes to follow.
numBytes := int(b & 0x7f)
if numBytes == 0 {
err = SyntaxError{"indefinite length found (not DER)"}
return
}
ret.length = 0
for i := 0; i < numBytes; i++ {
if offset >= len(bytes) {
err = SyntaxError{"truncated tag or length"}
return
}
b = bytes[offset]
offset++
if ret.length >= 1<<23 {
// We can't shift ret.length up without
// overflowing.
err = StructuralError{"length too large"}
return
}
ret.length <<= 8
ret.length |= int(b)
if ret.length == 0 {
// DER requires that lengths be minimal.
err = StructuralError{"superfluous leading zeros in length"}
return
}
}
// Short lengths must be encoded in short form.
if ret.length < 0x80 {
err = StructuralError{"non-minimal length"}
return
}
}
return
}
// parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse
// a number of ASN.1 values from the given byte slice and returns them as a
// slice of Go values of the given type.
func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type) (ret reflect.Value, err error) {
matchAny, expectedTag, compoundType, ok := getUniversalType(elemType)
if !ok {
err = StructuralError{"unknown Go type for slice"}
return
}
// First we iterate over the input and count the number of elements,
// checking that the types are correct in each case.
numElements := 0
for offset := 0; offset < len(bytes); {
var t tagAndLength
t, offset, err = parseTagAndLength(bytes, offset)
if err != nil {
return
}
switch t.tag {
case TagIA5String, TagGeneralString, TagT61String, TagUTF8String, TagNumericString, TagBMPString:
// We pretend that various other string types are
// PRINTABLE STRINGs so that a sequence of them can be
// parsed into a []string.
t.tag = TagPrintableString
case TagGeneralizedTime, TagUTCTime:
// Likewise, both time types are treated the same.
t.tag = TagUTCTime
}
if !matchAny && (t.class != ClassUniversal || t.isCompound != compoundType || t.tag != expectedTag) {
err = StructuralError{"sequence tag mismatch"}
return
}
if invalidLength(offset, t.length, len(bytes)) {
err = SyntaxError{"truncated sequence"}
return
}
offset += t.length
numElements++
}
ret = reflect.MakeSlice(sliceType, numElements, numElements)
params := fieldParameters{}
offset := 0
for i := 0; i < numElements; i++ {
offset, err = parseField(ret.Index(i), bytes, offset, params)
if err != nil {
return
}
}
return
}
var (
bitStringType = reflect.TypeFor[BitString]()
objectIdentifierType = reflect.TypeFor[ObjectIdentifier]()
enumeratedType = reflect.TypeFor[Enumerated]()
flagType = reflect.TypeFor[Flag]()
timeType = reflect.TypeFor[time.Time]()
rawValueType = reflect.TypeFor[RawValue]()
rawContentsType = reflect.TypeFor[RawContent]()
bigIntType = reflect.TypeFor[*big.Int]()
)
// invalidLength reports whether offset + length > sliceLength, or if the
// addition would overflow.
func invalidLength(offset, length, sliceLength int) bool {
return offset+length < offset || offset+length > sliceLength
}
// parseField is the main parsing function. Given a byte slice and an offset
// into the array, it will try to parse a suitable ASN.1 value out and store it
// in the given Value.
func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParameters) (offset int, err error) {
offset = initOffset
fieldType := v.Type()
// If we have run out of data, it may be that there are optional elements at the end.
if offset == len(bytes) {
if !setDefaultValue(v, params) {
err = SyntaxError{"sequence truncated"}
}
return
}
// Deal with the ANY type.
if ifaceType := fieldType; ifaceType.Kind() == reflect.Interface && ifaceType.NumMethod() == 0 {
var t tagAndLength
t, offset, err = parseTagAndLength(bytes, offset)
if err != nil {
return
}
if invalidLength(offset, t.length, len(bytes)) {
err = SyntaxError{"data truncated"}
return
}
var result any
if !t.isCompound && t.class == ClassUniversal {
innerBytes := bytes[offset : offset+t.length]
switch t.tag {
case TagBoolean:
result, err = parseBool(innerBytes)
case TagPrintableString:
result, err = parsePrintableString(innerBytes)
case TagNumericString:
result, err = parseNumericString(innerBytes)
case TagIA5String:
result, err = parseIA5String(innerBytes)
case TagT61String:
result, err = parseT61String(innerBytes)
case TagUTF8String:
result, err = parseUTF8String(innerBytes)
case TagInteger:
result, err = parseInt64(innerBytes)
case TagBitString:
result, err = parseBitString(innerBytes)
case TagOID:
result, err = parseObjectIdentifier(innerBytes)
case TagUTCTime:
result, err = parseUTCTime(innerBytes)
case TagGeneralizedTime:
result, err = parseGeneralizedTime(innerBytes)
case TagOctetString:
result = innerBytes
case TagBMPString:
result, err = parseBMPString(innerBytes)
default:
// If we don't know how to handle the type, we just leave Value as nil.
}
}
offset += t.length
if err != nil {
return
}
if result != nil {
v.Set(reflect.ValueOf(result))
}
return
}
t, offset, err := parseTagAndLength(bytes, offset)
if err != nil {
return
}
if params.explicit {
expectedClass := ClassContextSpecific
if params.application {
expectedClass = ClassApplication
}
if offset == len(bytes) {
err = StructuralError{"explicit tag has no child"}
return
}
if t.class == expectedClass && t.tag == *params.tag && (t.length == 0 || t.isCompound) {
if fieldType == rawValueType {
// The inner element should not be parsed for RawValues.
} else if t.length > 0 {
t, offset, err = parseTagAndLength(bytes, offset)
if err != nil {
return
}
} else {
if fieldType != flagType {
err = StructuralError{"zero length explicit tag was not an asn1.Flag"}
return
}
v.SetBool(true)
return
}
} else {
// The tags didn't match, it might be an optional element.
ok := setDefaultValue(v, params)
if ok {
offset = initOffset
} else {
err = StructuralError{"explicitly tagged member didn't match"}
}
return
}
}
matchAny, universalTag, compoundType, ok1 := getUniversalType(fieldType)
if !ok1 {
err = StructuralError{fmt.Sprintf("unknown Go type: %v", fieldType)}
return
}
// Special case for strings: all the ASN.1 string types map to the Go
// type string. getUniversalType returns the tag for PrintableString
// when it sees a string, so if we see a different string type on the
// wire, we change the universal type to match.
if universalTag == TagPrintableString {
if t.class == ClassUniversal {
switch t.tag {
case TagIA5String, TagGeneralString, TagT61String, TagUTF8String, TagNumericString, TagBMPString:
universalTag = t.tag
}
} else if params.stringType != 0 {
universalTag = params.stringType
}
}
// Special case for time: UTCTime and GeneralizedTime both map to the
// Go type time.Time. getUniversalType returns the tag for UTCTime when
// it sees a time.Time, so if we see a different time type on the wire,
// or the field is tagged with a different type, we change the universal
// type to match.
if universalTag == TagUTCTime {
if t.class == ClassUniversal {
if t.tag == TagGeneralizedTime {
universalTag = t.tag
}
} else if params.timeType != 0 {
universalTag = params.timeType
}
}
if params.set {
universalTag = TagSet
}
matchAnyClassAndTag := matchAny
expectedClass := ClassUniversal
expectedTag := universalTag
if !params.explicit && params.tag != nil {
expectedClass = ClassContextSpecific
expectedTag = *params.tag
matchAnyClassAndTag = false
}
if !params.explicit && params.application && params.tag != nil {
expectedClass = ClassApplication
expectedTag = *params.tag
matchAnyClassAndTag = false
}
if !params.explicit && params.private && params.tag != nil {
expectedClass = ClassPrivate
expectedTag = *params.tag
matchAnyClassAndTag = false
}
// We have unwrapped any explicit tagging at this point.
if !matchAnyClassAndTag && (t.class != expectedClass || t.tag != expectedTag) ||
(!matchAny && t.isCompound != compoundType) {
// Tags don't match. Again, it could be an optional element.
ok := setDefaultValue(v, params)
if ok {
offset = initOffset
} else {
err = StructuralError{fmt.Sprintf("tags don't match (%d vs %+v) %+v %s @%d", expectedTag, t, params, fieldType.Name(), offset)}
}
return
}
if invalidLength(offset, t.length, len(bytes)) {
err = SyntaxError{"data truncated"}
return
}
innerBytes := bytes[offset : offset+t.length]
offset += t.length
// We deal with the structures defined in this package first.
switch v := v.Addr().Interface().(type) {
case *RawValue:
*v = RawValue{t.class, t.tag, t.isCompound, innerBytes, bytes[initOffset:offset]}
return
case *ObjectIdentifier:
*v, err = parseObjectIdentifier(innerBytes)
return
case *BitString:
*v, err = parseBitString(innerBytes)
return
case *time.Time:
if universalTag == TagUTCTime {
*v, err = parseUTCTime(innerBytes)
return
}
*v, err = parseGeneralizedTime(innerBytes)
return
case *Enumerated:
parsedInt, err1 := parseInt32(innerBytes)
if err1 == nil {
*v = Enumerated(parsedInt)
}
err = err1
return
case *Flag:
*v = true
return
case **big.Int:
parsedInt, err1 := parseBigInt(innerBytes)
if err1 == nil {
*v = parsedInt
}
err = err1
return
}
switch val := v; val.Kind() {
case reflect.Bool:
parsedBool, err1 := parseBool(innerBytes)
if err1 == nil {
val.SetBool(parsedBool)
}
err = err1
return
case reflect.Int, reflect.Int32, reflect.Int64:
if val.Type().Size() == 4 {
parsedInt, err1 := parseInt32(innerBytes)
if err1 == nil {
val.SetInt(int64(parsedInt))
}
err = err1
} else {
parsedInt, err1 := parseInt64(innerBytes)
if err1 == nil {
val.SetInt(parsedInt)
}
err = err1
}
return
// TODO(dfc) Add support for the remaining integer types
case reflect.Struct:
structType := fieldType
for i := 0; i < structType.NumField(); i++ {
if !structType.Field(i).IsExported() {
err = StructuralError{"struct contains unexported fields"}
return
}
}
if structType.NumField() > 0 &&
structType.Field(0).Type == rawContentsType {
bytes := bytes[initOffset:offset]
val.Field(0).Set(reflect.ValueOf(RawContent(bytes)))
}
innerOffset := 0
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
if i == 0 && field.Type == rawContentsType {
continue
}
innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag.Get("asn1")))
if err != nil {
return
}
}
// We allow extra bytes at the end of the SEQUENCE because
// adding elements to the end has been used in X.509 as the
// version numbers have increased.
return
case reflect.Slice:
sliceType := fieldType
if sliceType.Elem().Kind() == reflect.Uint8 {
val.Set(reflect.MakeSlice(sliceType, len(innerBytes), len(innerBytes)))
reflect.Copy(val, reflect.ValueOf(innerBytes))
return
}
newSlice, err1 := parseSequenceOf(innerBytes, sliceType, sliceType.Elem())
if err1 == nil {
val.Set(newSlice)
}
err = err1
return
case reflect.String:
var v string
switch universalTag {
case TagPrintableString:
v, err = parsePrintableString(innerBytes)
case TagNumericString:
v, err = parseNumericString(innerBytes)
case TagIA5String:
v, err = parseIA5String(innerBytes)
case TagT61String:
v, err = parseT61String(innerBytes)
case TagUTF8String:
v, err = parseUTF8String(innerBytes)
case TagGeneralString:
// GeneralString is specified in ISO-2022/ECMA-35,
// A brief review suggests that it includes structures
// that allow the encoding to change midstring and
// such. We give up and pass it as an 8-bit string.
v, err = parseT61String(innerBytes)
case TagBMPString:
v, err = parseBMPString(innerBytes)
default:
err = SyntaxError{fmt.Sprintf("internal error: unknown string type %d", universalTag)}
}
if err == nil {
val.SetString(v)
}
return
}
err = StructuralError{"unsupported: " + v.Type().String()}
return
}
// canHaveDefaultValue reports whether k is a Kind that we will set a default
// value for. (A signed integer, essentially.)
func canHaveDefaultValue(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
}
return false
}
// setDefaultValue is used to install a default value, from a tag string, into
// a Value. It is successful if the field was optional, even if a default value
// wasn't provided or it failed to install it into the Value.
func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
if !params.optional {
return
}
ok = true
if params.defaultValue == nil {
return
}
if canHaveDefaultValue(v.Kind()) {
v.SetInt(*params.defaultValue)
}
return
}
// Unmarshal parses the DER-encoded ASN.1 data structure b
// and uses the reflect package to fill in an arbitrary value pointed at by val.
// Because Unmarshal uses the reflect package, the structs
// being written to must use upper case field names. If val
// is nil or not a pointer, Unmarshal returns an error.
//
// After parsing b, any bytes that were leftover and not used to fill
// val will be returned in rest. When parsing a SEQUENCE into a struct,
// any trailing elements of the SEQUENCE that do not have matching
// fields in val will not be included in rest, as these are considered
// valid elements of the SEQUENCE and not trailing data.
//
// - An ASN.1 INTEGER can be written to an int, int32, int64,
// or *[big.Int].
// If the encoded value does not fit in the Go type,
// Unmarshal returns a parse error.
//
// - An ASN.1 BIT STRING can be written to a [BitString].
//
// - An ASN.1 OCTET STRING can be written to a []byte.
//
// - An ASN.1 OBJECT IDENTIFIER can be written to an [ObjectIdentifier].
//
// - An ASN.1 ENUMERATED can be written to an [Enumerated].
//
// - An ASN.1 UTCTIME or GENERALIZEDTIME can be written to a [time.Time].
//
// - An ASN.1 PrintableString, IA5String, or NumericString can be written to a string.
//
// - Any of the above ASN.1 values can be written to an interface{}.
// The value stored in the interface has the corresponding Go type.
// For integers, that type is int64.
//
// - An ASN.1 SEQUENCE OF x or SET OF x can be written
// to a slice if an x can be written to the slice's element type.
//
// - An ASN.1 SEQUENCE or SET can be written to a struct
// if each of the elements in the sequence can be
// written to the corresponding element in the struct.
//
// The following tags on struct fields have special meaning to Unmarshal:
//
// application specifies that an APPLICATION tag is used
// private specifies that a PRIVATE tag is used
// default:x sets the default value for optional integer fields (only used if optional is also present)
// explicit specifies that an additional, explicit tag wraps the implicit one
// optional marks the field as ASN.1 OPTIONAL
// set causes a SET, rather than a SEQUENCE type to be expected
// tag:x specifies the ASN.1 tag number; implies ASN.1 CONTEXT SPECIFIC
//
// When decoding an ASN.1 value with an IMPLICIT tag into a string field,
// Unmarshal will default to a PrintableString, which doesn't support
// characters such as '@' and '&'. To force other encodings, use the following
// tags:
//
// ia5 causes strings to be unmarshaled as ASN.1 IA5String values
// numeric causes strings to be unmarshaled as ASN.1 NumericString values
// utf8 causes strings to be unmarshaled as ASN.1 UTF8String values
//
// When decoding an ASN.1 value with an IMPLICIT tag into a time.Time field,
// Unmarshal will default to a UTCTime, which doesn't support time zones or
// fractional seconds. To force usage of GeneralizedTime, use the following
// tag:
//
// generalized causes time.Times to be unmarshaled as ASN.1 GeneralizedTime values
//
// If the type of the first field of a structure is RawContent then the raw
// ASN1 contents of the struct will be stored in it.
//
// If the name of a slice type ends with "SET" then it's treated as if
// the "set" tag was set on it. This results in interpreting the type as a
// SET OF x rather than a SEQUENCE OF x. This can be used with nested slices
// where a struct tag cannot be given.
//
// Other ASN.1 types are not supported; if it encounters them,
// Unmarshal returns a parse error.
func Unmarshal(b []byte, val any) (rest []byte, err error) {
return UnmarshalWithParams(b, val, "")
}
// An invalidUnmarshalError describes an invalid argument passed to Unmarshal.
// (The argument to Unmarshal must be a non-nil pointer.)
type invalidUnmarshalError struct {
Type reflect.Type
}
func (e *invalidUnmarshalError) Error() string {
if e.Type == nil {
return "asn1: Unmarshal recipient value is nil"
}
if e.Type.Kind() != reflect.Pointer {
return "asn1: Unmarshal recipient value is non-pointer " + e.Type.String()
}
return "asn1: Unmarshal recipient value is nil " + e.Type.String()
}
// UnmarshalWithParams allows field parameters to be specified for the
// top-level element. The form of the params is the same as the field tags.
func UnmarshalWithParams(b []byte, val any, params string) (rest []byte, err error) {
v := reflect.ValueOf(val)
if v.Kind() != reflect.Pointer || v.IsNil() {
return nil, &invalidUnmarshalError{reflect.TypeOf(val)}
}
offset, err := parseField(v.Elem(), b, 0, parseFieldParameters(params))
if err != nil {
return nil, err
}
return b[offset:], nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package asn1
import (
"reflect"
"strconv"
"strings"
)
// ASN.1 objects have metadata preceding them:
// the tag: the type of the object
// a flag denoting if this object is compound or not
// the class type: the namespace of the tag
// the length of the object, in bytes
// Here are some standard tags and classes
// ASN.1 tags represent the type of the following object.
const (
TagBoolean = 1
TagInteger = 2
TagBitString = 3
TagOctetString = 4
TagNull = 5
TagOID = 6
TagEnum = 10
TagUTF8String = 12
TagSequence = 16
TagSet = 17
TagNumericString = 18
TagPrintableString = 19
TagT61String = 20
TagIA5String = 22
TagUTCTime = 23
TagGeneralizedTime = 24
TagGeneralString = 27
TagBMPString = 30
)
// ASN.1 class types represent the namespace of the tag.
const (
ClassUniversal = 0
ClassApplication = 1
ClassContextSpecific = 2
ClassPrivate = 3
)
type tagAndLength struct {
class, tag, length int
isCompound bool
}
// ASN.1 has IMPLICIT and EXPLICIT tags, which can be translated as "instead
// of" and "in addition to". When not specified, every primitive type has a
// default tag in the UNIVERSAL class.
//
// For example: a BIT STRING is tagged [UNIVERSAL 3] by default (although ASN.1
// doesn't actually have a UNIVERSAL keyword). However, by saying [IMPLICIT
// CONTEXT-SPECIFIC 42], that means that the tag is replaced by another.
//
// On the other hand, if it said [EXPLICIT CONTEXT-SPECIFIC 10], then an
// /additional/ tag would wrap the default tag. This explicit tag will have the
// compound flag set.
//
// (This is used in order to remove ambiguity with optional elements.)
//
// You can layer EXPLICIT and IMPLICIT tags to an arbitrary depth, however we
// don't support that here. We support a single layer of EXPLICIT or IMPLICIT
// tagging with tag strings on the fields of a structure.
// fieldParameters is the parsed representation of tag string from a structure field.
type fieldParameters struct {
optional bool // true iff the field is OPTIONAL
explicit bool // true iff an EXPLICIT tag is in use.
application bool // true iff an APPLICATION tag is in use.
private bool // true iff a PRIVATE tag is in use.
defaultValue *int64 // a default value for INTEGER typed fields (maybe nil).
tag *int // the EXPLICIT or IMPLICIT tag (maybe nil).
stringType int // the string tag to use when marshaling.
timeType int // the time tag to use when marshaling.
set bool // true iff this should be encoded as a SET
omitEmpty bool // true iff this should be omitted if empty when marshaling.
// Invariants:
// if explicit is set, tag is non-nil.
}
// Given a tag string with the format specified in the package comment,
// parseFieldParameters will parse it into a fieldParameters structure,
// ignoring unknown parts of the string.
func parseFieldParameters(str string) (ret fieldParameters) {
var part string
for len(str) > 0 {
part, str, _ = strings.Cut(str, ",")
switch {
case part == "optional":
ret.optional = true
case part == "explicit":
ret.explicit = true
if ret.tag == nil {
ret.tag = new(int)
}
case part == "generalized":
ret.timeType = TagGeneralizedTime
case part == "utc":
ret.timeType = TagUTCTime
case part == "ia5":
ret.stringType = TagIA5String
case part == "printable":
ret.stringType = TagPrintableString
case part == "numeric":
ret.stringType = TagNumericString
case part == "utf8":
ret.stringType = TagUTF8String
case strings.HasPrefix(part, "default:"):
i, err := strconv.ParseInt(part[8:], 10, 64)
if err == nil {
ret.defaultValue = new(int64)
*ret.defaultValue = i
}
case strings.HasPrefix(part, "tag:"):
i, err := strconv.Atoi(part[4:])
if err == nil {
ret.tag = new(int)
*ret.tag = i
}
case part == "set":
ret.set = true
case part == "application":
ret.application = true
if ret.tag == nil {
ret.tag = new(int)
}
case part == "private":
ret.private = true
if ret.tag == nil {
ret.tag = new(int)
}
case part == "omitempty":
ret.omitEmpty = true
}
}
return
}
// Given a reflected Go type, getUniversalType returns the default tag number
// and expected compound flag.
func getUniversalType(t reflect.Type) (matchAny bool, tagNumber int, isCompound, ok bool) {
switch t {
case rawValueType:
return true, -1, false, true
case objectIdentifierType:
return false, TagOID, false, true
case bitStringType:
return false, TagBitString, false, true
case timeType:
return false, TagUTCTime, false, true
case enumeratedType:
return false, TagEnum, false, true
case bigIntType:
return false, TagInteger, false, true
}
switch t.Kind() {
case reflect.Bool:
return false, TagBoolean, false, true
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return false, TagInteger, false, true
case reflect.Struct:
return false, TagSequence, true, true
case reflect.Slice:
if t.Elem().Kind() == reflect.Uint8 {
return false, TagOctetString, false, true
}
if strings.HasSuffix(t.Name(), "SET") {
return false, TagSet, true, true
}
return false, TagSequence, true, true
case reflect.String:
return false, TagPrintableString, false, true
}
return false, 0, false, false
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package asn1
import (
"bytes"
"errors"
"fmt"
"math/big"
"reflect"
"slices"
"time"
"unicode/utf8"
)
var (
byte00Encoder encoder = byteEncoder(0x00)
byteFFEncoder encoder = byteEncoder(0xff)
)
// encoder represents an ASN.1 element that is waiting to be marshaled.
type encoder interface {
// Len returns the number of bytes needed to marshal this element.
Len() int
// Encode encodes this element by writing Len() bytes to dst.
Encode(dst []byte)
}
type byteEncoder byte
func (c byteEncoder) Len() int {
return 1
}
func (c byteEncoder) Encode(dst []byte) {
dst[0] = byte(c)
}
type bytesEncoder []byte
func (b bytesEncoder) Len() int {
return len(b)
}
func (b bytesEncoder) Encode(dst []byte) {
if copy(dst, b) != len(b) {
panic("internal error")
}
}
type stringEncoder string
func (s stringEncoder) Len() int {
return len(s)
}
func (s stringEncoder) Encode(dst []byte) {
if copy(dst, s) != len(s) {
panic("internal error")
}
}
type multiEncoder []encoder
func (m multiEncoder) Len() int {
var size int
for _, e := range m {
size += e.Len()
}
return size
}
func (m multiEncoder) Encode(dst []byte) {
var off int
for _, e := range m {
e.Encode(dst[off:])
off += e.Len()
}
}
type setEncoder []encoder
func (s setEncoder) Len() int {
var size int
for _, e := range s {
size += e.Len()
}
return size
}
func (s setEncoder) Encode(dst []byte) {
// Per X690 Section 11.6: The encodings of the component values of a
// set-of value shall appear in ascending order, the encodings being
// compared as octet strings with the shorter components being padded
// at their trailing end with 0-octets.
//
// First we encode each element to its TLV encoding and then use
// octetSort to get the ordering expected by X690 DER rules before
// writing the sorted encodings out to dst.
l := make([][]byte, len(s))
for i, e := range s {
l[i] = make([]byte, e.Len())
e.Encode(l[i])
}
// Since we are using bytes.Compare to compare TLV encodings we
// don't need to right pad s[i] and s[j] to the same length as
// suggested in X690. If len(s[i]) < len(s[j]) the length octet of
// s[i], which is the first determining byte, will inherently be
// smaller than the length octet of s[j]. This lets us skip the
// padding step.
slices.SortFunc(l, bytes.Compare)
var off int
for _, b := range l {
copy(dst[off:], b)
off += len(b)
}
}
type taggedEncoder struct {
// scratch contains temporary space for encoding the tag and length of
// an element in order to avoid extra allocations.
scratch [8]byte
tag encoder
body encoder
}
func (t *taggedEncoder) Len() int {
return t.tag.Len() + t.body.Len()
}
func (t *taggedEncoder) Encode(dst []byte) {
t.tag.Encode(dst)
t.body.Encode(dst[t.tag.Len():])
}
type int64Encoder int64
func (i int64Encoder) Len() int {
n := 1
for i > 127 {
n++
i >>= 8
}
for i < -128 {
n++
i >>= 8
}
return n
}
func (i int64Encoder) Encode(dst []byte) {
n := i.Len()
for j := 0; j < n; j++ {
dst[j] = byte(i >> uint((n-1-j)*8))
}
}
func base128IntLength(n int64) int {
if n == 0 {
return 1
}
l := 0
for i := n; i > 0; i >>= 7 {
l++
}
return l
}
func appendBase128Int(dst []byte, n int64) []byte {
l := base128IntLength(n)
for i := l - 1; i >= 0; i-- {
o := byte(n >> uint(i*7))
o &= 0x7f
if i != 0 {
o |= 0x80
}
dst = append(dst, o)
}
return dst
}
func makeBigInt(n *big.Int) (encoder, error) {
if n == nil {
return nil, StructuralError{"empty integer"}
}
if n.Sign() < 0 {
// A negative number has to be converted to two's-complement
// form. So we'll invert and subtract 1. If the
// most-significant-bit isn't set then we'll need to pad the
// beginning with 0xff in order to keep the number negative.
nMinus1 := new(big.Int).Neg(n)
nMinus1.Sub(nMinus1, bigOne)
bytes := nMinus1.Bytes()
for i := range bytes {
bytes[i] ^= 0xff
}
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
return multiEncoder([]encoder{byteFFEncoder, bytesEncoder(bytes)}), nil
}
return bytesEncoder(bytes), nil
} else if n.Sign() == 0 {
// Zero is written as a single 0 zero rather than no bytes.
return byte00Encoder, nil
} else {
bytes := n.Bytes()
if len(bytes) > 0 && bytes[0]&0x80 != 0 {
// We'll have to pad this with 0x00 in order to stop it
// looking like a negative number.
return multiEncoder([]encoder{byte00Encoder, bytesEncoder(bytes)}), nil
}
return bytesEncoder(bytes), nil
}
}
func appendLength(dst []byte, i int) []byte {
n := lengthLength(i)
for ; n > 0; n-- {
dst = append(dst, byte(i>>uint((n-1)*8)))
}
return dst
}
func lengthLength(i int) (numBytes int) {
numBytes = 1
for i > 255 {
numBytes++
i >>= 8
}
return
}
func appendTagAndLength(dst []byte, t tagAndLength) []byte {
b := uint8(t.class) << 6
if t.isCompound {
b |= 0x20
}
if t.tag >= 31 {
b |= 0x1f
dst = append(dst, b)
dst = appendBase128Int(dst, int64(t.tag))
} else {
b |= uint8(t.tag)
dst = append(dst, b)
}
if t.length >= 128 {
l := lengthLength(t.length)
dst = append(dst, 0x80|byte(l))
dst = appendLength(dst, t.length)
} else {
dst = append(dst, byte(t.length))
}
return dst
}
type bitStringEncoder BitString
func (b bitStringEncoder) Len() int {
return len(b.Bytes) + 1
}
func (b bitStringEncoder) Encode(dst []byte) {
dst[0] = byte((8 - b.BitLength%8) % 8)
if copy(dst[1:], b.Bytes) != len(b.Bytes) {
panic("internal error")
}
}
type oidEncoder []int
func (oid oidEncoder) Len() int {
l := base128IntLength(int64(oid[0]*40 + oid[1]))
for i := 2; i < len(oid); i++ {
l += base128IntLength(int64(oid[i]))
}
return l
}
func (oid oidEncoder) Encode(dst []byte) {
dst = appendBase128Int(dst[:0], int64(oid[0]*40+oid[1]))
for i := 2; i < len(oid); i++ {
dst = appendBase128Int(dst, int64(oid[i]))
}
}
func makeObjectIdentifier(oid []int) (e encoder, err error) {
if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
return nil, StructuralError{"invalid object identifier"}
}
return oidEncoder(oid), nil
}
func makePrintableString(s string) (e encoder, err error) {
for i := 0; i < len(s); i++ {
// The asterisk is often used in PrintableString, even though
// it is invalid. If a PrintableString was specifically
// requested then the asterisk is permitted by this code.
// Ampersand is allowed in parsing due a handful of CA
// certificates, however when making new certificates
// it is rejected.
if !isPrintable(s[i], allowAsterisk, rejectAmpersand) {
return nil, StructuralError{"PrintableString contains invalid character"}
}
}
return stringEncoder(s), nil
}
func makeIA5String(s string) (e encoder, err error) {
for i := 0; i < len(s); i++ {
if s[i] > 127 {
return nil, StructuralError{"IA5String contains invalid character"}
}
}
return stringEncoder(s), nil
}
func makeNumericString(s string) (e encoder, err error) {
for i := 0; i < len(s); i++ {
if !isNumeric(s[i]) {
return nil, StructuralError{"NumericString contains invalid character"}
}
}
return stringEncoder(s), nil
}
func makeUTF8String(s string) encoder {
return stringEncoder(s)
}
func appendTwoDigits(dst []byte, v int) []byte {
return append(dst, byte('0'+(v/10)%10), byte('0'+v%10))
}
func appendFourDigits(dst []byte, v int) []byte {
return append(dst,
byte('0'+(v/1000)%10),
byte('0'+(v/100)%10),
byte('0'+(v/10)%10),
byte('0'+v%10))
}
func outsideUTCRange(t time.Time) bool {
year := t.Year()
return year < 1950 || year >= 2050
}
func makeUTCTime(t time.Time) (e encoder, err error) {
dst := make([]byte, 0, 18)
dst, err = appendUTCTime(dst, t)
if err != nil {
return nil, err
}
return bytesEncoder(dst), nil
}
func makeGeneralizedTime(t time.Time) (e encoder, err error) {
dst := make([]byte, 0, 20)
dst, err = appendGeneralizedTime(dst, t)
if err != nil {
return nil, err
}
return bytesEncoder(dst), nil
}
func appendUTCTime(dst []byte, t time.Time) (ret []byte, err error) {
year := t.Year()
switch {
case 1950 <= year && year < 2000:
dst = appendTwoDigits(dst, year-1900)
case 2000 <= year && year < 2050:
dst = appendTwoDigits(dst, year-2000)
default:
return nil, StructuralError{"cannot represent time as UTCTime"}
}
return appendTimeCommon(dst, t), nil
}
func appendGeneralizedTime(dst []byte, t time.Time) (ret []byte, err error) {
year := t.Year()
if year < 0 || year > 9999 {
return nil, StructuralError{"cannot represent time as GeneralizedTime"}
}
dst = appendFourDigits(dst, year)
return appendTimeCommon(dst, t), nil
}
func appendTimeCommon(dst []byte, t time.Time) []byte {
_, month, day := t.Date()
dst = appendTwoDigits(dst, int(month))
dst = appendTwoDigits(dst, day)
hour, min, sec := t.Clock()
dst = appendTwoDigits(dst, hour)
dst = appendTwoDigits(dst, min)
dst = appendTwoDigits(dst, sec)
_, offset := t.Zone()
switch {
case offset/60 == 0:
return append(dst, 'Z')
case offset > 0:
dst = append(dst, '+')
case offset < 0:
dst = append(dst, '-')
}
offsetMinutes := offset / 60
if offsetMinutes < 0 {
offsetMinutes = -offsetMinutes
}
dst = appendTwoDigits(dst, offsetMinutes/60)
dst = appendTwoDigits(dst, offsetMinutes%60)
return dst
}
func stripTagAndLength(in []byte) []byte {
_, offset, err := parseTagAndLength(in, 0)
if err != nil {
return in
}
return in[offset:]
}
func makeBody(value reflect.Value, params fieldParameters) (e encoder, err error) {
switch value.Type() {
case flagType:
return bytesEncoder(nil), nil
case timeType:
t := value.Interface().(time.Time)
if params.timeType == TagGeneralizedTime || outsideUTCRange(t) {
return makeGeneralizedTime(t)
}
return makeUTCTime(t)
case bitStringType:
return bitStringEncoder(value.Interface().(BitString)), nil
case objectIdentifierType:
return makeObjectIdentifier(value.Interface().(ObjectIdentifier))
case bigIntType:
return makeBigInt(value.Interface().(*big.Int))
}
switch v := value; v.Kind() {
case reflect.Bool:
if v.Bool() {
return byteFFEncoder, nil
}
return byte00Encoder, nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return int64Encoder(v.Int()), nil
case reflect.Struct:
t := v.Type()
for i := 0; i < t.NumField(); i++ {
if !t.Field(i).IsExported() {
return nil, StructuralError{"struct contains unexported fields"}
}
}
startingField := 0
n := t.NumField()
if n == 0 {
return bytesEncoder(nil), nil
}
// If the first element of the structure is a non-empty
// RawContents, then we don't bother serializing the rest.
if t.Field(0).Type == rawContentsType {
s := v.Field(0)
if s.Len() > 0 {
bytes := s.Bytes()
/* The RawContents will contain the tag and
* length fields but we'll also be writing
* those ourselves, so we strip them out of
* bytes */
return bytesEncoder(stripTagAndLength(bytes)), nil
}
startingField = 1
}
switch n1 := n - startingField; n1 {
case 0:
return bytesEncoder(nil), nil
case 1:
return makeField(v.Field(startingField), parseFieldParameters(t.Field(startingField).Tag.Get("asn1")))
default:
m := make([]encoder, n1)
for i := 0; i < n1; i++ {
m[i], err = makeField(v.Field(i+startingField), parseFieldParameters(t.Field(i+startingField).Tag.Get("asn1")))
if err != nil {
return nil, err
}
}
return multiEncoder(m), nil
}
case reflect.Slice:
sliceType := v.Type()
if sliceType.Elem().Kind() == reflect.Uint8 {
return bytesEncoder(v.Bytes()), nil
}
var fp fieldParameters
switch l := v.Len(); l {
case 0:
return bytesEncoder(nil), nil
case 1:
return makeField(v.Index(0), fp)
default:
m := make([]encoder, l)
for i := 0; i < l; i++ {
m[i], err = makeField(v.Index(i), fp)
if err != nil {
return nil, err
}
}
if params.set {
return setEncoder(m), nil
}
return multiEncoder(m), nil
}
case reflect.String:
switch params.stringType {
case TagIA5String:
return makeIA5String(v.String())
case TagPrintableString:
return makePrintableString(v.String())
case TagNumericString:
return makeNumericString(v.String())
default:
return makeUTF8String(v.String()), nil
}
}
return nil, StructuralError{"unknown Go type"}
}
func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) {
if !v.IsValid() {
return nil, fmt.Errorf("asn1: cannot marshal nil value")
}
// If the field is an interface{} then recurse into it.
if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
return makeField(v.Elem(), params)
}
if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
return bytesEncoder(nil), nil
}
if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
defaultValue := reflect.New(v.Type()).Elem()
defaultValue.SetInt(*params.defaultValue)
if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
return bytesEncoder(nil), nil
}
}
// If no default value is given then the zero value for the type is
// assumed to be the default value. This isn't obviously the correct
// behavior, but it's what Go has traditionally done.
if params.optional && params.defaultValue == nil {
if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
return bytesEncoder(nil), nil
}
}
if v.Type() == rawValueType {
rv := v.Interface().(RawValue)
if len(rv.FullBytes) != 0 {
return bytesEncoder(rv.FullBytes), nil
}
t := new(taggedEncoder)
t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}))
t.body = bytesEncoder(rv.Bytes)
return t, nil
}
matchAny, tag, isCompound, ok := getUniversalType(v.Type())
if !ok || matchAny {
return nil, StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
}
if params.timeType != 0 && tag != TagUTCTime {
return nil, StructuralError{"explicit time type given to non-time member"}
}
if params.stringType != 0 && tag != TagPrintableString {
return nil, StructuralError{"explicit string type given to non-string member"}
}
switch tag {
case TagPrintableString:
if params.stringType == 0 {
// This is a string without an explicit string type. We'll use
// a PrintableString if the character set in the string is
// sufficiently limited, otherwise we'll use a UTF8String.
for _, r := range v.String() {
if r >= utf8.RuneSelf || !isPrintable(byte(r), rejectAsterisk, rejectAmpersand) {
if !utf8.ValidString(v.String()) {
return nil, errors.New("asn1: string not valid UTF-8")
}
tag = TagUTF8String
break
}
}
} else {
tag = params.stringType
}
case TagUTCTime:
if params.timeType == TagGeneralizedTime || outsideUTCRange(v.Interface().(time.Time)) {
tag = TagGeneralizedTime
}
}
if params.set {
if tag != TagSequence {
return nil, StructuralError{"non sequence tagged as set"}
}
tag = TagSet
}
// makeField can be called for a slice that should be treated as a SET
// but doesn't have params.set set, for instance when using a slice
// with the SET type name suffix. In this case getUniversalType returns
// TagSet, but makeBody doesn't know about that so will treat the slice
// as a sequence. To work around this we set params.set.
if tag == TagSet && !params.set {
params.set = true
}
t := new(taggedEncoder)
t.body, err = makeBody(v, params)
if err != nil {
return nil, err
}
bodyLen := t.body.Len()
class := ClassUniversal
if params.tag != nil {
if params.application {
class = ClassApplication
} else if params.private {
class = ClassPrivate
} else {
class = ClassContextSpecific
}
if params.explicit {
t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{ClassUniversal, tag, bodyLen, isCompound}))
tt := new(taggedEncoder)
tt.body = t
tt.tag = bytesEncoder(appendTagAndLength(tt.scratch[:0], tagAndLength{
class: class,
tag: *params.tag,
length: bodyLen + t.tag.Len(),
isCompound: true,
}))
return tt, nil
}
// implicit tag.
tag = *params.tag
}
t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
return t, nil
}
// Marshal returns the ASN.1 encoding of val.
//
// In addition to the struct tags recognized by Unmarshal, the following can be
// used:
//
// ia5: causes strings to be marshaled as ASN.1, IA5String values
// omitempty: causes empty slices to be skipped
// printable: causes strings to be marshaled as ASN.1, PrintableString values
// utf8: causes strings to be marshaled as ASN.1, UTF8String values
// numeric: causes strings to be marshaled as ASN.1, NumericString values
// utc: causes time.Time to be marshaled as ASN.1, UTCTime values
// generalized: causes time.Time to be marshaled as ASN.1, GeneralizedTime values
func Marshal(val any) ([]byte, error) {
return MarshalWithParams(val, "")
}
// MarshalWithParams allows field parameters to be specified for the
// top-level element. The form of the params is the same as the field tags.
func MarshalWithParams(val any, params string) ([]byte, error) {
e, err := makeField(reflect.ValueOf(val), parseFieldParameters(params))
if err != nil {
return nil, err
}
b := make([]byte, e.Len())
e.Encode(b)
return b, nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package base32 implements base32 encoding as specified by RFC 4648.
package base32
import (
"io"
"slices"
"strconv"
)
/*
* Encodings
*/
// An Encoding is a radix 32 encoding/decoding scheme, defined by a
// 32-character alphabet. The most common is the "base32" encoding
// introduced for SASL GSSAPI and standardized in RFC 4648.
// The alternate "base32hex" encoding is used in DNSSEC.
type Encoding struct {
encode [32]byte // mapping of symbol index to symbol byte value
decodeMap [256]uint8 // mapping of symbol byte value to symbol index
padChar rune
}
const (
StdPadding rune = '=' // Standard padding character
NoPadding rune = -1 // No padding
)
const (
decodeMapInitialize = "" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
invalidIndex = '\xff'
)
// NewEncoding returns a new padded Encoding defined by the given alphabet,
// which must be a 32-byte string that contains unique byte values and
// does not contain the padding character or CR / LF ('\r', '\n').
// The alphabet is treated as a sequence of byte values
// without any special treatment for multi-byte UTF-8.
// The resulting Encoding uses the default padding character ('='),
// which may be changed or disabled via [Encoding.WithPadding].
func NewEncoding(encoder string) *Encoding {
if len(encoder) != 32 {
panic("encoding alphabet is not 32-bytes long")
}
e := new(Encoding)
e.padChar = StdPadding
copy(e.encode[:], encoder)
copy(e.decodeMap[:], decodeMapInitialize)
for i := 0; i < len(encoder); i++ {
// Note: While we document that the alphabet cannot contain
// the padding character, we do not enforce it since we do not know
// if the caller intends to switch the padding from StdPadding later.
switch {
case encoder[i] == '\n' || encoder[i] == '\r':
panic("encoding alphabet contains newline character")
case e.decodeMap[encoder[i]] != invalidIndex:
panic("encoding alphabet includes duplicate symbols")
}
e.decodeMap[encoder[i]] = uint8(i)
}
return e
}
// StdEncoding is the standard base32 encoding, as defined in RFC 4648.
var StdEncoding = NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567")
// HexEncoding is the “Extended Hex Alphabet” defined in RFC 4648.
// It is typically used in DNS.
var HexEncoding = NewEncoding("0123456789ABCDEFGHIJKLMNOPQRSTUV")
// WithPadding creates a new encoding identical to enc except
// with a specified padding character, or NoPadding to disable padding.
// The padding character must not be '\r' or '\n',
// must not be contained in the encoding's alphabet,
// must not be negative, and must be a rune equal or below '\xff'.
// Padding characters above '\x7f' are encoded as their exact byte value
// rather than using the UTF-8 representation of the codepoint.
func (enc Encoding) WithPadding(padding rune) *Encoding {
switch {
case padding < NoPadding || padding == '\r' || padding == '\n' || padding > 0xff:
panic("invalid padding")
case padding != NoPadding && enc.decodeMap[byte(padding)] != invalidIndex:
panic("padding contained in alphabet")
}
enc.padChar = padding
return &enc
}
/*
* Encoder
*/
// Encode encodes src using the encoding enc,
// writing [Encoding.EncodedLen](len(src)) bytes to dst.
//
// The encoding pads the output to a multiple of 8 bytes,
// so Encode is not appropriate for use on individual blocks
// of a large data stream. Use [NewEncoder] instead.
func (enc *Encoding) Encode(dst, src []byte) {
if len(src) == 0 {
return
}
// enc is a pointer receiver, so the use of enc.encode within the hot
// loop below means a nil check at every operation. Lift that nil check
// outside of the loop to speed up the encoder.
_ = enc.encode
di, si := 0, 0
n := (len(src) / 5) * 5
for si < n {
// Combining two 32 bit loads allows the same code to be used
// for 32 and 64 bit platforms.
hi := uint32(src[si+0])<<24 | uint32(src[si+1])<<16 | uint32(src[si+2])<<8 | uint32(src[si+3])
lo := hi<<8 | uint32(src[si+4])
dst[di+0] = enc.encode[(hi>>27)&0x1F]
dst[di+1] = enc.encode[(hi>>22)&0x1F]
dst[di+2] = enc.encode[(hi>>17)&0x1F]
dst[di+3] = enc.encode[(hi>>12)&0x1F]
dst[di+4] = enc.encode[(hi>>7)&0x1F]
dst[di+5] = enc.encode[(hi>>2)&0x1F]
dst[di+6] = enc.encode[(lo>>5)&0x1F]
dst[di+7] = enc.encode[(lo)&0x1F]
si += 5
di += 8
}
// Add the remaining small block
remain := len(src) - si
if remain == 0 {
return
}
// Encode the remaining bytes in reverse order.
val := uint32(0)
switch remain {
case 4:
val |= uint32(src[si+3])
dst[di+6] = enc.encode[val<<3&0x1F]
dst[di+5] = enc.encode[val>>2&0x1F]
fallthrough
case 3:
val |= uint32(src[si+2]) << 8
dst[di+4] = enc.encode[val>>7&0x1F]
fallthrough
case 2:
val |= uint32(src[si+1]) << 16
dst[di+3] = enc.encode[val>>12&0x1F]
dst[di+2] = enc.encode[val>>17&0x1F]
fallthrough
case 1:
val |= uint32(src[si+0]) << 24
dst[di+1] = enc.encode[val>>22&0x1F]
dst[di+0] = enc.encode[val>>27&0x1F]
}
// Pad the final quantum
if enc.padChar != NoPadding {
nPad := (remain * 8 / 5) + 1
for i := nPad; i < 8; i++ {
dst[di+i] = byte(enc.padChar)
}
}
}
// AppendEncode appends the base32 encoded src to dst
// and returns the extended buffer.
func (enc *Encoding) AppendEncode(dst, src []byte) []byte {
n := enc.EncodedLen(len(src))
dst = slices.Grow(dst, n)
enc.Encode(dst[len(dst):][:n], src)
return dst[:len(dst)+n]
}
// EncodeToString returns the base32 encoding of src.
func (enc *Encoding) EncodeToString(src []byte) string {
buf := make([]byte, enc.EncodedLen(len(src)))
enc.Encode(buf, src)
return string(buf)
}
type encoder struct {
err error
enc *Encoding
w io.Writer
buf [5]byte // buffered data waiting to be encoded
nbuf int // number of bytes in buf
out [1024]byte // output buffer
}
func (e *encoder) Write(p []byte) (n int, err error) {
if e.err != nil {
return 0, e.err
}
// Leading fringe.
if e.nbuf > 0 {
var i int
for i = 0; i < len(p) && e.nbuf < 5; i++ {
e.buf[e.nbuf] = p[i]
e.nbuf++
}
n += i
p = p[i:]
if e.nbuf < 5 {
return
}
e.enc.Encode(e.out[0:], e.buf[0:])
if _, e.err = e.w.Write(e.out[0:8]); e.err != nil {
return n, e.err
}
e.nbuf = 0
}
// Large interior chunks.
for len(p) >= 5 {
nn := len(e.out) / 8 * 5
if nn > len(p) {
nn = len(p)
nn -= nn % 5
}
e.enc.Encode(e.out[0:], p[0:nn])
if _, e.err = e.w.Write(e.out[0 : nn/5*8]); e.err != nil {
return n, e.err
}
n += nn
p = p[nn:]
}
// Trailing fringe.
copy(e.buf[:], p)
e.nbuf = len(p)
n += len(p)
return
}
// Close flushes any pending output from the encoder.
// It is an error to call Write after calling Close.
func (e *encoder) Close() error {
// If there's anything left in the buffer, flush it out
if e.err == nil && e.nbuf > 0 {
e.enc.Encode(e.out[0:], e.buf[0:e.nbuf])
encodedLen := e.enc.EncodedLen(e.nbuf)
e.nbuf = 0
_, e.err = e.w.Write(e.out[0:encodedLen])
}
return e.err
}
// NewEncoder returns a new base32 stream encoder. Data written to
// the returned writer will be encoded using enc and then written to w.
// Base32 encodings operate in 5-byte blocks; when finished
// writing, the caller must Close the returned encoder to flush any
// partially written blocks.
func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
return &encoder{enc: enc, w: w}
}
// EncodedLen returns the length in bytes of the base32 encoding
// of an input buffer of length n.
func (enc *Encoding) EncodedLen(n int) int {
if enc.padChar == NoPadding {
return n/5*8 + (n%5*8+4)/5
}
return (n + 4) / 5 * 8
}
/*
* Decoder
*/
type CorruptInputError int64
func (e CorruptInputError) Error() string {
return "illegal base32 data at input byte " + strconv.FormatInt(int64(e), 10)
}
// decode is like Decode but returns an additional 'end' value, which
// indicates if end-of-message padding was encountered and thus any
// additional data is an error. This method assumes that src has been
// stripped of all supported whitespace ('\r' and '\n').
func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
// Lift the nil check outside of the loop.
_ = enc.decodeMap
dsti := 0
olen := len(src)
for len(src) > 0 && !end {
// Decode quantum using the base32 alphabet
var dbuf [8]byte
dlen := 8
for j := 0; j < 8; {
if len(src) == 0 {
if enc.padChar != NoPadding {
// We have reached the end and are missing padding
return n, false, CorruptInputError(olen - len(src) - j)
}
// We have reached the end and are not expecting any padding
dlen, end = j, true
break
}
in := src[0]
src = src[1:]
if in == byte(enc.padChar) && j >= 2 && len(src) < 8 {
// We've reached the end and there's padding
if len(src)+j < 8-1 {
// not enough padding
return n, false, CorruptInputError(olen)
}
for k := 0; k < 8-1-j; k++ {
if len(src) > k && src[k] != byte(enc.padChar) {
// incorrect padding
return n, false, CorruptInputError(olen - len(src) + k - 1)
}
}
dlen, end = j, true
// 7, 5 and 2 are not valid padding lengths, and so 1, 3 and 6 are not
// valid dlen values. See RFC 4648 Section 6 "Base 32 Encoding" listing
// the five valid padding lengths, and Section 9 "Illustrations and
// Examples" for an illustration for how the 1st, 3rd and 6th base32
// src bytes do not yield enough information to decode a dst byte.
if dlen == 1 || dlen == 3 || dlen == 6 {
return n, false, CorruptInputError(olen - len(src) - 1)
}
break
}
dbuf[j] = enc.decodeMap[in]
if dbuf[j] == 0xFF {
return n, false, CorruptInputError(olen - len(src) - 1)
}
j++
}
// Pack 8x 5-bit source blocks into 5 byte destination
// quantum
switch dlen {
case 8:
dst[dsti+4] = dbuf[6]<<5 | dbuf[7]
n++
fallthrough
case 7:
dst[dsti+3] = dbuf[4]<<7 | dbuf[5]<<2 | dbuf[6]>>3
n++
fallthrough
case 5:
dst[dsti+2] = dbuf[3]<<4 | dbuf[4]>>1
n++
fallthrough
case 4:
dst[dsti+1] = dbuf[1]<<6 | dbuf[2]<<1 | dbuf[3]>>4
n++
fallthrough
case 2:
dst[dsti+0] = dbuf[0]<<3 | dbuf[1]>>2
n++
}
dsti += 5
}
return n, end, nil
}
// Decode decodes src using the encoding enc. It writes at most
// [Encoding.DecodedLen](len(src)) bytes to dst and returns the number of bytes
// written. The caller must ensure that dst is large enough to hold all
// the decoded data. If src contains invalid base32 data, it will return the
// number of bytes successfully written and [CorruptInputError].
// Newline characters (\r and \n) are ignored.
func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
buf := make([]byte, len(src))
l := stripNewlines(buf, src)
n, _, err = enc.decode(dst, buf[:l])
return
}
// AppendDecode appends the base32 decoded src to dst
// and returns the extended buffer.
// If the input is malformed, it returns the partially decoded src and an error.
// New line characters (\r and \n) are ignored.
func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) {
// Compute the output size without padding to avoid over allocating.
n := len(src)
for n > 0 && rune(src[n-1]) == enc.padChar {
n--
}
n = decodedLen(n, NoPadding)
dst = slices.Grow(dst, n)
n, err := enc.Decode(dst[len(dst):][:n], src)
return dst[:len(dst)+n], err
}
// DecodeString returns the bytes represented by the base32 string s.
// If the input is malformed, it returns the partially decoded data and
// [CorruptInputError]. New line characters (\r and \n) are ignored.
func (enc *Encoding) DecodeString(s string) ([]byte, error) {
buf := []byte(s)
l := stripNewlines(buf, buf)
n, _, err := enc.decode(buf, buf[:l])
return buf[:n], err
}
type decoder struct {
err error
enc *Encoding
r io.Reader
end bool // saw end of message
buf [1024]byte // leftover input
nbuf int
out []byte // leftover decoded output
outbuf [1024 / 8 * 5]byte
}
func readEncodedData(r io.Reader, buf []byte, min int, expectsPadding bool) (n int, err error) {
for n < min && err == nil {
var nn int
nn, err = r.Read(buf[n:])
n += nn
}
// data was read, less than min bytes could be read
if n < min && n > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
// no data was read, the buffer already contains some data
// when padding is disabled this is not an error, as the message can be of
// any length
if expectsPadding && min < 8 && n == 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return
}
func (d *decoder) Read(p []byte) (n int, err error) {
// Use leftover decoded output from last read.
if len(d.out) > 0 {
n = copy(p, d.out)
d.out = d.out[n:]
if len(d.out) == 0 {
return n, d.err
}
return n, nil
}
if d.err != nil {
return 0, d.err
}
// Read a chunk.
nn := (len(p) + 4) / 5 * 8
if nn < 8 {
nn = 8
}
if nn > len(d.buf) {
nn = len(d.buf)
}
// Minimum amount of bytes that needs to be read each cycle
var min int
var expectsPadding bool
if d.enc.padChar == NoPadding {
min = 1
expectsPadding = false
} else {
min = 8 - d.nbuf
expectsPadding = true
}
nn, d.err = readEncodedData(d.r, d.buf[d.nbuf:nn], min, expectsPadding)
d.nbuf += nn
if d.nbuf < min {
return 0, d.err
}
if nn > 0 && d.end {
return 0, CorruptInputError(0)
}
// Decode chunk into p, or d.out and then p if p is too small.
var nr int
if d.enc.padChar == NoPadding {
nr = d.nbuf
} else {
nr = d.nbuf / 8 * 8
}
nw := d.enc.DecodedLen(d.nbuf)
if nw > len(p) {
nw, d.end, err = d.enc.decode(d.outbuf[0:], d.buf[0:nr])
d.out = d.outbuf[0:nw]
n = copy(p, d.out)
d.out = d.out[n:]
} else {
n, d.end, err = d.enc.decode(p, d.buf[0:nr])
}
d.nbuf -= nr
for i := 0; i < d.nbuf; i++ {
d.buf[i] = d.buf[i+nr]
}
if err != nil && (d.err == nil || d.err == io.EOF) {
d.err = err
}
if len(d.out) > 0 {
// We cannot return all the decoded bytes to the caller in this
// invocation of Read, so we return a nil error to ensure that Read
// will be called again. The error stored in d.err, if any, will be
// returned with the last set of decoded bytes.
return n, nil
}
return n, d.err
}
type newlineFilteringReader struct {
wrapped io.Reader
}
// stripNewlines removes newline characters and returns the number
// of non-newline characters copied to dst.
func stripNewlines(dst, src []byte) int {
offset := 0
for _, b := range src {
if b == '\r' || b == '\n' {
continue
}
dst[offset] = b
offset++
}
return offset
}
func (r *newlineFilteringReader) Read(p []byte) (int, error) {
n, err := r.wrapped.Read(p)
for n > 0 {
s := p[0:n]
offset := stripNewlines(s, s)
if err != nil || offset > 0 {
return offset, err
}
// Previous buffer entirely whitespace, read again
n, err = r.wrapped.Read(p)
}
return n, err
}
// NewDecoder constructs a new base32 stream decoder.
func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
return &decoder{enc: enc, r: &newlineFilteringReader{r}}
}
// DecodedLen returns the maximum length in bytes of the decoded data
// corresponding to n bytes of base32-encoded data.
func (enc *Encoding) DecodedLen(n int) int {
return decodedLen(n, enc.padChar)
}
func decodedLen(n int, padChar rune) int {
if padChar == NoPadding {
return n/8*5 + n%8*5/8
}
return n / 8 * 5
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package base64 implements base64 encoding as specified by RFC 4648.
package base64
import (
"internal/byteorder"
"io"
"slices"
"strconv"
)
/*
* Encodings
*/
// An Encoding is a radix 64 encoding/decoding scheme, defined by a
// 64-character alphabet. The most common encoding is the "base64"
// encoding defined in RFC 4648 and used in MIME (RFC 2045) and PEM
// (RFC 1421). RFC 4648 also defines an alternate encoding, which is
// the standard encoding with - and _ substituted for + and /.
type Encoding struct {
encode [64]byte // mapping of symbol index to symbol byte value
decodeMap [256]uint8 // mapping of symbol byte value to symbol index
padChar rune
strict bool
}
const (
StdPadding rune = '=' // Standard padding character
NoPadding rune = -1 // No padding
)
const (
decodeMapInitialize = "" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
invalidIndex = '\xff'
)
// NewEncoding returns a new padded Encoding defined by the given alphabet,
// which must be a 64-byte string that contains unique byte values and
// does not contain the padding character or CR / LF ('\r', '\n').
// The alphabet is treated as a sequence of byte values
// without any special treatment for multi-byte UTF-8.
// The resulting Encoding uses the default padding character ('='),
// which may be changed or disabled via [Encoding.WithPadding].
func NewEncoding(encoder string) *Encoding {
if len(encoder) != 64 {
panic("encoding alphabet is not 64-bytes long")
}
e := new(Encoding)
e.padChar = StdPadding
copy(e.encode[:], encoder)
copy(e.decodeMap[:], decodeMapInitialize)
for i := 0; i < len(encoder); i++ {
// Note: While we document that the alphabet cannot contain
// the padding character, we do not enforce it since we do not know
// if the caller intends to switch the padding from StdPadding later.
switch {
case encoder[i] == '\n' || encoder[i] == '\r':
panic("encoding alphabet contains newline character")
case e.decodeMap[encoder[i]] != invalidIndex:
panic("encoding alphabet includes duplicate symbols")
}
e.decodeMap[encoder[i]] = uint8(i)
}
return e
}
// WithPadding creates a new encoding identical to enc except
// with a specified padding character, or [NoPadding] to disable padding.
// The padding character must not be '\r' or '\n',
// must not be contained in the encoding's alphabet,
// must not be negative, and must be a rune equal or below '\xff'.
// Padding characters above '\x7f' are encoded as their exact byte value
// rather than using the UTF-8 representation of the codepoint.
func (enc Encoding) WithPadding(padding rune) *Encoding {
switch {
case padding < NoPadding || padding == '\r' || padding == '\n' || padding > 0xff:
panic("invalid padding")
case padding != NoPadding && enc.decodeMap[byte(padding)] != invalidIndex:
panic("padding contained in alphabet")
}
enc.padChar = padding
return &enc
}
// Strict creates a new encoding identical to enc except with
// strict decoding enabled. In this mode, the decoder requires that
// trailing padding bits are zero, as described in RFC 4648 section 3.5.
//
// Note that the input is still malleable, as new line characters
// (CR and LF) are still ignored.
func (enc Encoding) Strict() *Encoding {
enc.strict = true
return &enc
}
// StdEncoding is the standard base64 encoding, as defined in RFC 4648.
var StdEncoding = NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/")
// URLEncoding is the alternate base64 encoding defined in RFC 4648.
// It is typically used in URLs and file names.
var URLEncoding = NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_")
// RawStdEncoding is the standard raw, unpadded base64 encoding,
// as defined in RFC 4648 section 3.2.
// This is the same as [StdEncoding] but omits padding characters.
var RawStdEncoding = StdEncoding.WithPadding(NoPadding)
// RawURLEncoding is the unpadded alternate base64 encoding defined in RFC 4648.
// It is typically used in URLs and file names.
// This is the same as [URLEncoding] but omits padding characters.
var RawURLEncoding = URLEncoding.WithPadding(NoPadding)
/*
* Encoder
*/
// Encode encodes src using the encoding enc,
// writing [Encoding.EncodedLen](len(src)) bytes to dst.
//
// The encoding pads the output to a multiple of 4 bytes,
// so Encode is not appropriate for use on individual blocks
// of a large data stream. Use [NewEncoder] instead.
func (enc *Encoding) Encode(dst, src []byte) {
if len(src) == 0 {
return
}
// enc is a pointer receiver, so the use of enc.encode within the hot
// loop below means a nil check at every operation. Lift that nil check
// outside of the loop to speed up the encoder.
_ = enc.encode
di, si := 0, 0
n := (len(src) / 3) * 3
for si < n {
// Convert 3x 8bit source bytes into 4 bytes
val := uint(src[si+0])<<16 | uint(src[si+1])<<8 | uint(src[si+2])
dst[di+0] = enc.encode[val>>18&0x3F]
dst[di+1] = enc.encode[val>>12&0x3F]
dst[di+2] = enc.encode[val>>6&0x3F]
dst[di+3] = enc.encode[val&0x3F]
si += 3
di += 4
}
remain := len(src) - si
if remain == 0 {
return
}
// Add the remaining small block
val := uint(src[si+0]) << 16
if remain == 2 {
val |= uint(src[si+1]) << 8
}
dst[di+0] = enc.encode[val>>18&0x3F]
dst[di+1] = enc.encode[val>>12&0x3F]
switch remain {
case 2:
dst[di+2] = enc.encode[val>>6&0x3F]
if enc.padChar != NoPadding {
dst[di+3] = byte(enc.padChar)
}
case 1:
if enc.padChar != NoPadding {
dst[di+2] = byte(enc.padChar)
dst[di+3] = byte(enc.padChar)
}
}
}
// AppendEncode appends the base64 encoded src to dst
// and returns the extended buffer.
func (enc *Encoding) AppendEncode(dst, src []byte) []byte {
n := enc.EncodedLen(len(src))
dst = slices.Grow(dst, n)
enc.Encode(dst[len(dst):][:n], src)
return dst[:len(dst)+n]
}
// EncodeToString returns the base64 encoding of src.
func (enc *Encoding) EncodeToString(src []byte) string {
buf := make([]byte, enc.EncodedLen(len(src)))
enc.Encode(buf, src)
return string(buf)
}
type encoder struct {
err error
enc *Encoding
w io.Writer
buf [3]byte // buffered data waiting to be encoded
nbuf int // number of bytes in buf
out [1024]byte // output buffer
}
func (e *encoder) Write(p []byte) (n int, err error) {
if e.err != nil {
return 0, e.err
}
// Leading fringe.
if e.nbuf > 0 {
var i int
for i = 0; i < len(p) && e.nbuf < 3; i++ {
e.buf[e.nbuf] = p[i]
e.nbuf++
}
n += i
p = p[i:]
if e.nbuf < 3 {
return
}
e.enc.Encode(e.out[:], e.buf[:])
if _, e.err = e.w.Write(e.out[:4]); e.err != nil {
return n, e.err
}
e.nbuf = 0
}
// Large interior chunks.
for len(p) >= 3 {
nn := len(e.out) / 4 * 3
if nn > len(p) {
nn = len(p)
nn -= nn % 3
}
e.enc.Encode(e.out[:], p[:nn])
if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil {
return n, e.err
}
n += nn
p = p[nn:]
}
// Trailing fringe.
copy(e.buf[:], p)
e.nbuf = len(p)
n += len(p)
return
}
// Close flushes any pending output from the encoder.
// It is an error to call Write after calling Close.
func (e *encoder) Close() error {
// If there's anything left in the buffer, flush it out
if e.err == nil && e.nbuf > 0 {
e.enc.Encode(e.out[:], e.buf[:e.nbuf])
_, e.err = e.w.Write(e.out[:e.enc.EncodedLen(e.nbuf)])
e.nbuf = 0
}
return e.err
}
// NewEncoder returns a new base64 stream encoder. Data written to
// the returned writer will be encoded using enc and then written to w.
// Base64 encodings operate in 4-byte blocks; when finished
// writing, the caller must Close the returned encoder to flush any
// partially written blocks.
func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
return &encoder{enc: enc, w: w}
}
// EncodedLen returns the length in bytes of the base64 encoding
// of an input buffer of length n.
func (enc *Encoding) EncodedLen(n int) int {
if enc.padChar == NoPadding {
return n/3*4 + (n%3*8+5)/6 // minimum # chars at 6 bits per char
}
return (n + 2) / 3 * 4 // minimum # 4-char quanta, 3 bytes each
}
/*
* Decoder
*/
type CorruptInputError int64
func (e CorruptInputError) Error() string {
return "illegal base64 data at input byte " + strconv.FormatInt(int64(e), 10)
}
// decodeQuantum decodes up to 4 base64 bytes. The received parameters are
// the destination buffer dst, the source buffer src and an index in the
// source buffer si.
// It returns the number of bytes read from src, the number of bytes written
// to dst, and an error, if any.
func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err error) {
// Decode quantum using the base64 alphabet
var dbuf [4]byte
dlen := 4
// Lift the nil check outside of the loop.
_ = enc.decodeMap
for j := 0; j < len(dbuf); j++ {
if len(src) == si {
switch {
case j == 0:
return si, 0, nil
case j == 1, enc.padChar != NoPadding:
return si, 0, CorruptInputError(si - j)
}
dlen = j
break
}
in := src[si]
si++
out := enc.decodeMap[in]
if out != 0xff {
dbuf[j] = out
continue
}
if in == '\n' || in == '\r' {
j--
continue
}
if rune(in) != enc.padChar {
return si, 0, CorruptInputError(si - 1)
}
// We've reached the end and there's padding
switch j {
case 0, 1:
// incorrect padding
return si, 0, CorruptInputError(si - 1)
case 2:
// "==" is expected, the first "=" is already consumed.
// skip over newlines
for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
si++
}
if si == len(src) {
// not enough padding
return si, 0, CorruptInputError(len(src))
}
if rune(src[si]) != enc.padChar {
// incorrect padding
return si, 0, CorruptInputError(si - 1)
}
si++
}
// skip over newlines
for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
si++
}
if si < len(src) {
// trailing garbage
err = CorruptInputError(si)
}
dlen = j
break
}
// Convert 4x 6bit source bytes into 3 bytes
val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
switch dlen {
case 4:
dst[2] = dbuf[2]
dbuf[2] = 0
fallthrough
case 3:
dst[1] = dbuf[1]
if enc.strict && dbuf[2] != 0 {
return si, 0, CorruptInputError(si - 1)
}
dbuf[1] = 0
fallthrough
case 2:
dst[0] = dbuf[0]
if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
return si, 0, CorruptInputError(si - 2)
}
}
return si, dlen - 1, err
}
// AppendDecode appends the base64 decoded src to dst
// and returns the extended buffer.
// If the input is malformed, it returns the partially decoded src and an error.
// New line characters (\r and \n) are ignored.
func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) {
// Compute the output size without padding to avoid over allocating.
n := len(src)
for n > 0 && rune(src[n-1]) == enc.padChar {
n--
}
n = decodedLen(n, NoPadding)
dst = slices.Grow(dst, n)
n, err := enc.Decode(dst[len(dst):][:n], src)
return dst[:len(dst)+n], err
}
// DecodeString returns the bytes represented by the base64 string s.
// If the input is malformed, it returns the partially decoded data and
// [CorruptInputError]. New line characters (\r and \n) are ignored.
func (enc *Encoding) DecodeString(s string) ([]byte, error) {
dbuf := make([]byte, enc.DecodedLen(len(s)))
n, err := enc.Decode(dbuf, []byte(s))
return dbuf[:n], err
}
type decoder struct {
err error
readErr error // error from r.Read
enc *Encoding
r io.Reader
buf [1024]byte // leftover input
nbuf int
out []byte // leftover decoded output
outbuf [1024 / 4 * 3]byte
}
func (d *decoder) Read(p []byte) (n int, err error) {
// Use leftover decoded output from last read.
if len(d.out) > 0 {
n = copy(p, d.out)
d.out = d.out[n:]
return n, nil
}
if d.err != nil {
return 0, d.err
}
// This code assumes that d.r strips supported whitespace ('\r' and '\n').
// Refill buffer.
for d.nbuf < 4 && d.readErr == nil {
nn := len(p) / 3 * 4
if nn < 4 {
nn = 4
}
if nn > len(d.buf) {
nn = len(d.buf)
}
nn, d.readErr = d.r.Read(d.buf[d.nbuf:nn])
d.nbuf += nn
}
if d.nbuf < 4 {
if d.enc.padChar == NoPadding && d.nbuf > 0 {
// Decode final fragment, without padding.
var nw int
nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:d.nbuf])
d.nbuf = 0
d.out = d.outbuf[:nw]
n = copy(p, d.out)
d.out = d.out[n:]
if n > 0 || len(p) == 0 && len(d.out) > 0 {
return n, nil
}
if d.err != nil {
return 0, d.err
}
}
d.err = d.readErr
if d.err == io.EOF && d.nbuf > 0 {
d.err = io.ErrUnexpectedEOF
}
return 0, d.err
}
// Decode chunk into p, or d.out and then p if p is too small.
nr := d.nbuf / 4 * 4
nw := d.nbuf / 4 * 3
if nw > len(p) {
nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:nr])
d.out = d.outbuf[:nw]
n = copy(p, d.out)
d.out = d.out[n:]
} else {
n, d.err = d.enc.Decode(p, d.buf[:nr])
}
d.nbuf -= nr
copy(d.buf[:d.nbuf], d.buf[nr:])
return n, d.err
}
// Decode decodes src using the encoding enc. It writes at most
// [Encoding.DecodedLen](len(src)) bytes to dst and returns the number of bytes
// written. The caller must ensure that dst is large enough to hold all
// the decoded data. If src contains invalid base64 data, it will return the
// number of bytes successfully written and [CorruptInputError].
// New line characters (\r and \n) are ignored.
func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
if len(src) == 0 {
return 0, nil
}
// Lift the nil check outside of the loop. enc.decodeMap is directly
// used later in this function, to let the compiler know that the
// receiver can't be nil.
_ = enc.decodeMap
si := 0
for strconv.IntSize >= 64 && len(src)-si >= 8 && len(dst)-n >= 8 {
src2 := src[si : si+8]
if dn, ok := assemble64(
enc.decodeMap[src2[0]],
enc.decodeMap[src2[1]],
enc.decodeMap[src2[2]],
enc.decodeMap[src2[3]],
enc.decodeMap[src2[4]],
enc.decodeMap[src2[5]],
enc.decodeMap[src2[6]],
enc.decodeMap[src2[7]],
); ok {
byteorder.BEPutUint64(dst[n:], dn)
n += 6
si += 8
} else {
var ninc int
si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
n += ninc
if err != nil {
return n, err
}
}
}
for len(src)-si >= 4 && len(dst)-n >= 4 {
src2 := src[si : si+4]
if dn, ok := assemble32(
enc.decodeMap[src2[0]],
enc.decodeMap[src2[1]],
enc.decodeMap[src2[2]],
enc.decodeMap[src2[3]],
); ok {
byteorder.BEPutUint32(dst[n:], dn)
n += 3
si += 4
} else {
var ninc int
si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
n += ninc
if err != nil {
return n, err
}
}
}
for si < len(src) {
var ninc int
si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
n += ninc
if err != nil {
return n, err
}
}
return n, err
}
// assemble32 assembles 4 base64 digits into 3 bytes.
// Each digit comes from the decode map, and will be 0xff
// if it came from an invalid character.
func assemble32(n1, n2, n3, n4 byte) (dn uint32, ok bool) {
// Check that all the digits are valid. If any of them was 0xff, their
// bitwise OR will be 0xff.
if n1|n2|n3|n4 == 0xff {
return 0, false
}
return uint32(n1)<<26 |
uint32(n2)<<20 |
uint32(n3)<<14 |
uint32(n4)<<8,
true
}
// assemble64 assembles 8 base64 digits into 6 bytes.
// Each digit comes from the decode map, and will be 0xff
// if it came from an invalid character.
func assemble64(n1, n2, n3, n4, n5, n6, n7, n8 byte) (dn uint64, ok bool) {
// Check that all the digits are valid. If any of them was 0xff, their
// bitwise OR will be 0xff.
if n1|n2|n3|n4|n5|n6|n7|n8 == 0xff {
return 0, false
}
return uint64(n1)<<58 |
uint64(n2)<<52 |
uint64(n3)<<46 |
uint64(n4)<<40 |
uint64(n5)<<34 |
uint64(n6)<<28 |
uint64(n7)<<22 |
uint64(n8)<<16,
true
}
type newlineFilteringReader struct {
wrapped io.Reader
}
func (r *newlineFilteringReader) Read(p []byte) (int, error) {
n, err := r.wrapped.Read(p)
for n > 0 {
offset := 0
for i, b := range p[:n] {
if b != '\r' && b != '\n' {
if i != offset {
p[offset] = b
}
offset++
}
}
if offset > 0 {
return offset, err
}
// Previous buffer entirely whitespace, read again
n, err = r.wrapped.Read(p)
}
return n, err
}
// NewDecoder constructs a new base64 stream decoder.
func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
return &decoder{enc: enc, r: &newlineFilteringReader{r}}
}
// DecodedLen returns the maximum length in bytes of the decoded data
// corresponding to n bytes of base64-encoded data.
func (enc *Encoding) DecodedLen(n int) int {
return decodedLen(n, enc.padChar)
}
func decodedLen(n int, padChar rune) int {
if padChar == NoPadding {
// Unpadded data may end with partial block of 2-3 characters.
return n/4*3 + n%4*6/8
}
// Padded base64 should always be a multiple of 4 characters in length.
return n / 4 * 3
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package binary implements simple translation between numbers and byte
// sequences and encoding and decoding of varints.
//
// Numbers are translated by reading and writing fixed-size values.
// A fixed-size value is either a fixed-size arithmetic
// type (bool, int8, uint8, int16, float32, complex64, ...)
// or an array or struct containing only fixed-size values.
//
// The varint functions encode and decode single integer values using
// a variable-length encoding; smaller values require fewer bytes.
// For a specification, see
// https://developers.google.com/protocol-buffers/docs/encoding.
//
// This package favors simplicity over efficiency. Clients that require
// high-performance serialization, especially for large data structures,
// should look at more advanced solutions such as the [encoding/gob]
// package or [google.golang.org/protobuf] for protocol buffers.
package binary
import (
"errors"
"io"
"math"
"reflect"
"slices"
"sync"
)
var errBufferTooSmall = errors.New("buffer too small")
// A ByteOrder specifies how to convert byte slices into
// 16-, 32-, or 64-bit unsigned integers.
//
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
type ByteOrder interface {
Uint16([]byte) uint16
Uint32([]byte) uint32
Uint64([]byte) uint64
PutUint16([]byte, uint16)
PutUint32([]byte, uint32)
PutUint64([]byte, uint64)
String() string
}
// AppendByteOrder specifies how to append 16-, 32-, or 64-bit unsigned integers
// into a byte slice.
//
// It is implemented by [LittleEndian], [BigEndian], and [NativeEndian].
type AppendByteOrder interface {
AppendUint16([]byte, uint16) []byte
AppendUint32([]byte, uint32) []byte
AppendUint64([]byte, uint64) []byte
String() string
}
// LittleEndian is the little-endian implementation of [ByteOrder] and [AppendByteOrder].
var LittleEndian littleEndian
// BigEndian is the big-endian implementation of [ByteOrder] and [AppendByteOrder].
var BigEndian bigEndian
type littleEndian struct{}
// Uint16 returns the uint16 representation of b[0:2].
func (littleEndian) Uint16(b []byte) uint16 {
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint16(b[0]) | uint16(b[1])<<8
}
// PutUint16 stores v into b[0:2].
func (littleEndian) PutUint16(b []byte, v uint16) {
_ = b[1] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
}
// AppendUint16 appends the bytes of v to b and returns the appended slice.
func (littleEndian) AppendUint16(b []byte, v uint16) []byte {
return append(b,
byte(v),
byte(v>>8),
)
}
// Uint32 returns the uint32 representation of b[0:4].
func (littleEndian) Uint32(b []byte) uint32 {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
}
// PutUint32 stores v into b[0:4].
func (littleEndian) PutUint32(b []byte, v uint32) {
_ = b[3] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
}
// AppendUint32 appends the bytes of v to b and returns the appended slice.
func (littleEndian) AppendUint32(b []byte, v uint32) []byte {
return append(b,
byte(v),
byte(v>>8),
byte(v>>16),
byte(v>>24),
)
}
// Uint64 returns the uint64 representation of b[0:8].
func (littleEndian) Uint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
}
// PutUint64 stores v into b[0:8].
func (littleEndian) PutUint64(b []byte, v uint64) {
_ = b[7] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
b[4] = byte(v >> 32)
b[5] = byte(v >> 40)
b[6] = byte(v >> 48)
b[7] = byte(v >> 56)
}
// AppendUint64 appends the bytes of v to b and returns the appended slice.
func (littleEndian) AppendUint64(b []byte, v uint64) []byte {
return append(b,
byte(v),
byte(v>>8),
byte(v>>16),
byte(v>>24),
byte(v>>32),
byte(v>>40),
byte(v>>48),
byte(v>>56),
)
}
func (littleEndian) String() string { return "LittleEndian" }
func (littleEndian) GoString() string { return "binary.LittleEndian" }
type bigEndian struct{}
// Uint16 returns the uint16 representation of b[0:2].
func (bigEndian) Uint16(b []byte) uint16 {
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint16(b[1]) | uint16(b[0])<<8
}
// PutUint16 stores v into b[0:2].
func (bigEndian) PutUint16(b []byte, v uint16) {
_ = b[1] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 8)
b[1] = byte(v)
}
// AppendUint16 appends the bytes of v to b and returns the appended slice.
func (bigEndian) AppendUint16(b []byte, v uint16) []byte {
return append(b,
byte(v>>8),
byte(v),
)
}
// Uint32 returns the uint32 representation of b[0:4].
func (bigEndian) Uint32(b []byte) uint32 {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
}
// PutUint32 stores v into b[0:4].
func (bigEndian) PutUint32(b []byte, v uint32) {
_ = b[3] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 24)
b[1] = byte(v >> 16)
b[2] = byte(v >> 8)
b[3] = byte(v)
}
// AppendUint32 appends the bytes of v to b and returns the appended slice.
func (bigEndian) AppendUint32(b []byte, v uint32) []byte {
return append(b,
byte(v>>24),
byte(v>>16),
byte(v>>8),
byte(v),
)
}
// Uint64 returns the uint64 representation of b[0:8].
func (bigEndian) Uint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
}
// PutUint64 stores v into b[0:8].
func (bigEndian) PutUint64(b []byte, v uint64) {
_ = b[7] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 56)
b[1] = byte(v >> 48)
b[2] = byte(v >> 40)
b[3] = byte(v >> 32)
b[4] = byte(v >> 24)
b[5] = byte(v >> 16)
b[6] = byte(v >> 8)
b[7] = byte(v)
}
// AppendUint64 appends the bytes of v to b and returns the appended slice.
func (bigEndian) AppendUint64(b []byte, v uint64) []byte {
return append(b,
byte(v>>56),
byte(v>>48),
byte(v>>40),
byte(v>>32),
byte(v>>24),
byte(v>>16),
byte(v>>8),
byte(v),
)
}
func (bigEndian) String() string { return "BigEndian" }
func (bigEndian) GoString() string { return "binary.BigEndian" }
func (nativeEndian) String() string { return "NativeEndian" }
func (nativeEndian) GoString() string { return "binary.NativeEndian" }
// Read reads structured binary data from r into data.
// Data must be a pointer to a fixed-size value or a slice
// of fixed-size values.
// Bytes read from r are decoded using the specified byte order
// and written to successive fields of the data.
// When decoding boolean values, a zero byte is decoded as false, and
// any other non-zero byte is decoded as true.
// When reading into structs, the field data for fields with
// blank (_) field names is skipped; i.e., blank field names
// may be used for padding.
// When reading into a struct, all non-blank fields must be exported
// or Read may panic.
//
// The error is [io.EOF] only if no bytes were read.
// If an [io.EOF] happens after reading some but not all the bytes,
// Read returns [io.ErrUnexpectedEOF].
func Read(r io.Reader, order ByteOrder, data any) error {
// Fast path for basic types and slices.
if n, _ := intDataSize(data); n != 0 {
bs := make([]byte, n)
if _, err := io.ReadFull(r, bs); err != nil {
return err
}
if decodeFast(bs, order, data) {
return nil
}
}
// Fallback to reflect-based decoding.
v := reflect.ValueOf(data)
size := -1
switch v.Kind() {
case reflect.Pointer:
v = v.Elem()
size = dataSize(v)
case reflect.Slice:
size = dataSize(v)
}
if size < 0 {
return errors.New("binary.Read: invalid type " + reflect.TypeOf(data).String())
}
d := &decoder{order: order, buf: make([]byte, size)}
if _, err := io.ReadFull(r, d.buf); err != nil {
return err
}
d.value(v)
return nil
}
// Decode decodes binary data from buf into data according to
// the given byte order.
// It returns an error if buf is too small, otherwise the number of
// bytes consumed from buf.
func Decode(buf []byte, order ByteOrder, data any) (int, error) {
if n, _ := intDataSize(data); n != 0 {
if len(buf) < n {
return 0, errBufferTooSmall
}
if decodeFast(buf, order, data) {
return n, nil
}
}
// Fallback to reflect-based decoding.
v := reflect.ValueOf(data)
size := -1
switch v.Kind() {
case reflect.Pointer:
v = v.Elem()
size = dataSize(v)
case reflect.Slice:
size = dataSize(v)
}
if size < 0 {
return 0, errors.New("binary.Decode: invalid type " + reflect.TypeOf(data).String())
}
if len(buf) < size {
return 0, errBufferTooSmall
}
d := &decoder{order: order, buf: buf[:size]}
d.value(v)
return size, nil
}
func decodeFast(bs []byte, order ByteOrder, data any) bool {
switch data := data.(type) {
case *bool:
*data = bs[0] != 0
case *int8:
*data = int8(bs[0])
case *uint8:
*data = bs[0]
case *int16:
*data = int16(order.Uint16(bs))
case *uint16:
*data = order.Uint16(bs)
case *int32:
*data = int32(order.Uint32(bs))
case *uint32:
*data = order.Uint32(bs)
case *int64:
*data = int64(order.Uint64(bs))
case *uint64:
*data = order.Uint64(bs)
case *float32:
*data = math.Float32frombits(order.Uint32(bs))
case *float64:
*data = math.Float64frombits(order.Uint64(bs))
case []bool:
for i, x := range bs { // Easier to loop over the input for 8-bit values.
data[i] = x != 0
}
case []int8:
for i, x := range bs {
data[i] = int8(x)
}
case []uint8:
copy(data, bs)
case []int16:
for i := range data {
data[i] = int16(order.Uint16(bs[2*i:]))
}
case []uint16:
for i := range data {
data[i] = order.Uint16(bs[2*i:])
}
case []int32:
for i := range data {
data[i] = int32(order.Uint32(bs[4*i:]))
}
case []uint32:
for i := range data {
data[i] = order.Uint32(bs[4*i:])
}
case []int64:
for i := range data {
data[i] = int64(order.Uint64(bs[8*i:]))
}
case []uint64:
for i := range data {
data[i] = order.Uint64(bs[8*i:])
}
case []float32:
for i := range data {
data[i] = math.Float32frombits(order.Uint32(bs[4*i:]))
}
case []float64:
for i := range data {
data[i] = math.Float64frombits(order.Uint64(bs[8*i:]))
}
default:
return false
}
return true
}
// Write writes the binary representation of data into w.
// Data must be a fixed-size value or a slice of fixed-size
// values, or a pointer to such data.
// Boolean values encode as one byte: 1 for true, and 0 for false.
// Bytes written to w are encoded using the specified byte order
// and read from successive fields of the data.
// When writing structs, zero values are written for fields
// with blank (_) field names.
func Write(w io.Writer, order ByteOrder, data any) error {
// Fast path for basic types and slices.
if n, bs := intDataSize(data); n != 0 {
if bs == nil {
bs = make([]byte, n)
encodeFast(bs, order, data)
}
_, err := w.Write(bs)
return err
}
// Fallback to reflect-based encoding.
v := reflect.Indirect(reflect.ValueOf(data))
size := dataSize(v)
if size < 0 {
return errors.New("binary.Write: some values are not fixed-sized in type " + reflect.TypeOf(data).String())
}
buf := make([]byte, size)
e := &encoder{order: order, buf: buf}
e.value(v)
_, err := w.Write(buf)
return err
}
// Encode encodes the binary representation of data into buf according to
// the given byte order.
// It returns an error if buf is too small, otherwise the number of
// bytes written into buf.
func Encode(buf []byte, order ByteOrder, data any) (int, error) {
// Fast path for basic types and slices.
if n, _ := intDataSize(data); n != 0 {
if len(buf) < n {
return 0, errBufferTooSmall
}
encodeFast(buf, order, data)
return n, nil
}
// Fallback to reflect-based encoding.
v := reflect.Indirect(reflect.ValueOf(data))
size := dataSize(v)
if size < 0 {
return 0, errors.New("binary.Encode: some values are not fixed-sized in type " + reflect.TypeOf(data).String())
}
if len(buf) < size {
return 0, errBufferTooSmall
}
e := &encoder{order: order, buf: buf}
e.value(v)
return size, nil
}
// Append appends the binary representation of data to buf.
// buf may be nil, in which case a new buffer will be allocated.
// See [Write] on which data are acceptable.
// It returns the (possibly extended) buffer containing data or an error.
func Append(buf []byte, order ByteOrder, data any) ([]byte, error) {
// Fast path for basic types and slices.
if n, _ := intDataSize(data); n != 0 {
buf, pos := ensure(buf, n)
encodeFast(pos, order, data)
return buf, nil
}
// Fallback to reflect-based encoding.
v := reflect.Indirect(reflect.ValueOf(data))
size := dataSize(v)
if size < 0 {
return nil, errors.New("binary.Append: some values are not fixed-sized in type " + reflect.TypeOf(data).String())
}
buf, pos := ensure(buf, size)
e := &encoder{order: order, buf: pos}
e.value(v)
return buf, nil
}
func encodeFast(bs []byte, order ByteOrder, data any) {
switch v := data.(type) {
case *bool:
if *v {
bs[0] = 1
} else {
bs[0] = 0
}
case bool:
if v {
bs[0] = 1
} else {
bs[0] = 0
}
case []bool:
for i, x := range v {
if x {
bs[i] = 1
} else {
bs[i] = 0
}
}
case *int8:
bs[0] = byte(*v)
case int8:
bs[0] = byte(v)
case []int8:
for i, x := range v {
bs[i] = byte(x)
}
case *uint8:
bs[0] = *v
case uint8:
bs[0] = v
case []uint8:
copy(bs, v)
case *int16:
order.PutUint16(bs, uint16(*v))
case int16:
order.PutUint16(bs, uint16(v))
case []int16:
for i, x := range v {
order.PutUint16(bs[2*i:], uint16(x))
}
case *uint16:
order.PutUint16(bs, *v)
case uint16:
order.PutUint16(bs, v)
case []uint16:
for i, x := range v {
order.PutUint16(bs[2*i:], x)
}
case *int32:
order.PutUint32(bs, uint32(*v))
case int32:
order.PutUint32(bs, uint32(v))
case []int32:
for i, x := range v {
order.PutUint32(bs[4*i:], uint32(x))
}
case *uint32:
order.PutUint32(bs, *v)
case uint32:
order.PutUint32(bs, v)
case []uint32:
for i, x := range v {
order.PutUint32(bs[4*i:], x)
}
case *int64:
order.PutUint64(bs, uint64(*v))
case int64:
order.PutUint64(bs, uint64(v))
case []int64:
for i, x := range v {
order.PutUint64(bs[8*i:], uint64(x))
}
case *uint64:
order.PutUint64(bs, *v)
case uint64:
order.PutUint64(bs, v)
case []uint64:
for i, x := range v {
order.PutUint64(bs[8*i:], x)
}
case *float32:
order.PutUint32(bs, math.Float32bits(*v))
case float32:
order.PutUint32(bs, math.Float32bits(v))
case []float32:
for i, x := range v {
order.PutUint32(bs[4*i:], math.Float32bits(x))
}
case *float64:
order.PutUint64(bs, math.Float64bits(*v))
case float64:
order.PutUint64(bs, math.Float64bits(v))
case []float64:
for i, x := range v {
order.PutUint64(bs[8*i:], math.Float64bits(x))
}
}
}
// Size returns how many bytes [Write] would generate to encode the value v, which
// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data.
// If v is neither of these, Size returns -1.
func Size(v any) int {
switch data := v.(type) {
case bool, int8, uint8:
return 1
case *bool:
if data == nil {
return -1
}
return 1
case *int8:
if data == nil {
return -1
}
return 1
case *uint8:
if data == nil {
return -1
}
return 1
case []bool:
return len(data)
case []int8:
return len(data)
case []uint8:
return len(data)
case int16, uint16:
return 2
case *int16:
if data == nil {
return -1
}
return 2
case *uint16:
if data == nil {
return -1
}
return 2
case []int16:
return 2 * len(data)
case []uint16:
return 2 * len(data)
case int32, uint32:
return 4
case *int32:
if data == nil {
return -1
}
return 4
case *uint32:
if data == nil {
return -1
}
return 4
case []int32:
return 4 * len(data)
case []uint32:
return 4 * len(data)
case int64, uint64:
return 8
case *int64:
if data == nil {
return -1
}
return 8
case *uint64:
if data == nil {
return -1
}
return 8
case []int64:
return 8 * len(data)
case []uint64:
return 8 * len(data)
case float32:
return 4
case *float32:
if data == nil {
return -1
}
return 4
case float64:
return 8
case *float64:
if data == nil {
return -1
}
return 8
case []float32:
return 4 * len(data)
case []float64:
return 8 * len(data)
}
return dataSize(reflect.Indirect(reflect.ValueOf(v)))
}
var structSize sync.Map // map[reflect.Type]int
// dataSize returns the number of bytes the actual data represented by v occupies in memory.
// For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice
// it returns the length of the slice times the element size and does not count the memory
// occupied by the header. If the type of v is not acceptable, dataSize returns -1.
func dataSize(v reflect.Value) int {
switch v.Kind() {
case reflect.Slice, reflect.Array:
t := v.Type().Elem()
if size, ok := structSize.Load(t); ok {
return size.(int) * v.Len()
}
size := sizeof(t)
if size >= 0 {
if t.Kind() == reflect.Struct {
structSize.Store(t, size)
}
return size * v.Len()
}
case reflect.Struct:
t := v.Type()
if size, ok := structSize.Load(t); ok {
return size.(int)
}
size := sizeof(t)
structSize.Store(t, size)
return size
default:
if v.IsValid() {
return sizeof(v.Type())
}
}
return -1
}
// sizeof returns the size >= 0 of variables for the given type or -1 if the type is not acceptable.
func sizeof(t reflect.Type) int {
switch t.Kind() {
case reflect.Array:
if s := sizeof(t.Elem()); s >= 0 {
return s * t.Len()
}
case reflect.Struct:
sum := 0
for i, n := 0, t.NumField(); i < n; i++ {
s := sizeof(t.Field(i).Type)
if s < 0 {
return -1
}
sum += s
}
return sum
case reflect.Bool,
reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
return int(t.Size())
}
return -1
}
type coder struct {
order ByteOrder
buf []byte
offset int
}
type decoder coder
type encoder coder
func (d *decoder) bool() bool {
x := d.buf[d.offset]
d.offset++
return x != 0
}
func (e *encoder) bool(x bool) {
if x {
e.buf[e.offset] = 1
} else {
e.buf[e.offset] = 0
}
e.offset++
}
func (d *decoder) uint8() uint8 {
x := d.buf[d.offset]
d.offset++
return x
}
func (e *encoder) uint8(x uint8) {
e.buf[e.offset] = x
e.offset++
}
func (d *decoder) uint16() uint16 {
x := d.order.Uint16(d.buf[d.offset : d.offset+2])
d.offset += 2
return x
}
func (e *encoder) uint16(x uint16) {
e.order.PutUint16(e.buf[e.offset:e.offset+2], x)
e.offset += 2
}
func (d *decoder) uint32() uint32 {
x := d.order.Uint32(d.buf[d.offset : d.offset+4])
d.offset += 4
return x
}
func (e *encoder) uint32(x uint32) {
e.order.PutUint32(e.buf[e.offset:e.offset+4], x)
e.offset += 4
}
func (d *decoder) uint64() uint64 {
x := d.order.Uint64(d.buf[d.offset : d.offset+8])
d.offset += 8
return x
}
func (e *encoder) uint64(x uint64) {
e.order.PutUint64(e.buf[e.offset:e.offset+8], x)
e.offset += 8
}
func (d *decoder) int8() int8 { return int8(d.uint8()) }
func (e *encoder) int8(x int8) { e.uint8(uint8(x)) }
func (d *decoder) int16() int16 { return int16(d.uint16()) }
func (e *encoder) int16(x int16) { e.uint16(uint16(x)) }
func (d *decoder) int32() int32 { return int32(d.uint32()) }
func (e *encoder) int32(x int32) { e.uint32(uint32(x)) }
func (d *decoder) int64() int64 { return int64(d.uint64()) }
func (e *encoder) int64(x int64) { e.uint64(uint64(x)) }
func (d *decoder) value(v reflect.Value) {
switch v.Kind() {
case reflect.Array:
l := v.Len()
for i := 0; i < l; i++ {
d.value(v.Index(i))
}
case reflect.Struct:
t := v.Type()
l := v.NumField()
for i := 0; i < l; i++ {
// Note: Calling v.CanSet() below is an optimization.
// It would be sufficient to check the field name,
// but creating the StructField info for each field is
// costly (run "go test -bench=ReadStruct" and compare
// results when making changes to this code).
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
d.value(v)
} else {
d.skip(v)
}
}
case reflect.Slice:
l := v.Len()
for i := 0; i < l; i++ {
d.value(v.Index(i))
}
case reflect.Bool:
v.SetBool(d.bool())
case reflect.Int8:
v.SetInt(int64(d.int8()))
case reflect.Int16:
v.SetInt(int64(d.int16()))
case reflect.Int32:
v.SetInt(int64(d.int32()))
case reflect.Int64:
v.SetInt(d.int64())
case reflect.Uint8:
v.SetUint(uint64(d.uint8()))
case reflect.Uint16:
v.SetUint(uint64(d.uint16()))
case reflect.Uint32:
v.SetUint(uint64(d.uint32()))
case reflect.Uint64:
v.SetUint(d.uint64())
case reflect.Float32:
v.SetFloat(float64(math.Float32frombits(d.uint32())))
case reflect.Float64:
v.SetFloat(math.Float64frombits(d.uint64()))
case reflect.Complex64:
v.SetComplex(complex(
float64(math.Float32frombits(d.uint32())),
float64(math.Float32frombits(d.uint32())),
))
case reflect.Complex128:
v.SetComplex(complex(
math.Float64frombits(d.uint64()),
math.Float64frombits(d.uint64()),
))
}
}
func (e *encoder) value(v reflect.Value) {
switch v.Kind() {
case reflect.Array:
l := v.Len()
for i := 0; i < l; i++ {
e.value(v.Index(i))
}
case reflect.Struct:
t := v.Type()
l := v.NumField()
for i := 0; i < l; i++ {
// see comment for corresponding code in decoder.value()
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
e.value(v)
} else {
e.skip(v)
}
}
case reflect.Slice:
l := v.Len()
for i := 0; i < l; i++ {
e.value(v.Index(i))
}
case reflect.Bool:
e.bool(v.Bool())
case reflect.Int8:
e.int8(int8(v.Int()))
case reflect.Int16:
e.int16(int16(v.Int()))
case reflect.Int32:
e.int32(int32(v.Int()))
case reflect.Int64:
e.int64(v.Int())
case reflect.Uint8:
e.uint8(uint8(v.Uint()))
case reflect.Uint16:
e.uint16(uint16(v.Uint()))
case reflect.Uint32:
e.uint32(uint32(v.Uint()))
case reflect.Uint64:
e.uint64(v.Uint())
case reflect.Float32:
e.uint32(math.Float32bits(float32(v.Float())))
case reflect.Float64:
e.uint64(math.Float64bits(v.Float()))
case reflect.Complex64:
x := v.Complex()
e.uint32(math.Float32bits(float32(real(x))))
e.uint32(math.Float32bits(float32(imag(x))))
case reflect.Complex128:
x := v.Complex()
e.uint64(math.Float64bits(real(x)))
e.uint64(math.Float64bits(imag(x)))
}
}
func (d *decoder) skip(v reflect.Value) {
d.offset += dataSize(v)
}
func (e *encoder) skip(v reflect.Value) {
n := dataSize(v)
clear(e.buf[e.offset : e.offset+n])
e.offset += n
}
// intDataSize returns the size of the data required to represent the data when encoded,
// and optionally a byte slice containing the encoded data if no conversion is necessary.
// It returns zero, nil if the type cannot be implemented by the fast path in Read or Write.
func intDataSize(data any) (int, []byte) {
switch data := data.(type) {
case bool, int8, uint8, *bool, *int8, *uint8:
return 1, nil
case []bool:
return len(data), nil
case []int8:
return len(data), nil
case []uint8:
return len(data), data
case int16, uint16, *int16, *uint16:
return 2, nil
case []int16:
return 2 * len(data), nil
case []uint16:
return 2 * len(data), nil
case int32, uint32, *int32, *uint32:
return 4, nil
case []int32:
return 4 * len(data), nil
case []uint32:
return 4 * len(data), nil
case int64, uint64, *int64, *uint64:
return 8, nil
case []int64:
return 8 * len(data), nil
case []uint64:
return 8 * len(data), nil
case float32, *float32:
return 4, nil
case float64, *float64:
return 8, nil
case []float32:
return 4 * len(data), nil
case []float64:
return 8 * len(data), nil
}
return 0, nil
}
// ensure grows buf to length len(buf) + n and returns the grown buffer
// and a slice starting at the original length of buf (that is, buf2[len(buf):]).
func ensure(buf []byte, n int) (buf2, pos []byte) {
l := len(buf)
buf = slices.Grow(buf, n)[:l+n]
return buf, buf[l:]
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package binary
// This file implements "varint" encoding of 64-bit integers.
// The encoding is:
// - unsigned integers are serialized 7 bits at a time, starting with the
// least significant bits
// - the most significant bit (msb) in each output byte indicates if there
// is a continuation byte (msb = 1)
// - signed integers are mapped to unsigned integers using "zig-zag"
// encoding: Positive values x are written as 2*x + 0, negative values
// are written as 2*(^x) + 1; that is, negative numbers are complemented
// and whether to complement is encoded in bit 0.
//
// Design note:
// At most 10 bytes are needed for 64-bit values. The encoding could
// be more dense: a full 64-bit value needs an extra byte just to hold bit 63.
// Instead, the msb of the previous byte could be used to hold bit 63 since we
// know there can't be more than 64 bits. This is a trivial improvement and
// would reduce the maximum encoding length to 9 bytes. However, it breaks the
// invariant that the msb is always the "continuation bit" and thus makes the
// format incompatible with a varint encoding for larger numbers (say 128-bit).
import (
"errors"
"io"
)
// MaxVarintLenN is the maximum length of a varint-encoded N-bit integer.
const (
MaxVarintLen16 = 3
MaxVarintLen32 = 5
MaxVarintLen64 = 10
)
// AppendUvarint appends the varint-encoded form of x,
// as generated by [PutUvarint], to buf and returns the extended buffer.
func AppendUvarint(buf []byte, x uint64) []byte {
for x >= 0x80 {
buf = append(buf, byte(x)|0x80)
x >>= 7
}
return append(buf, byte(x))
}
// PutUvarint encodes a uint64 into buf and returns the number of bytes written.
// If the buffer is too small, PutUvarint will panic.
func PutUvarint(buf []byte, x uint64) int {
i := 0
for x >= 0x80 {
buf[i] = byte(x) | 0x80
x >>= 7
i++
}
buf[i] = byte(x)
return i + 1
}
// Uvarint decodes a uint64 from buf and returns that value and the
// number of bytes read (> 0). If an error occurred, the value is 0
// and the number of bytes n is <= 0 meaning:
// - n == 0: buf too small;
// - n < 0: value larger than 64 bits (overflow) and -n is the number of
// bytes read.
func Uvarint(buf []byte) (uint64, int) {
var x uint64
var s uint
for i, b := range buf {
if i == MaxVarintLen64 {
// Catch byte reads past MaxVarintLen64.
// See issue https://golang.org/issues/41185
return 0, -(i + 1) // overflow
}
if b < 0x80 {
if i == MaxVarintLen64-1 && b > 1 {
return 0, -(i + 1) // overflow
}
return x | uint64(b)<<s, i + 1
}
x |= uint64(b&0x7f) << s
s += 7
}
return 0, 0
}
// AppendVarint appends the varint-encoded form of x,
// as generated by [PutVarint], to buf and returns the extended buffer.
func AppendVarint(buf []byte, x int64) []byte {
ux := uint64(x) << 1
if x < 0 {
ux = ^ux
}
return AppendUvarint(buf, ux)
}
// PutVarint encodes an int64 into buf and returns the number of bytes written.
// If the buffer is too small, PutVarint will panic.
func PutVarint(buf []byte, x int64) int {
ux := uint64(x) << 1
if x < 0 {
ux = ^ux
}
return PutUvarint(buf, ux)
}
// Varint decodes an int64 from buf and returns that value and the
// number of bytes read (> 0). If an error occurred, the value is 0
// and the number of bytes n is <= 0 with the following meaning:
// - n == 0: buf too small;
// - n < 0: value larger than 64 bits (overflow)
// and -n is the number of bytes read.
func Varint(buf []byte) (int64, int) {
ux, n := Uvarint(buf) // ok to continue in presence of error
x := int64(ux >> 1)
if ux&1 != 0 {
x = ^x
}
return x, n
}
var errOverflow = errors.New("binary: varint overflows a 64-bit integer")
// ReadUvarint reads an encoded unsigned integer from r and returns it as a uint64.
// The error is [io.EOF] only if no bytes were read.
// If an [io.EOF] happens after reading some but not all the bytes,
// ReadUvarint returns [io.ErrUnexpectedEOF].
func ReadUvarint(r io.ByteReader) (uint64, error) {
var x uint64
var s uint
for i := 0; i < MaxVarintLen64; i++ {
b, err := r.ReadByte()
if err != nil {
if i > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return x, err
}
if b < 0x80 {
if i == MaxVarintLen64-1 && b > 1 {
return x, errOverflow
}
return x | uint64(b)<<s, nil
}
x |= uint64(b&0x7f) << s
s += 7
}
return x, errOverflow
}
// ReadVarint reads an encoded signed integer from r and returns it as an int64.
// The error is [io.EOF] only if no bytes were read.
// If an [io.EOF] happens after reading some but not all the bytes,
// ReadVarint returns [io.ErrUnexpectedEOF].
func ReadVarint(r io.ByteReader) (int64, error) {
ux, err := ReadUvarint(r) // ok to continue in presence of error
x := int64(ux >> 1)
if ux&1 != 0 {
x = ^x
}
return x, err
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package csv reads and writes comma-separated values (CSV) files.
// There are many kinds of CSV files; this package supports the format
// described in RFC 4180, except that [Writer] uses LF
// instead of CRLF as newline character by default.
//
// A csv file contains zero or more records of one or more fields per record.
// Each record is separated by the newline character. The final record may
// optionally be followed by a newline character.
//
// field1,field2,field3
//
// White space is considered part of a field.
//
// Carriage returns before newline characters are silently removed.
//
// Blank lines are ignored. A line with only whitespace characters (excluding
// the ending newline character) is not considered a blank line.
//
// Fields which start and stop with the quote character " are called
// quoted-fields. The beginning and ending quote are not part of the
// field.
//
// The source:
//
// normal string,"quoted-field"
//
// results in the fields
//
// {`normal string`, `quoted-field`}
//
// Within a quoted-field a quote character followed by a second quote
// character is considered a single quote.
//
// "the ""word"" is true","a ""quoted-field"""
//
// results in
//
// {`the "word" is true`, `a "quoted-field"`}
//
// Newlines and commas may be included in a quoted-field
//
// "Multi-line
// field","comma is ,"
//
// results in
//
// {`Multi-line
// field`, `comma is ,`}
package csv
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"unicode"
"unicode/utf8"
)
// A ParseError is returned for parsing errors.
// Line and column numbers are 1-indexed.
type ParseError struct {
StartLine int // Line where the record starts
Line int // Line where the error occurred
Column int // Column (1-based byte index) where the error occurred
Err error // The actual error
}
func (e *ParseError) Error() string {
if e.Err == ErrFieldCount {
return fmt.Sprintf("record on line %d: %v", e.Line, e.Err)
}
if e.StartLine != e.Line {
return fmt.Sprintf("record on line %d; parse error on line %d, column %d: %v", e.StartLine, e.Line, e.Column, e.Err)
}
return fmt.Sprintf("parse error on line %d, column %d: %v", e.Line, e.Column, e.Err)
}
func (e *ParseError) Unwrap() error { return e.Err }
// These are the errors that can be returned in [ParseError.Err].
var (
ErrBareQuote = errors.New("bare \" in non-quoted-field")
ErrQuote = errors.New("extraneous or missing \" in quoted-field")
ErrFieldCount = errors.New("wrong number of fields")
// Deprecated: ErrTrailingComma is no longer used.
ErrTrailingComma = errors.New("extra delimiter at end of line")
)
var errInvalidDelim = errors.New("csv: invalid field or comment delimiter")
func validDelim(r rune) bool {
return r != 0 && r != '"' && r != '\r' && r != '\n' && utf8.ValidRune(r) && r != utf8.RuneError
}
// A Reader reads records from a CSV-encoded file.
//
// As returned by [NewReader], a Reader expects input conforming to RFC 4180.
// The exported fields can be changed to customize the details before the
// first call to [Reader.Read] or [Reader.ReadAll].
//
// The Reader converts all \r\n sequences in its input to plain \n,
// including in multiline field values, so that the returned data does
// not depend on which line-ending convention an input file uses.
type Reader struct {
// Comma is the field delimiter.
// It is set to comma (',') by NewReader.
// Comma must be a valid rune and must not be \r, \n,
// or the Unicode replacement character (0xFFFD).
Comma rune
// Comment, if not 0, is the comment character. Lines beginning with the
// Comment character without preceding whitespace are ignored.
// With leading whitespace the Comment character becomes part of the
// field, even if TrimLeadingSpace is true.
// Comment must be a valid rune and must not be \r, \n,
// or the Unicode replacement character (0xFFFD).
// It must also not be equal to Comma.
Comment rune
// FieldsPerRecord is the number of expected fields per record.
// If FieldsPerRecord is positive, Read requires each record to
// have the given number of fields. If FieldsPerRecord is 0, Read sets it to
// the number of fields in the first record, so that future records must
// have the same field count. If FieldsPerRecord is negative, no check is
// made and records may have a variable number of fields.
FieldsPerRecord int
// If LazyQuotes is true, a quote may appear in an unquoted field and a
// non-doubled quote may appear in a quoted field.
LazyQuotes bool
// If TrimLeadingSpace is true, leading white space in a field is ignored.
// This is done even if the field delimiter, Comma, is white space.
TrimLeadingSpace bool
// ReuseRecord controls whether calls to Read may return a slice sharing
// the backing array of the previous call's returned slice for performance.
// By default, each call to Read returns newly allocated memory owned by the caller.
ReuseRecord bool
// Deprecated: TrailingComma is no longer used.
TrailingComma bool
r *bufio.Reader
// numLine is the current line being read in the CSV file.
numLine int
// offset is the input stream byte offset of the current reader position.
offset int64
// rawBuffer is a line buffer only used by the readLine method.
rawBuffer []byte
// recordBuffer holds the unescaped fields, one after another.
// The fields can be accessed by using the indexes in fieldIndexes.
// E.g., For the row `a,"b","c""d",e`, recordBuffer will contain `abc"de`
// and fieldIndexes will contain the indexes [1, 2, 5, 6].
recordBuffer []byte
// fieldIndexes is an index of fields inside recordBuffer.
// The i'th field ends at offset fieldIndexes[i] in recordBuffer.
fieldIndexes []int
// fieldPositions is an index of field positions for the
// last record returned by Read.
fieldPositions []position
// lastRecord is a record cache and only used when ReuseRecord == true.
lastRecord []string
}
// NewReader returns a new Reader that reads from r.
func NewReader(r io.Reader) *Reader {
return &Reader{
Comma: ',',
r: bufio.NewReader(r),
}
}
// Read reads one record (a slice of fields) from r.
// If the record has an unexpected number of fields,
// Read returns the record along with the error [ErrFieldCount].
// If the record contains a field that cannot be parsed,
// Read returns a partial record along with the parse error.
// The partial record contains all fields read before the error.
// If there is no data left to be read, Read returns nil, [io.EOF].
// If [Reader.ReuseRecord] is true, the returned slice may be shared
// between multiple calls to Read.
func (r *Reader) Read() (record []string, err error) {
if r.ReuseRecord {
record, err = r.readRecord(r.lastRecord)
r.lastRecord = record
} else {
record, err = r.readRecord(nil)
}
return record, err
}
// FieldPos returns the line and column corresponding to
// the start of the field with the given index in the slice most recently
// returned by [Reader.Read]. Numbering of lines and columns starts at 1;
// columns are counted in bytes, not runes.
//
// If this is called with an out-of-bounds index, it panics.
func (r *Reader) FieldPos(field int) (line, column int) {
if field < 0 || field >= len(r.fieldPositions) {
panic("out of range index passed to FieldPos")
}
p := &r.fieldPositions[field]
return p.line, p.col
}
// InputOffset returns the input stream byte offset of the current reader
// position. The offset gives the location of the end of the most recently
// read row and the beginning of the next row.
func (r *Reader) InputOffset() int64 {
return r.offset
}
// pos holds the position of a field in the current line.
type position struct {
line, col int
}
// ReadAll reads all the remaining records from r.
// Each record is a slice of fields.
// A successful call returns err == nil, not err == [io.EOF]. Because ReadAll is
// defined to read until EOF, it does not treat end of file as an error to be
// reported.
func (r *Reader) ReadAll() (records [][]string, err error) {
for {
record, err := r.readRecord(nil)
if err == io.EOF {
return records, nil
}
if err != nil {
return nil, err
}
records = append(records, record)
}
}
// readLine reads the next line (with the trailing endline).
// If EOF is hit without a trailing endline, it will be omitted.
// If some bytes were read, then the error is never [io.EOF].
// The result is only valid until the next call to readLine.
func (r *Reader) readLine() ([]byte, error) {
line, err := r.r.ReadSlice('\n')
if err == bufio.ErrBufferFull {
r.rawBuffer = append(r.rawBuffer[:0], line...)
for err == bufio.ErrBufferFull {
line, err = r.r.ReadSlice('\n')
r.rawBuffer = append(r.rawBuffer, line...)
}
line = r.rawBuffer
}
readSize := len(line)
if readSize > 0 && err == io.EOF {
err = nil
// For backwards compatibility, drop trailing \r before EOF.
if line[readSize-1] == '\r' {
line = line[:readSize-1]
}
}
r.numLine++
r.offset += int64(readSize)
// Normalize \r\n to \n on all input lines.
if n := len(line); n >= 2 && line[n-2] == '\r' && line[n-1] == '\n' {
line[n-2] = '\n'
line = line[:n-1]
}
return line, err
}
// lengthNL reports the number of bytes for the trailing \n.
func lengthNL(b []byte) int {
if len(b) > 0 && b[len(b)-1] == '\n' {
return 1
}
return 0
}
// nextRune returns the next rune in b or utf8.RuneError.
func nextRune(b []byte) rune {
r, _ := utf8.DecodeRune(b)
return r
}
func (r *Reader) readRecord(dst []string) ([]string, error) {
if r.Comma == r.Comment || !validDelim(r.Comma) || (r.Comment != 0 && !validDelim(r.Comment)) {
return nil, errInvalidDelim
}
// Read line (automatically skipping past empty lines and any comments).
var line []byte
var errRead error
for errRead == nil {
line, errRead = r.readLine()
if r.Comment != 0 && nextRune(line) == r.Comment {
line = nil
continue // Skip comment lines
}
if errRead == nil && len(line) == lengthNL(line) {
line = nil
continue // Skip empty lines
}
break
}
if errRead == io.EOF {
return nil, errRead
}
// Parse each field in the record.
var err error
const quoteLen = len(`"`)
commaLen := utf8.RuneLen(r.Comma)
recLine := r.numLine // Starting line for record
r.recordBuffer = r.recordBuffer[:0]
r.fieldIndexes = r.fieldIndexes[:0]
r.fieldPositions = r.fieldPositions[:0]
pos := position{line: r.numLine, col: 1}
parseField:
for {
if r.TrimLeadingSpace {
i := bytes.IndexFunc(line, func(r rune) bool {
return !unicode.IsSpace(r)
})
if i < 0 {
i = len(line)
pos.col -= lengthNL(line)
}
line = line[i:]
pos.col += i
}
if len(line) == 0 || line[0] != '"' {
// Non-quoted string field
i := bytes.IndexRune(line, r.Comma)
field := line
if i >= 0 {
field = field[:i]
} else {
field = field[:len(field)-lengthNL(field)]
}
// Check to make sure a quote does not appear in field.
if !r.LazyQuotes {
if j := bytes.IndexByte(field, '"'); j >= 0 {
col := pos.col + j
err = &ParseError{StartLine: recLine, Line: r.numLine, Column: col, Err: ErrBareQuote}
break parseField
}
}
r.recordBuffer = append(r.recordBuffer, field...)
r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer))
r.fieldPositions = append(r.fieldPositions, pos)
if i >= 0 {
line = line[i+commaLen:]
pos.col += i + commaLen
continue parseField
}
break parseField
} else {
// Quoted string field
fieldPos := pos
line = line[quoteLen:]
pos.col += quoteLen
for {
i := bytes.IndexByte(line, '"')
if i >= 0 {
// Hit next quote.
r.recordBuffer = append(r.recordBuffer, line[:i]...)
line = line[i+quoteLen:]
pos.col += i + quoteLen
switch rn := nextRune(line); {
case rn == '"':
// `""` sequence (append quote).
r.recordBuffer = append(r.recordBuffer, '"')
line = line[quoteLen:]
pos.col += quoteLen
case rn == r.Comma:
// `",` sequence (end of field).
line = line[commaLen:]
pos.col += commaLen
r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer))
r.fieldPositions = append(r.fieldPositions, fieldPos)
continue parseField
case lengthNL(line) == len(line):
// `"\n` sequence (end of line).
r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer))
r.fieldPositions = append(r.fieldPositions, fieldPos)
break parseField
case r.LazyQuotes:
// `"` sequence (bare quote).
r.recordBuffer = append(r.recordBuffer, '"')
default:
// `"*` sequence (invalid non-escaped quote).
err = &ParseError{StartLine: recLine, Line: r.numLine, Column: pos.col - quoteLen, Err: ErrQuote}
break parseField
}
} else if len(line) > 0 {
// Hit end of line (copy all data so far).
r.recordBuffer = append(r.recordBuffer, line...)
if errRead != nil {
break parseField
}
pos.col += len(line)
line, errRead = r.readLine()
if len(line) > 0 {
pos.line++
pos.col = 1
}
if errRead == io.EOF {
errRead = nil
}
} else {
// Abrupt end of file (EOF or error).
if !r.LazyQuotes && errRead == nil {
err = &ParseError{StartLine: recLine, Line: pos.line, Column: pos.col, Err: ErrQuote}
break parseField
}
r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer))
r.fieldPositions = append(r.fieldPositions, fieldPos)
break parseField
}
}
}
}
if err == nil {
err = errRead
}
// Create a single string and create slices out of it.
// This pins the memory of the fields together, but allocates once.
str := string(r.recordBuffer) // Convert to string once to batch allocations
dst = dst[:0]
if cap(dst) < len(r.fieldIndexes) {
dst = make([]string, len(r.fieldIndexes))
}
dst = dst[:len(r.fieldIndexes)]
var preIdx int
for i, idx := range r.fieldIndexes {
dst[i] = str[preIdx:idx]
preIdx = idx
}
// Check or update the expected fields per record.
if r.FieldsPerRecord > 0 {
if len(dst) != r.FieldsPerRecord && err == nil {
err = &ParseError{
StartLine: recLine,
Line: recLine,
Column: 1,
Err: ErrFieldCount,
}
}
} else if r.FieldsPerRecord == 0 {
r.FieldsPerRecord = len(dst)
}
return dst, err
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package csv
import (
"bufio"
"io"
"strings"
"unicode"
"unicode/utf8"
)
// A Writer writes records using CSV encoding.
//
// As returned by [NewWriter], a Writer writes records terminated by a
// newline and uses ',' as the field delimiter. The exported fields can be
// changed to customize the details before
// the first call to [Writer.Write] or [Writer.WriteAll].
//
// [Writer.Comma] is the field delimiter.
//
// If [Writer.UseCRLF] is true,
// the Writer ends each output line with \r\n instead of \n.
//
// The writes of individual records are buffered.
// After all data has been written, the client should call the
// [Writer.Flush] method to guarantee all data has been forwarded to
// the underlying [io.Writer]. Any errors that occurred should
// be checked by calling the [Writer.Error] method.
type Writer struct {
Comma rune // Field delimiter (set to ',' by NewWriter)
UseCRLF bool // True to use \r\n as the line terminator
w *bufio.Writer
}
// NewWriter returns a new Writer that writes to w.
func NewWriter(w io.Writer) *Writer {
return &Writer{
Comma: ',',
w: bufio.NewWriter(w),
}
}
// Write writes a single CSV record to w along with any necessary quoting.
// A record is a slice of strings with each string being one field.
// Writes are buffered, so [Writer.Flush] must eventually be called to ensure
// that the record is written to the underlying [io.Writer].
func (w *Writer) Write(record []string) error {
if !validDelim(w.Comma) {
return errInvalidDelim
}
for n, field := range record {
if n > 0 {
if _, err := w.w.WriteRune(w.Comma); err != nil {
return err
}
}
// If we don't have to have a quoted field then just
// write out the field and continue to the next field.
if !w.fieldNeedsQuotes(field) {
if _, err := w.w.WriteString(field); err != nil {
return err
}
continue
}
if err := w.w.WriteByte('"'); err != nil {
return err
}
for len(field) > 0 {
// Search for special characters.
i := strings.IndexAny(field, "\"\r\n")
if i < 0 {
i = len(field)
}
// Copy verbatim everything before the special character.
if _, err := w.w.WriteString(field[:i]); err != nil {
return err
}
field = field[i:]
// Encode the special character.
if len(field) > 0 {
var err error
switch field[0] {
case '"':
_, err = w.w.WriteString(`""`)
case '\r':
if !w.UseCRLF {
err = w.w.WriteByte('\r')
}
case '\n':
if w.UseCRLF {
_, err = w.w.WriteString("\r\n")
} else {
err = w.w.WriteByte('\n')
}
}
field = field[1:]
if err != nil {
return err
}
}
}
if err := w.w.WriteByte('"'); err != nil {
return err
}
}
var err error
if w.UseCRLF {
_, err = w.w.WriteString("\r\n")
} else {
err = w.w.WriteByte('\n')
}
return err
}
// Flush writes any buffered data to the underlying [io.Writer].
// To check if an error occurred during Flush, call [Writer.Error].
func (w *Writer) Flush() {
w.w.Flush()
}
// Error reports any error that has occurred during
// a previous [Writer.Write] or [Writer.Flush].
func (w *Writer) Error() error {
_, err := w.w.Write(nil)
return err
}
// WriteAll writes multiple CSV records to w using [Writer.Write] and
// then calls [Writer.Flush], returning any error from the Flush.
func (w *Writer) WriteAll(records [][]string) error {
for _, record := range records {
err := w.Write(record)
if err != nil {
return err
}
}
return w.w.Flush()
}
// fieldNeedsQuotes reports whether our field must be enclosed in quotes.
// Fields with a Comma, fields with a quote or newline, and
// fields which start with a space must be enclosed in quotes.
// We used to quote empty strings, but we do not anymore (as of Go 1.4).
// The two representations should be equivalent, but Postgres distinguishes
// quoted vs non-quoted empty string during database imports, and it has
// an option to force the quoted behavior for non-quoted CSV but it has
// no option to force the non-quoted behavior for quoted CSV, making
// CSV with quoted empty strings strictly less useful.
// Not quoting the empty string also makes this package match the behavior
// of Microsoft Excel and Google Drive.
// For Postgres, quote the data terminating string `\.`.
func (w *Writer) fieldNeedsQuotes(field string) bool {
if field == "" {
return false
}
if field == `\.` {
return true
}
if w.Comma < utf8.RuneSelf {
for i := 0; i < len(field); i++ {
c := field[i]
if c == '\n' || c == '\r' || c == '"' || c == byte(w.Comma) {
return true
}
}
} else {
if strings.ContainsRune(field, w.Comma) || strings.ContainsAny(field, "\"\r\n") {
return true
}
}
r1, _ := utf8.DecodeRuneInString(field)
return unicode.IsSpace(r1)
}
// Code generated by go run decgen.go -output dec_helpers.go; DO NOT EDIT.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gob
import (
"math"
"reflect"
)
var decArrayHelper = map[reflect.Kind]decHelper{
reflect.Bool: decBoolArray,
reflect.Complex64: decComplex64Array,
reflect.Complex128: decComplex128Array,
reflect.Float32: decFloat32Array,
reflect.Float64: decFloat64Array,
reflect.Int: decIntArray,
reflect.Int16: decInt16Array,
reflect.Int32: decInt32Array,
reflect.Int64: decInt64Array,
reflect.Int8: decInt8Array,
reflect.String: decStringArray,
reflect.Uint: decUintArray,
reflect.Uint16: decUint16Array,
reflect.Uint32: decUint32Array,
reflect.Uint64: decUint64Array,
reflect.Uintptr: decUintptrArray,
}
var decSliceHelper = map[reflect.Kind]decHelper{
reflect.Bool: decBoolSlice,
reflect.Complex64: decComplex64Slice,
reflect.Complex128: decComplex128Slice,
reflect.Float32: decFloat32Slice,
reflect.Float64: decFloat64Slice,
reflect.Int: decIntSlice,
reflect.Int16: decInt16Slice,
reflect.Int32: decInt32Slice,
reflect.Int64: decInt64Slice,
reflect.Int8: decInt8Slice,
reflect.String: decStringSlice,
reflect.Uint: decUintSlice,
reflect.Uint16: decUint16Slice,
reflect.Uint32: decUint32Slice,
reflect.Uint64: decUint64Slice,
reflect.Uintptr: decUintptrSlice,
}
func decBoolArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decBoolSlice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decBoolSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]bool)
if !ok {
// It is kind bool but not type bool. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding bool array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
slice[i] = state.decodeUint() != 0
}
return true
}
func decComplex64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decComplex64Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decComplex64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]complex64)
if !ok {
// It is kind complex64 but not type complex64. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding complex64 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
real := float32FromBits(state.decodeUint(), ovfl)
imag := float32FromBits(state.decodeUint(), ovfl)
slice[i] = complex(float32(real), float32(imag))
}
return true
}
func decComplex128Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decComplex128Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decComplex128Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]complex128)
if !ok {
// It is kind complex128 but not type complex128. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding complex128 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
real := float64FromBits(state.decodeUint())
imag := float64FromBits(state.decodeUint())
slice[i] = complex(real, imag)
}
return true
}
func decFloat32Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decFloat32Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decFloat32Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]float32)
if !ok {
// It is kind float32 but not type float32. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding float32 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
slice[i] = float32(float32FromBits(state.decodeUint(), ovfl))
}
return true
}
func decFloat64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decFloat64Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decFloat64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]float64)
if !ok {
// It is kind float64 but not type float64. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding float64 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
slice[i] = float64FromBits(state.decodeUint())
}
return true
}
func decIntArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decIntSlice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decIntSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]int)
if !ok {
// It is kind int but not type int. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding int array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeInt()
// MinInt and MaxInt
if x < ^int64(^uint(0)>>1) || int64(^uint(0)>>1) < x {
error_(ovfl)
}
slice[i] = int(x)
}
return true
}
func decInt16Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decInt16Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decInt16Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]int16)
if !ok {
// It is kind int16 but not type int16. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding int16 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeInt()
if x < math.MinInt16 || math.MaxInt16 < x {
error_(ovfl)
}
slice[i] = int16(x)
}
return true
}
func decInt32Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decInt32Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decInt32Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]int32)
if !ok {
// It is kind int32 but not type int32. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding int32 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeInt()
if x < math.MinInt32 || math.MaxInt32 < x {
error_(ovfl)
}
slice[i] = int32(x)
}
return true
}
func decInt64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decInt64Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decInt64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]int64)
if !ok {
// It is kind int64 but not type int64. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding int64 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
slice[i] = state.decodeInt()
}
return true
}
func decInt8Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decInt8Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decInt8Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]int8)
if !ok {
// It is kind int8 but not type int8. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding int8 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeInt()
if x < math.MinInt8 || math.MaxInt8 < x {
error_(ovfl)
}
slice[i] = int8(x)
}
return true
}
func decStringArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decStringSlice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decStringSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]string)
if !ok {
// It is kind string but not type string. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding string array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
u := state.decodeUint()
n := int(u)
if n < 0 || uint64(n) != u || n > state.b.Len() {
errorf("length of string exceeds input size (%d bytes)", u)
}
if n > state.b.Len() {
errorf("string data too long for buffer: %d", n)
}
// Read the data.
data := state.b.Bytes()
if len(data) < n {
errorf("invalid string length %d: exceeds input size %d", n, len(data))
}
slice[i] = string(data[:n])
state.b.Drop(n)
}
return true
}
func decUintArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decUintSlice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decUintSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]uint)
if !ok {
// It is kind uint but not type uint. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding uint array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeUint()
/*TODO if math.MaxUint32 < x {
error_(ovfl)
}*/
slice[i] = uint(x)
}
return true
}
func decUint16Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decUint16Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decUint16Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]uint16)
if !ok {
// It is kind uint16 but not type uint16. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding uint16 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeUint()
if math.MaxUint16 < x {
error_(ovfl)
}
slice[i] = uint16(x)
}
return true
}
func decUint32Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decUint32Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decUint32Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]uint32)
if !ok {
// It is kind uint32 but not type uint32. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding uint32 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeUint()
if math.MaxUint32 < x {
error_(ovfl)
}
slice[i] = uint32(x)
}
return true
}
func decUint64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decUint64Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decUint64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]uint64)
if !ok {
// It is kind uint64 but not type uint64. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding uint64 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
slice[i] = state.decodeUint()
}
return true
}
func decUintptrArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decUintptrSlice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decUintptrSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]uintptr)
if !ok {
// It is kind uintptr but not type uintptr. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding uintptr array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeUint()
if uint64(^uintptr(0)) < x {
error_(ovfl)
}
slice[i] = uintptr(x)
}
return true
}
// growSlice is called for a slice that we only partially allocated,
// to grow it up to length.
func growSlice[E any](v reflect.Value, ps *[]E, length int) {
var zero E
s := *ps
s = append(s, zero)
cp := cap(s)
if cp > length {
cp = length
}
s = s[:cp]
v.Set(reflect.ValueOf(s))
*ps = s
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate go run decgen.go -output dec_helpers.go
package gob
import (
"encoding"
"errors"
"internal/saferio"
"io"
"math"
"math/bits"
"reflect"
)
var (
errBadUint = errors.New("gob: encoded unsigned integer out of range")
errBadType = errors.New("gob: unknown type id or corrupted data")
errRange = errors.New("gob: bad data: field numbers out of bounds")
)
type decHelper func(state *decoderState, v reflect.Value, length int, ovfl error) bool
// decoderState is the execution state of an instance of the decoder. A new state
// is created for nested objects.
type decoderState struct {
dec *Decoder
// The buffer is stored with an extra indirection because it may be replaced
// if we load a type during decode (when reading an interface value).
b *decBuffer
fieldnum int // the last field number read.
next *decoderState // for free list
}
// decBuffer is an extremely simple, fast implementation of a read-only byte buffer.
// It is initialized by calling Size and then copying the data into the slice returned by Bytes().
type decBuffer struct {
data []byte
offset int // Read offset.
}
func (d *decBuffer) Read(p []byte) (int, error) {
n := copy(p, d.data[d.offset:])
if n == 0 && len(p) != 0 {
return 0, io.EOF
}
d.offset += n
return n, nil
}
func (d *decBuffer) Drop(n int) {
if n > d.Len() {
panic("drop")
}
d.offset += n
}
func (d *decBuffer) ReadByte() (byte, error) {
if d.offset >= len(d.data) {
return 0, io.EOF
}
c := d.data[d.offset]
d.offset++
return c, nil
}
func (d *decBuffer) Len() int {
return len(d.data) - d.offset
}
func (d *decBuffer) Bytes() []byte {
return d.data[d.offset:]
}
// SetBytes sets the buffer to the bytes, discarding any existing data.
func (d *decBuffer) SetBytes(data []byte) {
d.data = data
d.offset = 0
}
func (d *decBuffer) Reset() {
d.data = d.data[0:0]
d.offset = 0
}
// We pass the bytes.Buffer separately for easier testing of the infrastructure
// without requiring a full Decoder.
func (dec *Decoder) newDecoderState(buf *decBuffer) *decoderState {
d := dec.freeList
if d == nil {
d = new(decoderState)
d.dec = dec
} else {
dec.freeList = d.next
}
d.b = buf
return d
}
func (dec *Decoder) freeDecoderState(d *decoderState) {
d.next = dec.freeList
dec.freeList = d
}
func overflow(name string) error {
return errors.New(`value for "` + name + `" out of range`)
}
// decodeUintReader reads an encoded unsigned integer from an io.Reader.
// Used only by the Decoder to read the message length.
func decodeUintReader(r io.Reader, buf []byte) (x uint64, width int, err error) {
width = 1
n, err := io.ReadFull(r, buf[0:width])
if n == 0 {
return
}
b := buf[0]
if b <= 0x7f {
return uint64(b), width, nil
}
n = -int(int8(b))
if n > uint64Size {
err = errBadUint
return
}
width, err = io.ReadFull(r, buf[0:n])
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return
}
// Could check that the high byte is zero but it's not worth it.
for _, b := range buf[0:width] {
x = x<<8 | uint64(b)
}
width++ // +1 for length byte
return
}
// decodeUint reads an encoded unsigned integer from state.r.
// Does not check for overflow.
func (state *decoderState) decodeUint() (x uint64) {
b, err := state.b.ReadByte()
if err != nil {
error_(err)
}
if b <= 0x7f {
return uint64(b)
}
n := -int(int8(b))
if n > uint64Size {
error_(errBadUint)
}
buf := state.b.Bytes()
if len(buf) < n {
errorf("invalid uint data length %d: exceeds input size %d", n, len(buf))
}
// Don't need to check error; it's safe to loop regardless.
// Could check that the high byte is zero but it's not worth it.
for _, b := range buf[0:n] {
x = x<<8 | uint64(b)
}
state.b.Drop(n)
return x
}
// decodeInt reads an encoded signed integer from state.r.
// Does not check for overflow.
func (state *decoderState) decodeInt() int64 {
x := state.decodeUint()
if x&1 != 0 {
return ^int64(x >> 1)
}
return int64(x >> 1)
}
// getLength decodes the next uint and makes sure it is a possible
// size for a data item that follows, which means it must fit in a
// non-negative int and fit in the buffer.
func (state *decoderState) getLength() (int, bool) {
n := int(state.decodeUint())
if n < 0 || state.b.Len() < n || tooBig <= n {
return 0, false
}
return n, true
}
// decOp is the signature of a decoding operator for a given type.
type decOp func(i *decInstr, state *decoderState, v reflect.Value)
// The 'instructions' of the decoding machine
type decInstr struct {
op decOp
field int // field number of the wire type
index []int // field access indices for destination type
ovfl error // error message for overflow/underflow (for arrays, of the elements)
}
// ignoreUint discards a uint value with no destination.
func ignoreUint(i *decInstr, state *decoderState, v reflect.Value) {
state.decodeUint()
}
// ignoreTwoUints discards a uint value with no destination. It's used to skip
// complex values.
func ignoreTwoUints(i *decInstr, state *decoderState, v reflect.Value) {
state.decodeUint()
state.decodeUint()
}
// Since the encoder writes no zeros, if we arrive at a decoder we have
// a value to extract and store. The field number has already been read
// (it's how we knew to call this decoder).
// Each decoder is responsible for handling any indirections associated
// with the data structure. If any pointer so reached is nil, allocation must
// be done.
// decAlloc takes a value and returns a settable value that can
// be assigned to. If the value is a pointer, decAlloc guarantees it points to storage.
// The callers to the individual decoders are expected to have used decAlloc.
// The individual decoders don't need it.
func decAlloc(v reflect.Value) reflect.Value {
for v.Kind() == reflect.Pointer {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
return v
}
// decBool decodes a uint and stores it as a boolean in value.
func decBool(i *decInstr, state *decoderState, value reflect.Value) {
value.SetBool(state.decodeUint() != 0)
}
// decInt8 decodes an integer and stores it as an int8 in value.
func decInt8(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeInt()
if v < math.MinInt8 || math.MaxInt8 < v {
error_(i.ovfl)
}
value.SetInt(v)
}
// decUint8 decodes an unsigned integer and stores it as a uint8 in value.
func decUint8(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeUint()
if math.MaxUint8 < v {
error_(i.ovfl)
}
value.SetUint(v)
}
// decInt16 decodes an integer and stores it as an int16 in value.
func decInt16(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeInt()
if v < math.MinInt16 || math.MaxInt16 < v {
error_(i.ovfl)
}
value.SetInt(v)
}
// decUint16 decodes an unsigned integer and stores it as a uint16 in value.
func decUint16(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeUint()
if math.MaxUint16 < v {
error_(i.ovfl)
}
value.SetUint(v)
}
// decInt32 decodes an integer and stores it as an int32 in value.
func decInt32(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeInt()
if v < math.MinInt32 || math.MaxInt32 < v {
error_(i.ovfl)
}
value.SetInt(v)
}
// decUint32 decodes an unsigned integer and stores it as a uint32 in value.
func decUint32(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeUint()
if math.MaxUint32 < v {
error_(i.ovfl)
}
value.SetUint(v)
}
// decInt64 decodes an integer and stores it as an int64 in value.
func decInt64(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeInt()
value.SetInt(v)
}
// decUint64 decodes an unsigned integer and stores it as a uint64 in value.
func decUint64(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeUint()
value.SetUint(v)
}
// Floating-point numbers are transmitted as uint64s holding the bits
// of the underlying representation. They are sent byte-reversed, with
// the exponent end coming out first, so integer floating point numbers
// (for example) transmit more compactly. This routine does the
// unswizzling.
func float64FromBits(u uint64) float64 {
v := bits.ReverseBytes64(u)
return math.Float64frombits(v)
}
// float32FromBits decodes an unsigned integer, treats it as a 32-bit floating-point
// number, and returns it. It's a helper function for float32 and complex64.
// It returns a float64 because that's what reflection needs, but its return
// value is known to be accurately representable in a float32.
func float32FromBits(u uint64, ovfl error) float64 {
v := float64FromBits(u)
av := v
if av < 0 {
av = -av
}
// +Inf is OK in both 32- and 64-bit floats. Underflow is always OK.
if math.MaxFloat32 < av && av <= math.MaxFloat64 {
error_(ovfl)
}
return v
}
// decFloat32 decodes an unsigned integer, treats it as a 32-bit floating-point
// number, and stores it in value.
func decFloat32(i *decInstr, state *decoderState, value reflect.Value) {
value.SetFloat(float32FromBits(state.decodeUint(), i.ovfl))
}
// decFloat64 decodes an unsigned integer, treats it as a 64-bit floating-point
// number, and stores it in value.
func decFloat64(i *decInstr, state *decoderState, value reflect.Value) {
value.SetFloat(float64FromBits(state.decodeUint()))
}
// decComplex64 decodes a pair of unsigned integers, treats them as a
// pair of floating point numbers, and stores them as a complex64 in value.
// The real part comes first.
func decComplex64(i *decInstr, state *decoderState, value reflect.Value) {
real := float32FromBits(state.decodeUint(), i.ovfl)
imag := float32FromBits(state.decodeUint(), i.ovfl)
value.SetComplex(complex(real, imag))
}
// decComplex128 decodes a pair of unsigned integers, treats them as a
// pair of floating point numbers, and stores them as a complex128 in value.
// The real part comes first.
func decComplex128(i *decInstr, state *decoderState, value reflect.Value) {
real := float64FromBits(state.decodeUint())
imag := float64FromBits(state.decodeUint())
value.SetComplex(complex(real, imag))
}
// decUint8Slice decodes a byte slice and stores in value a slice header
// describing the data.
// uint8 slices are encoded as an unsigned count followed by the raw bytes.
func decUint8Slice(i *decInstr, state *decoderState, value reflect.Value) {
n, ok := state.getLength()
if !ok {
errorf("bad %s slice length: %d", value.Type(), n)
}
if value.Cap() < n {
safe := saferio.SliceCap[byte](uint64(n))
if safe < 0 {
errorf("%s slice too big: %d elements", value.Type(), n)
}
value.Set(reflect.MakeSlice(value.Type(), safe, safe))
ln := safe
i := 0
for i < n {
if i >= ln {
// We didn't allocate the entire slice,
// due to using saferio.SliceCap.
// Grow the slice for one more element.
// The slice is full, so this should
// bump up the capacity.
value.Grow(1)
}
// Copy into s up to the capacity or n,
// whichever is less.
ln = value.Cap()
if ln > n {
ln = n
}
value.SetLen(ln)
sub := value.Slice(i, ln)
if _, err := state.b.Read(sub.Bytes()); err != nil {
errorf("error decoding []byte at %d: %s", i, err)
}
i = ln
}
} else {
value.SetLen(n)
if _, err := state.b.Read(value.Bytes()); err != nil {
errorf("error decoding []byte: %s", err)
}
}
}
// decString decodes byte array and stores in value a string header
// describing the data.
// Strings are encoded as an unsigned count followed by the raw bytes.
func decString(i *decInstr, state *decoderState, value reflect.Value) {
n, ok := state.getLength()
if !ok {
errorf("bad %s slice length: %d", value.Type(), n)
}
// Read the data.
data := state.b.Bytes()
if len(data) < n {
errorf("invalid string length %d: exceeds input size %d", n, len(data))
}
s := string(data[:n])
state.b.Drop(n)
value.SetString(s)
}
// ignoreUint8Array skips over the data for a byte slice value with no destination.
func ignoreUint8Array(i *decInstr, state *decoderState, value reflect.Value) {
n, ok := state.getLength()
if !ok {
errorf("slice length too large")
}
bn := state.b.Len()
if bn < n {
errorf("invalid slice length %d: exceeds input size %d", n, bn)
}
state.b.Drop(n)
}
// Execution engine
// The encoder engine is an array of instructions indexed by field number of the incoming
// decoder. It is executed with random access according to field number.
type decEngine struct {
instr []decInstr
numInstr int // the number of active instructions
}
// decodeSingle decodes a top-level value that is not a struct and stores it in value.
// Such values are preceded by a zero, making them have the memory layout of a
// struct field (although with an illegal field number).
func (dec *Decoder) decodeSingle(engine *decEngine, value reflect.Value) {
state := dec.newDecoderState(&dec.buf)
defer dec.freeDecoderState(state)
state.fieldnum = singletonField
if state.decodeUint() != 0 {
errorf("decode: corrupted data: non-zero delta for singleton")
}
instr := &engine.instr[singletonField]
instr.op(instr, state, value)
}
// decodeStruct decodes a top-level struct and stores it in value.
// Indir is for the value, not the type. At the time of the call it may
// differ from ut.indir, which was computed when the engine was built.
// This state cannot arise for decodeSingle, which is called directly
// from the user's value, not from the innards of an engine.
func (dec *Decoder) decodeStruct(engine *decEngine, value reflect.Value) {
state := dec.newDecoderState(&dec.buf)
defer dec.freeDecoderState(state)
state.fieldnum = -1
for state.b.Len() > 0 {
delta := int(state.decodeUint())
if delta < 0 {
errorf("decode: corrupted data: negative delta")
}
if delta == 0 { // struct terminator is zero delta fieldnum
break
}
if state.fieldnum >= len(engine.instr)-delta { // subtract to compare without overflow
error_(errRange)
}
fieldnum := state.fieldnum + delta
instr := &engine.instr[fieldnum]
var field reflect.Value
if instr.index != nil {
// Otherwise the field is unknown to us and instr.op is an ignore op.
field = value.FieldByIndex(instr.index)
if field.Kind() == reflect.Pointer {
field = decAlloc(field)
}
}
instr.op(instr, state, field)
state.fieldnum = fieldnum
}
}
var noValue reflect.Value
// ignoreStruct discards the data for a struct with no destination.
func (dec *Decoder) ignoreStruct(engine *decEngine) {
state := dec.newDecoderState(&dec.buf)
defer dec.freeDecoderState(state)
state.fieldnum = -1
for state.b.Len() > 0 {
delta := int(state.decodeUint())
if delta < 0 {
errorf("ignore decode: corrupted data: negative delta")
}
if delta == 0 { // struct terminator is zero delta fieldnum
break
}
fieldnum := state.fieldnum + delta
if fieldnum >= len(engine.instr) {
error_(errRange)
}
instr := &engine.instr[fieldnum]
instr.op(instr, state, noValue)
state.fieldnum = fieldnum
}
}
// ignoreSingle discards the data for a top-level non-struct value with no
// destination. It's used when calling Decode with a nil value.
func (dec *Decoder) ignoreSingle(engine *decEngine) {
state := dec.newDecoderState(&dec.buf)
defer dec.freeDecoderState(state)
state.fieldnum = singletonField
delta := int(state.decodeUint())
if delta != 0 {
errorf("decode: corrupted data: non-zero delta for singleton")
}
instr := &engine.instr[singletonField]
instr.op(instr, state, noValue)
}
// decodeArrayHelper does the work for decoding arrays and slices.
func (dec *Decoder) decodeArrayHelper(state *decoderState, value reflect.Value, elemOp decOp, length int, ovfl error, helper decHelper) {
if helper != nil && helper(state, value, length, ovfl) {
return
}
instr := &decInstr{elemOp, 0, nil, ovfl}
isPtr := value.Type().Elem().Kind() == reflect.Pointer
ln := value.Len()
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding array or slice: length exceeds input size (%d elements)", length)
}
if i >= ln {
// This is a slice that we only partially allocated.
// Grow it up to length.
value.Grow(1)
cp := value.Cap()
if cp > length {
cp = length
}
value.SetLen(cp)
ln = cp
}
v := value.Index(i)
if isPtr {
v = decAlloc(v)
}
elemOp(instr, state, v)
}
}
// decodeArray decodes an array and stores it in value.
// The length is an unsigned integer preceding the elements. Even though the length is redundant
// (it's part of the type), it's a useful check and is included in the encoding.
func (dec *Decoder) decodeArray(state *decoderState, value reflect.Value, elemOp decOp, length int, ovfl error, helper decHelper) {
if n := state.decodeUint(); n != uint64(length) {
errorf("length mismatch in decodeArray")
}
dec.decodeArrayHelper(state, value, elemOp, length, ovfl, helper)
}
// decodeIntoValue is a helper for map decoding.
func decodeIntoValue(state *decoderState, op decOp, isPtr bool, value reflect.Value, instr *decInstr) reflect.Value {
v := value
if isPtr {
v = decAlloc(value)
}
op(instr, state, v)
return value
}
// decodeMap decodes a map and stores it in value.
// Maps are encoded as a length followed by key:value pairs.
// Because the internals of maps are not visible to us, we must
// use reflection rather than pointer magic.
func (dec *Decoder) decodeMap(mtyp reflect.Type, state *decoderState, value reflect.Value, keyOp, elemOp decOp, ovfl error) {
n := int(state.decodeUint())
if value.IsNil() {
value.Set(reflect.MakeMapWithSize(mtyp, n))
}
keyIsPtr := mtyp.Key().Kind() == reflect.Pointer
elemIsPtr := mtyp.Elem().Kind() == reflect.Pointer
keyInstr := &decInstr{keyOp, 0, nil, ovfl}
elemInstr := &decInstr{elemOp, 0, nil, ovfl}
keyP := reflect.New(mtyp.Key())
elemP := reflect.New(mtyp.Elem())
for i := 0; i < n; i++ {
key := decodeIntoValue(state, keyOp, keyIsPtr, keyP.Elem(), keyInstr)
elem := decodeIntoValue(state, elemOp, elemIsPtr, elemP.Elem(), elemInstr)
value.SetMapIndex(key, elem)
keyP.Elem().SetZero()
elemP.Elem().SetZero()
}
}
// ignoreArrayHelper does the work for discarding arrays and slices.
func (dec *Decoder) ignoreArrayHelper(state *decoderState, elemOp decOp, length int) {
instr := &decInstr{elemOp, 0, nil, errors.New("no error")}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding array or slice: length exceeds input size (%d elements)", length)
}
elemOp(instr, state, noValue)
}
}
// ignoreArray discards the data for an array value with no destination.
func (dec *Decoder) ignoreArray(state *decoderState, elemOp decOp, length int) {
if n := state.decodeUint(); n != uint64(length) {
errorf("length mismatch in ignoreArray")
}
dec.ignoreArrayHelper(state, elemOp, length)
}
// ignoreMap discards the data for a map value with no destination.
func (dec *Decoder) ignoreMap(state *decoderState, keyOp, elemOp decOp) {
n := int(state.decodeUint())
keyInstr := &decInstr{keyOp, 0, nil, errors.New("no error")}
elemInstr := &decInstr{elemOp, 0, nil, errors.New("no error")}
for i := 0; i < n; i++ {
keyOp(keyInstr, state, noValue)
elemOp(elemInstr, state, noValue)
}
}
// decodeSlice decodes a slice and stores it in value.
// Slices are encoded as an unsigned length followed by the elements.
func (dec *Decoder) decodeSlice(state *decoderState, value reflect.Value, elemOp decOp, ovfl error, helper decHelper) {
u := state.decodeUint()
typ := value.Type()
size := uint64(typ.Elem().Size())
nBytes := u * size
n := int(u)
// Take care with overflow in this calculation.
if n < 0 || uint64(n) != u || nBytes > tooBig || (size > 0 && nBytes/size != u) {
// We don't check n against buffer length here because if it's a slice
// of interfaces, there will be buffer reloads.
errorf("%s slice too big: %d elements of %d bytes", typ.Elem(), u, size)
}
if value.Cap() < n {
safe := saferio.SliceCapWithSize(size, uint64(n))
if safe < 0 {
errorf("%s slice too big: %d elements of %d bytes", typ.Elem(), u, size)
}
value.Set(reflect.MakeSlice(typ, safe, safe))
} else {
value.SetLen(n)
}
dec.decodeArrayHelper(state, value, elemOp, n, ovfl, helper)
}
// ignoreSlice skips over the data for a slice value with no destination.
func (dec *Decoder) ignoreSlice(state *decoderState, elemOp decOp) {
dec.ignoreArrayHelper(state, elemOp, int(state.decodeUint()))
}
// decodeInterface decodes an interface value and stores it in value.
// Interfaces are encoded as the name of a concrete type followed by a value.
// If the name is empty, the value is nil and no value is sent.
func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, value reflect.Value) {
// Read the name of the concrete type.
nr := state.decodeUint()
if nr > 1<<31 { // zero is permissible for anonymous types
errorf("invalid type name length %d", nr)
}
if nr > uint64(state.b.Len()) {
errorf("invalid type name length %d: exceeds input size", nr)
}
n := int(nr)
name := state.b.Bytes()[:n]
state.b.Drop(n)
// Allocate the destination interface value.
if len(name) == 0 {
// Copy the nil interface value to the target.
value.SetZero()
return
}
if len(name) > 1024 {
errorf("name too long (%d bytes): %.20q...", len(name), name)
}
// The concrete type must be registered.
typi, ok := nameToConcreteType.Load(string(name))
if !ok {
errorf("name not registered for interface: %q", name)
}
typ := typi.(reflect.Type)
// Read the type id of the concrete value.
concreteId := dec.decodeTypeSequence(true)
if concreteId < 0 {
error_(dec.err)
}
// Byte count of value is next; we don't care what it is (it's there
// in case we want to ignore the value by skipping it completely).
state.decodeUint()
// Read the concrete value.
v := allocValue(typ)
dec.decodeValue(concreteId, v)
if dec.err != nil {
error_(dec.err)
}
// Assign the concrete value to the interface.
// Tread carefully; it might not satisfy the interface.
if !typ.AssignableTo(ityp) {
errorf("%s is not assignable to type %s", typ, ityp)
}
// Copy the interface value to the target.
value.Set(v)
}
// ignoreInterface discards the data for an interface value with no destination.
func (dec *Decoder) ignoreInterface(state *decoderState) {
// Read the name of the concrete type.
n, ok := state.getLength()
if !ok {
errorf("bad interface encoding: name too large for buffer")
}
bn := state.b.Len()
if bn < n {
errorf("invalid interface value length %d: exceeds input size %d", n, bn)
}
state.b.Drop(n)
id := dec.decodeTypeSequence(true)
if id < 0 {
error_(dec.err)
}
// At this point, the decoder buffer contains a delimited value. Just toss it.
n, ok = state.getLength()
if !ok {
errorf("bad interface encoding: data length too large for buffer")
}
state.b.Drop(n)
}
// decodeGobDecoder decodes something implementing the GobDecoder interface.
// The data is encoded as a byte slice.
func (dec *Decoder) decodeGobDecoder(ut *userTypeInfo, state *decoderState, value reflect.Value) {
// Read the bytes for the value.
n, ok := state.getLength()
if !ok {
errorf("GobDecoder: length too large for buffer")
}
b := state.b.Bytes()
if len(b) < n {
errorf("GobDecoder: invalid data length %d: exceeds input size %d", n, len(b))
}
b = b[:n]
state.b.Drop(n)
var err error
// We know it's one of these.
switch ut.externalDec {
case xGob:
err = value.Interface().(GobDecoder).GobDecode(b)
case xBinary:
err = value.Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary(b)
case xText:
err = value.Interface().(encoding.TextUnmarshaler).UnmarshalText(b)
}
if err != nil {
error_(err)
}
}
// ignoreGobDecoder discards the data for a GobDecoder value with no destination.
func (dec *Decoder) ignoreGobDecoder(state *decoderState) {
// Read the bytes for the value.
n, ok := state.getLength()
if !ok {
errorf("GobDecoder: length too large for buffer")
}
bn := state.b.Len()
if bn < n {
errorf("GobDecoder: invalid data length %d: exceeds input size %d", n, bn)
}
state.b.Drop(n)
}
// Index by Go types.
var decOpTable = [...]decOp{
reflect.Bool: decBool,
reflect.Int8: decInt8,
reflect.Int16: decInt16,
reflect.Int32: decInt32,
reflect.Int64: decInt64,
reflect.Uint8: decUint8,
reflect.Uint16: decUint16,
reflect.Uint32: decUint32,
reflect.Uint64: decUint64,
reflect.Float32: decFloat32,
reflect.Float64: decFloat64,
reflect.Complex64: decComplex64,
reflect.Complex128: decComplex128,
reflect.String: decString,
}
// Indexed by gob types. tComplex will be added during type.init().
var decIgnoreOpMap = map[typeId]decOp{
tBool: ignoreUint,
tInt: ignoreUint,
tUint: ignoreUint,
tFloat: ignoreUint,
tBytes: ignoreUint8Array,
tString: ignoreUint8Array,
tComplex: ignoreTwoUints,
}
// decOpFor returns the decoding op for the base type under rt and
// the indirection count to reach it.
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) *decOp {
ut := userType(rt)
// If the type implements GobEncoder, we handle it without further processing.
if ut.externalDec != 0 {
return dec.gobDecodeOpFor(ut)
}
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building.
if opPtr := inProgress[rt]; opPtr != nil {
return opPtr
}
typ := ut.base
var op decOp
k := typ.Kind()
if int(k) < len(decOpTable) {
op = decOpTable[k]
}
if op == nil {
inProgress[rt] = &op
// Special cases
switch t := typ; t.Kind() {
case reflect.Array:
name = "element of " + name
elemId := dec.wireType[wireId].ArrayT.Elem
elemOp := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name)
helper := decArrayHelper[t.Elem().Kind()]
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.decodeArray(state, value, *elemOp, t.Len(), ovfl, helper)
}
case reflect.Map:
keyId := dec.wireType[wireId].MapT.Key
elemId := dec.wireType[wireId].MapT.Elem
keyOp := dec.decOpFor(keyId, t.Key(), "key of "+name, inProgress)
elemOp := dec.decOpFor(elemId, t.Elem(), "element of "+name, inProgress)
ovfl := overflow(name)
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.decodeMap(t, state, value, *keyOp, *elemOp, ovfl)
}
case reflect.Slice:
name = "element of " + name
if t.Elem().Kind() == reflect.Uint8 {
op = decUint8Slice
break
}
var elemId typeId
if tt := builtinIdToType(wireId); tt != nil {
elemId = tt.(*sliceType).Elem
} else {
elemId = dec.wireType[wireId].SliceT.Elem
}
elemOp := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name)
helper := decSliceHelper[t.Elem().Kind()]
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.decodeSlice(state, value, *elemOp, ovfl, helper)
}
case reflect.Struct:
// Generate a closure that calls out to the engine for the nested type.
ut := userType(typ)
enginePtr, err := dec.getDecEnginePtr(wireId, ut)
if err != nil {
error_(err)
}
op = func(i *decInstr, state *decoderState, value reflect.Value) {
// indirect through enginePtr to delay evaluation for recursive structs.
dec.decodeStruct(*enginePtr, value)
}
case reflect.Interface:
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.decodeInterface(t, state, value)
}
}
}
if op == nil {
errorf("decode can't handle type %s", rt)
}
return &op
}
var maxIgnoreNestingDepth = 10000
// decIgnoreOpFor returns the decoding op for a field that has no destination.
func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp) *decOp {
// Track how deep we've recursed trying to skip nested ignored fields.
dec.ignoreDepth++
defer func() { dec.ignoreDepth-- }()
if dec.ignoreDepth > maxIgnoreNestingDepth {
error_(errors.New("invalid nesting depth"))
}
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building.
if opPtr := inProgress[wireId]; opPtr != nil {
return opPtr
}
op, ok := decIgnoreOpMap[wireId]
if !ok {
inProgress[wireId] = &op
if wireId == tInterface {
// Special case because it's a method: the ignored item might
// define types and we need to record their state in the decoder.
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreInterface(state)
}
return &op
}
// Special cases
wire := dec.wireType[wireId]
switch {
case wire == nil:
errorf("bad data: undefined type %s", wireId.string())
case wire.ArrayT != nil:
elemId := wire.ArrayT.Elem
elemOp := dec.decIgnoreOpFor(elemId, inProgress)
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreArray(state, *elemOp, wire.ArrayT.Len)
}
case wire.MapT != nil:
keyId := dec.wireType[wireId].MapT.Key
elemId := dec.wireType[wireId].MapT.Elem
keyOp := dec.decIgnoreOpFor(keyId, inProgress)
elemOp := dec.decIgnoreOpFor(elemId, inProgress)
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreMap(state, *keyOp, *elemOp)
}
case wire.SliceT != nil:
elemId := wire.SliceT.Elem
elemOp := dec.decIgnoreOpFor(elemId, inProgress)
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreSlice(state, *elemOp)
}
case wire.StructT != nil:
// Generate a closure that calls out to the engine for the nested type.
enginePtr, err := dec.getIgnoreEnginePtr(wireId)
if err != nil {
error_(err)
}
op = func(i *decInstr, state *decoderState, value reflect.Value) {
// indirect through enginePtr to delay evaluation for recursive structs
state.dec.ignoreStruct(*enginePtr)
}
case wire.GobEncoderT != nil, wire.BinaryMarshalerT != nil, wire.TextMarshalerT != nil:
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreGobDecoder(state)
}
}
}
if op == nil {
errorf("bad data: ignore can't handle type %s", wireId.string())
}
return &op
}
// gobDecodeOpFor returns the op for a type that is known to implement
// GobDecoder.
func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) *decOp {
rcvrType := ut.user
if ut.decIndir == -1 {
rcvrType = reflect.PointerTo(rcvrType)
} else if ut.decIndir > 0 {
for i := int8(0); i < ut.decIndir; i++ {
rcvrType = rcvrType.Elem()
}
}
var op decOp
op = func(i *decInstr, state *decoderState, value reflect.Value) {
// We now have the base type. We need its address if the receiver is a pointer.
if value.Kind() != reflect.Pointer && rcvrType.Kind() == reflect.Pointer {
value = value.Addr()
}
state.dec.decodeGobDecoder(ut, state, value)
}
return &op
}
// compatibleType asks: Are these two gob Types compatible?
// Answers the question for basic types, arrays, maps and slices, plus
// GobEncoder/Decoder pairs.
// Structs are considered ok; fields will be checked later.
func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool {
if rhs, ok := inProgress[fr]; ok {
return rhs == fw
}
inProgress[fr] = fw
ut := userType(fr)
wire, ok := dec.wireType[fw]
// If wire was encoded with an encoding method, fr must have that method.
// And if not, it must not.
// At most one of the booleans in ut is set.
// We could possibly relax this constraint in the future in order to
// choose the decoding method using the data in the wireType.
// The parentheses look odd but are correct.
if (ut.externalDec == xGob) != (ok && wire.GobEncoderT != nil) ||
(ut.externalDec == xBinary) != (ok && wire.BinaryMarshalerT != nil) ||
(ut.externalDec == xText) != (ok && wire.TextMarshalerT != nil) {
return false
}
if ut.externalDec != 0 { // This test trumps all others.
return true
}
switch t := ut.base; t.Kind() {
default:
// chan, etc: cannot handle.
return false
case reflect.Bool:
return fw == tBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return fw == tInt
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return fw == tUint
case reflect.Float32, reflect.Float64:
return fw == tFloat
case reflect.Complex64, reflect.Complex128:
return fw == tComplex
case reflect.String:
return fw == tString
case reflect.Interface:
return fw == tInterface
case reflect.Array:
if !ok || wire.ArrayT == nil {
return false
}
array := wire.ArrayT
return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress)
case reflect.Map:
if !ok || wire.MapT == nil {
return false
}
MapType := wire.MapT
return dec.compatibleType(t.Key(), MapType.Key, inProgress) && dec.compatibleType(t.Elem(), MapType.Elem, inProgress)
case reflect.Slice:
// Is it an array of bytes?
if t.Elem().Kind() == reflect.Uint8 {
return fw == tBytes
}
// Extract and compare element types.
var sw *sliceType
if tt := builtinIdToType(fw); tt != nil {
sw, _ = tt.(*sliceType)
} else if wire != nil {
sw = wire.SliceT
}
elem := userType(t.Elem()).base
return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress)
case reflect.Struct:
return true
}
}
// typeString returns a human-readable description of the type identified by remoteId.
func (dec *Decoder) typeString(remoteId typeId) string {
typeLock.Lock()
defer typeLock.Unlock()
if t := idToType(remoteId); t != nil {
// globally known type.
return t.string()
}
return dec.wireType[remoteId].string()
}
// compileSingle compiles the decoder engine for a non-struct top-level value, including
// GobDecoders.
func (dec *Decoder) compileSingle(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err error) {
rt := ut.user
engine = new(decEngine)
engine.instr = make([]decInstr, 1) // one item
name := rt.String() // best we can do
if !dec.compatibleType(rt, remoteId, make(map[reflect.Type]typeId)) {
remoteType := dec.typeString(remoteId)
// Common confusing case: local interface type, remote concrete type.
if ut.base.Kind() == reflect.Interface && remoteId != tInterface {
return nil, errors.New("gob: local interface type " + name + " can only be decoded from remote interface type; received concrete type " + remoteType)
}
return nil, errors.New("gob: decoding into local type " + name + ", received remote type " + remoteType)
}
op := dec.decOpFor(remoteId, rt, name, make(map[reflect.Type]*decOp))
ovfl := errors.New(`value for "` + name + `" out of range`)
engine.instr[singletonField] = decInstr{*op, singletonField, nil, ovfl}
engine.numInstr = 1
return
}
// compileIgnoreSingle compiles the decoder engine for a non-struct top-level value that will be discarded.
func (dec *Decoder) compileIgnoreSingle(remoteId typeId) *decEngine {
engine := new(decEngine)
engine.instr = make([]decInstr, 1) // one item
op := dec.decIgnoreOpFor(remoteId, make(map[typeId]*decOp))
ovfl := overflow(dec.typeString(remoteId))
engine.instr[0] = decInstr{*op, 0, nil, ovfl}
engine.numInstr = 1
return engine
}
// compileDec compiles the decoder engine for a value. If the value is not a struct,
// it calls out to compileSingle.
func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err error) {
defer catchError(&err)
rt := ut.base
srt := rt
if srt.Kind() != reflect.Struct || ut.externalDec != 0 {
return dec.compileSingle(remoteId, ut)
}
var wireStruct *structType
// Builtin types can come from global pool; the rest must be defined by the decoder.
// Also we know we're decoding a struct now, so the client must have sent one.
if t := builtinIdToType(remoteId); t != nil {
wireStruct, _ = t.(*structType)
} else {
wire := dec.wireType[remoteId]
if wire == nil {
error_(errBadType)
}
wireStruct = wire.StructT
}
if wireStruct == nil {
errorf("type mismatch in decoder: want struct type %s; got non-struct", rt)
}
engine = new(decEngine)
engine.instr = make([]decInstr, len(wireStruct.Field))
seen := make(map[reflect.Type]*decOp)
// Loop over the fields of the wire type.
for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ {
wireField := wireStruct.Field[fieldnum]
if wireField.Name == "" {
errorf("empty name for remote field of type %s", wireStruct.Name)
}
ovfl := overflow(wireField.Name)
// Find the field of the local type with the same name.
localField, present := srt.FieldByName(wireField.Name)
// TODO(r): anonymous names
if !present || !isExported(wireField.Name) {
op := dec.decIgnoreOpFor(wireField.Id, make(map[typeId]*decOp))
engine.instr[fieldnum] = decInstr{*op, fieldnum, nil, ovfl}
continue
}
if !dec.compatibleType(localField.Type, wireField.Id, make(map[reflect.Type]typeId)) {
errorf("wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name)
}
op := dec.decOpFor(wireField.Id, localField.Type, localField.Name, seen)
engine.instr[fieldnum] = decInstr{*op, fieldnum, localField.Index, ovfl}
engine.numInstr++
}
return
}
// getDecEnginePtr returns the engine for the specified type.
func (dec *Decoder) getDecEnginePtr(remoteId typeId, ut *userTypeInfo) (enginePtr **decEngine, err error) {
rt := ut.user
decoderMap, ok := dec.decoderCache[rt]
if !ok {
decoderMap = make(map[typeId]**decEngine)
dec.decoderCache[rt] = decoderMap
}
if enginePtr, ok = decoderMap[remoteId]; !ok {
// To handle recursive types, mark this engine as underway before compiling.
enginePtr = new(*decEngine)
decoderMap[remoteId] = enginePtr
*enginePtr, err = dec.compileDec(remoteId, ut)
if err != nil {
delete(decoderMap, remoteId)
}
}
return
}
// emptyStruct is the type we compile into when ignoring a struct value.
type emptyStruct struct{}
var emptyStructType = reflect.TypeFor[emptyStruct]()
// getIgnoreEnginePtr returns the engine for the specified type when the value is to be discarded.
func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err error) {
var ok bool
if enginePtr, ok = dec.ignorerCache[wireId]; !ok {
// To handle recursive types, mark this engine as underway before compiling.
enginePtr = new(*decEngine)
dec.ignorerCache[wireId] = enginePtr
wire := dec.wireType[wireId]
if wire != nil && wire.StructT != nil {
*enginePtr, err = dec.compileDec(wireId, userType(emptyStructType))
} else {
*enginePtr = dec.compileIgnoreSingle(wireId)
}
if err != nil {
delete(dec.ignorerCache, wireId)
}
}
return
}
// decodeValue decodes the data stream representing a value and stores it in value.
func (dec *Decoder) decodeValue(wireId typeId, value reflect.Value) {
defer catchError(&dec.err)
// If the value is nil, it means we should just ignore this item.
if !value.IsValid() {
dec.decodeIgnoredValue(wireId)
return
}
// Dereference down to the underlying type.
ut := userType(value.Type())
base := ut.base
var enginePtr **decEngine
enginePtr, dec.err = dec.getDecEnginePtr(wireId, ut)
if dec.err != nil {
return
}
value = decAlloc(value)
engine := *enginePtr
if st := base; st.Kind() == reflect.Struct && ut.externalDec == 0 {
wt := dec.wireType[wireId]
if engine.numInstr == 0 && st.NumField() > 0 &&
wt != nil && len(wt.StructT.Field) > 0 {
name := base.Name()
errorf("type mismatch: no fields matched compiling decoder for %s", name)
}
dec.decodeStruct(engine, value)
} else {
dec.decodeSingle(engine, value)
}
}
// decodeIgnoredValue decodes the data stream representing a value of the specified type and discards it.
func (dec *Decoder) decodeIgnoredValue(wireId typeId) {
var enginePtr **decEngine
enginePtr, dec.err = dec.getIgnoreEnginePtr(wireId)
if dec.err != nil {
return
}
wire := dec.wireType[wireId]
if wire != nil && wire.StructT != nil {
dec.ignoreStruct(*enginePtr)
} else {
dec.ignoreSingle(*enginePtr)
}
}
const (
intBits = 32 << (^uint(0) >> 63)
uintptrBits = 32 << (^uintptr(0) >> 63)
)
func init() {
var iop, uop decOp
switch intBits {
case 32:
iop = decInt32
uop = decUint32
case 64:
iop = decInt64
uop = decUint64
default:
panic("gob: unknown size of int/uint")
}
decOpTable[reflect.Int] = iop
decOpTable[reflect.Uint] = uop
// Finally uintptr
switch uintptrBits {
case 32:
uop = decUint32
case 64:
uop = decUint64
default:
panic("gob: unknown size of uintptr")
}
decOpTable[reflect.Uintptr] = uop
}
// Gob depends on being able to take the address
// of zeroed Values it creates, so use this wrapper instead
// of the standard reflect.Zero.
// Each call allocates once.
func allocValue(t reflect.Type) reflect.Value {
return reflect.New(t).Elem()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gob
import (
"bufio"
"errors"
"internal/saferio"
"io"
"reflect"
"sync"
)
// tooBig provides a sanity check for sizes; used in several places. Upper limit
// of is 1GB on 32-bit systems, 8GB on 64-bit, allowing room to grow a little
// without overflow.
const tooBig = (1 << 30) << (^uint(0) >> 62)
// A Decoder manages the receipt of type and data information read from the
// remote side of a connection. It is safe for concurrent use by multiple
// goroutines.
//
// The Decoder does only basic sanity checking on decoded input sizes,
// and its limits are not configurable. Take caution when decoding gob data
// from untrusted sources.
type Decoder struct {
mutex sync.Mutex // each item must be received atomically
r io.Reader // source of the data
buf decBuffer // buffer for more efficient i/o from r
wireType map[typeId]*wireType // map from remote ID to local description
decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines
ignorerCache map[typeId]**decEngine // ditto for ignored objects
freeList *decoderState // list of free decoderStates; avoids reallocation
countBuf []byte // used for decoding integers while parsing messages
err error
// ignoreDepth tracks the depth of recursively parsed ignored fields
ignoreDepth int
}
// NewDecoder returns a new decoder that reads from the [io.Reader].
// If r does not also implement [io.ByteReader], it will be wrapped in a
// [bufio.Reader].
func NewDecoder(r io.Reader) *Decoder {
dec := new(Decoder)
// We use the ability to read bytes as a plausible surrogate for buffering.
if _, ok := r.(io.ByteReader); !ok {
r = bufio.NewReader(r)
}
dec.r = r
dec.wireType = make(map[typeId]*wireType)
dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
dec.ignorerCache = make(map[typeId]**decEngine)
dec.countBuf = make([]byte, 9) // counts may be uint64s (unlikely!), require 9 bytes
return dec
}
// recvType loads the definition of a type.
func (dec *Decoder) recvType(id typeId) {
// Have we already seen this type? That's an error
if id < firstUserId || dec.wireType[id] != nil {
dec.err = errors.New("gob: duplicate type received")
return
}
// Type:
wire := new(wireType)
dec.decodeValue(tWireType, reflect.ValueOf(wire))
if dec.err != nil {
return
}
// Remember we've seen this type.
dec.wireType[id] = wire
}
var errBadCount = errors.New("invalid message length")
// recvMessage reads the next count-delimited item from the input. It is the converse
// of Encoder.writeMessage. It returns false on EOF or other error reading the message.
func (dec *Decoder) recvMessage() bool {
// Read a count.
nbytes, _, err := decodeUintReader(dec.r, dec.countBuf)
if err != nil {
dec.err = err
return false
}
if nbytes >= tooBig {
dec.err = errBadCount
return false
}
dec.readMessage(int(nbytes))
return dec.err == nil
}
// readMessage reads the next nbytes bytes from the input.
func (dec *Decoder) readMessage(nbytes int) {
if dec.buf.Len() != 0 {
// The buffer should always be empty now.
panic("non-empty decoder buffer")
}
// Read the data
var buf []byte
buf, dec.err = saferio.ReadData(dec.r, uint64(nbytes))
dec.buf.SetBytes(buf)
if dec.err == io.EOF {
dec.err = io.ErrUnexpectedEOF
}
}
// toInt turns an encoded uint64 into an int, according to the marshaling rules.
func toInt(x uint64) int64 {
i := int64(x >> 1)
if x&1 != 0 {
i = ^i
}
return i
}
func (dec *Decoder) nextInt() int64 {
n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
if err != nil {
dec.err = err
}
return toInt(n)
}
func (dec *Decoder) nextUint() uint64 {
n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
if err != nil {
dec.err = err
}
return n
}
// decodeTypeSequence parses:
// TypeSequence
//
// (TypeDefinition DelimitedTypeDefinition*)?
//
// and returns the type id of the next value. It returns -1 at
// EOF. Upon return, the remainder of dec.buf is the value to be
// decoded. If this is an interface value, it can be ignored by
// resetting that buffer.
func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId {
firstMessage := true
for dec.err == nil {
if dec.buf.Len() == 0 {
if !dec.recvMessage() {
// We can only return io.EOF if the input was empty.
// If we read one or more type spec messages,
// require a data item message to follow.
// If we hit an EOF before that, then give ErrUnexpectedEOF.
if !firstMessage && dec.err == io.EOF {
dec.err = io.ErrUnexpectedEOF
}
break
}
}
// Receive a type id.
id := typeId(dec.nextInt())
if id >= 0 {
// Value follows.
return id
}
// Type definition for (-id) follows.
dec.recvType(-id)
if dec.err != nil {
break
}
// When decoding an interface, after a type there may be a
// DelimitedValue still in the buffer. Skip its count.
// (Alternatively, the buffer is empty and the byte count
// will be absorbed by recvMessage.)
if dec.buf.Len() > 0 {
if !isInterface {
dec.err = errors.New("extra data in buffer")
break
}
dec.nextUint()
}
firstMessage = false
}
return -1
}
// Decode reads the next value from the input stream and stores
// it in the data represented by the empty interface value.
// If e is nil, the value will be discarded. Otherwise,
// the value underlying e must be a pointer to the
// correct type for the next data item received.
// If the input is at EOF, Decode returns [io.EOF] and
// does not modify e.
func (dec *Decoder) Decode(e any) error {
if e == nil {
return dec.DecodeValue(reflect.Value{})
}
value := reflect.ValueOf(e)
// If e represents a value as opposed to a pointer, the answer won't
// get back to the caller. Make sure it's a pointer.
if value.Kind() != reflect.Pointer {
dec.err = errors.New("gob: attempt to decode into a non-pointer")
return dec.err
}
return dec.DecodeValue(value)
}
// DecodeValue reads the next value from the input stream.
// If v is the zero reflect.Value (v.Kind() == Invalid), DecodeValue discards the value.
// Otherwise, it stores the value into v. In that case, v must represent
// a non-nil pointer to data or be an assignable reflect.Value (v.CanSet())
// If the input is at EOF, DecodeValue returns [io.EOF] and
// does not modify v.
func (dec *Decoder) DecodeValue(v reflect.Value) error {
if v.IsValid() {
if v.Kind() == reflect.Pointer && !v.IsNil() {
// That's okay, we'll store through the pointer.
} else if !v.CanSet() {
return errors.New("gob: DecodeValue of unassignable value")
}
}
// Make sure we're single-threaded through here.
dec.mutex.Lock()
defer dec.mutex.Unlock()
dec.buf.Reset() // In case data lingers from previous invocation.
dec.err = nil
id := dec.decodeTypeSequence(false)
if dec.err == nil {
dec.decodeValue(id, v)
}
return dec.err
}
// If debug.go is compiled into the program, debugFunc prints a human-readable
// representation of the gob data read from r by calling that file's Debug function.
// Otherwise it is nil.
var debugFunc func(io.Reader)
// Code generated by go run encgen.go -output enc_helpers.go; DO NOT EDIT.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gob
import (
"reflect"
)
var encArrayHelper = map[reflect.Kind]encHelper{
reflect.Bool: encBoolArray,
reflect.Complex64: encComplex64Array,
reflect.Complex128: encComplex128Array,
reflect.Float32: encFloat32Array,
reflect.Float64: encFloat64Array,
reflect.Int: encIntArray,
reflect.Int16: encInt16Array,
reflect.Int32: encInt32Array,
reflect.Int64: encInt64Array,
reflect.Int8: encInt8Array,
reflect.String: encStringArray,
reflect.Uint: encUintArray,
reflect.Uint16: encUint16Array,
reflect.Uint32: encUint32Array,
reflect.Uint64: encUint64Array,
reflect.Uintptr: encUintptrArray,
}
var encSliceHelper = map[reflect.Kind]encHelper{
reflect.Bool: encBoolSlice,
reflect.Complex64: encComplex64Slice,
reflect.Complex128: encComplex128Slice,
reflect.Float32: encFloat32Slice,
reflect.Float64: encFloat64Slice,
reflect.Int: encIntSlice,
reflect.Int16: encInt16Slice,
reflect.Int32: encInt32Slice,
reflect.Int64: encInt64Slice,
reflect.Int8: encInt8Slice,
reflect.String: encStringSlice,
reflect.Uint: encUintSlice,
reflect.Uint16: encUint16Slice,
reflect.Uint32: encUint32Slice,
reflect.Uint64: encUint64Slice,
reflect.Uintptr: encUintptrSlice,
}
func encBoolArray(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encBoolSlice(state, v.Slice(0, v.Len()))
}
func encBoolSlice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]bool)
if !ok {
// It is kind bool but not type bool. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != false || state.sendZero {
if x {
state.encodeUint(1)
} else {
state.encodeUint(0)
}
}
}
return true
}
func encComplex64Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encComplex64Slice(state, v.Slice(0, v.Len()))
}
func encComplex64Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]complex64)
if !ok {
// It is kind complex64 but not type complex64. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0+0i || state.sendZero {
rpart := floatBits(float64(real(x)))
ipart := floatBits(float64(imag(x)))
state.encodeUint(rpart)
state.encodeUint(ipart)
}
}
return true
}
func encComplex128Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encComplex128Slice(state, v.Slice(0, v.Len()))
}
func encComplex128Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]complex128)
if !ok {
// It is kind complex128 but not type complex128. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0+0i || state.sendZero {
rpart := floatBits(real(x))
ipart := floatBits(imag(x))
state.encodeUint(rpart)
state.encodeUint(ipart)
}
}
return true
}
func encFloat32Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encFloat32Slice(state, v.Slice(0, v.Len()))
}
func encFloat32Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]float32)
if !ok {
// It is kind float32 but not type float32. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
bits := floatBits(float64(x))
state.encodeUint(bits)
}
}
return true
}
func encFloat64Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encFloat64Slice(state, v.Slice(0, v.Len()))
}
func encFloat64Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]float64)
if !ok {
// It is kind float64 but not type float64. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
bits := floatBits(x)
state.encodeUint(bits)
}
}
return true
}
func encIntArray(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encIntSlice(state, v.Slice(0, v.Len()))
}
func encIntSlice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]int)
if !ok {
// It is kind int but not type int. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeInt(int64(x))
}
}
return true
}
func encInt16Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encInt16Slice(state, v.Slice(0, v.Len()))
}
func encInt16Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]int16)
if !ok {
// It is kind int16 but not type int16. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeInt(int64(x))
}
}
return true
}
func encInt32Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encInt32Slice(state, v.Slice(0, v.Len()))
}
func encInt32Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]int32)
if !ok {
// It is kind int32 but not type int32. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeInt(int64(x))
}
}
return true
}
func encInt64Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encInt64Slice(state, v.Slice(0, v.Len()))
}
func encInt64Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]int64)
if !ok {
// It is kind int64 but not type int64. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeInt(x)
}
}
return true
}
func encInt8Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encInt8Slice(state, v.Slice(0, v.Len()))
}
func encInt8Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]int8)
if !ok {
// It is kind int8 but not type int8. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeInt(int64(x))
}
}
return true
}
func encStringArray(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encStringSlice(state, v.Slice(0, v.Len()))
}
func encStringSlice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]string)
if !ok {
// It is kind string but not type string. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != "" || state.sendZero {
state.encodeUint(uint64(len(x)))
state.b.WriteString(x)
}
}
return true
}
func encUintArray(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encUintSlice(state, v.Slice(0, v.Len()))
}
func encUintSlice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]uint)
if !ok {
// It is kind uint but not type uint. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeUint(uint64(x))
}
}
return true
}
func encUint16Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encUint16Slice(state, v.Slice(0, v.Len()))
}
func encUint16Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]uint16)
if !ok {
// It is kind uint16 but not type uint16. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeUint(uint64(x))
}
}
return true
}
func encUint32Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encUint32Slice(state, v.Slice(0, v.Len()))
}
func encUint32Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]uint32)
if !ok {
// It is kind uint32 but not type uint32. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeUint(uint64(x))
}
}
return true
}
func encUint64Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encUint64Slice(state, v.Slice(0, v.Len()))
}
func encUint64Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]uint64)
if !ok {
// It is kind uint64 but not type uint64. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeUint(x)
}
}
return true
}
func encUintptrArray(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encUintptrSlice(state, v.Slice(0, v.Len()))
}
func encUintptrSlice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]uintptr)
if !ok {
// It is kind uintptr but not type uintptr. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeUint(uint64(x))
}
}
return true
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate go run encgen.go -output enc_helpers.go
package gob
import (
"encoding"
"encoding/binary"
"math"
"math/bits"
"reflect"
"sync"
)
const uint64Size = 8
type encHelper func(state *encoderState, v reflect.Value) bool
// encoderState is the global execution state of an instance of the encoder.
// Field numbers are delta encoded and always increase. The field
// number is initialized to -1 so 0 comes out as delta(1). A delta of
// 0 terminates the structure.
type encoderState struct {
enc *Encoder
b *encBuffer
sendZero bool // encoding an array element or map key/value pair; send zero values
fieldnum int // the last field number written.
buf [1 + uint64Size]byte // buffer used by the encoder; here to avoid allocation.
next *encoderState // for free list
}
// encBuffer is an extremely simple, fast implementation of a write-only byte buffer.
// It never returns a non-nil error, but Write returns an error value so it matches io.Writer.
type encBuffer struct {
data []byte
scratch [64]byte
}
var encBufferPool = sync.Pool{
New: func() any {
e := new(encBuffer)
e.data = e.scratch[0:0]
return e
},
}
func (e *encBuffer) writeByte(c byte) {
e.data = append(e.data, c)
}
func (e *encBuffer) Write(p []byte) (int, error) {
e.data = append(e.data, p...)
return len(p), nil
}
func (e *encBuffer) WriteString(s string) {
e.data = append(e.data, s...)
}
func (e *encBuffer) Len() int {
return len(e.data)
}
func (e *encBuffer) Bytes() []byte {
return e.data
}
func (e *encBuffer) Reset() {
if len(e.data) >= tooBig {
e.data = e.scratch[0:0]
} else {
e.data = e.data[0:0]
}
}
func (enc *Encoder) newEncoderState(b *encBuffer) *encoderState {
e := enc.freeList
if e == nil {
e = new(encoderState)
e.enc = enc
} else {
enc.freeList = e.next
}
e.sendZero = false
e.fieldnum = 0
e.b = b
if len(b.data) == 0 {
b.data = b.scratch[0:0]
}
return e
}
func (enc *Encoder) freeEncoderState(e *encoderState) {
e.next = enc.freeList
enc.freeList = e
}
// Unsigned integers have a two-state encoding. If the number is less
// than 128 (0 through 0x7F), its value is written directly.
// Otherwise the value is written in big-endian byte order preceded
// by the byte length, negated.
// encodeUint writes an encoded unsigned integer to state.b.
func (state *encoderState) encodeUint(x uint64) {
if x <= 0x7F {
state.b.writeByte(uint8(x))
return
}
binary.BigEndian.PutUint64(state.buf[1:], x)
bc := bits.LeadingZeros64(x) >> 3 // 8 - bytelen(x)
state.buf[bc] = uint8(bc - uint64Size) // and then we subtract 8 to get -bytelen(x)
state.b.Write(state.buf[bc : uint64Size+1])
}
// encodeInt writes an encoded signed integer to state.w.
// The low bit of the encoding says whether to bit complement the (other bits of the)
// uint to recover the int.
func (state *encoderState) encodeInt(i int64) {
var x uint64
if i < 0 {
x = uint64(^i<<1) | 1
} else {
x = uint64(i << 1)
}
state.encodeUint(x)
}
// encOp is the signature of an encoding operator for a given type.
type encOp func(i *encInstr, state *encoderState, v reflect.Value)
// The 'instructions' of the encoding machine
type encInstr struct {
op encOp
field int // field number in input
index []int // struct index
indir int // how many pointer indirections to reach the value in the struct
}
// update emits a field number and updates the state to record its value for delta encoding.
// If the instruction pointer is nil, it does nothing
func (state *encoderState) update(instr *encInstr) {
if instr != nil {
state.encodeUint(uint64(instr.field - state.fieldnum))
state.fieldnum = instr.field
}
}
// Each encoder for a composite is responsible for handling any
// indirections associated with the elements of the data structure.
// If any pointer so reached is nil, no bytes are written. If the
// data item is zero, no bytes are written. Single values - ints,
// strings etc. - are indirected before calling their encoders.
// Otherwise, the output (for a scalar) is the field number, as an
// encoded integer, followed by the field data in its appropriate
// format.
// encIndirect dereferences pv indir times and returns the result.
func encIndirect(pv reflect.Value, indir int) reflect.Value {
for ; indir > 0; indir-- {
if pv.IsNil() {
break
}
pv = pv.Elem()
}
return pv
}
// encBool encodes the bool referenced by v as an unsigned 0 or 1.
func encBool(i *encInstr, state *encoderState, v reflect.Value) {
b := v.Bool()
if b || state.sendZero {
state.update(i)
if b {
state.encodeUint(1)
} else {
state.encodeUint(0)
}
}
}
// encInt encodes the signed integer (int int8 int16 int32 int64) referenced by v.
func encInt(i *encInstr, state *encoderState, v reflect.Value) {
value := v.Int()
if value != 0 || state.sendZero {
state.update(i)
state.encodeInt(value)
}
}
// encUint encodes the unsigned integer (uint uint8 uint16 uint32 uint64 uintptr) referenced by v.
func encUint(i *encInstr, state *encoderState, v reflect.Value) {
value := v.Uint()
if value != 0 || state.sendZero {
state.update(i)
state.encodeUint(value)
}
}
// floatBits returns a uint64 holding the bits of a floating-point number.
// Floating-point numbers are transmitted as uint64s holding the bits
// of the underlying representation. They are sent byte-reversed, with
// the exponent end coming out first, so integer floating point numbers
// (for example) transmit more compactly. This routine does the
// swizzling.
func floatBits(f float64) uint64 {
u := math.Float64bits(f)
return bits.ReverseBytes64(u)
}
// encFloat encodes the floating point value (float32 float64) referenced by v.
func encFloat(i *encInstr, state *encoderState, v reflect.Value) {
f := v.Float()
if f != 0 || state.sendZero {
bits := floatBits(f)
state.update(i)
state.encodeUint(bits)
}
}
// encComplex encodes the complex value (complex64 complex128) referenced by v.
// Complex numbers are just a pair of floating-point numbers, real part first.
func encComplex(i *encInstr, state *encoderState, v reflect.Value) {
c := v.Complex()
if c != 0+0i || state.sendZero {
rpart := floatBits(real(c))
ipart := floatBits(imag(c))
state.update(i)
state.encodeUint(rpart)
state.encodeUint(ipart)
}
}
// encUint8Array encodes the byte array referenced by v.
// Byte arrays are encoded as an unsigned count followed by the raw bytes.
func encUint8Array(i *encInstr, state *encoderState, v reflect.Value) {
b := v.Bytes()
if len(b) > 0 || state.sendZero {
state.update(i)
state.encodeUint(uint64(len(b)))
state.b.Write(b)
}
}
// encString encodes the string referenced by v.
// Strings are encoded as an unsigned count followed by the raw bytes.
func encString(i *encInstr, state *encoderState, v reflect.Value) {
s := v.String()
if len(s) > 0 || state.sendZero {
state.update(i)
state.encodeUint(uint64(len(s)))
state.b.WriteString(s)
}
}
// encStructTerminator encodes the end of an encoded struct
// as delta field number of 0.
func encStructTerminator(i *encInstr, state *encoderState, v reflect.Value) {
state.encodeUint(0)
}
// Execution engine
// encEngine an array of instructions indexed by field number of the encoding
// data, typically a struct. It is executed top to bottom, walking the struct.
type encEngine struct {
instr []encInstr
}
const singletonField = 0
// valid reports whether the value is valid and a non-nil pointer.
// (Slices, maps, and chans take care of themselves.)
func valid(v reflect.Value) bool {
switch v.Kind() {
case reflect.Invalid:
return false
case reflect.Pointer:
return !v.IsNil()
}
return true
}
// encodeSingle encodes a single top-level non-struct value.
func (enc *Encoder) encodeSingle(b *encBuffer, engine *encEngine, value reflect.Value) {
state := enc.newEncoderState(b)
defer enc.freeEncoderState(state)
state.fieldnum = singletonField
// There is no surrounding struct to frame the transmission, so we must
// generate data even if the item is zero. To do this, set sendZero.
state.sendZero = true
instr := &engine.instr[singletonField]
if instr.indir > 0 {
value = encIndirect(value, instr.indir)
}
if valid(value) {
instr.op(instr, state, value)
}
}
// encodeStruct encodes a single struct value.
func (enc *Encoder) encodeStruct(b *encBuffer, engine *encEngine, value reflect.Value) {
if !valid(value) {
return
}
state := enc.newEncoderState(b)
defer enc.freeEncoderState(state)
state.fieldnum = -1
for i := 0; i < len(engine.instr); i++ {
instr := &engine.instr[i]
if i >= value.NumField() {
// encStructTerminator
instr.op(instr, state, reflect.Value{})
break
}
field := value.FieldByIndex(instr.index)
if instr.indir > 0 {
field = encIndirect(field, instr.indir)
// TODO: Is field guaranteed valid? If so we could avoid this check.
if !valid(field) {
continue
}
}
instr.op(instr, state, field)
}
}
// encodeArray encodes an array.
func (enc *Encoder) encodeArray(b *encBuffer, value reflect.Value, op encOp, elemIndir int, length int, helper encHelper) {
state := enc.newEncoderState(b)
defer enc.freeEncoderState(state)
state.fieldnum = -1
state.sendZero = true
state.encodeUint(uint64(length))
if helper != nil && helper(state, value) {
return
}
for i := 0; i < length; i++ {
elem := value.Index(i)
if elemIndir > 0 {
elem = encIndirect(elem, elemIndir)
// TODO: Is elem guaranteed valid? If so we could avoid this check.
if !valid(elem) {
errorf("encodeArray: nil element")
}
}
op(nil, state, elem)
}
}
// encodeReflectValue is a helper for maps. It encodes the value v.
func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) {
for i := 0; i < indir && v.IsValid(); i++ {
v = reflect.Indirect(v)
}
if !v.IsValid() {
errorf("encodeReflectValue: nil element")
}
op(nil, state, v)
}
// encodeMap encodes a map as unsigned count followed by key:value pairs.
func (enc *Encoder) encodeMap(b *encBuffer, mv reflect.Value, keyOp, elemOp encOp, keyIndir, elemIndir int) {
state := enc.newEncoderState(b)
state.fieldnum = -1
state.sendZero = true
state.encodeUint(uint64(mv.Len()))
mi := mv.MapRange()
for mi.Next() {
encodeReflectValue(state, mi.Key(), keyOp, keyIndir)
encodeReflectValue(state, mi.Value(), elemOp, elemIndir)
}
enc.freeEncoderState(state)
}
// encodeInterface encodes the interface value iv.
// To send an interface, we send a string identifying the concrete type, followed
// by the type identifier (which might require defining that type right now), followed
// by the concrete value. A nil value gets sent as the empty string for the name,
// followed by no value.
func (enc *Encoder) encodeInterface(b *encBuffer, iv reflect.Value) {
// Gobs can encode nil interface values but not typed interface
// values holding nil pointers, since nil pointers point to no value.
elem := iv.Elem()
if elem.Kind() == reflect.Pointer && elem.IsNil() {
errorf("gob: cannot encode nil pointer of type %s inside interface", iv.Elem().Type())
}
state := enc.newEncoderState(b)
state.fieldnum = -1
state.sendZero = true
if iv.IsNil() {
state.encodeUint(0)
return
}
ut := userType(iv.Elem().Type())
namei, ok := concreteTypeToName.Load(ut.base)
if !ok {
errorf("type not registered for interface: %s", ut.base)
}
name := namei.(string)
// Send the name.
state.encodeUint(uint64(len(name)))
state.b.WriteString(name)
// Define the type id if necessary.
enc.sendTypeDescriptor(enc.writer(), state, ut)
// Send the type id.
enc.sendTypeId(state, ut)
// Encode the value into a new buffer. Any nested type definitions
// should be written to b, before the encoded value.
enc.pushWriter(b)
data := encBufferPool.Get().(*encBuffer)
data.Write(spaceForLength)
enc.encode(data, elem, ut)
if enc.err != nil {
error_(enc.err)
}
enc.popWriter()
enc.writeMessage(b, data)
data.Reset()
encBufferPool.Put(data)
if enc.err != nil {
error_(enc.err)
}
enc.freeEncoderState(state)
}
// encodeGobEncoder encodes a value that implements the GobEncoder interface.
// The data is sent as a byte array.
func (enc *Encoder) encodeGobEncoder(b *encBuffer, ut *userTypeInfo, v reflect.Value) {
// TODO: should we catch panics from the called method?
var data []byte
var err error
// We know it's one of these.
switch ut.externalEnc {
case xGob:
data, err = v.Interface().(GobEncoder).GobEncode()
case xBinary:
data, err = v.Interface().(encoding.BinaryMarshaler).MarshalBinary()
case xText:
data, err = v.Interface().(encoding.TextMarshaler).MarshalText()
}
if err != nil {
error_(err)
}
state := enc.newEncoderState(b)
state.fieldnum = -1
state.encodeUint(uint64(len(data)))
state.b.Write(data)
enc.freeEncoderState(state)
}
var encOpTable = [...]encOp{
reflect.Bool: encBool,
reflect.Int: encInt,
reflect.Int8: encInt,
reflect.Int16: encInt,
reflect.Int32: encInt,
reflect.Int64: encInt,
reflect.Uint: encUint,
reflect.Uint8: encUint,
reflect.Uint16: encUint,
reflect.Uint32: encUint,
reflect.Uint64: encUint,
reflect.Uintptr: encUint,
reflect.Float32: encFloat,
reflect.Float64: encFloat,
reflect.Complex64: encComplex,
reflect.Complex128: encComplex,
reflect.String: encString,
}
// encOpFor returns (a pointer to) the encoding op for the base type under rt and
// the indirection count to reach it.
func encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp, building map[*typeInfo]bool) (*encOp, int) {
ut := userType(rt)
// If the type implements GobEncoder, we handle it without further processing.
if ut.externalEnc != 0 {
return gobEncodeOpFor(ut)
}
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building.
if opPtr := inProgress[rt]; opPtr != nil {
return opPtr, ut.indir
}
typ := ut.base
indir := ut.indir
k := typ.Kind()
var op encOp
if int(k) < len(encOpTable) {
op = encOpTable[k]
}
if op == nil {
inProgress[rt] = &op
// Special cases
switch t := typ; t.Kind() {
case reflect.Slice:
if t.Elem().Kind() == reflect.Uint8 {
op = encUint8Array
break
}
// Slices have a header; we decode it to find the underlying array.
elemOp, elemIndir := encOpFor(t.Elem(), inProgress, building)
helper := encSliceHelper[t.Elem().Kind()]
op = func(i *encInstr, state *encoderState, slice reflect.Value) {
if !state.sendZero && slice.Len() == 0 {
return
}
state.update(i)
state.enc.encodeArray(state.b, slice, *elemOp, elemIndir, slice.Len(), helper)
}
case reflect.Array:
// True arrays have size in the type.
elemOp, elemIndir := encOpFor(t.Elem(), inProgress, building)
helper := encArrayHelper[t.Elem().Kind()]
op = func(i *encInstr, state *encoderState, array reflect.Value) {
state.update(i)
state.enc.encodeArray(state.b, array, *elemOp, elemIndir, array.Len(), helper)
}
case reflect.Map:
keyOp, keyIndir := encOpFor(t.Key(), inProgress, building)
elemOp, elemIndir := encOpFor(t.Elem(), inProgress, building)
op = func(i *encInstr, state *encoderState, mv reflect.Value) {
// We send zero-length (but non-nil) maps because the
// receiver might want to use the map. (Maps don't use append.)
if !state.sendZero && mv.IsNil() {
return
}
state.update(i)
state.enc.encodeMap(state.b, mv, *keyOp, *elemOp, keyIndir, elemIndir)
}
case reflect.Struct:
// Generate a closure that calls out to the engine for the nested type.
getEncEngine(userType(typ), building)
info := mustGetTypeInfo(typ)
op = func(i *encInstr, state *encoderState, sv reflect.Value) {
state.update(i)
// indirect through info to delay evaluation for recursive structs
enc := info.encoder.Load()
state.enc.encodeStruct(state.b, enc, sv)
}
case reflect.Interface:
op = func(i *encInstr, state *encoderState, iv reflect.Value) {
if !state.sendZero && (!iv.IsValid() || iv.IsNil()) {
return
}
state.update(i)
state.enc.encodeInterface(state.b, iv)
}
}
}
if op == nil {
errorf("can't happen: encode type %s", rt)
}
return &op, indir
}
// gobEncodeOpFor returns the op for a type that is known to implement GobEncoder.
func gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) {
rt := ut.user
if ut.encIndir == -1 {
rt = reflect.PointerTo(rt)
} else if ut.encIndir > 0 {
for i := int8(0); i < ut.encIndir; i++ {
rt = rt.Elem()
}
}
var op encOp
op = func(i *encInstr, state *encoderState, v reflect.Value) {
if ut.encIndir == -1 {
// Need to climb up one level to turn value into pointer.
if !v.CanAddr() {
errorf("unaddressable value of type %s", rt)
}
v = v.Addr()
}
if !state.sendZero && v.IsZero() {
return
}
state.update(i)
state.enc.encodeGobEncoder(state.b, ut, v)
}
return &op, int(ut.encIndir) // encIndir: op will get called with p == address of receiver.
}
// compileEnc returns the engine to compile the type.
func compileEnc(ut *userTypeInfo, building map[*typeInfo]bool) *encEngine {
srt := ut.base
engine := new(encEngine)
seen := make(map[reflect.Type]*encOp)
rt := ut.base
if ut.externalEnc != 0 {
rt = ut.user
}
if ut.externalEnc == 0 && srt.Kind() == reflect.Struct {
for fieldNum, wireFieldNum := 0, 0; fieldNum < srt.NumField(); fieldNum++ {
f := srt.Field(fieldNum)
if !isSent(&f) {
continue
}
op, indir := encOpFor(f.Type, seen, building)
engine.instr = append(engine.instr, encInstr{*op, wireFieldNum, f.Index, indir})
wireFieldNum++
}
if srt.NumField() > 0 && len(engine.instr) == 0 {
errorf("type %s has no exported fields", rt)
}
engine.instr = append(engine.instr, encInstr{encStructTerminator, 0, nil, 0})
} else {
engine.instr = make([]encInstr, 1)
op, indir := encOpFor(rt, seen, building)
engine.instr[0] = encInstr{*op, singletonField, nil, indir}
}
return engine
}
// getEncEngine returns the engine to compile the type.
func getEncEngine(ut *userTypeInfo, building map[*typeInfo]bool) *encEngine {
info, err := getTypeInfo(ut)
if err != nil {
error_(err)
}
enc := info.encoder.Load()
if enc == nil {
enc = buildEncEngine(info, ut, building)
}
return enc
}
func buildEncEngine(info *typeInfo, ut *userTypeInfo, building map[*typeInfo]bool) *encEngine {
// Check for recursive types.
if building != nil && building[info] {
return nil
}
info.encInit.Lock()
defer info.encInit.Unlock()
enc := info.encoder.Load()
if enc == nil {
if building == nil {
building = make(map[*typeInfo]bool)
}
building[info] = true
enc = compileEnc(ut, building)
info.encoder.Store(enc)
}
return enc
}
func (enc *Encoder) encode(b *encBuffer, value reflect.Value, ut *userTypeInfo) {
defer catchError(&enc.err)
engine := getEncEngine(ut, nil)
indir := ut.indir
if ut.externalEnc != 0 {
indir = int(ut.encIndir)
}
for i := 0; i < indir; i++ {
value = reflect.Indirect(value)
}
if ut.externalEnc == 0 && value.Kind() == reflect.Struct {
enc.encodeStruct(b, engine, value)
} else {
enc.encodeSingle(b, engine, value)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gob
import (
"errors"
"io"
"reflect"
"sync"
)
// An Encoder manages the transmission of type and data information to the
// other side of a connection. It is safe for concurrent use by multiple
// goroutines.
type Encoder struct {
mutex sync.Mutex // each item must be sent atomically
w []io.Writer // where to send the data
sent map[reflect.Type]typeId // which types we've already sent
countState *encoderState // stage for writing counts
freeList *encoderState // list of free encoderStates; avoids reallocation
byteBuf encBuffer // buffer for top-level encoderState
err error
}
// Before we encode a message, we reserve space at the head of the
// buffer in which to encode its length. This means we can use the
// buffer to assemble the message without another allocation.
const maxLength = 9 // Maximum size of an encoded length.
var spaceForLength = make([]byte, maxLength)
// NewEncoder returns a new encoder that will transmit on the [io.Writer].
func NewEncoder(w io.Writer) *Encoder {
enc := new(Encoder)
enc.w = []io.Writer{w}
enc.sent = make(map[reflect.Type]typeId)
enc.countState = enc.newEncoderState(new(encBuffer))
return enc
}
// writer returns the innermost writer the encoder is using.
func (enc *Encoder) writer() io.Writer {
return enc.w[len(enc.w)-1]
}
// pushWriter adds a writer to the encoder.
func (enc *Encoder) pushWriter(w io.Writer) {
enc.w = append(enc.w, w)
}
// popWriter pops the innermost writer.
func (enc *Encoder) popWriter() {
enc.w = enc.w[0 : len(enc.w)-1]
}
func (enc *Encoder) setError(err error) {
if enc.err == nil { // remember the first.
enc.err = err
}
}
// writeMessage sends the data item preceded by an unsigned count of its length.
func (enc *Encoder) writeMessage(w io.Writer, b *encBuffer) {
// Space has been reserved for the length at the head of the message.
// This is a little dirty: we grab the slice from the bytes.Buffer and massage
// it by hand.
message := b.Bytes()
messageLen := len(message) - maxLength
// Length cannot be bigger than the decoder can handle.
if messageLen >= tooBig {
enc.setError(errors.New("gob: encoder: message too big"))
return
}
// Encode the length.
enc.countState.b.Reset()
enc.countState.encodeUint(uint64(messageLen))
// Copy the length to be a prefix of the message.
offset := maxLength - enc.countState.b.Len()
copy(message[offset:], enc.countState.b.Bytes())
// Write the data.
_, err := w.Write(message[offset:])
// Drain the buffer and restore the space at the front for the count of the next message.
b.Reset()
b.Write(spaceForLength)
if err != nil {
enc.setError(err)
}
}
// sendActualType sends the requested type, without further investigation, unless
// it's been sent before.
func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) {
if _, alreadySent := enc.sent[actual]; alreadySent {
return false
}
info, err := getTypeInfo(ut)
if err != nil {
enc.setError(err)
return
}
// Send the pair (-id, type)
// Id:
state.encodeInt(-int64(info.id))
// Type:
enc.encode(state.b, reflect.ValueOf(info.wire), wireTypeUserInfo)
enc.writeMessage(w, state.b)
if enc.err != nil {
return
}
// Remember we've sent this type, both what the user gave us and the base type.
enc.sent[ut.base] = info.id
if ut.user != ut.base {
enc.sent[ut.user] = info.id
}
// Now send the inner types
switch st := actual; st.Kind() {
case reflect.Struct:
for i := 0; i < st.NumField(); i++ {
if isExported(st.Field(i).Name) {
enc.sendType(w, state, st.Field(i).Type)
}
}
case reflect.Array, reflect.Slice:
enc.sendType(w, state, st.Elem())
case reflect.Map:
enc.sendType(w, state, st.Key())
enc.sendType(w, state, st.Elem())
}
return true
}
// sendType sends the type info to the other side, if necessary.
func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
ut := userType(origt)
if ut.externalEnc != 0 {
// The rules are different: regardless of the underlying type's representation,
// we need to tell the other side that the base type is a GobEncoder.
return enc.sendActualType(w, state, ut, ut.base)
}
// It's a concrete value, so drill down to the base type.
switch rt := ut.base; rt.Kind() {
default:
// Basic types and interfaces do not need to be described.
return
case reflect.Slice:
// If it's []uint8, don't send; it's considered basic.
if rt.Elem().Kind() == reflect.Uint8 {
return
}
// Otherwise we do send.
break
case reflect.Array:
// arrays must be sent so we know their lengths and element types.
break
case reflect.Map:
// maps must be sent so we know their lengths and key/value types.
break
case reflect.Struct:
// structs must be sent so we know their fields.
break
case reflect.Chan, reflect.Func:
// If we get here, it's a field of a struct; ignore it.
return
}
return enc.sendActualType(w, state, ut, ut.base)
}
// Encode transmits the data item represented by the empty interface value,
// guaranteeing that all necessary type information has been transmitted first.
// Passing a nil pointer to Encoder will panic, as they cannot be transmitted by gob.
func (enc *Encoder) Encode(e any) error {
return enc.EncodeValue(reflect.ValueOf(e))
}
// sendTypeDescriptor makes sure the remote side knows about this type.
// It will send a descriptor if this is the first time the type has been
// sent.
func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) {
// Make sure the type is known to the other side.
// First, have we already sent this type?
rt := ut.base
if ut.externalEnc != 0 {
rt = ut.user
}
if _, alreadySent := enc.sent[rt]; !alreadySent {
// No, so send it.
sent := enc.sendType(w, state, rt)
if enc.err != nil {
return
}
// If the type info has still not been transmitted, it means we have
// a singleton basic type (int, []byte etc.) at top level. We don't
// need to send the type info but we do need to update enc.sent.
if !sent {
info, err := getTypeInfo(ut)
if err != nil {
enc.setError(err)
return
}
enc.sent[rt] = info.id
}
}
}
// sendTypeId sends the id, which must have already been defined.
func (enc *Encoder) sendTypeId(state *encoderState, ut *userTypeInfo) {
// Identify the type of this top-level value.
state.encodeInt(int64(enc.sent[ut.base]))
}
// EncodeValue transmits the data item represented by the reflection value,
// guaranteeing that all necessary type information has been transmitted first.
// Passing a nil pointer to EncodeValue will panic, as they cannot be transmitted by gob.
func (enc *Encoder) EncodeValue(value reflect.Value) error {
if value.Kind() == reflect.Invalid {
return errors.New("gob: cannot encode nil value")
}
if value.Kind() == reflect.Pointer && value.IsNil() {
panic("gob: cannot encode nil pointer of type " + value.Type().String())
}
// Make sure we're single-threaded through here, so multiple
// goroutines can share an encoder.
enc.mutex.Lock()
defer enc.mutex.Unlock()
// Remove any nested writers remaining due to previous errors.
enc.w = enc.w[0:1]
ut, err := validUserType(value.Type())
if err != nil {
return err
}
enc.err = nil
enc.byteBuf.Reset()
enc.byteBuf.Write(spaceForLength)
state := enc.newEncoderState(&enc.byteBuf)
enc.sendTypeDescriptor(enc.writer(), state, ut)
enc.sendTypeId(state, ut)
if enc.err != nil {
return enc.err
}
// Encode the object.
enc.encode(state.b, value, ut)
if enc.err == nil {
enc.writeMessage(enc.writer(), state.b)
}
enc.freeEncoderState(state)
return enc.err
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gob
import "fmt"
// Errors in decoding and encoding are handled using panic and recover.
// Panics caused by user error (that is, everything except run-time panics
// such as "index out of bounds" errors) do not leave the file that caused
// them, but are instead turned into plain error returns. Encoding and
// decoding functions and methods that do not return an error either use
// panic to report an error or are guaranteed error-free.
// A gobError is used to distinguish errors (panics) generated in this package.
type gobError struct {
err error
}
// errorf is like error_ but takes Printf-style arguments to construct an error.
// It always prefixes the message with "gob: ".
func errorf(format string, args ...any) {
error_(fmt.Errorf("gob: "+format, args...))
}
// error_ wraps the argument error and uses it as the argument to panic.
func error_(err error) {
panic(gobError{err})
}
// catchError is meant to be used as a deferred function to turn a panic(gobError) into a
// plain error. It overwrites the error return of the function that deferred its call.
func catchError(err *error) {
if e := recover(); e != nil {
ge, ok := e.(gobError)
if !ok {
panic(e)
}
*err = ge.err
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gob
import (
"encoding"
"errors"
"fmt"
"maps"
"os"
"reflect"
"sync"
"sync/atomic"
"unicode"
"unicode/utf8"
)
// userTypeInfo stores the information associated with a type the user has handed
// to the package. It's computed once and stored in a map keyed by reflection
// type.
type userTypeInfo struct {
user reflect.Type // the type the user handed us
base reflect.Type // the base type after all indirections
indir int // number of indirections to reach the base type
externalEnc int // xGob, xBinary, or xText
externalDec int // xGob, xBinary, or xText
encIndir int8 // number of indirections to reach the receiver type; may be negative
decIndir int8 // number of indirections to reach the receiver type; may be negative
}
// externalEncoding bits
const (
xGob = 1 + iota // GobEncoder or GobDecoder
xBinary // encoding.BinaryMarshaler or encoding.BinaryUnmarshaler
xText // encoding.TextMarshaler or encoding.TextUnmarshaler
)
var userTypeCache sync.Map // map[reflect.Type]*userTypeInfo
// validUserType returns, and saves, the information associated with user-provided type rt.
// If the user type is not valid, err will be non-nil. To be used when the error handler
// is not set up.
func validUserType(rt reflect.Type) (*userTypeInfo, error) {
if ui, ok := userTypeCache.Load(rt); ok {
return ui.(*userTypeInfo), nil
}
// Construct a new userTypeInfo and atomically add it to the userTypeCache.
// If we lose the race, we'll waste a little CPU and create a little garbage
// but return the existing value anyway.
ut := new(userTypeInfo)
ut.base = rt
ut.user = rt
// A type that is just a cycle of pointers (such as type T *T) cannot
// be represented in gobs, which need some concrete data. We use a
// cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6,
// pp 539-540. As we step through indirections, run another type at
// half speed. If they meet up, there's a cycle.
slowpoke := ut.base // walks half as fast as ut.base
for {
pt := ut.base
if pt.Kind() != reflect.Pointer {
break
}
ut.base = pt.Elem()
if ut.base == slowpoke { // ut.base lapped slowpoke
// recursive pointer type.
return nil, errors.New("can't represent recursive pointer type " + ut.base.String())
}
if ut.indir%2 == 0 {
slowpoke = slowpoke.Elem()
}
ut.indir++
}
if ok, indir := implementsInterface(ut.user, gobEncoderInterfaceType); ok {
ut.externalEnc, ut.encIndir = xGob, indir
} else if ok, indir := implementsInterface(ut.user, binaryMarshalerInterfaceType); ok {
ut.externalEnc, ut.encIndir = xBinary, indir
}
// NOTE(rsc): Would like to allow MarshalText here, but results in incompatibility
// with older encodings for net.IP. See golang.org/issue/6760.
// } else if ok, indir := implementsInterface(ut.user, textMarshalerInterfaceType); ok {
// ut.externalEnc, ut.encIndir = xText, indir
// }
if ok, indir := implementsInterface(ut.user, gobDecoderInterfaceType); ok {
ut.externalDec, ut.decIndir = xGob, indir
} else if ok, indir := implementsInterface(ut.user, binaryUnmarshalerInterfaceType); ok {
ut.externalDec, ut.decIndir = xBinary, indir
}
// See note above.
// } else if ok, indir := implementsInterface(ut.user, textUnmarshalerInterfaceType); ok {
// ut.externalDec, ut.decIndir = xText, indir
// }
ui, _ := userTypeCache.LoadOrStore(rt, ut)
return ui.(*userTypeInfo), nil
}
var (
gobEncoderInterfaceType = reflect.TypeFor[GobEncoder]()
gobDecoderInterfaceType = reflect.TypeFor[GobDecoder]()
binaryMarshalerInterfaceType = reflect.TypeFor[encoding.BinaryMarshaler]()
binaryUnmarshalerInterfaceType = reflect.TypeFor[encoding.BinaryUnmarshaler]()
textMarshalerInterfaceType = reflect.TypeFor[encoding.TextMarshaler]()
textUnmarshalerInterfaceType = reflect.TypeFor[encoding.TextUnmarshaler]()
wireTypeType = reflect.TypeFor[wireType]()
)
// implementsInterface reports whether the type implements the
// gobEncoder/gobDecoder interface.
// It also returns the number of indirections required to get to the
// implementation.
func implementsInterface(typ, gobEncDecType reflect.Type) (success bool, indir int8) {
if typ == nil {
return
}
rt := typ
// The type might be a pointer and we need to keep
// dereferencing to the base type until we find an implementation.
for {
if rt.Implements(gobEncDecType) {
return true, indir
}
if p := rt; p.Kind() == reflect.Pointer {
indir++
if indir > 100 { // insane number of indirections
return false, 0
}
rt = p.Elem()
continue
}
break
}
// No luck yet, but if this is a base type (non-pointer), the pointer might satisfy.
if typ.Kind() != reflect.Pointer {
// Not a pointer, but does the pointer work?
if reflect.PointerTo(typ).Implements(gobEncDecType) {
return true, -1
}
}
return false, 0
}
// userType returns, and saves, the information associated with user-provided type rt.
// If the user type is not valid, it calls error.
func userType(rt reflect.Type) *userTypeInfo {
ut, err := validUserType(rt)
if err != nil {
error_(err)
}
return ut
}
// A typeId represents a gob Type as an integer that can be passed on the wire.
// Internally, typeIds are used as keys to a map to recover the underlying type info.
type typeId int32
var typeLock sync.Mutex // set while building a type
const firstUserId = 64 // lowest id number granted to user
type gobType interface {
id() typeId
setId(id typeId)
name() string
string() string // not public; only for debugging
safeString(seen map[typeId]bool) string
}
var (
types = make(map[reflect.Type]gobType, 32)
idToTypeSlice = make([]gobType, 1, firstUserId)
builtinIdToTypeSlice [firstUserId]gobType // set in init() after builtins are established
)
func idToType(id typeId) gobType {
if id < 0 || int(id) >= len(idToTypeSlice) {
return nil
}
return idToTypeSlice[id]
}
func builtinIdToType(id typeId) gobType {
if id < 0 || int(id) >= len(builtinIdToTypeSlice) {
return nil
}
return builtinIdToTypeSlice[id]
}
func setTypeId(typ gobType) {
// When building recursive types, someone may get there before us.
if typ.id() != 0 {
return
}
nextId := typeId(len(idToTypeSlice))
typ.setId(nextId)
idToTypeSlice = append(idToTypeSlice, typ)
}
func (t typeId) gobType() gobType {
if t == 0 {
return nil
}
return idToType(t)
}
// string returns the string representation of the type associated with the typeId.
func (t typeId) string() string {
if t.gobType() == nil {
return "<nil>"
}
return t.gobType().string()
}
// Name returns the name of the type associated with the typeId.
func (t typeId) name() string {
if t.gobType() == nil {
return "<nil>"
}
return t.gobType().name()
}
// CommonType holds elements of all types.
// It is a historical artifact, kept for binary compatibility and exported
// only for the benefit of the package's encoding of type descriptors. It is
// not intended for direct use by clients.
type CommonType struct {
Name string
Id typeId
}
func (t *CommonType) id() typeId { return t.Id }
func (t *CommonType) setId(id typeId) { t.Id = id }
func (t *CommonType) string() string { return t.Name }
func (t *CommonType) safeString(seen map[typeId]bool) string {
return t.Name
}
func (t *CommonType) name() string { return t.Name }
// Create and check predefined types
// The string for tBytes is "bytes" not "[]byte" to signify its specialness.
var (
// Primordial types, needed during initialization.
// Always passed as pointers so the interface{} type
// goes through without losing its interfaceness.
tBool = bootstrapType("bool", (*bool)(nil))
tInt = bootstrapType("int", (*int)(nil))
tUint = bootstrapType("uint", (*uint)(nil))
tFloat = bootstrapType("float", (*float64)(nil))
tBytes = bootstrapType("bytes", (*[]byte)(nil))
tString = bootstrapType("string", (*string)(nil))
tComplex = bootstrapType("complex", (*complex128)(nil))
tInterface = bootstrapType("interface", (*any)(nil))
// Reserve some Ids for compatible expansion
tReserved7 = bootstrapType("_reserved1", (*struct{ r7 int })(nil))
tReserved6 = bootstrapType("_reserved1", (*struct{ r6 int })(nil))
tReserved5 = bootstrapType("_reserved1", (*struct{ r5 int })(nil))
tReserved4 = bootstrapType("_reserved1", (*struct{ r4 int })(nil))
tReserved3 = bootstrapType("_reserved1", (*struct{ r3 int })(nil))
tReserved2 = bootstrapType("_reserved1", (*struct{ r2 int })(nil))
tReserved1 = bootstrapType("_reserved1", (*struct{ r1 int })(nil))
)
// Predefined because it's needed by the Decoder
var tWireType = mustGetTypeInfo(wireTypeType).id
var wireTypeUserInfo *userTypeInfo // userTypeInfo of wireType
func init() {
// Some magic numbers to make sure there are no surprises.
checkId(16, tWireType)
checkId(17, mustGetTypeInfo(reflect.TypeFor[arrayType]()).id)
checkId(18, mustGetTypeInfo(reflect.TypeFor[CommonType]()).id)
checkId(19, mustGetTypeInfo(reflect.TypeFor[sliceType]()).id)
checkId(20, mustGetTypeInfo(reflect.TypeFor[structType]()).id)
checkId(21, mustGetTypeInfo(reflect.TypeFor[fieldType]()).id)
checkId(23, mustGetTypeInfo(reflect.TypeFor[mapType]()).id)
copy(builtinIdToTypeSlice[:], idToTypeSlice)
// Move the id space upwards to allow for growth in the predefined world
// without breaking existing files.
if nextId := len(idToTypeSlice); nextId > firstUserId {
panic(fmt.Sprintln("nextId too large:", nextId))
}
idToTypeSlice = idToTypeSlice[:firstUserId]
registerBasics()
wireTypeUserInfo = userType(wireTypeType)
}
// Array type
type arrayType struct {
CommonType
Elem typeId
Len int
}
func newArrayType(name string) *arrayType {
a := &arrayType{CommonType{Name: name}, 0, 0}
return a
}
func (a *arrayType) init(elem gobType, len int) {
// Set our type id before evaluating the element's, in case it's our own.
setTypeId(a)
a.Elem = elem.id()
a.Len = len
}
func (a *arrayType) safeString(seen map[typeId]bool) string {
if seen[a.Id] {
return a.Name
}
seen[a.Id] = true
return fmt.Sprintf("[%d]%s", a.Len, a.Elem.gobType().safeString(seen))
}
func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) }
// GobEncoder type (something that implements the GobEncoder interface)
type gobEncoderType struct {
CommonType
}
func newGobEncoderType(name string) *gobEncoderType {
g := &gobEncoderType{CommonType{Name: name}}
setTypeId(g)
return g
}
func (g *gobEncoderType) safeString(seen map[typeId]bool) string {
return g.Name
}
func (g *gobEncoderType) string() string { return g.Name }
// Map type
type mapType struct {
CommonType
Key typeId
Elem typeId
}
func newMapType(name string) *mapType {
m := &mapType{CommonType{Name: name}, 0, 0}
return m
}
func (m *mapType) init(key, elem gobType) {
// Set our type id before evaluating the element's, in case it's our own.
setTypeId(m)
m.Key = key.id()
m.Elem = elem.id()
}
func (m *mapType) safeString(seen map[typeId]bool) string {
if seen[m.Id] {
return m.Name
}
seen[m.Id] = true
key := m.Key.gobType().safeString(seen)
elem := m.Elem.gobType().safeString(seen)
return fmt.Sprintf("map[%s]%s", key, elem)
}
func (m *mapType) string() string { return m.safeString(make(map[typeId]bool)) }
// Slice type
type sliceType struct {
CommonType
Elem typeId
}
func newSliceType(name string) *sliceType {
s := &sliceType{CommonType{Name: name}, 0}
return s
}
func (s *sliceType) init(elem gobType) {
// Set our type id before evaluating the element's, in case it's our own.
setTypeId(s)
// See the comments about ids in newTypeObject. Only slices and
// structs have mutual recursion.
if elem.id() == 0 {
setTypeId(elem)
}
s.Elem = elem.id()
}
func (s *sliceType) safeString(seen map[typeId]bool) string {
if seen[s.Id] {
return s.Name
}
seen[s.Id] = true
return fmt.Sprintf("[]%s", s.Elem.gobType().safeString(seen))
}
func (s *sliceType) string() string { return s.safeString(make(map[typeId]bool)) }
// Struct type
type fieldType struct {
Name string
Id typeId
}
type structType struct {
CommonType
Field []fieldType
}
func (s *structType) safeString(seen map[typeId]bool) string {
if s == nil {
return "<nil>"
}
if _, ok := seen[s.Id]; ok {
return s.Name
}
seen[s.Id] = true
str := s.Name + " = struct { "
for _, f := range s.Field {
str += fmt.Sprintf("%s %s; ", f.Name, f.Id.gobType().safeString(seen))
}
str += "}"
return str
}
func (s *structType) string() string { return s.safeString(make(map[typeId]bool)) }
func newStructType(name string) *structType {
s := &structType{CommonType{Name: name}, nil}
// For historical reasons we set the id here rather than init.
// See the comment in newTypeObject for details.
setTypeId(s)
return s
}
// newTypeObject allocates a gobType for the reflection type rt.
// Unless ut represents a GobEncoder, rt should be the base type
// of ut.
// This is only called from the encoding side. The decoding side
// works through typeIds and userTypeInfos alone.
func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, error) {
// Does this type implement GobEncoder?
if ut.externalEnc != 0 {
return newGobEncoderType(name), nil
}
var err error
var type0, type1 gobType
defer func() {
if err != nil {
delete(types, rt)
}
}()
// Install the top-level type before the subtypes (e.g. struct before
// fields) so recursive types can be constructed safely.
switch t := rt; t.Kind() {
// All basic types are easy: they are predefined.
case reflect.Bool:
return tBool.gobType(), nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return tInt.gobType(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return tUint.gobType(), nil
case reflect.Float32, reflect.Float64:
return tFloat.gobType(), nil
case reflect.Complex64, reflect.Complex128:
return tComplex.gobType(), nil
case reflect.String:
return tString.gobType(), nil
case reflect.Interface:
return tInterface.gobType(), nil
case reflect.Array:
at := newArrayType(name)
types[rt] = at
type0, err = getBaseType("", t.Elem())
if err != nil {
return nil, err
}
// Historical aside:
// For arrays, maps, and slices, we set the type id after the elements
// are constructed. This is to retain the order of type id allocation after
// a fix made to handle recursive types, which changed the order in
// which types are built. Delaying the setting in this way preserves
// type ids while allowing recursive types to be described. Structs,
// done below, were already handling recursion correctly so they
// assign the top-level id before those of the field.
at.init(type0, t.Len())
return at, nil
case reflect.Map:
mt := newMapType(name)
types[rt] = mt
type0, err = getBaseType("", t.Key())
if err != nil {
return nil, err
}
type1, err = getBaseType("", t.Elem())
if err != nil {
return nil, err
}
mt.init(type0, type1)
return mt, nil
case reflect.Slice:
// []byte == []uint8 is a special case
if t.Elem().Kind() == reflect.Uint8 {
return tBytes.gobType(), nil
}
st := newSliceType(name)
types[rt] = st
type0, err = getBaseType(t.Elem().Name(), t.Elem())
if err != nil {
return nil, err
}
st.init(type0)
return st, nil
case reflect.Struct:
st := newStructType(name)
types[rt] = st
idToTypeSlice[st.id()] = st
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if !isSent(&f) {
continue
}
typ := userType(f.Type).base
tname := typ.Name()
if tname == "" {
t := userType(f.Type).base
tname = t.String()
}
gt, err := getBaseType(tname, f.Type)
if err != nil {
return nil, err
}
// Some mutually recursive types can cause us to be here while
// still defining the element. Fix the element type id here.
// We could do this more neatly by setting the id at the start of
// building every type, but that would break binary compatibility.
if gt.id() == 0 {
setTypeId(gt)
}
st.Field = append(st.Field, fieldType{f.Name, gt.id()})
}
return st, nil
default:
return nil, errors.New("gob NewTypeObject can't handle type: " + rt.String())
}
}
// isExported reports whether this is an exported - upper case - name.
func isExported(name string) bool {
rune, _ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(rune)
}
// isSent reports whether this struct field is to be transmitted.
// It will be transmitted only if it is exported and not a chan or func field
// or pointer to chan or func.
func isSent(field *reflect.StructField) bool {
if !isExported(field.Name) {
return false
}
// If the field is a chan or func or pointer thereto, don't send it.
// That is, treat it like an unexported field.
typ := field.Type
for typ.Kind() == reflect.Pointer {
typ = typ.Elem()
}
if typ.Kind() == reflect.Chan || typ.Kind() == reflect.Func {
return false
}
return true
}
// getBaseType returns the Gob type describing the given reflect.Type's base type.
// typeLock must be held.
func getBaseType(name string, rt reflect.Type) (gobType, error) {
ut := userType(rt)
return getType(name, ut, ut.base)
}
// getType returns the Gob type describing the given reflect.Type.
// Should be called only when handling GobEncoders/Decoders,
// which may be pointers. All other types are handled through the
// base type, never a pointer.
// typeLock must be held.
func getType(name string, ut *userTypeInfo, rt reflect.Type) (gobType, error) {
typ, present := types[rt]
if present {
return typ, nil
}
typ, err := newTypeObject(name, ut, rt)
if err == nil {
types[rt] = typ
}
return typ, err
}
func checkId(want, got typeId) {
if want != got {
fmt.Fprintf(os.Stderr, "checkId: %d should be %d\n", int(got), int(want))
panic("bootstrap type wrong id: " + got.name() + " " + got.string() + " not " + want.string())
}
}
// used for building the basic types; called only from init(). the incoming
// interface always refers to a pointer.
func bootstrapType(name string, e any) typeId {
rt := reflect.TypeOf(e).Elem()
_, present := types[rt]
if present {
panic("bootstrap type already present: " + name + ", " + rt.String())
}
typ := &CommonType{Name: name}
types[rt] = typ
setTypeId(typ)
return typ.id()
}
// Representation of the information we send and receive about this type.
// Each value we send is preceded by its type definition: an encoded int.
// However, the very first time we send the value, we first send the pair
// (-id, wireType).
// For bootstrapping purposes, we assume that the recipient knows how
// to decode a wireType; it is exactly the wireType struct here, interpreted
// using the gob rules for sending a structure, except that we assume the
// ids for wireType and structType etc. are known. The relevant pieces
// are built in encode.go's init() function.
// To maintain binary compatibility, if you extend this type, always put
// the new fields last.
type wireType struct {
ArrayT *arrayType
SliceT *sliceType
StructT *structType
MapT *mapType
GobEncoderT *gobEncoderType
BinaryMarshalerT *gobEncoderType
TextMarshalerT *gobEncoderType
}
func (w *wireType) string() string {
const unknown = "unknown type"
if w == nil {
return unknown
}
switch {
case w.ArrayT != nil:
return w.ArrayT.Name
case w.SliceT != nil:
return w.SliceT.Name
case w.StructT != nil:
return w.StructT.Name
case w.MapT != nil:
return w.MapT.Name
case w.GobEncoderT != nil:
return w.GobEncoderT.Name
case w.BinaryMarshalerT != nil:
return w.BinaryMarshalerT.Name
case w.TextMarshalerT != nil:
return w.TextMarshalerT.Name
}
return unknown
}
type typeInfo struct {
id typeId
encInit sync.Mutex // protects creation of encoder
encoder atomic.Pointer[encEngine]
wire wireType
}
// typeInfoMap is an atomic pointer to map[reflect.Type]*typeInfo.
// It's updated copy-on-write. Readers just do an atomic load
// to get the current version of the map. Writers make a full copy of
// the map and atomically update the pointer to point to the new map.
// Under heavy read contention, this is significantly faster than a map
// protected by a mutex.
var typeInfoMap atomic.Value
// typeInfoMapInit is used instead of typeInfoMap during init time,
// as types are registered sequentially during init and we can save
// the overhead of making map copies.
// It is saved to typeInfoMap and set to nil before init finishes.
var typeInfoMapInit = make(map[reflect.Type]*typeInfo, 16)
func lookupTypeInfo(rt reflect.Type) *typeInfo {
if m := typeInfoMapInit; m != nil {
return m[rt]
}
m, _ := typeInfoMap.Load().(map[reflect.Type]*typeInfo)
return m[rt]
}
func getTypeInfo(ut *userTypeInfo) (*typeInfo, error) {
rt := ut.base
if ut.externalEnc != 0 {
// We want the user type, not the base type.
rt = ut.user
}
if info := lookupTypeInfo(rt); info != nil {
return info, nil
}
return buildTypeInfo(ut, rt)
}
// buildTypeInfo constructs the type information for the type
// and stores it in the type info map.
func buildTypeInfo(ut *userTypeInfo, rt reflect.Type) (*typeInfo, error) {
typeLock.Lock()
defer typeLock.Unlock()
if info := lookupTypeInfo(rt); info != nil {
return info, nil
}
gt, err := getBaseType(rt.Name(), rt)
if err != nil {
return nil, err
}
info := &typeInfo{id: gt.id()}
if ut.externalEnc != 0 {
userType, err := getType(rt.Name(), ut, rt)
if err != nil {
return nil, err
}
gt := userType.id().gobType().(*gobEncoderType)
switch ut.externalEnc {
case xGob:
info.wire.GobEncoderT = gt
case xBinary:
info.wire.BinaryMarshalerT = gt
case xText:
info.wire.TextMarshalerT = gt
}
rt = ut.user
} else {
t := info.id.gobType()
switch typ := rt; typ.Kind() {
case reflect.Array:
info.wire.ArrayT = t.(*arrayType)
case reflect.Map:
info.wire.MapT = t.(*mapType)
case reflect.Slice:
// []byte == []uint8 is a special case handled separately
if typ.Elem().Kind() != reflect.Uint8 {
info.wire.SliceT = t.(*sliceType)
}
case reflect.Struct:
info.wire.StructT = t.(*structType)
}
}
if m := typeInfoMapInit; m != nil {
m[rt] = info
return info, nil
}
// Create new map with old contents plus new entry.
m, _ := typeInfoMap.Load().(map[reflect.Type]*typeInfo)
newm := maps.Clone(m)
newm[rt] = info
typeInfoMap.Store(newm)
return info, nil
}
// Called only when a panic is acceptable and unexpected.
func mustGetTypeInfo(rt reflect.Type) *typeInfo {
t, err := getTypeInfo(userType(rt))
if err != nil {
panic("getTypeInfo: " + err.Error())
}
return t
}
// GobEncoder is the interface describing data that provides its own
// representation for encoding values for transmission to a GobDecoder.
// A type that implements GobEncoder and GobDecoder has complete
// control over the representation of its data and may therefore
// contain things such as private fields, channels, and functions,
// which are not usually transmissible in gob streams.
//
// Note: Since gobs can be stored permanently, it is good design
// to guarantee the encoding used by a GobEncoder is stable as the
// software evolves. For instance, it might make sense for GobEncode
// to include a version number in the encoding.
type GobEncoder interface {
// GobEncode returns a byte slice representing the encoding of the
// receiver for transmission to a GobDecoder, usually of the same
// concrete type.
GobEncode() ([]byte, error)
}
// GobDecoder is the interface describing data that provides its own
// routine for decoding transmitted values sent by a GobEncoder.
type GobDecoder interface {
// GobDecode overwrites the receiver, which must be a pointer,
// with the value represented by the byte slice, which was written
// by GobEncode, usually for the same concrete type.
GobDecode([]byte) error
}
var (
nameToConcreteType sync.Map // map[string]reflect.Type
concreteTypeToName sync.Map // map[reflect.Type]string
)
// RegisterName is like [Register] but uses the provided name rather than the
// type's default.
func RegisterName(name string, value any) {
if name == "" {
// reserved for nil
panic("attempt to register empty name")
}
ut := userType(reflect.TypeOf(value))
// Check for incompatible duplicates. The name must refer to the
// same user type, and vice versa.
// Store the name and type provided by the user....
if t, dup := nameToConcreteType.LoadOrStore(name, reflect.TypeOf(value)); dup && t != ut.user {
panic(fmt.Sprintf("gob: registering duplicate types for %q: %s != %s", name, t, ut.user))
}
// but the flattened type in the type table, since that's what decode needs.
if n, dup := concreteTypeToName.LoadOrStore(ut.base, name); dup && n != name {
nameToConcreteType.Delete(name)
panic(fmt.Sprintf("gob: registering duplicate names for %s: %q != %q", ut.user, n, name))
}
}
// Register records a type, identified by a value for that type, under its
// internal type name. That name will identify the concrete type of a value
// sent or received as an interface variable. Only types that will be
// transferred as implementations of interface values need to be registered.
// Expecting to be used only during initialization, it panics if the mapping
// between types and names is not a bijection.
func Register(value any) {
// Default to printed representation for unnamed types
rt := reflect.TypeOf(value)
name := rt.String()
// But for named types (or pointers to them), qualify with import path (but see inner comment).
// Dereference one pointer looking for a named type.
star := ""
if rt.Name() == "" {
if pt := rt; pt.Kind() == reflect.Pointer {
star = "*"
// NOTE: The following line should be rt = pt.Elem() to implement
// what the comment above claims, but fixing it would break compatibility
// with existing gobs.
//
// Given package p imported as "full/p" with these definitions:
// package p
// type T1 struct { ... }
// this table shows the intended and actual strings used by gob to
// name the types:
//
// Type Correct string Actual string
//
// T1 full/p.T1 full/p.T1
// *T1 *full/p.T1 *p.T1
//
// The missing full path cannot be fixed without breaking existing gob decoders.
rt = pt
}
}
if rt.Name() != "" {
if rt.PkgPath() == "" {
name = star + rt.Name()
} else {
name = star + rt.PkgPath() + "." + rt.Name()
}
}
RegisterName(name, value)
}
func registerBasics() {
Register(int(0))
Register(int8(0))
Register(int16(0))
Register(int32(0))
Register(int64(0))
Register(uint(0))
Register(uint8(0))
Register(uint16(0))
Register(uint32(0))
Register(uint64(0))
Register(float32(0))
Register(float64(0))
Register(complex64(0i))
Register(complex128(0i))
Register(uintptr(0))
Register(false)
Register("")
Register([]byte(nil))
Register([]int(nil))
Register([]int8(nil))
Register([]int16(nil))
Register([]int32(nil))
Register([]int64(nil))
Register([]uint(nil))
Register([]uint8(nil))
Register([]uint16(nil))
Register([]uint32(nil))
Register([]uint64(nil))
Register([]float32(nil))
Register([]float64(nil))
Register([]complex64(nil))
Register([]complex128(nil))
Register([]uintptr(nil))
Register([]bool(nil))
Register([]string(nil))
}
func init() {
typeInfoMap.Store(typeInfoMapInit)
typeInfoMapInit = nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package hex implements hexadecimal encoding and decoding.
package hex
import (
"errors"
"fmt"
"io"
"slices"
"strings"
)
const (
hextable = "0123456789abcdef"
reverseHexTable = "" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\xff\xff\xff\xff\xff\xff" +
"\xff\x0a\x0b\x0c\x0d\x0e\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\x0a\x0b\x0c\x0d\x0e\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
)
// EncodedLen returns the length of an encoding of n source bytes.
// Specifically, it returns n * 2.
func EncodedLen(n int) int { return n * 2 }
// Encode encodes src into [EncodedLen](len(src))
// bytes of dst. As a convenience, it returns the number
// of bytes written to dst, but this value is always [EncodedLen](len(src)).
// Encode implements hexadecimal encoding.
func Encode(dst, src []byte) int {
j := 0
for _, v := range src {
dst[j] = hextable[v>>4]
dst[j+1] = hextable[v&0x0f]
j += 2
}
return len(src) * 2
}
// AppendEncode appends the hexadecimally encoded src to dst
// and returns the extended buffer.
func AppendEncode(dst, src []byte) []byte {
n := EncodedLen(len(src))
dst = slices.Grow(dst, n)
Encode(dst[len(dst):][:n], src)
return dst[:len(dst)+n]
}
// ErrLength reports an attempt to decode an odd-length input
// using [Decode] or [DecodeString].
// The stream-based Decoder returns [io.ErrUnexpectedEOF] instead of ErrLength.
var ErrLength = errors.New("encoding/hex: odd length hex string")
// InvalidByteError values describe errors resulting from an invalid byte in a hex string.
type InvalidByteError byte
func (e InvalidByteError) Error() string {
return fmt.Sprintf("encoding/hex: invalid byte: %#U", rune(e))
}
// DecodedLen returns the length of a decoding of x source bytes.
// Specifically, it returns x / 2.
func DecodedLen(x int) int { return x / 2 }
// Decode decodes src into [DecodedLen](len(src)) bytes,
// returning the actual number of bytes written to dst.
//
// Decode expects that src contains only hexadecimal
// characters and that src has even length.
// If the input is malformed, Decode returns the number
// of bytes decoded before the error.
func Decode(dst, src []byte) (int, error) {
i, j := 0, 1
for ; j < len(src); j += 2 {
p := src[j-1]
q := src[j]
a := reverseHexTable[p]
b := reverseHexTable[q]
if a > 0x0f {
return i, InvalidByteError(p)
}
if b > 0x0f {
return i, InvalidByteError(q)
}
dst[i] = (a << 4) | b
i++
}
if len(src)%2 == 1 {
// Check for invalid char before reporting bad length,
// since the invalid char (if present) is an earlier problem.
if reverseHexTable[src[j-1]] > 0x0f {
return i, InvalidByteError(src[j-1])
}
return i, ErrLength
}
return i, nil
}
// AppendDecode appends the hexadecimally decoded src to dst
// and returns the extended buffer.
// If the input is malformed, it returns the partially decoded src and an error.
func AppendDecode(dst, src []byte) ([]byte, error) {
n := DecodedLen(len(src))
dst = slices.Grow(dst, n)
n, err := Decode(dst[len(dst):][:n], src)
return dst[:len(dst)+n], err
}
// EncodeToString returns the hexadecimal encoding of src.
func EncodeToString(src []byte) string {
dst := make([]byte, EncodedLen(len(src)))
Encode(dst, src)
return string(dst)
}
// DecodeString returns the bytes represented by the hexadecimal string s.
//
// DecodeString expects that src contains only hexadecimal
// characters and that src has even length.
// If the input is malformed, DecodeString returns
// the bytes decoded before the error.
func DecodeString(s string) ([]byte, error) {
dst := make([]byte, DecodedLen(len(s)))
n, err := Decode(dst, []byte(s))
return dst[:n], err
}
// Dump returns a string that contains a hex dump of the given data. The format
// of the hex dump matches the output of `hexdump -C` on the command line.
func Dump(data []byte) string {
if len(data) == 0 {
return ""
}
var buf strings.Builder
// Dumper will write 79 bytes per complete 16 byte chunk, and at least
// 64 bytes for whatever remains. Round the allocation up, since only a
// maximum of 15 bytes will be wasted.
buf.Grow((1 + ((len(data) - 1) / 16)) * 79)
dumper := Dumper(&buf)
dumper.Write(data)
dumper.Close()
return buf.String()
}
// bufferSize is the number of hexadecimal characters to buffer in encoder and decoder.
const bufferSize = 1024
type encoder struct {
w io.Writer
err error
out [bufferSize]byte // output buffer
}
// NewEncoder returns an [io.Writer] that writes lowercase hexadecimal characters to w.
func NewEncoder(w io.Writer) io.Writer {
return &encoder{w: w}
}
func (e *encoder) Write(p []byte) (n int, err error) {
for len(p) > 0 && e.err == nil {
chunkSize := bufferSize / 2
if len(p) < chunkSize {
chunkSize = len(p)
}
var written int
encoded := Encode(e.out[:], p[:chunkSize])
written, e.err = e.w.Write(e.out[:encoded])
n += written / 2
p = p[chunkSize:]
}
return n, e.err
}
type decoder struct {
r io.Reader
err error
in []byte // input buffer (encoded form)
arr [bufferSize]byte // backing array for in
}
// NewDecoder returns an [io.Reader] that decodes hexadecimal characters from r.
// NewDecoder expects that r contain only an even number of hexadecimal characters.
func NewDecoder(r io.Reader) io.Reader {
return &decoder{r: r}
}
func (d *decoder) Read(p []byte) (n int, err error) {
// Fill internal buffer with sufficient bytes to decode
if len(d.in) < 2 && d.err == nil {
var numCopy, numRead int
numCopy = copy(d.arr[:], d.in) // Copies either 0 or 1 bytes
numRead, d.err = d.r.Read(d.arr[numCopy:])
d.in = d.arr[:numCopy+numRead]
if d.err == io.EOF && len(d.in)%2 != 0 {
if a := reverseHexTable[d.in[len(d.in)-1]]; a > 0x0f {
d.err = InvalidByteError(d.in[len(d.in)-1])
} else {
d.err = io.ErrUnexpectedEOF
}
}
}
// Decode internal buffer into output buffer
if numAvail := len(d.in) / 2; len(p) > numAvail {
p = p[:numAvail]
}
numDec, err := Decode(p, d.in[:len(p)*2])
d.in = d.in[2*numDec:]
if err != nil {
d.in, d.err = nil, err // Decode error; discard input remainder
}
if len(d.in) < 2 {
return numDec, d.err // Only expose errors when buffer fully consumed
}
return numDec, nil
}
// Dumper returns a [io.WriteCloser] that writes a hex dump of all written data to
// w. The format of the dump matches the output of `hexdump -C` on the command
// line.
func Dumper(w io.Writer) io.WriteCloser {
return &dumper{w: w}
}
type dumper struct {
w io.Writer
rightChars [18]byte
buf [14]byte
used int // number of bytes in the current line
n uint // number of bytes, total
closed bool
}
func toChar(b byte) byte {
if b < 32 || b > 126 {
return '.'
}
return b
}
func (h *dumper) Write(data []byte) (n int, err error) {
if h.closed {
return 0, errors.New("encoding/hex: dumper closed")
}
// Output lines look like:
// 00000010 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d |./0123456789:;<=|
// ^ offset ^ extra space ^ ASCII of line.
for i := range data {
if h.used == 0 {
// At the beginning of a line we print the current
// offset in hex.
h.buf[0] = byte(h.n >> 24)
h.buf[1] = byte(h.n >> 16)
h.buf[2] = byte(h.n >> 8)
h.buf[3] = byte(h.n)
Encode(h.buf[4:], h.buf[:4])
h.buf[12] = ' '
h.buf[13] = ' '
_, err = h.w.Write(h.buf[4:])
if err != nil {
return
}
}
Encode(h.buf[:], data[i:i+1])
h.buf[2] = ' '
l := 3
if h.used == 7 {
// There's an additional space after the 8th byte.
h.buf[3] = ' '
l = 4
} else if h.used == 15 {
// At the end of the line there's an extra space and
// the bar for the right column.
h.buf[3] = ' '
h.buf[4] = '|'
l = 5
}
_, err = h.w.Write(h.buf[:l])
if err != nil {
return
}
n++
h.rightChars[h.used] = toChar(data[i])
h.used++
h.n++
if h.used == 16 {
h.rightChars[16] = '|'
h.rightChars[17] = '\n'
_, err = h.w.Write(h.rightChars[:])
if err != nil {
return
}
h.used = 0
}
}
return
}
func (h *dumper) Close() (err error) {
// See the comments in Write() for the details of this format.
if h.closed {
return
}
h.closed = true
if h.used == 0 {
return
}
h.buf[0] = ' '
h.buf[1] = ' '
h.buf[2] = ' '
h.buf[3] = ' '
h.buf[4] = '|'
nBytes := h.used
for h.used < 16 {
l := 3
if h.used == 7 {
l = 4
} else if h.used == 15 {
l = 5
}
_, err = h.w.Write(h.buf[:l])
if err != nil {
return
}
h.used++
}
h.rightChars[nBytes] = '|'
h.rightChars[nBytes+1] = '\n'
_, err = h.w.Write(h.rightChars[:nBytes+2])
return
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Represents JSON data structure using native Go types: booleans, floats,
// strings, arrays, and maps.
//go:build !goexperiment.jsonv2
package json
import (
"encoding"
"encoding/base64"
"fmt"
"reflect"
"strconv"
"strings"
"unicode"
"unicode/utf16"
"unicode/utf8"
_ "unsafe" // for linkname
)
// Unmarshal parses the JSON-encoded data and stores the result
// in the value pointed to by v. If v is nil or not a pointer,
// Unmarshal returns an [InvalidUnmarshalError].
//
// Unmarshal uses the inverse of the encodings that
// [Marshal] uses, allocating maps, slices, and pointers as necessary,
// with the following additional rules:
//
// To unmarshal JSON into a pointer, Unmarshal first handles the case of
// the JSON being the JSON literal null. In that case, Unmarshal sets
// the pointer to nil. Otherwise, Unmarshal unmarshals the JSON into
// the value pointed at by the pointer. If the pointer is nil, Unmarshal
// allocates a new value for it to point to.
//
// To unmarshal JSON into a value implementing [Unmarshaler],
// Unmarshal calls that value's [Unmarshaler.UnmarshalJSON] method, including
// when the input is a JSON null.
// Otherwise, if the value implements [encoding.TextUnmarshaler]
// and the input is a JSON quoted string, Unmarshal calls
// [encoding.TextUnmarshaler.UnmarshalText] with the unquoted form of the string.
//
// To unmarshal JSON into a struct, Unmarshal matches incoming object keys to
// the keys used by [Marshal] (either the struct field name or its tag),
// ignoring case. If multiple struct fields match an object key, an exact case
// match is preferred over a case-insensitive one.
//
// Incoming object members are processed in the order observed. If an object
// includes duplicate keys, later duplicates will replace or be merged into
// prior values.
//
// To unmarshal JSON into an interface value,
// Unmarshal stores one of these in the interface value:
//
// - bool, for JSON booleans
// - float64, for JSON numbers
// - string, for JSON strings
// - []any, for JSON arrays
// - map[string]any, for JSON objects
// - nil for JSON null
//
// To unmarshal a JSON array into a slice, Unmarshal resets the slice length
// to zero and then appends each element to the slice.
// As a special case, to unmarshal an empty JSON array into a slice,
// Unmarshal replaces the slice with a new empty slice.
//
// To unmarshal a JSON array into a Go array, Unmarshal decodes
// JSON array elements into corresponding Go array elements.
// If the Go array is smaller than the JSON array,
// the additional JSON array elements are discarded.
// If the JSON array is smaller than the Go array,
// the additional Go array elements are set to zero values.
//
// To unmarshal a JSON object into a map, Unmarshal first establishes a map to
// use. If the map is nil, Unmarshal allocates a new map. Otherwise Unmarshal
// reuses the existing map, keeping existing entries. Unmarshal then stores
// key-value pairs from the JSON object into the map. The map's key type must
// either be any string type, an integer, or implement [encoding.TextUnmarshaler].
//
// If the JSON-encoded data contain a syntax error, Unmarshal returns a [SyntaxError].
//
// If a JSON value is not appropriate for a given target type,
// or if a JSON number overflows the target type, Unmarshal
// skips that field and completes the unmarshaling as best it can.
// If no more serious errors are encountered, Unmarshal returns
// an [UnmarshalTypeError] describing the earliest such error. In any
// case, it's not guaranteed that all the remaining fields following
// the problematic one will be unmarshaled into the target object.
//
// The JSON null value unmarshals into an interface, map, pointer, or slice
// by setting that Go value to nil. Because null is often used in JSON to mean
// “not present,” unmarshaling a JSON null into any other Go type has no effect
// on the value and produces no error.
//
// When unmarshaling quoted strings, invalid UTF-8 or
// invalid UTF-16 surrogate pairs are not treated as an error.
// Instead, they are replaced by the Unicode replacement
// character U+FFFD.
func Unmarshal(data []byte, v any) error {
// Check for well-formedness.
// Avoids filling out half a data structure
// before discovering a JSON syntax error.
var d decodeState
err := checkValid(data, &d.scan)
if err != nil {
return err
}
d.init(data)
return d.unmarshal(v)
}
// Unmarshaler is the interface implemented by types
// that can unmarshal a JSON description of themselves.
// The input can be assumed to be a valid encoding of
// a JSON value. UnmarshalJSON must copy the JSON data
// if it wishes to retain the data after returning.
type Unmarshaler interface {
UnmarshalJSON([]byte) error
}
// An UnmarshalTypeError describes a JSON value that was
// not appropriate for a value of a specific Go type.
type UnmarshalTypeError struct {
Value string // description of JSON value - "bool", "array", "number -5"
Type reflect.Type // type of Go value it could not be assigned to
Offset int64 // error occurred after reading Offset bytes
Struct string // name of the struct type containing the field
Field string // the full path from root node to the field, include embedded struct
}
func (e *UnmarshalTypeError) Error() string {
if e.Struct != "" || e.Field != "" {
return "json: cannot unmarshal " + e.Value + " into Go struct field " + e.Struct + "." + e.Field + " of type " + e.Type.String()
}
return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String()
}
// An UnmarshalFieldError describes a JSON object key that
// led to an unexported (and therefore unwritable) struct field.
//
// Deprecated: No longer used; kept for compatibility.
type UnmarshalFieldError struct {
Key string
Type reflect.Type
Field reflect.StructField
}
func (e *UnmarshalFieldError) Error() string {
return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String()
}
// An InvalidUnmarshalError describes an invalid argument passed to [Unmarshal].
// (The argument to [Unmarshal] must be a non-nil pointer.)
type InvalidUnmarshalError struct {
Type reflect.Type
}
func (e *InvalidUnmarshalError) Error() string {
if e.Type == nil {
return "json: Unmarshal(nil)"
}
if e.Type.Kind() != reflect.Pointer {
return "json: Unmarshal(non-pointer " + e.Type.String() + ")"
}
return "json: Unmarshal(nil " + e.Type.String() + ")"
}
func (d *decodeState) unmarshal(v any) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Pointer || rv.IsNil() {
return &InvalidUnmarshalError{reflect.TypeOf(v)}
}
d.scan.reset()
d.scanWhile(scanSkipSpace)
// We decode rv not rv.Elem because the Unmarshaler interface
// test must be applied at the top level of the value.
err := d.value(rv)
if err != nil {
return d.addErrorContext(err)
}
return d.savedError
}
// A Number represents a JSON number literal.
type Number string
// String returns the literal text of the number.
func (n Number) String() string { return string(n) }
// Float64 returns the number as a float64.
func (n Number) Float64() (float64, error) {
return strconv.ParseFloat(string(n), 64)
}
// Int64 returns the number as an int64.
func (n Number) Int64() (int64, error) {
return strconv.ParseInt(string(n), 10, 64)
}
// An errorContext provides context for type errors during decoding.
type errorContext struct {
Struct reflect.Type
FieldStack []string
}
// decodeState represents the state while decoding a JSON value.
type decodeState struct {
data []byte
off int // next read offset in data
opcode int // last read result
scan scanner
errorContext *errorContext
savedError error
useNumber bool
disallowUnknownFields bool
}
// readIndex returns the position of the last byte read.
func (d *decodeState) readIndex() int {
return d.off - 1
}
// phasePanicMsg is used as a panic message when we end up with something that
// shouldn't happen. It can indicate a bug in the JSON decoder, or that
// something is editing the data slice while the decoder executes.
const phasePanicMsg = "JSON decoder out of sync - data changing underfoot?"
func (d *decodeState) init(data []byte) *decodeState {
d.data = data
d.off = 0
d.savedError = nil
if d.errorContext != nil {
d.errorContext.Struct = nil
// Reuse the allocated space for the FieldStack slice.
d.errorContext.FieldStack = d.errorContext.FieldStack[:0]
}
return d
}
// saveError saves the first err it is called with,
// for reporting at the end of the unmarshal.
func (d *decodeState) saveError(err error) {
if d.savedError == nil {
d.savedError = d.addErrorContext(err)
}
}
// addErrorContext returns a new error enhanced with information from d.errorContext
func (d *decodeState) addErrorContext(err error) error {
if d.errorContext != nil && (d.errorContext.Struct != nil || len(d.errorContext.FieldStack) > 0) {
switch err := err.(type) {
case *UnmarshalTypeError:
err.Struct = d.errorContext.Struct.Name()
fieldStack := d.errorContext.FieldStack
if err.Field != "" {
fieldStack = append(fieldStack, err.Field)
}
err.Field = strings.Join(fieldStack, ".")
}
}
return err
}
// skip scans to the end of what was started.
func (d *decodeState) skip() {
s, data, i := &d.scan, d.data, d.off
depth := len(s.parseState)
for {
op := s.step(s, data[i])
i++
if len(s.parseState) < depth {
d.off = i
d.opcode = op
return
}
}
}
// scanNext processes the byte at d.data[d.off].
func (d *decodeState) scanNext() {
if d.off < len(d.data) {
d.opcode = d.scan.step(&d.scan, d.data[d.off])
d.off++
} else {
d.opcode = d.scan.eof()
d.off = len(d.data) + 1 // mark processed EOF with len+1
}
}
// scanWhile processes bytes in d.data[d.off:] until it
// receives a scan code not equal to op.
func (d *decodeState) scanWhile(op int) {
s, data, i := &d.scan, d.data, d.off
for i < len(data) {
newOp := s.step(s, data[i])
i++
if newOp != op {
d.opcode = newOp
d.off = i
return
}
}
d.off = len(data) + 1 // mark processed EOF with len+1
d.opcode = d.scan.eof()
}
// rescanLiteral is similar to scanWhile(scanContinue), but it specialises the
// common case where we're decoding a literal. The decoder scans the input
// twice, once for syntax errors and to check the length of the value, and the
// second to perform the decoding.
//
// Only in the second step do we use decodeState to tokenize literals, so we
// know there aren't any syntax errors. We can take advantage of that knowledge,
// and scan a literal's bytes much more quickly.
func (d *decodeState) rescanLiteral() {
data, i := d.data, d.off
Switch:
switch data[i-1] {
case '"': // string
for ; i < len(data); i++ {
switch data[i] {
case '\\':
i++ // escaped char
case '"':
i++ // tokenize the closing quote too
break Switch
}
}
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-': // number
for ; i < len(data); i++ {
switch data[i] {
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'.', 'e', 'E', '+', '-':
default:
break Switch
}
}
case 't': // true
i += len("rue")
case 'f': // false
i += len("alse")
case 'n': // null
i += len("ull")
}
if i < len(data) {
d.opcode = stateEndValue(&d.scan, data[i])
} else {
d.opcode = scanEnd
}
d.off = i + 1
}
// value consumes a JSON value from d.data[d.off-1:], decoding into v, and
// reads the following byte ahead. If v is invalid, the value is discarded.
// The first byte of the value has been read already.
func (d *decodeState) value(v reflect.Value) error {
switch d.opcode {
default:
panic(phasePanicMsg)
case scanBeginArray:
if v.IsValid() {
if err := d.array(v); err != nil {
return err
}
} else {
d.skip()
}
d.scanNext()
case scanBeginObject:
if v.IsValid() {
if err := d.object(v); err != nil {
return err
}
} else {
d.skip()
}
d.scanNext()
case scanBeginLiteral:
// All bytes inside literal return scanContinue op code.
start := d.readIndex()
d.rescanLiteral()
if v.IsValid() {
if err := d.literalStore(d.data[start:d.readIndex()], v, false); err != nil {
return err
}
}
}
return nil
}
type unquotedValue struct{}
// valueQuoted is like value but decodes a
// quoted string literal or literal null into an interface value.
// If it finds anything other than a quoted string literal or null,
// valueQuoted returns unquotedValue{}.
func (d *decodeState) valueQuoted() any {
switch d.opcode {
default:
panic(phasePanicMsg)
case scanBeginArray, scanBeginObject:
d.skip()
d.scanNext()
case scanBeginLiteral:
v := d.literalInterface()
switch v.(type) {
case nil, string:
return v
}
}
return unquotedValue{}
}
// indirect walks down v allocating pointers as needed,
// until it gets to a non-pointer.
// If it encounters an Unmarshaler, indirect stops and returns that.
// If decodingNull is true, indirect stops at the first settable pointer so it
// can be set to nil.
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
// Issue #24153 indicates that it is generally not a guaranteed property
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
// and expect the value to still be settable for values derived from
// unexported embedded struct fields.
//
// The logic below effectively does this when it first addresses the value
// (to satisfy possible pointer methods) and continues to dereference
// subsequent pointers as necessary.
//
// After the first round-trip, we set v back to the original value to
// preserve the original RW flags contained in reflect.Value.
v0 := v
haveAddr := false
// If v is a named type and is addressable,
// start with its address, so that if the type has pointer methods,
// we find them.
if v.Kind() != reflect.Pointer && v.Type().Name() != "" && v.CanAddr() {
haveAddr = true
v = v.Addr()
}
for {
// Load value from interface, but only if the result will be
// usefully addressable.
if v.Kind() == reflect.Interface && !v.IsNil() {
e := v.Elem()
if e.Kind() == reflect.Pointer && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Pointer) {
haveAddr = false
v = e
continue
}
}
if v.Kind() != reflect.Pointer {
break
}
if decodingNull && v.CanSet() {
break
}
// Prevent infinite loop if v is an interface pointing to its own address:
// var v any
// v = &v
if v.Elem().Kind() == reflect.Interface && v.Elem().Elem().Equal(v) {
v = v.Elem()
break
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if v.Type().NumMethod() > 0 && v.CanInterface() {
if u, ok := v.Interface().(Unmarshaler); ok {
return u, nil, reflect.Value{}
}
if !decodingNull {
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
return nil, u, reflect.Value{}
}
}
}
if haveAddr {
v = v0 // restore original value after round-trip Value.Addr().Elem()
haveAddr = false
} else {
v = v.Elem()
}
}
return nil, nil, v
}
// array consumes an array from d.data[d.off-1:], decoding into v.
// The first byte of the array ('[') has been read already.
func (d *decodeState) array(v reflect.Value) error {
// Check for unmarshaler.
u, ut, pv := indirect(v, false)
if u != nil {
start := d.readIndex()
d.skip()
return u.UnmarshalJSON(d.data[start:d.off])
}
if ut != nil {
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
d.skip()
return nil
}
v = pv
// Check type of target.
switch v.Kind() {
case reflect.Interface:
if v.NumMethod() == 0 {
// Decoding into nil interface? Switch to non-reflect code.
ai := d.arrayInterface()
v.Set(reflect.ValueOf(ai))
return nil
}
// Otherwise it's invalid.
fallthrough
default:
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
d.skip()
return nil
case reflect.Array, reflect.Slice:
break
}
i := 0
for {
// Look ahead for ] - can only happen on first iteration.
d.scanWhile(scanSkipSpace)
if d.opcode == scanEndArray {
break
}
// Expand slice length, growing the slice if necessary.
if v.Kind() == reflect.Slice {
if i >= v.Cap() {
v.Grow(1)
}
if i >= v.Len() {
v.SetLen(i + 1)
}
}
if i < v.Len() {
// Decode into element.
if err := d.value(v.Index(i)); err != nil {
return err
}
} else {
// Ran out of fixed array: skip.
if err := d.value(reflect.Value{}); err != nil {
return err
}
}
i++
// Next token must be , or ].
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.opcode == scanEndArray {
break
}
if d.opcode != scanArrayValue {
panic(phasePanicMsg)
}
}
if i < v.Len() {
if v.Kind() == reflect.Array {
for ; i < v.Len(); i++ {
v.Index(i).SetZero() // zero remainder of array
}
} else {
v.SetLen(i) // truncate the slice
}
}
if i == 0 && v.Kind() == reflect.Slice {
v.Set(reflect.MakeSlice(v.Type(), 0, 0))
}
return nil
}
var nullLiteral = []byte("null")
var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
// object consumes an object from d.data[d.off-1:], decoding into v.
// The first byte ('{') of the object has been read already.
func (d *decodeState) object(v reflect.Value) error {
// Check for unmarshaler.
u, ut, pv := indirect(v, false)
if u != nil {
start := d.readIndex()
d.skip()
return u.UnmarshalJSON(d.data[start:d.off])
}
if ut != nil {
d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)})
d.skip()
return nil
}
v = pv
t := v.Type()
// Decoding into nil interface? Switch to non-reflect code.
if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
oi := d.objectInterface()
v.Set(reflect.ValueOf(oi))
return nil
}
var fields structFields
// Check type of target:
// struct or
// map[T1]T2 where T1 is string, an integer type,
// or an encoding.TextUnmarshaler
switch v.Kind() {
case reflect.Map:
// Map key must either have string kind, have an integer kind,
// or be an encoding.TextUnmarshaler.
switch t.Key().Kind() {
case reflect.String,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
default:
if !reflect.PointerTo(t.Key()).Implements(textUnmarshalerType) {
d.saveError(&UnmarshalTypeError{Value: "object", Type: t, Offset: int64(d.off)})
d.skip()
return nil
}
}
if v.IsNil() {
v.Set(reflect.MakeMap(t))
}
case reflect.Struct:
fields = cachedTypeFields(t)
// ok
default:
d.saveError(&UnmarshalTypeError{Value: "object", Type: t, Offset: int64(d.off)})
d.skip()
return nil
}
var mapElem reflect.Value
var origErrorContext errorContext
if d.errorContext != nil {
origErrorContext = *d.errorContext
}
for {
// Read opening " of string key or closing }.
d.scanWhile(scanSkipSpace)
if d.opcode == scanEndObject {
// closing } - can only happen on first iteration.
break
}
if d.opcode != scanBeginLiteral {
panic(phasePanicMsg)
}
// Read key.
start := d.readIndex()
d.rescanLiteral()
item := d.data[start:d.readIndex()]
key, ok := unquoteBytes(item)
if !ok {
panic(phasePanicMsg)
}
// Figure out field corresponding to key.
var subv reflect.Value
destring := false // whether the value is wrapped in a string to be decoded first
if v.Kind() == reflect.Map {
elemType := t.Elem()
if !mapElem.IsValid() {
mapElem = reflect.New(elemType).Elem()
} else {
mapElem.SetZero()
}
subv = mapElem
} else {
f := fields.byExactName[string(key)]
if f == nil {
f = fields.byFoldedName[string(foldName(key))]
}
if f != nil {
subv = v
destring = f.quoted
if d.errorContext == nil {
d.errorContext = new(errorContext)
}
for i, ind := range f.index {
if subv.Kind() == reflect.Pointer {
if subv.IsNil() {
// If a struct embeds a pointer to an unexported type,
// it is not possible to set a newly allocated value
// since the field is unexported.
//
// See https://golang.org/issue/21357
if !subv.CanSet() {
d.saveError(fmt.Errorf("json: cannot set embedded pointer to unexported struct: %v", subv.Type().Elem()))
// Invalidate subv to ensure d.value(subv) skips over
// the JSON value without assigning it to subv.
subv = reflect.Value{}
destring = false
break
}
subv.Set(reflect.New(subv.Type().Elem()))
}
subv = subv.Elem()
}
if i < len(f.index)-1 {
d.errorContext.FieldStack = append(
d.errorContext.FieldStack,
subv.Type().Field(ind).Name,
)
}
subv = subv.Field(ind)
}
d.errorContext.Struct = t
d.errorContext.FieldStack = append(d.errorContext.FieldStack, f.name)
} else if d.disallowUnknownFields {
d.saveError(fmt.Errorf("json: unknown field %q", key))
}
}
// Read : before value.
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.opcode != scanObjectKey {
panic(phasePanicMsg)
}
d.scanWhile(scanSkipSpace)
if destring {
switch qv := d.valueQuoted().(type) {
case nil:
if err := d.literalStore(nullLiteral, subv, false); err != nil {
return err
}
case string:
if err := d.literalStore([]byte(qv), subv, true); err != nil {
return err
}
default:
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", subv.Type()))
}
} else {
if err := d.value(subv); err != nil {
return err
}
}
// Write value back to map;
// if using struct, subv points into struct already.
if v.Kind() == reflect.Map {
kt := t.Key()
var kv reflect.Value
if reflect.PointerTo(kt).Implements(textUnmarshalerType) {
kv = reflect.New(kt)
if err := d.literalStore(item, kv, true); err != nil {
return err
}
kv = kv.Elem()
} else {
switch kt.Kind() {
case reflect.String:
kv = reflect.New(kt).Elem()
kv.SetString(string(key))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
s := string(key)
n, err := strconv.ParseInt(s, 10, 64)
if err != nil || kt.OverflowInt(n) {
d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)})
break
}
kv = reflect.New(kt).Elem()
kv.SetInt(n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
s := string(key)
n, err := strconv.ParseUint(s, 10, 64)
if err != nil || kt.OverflowUint(n) {
d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)})
break
}
kv = reflect.New(kt).Elem()
kv.SetUint(n)
default:
panic("json: Unexpected key type") // should never occur
}
}
if kv.IsValid() {
v.SetMapIndex(kv, subv)
}
}
// Next token must be , or }.
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.errorContext != nil {
// Reset errorContext to its original state.
// Keep the same underlying array for FieldStack, to reuse the
// space and avoid unnecessary allocs.
d.errorContext.FieldStack = d.errorContext.FieldStack[:len(origErrorContext.FieldStack)]
d.errorContext.Struct = origErrorContext.Struct
}
if d.opcode == scanEndObject {
break
}
if d.opcode != scanObjectValue {
panic(phasePanicMsg)
}
}
return nil
}
// convertNumber converts the number literal s to a float64 or a Number
// depending on the setting of d.useNumber.
func (d *decodeState) convertNumber(s string) (any, error) {
if d.useNumber {
return Number(s), nil
}
f, err := strconv.ParseFloat(s, 64)
if err != nil {
return nil, &UnmarshalTypeError{Value: "number " + s, Type: reflect.TypeFor[float64](), Offset: int64(d.off)}
}
return f, nil
}
var numberType = reflect.TypeFor[Number]()
// literalStore decodes a literal stored in item into v.
//
// fromQuoted indicates whether this literal came from unwrapping a
// string from the ",string" struct tag option. this is used only to
// produce more helpful error messages.
func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool) error {
// Check for unmarshaler.
if len(item) == 0 {
// Empty string given.
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
return nil
}
isNull := item[0] == 'n' // null
u, ut, pv := indirect(v, isNull)
if u != nil {
return u.UnmarshalJSON(item)
}
if ut != nil {
if item[0] != '"' {
if fromQuoted {
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
return nil
}
val := "number"
switch item[0] {
case 'n':
val = "null"
case 't', 'f':
val = "bool"
}
d.saveError(&UnmarshalTypeError{Value: val, Type: v.Type(), Offset: int64(d.readIndex())})
return nil
}
s, ok := unquoteBytes(item)
if !ok {
if fromQuoted {
return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())
}
panic(phasePanicMsg)
}
return ut.UnmarshalText(s)
}
v = pv
switch c := item[0]; c {
case 'n': // null
// The main parser checks that only true and false can reach here,
// but if this was a quoted string input, it could be anything.
if fromQuoted && string(item) != "null" {
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
break
}
switch v.Kind() {
case reflect.Interface, reflect.Pointer, reflect.Map, reflect.Slice:
v.SetZero()
// otherwise, ignore null for primitives/string
}
case 't', 'f': // true, false
value := item[0] == 't'
// The main parser checks that only true and false can reach here,
// but if this was a quoted string input, it could be anything.
if fromQuoted && string(item) != "true" && string(item) != "false" {
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
break
}
switch v.Kind() {
default:
if fromQuoted {
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.saveError(&UnmarshalTypeError{Value: "bool", Type: v.Type(), Offset: int64(d.readIndex())})
}
case reflect.Bool:
v.SetBool(value)
case reflect.Interface:
if v.NumMethod() == 0 {
v.Set(reflect.ValueOf(value))
} else {
d.saveError(&UnmarshalTypeError{Value: "bool", Type: v.Type(), Offset: int64(d.readIndex())})
}
}
case '"': // string
s, ok := unquoteBytes(item)
if !ok {
if fromQuoted {
return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())
}
panic(phasePanicMsg)
}
switch v.Kind() {
default:
d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())})
case reflect.Slice:
if v.Type().Elem().Kind() != reflect.Uint8 {
d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())})
break
}
b := make([]byte, base64.StdEncoding.DecodedLen(len(s)))
n, err := base64.StdEncoding.Decode(b, s)
if err != nil {
d.saveError(err)
break
}
v.SetBytes(b[:n])
case reflect.String:
t := string(s)
if v.Type() == numberType && !isValidNumber(t) {
return fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number", item)
}
v.SetString(t)
case reflect.Interface:
if v.NumMethod() == 0 {
v.Set(reflect.ValueOf(string(s)))
} else {
d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())})
}
}
default: // number
if c != '-' && (c < '0' || c > '9') {
if fromQuoted {
return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())
}
panic(phasePanicMsg)
}
switch v.Kind() {
default:
if v.Kind() == reflect.String && v.Type() == numberType {
// s must be a valid number, because it's
// already been tokenized.
v.SetString(string(item))
break
}
if fromQuoted {
return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())
}
d.saveError(&UnmarshalTypeError{Value: "number", Type: v.Type(), Offset: int64(d.readIndex())})
case reflect.Interface:
n, err := d.convertNumber(string(item))
if err != nil {
d.saveError(err)
break
}
if v.NumMethod() != 0 {
d.saveError(&UnmarshalTypeError{Value: "number", Type: v.Type(), Offset: int64(d.readIndex())})
break
}
v.Set(reflect.ValueOf(n))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n, err := strconv.ParseInt(string(item), 10, 64)
if err != nil || v.OverflowInt(n) {
d.saveError(&UnmarshalTypeError{Value: "number " + string(item), Type: v.Type(), Offset: int64(d.readIndex())})
break
}
v.SetInt(n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
n, err := strconv.ParseUint(string(item), 10, 64)
if err != nil || v.OverflowUint(n) {
d.saveError(&UnmarshalTypeError{Value: "number " + string(item), Type: v.Type(), Offset: int64(d.readIndex())})
break
}
v.SetUint(n)
case reflect.Float32, reflect.Float64:
n, err := strconv.ParseFloat(string(item), v.Type().Bits())
if err != nil || v.OverflowFloat(n) {
d.saveError(&UnmarshalTypeError{Value: "number " + string(item), Type: v.Type(), Offset: int64(d.readIndex())})
break
}
v.SetFloat(n)
}
}
return nil
}
// The xxxInterface routines build up a value to be stored
// in an empty interface. They are not strictly necessary,
// but they avoid the weight of reflection in this common case.
// valueInterface is like value but returns any.
func (d *decodeState) valueInterface() (val any) {
switch d.opcode {
default:
panic(phasePanicMsg)
case scanBeginArray:
val = d.arrayInterface()
d.scanNext()
case scanBeginObject:
val = d.objectInterface()
d.scanNext()
case scanBeginLiteral:
val = d.literalInterface()
}
return
}
// arrayInterface is like array but returns []any.
func (d *decodeState) arrayInterface() []any {
var v = make([]any, 0)
for {
// Look ahead for ] - can only happen on first iteration.
d.scanWhile(scanSkipSpace)
if d.opcode == scanEndArray {
break
}
v = append(v, d.valueInterface())
// Next token must be , or ].
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.opcode == scanEndArray {
break
}
if d.opcode != scanArrayValue {
panic(phasePanicMsg)
}
}
return v
}
// objectInterface is like object but returns map[string]any.
func (d *decodeState) objectInterface() map[string]any {
m := make(map[string]any)
for {
// Read opening " of string key or closing }.
d.scanWhile(scanSkipSpace)
if d.opcode == scanEndObject {
// closing } - can only happen on first iteration.
break
}
if d.opcode != scanBeginLiteral {
panic(phasePanicMsg)
}
// Read string key.
start := d.readIndex()
d.rescanLiteral()
item := d.data[start:d.readIndex()]
key, ok := unquote(item)
if !ok {
panic(phasePanicMsg)
}
// Read : before value.
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.opcode != scanObjectKey {
panic(phasePanicMsg)
}
d.scanWhile(scanSkipSpace)
// Read value.
m[key] = d.valueInterface()
// Next token must be , or }.
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.opcode == scanEndObject {
break
}
if d.opcode != scanObjectValue {
panic(phasePanicMsg)
}
}
return m
}
// literalInterface consumes and returns a literal from d.data[d.off-1:] and
// it reads the following byte ahead. The first byte of the literal has been
// read already (that's how the caller knows it's a literal).
func (d *decodeState) literalInterface() any {
// All bytes inside literal return scanContinue op code.
start := d.readIndex()
d.rescanLiteral()
item := d.data[start:d.readIndex()]
switch c := item[0]; c {
case 'n': // null
return nil
case 't', 'f': // true, false
return c == 't'
case '"': // string
s, ok := unquote(item)
if !ok {
panic(phasePanicMsg)
}
return s
default: // number
if c != '-' && (c < '0' || c > '9') {
panic(phasePanicMsg)
}
n, err := d.convertNumber(string(item))
if err != nil {
d.saveError(err)
}
return n
}
}
// getu4 decodes \uXXXX from the beginning of s, returning the hex value,
// or it returns -1.
func getu4(s []byte) rune {
if len(s) < 6 || s[0] != '\\' || s[1] != 'u' {
return -1
}
var r rune
for _, c := range s[2:6] {
switch {
case '0' <= c && c <= '9':
c = c - '0'
case 'a' <= c && c <= 'f':
c = c - 'a' + 10
case 'A' <= c && c <= 'F':
c = c - 'A' + 10
default:
return -1
}
r = r*16 + rune(c)
}
return r
}
// unquote converts a quoted JSON string literal s into an actual string t.
// The rules are different than for Go, so cannot use strconv.Unquote.
func unquote(s []byte) (t string, ok bool) {
s, ok = unquoteBytes(s)
t = string(s)
return
}
// unquoteBytes should be an internal detail,
// but widely used packages access it using linkname.
// Notable members of the hall of shame include:
// - github.com/bytedance/sonic
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname unquoteBytes
func unquoteBytes(s []byte) (t []byte, ok bool) {
if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' {
return
}
s = s[1 : len(s)-1]
// Check for unusual characters. If there are none,
// then no unquoting is needed, so return a slice of the
// original bytes.
r := 0
for r < len(s) {
c := s[r]
if c == '\\' || c == '"' || c < ' ' {
break
}
if c < utf8.RuneSelf {
r++
continue
}
rr, size := utf8.DecodeRune(s[r:])
if rr == utf8.RuneError && size == 1 {
break
}
r += size
}
if r == len(s) {
return s, true
}
b := make([]byte, len(s)+2*utf8.UTFMax)
w := copy(b, s[0:r])
for r < len(s) {
// Out of room? Can only happen if s is full of
// malformed UTF-8 and we're replacing each
// byte with RuneError.
if w >= len(b)-2*utf8.UTFMax {
nb := make([]byte, (len(b)+utf8.UTFMax)*2)
copy(nb, b[0:w])
b = nb
}
switch c := s[r]; {
case c == '\\':
r++
if r >= len(s) {
return
}
switch s[r] {
default:
return
case '"', '\\', '/', '\'':
b[w] = s[r]
r++
w++
case 'b':
b[w] = '\b'
r++
w++
case 'f':
b[w] = '\f'
r++
w++
case 'n':
b[w] = '\n'
r++
w++
case 'r':
b[w] = '\r'
r++
w++
case 't':
b[w] = '\t'
r++
w++
case 'u':
r--
rr := getu4(s[r:])
if rr < 0 {
return
}
r += 6
if utf16.IsSurrogate(rr) {
rr1 := getu4(s[r:])
if dec := utf16.DecodeRune(rr, rr1); dec != unicode.ReplacementChar {
// A valid pair; consume.
r += 6
w += utf8.EncodeRune(b[w:], dec)
break
}
// Invalid surrogate; fall back to replacement rune.
rr = unicode.ReplacementChar
}
w += utf8.EncodeRune(b[w:], rr)
}
// Quote, control characters are invalid.
case c == '"', c < ' ':
return
// ASCII
case c < utf8.RuneSelf:
b[w] = c
r++
w++
// Coerce to well-formed UTF-8.
default:
rr, size := utf8.DecodeRune(s[r:])
r += size
w += utf8.EncodeRune(b[w:], rr)
}
}
return b[0:w], true
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !goexperiment.jsonv2
// Package json implements encoding and decoding of JSON as defined in RFC 7159.
// The mapping between JSON and Go values is described in the documentation for
// the Marshal and Unmarshal functions.
//
// See "JSON and Go" for an introduction to this package:
// https://golang.org/doc/articles/json_and_go.html
//
// # Security Considerations
//
// The JSON standard (RFC 7159) is lax in its definition of a number of parser
// behaviors. As such, many JSON parsers behave differently in various
// scenarios. These differences in parsers mean that systems that use multiple
// independent JSON parser implementations may parse the same JSON object in
// differing ways.
//
// Systems that rely on a JSON object being parsed consistently for security
// purposes should be careful to understand the behaviors of this parser, as
// well as how these behaviors may cause interoperability issues with other
// parser implementations.
//
// Due to the Go Backwards Compatibility promise (https://go.dev/doc/go1compat)
// there are a number of behaviors this package exhibits that may cause
// interopability issues, but cannot be changed. In particular the following
// parsing behaviors may cause issues:
//
// - If a JSON object contains duplicate keys, keys are processed in the order
// they are observed, meaning later values will replace or be merged into
// prior values, depending on the field type (in particular maps and structs
// will have values merged, while other types have values replaced).
// - When parsing a JSON object into a Go struct, keys are considered in a
// case-insensitive fashion.
// - When parsing a JSON object into a Go struct, unknown keys in the JSON
// object are ignored (unless a [Decoder] is used and
// [Decoder.DisallowUnknownFields] has been called).
// - Invalid UTF-8 bytes in JSON strings are replaced by the Unicode
// replacement character.
// - Large JSON number integers will lose precision when unmarshaled into
// floating-point types.
package json
import (
"bytes"
"cmp"
"encoding"
"encoding/base64"
"fmt"
"math"
"reflect"
"slices"
"strconv"
"strings"
"sync"
"unicode"
"unicode/utf8"
_ "unsafe" // for linkname
)
// Marshal returns the JSON encoding of v.
//
// Marshal traverses the value v recursively.
// If an encountered value implements [Marshaler]
// and is not a nil pointer, Marshal calls [Marshaler.MarshalJSON]
// to produce JSON. If no [Marshaler.MarshalJSON] method is present but the
// value implements [encoding.TextMarshaler] instead, Marshal calls
// [encoding.TextMarshaler.MarshalText] and encodes the result as a JSON string.
// The nil pointer exception is not strictly necessary
// but mimics a similar, necessary exception in the behavior of
// [Unmarshaler.UnmarshalJSON].
//
// Otherwise, Marshal uses the following type-dependent default encodings:
//
// Boolean values encode as JSON booleans.
//
// Floating point, integer, and [Number] values encode as JSON numbers.
// NaN and +/-Inf values will return an [UnsupportedValueError].
//
// String values encode as JSON strings coerced to valid UTF-8,
// replacing invalid bytes with the Unicode replacement rune.
// So that the JSON will be safe to embed inside HTML <script> tags,
// the string is encoded using [HTMLEscape],
// which replaces "<", ">", "&", U+2028, and U+2029 are escaped
// to "\u003c","\u003e", "\u0026", "\u2028", and "\u2029".
// This replacement can be disabled when using an [Encoder],
// by calling [Encoder.SetEscapeHTML](false).
//
// Array and slice values encode as JSON arrays, except that
// []byte encodes as a base64-encoded string, and a nil slice
// encodes as the null JSON value.
//
// Struct values encode as JSON objects.
// Each exported struct field becomes a member of the object, using the
// field name as the object key, unless the field is omitted for one of the
// reasons given below.
//
// The encoding of each struct field can be customized by the format string
// stored under the "json" key in the struct field's tag.
// The format string gives the name of the field, possibly followed by a
// comma-separated list of options. The name may be empty in order to
// specify options without overriding the default field name.
//
// The "omitempty" option specifies that the field should be omitted
// from the encoding if the field has an empty value, defined as
// false, 0, a nil pointer, a nil interface value, and any array,
// slice, map, or string of length zero.
//
// As a special case, if the field tag is "-", the field is always omitted.
// Note that a field with name "-" can still be generated using the tag "-,".
//
// Examples of struct field tags and their meanings:
//
// // Field appears in JSON as key "myName".
// Field int `json:"myName"`
//
// // Field appears in JSON as key "myName" and
// // the field is omitted from the object if its value is empty,
// // as defined above.
// Field int `json:"myName,omitempty"`
//
// // Field appears in JSON as key "Field" (the default), but
// // the field is skipped if empty.
// // Note the leading comma.
// Field int `json:",omitempty"`
//
// // Field is ignored by this package.
// Field int `json:"-"`
//
// // Field appears in JSON as key "-".
// Field int `json:"-,"`
//
// The "omitzero" option specifies that the field should be omitted
// from the encoding if the field has a zero value, according to rules:
//
// 1) If the field type has an "IsZero() bool" method, that will be used to
// determine whether the value is zero.
//
// 2) Otherwise, the value is zero if it is the zero value for its type.
//
// If both "omitempty" and "omitzero" are specified, the field will be omitted
// if the value is either empty or zero (or both).
//
// The "string" option signals that a field is stored as JSON inside a
// JSON-encoded string. It applies only to fields of string, floating point,
// integer, or boolean types. This extra level of encoding is sometimes used
// when communicating with JavaScript programs:
//
// Int64String int64 `json:",string"`
//
// The key name will be used if it's a non-empty string consisting of
// only Unicode letters, digits, and ASCII punctuation except quotation
// marks, backslash, and comma.
//
// Embedded struct fields are usually marshaled as if their inner exported fields
// were fields in the outer struct, subject to the usual Go visibility rules amended
// as described in the next paragraph.
// An anonymous struct field with a name given in its JSON tag is treated as
// having that name, rather than being anonymous.
// An anonymous struct field of interface type is treated the same as having
// that type as its name, rather than being anonymous.
//
// The Go visibility rules for struct fields are amended for JSON when
// deciding which field to marshal or unmarshal. If there are
// multiple fields at the same level, and that level is the least
// nested (and would therefore be the nesting level selected by the
// usual Go rules), the following extra rules apply:
//
// 1) Of those fields, if any are JSON-tagged, only tagged fields are considered,
// even if there are multiple untagged fields that would otherwise conflict.
//
// 2) If there is exactly one field (tagged or not according to the first rule), that is selected.
//
// 3) Otherwise there are multiple fields, and all are ignored; no error occurs.
//
// Handling of anonymous struct fields is new in Go 1.1.
// Prior to Go 1.1, anonymous struct fields were ignored. To force ignoring of
// an anonymous struct field in both current and earlier versions, give the field
// a JSON tag of "-".
//
// Map values encode as JSON objects. The map's key type must either be a
// string, an integer type, or implement [encoding.TextMarshaler]. The map keys
// are sorted and used as JSON object keys by applying the following rules,
// subject to the UTF-8 coercion described for string values above:
// - keys of any string type are used directly
// - keys that implement [encoding.TextMarshaler] are marshaled
// - integer keys are converted to strings
//
// Pointer values encode as the value pointed to.
// A nil pointer encodes as the null JSON value.
//
// Interface values encode as the value contained in the interface.
// A nil interface value encodes as the null JSON value.
//
// Channel, complex, and function values cannot be encoded in JSON.
// Attempting to encode such a value causes Marshal to return
// an [UnsupportedTypeError].
//
// JSON cannot represent cyclic data structures and Marshal does not
// handle them. Passing cyclic structures to Marshal will result in
// an error.
func Marshal(v any) ([]byte, error) {
e := newEncodeState()
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: true})
if err != nil {
return nil, err
}
buf := append([]byte(nil), e.Bytes()...)
return buf, nil
}
// MarshalIndent is like [Marshal] but applies [Indent] to format the output.
// Each JSON element in the output will begin on a new line beginning with prefix
// followed by one or more copies of indent according to the indentation nesting.
func MarshalIndent(v any, prefix, indent string) ([]byte, error) {
b, err := Marshal(v)
if err != nil {
return nil, err
}
b2 := make([]byte, 0, indentGrowthFactor*len(b))
b2, err = appendIndent(b2, b, prefix, indent)
if err != nil {
return nil, err
}
return b2, nil
}
// Marshaler is the interface implemented by types that
// can marshal themselves into valid JSON.
type Marshaler interface {
MarshalJSON() ([]byte, error)
}
// An UnsupportedTypeError is returned by [Marshal] when attempting
// to encode an unsupported value type.
type UnsupportedTypeError struct {
Type reflect.Type
}
func (e *UnsupportedTypeError) Error() string {
return "json: unsupported type: " + e.Type.String()
}
// An UnsupportedValueError is returned by [Marshal] when attempting
// to encode an unsupported value.
type UnsupportedValueError struct {
Value reflect.Value
Str string
}
func (e *UnsupportedValueError) Error() string {
return "json: unsupported value: " + e.Str
}
// Before Go 1.2, an InvalidUTF8Error was returned by [Marshal] when
// attempting to encode a string value with invalid UTF-8 sequences.
// As of Go 1.2, [Marshal] instead coerces the string to valid UTF-8 by
// replacing invalid bytes with the Unicode replacement rune U+FFFD.
//
// Deprecated: No longer used; kept for compatibility.
type InvalidUTF8Error struct {
S string // the whole string value that caused the error
}
func (e *InvalidUTF8Error) Error() string {
return "json: invalid UTF-8 in string: " + strconv.Quote(e.S)
}
// A MarshalerError represents an error from calling a
// [Marshaler.MarshalJSON] or [encoding.TextMarshaler.MarshalText] method.
type MarshalerError struct {
Type reflect.Type
Err error
sourceFunc string
}
func (e *MarshalerError) Error() string {
srcFunc := e.sourceFunc
if srcFunc == "" {
srcFunc = "MarshalJSON"
}
return "json: error calling " + srcFunc +
" for type " + e.Type.String() +
": " + e.Err.Error()
}
// Unwrap returns the underlying error.
func (e *MarshalerError) Unwrap() error { return e.Err }
const hex = "0123456789abcdef"
// An encodeState encodes JSON into a bytes.Buffer.
type encodeState struct {
bytes.Buffer // accumulated output
// Keep track of what pointers we've seen in the current recursive call
// path, to avoid cycles that could lead to a stack overflow. Only do
// the relatively expensive map operations if ptrLevel is larger than
// startDetectingCyclesAfter, so that we skip the work if we're within a
// reasonable amount of nested pointers deep.
ptrLevel uint
ptrSeen map[any]struct{}
}
const startDetectingCyclesAfter = 1000
var encodeStatePool sync.Pool
func newEncodeState() *encodeState {
if v := encodeStatePool.Get(); v != nil {
e := v.(*encodeState)
e.Reset()
if len(e.ptrSeen) > 0 {
panic("ptrEncoder.encode should have emptied ptrSeen via defers")
}
e.ptrLevel = 0
return e
}
return &encodeState{ptrSeen: make(map[any]struct{})}
}
// jsonError is an error wrapper type for internal use only.
// Panics with errors are wrapped in jsonError so that the top-level recover
// can distinguish intentional panics from this package.
type jsonError struct{ error }
func (e *encodeState) marshal(v any, opts encOpts) (err error) {
defer func() {
if r := recover(); r != nil {
if je, ok := r.(jsonError); ok {
err = je.error
} else {
panic(r)
}
}
}()
e.reflectValue(reflect.ValueOf(v), opts)
return nil
}
// error aborts the encoding by panicking with err wrapped in jsonError.
func (e *encodeState) error(err error) {
panic(jsonError{err})
}
func isEmptyValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
reflect.Float32, reflect.Float64,
reflect.Interface, reflect.Pointer:
return v.IsZero()
}
return false
}
func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) {
valueEncoder(v)(e, v, opts)
}
type encOpts struct {
// quoted causes primitive fields to be encoded inside JSON strings.
quoted bool
// escapeHTML causes '<', '>', and '&' to be escaped in JSON strings.
escapeHTML bool
}
type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts)
var encoderCache sync.Map // map[reflect.Type]encoderFunc
func valueEncoder(v reflect.Value) encoderFunc {
if !v.IsValid() {
return invalidValueEncoder
}
return typeEncoder(v.Type())
}
func typeEncoder(t reflect.Type) encoderFunc {
if fi, ok := encoderCache.Load(t); ok {
return fi.(encoderFunc)
}
// To deal with recursive types, populate the map with an
// indirect func before we build it. If the type is recursive,
// the second lookup for the type will return the indirect func.
//
// This indirect func is only used for recursive types,
// and briefly during racing calls to typeEncoder.
indirect := sync.OnceValue(func() encoderFunc {
return newTypeEncoder(t, true)
})
fi, loaded := encoderCache.LoadOrStore(t, encoderFunc(func(e *encodeState, v reflect.Value, opts encOpts) {
indirect()(e, v, opts)
}))
if loaded {
return fi.(encoderFunc)
}
f := indirect()
encoderCache.Store(t, f)
return f
}
var (
marshalerType = reflect.TypeFor[Marshaler]()
textMarshalerType = reflect.TypeFor[encoding.TextMarshaler]()
)
// newTypeEncoder constructs an encoderFunc for a type.
// The returned encoder only checks CanAddr when allowAddr is true.
func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
// If we have a non-pointer value whose type implements
// Marshaler with a value receiver, then we're better off taking
// the address of the value - otherwise we end up with an
// allocation as we cast the value to an interface.
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) {
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
}
if t.Implements(marshalerType) {
return marshalerEncoder
}
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) {
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
}
if t.Implements(textMarshalerType) {
return textMarshalerEncoder
}
switch t.Kind() {
case reflect.Bool:
return boolEncoder
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return intEncoder
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return uintEncoder
case reflect.Float32:
return float32Encoder
case reflect.Float64:
return float64Encoder
case reflect.String:
return stringEncoder
case reflect.Interface:
return interfaceEncoder
case reflect.Struct:
return newStructEncoder(t)
case reflect.Map:
return newMapEncoder(t)
case reflect.Slice:
return newSliceEncoder(t)
case reflect.Array:
return newArrayEncoder(t)
case reflect.Pointer:
return newPtrEncoder(t)
default:
return unsupportedTypeEncoder
}
}
func invalidValueEncoder(e *encodeState, v reflect.Value, _ encOpts) {
e.WriteString("null")
}
func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Kind() == reflect.Pointer && v.IsNil() {
e.WriteString("null")
return
}
m, ok := v.Interface().(Marshaler)
if !ok {
e.WriteString("null")
return
}
b, err := m.MarshalJSON()
if err == nil {
e.Grow(len(b))
out := e.AvailableBuffer()
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}
func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
va := v.Addr()
if va.IsNil() {
e.WriteString("null")
return
}
m := va.Interface().(Marshaler)
b, err := m.MarshalJSON()
if err == nil {
e.Grow(len(b))
out := e.AvailableBuffer()
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}
func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Kind() == reflect.Pointer && v.IsNil() {
e.WriteString("null")
return
}
m, ok := v.Interface().(encoding.TextMarshaler)
if !ok {
e.WriteString("null")
return
}
b, err := m.MarshalText()
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalText"})
}
e.Write(appendString(e.AvailableBuffer(), b, opts.escapeHTML))
}
func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
va := v.Addr()
if va.IsNil() {
e.WriteString("null")
return
}
m := va.Interface().(encoding.TextMarshaler)
b, err := m.MarshalText()
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalText"})
}
e.Write(appendString(e.AvailableBuffer(), b, opts.escapeHTML))
}
func boolEncoder(e *encodeState, v reflect.Value, opts encOpts) {
b := e.AvailableBuffer()
b = mayAppendQuote(b, opts.quoted)
b = strconv.AppendBool(b, v.Bool())
b = mayAppendQuote(b, opts.quoted)
e.Write(b)
}
func intEncoder(e *encodeState, v reflect.Value, opts encOpts) {
b := e.AvailableBuffer()
b = mayAppendQuote(b, opts.quoted)
b = strconv.AppendInt(b, v.Int(), 10)
b = mayAppendQuote(b, opts.quoted)
e.Write(b)
}
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
b := e.AvailableBuffer()
b = mayAppendQuote(b, opts.quoted)
b = strconv.AppendUint(b, v.Uint(), 10)
b = mayAppendQuote(b, opts.quoted)
e.Write(b)
}
type floatEncoder int // number of bits
func (bits floatEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
f := v.Float()
if math.IsInf(f, 0) || math.IsNaN(f) {
e.error(&UnsupportedValueError{v, strconv.FormatFloat(f, 'g', -1, int(bits))})
}
// Convert as if by ES6 number to string conversion.
// This matches most other JSON generators.
// See golang.org/issue/6384 and golang.org/issue/14135.
// Like fmt %g, but the exponent cutoffs are different
// and exponents themselves are not padded to two digits.
b := e.AvailableBuffer()
b = mayAppendQuote(b, opts.quoted)
abs := math.Abs(f)
fmt := byte('f')
// Note: Must use float32 comparisons for underlying float32 value to get precise cutoffs right.
if abs != 0 {
if bits == 64 && (abs < 1e-6 || abs >= 1e21) || bits == 32 && (float32(abs) < 1e-6 || float32(abs) >= 1e21) {
fmt = 'e'
}
}
b = strconv.AppendFloat(b, f, fmt, -1, int(bits))
if fmt == 'e' {
// clean up e-09 to e-9
n := len(b)
if n >= 4 && b[n-4] == 'e' && b[n-3] == '-' && b[n-2] == '0' {
b[n-2] = b[n-1]
b = b[:n-1]
}
}
b = mayAppendQuote(b, opts.quoted)
e.Write(b)
}
var (
float32Encoder = (floatEncoder(32)).encode
float64Encoder = (floatEncoder(64)).encode
)
func stringEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Type() == numberType {
numStr := v.String()
// In Go1.5 the empty string encodes to "0", while this is not a valid number literal
// we keep compatibility so check validity after this.
if numStr == "" {
numStr = "0" // Number's zero-val
}
if !isValidNumber(numStr) {
e.error(fmt.Errorf("json: invalid number literal %q", numStr))
}
b := e.AvailableBuffer()
b = mayAppendQuote(b, opts.quoted)
b = append(b, numStr...)
b = mayAppendQuote(b, opts.quoted)
e.Write(b)
return
}
if opts.quoted {
b := appendString(nil, v.String(), opts.escapeHTML)
e.Write(appendString(e.AvailableBuffer(), b, false)) // no need to escape again since it is already escaped
} else {
e.Write(appendString(e.AvailableBuffer(), v.String(), opts.escapeHTML))
}
}
// isValidNumber reports whether s is a valid JSON number literal.
//
// isValidNumber should be an internal detail,
// but widely used packages access it using linkname.
// Notable members of the hall of shame include:
// - github.com/bytedance/sonic
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname isValidNumber
func isValidNumber(s string) bool {
// This function implements the JSON numbers grammar.
// See https://tools.ietf.org/html/rfc7159#section-6
// and https://www.json.org/img/number.png
if s == "" {
return false
}
// Optional -
if s[0] == '-' {
s = s[1:]
if s == "" {
return false
}
}
// Digits
switch {
default:
return false
case s[0] == '0':
s = s[1:]
case '1' <= s[0] && s[0] <= '9':
s = s[1:]
for len(s) > 0 && '0' <= s[0] && s[0] <= '9' {
s = s[1:]
}
}
// . followed by 1 or more digits.
if len(s) >= 2 && s[0] == '.' && '0' <= s[1] && s[1] <= '9' {
s = s[2:]
for len(s) > 0 && '0' <= s[0] && s[0] <= '9' {
s = s[1:]
}
}
// e or E followed by an optional - or + and
// 1 or more digits.
if len(s) >= 2 && (s[0] == 'e' || s[0] == 'E') {
s = s[1:]
if s[0] == '+' || s[0] == '-' {
s = s[1:]
if s == "" {
return false
}
}
for len(s) > 0 && '0' <= s[0] && s[0] <= '9' {
s = s[1:]
}
}
// Make sure we are at the end.
return s == ""
}
func interfaceEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.IsNil() {
e.WriteString("null")
return
}
e.reflectValue(v.Elem(), opts)
}
func unsupportedTypeEncoder(e *encodeState, v reflect.Value, _ encOpts) {
e.error(&UnsupportedTypeError{v.Type()})
}
type structEncoder struct {
fields structFields
}
type structFields struct {
list []field
byExactName map[string]*field
byFoldedName map[string]*field
}
func (se structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
next := byte('{')
FieldLoop:
for i := range se.fields.list {
f := &se.fields.list[i]
// Find the nested struct field by following f.index.
fv := v
for _, i := range f.index {
if fv.Kind() == reflect.Pointer {
if fv.IsNil() {
continue FieldLoop
}
fv = fv.Elem()
}
fv = fv.Field(i)
}
if (f.omitEmpty && isEmptyValue(fv)) ||
(f.omitZero && (f.isZero == nil && fv.IsZero() || (f.isZero != nil && f.isZero(fv)))) {
continue
}
e.WriteByte(next)
next = ','
if opts.escapeHTML {
e.WriteString(f.nameEscHTML)
} else {
e.WriteString(f.nameNonEsc)
}
opts.quoted = f.quoted
f.encoder(e, fv, opts)
}
if next == '{' {
e.WriteString("{}")
} else {
e.WriteByte('}')
}
}
func newStructEncoder(t reflect.Type) encoderFunc {
se := structEncoder{fields: cachedTypeFields(t)}
return se.encode
}
type mapEncoder struct {
elemEnc encoderFunc
}
func (me mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
if v.IsNil() {
e.WriteString("null")
return
}
if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter {
// We're a large number of nested ptrEncoder.encode calls deep;
// start checking if we've run into a pointer cycle.
ptr := v.UnsafePointer()
if _, ok := e.ptrSeen[ptr]; ok {
e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())})
}
e.ptrSeen[ptr] = struct{}{}
defer delete(e.ptrSeen, ptr)
}
e.WriteByte('{')
// Extract and sort the keys.
var (
sv = make([]reflectWithString, v.Len())
mi = v.MapRange()
err error
)
for i := 0; mi.Next(); i++ {
if sv[i].ks, err = resolveKeyName(mi.Key()); err != nil {
e.error(fmt.Errorf("json: encoding error for type %q: %q", v.Type().String(), err.Error()))
}
sv[i].v = mi.Value()
}
slices.SortFunc(sv, func(i, j reflectWithString) int {
return strings.Compare(i.ks, j.ks)
})
for i, kv := range sv {
if i > 0 {
e.WriteByte(',')
}
e.Write(appendString(e.AvailableBuffer(), kv.ks, opts.escapeHTML))
e.WriteByte(':')
me.elemEnc(e, kv.v, opts)
}
e.WriteByte('}')
e.ptrLevel--
}
func newMapEncoder(t reflect.Type) encoderFunc {
switch t.Key().Kind() {
case reflect.String,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
default:
if !t.Key().Implements(textMarshalerType) {
return unsupportedTypeEncoder
}
}
me := mapEncoder{typeEncoder(t.Elem())}
return me.encode
}
func encodeByteSlice(e *encodeState, v reflect.Value, _ encOpts) {
if v.IsNil() {
e.WriteString("null")
return
}
s := v.Bytes()
b := e.AvailableBuffer()
b = append(b, '"')
b = base64.StdEncoding.AppendEncode(b, s)
b = append(b, '"')
e.Write(b)
}
// sliceEncoder just wraps an arrayEncoder, checking to make sure the value isn't nil.
type sliceEncoder struct {
arrayEnc encoderFunc
}
func (se sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
if v.IsNil() {
e.WriteString("null")
return
}
if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter {
// We're a large number of nested ptrEncoder.encode calls deep;
// start checking if we've run into a pointer cycle.
// Here we use a struct to memorize the pointer to the first element of the slice
// and its length.
ptr := struct {
ptr any // always an unsafe.Pointer, but avoids a dependency on package unsafe
len int
}{v.UnsafePointer(), v.Len()}
if _, ok := e.ptrSeen[ptr]; ok {
e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())})
}
e.ptrSeen[ptr] = struct{}{}
defer delete(e.ptrSeen, ptr)
}
se.arrayEnc(e, v, opts)
e.ptrLevel--
}
func newSliceEncoder(t reflect.Type) encoderFunc {
// Byte slices get special treatment; arrays don't.
if t.Elem().Kind() == reflect.Uint8 {
p := reflect.PointerTo(t.Elem())
if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) {
return encodeByteSlice
}
}
enc := sliceEncoder{newArrayEncoder(t)}
return enc.encode
}
type arrayEncoder struct {
elemEnc encoderFunc
}
func (ae arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
e.WriteByte('[')
n := v.Len()
for i := 0; i < n; i++ {
if i > 0 {
e.WriteByte(',')
}
ae.elemEnc(e, v.Index(i), opts)
}
e.WriteByte(']')
}
func newArrayEncoder(t reflect.Type) encoderFunc {
enc := arrayEncoder{typeEncoder(t.Elem())}
return enc.encode
}
type ptrEncoder struct {
elemEnc encoderFunc
}
func (pe ptrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
if v.IsNil() {
e.WriteString("null")
return
}
if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter {
// We're a large number of nested ptrEncoder.encode calls deep;
// start checking if we've run into a pointer cycle.
ptr := v.Interface()
if _, ok := e.ptrSeen[ptr]; ok {
e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())})
}
e.ptrSeen[ptr] = struct{}{}
defer delete(e.ptrSeen, ptr)
}
pe.elemEnc(e, v.Elem(), opts)
e.ptrLevel--
}
func newPtrEncoder(t reflect.Type) encoderFunc {
enc := ptrEncoder{typeEncoder(t.Elem())}
return enc.encode
}
type condAddrEncoder struct {
canAddrEnc, elseEnc encoderFunc
}
func (ce condAddrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
if v.CanAddr() {
ce.canAddrEnc(e, v, opts)
} else {
ce.elseEnc(e, v, opts)
}
}
// newCondAddrEncoder returns an encoder that checks whether its value
// CanAddr and delegates to canAddrEnc if so, else to elseEnc.
func newCondAddrEncoder(canAddrEnc, elseEnc encoderFunc) encoderFunc {
enc := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc}
return enc.encode
}
func isValidTag(s string) bool {
if s == "" {
return false
}
for _, c := range s {
switch {
case strings.ContainsRune("!#$%&()*+-./:;<=>?@[]^_{|}~ ", c):
// Backslash and quote chars are reserved, but
// otherwise any punctuation chars are allowed
// in a tag name.
case !unicode.IsLetter(c) && !unicode.IsDigit(c):
return false
}
}
return true
}
func typeByIndex(t reflect.Type, index []int) reflect.Type {
for _, i := range index {
if t.Kind() == reflect.Pointer {
t = t.Elem()
}
t = t.Field(i).Type
}
return t
}
type reflectWithString struct {
v reflect.Value
ks string
}
func resolveKeyName(k reflect.Value) (string, error) {
if k.Kind() == reflect.String {
return k.String(), nil
}
if tm, ok := k.Interface().(encoding.TextMarshaler); ok {
if k.Kind() == reflect.Pointer && k.IsNil() {
return "", nil
}
buf, err := tm.MarshalText()
return string(buf), err
}
switch k.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(k.Int(), 10), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return strconv.FormatUint(k.Uint(), 10), nil
}
panic("unexpected map key type")
}
func appendString[Bytes []byte | string](dst []byte, src Bytes, escapeHTML bool) []byte {
dst = append(dst, '"')
start := 0
for i := 0; i < len(src); {
if b := src[i]; b < utf8.RuneSelf {
if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) {
i++
continue
}
dst = append(dst, src[start:i]...)
switch b {
case '\\', '"':
dst = append(dst, '\\', b)
case '\b':
dst = append(dst, '\\', 'b')
case '\f':
dst = append(dst, '\\', 'f')
case '\n':
dst = append(dst, '\\', 'n')
case '\r':
dst = append(dst, '\\', 'r')
case '\t':
dst = append(dst, '\\', 't')
default:
// This encodes bytes < 0x20 except for \b, \f, \n, \r and \t.
// If escapeHTML is set, it also escapes <, >, and &
// because they can lead to security holes when
// user-controlled strings are rendered into JSON
// and served to some browsers.
dst = append(dst, '\\', 'u', '0', '0', hex[b>>4], hex[b&0xF])
}
i++
start = i
continue
}
// TODO(https://go.dev/issue/56948): Use generic utf8 functionality.
// For now, cast only a small portion of byte slices to a string
// so that it can be stack allocated. This slows down []byte slightly
// due to the extra copy, but keeps string performance roughly the same.
n := min(len(src)-i, utf8.UTFMax)
c, size := utf8.DecodeRuneInString(string(src[i : i+n]))
if c == utf8.RuneError && size == 1 {
dst = append(dst, src[start:i]...)
dst = append(dst, `\ufffd`...)
i += size
start = i
continue
}
// U+2028 is LINE SEPARATOR.
// U+2029 is PARAGRAPH SEPARATOR.
// They are both technically valid characters in JSON strings,
// but don't work in JSONP, which has to be evaluated as JavaScript,
// and can lead to security holes there. It is valid JSON to
// escape them, so we do so unconditionally.
// See https://en.wikipedia.org/wiki/JSON#Safety.
if c == '\u2028' || c == '\u2029' {
dst = append(dst, src[start:i]...)
dst = append(dst, '\\', 'u', '2', '0', '2', hex[c&0xF])
i += size
start = i
continue
}
i += size
}
dst = append(dst, src[start:]...)
dst = append(dst, '"')
return dst
}
// A field represents a single field found in a struct.
type field struct {
name string
nameBytes []byte // []byte(name)
nameNonEsc string // `"` + name + `":`
nameEscHTML string // `"` + HTMLEscape(name) + `":`
tag bool
index []int
typ reflect.Type
omitEmpty bool
omitZero bool
isZero func(reflect.Value) bool
quoted bool
encoder encoderFunc
}
type isZeroer interface {
IsZero() bool
}
var isZeroerType = reflect.TypeFor[isZeroer]()
// typeFields returns a list of fields that JSON should recognize for the given type.
// The algorithm is breadth-first search over the set of structs to include - the top struct
// and then any reachable anonymous structs.
//
// typeFields should be an internal detail,
// but widely used packages access it using linkname.
// Notable members of the hall of shame include:
// - github.com/bytedance/sonic
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname typeFields
func typeFields(t reflect.Type) structFields {
// Anonymous fields to explore at the current level and the next.
current := []field{}
next := []field{{typ: t}}
// Count of queued names for current level and the next.
var count, nextCount map[reflect.Type]int
// Types already visited at an earlier level.
visited := map[reflect.Type]bool{}
// Fields found.
var fields []field
// Buffer to run appendHTMLEscape on field names.
var nameEscBuf []byte
for len(next) > 0 {
current, next = next, current[:0]
count, nextCount = nextCount, map[reflect.Type]int{}
for _, f := range current {
if visited[f.typ] {
continue
}
visited[f.typ] = true
// Scan f.typ for fields to include.
for i := 0; i < f.typ.NumField(); i++ {
sf := f.typ.Field(i)
if sf.Anonymous {
t := sf.Type
if t.Kind() == reflect.Pointer {
t = t.Elem()
}
if !sf.IsExported() && t.Kind() != reflect.Struct {
// Ignore embedded fields of unexported non-struct types.
continue
}
// Do not ignore embedded fields of unexported struct types
// since they may have exported fields.
} else if !sf.IsExported() {
// Ignore unexported non-embedded fields.
continue
}
tag := sf.Tag.Get("json")
if tag == "-" {
continue
}
name, opts := parseTag(tag)
if !isValidTag(name) {
name = ""
}
index := make([]int, len(f.index)+1)
copy(index, f.index)
index[len(f.index)] = i
ft := sf.Type
if ft.Name() == "" && ft.Kind() == reflect.Pointer {
// Follow pointer.
ft = ft.Elem()
}
// Only strings, floats, integers, and booleans can be quoted.
quoted := false
if opts.Contains("string") {
switch ft.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
reflect.Float32, reflect.Float64,
reflect.String:
quoted = true
}
}
// Record found field and index sequence.
if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
tagged := name != ""
if name == "" {
name = sf.Name
}
field := field{
name: name,
tag: tagged,
index: index,
typ: ft,
omitEmpty: opts.Contains("omitempty"),
omitZero: opts.Contains("omitzero"),
quoted: quoted,
}
field.nameBytes = []byte(field.name)
// Build nameEscHTML and nameNonEsc ahead of time.
nameEscBuf = appendHTMLEscape(nameEscBuf[:0], field.nameBytes)
field.nameEscHTML = `"` + string(nameEscBuf) + `":`
field.nameNonEsc = `"` + field.name + `":`
if field.omitZero {
t := sf.Type
// Provide a function that uses a type's IsZero method.
switch {
case t.Kind() == reflect.Interface && t.Implements(isZeroerType):
field.isZero = func(v reflect.Value) bool {
// Avoid panics calling IsZero on a nil interface or
// non-nil interface with nil pointer.
return v.IsNil() ||
(v.Elem().Kind() == reflect.Pointer && v.Elem().IsNil()) ||
v.Interface().(isZeroer).IsZero()
}
case t.Kind() == reflect.Pointer && t.Implements(isZeroerType):
field.isZero = func(v reflect.Value) bool {
// Avoid panics calling IsZero on nil pointer.
return v.IsNil() || v.Interface().(isZeroer).IsZero()
}
case t.Implements(isZeroerType):
field.isZero = func(v reflect.Value) bool {
return v.Interface().(isZeroer).IsZero()
}
case reflect.PointerTo(t).Implements(isZeroerType):
field.isZero = func(v reflect.Value) bool {
if !v.CanAddr() {
// Temporarily box v so we can take the address.
v2 := reflect.New(v.Type()).Elem()
v2.Set(v)
v = v2
}
return v.Addr().Interface().(isZeroer).IsZero()
}
}
}
fields = append(fields, field)
if count[f.typ] > 1 {
// If there were multiple instances, add a second,
// so that the annihilation code will see a duplicate.
// It only cares about the distinction between 1 and 2,
// so don't bother generating any more copies.
fields = append(fields, fields[len(fields)-1])
}
continue
}
// Record new anonymous struct to explore in next round.
nextCount[ft]++
if nextCount[ft] == 1 {
next = append(next, field{name: ft.Name(), index: index, typ: ft})
}
}
}
}
slices.SortFunc(fields, func(a, b field) int {
// sort field by name, breaking ties with depth, then
// breaking ties with "name came from json tag", then
// breaking ties with index sequence.
if c := strings.Compare(a.name, b.name); c != 0 {
return c
}
if c := cmp.Compare(len(a.index), len(b.index)); c != 0 {
return c
}
if a.tag != b.tag {
if a.tag {
return -1
}
return +1
}
return slices.Compare(a.index, b.index)
})
// Delete all fields that are hidden by the Go rules for embedded fields,
// except that fields with JSON tags are promoted.
// The fields are sorted in primary order of name, secondary order
// of field index length. Loop over names; for each name, delete
// hidden fields by choosing the one dominant field that survives.
out := fields[:0]
for advance, i := 0, 0; i < len(fields); i += advance {
// One iteration per name.
// Find the sequence of fields with the name of this first field.
fi := fields[i]
name := fi.name
for advance = 1; i+advance < len(fields); advance++ {
fj := fields[i+advance]
if fj.name != name {
break
}
}
if advance == 1 { // Only one field with this name
out = append(out, fi)
continue
}
dominant, ok := dominantField(fields[i : i+advance])
if ok {
out = append(out, dominant)
}
}
fields = out
slices.SortFunc(fields, func(i, j field) int {
return slices.Compare(i.index, j.index)
})
for i := range fields {
f := &fields[i]
f.encoder = typeEncoder(typeByIndex(t, f.index))
}
exactNameIndex := make(map[string]*field, len(fields))
foldedNameIndex := make(map[string]*field, len(fields))
for i, field := range fields {
exactNameIndex[field.name] = &fields[i]
// For historical reasons, first folded match takes precedence.
if _, ok := foldedNameIndex[string(foldName(field.nameBytes))]; !ok {
foldedNameIndex[string(foldName(field.nameBytes))] = &fields[i]
}
}
return structFields{fields, exactNameIndex, foldedNameIndex}
}
// dominantField looks through the fields, all of which are known to
// have the same name, to find the single field that dominates the
// others using Go's embedding rules, modified by the presence of
// JSON tags. If there are multiple top-level fields, the boolean
// will be false: This condition is an error in Go and we skip all
// the fields.
func dominantField(fields []field) (field, bool) {
// The fields are sorted in increasing index-length order, then by presence of tag.
// That means that the first field is the dominant one. We need only check
// for error cases: two fields at top level, either both tagged or neither tagged.
if len(fields) > 1 && len(fields[0].index) == len(fields[1].index) && fields[0].tag == fields[1].tag {
return field{}, false
}
return fields[0], true
}
var fieldCache sync.Map // map[reflect.Type]structFields
// cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
func cachedTypeFields(t reflect.Type) structFields {
if f, ok := fieldCache.Load(t); ok {
return f.(structFields)
}
f, _ := fieldCache.LoadOrStore(t, typeFields(t))
return f.(structFields)
}
func mayAppendQuote(b []byte, quoted bool) []byte {
if quoted {
b = append(b, '"')
}
return b
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !goexperiment.jsonv2
package json
import (
"unicode"
"unicode/utf8"
)
// foldName returns a folded string such that foldName(x) == foldName(y)
// is identical to bytes.EqualFold(x, y).
func foldName(in []byte) []byte {
// This is inlinable to take advantage of "function outlining".
var arr [32]byte // large enough for most JSON names
return appendFoldedName(arr[:0], in)
}
func appendFoldedName(out, in []byte) []byte {
for i := 0; i < len(in); {
// Handle single-byte ASCII.
if c := in[i]; c < utf8.RuneSelf {
if 'a' <= c && c <= 'z' {
c -= 'a' - 'A'
}
out = append(out, c)
i++
continue
}
// Handle multi-byte Unicode.
r, n := utf8.DecodeRune(in[i:])
out = utf8.AppendRune(out, foldRune(r))
i += n
}
return out
}
// foldRune is returns the smallest rune for all runes in the same fold set.
func foldRune(r rune) rune {
for {
r2 := unicode.SimpleFold(r)
if r2 <= r {
return r2
}
r = r2
}
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !goexperiment.jsonv2
package json
import "bytes"
// HTMLEscape appends to dst the JSON-encoded src with <, >, &, U+2028 and U+2029
// characters inside string literals changed to \u003c, \u003e, \u0026, \u2028, \u2029
// so that the JSON will be safe to embed inside HTML <script> tags.
// For historical reasons, web browsers don't honor standard HTML
// escaping within <script> tags, so an alternative JSON encoding must be used.
func HTMLEscape(dst *bytes.Buffer, src []byte) {
dst.Grow(len(src))
dst.Write(appendHTMLEscape(dst.AvailableBuffer(), src))
}
func appendHTMLEscape(dst, src []byte) []byte {
// The characters can only appear in string literals,
// so just scan the string one byte at a time.
start := 0
for i, c := range src {
if c == '<' || c == '>' || c == '&' {
dst = append(dst, src[start:i]...)
dst = append(dst, '\\', 'u', '0', '0', hex[c>>4], hex[c&0xF])
start = i + 1
}
// Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9).
if c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 {
dst = append(dst, src[start:i]...)
dst = append(dst, '\\', 'u', '2', '0', '2', hex[src[i+2]&0xF])
start = i + len("\u2029")
}
}
return append(dst, src[start:]...)
}
// Compact appends to dst the JSON-encoded src with
// insignificant space characters elided.
func Compact(dst *bytes.Buffer, src []byte) error {
dst.Grow(len(src))
b := dst.AvailableBuffer()
b, err := appendCompact(b, src, false)
dst.Write(b)
return err
}
func appendCompact(dst, src []byte, escape bool) ([]byte, error) {
origLen := len(dst)
scan := newScanner()
defer freeScanner(scan)
start := 0
for i, c := range src {
if escape && (c == '<' || c == '>' || c == '&') {
if start < i {
dst = append(dst, src[start:i]...)
}
dst = append(dst, '\\', 'u', '0', '0', hex[c>>4], hex[c&0xF])
start = i + 1
}
// Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9).
if escape && c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 {
if start < i {
dst = append(dst, src[start:i]...)
}
dst = append(dst, '\\', 'u', '2', '0', '2', hex[src[i+2]&0xF])
start = i + 3
}
v := scan.step(scan, c)
if v >= scanSkipSpace {
if v == scanError {
break
}
if start < i {
dst = append(dst, src[start:i]...)
}
start = i + 1
}
}
if scan.eof() == scanError {
return dst[:origLen], scan.err
}
if start < len(src) {
dst = append(dst, src[start:]...)
}
return dst, nil
}
func appendNewline(dst []byte, prefix, indent string, depth int) []byte {
dst = append(dst, '\n')
dst = append(dst, prefix...)
for i := 0; i < depth; i++ {
dst = append(dst, indent...)
}
return dst
}
// indentGrowthFactor specifies the growth factor of indenting JSON input.
// Empirically, the growth factor was measured to be between 1.4x to 1.8x
// for some set of compacted JSON with the indent being a single tab.
// Specify a growth factor slightly larger than what is observed
// to reduce probability of allocation in appendIndent.
// A factor no higher than 2 ensures that wasted space never exceeds 50%.
const indentGrowthFactor = 2
// Indent appends to dst an indented form of the JSON-encoded src.
// Each element in a JSON object or array begins on a new,
// indented line beginning with prefix followed by one or more
// copies of indent according to the indentation nesting.
// The data appended to dst does not begin with the prefix nor
// any indentation, to make it easier to embed inside other formatted JSON data.
// Although leading space characters (space, tab, carriage return, newline)
// at the beginning of src are dropped, trailing space characters
// at the end of src are preserved and copied to dst.
// For example, if src has no trailing spaces, neither will dst;
// if src ends in a trailing newline, so will dst.
func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error {
dst.Grow(indentGrowthFactor * len(src))
b := dst.AvailableBuffer()
b, err := appendIndent(b, src, prefix, indent)
dst.Write(b)
return err
}
func appendIndent(dst, src []byte, prefix, indent string) ([]byte, error) {
origLen := len(dst)
scan := newScanner()
defer freeScanner(scan)
needIndent := false
depth := 0
for _, c := range src {
scan.bytes++
v := scan.step(scan, c)
if v == scanSkipSpace {
continue
}
if v == scanError {
break
}
if needIndent && v != scanEndObject && v != scanEndArray {
needIndent = false
depth++
dst = appendNewline(dst, prefix, indent, depth)
}
// Emit semantically uninteresting bytes
// (in particular, punctuation in strings) unmodified.
if v == scanContinue {
dst = append(dst, c)
continue
}
// Add spacing around real punctuation.
switch c {
case '{', '[':
// delay indent so that empty object and array are formatted as {} and [].
needIndent = true
dst = append(dst, c)
case ',':
dst = append(dst, c)
dst = appendNewline(dst, prefix, indent, depth)
case ':':
dst = append(dst, c, ' ')
case '}', ']':
if needIndent {
// suppress indent in empty object/array
needIndent = false
} else {
depth--
dst = appendNewline(dst, prefix, indent, depth)
}
dst = append(dst, c)
default:
dst = append(dst, c)
}
}
if scan.eof() == scanError {
return dst[:origLen], scan.err
}
return dst, nil
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !goexperiment.jsonv2
package json
// JSON value parser state machine.
// Just about at the limit of what is reasonable to write by hand.
// Some parts are a bit tedious, but overall it nicely factors out the
// otherwise common code from the multiple scanning functions
// in this package (Compact, Indent, checkValid, etc).
//
// This file starts with two simple examples using the scanner
// before diving into the scanner itself.
import (
"strconv"
"sync"
)
// Valid reports whether data is a valid JSON encoding.
func Valid(data []byte) bool {
scan := newScanner()
defer freeScanner(scan)
return checkValid(data, scan) == nil
}
// checkValid verifies that data is valid JSON-encoded data.
// scan is passed in for use by checkValid to avoid an allocation.
// checkValid returns nil or a SyntaxError.
func checkValid(data []byte, scan *scanner) error {
scan.reset()
for _, c := range data {
scan.bytes++
if scan.step(scan, c) == scanError {
return scan.err
}
}
if scan.eof() == scanError {
return scan.err
}
return nil
}
// A SyntaxError is a description of a JSON syntax error.
// [Unmarshal] will return a SyntaxError if the JSON can't be parsed.
type SyntaxError struct {
msg string // description of error
Offset int64 // error occurred after reading Offset bytes
}
func (e *SyntaxError) Error() string { return e.msg }
// A scanner is a JSON scanning state machine.
// Callers call scan.reset and then pass bytes in one at a time
// by calling scan.step(&scan, c) for each byte.
// The return value, referred to as an opcode, tells the
// caller about significant parsing events like beginning
// and ending literals, objects, and arrays, so that the
// caller can follow along if it wishes.
// The return value scanEnd indicates that a single top-level
// JSON value has been completed, *before* the byte that
// just got passed in. (The indication must be delayed in order
// to recognize the end of numbers: is 123 a whole value or
// the beginning of 12345e+6?).
type scanner struct {
// The step is a func to be called to execute the next transition.
// Also tried using an integer constant and a single func
// with a switch, but using the func directly was 10% faster
// on a 64-bit Mac Mini, and it's nicer to read.
step func(*scanner, byte) int
// Reached end of top-level value.
endTop bool
// Stack of what we're in the middle of - array values, object keys, object values.
parseState []int
// Error that happened, if any.
err error
// total bytes consumed, updated by decoder.Decode (and deliberately
// not set to zero by scan.reset)
bytes int64
}
var scannerPool = sync.Pool{
New: func() any {
return &scanner{}
},
}
func newScanner() *scanner {
scan := scannerPool.Get().(*scanner)
// scan.reset by design doesn't set bytes to zero
scan.bytes = 0
scan.reset()
return scan
}
func freeScanner(scan *scanner) {
// Avoid hanging on to too much memory in extreme cases.
if len(scan.parseState) > 1024 {
scan.parseState = nil
}
scannerPool.Put(scan)
}
// These values are returned by the state transition functions
// assigned to scanner.state and the method scanner.eof.
// They give details about the current state of the scan that
// callers might be interested to know about.
// It is okay to ignore the return value of any particular
// call to scanner.state: if one call returns scanError,
// every subsequent call will return scanError too.
const (
// Continue.
scanContinue = iota // uninteresting byte
scanBeginLiteral // end implied by next result != scanContinue
scanBeginObject // begin object
scanObjectKey // just finished object key (string)
scanObjectValue // just finished non-last object value
scanEndObject // end object (implies scanObjectValue if possible)
scanBeginArray // begin array
scanArrayValue // just finished array value
scanEndArray // end array (implies scanArrayValue if possible)
scanSkipSpace // space byte; can skip; known to be last "continue" result
// Stop.
scanEnd // top-level value ended *before* this byte; known to be first "stop" result
scanError // hit an error, scanner.err.
)
// These values are stored in the parseState stack.
// They give the current state of a composite value
// being scanned. If the parser is inside a nested value
// the parseState describes the nested state, outermost at entry 0.
const (
parseObjectKey = iota // parsing object key (before colon)
parseObjectValue // parsing object value (after colon)
parseArrayValue // parsing array value
)
// This limits the max nesting depth to prevent stack overflow.
// This is permitted by https://tools.ietf.org/html/rfc7159#section-9
const maxNestingDepth = 10000
// reset prepares the scanner for use.
// It must be called before calling s.step.
func (s *scanner) reset() {
s.step = stateBeginValue
s.parseState = s.parseState[0:0]
s.err = nil
s.endTop = false
}
// eof tells the scanner that the end of input has been reached.
// It returns a scan status just as s.step does.
func (s *scanner) eof() int {
if s.err != nil {
return scanError
}
if s.endTop {
return scanEnd
}
s.step(s, ' ')
if s.endTop {
return scanEnd
}
if s.err == nil {
s.err = &SyntaxError{"unexpected end of JSON input", s.bytes}
}
return scanError
}
// pushParseState pushes a new parse state newParseState onto the parse stack.
// an error state is returned if maxNestingDepth was exceeded, otherwise successState is returned.
func (s *scanner) pushParseState(c byte, newParseState int, successState int) int {
s.parseState = append(s.parseState, newParseState)
if len(s.parseState) <= maxNestingDepth {
return successState
}
return s.error(c, "exceeded max depth")
}
// popParseState pops a parse state (already obtained) off the stack
// and updates s.step accordingly.
func (s *scanner) popParseState() {
n := len(s.parseState) - 1
s.parseState = s.parseState[0:n]
if n == 0 {
s.step = stateEndTop
s.endTop = true
} else {
s.step = stateEndValue
}
}
func isSpace(c byte) bool {
return c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n')
}
// stateBeginValueOrEmpty is the state after reading `[`.
func stateBeginValueOrEmpty(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == ']' {
return stateEndValue(s, c)
}
return stateBeginValue(s, c)
}
// stateBeginValue is the state at the beginning of the input.
func stateBeginValue(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
switch c {
case '{':
s.step = stateBeginStringOrEmpty
return s.pushParseState(c, parseObjectKey, scanBeginObject)
case '[':
s.step = stateBeginValueOrEmpty
return s.pushParseState(c, parseArrayValue, scanBeginArray)
case '"':
s.step = stateInString
return scanBeginLiteral
case '-':
s.step = stateNeg
return scanBeginLiteral
case '0': // beginning of 0.123
s.step = state0
return scanBeginLiteral
case 't': // beginning of true
s.step = stateT
return scanBeginLiteral
case 'f': // beginning of false
s.step = stateF
return scanBeginLiteral
case 'n': // beginning of null
s.step = stateN
return scanBeginLiteral
}
if '1' <= c && c <= '9' { // beginning of 1234.5
s.step = state1
return scanBeginLiteral
}
return s.error(c, "looking for beginning of value")
}
// stateBeginStringOrEmpty is the state after reading `{`.
func stateBeginStringOrEmpty(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == '}' {
n := len(s.parseState)
s.parseState[n-1] = parseObjectValue
return stateEndValue(s, c)
}
return stateBeginString(s, c)
}
// stateBeginString is the state after reading `{"key": value,`.
func stateBeginString(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == '"' {
s.step = stateInString
return scanBeginLiteral
}
return s.error(c, "looking for beginning of object key string")
}
// stateEndValue is the state after completing a value,
// such as after reading `{}` or `true` or `["x"`.
func stateEndValue(s *scanner, c byte) int {
n := len(s.parseState)
if n == 0 {
// Completed top-level before the current byte.
s.step = stateEndTop
s.endTop = true
return stateEndTop(s, c)
}
if isSpace(c) {
s.step = stateEndValue
return scanSkipSpace
}
ps := s.parseState[n-1]
switch ps {
case parseObjectKey:
if c == ':' {
s.parseState[n-1] = parseObjectValue
s.step = stateBeginValue
return scanObjectKey
}
return s.error(c, "after object key")
case parseObjectValue:
if c == ',' {
s.parseState[n-1] = parseObjectKey
s.step = stateBeginString
return scanObjectValue
}
if c == '}' {
s.popParseState()
return scanEndObject
}
return s.error(c, "after object key:value pair")
case parseArrayValue:
if c == ',' {
s.step = stateBeginValue
return scanArrayValue
}
if c == ']' {
s.popParseState()
return scanEndArray
}
return s.error(c, "after array element")
}
return s.error(c, "")
}
// stateEndTop is the state after finishing the top-level value,
// such as after reading `{}` or `[1,2,3]`.
// Only space characters should be seen now.
func stateEndTop(s *scanner, c byte) int {
if !isSpace(c) {
// Complain about non-space byte on next call.
s.error(c, "after top-level value")
}
return scanEnd
}
// stateInString is the state after reading `"`.
func stateInString(s *scanner, c byte) int {
if c == '"' {
s.step = stateEndValue
return scanContinue
}
if c == '\\' {
s.step = stateInStringEsc
return scanContinue
}
if c < 0x20 {
return s.error(c, "in string literal")
}
return scanContinue
}
// stateInStringEsc is the state after reading `"\` during a quoted string.
func stateInStringEsc(s *scanner, c byte) int {
switch c {
case 'b', 'f', 'n', 'r', 't', '\\', '/', '"':
s.step = stateInString
return scanContinue
case 'u':
s.step = stateInStringEscU
return scanContinue
}
return s.error(c, "in string escape code")
}
// stateInStringEscU is the state after reading `"\u` during a quoted string.
func stateInStringEscU(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU1
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU1 is the state after reading `"\u1` during a quoted string.
func stateInStringEscU1(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU12
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU12 is the state after reading `"\u12` during a quoted string.
func stateInStringEscU12(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU123
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU123 is the state after reading `"\u123` during a quoted string.
func stateInStringEscU123(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInString
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateNeg is the state after reading `-` during a number.
func stateNeg(s *scanner, c byte) int {
if c == '0' {
s.step = state0
return scanContinue
}
if '1' <= c && c <= '9' {
s.step = state1
return scanContinue
}
return s.error(c, "in numeric literal")
}
// state1 is the state after reading a non-zero integer during a number,
// such as after reading `1` or `100` but not `0`.
func state1(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = state1
return scanContinue
}
return state0(s, c)
}
// state0 is the state after reading `0` during a number.
func state0(s *scanner, c byte) int {
if c == '.' {
s.step = stateDot
return scanContinue
}
if c == 'e' || c == 'E' {
s.step = stateE
return scanContinue
}
return stateEndValue(s, c)
}
// stateDot is the state after reading the integer and decimal point in a number,
// such as after reading `1.`.
func stateDot(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = stateDot0
return scanContinue
}
return s.error(c, "after decimal point in numeric literal")
}
// stateDot0 is the state after reading the integer, decimal point, and subsequent
// digits of a number, such as after reading `3.14`.
func stateDot0(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
return scanContinue
}
if c == 'e' || c == 'E' {
s.step = stateE
return scanContinue
}
return stateEndValue(s, c)
}
// stateE is the state after reading the mantissa and e in a number,
// such as after reading `314e` or `0.314e`.
func stateE(s *scanner, c byte) int {
if c == '+' || c == '-' {
s.step = stateESign
return scanContinue
}
return stateESign(s, c)
}
// stateESign is the state after reading the mantissa, e, and sign in a number,
// such as after reading `314e-` or `0.314e+`.
func stateESign(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = stateE0
return scanContinue
}
return s.error(c, "in exponent of numeric literal")
}
// stateE0 is the state after reading the mantissa, e, optional sign,
// and at least one digit of the exponent in a number,
// such as after reading `314e-2` or `0.314e+1` or `3.14e0`.
func stateE0(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
return scanContinue
}
return stateEndValue(s, c)
}
// stateT is the state after reading `t`.
func stateT(s *scanner, c byte) int {
if c == 'r' {
s.step = stateTr
return scanContinue
}
return s.error(c, "in literal true (expecting 'r')")
}
// stateTr is the state after reading `tr`.
func stateTr(s *scanner, c byte) int {
if c == 'u' {
s.step = stateTru
return scanContinue
}
return s.error(c, "in literal true (expecting 'u')")
}
// stateTru is the state after reading `tru`.
func stateTru(s *scanner, c byte) int {
if c == 'e' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal true (expecting 'e')")
}
// stateF is the state after reading `f`.
func stateF(s *scanner, c byte) int {
if c == 'a' {
s.step = stateFa
return scanContinue
}
return s.error(c, "in literal false (expecting 'a')")
}
// stateFa is the state after reading `fa`.
func stateFa(s *scanner, c byte) int {
if c == 'l' {
s.step = stateFal
return scanContinue
}
return s.error(c, "in literal false (expecting 'l')")
}
// stateFal is the state after reading `fal`.
func stateFal(s *scanner, c byte) int {
if c == 's' {
s.step = stateFals
return scanContinue
}
return s.error(c, "in literal false (expecting 's')")
}
// stateFals is the state after reading `fals`.
func stateFals(s *scanner, c byte) int {
if c == 'e' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal false (expecting 'e')")
}
// stateN is the state after reading `n`.
func stateN(s *scanner, c byte) int {
if c == 'u' {
s.step = stateNu
return scanContinue
}
return s.error(c, "in literal null (expecting 'u')")
}
// stateNu is the state after reading `nu`.
func stateNu(s *scanner, c byte) int {
if c == 'l' {
s.step = stateNul
return scanContinue
}
return s.error(c, "in literal null (expecting 'l')")
}
// stateNul is the state after reading `nul`.
func stateNul(s *scanner, c byte) int {
if c == 'l' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal null (expecting 'l')")
}
// stateError is the state after reaching a syntax error,
// such as after reading `[1}` or `5.1.2`.
func stateError(s *scanner, c byte) int {
return scanError
}
// error records an error and switches to the error state.
func (s *scanner) error(c byte, context string) int {
s.step = stateError
s.err = &SyntaxError{"invalid character " + quoteChar(c) + " " + context, s.bytes}
return scanError
}
// quoteChar formats c as a quoted character literal.
func quoteChar(c byte) string {
// special cases - different from quoted strings
if c == '\'' {
return `'\''`
}
if c == '"' {
return `'"'`
}
// use quoted string with different quotation marks
s := strconv.Quote(string(c))
return "'" + s[1:len(s)-1] + "'"
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !goexperiment.jsonv2
package json
import (
"bytes"
"errors"
"io"
)
// A Decoder reads and decodes JSON values from an input stream.
type Decoder struct {
r io.Reader
buf []byte
d decodeState
scanp int // start of unread data in buf
scanned int64 // amount of data already scanned
scan scanner
err error
tokenState int
tokenStack []int
}
// NewDecoder returns a new decoder that reads from r.
//
// The decoder introduces its own buffering and may
// read data from r beyond the JSON values requested.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
}
// UseNumber causes the Decoder to unmarshal a number into an
// interface value as a [Number] instead of as a float64.
func (dec *Decoder) UseNumber() { dec.d.useNumber = true }
// DisallowUnknownFields causes the Decoder to return an error when the destination
// is a struct and the input contains object keys which do not match any
// non-ignored, exported fields in the destination.
func (dec *Decoder) DisallowUnknownFields() { dec.d.disallowUnknownFields = true }
// Decode reads the next JSON-encoded value from its
// input and stores it in the value pointed to by v.
//
// See the documentation for [Unmarshal] for details about
// the conversion of JSON into a Go value.
func (dec *Decoder) Decode(v any) error {
if dec.err != nil {
return dec.err
}
if err := dec.tokenPrepareForDecode(); err != nil {
return err
}
if !dec.tokenValueAllowed() {
return &SyntaxError{msg: "not at beginning of value", Offset: dec.InputOffset()}
}
// Read whole value into buffer.
n, err := dec.readValue()
if err != nil {
return err
}
dec.d.init(dec.buf[dec.scanp : dec.scanp+n])
dec.scanp += n
// Don't save err from unmarshal into dec.err:
// the connection is still usable since we read a complete JSON
// object from it before the error happened.
err = dec.d.unmarshal(v)
// fixup token streaming state
dec.tokenValueEnd()
return err
}
// Buffered returns a reader of the data remaining in the Decoder's
// buffer. The reader is valid until the next call to [Decoder.Decode].
func (dec *Decoder) Buffered() io.Reader {
return bytes.NewReader(dec.buf[dec.scanp:])
}
// readValue reads a JSON value into dec.buf.
// It returns the length of the encoding.
func (dec *Decoder) readValue() (int, error) {
dec.scan.reset()
scanp := dec.scanp
var err error
Input:
// help the compiler see that scanp is never negative, so it can remove
// some bounds checks below.
for scanp >= 0 {
// Look in the buffer for a new value.
for ; scanp < len(dec.buf); scanp++ {
c := dec.buf[scanp]
dec.scan.bytes++
switch dec.scan.step(&dec.scan, c) {
case scanEnd:
// scanEnd is delayed one byte so we decrement
// the scanner bytes count by 1 to ensure that
// this value is correct in the next call of Decode.
dec.scan.bytes--
break Input
case scanEndObject, scanEndArray:
// scanEnd is delayed one byte.
// We might block trying to get that byte from src,
// so instead invent a space byte.
if stateEndValue(&dec.scan, ' ') == scanEnd {
scanp++
break Input
}
case scanError:
dec.err = dec.scan.err
return 0, dec.scan.err
}
}
// Did the last read have an error?
// Delayed until now to allow buffer scan.
if err != nil {
if err == io.EOF {
if dec.scan.step(&dec.scan, ' ') == scanEnd {
break Input
}
if nonSpace(dec.buf) {
err = io.ErrUnexpectedEOF
}
}
dec.err = err
return 0, err
}
n := scanp - dec.scanp
err = dec.refill()
scanp = dec.scanp + n
}
return scanp - dec.scanp, nil
}
func (dec *Decoder) refill() error {
// Make room to read more into the buffer.
// First slide down data already consumed.
if dec.scanp > 0 {
dec.scanned += int64(dec.scanp)
n := copy(dec.buf, dec.buf[dec.scanp:])
dec.buf = dec.buf[:n]
dec.scanp = 0
}
// Grow buffer if not large enough.
const minRead = 512
if cap(dec.buf)-len(dec.buf) < minRead {
newBuf := make([]byte, len(dec.buf), 2*cap(dec.buf)+minRead)
copy(newBuf, dec.buf)
dec.buf = newBuf
}
// Read. Delay error for next iteration (after scan).
n, err := dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)])
dec.buf = dec.buf[0 : len(dec.buf)+n]
return err
}
func nonSpace(b []byte) bool {
for _, c := range b {
if !isSpace(c) {
return true
}
}
return false
}
// An Encoder writes JSON values to an output stream.
type Encoder struct {
w io.Writer
err error
escapeHTML bool
indentBuf []byte
indentPrefix string
indentValue string
}
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w, escapeHTML: true}
}
// Encode writes the JSON encoding of v to the stream,
// with insignificant space characters elided,
// followed by a newline character.
//
// See the documentation for [Marshal] for details about the
// conversion of Go values to JSON.
func (enc *Encoder) Encode(v any) error {
if enc.err != nil {
return enc.err
}
e := newEncodeState()
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
if err != nil {
return err
}
// Terminate each value with a newline.
// This makes the output look a little nicer
// when debugging, and some kind of space
// is required if the encoded value was a number,
// so that the reader knows there aren't more
// digits coming.
e.WriteByte('\n')
b := e.Bytes()
if enc.indentPrefix != "" || enc.indentValue != "" {
enc.indentBuf, err = appendIndent(enc.indentBuf[:0], b, enc.indentPrefix, enc.indentValue)
if err != nil {
return err
}
b = enc.indentBuf
}
if _, err = enc.w.Write(b); err != nil {
enc.err = err
}
return err
}
// SetIndent instructs the encoder to format each subsequent encoded
// value as if indented by the package-level function Indent(dst, src, prefix, indent).
// Calling SetIndent("", "") disables indentation.
func (enc *Encoder) SetIndent(prefix, indent string) {
enc.indentPrefix = prefix
enc.indentValue = indent
}
// SetEscapeHTML specifies whether problematic HTML characters
// should be escaped inside JSON quoted strings.
// The default behavior is to escape &, <, and > to \u0026, \u003c, and \u003e
// to avoid certain safety problems that can arise when embedding JSON in HTML.
//
// In non-HTML settings where the escaping interferes with the readability
// of the output, SetEscapeHTML(false) disables this behavior.
func (enc *Encoder) SetEscapeHTML(on bool) {
enc.escapeHTML = on
}
// RawMessage is a raw encoded JSON value.
// It implements [Marshaler] and [Unmarshaler] and can
// be used to delay JSON decoding or precompute a JSON encoding.
type RawMessage []byte
// MarshalJSON returns m as the JSON encoding of m.
func (m RawMessage) MarshalJSON() ([]byte, error) {
if m == nil {
return []byte("null"), nil
}
return m, nil
}
// UnmarshalJSON sets *m to a copy of data.
func (m *RawMessage) UnmarshalJSON(data []byte) error {
if m == nil {
return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
}
*m = append((*m)[0:0], data...)
return nil
}
var _ Marshaler = (*RawMessage)(nil)
var _ Unmarshaler = (*RawMessage)(nil)
// A Token holds a value of one of these types:
//
// - [Delim], for the four JSON delimiters [ ] { }
// - bool, for JSON booleans
// - float64, for JSON numbers
// - [Number], for JSON numbers
// - string, for JSON string literals
// - nil, for JSON null
type Token any
const (
tokenTopValue = iota
tokenArrayStart
tokenArrayValue
tokenArrayComma
tokenObjectStart
tokenObjectKey
tokenObjectColon
tokenObjectValue
tokenObjectComma
)
// advance tokenstate from a separator state to a value state
func (dec *Decoder) tokenPrepareForDecode() error {
// Note: Not calling peek before switch, to avoid
// putting peek into the standard Decode path.
// peek is only called when using the Token API.
switch dec.tokenState {
case tokenArrayComma:
c, err := dec.peek()
if err != nil {
return err
}
if c != ',' {
return &SyntaxError{"expected comma after array element", dec.InputOffset()}
}
dec.scanp++
dec.tokenState = tokenArrayValue
case tokenObjectColon:
c, err := dec.peek()
if err != nil {
return err
}
if c != ':' {
return &SyntaxError{"expected colon after object key", dec.InputOffset()}
}
dec.scanp++
dec.tokenState = tokenObjectValue
}
return nil
}
func (dec *Decoder) tokenValueAllowed() bool {
switch dec.tokenState {
case tokenTopValue, tokenArrayStart, tokenArrayValue, tokenObjectValue:
return true
}
return false
}
func (dec *Decoder) tokenValueEnd() {
switch dec.tokenState {
case tokenArrayStart, tokenArrayValue:
dec.tokenState = tokenArrayComma
case tokenObjectValue:
dec.tokenState = tokenObjectComma
}
}
// A Delim is a JSON array or object delimiter, one of [ ] { or }.
type Delim rune
func (d Delim) String() string {
return string(d)
}
// Token returns the next JSON token in the input stream.
// At the end of the input stream, Token returns nil, [io.EOF].
//
// Token guarantees that the delimiters [ ] { } it returns are
// properly nested and matched: if Token encounters an unexpected
// delimiter in the input, it will return an error.
//
// The input stream consists of basic JSON values—bool, string,
// number, and null—along with delimiters [ ] { } of type [Delim]
// to mark the start and end of arrays and objects.
// Commas and colons are elided.
func (dec *Decoder) Token() (Token, error) {
for {
c, err := dec.peek()
if err != nil {
return nil, err
}
switch c {
case '[':
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
dec.tokenState = tokenArrayStart
return Delim('['), nil
case ']':
if dec.tokenState != tokenArrayStart && dec.tokenState != tokenArrayComma {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
dec.tokenValueEnd()
return Delim(']'), nil
case '{':
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
dec.tokenState = tokenObjectStart
return Delim('{'), nil
case '}':
if dec.tokenState != tokenObjectStart && dec.tokenState != tokenObjectComma {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
dec.tokenValueEnd()
return Delim('}'), nil
case ':':
if dec.tokenState != tokenObjectColon {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = tokenObjectValue
continue
case ',':
if dec.tokenState == tokenArrayComma {
dec.scanp++
dec.tokenState = tokenArrayValue
continue
}
if dec.tokenState == tokenObjectComma {
dec.scanp++
dec.tokenState = tokenObjectKey
continue
}
return dec.tokenError(c)
case '"':
if dec.tokenState == tokenObjectStart || dec.tokenState == tokenObjectKey {
var x string
old := dec.tokenState
dec.tokenState = tokenTopValue
err := dec.Decode(&x)
dec.tokenState = old
if err != nil {
return nil, err
}
dec.tokenState = tokenObjectColon
return x, nil
}
fallthrough
default:
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
var x any
if err := dec.Decode(&x); err != nil {
return nil, err
}
return x, nil
}
}
}
func (dec *Decoder) tokenError(c byte) (Token, error) {
var context string
switch dec.tokenState {
case tokenTopValue:
context = " looking for beginning of value"
case tokenArrayStart, tokenArrayValue, tokenObjectValue:
context = " looking for beginning of value"
case tokenArrayComma:
context = " after array element"
case tokenObjectKey:
context = " looking for beginning of object key string"
case tokenObjectColon:
context = " after object key"
case tokenObjectComma:
context = " after object key:value pair"
}
return nil, &SyntaxError{"invalid character " + quoteChar(c) + context, dec.InputOffset()}
}
// More reports whether there is another element in the
// current array or object being parsed.
func (dec *Decoder) More() bool {
c, err := dec.peek()
return err == nil && c != ']' && c != '}'
}
func (dec *Decoder) peek() (byte, error) {
var err error
for {
for i := dec.scanp; i < len(dec.buf); i++ {
c := dec.buf[i]
if isSpace(c) {
continue
}
dec.scanp = i
return c, nil
}
// buffer has been scanned, now report any error
if err != nil {
return 0, err
}
err = dec.refill()
}
}
// InputOffset returns the input stream byte offset of the current decoder position.
// The offset gives the location of the end of the most recently returned token
// and the beginning of the next token.
func (dec *Decoder) InputOffset() int64 {
return dec.scanned + int64(dec.scanp)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !goexperiment.jsonv2
package json
import (
"strings"
)
// tagOptions is the string following a comma in a struct field's "json"
// tag, or the empty string. It does not include the leading comma.
type tagOptions string
// parseTag splits a struct field's json tag into its name and
// comma-separated options.
func parseTag(tag string) (string, tagOptions) {
tag, opt, _ := strings.Cut(tag, ",")
return tag, tagOptions(opt)
}
// Contains reports whether a comma-separated list of options
// contains a particular substr flag. substr must be surrounded by a
// string boundary or commas.
func (o tagOptions) Contains(optionName string) bool {
if len(o) == 0 {
return false
}
s := string(o)
for s != "" {
var name string
name, s, _ = strings.Cut(s, ",")
if name == optionName {
return true
}
}
return false
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package pem implements the PEM data encoding, which originated in Privacy
// Enhanced Mail. The most common use of PEM encoding today is in TLS keys and
// certificates. See RFC 1421.
package pem
import (
"bytes"
"encoding/base64"
"errors"
"io"
"slices"
"strings"
)
// A Block represents a PEM encoded structure.
//
// The encoded form is:
//
// -----BEGIN Type-----
// Headers
// base64-encoded Bytes
// -----END Type-----
//
// where [Block.Headers] is a possibly empty sequence of Key: Value lines.
type Block struct {
Type string // The type, taken from the preamble (i.e. "RSA PRIVATE KEY").
Headers map[string]string // Optional headers.
Bytes []byte // The decoded bytes of the contents. Typically a DER encoded ASN.1 structure.
}
// getLine results the first \r\n or \n delineated line from the given byte
// array. The line does not include trailing whitespace or the trailing new
// line bytes. The remainder of the byte array (also not including the new line
// bytes) is also returned and this will always be smaller than the original
// argument.
func getLine(data []byte) (line, rest []byte) {
i := bytes.IndexByte(data, '\n')
var j int
if i < 0 {
i = len(data)
j = i
} else {
j = i + 1
if i > 0 && data[i-1] == '\r' {
i--
}
}
return bytes.TrimRight(data[0:i], " \t"), data[j:]
}
// removeSpacesAndTabs returns a copy of its input with all spaces and tabs
// removed, if there were any. Otherwise, the input is returned unchanged.
//
// The base64 decoder already skips newline characters, so we don't need to
// filter them out here.
func removeSpacesAndTabs(data []byte) []byte {
if !bytes.ContainsAny(data, " \t") {
// Fast path; most base64 data within PEM contains newlines, but
// no spaces nor tabs. Skip the extra alloc and work.
return data
}
result := make([]byte, len(data))
n := 0
for _, b := range data {
if b == ' ' || b == '\t' {
continue
}
result[n] = b
n++
}
return result[0:n]
}
var pemStart = []byte("\n-----BEGIN ")
var pemEnd = []byte("\n-----END ")
var pemEndOfLine = []byte("-----")
var colon = []byte(":")
// Decode will find the next PEM formatted block (certificate, private key
// etc) in the input. It returns that block and the remainder of the input. If
// no PEM data is found, p is nil and the whole of the input is returned in
// rest. Blocks must start at the beginning of a line and end at the end of a line.
func Decode(data []byte) (p *Block, rest []byte) {
// pemStart begins with a newline. However, at the very beginning of
// the byte array, we'll accept the start string without it.
rest = data
for {
if bytes.HasPrefix(rest, pemStart[1:]) {
rest = rest[len(pemStart)-1:]
} else if _, after, ok := bytes.Cut(rest, pemStart); ok {
rest = after
} else {
return nil, data
}
var typeLine []byte
typeLine, rest = getLine(rest)
if !bytes.HasSuffix(typeLine, pemEndOfLine) {
continue
}
typeLine = typeLine[0 : len(typeLine)-len(pemEndOfLine)]
p = &Block{
Headers: make(map[string]string),
Type: string(typeLine),
}
for {
// This loop terminates because getLine's second result is
// always smaller than its argument.
if len(rest) == 0 {
return nil, data
}
line, next := getLine(rest)
key, val, ok := bytes.Cut(line, colon)
if !ok {
break
}
// TODO(agl): need to cope with values that spread across lines.
key = bytes.TrimSpace(key)
val = bytes.TrimSpace(val)
p.Headers[string(key)] = string(val)
rest = next
}
var endIndex, endTrailerIndex int
// If there were no headers, the END line might occur
// immediately, without a leading newline.
if len(p.Headers) == 0 && bytes.HasPrefix(rest, pemEnd[1:]) {
endIndex = 0
endTrailerIndex = len(pemEnd) - 1
} else {
endIndex = bytes.Index(rest, pemEnd)
endTrailerIndex = endIndex + len(pemEnd)
}
if endIndex < 0 {
continue
}
// After the "-----" of the ending line, there should be the same type
// and then a final five dashes.
endTrailer := rest[endTrailerIndex:]
endTrailerLen := len(typeLine) + len(pemEndOfLine)
if len(endTrailer) < endTrailerLen {
continue
}
restOfEndLine := endTrailer[endTrailerLen:]
endTrailer = endTrailer[:endTrailerLen]
if !bytes.HasPrefix(endTrailer, typeLine) ||
!bytes.HasSuffix(endTrailer, pemEndOfLine) {
continue
}
// The line must end with only whitespace.
if s, _ := getLine(restOfEndLine); len(s) != 0 {
continue
}
base64Data := removeSpacesAndTabs(rest[:endIndex])
p.Bytes = make([]byte, base64.StdEncoding.DecodedLen(len(base64Data)))
n, err := base64.StdEncoding.Decode(p.Bytes, base64Data)
if err != nil {
continue
}
p.Bytes = p.Bytes[:n]
// the -1 is because we might have only matched pemEnd without the
// leading newline if the PEM block was empty.
_, rest = getLine(rest[endIndex+len(pemEnd)-1:])
return p, rest
}
}
const pemLineLength = 64
type lineBreaker struct {
line [pemLineLength]byte
used int
out io.Writer
}
var nl = []byte{'\n'}
func (l *lineBreaker) Write(b []byte) (n int, err error) {
if l.used+len(b) < pemLineLength {
copy(l.line[l.used:], b)
l.used += len(b)
return len(b), nil
}
n, err = l.out.Write(l.line[0:l.used])
if err != nil {
return
}
excess := pemLineLength - l.used
l.used = 0
n, err = l.out.Write(b[0:excess])
if err != nil {
return
}
n, err = l.out.Write(nl)
if err != nil {
return
}
return l.Write(b[excess:])
}
func (l *lineBreaker) Close() (err error) {
if l.used > 0 {
_, err = l.out.Write(l.line[0:l.used])
if err != nil {
return
}
_, err = l.out.Write(nl)
}
return
}
func writeHeader(out io.Writer, k, v string) error {
_, err := out.Write([]byte(k + ": " + v + "\n"))
return err
}
// Encode writes the PEM encoding of b to out.
func Encode(out io.Writer, b *Block) error {
// Check for invalid block before writing any output.
for k := range b.Headers {
if strings.Contains(k, ":") {
return errors.New("pem: cannot encode a header key that contains a colon")
}
}
// All errors below are relayed from underlying io.Writer,
// so it is now safe to write data.
if _, err := out.Write(pemStart[1:]); err != nil {
return err
}
if _, err := out.Write([]byte(b.Type + "-----\n")); err != nil {
return err
}
if len(b.Headers) > 0 {
const procType = "Proc-Type"
h := make([]string, 0, len(b.Headers))
hasProcType := false
for k := range b.Headers {
if k == procType {
hasProcType = true
continue
}
h = append(h, k)
}
// The Proc-Type header must be written first.
// See RFC 1421, section 4.6.1.1
if hasProcType {
if err := writeHeader(out, procType, b.Headers[procType]); err != nil {
return err
}
}
// For consistency of output, write other headers sorted by key.
slices.Sort(h)
for _, k := range h {
if err := writeHeader(out, k, b.Headers[k]); err != nil {
return err
}
}
if _, err := out.Write(nl); err != nil {
return err
}
}
var breaker lineBreaker
breaker.out = out
b64 := base64.NewEncoder(base64.StdEncoding, &breaker)
if _, err := b64.Write(b.Bytes); err != nil {
return err
}
b64.Close()
breaker.Close()
if _, err := out.Write(pemEnd[1:]); err != nil {
return err
}
_, err := out.Write([]byte(b.Type + "-----\n"))
return err
}
// EncodeToMemory returns the PEM encoding of b.
//
// If b has invalid headers and cannot be encoded,
// EncodeToMemory returns nil. If it is important to
// report details about this error case, use [Encode] instead.
func EncodeToMemory(b *Block) []byte {
var buf bytes.Buffer
if err := Encode(&buf, b); err != nil {
return nil
}
return buf.Bytes()
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xml
import (
"bufio"
"bytes"
"encoding"
"errors"
"fmt"
"io"
"reflect"
"strconv"
"strings"
)
const (
// Header is a generic XML header suitable for use with the output of [Marshal].
// This is not automatically added to any output of this package,
// it is provided as a convenience.
Header = `<?xml version="1.0" encoding="UTF-8"?>` + "\n"
)
// Marshal returns the XML encoding of v.
//
// Marshal handles an array or slice by marshaling each of the elements.
// Marshal handles a pointer by marshaling the value it points at or, if the
// pointer is nil, by writing nothing. Marshal handles an interface value by
// marshaling the value it contains or, if the interface value is nil, by
// writing nothing. Marshal handles all other data by writing one or more XML
// elements containing the data.
//
// The name for the XML elements is taken from, in order of preference:
// - the tag on the XMLName field, if the data is a struct
// - the value of the XMLName field of type [Name]
// - the tag of the struct field used to obtain the data
// - the name of the struct field used to obtain the data
// - the name of the marshaled type
//
// The XML element for a struct contains marshaled elements for each of the
// exported fields of the struct, with these exceptions:
// - the XMLName field, described above, is omitted.
// - a field with tag "-" is omitted.
// - a field with tag "name,attr" becomes an attribute with
// the given name in the XML element.
// - a field with tag ",attr" becomes an attribute with the
// field name in the XML element.
// - a field with tag ",chardata" is written as character data,
// not as an XML element.
// - a field with tag ",cdata" is written as character data
// wrapped in one or more <![CDATA[ ... ]]> tags, not as an XML element.
// - a field with tag ",innerxml" is written verbatim, not subject
// to the usual marshaling procedure.
// - a field with tag ",comment" is written as an XML comment, not
// subject to the usual marshaling procedure. It must not contain
// the "--" string within it.
// - a field with a tag including the "omitempty" option is omitted
// if the field value is empty. The empty values are false, 0, any
// nil pointer or interface value, and any array, slice, map, or
// string of length zero.
// - an anonymous struct field is handled as if the fields of its
// value were part of the outer struct.
// - an anonymous struct field of interface type is treated the same as having
// that type as its name, rather than being anonymous.
// - a field implementing [Marshaler] is written by calling its MarshalXML
// method.
// - a field implementing [encoding.TextMarshaler] is written by encoding the
// result of its MarshalText method as text.
//
// If a field uses a tag "a>b>c", then the element c will be nested inside
// parent elements a and b. Fields that appear next to each other that name
// the same parent will be enclosed in one XML element.
//
// If the XML name for a struct field is defined by both the field tag and the
// struct's XMLName field, the names must match.
//
// See [MarshalIndent] for an example.
//
// Marshal will return an error if asked to marshal a channel, function, or map.
func Marshal(v any) ([]byte, error) {
var b bytes.Buffer
enc := NewEncoder(&b)
if err := enc.Encode(v); err != nil {
return nil, err
}
if err := enc.Close(); err != nil {
return nil, err
}
return b.Bytes(), nil
}
// Marshaler is the interface implemented by objects that can marshal
// themselves into valid XML elements.
//
// MarshalXML encodes the receiver as zero or more XML elements.
// By convention, arrays or slices are typically encoded as a sequence
// of elements, one per entry.
// Using start as the element tag is not required, but doing so
// will enable [Unmarshal] to match the XML elements to the correct
// struct field.
// One common implementation strategy is to construct a separate
// value with a layout corresponding to the desired XML and then
// to encode it using e.EncodeElement.
// Another common strategy is to use repeated calls to e.EncodeToken
// to generate the XML output one token at a time.
// The sequence of encoded tokens must make up zero or more valid
// XML elements.
type Marshaler interface {
MarshalXML(e *Encoder, start StartElement) error
}
// MarshalerAttr is the interface implemented by objects that can marshal
// themselves into valid XML attributes.
//
// MarshalXMLAttr returns an XML attribute with the encoded value of the receiver.
// Using name as the attribute name is not required, but doing so
// will enable [Unmarshal] to match the attribute to the correct
// struct field.
// If MarshalXMLAttr returns the zero attribute [Attr]{}, no attribute
// will be generated in the output.
// MarshalXMLAttr is used only for struct fields with the
// "attr" option in the field tag.
type MarshalerAttr interface {
MarshalXMLAttr(name Name) (Attr, error)
}
// MarshalIndent works like [Marshal], but each XML element begins on a new
// indented line that starts with prefix and is followed by one or more
// copies of indent according to the nesting depth.
func MarshalIndent(v any, prefix, indent string) ([]byte, error) {
var b bytes.Buffer
enc := NewEncoder(&b)
enc.Indent(prefix, indent)
if err := enc.Encode(v); err != nil {
return nil, err
}
if err := enc.Close(); err != nil {
return nil, err
}
return b.Bytes(), nil
}
// An Encoder writes XML data to an output stream.
type Encoder struct {
p printer
}
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
e := &Encoder{printer{w: bufio.NewWriter(w)}}
e.p.encoder = e
return e
}
// Indent sets the encoder to generate XML in which each element
// begins on a new indented line that starts with prefix and is followed by
// one or more copies of indent according to the nesting depth.
func (enc *Encoder) Indent(prefix, indent string) {
enc.p.prefix = prefix
enc.p.indent = indent
}
// Encode writes the XML encoding of v to the stream.
//
// See the documentation for [Marshal] for details about the conversion
// of Go values to XML.
//
// Encode calls [Encoder.Flush] before returning.
func (enc *Encoder) Encode(v any) error {
err := enc.p.marshalValue(reflect.ValueOf(v), nil, nil)
if err != nil {
return err
}
return enc.p.w.Flush()
}
// EncodeElement writes the XML encoding of v to the stream,
// using start as the outermost tag in the encoding.
//
// See the documentation for [Marshal] for details about the conversion
// of Go values to XML.
//
// EncodeElement calls [Encoder.Flush] before returning.
func (enc *Encoder) EncodeElement(v any, start StartElement) error {
err := enc.p.marshalValue(reflect.ValueOf(v), nil, &start)
if err != nil {
return err
}
return enc.p.w.Flush()
}
var (
begComment = []byte("<!--")
endComment = []byte("-->")
endProcInst = []byte("?>")
)
// EncodeToken writes the given XML token to the stream.
// It returns an error if [StartElement] and [EndElement] tokens are not properly matched.
//
// EncodeToken does not call [Encoder.Flush], because usually it is part of a larger operation
// such as [Encoder.Encode] or [Encoder.EncodeElement] (or a custom [Marshaler]'s MarshalXML invoked
// during those), and those will call Flush when finished.
// Callers that create an Encoder and then invoke EncodeToken directly, without
// using Encode or EncodeElement, need to call Flush when finished to ensure
// that the XML is written to the underlying writer.
//
// EncodeToken allows writing a [ProcInst] with Target set to "xml" only as the first token
// in the stream.
func (enc *Encoder) EncodeToken(t Token) error {
p := &enc.p
switch t := t.(type) {
case StartElement:
if err := p.writeStart(&t); err != nil {
return err
}
case EndElement:
if err := p.writeEnd(t.Name); err != nil {
return err
}
case CharData:
escapeText(p, t, false)
case Comment:
if bytes.Contains(t, endComment) {
return fmt.Errorf("xml: EncodeToken of Comment containing --> marker")
}
p.WriteString("<!--")
p.Write(t)
p.WriteString("-->")
return p.cachedWriteError()
case ProcInst:
// First token to be encoded which is also a ProcInst with target of xml
// is the xml declaration. The only ProcInst where target of xml is allowed.
if t.Target == "xml" && p.w.Buffered() != 0 {
return fmt.Errorf("xml: EncodeToken of ProcInst xml target only valid for xml declaration, first token encoded")
}
if !isNameString(t.Target) {
return fmt.Errorf("xml: EncodeToken of ProcInst with invalid Target")
}
if bytes.Contains(t.Inst, endProcInst) {
return fmt.Errorf("xml: EncodeToken of ProcInst containing ?> marker")
}
p.WriteString("<?")
p.WriteString(t.Target)
if len(t.Inst) > 0 {
p.WriteByte(' ')
p.Write(t.Inst)
}
p.WriteString("?>")
case Directive:
if !isValidDirective(t) {
return fmt.Errorf("xml: EncodeToken of Directive containing wrong < or > markers")
}
p.WriteString("<!")
p.Write(t)
p.WriteString(">")
default:
return fmt.Errorf("xml: EncodeToken of invalid token type")
}
return p.cachedWriteError()
}
// isValidDirective reports whether dir is a valid directive text,
// meaning angle brackets are matched, ignoring comments and strings.
func isValidDirective(dir Directive) bool {
var (
depth int
inquote uint8
incomment bool
)
for i, c := range dir {
switch {
case incomment:
if c == '>' {
if n := 1 + i - len(endComment); n >= 0 && bytes.Equal(dir[n:i+1], endComment) {
incomment = false
}
}
// Just ignore anything in comment
case inquote != 0:
if c == inquote {
inquote = 0
}
// Just ignore anything within quotes
case c == '\'' || c == '"':
inquote = c
case c == '<':
if i+len(begComment) < len(dir) && bytes.Equal(dir[i:i+len(begComment)], begComment) {
incomment = true
} else {
depth++
}
case c == '>':
if depth == 0 {
return false
}
depth--
}
}
return depth == 0 && inquote == 0 && !incomment
}
// Flush flushes any buffered XML to the underlying writer.
// See the [Encoder.EncodeToken] documentation for details about when it is necessary.
func (enc *Encoder) Flush() error {
return enc.p.w.Flush()
}
// Close the Encoder, indicating that no more data will be written. It flushes
// any buffered XML to the underlying writer and returns an error if the
// written XML is invalid (e.g. by containing unclosed elements).
func (enc *Encoder) Close() error {
return enc.p.Close()
}
type printer struct {
w *bufio.Writer
encoder *Encoder
seq int
indent string
prefix string
depth int
indentedIn bool
putNewline bool
attrNS map[string]string // map prefix -> name space
attrPrefix map[string]string // map name space -> prefix
prefixes []string
tags []Name
closed bool
err error
}
// createAttrPrefix finds the name space prefix attribute to use for the given name space,
// defining a new prefix if necessary. It returns the prefix.
func (p *printer) createAttrPrefix(url string) string {
if prefix := p.attrPrefix[url]; prefix != "" {
return prefix
}
// The "http://www.w3.org/XML/1998/namespace" name space is predefined as "xml"
// and must be referred to that way.
// (The "http://www.w3.org/2000/xmlns/" name space is also predefined as "xmlns",
// but users should not be trying to use that one directly - that's our job.)
if url == xmlURL {
return xmlPrefix
}
// Need to define a new name space.
if p.attrPrefix == nil {
p.attrPrefix = make(map[string]string)
p.attrNS = make(map[string]string)
}
// Pick a name. We try to use the final element of the path
// but fall back to _.
prefix := strings.TrimRight(url, "/")
if i := strings.LastIndex(prefix, "/"); i >= 0 {
prefix = prefix[i+1:]
}
if prefix == "" || !isName([]byte(prefix)) || strings.Contains(prefix, ":") {
prefix = "_"
}
// xmlanything is reserved and any variant of it regardless of
// case should be matched, so:
// (('X'|'x') ('M'|'m') ('L'|'l'))
// See Section 2.3 of https://www.w3.org/TR/REC-xml/
if len(prefix) >= 3 && strings.EqualFold(prefix[:3], "xml") {
prefix = "_" + prefix
}
if p.attrNS[prefix] != "" {
// Name is taken. Find a better one.
for p.seq++; ; p.seq++ {
if id := prefix + "_" + strconv.Itoa(p.seq); p.attrNS[id] == "" {
prefix = id
break
}
}
}
p.attrPrefix[url] = prefix
p.attrNS[prefix] = url
p.WriteString(`xmlns:`)
p.WriteString(prefix)
p.WriteString(`="`)
EscapeText(p, []byte(url))
p.WriteString(`" `)
p.prefixes = append(p.prefixes, prefix)
return prefix
}
// deleteAttrPrefix removes an attribute name space prefix.
func (p *printer) deleteAttrPrefix(prefix string) {
delete(p.attrPrefix, p.attrNS[prefix])
delete(p.attrNS, prefix)
}
func (p *printer) markPrefix() {
p.prefixes = append(p.prefixes, "")
}
func (p *printer) popPrefix() {
for len(p.prefixes) > 0 {
prefix := p.prefixes[len(p.prefixes)-1]
p.prefixes = p.prefixes[:len(p.prefixes)-1]
if prefix == "" {
break
}
p.deleteAttrPrefix(prefix)
}
}
var (
marshalerType = reflect.TypeFor[Marshaler]()
marshalerAttrType = reflect.TypeFor[MarshalerAttr]()
textMarshalerType = reflect.TypeFor[encoding.TextMarshaler]()
)
// marshalValue writes one or more XML elements representing val.
// If val was obtained from a struct field, finfo must have its details.
func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplate *StartElement) error {
if startTemplate != nil && startTemplate.Name.Local == "" {
return fmt.Errorf("xml: EncodeElement of StartElement with missing name")
}
if !val.IsValid() {
return nil
}
if finfo != nil && finfo.flags&fOmitEmpty != 0 && isEmptyValue(val) {
return nil
}
// Drill into interfaces and pointers.
// This can turn into an infinite loop given a cyclic chain,
// but it matches the Go 1 behavior.
for val.Kind() == reflect.Interface || val.Kind() == reflect.Pointer {
if val.IsNil() {
return nil
}
val = val.Elem()
}
kind := val.Kind()
typ := val.Type()
// Check for marshaler.
if val.CanInterface() && typ.Implements(marshalerType) {
return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate))
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(marshalerType) {
return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate))
}
}
// Check for text marshaler.
if val.CanInterface() && typ.Implements(textMarshalerType) {
return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate))
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate))
}
}
// Slices and arrays iterate over the elements. They do not have an enclosing tag.
if (kind == reflect.Slice || kind == reflect.Array) && typ.Elem().Kind() != reflect.Uint8 {
for i, n := 0, val.Len(); i < n; i++ {
if err := p.marshalValue(val.Index(i), finfo, startTemplate); err != nil {
return err
}
}
return nil
}
tinfo, err := getTypeInfo(typ)
if err != nil {
return err
}
// Create start element.
// Precedence for the XML element name is:
// 0. startTemplate
// 1. XMLName field in underlying struct;
// 2. field name/tag in the struct field; and
// 3. type name
var start StartElement
if startTemplate != nil {
start.Name = startTemplate.Name
start.Attr = append(start.Attr, startTemplate.Attr...)
} else if tinfo.xmlname != nil {
xmlname := tinfo.xmlname
if xmlname.name != "" {
start.Name.Space, start.Name.Local = xmlname.xmlns, xmlname.name
} else {
fv := xmlname.value(val, dontInitNilPointers)
if v, ok := fv.Interface().(Name); ok && v.Local != "" {
start.Name = v
}
}
}
if start.Name.Local == "" && finfo != nil {
start.Name.Space, start.Name.Local = finfo.xmlns, finfo.name
}
if start.Name.Local == "" {
name := typ.Name()
if i := strings.IndexByte(name, '['); i >= 0 {
// Truncate generic instantiation name. See issue 48318.
name = name[:i]
}
if name == "" {
return &UnsupportedTypeError{typ}
}
start.Name.Local = name
}
// Attributes
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
if finfo.flags&fAttr == 0 {
continue
}
fv := finfo.value(val, dontInitNilPointers)
if finfo.flags&fOmitEmpty != 0 && (!fv.IsValid() || isEmptyValue(fv)) {
continue
}
if fv.Kind() == reflect.Interface && fv.IsNil() {
continue
}
name := Name{Space: finfo.xmlns, Local: finfo.name}
if err := p.marshalAttr(&start, name, fv); err != nil {
return err
}
}
// If an empty name was found, namespace is overridden with an empty space
if tinfo.xmlname != nil && start.Name.Space == "" &&
tinfo.xmlname.xmlns == "" && tinfo.xmlname.name == "" &&
len(p.tags) != 0 && p.tags[len(p.tags)-1].Space != "" {
start.Attr = append(start.Attr, Attr{Name{"", xmlnsPrefix}, ""})
}
if err := p.writeStart(&start); err != nil {
return err
}
if val.Kind() == reflect.Struct {
err = p.marshalStruct(tinfo, val)
} else {
s, b, err1 := p.marshalSimple(typ, val)
if err1 != nil {
err = err1
} else if b != nil {
EscapeText(p, b)
} else {
p.EscapeString(s)
}
}
if err != nil {
return err
}
if err := p.writeEnd(start.Name); err != nil {
return err
}
return p.cachedWriteError()
}
// marshalAttr marshals an attribute with the given name and value, adding to start.Attr.
func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value) error {
if val.CanInterface() && val.Type().Implements(marshalerAttrType) {
attr, err := val.Interface().(MarshalerAttr).MarshalXMLAttr(name)
if err != nil {
return err
}
if attr.Name.Local != "" {
start.Attr = append(start.Attr, attr)
}
return nil
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) {
attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
if err != nil {
return err
}
if attr.Name.Local != "" {
start.Attr = append(start.Attr, attr)
}
return nil
}
}
if val.CanInterface() && val.Type().Implements(textMarshalerType) {
text, err := val.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
start.Attr = append(start.Attr, Attr{name, string(text)})
return nil
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
text, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
start.Attr = append(start.Attr, Attr{name, string(text)})
return nil
}
}
// Dereference or skip nil pointer, interface values.
switch val.Kind() {
case reflect.Pointer, reflect.Interface:
if val.IsNil() {
return nil
}
val = val.Elem()
}
// Walk slices.
if val.Kind() == reflect.Slice && val.Type().Elem().Kind() != reflect.Uint8 {
n := val.Len()
for i := 0; i < n; i++ {
if err := p.marshalAttr(start, name, val.Index(i)); err != nil {
return err
}
}
return nil
}
if val.Type() == attrType {
start.Attr = append(start.Attr, val.Interface().(Attr))
return nil
}
s, b, err := p.marshalSimple(val.Type(), val)
if err != nil {
return err
}
if b != nil {
s = string(b)
}
start.Attr = append(start.Attr, Attr{name, s})
return nil
}
// defaultStart returns the default start element to use,
// given the reflect type, field info, and start template.
func defaultStart(typ reflect.Type, finfo *fieldInfo, startTemplate *StartElement) StartElement {
var start StartElement
// Precedence for the XML element name is as above,
// except that we do not look inside structs for the first field.
if startTemplate != nil {
start.Name = startTemplate.Name
start.Attr = append(start.Attr, startTemplate.Attr...)
} else if finfo != nil && finfo.name != "" {
start.Name.Local = finfo.name
start.Name.Space = finfo.xmlns
} else if typ.Name() != "" {
start.Name.Local = typ.Name()
} else {
// Must be a pointer to a named type,
// since it has the Marshaler methods.
start.Name.Local = typ.Elem().Name()
}
return start
}
// marshalInterface marshals a Marshaler interface value.
func (p *printer) marshalInterface(val Marshaler, start StartElement) error {
// Push a marker onto the tag stack so that MarshalXML
// cannot close the XML tags that it did not open.
p.tags = append(p.tags, Name{})
n := len(p.tags)
err := val.MarshalXML(p.encoder, start)
if err != nil {
return err
}
// Make sure MarshalXML closed all its tags. p.tags[n-1] is the mark.
if len(p.tags) > n {
return fmt.Errorf("xml: %s.MarshalXML wrote invalid XML: <%s> not closed", receiverType(val), p.tags[len(p.tags)-1].Local)
}
p.tags = p.tags[:n-1]
return nil
}
// marshalTextInterface marshals a TextMarshaler interface value.
func (p *printer) marshalTextInterface(val encoding.TextMarshaler, start StartElement) error {
if err := p.writeStart(&start); err != nil {
return err
}
text, err := val.MarshalText()
if err != nil {
return err
}
EscapeText(p, text)
return p.writeEnd(start.Name)
}
// writeStart writes the given start element.
func (p *printer) writeStart(start *StartElement) error {
if start.Name.Local == "" {
return fmt.Errorf("xml: start tag with no name")
}
p.tags = append(p.tags, start.Name)
p.markPrefix()
p.writeIndent(1)
p.WriteByte('<')
p.WriteString(start.Name.Local)
if start.Name.Space != "" {
p.WriteString(` xmlns="`)
p.EscapeString(start.Name.Space)
p.WriteByte('"')
}
// Attributes
for _, attr := range start.Attr {
name := attr.Name
if name.Local == "" {
continue
}
p.WriteByte(' ')
if name.Space != "" {
p.WriteString(p.createAttrPrefix(name.Space))
p.WriteByte(':')
}
p.WriteString(name.Local)
p.WriteString(`="`)
p.EscapeString(attr.Value)
p.WriteByte('"')
}
p.WriteByte('>')
return nil
}
func (p *printer) writeEnd(name Name) error {
if name.Local == "" {
return fmt.Errorf("xml: end tag with no name")
}
if len(p.tags) == 0 || p.tags[len(p.tags)-1].Local == "" {
return fmt.Errorf("xml: end tag </%s> without start tag", name.Local)
}
if top := p.tags[len(p.tags)-1]; top != name {
if top.Local != name.Local {
return fmt.Errorf("xml: end tag </%s> does not match start tag <%s>", name.Local, top.Local)
}
return fmt.Errorf("xml: end tag </%s> in namespace %s does not match start tag <%s> in namespace %s", name.Local, name.Space, top.Local, top.Space)
}
p.tags = p.tags[:len(p.tags)-1]
p.writeIndent(-1)
p.WriteByte('<')
p.WriteByte('/')
p.WriteString(name.Local)
p.WriteByte('>')
p.popPrefix()
return nil
}
func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) (string, []byte, error) {
switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(val.Int(), 10), nil, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return strconv.FormatUint(val.Uint(), 10), nil, nil
case reflect.Float32, reflect.Float64:
return strconv.FormatFloat(val.Float(), 'g', -1, val.Type().Bits()), nil, nil
case reflect.String:
return val.String(), nil, nil
case reflect.Bool:
return strconv.FormatBool(val.Bool()), nil, nil
case reflect.Array:
if typ.Elem().Kind() != reflect.Uint8 {
break
}
// [...]byte
var bytes []byte
if val.CanAddr() {
bytes = val.Bytes()
} else {
bytes = make([]byte, val.Len())
reflect.Copy(reflect.ValueOf(bytes), val)
}
return "", bytes, nil
case reflect.Slice:
if typ.Elem().Kind() != reflect.Uint8 {
break
}
// []byte
return "", val.Bytes(), nil
}
return "", nil, &UnsupportedTypeError{typ}
}
var ddBytes = []byte("--")
// indirect drills into interfaces and pointers, returning the pointed-at value.
// If it encounters a nil interface or pointer, indirect returns that nil value.
// This can turn into an infinite loop given a cyclic chain,
// but it matches the Go 1 behavior.
func indirect(vf reflect.Value) reflect.Value {
for vf.Kind() == reflect.Interface || vf.Kind() == reflect.Pointer {
if vf.IsNil() {
return vf
}
vf = vf.Elem()
}
return vf
}
func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
s := parentStack{p: p}
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
if finfo.flags&fAttr != 0 {
continue
}
vf := finfo.value(val, dontInitNilPointers)
if !vf.IsValid() {
// The field is behind an anonymous struct field that's
// nil. Skip it.
continue
}
switch finfo.flags & fMode {
case fCDATA, fCharData:
emit := EscapeText
if finfo.flags&fMode == fCDATA {
emit = emitCDATA
}
if err := s.trim(finfo.parents); err != nil {
return err
}
if vf.CanInterface() && vf.Type().Implements(textMarshalerType) {
data, err := vf.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
if err := emit(p, data); err != nil {
return err
}
continue
}
if vf.CanAddr() {
pv := vf.Addr()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
data, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
if err := emit(p, data); err != nil {
return err
}
continue
}
}
var scratch [64]byte
vf = indirect(vf)
switch vf.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if err := emit(p, strconv.AppendInt(scratch[:0], vf.Int(), 10)); err != nil {
return err
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
if err := emit(p, strconv.AppendUint(scratch[:0], vf.Uint(), 10)); err != nil {
return err
}
case reflect.Float32, reflect.Float64:
if err := emit(p, strconv.AppendFloat(scratch[:0], vf.Float(), 'g', -1, vf.Type().Bits())); err != nil {
return err
}
case reflect.Bool:
if err := emit(p, strconv.AppendBool(scratch[:0], vf.Bool())); err != nil {
return err
}
case reflect.String:
if err := emit(p, []byte(vf.String())); err != nil {
return err
}
case reflect.Slice:
if elem, ok := vf.Interface().([]byte); ok {
if err := emit(p, elem); err != nil {
return err
}
}
}
continue
case fComment:
if err := s.trim(finfo.parents); err != nil {
return err
}
vf = indirect(vf)
k := vf.Kind()
if !(k == reflect.String || k == reflect.Slice && vf.Type().Elem().Kind() == reflect.Uint8) {
return fmt.Errorf("xml: bad type for comment field of %s", val.Type())
}
if vf.Len() == 0 {
continue
}
p.writeIndent(0)
p.WriteString("<!--")
dashDash := false
dashLast := false
switch k {
case reflect.String:
s := vf.String()
dashDash = strings.Contains(s, "--")
dashLast = s[len(s)-1] == '-'
if !dashDash {
p.WriteString(s)
}
case reflect.Slice:
b := vf.Bytes()
dashDash = bytes.Contains(b, ddBytes)
dashLast = b[len(b)-1] == '-'
if !dashDash {
p.Write(b)
}
default:
panic("can't happen")
}
if dashDash {
return fmt.Errorf(`xml: comments must not contain "--"`)
}
if dashLast {
// "--->" is invalid grammar. Make it "- -->"
p.WriteByte(' ')
}
p.WriteString("-->")
continue
case fInnerXML:
vf = indirect(vf)
iface := vf.Interface()
switch raw := iface.(type) {
case []byte:
p.Write(raw)
continue
case string:
p.WriteString(raw)
continue
}
case fElement, fElement | fAny:
if err := s.trim(finfo.parents); err != nil {
return err
}
if len(finfo.parents) > len(s.stack) {
if vf.Kind() != reflect.Pointer && vf.Kind() != reflect.Interface || !vf.IsNil() {
if err := s.push(finfo.parents[len(s.stack):]); err != nil {
return err
}
}
}
}
if err := p.marshalValue(vf, finfo, nil); err != nil {
return err
}
}
s.trim(nil)
return p.cachedWriteError()
}
// Write implements io.Writer
func (p *printer) Write(b []byte) (n int, err error) {
if p.closed && p.err == nil {
p.err = errors.New("use of closed Encoder")
}
if p.err == nil {
n, p.err = p.w.Write(b)
}
return n, p.err
}
// WriteString implements io.StringWriter
func (p *printer) WriteString(s string) (n int, err error) {
if p.closed && p.err == nil {
p.err = errors.New("use of closed Encoder")
}
if p.err == nil {
n, p.err = p.w.WriteString(s)
}
return n, p.err
}
// WriteByte implements io.ByteWriter
func (p *printer) WriteByte(c byte) error {
if p.closed && p.err == nil {
p.err = errors.New("use of closed Encoder")
}
if p.err == nil {
p.err = p.w.WriteByte(c)
}
return p.err
}
// Close the Encoder, indicating that no more data will be written. It flushes
// any buffered XML to the underlying writer and returns an error if the
// written XML is invalid (e.g. by containing unclosed elements).
func (p *printer) Close() error {
if p.closed {
return nil
}
p.closed = true
if err := p.w.Flush(); err != nil {
return err
}
if len(p.tags) > 0 {
return fmt.Errorf("unclosed tag <%s>", p.tags[len(p.tags)-1].Local)
}
return nil
}
// return the bufio Writer's cached write error
func (p *printer) cachedWriteError() error {
_, err := p.Write(nil)
return err
}
func (p *printer) writeIndent(depthDelta int) {
if len(p.prefix) == 0 && len(p.indent) == 0 {
return
}
if depthDelta < 0 {
p.depth--
if p.indentedIn {
p.indentedIn = false
return
}
p.indentedIn = false
}
if p.putNewline {
p.WriteByte('\n')
} else {
p.putNewline = true
}
if len(p.prefix) > 0 {
p.WriteString(p.prefix)
}
if len(p.indent) > 0 {
for i := 0; i < p.depth; i++ {
p.WriteString(p.indent)
}
}
if depthDelta > 0 {
p.depth++
p.indentedIn = true
}
}
type parentStack struct {
p *printer
stack []string
}
// trim updates the XML context to match the longest common prefix of the stack
// and the given parents. A closing tag will be written for every parent
// popped. Passing a zero slice or nil will close all the elements.
func (s *parentStack) trim(parents []string) error {
split := 0
for ; split < len(parents) && split < len(s.stack); split++ {
if parents[split] != s.stack[split] {
break
}
}
for i := len(s.stack) - 1; i >= split; i-- {
if err := s.p.writeEnd(Name{Local: s.stack[i]}); err != nil {
return err
}
}
s.stack = s.stack[:split]
return nil
}
// push adds parent elements to the stack and writes open tags.
func (s *parentStack) push(parents []string) error {
for i := 0; i < len(parents); i++ {
if err := s.p.writeStart(&StartElement{Name: Name{Local: parents[i]}}); err != nil {
return err
}
}
s.stack = append(s.stack, parents...)
return nil
}
// UnsupportedTypeError is returned when [Marshal] encounters a type
// that cannot be converted into XML.
type UnsupportedTypeError struct {
Type reflect.Type
}
func (e *UnsupportedTypeError) Error() string {
return "xml: unsupported type: " + e.Type.String()
}
func isEmptyValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
reflect.Float32, reflect.Float64,
reflect.Interface, reflect.Pointer:
return v.IsZero()
}
return false
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xml
import (
"bytes"
"encoding"
"errors"
"fmt"
"reflect"
"runtime"
"strconv"
"strings"
)
// BUG(rsc): Mapping between XML elements and data structures is inherently flawed:
// an XML element is an order-dependent collection of anonymous
// values, while a data structure is an order-independent collection
// of named values.
// See [encoding/json] for a textual representation more suitable
// to data structures.
// Unmarshal parses the XML-encoded data and stores the result in
// the value pointed to by v, which must be an arbitrary struct,
// slice, or string. Well-formed data that does not fit into v is
// discarded.
//
// Because Unmarshal uses the reflect package, it can only assign
// to exported (upper case) fields. Unmarshal uses a case-sensitive
// comparison to match XML element names to tag values and struct
// field names.
//
// Unmarshal maps an XML element to a struct using the following rules.
// In the rules, the tag of a field refers to the value associated with the
// key 'xml' in the struct field's tag (see the example above).
//
// - If the struct has a field of type []byte or string with tag
// ",innerxml", Unmarshal accumulates the raw XML nested inside the
// element in that field. The rest of the rules still apply.
//
// - If the struct has a field named XMLName of type Name,
// Unmarshal records the element name in that field.
//
// - If the XMLName field has an associated tag of the form
// "name" or "namespace-URL name", the XML element must have
// the given name (and, optionally, name space) or else Unmarshal
// returns an error.
//
// - If the XML element has an attribute whose name matches a
// struct field name with an associated tag containing ",attr" or
// the explicit name in a struct field tag of the form "name,attr",
// Unmarshal records the attribute value in that field.
//
// - If the XML element has an attribute not handled by the previous
// rule and the struct has a field with an associated tag containing
// ",any,attr", Unmarshal records the attribute value in the first
// such field.
//
// - If the XML element contains character data, that data is
// accumulated in the first struct field that has tag ",chardata".
// The struct field may have type []byte or string.
// If there is no such field, the character data is discarded.
//
// - If the XML element contains comments, they are accumulated in
// the first struct field that has tag ",comment". The struct
// field may have type []byte or string. If there is no such
// field, the comments are discarded.
//
// - If the XML element contains a sub-element whose name matches
// the prefix of a tag formatted as "a" or "a>b>c", unmarshal
// will descend into the XML structure looking for elements with the
// given names, and will map the innermost elements to that struct
// field. A tag starting with ">" is equivalent to one starting
// with the field name followed by ">".
//
// - If the XML element contains a sub-element whose name matches
// a struct field's XMLName tag and the struct field has no
// explicit name tag as per the previous rule, unmarshal maps
// the sub-element to that struct field.
//
// - If the XML element contains a sub-element whose name matches a
// field without any mode flags (",attr", ",chardata", etc), Unmarshal
// maps the sub-element to that struct field.
//
// - If the XML element contains a sub-element that hasn't matched any
// of the above rules and the struct has a field with tag ",any",
// unmarshal maps the sub-element to that struct field.
//
// - An anonymous struct field is handled as if the fields of its
// value were part of the outer struct.
//
// - A struct field with tag "-" is never unmarshaled into.
//
// If Unmarshal encounters a field type that implements the Unmarshaler
// interface, Unmarshal calls its UnmarshalXML method to produce the value from
// the XML element. Otherwise, if the value implements
// [encoding.TextUnmarshaler], Unmarshal calls that value's UnmarshalText method.
//
// Unmarshal maps an XML element to a string or []byte by saving the
// concatenation of that element's character data in the string or
// []byte. The saved []byte is never nil.
//
// Unmarshal maps an attribute value to a string or []byte by saving
// the value in the string or slice.
//
// Unmarshal maps an attribute value to an [Attr] by saving the attribute,
// including its name, in the Attr.
//
// Unmarshal maps an XML element or attribute value to a slice by
// extending the length of the slice and mapping the element or attribute
// to the newly created value.
//
// Unmarshal maps an XML element or attribute value to a bool by
// setting it to the boolean value represented by the string. Whitespace
// is trimmed and ignored.
//
// Unmarshal maps an XML element or attribute value to an integer or
// floating-point field by setting the field to the result of
// interpreting the string value in decimal. There is no check for
// overflow. Whitespace is trimmed and ignored.
//
// Unmarshal maps an XML element to a Name by recording the element
// name.
//
// Unmarshal maps an XML element to a pointer by setting the pointer
// to a freshly allocated value and then mapping the element to that value.
//
// A missing element or empty attribute value will be unmarshaled as a zero value.
// If the field is a slice, a zero value will be appended to the field. Otherwise, the
// field will be set to its zero value.
func Unmarshal(data []byte, v any) error {
return NewDecoder(bytes.NewReader(data)).Decode(v)
}
// Decode works like [Unmarshal], except it reads the decoder
// stream to find the start element.
func (d *Decoder) Decode(v any) error {
return d.DecodeElement(v, nil)
}
// DecodeElement works like [Unmarshal] except that it takes
// a pointer to the start XML element to decode into v.
// It is useful when a client reads some raw XML tokens itself
// but also wants to defer to [Unmarshal] for some elements.
func (d *Decoder) DecodeElement(v any, start *StartElement) error {
val := reflect.ValueOf(v)
if val.Kind() != reflect.Pointer {
return errors.New("non-pointer passed to Unmarshal")
}
if val.IsNil() {
return errors.New("nil pointer passed to Unmarshal")
}
return d.unmarshal(val.Elem(), start, 0)
}
// An UnmarshalError represents an error in the unmarshaling process.
type UnmarshalError string
func (e UnmarshalError) Error() string { return string(e) }
// Unmarshaler is the interface implemented by objects that can unmarshal
// an XML element description of themselves.
//
// UnmarshalXML decodes a single XML element
// beginning with the given start element.
// If it returns an error, the outer call to Unmarshal stops and
// returns that error.
// UnmarshalXML must consume exactly one XML element.
// One common implementation strategy is to unmarshal into
// a separate value with a layout matching the expected XML
// using d.DecodeElement, and then to copy the data from
// that value into the receiver.
// Another common strategy is to use d.Token to process the
// XML object one token at a time.
// UnmarshalXML may not use d.RawToken.
type Unmarshaler interface {
UnmarshalXML(d *Decoder, start StartElement) error
}
// UnmarshalerAttr is the interface implemented by objects that can unmarshal
// an XML attribute description of themselves.
//
// UnmarshalXMLAttr decodes a single XML attribute.
// If it returns an error, the outer call to [Unmarshal] stops and
// returns that error.
// UnmarshalXMLAttr is used only for struct fields with the
// "attr" option in the field tag.
type UnmarshalerAttr interface {
UnmarshalXMLAttr(attr Attr) error
}
// receiverType returns the receiver type to use in an expression like "%s.MethodName".
func receiverType(val any) string {
t := reflect.TypeOf(val)
if t.Name() != "" {
return t.String()
}
return "(" + t.String() + ")"
}
// unmarshalInterface unmarshals a single XML element into val.
// start is the opening tag of the element.
func (d *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error {
// Record that decoder must stop at end tag corresponding to start.
d.pushEOF()
d.unmarshalDepth++
err := val.UnmarshalXML(d, *start)
d.unmarshalDepth--
if err != nil {
d.popEOF()
return err
}
if !d.popEOF() {
return fmt.Errorf("xml: %s.UnmarshalXML did not consume entire <%s> element", receiverType(val), start.Name.Local)
}
return nil
}
// unmarshalTextInterface unmarshals a single XML element into val.
// The chardata contained in the element (but not its children)
// is passed to the text unmarshaler.
func (d *Decoder) unmarshalTextInterface(val encoding.TextUnmarshaler) error {
var buf []byte
depth := 1
for depth > 0 {
t, err := d.Token()
if err != nil {
return err
}
switch t := t.(type) {
case CharData:
if depth == 1 {
buf = append(buf, t...)
}
case StartElement:
depth++
case EndElement:
depth--
}
}
return val.UnmarshalText(buf)
}
// unmarshalAttr unmarshals a single XML attribute into val.
func (d *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
if val.Kind() == reflect.Pointer {
if val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
val = val.Elem()
}
if val.CanInterface() && val.Type().Implements(unmarshalerAttrType) {
// This is an unmarshaler with a non-pointer receiver,
// so it's likely to be incorrect, but we do what we're told.
return val.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr)
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(unmarshalerAttrType) {
return pv.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr)
}
}
// Not an UnmarshalerAttr; try encoding.TextUnmarshaler.
if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
// This is an unmarshaler with a non-pointer receiver,
// so it's likely to be incorrect, but we do what we're told.
return val.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
}
}
if val.Kind() == reflect.Slice && val.Type().Elem().Kind() != reflect.Uint8 {
// Slice of element values.
// Grow slice.
n := val.Len()
val.Grow(1)
val.SetLen(n + 1)
// Recur to read element into slice.
if err := d.unmarshalAttr(val.Index(n), attr); err != nil {
val.SetLen(n)
return err
}
return nil
}
if val.Type() == attrType {
val.Set(reflect.ValueOf(attr))
return nil
}
return copyValue(val, []byte(attr.Value))
}
var (
attrType = reflect.TypeFor[Attr]()
unmarshalerType = reflect.TypeFor[Unmarshaler]()
unmarshalerAttrType = reflect.TypeFor[UnmarshalerAttr]()
textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
)
const (
maxUnmarshalDepth = 10000
maxUnmarshalDepthWasm = 5000 // go.dev/issue/56498
)
var errUnmarshalDepth = errors.New("exceeded max depth")
// Unmarshal a single XML element into val.
func (d *Decoder) unmarshal(val reflect.Value, start *StartElement, depth int) error {
if depth >= maxUnmarshalDepth || runtime.GOARCH == "wasm" && depth >= maxUnmarshalDepthWasm {
return errUnmarshalDepth
}
// Find start element if we need it.
if start == nil {
for {
tok, err := d.Token()
if err != nil {
return err
}
if t, ok := tok.(StartElement); ok {
start = &t
break
}
}
}
// Load value from interface, but only if the result will be
// usefully addressable.
if val.Kind() == reflect.Interface && !val.IsNil() {
e := val.Elem()
if e.Kind() == reflect.Pointer && !e.IsNil() {
val = e
}
}
if val.Kind() == reflect.Pointer {
if val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
val = val.Elem()
}
if val.CanInterface() && val.Type().Implements(unmarshalerType) {
// This is an unmarshaler with a non-pointer receiver,
// so it's likely to be incorrect, but we do what we're told.
return d.unmarshalInterface(val.Interface().(Unmarshaler), start)
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(unmarshalerType) {
return d.unmarshalInterface(pv.Interface().(Unmarshaler), start)
}
}
if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
return d.unmarshalTextInterface(val.Interface().(encoding.TextUnmarshaler))
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
return d.unmarshalTextInterface(pv.Interface().(encoding.TextUnmarshaler))
}
}
var (
data []byte
saveData reflect.Value
comment []byte
saveComment reflect.Value
saveXML reflect.Value
saveXMLIndex int
saveXMLData []byte
saveAny reflect.Value
sv reflect.Value
tinfo *typeInfo
err error
)
switch v := val; v.Kind() {
default:
return errors.New("unknown type " + v.Type().String())
case reflect.Interface:
// TODO: For now, simply ignore the field. In the near
// future we may choose to unmarshal the start
// element on it, if not nil.
return d.Skip()
case reflect.Slice:
typ := v.Type()
if typ.Elem().Kind() == reflect.Uint8 {
// []byte
saveData = v
break
}
// Slice of element values.
// Grow slice.
n := v.Len()
v.Grow(1)
v.SetLen(n + 1)
// Recur to read element into slice.
if err := d.unmarshal(v.Index(n), start, depth+1); err != nil {
v.SetLen(n)
return err
}
return nil
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String:
saveData = v
case reflect.Struct:
typ := v.Type()
if typ == nameType {
v.Set(reflect.ValueOf(start.Name))
break
}
sv = v
tinfo, err = getTypeInfo(typ)
if err != nil {
return err
}
// Validate and assign element name.
if tinfo.xmlname != nil {
finfo := tinfo.xmlname
if finfo.name != "" && finfo.name != start.Name.Local {
return UnmarshalError("expected element type <" + finfo.name + "> but have <" + start.Name.Local + ">")
}
if finfo.xmlns != "" && finfo.xmlns != start.Name.Space {
e := "expected element <" + finfo.name + "> in name space " + finfo.xmlns + " but have "
if start.Name.Space == "" {
e += "no name space"
} else {
e += start.Name.Space
}
return UnmarshalError(e)
}
fv := finfo.value(sv, initNilPointers)
if _, ok := fv.Interface().(Name); ok {
fv.Set(reflect.ValueOf(start.Name))
}
}
// Assign attributes.
for _, a := range start.Attr {
handled := false
any := -1
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
switch finfo.flags & fMode {
case fAttr:
strv := finfo.value(sv, initNilPointers)
if a.Name.Local == finfo.name && (finfo.xmlns == "" || finfo.xmlns == a.Name.Space) {
if err := d.unmarshalAttr(strv, a); err != nil {
return err
}
handled = true
}
case fAny | fAttr:
if any == -1 {
any = i
}
}
}
if !handled && any >= 0 {
finfo := &tinfo.fields[any]
strv := finfo.value(sv, initNilPointers)
if err := d.unmarshalAttr(strv, a); err != nil {
return err
}
}
}
// Determine whether we need to save character data or comments.
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
switch finfo.flags & fMode {
case fCDATA, fCharData:
if !saveData.IsValid() {
saveData = finfo.value(sv, initNilPointers)
}
case fComment:
if !saveComment.IsValid() {
saveComment = finfo.value(sv, initNilPointers)
}
case fAny, fAny | fElement:
if !saveAny.IsValid() {
saveAny = finfo.value(sv, initNilPointers)
}
case fInnerXML:
if !saveXML.IsValid() {
saveXML = finfo.value(sv, initNilPointers)
if d.saved == nil {
saveXMLIndex = 0
d.saved = new(bytes.Buffer)
} else {
saveXMLIndex = d.savedOffset()
}
}
}
}
}
// Find end element.
// Process sub-elements along the way.
Loop:
for {
var savedOffset int
if saveXML.IsValid() {
savedOffset = d.savedOffset()
}
tok, err := d.Token()
if err != nil {
return err
}
switch t := tok.(type) {
case StartElement:
consumed := false
if sv.IsValid() {
// unmarshalPath can call unmarshal, so we need to pass the depth through so that
// we can continue to enforce the maximum recursion limit.
consumed, err = d.unmarshalPath(tinfo, sv, nil, &t, depth)
if err != nil {
return err
}
if !consumed && saveAny.IsValid() {
consumed = true
if err := d.unmarshal(saveAny, &t, depth+1); err != nil {
return err
}
}
}
if !consumed {
if err := d.Skip(); err != nil {
return err
}
}
case EndElement:
if saveXML.IsValid() {
saveXMLData = d.saved.Bytes()[saveXMLIndex:savedOffset]
if saveXMLIndex == 0 {
d.saved = nil
}
}
break Loop
case CharData:
if saveData.IsValid() {
data = append(data, t...)
}
case Comment:
if saveComment.IsValid() {
comment = append(comment, t...)
}
}
}
if saveData.IsValid() && saveData.CanInterface() && saveData.Type().Implements(textUnmarshalerType) {
if err := saveData.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
return err
}
saveData = reflect.Value{}
}
if saveData.IsValid() && saveData.CanAddr() {
pv := saveData.Addr()
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
return err
}
saveData = reflect.Value{}
}
}
if err := copyValue(saveData, data); err != nil {
return err
}
switch t := saveComment; t.Kind() {
case reflect.String:
t.SetString(string(comment))
case reflect.Slice:
t.Set(reflect.ValueOf(comment))
}
switch t := saveXML; t.Kind() {
case reflect.String:
t.SetString(string(saveXMLData))
case reflect.Slice:
if t.Type().Elem().Kind() == reflect.Uint8 {
t.Set(reflect.ValueOf(saveXMLData))
}
}
return nil
}
func copyValue(dst reflect.Value, src []byte) (err error) {
dst0 := dst
if dst.Kind() == reflect.Pointer {
if dst.IsNil() {
dst.Set(reflect.New(dst.Type().Elem()))
}
dst = dst.Elem()
}
// Save accumulated data.
switch dst.Kind() {
case reflect.Invalid:
// Probably a comment.
default:
return errors.New("cannot unmarshal into " + dst0.Type().String())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if len(src) == 0 {
dst.SetInt(0)
return nil
}
itmp, err := strconv.ParseInt(strings.TrimSpace(string(src)), 10, dst.Type().Bits())
if err != nil {
return err
}
dst.SetInt(itmp)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
if len(src) == 0 {
dst.SetUint(0)
return nil
}
utmp, err := strconv.ParseUint(strings.TrimSpace(string(src)), 10, dst.Type().Bits())
if err != nil {
return err
}
dst.SetUint(utmp)
case reflect.Float32, reflect.Float64:
if len(src) == 0 {
dst.SetFloat(0)
return nil
}
ftmp, err := strconv.ParseFloat(strings.TrimSpace(string(src)), dst.Type().Bits())
if err != nil {
return err
}
dst.SetFloat(ftmp)
case reflect.Bool:
if len(src) == 0 {
dst.SetBool(false)
return nil
}
value, err := strconv.ParseBool(strings.TrimSpace(string(src)))
if err != nil {
return err
}
dst.SetBool(value)
case reflect.String:
dst.SetString(string(src))
case reflect.Slice:
if len(src) == 0 {
// non-nil to flag presence
src = []byte{}
}
dst.SetBytes(src)
}
return nil
}
// unmarshalPath walks down an XML structure looking for wanted
// paths, and calls unmarshal on them.
// The consumed result tells whether XML elements have been consumed
// from the Decoder until start's matching end element, or if it's
// still untouched because start is uninteresting for sv's fields.
func (d *Decoder) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []string, start *StartElement, depth int) (consumed bool, err error) {
recurse := false
Loop:
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
if finfo.flags&fElement == 0 || len(finfo.parents) < len(parents) || finfo.xmlns != "" && finfo.xmlns != start.Name.Space {
continue
}
for j := range parents {
if parents[j] != finfo.parents[j] {
continue Loop
}
}
if len(finfo.parents) == len(parents) && finfo.name == start.Name.Local {
// It's a perfect match, unmarshal the field.
return true, d.unmarshal(finfo.value(sv, initNilPointers), start, depth+1)
}
if len(finfo.parents) > len(parents) && finfo.parents[len(parents)] == start.Name.Local {
// It's a prefix for the field. Break and recurse
// since it's not ok for one field path to be itself
// the prefix for another field path.
recurse = true
// We can reuse the same slice as long as we
// don't try to append to it.
parents = finfo.parents[:len(parents)+1]
break
}
}
if !recurse {
// We have no business with this element.
return false, nil
}
// The element is not a perfect match for any field, but one
// or more fields have the path to this element as a parent
// prefix. Recurse and attempt to match these.
for {
var tok Token
tok, err = d.Token()
if err != nil {
return true, err
}
switch t := tok.(type) {
case StartElement:
// the recursion depth of unmarshalPath is limited to the path length specified
// by the struct field tag, so we don't increment the depth here.
consumed2, err := d.unmarshalPath(tinfo, sv, parents, &t, depth)
if err != nil {
return true, err
}
if !consumed2 {
if err := d.Skip(); err != nil {
return true, err
}
}
case EndElement:
return true, nil
}
}
}
// Skip reads tokens until it has consumed the end element
// matching the most recent start element already consumed,
// skipping nested structures.
// It returns nil if it finds an end element matching the start
// element; otherwise it returns an error describing the problem.
func (d *Decoder) Skip() error {
var depth int64
for {
tok, err := d.Token()
if err != nil {
return err
}
switch tok.(type) {
case StartElement:
depth++
case EndElement:
if depth == 0 {
return nil
}
depth--
}
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xml
import (
"fmt"
"reflect"
"strings"
"sync"
)
// typeInfo holds details for the xml representation of a type.
type typeInfo struct {
xmlname *fieldInfo
fields []fieldInfo
}
// fieldInfo holds details for the xml representation of a single field.
type fieldInfo struct {
idx []int
name string
xmlns string
flags fieldFlags
parents []string
}
type fieldFlags int
const (
fElement fieldFlags = 1 << iota
fAttr
fCDATA
fCharData
fInnerXML
fComment
fAny
fOmitEmpty
fMode = fElement | fAttr | fCDATA | fCharData | fInnerXML | fComment | fAny
xmlName = "XMLName"
)
var tinfoMap sync.Map // map[reflect.Type]*typeInfo
var nameType = reflect.TypeFor[Name]()
// getTypeInfo returns the typeInfo structure with details necessary
// for marshaling and unmarshaling typ.
func getTypeInfo(typ reflect.Type) (*typeInfo, error) {
if ti, ok := tinfoMap.Load(typ); ok {
return ti.(*typeInfo), nil
}
tinfo := &typeInfo{}
if typ.Kind() == reflect.Struct && typ != nameType {
n := typ.NumField()
for i := 0; i < n; i++ {
f := typ.Field(i)
if (!f.IsExported() && !f.Anonymous) || f.Tag.Get("xml") == "-" {
continue // Private field
}
// For embedded structs, embed its fields.
if f.Anonymous {
t := f.Type
if t.Kind() == reflect.Pointer {
t = t.Elem()
}
if t.Kind() == reflect.Struct {
inner, err := getTypeInfo(t)
if err != nil {
return nil, err
}
if tinfo.xmlname == nil {
tinfo.xmlname = inner.xmlname
}
for _, finfo := range inner.fields {
finfo.idx = append([]int{i}, finfo.idx...)
if err := addFieldInfo(typ, tinfo, &finfo); err != nil {
return nil, err
}
}
continue
}
}
finfo, err := structFieldInfo(typ, &f)
if err != nil {
return nil, err
}
if f.Name == xmlName {
tinfo.xmlname = finfo
continue
}
// Add the field if it doesn't conflict with other fields.
if err := addFieldInfo(typ, tinfo, finfo); err != nil {
return nil, err
}
}
}
ti, _ := tinfoMap.LoadOrStore(typ, tinfo)
return ti.(*typeInfo), nil
}
// structFieldInfo builds and returns a fieldInfo for f.
func structFieldInfo(typ reflect.Type, f *reflect.StructField) (*fieldInfo, error) {
finfo := &fieldInfo{idx: f.Index}
// Split the tag from the xml namespace if necessary.
tag := f.Tag.Get("xml")
if ns, t, ok := strings.Cut(tag, " "); ok {
finfo.xmlns, tag = ns, t
}
// Parse flags.
tokens := strings.Split(tag, ",")
if len(tokens) == 1 {
finfo.flags = fElement
} else {
tag = tokens[0]
for _, flag := range tokens[1:] {
switch flag {
case "attr":
finfo.flags |= fAttr
case "cdata":
finfo.flags |= fCDATA
case "chardata":
finfo.flags |= fCharData
case "innerxml":
finfo.flags |= fInnerXML
case "comment":
finfo.flags |= fComment
case "any":
finfo.flags |= fAny
case "omitempty":
finfo.flags |= fOmitEmpty
}
}
// Validate the flags used.
valid := true
switch mode := finfo.flags & fMode; mode {
case 0:
finfo.flags |= fElement
case fAttr, fCDATA, fCharData, fInnerXML, fComment, fAny, fAny | fAttr:
if f.Name == xmlName || tag != "" && mode != fAttr {
valid = false
}
default:
// This will also catch multiple modes in a single field.
valid = false
}
if finfo.flags&fMode == fAny {
finfo.flags |= fElement
}
if finfo.flags&fOmitEmpty != 0 && finfo.flags&(fElement|fAttr) == 0 {
valid = false
}
if !valid {
return nil, fmt.Errorf("xml: invalid tag in field %s of type %s: %q",
f.Name, typ, f.Tag.Get("xml"))
}
}
// Use of xmlns without a name is not allowed.
if finfo.xmlns != "" && tag == "" {
return nil, fmt.Errorf("xml: namespace without name in field %s of type %s: %q",
f.Name, typ, f.Tag.Get("xml"))
}
if f.Name == xmlName {
// The XMLName field records the XML element name. Don't
// process it as usual because its name should default to
// empty rather than to the field name.
finfo.name = tag
return finfo, nil
}
if tag == "" {
// If the name part of the tag is completely empty, get
// default from XMLName of underlying struct if feasible,
// or field name otherwise.
if xmlname := lookupXMLName(f.Type); xmlname != nil {
finfo.xmlns, finfo.name = xmlname.xmlns, xmlname.name
} else {
finfo.name = f.Name
}
return finfo, nil
}
// Prepare field name and parents.
parents := strings.Split(tag, ">")
if parents[0] == "" {
parents[0] = f.Name
}
if parents[len(parents)-1] == "" {
return nil, fmt.Errorf("xml: trailing '>' in field %s of type %s", f.Name, typ)
}
finfo.name = parents[len(parents)-1]
if len(parents) > 1 {
if (finfo.flags & fElement) == 0 {
return nil, fmt.Errorf("xml: %s chain not valid with %s flag", tag, strings.Join(tokens[1:], ","))
}
finfo.parents = parents[:len(parents)-1]
}
// If the field type has an XMLName field, the names must match
// so that the behavior of both marshaling and unmarshaling
// is straightforward and unambiguous.
if finfo.flags&fElement != 0 {
ftyp := f.Type
xmlname := lookupXMLName(ftyp)
if xmlname != nil && xmlname.name != finfo.name {
return nil, fmt.Errorf("xml: name %q in tag of %s.%s conflicts with name %q in %s.XMLName",
finfo.name, typ, f.Name, xmlname.name, ftyp)
}
}
return finfo, nil
}
// lookupXMLName returns the fieldInfo for typ's XMLName field
// in case it exists and has a valid xml field tag, otherwise
// it returns nil.
func lookupXMLName(typ reflect.Type) (xmlname *fieldInfo) {
for typ.Kind() == reflect.Pointer {
typ = typ.Elem()
}
if typ.Kind() != reflect.Struct {
return nil
}
for i, n := 0, typ.NumField(); i < n; i++ {
f := typ.Field(i)
if f.Name != xmlName {
continue
}
finfo, err := structFieldInfo(typ, &f)
if err == nil && finfo.name != "" {
return finfo
}
// Also consider errors as a non-existent field tag
// and let getTypeInfo itself report the error.
break
}
return nil
}
// addFieldInfo adds finfo to tinfo.fields if there are no
// conflicts, or if conflicts arise from previous fields that were
// obtained from deeper embedded structures than finfo. In the latter
// case, the conflicting entries are dropped.
// A conflict occurs when the path (parent + name) to a field is
// itself a prefix of another path, or when two paths match exactly.
// It is okay for field paths to share a common, shorter prefix.
func addFieldInfo(typ reflect.Type, tinfo *typeInfo, newf *fieldInfo) error {
var conflicts []int
Loop:
// First, figure all conflicts. Most working code will have none.
for i := range tinfo.fields {
oldf := &tinfo.fields[i]
if oldf.flags&fMode != newf.flags&fMode {
continue
}
if oldf.xmlns != "" && newf.xmlns != "" && oldf.xmlns != newf.xmlns {
continue
}
minl := min(len(newf.parents), len(oldf.parents))
for p := 0; p < minl; p++ {
if oldf.parents[p] != newf.parents[p] {
continue Loop
}
}
if len(oldf.parents) > len(newf.parents) {
if oldf.parents[len(newf.parents)] == newf.name {
conflicts = append(conflicts, i)
}
} else if len(oldf.parents) < len(newf.parents) {
if newf.parents[len(oldf.parents)] == oldf.name {
conflicts = append(conflicts, i)
}
} else {
if newf.name == oldf.name && newf.xmlns == oldf.xmlns {
conflicts = append(conflicts, i)
}
}
}
// Without conflicts, add the new field and return.
if conflicts == nil {
tinfo.fields = append(tinfo.fields, *newf)
return nil
}
// If any conflict is shallower, ignore the new field.
// This matches the Go field resolution on embedding.
for _, i := range conflicts {
if len(tinfo.fields[i].idx) < len(newf.idx) {
return nil
}
}
// Otherwise, if any of them is at the same depth level, it's an error.
for _, i := range conflicts {
oldf := &tinfo.fields[i]
if len(oldf.idx) == len(newf.idx) {
f1 := typ.FieldByIndex(oldf.idx)
f2 := typ.FieldByIndex(newf.idx)
return &TagPathError{typ, f1.Name, f1.Tag.Get("xml"), f2.Name, f2.Tag.Get("xml")}
}
}
// Otherwise, the new field is shallower, and thus takes precedence,
// so drop the conflicting fields from tinfo and append the new one.
for c := len(conflicts) - 1; c >= 0; c-- {
i := conflicts[c]
copy(tinfo.fields[i:], tinfo.fields[i+1:])
tinfo.fields = tinfo.fields[:len(tinfo.fields)-1]
}
tinfo.fields = append(tinfo.fields, *newf)
return nil
}
// A TagPathError represents an error in the unmarshaling process
// caused by the use of field tags with conflicting paths.
type TagPathError struct {
Struct reflect.Type
Field1, Tag1 string
Field2, Tag2 string
}
func (e *TagPathError) Error() string {
return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
}
const (
initNilPointers = true
dontInitNilPointers = false
)
// value returns v's field value corresponding to finfo.
// It's equivalent to v.FieldByIndex(finfo.idx), but when passed
// initNilPointers, it initializes and dereferences pointers as necessary.
// When passed dontInitNilPointers and a nil pointer is reached, the function
// returns a zero reflect.Value.
func (finfo *fieldInfo) value(v reflect.Value, shouldInitNilPointers bool) reflect.Value {
for i, x := range finfo.idx {
if i > 0 {
t := v.Type()
if t.Kind() == reflect.Pointer && t.Elem().Kind() == reflect.Struct {
if v.IsNil() {
if !shouldInitNilPointers {
return reflect.Value{}
}
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
}
v = v.Field(x)
}
return v
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package xml implements a simple XML 1.0 parser that
// understands XML name spaces.
package xml
// References:
// Annotated XML spec: https://www.xml.com/axml/testaxml.htm
// XML name spaces: https://www.w3.org/TR/REC-xml-names/
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"strconv"
"strings"
"unicode"
"unicode/utf8"
)
// A SyntaxError represents a syntax error in the XML input stream.
type SyntaxError struct {
Msg string
Line int
}
func (e *SyntaxError) Error() string {
return "XML syntax error on line " + strconv.Itoa(e.Line) + ": " + e.Msg
}
// A Name represents an XML name (Local) annotated
// with a name space identifier (Space).
// In tokens returned by [Decoder.Token], the Space identifier
// is given as a canonical URL, not the short prefix used
// in the document being parsed.
type Name struct {
Space, Local string
}
// An Attr represents an attribute in an XML element (Name=Value).
type Attr struct {
Name Name
Value string
}
// A Token is an interface holding one of the token types:
// [StartElement], [EndElement], [CharData], [Comment], [ProcInst], or [Directive].
type Token any
// A StartElement represents an XML start element.
type StartElement struct {
Name Name
Attr []Attr
}
// Copy creates a new copy of StartElement.
func (e StartElement) Copy() StartElement {
attrs := make([]Attr, len(e.Attr))
copy(attrs, e.Attr)
e.Attr = attrs
return e
}
// End returns the corresponding XML end element.
func (e StartElement) End() EndElement {
return EndElement{e.Name}
}
// An EndElement represents an XML end element.
type EndElement struct {
Name Name
}
// A CharData represents XML character data (raw text),
// in which XML escape sequences have been replaced by
// the characters they represent.
type CharData []byte
// Copy creates a new copy of CharData.
func (c CharData) Copy() CharData { return CharData(bytes.Clone(c)) }
// A Comment represents an XML comment of the form <!--comment-->.
// The bytes do not include the <!-- and --> comment markers.
type Comment []byte
// Copy creates a new copy of Comment.
func (c Comment) Copy() Comment { return Comment(bytes.Clone(c)) }
// A ProcInst represents an XML processing instruction of the form <?target inst?>
type ProcInst struct {
Target string
Inst []byte
}
// Copy creates a new copy of ProcInst.
func (p ProcInst) Copy() ProcInst {
p.Inst = bytes.Clone(p.Inst)
return p
}
// A Directive represents an XML directive of the form <!text>.
// The bytes do not include the <! and > markers.
type Directive []byte
// Copy creates a new copy of Directive.
func (d Directive) Copy() Directive { return Directive(bytes.Clone(d)) }
// CopyToken returns a copy of a Token.
func CopyToken(t Token) Token {
switch v := t.(type) {
case CharData:
return v.Copy()
case Comment:
return v.Copy()
case Directive:
return v.Copy()
case ProcInst:
return v.Copy()
case StartElement:
return v.Copy()
}
return t
}
// A TokenReader is anything that can decode a stream of XML tokens, including a
// [Decoder].
//
// When Token encounters an error or end-of-file condition after successfully
// reading a token, it returns the token. It may return the (non-nil) error from
// the same call or return the error (and a nil token) from a subsequent call.
// An instance of this general case is that a TokenReader returning a non-nil
// token at the end of the token stream may return either io.EOF or a nil error.
// The next Read should return nil, [io.EOF].
//
// Implementations of Token are discouraged from returning a nil token with a
// nil error. Callers should treat a return of nil, nil as indicating that
// nothing happened; in particular it does not indicate EOF.
type TokenReader interface {
Token() (Token, error)
}
// A Decoder represents an XML parser reading a particular input stream.
// The parser assumes that its input is encoded in UTF-8.
type Decoder struct {
// Strict defaults to true, enforcing the requirements
// of the XML specification.
// If set to false, the parser allows input containing common
// mistakes:
// * If an element is missing an end tag, the parser invents
// end tags as necessary to keep the return values from Token
// properly balanced.
// * In attribute values and character data, unknown or malformed
// character entities (sequences beginning with &) are left alone.
//
// Setting:
//
// d.Strict = false
// d.AutoClose = xml.HTMLAutoClose
// d.Entity = xml.HTMLEntity
//
// creates a parser that can handle typical HTML.
//
// Strict mode does not enforce the requirements of the XML name spaces TR.
// In particular it does not reject name space tags using undefined prefixes.
// Such tags are recorded with the unknown prefix as the name space URL.
Strict bool
// When Strict == false, AutoClose indicates a set of elements to
// consider closed immediately after they are opened, regardless
// of whether an end element is present.
AutoClose []string
// Entity can be used to map non-standard entity names to string replacements.
// The parser behaves as if these standard mappings are present in the map,
// regardless of the actual map content:
//
// "lt": "<",
// "gt": ">",
// "amp": "&",
// "apos": "'",
// "quot": `"`,
Entity map[string]string
// CharsetReader, if non-nil, defines a function to generate
// charset-conversion readers, converting from the provided
// non-UTF-8 charset into UTF-8. If CharsetReader is nil or
// returns an error, parsing stops with an error. One of the
// CharsetReader's result values must be non-nil.
CharsetReader func(charset string, input io.Reader) (io.Reader, error)
// DefaultSpace sets the default name space used for unadorned tags,
// as if the entire XML stream were wrapped in an element containing
// the attribute xmlns="DefaultSpace".
DefaultSpace string
r io.ByteReader
t TokenReader
buf bytes.Buffer
saved *bytes.Buffer
stk *stack
free *stack
needClose bool
toClose Name
nextToken Token
nextByte int
ns map[string]string
err error
line int
linestart int64
offset int64
unmarshalDepth int
}
// NewDecoder creates a new XML parser reading from r.
// If r does not implement [io.ByteReader], NewDecoder will
// do its own buffering.
func NewDecoder(r io.Reader) *Decoder {
d := &Decoder{
ns: make(map[string]string),
nextByte: -1,
line: 1,
Strict: true,
}
d.switchToReader(r)
return d
}
// NewTokenDecoder creates a new XML parser using an underlying token stream.
func NewTokenDecoder(t TokenReader) *Decoder {
// Is it already a Decoder?
if d, ok := t.(*Decoder); ok {
return d
}
d := &Decoder{
ns: make(map[string]string),
t: t,
nextByte: -1,
line: 1,
Strict: true,
}
return d
}
// Token returns the next XML token in the input stream.
// At the end of the input stream, Token returns nil, [io.EOF].
//
// Slices of bytes in the returned token data refer to the
// parser's internal buffer and remain valid only until the next
// call to Token. To acquire a copy of the bytes, call [CopyToken]
// or the token's Copy method.
//
// Token expands self-closing elements such as <br>
// into separate start and end elements returned by successive calls.
//
// Token guarantees that the [StartElement] and [EndElement]
// tokens it returns are properly nested and matched:
// if Token encounters an unexpected end element
// or EOF before all expected end elements,
// it will return an error.
//
// If [Decoder.CharsetReader] is called and returns an error,
// the error is wrapped and returned.
//
// Token implements XML name spaces as described by
// https://www.w3.org/TR/REC-xml-names/. Each of the
// [Name] structures contained in the Token has the Space
// set to the URL identifying its name space when known.
// If Token encounters an unrecognized name space prefix,
// it uses the prefix as the Space rather than report an error.
func (d *Decoder) Token() (Token, error) {
var t Token
var err error
if d.stk != nil && d.stk.kind == stkEOF {
return nil, io.EOF
}
if d.nextToken != nil {
t = d.nextToken
d.nextToken = nil
} else {
if t, err = d.rawToken(); t == nil && err != nil {
if err == io.EOF && d.stk != nil && d.stk.kind != stkEOF {
err = d.syntaxError("unexpected EOF")
}
return nil, err
}
// We still have a token to process, so clear any
// errors (e.g. EOF) and proceed.
err = nil
}
if !d.Strict {
if t1, ok := d.autoClose(t); ok {
d.nextToken = t
t = t1
}
}
switch t1 := t.(type) {
case StartElement:
// In XML name spaces, the translations listed in the
// attributes apply to the element name and
// to the other attribute names, so process
// the translations first.
for _, a := range t1.Attr {
if a.Name.Space == xmlnsPrefix {
v, ok := d.ns[a.Name.Local]
d.pushNs(a.Name.Local, v, ok)
d.ns[a.Name.Local] = a.Value
}
if a.Name.Space == "" && a.Name.Local == xmlnsPrefix {
// Default space for untagged names
v, ok := d.ns[""]
d.pushNs("", v, ok)
d.ns[""] = a.Value
}
}
d.pushElement(t1.Name)
d.translate(&t1.Name, true)
for i := range t1.Attr {
d.translate(&t1.Attr[i].Name, false)
}
t = t1
case EndElement:
if !d.popElement(&t1) {
return nil, d.err
}
t = t1
}
return t, err
}
const (
xmlURL = "http://www.w3.org/XML/1998/namespace"
xmlnsPrefix = "xmlns"
xmlPrefix = "xml"
)
// Apply name space translation to name n.
// The default name space (for Space=="")
// applies only to element names, not to attribute names.
func (d *Decoder) translate(n *Name, isElementName bool) {
switch {
case n.Space == xmlnsPrefix:
return
case n.Space == "" && !isElementName:
return
case n.Space == xmlPrefix:
n.Space = xmlURL
case n.Space == "" && n.Local == xmlnsPrefix:
return
}
if v, ok := d.ns[n.Space]; ok {
n.Space = v
} else if n.Space == "" {
n.Space = d.DefaultSpace
}
}
func (d *Decoder) switchToReader(r io.Reader) {
// Get efficient byte at a time reader.
// Assume that if reader has its own
// ReadByte, it's efficient enough.
// Otherwise, use bufio.
if rb, ok := r.(io.ByteReader); ok {
d.r = rb
} else {
d.r = bufio.NewReader(r)
}
}
// Parsing state - stack holds old name space translations
// and the current set of open elements. The translations to pop when
// ending a given tag are *below* it on the stack, which is
// more work but forced on us by XML.
type stack struct {
next *stack
kind int
name Name
ok bool
}
const (
stkStart = iota
stkNs
stkEOF
)
func (d *Decoder) push(kind int) *stack {
s := d.free
if s != nil {
d.free = s.next
} else {
s = new(stack)
}
s.next = d.stk
s.kind = kind
d.stk = s
return s
}
func (d *Decoder) pop() *stack {
s := d.stk
if s != nil {
d.stk = s.next
s.next = d.free
d.free = s
}
return s
}
// Record that after the current element is finished
// (that element is already pushed on the stack)
// Token should return EOF until popEOF is called.
func (d *Decoder) pushEOF() {
// Walk down stack to find Start.
// It might not be the top, because there might be stkNs
// entries above it.
start := d.stk
for start.kind != stkStart {
start = start.next
}
// The stkNs entries below a start are associated with that
// element too; skip over them.
for start.next != nil && start.next.kind == stkNs {
start = start.next
}
s := d.free
if s != nil {
d.free = s.next
} else {
s = new(stack)
}
s.kind = stkEOF
s.next = start.next
start.next = s
}
// Undo a pushEOF.
// The element must have been finished, so the EOF should be at the top of the stack.
func (d *Decoder) popEOF() bool {
if d.stk == nil || d.stk.kind != stkEOF {
return false
}
d.pop()
return true
}
// Record that we are starting an element with the given name.
func (d *Decoder) pushElement(name Name) {
s := d.push(stkStart)
s.name = name
}
// Record that we are changing the value of ns[local].
// The old value is url, ok.
func (d *Decoder) pushNs(local string, url string, ok bool) {
s := d.push(stkNs)
s.name.Local = local
s.name.Space = url
s.ok = ok
}
// Creates a SyntaxError with the current line number.
func (d *Decoder) syntaxError(msg string) error {
return &SyntaxError{Msg: msg, Line: d.line}
}
// Record that we are ending an element with the given name.
// The name must match the record at the top of the stack,
// which must be a pushElement record.
// After popping the element, apply any undo records from
// the stack to restore the name translations that existed
// before we saw this element.
func (d *Decoder) popElement(t *EndElement) bool {
s := d.pop()
name := t.Name
switch {
case s == nil || s.kind != stkStart:
d.err = d.syntaxError("unexpected end element </" + name.Local + ">")
return false
case s.name.Local != name.Local:
if !d.Strict {
d.needClose = true
d.toClose = t.Name
t.Name = s.name
return true
}
d.err = d.syntaxError("element <" + s.name.Local + "> closed by </" + name.Local + ">")
return false
case s.name.Space != name.Space:
ns := name.Space
if name.Space == "" {
ns = `""`
}
d.err = d.syntaxError("element <" + s.name.Local + "> in space " + s.name.Space +
" closed by </" + name.Local + "> in space " + ns)
return false
}
d.translate(&t.Name, true)
// Pop stack until a Start or EOF is on the top, undoing the
// translations that were associated with the element we just closed.
for d.stk != nil && d.stk.kind != stkStart && d.stk.kind != stkEOF {
s := d.pop()
if s.ok {
d.ns[s.name.Local] = s.name.Space
} else {
delete(d.ns, s.name.Local)
}
}
return true
}
// If the top element on the stack is autoclosing and
// t is not the end tag, invent the end tag.
func (d *Decoder) autoClose(t Token) (Token, bool) {
if d.stk == nil || d.stk.kind != stkStart {
return nil, false
}
for _, s := range d.AutoClose {
if strings.EqualFold(s, d.stk.name.Local) {
// This one should be auto closed if t doesn't close it.
et, ok := t.(EndElement)
if !ok || !strings.EqualFold(et.Name.Local, d.stk.name.Local) {
return EndElement{d.stk.name}, true
}
break
}
}
return nil, false
}
var errRawToken = errors.New("xml: cannot use RawToken from UnmarshalXML method")
// RawToken is like [Decoder.Token] but does not verify that
// start and end elements match and does not translate
// name space prefixes to their corresponding URLs.
func (d *Decoder) RawToken() (Token, error) {
if d.unmarshalDepth > 0 {
return nil, errRawToken
}
return d.rawToken()
}
func (d *Decoder) rawToken() (Token, error) {
if d.t != nil {
return d.t.Token()
}
if d.err != nil {
return nil, d.err
}
if d.needClose {
// The last element we read was self-closing and
// we returned just the StartElement half.
// Return the EndElement half now.
d.needClose = false
return EndElement{d.toClose}, nil
}
b, ok := d.getc()
if !ok {
return nil, d.err
}
if b != '<' {
// Text section.
d.ungetc(b)
data := d.text(-1, false)
if data == nil {
return nil, d.err
}
return CharData(data), nil
}
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
switch b {
case '/':
// </: End element
var name Name
if name, ok = d.nsname(); !ok {
if d.err == nil {
d.err = d.syntaxError("expected element name after </")
}
return nil, d.err
}
d.space()
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b != '>' {
d.err = d.syntaxError("invalid characters between </" + name.Local + " and >")
return nil, d.err
}
return EndElement{name}, nil
case '?':
// <?: Processing instruction.
var target string
if target, ok = d.name(); !ok {
if d.err == nil {
d.err = d.syntaxError("expected target name after <?")
}
return nil, d.err
}
d.space()
d.buf.Reset()
var b0 byte
for {
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
d.buf.WriteByte(b)
if b0 == '?' && b == '>' {
break
}
b0 = b
}
data := d.buf.Bytes()
data = data[0 : len(data)-2] // chop ?>
if target == "xml" {
content := string(data)
ver := procInst("version", content)
if ver != "" && ver != "1.0" {
d.err = fmt.Errorf("xml: unsupported version %q; only version 1.0 is supported", ver)
return nil, d.err
}
enc := procInst("encoding", content)
if enc != "" && enc != "utf-8" && enc != "UTF-8" && !strings.EqualFold(enc, "utf-8") {
if d.CharsetReader == nil {
d.err = fmt.Errorf("xml: encoding %q declared but Decoder.CharsetReader is nil", enc)
return nil, d.err
}
newr, err := d.CharsetReader(enc, d.r.(io.Reader))
if err != nil {
d.err = fmt.Errorf("xml: opening charset %q: %w", enc, err)
return nil, d.err
}
if newr == nil {
panic("CharsetReader returned a nil Reader for charset " + enc)
}
d.switchToReader(newr)
}
}
return ProcInst{target, data}, nil
case '!':
// <!: Maybe comment, maybe CDATA.
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
switch b {
case '-': // <!-
// Probably <!-- for a comment.
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b != '-' {
d.err = d.syntaxError("invalid sequence <!- not part of <!--")
return nil, d.err
}
// Look for terminator.
d.buf.Reset()
var b0, b1 byte
for {
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
d.buf.WriteByte(b)
if b0 == '-' && b1 == '-' {
if b != '>' {
d.err = d.syntaxError(
`invalid sequence "--" not allowed in comments`)
return nil, d.err
}
break
}
b0, b1 = b1, b
}
data := d.buf.Bytes()
data = data[0 : len(data)-3] // chop -->
return Comment(data), nil
case '[': // <![
// Probably <![CDATA[.
for i := 0; i < 6; i++ {
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b != "CDATA["[i] {
d.err = d.syntaxError("invalid <![ sequence")
return nil, d.err
}
}
// Have <![CDATA[. Read text until ]]>.
data := d.text(-1, true)
if data == nil {
return nil, d.err
}
return CharData(data), nil
}
// Probably a directive: <!DOCTYPE ...>, <!ENTITY ...>, etc.
// We don't care, but accumulate for caller. Quoted angle
// brackets do not count for nesting.
d.buf.Reset()
d.buf.WriteByte(b)
inquote := uint8(0)
depth := 0
for {
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if inquote == 0 && b == '>' && depth == 0 {
break
}
HandleB:
d.buf.WriteByte(b)
switch {
case b == inquote:
inquote = 0
case inquote != 0:
// in quotes, no special action
case b == '\'' || b == '"':
inquote = b
case b == '>' && inquote == 0:
depth--
case b == '<' && inquote == 0:
// Look for <!-- to begin comment.
s := "!--"
for i := 0; i < len(s); i++ {
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b != s[i] {
for j := 0; j < i; j++ {
d.buf.WriteByte(s[j])
}
depth++
goto HandleB
}
}
// Remove < that was written above.
d.buf.Truncate(d.buf.Len() - 1)
// Look for terminator.
var b0, b1 byte
for {
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b0 == '-' && b1 == '-' && b == '>' {
break
}
b0, b1 = b1, b
}
// Replace the comment with a space in the returned Directive
// body, so that markup parts that were separated by the comment
// (like a "<" and a "!") don't get joined when re-encoding the
// Directive, taking new semantic meaning.
d.buf.WriteByte(' ')
}
}
return Directive(d.buf.Bytes()), nil
}
// Must be an open element like <a href="foo">
d.ungetc(b)
var (
name Name
empty bool
attr []Attr
)
if name, ok = d.nsname(); !ok {
if d.err == nil {
d.err = d.syntaxError("expected element name after <")
}
return nil, d.err
}
attr = []Attr{}
for {
d.space()
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b == '/' {
empty = true
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b != '>' {
d.err = d.syntaxError("expected /> in element")
return nil, d.err
}
break
}
if b == '>' {
break
}
d.ungetc(b)
a := Attr{}
if a.Name, ok = d.nsname(); !ok {
if d.err == nil {
d.err = d.syntaxError("expected attribute name in element")
}
return nil, d.err
}
d.space()
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b != '=' {
if d.Strict {
d.err = d.syntaxError("attribute name without = in element")
return nil, d.err
}
d.ungetc(b)
a.Value = a.Name.Local
} else {
d.space()
data := d.attrval()
if data == nil {
return nil, d.err
}
a.Value = string(data)
}
attr = append(attr, a)
}
if empty {
d.needClose = true
d.toClose = name
}
return StartElement{name, attr}, nil
}
func (d *Decoder) attrval() []byte {
b, ok := d.mustgetc()
if !ok {
return nil
}
// Handle quoted attribute values
if b == '"' || b == '\'' {
return d.text(int(b), false)
}
// Handle unquoted attribute values for strict parsers
if d.Strict {
d.err = d.syntaxError("unquoted or missing attribute value in element")
return nil
}
// Handle unquoted attribute values for unstrict parsers
d.ungetc(b)
d.buf.Reset()
for {
b, ok = d.mustgetc()
if !ok {
return nil
}
// https://www.w3.org/TR/REC-html40/intro/sgmltut.html#h-3.2.2
if 'a' <= b && b <= 'z' || 'A' <= b && b <= 'Z' ||
'0' <= b && b <= '9' || b == '_' || b == ':' || b == '-' {
d.buf.WriteByte(b)
} else {
d.ungetc(b)
break
}
}
return d.buf.Bytes()
}
// Skip spaces if any
func (d *Decoder) space() {
for {
b, ok := d.getc()
if !ok {
return
}
switch b {
case ' ', '\r', '\n', '\t':
default:
d.ungetc(b)
return
}
}
}
// Read a single byte.
// If there is no byte to read, return ok==false
// and leave the error in d.err.
// Maintain line number.
func (d *Decoder) getc() (b byte, ok bool) {
if d.err != nil {
return 0, false
}
if d.nextByte >= 0 {
b = byte(d.nextByte)
d.nextByte = -1
} else {
b, d.err = d.r.ReadByte()
if d.err != nil {
return 0, false
}
if d.saved != nil {
d.saved.WriteByte(b)
}
}
if b == '\n' {
d.line++
d.linestart = d.offset + 1
}
d.offset++
return b, true
}
// InputOffset returns the input stream byte offset of the current decoder position.
// The offset gives the location of the end of the most recently returned token
// and the beginning of the next token.
func (d *Decoder) InputOffset() int64 {
return d.offset
}
// InputPos returns the line of the current decoder position and the 1 based
// input position of the line. The position gives the location of the end of the
// most recently returned token.
func (d *Decoder) InputPos() (line, column int) {
return d.line, int(d.offset-d.linestart) + 1
}
// Return saved offset.
// If we did ungetc (nextByte >= 0), have to back up one.
func (d *Decoder) savedOffset() int {
n := d.saved.Len()
if d.nextByte >= 0 {
n--
}
return n
}
// Must read a single byte.
// If there is no byte to read,
// set d.err to SyntaxError("unexpected EOF")
// and return ok==false
func (d *Decoder) mustgetc() (b byte, ok bool) {
if b, ok = d.getc(); !ok {
if d.err == io.EOF {
d.err = d.syntaxError("unexpected EOF")
}
}
return
}
// Unread a single byte.
func (d *Decoder) ungetc(b byte) {
if b == '\n' {
d.line--
}
d.nextByte = int(b)
d.offset--
}
var entity = map[string]rune{
"lt": '<',
"gt": '>',
"amp": '&',
"apos": '\'',
"quot": '"',
}
// Read plain text section (XML calls it character data).
// If quote >= 0, we are in a quoted string and need to find the matching quote.
// If cdata == true, we are in a <![CDATA[ section and need to find ]]>.
// On failure return nil and leave the error in d.err.
func (d *Decoder) text(quote int, cdata bool) []byte {
var b0, b1 byte
var trunc int
d.buf.Reset()
Input:
for {
b, ok := d.getc()
if !ok {
if cdata {
if d.err == io.EOF {
d.err = d.syntaxError("unexpected EOF in CDATA section")
}
return nil
}
break Input
}
// <![CDATA[ section ends with ]]>.
// It is an error for ]]> to appear in ordinary text,
// but it is allowed in quoted strings.
if quote < 0 && b0 == ']' && b1 == ']' && b == '>' {
if cdata {
trunc = 2
break Input
}
d.err = d.syntaxError("unescaped ]]> not in CDATA section")
return nil
}
// Stop reading text if we see a <.
if b == '<' && !cdata {
if quote >= 0 {
d.err = d.syntaxError("unescaped < inside quoted string")
return nil
}
d.ungetc('<')
break Input
}
if quote >= 0 && b == byte(quote) {
break Input
}
if b == '&' && !cdata {
// Read escaped character expression up to semicolon.
// XML in all its glory allows a document to define and use
// its own character names with <!ENTITY ...> directives.
// Parsers are required to recognize lt, gt, amp, apos, and quot
// even if they have not been declared.
before := d.buf.Len()
d.buf.WriteByte('&')
var ok bool
var text string
var haveText bool
if b, ok = d.mustgetc(); !ok {
return nil
}
if b == '#' {
d.buf.WriteByte(b)
if b, ok = d.mustgetc(); !ok {
return nil
}
base := 10
if b == 'x' {
base = 16
d.buf.WriteByte(b)
if b, ok = d.mustgetc(); !ok {
return nil
}
}
start := d.buf.Len()
for '0' <= b && b <= '9' ||
base == 16 && 'a' <= b && b <= 'f' ||
base == 16 && 'A' <= b && b <= 'F' {
d.buf.WriteByte(b)
if b, ok = d.mustgetc(); !ok {
return nil
}
}
if b != ';' {
d.ungetc(b)
} else {
s := string(d.buf.Bytes()[start:])
d.buf.WriteByte(';')
n, err := strconv.ParseUint(s, base, 64)
if err == nil && n <= unicode.MaxRune {
text = string(rune(n))
haveText = true
}
}
} else {
d.ungetc(b)
if !d.readName() {
if d.err != nil {
return nil
}
}
if b, ok = d.mustgetc(); !ok {
return nil
}
if b != ';' {
d.ungetc(b)
} else {
name := d.buf.Bytes()[before+1:]
d.buf.WriteByte(';')
if isName(name) {
s := string(name)
if r, ok := entity[s]; ok {
text = string(r)
haveText = true
} else if d.Entity != nil {
text, haveText = d.Entity[s]
}
}
}
}
if haveText {
d.buf.Truncate(before)
d.buf.WriteString(text)
b0, b1 = 0, 0
continue Input
}
if !d.Strict {
b0, b1 = 0, 0
continue Input
}
ent := string(d.buf.Bytes()[before:])
if ent[len(ent)-1] != ';' {
ent += " (no semicolon)"
}
d.err = d.syntaxError("invalid character entity " + ent)
return nil
}
// We must rewrite unescaped \r and \r\n into \n.
if b == '\r' {
d.buf.WriteByte('\n')
} else if b1 == '\r' && b == '\n' {
// Skip \r\n--we already wrote \n.
} else {
d.buf.WriteByte(b)
}
b0, b1 = b1, b
}
data := d.buf.Bytes()
data = data[0 : len(data)-trunc]
// Inspect each rune for being a disallowed character.
buf := data
for len(buf) > 0 {
r, size := utf8.DecodeRune(buf)
if r == utf8.RuneError && size == 1 {
d.err = d.syntaxError("invalid UTF-8")
return nil
}
buf = buf[size:]
if !isInCharacterRange(r) {
d.err = d.syntaxError(fmt.Sprintf("illegal character code %U", r))
return nil
}
}
return data
}
// Decide whether the given rune is in the XML Character Range, per
// the Char production of https://www.xml.com/axml/testaxml.htm,
// Section 2.2 Characters.
func isInCharacterRange(r rune) (inrange bool) {
return r == 0x09 ||
r == 0x0A ||
r == 0x0D ||
r >= 0x20 && r <= 0xD7FF ||
r >= 0xE000 && r <= 0xFFFD ||
r >= 0x10000 && r <= 0x10FFFF
}
// Get name space name: name with a : stuck in the middle.
// The part before the : is the name space identifier.
func (d *Decoder) nsname() (name Name, ok bool) {
s, ok := d.name()
if !ok {
return
}
if strings.Count(s, ":") > 1 {
return name, false
} else if space, local, ok := strings.Cut(s, ":"); !ok || space == "" || local == "" {
name.Local = s
} else {
name.Space = space
name.Local = local
}
return name, true
}
// Get name: /first(first|second)*/
// Do not set d.err if the name is missing (unless unexpected EOF is received):
// let the caller provide better context.
func (d *Decoder) name() (s string, ok bool) {
d.buf.Reset()
if !d.readName() {
return "", false
}
// Now we check the characters.
b := d.buf.Bytes()
if !isName(b) {
d.err = d.syntaxError("invalid XML name: " + string(b))
return "", false
}
return string(b), true
}
// Read a name and append its bytes to d.buf.
// The name is delimited by any single-byte character not valid in names.
// All multi-byte characters are accepted; the caller must check their validity.
func (d *Decoder) readName() (ok bool) {
var b byte
if b, ok = d.mustgetc(); !ok {
return
}
if b < utf8.RuneSelf && !isNameByte(b) {
d.ungetc(b)
return false
}
d.buf.WriteByte(b)
for {
if b, ok = d.mustgetc(); !ok {
return
}
if b < utf8.RuneSelf && !isNameByte(b) {
d.ungetc(b)
break
}
d.buf.WriteByte(b)
}
return true
}
func isNameByte(c byte) bool {
return 'A' <= c && c <= 'Z' ||
'a' <= c && c <= 'z' ||
'0' <= c && c <= '9' ||
c == '_' || c == ':' || c == '.' || c == '-'
}
func isName(s []byte) bool {
if len(s) == 0 {
return false
}
c, n := utf8.DecodeRune(s)
if c == utf8.RuneError && n == 1 {
return false
}
if !unicode.Is(first, c) {
return false
}
for n < len(s) {
s = s[n:]
c, n = utf8.DecodeRune(s)
if c == utf8.RuneError && n == 1 {
return false
}
if !unicode.Is(first, c) && !unicode.Is(second, c) {
return false
}
}
return true
}
func isNameString(s string) bool {
if len(s) == 0 {
return false
}
c, n := utf8.DecodeRuneInString(s)
if c == utf8.RuneError && n == 1 {
return false
}
if !unicode.Is(first, c) {
return false
}
for n < len(s) {
s = s[n:]
c, n = utf8.DecodeRuneInString(s)
if c == utf8.RuneError && n == 1 {
return false
}
if !unicode.Is(first, c) && !unicode.Is(second, c) {
return false
}
}
return true
}
// These tables were generated by cut and paste from Appendix B of
// the XML spec at https://www.xml.com/axml/testaxml.htm
// and then reformatting. First corresponds to (Letter | '_' | ':')
// and second corresponds to NameChar.
var first = &unicode.RangeTable{
R16: []unicode.Range16{
{0x003A, 0x003A, 1},
{0x0041, 0x005A, 1},
{0x005F, 0x005F, 1},
{0x0061, 0x007A, 1},
{0x00C0, 0x00D6, 1},
{0x00D8, 0x00F6, 1},
{0x00F8, 0x00FF, 1},
{0x0100, 0x0131, 1},
{0x0134, 0x013E, 1},
{0x0141, 0x0148, 1},
{0x014A, 0x017E, 1},
{0x0180, 0x01C3, 1},
{0x01CD, 0x01F0, 1},
{0x01F4, 0x01F5, 1},
{0x01FA, 0x0217, 1},
{0x0250, 0x02A8, 1},
{0x02BB, 0x02C1, 1},
{0x0386, 0x0386, 1},
{0x0388, 0x038A, 1},
{0x038C, 0x038C, 1},
{0x038E, 0x03A1, 1},
{0x03A3, 0x03CE, 1},
{0x03D0, 0x03D6, 1},
{0x03DA, 0x03E0, 2},
{0x03E2, 0x03F3, 1},
{0x0401, 0x040C, 1},
{0x040E, 0x044F, 1},
{0x0451, 0x045C, 1},
{0x045E, 0x0481, 1},
{0x0490, 0x04C4, 1},
{0x04C7, 0x04C8, 1},
{0x04CB, 0x04CC, 1},
{0x04D0, 0x04EB, 1},
{0x04EE, 0x04F5, 1},
{0x04F8, 0x04F9, 1},
{0x0531, 0x0556, 1},
{0x0559, 0x0559, 1},
{0x0561, 0x0586, 1},
{0x05D0, 0x05EA, 1},
{0x05F0, 0x05F2, 1},
{0x0621, 0x063A, 1},
{0x0641, 0x064A, 1},
{0x0671, 0x06B7, 1},
{0x06BA, 0x06BE, 1},
{0x06C0, 0x06CE, 1},
{0x06D0, 0x06D3, 1},
{0x06D5, 0x06D5, 1},
{0x06E5, 0x06E6, 1},
{0x0905, 0x0939, 1},
{0x093D, 0x093D, 1},
{0x0958, 0x0961, 1},
{0x0985, 0x098C, 1},
{0x098F, 0x0990, 1},
{0x0993, 0x09A8, 1},
{0x09AA, 0x09B0, 1},
{0x09B2, 0x09B2, 1},
{0x09B6, 0x09B9, 1},
{0x09DC, 0x09DD, 1},
{0x09DF, 0x09E1, 1},
{0x09F0, 0x09F1, 1},
{0x0A05, 0x0A0A, 1},
{0x0A0F, 0x0A10, 1},
{0x0A13, 0x0A28, 1},
{0x0A2A, 0x0A30, 1},
{0x0A32, 0x0A33, 1},
{0x0A35, 0x0A36, 1},
{0x0A38, 0x0A39, 1},
{0x0A59, 0x0A5C, 1},
{0x0A5E, 0x0A5E, 1},
{0x0A72, 0x0A74, 1},
{0x0A85, 0x0A8B, 1},
{0x0A8D, 0x0A8D, 1},
{0x0A8F, 0x0A91, 1},
{0x0A93, 0x0AA8, 1},
{0x0AAA, 0x0AB0, 1},
{0x0AB2, 0x0AB3, 1},
{0x0AB5, 0x0AB9, 1},
{0x0ABD, 0x0AE0, 0x23},
{0x0B05, 0x0B0C, 1},
{0x0B0F, 0x0B10, 1},
{0x0B13, 0x0B28, 1},
{0x0B2A, 0x0B30, 1},
{0x0B32, 0x0B33, 1},
{0x0B36, 0x0B39, 1},
{0x0B3D, 0x0B3D, 1},
{0x0B5C, 0x0B5D, 1},
{0x0B5F, 0x0B61, 1},
{0x0B85, 0x0B8A, 1},
{0x0B8E, 0x0B90, 1},
{0x0B92, 0x0B95, 1},
{0x0B99, 0x0B9A, 1},
{0x0B9C, 0x0B9C, 1},
{0x0B9E, 0x0B9F, 1},
{0x0BA3, 0x0BA4, 1},
{0x0BA8, 0x0BAA, 1},
{0x0BAE, 0x0BB5, 1},
{0x0BB7, 0x0BB9, 1},
{0x0C05, 0x0C0C, 1},
{0x0C0E, 0x0C10, 1},
{0x0C12, 0x0C28, 1},
{0x0C2A, 0x0C33, 1},
{0x0C35, 0x0C39, 1},
{0x0C60, 0x0C61, 1},
{0x0C85, 0x0C8C, 1},
{0x0C8E, 0x0C90, 1},
{0x0C92, 0x0CA8, 1},
{0x0CAA, 0x0CB3, 1},
{0x0CB5, 0x0CB9, 1},
{0x0CDE, 0x0CDE, 1},
{0x0CE0, 0x0CE1, 1},
{0x0D05, 0x0D0C, 1},
{0x0D0E, 0x0D10, 1},
{0x0D12, 0x0D28, 1},
{0x0D2A, 0x0D39, 1},
{0x0D60, 0x0D61, 1},
{0x0E01, 0x0E2E, 1},
{0x0E30, 0x0E30, 1},
{0x0E32, 0x0E33, 1},
{0x0E40, 0x0E45, 1},
{0x0E81, 0x0E82, 1},
{0x0E84, 0x0E84, 1},
{0x0E87, 0x0E88, 1},
{0x0E8A, 0x0E8D, 3},
{0x0E94, 0x0E97, 1},
{0x0E99, 0x0E9F, 1},
{0x0EA1, 0x0EA3, 1},
{0x0EA5, 0x0EA7, 2},
{0x0EAA, 0x0EAB, 1},
{0x0EAD, 0x0EAE, 1},
{0x0EB0, 0x0EB0, 1},
{0x0EB2, 0x0EB3, 1},
{0x0EBD, 0x0EBD, 1},
{0x0EC0, 0x0EC4, 1},
{0x0F40, 0x0F47, 1},
{0x0F49, 0x0F69, 1},
{0x10A0, 0x10C5, 1},
{0x10D0, 0x10F6, 1},
{0x1100, 0x1100, 1},
{0x1102, 0x1103, 1},
{0x1105, 0x1107, 1},
{0x1109, 0x1109, 1},
{0x110B, 0x110C, 1},
{0x110E, 0x1112, 1},
{0x113C, 0x1140, 2},
{0x114C, 0x1150, 2},
{0x1154, 0x1155, 1},
{0x1159, 0x1159, 1},
{0x115F, 0x1161, 1},
{0x1163, 0x1169, 2},
{0x116D, 0x116E, 1},
{0x1172, 0x1173, 1},
{0x1175, 0x119E, 0x119E - 0x1175},
{0x11A8, 0x11AB, 0x11AB - 0x11A8},
{0x11AE, 0x11AF, 1},
{0x11B7, 0x11B8, 1},
{0x11BA, 0x11BA, 1},
{0x11BC, 0x11C2, 1},
{0x11EB, 0x11F0, 0x11F0 - 0x11EB},
{0x11F9, 0x11F9, 1},
{0x1E00, 0x1E9B, 1},
{0x1EA0, 0x1EF9, 1},
{0x1F00, 0x1F15, 1},
{0x1F18, 0x1F1D, 1},
{0x1F20, 0x1F45, 1},
{0x1F48, 0x1F4D, 1},
{0x1F50, 0x1F57, 1},
{0x1F59, 0x1F5B, 0x1F5B - 0x1F59},
{0x1F5D, 0x1F5D, 1},
{0x1F5F, 0x1F7D, 1},
{0x1F80, 0x1FB4, 1},
{0x1FB6, 0x1FBC, 1},
{0x1FBE, 0x1FBE, 1},
{0x1FC2, 0x1FC4, 1},
{0x1FC6, 0x1FCC, 1},
{0x1FD0, 0x1FD3, 1},
{0x1FD6, 0x1FDB, 1},
{0x1FE0, 0x1FEC, 1},
{0x1FF2, 0x1FF4, 1},
{0x1FF6, 0x1FFC, 1},
{0x2126, 0x2126, 1},
{0x212A, 0x212B, 1},
{0x212E, 0x212E, 1},
{0x2180, 0x2182, 1},
{0x3007, 0x3007, 1},
{0x3021, 0x3029, 1},
{0x3041, 0x3094, 1},
{0x30A1, 0x30FA, 1},
{0x3105, 0x312C, 1},
{0x4E00, 0x9FA5, 1},
{0xAC00, 0xD7A3, 1},
},
}
var second = &unicode.RangeTable{
R16: []unicode.Range16{
{0x002D, 0x002E, 1},
{0x0030, 0x0039, 1},
{0x00B7, 0x00B7, 1},
{0x02D0, 0x02D1, 1},
{0x0300, 0x0345, 1},
{0x0360, 0x0361, 1},
{0x0387, 0x0387, 1},
{0x0483, 0x0486, 1},
{0x0591, 0x05A1, 1},
{0x05A3, 0x05B9, 1},
{0x05BB, 0x05BD, 1},
{0x05BF, 0x05BF, 1},
{0x05C1, 0x05C2, 1},
{0x05C4, 0x0640, 0x0640 - 0x05C4},
{0x064B, 0x0652, 1},
{0x0660, 0x0669, 1},
{0x0670, 0x0670, 1},
{0x06D6, 0x06DC, 1},
{0x06DD, 0x06DF, 1},
{0x06E0, 0x06E4, 1},
{0x06E7, 0x06E8, 1},
{0x06EA, 0x06ED, 1},
{0x06F0, 0x06F9, 1},
{0x0901, 0x0903, 1},
{0x093C, 0x093C, 1},
{0x093E, 0x094C, 1},
{0x094D, 0x094D, 1},
{0x0951, 0x0954, 1},
{0x0962, 0x0963, 1},
{0x0966, 0x096F, 1},
{0x0981, 0x0983, 1},
{0x09BC, 0x09BC, 1},
{0x09BE, 0x09BF, 1},
{0x09C0, 0x09C4, 1},
{0x09C7, 0x09C8, 1},
{0x09CB, 0x09CD, 1},
{0x09D7, 0x09D7, 1},
{0x09E2, 0x09E3, 1},
{0x09E6, 0x09EF, 1},
{0x0A02, 0x0A3C, 0x3A},
{0x0A3E, 0x0A3F, 1},
{0x0A40, 0x0A42, 1},
{0x0A47, 0x0A48, 1},
{0x0A4B, 0x0A4D, 1},
{0x0A66, 0x0A6F, 1},
{0x0A70, 0x0A71, 1},
{0x0A81, 0x0A83, 1},
{0x0ABC, 0x0ABC, 1},
{0x0ABE, 0x0AC5, 1},
{0x0AC7, 0x0AC9, 1},
{0x0ACB, 0x0ACD, 1},
{0x0AE6, 0x0AEF, 1},
{0x0B01, 0x0B03, 1},
{0x0B3C, 0x0B3C, 1},
{0x0B3E, 0x0B43, 1},
{0x0B47, 0x0B48, 1},
{0x0B4B, 0x0B4D, 1},
{0x0B56, 0x0B57, 1},
{0x0B66, 0x0B6F, 1},
{0x0B82, 0x0B83, 1},
{0x0BBE, 0x0BC2, 1},
{0x0BC6, 0x0BC8, 1},
{0x0BCA, 0x0BCD, 1},
{0x0BD7, 0x0BD7, 1},
{0x0BE7, 0x0BEF, 1},
{0x0C01, 0x0C03, 1},
{0x0C3E, 0x0C44, 1},
{0x0C46, 0x0C48, 1},
{0x0C4A, 0x0C4D, 1},
{0x0C55, 0x0C56, 1},
{0x0C66, 0x0C6F, 1},
{0x0C82, 0x0C83, 1},
{0x0CBE, 0x0CC4, 1},
{0x0CC6, 0x0CC8, 1},
{0x0CCA, 0x0CCD, 1},
{0x0CD5, 0x0CD6, 1},
{0x0CE6, 0x0CEF, 1},
{0x0D02, 0x0D03, 1},
{0x0D3E, 0x0D43, 1},
{0x0D46, 0x0D48, 1},
{0x0D4A, 0x0D4D, 1},
{0x0D57, 0x0D57, 1},
{0x0D66, 0x0D6F, 1},
{0x0E31, 0x0E31, 1},
{0x0E34, 0x0E3A, 1},
{0x0E46, 0x0E46, 1},
{0x0E47, 0x0E4E, 1},
{0x0E50, 0x0E59, 1},
{0x0EB1, 0x0EB1, 1},
{0x0EB4, 0x0EB9, 1},
{0x0EBB, 0x0EBC, 1},
{0x0EC6, 0x0EC6, 1},
{0x0EC8, 0x0ECD, 1},
{0x0ED0, 0x0ED9, 1},
{0x0F18, 0x0F19, 1},
{0x0F20, 0x0F29, 1},
{0x0F35, 0x0F39, 2},
{0x0F3E, 0x0F3F, 1},
{0x0F71, 0x0F84, 1},
{0x0F86, 0x0F8B, 1},
{0x0F90, 0x0F95, 1},
{0x0F97, 0x0F97, 1},
{0x0F99, 0x0FAD, 1},
{0x0FB1, 0x0FB7, 1},
{0x0FB9, 0x0FB9, 1},
{0x20D0, 0x20DC, 1},
{0x20E1, 0x3005, 0x3005 - 0x20E1},
{0x302A, 0x302F, 1},
{0x3031, 0x3035, 1},
{0x3099, 0x309A, 1},
{0x309D, 0x309E, 1},
{0x30FC, 0x30FE, 1},
},
}
// HTMLEntity is an entity map containing translations for the
// standard HTML entity characters.
//
// See the [Decoder.Strict] and [Decoder.Entity] fields' documentation.
var HTMLEntity map[string]string = htmlEntity
var htmlEntity = map[string]string{
/*
hget http://www.w3.org/TR/html4/sgml/entities.html |
ssam '
,y /\>/ x/\<(.|\n)+/ s/\n/ /g
,x v/^\<!ENTITY/d
,s/\<!ENTITY ([^ ]+) .*U\+([0-9A-F][0-9A-F][0-9A-F][0-9A-F]) .+/ "\1": "\\u\2",/g
'
*/
"nbsp": "\u00A0",
"iexcl": "\u00A1",
"cent": "\u00A2",
"pound": "\u00A3",
"curren": "\u00A4",
"yen": "\u00A5",
"brvbar": "\u00A6",
"sect": "\u00A7",
"uml": "\u00A8",
"copy": "\u00A9",
"ordf": "\u00AA",
"laquo": "\u00AB",
"not": "\u00AC",
"shy": "\u00AD",
"reg": "\u00AE",
"macr": "\u00AF",
"deg": "\u00B0",
"plusmn": "\u00B1",
"sup2": "\u00B2",
"sup3": "\u00B3",
"acute": "\u00B4",
"micro": "\u00B5",
"para": "\u00B6",
"middot": "\u00B7",
"cedil": "\u00B8",
"sup1": "\u00B9",
"ordm": "\u00BA",
"raquo": "\u00BB",
"frac14": "\u00BC",
"frac12": "\u00BD",
"frac34": "\u00BE",
"iquest": "\u00BF",
"Agrave": "\u00C0",
"Aacute": "\u00C1",
"Acirc": "\u00C2",
"Atilde": "\u00C3",
"Auml": "\u00C4",
"Aring": "\u00C5",
"AElig": "\u00C6",
"Ccedil": "\u00C7",
"Egrave": "\u00C8",
"Eacute": "\u00C9",
"Ecirc": "\u00CA",
"Euml": "\u00CB",
"Igrave": "\u00CC",
"Iacute": "\u00CD",
"Icirc": "\u00CE",
"Iuml": "\u00CF",
"ETH": "\u00D0",
"Ntilde": "\u00D1",
"Ograve": "\u00D2",
"Oacute": "\u00D3",
"Ocirc": "\u00D4",
"Otilde": "\u00D5",
"Ouml": "\u00D6",
"times": "\u00D7",
"Oslash": "\u00D8",
"Ugrave": "\u00D9",
"Uacute": "\u00DA",
"Ucirc": "\u00DB",
"Uuml": "\u00DC",
"Yacute": "\u00DD",
"THORN": "\u00DE",
"szlig": "\u00DF",
"agrave": "\u00E0",
"aacute": "\u00E1",
"acirc": "\u00E2",
"atilde": "\u00E3",
"auml": "\u00E4",
"aring": "\u00E5",
"aelig": "\u00E6",
"ccedil": "\u00E7",
"egrave": "\u00E8",
"eacute": "\u00E9",
"ecirc": "\u00EA",
"euml": "\u00EB",
"igrave": "\u00EC",
"iacute": "\u00ED",
"icirc": "\u00EE",
"iuml": "\u00EF",
"eth": "\u00F0",
"ntilde": "\u00F1",
"ograve": "\u00F2",
"oacute": "\u00F3",
"ocirc": "\u00F4",
"otilde": "\u00F5",
"ouml": "\u00F6",
"divide": "\u00F7",
"oslash": "\u00F8",
"ugrave": "\u00F9",
"uacute": "\u00FA",
"ucirc": "\u00FB",
"uuml": "\u00FC",
"yacute": "\u00FD",
"thorn": "\u00FE",
"yuml": "\u00FF",
"fnof": "\u0192",
"Alpha": "\u0391",
"Beta": "\u0392",
"Gamma": "\u0393",
"Delta": "\u0394",
"Epsilon": "\u0395",
"Zeta": "\u0396",
"Eta": "\u0397",
"Theta": "\u0398",
"Iota": "\u0399",
"Kappa": "\u039A",
"Lambda": "\u039B",
"Mu": "\u039C",
"Nu": "\u039D",
"Xi": "\u039E",
"Omicron": "\u039F",
"Pi": "\u03A0",
"Rho": "\u03A1",
"Sigma": "\u03A3",
"Tau": "\u03A4",
"Upsilon": "\u03A5",
"Phi": "\u03A6",
"Chi": "\u03A7",
"Psi": "\u03A8",
"Omega": "\u03A9",
"alpha": "\u03B1",
"beta": "\u03B2",
"gamma": "\u03B3",
"delta": "\u03B4",
"epsilon": "\u03B5",
"zeta": "\u03B6",
"eta": "\u03B7",
"theta": "\u03B8",
"iota": "\u03B9",
"kappa": "\u03BA",
"lambda": "\u03BB",
"mu": "\u03BC",
"nu": "\u03BD",
"xi": "\u03BE",
"omicron": "\u03BF",
"pi": "\u03C0",
"rho": "\u03C1",
"sigmaf": "\u03C2",
"sigma": "\u03C3",
"tau": "\u03C4",
"upsilon": "\u03C5",
"phi": "\u03C6",
"chi": "\u03C7",
"psi": "\u03C8",
"omega": "\u03C9",
"thetasym": "\u03D1",
"upsih": "\u03D2",
"piv": "\u03D6",
"bull": "\u2022",
"hellip": "\u2026",
"prime": "\u2032",
"Prime": "\u2033",
"oline": "\u203E",
"frasl": "\u2044",
"weierp": "\u2118",
"image": "\u2111",
"real": "\u211C",
"trade": "\u2122",
"alefsym": "\u2135",
"larr": "\u2190",
"uarr": "\u2191",
"rarr": "\u2192",
"darr": "\u2193",
"harr": "\u2194",
"crarr": "\u21B5",
"lArr": "\u21D0",
"uArr": "\u21D1",
"rArr": "\u21D2",
"dArr": "\u21D3",
"hArr": "\u21D4",
"forall": "\u2200",
"part": "\u2202",
"exist": "\u2203",
"empty": "\u2205",
"nabla": "\u2207",
"isin": "\u2208",
"notin": "\u2209",
"ni": "\u220B",
"prod": "\u220F",
"sum": "\u2211",
"minus": "\u2212",
"lowast": "\u2217",
"radic": "\u221A",
"prop": "\u221D",
"infin": "\u221E",
"ang": "\u2220",
"and": "\u2227",
"or": "\u2228",
"cap": "\u2229",
"cup": "\u222A",
"int": "\u222B",
"there4": "\u2234",
"sim": "\u223C",
"cong": "\u2245",
"asymp": "\u2248",
"ne": "\u2260",
"equiv": "\u2261",
"le": "\u2264",
"ge": "\u2265",
"sub": "\u2282",
"sup": "\u2283",
"nsub": "\u2284",
"sube": "\u2286",
"supe": "\u2287",
"oplus": "\u2295",
"otimes": "\u2297",
"perp": "\u22A5",
"sdot": "\u22C5",
"lceil": "\u2308",
"rceil": "\u2309",
"lfloor": "\u230A",
"rfloor": "\u230B",
"lang": "\u2329",
"rang": "\u232A",
"loz": "\u25CA",
"spades": "\u2660",
"clubs": "\u2663",
"hearts": "\u2665",
"diams": "\u2666",
"quot": "\u0022",
"amp": "\u0026",
"lt": "\u003C",
"gt": "\u003E",
"OElig": "\u0152",
"oelig": "\u0153",
"Scaron": "\u0160",
"scaron": "\u0161",
"Yuml": "\u0178",
"circ": "\u02C6",
"tilde": "\u02DC",
"ensp": "\u2002",
"emsp": "\u2003",
"thinsp": "\u2009",
"zwnj": "\u200C",
"zwj": "\u200D",
"lrm": "\u200E",
"rlm": "\u200F",
"ndash": "\u2013",
"mdash": "\u2014",
"lsquo": "\u2018",
"rsquo": "\u2019",
"sbquo": "\u201A",
"ldquo": "\u201C",
"rdquo": "\u201D",
"bdquo": "\u201E",
"dagger": "\u2020",
"Dagger": "\u2021",
"permil": "\u2030",
"lsaquo": "\u2039",
"rsaquo": "\u203A",
"euro": "\u20AC",
}
// HTMLAutoClose is the set of HTML elements that
// should be considered to close automatically.
//
// See the [Decoder.Strict] and [Decoder.Entity] fields' documentation.
var HTMLAutoClose []string = htmlAutoClose
var htmlAutoClose = []string{
/*
hget http://www.w3.org/TR/html4/loose.dtd |
9 sed -n 's/<!ELEMENT ([^ ]*) +- O EMPTY.+/ "\1",/p' | tr A-Z a-z
*/
"basefont",
"br",
"area",
"link",
"img",
"param",
"hr",
"input",
"col",
"frame",
"isindex",
"base",
"meta",
}
var (
escQuot = []byte(""") // shorter than """
escApos = []byte("'") // shorter than "'"
escAmp = []byte("&")
escLT = []byte("<")
escGT = []byte(">")
escTab = []byte("	")
escNL = []byte("
")
escCR = []byte("
")
escFFFD = []byte("\uFFFD") // Unicode replacement character
)
// EscapeText writes to w the properly escaped XML equivalent
// of the plain text data s.
func EscapeText(w io.Writer, s []byte) error {
return escapeText(w, s, true)
}
// escapeText writes to w the properly escaped XML equivalent
// of the plain text data s. If escapeNewline is true, newline
// characters will be escaped.
func escapeText(w io.Writer, s []byte, escapeNewline bool) error {
var esc []byte
last := 0
for i := 0; i < len(s); {
r, width := utf8.DecodeRune(s[i:])
i += width
switch r {
case '"':
esc = escQuot
case '\'':
esc = escApos
case '&':
esc = escAmp
case '<':
esc = escLT
case '>':
esc = escGT
case '\t':
esc = escTab
case '\n':
if !escapeNewline {
continue
}
esc = escNL
case '\r':
esc = escCR
default:
if !isInCharacterRange(r) || (r == 0xFFFD && width == 1) {
esc = escFFFD
break
}
continue
}
if _, err := w.Write(s[last : i-width]); err != nil {
return err
}
if _, err := w.Write(esc); err != nil {
return err
}
last = i
}
_, err := w.Write(s[last:])
return err
}
// EscapeString writes to p the properly escaped XML equivalent
// of the plain text data s.
func (p *printer) EscapeString(s string) {
var esc []byte
last := 0
for i := 0; i < len(s); {
r, width := utf8.DecodeRuneInString(s[i:])
i += width
switch r {
case '"':
esc = escQuot
case '\'':
esc = escApos
case '&':
esc = escAmp
case '<':
esc = escLT
case '>':
esc = escGT
case '\t':
esc = escTab
case '\n':
esc = escNL
case '\r':
esc = escCR
default:
if !isInCharacterRange(r) || (r == 0xFFFD && width == 1) {
esc = escFFFD
break
}
continue
}
p.WriteString(s[last : i-width])
p.Write(esc)
last = i
}
p.WriteString(s[last:])
}
// Escape is like [EscapeText] but omits the error return value.
// It is provided for backwards compatibility with Go 1.0.
// Code targeting Go 1.1 or later should use [EscapeText].
func Escape(w io.Writer, s []byte) {
EscapeText(w, s)
}
var (
cdataStart = []byte("<![CDATA[")
cdataEnd = []byte("]]>")
cdataEscape = []byte("]]]]><![CDATA[>")
)
// emitCDATA writes to w the CDATA-wrapped plain text data s.
// It escapes CDATA directives nested in s.
func emitCDATA(w io.Writer, s []byte) error {
if len(s) == 0 {
return nil
}
if _, err := w.Write(cdataStart); err != nil {
return err
}
for {
before, after, ok := bytes.Cut(s, cdataEnd)
if !ok {
break
}
// Found a nested CDATA directive end.
if _, err := w.Write(before); err != nil {
return err
}
if _, err := w.Write(cdataEscape); err != nil {
return err
}
s = after
}
if _, err := w.Write(s); err != nil {
return err
}
_, err := w.Write(cdataEnd)
return err
}
// procInst parses the `param="..."` or `param='...'`
// value out of the provided string, returning "" if not found.
func procInst(param, s string) string {
// TODO: this parsing is somewhat lame and not exact.
// It works for all actual cases, though.
param = param + "="
lenp := len(param)
i := 0
var sep byte
for i < len(s) {
sub := s[i:]
k := strings.Index(sub, param)
if k < 0 || lenp+k >= len(sub) {
return ""
}
i += lenp + k + 1
if c := sub[lenp+k]; c == '\'' || c == '"' {
sep = c
break
}
}
if sep == 0 {
return ""
}
j := strings.IndexByte(s[i:], sep)
if j < 0 {
return ""
}
return s[i : i+j]
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package errors implements functions to manipulate errors.
//
// The [New] function creates errors whose only content is a text message.
//
// An error e wraps another error if e's type has one of the methods
//
// Unwrap() error
// Unwrap() []error
//
// If e.Unwrap() returns a non-nil error w or a slice containing w,
// then we say that e wraps w. A nil error returned from e.Unwrap()
// indicates that e does not wrap any error. It is invalid for an
// Unwrap method to return an []error containing a nil error value.
//
// An easy way to create wrapped errors is to call [fmt.Errorf] and apply
// the %w verb to the error argument:
//
// wrapsErr := fmt.Errorf("... %w ...", ..., err, ...)
//
// Successive unwrapping of an error creates a tree. The [Is] and [As]
// functions inspect an error's tree by examining first the error
// itself followed by the tree of each of its children in turn
// (pre-order, depth-first traversal).
//
// See https://go.dev/blog/go1.13-errors for a deeper discussion of the
// philosophy of wrapping and when to wrap.
//
// [Is] examines the tree of its first argument looking for an error that
// matches the second. It reports whether it finds a match. It should be
// used in preference to simple equality checks:
//
// if errors.Is(err, fs.ErrExist)
//
// is preferable to
//
// if err == fs.ErrExist
//
// because the former will succeed if err wraps [io/fs.ErrExist].
//
// [As] examines the tree of its first argument looking for an error that can be
// assigned to its second argument, which must be a pointer. If it succeeds, it
// performs the assignment and returns true. Otherwise, it returns false. The form
//
// var perr *fs.PathError
// if errors.As(err, &perr) {
// fmt.Println(perr.Path)
// }
//
// is preferable to
//
// if perr, ok := err.(*fs.PathError); ok {
// fmt.Println(perr.Path)
// }
//
// because the former will succeed if err wraps an [*io/fs.PathError].
package errors
// New returns an error that formats as the given text.
// Each call to New returns a distinct error value even if the text is identical.
func New(text string) error {
return &errorString{text}
}
// errorString is a trivial implementation of error.
type errorString struct {
s string
}
func (e *errorString) Error() string {
return e.s
}
// ErrUnsupported indicates that a requested operation cannot be performed,
// because it is unsupported. For example, a call to [os.Link] when using a
// file system that does not support hard links.
//
// Functions and methods should not return this error but should instead
// return an error including appropriate context that satisfies
//
// errors.Is(err, errors.ErrUnsupported)
//
// either by directly wrapping ErrUnsupported or by implementing an [Is] method.
//
// Functions and methods should document the cases in which an error
// wrapping this will be returned.
var ErrUnsupported = New("unsupported operation")
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package errors
import (
"unsafe"
)
// Join returns an error that wraps the given errors.
// Any nil error values are discarded.
// Join returns nil if every value in errs is nil.
// The error formats as the concatenation of the strings obtained
// by calling the Error method of each element of errs, with a newline
// between each string.
//
// A non-nil error returned by Join implements the Unwrap() []error method.
func Join(errs ...error) error {
n := 0
for _, err := range errs {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
if n == 1 {
for _, err := range errs {
if _, ok := err.(interface {
Unwrap() []error
}); ok {
return err
}
}
}
e := &joinError{
errs: make([]error, 0, n),
}
for _, err := range errs {
if err != nil {
e.errs = append(e.errs, err)
}
}
return e
}
type joinError struct {
errs []error
}
func (e *joinError) Error() string {
// Since Join returns nil if every value in errs is nil,
// e.errs cannot be empty.
if len(e.errs) == 1 {
return e.errs[0].Error()
}
b := []byte(e.errs[0].Error())
for _, err := range e.errs[1:] {
b = append(b, '\n')
b = append(b, err.Error()...)
}
// At this point, b has at least one byte '\n'.
return unsafe.String(&b[0], len(b))
}
func (e *joinError) Unwrap() []error {
return e.errs
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package errors
import (
"internal/reflectlite"
)
// Unwrap returns the result of calling the Unwrap method on err, if err's
// type contains an Unwrap method returning error.
// Otherwise, Unwrap returns nil.
//
// Unwrap only calls a method of the form "Unwrap() error".
// In particular Unwrap does not unwrap errors returned by [Join].
func Unwrap(err error) error {
u, ok := err.(interface {
Unwrap() error
})
if !ok {
return nil
}
return u.Unwrap()
}
// Is reports whether any error in err's tree matches target.
//
// The tree consists of err itself, followed by the errors obtained by repeatedly
// calling its Unwrap() error or Unwrap() []error method. When err wraps multiple
// errors, Is examines err followed by a depth-first traversal of its children.
//
// An error is considered to match a target if it is equal to that target or if
// it implements a method Is(error) bool such that Is(target) returns true.
//
// An error type might provide an Is method so it can be treated as equivalent
// to an existing error. For example, if MyError defines
//
// func (m MyError) Is(target error) bool { return target == fs.ErrExist }
//
// then Is(MyError{}, fs.ErrExist) returns true. See [syscall.Errno.Is] for
// an example in the standard library. An Is method should only shallowly
// compare err and the target and not call [Unwrap] on either.
func Is(err, target error) bool {
if err == nil || target == nil {
return err == target
}
isComparable := reflectlite.TypeOf(target).Comparable()
return is(err, target, isComparable)
}
func is(err, target error, targetComparable bool) bool {
for {
if targetComparable && err == target {
return true
}
if x, ok := err.(interface{ Is(error) bool }); ok && x.Is(target) {
return true
}
switch x := err.(type) {
case interface{ Unwrap() error }:
err = x.Unwrap()
if err == nil {
return false
}
case interface{ Unwrap() []error }:
for _, err := range x.Unwrap() {
if is(err, target, targetComparable) {
return true
}
}
return false
default:
return false
}
}
}
// As finds the first error in err's tree that matches target, and if one is found, sets
// target to that error value and returns true. Otherwise, it returns false.
//
// The tree consists of err itself, followed by the errors obtained by repeatedly
// calling its Unwrap() error or Unwrap() []error method. When err wraps multiple
// errors, As examines err followed by a depth-first traversal of its children.
//
// An error matches target if the error's concrete value is assignable to the value
// pointed to by target, or if the error has a method As(any) bool such that
// As(target) returns true. In the latter case, the As method is responsible for
// setting target.
//
// An error type might provide an As method so it can be treated as if it were a
// different error type.
//
// As panics if target is not a non-nil pointer to either a type that implements
// error, or to any interface type.
func As(err error, target any) bool {
if err == nil {
return false
}
if target == nil {
panic("errors: target cannot be nil")
}
val := reflectlite.ValueOf(target)
typ := val.Type()
if typ.Kind() != reflectlite.Ptr || val.IsNil() {
panic("errors: target must be a non-nil pointer")
}
targetType := typ.Elem()
if targetType.Kind() != reflectlite.Interface && !targetType.Implements(errorType) {
panic("errors: *target must be interface or implement error")
}
return as(err, target, val, targetType)
}
func as(err error, target any, targetVal reflectlite.Value, targetType reflectlite.Type) bool {
for {
if reflectlite.TypeOf(err).AssignableTo(targetType) {
targetVal.Elem().Set(reflectlite.ValueOf(err))
return true
}
if x, ok := err.(interface{ As(any) bool }); ok && x.As(target) {
return true
}
switch x := err.(type) {
case interface{ Unwrap() error }:
err = x.Unwrap()
if err == nil {
return false
}
case interface{ Unwrap() []error }:
for _, err := range x.Unwrap() {
if err == nil {
continue
}
if as(err, target, targetVal, targetType) {
return true
}
}
return false
default:
return false
}
}
}
var errorType = reflectlite.TypeOf((*error)(nil)).Elem()
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package expvar provides a standardized interface to public variables, such
// as operation counters in servers. It exposes these variables via HTTP at
// /debug/vars in JSON format. As of Go 1.22, the /debug/vars request must
// use GET.
//
// Operations to set or modify these public variables are atomic.
//
// In addition to adding the HTTP handler, this package registers the
// following variables:
//
// cmdline os.Args
// memstats runtime.Memstats
//
// The package is sometimes only imported for the side effect of
// registering its HTTP handler and the above variables. To use it
// this way, link this package into your program:
//
// import _ "expvar"
package expvar
import (
"encoding/json"
"internal/godebug"
"log"
"math"
"net/http"
"os"
"runtime"
"slices"
"strconv"
"sync"
"sync/atomic"
"unicode/utf8"
)
// Var is an abstract type for all exported variables.
type Var interface {
// String returns a valid JSON value for the variable.
// Types with String methods that do not return valid JSON
// (such as time.Time) must not be used as a Var.
String() string
}
type jsonVar interface {
// appendJSON appends the JSON representation of the receiver to b.
appendJSON(b []byte) []byte
}
// Int is a 64-bit integer variable that satisfies the [Var] interface.
type Int struct {
i atomic.Int64
}
func (v *Int) Value() int64 {
return v.i.Load()
}
func (v *Int) String() string {
return string(v.appendJSON(nil))
}
func (v *Int) appendJSON(b []byte) []byte {
return strconv.AppendInt(b, v.i.Load(), 10)
}
func (v *Int) Add(delta int64) {
v.i.Add(delta)
}
func (v *Int) Set(value int64) {
v.i.Store(value)
}
// Float is a 64-bit float variable that satisfies the [Var] interface.
type Float struct {
f atomic.Uint64
}
func (v *Float) Value() float64 {
return math.Float64frombits(v.f.Load())
}
func (v *Float) String() string {
return string(v.appendJSON(nil))
}
func (v *Float) appendJSON(b []byte) []byte {
return strconv.AppendFloat(b, math.Float64frombits(v.f.Load()), 'g', -1, 64)
}
// Add adds delta to v.
func (v *Float) Add(delta float64) {
for {
cur := v.f.Load()
curVal := math.Float64frombits(cur)
nxtVal := curVal + delta
nxt := math.Float64bits(nxtVal)
if v.f.CompareAndSwap(cur, nxt) {
return
}
}
}
// Set sets v to value.
func (v *Float) Set(value float64) {
v.f.Store(math.Float64bits(value))
}
// Map is a string-to-Var map variable that satisfies the [Var] interface.
type Map struct {
m sync.Map // map[string]Var
keysMu sync.RWMutex
keys []string // sorted
}
// KeyValue represents a single entry in a [Map].
type KeyValue struct {
Key string
Value Var
}
func (v *Map) String() string {
return string(v.appendJSON(nil))
}
func (v *Map) appendJSON(b []byte) []byte {
return v.appendJSONMayExpand(b, false)
}
func (v *Map) appendJSONMayExpand(b []byte, expand bool) []byte {
afterCommaDelim := byte(' ')
mayAppendNewline := func(b []byte) []byte { return b }
if expand {
afterCommaDelim = '\n'
mayAppendNewline = func(b []byte) []byte { return append(b, '\n') }
}
b = append(b, '{')
b = mayAppendNewline(b)
first := true
v.Do(func(kv KeyValue) {
if !first {
b = append(b, ',', afterCommaDelim)
}
first = false
b = appendJSONQuote(b, kv.Key)
b = append(b, ':', ' ')
switch v := kv.Value.(type) {
case nil:
b = append(b, "null"...)
case jsonVar:
b = v.appendJSON(b)
default:
b = append(b, v.String()...)
}
})
b = mayAppendNewline(b)
b = append(b, '}')
b = mayAppendNewline(b)
return b
}
// Init removes all keys from the map.
func (v *Map) Init() *Map {
v.keysMu.Lock()
defer v.keysMu.Unlock()
v.keys = v.keys[:0]
v.m.Clear()
return v
}
// addKey updates the sorted list of keys in v.keys.
func (v *Map) addKey(key string) {
v.keysMu.Lock()
defer v.keysMu.Unlock()
// Using insertion sort to place key into the already-sorted v.keys.
i, found := slices.BinarySearch(v.keys, key)
if found {
return
}
v.keys = slices.Insert(v.keys, i, key)
}
func (v *Map) Get(key string) Var {
i, _ := v.m.Load(key)
av, _ := i.(Var)
return av
}
func (v *Map) Set(key string, av Var) {
// Before we store the value, check to see whether the key is new. Try a Load
// before LoadOrStore: LoadOrStore causes the key interface to escape even on
// the Load path.
if _, ok := v.m.Load(key); !ok {
if _, dup := v.m.LoadOrStore(key, av); !dup {
v.addKey(key)
return
}
}
v.m.Store(key, av)
}
// Add adds delta to the *[Int] value stored under the given map key.
func (v *Map) Add(key string, delta int64) {
i, ok := v.m.Load(key)
if !ok {
var dup bool
i, dup = v.m.LoadOrStore(key, new(Int))
if !dup {
v.addKey(key)
}
}
// Add to Int; ignore otherwise.
if iv, ok := i.(*Int); ok {
iv.Add(delta)
}
}
// AddFloat adds delta to the *[Float] value stored under the given map key.
func (v *Map) AddFloat(key string, delta float64) {
i, ok := v.m.Load(key)
if !ok {
var dup bool
i, dup = v.m.LoadOrStore(key, new(Float))
if !dup {
v.addKey(key)
}
}
// Add to Float; ignore otherwise.
if iv, ok := i.(*Float); ok {
iv.Add(delta)
}
}
// Delete deletes the given key from the map.
func (v *Map) Delete(key string) {
v.keysMu.Lock()
defer v.keysMu.Unlock()
i, found := slices.BinarySearch(v.keys, key)
if found {
v.keys = slices.Delete(v.keys, i, i+1)
v.m.Delete(key)
}
}
// Do calls f for each entry in the map.
// The map is locked during the iteration,
// but existing entries may be concurrently updated.
func (v *Map) Do(f func(KeyValue)) {
v.keysMu.RLock()
defer v.keysMu.RUnlock()
for _, k := range v.keys {
i, _ := v.m.Load(k)
val, _ := i.(Var)
f(KeyValue{k, val})
}
}
// String is a string variable, and satisfies the [Var] interface.
type String struct {
s atomic.Value // string
}
func (v *String) Value() string {
p, _ := v.s.Load().(string)
return p
}
// String implements the [Var] interface. To get the unquoted string
// use [String.Value].
func (v *String) String() string {
return string(v.appendJSON(nil))
}
func (v *String) appendJSON(b []byte) []byte {
return appendJSONQuote(b, v.Value())
}
func (v *String) Set(value string) {
v.s.Store(value)
}
// Func implements [Var] by calling the function
// and formatting the returned value using JSON.
type Func func() any
func (f Func) Value() any {
return f()
}
func (f Func) String() string {
v, _ := json.Marshal(f())
return string(v)
}
// All published variables.
var vars Map
// Publish declares a named exported variable. This should be called from a
// package's init function when it creates its Vars. If the name is already
// registered then this will log.Panic.
func Publish(name string, v Var) {
if _, dup := vars.m.LoadOrStore(name, v); dup {
log.Panicln("Reuse of exported var name:", name)
}
vars.keysMu.Lock()
defer vars.keysMu.Unlock()
vars.keys = append(vars.keys, name)
slices.Sort(vars.keys)
}
// Get retrieves a named exported variable. It returns nil if the name has
// not been registered.
func Get(name string) Var {
return vars.Get(name)
}
// Convenience functions for creating new exported variables.
func NewInt(name string) *Int {
v := new(Int)
Publish(name, v)
return v
}
func NewFloat(name string) *Float {
v := new(Float)
Publish(name, v)
return v
}
func NewMap(name string) *Map {
v := new(Map).Init()
Publish(name, v)
return v
}
func NewString(name string) *String {
v := new(String)
Publish(name, v)
return v
}
// Do calls f for each exported variable.
// The global variable map is locked during the iteration,
// but existing entries may be concurrently updated.
func Do(f func(KeyValue)) {
vars.Do(f)
}
func expvarHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write(vars.appendJSONMayExpand(nil, true))
}
// Handler returns the expvar HTTP Handler.
//
// This is only needed to install the handler in a non-standard location.
func Handler() http.Handler {
return http.HandlerFunc(expvarHandler)
}
func cmdline() any {
return os.Args
}
func memstats() any {
stats := new(runtime.MemStats)
runtime.ReadMemStats(stats)
return *stats
}
func init() {
if godebug.New("httpmuxgo121").Value() == "1" {
http.HandleFunc("/debug/vars", expvarHandler)
} else {
http.HandleFunc("GET /debug/vars", expvarHandler)
}
Publish("cmdline", Func(cmdline))
Publish("memstats", Func(memstats))
}
// TODO: Use json.appendString instead.
func appendJSONQuote(b []byte, s string) []byte {
const hex = "0123456789abcdef"
b = append(b, '"')
for _, r := range s {
switch {
case r < ' ' || r == '\\' || r == '"' || r == '<' || r == '>' || r == '&' || r == '\u2028' || r == '\u2029':
switch r {
case '\\', '"':
b = append(b, '\\', byte(r))
case '\n':
b = append(b, '\\', 'n')
case '\r':
b = append(b, '\\', 'r')
case '\t':
b = append(b, '\\', 't')
default:
b = append(b, '\\', 'u', hex[(r>>12)&0xf], hex[(r>>8)&0xf], hex[(r>>4)&0xf], hex[(r>>0)&0xf])
}
case r < utf8.RuneSelf:
b = append(b, byte(r))
default:
b = utf8.AppendRune(b, r)
}
}
b = append(b, '"')
return b
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package build
import (
"bytes"
"errors"
"fmt"
"go/ast"
"go/build/constraint"
"go/doc"
"go/token"
"internal/buildcfg"
"internal/godebug"
"internal/goroot"
"internal/goversion"
"internal/platform"
"internal/syslist"
"io"
"io/fs"
"os"
"os/exec"
pathpkg "path"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"unicode"
"unicode/utf8"
_ "unsafe" // for linkname
)
// A Context specifies the supporting context for a build.
type Context struct {
GOARCH string // target architecture
GOOS string // target operating system
GOROOT string // Go root
GOPATH string // Go paths
// Dir is the caller's working directory, or the empty string to use
// the current directory of the running process. In module mode, this is used
// to locate the main module.
//
// If Dir is non-empty, directories passed to Import and ImportDir must
// be absolute.
Dir string
CgoEnabled bool // whether cgo files are included
UseAllFiles bool // use files regardless of go:build lines, file names
Compiler string // compiler to assume when computing target paths
// The build, tool, and release tags specify build constraints
// that should be considered satisfied when processing go:build lines.
// Clients creating a new context may customize BuildTags, which
// defaults to empty, but it is usually an error to customize ToolTags or ReleaseTags.
// ToolTags defaults to build tags appropriate to the current Go toolchain configuration.
// ReleaseTags defaults to the list of Go releases the current release is compatible with.
// BuildTags is not set for the Default build Context.
// In addition to the BuildTags, ToolTags, and ReleaseTags, build constraints
// consider the values of GOARCH and GOOS as satisfied tags.
// The last element in ReleaseTags is assumed to be the current release.
BuildTags []string
ToolTags []string
ReleaseTags []string
// The install suffix specifies a suffix to use in the name of the installation
// directory. By default it is empty, but custom builds that need to keep
// their outputs separate can set InstallSuffix to do so. For example, when
// using the race detector, the go command uses InstallSuffix = "race", so
// that on a Linux/386 system, packages are written to a directory named
// "linux_386_race" instead of the usual "linux_386".
InstallSuffix string
// By default, Import uses the operating system's file system calls
// to read directories and files. To read from other sources,
// callers can set the following functions. They all have default
// behaviors that use the local file system, so clients need only set
// the functions whose behaviors they wish to change.
// JoinPath joins the sequence of path fragments into a single path.
// If JoinPath is nil, Import uses filepath.Join.
JoinPath func(elem ...string) string
// SplitPathList splits the path list into a slice of individual paths.
// If SplitPathList is nil, Import uses filepath.SplitList.
SplitPathList func(list string) []string
// IsAbsPath reports whether path is an absolute path.
// If IsAbsPath is nil, Import uses filepath.IsAbs.
IsAbsPath func(path string) bool
// IsDir reports whether the path names a directory.
// If IsDir is nil, Import calls os.Stat and uses the result's IsDir method.
IsDir func(path string) bool
// HasSubdir reports whether dir is lexically a subdirectory of
// root, perhaps multiple levels below. It does not try to check
// whether dir exists.
// If so, HasSubdir sets rel to a slash-separated path that
// can be joined to root to produce a path equivalent to dir.
// If HasSubdir is nil, Import uses an implementation built on
// filepath.EvalSymlinks.
HasSubdir func(root, dir string) (rel string, ok bool)
// ReadDir returns a slice of fs.FileInfo, sorted by Name,
// describing the content of the named directory.
// If ReadDir is nil, Import uses os.ReadDir.
ReadDir func(dir string) ([]fs.FileInfo, error)
// OpenFile opens a file (not a directory) for reading.
// If OpenFile is nil, Import uses os.Open.
OpenFile func(path string) (io.ReadCloser, error)
}
// joinPath calls ctxt.JoinPath (if not nil) or else filepath.Join.
func (ctxt *Context) joinPath(elem ...string) string {
if f := ctxt.JoinPath; f != nil {
return f(elem...)
}
return filepath.Join(elem...)
}
// splitPathList calls ctxt.SplitPathList (if not nil) or else filepath.SplitList.
func (ctxt *Context) splitPathList(s string) []string {
if f := ctxt.SplitPathList; f != nil {
return f(s)
}
return filepath.SplitList(s)
}
// isAbsPath calls ctxt.IsAbsPath (if not nil) or else filepath.IsAbs.
func (ctxt *Context) isAbsPath(path string) bool {
if f := ctxt.IsAbsPath; f != nil {
return f(path)
}
return filepath.IsAbs(path)
}
// isDir calls ctxt.IsDir (if not nil) or else uses os.Stat.
func (ctxt *Context) isDir(path string) bool {
if f := ctxt.IsDir; f != nil {
return f(path)
}
fi, err := os.Stat(path)
return err == nil && fi.IsDir()
}
// hasSubdir calls ctxt.HasSubdir (if not nil) or else uses
// the local file system to answer the question.
func (ctxt *Context) hasSubdir(root, dir string) (rel string, ok bool) {
if f := ctxt.HasSubdir; f != nil {
return f(root, dir)
}
// Try using paths we received.
if rel, ok = hasSubdir(root, dir); ok {
return
}
// Try expanding symlinks and comparing
// expanded against unexpanded and
// expanded against expanded.
rootSym, _ := filepath.EvalSymlinks(root)
dirSym, _ := filepath.EvalSymlinks(dir)
if rel, ok = hasSubdir(rootSym, dir); ok {
return
}
if rel, ok = hasSubdir(root, dirSym); ok {
return
}
return hasSubdir(rootSym, dirSym)
}
// hasSubdir reports if dir is within root by performing lexical analysis only.
func hasSubdir(root, dir string) (rel string, ok bool) {
const sep = string(filepath.Separator)
root = filepath.Clean(root)
if !strings.HasSuffix(root, sep) {
root += sep
}
dir = filepath.Clean(dir)
after, found := strings.CutPrefix(dir, root)
if !found {
return "", false
}
return filepath.ToSlash(after), true
}
// readDir calls ctxt.ReadDir (if not nil) or else os.ReadDir.
func (ctxt *Context) readDir(path string) ([]fs.DirEntry, error) {
// TODO: add a fs.DirEntry version of Context.ReadDir
if f := ctxt.ReadDir; f != nil {
fis, err := f(path)
if err != nil {
return nil, err
}
des := make([]fs.DirEntry, len(fis))
for i, fi := range fis {
des[i] = fs.FileInfoToDirEntry(fi)
}
return des, nil
}
return os.ReadDir(path)
}
// openFile calls ctxt.OpenFile (if not nil) or else os.Open.
func (ctxt *Context) openFile(path string) (io.ReadCloser, error) {
if fn := ctxt.OpenFile; fn != nil {
return fn(path)
}
f, err := os.Open(path)
if err != nil {
return nil, err // nil interface
}
return f, nil
}
// isFile determines whether path is a file by trying to open it.
// It reuses openFile instead of adding another function to the
// list in Context.
func (ctxt *Context) isFile(path string) bool {
f, err := ctxt.openFile(path)
if err != nil {
return false
}
f.Close()
return true
}
// gopath returns the list of Go path directories.
func (ctxt *Context) gopath() []string {
var all []string
for _, p := range ctxt.splitPathList(ctxt.GOPATH) {
if p == "" || p == ctxt.GOROOT {
// Empty paths are uninteresting.
// If the path is the GOROOT, ignore it.
// People sometimes set GOPATH=$GOROOT.
// Do not get confused by this common mistake.
continue
}
if strings.HasPrefix(p, "~") {
// Path segments starting with ~ on Unix are almost always
// users who have incorrectly quoted ~ while setting GOPATH,
// preventing it from expanding to $HOME.
// The situation is made more confusing by the fact that
// bash allows quoted ~ in $PATH (most shells do not).
// Do not get confused by this, and do not try to use the path.
// It does not exist, and printing errors about it confuses
// those users even more, because they think "sure ~ exists!".
// The go command diagnoses this situation and prints a
// useful error.
// On Windows, ~ is used in short names, such as c:\progra~1
// for c:\program files.
continue
}
all = append(all, p)
}
return all
}
// SrcDirs returns a list of package source root directories.
// It draws from the current Go root and Go path but omits directories
// that do not exist.
func (ctxt *Context) SrcDirs() []string {
var all []string
if ctxt.GOROOT != "" && ctxt.Compiler != "gccgo" {
dir := ctxt.joinPath(ctxt.GOROOT, "src")
if ctxt.isDir(dir) {
all = append(all, dir)
}
}
for _, p := range ctxt.gopath() {
dir := ctxt.joinPath(p, "src")
if ctxt.isDir(dir) {
all = append(all, dir)
}
}
return all
}
// Default is the default Context for builds.
// It uses the GOARCH, GOOS, GOROOT, and GOPATH environment variables
// if set, or else the compiled code's GOARCH, GOOS, and GOROOT.
var Default Context = defaultContext()
// Keep consistent with cmd/go/internal/cfg.defaultGOPATH.
func defaultGOPATH() string {
env := "HOME"
if runtime.GOOS == "windows" {
env = "USERPROFILE"
} else if runtime.GOOS == "plan9" {
env = "home"
}
if home := os.Getenv(env); home != "" {
def := filepath.Join(home, "go")
if filepath.Clean(def) == filepath.Clean(runtime.GOROOT()) {
// Don't set the default GOPATH to GOROOT,
// as that will trigger warnings from the go tool.
return ""
}
return def
}
return ""
}
// defaultToolTags should be an internal detail,
// but widely used packages access it using linkname.
// Notable members of the hall of shame include:
// - github.com/gopherjs/gopherjs
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname defaultToolTags
var defaultToolTags []string
// defaultReleaseTags should be an internal detail,
// but widely used packages access it using linkname.
// Notable members of the hall of shame include:
// - github.com/gopherjs/gopherjs
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname defaultReleaseTags
var defaultReleaseTags []string
func defaultContext() Context {
var c Context
c.GOARCH = buildcfg.GOARCH
c.GOOS = buildcfg.GOOS
if goroot := runtime.GOROOT(); goroot != "" {
c.GOROOT = filepath.Clean(goroot)
}
c.GOPATH = envOr("GOPATH", defaultGOPATH())
c.Compiler = runtime.Compiler
c.ToolTags = append(c.ToolTags, buildcfg.ToolTags...)
defaultToolTags = append([]string{}, c.ToolTags...) // our own private copy
// Each major Go release in the Go 1.x series adds a new
// "go1.x" release tag. That is, the go1.x tag is present in
// all releases >= Go 1.x. Code that requires Go 1.x or later
// should say "go:build go1.x", and code that should only be
// built before Go 1.x (perhaps it is the stub to use in that
// case) should say "go:build !go1.x".
// The last element in ReleaseTags is the current release.
for i := 1; i <= goversion.Version; i++ {
c.ReleaseTags = append(c.ReleaseTags, "go1."+strconv.Itoa(i))
}
defaultReleaseTags = append([]string{}, c.ReleaseTags...) // our own private copy
env := os.Getenv("CGO_ENABLED")
if env == "" {
env = buildcfg.DefaultCGO_ENABLED
}
switch env {
case "1":
c.CgoEnabled = true
case "0":
c.CgoEnabled = false
default:
// cgo must be explicitly enabled for cross compilation builds
if runtime.GOARCH == c.GOARCH && runtime.GOOS == c.GOOS {
c.CgoEnabled = platform.CgoSupported(c.GOOS, c.GOARCH)
break
}
c.CgoEnabled = false
}
return c
}
func envOr(name, def string) string {
s := os.Getenv(name)
if s == "" {
return def
}
return s
}
// An ImportMode controls the behavior of the Import method.
type ImportMode uint
const (
// If FindOnly is set, Import stops after locating the directory
// that should contain the sources for a package. It does not
// read any files in the directory.
FindOnly ImportMode = 1 << iota
// If AllowBinary is set, Import can be satisfied by a compiled
// package object without corresponding sources.
//
// Deprecated:
// The supported way to create a compiled-only package is to
// write source code containing a //go:binary-only-package comment at
// the top of the file. Such a package will be recognized
// regardless of this flag setting (because it has source code)
// and will have BinaryOnly set to true in the returned Package.
AllowBinary
// If ImportComment is set, parse import comments on package statements.
// Import returns an error if it finds a comment it cannot understand
// or finds conflicting comments in multiple source files.
// See golang.org/s/go14customimport for more information.
ImportComment
// By default, Import searches vendor directories
// that apply in the given source directory before searching
// the GOROOT and GOPATH roots.
// If an Import finds and returns a package using a vendor
// directory, the resulting ImportPath is the complete path
// to the package, including the path elements leading up
// to and including "vendor".
// For example, if Import("y", "x/subdir", 0) finds
// "x/vendor/y", the returned package's ImportPath is "x/vendor/y",
// not plain "y".
// See golang.org/s/go15vendor for more information.
//
// Setting IgnoreVendor ignores vendor directories.
//
// In contrast to the package's ImportPath,
// the returned package's Imports, TestImports, and XTestImports
// are always the exact import paths from the source files:
// Import makes no attempt to resolve or check those paths.
IgnoreVendor
)
// A Package describes the Go package found in a directory.
type Package struct {
Dir string // directory containing package sources
Name string // package name
ImportComment string // path in import comment on package statement
Doc string // documentation synopsis
ImportPath string // import path of package ("" if unknown)
Root string // root of Go tree where this package lives
SrcRoot string // package source root directory ("" if unknown)
PkgRoot string // package install root directory ("" if unknown)
PkgTargetRoot string // architecture dependent install root directory ("" if unknown)
BinDir string // command install directory ("" if unknown)
Goroot bool // package found in Go root
PkgObj string // installed .a file
AllTags []string // tags that can influence file selection in this directory
ConflictDir string // this directory shadows Dir in $GOPATH
BinaryOnly bool // cannot be rebuilt from source (has //go:binary-only-package comment)
// Source files
GoFiles []string // .go source files (excluding CgoFiles, TestGoFiles, XTestGoFiles)
CgoFiles []string // .go source files that import "C"
IgnoredGoFiles []string // .go source files ignored for this build (including ignored _test.go files)
InvalidGoFiles []string // .go source files with detected problems (parse error, wrong package name, and so on)
IgnoredOtherFiles []string // non-.go source files ignored for this build
CFiles []string // .c source files
CXXFiles []string // .cc, .cpp and .cxx source files
MFiles []string // .m (Objective-C) source files
HFiles []string // .h, .hh, .hpp and .hxx source files
FFiles []string // .f, .F, .for and .f90 Fortran source files
SFiles []string // .s source files
SwigFiles []string // .swig files
SwigCXXFiles []string // .swigcxx files
SysoFiles []string // .syso system object files to add to archive
// Cgo directives
CgoCFLAGS []string // Cgo CFLAGS directives
CgoCPPFLAGS []string // Cgo CPPFLAGS directives
CgoCXXFLAGS []string // Cgo CXXFLAGS directives
CgoFFLAGS []string // Cgo FFLAGS directives
CgoLDFLAGS []string // Cgo LDFLAGS directives
CgoPkgConfig []string // Cgo pkg-config directives
// Test information
TestGoFiles []string // _test.go files in package
XTestGoFiles []string // _test.go files outside package
// Go directive comments (//go:zzz...) found in source files.
Directives []Directive
TestDirectives []Directive
XTestDirectives []Directive
// Dependency information
Imports []string // import paths from GoFiles, CgoFiles
ImportPos map[string][]token.Position // line information for Imports
TestImports []string // import paths from TestGoFiles
TestImportPos map[string][]token.Position // line information for TestImports
XTestImports []string // import paths from XTestGoFiles
XTestImportPos map[string][]token.Position // line information for XTestImports
// //go:embed patterns found in Go source files
// For example, if a source file says
// //go:embed a* b.c
// then the list will contain those two strings as separate entries.
// (See package embed for more details about //go:embed.)
EmbedPatterns []string // patterns from GoFiles, CgoFiles
EmbedPatternPos map[string][]token.Position // line information for EmbedPatterns
TestEmbedPatterns []string // patterns from TestGoFiles
TestEmbedPatternPos map[string][]token.Position // line information for TestEmbedPatterns
XTestEmbedPatterns []string // patterns from XTestGoFiles
XTestEmbedPatternPos map[string][]token.Position // line information for XTestEmbedPatternPos
}
// A Directive is a Go directive comment (//go:zzz...) found in a source file.
type Directive struct {
Text string // full line comment including leading slashes
Pos token.Position // position of comment
}
// IsCommand reports whether the package is considered a
// command to be installed (not just a library).
// Packages named "main" are treated as commands.
func (p *Package) IsCommand() bool {
return p.Name == "main"
}
// ImportDir is like [Import] but processes the Go package found in
// the named directory.
func (ctxt *Context) ImportDir(dir string, mode ImportMode) (*Package, error) {
return ctxt.Import(".", dir, mode)
}
// NoGoError is the error used by [Import] to describe a directory
// containing no buildable Go source files. (It may still contain
// test files, files hidden by build tags, and so on.)
type NoGoError struct {
Dir string
}
func (e *NoGoError) Error() string {
return "no buildable Go source files in " + e.Dir
}
// MultiplePackageError describes a directory containing
// multiple buildable Go source files for multiple packages.
type MultiplePackageError struct {
Dir string // directory containing files
Packages []string // package names found
Files []string // corresponding files: Files[i] declares package Packages[i]
}
func (e *MultiplePackageError) Error() string {
// Error string limited to two entries for compatibility.
return fmt.Sprintf("found packages %s (%s) and %s (%s) in %s", e.Packages[0], e.Files[0], e.Packages[1], e.Files[1], e.Dir)
}
func nameExt(name string) string {
i := strings.LastIndex(name, ".")
if i < 0 {
return ""
}
return name[i:]
}
var installgoroot = godebug.New("installgoroot")
// Import returns details about the Go package named by the import path,
// interpreting local import paths relative to the srcDir directory.
// If the path is a local import path naming a package that can be imported
// using a standard import path, the returned package will set p.ImportPath
// to that path.
//
// In the directory containing the package, .go, .c, .h, and .s files are
// considered part of the package except for:
//
// - .go files in package documentation
// - files starting with _ or . (likely editor temporary files)
// - files with build constraints not satisfied by the context
//
// If an error occurs, Import returns a non-nil error and a non-nil
// *[Package] containing partial information.
func (ctxt *Context) Import(path string, srcDir string, mode ImportMode) (*Package, error) {
p := &Package{
ImportPath: path,
}
if path == "" {
return p, fmt.Errorf("import %q: invalid import path", path)
}
var pkgtargetroot string
var pkga string
var pkgerr error
suffix := ""
if ctxt.InstallSuffix != "" {
suffix = "_" + ctxt.InstallSuffix
}
switch ctxt.Compiler {
case "gccgo":
pkgtargetroot = "pkg/gccgo_" + ctxt.GOOS + "_" + ctxt.GOARCH + suffix
case "gc":
pkgtargetroot = "pkg/" + ctxt.GOOS + "_" + ctxt.GOARCH + suffix
default:
// Save error for end of function.
pkgerr = fmt.Errorf("import %q: unknown compiler %q", path, ctxt.Compiler)
}
setPkga := func() {
switch ctxt.Compiler {
case "gccgo":
dir, elem := pathpkg.Split(p.ImportPath)
pkga = pkgtargetroot + "/" + dir + "lib" + elem + ".a"
case "gc":
pkga = pkgtargetroot + "/" + p.ImportPath + ".a"
}
}
setPkga()
binaryOnly := false
if IsLocalImport(path) {
pkga = "" // local imports have no installed path
if srcDir == "" {
return p, fmt.Errorf("import %q: import relative to unknown directory", path)
}
if !ctxt.isAbsPath(path) {
p.Dir = ctxt.joinPath(srcDir, path)
}
// p.Dir directory may or may not exist. Gather partial information first, check if it exists later.
// Determine canonical import path, if any.
// Exclude results where the import path would include /testdata/.
inTestdata := func(sub string) bool {
return strings.Contains(sub, "/testdata/") || strings.HasSuffix(sub, "/testdata") || strings.HasPrefix(sub, "testdata/") || sub == "testdata"
}
if ctxt.GOROOT != "" {
root := ctxt.joinPath(ctxt.GOROOT, "src")
if sub, ok := ctxt.hasSubdir(root, p.Dir); ok && !inTestdata(sub) {
p.Goroot = true
p.ImportPath = sub
p.Root = ctxt.GOROOT
setPkga() // p.ImportPath changed
goto Found
}
}
all := ctxt.gopath()
for i, root := range all {
rootsrc := ctxt.joinPath(root, "src")
if sub, ok := ctxt.hasSubdir(rootsrc, p.Dir); ok && !inTestdata(sub) {
// We found a potential import path for dir,
// but check that using it wouldn't find something
// else first.
if ctxt.GOROOT != "" && ctxt.Compiler != "gccgo" {
if dir := ctxt.joinPath(ctxt.GOROOT, "src", sub); ctxt.isDir(dir) {
p.ConflictDir = dir
goto Found
}
}
for _, earlyRoot := range all[:i] {
if dir := ctxt.joinPath(earlyRoot, "src", sub); ctxt.isDir(dir) {
p.ConflictDir = dir
goto Found
}
}
// sub would not name some other directory instead of this one.
// Record it.
p.ImportPath = sub
p.Root = root
setPkga() // p.ImportPath changed
goto Found
}
}
// It's okay that we didn't find a root containing dir.
// Keep going with the information we have.
} else {
if strings.HasPrefix(path, "/") {
return p, fmt.Errorf("import %q: cannot import absolute path", path)
}
if err := ctxt.importGo(p, path, srcDir, mode); err == nil {
goto Found
} else if err != errNoModules {
return p, err
}
gopath := ctxt.gopath() // needed twice below; avoid computing many times
// tried records the location of unsuccessful package lookups
var tried struct {
vendor []string
goroot string
gopath []string
}
// Vendor directories get first chance to satisfy import.
if mode&IgnoreVendor == 0 && srcDir != "" {
searchVendor := func(root string, isGoroot bool) bool {
sub, ok := ctxt.hasSubdir(root, srcDir)
if !ok || !strings.HasPrefix(sub, "src/") || strings.Contains(sub, "/testdata/") {
return false
}
for {
vendor := ctxt.joinPath(root, sub, "vendor")
if ctxt.isDir(vendor) {
dir := ctxt.joinPath(vendor, path)
if ctxt.isDir(dir) && hasGoFiles(ctxt, dir) {
p.Dir = dir
p.ImportPath = strings.TrimPrefix(pathpkg.Join(sub, "vendor", path), "src/")
p.Goroot = isGoroot
p.Root = root
setPkga() // p.ImportPath changed
return true
}
tried.vendor = append(tried.vendor, dir)
}
i := strings.LastIndex(sub, "/")
if i < 0 {
break
}
sub = sub[:i]
}
return false
}
if ctxt.Compiler != "gccgo" && ctxt.GOROOT != "" && searchVendor(ctxt.GOROOT, true) {
goto Found
}
for _, root := range gopath {
if searchVendor(root, false) {
goto Found
}
}
}
// Determine directory from import path.
if ctxt.GOROOT != "" {
// If the package path starts with "vendor/", only search GOROOT before
// GOPATH if the importer is also within GOROOT. That way, if the user has
// vendored in a package that is subsequently included in the standard
// distribution, they'll continue to pick up their own vendored copy.
gorootFirst := srcDir == "" || !strings.HasPrefix(path, "vendor/")
if !gorootFirst {
_, gorootFirst = ctxt.hasSubdir(ctxt.GOROOT, srcDir)
}
if gorootFirst {
dir := ctxt.joinPath(ctxt.GOROOT, "src", path)
if ctxt.Compiler != "gccgo" {
isDir := ctxt.isDir(dir)
binaryOnly = !isDir && mode&AllowBinary != 0 && pkga != "" && ctxt.isFile(ctxt.joinPath(ctxt.GOROOT, pkga))
if isDir || binaryOnly {
p.Dir = dir
p.Goroot = true
p.Root = ctxt.GOROOT
goto Found
}
}
tried.goroot = dir
}
if ctxt.Compiler == "gccgo" && goroot.IsStandardPackage(ctxt.GOROOT, ctxt.Compiler, path) {
// TODO(bcmills): Setting p.Dir here is misleading, because gccgo
// doesn't actually load its standard-library packages from this
// directory. See if we can leave it unset.
p.Dir = ctxt.joinPath(ctxt.GOROOT, "src", path)
p.Goroot = true
p.Root = ctxt.GOROOT
goto Found
}
}
for _, root := range gopath {
dir := ctxt.joinPath(root, "src", path)
isDir := ctxt.isDir(dir)
binaryOnly = !isDir && mode&AllowBinary != 0 && pkga != "" && ctxt.isFile(ctxt.joinPath(root, pkga))
if isDir || binaryOnly {
p.Dir = dir
p.Root = root
goto Found
}
tried.gopath = append(tried.gopath, dir)
}
// If we tried GOPATH first due to a "vendor/" prefix, fall back to GOPATH.
// That way, the user can still get useful results from 'go list' for
// standard-vendored paths passed on the command line.
if ctxt.GOROOT != "" && tried.goroot == "" {
dir := ctxt.joinPath(ctxt.GOROOT, "src", path)
if ctxt.Compiler != "gccgo" {
isDir := ctxt.isDir(dir)
binaryOnly = !isDir && mode&AllowBinary != 0 && pkga != "" && ctxt.isFile(ctxt.joinPath(ctxt.GOROOT, pkga))
if isDir || binaryOnly {
p.Dir = dir
p.Goroot = true
p.Root = ctxt.GOROOT
goto Found
}
}
tried.goroot = dir
}
// package was not found
var paths []string
format := "\t%s (vendor tree)"
for _, dir := range tried.vendor {
paths = append(paths, fmt.Sprintf(format, dir))
format = "\t%s"
}
if tried.goroot != "" {
paths = append(paths, fmt.Sprintf("\t%s (from $GOROOT)", tried.goroot))
} else {
paths = append(paths, "\t($GOROOT not set)")
}
format = "\t%s (from $GOPATH)"
for _, dir := range tried.gopath {
paths = append(paths, fmt.Sprintf(format, dir))
format = "\t%s"
}
if len(tried.gopath) == 0 {
paths = append(paths, "\t($GOPATH not set. For more details see: 'go help gopath')")
}
return p, fmt.Errorf("cannot find package %q in any of:\n%s", path, strings.Join(paths, "\n"))
}
Found:
if p.Root != "" {
p.SrcRoot = ctxt.joinPath(p.Root, "src")
p.PkgRoot = ctxt.joinPath(p.Root, "pkg")
p.BinDir = ctxt.joinPath(p.Root, "bin")
if pkga != "" {
// Always set PkgTargetRoot. It might be used when building in shared
// mode.
p.PkgTargetRoot = ctxt.joinPath(p.Root, pkgtargetroot)
// Set the install target if applicable.
if !p.Goroot || (installgoroot.Value() == "all" && p.ImportPath != "unsafe" && p.ImportPath != "builtin") {
if p.Goroot {
installgoroot.IncNonDefault()
}
p.PkgObj = ctxt.joinPath(p.Root, pkga)
}
}
}
// If it's a local import path, by the time we get here, we still haven't checked
// that p.Dir directory exists. This is the right time to do that check.
// We can't do it earlier, because we want to gather partial information for the
// non-nil *Package returned when an error occurs.
// We need to do this before we return early on FindOnly flag.
if IsLocalImport(path) && !ctxt.isDir(p.Dir) {
if ctxt.Compiler == "gccgo" && p.Goroot {
// gccgo has no sources for GOROOT packages.
return p, nil
}
// package was not found
return p, fmt.Errorf("cannot find package %q in:\n\t%s", p.ImportPath, p.Dir)
}
if mode&FindOnly != 0 {
return p, pkgerr
}
if binaryOnly && (mode&AllowBinary) != 0 {
return p, pkgerr
}
if ctxt.Compiler == "gccgo" && p.Goroot {
// gccgo has no sources for GOROOT packages.
return p, nil
}
dirs, err := ctxt.readDir(p.Dir)
if err != nil {
return p, err
}
var badGoError error
badGoFiles := make(map[string]bool)
badGoFile := func(name string, err error) {
if badGoError == nil {
badGoError = err
}
if !badGoFiles[name] {
p.InvalidGoFiles = append(p.InvalidGoFiles, name)
badGoFiles[name] = true
}
}
var Sfiles []string // files with ".S"(capital S)/.sx(capital s equivalent for case insensitive filesystems)
var firstFile, firstCommentFile string
embedPos := make(map[string][]token.Position)
testEmbedPos := make(map[string][]token.Position)
xTestEmbedPos := make(map[string][]token.Position)
importPos := make(map[string][]token.Position)
testImportPos := make(map[string][]token.Position)
xTestImportPos := make(map[string][]token.Position)
allTags := make(map[string]bool)
fset := token.NewFileSet()
for _, d := range dirs {
if d.IsDir() {
continue
}
if d.Type() == fs.ModeSymlink {
if ctxt.isDir(ctxt.joinPath(p.Dir, d.Name())) {
// Symlinks to directories are not source files.
continue
}
}
name := d.Name()
ext := nameExt(name)
info, err := ctxt.matchFile(p.Dir, name, allTags, &p.BinaryOnly, fset)
if err != nil && strings.HasSuffix(name, ".go") {
badGoFile(name, err)
continue
}
if info == nil {
if strings.HasPrefix(name, "_") || strings.HasPrefix(name, ".") {
// not due to build constraints - don't report
} else if ext == ".go" {
p.IgnoredGoFiles = append(p.IgnoredGoFiles, name)
} else if fileListForExt(p, ext) != nil {
p.IgnoredOtherFiles = append(p.IgnoredOtherFiles, name)
}
continue
}
// Going to save the file. For non-Go files, can stop here.
switch ext {
case ".go":
// keep going
case ".S", ".sx":
// special case for cgo, handled at end
Sfiles = append(Sfiles, name)
continue
default:
if list := fileListForExt(p, ext); list != nil {
*list = append(*list, name)
}
continue
}
data, filename := info.header, info.name
if info.parseErr != nil {
badGoFile(name, info.parseErr)
// Fall through: we might still have a partial AST in info.parsed,
// and we want to list files with parse errors anyway.
}
var pkg string
if info.parsed != nil {
pkg = info.parsed.Name.Name
if pkg == "documentation" {
p.IgnoredGoFiles = append(p.IgnoredGoFiles, name)
continue
}
}
isTest := strings.HasSuffix(name, "_test.go")
isXTest := false
if isTest && strings.HasSuffix(pkg, "_test") && p.Name != pkg {
isXTest = true
pkg = pkg[:len(pkg)-len("_test")]
}
if p.Name == "" {
p.Name = pkg
firstFile = name
} else if pkg != p.Name {
// TODO(#45999): The choice of p.Name is arbitrary based on file iteration
// order. Instead of resolving p.Name arbitrarily, we should clear out the
// existing name and mark the existing files as also invalid.
badGoFile(name, &MultiplePackageError{
Dir: p.Dir,
Packages: []string{p.Name, pkg},
Files: []string{firstFile, name},
})
}
// Grab the first package comment as docs, provided it is not from a test file.
if info.parsed != nil && info.parsed.Doc != nil && p.Doc == "" && !isTest && !isXTest {
p.Doc = doc.Synopsis(info.parsed.Doc.Text())
}
if mode&ImportComment != 0 {
qcom, line := findImportComment(data)
if line != 0 {
com, err := strconv.Unquote(qcom)
if err != nil {
badGoFile(name, fmt.Errorf("%s:%d: cannot parse import comment", filename, line))
} else if p.ImportComment == "" {
p.ImportComment = com
firstCommentFile = name
} else if p.ImportComment != com {
badGoFile(name, fmt.Errorf("found import comments %q (%s) and %q (%s) in %s", p.ImportComment, firstCommentFile, com, name, p.Dir))
}
}
}
// Record imports and information about cgo.
isCgo := false
for _, imp := range info.imports {
if imp.path == "C" {
if isTest {
badGoFile(name, fmt.Errorf("use of cgo in test %s not supported", filename))
continue
}
isCgo = true
if imp.doc != nil {
if err := ctxt.saveCgo(filename, p, imp.doc); err != nil {
badGoFile(name, err)
}
}
}
}
var fileList *[]string
var importMap, embedMap map[string][]token.Position
var directives *[]Directive
switch {
case isCgo:
allTags["cgo"] = true
if ctxt.CgoEnabled {
fileList = &p.CgoFiles
importMap = importPos
embedMap = embedPos
directives = &p.Directives
} else {
// Ignore imports and embeds from cgo files if cgo is disabled.
fileList = &p.IgnoredGoFiles
}
case isXTest:
fileList = &p.XTestGoFiles
importMap = xTestImportPos
embedMap = xTestEmbedPos
directives = &p.XTestDirectives
case isTest:
fileList = &p.TestGoFiles
importMap = testImportPos
embedMap = testEmbedPos
directives = &p.TestDirectives
default:
fileList = &p.GoFiles
importMap = importPos
embedMap = embedPos
directives = &p.Directives
}
*fileList = append(*fileList, name)
if importMap != nil {
for _, imp := range info.imports {
importMap[imp.path] = append(importMap[imp.path], fset.Position(imp.pos))
}
}
if embedMap != nil {
for _, emb := range info.embeds {
embedMap[emb.pattern] = append(embedMap[emb.pattern], emb.pos)
}
}
if directives != nil {
*directives = append(*directives, info.directives...)
}
}
for tag := range allTags {
p.AllTags = append(p.AllTags, tag)
}
slices.Sort(p.AllTags)
p.EmbedPatterns, p.EmbedPatternPos = cleanDecls(embedPos)
p.TestEmbedPatterns, p.TestEmbedPatternPos = cleanDecls(testEmbedPos)
p.XTestEmbedPatterns, p.XTestEmbedPatternPos = cleanDecls(xTestEmbedPos)
p.Imports, p.ImportPos = cleanDecls(importPos)
p.TestImports, p.TestImportPos = cleanDecls(testImportPos)
p.XTestImports, p.XTestImportPos = cleanDecls(xTestImportPos)
// add the .S/.sx files only if we are using cgo
// (which means gcc will compile them).
// The standard assemblers expect .s files.
if len(p.CgoFiles) > 0 {
p.SFiles = append(p.SFiles, Sfiles...)
slices.Sort(p.SFiles)
} else {
p.IgnoredOtherFiles = append(p.IgnoredOtherFiles, Sfiles...)
slices.Sort(p.IgnoredOtherFiles)
}
if badGoError != nil {
return p, badGoError
}
if len(p.GoFiles)+len(p.CgoFiles)+len(p.TestGoFiles)+len(p.XTestGoFiles) == 0 {
return p, &NoGoError{p.Dir}
}
return p, pkgerr
}
func fileListForExt(p *Package, ext string) *[]string {
switch ext {
case ".c":
return &p.CFiles
case ".cc", ".cpp", ".cxx":
return &p.CXXFiles
case ".m":
return &p.MFiles
case ".h", ".hh", ".hpp", ".hxx":
return &p.HFiles
case ".f", ".F", ".for", ".f90":
return &p.FFiles
case ".s", ".S", ".sx":
return &p.SFiles
case ".swig":
return &p.SwigFiles
case ".swigcxx":
return &p.SwigCXXFiles
case ".syso":
return &p.SysoFiles
}
return nil
}
func uniq(list []string) []string {
if list == nil {
return nil
}
out := make([]string, len(list))
copy(out, list)
slices.Sort(out)
uniq := out[:0]
for _, x := range out {
if len(uniq) == 0 || uniq[len(uniq)-1] != x {
uniq = append(uniq, x)
}
}
return uniq
}
var errNoModules = errors.New("not using modules")
// importGo checks whether it can use the go command to find the directory for path.
// If using the go command is not appropriate, importGo returns errNoModules.
// Otherwise, importGo tries using the go command and reports whether that succeeded.
// Using the go command lets build.Import and build.Context.Import find code
// in Go modules. In the long term we want tools to use go/packages (currently golang.org/x/tools/go/packages),
// which will also use the go command.
// Invoking the go command here is not very efficient in that it computes information
// about the requested package and all dependencies and then only reports about the requested package.
// Then we reinvoke it for every dependency. But this is still better than not working at all.
// See golang.org/issue/26504.
func (ctxt *Context) importGo(p *Package, path, srcDir string, mode ImportMode) error {
// To invoke the go command,
// we must not being doing special things like AllowBinary or IgnoreVendor,
// and all the file system callbacks must be nil (we're meant to use the local file system).
if mode&AllowBinary != 0 || mode&IgnoreVendor != 0 ||
ctxt.JoinPath != nil || ctxt.SplitPathList != nil || ctxt.IsAbsPath != nil || ctxt.IsDir != nil || ctxt.HasSubdir != nil || ctxt.ReadDir != nil || ctxt.OpenFile != nil || !equal(ctxt.ToolTags, defaultToolTags) || !equal(ctxt.ReleaseTags, defaultReleaseTags) {
return errNoModules
}
// If ctxt.GOROOT is not set, we don't know which go command to invoke,
// and even if we did we might return packages in GOROOT that we wouldn't otherwise find
// (because we don't know to search in 'go env GOROOT' otherwise).
if ctxt.GOROOT == "" {
return errNoModules
}
// Predict whether module aware mode is enabled by checking the value of
// GO111MODULE and looking for a go.mod file in the source directory or
// one of its parents. Running 'go env GOMOD' in the source directory would
// give a canonical answer, but we'd prefer not to execute another command.
go111Module := os.Getenv("GO111MODULE")
switch go111Module {
case "off":
return errNoModules
default: // "", "on", "auto", anything else
// Maybe use modules.
}
if srcDir != "" {
var absSrcDir string
if filepath.IsAbs(srcDir) {
absSrcDir = srcDir
} else if ctxt.Dir != "" {
return fmt.Errorf("go/build: Dir is non-empty, so relative srcDir is not allowed: %v", srcDir)
} else {
// Find the absolute source directory. hasSubdir does not handle
// relative paths (and can't because the callbacks don't support this).
var err error
absSrcDir, err = filepath.Abs(srcDir)
if err != nil {
return errNoModules
}
}
// If the source directory is in GOROOT, then the in-process code works fine
// and we should keep using it. Moreover, the 'go list' approach below doesn't
// take standard-library vendoring into account and will fail.
if _, ok := ctxt.hasSubdir(filepath.Join(ctxt.GOROOT, "src"), absSrcDir); ok {
return errNoModules
}
}
// For efficiency, if path is a standard library package, let the usual lookup code handle it.
if dir := ctxt.joinPath(ctxt.GOROOT, "src", path); ctxt.isDir(dir) {
return errNoModules
}
// If GO111MODULE=auto, look to see if there is a go.mod.
// Since go1.13, it doesn't matter if we're inside GOPATH.
if go111Module == "auto" {
var (
parent string
err error
)
if ctxt.Dir == "" {
parent, err = os.Getwd()
if err != nil {
// A nonexistent working directory can't be in a module.
return errNoModules
}
} else {
parent, err = filepath.Abs(ctxt.Dir)
if err != nil {
// If the caller passed a bogus Dir explicitly, that's materially
// different from not having modules enabled.
return err
}
}
for {
if f, err := ctxt.openFile(ctxt.joinPath(parent, "go.mod")); err == nil {
buf := make([]byte, 100)
_, err := f.Read(buf)
f.Close()
if err == nil || err == io.EOF {
// go.mod exists and is readable (is a file, not a directory).
break
}
}
d := filepath.Dir(parent)
if len(d) >= len(parent) {
return errNoModules // reached top of file system, no go.mod
}
parent = d
}
}
goCmd := filepath.Join(ctxt.GOROOT, "bin", "go")
cmd := exec.Command(goCmd, "list", "-e", "-compiler="+ctxt.Compiler, "-tags="+strings.Join(ctxt.BuildTags, ","), "-installsuffix="+ctxt.InstallSuffix, "-f={{.Dir}}\n{{.ImportPath}}\n{{.Root}}\n{{.Goroot}}\n{{if .Error}}{{.Error}}{{end}}\n", "--", path)
if ctxt.Dir != "" {
cmd.Dir = ctxt.Dir
}
var stdout, stderr strings.Builder
cmd.Stdout = &stdout
cmd.Stderr = &stderr
cgo := "0"
if ctxt.CgoEnabled {
cgo = "1"
}
cmd.Env = append(cmd.Environ(),
"GOOS="+ctxt.GOOS,
"GOARCH="+ctxt.GOARCH,
"GOROOT="+ctxt.GOROOT,
"GOPATH="+ctxt.GOPATH,
"CGO_ENABLED="+cgo,
)
if err := cmd.Run(); err != nil {
return fmt.Errorf("go/build: go list %s: %v\n%s\n", path, err, stderr.String())
}
f := strings.SplitN(stdout.String(), "\n", 5)
if len(f) != 5 {
return fmt.Errorf("go/build: importGo %s: unexpected output:\n%s\n", path, stdout.String())
}
dir := f[0]
errStr := strings.TrimSpace(f[4])
if errStr != "" && dir == "" {
// If 'go list' could not locate the package (dir is empty),
// return the same error that 'go list' reported.
return errors.New(errStr)
}
// If 'go list' did locate the package, ignore the error.
// It was probably related to loading source files, and we'll
// encounter it ourselves shortly if the FindOnly flag isn't set.
p.Dir = dir
p.ImportPath = f[1]
p.Root = f[2]
p.Goroot = f[3] == "true"
return nil
}
func equal(x, y []string) bool {
if len(x) != len(y) {
return false
}
for i, xi := range x {
if xi != y[i] {
return false
}
}
return true
}
// hasGoFiles reports whether dir contains any files with names ending in .go.
// For a vendor check we must exclude directories that contain no .go files.
// Otherwise it is not possible to vendor just a/b/c and still import the
// non-vendored a/b. See golang.org/issue/13832.
func hasGoFiles(ctxt *Context, dir string) bool {
ents, _ := ctxt.readDir(dir)
for _, ent := range ents {
if !ent.IsDir() && strings.HasSuffix(ent.Name(), ".go") {
return true
}
}
return false
}
func findImportComment(data []byte) (s string, line int) {
// expect keyword package
word, data := parseWord(data)
if string(word) != "package" {
return "", 0
}
// expect package name
_, data = parseWord(data)
// now ready for import comment, a // or /* */ comment
// beginning and ending on the current line.
for len(data) > 0 && (data[0] == ' ' || data[0] == '\t' || data[0] == '\r') {
data = data[1:]
}
var comment []byte
switch {
case bytes.HasPrefix(data, slashSlash):
comment, _, _ = bytes.Cut(data[2:], newline)
case bytes.HasPrefix(data, slashStar):
var ok bool
comment, _, ok = bytes.Cut(data[2:], starSlash)
if !ok {
// malformed comment
return "", 0
}
if bytes.Contains(comment, newline) {
return "", 0
}
}
comment = bytes.TrimSpace(comment)
// split comment into `import`, `"pkg"`
word, arg := parseWord(comment)
if string(word) != "import" {
return "", 0
}
line = 1 + bytes.Count(data[:cap(data)-cap(arg)], newline)
return strings.TrimSpace(string(arg)), line
}
var (
slashSlash = []byte("//")
slashStar = []byte("/*")
starSlash = []byte("*/")
newline = []byte("\n")
)
// skipSpaceOrComment returns data with any leading spaces or comments removed.
func skipSpaceOrComment(data []byte) []byte {
for len(data) > 0 {
switch data[0] {
case ' ', '\t', '\r', '\n':
data = data[1:]
continue
case '/':
if bytes.HasPrefix(data, slashSlash) {
i := bytes.Index(data, newline)
if i < 0 {
return nil
}
data = data[i+1:]
continue
}
if bytes.HasPrefix(data, slashStar) {
data = data[2:]
i := bytes.Index(data, starSlash)
if i < 0 {
return nil
}
data = data[i+2:]
continue
}
}
break
}
return data
}
// parseWord skips any leading spaces or comments in data
// and then parses the beginning of data as an identifier or keyword,
// returning that word and what remains after the word.
func parseWord(data []byte) (word, rest []byte) {
data = skipSpaceOrComment(data)
// Parse past leading word characters.
rest = data
for {
r, size := utf8.DecodeRune(rest)
if unicode.IsLetter(r) || '0' <= r && r <= '9' || r == '_' {
rest = rest[size:]
continue
}
break
}
word = data[:len(data)-len(rest)]
if len(word) == 0 {
return nil, nil
}
return word, rest
}
// MatchFile reports whether the file with the given name in the given directory
// matches the context and would be included in a [Package] created by [ImportDir]
// of that directory.
//
// MatchFile considers the name of the file and may use ctxt.OpenFile to
// read some or all of the file's content.
func (ctxt *Context) MatchFile(dir, name string) (match bool, err error) {
info, err := ctxt.matchFile(dir, name, nil, nil, nil)
return info != nil, err
}
var dummyPkg Package
// fileInfo records information learned about a file included in a build.
type fileInfo struct {
name string // full name including dir
header []byte
fset *token.FileSet
parsed *ast.File
parseErr error
imports []fileImport
embeds []fileEmbed
directives []Directive
}
type fileImport struct {
path string
pos token.Pos
doc *ast.CommentGroup
}
type fileEmbed struct {
pattern string
pos token.Position
}
// matchFile determines whether the file with the given name in the given directory
// should be included in the package being constructed.
// If the file should be included, matchFile returns a non-nil *fileInfo (and a nil error).
// Non-nil errors are reserved for unexpected problems.
//
// If name denotes a Go program, matchFile reads until the end of the
// imports and returns that section of the file in the fileInfo's header field,
// even though it only considers text until the first non-comment
// for go:build lines.
//
// If allTags is non-nil, matchFile records any encountered build tag
// by setting allTags[tag] = true.
func (ctxt *Context) matchFile(dir, name string, allTags map[string]bool, binaryOnly *bool, fset *token.FileSet) (*fileInfo, error) {
if strings.HasPrefix(name, "_") ||
strings.HasPrefix(name, ".") {
return nil, nil
}
i := strings.LastIndex(name, ".")
if i < 0 {
i = len(name)
}
ext := name[i:]
if ext != ".go" && fileListForExt(&dummyPkg, ext) == nil {
// skip
return nil, nil
}
if !ctxt.goodOSArchFile(name, allTags) && !ctxt.UseAllFiles {
return nil, nil
}
info := &fileInfo{name: ctxt.joinPath(dir, name), fset: fset}
if ext == ".syso" {
// binary, no reading
return info, nil
}
f, err := ctxt.openFile(info.name)
if err != nil {
return nil, err
}
if strings.HasSuffix(name, ".go") {
err = readGoInfo(f, info)
if strings.HasSuffix(name, "_test.go") {
binaryOnly = nil // ignore //go:binary-only-package comments in _test.go files
}
} else {
binaryOnly = nil // ignore //go:binary-only-package comments in non-Go sources
info.header, err = readComments(f)
}
f.Close()
if err != nil {
return info, fmt.Errorf("read %s: %v", info.name, err)
}
// Look for go:build comments to accept or reject the file.
ok, sawBinaryOnly, err := ctxt.shouldBuild(info.header, allTags)
if err != nil {
return nil, fmt.Errorf("%s: %v", name, err)
}
if !ok && !ctxt.UseAllFiles {
return nil, nil
}
if binaryOnly != nil && sawBinaryOnly {
*binaryOnly = true
}
return info, nil
}
func cleanDecls(m map[string][]token.Position) ([]string, map[string][]token.Position) {
all := make([]string, 0, len(m))
for path := range m {
all = append(all, path)
}
slices.Sort(all)
return all, m
}
// Import is shorthand for Default.Import.
func Import(path, srcDir string, mode ImportMode) (*Package, error) {
return Default.Import(path, srcDir, mode)
}
// ImportDir is shorthand for Default.ImportDir.
func ImportDir(dir string, mode ImportMode) (*Package, error) {
return Default.ImportDir(dir, mode)
}
var (
plusBuild = []byte("+build")
goBuildComment = []byte("//go:build")
errMultipleGoBuild = errors.New("multiple //go:build comments")
)
func isGoBuildComment(line []byte) bool {
if !bytes.HasPrefix(line, goBuildComment) {
return false
}
line = bytes.TrimSpace(line)
rest := line[len(goBuildComment):]
return len(rest) == 0 || len(bytes.TrimSpace(rest)) < len(rest)
}
// Special comment denoting a binary-only package.
// See https://golang.org/design/2775-binary-only-packages
// for more about the design of binary-only packages.
var binaryOnlyComment = []byte("//go:binary-only-package")
// shouldBuild reports whether it is okay to use this file,
// The rule is that in the file's leading run of // comments
// and blank lines, which must be followed by a blank line
// (to avoid including a Go package clause doc comment),
// lines beginning with '//go:build' are taken as build directives.
//
// The file is accepted only if each such line lists something
// matching the file. For example:
//
// //go:build windows linux
//
// marks the file as applicable only on Windows and Linux.
//
// For each build tag it consults, shouldBuild sets allTags[tag] = true.
//
// shouldBuild reports whether the file should be built
// and whether a //go:binary-only-package comment was found.
func (ctxt *Context) shouldBuild(content []byte, allTags map[string]bool) (shouldBuild, binaryOnly bool, err error) {
// Identify leading run of // comments and blank lines,
// which must be followed by a blank line.
// Also identify any //go:build comments.
content, goBuild, sawBinaryOnly, err := parseFileHeader(content)
if err != nil {
return false, false, err
}
// If //go:build line is present, it controls.
// Otherwise fall back to +build processing.
switch {
case goBuild != nil:
x, err := constraint.Parse(string(goBuild))
if err != nil {
return false, false, fmt.Errorf("parsing //go:build line: %v", err)
}
shouldBuild = ctxt.eval(x, allTags)
default:
shouldBuild = true
p := content
for len(p) > 0 {
line := p
if i := bytes.IndexByte(line, '\n'); i >= 0 {
line, p = line[:i], p[i+1:]
} else {
p = p[len(p):]
}
line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, slashSlash) || !bytes.Contains(line, plusBuild) {
continue
}
text := string(line)
if !constraint.IsPlusBuild(text) {
continue
}
if x, err := constraint.Parse(text); err == nil {
if !ctxt.eval(x, allTags) {
shouldBuild = false
}
}
}
}
return shouldBuild, sawBinaryOnly, nil
}
// parseFileHeader should be an internal detail,
// but widely used packages access it using linkname.
// Notable members of the hall of shame include:
// - github.com/bazelbuild/bazel-gazelle
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname parseFileHeader
func parseFileHeader(content []byte) (trimmed, goBuild []byte, sawBinaryOnly bool, err error) {
end := 0
p := content
ended := false // found non-blank, non-// line, so stopped accepting //go:build lines
inSlashStar := false // in /* */ comment
Lines:
for len(p) > 0 {
line := p
if i := bytes.IndexByte(line, '\n'); i >= 0 {
line, p = line[:i], p[i+1:]
} else {
p = p[len(p):]
}
line = bytes.TrimSpace(line)
if len(line) == 0 && !ended { // Blank line
// Remember position of most recent blank line.
// When we find the first non-blank, non-// line,
// this "end" position marks the latest file position
// where a //go:build line can appear.
// (It must appear _before_ a blank line before the non-blank, non-// line.
// Yes, that's confusing, which is part of why we moved to //go:build lines.)
// Note that ended==false here means that inSlashStar==false,
// since seeing a /* would have set ended==true.
end = len(content) - len(p)
continue Lines
}
if !bytes.HasPrefix(line, slashSlash) { // Not comment line
ended = true
}
if !inSlashStar && isGoBuildComment(line) {
if goBuild != nil {
return nil, nil, false, errMultipleGoBuild
}
goBuild = line
}
if !inSlashStar && bytes.Equal(line, binaryOnlyComment) {
sawBinaryOnly = true
}
Comments:
for len(line) > 0 {
if inSlashStar {
if i := bytes.Index(line, starSlash); i >= 0 {
inSlashStar = false
line = bytes.TrimSpace(line[i+len(starSlash):])
continue Comments
}
continue Lines
}
if bytes.HasPrefix(line, slashSlash) {
continue Lines
}
if bytes.HasPrefix(line, slashStar) {
inSlashStar = true
line = bytes.TrimSpace(line[len(slashStar):])
continue Comments
}
// Found non-comment text.
break Lines
}
}
return content[:end], goBuild, sawBinaryOnly, nil
}
// saveCgo saves the information from the #cgo lines in the import "C" comment.
// These lines set CFLAGS, CPPFLAGS, CXXFLAGS and LDFLAGS and pkg-config directives
// that affect the way cgo's C code is built.
func (ctxt *Context) saveCgo(filename string, di *Package, cg *ast.CommentGroup) error {
text := cg.Text()
for _, line := range strings.Split(text, "\n") {
orig := line
// Line is
// #cgo [GOOS/GOARCH...] LDFLAGS: stuff
//
line = strings.TrimSpace(line)
if len(line) < 5 || line[:4] != "#cgo" || (line[4] != ' ' && line[4] != '\t') {
continue
}
// #cgo (nocallback|noescape) <function name>
if fields := strings.Fields(line); len(fields) == 3 && (fields[1] == "nocallback" || fields[1] == "noescape") {
continue
}
// Split at colon.
line, argstr, ok := strings.Cut(strings.TrimSpace(line[4:]), ":")
if !ok {
return fmt.Errorf("%s: invalid #cgo line: %s", filename, orig)
}
// Parse GOOS/GOARCH stuff.
f := strings.Fields(line)
if len(f) < 1 {
return fmt.Errorf("%s: invalid #cgo line: %s", filename, orig)
}
cond, verb := f[:len(f)-1], f[len(f)-1]
if len(cond) > 0 {
ok := false
for _, c := range cond {
if ctxt.matchAuto(c, nil) {
ok = true
break
}
}
if !ok {
continue
}
}
args, err := splitQuoted(argstr)
if err != nil {
return fmt.Errorf("%s: invalid #cgo line: %s", filename, orig)
}
for i, arg := range args {
if arg, ok = expandSrcDir(arg, di.Dir); !ok {
return fmt.Errorf("%s: malformed #cgo argument: %s", filename, arg)
}
args[i] = arg
}
switch verb {
case "CFLAGS", "CPPFLAGS", "CXXFLAGS", "FFLAGS", "LDFLAGS":
// Change relative paths to absolute.
ctxt.makePathsAbsolute(args, di.Dir)
}
switch verb {
case "CFLAGS":
di.CgoCFLAGS = append(di.CgoCFLAGS, args...)
case "CPPFLAGS":
di.CgoCPPFLAGS = append(di.CgoCPPFLAGS, args...)
case "CXXFLAGS":
di.CgoCXXFLAGS = append(di.CgoCXXFLAGS, args...)
case "FFLAGS":
di.CgoFFLAGS = append(di.CgoFFLAGS, args...)
case "LDFLAGS":
di.CgoLDFLAGS = append(di.CgoLDFLAGS, args...)
case "pkg-config":
di.CgoPkgConfig = append(di.CgoPkgConfig, args...)
default:
return fmt.Errorf("%s: invalid #cgo verb: %s", filename, orig)
}
}
return nil
}
// expandSrcDir expands any occurrence of ${SRCDIR}, making sure
// the result is safe for the shell.
func expandSrcDir(str string, srcdir string) (string, bool) {
// "\" delimited paths cause safeCgoName to fail
// so convert native paths with a different delimiter
// to "/" before starting (eg: on windows).
srcdir = filepath.ToSlash(srcdir)
chunks := strings.Split(str, "${SRCDIR}")
if len(chunks) < 2 {
return str, safeCgoName(str)
}
ok := true
for _, chunk := range chunks {
ok = ok && (chunk == "" || safeCgoName(chunk))
}
ok = ok && (srcdir == "" || safeCgoName(srcdir))
res := strings.Join(chunks, srcdir)
return res, ok && res != ""
}
// makePathsAbsolute looks for compiler options that take paths and
// makes them absolute. We do this because through the 1.8 release we
// ran the compiler in the package directory, so any relative -I or -L
// options would be relative to that directory. In 1.9 we changed to
// running the compiler in the build directory, to get consistent
// build results (issue #19964). To keep builds working, we change any
// relative -I or -L options to be absolute.
//
// Using filepath.IsAbs and filepath.Join here means the results will be
// different on different systems, but that's OK: -I and -L options are
// inherently system-dependent.
func (ctxt *Context) makePathsAbsolute(args []string, srcDir string) {
nextPath := false
for i, arg := range args {
if nextPath {
if !filepath.IsAbs(arg) {
args[i] = filepath.Join(srcDir, arg)
}
nextPath = false
} else if strings.HasPrefix(arg, "-I") || strings.HasPrefix(arg, "-L") {
if len(arg) == 2 {
nextPath = true
} else {
if !filepath.IsAbs(arg[2:]) {
args[i] = arg[:2] + filepath.Join(srcDir, arg[2:])
}
}
}
}
}
// NOTE: $ is not safe for the shell, but it is allowed here because of linker options like -Wl,$ORIGIN.
// We never pass these arguments to a shell (just to programs we construct argv for), so this should be okay.
// See golang.org/issue/6038.
// The @ is for OS X. See golang.org/issue/13720.
// The % is for Jenkins. See golang.org/issue/16959.
// The ! is because module paths may use them. See golang.org/issue/26716.
// The ~ and ^ are for sr.ht. See golang.org/issue/32260.
const safeString = "+-.,/0123456789=ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz:$@%! ~^"
func safeCgoName(s string) bool {
if s == "" {
return false
}
for i := 0; i < len(s); i++ {
if c := s[i]; c < utf8.RuneSelf && strings.IndexByte(safeString, c) < 0 {
return false
}
}
return true
}
// splitQuoted splits the string s around each instance of one or more consecutive
// white space characters while taking into account quotes and escaping, and
// returns an array of substrings of s or an empty list if s contains only white space.
// Single quotes and double quotes are recognized to prevent splitting within the
// quoted region, and are removed from the resulting substrings. If a quote in s
// isn't closed err will be set and r will have the unclosed argument as the
// last element. The backslash is used for escaping.
//
// For example, the following string:
//
// a b:"c d" 'e''f' "g\""
//
// Would be parsed as:
//
// []string{"a", "b:c d", "ef", `g"`}
func splitQuoted(s string) (r []string, err error) {
var args []string
arg := make([]rune, len(s))
escaped := false
quoted := false
quote := '\x00'
i := 0
for _, rune := range s {
switch {
case escaped:
escaped = false
case rune == '\\':
escaped = true
continue
case quote != '\x00':
if rune == quote {
quote = '\x00'
continue
}
case rune == '"' || rune == '\'':
quoted = true
quote = rune
continue
case unicode.IsSpace(rune):
if quoted || i > 0 {
quoted = false
args = append(args, string(arg[:i]))
i = 0
}
continue
}
arg[i] = rune
i++
}
if quoted || i > 0 {
args = append(args, string(arg[:i]))
}
if quote != 0 {
err = errors.New("unclosed quote")
} else if escaped {
err = errors.New("unfinished escaping")
}
return args, err
}
// matchAuto interprets text as either a +build or //go:build expression (whichever works),
// reporting whether the expression matches the build context.
//
// matchAuto is only used for testing of tag evaluation
// and in #cgo lines, which accept either syntax.
func (ctxt *Context) matchAuto(text string, allTags map[string]bool) bool {
if strings.ContainsAny(text, "&|()") {
text = "//go:build " + text
} else {
text = "// +build " + text
}
x, err := constraint.Parse(text)
if err != nil {
return false
}
return ctxt.eval(x, allTags)
}
func (ctxt *Context) eval(x constraint.Expr, allTags map[string]bool) bool {
return x.Eval(func(tag string) bool { return ctxt.matchTag(tag, allTags) })
}
// matchTag reports whether the name is one of:
//
// cgo (if cgo is enabled)
// $GOOS
// $GOARCH
// ctxt.Compiler
// linux (if GOOS = android)
// solaris (if GOOS = illumos)
// darwin (if GOOS = ios)
// unix (if this is a Unix GOOS)
// boringcrypto (if GOEXPERIMENT=boringcrypto is enabled)
// tag (if tag is listed in ctxt.BuildTags, ctxt.ToolTags, or ctxt.ReleaseTags)
//
// It records all consulted tags in allTags.
func (ctxt *Context) matchTag(name string, allTags map[string]bool) bool {
if allTags != nil {
allTags[name] = true
}
// special tags
if ctxt.CgoEnabled && name == "cgo" {
return true
}
if name == ctxt.GOOS || name == ctxt.GOARCH || name == ctxt.Compiler {
return true
}
if ctxt.GOOS == "android" && name == "linux" {
return true
}
if ctxt.GOOS == "illumos" && name == "solaris" {
return true
}
if ctxt.GOOS == "ios" && name == "darwin" {
return true
}
if name == "unix" && syslist.UnixOS[ctxt.GOOS] {
return true
}
if name == "boringcrypto" {
name = "goexperiment.boringcrypto" // boringcrypto is an old name for goexperiment.boringcrypto
}
// other tags
return slices.Contains(ctxt.BuildTags, name) || slices.Contains(ctxt.ToolTags, name) ||
slices.Contains(ctxt.ReleaseTags, name)
}
// goodOSArchFile returns false if the name contains a $GOOS or $GOARCH
// suffix which does not match the current system.
// The recognized name formats are:
//
// name_$(GOOS).*
// name_$(GOARCH).*
// name_$(GOOS)_$(GOARCH).*
// name_$(GOOS)_test.*
// name_$(GOARCH)_test.*
// name_$(GOOS)_$(GOARCH)_test.*
//
// Exceptions:
// if GOOS=android, then files with GOOS=linux are also matched.
// if GOOS=illumos, then files with GOOS=solaris are also matched.
// if GOOS=ios, then files with GOOS=darwin are also matched.
func (ctxt *Context) goodOSArchFile(name string, allTags map[string]bool) bool {
name, _, _ = strings.Cut(name, ".")
// Before Go 1.4, a file called "linux.go" would be equivalent to having a
// build tag "linux" in that file. For Go 1.4 and beyond, we require this
// auto-tagging to apply only to files with a non-empty prefix, so
// "foo_linux.go" is tagged but "linux.go" is not. This allows new operating
// systems, such as android, to arrive without breaking existing code with
// innocuous source code in "android.go". The easiest fix: cut everything
// in the name before the initial _.
i := strings.Index(name, "_")
if i < 0 {
return true
}
name = name[i:] // ignore everything before first _
l := strings.Split(name, "_")
if n := len(l); n > 0 && l[n-1] == "test" {
l = l[:n-1]
}
n := len(l)
if n >= 2 && syslist.KnownOS[l[n-2]] && syslist.KnownArch[l[n-1]] {
if allTags != nil {
// In case we short-circuit on l[n-1].
allTags[l[n-2]] = true
}
return ctxt.matchTag(l[n-1], allTags) && ctxt.matchTag(l[n-2], allTags)
}
if n >= 1 && (syslist.KnownOS[l[n-1]] || syslist.KnownArch[l[n-1]]) {
return ctxt.matchTag(l[n-1], allTags)
}
return true
}
// ToolDir is the directory containing build tools.
var ToolDir = getToolDir()
// IsLocalImport reports whether the import path is
// a local import path, like ".", "..", "./foo", or "../foo".
func IsLocalImport(path string) bool {
return path == "." || path == ".." ||
strings.HasPrefix(path, "./") || strings.HasPrefix(path, "../")
}
// ArchChar returns "?" and an error.
// In earlier versions of Go, the returned string was used to derive
// the compiler and linker tool names, the default object file suffix,
// and the default linker output name. As of Go 1.5, those strings
// no longer vary by architecture; they are compile, link, .o, and a.out, respectively.
func ArchChar(goarch string) (string, error) {
return "?", errors.New("architecture letter no longer used")
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package constraint implements parsing and evaluation of build constraint lines.
// See https://golang.org/cmd/go/#hdr-Build_constraints for documentation about build constraints themselves.
//
// This package parses both the original “// +build” syntax and the “//go:build” syntax that was added in Go 1.17.
// See https://golang.org/design/draft-gobuild for details about the “//go:build” syntax.
package constraint
import (
"errors"
"strings"
"unicode"
"unicode/utf8"
)
// maxSize is a limit used to control the complexity of expressions, in order
// to prevent stack exhaustion issues due to recursion.
const maxSize = 1000
// An Expr is a build tag constraint expression.
// The underlying concrete type is *[AndExpr], *[OrExpr], *[NotExpr], or *[TagExpr].
type Expr interface {
// String returns the string form of the expression,
// using the boolean syntax used in //go:build lines.
String() string
// Eval reports whether the expression evaluates to true.
// It calls ok(tag) as needed to find out whether a given build tag
// is satisfied by the current build configuration.
Eval(ok func(tag string) bool) bool
// The presence of an isExpr method explicitly marks the type as an Expr.
// Only implementations in this package should be used as Exprs.
isExpr()
}
// A TagExpr is an [Expr] for the single tag Tag.
type TagExpr struct {
Tag string // for example, “linux” or “cgo”
}
func (x *TagExpr) isExpr() {}
func (x *TagExpr) Eval(ok func(tag string) bool) bool {
return ok(x.Tag)
}
func (x *TagExpr) String() string {
return x.Tag
}
func tag(tag string) Expr { return &TagExpr{tag} }
// A NotExpr represents the expression !X (the negation of X).
type NotExpr struct {
X Expr
}
func (x *NotExpr) isExpr() {}
func (x *NotExpr) Eval(ok func(tag string) bool) bool {
return !x.X.Eval(ok)
}
func (x *NotExpr) String() string {
s := x.X.String()
switch x.X.(type) {
case *AndExpr, *OrExpr:
s = "(" + s + ")"
}
return "!" + s
}
func not(x Expr) Expr { return &NotExpr{x} }
// An AndExpr represents the expression X && Y.
type AndExpr struct {
X, Y Expr
}
func (x *AndExpr) isExpr() {}
func (x *AndExpr) Eval(ok func(tag string) bool) bool {
// Note: Eval both, to make sure ok func observes all tags.
xok := x.X.Eval(ok)
yok := x.Y.Eval(ok)
return xok && yok
}
func (x *AndExpr) String() string {
return andArg(x.X) + " && " + andArg(x.Y)
}
func andArg(x Expr) string {
s := x.String()
if _, ok := x.(*OrExpr); ok {
s = "(" + s + ")"
}
return s
}
func and(x, y Expr) Expr {
return &AndExpr{x, y}
}
// An OrExpr represents the expression X || Y.
type OrExpr struct {
X, Y Expr
}
func (x *OrExpr) isExpr() {}
func (x *OrExpr) Eval(ok func(tag string) bool) bool {
// Note: Eval both, to make sure ok func observes all tags.
xok := x.X.Eval(ok)
yok := x.Y.Eval(ok)
return xok || yok
}
func (x *OrExpr) String() string {
return orArg(x.X) + " || " + orArg(x.Y)
}
func orArg(x Expr) string {
s := x.String()
if _, ok := x.(*AndExpr); ok {
s = "(" + s + ")"
}
return s
}
func or(x, y Expr) Expr {
return &OrExpr{x, y}
}
// A SyntaxError reports a syntax error in a parsed build expression.
type SyntaxError struct {
Offset int // byte offset in input where error was detected
Err string // description of error
}
func (e *SyntaxError) Error() string {
return e.Err
}
var errNotConstraint = errors.New("not a build constraint")
// Parse parses a single build constraint line of the form “//go:build ...” or “// +build ...”
// and returns the corresponding boolean expression.
func Parse(line string) (Expr, error) {
if text, ok := splitGoBuild(line); ok {
return parseExpr(text)
}
if text, ok := splitPlusBuild(line); ok {
return parsePlusBuildExpr(text)
}
return nil, errNotConstraint
}
// IsGoBuild reports whether the line of text is a “//go:build” constraint.
// It only checks the prefix of the text, not that the expression itself parses.
func IsGoBuild(line string) bool {
_, ok := splitGoBuild(line)
return ok
}
// splitGoBuild splits apart the leading //go:build prefix in line from the build expression itself.
// It returns "", false if the input is not a //go:build line or if the input contains multiple lines.
func splitGoBuild(line string) (expr string, ok bool) {
// A single trailing newline is OK; otherwise multiple lines are not.
if len(line) > 0 && line[len(line)-1] == '\n' {
line = line[:len(line)-1]
}
if strings.Contains(line, "\n") {
return "", false
}
if !strings.HasPrefix(line, "//go:build") {
return "", false
}
line = strings.TrimSpace(line)
line = line[len("//go:build"):]
// If strings.TrimSpace finds more to trim after removing the //go:build prefix,
// it means that the prefix was followed by a space, making this a //go:build line
// (as opposed to a //go:buildsomethingelse line).
// If line is empty, we had "//go:build" by itself, which also counts.
trim := strings.TrimSpace(line)
if len(line) == len(trim) && line != "" {
return "", false
}
return trim, true
}
// An exprParser holds state for parsing a build expression.
type exprParser struct {
s string // input string
i int // next read location in s
tok string // last token read
isTag bool
pos int // position (start) of last token
size int
}
// parseExpr parses a boolean build tag expression.
func parseExpr(text string) (x Expr, err error) {
defer func() {
if e := recover(); e != nil {
if e, ok := e.(*SyntaxError); ok {
err = e
return
}
panic(e) // unreachable unless parser has a bug
}
}()
p := &exprParser{s: text}
x = p.or()
if p.tok != "" {
panic(&SyntaxError{Offset: p.pos, Err: "unexpected token " + p.tok})
}
return x, nil
}
// or parses a sequence of || expressions.
// On entry, the next input token has not yet been lexed.
// On exit, the next input token has been lexed and is in p.tok.
func (p *exprParser) or() Expr {
x := p.and()
for p.tok == "||" {
x = or(x, p.and())
}
return x
}
// and parses a sequence of && expressions.
// On entry, the next input token has not yet been lexed.
// On exit, the next input token has been lexed and is in p.tok.
func (p *exprParser) and() Expr {
x := p.not()
for p.tok == "&&" {
x = and(x, p.not())
}
return x
}
// not parses a ! expression.
// On entry, the next input token has not yet been lexed.
// On exit, the next input token has been lexed and is in p.tok.
func (p *exprParser) not() Expr {
p.size++
if p.size > maxSize {
panic(&SyntaxError{Offset: p.pos, Err: "build expression too large"})
}
p.lex()
if p.tok == "!" {
p.lex()
if p.tok == "!" {
panic(&SyntaxError{Offset: p.pos, Err: "double negation not allowed"})
}
return not(p.atom())
}
return p.atom()
}
// atom parses a tag or a parenthesized expression.
// On entry, the next input token HAS been lexed.
// On exit, the next input token has been lexed and is in p.tok.
func (p *exprParser) atom() Expr {
// first token already in p.tok
if p.tok == "(" {
pos := p.pos
defer func() {
if e := recover(); e != nil {
if e, ok := e.(*SyntaxError); ok && e.Err == "unexpected end of expression" {
e.Err = "missing close paren"
}
panic(e)
}
}()
x := p.or()
if p.tok != ")" {
panic(&SyntaxError{Offset: pos, Err: "missing close paren"})
}
p.lex()
return x
}
if !p.isTag {
if p.tok == "" {
panic(&SyntaxError{Offset: p.pos, Err: "unexpected end of expression"})
}
panic(&SyntaxError{Offset: p.pos, Err: "unexpected token " + p.tok})
}
tok := p.tok
p.lex()
return tag(tok)
}
// lex finds and consumes the next token in the input stream.
// On return, p.tok is set to the token text,
// p.isTag reports whether the token was a tag,
// and p.pos records the byte offset of the start of the token in the input stream.
// If lex reaches the end of the input, p.tok is set to the empty string.
// For any other syntax error, lex panics with a SyntaxError.
func (p *exprParser) lex() {
p.isTag = false
for p.i < len(p.s) && (p.s[p.i] == ' ' || p.s[p.i] == '\t') {
p.i++
}
if p.i >= len(p.s) {
p.tok = ""
p.pos = p.i
return
}
switch p.s[p.i] {
case '(', ')', '!':
p.pos = p.i
p.i++
p.tok = p.s[p.pos:p.i]
return
case '&', '|':
if p.i+1 >= len(p.s) || p.s[p.i+1] != p.s[p.i] {
panic(&SyntaxError{Offset: p.i, Err: "invalid syntax at " + string(rune(p.s[p.i]))})
}
p.pos = p.i
p.i += 2
p.tok = p.s[p.pos:p.i]
return
}
tag := p.s[p.i:]
for i, c := range tag {
if !unicode.IsLetter(c) && !unicode.IsDigit(c) && c != '_' && c != '.' {
tag = tag[:i]
break
}
}
if tag == "" {
c, _ := utf8.DecodeRuneInString(p.s[p.i:])
panic(&SyntaxError{Offset: p.i, Err: "invalid syntax at " + string(c)})
}
p.pos = p.i
p.i += len(tag)
p.tok = p.s[p.pos:p.i]
p.isTag = true
}
// IsPlusBuild reports whether the line of text is a “// +build” constraint.
// It only checks the prefix of the text, not that the expression itself parses.
func IsPlusBuild(line string) bool {
_, ok := splitPlusBuild(line)
return ok
}
// splitPlusBuild splits apart the leading // +build prefix in line from the build expression itself.
// It returns "", false if the input is not a // +build line or if the input contains multiple lines.
func splitPlusBuild(line string) (expr string, ok bool) {
// A single trailing newline is OK; otherwise multiple lines are not.
if len(line) > 0 && line[len(line)-1] == '\n' {
line = line[:len(line)-1]
}
if strings.Contains(line, "\n") {
return "", false
}
if !strings.HasPrefix(line, "//") {
return "", false
}
line = line[len("//"):]
// Note the space is optional; "//+build" is recognized too.
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "+build") {
return "", false
}
line = line[len("+build"):]
// If strings.TrimSpace finds more to trim after removing the +build prefix,
// it means that the prefix was followed by a space, making this a +build line
// (as opposed to a +buildsomethingelse line).
// If line is empty, we had "// +build" by itself, which also counts.
trim := strings.TrimSpace(line)
if len(line) == len(trim) && line != "" {
return "", false
}
return trim, true
}
// parsePlusBuildExpr parses a legacy build tag expression (as used with “// +build”).
func parsePlusBuildExpr(text string) (Expr, error) {
// Only allow up to 100 AND/OR operators for "old" syntax.
// This is much less than the limit for "new" syntax,
// but uses of old syntax were always very simple.
const maxOldSize = 100
size := 0
var x Expr
for _, clause := range strings.Fields(text) {
var y Expr
for _, lit := range strings.Split(clause, ",") {
var z Expr
var neg bool
if strings.HasPrefix(lit, "!!") || lit == "!" {
z = tag("ignore")
} else {
if strings.HasPrefix(lit, "!") {
neg = true
lit = lit[len("!"):]
}
if isValidTag(lit) {
z = tag(lit)
} else {
z = tag("ignore")
}
if neg {
z = not(z)
}
}
if y == nil {
y = z
} else {
if size++; size > maxOldSize {
return nil, errComplex
}
y = and(y, z)
}
}
if x == nil {
x = y
} else {
if size++; size > maxOldSize {
return nil, errComplex
}
x = or(x, y)
}
}
if x == nil {
x = tag("ignore")
}
return x, nil
}
// isValidTag reports whether the word is a valid build tag.
// Tags must be letters, digits, underscores or dots.
// Unlike in Go identifiers, all digits are fine (e.g., "386").
func isValidTag(word string) bool {
if word == "" {
return false
}
for _, c := range word {
if !unicode.IsLetter(c) && !unicode.IsDigit(c) && c != '_' && c != '.' {
return false
}
}
return true
}
var errComplex = errors.New("expression too complex for // +build lines")
// PlusBuildLines returns a sequence of “// +build” lines that evaluate to the build expression x.
// If the expression is too complex to convert directly to “// +build” lines, PlusBuildLines returns an error.
func PlusBuildLines(x Expr) ([]string, error) {
// Push all NOTs to the expression leaves, so that //go:build !(x && y) can be treated as !x || !y.
// This rewrite is both efficient and commonly needed, so it's worth doing.
// Essentially all other possible rewrites are too expensive and too rarely needed.
x = pushNot(x, false)
// Split into AND of ORs of ANDs of literals (tag or NOT tag).
var split [][][]Expr
for _, or := range appendSplitAnd(nil, x) {
var ands [][]Expr
for _, and := range appendSplitOr(nil, or) {
var lits []Expr
for _, lit := range appendSplitAnd(nil, and) {
switch lit.(type) {
case *TagExpr, *NotExpr:
lits = append(lits, lit)
default:
return nil, errComplex
}
}
ands = append(ands, lits)
}
split = append(split, ands)
}
// If all the ORs have length 1 (no actual OR'ing going on),
// push the top-level ANDs to the bottom level, so that we get
// one // +build line instead of many.
maxOr := 0
for _, or := range split {
if maxOr < len(or) {
maxOr = len(or)
}
}
if maxOr == 1 {
var lits []Expr
for _, or := range split {
lits = append(lits, or[0]...)
}
split = [][][]Expr{{lits}}
}
// Prepare the +build lines.
var lines []string
for _, or := range split {
line := "// +build"
for _, and := range or {
clause := ""
for i, lit := range and {
if i > 0 {
clause += ","
}
clause += lit.String()
}
line += " " + clause
}
lines = append(lines, line)
}
return lines, nil
}
// pushNot applies DeMorgan's law to push negations down the expression,
// so that only tags are negated in the result.
// (It applies the rewrites !(X && Y) => (!X || !Y) and !(X || Y) => (!X && !Y).)
func pushNot(x Expr, not bool) Expr {
switch x := x.(type) {
default:
// unreachable
return x
case *NotExpr:
if _, ok := x.X.(*TagExpr); ok && !not {
return x
}
return pushNot(x.X, !not)
case *TagExpr:
if not {
return &NotExpr{X: x}
}
return x
case *AndExpr:
x1 := pushNot(x.X, not)
y1 := pushNot(x.Y, not)
if not {
return or(x1, y1)
}
if x1 == x.X && y1 == x.Y {
return x
}
return and(x1, y1)
case *OrExpr:
x1 := pushNot(x.X, not)
y1 := pushNot(x.Y, not)
if not {
return and(x1, y1)
}
if x1 == x.X && y1 == x.Y {
return x
}
return or(x1, y1)
}
}
// appendSplitAnd appends x to list while splitting apart any top-level && expressions.
// For example, appendSplitAnd({W}, X && Y && Z) = {W, X, Y, Z}.
func appendSplitAnd(list []Expr, x Expr) []Expr {
if x, ok := x.(*AndExpr); ok {
list = appendSplitAnd(list, x.X)
list = appendSplitAnd(list, x.Y)
return list
}
return append(list, x)
}
// appendSplitOr appends x to list while splitting apart any top-level || expressions.
// For example, appendSplitOr({W}, X || Y || Z) = {W, X, Y, Z}.
func appendSplitOr(list []Expr, x Expr) []Expr {
if x, ok := x.(*OrExpr); ok {
list = appendSplitOr(list, x.X)
list = appendSplitOr(list, x.Y)
return list
}
return append(list, x)
}
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package constraint
import (
"strconv"
"strings"
)
// GoVersion returns the minimum Go version implied by a given build expression.
// If the expression can be satisfied without any Go version tags, GoVersion returns an empty string.
//
// For example:
//
// GoVersion(linux && go1.22) = "go1.22"
// GoVersion((linux && go1.22) || (windows && go1.20)) = "go1.20" => go1.20
// GoVersion(linux) = ""
// GoVersion(linux || (windows && go1.22)) = ""
// GoVersion(!go1.22) = ""
//
// GoVersion assumes that any tag or negated tag may independently be true,
// so that its analysis can be purely structural, without SAT solving.
// “Impossible” subexpressions may therefore affect the result.
//
// For example:
//
// GoVersion((linux && !linux && go1.20) || go1.21) = "go1.20"
func GoVersion(x Expr) string {
v := minVersion(x, +1)
if v < 0 {
return ""
}
if v == 0 {
return "go1"
}
return "go1." + strconv.Itoa(v)
}
// minVersion returns the minimum Go major version (9 for go1.9)
// implied by expression z, or if sign < 0, by expression !z.
func minVersion(z Expr, sign int) int {
switch z := z.(type) {
default:
return -1
case *AndExpr:
op := andVersion
if sign < 0 {
op = orVersion
}
return op(minVersion(z.X, sign), minVersion(z.Y, sign))
case *OrExpr:
op := orVersion
if sign < 0 {
op = andVersion
}
return op(minVersion(z.X, sign), minVersion(z.Y, sign))
case *NotExpr:
return minVersion(z.X, -sign)
case *TagExpr:
if sign < 0 {
// !foo implies nothing
return -1
}
if z.Tag == "go1" {
return 0
}
_, v, _ := strings.Cut(z.Tag, "go1.")
n, err := strconv.Atoi(v)
if err != nil {
// not a go1.N tag
return -1
}
return n
}
}
// andVersion returns the minimum Go version
// implied by the AND of two minimum Go versions,
// which is the max of the versions.
func andVersion(x, y int) int {
if x > y {
return x
}
return y
}
// orVersion returns the minimum Go version
// implied by the OR of two minimum Go versions,
// which is the min of the versions.
func orVersion(x, y int) int {
if x < y {
return x
}
return y
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc
package build
import (
"path/filepath"
"runtime"
)
// getToolDir returns the default value of ToolDir.
func getToolDir() string {
return filepath.Join(runtime.GOROOT(), "pkg/tool/"+runtime.GOOS+"_"+runtime.GOARCH)
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package build
import (
"bufio"
"bytes"
"errors"
"fmt"
"go/ast"
"go/parser"
"go/scanner"
"go/token"
"io"
"strconv"
"strings"
"unicode"
"unicode/utf8"
_ "unsafe" // for linkname
)
type importReader struct {
b *bufio.Reader
buf []byte
peek byte
err error
eof bool
nerr int
pos token.Position
}
var bom = []byte{0xef, 0xbb, 0xbf}
func newImportReader(name string, r io.Reader) *importReader {
b := bufio.NewReader(r)
// Remove leading UTF-8 BOM.
// Per https://golang.org/ref/spec#Source_code_representation:
// a compiler may ignore a UTF-8-encoded byte order mark (U+FEFF)
// if it is the first Unicode code point in the source text.
if leadingBytes, err := b.Peek(3); err == nil && bytes.Equal(leadingBytes, bom) {
b.Discard(3)
}
return &importReader{
b: b,
pos: token.Position{
Filename: name,
Line: 1,
Column: 1,
},
}
}
func isIdent(c byte) bool {
return 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' || c == '_' || c >= utf8.RuneSelf
}
var (
errSyntax = errors.New("syntax error")
errNUL = errors.New("unexpected NUL in input")
)
// syntaxError records a syntax error, but only if an I/O error has not already been recorded.
func (r *importReader) syntaxError() {
if r.err == nil {
r.err = errSyntax
}
}
// readByte reads the next byte from the input, saves it in buf, and returns it.
// If an error occurs, readByte records the error in r.err and returns 0.
func (r *importReader) readByte() byte {
c, err := r.b.ReadByte()
if err == nil {
r.buf = append(r.buf, c)
if c == 0 {
err = errNUL
}
}
if err != nil {
if err == io.EOF {
r.eof = true
} else if r.err == nil {
r.err = err
}
c = 0
}
return c
}
// readByteNoBuf is like readByte but doesn't buffer the byte.
// It exhausts r.buf before reading from r.b.
func (r *importReader) readByteNoBuf() byte {
var c byte
var err error
if len(r.buf) > 0 {
c = r.buf[0]
r.buf = r.buf[1:]
} else {
c, err = r.b.ReadByte()
if err == nil && c == 0 {
err = errNUL
}
}
if err != nil {
if err == io.EOF {
r.eof = true
} else if r.err == nil {
r.err = err
}
return 0
}
r.pos.Offset++
if c == '\n' {
r.pos.Line++
r.pos.Column = 1
} else {
r.pos.Column++
}
return c
}
// peekByte returns the next byte from the input reader but does not advance beyond it.
// If skipSpace is set, peekByte skips leading spaces and comments.
func (r *importReader) peekByte(skipSpace bool) byte {
if r.err != nil {
if r.nerr++; r.nerr > 10000 {
panic("go/build: import reader looping")
}
return 0
}
// Use r.peek as first input byte.
// Don't just return r.peek here: it might have been left by peekByte(false)
// and this might be peekByte(true).
c := r.peek
if c == 0 {
c = r.readByte()
}
for r.err == nil && !r.eof {
if skipSpace {
// For the purposes of this reader, semicolons are never necessary to
// understand the input and are treated as spaces.
switch c {
case ' ', '\f', '\t', '\r', '\n', ';':
c = r.readByte()
continue
case '/':
c = r.readByte()
if c == '/' {
for c != '\n' && r.err == nil && !r.eof {
c = r.readByte()
}
} else if c == '*' {
var c1 byte
for (c != '*' || c1 != '/') && r.err == nil {
if r.eof {
r.syntaxError()
}
c, c1 = c1, r.readByte()
}
} else {
r.syntaxError()
}
c = r.readByte()
continue
}
}
break
}
r.peek = c
return r.peek
}
// nextByte is like peekByte but advances beyond the returned byte.
func (r *importReader) nextByte(skipSpace bool) byte {
c := r.peekByte(skipSpace)
r.peek = 0
return c
}
var goEmbed = []byte("go:embed")
// findEmbed advances the input reader to the next //go:embed comment.
// It reports whether it found a comment.
// (Otherwise it found an error or EOF.)
func (r *importReader) findEmbed(first bool) bool {
// The import block scan stopped after a non-space character,
// so the reader is not at the start of a line on the first call.
// After that, each //go:embed extraction leaves the reader
// at the end of a line.
startLine := !first
var c byte
for r.err == nil && !r.eof {
c = r.readByteNoBuf()
Reswitch:
switch c {
default:
startLine = false
case '\n':
startLine = true
case ' ', '\t':
// leave startLine alone
case '"':
startLine = false
for r.err == nil {
if r.eof {
r.syntaxError()
}
c = r.readByteNoBuf()
if c == '\\' {
r.readByteNoBuf()
if r.err != nil {
r.syntaxError()
return false
}
continue
}
if c == '"' {
c = r.readByteNoBuf()
goto Reswitch
}
}
goto Reswitch
case '`':
startLine = false
for r.err == nil {
if r.eof {
r.syntaxError()
}
c = r.readByteNoBuf()
if c == '`' {
c = r.readByteNoBuf()
goto Reswitch
}
}
case '\'':
startLine = false
for r.err == nil {
if r.eof {
r.syntaxError()
}
c = r.readByteNoBuf()
if c == '\\' {
r.readByteNoBuf()
if r.err != nil {
r.syntaxError()
return false
}
continue
}
if c == '\'' {
c = r.readByteNoBuf()
goto Reswitch
}
}
case '/':
c = r.readByteNoBuf()
switch c {
default:
startLine = false
goto Reswitch
case '*':
var c1 byte
for (c != '*' || c1 != '/') && r.err == nil {
if r.eof {
r.syntaxError()
}
c, c1 = c1, r.readByteNoBuf()
}
startLine = false
case '/':
if startLine {
// Try to read this as a //go:embed comment.
for i := range goEmbed {
c = r.readByteNoBuf()
if c != goEmbed[i] {
goto SkipSlashSlash
}
}
c = r.readByteNoBuf()
if c == ' ' || c == '\t' {
// Found one!
return true
}
}
SkipSlashSlash:
for c != '\n' && r.err == nil && !r.eof {
c = r.readByteNoBuf()
}
startLine = true
}
}
}
return false
}
// readKeyword reads the given keyword from the input.
// If the keyword is not present, readKeyword records a syntax error.
func (r *importReader) readKeyword(kw string) {
r.peekByte(true)
for i := 0; i < len(kw); i++ {
if r.nextByte(false) != kw[i] {
r.syntaxError()
return
}
}
if isIdent(r.peekByte(false)) {
r.syntaxError()
}
}
// readIdent reads an identifier from the input.
// If an identifier is not present, readIdent records a syntax error.
func (r *importReader) readIdent() {
c := r.peekByte(true)
if !isIdent(c) {
r.syntaxError()
return
}
for isIdent(r.peekByte(false)) {
r.peek = 0
}
}
// readString reads a quoted string literal from the input.
// If an identifier is not present, readString records a syntax error.
func (r *importReader) readString() {
switch r.nextByte(true) {
case '`':
for r.err == nil {
if r.nextByte(false) == '`' {
break
}
if r.eof {
r.syntaxError()
}
}
case '"':
for r.err == nil {
c := r.nextByte(false)
if c == '"' {
break
}
if r.eof || c == '\n' {
r.syntaxError()
}
if c == '\\' {
r.nextByte(false)
}
}
default:
r.syntaxError()
}
}
// readImport reads an import clause - optional identifier followed by quoted string -
// from the input.
func (r *importReader) readImport() {
c := r.peekByte(true)
if c == '.' {
r.peek = 0
} else if isIdent(c) {
r.readIdent()
}
r.readString()
}
// readComments is like io.ReadAll, except that it only reads the leading
// block of comments in the file.
//
// readComments should be an internal detail,
// but widely used packages access it using linkname.
// Notable members of the hall of shame include:
// - github.com/bazelbuild/bazel-gazelle
//
// Do not remove or change the type signature.
// See go.dev/issue/67401.
//
//go:linkname readComments
func readComments(f io.Reader) ([]byte, error) {
r := newImportReader("", f)
r.peekByte(true)
if r.err == nil && !r.eof {
// Didn't reach EOF, so must have found a non-space byte. Remove it.
r.buf = r.buf[:len(r.buf)-1]
}
return r.buf, r.err
}
// readGoInfo expects a Go file as input and reads the file up to and including the import section.
// It records what it learned in *info.
// If info.fset is non-nil, readGoInfo parses the file and sets info.parsed, info.parseErr,
// info.imports and info.embeds.
//
// It only returns an error if there are problems reading the file,
// not for syntax errors in the file itself.
func readGoInfo(f io.Reader, info *fileInfo) error {
r := newImportReader(info.name, f)
r.readKeyword("package")
r.readIdent()
for r.peekByte(true) == 'i' {
r.readKeyword("import")
if r.peekByte(true) == '(' {
r.nextByte(false)
for r.peekByte(true) != ')' && r.err == nil {
r.readImport()
}
r.nextByte(false)
} else {
r.readImport()
}
}
info.header = r.buf
// If we stopped successfully before EOF, we read a byte that told us we were done.
// Return all but that last byte, which would cause a syntax error if we let it through.
if r.err == nil && !r.eof {
info.header = r.buf[:len(r.buf)-1]
}
// If we stopped for a syntax error, consume the whole file so that
// we are sure we don't change the errors that go/parser returns.
if r.err == errSyntax {
r.err = nil
for r.err == nil && !r.eof {
r.readByte()
}
info.header = r.buf
}
if r.err != nil {
return r.err
}
if info.fset == nil {
return nil
}
// Parse file header & record imports.
info.parsed, info.parseErr = parser.ParseFile(info.fset, info.name, info.header, parser.ImportsOnly|parser.ParseComments)
if info.parseErr != nil {
return nil
}
hasEmbed := false
for _, decl := range info.parsed.Decls {
d, ok := decl.(*ast.GenDecl)
if !ok {
continue
}
for _, dspec := range d.Specs {
spec, ok := dspec.(*ast.ImportSpec)
if !ok {
continue
}
quoted := spec.Path.Value
path, err := strconv.Unquote(quoted)
if err != nil {
return fmt.Errorf("parser returned invalid quoted string: <%s>", quoted)
}
if !isValidImport(path) {
// The parser used to return a parse error for invalid import paths, but
// no longer does, so check for and create the error here instead.
info.parseErr = scanner.Error{Pos: info.fset.Position(spec.Pos()), Msg: "invalid import path: " + path}
info.imports = nil
return nil
}
if path == "embed" {
hasEmbed = true
}
doc := spec.Doc
if doc == nil && len(d.Specs) == 1 {
doc = d.Doc
}
info.imports = append(info.imports, fileImport{path, spec.Pos(), doc})
}
}
// Extract directives.
for _, group := range info.parsed.Comments {
if group.Pos() >= info.parsed.Package {
break
}
for _, c := range group.List {
if strings.HasPrefix(c.Text, "//go:") {
info.directives = append(info.directives, Directive{c.Text, info.fset.Position(c.Slash)})
}
}
}
// If the file imports "embed",
// we have to look for //go:embed comments
// in the remainder of the file.
// The compiler will enforce the mapping of comments to
// declared variables. We just need to know the patterns.
// If there were //go:embed comments earlier in the file
// (near the package statement or imports), the compiler
// will reject them. They can be (and have already been) ignored.
if hasEmbed {
var line []byte
for first := true; r.findEmbed(first); first = false {
line = line[:0]
pos := r.pos
for {
c := r.readByteNoBuf()
if c == '\n' || r.err != nil || r.eof {
break
}
line = append(line, c)
}
// Add args if line is well-formed.
// Ignore badly-formed lines - the compiler will report them when it finds them,
// and we can pretend they are not there to help go list succeed with what it knows.
embs, err := parseGoEmbed(string(line), pos)
if err == nil {
info.embeds = append(info.embeds, embs...)
}
}
}
return nil
}
// isValidImport checks if the import is a valid import using the more strict
// checks allowed by the implementation restriction in https://go.dev/ref/spec#Import_declarations.
// It was ported from the function of the same name that was removed from the
// parser in CL 424855, when the parser stopped doing these checks.
func isValidImport(s string) bool {
const illegalChars = `!"#$%&'()*,:;<=>?[\]^{|}` + "`\uFFFD"
for _, r := range s {
if !unicode.IsGraphic(r) || unicode.IsSpace(r) || strings.ContainsRune(illegalChars, r) {
return false
}
}
return s != ""
}
// parseGoEmbed parses the text following "//go:embed" to extract the glob patterns.
// It accepts unquoted space-separated patterns as well as double-quoted and back-quoted Go strings.
// This is based on a similar function in cmd/compile/internal/gc/noder.go;
// this version calculates position information as well.
func parseGoEmbed(args string, pos token.Position) ([]fileEmbed, error) {
trimBytes := func(n int) {
pos.Offset += n
pos.Column += utf8.RuneCountInString(args[:n])
args = args[n:]
}
trimSpace := func() {
trim := strings.TrimLeftFunc(args, unicode.IsSpace)
trimBytes(len(args) - len(trim))
}
var list []fileEmbed
for trimSpace(); args != ""; trimSpace() {
var path string
pathPos := pos
Switch:
switch args[0] {
default:
i := len(args)
for j, c := range args {
if unicode.IsSpace(c) {
i = j
break
}
}
path = args[:i]
trimBytes(i)
case '`':
var ok bool
path, _, ok = strings.Cut(args[1:], "`")
if !ok {
return nil, fmt.Errorf("invalid quoted string in //go:embed: %s", args)
}
trimBytes(1 + len(path) + 1)
case '"':
i := 1
for ; i < len(args); i++ {
if args[i] == '\\' {
i++
continue
}
if args[i] == '"' {
q, err := strconv.Unquote(args[:i+1])
if err != nil {
return nil, fmt.Errorf("invalid quoted string in //go:embed: %s", args[:i+1])
}
path = q
trimBytes(i + 1)
break Switch
}
}
if i >= len(args) {
return nil, fmt.Errorf("invalid quoted string in //go:embed: %s", args)
}
}
if args != "" {
r, _ := utf8.DecodeRuneInString(args)
if !unicode.IsSpace(r) {
return nil, fmt.Errorf("invalid quoted string in //go:embed: %s", args)
}
}
list = append(list, fileEmbed{path, pathPos})
}
return list, nil
}
// Code generated by "stringer -type Kind"; DO NOT EDIT.
package constant
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[Unknown-0]
_ = x[Bool-1]
_ = x[String-2]
_ = x[Int-3]
_ = x[Float-4]
_ = x[Complex-5]
}
const _Kind_name = "UnknownBoolStringIntFloatComplex"
var _Kind_index = [...]uint8{0, 7, 11, 17, 20, 25, 32}
func (i Kind) String() string {
if i < 0 || i >= Kind(len(_Kind_index)-1) {
return "Kind(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _Kind_name[_Kind_index[i]:_Kind_index[i+1]]
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package constant implements Values representing untyped
// Go constants and their corresponding operations.
//
// A special Unknown value may be used when a value
// is unknown due to an error. Operations on unknown
// values produce unknown values unless specified
// otherwise.
package constant
import (
"fmt"
"go/token"
"math"
"math/big"
"math/bits"
"strconv"
"strings"
"sync"
"unicode/utf8"
)
//go:generate stringer -type Kind
// Kind specifies the kind of value represented by a [Value].
type Kind int
const (
// unknown values
Unknown Kind = iota
// non-numeric values
Bool
String
// numeric values
Int
Float
Complex
)
// A Value represents the value of a Go constant.
type Value interface {
// Kind returns the value kind.
Kind() Kind
// String returns a short, quoted (human-readable) form of the value.
// For numeric values, the result may be an approximation;
// for String values the result may be a shortened string.
// Use ExactString for a string representing a value exactly.
String() string
// ExactString returns an exact, quoted (human-readable) form of the value.
// If the Value is of Kind String, use StringVal to obtain the unquoted string.
ExactString() string
// Prevent external implementations.
implementsValue()
}
// ----------------------------------------------------------------------------
// Implementations
// Maximum supported mantissa precision.
// The spec requires at least 256 bits; typical implementations use 512 bits.
const prec = 512
// TODO(gri) Consider storing "error" information in an unknownVal so clients
// can provide better error messages. For instance, if a number is
// too large (incl. infinity), that could be recorded in unknownVal.
// See also #20583 and #42695 for use cases.
// Representation of values:
//
// Values of Int and Float Kind have two different representations each: int64Val
// and intVal, and ratVal and floatVal. When possible, the "smaller", respectively
// more precise (for Floats) representation is chosen. However, once a Float value
// is represented as a floatVal, any subsequent results remain floatVals (unless
// explicitly converted); i.e., no attempt is made to convert a floatVal back into
// a ratVal. The reasoning is that all representations but floatVal are mathematically
// exact, but once that precision is lost (by moving to floatVal), moving back to
// a different representation implies a precision that's not actually there.
type (
unknownVal struct{}
boolVal bool
stringVal struct {
// Lazy value: either a string (l,r==nil) or an addition (l,r!=nil).
mu sync.Mutex
s string
l, r *stringVal
}
int64Val int64 // Int values representable as an int64
intVal struct{ val *big.Int } // Int values not representable as an int64
ratVal struct{ val *big.Rat } // Float values representable as a fraction
floatVal struct{ val *big.Float } // Float values not representable as a fraction
complexVal struct{ re, im Value }
)
func (unknownVal) Kind() Kind { return Unknown }
func (boolVal) Kind() Kind { return Bool }
func (*stringVal) Kind() Kind { return String }
func (int64Val) Kind() Kind { return Int }
func (intVal) Kind() Kind { return Int }
func (ratVal) Kind() Kind { return Float }
func (floatVal) Kind() Kind { return Float }
func (complexVal) Kind() Kind { return Complex }
func (unknownVal) String() string { return "unknown" }
func (x boolVal) String() string { return strconv.FormatBool(bool(x)) }
// String returns a possibly shortened quoted form of the String value.
func (x *stringVal) String() string {
const maxLen = 72 // a reasonable length
s := strconv.Quote(x.string())
if utf8.RuneCountInString(s) > maxLen {
// The string without the enclosing quotes is greater than maxLen-2 runes
// long. Remove the last 3 runes (including the closing '"') by keeping
// only the first maxLen-3 runes; then add "...".
i := 0
for n := 0; n < maxLen-3; n++ {
_, size := utf8.DecodeRuneInString(s[i:])
i += size
}
s = s[:i] + "..."
}
return s
}
// string constructs and returns the actual string literal value.
// If x represents an addition, then it rewrites x to be a single
// string, to speed future calls. This lazy construction avoids
// building different string values for all subpieces of a large
// concatenation. See golang.org/issue/23348.
func (x *stringVal) string() string {
x.mu.Lock()
if x.l != nil {
x.s = strings.Join(reverse(x.appendReverse(nil)), "")
x.l = nil
x.r = nil
}
s := x.s
x.mu.Unlock()
return s
}
// reverse reverses x in place and returns it.
func reverse(x []string) []string {
n := len(x)
for i := 0; i+i < n; i++ {
x[i], x[n-1-i] = x[n-1-i], x[i]
}
return x
}
// appendReverse appends to list all of x's subpieces, but in reverse,
// and returns the result. Appending the reversal allows processing
// the right side in a recursive call and the left side in a loop.
// Because a chain like a + b + c + d + e is actually represented
// as ((((a + b) + c) + d) + e), the left-side loop avoids deep recursion.
// x must be locked.
func (x *stringVal) appendReverse(list []string) []string {
y := x
for y.r != nil {
y.r.mu.Lock()
list = y.r.appendReverse(list)
y.r.mu.Unlock()
l := y.l
if y != x {
y.mu.Unlock()
}
l.mu.Lock()
y = l
}
s := y.s
if y != x {
y.mu.Unlock()
}
return append(list, s)
}
func (x int64Val) String() string { return strconv.FormatInt(int64(x), 10) }
func (x intVal) String() string { return x.val.String() }
func (x ratVal) String() string { return rtof(x).String() }
// String returns a decimal approximation of the Float value.
func (x floatVal) String() string {
f := x.val
// Don't try to convert infinities (will not terminate).
if f.IsInf() {
return f.String()
}
// Use exact fmt formatting if in float64 range (common case):
// proceed if f doesn't underflow to 0 or overflow to inf.
if x, _ := f.Float64(); f.Sign() == 0 == (x == 0) && !math.IsInf(x, 0) {
s := fmt.Sprintf("%.6g", x)
if !f.IsInt() && strings.IndexByte(s, '.') < 0 {
// f is not an integer, but its string representation
// doesn't reflect that. Use more digits. See issue 56220.
s = fmt.Sprintf("%g", x)
}
return s
}
// Out of float64 range. Do approximate manual to decimal
// conversion to avoid precise but possibly slow Float
// formatting.
// f = mant * 2**exp
var mant big.Float
exp := f.MantExp(&mant) // 0.5 <= |mant| < 1.0
// approximate float64 mantissa m and decimal exponent d
// f ~ m * 10**d
m, _ := mant.Float64() // 0.5 <= |m| < 1.0
d := float64(exp) * (math.Ln2 / math.Ln10) // log_10(2)
// adjust m for truncated (integer) decimal exponent e
e := int64(d)
m *= math.Pow(10, d-float64(e))
// ensure 1 <= |m| < 10
switch am := math.Abs(m); {
case am < 1-0.5e-6:
// The %.6g format below rounds m to 5 digits after the
// decimal point. Make sure that m*10 < 10 even after
// rounding up: m*10 + 0.5e-5 < 10 => m < 1 - 0.5e6.
m *= 10
e--
case am >= 10:
m /= 10
e++
}
return fmt.Sprintf("%.6ge%+d", m, e)
}
func (x complexVal) String() string { return fmt.Sprintf("(%s + %si)", x.re, x.im) }
func (x unknownVal) ExactString() string { return x.String() }
func (x boolVal) ExactString() string { return x.String() }
func (x *stringVal) ExactString() string { return strconv.Quote(x.string()) }
func (x int64Val) ExactString() string { return x.String() }
func (x intVal) ExactString() string { return x.String() }
func (x ratVal) ExactString() string {
r := x.val
if r.IsInt() {
return r.Num().String()
}
return r.String()
}
func (x floatVal) ExactString() string { return x.val.Text('p', 0) }
func (x complexVal) ExactString() string {
return fmt.Sprintf("(%s + %si)", x.re.ExactString(), x.im.ExactString())
}
func (unknownVal) implementsValue() {}
func (boolVal) implementsValue() {}
func (*stringVal) implementsValue() {}
func (int64Val) implementsValue() {}
func (ratVal) implementsValue() {}
func (intVal) implementsValue() {}
func (floatVal) implementsValue() {}
func (complexVal) implementsValue() {}
func newInt() *big.Int { return new(big.Int) }
func newRat() *big.Rat { return new(big.Rat) }
func newFloat() *big.Float { return new(big.Float).SetPrec(prec) }
func i64toi(x int64Val) intVal { return intVal{newInt().SetInt64(int64(x))} }
func i64tor(x int64Val) ratVal { return ratVal{newRat().SetInt64(int64(x))} }
func i64tof(x int64Val) floatVal { return floatVal{newFloat().SetInt64(int64(x))} }
func itor(x intVal) ratVal { return ratVal{newRat().SetInt(x.val)} }
func itof(x intVal) floatVal { return floatVal{newFloat().SetInt(x.val)} }
func rtof(x ratVal) floatVal { return floatVal{newFloat().SetRat(x.val)} }
func vtoc(x Value) complexVal { return complexVal{x, int64Val(0)} }
func makeInt(x *big.Int) Value {
if x.IsInt64() {
return int64Val(x.Int64())
}
return intVal{x}
}
func makeRat(x *big.Rat) Value {
a := x.Num()
b := x.Denom()
if smallInt(a) && smallInt(b) {
// ok to remain fraction
return ratVal{x}
}
// components too large => switch to float
return floatVal{newFloat().SetRat(x)}
}
var floatVal0 = floatVal{newFloat()}
func makeFloat(x *big.Float) Value {
// convert -0
if x.Sign() == 0 {
return floatVal0
}
if x.IsInf() {
return unknownVal{}
}
// No attempt is made to "go back" to ratVal, even if possible,
// to avoid providing the illusion of a mathematically exact
// representation.
return floatVal{x}
}
func makeComplex(re, im Value) Value {
if re.Kind() == Unknown || im.Kind() == Unknown {
return unknownVal{}
}
return complexVal{re, im}
}
func makeFloatFromLiteral(lit string) Value {
if f, ok := newFloat().SetString(lit); ok {
if smallFloat(f) {
// ok to use rationals
if f.Sign() == 0 {
// Issue 20228: If the float underflowed to zero, parse just "0".
// Otherwise, lit might contain a value with a large negative exponent,
// such as -6e-1886451601. As a float, that will underflow to 0,
// but it'll take forever to parse as a Rat.
lit = "0"
}
if r, ok := newRat().SetString(lit); ok {
return ratVal{r}
}
}
// otherwise use floats
return makeFloat(f)
}
return nil
}
// Permit fractions with component sizes up to maxExp
// before switching to using floating-point numbers.
const maxExp = 4 << 10
// smallInt reports whether x would lead to "reasonably"-sized fraction
// if converted to a *big.Rat.
func smallInt(x *big.Int) bool {
return x.BitLen() < maxExp
}
// smallFloat64 reports whether x would lead to "reasonably"-sized fraction
// if converted to a *big.Rat.
func smallFloat64(x float64) bool {
if math.IsInf(x, 0) {
return false
}
_, e := math.Frexp(x)
return -maxExp < e && e < maxExp
}
// smallFloat reports whether x would lead to "reasonably"-sized fraction
// if converted to a *big.Rat.
func smallFloat(x *big.Float) bool {
if x.IsInf() {
return false
}
e := x.MantExp(nil)
return -maxExp < e && e < maxExp
}
// ----------------------------------------------------------------------------
// Factories
// MakeUnknown returns the [Unknown] value.
func MakeUnknown() Value { return unknownVal{} }
// MakeBool returns the [Bool] value for b.
func MakeBool(b bool) Value { return boolVal(b) }
// MakeString returns the [String] value for s.
func MakeString(s string) Value {
if s == "" {
return &emptyString // common case
}
return &stringVal{s: s}
}
var emptyString stringVal
// MakeInt64 returns the [Int] value for x.
func MakeInt64(x int64) Value { return int64Val(x) }
// MakeUint64 returns the [Int] value for x.
func MakeUint64(x uint64) Value {
if x < 1<<63 {
return int64Val(int64(x))
}
return intVal{newInt().SetUint64(x)}
}
// MakeFloat64 returns the [Float] value for x.
// If x is -0.0, the result is 0.0.
// If x is not finite, the result is an [Unknown].
func MakeFloat64(x float64) Value {
if math.IsInf(x, 0) || math.IsNaN(x) {
return unknownVal{}
}
if smallFloat64(x) {
return ratVal{newRat().SetFloat64(x + 0)} // convert -0 to 0
}
return floatVal{newFloat().SetFloat64(x + 0)}
}
// MakeFromLiteral returns the corresponding integer, floating-point,
// imaginary, character, or string value for a Go literal string. The
// tok value must be one of [token.INT], [token.FLOAT], [token.IMAG],
// [token.CHAR], or [token.STRING]. The final argument must be zero.
// If the literal string syntax is invalid, the result is an [Unknown].
func MakeFromLiteral(lit string, tok token.Token, zero uint) Value {
if zero != 0 {
panic("MakeFromLiteral called with non-zero last argument")
}
switch tok {
case token.INT:
if x, err := strconv.ParseInt(lit, 0, 64); err == nil {
return int64Val(x)
}
if x, ok := newInt().SetString(lit, 0); ok {
return intVal{x}
}
case token.FLOAT:
if x := makeFloatFromLiteral(lit); x != nil {
return x
}
case token.IMAG:
if n := len(lit); n > 0 && lit[n-1] == 'i' {
if im := makeFloatFromLiteral(lit[:n-1]); im != nil {
return makeComplex(int64Val(0), im)
}
}
case token.CHAR:
if n := len(lit); n >= 2 {
if code, _, _, err := strconv.UnquoteChar(lit[1:n-1], '\''); err == nil {
return MakeInt64(int64(code))
}
}
case token.STRING:
if s, err := strconv.Unquote(lit); err == nil {
return MakeString(s)
}
default:
panic(fmt.Sprintf("%v is not a valid token", tok))
}
return unknownVal{}
}
// ----------------------------------------------------------------------------
// Accessors
//
// For unknown arguments the result is the zero value for the respective
// accessor type, except for Sign, where the result is 1.
// BoolVal returns the Go boolean value of x, which must be a [Bool] or an [Unknown].
// If x is [Unknown], the result is false.
func BoolVal(x Value) bool {
switch x := x.(type) {
case boolVal:
return bool(x)
case unknownVal:
return false
default:
panic(fmt.Sprintf("%v not a Bool", x))
}
}
// StringVal returns the Go string value of x, which must be a [String] or an [Unknown].
// If x is [Unknown], the result is "".
func StringVal(x Value) string {
switch x := x.(type) {
case *stringVal:
return x.string()
case unknownVal:
return ""
default:
panic(fmt.Sprintf("%v not a String", x))
}
}
// Int64Val returns the Go int64 value of x and whether the result is exact;
// x must be an [Int] or an [Unknown]. If the result is not exact, its value is undefined.
// If x is [Unknown], the result is (0, false).
func Int64Val(x Value) (int64, bool) {
switch x := x.(type) {
case int64Val:
return int64(x), true
case intVal:
return x.val.Int64(), false // not an int64Val and thus not exact
case unknownVal:
return 0, false
default:
panic(fmt.Sprintf("%v not an Int", x))
}
}
// Uint64Val returns the Go uint64 value of x and whether the result is exact;
// x must be an [Int] or an [Unknown]. If the result is not exact, its value is undefined.
// If x is [Unknown], the result is (0, false).
func Uint64Val(x Value) (uint64, bool) {
switch x := x.(type) {
case int64Val:
return uint64(x), x >= 0
case intVal:
return x.val.Uint64(), x.val.IsUint64()
case unknownVal:
return 0, false
default:
panic(fmt.Sprintf("%v not an Int", x))
}
}
// Float32Val is like [Float64Val] but for float32 instead of float64.
func Float32Val(x Value) (float32, bool) {
switch x := x.(type) {
case int64Val:
f := float32(x)
return f, int64Val(f) == x
case intVal:
f, acc := newFloat().SetInt(x.val).Float32()
return f, acc == big.Exact
case ratVal:
return x.val.Float32()
case floatVal:
f, acc := x.val.Float32()
return f, acc == big.Exact
case unknownVal:
return 0, false
default:
panic(fmt.Sprintf("%v not a Float", x))
}
}
// Float64Val returns the nearest Go float64 value of x and whether the result is exact;
// x must be numeric or an [Unknown], but not [Complex]. For values too small (too close to 0)
// to represent as float64, [Float64Val] silently underflows to 0. The result sign always
// matches the sign of x, even for 0.
// If x is [Unknown], the result is (0, false).
func Float64Val(x Value) (float64, bool) {
switch x := x.(type) {
case int64Val:
f := float64(int64(x))
return f, int64Val(f) == x
case intVal:
f, acc := newFloat().SetInt(x.val).Float64()
return f, acc == big.Exact
case ratVal:
return x.val.Float64()
case floatVal:
f, acc := x.val.Float64()
return f, acc == big.Exact
case unknownVal:
return 0, false
default:
panic(fmt.Sprintf("%v not a Float", x))
}
}
// Val returns the underlying value for a given constant. Since it returns an
// interface, it is up to the caller to type assert the result to the expected
// type. The possible dynamic return types are:
//
// x Kind type of result
// -----------------------------------------
// Bool bool
// String string
// Int int64 or *big.Int
// Float *big.Float or *big.Rat
// everything else nil
func Val(x Value) any {
switch x := x.(type) {
case boolVal:
return bool(x)
case *stringVal:
return x.string()
case int64Val:
return int64(x)
case intVal:
return x.val
case ratVal:
return x.val
case floatVal:
return x.val
default:
return nil
}
}
// Make returns the [Value] for x.
//
// type of x result Kind
// ----------------------------
// bool Bool
// string String
// int64 Int
// *big.Int Int
// *big.Float Float
// *big.Rat Float
// anything else Unknown
func Make(x any) Value {
switch x := x.(type) {
case bool:
return boolVal(x)
case string:
return &stringVal{s: x}
case int64:
return int64Val(x)
case *big.Int:
return makeInt(x)
case *big.Rat:
return makeRat(x)
case *big.Float:
return makeFloat(x)
default:
return unknownVal{}
}
}
// BitLen returns the number of bits required to represent
// the absolute value x in binary representation; x must be an [Int] or an [Unknown].
// If x is [Unknown], the result is 0.
func BitLen(x Value) int {
switch x := x.(type) {
case int64Val:
u := uint64(x)
if x < 0 {
u = uint64(-x)
}
return 64 - bits.LeadingZeros64(u)
case intVal:
return x.val.BitLen()
case unknownVal:
return 0
default:
panic(fmt.Sprintf("%v not an Int", x))
}
}
// Sign returns -1, 0, or 1 depending on whether x < 0, x == 0, or x > 0;
// x must be numeric or [Unknown]. For complex values x, the sign is 0 if x == 0,
// otherwise it is != 0. If x is [Unknown], the result is 1.
func Sign(x Value) int {
switch x := x.(type) {
case int64Val:
switch {
case x < 0:
return -1
case x > 0:
return 1
}
return 0
case intVal:
return x.val.Sign()
case ratVal:
return x.val.Sign()
case floatVal:
return x.val.Sign()
case complexVal:
return Sign(x.re) | Sign(x.im)
case unknownVal:
return 1 // avoid spurious division by zero errors
default:
panic(fmt.Sprintf("%v not numeric", x))
}
}
// ----------------------------------------------------------------------------
// Support for assembling/disassembling numeric values
const (
// Compute the size of a Word in bytes.
_m = ^big.Word(0)
_log = _m>>8&1 + _m>>16&1 + _m>>32&1
wordSize = 1 << _log
)
// Bytes returns the bytes for the absolute value of x in little-
// endian binary representation; x must be an [Int].
func Bytes(x Value) []byte {
var t intVal
switch x := x.(type) {
case int64Val:
t = i64toi(x)
case intVal:
t = x
default:
panic(fmt.Sprintf("%v not an Int", x))
}
words := t.val.Bits()
bytes := make([]byte, len(words)*wordSize)
i := 0
for _, w := range words {
for j := 0; j < wordSize; j++ {
bytes[i] = byte(w)
w >>= 8
i++
}
}
// remove leading 0's
for i > 0 && bytes[i-1] == 0 {
i--
}
return bytes[:i]
}
// MakeFromBytes returns the [Int] value given the bytes of its little-endian
// binary representation. An empty byte slice argument represents 0.
func MakeFromBytes(bytes []byte) Value {
words := make([]big.Word, (len(bytes)+(wordSize-1))/wordSize)
i := 0
var w big.Word
var s uint
for _, b := range bytes {
w |= big.Word(b) << s
if s += 8; s == wordSize*8 {
words[i] = w
i++
w = 0
s = 0
}
}
// store last word
if i < len(words) {
words[i] = w
i++
}
// remove leading 0's
for i > 0 && words[i-1] == 0 {
i--
}
return makeInt(newInt().SetBits(words[:i]))
}
// Num returns the numerator of x; x must be [Int], [Float], or [Unknown].
// If x is [Unknown], or if it is too large or small to represent as a
// fraction, the result is [Unknown]. Otherwise the result is an [Int]
// with the same sign as x.
func Num(x Value) Value {
switch x := x.(type) {
case int64Val, intVal:
return x
case ratVal:
return makeInt(x.val.Num())
case floatVal:
if smallFloat(x.val) {
r, _ := x.val.Rat(nil)
return makeInt(r.Num())
}
case unknownVal:
break
default:
panic(fmt.Sprintf("%v not Int or Float", x))
}
return unknownVal{}
}
// Denom returns the denominator of x; x must be [Int], [Float], or [Unknown].
// If x is [Unknown], or if it is too large or small to represent as a
// fraction, the result is [Unknown]. Otherwise the result is an [Int] >= 1.
func Denom(x Value) Value {
switch x := x.(type) {
case int64Val, intVal:
return int64Val(1)
case ratVal:
return makeInt(x.val.Denom())
case floatVal:
if smallFloat(x.val) {
r, _ := x.val.Rat(nil)
return makeInt(r.Denom())
}
case unknownVal:
break
default:
panic(fmt.Sprintf("%v not Int or Float", x))
}
return unknownVal{}
}
// MakeImag returns the [Complex] value x*i;
// x must be [Int], [Float], or [Unknown].
// If x is [Unknown], the result is [Unknown].
func MakeImag(x Value) Value {
switch x.(type) {
case unknownVal:
return x
case int64Val, intVal, ratVal, floatVal:
return makeComplex(int64Val(0), x)
default:
panic(fmt.Sprintf("%v not Int or Float", x))
}
}
// Real returns the real part of x, which must be a numeric or unknown value.
// If x is [Unknown], the result is [Unknown].
func Real(x Value) Value {
switch x := x.(type) {
case unknownVal, int64Val, intVal, ratVal, floatVal:
return x
case complexVal:
return x.re
default:
panic(fmt.Sprintf("%v not numeric", x))
}
}
// Imag returns the imaginary part of x, which must be a numeric or unknown value.
// If x is [Unknown], the result is [Unknown].
func Imag(x Value) Value {
switch x := x.(type) {
case unknownVal:
return x
case int64Val, intVal, ratVal, floatVal:
return int64Val(0)
case complexVal:
return x.im
default:
panic(fmt.Sprintf("%v not numeric", x))
}
}
// ----------------------------------------------------------------------------
// Numeric conversions
// ToInt converts x to an [Int] value if x is representable as an [Int].
// Otherwise it returns an [Unknown].
func ToInt(x Value) Value {
switch x := x.(type) {
case int64Val, intVal:
return x
case ratVal:
if x.val.IsInt() {
return makeInt(x.val.Num())
}
case floatVal:
// avoid creation of huge integers
// (Existing tests require permitting exponents of at least 1024;
// allow any value that would also be permissible as a fraction.)
if smallFloat(x.val) {
i := newInt()
if _, acc := x.val.Int(i); acc == big.Exact {
return makeInt(i)
}
// If we can get an integer by rounding up or down,
// assume x is not an integer because of rounding
// errors in prior computations.
const delta = 4 // a small number of bits > 0
var t big.Float
t.SetPrec(prec - delta)
// try rounding down a little
t.SetMode(big.ToZero)
t.Set(x.val)
if _, acc := t.Int(i); acc == big.Exact {
return makeInt(i)
}
// try rounding up a little
t.SetMode(big.AwayFromZero)
t.Set(x.val)
if _, acc := t.Int(i); acc == big.Exact {
return makeInt(i)
}
}
case complexVal:
if re := ToFloat(x); re.Kind() == Float {
return ToInt(re)
}
}
return unknownVal{}
}
// ToFloat converts x to a [Float] value if x is representable as a [Float].
// Otherwise it returns an [Unknown].
func ToFloat(x Value) Value {
switch x := x.(type) {
case int64Val:
return i64tor(x) // x is always a small int
case intVal:
if smallInt(x.val) {
return itor(x)
}
return itof(x)
case ratVal, floatVal:
return x
case complexVal:
if Sign(x.im) == 0 {
return ToFloat(x.re)
}
}
return unknownVal{}
}
// ToComplex converts x to a [Complex] value if x is representable as a [Complex].
// Otherwise it returns an [Unknown].
func ToComplex(x Value) Value {
switch x := x.(type) {
case int64Val, intVal, ratVal, floatVal:
return vtoc(x)
case complexVal:
return x
}
return unknownVal{}
}
// ----------------------------------------------------------------------------
// Operations
// is32bit reports whether x can be represented using 32 bits.
func is32bit(x int64) bool {
const s = 32
return -1<<(s-1) <= x && x <= 1<<(s-1)-1
}
// is63bit reports whether x can be represented using 63 bits.
func is63bit(x int64) bool {
const s = 63
return -1<<(s-1) <= x && x <= 1<<(s-1)-1
}
// UnaryOp returns the result of the unary expression op y.
// The operation must be defined for the operand.
// If prec > 0 it specifies the ^ (xor) result size in bits.
// If y is [Unknown], the result is [Unknown].
func UnaryOp(op token.Token, y Value, prec uint) Value {
switch op {
case token.ADD:
switch y.(type) {
case unknownVal, int64Val, intVal, ratVal, floatVal, complexVal:
return y
}
case token.SUB:
switch y := y.(type) {
case unknownVal:
return y
case int64Val:
if z := -y; z != y {
return z // no overflow
}
return makeInt(newInt().Neg(big.NewInt(int64(y))))
case intVal:
return makeInt(newInt().Neg(y.val))
case ratVal:
return makeRat(newRat().Neg(y.val))
case floatVal:
return makeFloat(newFloat().Neg(y.val))
case complexVal:
re := UnaryOp(token.SUB, y.re, 0)
im := UnaryOp(token.SUB, y.im, 0)
return makeComplex(re, im)
}
case token.XOR:
z := newInt()
switch y := y.(type) {
case unknownVal:
return y
case int64Val:
z.Not(big.NewInt(int64(y)))
case intVal:
z.Not(y.val)
default:
goto Error
}
// For unsigned types, the result will be negative and
// thus "too large": We must limit the result precision
// to the type's precision.
if prec > 0 {
z.AndNot(z, newInt().Lsh(big.NewInt(-1), prec)) // z &^= (-1)<<prec
}
return makeInt(z)
case token.NOT:
switch y := y.(type) {
case unknownVal:
return y
case boolVal:
return !y
}
}
Error:
panic(fmt.Sprintf("invalid unary operation %s%v", op, y))
}
func ord(x Value) int {
switch x.(type) {
default:
// force invalid value into "x position" in match
// (don't panic here so that callers can provide a better error message)
return -1
case unknownVal:
return 0
case boolVal, *stringVal:
return 1
case int64Val:
return 2
case intVal:
return 3
case ratVal:
return 4
case floatVal:
return 5
case complexVal:
return 6
}
}
// match returns the matching representation (same type) with the
// smallest complexity for two values x and y. If one of them is
// numeric, both of them must be numeric. If one of them is Unknown
// or invalid (say, nil) both results are that value.
func match(x, y Value) (_, _ Value) {
switch ox, oy := ord(x), ord(y); {
case ox < oy:
x, y = match0(x, y)
case ox > oy:
y, x = match0(y, x)
}
return x, y
}
// match0 must only be called by match.
// Invariant: ord(x) < ord(y)
func match0(x, y Value) (_, _ Value) {
// Prefer to return the original x and y arguments when possible,
// to avoid unnecessary heap allocations.
switch y.(type) {
case intVal:
switch x1 := x.(type) {
case int64Val:
return i64toi(x1), y
}
case ratVal:
switch x1 := x.(type) {
case int64Val:
return i64tor(x1), y
case intVal:
return itor(x1), y
}
case floatVal:
switch x1 := x.(type) {
case int64Val:
return i64tof(x1), y
case intVal:
return itof(x1), y
case ratVal:
return rtof(x1), y
}
case complexVal:
return vtoc(x), y
}
// force unknown and invalid values into "x position" in callers of match
// (don't panic here so that callers can provide a better error message)
return x, x
}
// BinaryOp returns the result of the binary expression x op y.
// The operation must be defined for the operands. If one of the
// operands is [Unknown], the result is [Unknown].
// BinaryOp doesn't handle comparisons or shifts; use [Compare]
// or [Shift] instead.
//
// To force integer division of [Int] operands, use op == [token.QUO_ASSIGN]
// instead of [token.QUO]; the result is guaranteed to be [Int] in this case.
// Division by zero leads to a run-time panic.
func BinaryOp(x_ Value, op token.Token, y_ Value) Value {
x, y := match(x_, y_)
switch x := x.(type) {
case unknownVal:
return x
case boolVal:
y := y.(boolVal)
switch op {
case token.LAND:
return x && y
case token.LOR:
return x || y
}
case int64Val:
a := int64(x)
b := int64(y.(int64Val))
var c int64
switch op {
case token.ADD:
if !is63bit(a) || !is63bit(b) {
return makeInt(newInt().Add(big.NewInt(a), big.NewInt(b)))
}
c = a + b
case token.SUB:
if !is63bit(a) || !is63bit(b) {
return makeInt(newInt().Sub(big.NewInt(a), big.NewInt(b)))
}
c = a - b
case token.MUL:
if !is32bit(a) || !is32bit(b) {
return makeInt(newInt().Mul(big.NewInt(a), big.NewInt(b)))
}
c = a * b
case token.QUO:
return makeRat(big.NewRat(a, b))
case token.QUO_ASSIGN: // force integer division
c = a / b
case token.REM:
c = a % b
case token.AND:
c = a & b
case token.OR:
c = a | b
case token.XOR:
c = a ^ b
case token.AND_NOT:
c = a &^ b
default:
goto Error
}
return int64Val(c)
case intVal:
a := x.val
b := y.(intVal).val
c := newInt()
switch op {
case token.ADD:
c.Add(a, b)
case token.SUB:
c.Sub(a, b)
case token.MUL:
c.Mul(a, b)
case token.QUO:
return makeRat(newRat().SetFrac(a, b))
case token.QUO_ASSIGN: // force integer division
c.Quo(a, b)
case token.REM:
c.Rem(a, b)
case token.AND:
c.And(a, b)
case token.OR:
c.Or(a, b)
case token.XOR:
c.Xor(a, b)
case token.AND_NOT:
c.AndNot(a, b)
default:
goto Error
}
return makeInt(c)
case ratVal:
a := x.val
b := y.(ratVal).val
c := newRat()
switch op {
case token.ADD:
c.Add(a, b)
case token.SUB:
c.Sub(a, b)
case token.MUL:
c.Mul(a, b)
case token.QUO:
c.Quo(a, b)
default:
goto Error
}
return makeRat(c)
case floatVal:
a := x.val
b := y.(floatVal).val
c := newFloat()
switch op {
case token.ADD:
c.Add(a, b)
case token.SUB:
c.Sub(a, b)
case token.MUL:
c.Mul(a, b)
case token.QUO:
c.Quo(a, b)
default:
goto Error
}
return makeFloat(c)
case complexVal:
y := y.(complexVal)
a, b := x.re, x.im
c, d := y.re, y.im
var re, im Value
switch op {
case token.ADD:
// (a+c) + i(b+d)
re = add(a, c)
im = add(b, d)
case token.SUB:
// (a-c) + i(b-d)
re = sub(a, c)
im = sub(b, d)
case token.MUL:
// (ac-bd) + i(bc+ad)
ac := mul(a, c)
bd := mul(b, d)
bc := mul(b, c)
ad := mul(a, d)
re = sub(ac, bd)
im = add(bc, ad)
case token.QUO:
// (ac+bd)/s + i(bc-ad)/s, with s = cc + dd
ac := mul(a, c)
bd := mul(b, d)
bc := mul(b, c)
ad := mul(a, d)
cc := mul(c, c)
dd := mul(d, d)
s := add(cc, dd)
re = add(ac, bd)
re = quo(re, s)
im = sub(bc, ad)
im = quo(im, s)
default:
goto Error
}
return makeComplex(re, im)
case *stringVal:
if op == token.ADD {
return &stringVal{l: x, r: y.(*stringVal)}
}
}
Error:
panic(fmt.Sprintf("invalid binary operation %v %s %v", x_, op, y_))
}
func add(x, y Value) Value { return BinaryOp(x, token.ADD, y) }
func sub(x, y Value) Value { return BinaryOp(x, token.SUB, y) }
func mul(x, y Value) Value { return BinaryOp(x, token.MUL, y) }
func quo(x, y Value) Value { return BinaryOp(x, token.QUO, y) }
// Shift returns the result of the shift expression x op s
// with op == [token.SHL] or [token.SHR] (<< or >>). x must be
// an [Int] or an [Unknown]. If x is [Unknown], the result is x.
func Shift(x Value, op token.Token, s uint) Value {
switch x := x.(type) {
case unknownVal:
return x
case int64Val:
if s == 0 {
return x
}
switch op {
case token.SHL:
z := i64toi(x).val
return makeInt(z.Lsh(z, s))
case token.SHR:
return x >> s
}
case intVal:
if s == 0 {
return x
}
z := newInt()
switch op {
case token.SHL:
return makeInt(z.Lsh(x.val, s))
case token.SHR:
return makeInt(z.Rsh(x.val, s))
}
}
panic(fmt.Sprintf("invalid shift %v %s %d", x, op, s))
}
func cmpZero(x int, op token.Token) bool {
switch op {
case token.EQL:
return x == 0
case token.NEQ:
return x != 0
case token.LSS:
return x < 0
case token.LEQ:
return x <= 0
case token.GTR:
return x > 0
case token.GEQ:
return x >= 0
}
panic(fmt.Sprintf("invalid comparison %v %s 0", x, op))
}
// Compare returns the result of the comparison x op y.
// The comparison must be defined for the operands.
// If one of the operands is [Unknown], the result is
// false.
func Compare(x_ Value, op token.Token, y_ Value) bool {
x, y := match(x_, y_)
switch x := x.(type) {
case unknownVal:
return false
case boolVal:
y := y.(boolVal)
switch op {
case token.EQL:
return x == y
case token.NEQ:
return x != y
}
case int64Val:
y := y.(int64Val)
switch op {
case token.EQL:
return x == y
case token.NEQ:
return x != y
case token.LSS:
return x < y
case token.LEQ:
return x <= y
case token.GTR:
return x > y
case token.GEQ:
return x >= y
}
case intVal:
return cmpZero(x.val.Cmp(y.(intVal).val), op)
case ratVal:
return cmpZero(x.val.Cmp(y.(ratVal).val), op)
case floatVal:
return cmpZero(x.val.Cmp(y.(floatVal).val), op)
case complexVal:
y := y.(complexVal)
re := Compare(x.re, token.EQL, y.re)
im := Compare(x.im, token.EQL, y.im)
switch op {
case token.EQL:
return re && im
case token.NEQ:
return !re || !im
}
case *stringVal:
xs := x.string()
ys := y.(*stringVal).string()
switch op {
case token.EQL:
return xs == ys
case token.NEQ:
return xs != ys
case token.LSS:
return xs < ys
case token.LEQ:
return xs <= ys
case token.GTR:
return xs > ys
case token.GEQ:
return xs >= ys
}
}
panic(fmt.Sprintf("invalid comparison %v %s %v", x_, op, y_))
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package doc
import (
"go/doc/comment"
"io"
)
// ToHTML converts comment text to formatted HTML.
//
// Deprecated: ToHTML cannot identify documentation links
// in the doc comment, because they depend on knowing what
// package the text came from, which is not included in this API.
//
// Given the *[doc.Package] p where text was found,
// ToHTML(w, text, nil) can be replaced by:
//
// w.Write(p.HTML(text))
//
// which is in turn shorthand for:
//
// w.Write(p.Printer().HTML(p.Parser().Parse(text)))
//
// If words may be non-nil, the longer replacement is:
//
// parser := p.Parser()
// parser.Words = words
// w.Write(p.Printer().HTML(parser.Parse(d)))
func ToHTML(w io.Writer, text string, words map[string]string) {
p := new(Package).Parser()
p.Words = words
d := p.Parse(text)
pr := new(comment.Printer)
w.Write(pr.HTML(d))
}
// ToText converts comment text to formatted text.
//
// Deprecated: ToText cannot identify documentation links
// in the doc comment, because they depend on knowing what
// package the text came from, which is not included in this API.
//
// Given the *[doc.Package] p where text was found,
// ToText(w, text, "", "\t", 80) can be replaced by:
//
// w.Write(p.Text(text))
//
// In the general case, ToText(w, text, prefix, codePrefix, width)
// can be replaced by:
//
// d := p.Parser().Parse(text)
// pr := p.Printer()
// pr.TextPrefix = prefix
// pr.TextCodePrefix = codePrefix
// pr.TextWidth = width
// w.Write(pr.Text(d))
//
// See the documentation for [Package.Text] and [comment.Printer.Text]
// for more details.
func ToText(w io.Writer, text string, prefix, codePrefix string, width int) {
d := new(Package).Parser().Parse(text)
pr := &comment.Printer{
TextPrefix: prefix,
TextCodePrefix: codePrefix,
TextWidth: width,
}
w.Write(pr.Text(d))
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package comment
import (
"bytes"
"fmt"
"strconv"
)
// An htmlPrinter holds the state needed for printing a [Doc] as HTML.
type htmlPrinter struct {
*Printer
tight bool
}
// HTML returns an HTML formatting of the [Doc].
// See the [Printer] documentation for ways to customize the HTML output.
func (p *Printer) HTML(d *Doc) []byte {
hp := &htmlPrinter{Printer: p}
var out bytes.Buffer
for _, x := range d.Content {
hp.block(&out, x)
}
return out.Bytes()
}
// block prints the block x to out.
func (p *htmlPrinter) block(out *bytes.Buffer, x Block) {
switch x := x.(type) {
default:
fmt.Fprintf(out, "?%T", x)
case *Paragraph:
if !p.tight {
out.WriteString("<p>")
}
p.text(out, x.Text)
out.WriteString("\n")
case *Heading:
out.WriteString("<h")
h := strconv.Itoa(p.headingLevel())
out.WriteString(h)
if id := p.headingID(x); id != "" {
out.WriteString(` id="`)
p.escape(out, id)
out.WriteString(`"`)
}
out.WriteString(">")
p.text(out, x.Text)
out.WriteString("</h")
out.WriteString(h)
out.WriteString(">\n")
case *Code:
out.WriteString("<pre>")
p.escape(out, x.Text)
out.WriteString("</pre>\n")
case *List:
kind := "ol>\n"
if x.Items[0].Number == "" {
kind = "ul>\n"
}
out.WriteString("<")
out.WriteString(kind)
next := "1"
for _, item := range x.Items {
out.WriteString("<li")
if n := item.Number; n != "" {
if n != next {
out.WriteString(` value="`)
out.WriteString(n)
out.WriteString(`"`)
next = n
}
next = inc(next)
}
out.WriteString(">")
p.tight = !x.BlankBetween()
for _, blk := range item.Content {
p.block(out, blk)
}
p.tight = false
}
out.WriteString("</")
out.WriteString(kind)
}
}
// inc increments the decimal string s.
// For example, inc("1199") == "1200".
func inc(s string) string {
b := []byte(s)
for i := len(b) - 1; i >= 0; i-- {
if b[i] < '9' {
b[i]++
return string(b)
}
b[i] = '0'
}
return "1" + string(b)
}
// text prints the text sequence x to out.
func (p *htmlPrinter) text(out *bytes.Buffer, x []Text) {
for _, t := range x {
switch t := t.(type) {
case Plain:
p.escape(out, string(t))
case Italic:
out.WriteString("<i>")
p.escape(out, string(t))
out.WriteString("</i>")
case *Link:
out.WriteString(`<a href="`)
p.escape(out, t.URL)
out.WriteString(`">`)
p.text(out, t.Text)
out.WriteString("</a>")
case *DocLink:
url := p.docLinkURL(t)
if url != "" {
out.WriteString(`<a href="`)
p.escape(out, url)
out.WriteString(`">`)
}
p.text(out, t.Text)
if url != "" {
out.WriteString("</a>")
}
}
}
}
// escape prints s to out as plain text,
// escaping < & " ' and > to avoid being misinterpreted
// in larger HTML constructs.
func (p *htmlPrinter) escape(out *bytes.Buffer, s string) {
start := 0
for i := 0; i < len(s); i++ {
switch s[i] {
case '<':
out.WriteString(s[start:i])
out.WriteString("<")
start = i + 1
case '&':
out.WriteString(s[start:i])
out.WriteString("&")
start = i + 1
case '"':
out.WriteString(s[start:i])
out.WriteString(""")
start = i + 1
case '\'':
out.WriteString(s[start:i])
out.WriteString("'")
start = i + 1
case '>':
out.WriteString(s[start:i])
out.WriteString(">")
start = i + 1
}
}
out.WriteString(s[start:])
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package comment
import (
"bytes"
"fmt"
"strings"
)
// An mdPrinter holds the state needed for printing a Doc as Markdown.
type mdPrinter struct {
*Printer
headingPrefix string
raw bytes.Buffer
}
// Markdown returns a Markdown formatting of the Doc.
// See the [Printer] documentation for ways to customize the Markdown output.
func (p *Printer) Markdown(d *Doc) []byte {
mp := &mdPrinter{
Printer: p,
headingPrefix: strings.Repeat("#", p.headingLevel()) + " ",
}
var out bytes.Buffer
for i, x := range d.Content {
if i > 0 {
out.WriteByte('\n')
}
mp.block(&out, x)
}
return out.Bytes()
}
// block prints the block x to out.
func (p *mdPrinter) block(out *bytes.Buffer, x Block) {
switch x := x.(type) {
default:
fmt.Fprintf(out, "?%T", x)
case *Paragraph:
p.text(out, x.Text)
out.WriteString("\n")
case *Heading:
out.WriteString(p.headingPrefix)
p.text(out, x.Text)
if id := p.headingID(x); id != "" {
out.WriteString(" {#")
out.WriteString(id)
out.WriteString("}")
}
out.WriteString("\n")
case *Code:
md := x.Text
for md != "" {
var line string
line, md, _ = strings.Cut(md, "\n")
if line != "" {
out.WriteString("\t")
out.WriteString(line)
}
out.WriteString("\n")
}
case *List:
loose := x.BlankBetween()
for i, item := range x.Items {
if i > 0 && loose {
out.WriteString("\n")
}
if n := item.Number; n != "" {
out.WriteString(" ")
out.WriteString(n)
out.WriteString(". ")
} else {
out.WriteString(" - ") // SP SP - SP
}
for i, blk := range item.Content {
const fourSpace = " "
if i > 0 {
out.WriteString("\n" + fourSpace)
}
p.text(out, blk.(*Paragraph).Text)
out.WriteString("\n")
}
}
}
}
// text prints the text sequence x to out.
func (p *mdPrinter) text(out *bytes.Buffer, x []Text) {
p.raw.Reset()
p.rawText(&p.raw, x)
line := bytes.TrimSpace(p.raw.Bytes())
if len(line) == 0 {
return
}
switch line[0] {
case '+', '-', '*', '#':
// Escape what would be the start of an unordered list or heading.
out.WriteByte('\\')
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
i := 1
for i < len(line) && '0' <= line[i] && line[i] <= '9' {
i++
}
if i < len(line) && (line[i] == '.' || line[i] == ')') {
// Escape what would be the start of an ordered list.
out.Write(line[:i])
out.WriteByte('\\')
line = line[i:]
}
}
out.Write(line)
}
// rawText prints the text sequence x to out,
// without worrying about escaping characters
// that have special meaning at the start of a Markdown line.
func (p *mdPrinter) rawText(out *bytes.Buffer, x []Text) {
for _, t := range x {
switch t := t.(type) {
case Plain:
p.escape(out, string(t))
case Italic:
out.WriteString("*")
p.escape(out, string(t))
out.WriteString("*")
case *Link:
out.WriteString("[")
p.rawText(out, t.Text)
out.WriteString("](")
out.WriteString(t.URL)
out.WriteString(")")
case *DocLink:
url := p.docLinkURL(t)
if url != "" {
out.WriteString("[")
}
p.rawText(out, t.Text)
if url != "" {
out.WriteString("](")
url = strings.ReplaceAll(url, "(", "%28")
url = strings.ReplaceAll(url, ")", "%29")
out.WriteString(url)
out.WriteString(")")
}
}
}
}
// escape prints s to out as plain text,
// escaping special characters to avoid being misinterpreted
// as Markdown markup sequences.
func (p *mdPrinter) escape(out *bytes.Buffer, s string) {
start := 0
for i := 0; i < len(s); i++ {
switch s[i] {
case '\n':
// Turn all \n into spaces, for a few reasons:
// - Avoid introducing paragraph breaks accidentally.
// - Avoid the need to reindent after the newline.
// - Avoid problems with Markdown renderers treating
// every mid-paragraph newline as a <br>.
out.WriteString(s[start:i])
out.WriteByte(' ')
start = i + 1
continue
case '`', '_', '*', '[', '<', '\\':
// Not all of these need to be escaped all the time,
// but is valid and easy to do so.
// We assume the Markdown is being passed to a
// Markdown renderer, not edited by a person,
// so it's fine to have escapes that are not strictly
// necessary in some cases.
out.WriteString(s[start:i])
out.WriteByte('\\')
out.WriteByte(s[i])
start = i + 1
}
}
out.WriteString(s[start:])
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package comment
import (
"slices"
"strings"
"unicode"
"unicode/utf8"
)
// A Doc is a parsed Go doc comment.
type Doc struct {
// Content is the sequence of content blocks in the comment.
Content []Block
// Links is the link definitions in the comment.
Links []*LinkDef
}
// A LinkDef is a single link definition.
type LinkDef struct {
Text string // the link text
URL string // the link URL
Used bool // whether the comment uses the definition
}
// A Block is block-level content in a doc comment,
// one of [*Code], [*Heading], [*List], or [*Paragraph].
type Block interface {
block()
}
// A Heading is a doc comment heading.
type Heading struct {
Text []Text // the heading text
}
func (*Heading) block() {}
// A List is a numbered or bullet list.
// Lists are always non-empty: len(Items) > 0.
// In a numbered list, every Items[i].Number is a non-empty string.
// In a bullet list, every Items[i].Number is an empty string.
type List struct {
// Items is the list items.
Items []*ListItem
// ForceBlankBefore indicates that the list must be
// preceded by a blank line when reformatting the comment,
// overriding the usual conditions. See the BlankBefore method.
//
// The comment parser sets ForceBlankBefore for any list
// that is preceded by a blank line, to make sure
// the blank line is preserved when printing.
ForceBlankBefore bool
// ForceBlankBetween indicates that list items must be
// separated by blank lines when reformatting the comment,
// overriding the usual conditions. See the BlankBetween method.
//
// The comment parser sets ForceBlankBetween for any list
// that has a blank line between any two of its items, to make sure
// the blank lines are preserved when printing.
ForceBlankBetween bool
}
func (*List) block() {}
// BlankBefore reports whether a reformatting of the comment
// should include a blank line before the list.
// The default rule is the same as for [BlankBetween]:
// if the list item content contains any blank lines
// (meaning at least one item has multiple paragraphs)
// then the list itself must be preceded by a blank line.
// A preceding blank line can be forced by setting [List].ForceBlankBefore.
func (l *List) BlankBefore() bool {
return l.ForceBlankBefore || l.BlankBetween()
}
// BlankBetween reports whether a reformatting of the comment
// should include a blank line between each pair of list items.
// The default rule is that if the list item content contains any blank lines
// (meaning at least one item has multiple paragraphs)
// then list items must themselves be separated by blank lines.
// Blank line separators can be forced by setting [List].ForceBlankBetween.
func (l *List) BlankBetween() bool {
if l.ForceBlankBetween {
return true
}
for _, item := range l.Items {
if len(item.Content) != 1 {
// Unreachable for parsed comments today,
// since the only way to get multiple item.Content
// is multiple paragraphs, which must have been
// separated by a blank line.
return true
}
}
return false
}
// A ListItem is a single item in a numbered or bullet list.
type ListItem struct {
// Number is a decimal string in a numbered list
// or an empty string in a bullet list.
Number string // "1", "2", ...; "" for bullet list
// Content is the list content.
// Currently, restrictions in the parser and printer
// require every element of Content to be a *Paragraph.
Content []Block // Content of this item.
}
// A Paragraph is a paragraph of text.
type Paragraph struct {
Text []Text
}
func (*Paragraph) block() {}
// A Code is a preformatted code block.
type Code struct {
// Text is the preformatted text, ending with a newline character.
// It may be multiple lines, each of which ends with a newline character.
// It is never empty, nor does it start or end with a blank line.
Text string
}
func (*Code) block() {}
// A Text is text-level content in a doc comment,
// one of [Plain], [Italic], [*Link], or [*DocLink].
type Text interface {
text()
}
// A Plain is a string rendered as plain text (not italicized).
type Plain string
func (Plain) text() {}
// An Italic is a string rendered as italicized text.
type Italic string
func (Italic) text() {}
// A Link is a link to a specific URL.
type Link struct {
Auto bool // is this an automatic (implicit) link of a literal URL?
Text []Text // text of link
URL string // target URL of link
}
func (*Link) text() {}
// A DocLink is a link to documentation for a Go package or symbol.
type DocLink struct {
Text []Text // text of link
// ImportPath, Recv, and Name identify the Go package or symbol
// that is the link target. The potential combinations of
// non-empty fields are:
// - ImportPath: a link to another package
// - ImportPath, Name: a link to a const, func, type, or var in another package
// - ImportPath, Recv, Name: a link to a method in another package
// - Name: a link to a const, func, type, or var in this package
// - Recv, Name: a link to a method in this package
ImportPath string // import path
Recv string // receiver type, without any pointer star, for methods
Name string // const, func, type, var, or method name
}
func (*DocLink) text() {}
// A Parser is a doc comment parser.
// The fields in the struct can be filled in before calling [Parser.Parse]
// in order to customize the details of the parsing process.
type Parser struct {
// Words is a map of Go identifier words that
// should be italicized and potentially linked.
// If Words[w] is the empty string, then the word w
// is only italicized. Otherwise it is linked, using
// Words[w] as the link target.
// Words corresponds to the [go/doc.ToHTML] words parameter.
Words map[string]string
// LookupPackage resolves a package name to an import path.
//
// If LookupPackage(name) returns ok == true, then [name]
// (or [name.Sym] or [name.Sym.Method])
// is considered a documentation link to importPath's package docs.
// It is valid to return "", true, in which case name is considered
// to refer to the current package.
//
// If LookupPackage(name) returns ok == false,
// then [name] (or [name.Sym] or [name.Sym.Method])
// will not be considered a documentation link,
// except in the case where name is the full (but single-element) import path
// of a package in the standard library, such as in [math] or [io.Reader].
// LookupPackage is still called for such names,
// in order to permit references to imports of other packages
// with the same package names.
//
// Setting LookupPackage to nil is equivalent to setting it to
// a function that always returns "", false.
LookupPackage func(name string) (importPath string, ok bool)
// LookupSym reports whether a symbol name or method name
// exists in the current package.
//
// If LookupSym("", "Name") returns true, then [Name]
// is considered a documentation link for a const, func, type, or var.
//
// Similarly, if LookupSym("Recv", "Name") returns true,
// then [Recv.Name] is considered a documentation link for
// type Recv's method Name.
//
// Setting LookupSym to nil is equivalent to setting it to a function
// that always returns false.
LookupSym func(recv, name string) (ok bool)
}
// parseDoc is parsing state for a single doc comment.
type parseDoc struct {
*Parser
*Doc
links map[string]*LinkDef
lines []string
lookupSym func(recv, name string) bool
}
// lookupPkg is called to look up the pkg in [pkg], [pkg.Name], and [pkg.Name.Recv].
// If pkg has a slash, it is assumed to be the full import path and is returned with ok = true.
//
// Otherwise, pkg is probably a simple package name like "rand" (not "crypto/rand" or "math/rand").
// d.LookupPackage provides a way for the caller to allow resolving such names with reference
// to the imports in the surrounding package.
//
// There is one collision between these two cases: single-element standard library names
// like "math" are full import paths but don't contain slashes. We let d.LookupPackage have
// the first chance to resolve it, in case there's a different package imported as math,
// and otherwise we refer to a built-in list of single-element standard library package names.
func (d *parseDoc) lookupPkg(pkg string) (importPath string, ok bool) {
if strings.Contains(pkg, "/") { // assume a full import path
if validImportPath(pkg) {
return pkg, true
}
return "", false
}
if d.LookupPackage != nil {
// Give LookupPackage a chance.
if path, ok := d.LookupPackage(pkg); ok {
return path, true
}
}
return DefaultLookupPackage(pkg)
}
func isStdPkg(path string) bool {
_, ok := slices.BinarySearch(stdPkgs, path)
return ok
}
// DefaultLookupPackage is the default package lookup
// function, used when [Parser.LookupPackage] is nil.
// It recognizes names of the packages from the standard
// library with single-element import paths, such as math,
// which would otherwise be impossible to name.
//
// Note that the go/doc package provides a more sophisticated
// lookup based on the imports used in the current package.
func DefaultLookupPackage(name string) (importPath string, ok bool) {
if isStdPkg(name) {
return name, true
}
return "", false
}
// Parse parses the doc comment text and returns the *[Doc] form.
// Comment markers (/* // and */) in the text must have already been removed.
func (p *Parser) Parse(text string) *Doc {
lines := unindent(strings.Split(text, "\n"))
d := &parseDoc{
Parser: p,
Doc: new(Doc),
links: make(map[string]*LinkDef),
lines: lines,
lookupSym: func(recv, name string) bool { return false },
}
if p.LookupSym != nil {
d.lookupSym = p.LookupSym
}
// First pass: break into block structure and collect known links.
// The text is all recorded as Plain for now.
var prev span
for _, s := range parseSpans(lines) {
var b Block
switch s.kind {
default:
panic("go/doc/comment: internal error: unknown span kind")
case spanList:
b = d.list(lines[s.start:s.end], prev.end < s.start)
case spanCode:
b = d.code(lines[s.start:s.end])
case spanOldHeading:
b = d.oldHeading(lines[s.start])
case spanHeading:
b = d.heading(lines[s.start])
case spanPara:
b = d.paragraph(lines[s.start:s.end])
}
if b != nil {
d.Content = append(d.Content, b)
}
prev = s
}
// Second pass: interpret all the Plain text now that we know the links.
for _, b := range d.Content {
switch b := b.(type) {
case *Paragraph:
b.Text = d.parseLinkedText(string(b.Text[0].(Plain)))
case *List:
for _, i := range b.Items {
for _, c := range i.Content {
p := c.(*Paragraph)
p.Text = d.parseLinkedText(string(p.Text[0].(Plain)))
}
}
}
}
return d.Doc
}
// A span represents a single span of comment lines (lines[start:end])
// of an identified kind (code, heading, paragraph, and so on).
type span struct {
start int
end int
kind spanKind
}
// A spanKind describes the kind of span.
type spanKind int
const (
_ spanKind = iota
spanCode
spanHeading
spanList
spanOldHeading
spanPara
)
func parseSpans(lines []string) []span {
var spans []span
// The loop may process a line twice: once as unindented
// and again forced indented. So the maximum expected
// number of iterations is 2*len(lines). The repeating logic
// can be subtle, though, and to protect against introduction
// of infinite loops in future changes, we watch to see that
// we are not looping too much. A panic is better than a
// quiet infinite loop.
watchdog := 2 * len(lines)
i := 0
forceIndent := 0
Spans:
for {
// Skip blank lines.
for i < len(lines) && lines[i] == "" {
i++
}
if i >= len(lines) {
break
}
if watchdog--; watchdog < 0 {
panic("go/doc/comment: internal error: not making progress")
}
var kind spanKind
start := i
end := i
if i < forceIndent || indented(lines[i]) {
// Indented (or force indented).
// Ends before next unindented. (Blank lines are OK.)
// If this is an unindented list that we are heuristically treating as indented,
// then accept unindented list item lines up to the first blank lines.
// The heuristic is disabled at blank lines to contain its effect
// to non-gofmt'ed sections of the comment.
unindentedListOK := isList(lines[i]) && i < forceIndent
i++
for i < len(lines) && (lines[i] == "" || i < forceIndent || indented(lines[i]) || (unindentedListOK && isList(lines[i]))) {
if lines[i] == "" {
unindentedListOK = false
}
i++
}
// Drop trailing blank lines.
end = i
for end > start && lines[end-1] == "" {
end--
}
// If indented lines are followed (without a blank line)
// by an unindented line ending in a brace,
// take that one line too. This fixes the common mistake
// of pasting in something like
//
// func main() {
// fmt.Println("hello, world")
// }
//
// and forgetting to indent it.
// The heuristic will never trigger on a gofmt'ed comment,
// because any gofmt'ed code block or list would be
// followed by a blank line or end of comment.
if end < len(lines) && strings.HasPrefix(lines[end], "}") {
end++
}
if isList(lines[start]) {
kind = spanList
} else {
kind = spanCode
}
} else {
// Unindented. Ends at next blank or indented line.
i++
for i < len(lines) && lines[i] != "" && !indented(lines[i]) {
i++
}
end = i
// If unindented lines are followed (without a blank line)
// by an indented line that would start a code block,
// check whether the final unindented lines
// should be left for the indented section.
// This can happen for the common mistakes of
// unindented code or unindented lists.
// The heuristic will never trigger on a gofmt'ed comment,
// because any gofmt'ed code block would have a blank line
// preceding it after the unindented lines.
if i < len(lines) && lines[i] != "" && !isList(lines[i]) {
switch {
case isList(lines[i-1]):
// If the final unindented line looks like a list item,
// this may be the first indented line wrap of
// a mistakenly unindented list.
// Leave all the unindented list items.
forceIndent = end
end--
for end > start && isList(lines[end-1]) {
end--
}
case strings.HasSuffix(lines[i-1], "{") || strings.HasSuffix(lines[i-1], `\`):
// If the final unindented line ended in { or \
// it is probably the start of a misindented code block.
// Give the user a single line fix.
// Often that's enough; if not, the user can fix the others themselves.
forceIndent = end
end--
}
if start == end && forceIndent > start {
i = start
continue Spans
}
}
// Span is either paragraph or heading.
if end-start == 1 && isHeading(lines[start]) {
kind = spanHeading
} else if end-start == 1 && isOldHeading(lines[start], lines, start) {
kind = spanOldHeading
} else {
kind = spanPara
}
}
spans = append(spans, span{start, end, kind})
i = end
}
return spans
}
// indented reports whether line is indented
// (starts with a leading space or tab).
func indented(line string) bool {
return line != "" && (line[0] == ' ' || line[0] == '\t')
}
// unindent removes any common space/tab prefix
// from each line in lines, returning a copy of lines in which
// those prefixes have been trimmed from each line.
// It also replaces any lines containing only spaces with blank lines (empty strings).
func unindent(lines []string) []string {
// Trim leading and trailing blank lines.
for len(lines) > 0 && isBlank(lines[0]) {
lines = lines[1:]
}
for len(lines) > 0 && isBlank(lines[len(lines)-1]) {
lines = lines[:len(lines)-1]
}
if len(lines) == 0 {
return nil
}
// Compute and remove common indentation.
prefix := leadingSpace(lines[0])
for _, line := range lines[1:] {
if !isBlank(line) {
prefix = commonPrefix(prefix, leadingSpace(line))
}
}
out := make([]string, len(lines))
for i, line := range lines {
line = strings.TrimPrefix(line, prefix)
if strings.TrimSpace(line) == "" {
line = ""
}
out[i] = line
}
for len(out) > 0 && out[0] == "" {
out = out[1:]
}
for len(out) > 0 && out[len(out)-1] == "" {
out = out[:len(out)-1]
}
return out
}
// isBlank reports whether s is a blank line.
func isBlank(s string) bool {
return len(s) == 0 || (len(s) == 1 && s[0] == '\n')
}
// commonPrefix returns the longest common prefix of a and b.
func commonPrefix(a, b string) string {
i := 0
for i < len(a) && i < len(b) && a[i] == b[i] {
i++
}
return a[0:i]
}
// leadingSpace returns the longest prefix of s consisting of spaces and tabs.
func leadingSpace(s string) string {
i := 0
for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
i++
}
return s[:i]
}
// isOldHeading reports whether line is an old-style section heading.
// line is all[off].
func isOldHeading(line string, all []string, off int) bool {
if off <= 0 || all[off-1] != "" || off+2 >= len(all) || all[off+1] != "" || leadingSpace(all[off+2]) != "" {
return false
}
line = strings.TrimSpace(line)
// a heading must start with an uppercase letter
r, _ := utf8.DecodeRuneInString(line)
if !unicode.IsLetter(r) || !unicode.IsUpper(r) {
return false
}
// it must end in a letter or digit:
r, _ = utf8.DecodeLastRuneInString(line)
if !unicode.IsLetter(r) && !unicode.IsDigit(r) {
return false
}
// exclude lines with illegal characters. we allow "(),"
if strings.ContainsAny(line, ";:!?+*/=[]{}_^°&§~%#@<\">\\") {
return false
}
// allow "'" for possessive "'s" only
for b := line; ; {
var ok bool
if _, b, ok = strings.Cut(b, "'"); !ok {
break
}
if b != "s" && !strings.HasPrefix(b, "s ") {
return false // ' not followed by s and then end-of-word
}
}
// allow "." when followed by non-space
for b := line; ; {
var ok bool
if _, b, ok = strings.Cut(b, "."); !ok {
break
}
if b == "" || strings.HasPrefix(b, " ") {
return false // not followed by non-space
}
}
return true
}
// oldHeading returns the *Heading for the given old-style section heading line.
func (d *parseDoc) oldHeading(line string) Block {
return &Heading{Text: []Text{Plain(strings.TrimSpace(line))}}
}
// isHeading reports whether line is a new-style section heading.
func isHeading(line string) bool {
return len(line) >= 2 &&
line[0] == '#' &&
(line[1] == ' ' || line[1] == '\t') &&
strings.TrimSpace(line) != "#"
}
// heading returns the *Heading for the given new-style section heading line.
func (d *parseDoc) heading(line string) Block {
return &Heading{Text: []Text{Plain(strings.TrimSpace(line[1:]))}}
}
// code returns a code block built from the lines.
func (d *parseDoc) code(lines []string) *Code {
body := unindent(lines)
body = append(body, "") // to get final \n from Join
return &Code{Text: strings.Join(body, "\n")}
}
// paragraph returns a paragraph block built from the lines.
// If the lines are link definitions, paragraph adds them to d and returns nil.
func (d *parseDoc) paragraph(lines []string) Block {
// Is this a block of known links? Handle.
var defs []*LinkDef
for _, line := range lines {
def, ok := parseLink(line)
if !ok {
goto NoDefs
}
defs = append(defs, def)
}
for _, def := range defs {
d.Links = append(d.Links, def)
if d.links[def.Text] == nil {
d.links[def.Text] = def
}
}
return nil
NoDefs:
return &Paragraph{Text: []Text{Plain(strings.Join(lines, "\n"))}}
}
// parseLink parses a single link definition line:
//
// [text]: url
//
// It returns the link definition and whether the line was well formed.
func parseLink(line string) (*LinkDef, bool) {
if line == "" || line[0] != '[' {
return nil, false
}
i := strings.Index(line, "]:")
if i < 0 || i+3 >= len(line) || (line[i+2] != ' ' && line[i+2] != '\t') {
return nil, false
}
text := line[1:i]
url := strings.TrimSpace(line[i+3:])
j := strings.Index(url, "://")
if j < 0 || !isScheme(url[:j]) {
return nil, false
}
// Line has right form and has valid scheme://.
// That's good enough for us - we are not as picky
// about the characters beyond the :// as we are
// when extracting inline URLs from text.
return &LinkDef{Text: text, URL: url}, true
}
// list returns a list built from the indented lines,
// using forceBlankBefore as the value of the List's ForceBlankBefore field.
func (d *parseDoc) list(lines []string, forceBlankBefore bool) *List {
num, _, _ := listMarker(lines[0])
var (
list *List = &List{ForceBlankBefore: forceBlankBefore}
item *ListItem
text []string
)
flush := func() {
if item != nil {
if para := d.paragraph(text); para != nil {
item.Content = append(item.Content, para)
}
}
text = nil
}
for _, line := range lines {
if n, after, ok := listMarker(line); ok && (n != "") == (num != "") {
// start new list item
flush()
item = &ListItem{Number: n}
list.Items = append(list.Items, item)
line = after
}
line = strings.TrimSpace(line)
if line == "" {
list.ForceBlankBetween = true
flush()
continue
}
text = append(text, strings.TrimSpace(line))
}
flush()
return list
}
// listMarker parses the line as beginning with a list marker.
// If it can do that, it returns the numeric marker ("" for a bullet list),
// the rest of the line, and ok == true.
// Otherwise, it returns "", "", false.
func listMarker(line string) (num, rest string, ok bool) {
line = strings.TrimSpace(line)
if line == "" {
return "", "", false
}
// Can we find a marker?
if r, n := utf8.DecodeRuneInString(line); r == '•' || r == '*' || r == '+' || r == '-' {
num, rest = "", line[n:]
} else if '0' <= line[0] && line[0] <= '9' {
n := 1
for n < len(line) && '0' <= line[n] && line[n] <= '9' {
n++
}
if n >= len(line) || (line[n] != '.' && line[n] != ')') {
return "", "", false
}
num, rest = line[:n], line[n+1:]
} else {
return "", "", false
}
if !indented(rest) || strings.TrimSpace(rest) == "" {
return "", "", false
}
return num, rest, true
}
// isList reports whether the line is the first line of a list,
// meaning starts with a list marker after any indentation.
// (The caller is responsible for checking the line is indented, as appropriate.)
func isList(line string) bool {
_, _, ok := listMarker(line)
return ok
}
// parseLinkedText parses text that is allowed to contain explicit links,
// such as [math.Sin] or [Go home page], into a slice of Text items.
//
// A “pkg” is only assumed to be a full import path if it starts with
// a domain name (a path element with a dot) or is one of the packages
// from the standard library (“[os]”, “[encoding/json]”, and so on).
// To avoid problems with maps, generics, and array types, doc links
// must be both preceded and followed by punctuation, spaces, tabs,
// or the start or end of a line. An example problem would be treating
// map[ast.Expr]TypeAndValue as containing a link.
func (d *parseDoc) parseLinkedText(text string) []Text {
var out []Text
wrote := 0
flush := func(i int) {
if wrote < i {
out = d.parseText(out, text[wrote:i], true)
wrote = i
}
}
start := -1
var buf []byte
for i := 0; i < len(text); i++ {
c := text[i]
if c == '\n' || c == '\t' {
c = ' '
}
switch c {
case '[':
start = i
case ']':
if start >= 0 {
if def, ok := d.links[string(buf)]; ok {
def.Used = true
flush(start)
out = append(out, &Link{
Text: d.parseText(nil, text[start+1:i], false),
URL: def.URL,
})
wrote = i + 1
} else if link, ok := d.docLink(text[start+1:i], text[:start], text[i+1:]); ok {
flush(start)
link.Text = d.parseText(nil, text[start+1:i], false)
out = append(out, link)
wrote = i + 1
}
}
start = -1
buf = buf[:0]
}
if start >= 0 && i != start {
buf = append(buf, c)
}
}
flush(len(text))
return out
}
// docLink parses text, which was found inside [ ] brackets,
// as a doc link if possible, returning the DocLink and ok == true
// or else nil, false.
// The before and after strings are the text before the [ and after the ]
// on the same line. Doc links must be preceded and followed by
// punctuation, spaces, tabs, or the start or end of a line.
func (d *parseDoc) docLink(text, before, after string) (link *DocLink, ok bool) {
if before != "" {
r, _ := utf8.DecodeLastRuneInString(before)
if !unicode.IsPunct(r) && r != ' ' && r != '\t' && r != '\n' {
return nil, false
}
}
if after != "" {
r, _ := utf8.DecodeRuneInString(after)
if !unicode.IsPunct(r) && r != ' ' && r != '\t' && r != '\n' {
return nil, false
}
}
text = strings.TrimPrefix(text, "*")
pkg, name, ok := splitDocName(text)
var recv string
if ok {
pkg, recv, _ = splitDocName(pkg)
}
if pkg != "" {
if pkg, ok = d.lookupPkg(pkg); !ok {
return nil, false
}
} else {
if ok = d.lookupSym(recv, name); !ok {
return nil, false
}
}
link = &DocLink{
ImportPath: pkg,
Recv: recv,
Name: name,
}
return link, true
}
// If text is of the form before.Name, where Name is a capitalized Go identifier,
// then splitDocName returns before, name, true.
// Otherwise it returns text, "", false.
func splitDocName(text string) (before, name string, foundDot bool) {
i := strings.LastIndex(text, ".")
name = text[i+1:]
if !isName(name) {
return text, "", false
}
if i >= 0 {
before = text[:i]
}
return before, name, true
}
// parseText parses s as text and returns the result of appending
// those parsed Text elements to out.
// parseText does not handle explicit links like [math.Sin] or [Go home page]:
// those are handled by parseLinkedText.
// If autoLink is true, then parseText recognizes URLs and words from d.Words
// and converts those to links as appropriate.
func (d *parseDoc) parseText(out []Text, s string, autoLink bool) []Text {
var w strings.Builder
wrote := 0
writeUntil := func(i int) {
w.WriteString(s[wrote:i])
wrote = i
}
flush := func(i int) {
writeUntil(i)
if w.Len() > 0 {
out = append(out, Plain(w.String()))
w.Reset()
}
}
for i := 0; i < len(s); {
t := s[i:]
if autoLink {
if url, ok := autoURL(t); ok {
flush(i)
// Note: The old comment parser would look up the URL in words
// and replace the target with words[URL] if it was non-empty.
// That would allow creating links that display as one URL but
// when clicked go to a different URL. Not sure what the point
// of that is, so we're not doing that lookup here.
out = append(out, &Link{Auto: true, Text: []Text{Plain(url)}, URL: url})
i += len(url)
wrote = i
continue
}
if id, ok := ident(t); ok {
url, italics := d.Words[id]
if !italics {
i += len(id)
continue
}
flush(i)
if url == "" {
out = append(out, Italic(id))
} else {
out = append(out, &Link{Auto: true, Text: []Text{Italic(id)}, URL: url})
}
i += len(id)
wrote = i
continue
}
}
switch {
case strings.HasPrefix(t, "``"):
if len(t) >= 3 && t[2] == '`' {
// Do not convert `` inside ```, in case people are mistakenly writing Markdown.
i += 3
for i < len(t) && t[i] == '`' {
i++
}
break
}
writeUntil(i)
w.WriteRune('“')
i += 2
wrote = i
case strings.HasPrefix(t, "''"):
writeUntil(i)
w.WriteRune('”')
i += 2
wrote = i
default:
i++
}
}
flush(len(s))
return out
}
// autoURL checks whether s begins with a URL that should be hyperlinked.
// If so, it returns the URL, which is a prefix of s, and ok == true.
// Otherwise it returns "", false.
// The caller should skip over the first len(url) bytes of s
// before further processing.
func autoURL(s string) (url string, ok bool) {
// Find the ://. Fast path to pick off non-URL,
// since we call this at every position in the string.
// The shortest possible URL is ftp://x, 7 bytes.
var i int
switch {
case len(s) < 7:
return "", false
case s[3] == ':':
i = 3
case s[4] == ':':
i = 4
case s[5] == ':':
i = 5
case s[6] == ':':
i = 6
default:
return "", false
}
if i+3 > len(s) || s[i:i+3] != "://" {
return "", false
}
// Check valid scheme.
if !isScheme(s[:i]) {
return "", false
}
// Scan host part. Must have at least one byte,
// and must start and end in non-punctuation.
i += 3
if i >= len(s) || !isHost(s[i]) || isPunct(s[i]) {
return "", false
}
i++
end := i
for i < len(s) && isHost(s[i]) {
if !isPunct(s[i]) {
end = i + 1
}
i++
}
i = end
// At this point we are definitely returning a URL (scheme://host).
// We just have to find the longest path we can add to it.
// Heuristics abound.
// We allow parens, braces, and brackets,
// but only if they match (#5043, #22285).
// We allow .,:;?! in the path but not at the end,
// to avoid end-of-sentence punctuation (#18139, #16565).
stk := []byte{}
end = i
Path:
for ; i < len(s); i++ {
if isPunct(s[i]) {
continue
}
if !isPath(s[i]) {
break
}
switch s[i] {
case '(':
stk = append(stk, ')')
case '{':
stk = append(stk, '}')
case '[':
stk = append(stk, ']')
case ')', '}', ']':
if len(stk) == 0 || stk[len(stk)-1] != s[i] {
break Path
}
stk = stk[:len(stk)-1]
}
if len(stk) == 0 {
end = i + 1
}
}
return s[:end], true
}
// isScheme reports whether s is a recognized URL scheme.
// Note that if strings of new length (beyond 3-7)
// are added here, the fast path at the top of autoURL will need updating.
func isScheme(s string) bool {
switch s {
case "file",
"ftp",
"gopher",
"http",
"https",
"mailto",
"nntp":
return true
}
return false
}
// isHost reports whether c is a byte that can appear in a URL host,
// like www.example.com or user@[::1]:8080
func isHost(c byte) bool {
// mask is a 128-bit bitmap with 1s for allowed bytes,
// so that the byte c can be tested with a shift and an and.
// If c > 128, then 1<<c and 1<<(c-64) will both be zero,
// and this function will return false.
const mask = 0 |
(1<<26-1)<<'A' |
(1<<26-1)<<'a' |
(1<<10-1)<<'0' |
1<<'_' |
1<<'@' |
1<<'-' |
1<<'.' |
1<<'[' |
1<<']' |
1<<':'
return ((uint64(1)<<c)&(mask&(1<<64-1)) |
(uint64(1)<<(c-64))&(mask>>64)) != 0
}
// isPunct reports whether c is a punctuation byte that can appear
// inside a path but not at the end.
func isPunct(c byte) bool {
// mask is a 128-bit bitmap with 1s for allowed bytes,
// so that the byte c can be tested with a shift and an and.
// If c > 128, then 1<<c and 1<<(c-64) will both be zero,
// and this function will return false.
const mask = 0 |
1<<'.' |
1<<',' |
1<<':' |
1<<';' |
1<<'?' |
1<<'!'
return ((uint64(1)<<c)&(mask&(1<<64-1)) |
(uint64(1)<<(c-64))&(mask>>64)) != 0
}
// isPath reports whether c is a (non-punctuation) path byte.
func isPath(c byte) bool {
// mask is a 128-bit bitmap with 1s for allowed bytes,
// so that the byte c can be tested with a shift and an and.
// If c > 128, then 1<<c and 1<<(c-64) will both be zero,
// and this function will return false.
const mask = 0 |
(1<<26-1)<<'A' |
(1<<26-1)<<'a' |
(1<<10-1)<<'0' |
1<<'$' |
1<<'\'' |
1<<'(' |
1<<')' |
1<<'*' |
1<<'+' |
1<<'&' |
1<<'#' |
1<<'=' |
1<<'@' |
1<<'~' |
1<<'_' |
1<<'/' |
1<<'-' |
1<<'[' |
1<<']' |
1<<'{' |
1<<'}' |
1<<'%'
return ((uint64(1)<<c)&(mask&(1<<64-1)) |
(uint64(1)<<(c-64))&(mask>>64)) != 0
}
// isName reports whether s is a capitalized Go identifier (like Name).
func isName(s string) bool {
t, ok := ident(s)
if !ok || t != s {
return false
}
r, _ := utf8.DecodeRuneInString(s)
return unicode.IsUpper(r)
}
// ident checks whether s begins with a Go identifier.
// If so, it returns the identifier, which is a prefix of s, and ok == true.
// Otherwise it returns "", false.
// The caller should skip over the first len(id) bytes of s
// before further processing.
func ident(s string) (id string, ok bool) {
// Scan [\pL_][\pL_0-9]*
n := 0
for n < len(s) {
if c := s[n]; c < utf8.RuneSelf {
if isIdentASCII(c) && (n > 0 || c < '0' || c > '9') {
n++
continue
}
break
}
r, nr := utf8.DecodeRuneInString(s[n:])
if unicode.IsLetter(r) {
n += nr
continue
}
break
}
return s[:n], n > 0
}
// isIdentASCII reports whether c is an ASCII identifier byte.
func isIdentASCII(c byte) bool {
// mask is a 128-bit bitmap with 1s for allowed bytes,
// so that the byte c can be tested with a shift and an and.
// If c > 128, then 1<<c and 1<<(c-64) will both be zero,
// and this function will return false.
const mask = 0 |
(1<<26-1)<<'A' |
(1<<26-1)<<'a' |
(1<<10-1)<<'0' |
1<<'_'
return ((uint64(1)<<c)&(mask&(1<<64-1)) |
(uint64(1)<<(c-64))&(mask>>64)) != 0
}
// validImportPath reports whether path is a valid import path.
// It is a lightly edited copy of golang.org/x/mod/module.CheckImportPath.
func validImportPath(path string) bool {
if !utf8.ValidString(path) {
return false
}
if path == "" {
return false
}
if path[0] == '-' {
return false
}
if strings.Contains(path, "//") {
return false
}
if path[len(path)-1] == '/' {
return false
}
elemStart := 0
for i, r := range path {
if r == '/' {
if !validImportPathElem(path[elemStart:i]) {
return false
}
elemStart = i + 1
}
}
return validImportPathElem(path[elemStart:])
}
func validImportPathElem(elem string) bool {
if elem == "" || elem[0] == '.' || elem[len(elem)-1] == '.' {
return false
}
for i := 0; i < len(elem); i++ {
if !importPathOK(elem[i]) {
return false
}
}
return true
}
func importPathOK(c byte) bool {
// mask is a 128-bit bitmap with 1s for allowed bytes,
// so that the byte c can be tested with a shift and an and.
// If c > 128, then 1<<c and 1<<(c-64) will both be zero,
// and this function will return false.
const mask = 0 |
(1<<26-1)<<'A' |
(1<<26-1)<<'a' |
(1<<10-1)<<'0' |
1<<'-' |
1<<'.' |
1<<'~' |
1<<'_' |
1<<'+'
return ((uint64(1)<<c)&(mask&(1<<64-1)) |
(uint64(1)<<(c-64))&(mask>>64)) != 0
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package comment
import (
"bytes"
"fmt"
"strings"
)
// A Printer is a doc comment printer.
// The fields in the struct can be filled in before calling
// any of the printing methods
// in order to customize the details of the printing process.
type Printer struct {
// HeadingLevel is the nesting level used for
// HTML and Markdown headings.
// If HeadingLevel is zero, it defaults to level 3,
// meaning to use <h3> and ###.
HeadingLevel int
// HeadingID is a function that computes the heading ID
// (anchor tag) to use for the heading h when generating
// HTML and Markdown. If HeadingID returns an empty string,
// then the heading ID is omitted.
// If HeadingID is nil, h.DefaultID is used.
HeadingID func(h *Heading) string
// DocLinkURL is a function that computes the URL for the given DocLink.
// If DocLinkURL is nil, then link.DefaultURL(p.DocLinkBaseURL) is used.
DocLinkURL func(link *DocLink) string
// DocLinkBaseURL is used when DocLinkURL is nil,
// passed to [DocLink.DefaultURL] to construct a DocLink's URL.
// See that method's documentation for details.
DocLinkBaseURL string
// TextPrefix is a prefix to print at the start of every line
// when generating text output using the Text method.
TextPrefix string
// TextCodePrefix is the prefix to print at the start of each
// preformatted (code block) line when generating text output,
// instead of (not in addition to) TextPrefix.
// If TextCodePrefix is the empty string, it defaults to TextPrefix+"\t".
TextCodePrefix string
// TextWidth is the maximum width text line to generate,
// measured in Unicode code points,
// excluding TextPrefix and the newline character.
// If TextWidth is zero, it defaults to 80 minus the number of code points in TextPrefix.
// If TextWidth is negative, there is no limit.
TextWidth int
}
func (p *Printer) headingLevel() int {
if p.HeadingLevel <= 0 {
return 3
}
return p.HeadingLevel
}
func (p *Printer) headingID(h *Heading) string {
if p.HeadingID == nil {
return h.DefaultID()
}
return p.HeadingID(h)
}
func (p *Printer) docLinkURL(link *DocLink) string {
if p.DocLinkURL != nil {
return p.DocLinkURL(link)
}
return link.DefaultURL(p.DocLinkBaseURL)
}
// DefaultURL constructs and returns the documentation URL for l,
// using baseURL as a prefix for links to other packages.
//
// The possible forms returned by DefaultURL are:
// - baseURL/ImportPath, for a link to another package
// - baseURL/ImportPath#Name, for a link to a const, func, type, or var in another package
// - baseURL/ImportPath#Recv.Name, for a link to a method in another package
// - #Name, for a link to a const, func, type, or var in this package
// - #Recv.Name, for a link to a method in this package
//
// If baseURL ends in a trailing slash, then DefaultURL inserts
// a slash between ImportPath and # in the anchored forms.
// For example, here are some baseURL values and URLs they can generate:
//
// "/pkg/" → "/pkg/math/#Sqrt"
// "/pkg" → "/pkg/math#Sqrt"
// "/" → "/math/#Sqrt"
// "" → "/math#Sqrt"
func (l *DocLink) DefaultURL(baseURL string) string {
if l.ImportPath != "" {
slash := ""
if strings.HasSuffix(baseURL, "/") {
slash = "/"
} else {
baseURL += "/"
}
switch {
case l.Name == "":
return baseURL + l.ImportPath + slash
case l.Recv != "":
return baseURL + l.ImportPath + slash + "#" + l.Recv + "." + l.Name
default:
return baseURL + l.ImportPath + slash + "#" + l.Name
}
}
if l.Recv != "" {
return "#" + l.Recv + "." + l.Name
}
return "#" + l.Name
}
// DefaultID returns the default anchor ID for the heading h.
//
// The default anchor ID is constructed by converting every
// rune that is not alphanumeric ASCII to an underscore
// and then adding the prefix “hdr-”.
// For example, if the heading text is “Go Doc Comments”,
// the default ID is “hdr-Go_Doc_Comments”.
func (h *Heading) DefaultID() string {
// Note: The “hdr-” prefix is important to avoid DOM clobbering attacks.
// See https://pkg.go.dev/github.com/google/safehtml#Identifier.
var out strings.Builder
var p textPrinter
p.oneLongLine(&out, h.Text)
s := strings.TrimSpace(out.String())
if s == "" {
return ""
}
out.Reset()
out.WriteString("hdr-")
for _, r := range s {
if r < 0x80 && isIdentASCII(byte(r)) {
out.WriteByte(byte(r))
} else {
out.WriteByte('_')
}
}
return out.String()
}
type commentPrinter struct {
*Printer
}
// Comment returns the standard Go formatting of the [Doc],
// without any comment markers.
func (p *Printer) Comment(d *Doc) []byte {
cp := &commentPrinter{Printer: p}
var out bytes.Buffer
for i, x := range d.Content {
if i > 0 && blankBefore(x) {
out.WriteString("\n")
}
cp.block(&out, x)
}
// Print one block containing all the link definitions that were used,
// and then a second block containing all the unused ones.
// This makes it easy to clean up the unused ones: gofmt and
// delete the final block. And it's a nice visual signal without
// affecting the way the comment formats for users.
for i := 0; i < 2; i++ {
used := i == 0
first := true
for _, def := range d.Links {
if def.Used == used {
if first {
out.WriteString("\n")
first = false
}
out.WriteString("[")
out.WriteString(def.Text)
out.WriteString("]: ")
out.WriteString(def.URL)
out.WriteString("\n")
}
}
}
return out.Bytes()
}
// blankBefore reports whether the block x requires a blank line before it.
// All blocks do, except for Lists that return false from x.BlankBefore().
func blankBefore(x Block) bool {
if x, ok := x.(*List); ok {
return x.BlankBefore()
}
return true
}
// block prints the block x to out.
func (p *commentPrinter) block(out *bytes.Buffer, x Block) {
switch x := x.(type) {
default:
fmt.Fprintf(out, "?%T", x)
case *Paragraph:
p.text(out, "", x.Text)
out.WriteString("\n")
case *Heading:
out.WriteString("# ")
p.text(out, "", x.Text)
out.WriteString("\n")
case *Code:
md := x.Text
for md != "" {
var line string
line, md, _ = strings.Cut(md, "\n")
if line != "" {
out.WriteString("\t")
out.WriteString(line)
}
out.WriteString("\n")
}
case *List:
loose := x.BlankBetween()
for i, item := range x.Items {
if i > 0 && loose {
out.WriteString("\n")
}
out.WriteString(" ")
if item.Number == "" {
out.WriteString(" - ")
} else {
out.WriteString(item.Number)
out.WriteString(". ")
}
for i, blk := range item.Content {
const fourSpace = " "
if i > 0 {
out.WriteString("\n" + fourSpace)
}
p.text(out, fourSpace, blk.(*Paragraph).Text)
out.WriteString("\n")
}
}
}
}
// text prints the text sequence x to out.
func (p *commentPrinter) text(out *bytes.Buffer, indent string, x []Text) {
for _, t := range x {
switch t := t.(type) {
case Plain:
p.indent(out, indent, string(t))
case Italic:
p.indent(out, indent, string(t))
case *Link:
if t.Auto {
p.text(out, indent, t.Text)
} else {
out.WriteString("[")
p.text(out, indent, t.Text)
out.WriteString("]")
}
case *DocLink:
out.WriteString("[")
p.text(out, indent, t.Text)
out.WriteString("]")
}
}
}
// indent prints s to out, indenting with the indent string
// after each newline in s.
func (p *commentPrinter) indent(out *bytes.Buffer, indent, s string) {
for s != "" {
line, rest, ok := strings.Cut(s, "\n")
out.WriteString(line)
if ok {
out.WriteString("\n")
out.WriteString(indent)
}
s = rest
}
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package comment
import (
"bytes"
"fmt"
"sort"
"strings"
"unicode/utf8"
)
// A textPrinter holds the state needed for printing a Doc as plain text.
type textPrinter struct {
*Printer
long strings.Builder
prefix string
codePrefix string
width int
}
// Text returns a textual formatting of the [Doc].
// See the [Printer] documentation for ways to customize the text output.
func (p *Printer) Text(d *Doc) []byte {
tp := &textPrinter{
Printer: p,
prefix: p.TextPrefix,
codePrefix: p.TextCodePrefix,
width: p.TextWidth,
}
if tp.codePrefix == "" {
tp.codePrefix = p.TextPrefix + "\t"
}
if tp.width == 0 {
tp.width = 80 - utf8.RuneCountInString(tp.prefix)
}
var out bytes.Buffer
for i, x := range d.Content {
if i > 0 && blankBefore(x) {
out.WriteString(tp.prefix)
writeNL(&out)
}
tp.block(&out, x)
}
anyUsed := false
for _, def := range d.Links {
if def.Used {
anyUsed = true
break
}
}
if anyUsed {
writeNL(&out)
for _, def := range d.Links {
if def.Used {
fmt.Fprintf(&out, "[%s]: %s\n", def.Text, def.URL)
}
}
}
return out.Bytes()
}
// writeNL calls out.WriteByte('\n')
// but first trims trailing spaces on the previous line.
func writeNL(out *bytes.Buffer) {
// Trim trailing spaces.
data := out.Bytes()
n := 0
for n < len(data) && (data[len(data)-n-1] == ' ' || data[len(data)-n-1] == '\t') {
n++
}
if n > 0 {
out.Truncate(len(data) - n)
}
out.WriteByte('\n')
}
// block prints the block x to out.
func (p *textPrinter) block(out *bytes.Buffer, x Block) {
switch x := x.(type) {
default:
fmt.Fprintf(out, "?%T\n", x)
case *Paragraph:
out.WriteString(p.prefix)
p.text(out, "", x.Text)
case *Heading:
out.WriteString(p.prefix)
out.WriteString("# ")
p.text(out, "", x.Text)
case *Code:
text := x.Text
for text != "" {
var line string
line, text, _ = strings.Cut(text, "\n")
if line != "" {
out.WriteString(p.codePrefix)
out.WriteString(line)
}
writeNL(out)
}
case *List:
loose := x.BlankBetween()
for i, item := range x.Items {
if i > 0 && loose {
out.WriteString(p.prefix)
writeNL(out)
}
out.WriteString(p.prefix)
out.WriteString(" ")
if item.Number == "" {
out.WriteString(" - ")
} else {
out.WriteString(item.Number)
out.WriteString(". ")
}
for i, blk := range item.Content {
const fourSpace = " "
if i > 0 {
writeNL(out)
out.WriteString(p.prefix)
out.WriteString(fourSpace)
}
p.text(out, fourSpace, blk.(*Paragraph).Text)
}
}
}
}
// text prints the text sequence x to out.
func (p *textPrinter) text(out *bytes.Buffer, indent string, x []Text) {
p.oneLongLine(&p.long, x)
words := strings.Fields(p.long.String())
p.long.Reset()
var seq []int
if p.width < 0 || len(words) == 0 {
seq = []int{0, len(words)} // one long line
} else {
seq = wrap(words, p.width-utf8.RuneCountInString(indent))
}
for i := 0; i+1 < len(seq); i++ {
if i > 0 {
out.WriteString(p.prefix)
out.WriteString(indent)
}
for j, w := range words[seq[i]:seq[i+1]] {
if j > 0 {
out.WriteString(" ")
}
out.WriteString(w)
}
writeNL(out)
}
}
// oneLongLine prints the text sequence x to out as one long line,
// without worrying about line wrapping.
// Explicit links have the [ ] dropped to improve readability.
func (p *textPrinter) oneLongLine(out *strings.Builder, x []Text) {
for _, t := range x {
switch t := t.(type) {
case Plain:
out.WriteString(string(t))
case Italic:
out.WriteString(string(t))
case *Link:
p.oneLongLine(out, t.Text)
case *DocLink:
p.oneLongLine(out, t.Text)
}
}
}
// wrap wraps words into lines of at most max runes,
// minimizing the sum of the squares of the leftover lengths
// at the end of each line (except the last, of course),
// with a preference for ending lines at punctuation (.,:;).
//
// The returned slice gives the indexes of the first words
// on each line in the wrapped text with a final entry of len(words).
// Thus the lines are words[seq[0]:seq[1]], words[seq[1]:seq[2]],
// ..., words[seq[len(seq)-2]:seq[len(seq)-1]].
//
// The implementation runs in O(n log n) time, where n = len(words),
// using the algorithm described in D. S. Hirschberg and L. L. Larmore,
// “[The least weight subsequence problem],” FOCS 1985, pp. 137-143.
//
// [The least weight subsequence problem]: https://doi.org/10.1109/SFCS.1985.60
func wrap(words []string, max int) (seq []int) {
// The algorithm requires that our scoring function be concave,
// meaning that for all i₀ ≤ i₁ < j₀ ≤ j₁,
// weight(i₀, j₀) + weight(i₁, j₁) ≤ weight(i₀, j₁) + weight(i₁, j₀).
//
// Our weights are two-element pairs [hi, lo]
// ordered by elementwise comparison.
// The hi entry counts the weight for lines that are longer than max,
// and the lo entry counts the weight for lines that are not.
// This forces the algorithm to first minimize the number of lines
// that are longer than max, which correspond to lines with
// single very long words. Having done that, it can move on to
// minimizing the lo score, which is more interesting.
//
// The lo score is the sum for each line of the square of the
// number of spaces remaining at the end of the line and a
// penalty of 64 given out for not ending the line in a
// punctuation character (.,:;).
// The penalty is somewhat arbitrarily chosen by trying
// different amounts and judging how nice the wrapped text looks.
// Roughly speaking, using 64 means that we are willing to
// end a line with eight blank spaces in order to end at a
// punctuation character, even if the next word would fit in
// those spaces.
//
// We care about ending in punctuation characters because
// it makes the text easier to skim if not too many sentences
// or phrases begin with a single word on the previous line.
// A score is the score (also called weight) for a given line.
// add and cmp add and compare scores.
type score struct {
hi int64
lo int64
}
add := func(s, t score) score { return score{s.hi + t.hi, s.lo + t.lo} }
cmp := func(s, t score) int {
switch {
case s.hi < t.hi:
return -1
case s.hi > t.hi:
return +1
case s.lo < t.lo:
return -1
case s.lo > t.lo:
return +1
}
return 0
}
// total[j] is the total number of runes
// (including separating spaces) in words[:j].
total := make([]int, len(words)+1)
total[0] = 0
for i, s := range words {
total[1+i] = total[i] + utf8.RuneCountInString(s) + 1
}
// weight returns weight(i, j).
weight := func(i, j int) score {
// On the last line, there is zero weight for being too short.
n := total[j] - 1 - total[i]
if j == len(words) && n <= max {
return score{0, 0}
}
// Otherwise the weight is the penalty plus the square of the number of
// characters remaining on the line or by which the line goes over.
// In the latter case, that value goes in the hi part of the score.
// (See note above.)
p := wrapPenalty(words[j-1])
v := int64(max-n) * int64(max-n)
if n > max {
return score{v, p}
}
return score{0, v + p}
}
// The rest of this function is “The Basic Algorithm” from
// Hirschberg and Larmore's conference paper,
// using the same names as in the paper.
f := []score{{0, 0}}
g := func(i, j int) score { return add(f[i], weight(i, j)) }
bridge := func(a, b, c int) bool {
k := c + sort.Search(len(words)+1-c, func(k int) bool {
k += c
return cmp(g(a, k), g(b, k)) > 0
})
if k > len(words) {
return true
}
return cmp(g(c, k), g(b, k)) <= 0
}
// d is a one-ended deque implemented as a slice.
d := make([]int, 1, len(words))
d[0] = 0
bestleft := make([]int, 1, len(words))
bestleft[0] = -1
for m := 1; m < len(words); m++ {
f = append(f, g(d[0], m))
bestleft = append(bestleft, d[0])
for len(d) > 1 && cmp(g(d[1], m+1), g(d[0], m+1)) <= 0 {
d = d[1:] // “Retire”
}
for len(d) > 1 && bridge(d[len(d)-2], d[len(d)-1], m) {
d = d[:len(d)-1] // “Fire”
}
if cmp(g(m, len(words)), g(d[len(d)-1], len(words))) < 0 {
d = append(d, m) // “Hire”
// The next few lines are not in the paper but are necessary
// to handle two-word inputs correctly. It appears to be
// just a bug in the paper's pseudocode.
if len(d) == 2 && cmp(g(d[1], m+1), g(d[0], m+1)) <= 0 {
d = d[1:]
}
}
}
bestleft = append(bestleft, d[0])
// Recover least weight sequence from bestleft.
n := 1
for m := len(words); m > 0; m = bestleft[m] {
n++
}
seq = make([]int, n)
for m := len(words); m > 0; m = bestleft[m] {
n--
seq[n] = m
}
return seq
}
// wrapPenalty is the penalty for inserting a line break after word s.
func wrapPenalty(s string) int64 {
switch s[len(s)-1] {
case '.', ',', ':', ';':
return 0
}
return 64
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package doc extracts source code documentation from a Go AST.
package doc
import (
"fmt"
"go/ast"
"go/doc/comment"
"go/token"
"strings"
)
// Package is the documentation for an entire package.
type Package struct {
Doc string
Name string
ImportPath string
Imports []string
Filenames []string
Notes map[string][]*Note
// Deprecated: For backward compatibility Bugs is still populated,
// but all new code should use Notes instead.
Bugs []string
// declarations
Consts []*Value
Types []*Type
Vars []*Value
Funcs []*Func
// Examples is a sorted list of examples associated with
// the package. Examples are extracted from _test.go files
// provided to NewFromFiles.
Examples []*Example
importByName map[string]string
syms map[string]bool
}
// Value is the documentation for a (possibly grouped) var or const declaration.
type Value struct {
Doc string
Names []string // var or const names in declaration order
Decl *ast.GenDecl
order int
}
// Type is the documentation for a type declaration.
type Type struct {
Doc string
Name string
Decl *ast.GenDecl
// associated declarations
Consts []*Value // sorted list of constants of (mostly) this type
Vars []*Value // sorted list of variables of (mostly) this type
Funcs []*Func // sorted list of functions returning this type
Methods []*Func // sorted list of methods (including embedded ones) of this type
// Examples is a sorted list of examples associated with
// this type. Examples are extracted from _test.go files
// provided to NewFromFiles.
Examples []*Example
}
// Func is the documentation for a func declaration.
type Func struct {
Doc string
Name string
Decl *ast.FuncDecl
// methods
// (for functions, these fields have the respective zero value)
Recv string // actual receiver "T" or "*T" possibly followed by type parameters [P1, ..., Pn]
Orig string // original receiver "T" or "*T"
Level int // embedding level; 0 means not embedded
// Examples is a sorted list of examples associated with this
// function or method. Examples are extracted from _test.go files
// provided to NewFromFiles.
Examples []*Example
}
// A Note represents a marked comment starting with "MARKER(uid): note body".
// Any note with a marker of 2 or more upper case [A-Z] letters and a uid of
// at least one character is recognized. The ":" following the uid is optional.
// Notes are collected in the Package.Notes map indexed by the notes marker.
type Note struct {
Pos, End token.Pos // position range of the comment containing the marker
UID string // uid found with the marker
Body string // note body text
}
// Mode values control the operation of [New] and [NewFromFiles].
type Mode int
const (
// AllDecls says to extract documentation for all package-level
// declarations, not just exported ones.
AllDecls Mode = 1 << iota
// AllMethods says to show all embedded methods, not just the ones of
// invisible (unexported) anonymous fields.
AllMethods
// PreserveAST says to leave the AST unmodified. Originally, pieces of
// the AST such as function bodies were nil-ed out to save memory in
// godoc, but not all programs want that behavior.
PreserveAST
)
// New computes the package documentation for the given package AST.
// New takes ownership of the AST pkg and may edit or overwrite it.
// To have the [Examples] fields populated, use [NewFromFiles] and include
// the package's _test.go files.
func New(pkg *ast.Package, importPath string, mode Mode) *Package {
var r reader
r.readPackage(pkg, mode)
r.computeMethodSets()
r.cleanupTypes()
p := &Package{
Doc: r.doc,
Name: pkg.Name,
ImportPath: importPath,
Imports: sortedKeys(r.imports),
Filenames: r.filenames,
Notes: r.notes,
Bugs: noteBodies(r.notes["BUG"]),
Consts: sortedValues(r.values, token.CONST),
Types: sortedTypes(r.types, mode&AllMethods != 0),
Vars: sortedValues(r.values, token.VAR),
Funcs: sortedFuncs(r.funcs, true),
importByName: r.importByName,
syms: make(map[string]bool),
}
p.collectValues(p.Consts)
p.collectValues(p.Vars)
p.collectTypes(p.Types)
p.collectFuncs(p.Funcs)
return p
}
func (p *Package) collectValues(values []*Value) {
for _, v := range values {
for _, name := range v.Names {
p.syms[name] = true
}
}
}
func (p *Package) collectTypes(types []*Type) {
for _, t := range types {
if p.syms[t.Name] {
// Shouldn't be any cycles but stop just in case.
continue
}
p.syms[t.Name] = true
p.collectValues(t.Consts)
p.collectValues(t.Vars)
p.collectFuncs(t.Funcs)
p.collectFuncs(t.Methods)
}
}
func (p *Package) collectFuncs(funcs []*Func) {
for _, f := range funcs {
if f.Recv != "" {
r := strings.TrimPrefix(f.Recv, "*")
if i := strings.IndexByte(r, '['); i >= 0 {
r = r[:i] // remove type parameters
}
p.syms[r+"."+f.Name] = true
} else {
p.syms[f.Name] = true
}
}
}
// NewFromFiles computes documentation for a package.
//
// The package is specified by a list of *ast.Files and corresponding
// file set, which must not be nil.
//
// NewFromFiles uses all provided files when computing documentation,
// so it is the caller's responsibility to provide only the files that
// match the desired build context. "go/build".Context.MatchFile can
// be used for determining whether a file matches a build context with
// the desired GOOS and GOARCH values, and other build constraints.
// The import path of the package is specified by importPath.
//
// Examples found in _test.go files are associated with the corresponding
// type, function, method, or the package, based on their name.
// If the example has a suffix in its name, it is set in the
// [Example.Suffix] field. [Examples] with malformed names are skipped.
//
// Optionally, a single extra argument of type [Mode] can be provided to
// control low-level aspects of the documentation extraction behavior.
//
// NewFromFiles takes ownership of the AST files and may edit them,
// unless the PreserveAST Mode bit is on.
func NewFromFiles(fset *token.FileSet, files []*ast.File, importPath string, opts ...any) (*Package, error) {
// Check for invalid API usage.
if fset == nil {
panic(fmt.Errorf("doc.NewFromFiles: no token.FileSet provided (fset == nil)"))
}
var mode Mode
switch len(opts) { // There can only be 0 or 1 options, so a simple switch works for now.
case 0:
// Nothing to do.
case 1:
m, ok := opts[0].(Mode)
if !ok {
panic(fmt.Errorf("doc.NewFromFiles: option argument type must be doc.Mode"))
}
mode = m
default:
panic(fmt.Errorf("doc.NewFromFiles: there must not be more than 1 option argument"))
}
// Collect .go and _test.go files.
var (
pkgName string
goFiles = make(map[string]*ast.File)
testGoFiles []*ast.File
)
for i, file := range files {
f := fset.File(file.Pos())
if f == nil {
return nil, fmt.Errorf("file files[%d] is not found in the provided file set", i)
}
switch filename := f.Name(); {
case strings.HasSuffix(filename, "_test.go"):
testGoFiles = append(testGoFiles, file)
case strings.HasSuffix(filename, ".go"):
pkgName = file.Name.Name
goFiles[filename] = file
default:
return nil, fmt.Errorf("file files[%d] filename %q does not have a .go extension", i, filename)
}
}
// Compute package documentation.
//
// Since this package doesn't need Package.{Scope,Imports}, or
// handle errors, and ast.File's Scope field is unset in files
// parsed with parser.SkipObjectResolution, we construct the
// Package directly instead of calling [ast.NewPackage].
pkg := &ast.Package{Name: pkgName, Files: goFiles}
p := New(pkg, importPath, mode)
classifyExamples(p, Examples(testGoFiles...))
return p, nil
}
// lookupSym reports whether the package has a given symbol or method.
//
// If recv == "", HasSym reports whether the package has a top-level
// const, func, type, or var named name.
//
// If recv != "", HasSym reports whether the package has a type
// named recv with a method named name.
func (p *Package) lookupSym(recv, name string) bool {
if recv != "" {
return p.syms[recv+"."+name]
}
return p.syms[name]
}
// lookupPackage returns the import path identified by name
// in the given package. If name uniquely identifies a single import,
// then lookupPackage returns that import.
// If multiple packages are imported as name, importPath returns "", false.
// Otherwise, if name is the name of p itself, importPath returns "", true,
// to signal a reference to p.
// Otherwise, importPath returns "", false.
func (p *Package) lookupPackage(name string) (importPath string, ok bool) {
if path, ok := p.importByName[name]; ok {
if path == "" {
return "", false // multiple imports used the name
}
return path, true // found import
}
if p.Name == name {
return "", true // allow reference to this package
}
return "", false // unknown name
}
// Parser returns a doc comment parser configured
// for parsing doc comments from package p.
// Each call returns a new parser, so that the caller may
// customize it before use.
func (p *Package) Parser() *comment.Parser {
return &comment.Parser{
LookupPackage: p.lookupPackage,
LookupSym: p.lookupSym,
}
}
// Printer returns a doc comment printer configured
// for printing doc comments from package p.
// Each call returns a new printer, so that the caller may
// customize it before use.
func (p *Package) Printer() *comment.Printer {
// No customization today, but having p.Printer()
// gives us flexibility in the future, and it is convenient for callers.
return &comment.Printer{}
}
// HTML returns formatted HTML for the doc comment text.
//
// To customize details of the HTML, use [Package.Printer]
// to obtain a [comment.Printer], and configure it
// before calling its HTML method.
func (p *Package) HTML(text string) []byte {
return p.Printer().HTML(p.Parser().Parse(text))
}
// Markdown returns formatted Markdown for the doc comment text.
//
// To customize details of the Markdown, use [Package.Printer]
// to obtain a [comment.Printer], and configure it
// before calling its Markdown method.
func (p *Package) Markdown(text string) []byte {
return p.Printer().Markdown(p.Parser().Parse(text))
}
// Text returns formatted text for the doc comment text,
// wrapped to 80 Unicode code points and using tabs for
// code block indentation.
//
// To customize details of the formatting, use [Package.Printer]
// to obtain a [comment.Printer], and configure it
// before calling its Text method.
func (p *Package) Text(text string) []byte {
return p.Printer().Text(p.Parser().Parse(text))
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Extract example functions from file ASTs.
package doc
import (
"cmp"
"go/ast"
"go/token"
"internal/lazyregexp"
"path"
"slices"
"strconv"
"strings"
"unicode"
"unicode/utf8"
)
// An Example represents an example function found in a test source file.
type Example struct {
Name string // name of the item being exemplified (including optional suffix)
Suffix string // example suffix, without leading '_' (only populated by NewFromFiles)
Doc string // example function doc string
Code ast.Node
Play *ast.File // a whole program version of the example
Comments []*ast.CommentGroup
Output string // expected output
Unordered bool
EmptyOutput bool // expect empty output
Order int // original source code order
}
// Examples returns the examples found in testFiles, sorted by Name field.
// The Order fields record the order in which the examples were encountered.
// The Suffix field is not populated when Examples is called directly, it is
// only populated by [NewFromFiles] for examples it finds in _test.go files.
//
// Playable Examples must be in a package whose name ends in "_test".
// An Example is "playable" (the Play field is non-nil) in either of these
// circumstances:
// - The example function is self-contained: the function references only
// identifiers from other packages (or predeclared identifiers, such as
// "int") and the test file does not include a dot import.
// - The entire test file is the example: the file contains exactly one
// example function, zero test, fuzz test, or benchmark function, and at
// least one top-level function, type, variable, or constant declaration
// other than the example function.
func Examples(testFiles ...*ast.File) []*Example {
var list []*Example
for _, file := range testFiles {
hasTests := false // file contains tests, fuzz test, or benchmarks
numDecl := 0 // number of non-import declarations in the file
var flist []*Example
for _, decl := range file.Decls {
if g, ok := decl.(*ast.GenDecl); ok && g.Tok != token.IMPORT {
numDecl++
continue
}
f, ok := decl.(*ast.FuncDecl)
if !ok || f.Recv != nil {
continue
}
numDecl++
name := f.Name.Name
if isTest(name, "Test") || isTest(name, "Benchmark") || isTest(name, "Fuzz") {
hasTests = true
continue
}
if !isTest(name, "Example") {
continue
}
if params := f.Type.Params; len(params.List) != 0 {
continue // function has params; not a valid example
}
if f.Body == nil { // ast.File.Body nil dereference (see issue 28044)
continue
}
var doc string
if f.Doc != nil {
doc = f.Doc.Text()
}
output, unordered, hasOutput := exampleOutput(f.Body, file.Comments)
flist = append(flist, &Example{
Name: name[len("Example"):],
Doc: doc,
Code: f.Body,
Play: playExample(file, f),
Comments: file.Comments,
Output: output,
Unordered: unordered,
EmptyOutput: output == "" && hasOutput,
Order: len(flist),
})
}
if !hasTests && numDecl > 1 && len(flist) == 1 {
// If this file only has one example function, some
// other top-level declarations, and no tests or
// benchmarks, use the whole file as the example.
flist[0].Code = file
flist[0].Play = playExampleFile(file)
}
list = append(list, flist...)
}
// sort by name
slices.SortFunc(list, func(a, b *Example) int {
return cmp.Compare(a.Name, b.Name)
})
return list
}
var outputPrefix = lazyregexp.New(`(?i)^[[:space:]]*(unordered )?output:`)
// Extracts the expected output and whether there was a valid output comment.
func exampleOutput(b *ast.BlockStmt, comments []*ast.CommentGroup) (output string, unordered, ok bool) {
if _, last := lastComment(b, comments); last != nil {
// test that it begins with the correct prefix
text := last.Text()
if loc := outputPrefix.FindStringSubmatchIndex(text); loc != nil {
if loc[2] != -1 {
unordered = true
}
text = text[loc[1]:]
// Strip zero or more spaces followed by \n or a single space.
text = strings.TrimLeft(text, " ")
if len(text) > 0 && text[0] == '\n' {
text = text[1:]
}
return text, unordered, true
}
}
return "", false, false // no suitable comment found
}
// isTest tells whether name looks like a test, example, fuzz test, or
// benchmark. It is a Test (say) if there is a character after Test that is not
// a lower-case letter. (We don't want Testiness.)
func isTest(name, prefix string) bool {
if !strings.HasPrefix(name, prefix) {
return false
}
if len(name) == len(prefix) { // "Test" is ok
return true
}
rune, _ := utf8.DecodeRuneInString(name[len(prefix):])
return !unicode.IsLower(rune)
}
// playExample synthesizes a new *ast.File based on the provided
// file with the provided function body as the body of main.
func playExample(file *ast.File, f *ast.FuncDecl) *ast.File {
body := f.Body
if !strings.HasSuffix(file.Name.Name, "_test") {
// We don't support examples that are part of the
// greater package (yet).
return nil
}
// Collect top-level declarations in the file.
topDecls := make(map[*ast.Object]ast.Decl)
typMethods := make(map[string][]ast.Decl)
for _, decl := range file.Decls {
switch d := decl.(type) {
case *ast.FuncDecl:
if d.Recv == nil {
topDecls[d.Name.Obj] = d
} else {
if len(d.Recv.List) == 1 {
t := d.Recv.List[0].Type
tname, _ := baseTypeName(t)
typMethods[tname] = append(typMethods[tname], d)
}
}
case *ast.GenDecl:
for _, spec := range d.Specs {
switch s := spec.(type) {
case *ast.TypeSpec:
topDecls[s.Name.Obj] = d
case *ast.ValueSpec:
for _, name := range s.Names {
topDecls[name.Obj] = d
}
}
}
}
}
// Find unresolved identifiers and uses of top-level declarations.
depDecls, unresolved := findDeclsAndUnresolved(body, topDecls, typMethods)
// Use unresolved identifiers to determine the imports used by this
// example. The heuristic assumes package names match base import
// paths for imports w/o renames (should be good enough most of the time).
var namedImports []ast.Spec
var blankImports []ast.Spec // _ imports
// To preserve the blank lines between groups of imports, find the
// start position of each group, and assign that position to all
// imports from that group.
groupStarts := findImportGroupStarts(file.Imports)
groupStart := func(s *ast.ImportSpec) token.Pos {
for i, start := range groupStarts {
if s.Path.ValuePos < start {
return groupStarts[i-1]
}
}
return groupStarts[len(groupStarts)-1]
}
for _, s := range file.Imports {
p, err := strconv.Unquote(s.Path.Value)
if err != nil {
continue
}
if p == "syscall/js" {
// We don't support examples that import syscall/js,
// because the package syscall/js is not available in the playground.
return nil
}
n := path.Base(p)
if s.Name != nil {
n = s.Name.Name
switch n {
case "_":
blankImports = append(blankImports, s)
continue
case ".":
// We can't resolve dot imports (yet).
return nil
}
}
if unresolved[n] {
// Copy the spec and its path to avoid modifying the original.
spec := *s
path := *s.Path
spec.Path = &path
spec.Path.ValuePos = groupStart(&spec)
namedImports = append(namedImports, &spec)
delete(unresolved, n)
}
}
// Remove predeclared identifiers from unresolved list.
for n := range unresolved {
if predeclaredTypes[n] || predeclaredConstants[n] || predeclaredFuncs[n] {
delete(unresolved, n)
}
}
// If there are other unresolved identifiers, give up because this
// synthesized file is not going to build.
if len(unresolved) > 0 {
return nil
}
// Include documentation belonging to blank imports.
var comments []*ast.CommentGroup
for _, s := range blankImports {
if c := s.(*ast.ImportSpec).Doc; c != nil {
comments = append(comments, c)
}
}
// Include comments that are inside the function body.
for _, c := range file.Comments {
if body.Pos() <= c.Pos() && c.End() <= body.End() {
comments = append(comments, c)
}
}
// Strip the "Output:" or "Unordered output:" comment and adjust body
// end position.
body, comments = stripOutputComment(body, comments)
// Include documentation belonging to dependent declarations.
for _, d := range depDecls {
switch d := d.(type) {
case *ast.GenDecl:
if d.Doc != nil {
comments = append(comments, d.Doc)
}
case *ast.FuncDecl:
if d.Doc != nil {
comments = append(comments, d.Doc)
}
}
}
// Synthesize import declaration.
importDecl := &ast.GenDecl{
Tok: token.IMPORT,
Lparen: 1, // Need non-zero Lparen and Rparen so that printer
Rparen: 1, // treats this as a factored import.
}
importDecl.Specs = append(namedImports, blankImports...)
// Synthesize main function.
funcDecl := &ast.FuncDecl{
Name: ast.NewIdent("main"),
Type: f.Type,
Body: body,
}
decls := make([]ast.Decl, 0, 2+len(depDecls))
decls = append(decls, importDecl)
decls = append(decls, depDecls...)
decls = append(decls, funcDecl)
slices.SortFunc(decls, func(a, b ast.Decl) int {
return cmp.Compare(a.Pos(), b.Pos())
})
slices.SortFunc(comments, func(a, b *ast.CommentGroup) int {
return cmp.Compare(a.Pos(), b.Pos())
})
// Synthesize file.
return &ast.File{
Name: ast.NewIdent("main"),
Decls: decls,
Comments: comments,
}
}
// findDeclsAndUnresolved returns all the top-level declarations mentioned in
// the body, and a set of unresolved symbols (those that appear in the body but
// have no declaration in the program).
//
// topDecls maps objects to the top-level declaration declaring them (not
// necessarily obj.Decl, as obj.Decl will be a Spec for GenDecls, but
// topDecls[obj] will be the GenDecl itself).
func findDeclsAndUnresolved(body ast.Node, topDecls map[*ast.Object]ast.Decl, typMethods map[string][]ast.Decl) ([]ast.Decl, map[string]bool) {
// This function recursively finds every top-level declaration used
// transitively by the body, populating usedDecls and usedObjs. Then it
// trims down the declarations to include only the symbols actually
// referenced by the body.
unresolved := make(map[string]bool)
var depDecls []ast.Decl
usedDecls := make(map[ast.Decl]bool) // set of top-level decls reachable from the body
usedObjs := make(map[*ast.Object]bool) // set of objects reachable from the body (each declared by a usedDecl)
var inspectFunc func(ast.Node) bool
inspectFunc = func(n ast.Node) bool {
switch e := n.(type) {
case *ast.Ident:
if e.Obj == nil && e.Name != "_" {
unresolved[e.Name] = true
} else if d := topDecls[e.Obj]; d != nil {
usedObjs[e.Obj] = true
if !usedDecls[d] {
usedDecls[d] = true
depDecls = append(depDecls, d)
}
}
return true
case *ast.SelectorExpr:
// For selector expressions, only inspect the left hand side.
// (For an expression like fmt.Println, only add "fmt" to the
// set of unresolved names, not "Println".)
ast.Inspect(e.X, inspectFunc)
return false
case *ast.KeyValueExpr:
// For key value expressions, only inspect the value
// as the key should be resolved by the type of the
// composite literal.
ast.Inspect(e.Value, inspectFunc)
return false
}
return true
}
inspectFieldList := func(fl *ast.FieldList) {
if fl != nil {
for _, f := range fl.List {
ast.Inspect(f.Type, inspectFunc)
}
}
}
// Find the decls immediately referenced by body.
ast.Inspect(body, inspectFunc)
// Now loop over them, adding to the list when we find a new decl that the
// body depends on. Keep going until we don't find anything new.
for i := 0; i < len(depDecls); i++ {
switch d := depDecls[i].(type) {
case *ast.FuncDecl:
// Inspect type parameters.
inspectFieldList(d.Type.TypeParams)
// Inspect types of parameters and results. See #28492.
inspectFieldList(d.Type.Params)
inspectFieldList(d.Type.Results)
// Functions might not have a body. See #42706.
if d.Body != nil {
ast.Inspect(d.Body, inspectFunc)
}
case *ast.GenDecl:
for _, spec := range d.Specs {
switch s := spec.(type) {
case *ast.TypeSpec:
inspectFieldList(s.TypeParams)
ast.Inspect(s.Type, inspectFunc)
depDecls = append(depDecls, typMethods[s.Name.Name]...)
case *ast.ValueSpec:
if s.Type != nil {
ast.Inspect(s.Type, inspectFunc)
}
for _, val := range s.Values {
ast.Inspect(val, inspectFunc)
}
}
}
}
}
// Some decls include multiple specs, such as a variable declaration with
// multiple variables on the same line, or a parenthesized declaration. Trim
// the declarations to include only the specs that are actually mentioned.
// However, if there is a constant group with iota, leave it all: later
// constant declarations in the group may have no value and so cannot stand
// on their own, and removing any constant from the group could change the
// values of subsequent ones.
// See testdata/examples/iota.go for a minimal example.
var ds []ast.Decl
for _, d := range depDecls {
switch d := d.(type) {
case *ast.FuncDecl:
ds = append(ds, d)
case *ast.GenDecl:
containsIota := false // does any spec have iota?
// Collect all Specs that were mentioned in the example.
var specs []ast.Spec
for _, s := range d.Specs {
switch s := s.(type) {
case *ast.TypeSpec:
if usedObjs[s.Name.Obj] {
specs = append(specs, s)
}
case *ast.ValueSpec:
if !containsIota {
containsIota = hasIota(s)
}
// A ValueSpec may have multiple names (e.g. "var a, b int").
// Keep only the names that were mentioned in the example.
// Exception: the multiple names have a single initializer (which
// would be a function call with multiple return values). In that
// case, keep everything.
if len(s.Names) > 1 && len(s.Values) == 1 {
specs = append(specs, s)
continue
}
ns := *s
ns.Names = nil
ns.Values = nil
for i, n := range s.Names {
if usedObjs[n.Obj] {
ns.Names = append(ns.Names, n)
if s.Values != nil {
ns.Values = append(ns.Values, s.Values[i])
}
}
}
if len(ns.Names) > 0 {
specs = append(specs, &ns)
}
}
}
if len(specs) > 0 {
// Constant with iota? Keep it all.
if d.Tok == token.CONST && containsIota {
ds = append(ds, d)
} else {
// Synthesize a GenDecl with just the Specs we need.
nd := *d // copy the GenDecl
nd.Specs = specs
if len(specs) == 1 {
// Remove grouping parens if there is only one spec.
nd.Lparen = 0
}
ds = append(ds, &nd)
}
}
}
}
return ds, unresolved
}
func hasIota(s ast.Spec) bool {
for n := range ast.Preorder(s) {
// Check that this is the special built-in "iota" identifier, not
// a user-defined shadow.
if id, ok := n.(*ast.Ident); ok && id.Name == "iota" && id.Obj == nil {
return true
}
}
return false
}
// findImportGroupStarts finds the start positions of each sequence of import
// specs that are not separated by a blank line.
func findImportGroupStarts(imps []*ast.ImportSpec) []token.Pos {
startImps := findImportGroupStarts1(imps)
groupStarts := make([]token.Pos, len(startImps))
for i, imp := range startImps {
groupStarts[i] = imp.Pos()
}
return groupStarts
}
// Helper for findImportGroupStarts to ease testing.
func findImportGroupStarts1(origImps []*ast.ImportSpec) []*ast.ImportSpec {
// Copy to avoid mutation.
imps := make([]*ast.ImportSpec, len(origImps))
copy(imps, origImps)
// Assume the imports are sorted by position.
slices.SortFunc(imps, func(a, b *ast.ImportSpec) int {
return cmp.Compare(a.Pos(), b.Pos())
})
// Assume gofmt has been applied, so there is a blank line between adjacent imps
// if and only if they are more than 2 positions apart (newline, tab).
var groupStarts []*ast.ImportSpec
prevEnd := token.Pos(-2)
for _, imp := range imps {
if imp.Pos()-prevEnd > 2 {
groupStarts = append(groupStarts, imp)
}
prevEnd = imp.End()
// Account for end-of-line comments.
if imp.Comment != nil {
prevEnd = imp.Comment.End()
}
}
return groupStarts
}
// playExampleFile takes a whole file example and synthesizes a new *ast.File
// such that the example is function main in package main.
func playExampleFile(file *ast.File) *ast.File {
// Strip copyright comment if present.
comments := file.Comments
if len(comments) > 0 && strings.HasPrefix(comments[0].Text(), "Copyright") {
comments = comments[1:]
}
// Copy declaration slice, rewriting the ExampleX function to main.
var decls []ast.Decl
for _, d := range file.Decls {
if f, ok := d.(*ast.FuncDecl); ok && isTest(f.Name.Name, "Example") {
// Copy the FuncDecl, as it may be used elsewhere.
newF := *f
newF.Name = ast.NewIdent("main")
newF.Body, comments = stripOutputComment(f.Body, comments)
d = &newF
}
decls = append(decls, d)
}
// Copy the File, as it may be used elsewhere.
f := *file
f.Name = ast.NewIdent("main")
f.Decls = decls
f.Comments = comments
return &f
}
// stripOutputComment finds and removes the "Output:" or "Unordered output:"
// comment from body and comments, and adjusts the body block's end position.
func stripOutputComment(body *ast.BlockStmt, comments []*ast.CommentGroup) (*ast.BlockStmt, []*ast.CommentGroup) {
// Do nothing if there is no "Output:" or "Unordered output:" comment.
i, last := lastComment(body, comments)
if last == nil || !outputPrefix.MatchString(last.Text()) {
return body, comments
}
// Copy body and comments, as the originals may be used elsewhere.
newBody := &ast.BlockStmt{
Lbrace: body.Lbrace,
List: body.List,
Rbrace: last.Pos(),
}
newComments := make([]*ast.CommentGroup, len(comments)-1)
copy(newComments, comments[:i])
copy(newComments[i:], comments[i+1:])
return newBody, newComments
}
// lastComment returns the last comment inside the provided block.
func lastComment(b *ast.BlockStmt, c []*ast.CommentGroup) (i int, last *ast.CommentGroup) {
if b == nil {
return
}
pos, end := b.Pos(), b.End()
for j, cg := range c {
if cg.Pos() < pos {
continue
}
if cg.End() > end {
break
}
i, last = j, cg
}
return
}
// classifyExamples classifies examples and assigns them to the Examples field
// of the relevant Func, Type, or Package that the example is associated with.
//
// The classification process is ambiguous in some cases:
//
// - ExampleFoo_Bar matches a type named Foo_Bar
// or a method named Foo.Bar.
// - ExampleFoo_bar matches a type named Foo_bar
// or Foo (with a "bar" suffix).
//
// Examples with malformed names are not associated with anything.
func classifyExamples(p *Package, examples []*Example) {
if len(examples) == 0 {
return
}
// Mapping of names for funcs, types, and methods to the example listing.
ids := make(map[string]*[]*Example)
ids[""] = &p.Examples // package-level examples have an empty name
for _, f := range p.Funcs {
if !token.IsExported(f.Name) {
continue
}
ids[f.Name] = &f.Examples
}
for _, t := range p.Types {
if !token.IsExported(t.Name) {
continue
}
ids[t.Name] = &t.Examples
for _, f := range t.Funcs {
if !token.IsExported(f.Name) {
continue
}
ids[f.Name] = &f.Examples
}
for _, m := range t.Methods {
if !token.IsExported(m.Name) {
continue
}
ids[strings.TrimPrefix(nameWithoutInst(m.Recv), "*")+"_"+m.Name] = &m.Examples
}
}
// Group each example with the associated func, type, or method.
for _, ex := range examples {
// Consider all possible split points for the suffix
// by starting at the end of string (no suffix case),
// then trying all positions that contain a '_' character.
//
// An association is made on the first successful match.
// Examples with malformed names that match nothing are skipped.
for i := len(ex.Name); i >= 0; i = strings.LastIndexByte(ex.Name[:i], '_') {
prefix, suffix, ok := splitExampleName(ex.Name, i)
if !ok {
continue
}
exs, ok := ids[prefix]
if !ok {
continue
}
ex.Suffix = suffix
*exs = append(*exs, ex)
break
}
}
// Sort list of example according to the user-specified suffix name.
for _, exs := range ids {
slices.SortFunc(*exs, func(a, b *Example) int {
return cmp.Compare(a.Suffix, b.Suffix)
})
}
}
// nameWithoutInst returns name if name has no brackets. If name contains
// brackets, then it returns name with all the contents between (and including)
// the outermost left and right bracket removed.
//
// Adapted from debug/gosym/symtab.go:Sym.nameWithoutInst.
func nameWithoutInst(name string) string {
start := strings.Index(name, "[")
if start < 0 {
return name
}
end := strings.LastIndex(name, "]")
if end < 0 {
// Malformed name, should contain closing bracket too.
return name
}
return name[0:start] + name[end+1:]
}
// splitExampleName attempts to split example name s at index i,
// and reports if that produces a valid split. The suffix may be
// absent. Otherwise, it must start with a lower-case letter and
// be preceded by '_'.
//
// One of i == len(s) or s[i] == '_' must be true.
func splitExampleName(s string, i int) (prefix, suffix string, ok bool) {
if i == len(s) {
return s, "", true
}
if i == len(s)-1 {
return "", "", false
}
prefix, suffix = s[:i], s[i+1:]
return prefix, suffix, isExampleSuffix(suffix)
}
func isExampleSuffix(s string) bool {
r, size := utf8.DecodeRuneInString(s)
return size > 0 && unicode.IsLower(r)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements export filtering of an AST.
package doc
import (
"go/ast"
"go/token"
)
// filterIdentList removes unexported names from list in place
// and returns the resulting list.
func filterIdentList(list []*ast.Ident) []*ast.Ident {
j := 0
for _, x := range list {
if token.IsExported(x.Name) {
list[j] = x
j++
}
}
return list[0:j]
}
var underscore = ast.NewIdent("_")
func filterCompositeLit(lit *ast.CompositeLit, filter Filter, export bool) {
n := len(lit.Elts)
lit.Elts = filterExprList(lit.Elts, filter, export)
if len(lit.Elts) < n {
lit.Incomplete = true
}
}
func filterExprList(list []ast.Expr, filter Filter, export bool) []ast.Expr {
j := 0
for _, exp := range list {
switch x := exp.(type) {
case *ast.CompositeLit:
filterCompositeLit(x, filter, export)
case *ast.KeyValueExpr:
if x, ok := x.Key.(*ast.Ident); ok && !filter(x.Name) {
continue
}
if x, ok := x.Value.(*ast.CompositeLit); ok {
filterCompositeLit(x, filter, export)
}
}
list[j] = exp
j++
}
return list[0:j]
}
// updateIdentList replaces all unexported identifiers with underscore
// and reports whether at least one exported name exists.
func updateIdentList(list []*ast.Ident) (hasExported bool) {
for i, x := range list {
if token.IsExported(x.Name) {
hasExported = true
} else {
list[i] = underscore
}
}
return hasExported
}
// hasExportedName reports whether list contains any exported names.
func hasExportedName(list []*ast.Ident) bool {
for _, x := range list {
if x.IsExported() {
return true
}
}
return false
}
// removeAnonymousField removes anonymous fields named name from an interface.
func removeAnonymousField(name string, ityp *ast.InterfaceType) {
list := ityp.Methods.List // we know that ityp.Methods != nil
j := 0
for _, field := range list {
keepField := true
if n := len(field.Names); n == 0 {
// anonymous field
if fname, _ := baseTypeName(field.Type); fname == name {
keepField = false
}
}
if keepField {
list[j] = field
j++
}
}
if j < len(list) {
ityp.Incomplete = true
}
ityp.Methods.List = list[0:j]
}
// filterFieldList removes unexported fields (field names) from the field list
// in place and reports whether fields were removed. Anonymous fields are
// recorded with the parent type. filterType is called with the types of
// all remaining fields.
func (r *reader) filterFieldList(parent *namedType, fields *ast.FieldList, ityp *ast.InterfaceType) (removedFields bool) {
if fields == nil {
return
}
list := fields.List
j := 0
for _, field := range list {
keepField := false
if n := len(field.Names); n == 0 {
// anonymous field or embedded type or union element
fname := r.recordAnonymousField(parent, field.Type)
if fname != "" {
if token.IsExported(fname) {
keepField = true
} else if ityp != nil && predeclaredTypes[fname] {
// possibly an embedded predeclared type; keep it for now but
// remember this interface so that it can be fixed if name is also
// defined locally
keepField = true
r.remember(fname, ityp)
}
} else {
// If we're operating on an interface, assume that this is an embedded
// type or union element.
//
// TODO(rfindley): consider traversing into approximation/unions
// elements to see if they are entirely unexported.
keepField = ityp != nil
}
} else {
field.Names = filterIdentList(field.Names)
if len(field.Names) < n {
removedFields = true
}
if len(field.Names) > 0 {
keepField = true
}
}
if keepField {
r.filterType(nil, field.Type)
list[j] = field
j++
}
}
if j < len(list) {
removedFields = true
}
fields.List = list[0:j]
return
}
// filterParamList applies filterType to each parameter type in fields.
func (r *reader) filterParamList(fields *ast.FieldList) {
if fields != nil {
for _, f := range fields.List {
r.filterType(nil, f.Type)
}
}
}
// filterType strips any unexported struct fields or method types from typ
// in place. If fields (or methods) have been removed, the corresponding
// struct or interface type has the Incomplete field set to true.
func (r *reader) filterType(parent *namedType, typ ast.Expr) {
switch t := typ.(type) {
case *ast.Ident:
// nothing to do
case *ast.ParenExpr:
r.filterType(nil, t.X)
case *ast.StarExpr: // possibly an embedded type literal
r.filterType(nil, t.X)
case *ast.UnaryExpr:
if t.Op == token.TILDE { // approximation element
r.filterType(nil, t.X)
}
case *ast.BinaryExpr:
if t.Op == token.OR { // union
r.filterType(nil, t.X)
r.filterType(nil, t.Y)
}
case *ast.ArrayType:
r.filterType(nil, t.Elt)
case *ast.StructType:
if r.filterFieldList(parent, t.Fields, nil) {
t.Incomplete = true
}
case *ast.FuncType:
r.filterParamList(t.TypeParams)
r.filterParamList(t.Params)
r.filterParamList(t.Results)
case *ast.InterfaceType:
if r.filterFieldList(parent, t.Methods, t) {
t.Incomplete = true
}
case *ast.MapType:
r.filterType(nil, t.Key)
r.filterType(nil, t.Value)
case *ast.ChanType:
r.filterType(nil, t.Value)
}
}
func (r *reader) filterSpec(spec ast.Spec) bool {
switch s := spec.(type) {
case *ast.ImportSpec:
// always keep imports so we can collect them
return true
case *ast.ValueSpec:
s.Values = filterExprList(s.Values, token.IsExported, true)
if len(s.Values) > 0 || s.Type == nil && len(s.Values) == 0 {
// If there are values declared on RHS, just replace the unexported
// identifiers on the LHS with underscore, so that it matches
// the sequence of expression on the RHS.
//
// Similarly, if there are no type and values, then this expression
// must be following an iota expression, where order matters.
if updateIdentList(s.Names) {
r.filterType(nil, s.Type)
return true
}
} else {
s.Names = filterIdentList(s.Names)
if len(s.Names) > 0 {
r.filterType(nil, s.Type)
return true
}
}
case *ast.TypeSpec:
// Don't filter type parameters here, by analogy with function parameters
// which are not filtered for top-level function declarations.
if name := s.Name.Name; token.IsExported(name) {
r.filterType(r.lookupType(s.Name.Name), s.Type)
return true
} else if IsPredeclared(name) {
if r.shadowedPredecl == nil {
r.shadowedPredecl = make(map[string]bool)
}
r.shadowedPredecl[name] = true
}
}
return false
}
// copyConstType returns a copy of typ with position pos.
// typ must be a valid constant type.
// In practice, only (possibly qualified) identifiers are possible.
func copyConstType(typ ast.Expr, pos token.Pos) ast.Expr {
switch typ := typ.(type) {
case *ast.Ident:
return &ast.Ident{Name: typ.Name, NamePos: pos}
case *ast.SelectorExpr:
if id, ok := typ.X.(*ast.Ident); ok {
// presumably a qualified identifier
return &ast.SelectorExpr{
Sel: ast.NewIdent(typ.Sel.Name),
X: &ast.Ident{Name: id.Name, NamePos: pos},
}
}
}
return nil // shouldn't happen, but be conservative and don't panic
}
func (r *reader) filterSpecList(list []ast.Spec, tok token.Token) []ast.Spec {
if tok == token.CONST {
// Propagate any type information that would get lost otherwise
// when unexported constants are filtered.
var prevType ast.Expr
for _, spec := range list {
spec := spec.(*ast.ValueSpec)
if spec.Type == nil && len(spec.Values) == 0 && prevType != nil {
// provide current spec with an explicit type
spec.Type = copyConstType(prevType, spec.Pos())
}
if hasExportedName(spec.Names) {
// exported names are preserved so there's no need to propagate the type
prevType = nil
} else {
prevType = spec.Type
}
}
}
j := 0
for _, s := range list {
if r.filterSpec(s) {
list[j] = s
j++
}
}
return list[0:j]
}
func (r *reader) filterDecl(decl ast.Decl) bool {
switch d := decl.(type) {
case *ast.GenDecl:
d.Specs = r.filterSpecList(d.Specs, d.Tok)
return len(d.Specs) > 0
case *ast.FuncDecl:
// ok to filter these methods early because any
// conflicting method will be filtered here, too -
// thus, removing these methods early will not lead
// to the false removal of possible conflicts
return token.IsExported(d.Name.Name)
}
return false
}
// fileExports removes unexported declarations from src in place.
func (r *reader) fileExports(src *ast.File) {
j := 0
for _, d := range src.Decls {
if r.filterDecl(d) {
src.Decls[j] = d
j++
}
}
src.Decls = src.Decls[0:j]
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package doc
import "go/ast"
type Filter func(string) bool
func matchFields(fields *ast.FieldList, f Filter) bool {
if fields != nil {
for _, field := range fields.List {
for _, name := range field.Names {
if f(name.Name) {
return true
}
}
}
}
return false
}
func matchDecl(d *ast.GenDecl, f Filter) bool {
for _, d := range d.Specs {
switch v := d.(type) {
case *ast.ValueSpec:
for _, name := range v.Names {
if f(name.Name) {
return true
}
}
case *ast.TypeSpec:
if f(v.Name.Name) {
return true
}
// We don't match ordinary parameters in filterFuncs, so by analogy don't
// match type parameters here.
switch t := v.Type.(type) {
case *ast.StructType:
if matchFields(t.Fields, f) {
return true
}
case *ast.InterfaceType:
if matchFields(t.Methods, f) {
return true
}
}
}
}
return false
}
func filterValues(a []*Value, f Filter) []*Value {
w := 0
for _, vd := range a {
if matchDecl(vd.Decl, f) {
a[w] = vd
w++
}
}
return a[0:w]
}
func filterFuncs(a []*Func, f Filter) []*Func {
w := 0
for _, fd := range a {
if f(fd.Name) {
a[w] = fd
w++
}
}
return a[0:w]
}
func filterTypes(a []*Type, f Filter) []*Type {
w := 0
for _, td := range a {
n := 0 // number of matches
if matchDecl(td.Decl, f) {
n = 1
} else {
// type name doesn't match, but we may have matching consts, vars, factories or methods
td.Consts = filterValues(td.Consts, f)
td.Vars = filterValues(td.Vars, f)
td.Funcs = filterFuncs(td.Funcs, f)
td.Methods = filterFuncs(td.Methods, f)
n += len(td.Consts) + len(td.Vars) + len(td.Funcs) + len(td.Methods)
}
if n > 0 {
a[w] = td
w++
}
}
return a[0:w]
}
// Filter eliminates documentation for names that don't pass through the filter f.
// TODO(gri): Recognize "Type.Method" as a name.
func (p *Package) Filter(f Filter) {
p.Consts = filterValues(p.Consts, f)
p.Vars = filterValues(p.Vars, f)
p.Types = filterTypes(p.Types, f)
p.Funcs = filterFuncs(p.Funcs, f)
p.Doc = "" // don't show top-level package doc
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package doc
import (
"cmp"
"fmt"
"go/ast"
"go/token"
"internal/lazyregexp"
"path"
"slices"
"strconv"
"strings"
"unicode"
"unicode/utf8"
)
// ----------------------------------------------------------------------------
// function/method sets
//
// Internally, we treat functions like methods and collect them in method sets.
// A methodSet describes a set of methods. Entries where Decl == nil are conflict
// entries (more than one method with the same name at the same embedding level).
type methodSet map[string]*Func
// recvString returns a string representation of recv of the form "T", "*T",
// "T[A, ...]", "*T[A, ...]" or "BADRECV" (if not a proper receiver type).
func recvString(recv ast.Expr) string {
switch t := recv.(type) {
case *ast.Ident:
return t.Name
case *ast.StarExpr:
return "*" + recvString(t.X)
case *ast.IndexExpr:
// Generic type with one parameter.
return fmt.Sprintf("%s[%s]", recvString(t.X), recvParam(t.Index))
case *ast.IndexListExpr:
// Generic type with multiple parameters.
if len(t.Indices) > 0 {
var b strings.Builder
b.WriteString(recvString(t.X))
b.WriteByte('[')
b.WriteString(recvParam(t.Indices[0]))
for _, e := range t.Indices[1:] {
b.WriteString(", ")
b.WriteString(recvParam(e))
}
b.WriteByte(']')
return b.String()
}
}
return "BADRECV"
}
func recvParam(p ast.Expr) string {
if id, ok := p.(*ast.Ident); ok {
return id.Name
}
return "BADPARAM"
}
// set creates the corresponding Func for f and adds it to mset.
// If there are multiple f's with the same name, set keeps the first
// one with documentation; conflicts are ignored. The boolean
// specifies whether to leave the AST untouched.
func (mset methodSet) set(f *ast.FuncDecl, preserveAST bool) {
name := f.Name.Name
if g := mset[name]; g != nil && g.Doc != "" {
// A function with the same name has already been registered;
// since it has documentation, assume f is simply another
// implementation and ignore it. This does not happen if the
// caller is using go/build.ScanDir to determine the list of
// files implementing a package.
return
}
// function doesn't exist or has no documentation; use f
recv := ""
if f.Recv != nil {
var typ ast.Expr
// be careful in case of incorrect ASTs
if list := f.Recv.List; len(list) == 1 {
typ = list[0].Type
}
recv = recvString(typ)
}
mset[name] = &Func{
Doc: f.Doc.Text(),
Name: name,
Decl: f,
Recv: recv,
Orig: recv,
}
if !preserveAST {
f.Doc = nil // doc consumed - remove from AST
}
}
// add adds method m to the method set; m is ignored if the method set
// already contains a method with the same name at the same or a higher
// level than m.
func (mset methodSet) add(m *Func) {
old := mset[m.Name]
if old == nil || m.Level < old.Level {
mset[m.Name] = m
return
}
if m.Level == old.Level {
// conflict - mark it using a method with nil Decl
mset[m.Name] = &Func{
Name: m.Name,
Level: m.Level,
}
}
}
// ----------------------------------------------------------------------------
// Named types
// baseTypeName returns the name of the base type of x (or "")
// and whether the type is imported or not.
func baseTypeName(x ast.Expr) (name string, imported bool) {
switch t := x.(type) {
case *ast.Ident:
return t.Name, false
case *ast.IndexExpr:
return baseTypeName(t.X)
case *ast.IndexListExpr:
return baseTypeName(t.X)
case *ast.SelectorExpr:
if _, ok := t.X.(*ast.Ident); ok {
// only possible for qualified type names;
// assume type is imported
return t.Sel.Name, true
}
case *ast.ParenExpr:
return baseTypeName(t.X)
case *ast.StarExpr:
return baseTypeName(t.X)
}
return "", false
}
// An embeddedSet describes a set of embedded types.
type embeddedSet map[*namedType]bool
// A namedType represents a named unqualified (package local, or possibly
// predeclared) type. The namedType for a type name is always found via
// reader.lookupType.
type namedType struct {
doc string // doc comment for type
name string // type name
decl *ast.GenDecl // nil if declaration hasn't been seen yet
isEmbedded bool // true if this type is embedded
isStruct bool // true if this type is a struct
embedded embeddedSet // true if the embedded type is a pointer
// associated declarations
values []*Value // consts and vars
funcs methodSet
methods methodSet
}
// ----------------------------------------------------------------------------
// AST reader
// reader accumulates documentation for a single package.
// It modifies the AST: Comments (declaration documentation)
// that have been collected by the reader are set to nil
// in the respective AST nodes so that they are not printed
// twice (once when printing the documentation and once when
// printing the corresponding AST node).
type reader struct {
mode Mode
// package properties
doc string // package documentation, if any
filenames []string
notes map[string][]*Note
// imports
imports map[string]int
hasDotImp bool // if set, package contains a dot import
importByName map[string]string
// declarations
values []*Value // consts and vars
order int // sort order of const and var declarations (when we can't use a name)
types map[string]*namedType
funcs methodSet
// support for package-local shadowing of predeclared types
shadowedPredecl map[string]bool
fixmap map[string][]*ast.InterfaceType
}
func (r *reader) isVisible(name string) bool {
return r.mode&AllDecls != 0 || token.IsExported(name)
}
// lookupType returns the base type with the given name.
// If the base type has not been encountered yet, a new
// type with the given name but no associated declaration
// is added to the type map.
func (r *reader) lookupType(name string) *namedType {
if name == "" || name == "_" {
return nil // no type docs for anonymous types
}
if typ, found := r.types[name]; found {
return typ
}
// type not found - add one without declaration
typ := &namedType{
name: name,
embedded: make(embeddedSet),
funcs: make(methodSet),
methods: make(methodSet),
}
r.types[name] = typ
return typ
}
// recordAnonymousField registers fieldType as the type of an
// anonymous field in the parent type. If the field is imported
// (qualified name) or the parent is nil, the field is ignored.
// The function returns the field name.
func (r *reader) recordAnonymousField(parent *namedType, fieldType ast.Expr) (fname string) {
fname, imp := baseTypeName(fieldType)
if parent == nil || imp {
return
}
if ftype := r.lookupType(fname); ftype != nil {
ftype.isEmbedded = true
_, ptr := fieldType.(*ast.StarExpr)
parent.embedded[ftype] = ptr
}
return
}
func (r *reader) readDoc(comment *ast.CommentGroup) {
// By convention there should be only one package comment
// but collect all of them if there are more than one.
text := comment.Text()
if r.doc == "" {
r.doc = text
return
}
r.doc += "\n" + text
}
func (r *reader) remember(predecl string, typ *ast.InterfaceType) {
if r.fixmap == nil {
r.fixmap = make(map[string][]*ast.InterfaceType)
}
r.fixmap[predecl] = append(r.fixmap[predecl], typ)
}
func specNames(specs []ast.Spec) []string {
names := make([]string, 0, len(specs)) // reasonable estimate
for _, s := range specs {
// s guaranteed to be an *ast.ValueSpec by readValue
for _, ident := range s.(*ast.ValueSpec).Names {
names = append(names, ident.Name)
}
}
return names
}
// readValue processes a const or var declaration.
func (r *reader) readValue(decl *ast.GenDecl) {
// determine if decl should be associated with a type
// Heuristic: For each typed entry, determine the type name, if any.
// If there is exactly one type name that is sufficiently
// frequent, associate the decl with the respective type.
domName := ""
domFreq := 0
prev := ""
n := 0
for _, spec := range decl.Specs {
s, ok := spec.(*ast.ValueSpec)
if !ok {
continue // should not happen, but be conservative
}
name := ""
switch {
case s.Type != nil:
// a type is present; determine its name
if n, imp := baseTypeName(s.Type); !imp {
name = n
}
case decl.Tok == token.CONST && len(s.Values) == 0:
// no type or value is present but we have a constant declaration;
// use the previous type name (possibly the empty string)
name = prev
}
if name != "" {
// entry has a named type
if domName != "" && domName != name {
// more than one type name - do not associate
// with any type
domName = ""
break
}
domName = name
domFreq++
}
prev = name
n++
}
// nothing to do w/o a legal declaration
if n == 0 {
return
}
// determine values list with which to associate the Value for this decl
values := &r.values
const threshold = 0.75
if domName != "" && r.isVisible(domName) && domFreq >= int(float64(len(decl.Specs))*threshold) {
// typed entries are sufficiently frequent
if typ := r.lookupType(domName); typ != nil {
values = &typ.values // associate with that type
}
}
*values = append(*values, &Value{
Doc: decl.Doc.Text(),
Names: specNames(decl.Specs),
Decl: decl,
order: r.order,
})
if r.mode&PreserveAST == 0 {
decl.Doc = nil // doc consumed - remove from AST
}
// Note: It's important that the order used here is global because the cleanupTypes
// methods may move values associated with types back into the global list. If the
// order is list-specific, sorting is not deterministic because the same order value
// may appear multiple times (was bug, found when fixing #16153).
r.order++
}
// fields returns a struct's fields or an interface's methods.
func fields(typ ast.Expr) (list []*ast.Field, isStruct bool) {
var fields *ast.FieldList
switch t := typ.(type) {
case *ast.StructType:
fields = t.Fields
isStruct = true
case *ast.InterfaceType:
fields = t.Methods
}
if fields != nil {
list = fields.List
}
return
}
// readType processes a type declaration.
func (r *reader) readType(decl *ast.GenDecl, spec *ast.TypeSpec) {
typ := r.lookupType(spec.Name.Name)
if typ == nil {
return // no name or blank name - ignore the type
}
// A type should be added at most once, so typ.decl
// should be nil - if it is not, simply overwrite it.
typ.decl = decl
// compute documentation
doc := spec.Doc
if doc == nil {
// no doc associated with the spec, use the declaration doc, if any
doc = decl.Doc
}
if r.mode&PreserveAST == 0 {
spec.Doc = nil // doc consumed - remove from AST
decl.Doc = nil // doc consumed - remove from AST
}
typ.doc = doc.Text()
// record anonymous fields (they may contribute methods)
// (some fields may have been recorded already when filtering
// exports, but that's ok)
var list []*ast.Field
list, typ.isStruct = fields(spec.Type)
for _, field := range list {
if len(field.Names) == 0 {
r.recordAnonymousField(typ, field.Type)
}
}
}
// isPredeclared reports whether n denotes a predeclared type.
func (r *reader) isPredeclared(n string) bool {
return predeclaredTypes[n] && r.types[n] == nil
}
// readFunc processes a func or method declaration.
func (r *reader) readFunc(fun *ast.FuncDecl) {
// strip function body if requested.
if r.mode&PreserveAST == 0 {
fun.Body = nil
}
// associate methods with the receiver type, if any
if fun.Recv != nil {
// method
if len(fun.Recv.List) == 0 {
// should not happen (incorrect AST); (See issue 17788)
// don't show this method
return
}
recvTypeName, imp := baseTypeName(fun.Recv.List[0].Type)
if imp {
// should not happen (incorrect AST);
// don't show this method
return
}
if typ := r.lookupType(recvTypeName); typ != nil {
typ.methods.set(fun, r.mode&PreserveAST != 0)
}
// otherwise ignore the method
// TODO(gri): There may be exported methods of non-exported types
// that can be called because of exported values (consts, vars, or
// function results) of that type. Could determine if that is the
// case and then show those methods in an appropriate section.
return
}
// Associate factory functions with the first visible result type, as long as
// others are predeclared types.
if fun.Type.Results.NumFields() >= 1 {
var typ *namedType // type to associate the function with
numResultTypes := 0
for _, res := range fun.Type.Results.List {
factoryType := res.Type
if t, ok := factoryType.(*ast.ArrayType); ok {
// We consider functions that return slices or arrays of type
// T (or pointers to T) as factory functions of T.
factoryType = t.Elt
}
if n, imp := baseTypeName(factoryType); !imp && r.isVisible(n) && !r.isPredeclared(n) {
if lookupTypeParam(n, fun.Type.TypeParams) != nil {
// Issue #49477: don't associate fun with its type parameter result.
// A type parameter is not a defined type.
continue
}
if t := r.lookupType(n); t != nil {
typ = t
numResultTypes++
if numResultTypes > 1 {
break
}
}
}
}
// If there is exactly one result type,
// associate the function with that type.
if numResultTypes == 1 {
typ.funcs.set(fun, r.mode&PreserveAST != 0)
return
}
}
// just an ordinary function
r.funcs.set(fun, r.mode&PreserveAST != 0)
}
// lookupTypeParam searches for type parameters named name within the tparams
// field list, returning the relevant identifier if found, or nil if not.
func lookupTypeParam(name string, tparams *ast.FieldList) *ast.Ident {
if tparams == nil {
return nil
}
for _, field := range tparams.List {
for _, id := range field.Names {
if id.Name == name {
return id
}
}
}
return nil
}
var (
noteMarker = `([A-Z][A-Z]+)\(([^)]+)\):?` // MARKER(uid), MARKER at least 2 chars, uid at least 1 char
noteMarkerRx = lazyregexp.New(`^[ \t]*` + noteMarker) // MARKER(uid) at text start
noteCommentRx = lazyregexp.New(`^/[/*][ \t]*` + noteMarker) // MARKER(uid) at comment start
)
// clean replaces each sequence of space, \r, or \t characters
// with a single space and removes any trailing and leading spaces.
func clean(s string) string {
var b []byte
p := byte(' ')
for i := 0; i < len(s); i++ {
q := s[i]
if q == '\r' || q == '\t' {
q = ' '
}
if q != ' ' || p != ' ' {
b = append(b, q)
p = q
}
}
// remove trailing blank, if any
if n := len(b); n > 0 && p == ' ' {
b = b[0 : n-1]
}
return string(b)
}
// readNote collects a single note from a sequence of comments.
func (r *reader) readNote(list []*ast.Comment) {
text := (&ast.CommentGroup{List: list}).Text()
if m := noteMarkerRx.FindStringSubmatchIndex(text); m != nil {
// The note body starts after the marker.
// We remove any formatting so that we don't
// get spurious line breaks/indentation when
// showing the TODO body.
body := clean(text[m[1]:])
if body != "" {
marker := text[m[2]:m[3]]
r.notes[marker] = append(r.notes[marker], &Note{
Pos: list[0].Pos(),
End: list[len(list)-1].End(),
UID: text[m[4]:m[5]],
Body: body,
})
}
}
}
// readNotes extracts notes from comments.
// A note must start at the beginning of a comment with "MARKER(uid):"
// and is followed by the note body (e.g., "// BUG(gri): fix this").
// The note ends at the end of the comment group or at the start of
// another note in the same comment group, whichever comes first.
func (r *reader) readNotes(comments []*ast.CommentGroup) {
for _, group := range comments {
i := -1 // comment index of most recent note start, valid if >= 0
list := group.List
for j, c := range list {
if noteCommentRx.MatchString(c.Text) {
if i >= 0 {
r.readNote(list[i:j])
}
i = j
}
}
if i >= 0 {
r.readNote(list[i:])
}
}
}
// readFile adds the AST for a source file to the reader.
func (r *reader) readFile(src *ast.File) {
// add package documentation
if src.Doc != nil {
r.readDoc(src.Doc)
if r.mode&PreserveAST == 0 {
src.Doc = nil // doc consumed - remove from AST
}
}
// add all declarations but for functions which are processed in a separate pass
for _, decl := range src.Decls {
switch d := decl.(type) {
case *ast.GenDecl:
switch d.Tok {
case token.IMPORT:
// imports are handled individually
for _, spec := range d.Specs {
if s, ok := spec.(*ast.ImportSpec); ok {
if import_, err := strconv.Unquote(s.Path.Value); err == nil {
r.imports[import_] = 1
var name string
if s.Name != nil {
name = s.Name.Name
if name == "." {
r.hasDotImp = true
}
}
if name != "." {
if name == "" {
name = assumedPackageName(import_)
}
old, ok := r.importByName[name]
if !ok {
r.importByName[name] = import_
} else if old != import_ && old != "" {
r.importByName[name] = "" // ambiguous
}
}
}
}
}
case token.CONST, token.VAR:
// constants and variables are always handled as a group
r.readValue(d)
case token.TYPE:
// types are handled individually
if len(d.Specs) == 1 && !d.Lparen.IsValid() {
// common case: single declaration w/o parentheses
// (if a single declaration is parenthesized,
// create a new fake declaration below, so that
// go/doc type declarations always appear w/o
// parentheses)
if s, ok := d.Specs[0].(*ast.TypeSpec); ok {
r.readType(d, s)
}
break
}
for _, spec := range d.Specs {
if s, ok := spec.(*ast.TypeSpec); ok {
// use an individual (possibly fake) declaration
// for each type; this also ensures that each type
// gets to (re-)use the declaration documentation
// if there's none associated with the spec itself
fake := &ast.GenDecl{
Doc: d.Doc,
// don't use the existing TokPos because it
// will lead to the wrong selection range for
// the fake declaration if there are more
// than one type in the group (this affects
// src/cmd/godoc/godoc.go's posLink_urlFunc)
TokPos: s.Pos(),
Tok: token.TYPE,
Specs: []ast.Spec{s},
}
r.readType(fake, s)
}
}
}
}
}
// collect MARKER(...): annotations
r.readNotes(src.Comments)
if r.mode&PreserveAST == 0 {
src.Comments = nil // consumed unassociated comments - remove from AST
}
}
func (r *reader) readPackage(pkg *ast.Package, mode Mode) {
// initialize reader
r.filenames = make([]string, len(pkg.Files))
r.imports = make(map[string]int)
r.mode = mode
r.types = make(map[string]*namedType)
r.funcs = make(methodSet)
r.notes = make(map[string][]*Note)
r.importByName = make(map[string]string)
// sort package files before reading them so that the
// result does not depend on map iteration order
i := 0
for filename := range pkg.Files {
r.filenames[i] = filename
i++
}
slices.Sort(r.filenames)
// process files in sorted order
for _, filename := range r.filenames {
f := pkg.Files[filename]
if mode&AllDecls == 0 {
r.fileExports(f)
}
r.readFile(f)
}
for name, path := range r.importByName {
if path == "" {
delete(r.importByName, name)
}
}
// process functions now that we have better type information
for _, f := range pkg.Files {
for _, decl := range f.Decls {
if d, ok := decl.(*ast.FuncDecl); ok {
r.readFunc(d)
}
}
}
}
// ----------------------------------------------------------------------------
// Types
func customizeRecv(f *Func, recvTypeName string, embeddedIsPtr bool, level int) *Func {
if f == nil || f.Decl == nil || f.Decl.Recv == nil || len(f.Decl.Recv.List) != 1 {
return f // shouldn't happen, but be safe
}
// copy existing receiver field and set new type
newField := *f.Decl.Recv.List[0]
origPos := newField.Type.Pos()
_, origRecvIsPtr := newField.Type.(*ast.StarExpr)
newIdent := &ast.Ident{NamePos: origPos, Name: recvTypeName}
var typ ast.Expr = newIdent
if !embeddedIsPtr && origRecvIsPtr {
newIdent.NamePos++ // '*' is one character
typ = &ast.StarExpr{Star: origPos, X: newIdent}
}
newField.Type = typ
// copy existing receiver field list and set new receiver field
newFieldList := *f.Decl.Recv
newFieldList.List = []*ast.Field{&newField}
// copy existing function declaration and set new receiver field list
newFuncDecl := *f.Decl
newFuncDecl.Recv = &newFieldList
// copy existing function documentation and set new declaration
newF := *f
newF.Decl = &newFuncDecl
newF.Recv = recvString(typ)
// the Orig field never changes
newF.Level = level
return &newF
}
// collectEmbeddedMethods collects the embedded methods of typ in mset.
func (r *reader) collectEmbeddedMethods(mset methodSet, typ *namedType, recvTypeName string, embeddedIsPtr bool, level int, visited embeddedSet) {
visited[typ] = true
for embedded, isPtr := range typ.embedded {
// Once an embedded type is embedded as a pointer type
// all embedded types in those types are treated like
// pointer types for the purpose of the receiver type
// computation; i.e., embeddedIsPtr is sticky for this
// embedding hierarchy.
thisEmbeddedIsPtr := embeddedIsPtr || isPtr
for _, m := range embedded.methods {
// only top-level methods are embedded
if m.Level == 0 {
mset.add(customizeRecv(m, recvTypeName, thisEmbeddedIsPtr, level))
}
}
if !visited[embedded] {
r.collectEmbeddedMethods(mset, embedded, recvTypeName, thisEmbeddedIsPtr, level+1, visited)
}
}
delete(visited, typ)
}
// computeMethodSets determines the actual method sets for each type encountered.
func (r *reader) computeMethodSets() {
for _, t := range r.types {
// collect embedded methods for t
if t.isStruct {
// struct
r.collectEmbeddedMethods(t.methods, t, t.name, false, 1, make(embeddedSet))
} else {
// interface
// TODO(gri) fix this
}
}
// For any predeclared names that are declared locally, don't treat them as
// exported fields anymore.
for predecl := range r.shadowedPredecl {
for _, ityp := range r.fixmap[predecl] {
removeAnonymousField(predecl, ityp)
}
}
}
// cleanupTypes removes the association of functions and methods with
// types that have no declaration. Instead, these functions and methods
// are shown at the package level. It also removes types with missing
// declarations or which are not visible.
func (r *reader) cleanupTypes() {
for _, t := range r.types {
visible := r.isVisible(t.name)
predeclared := predeclaredTypes[t.name]
if t.decl == nil && (predeclared || visible && (t.isEmbedded || r.hasDotImp)) {
// t.name is a predeclared type (and was not redeclared in this package),
// or it was embedded somewhere but its declaration is missing (because
// the AST is incomplete), or we have a dot-import (and all bets are off):
// move any associated values, funcs, and methods back to the top-level so
// that they are not lost.
// 1) move values
r.values = append(r.values, t.values...)
// 2) move factory functions
for name, f := range t.funcs {
// in a correct AST, package-level function names
// are all different - no need to check for conflicts
r.funcs[name] = f
}
// 3) move methods
if !predeclared {
for name, m := range t.methods {
// don't overwrite functions with the same name - drop them
if _, found := r.funcs[name]; !found {
r.funcs[name] = m
}
}
}
}
// remove types w/o declaration or which are not visible
if t.decl == nil || !visible {
delete(r.types, t.name)
}
}
}
// ----------------------------------------------------------------------------
// Sorting
func sortedKeys(m map[string]int) []string {
list := make([]string, len(m))
i := 0
for key := range m {
list[i] = key
i++
}
slices.Sort(list)
return list
}
// sortingName returns the name to use when sorting d into place.
func sortingName(d *ast.GenDecl) string {
if len(d.Specs) == 1 {
if s, ok := d.Specs[0].(*ast.ValueSpec); ok {
return s.Names[0].Name
}
}
return ""
}
func sortedValues(m []*Value, tok token.Token) []*Value {
list := make([]*Value, len(m)) // big enough in any case
i := 0
for _, val := range m {
if val.Decl.Tok == tok {
list[i] = val
i++
}
}
list = list[0:i]
slices.SortFunc(list, func(a, b *Value) int {
r := strings.Compare(sortingName(a.Decl), sortingName(b.Decl))
if r != 0 {
return r
}
return cmp.Compare(a.order, b.order)
})
return list
}
func sortedTypes(m map[string]*namedType, allMethods bool) []*Type {
list := make([]*Type, len(m))
i := 0
for _, t := range m {
list[i] = &Type{
Doc: t.doc,
Name: t.name,
Decl: t.decl,
Consts: sortedValues(t.values, token.CONST),
Vars: sortedValues(t.values, token.VAR),
Funcs: sortedFuncs(t.funcs, true),
Methods: sortedFuncs(t.methods, allMethods),
}
i++
}
slices.SortFunc(list, func(a, b *Type) int {
return strings.Compare(a.Name, b.Name)
})
return list
}
func removeStar(s string) string {
if len(s) > 0 && s[0] == '*' {
return s[1:]
}
return s
}
func sortedFuncs(m methodSet, allMethods bool) []*Func {
list := make([]*Func, len(m))
i := 0
for _, m := range m {
// determine which methods to include
switch {
case m.Decl == nil:
// exclude conflict entry
case allMethods, m.Level == 0, !token.IsExported(removeStar(m.Orig)):
// forced inclusion, method not embedded, or method
// embedded but original receiver type not exported
list[i] = m
i++
}
}
list = list[0:i]
slices.SortFunc(list, func(a, b *Func) int {
return strings.Compare(a.Name, b.Name)
})
return list
}
// noteBodies returns a list of note body strings given a list of notes.
// This is only used to populate the deprecated Package.Bugs field.
func noteBodies(notes []*Note) []string {
var list []string
for _, n := range notes {
list = append(list, n.Body)
}
return list
}
// ----------------------------------------------------------------------------
// Predeclared identifiers
// IsPredeclared reports whether s is a predeclared identifier.
func IsPredeclared(s string) bool {
return predeclaredTypes[s] || predeclaredFuncs[s] || predeclaredConstants[s]
}
var predeclaredTypes = map[string]bool{
"any": true,
"bool": true,
"byte": true,
"comparable": true,
"complex64": true,
"complex128": true,
"error": true,
"float32": true,
"float64": true,
"int": true,
"int8": true,
"int16": true,
"int32": true,
"int64": true,
"rune": true,
"string": true,
"uint": true,
"uint8": true,
"uint16": true,
"uint32": true,
"uint64": true,
"uintptr": true,
}
var predeclaredFuncs = map[string]bool{
"append": true,
"cap": true,
"clear": true,
"close": true,
"complex": true,
"copy": true,
"delete": true,
"imag": true,
"len": true,
"make": true,
"max": true,
"min": true,
"new": true,
"panic": true,
"print": true,
"println": true,
"real": true,
"recover": true,
}
var predeclaredConstants = map[string]bool{
"false": true,
"iota": true,
"nil": true,
"true": true,
}
// assumedPackageName returns the assumed package name
// for a given import path. This is a copy of
// golang.org/x/tools/internal/imports.ImportPathToAssumedName.
func assumedPackageName(importPath string) string {
notIdentifier := func(ch rune) bool {
return !('a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' ||
'0' <= ch && ch <= '9' ||
ch == '_' ||
ch >= utf8.RuneSelf && (unicode.IsLetter(ch) || unicode.IsDigit(ch)))
}
base := path.Base(importPath)
if strings.HasPrefix(base, "v") {
if _, err := strconv.Atoi(base[1:]); err == nil {
dir := path.Dir(importPath)
if dir != "." {
base = path.Base(dir)
}
}
}
base = strings.TrimPrefix(base, "go-")
if i := strings.IndexFunc(base, notIdentifier); i >= 0 {
base = base[:i]
}
return base
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package doc
import (
"go/doc/comment"
"strings"
"unicode"
)
// firstSentence returns the first sentence in s.
// The sentence ends after the first period followed by space and
// not preceded by exactly one uppercase letter.
func firstSentence(s string) string {
var ppp, pp, p rune
for i, q := range s {
if q == '\n' || q == '\r' || q == '\t' {
q = ' '
}
if q == ' ' && p == '.' && (!unicode.IsUpper(pp) || unicode.IsUpper(ppp)) {
return s[:i]
}
if p == '。' || p == '.' {
return s[:i]
}
ppp, pp, p = pp, p, q
}
return s
}
// Synopsis returns a cleaned version of the first sentence in text.
//
// Deprecated: New programs should use [Package.Synopsis] instead,
// which handles links in text properly.
func Synopsis(text string) string {
var p Package
return p.Synopsis(text)
}
// IllegalPrefixes is a list of lower-case prefixes that identify
// a comment as not being a doc comment.
// This helps to avoid misinterpreting the common mistake
// of a copyright notice immediately before a package statement
// as being a doc comment.
var IllegalPrefixes = []string{
"copyright",
"all rights",
"author",
}
// Synopsis returns a cleaned version of the first sentence in text.
// That sentence ends after the first period followed by space and not
// preceded by exactly one uppercase letter, or at the first paragraph break.
// The result string has no \n, \r, or \t characters and uses only single
// spaces between words. If text starts with any of the [IllegalPrefixes],
// the result is the empty string.
func (p *Package) Synopsis(text string) string {
text = firstSentence(text)
lower := strings.ToLower(text)
for _, prefix := range IllegalPrefixes {
if strings.HasPrefix(lower, prefix) {
return ""
}
}
pr := p.Printer()
pr.TextWidth = -1
d := p.Parser().Parse(text)
if len(d.Content) == 0 {
return ""
}
if _, ok := d.Content[0].(*comment.Paragraph); !ok {
return ""
}
d.Content = d.Content[:1] // might be blank lines, code blocks, etc in “first sentence”
return strings.TrimSpace(string(pr.Text(d)))
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package format implements standard formatting of Go source.
//
// Note that formatting of Go source code changes over time, so tools relying on
// consistent formatting should execute a specific version of the gofmt binary
// instead of using this package. That way, the formatting will be stable, and
// the tools won't need to be recompiled each time gofmt changes.
//
// For example, pre-submit checks that use this package directly would behave
// differently depending on what Go version each developer uses, causing the
// check to be inherently fragile.
package format
import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"io"
)
// Keep these in sync with cmd/gofmt/gofmt.go.
const (
tabWidth = 8
printerMode = printer.UseSpaces | printer.TabIndent | printerNormalizeNumbers
// printerNormalizeNumbers means to canonicalize number literal prefixes
// and exponents while printing. See https://golang.org/doc/go1.13#gofmt.
//
// This value is defined in go/printer specifically for go/format and cmd/gofmt.
printerNormalizeNumbers = 1 << 30
)
var config = printer.Config{Mode: printerMode, Tabwidth: tabWidth}
const parserMode = parser.ParseComments | parser.SkipObjectResolution
// Node formats node in canonical gofmt style and writes the result to dst.
//
// The node type must be *[ast.File], *[printer.CommentedNode], [][ast.Decl],
// [][ast.Stmt], or assignment-compatible to [ast.Expr], [ast.Decl], [ast.Spec],
// or [ast.Stmt]. Node does not modify node. Imports are not sorted for
// nodes representing partial source files (for instance, if the node is
// not an *[ast.File] or a *[printer.CommentedNode] not wrapping an *[ast.File]).
//
// The function may return early (before the entire result is written)
// and return a formatting error, for instance due to an incorrect AST.
func Node(dst io.Writer, fset *token.FileSet, node any) error {
// Determine if we have a complete source file (file != nil).
var file *ast.File
var cnode *printer.CommentedNode
switch n := node.(type) {
case *ast.File:
file = n
case *printer.CommentedNode:
if f, ok := n.Node.(*ast.File); ok {
file = f
cnode = n
}
}
// Sort imports if necessary.
if file != nil && hasUnsortedImports(file) {
// Make a copy of the AST because ast.SortImports is destructive.
// TODO(gri) Do this more efficiently.
var buf bytes.Buffer
err := config.Fprint(&buf, fset, file)
if err != nil {
return err
}
file, err = parser.ParseFile(fset, "", buf.Bytes(), parserMode)
if err != nil {
// We should never get here. If we do, provide good diagnostic.
return fmt.Errorf("format.Node internal error (%s)", err)
}
ast.SortImports(fset, file)
// Use new file with sorted imports.
node = file
if cnode != nil {
node = &printer.CommentedNode{Node: file, Comments: cnode.Comments}
}
}
return config.Fprint(dst, fset, node)
}
// Source formats src in canonical gofmt style and returns the result
// or an (I/O or syntax) error. src is expected to be a syntactically
// correct Go source file, or a list of Go declarations or statements.
//
// If src is a partial source file, the leading and trailing space of src
// is applied to the result (such that it has the same leading and trailing
// space as src), and the result is indented by the same amount as the first
// line of src containing code. Imports are not sorted for partial source files.
func Source(src []byte) ([]byte, error) {
fset := token.NewFileSet()
file, sourceAdj, indentAdj, err := parse(fset, "", src, true)
if err != nil {
return nil, err
}
if sourceAdj == nil {
// Complete source file.
// TODO(gri) consider doing this always.
ast.SortImports(fset, file)
}
return format(fset, file, sourceAdj, indentAdj, src, config)
}
func hasUnsortedImports(file *ast.File) bool {
for _, d := range file.Decls {
d, ok := d.(*ast.GenDecl)
if !ok || d.Tok != token.IMPORT {
// Not an import declaration, so we're done.
// Imports are always first.
return false
}
if d.Lparen.IsValid() {
// For now assume all grouped imports are unsorted.
// TODO(gri) Should check if they are sorted already.
return true
}
// Ungrouped imports are sorted by default.
}
return false
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// TODO(gri): This file and the file src/cmd/gofmt/internal.go are
// the same (but for this comment and the package name). Do not modify
// one without the other. Determine if we can factor out functionality
// in a public API. See also #11844 for context.
package format
import (
"bytes"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"strings"
)
// parse parses src, which was read from the named file,
// as a Go source file, declaration, or statement list.
func parse(fset *token.FileSet, filename string, src []byte, fragmentOk bool) (
file *ast.File,
sourceAdj func(src []byte, indent int) []byte,
indentAdj int,
err error,
) {
// Try as whole source file.
file, err = parser.ParseFile(fset, filename, src, parserMode)
// If there's no error, return. If the error is that the source file didn't begin with a
// package line and source fragments are ok, fall through to
// try as a source fragment. Stop and return on any other error.
if err == nil || !fragmentOk || !strings.Contains(err.Error(), "expected 'package'") {
return
}
// If this is a declaration list, make it a source file
// by inserting a package clause.
// Insert using a ';', not a newline, so that the line numbers
// in psrc match the ones in src.
psrc := append([]byte("package p;"), src...)
file, err = parser.ParseFile(fset, filename, psrc, parserMode)
if err == nil {
sourceAdj = func(src []byte, indent int) []byte {
// Remove the package clause.
// Gofmt has turned the ';' into a '\n'.
src = src[indent+len("package p\n"):]
return bytes.TrimSpace(src)
}
return
}
// If the error is that the source file didn't begin with a
// declaration, fall through to try as a statement list.
// Stop and return on any other error.
if !strings.Contains(err.Error(), "expected declaration") {
return
}
// If this is a statement list, make it a source file
// by inserting a package clause and turning the list
// into a function body. This handles expressions too.
// Insert using a ';', not a newline, so that the line numbers
// in fsrc match the ones in src. Add an extra '\n' before the '}'
// to make sure comments are flushed before the '}'.
fsrc := append(append([]byte("package p; func _() {"), src...), '\n', '\n', '}')
file, err = parser.ParseFile(fset, filename, fsrc, parserMode)
if err == nil {
sourceAdj = func(src []byte, indent int) []byte {
// Cap adjusted indent to zero.
if indent < 0 {
indent = 0
}
// Remove the wrapping.
// Gofmt has turned the "; " into a "\n\n".
// There will be two non-blank lines with indent, hence 2*indent.
src = src[2*indent+len("package p\n\nfunc _() {"):]
// Remove only the "}\n" suffix: remaining whitespaces will be trimmed anyway
src = src[:len(src)-len("}\n")]
return bytes.TrimSpace(src)
}
// Gofmt has also indented the function body one level.
// Adjust that with indentAdj.
indentAdj = -1
}
// Succeeded, or out of options.
return
}
// format formats the given package file originally obtained from src
// and adjusts the result based on the original source via sourceAdj
// and indentAdj.
func format(
fset *token.FileSet,
file *ast.File,
sourceAdj func(src []byte, indent int) []byte,
indentAdj int,
src []byte,
cfg printer.Config,
) ([]byte, error) {
if sourceAdj == nil {
// Complete source file.
var buf bytes.Buffer
err := cfg.Fprint(&buf, fset, file)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// Partial source file.
// Determine and prepend leading space.
i, j := 0, 0
for j < len(src) && isSpace(src[j]) {
if src[j] == '\n' {
i = j + 1 // byte offset of last line in leading space
}
j++
}
var res []byte
res = append(res, src[:i]...)
// Determine and prepend indentation of first code line.
// Spaces are ignored unless there are no tabs,
// in which case spaces count as one tab.
indent := 0
hasSpace := false
for _, b := range src[i:j] {
switch b {
case ' ':
hasSpace = true
case '\t':
indent++
}
}
if indent == 0 && hasSpace {
indent = 1
}
for i := 0; i < indent; i++ {
res = append(res, '\t')
}
// Format the source.
// Write it without any leading and trailing space.
cfg.Indent = indent + indentAdj
var buf bytes.Buffer
err := cfg.Fprint(&buf, fset, file)
if err != nil {
return nil, err
}
out := sourceAdj(buf.Bytes(), cfg.Indent)
// If the adjusted output is empty, the source
// was empty but (possibly) for white space.
// The result is the incoming source.
if len(out) == 0 {
return src, nil
}
// Otherwise, append output to leading space.
res = append(res, out...)
// Determine and append trailing space.
i = len(src)
for i > 0 && isSpace(src[i-1]) {
i--
}
return append(res, src[i:]...), nil
}
// isSpace reports whether the byte is a space character.
// isSpace defines a space as being among the following bytes: ' ', '\t', '\n' and '\r'.
func isSpace(b byte) bool {
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file contains the exported entry points for invoking the parser.
package parser
import (
"bytes"
"errors"
"go/ast"
"go/token"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
)
// If src != nil, readSource converts src to a []byte if possible;
// otherwise it returns an error. If src == nil, readSource returns
// the result of reading the file specified by filename.
func readSource(filename string, src any) ([]byte, error) {
if src != nil {
switch s := src.(type) {
case string:
return []byte(s), nil
case []byte:
return s, nil
case *bytes.Buffer:
// is io.Reader, but src is already available in []byte form
if s != nil {
return s.Bytes(), nil
}
case io.Reader:
return io.ReadAll(s)
}
return nil, errors.New("invalid source")
}
return os.ReadFile(filename)
}
// A Mode value is a set of flags (or 0).
// They control the amount of source code parsed and other optional
// parser functionality.
type Mode uint
const (
PackageClauseOnly Mode = 1 << iota // stop parsing after package clause
ImportsOnly // stop parsing after import declarations
ParseComments // parse comments and add them to AST
Trace // print a trace of parsed productions
DeclarationErrors // report declaration errors
SpuriousErrors // same as AllErrors, for backward-compatibility
SkipObjectResolution // skip deprecated identifier resolution; see ParseFile
AllErrors = SpuriousErrors // report all errors (not just the first 10 on different lines)
)
// ParseFile parses the source code of a single Go source file and returns
// the corresponding [ast.File] node. The source code may be provided via
// the filename of the source file, or via the src parameter.
//
// If src != nil, ParseFile parses the source from src and the filename is
// only used when recording position information. The type of the argument
// for the src parameter must be string, []byte, or [io.Reader].
// If src == nil, ParseFile parses the file specified by filename.
//
// The mode parameter controls the amount of source text parsed and
// other optional parser functionality. If the [SkipObjectResolution]
// mode bit is set (recommended), the object resolution phase of
// parsing will be skipped, causing File.Scope, File.Unresolved, and
// all Ident.Obj fields to be nil. Those fields are deprecated; see
// [ast.Object] for details.
//
// Position information is recorded in the file set fset, which must not be
// nil.
//
// If the source couldn't be read, the returned AST is nil and the error
// indicates the specific failure. If the source was read but syntax
// errors were found, the result is a partial AST (with [ast.Bad]* nodes
// representing the fragments of erroneous source code). Multiple errors
// are returned via a scanner.ErrorList which is sorted by source position.
func ParseFile(fset *token.FileSet, filename string, src any, mode Mode) (f *ast.File, err error) {
if fset == nil {
panic("parser.ParseFile: no token.FileSet provided (fset == nil)")
}
// get source
text, err := readSource(filename, src)
if err != nil {
return nil, err
}
file := fset.AddFile(filename, -1, len(text))
var p parser
defer func() {
if e := recover(); e != nil {
// resume same panic if it's not a bailout
bail, ok := e.(bailout)
if !ok {
panic(e)
} else if bail.msg != "" {
p.errors.Add(p.file.Position(bail.pos), bail.msg)
}
}
// set result values
if f == nil {
// source is not a valid Go source file - satisfy
// ParseFile API and return a valid (but) empty
// *ast.File
f = &ast.File{
Name: new(ast.Ident),
Scope: ast.NewScope(nil),
}
}
// Ensure the start/end are consistent,
// whether parsing succeeded or not.
f.FileStart = token.Pos(file.Base())
f.FileEnd = token.Pos(file.Base() + file.Size())
p.errors.Sort()
err = p.errors.Err()
}()
// parse source
p.init(file, text, mode)
f = p.parseFile()
return
}
// ParseDir calls [ParseFile] for all files with names ending in ".go" in the
// directory specified by path and returns a map of package name -> package
// AST with all the packages found.
//
// If filter != nil, only the files with [fs.FileInfo] entries passing through
// the filter (and ending in ".go") are considered. The mode bits are passed
// to [ParseFile] unchanged. Position information is recorded in fset, which
// must not be nil.
//
// If the directory couldn't be read, a nil map and the respective error are
// returned. If a parse error occurred, a non-nil but incomplete map and the
// first error encountered are returned.
//
// Deprecated: ParseDir does not consider build tags when associating
// files with packages. For precise information about the relationship
// between packages and files, use golang.org/x/tools/go/packages,
// which can also optionally parse and type-check the files too.
func ParseDir(fset *token.FileSet, path string, filter func(fs.FileInfo) bool, mode Mode) (pkgs map[string]*ast.Package, first error) {
list, err := os.ReadDir(path)
if err != nil {
return nil, err
}
pkgs = make(map[string]*ast.Package)
for _, d := range list {
if d.IsDir() || !strings.HasSuffix(d.Name(), ".go") {
continue
}
if filter != nil {
info, err := d.Info()
if err != nil {
return nil, err
}
if !filter(info) {
continue
}
}
filename := filepath.Join(path, d.Name())
if src, err := ParseFile(fset, filename, nil, mode); err == nil {
name := src.Name.Name
pkg, found := pkgs[name]
if !found {
pkg = &ast.Package{
Name: name,
Files: make(map[string]*ast.File),
}
pkgs[name] = pkg
}
pkg.Files[filename] = src
} else if first == nil {
first = err
}
}
return
}
// ParseExprFrom is a convenience function for parsing an expression.
// The arguments have the same meaning as for [ParseFile], but the source must
// be a valid Go (type or value) expression. Specifically, fset must not
// be nil.
//
// If the source couldn't be read, the returned AST is nil and the error
// indicates the specific failure. If the source was read but syntax
// errors were found, the result is a partial AST (with [ast.Bad]* nodes
// representing the fragments of erroneous source code). Multiple errors
// are returned via a scanner.ErrorList which is sorted by source position.
func ParseExprFrom(fset *token.FileSet, filename string, src any, mode Mode) (expr ast.Expr, err error) {
if fset == nil {
panic("parser.ParseExprFrom: no token.FileSet provided (fset == nil)")
}
// get source
text, err := readSource(filename, src)
if err != nil {
return nil, err
}
var p parser
defer func() {
if e := recover(); e != nil {
// resume same panic if it's not a bailout
bail, ok := e.(bailout)
if !ok {
panic(e)
} else if bail.msg != "" {
p.errors.Add(p.file.Position(bail.pos), bail.msg)
}
}
p.errors.Sort()
err = p.errors.Err()
}()
// parse expr
file := fset.AddFile(filename, -1, len(text))
p.init(file, text, mode)
expr = p.parseRhs()
// If a semicolon was inserted, consume it;
// report an error if there's more tokens.
if p.tok == token.SEMICOLON && p.lit == "\n" {
p.next()
}
p.expect(token.EOF)
return
}
// ParseExpr is a convenience function for obtaining the AST of an expression x.
// The position information recorded in the AST is undefined. The filename used
// in error messages is the empty string.
//
// If syntax errors were found, the result is a partial AST (with [ast.Bad]* nodes
// representing the fragments of erroneous source code). Multiple errors are
// returned via a scanner.ErrorList which is sorted by source position.
func ParseExpr(x string) (ast.Expr, error) {
return ParseExprFrom(token.NewFileSet(), "", []byte(x), 0)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package parser implements a parser for Go source files.
//
// The [ParseFile] function reads file input from a string, []byte, or
// io.Reader, and produces an [ast.File] representing the complete
// abstract syntax tree of the file.
//
// The [ParseExprFrom] function reads a single source-level expression and
// produces an [ast.Expr], the syntax tree of the expression.
//
// The parser accepts a larger language than is syntactically permitted by
// the Go spec, for simplicity, and for improved robustness in the presence
// of syntax errors. For instance, in method declarations, the receiver is
// treated like an ordinary parameter list and thus may contain multiple
// entries where the spec permits exactly one. Consequently, the corresponding
// field in the AST (ast.FuncDecl.Recv) field is not restricted to one entry.
//
// Applications that need to parse one or more complete packages of Go
// source code may find it more convenient not to interact directly
// with the parser but instead to use the Load function in package
// [golang.org/x/tools/go/packages].
package parser
import (
"fmt"
"go/ast"
"go/build/constraint"
"go/scanner"
"go/token"
"strings"
)
// The parser structure holds the parser's internal state.
type parser struct {
file *token.File
errors scanner.ErrorList
scanner scanner.Scanner
// Tracing/debugging
mode Mode // parsing mode
trace bool // == (mode&Trace != 0)
indent int // indentation used for tracing output
// Comments
comments []*ast.CommentGroup
leadComment *ast.CommentGroup // last lead comment
lineComment *ast.CommentGroup // last line comment
top bool // in top of file (before package clause)
goVersion string // minimum Go version found in //go:build comment
// Next token
pos token.Pos // token position
tok token.Token // one token look-ahead
lit string // token literal
// Error recovery
// (used to limit the number of calls to parser.advance
// w/o making scanning progress - avoids potential endless
// loops across multiple parser functions during error recovery)
syncPos token.Pos // last synchronization position
syncCnt int // number of parser.advance calls without progress
// Non-syntactic parser control
exprLev int // < 0: in control clause, >= 0: in expression
inRhs bool // if set, the parser is parsing a rhs expression
imports []*ast.ImportSpec // list of imports
// nestLev is used to track and limit the recursion depth
// during parsing.
nestLev int
}
func (p *parser) init(file *token.File, src []byte, mode Mode) {
p.file = file
eh := func(pos token.Position, msg string) { p.errors.Add(pos, msg) }
p.scanner.Init(p.file, src, eh, scanner.ScanComments)
p.top = true
p.mode = mode
p.trace = mode&Trace != 0 // for convenience (p.trace is used frequently)
p.next()
}
// ----------------------------------------------------------------------------
// Parsing support
func (p *parser) printTrace(a ...any) {
const dots = ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . "
const n = len(dots)
pos := p.file.Position(p.pos)
fmt.Printf("%5d:%3d: ", pos.Line, pos.Column)
i := 2 * p.indent
for i > n {
fmt.Print(dots)
i -= n
}
// i <= n
fmt.Print(dots[0:i])
fmt.Println(a...)
}
func trace(p *parser, msg string) *parser {
p.printTrace(msg, "(")
p.indent++
return p
}
// Usage pattern: defer un(trace(p, "..."))
func un(p *parser) {
p.indent--
p.printTrace(")")
}
// maxNestLev is the deepest we're willing to recurse during parsing
const maxNestLev int = 1e5
func incNestLev(p *parser) *parser {
p.nestLev++
if p.nestLev > maxNestLev {
p.error(p.pos, "exceeded max nesting depth")
panic(bailout{})
}
return p
}
// decNestLev is used to track nesting depth during parsing to prevent stack exhaustion.
// It is used along with incNestLev in a similar fashion to how un and trace are used.
func decNestLev(p *parser) {
p.nestLev--
}
// Advance to the next token.
func (p *parser) next0() {
// Because of one-token look-ahead, print the previous token
// when tracing as it provides a more readable output. The
// very first token (!p.pos.IsValid()) is not initialized
// (it is token.ILLEGAL), so don't print it.
if p.trace && p.pos.IsValid() {
s := p.tok.String()
switch {
case p.tok.IsLiteral():
p.printTrace(s, p.lit)
case p.tok.IsOperator(), p.tok.IsKeyword():
p.printTrace("\"" + s + "\"")
default:
p.printTrace(s)
}
}
for {
p.pos, p.tok, p.lit = p.scanner.Scan()
if p.tok == token.COMMENT {
if p.top && strings.HasPrefix(p.lit, "//go:build") {
if x, err := constraint.Parse(p.lit); err == nil {
p.goVersion = constraint.GoVersion(x)
}
}
if p.mode&ParseComments == 0 {
continue
}
} else {
// Found a non-comment; top of file is over.
p.top = false
}
break
}
}
// lineFor returns the line of pos, ignoring line directive adjustments.
func (p *parser) lineFor(pos token.Pos) int {
return p.file.PositionFor(pos, false).Line
}
// Consume a comment and return it and the line on which it ends.
func (p *parser) consumeComment() (comment *ast.Comment, endline int) {
// /*-style comments may end on a different line than where they start.
// Scan the comment for '\n' chars and adjust endline accordingly.
endline = p.lineFor(p.pos)
if p.lit[1] == '*' {
// don't use range here - no need to decode Unicode code points
for i := 0; i < len(p.lit); i++ {
if p.lit[i] == '\n' {
endline++
}
}
}
comment = &ast.Comment{Slash: p.pos, Text: p.lit}
p.next0()
return
}
// Consume a group of adjacent comments, add it to the parser's
// comments list, and return it together with the line at which
// the last comment in the group ends. A non-comment token or n
// empty lines terminate a comment group.
func (p *parser) consumeCommentGroup(n int) (comments *ast.CommentGroup, endline int) {
var list []*ast.Comment
endline = p.lineFor(p.pos)
for p.tok == token.COMMENT && p.lineFor(p.pos) <= endline+n {
var comment *ast.Comment
comment, endline = p.consumeComment()
list = append(list, comment)
}
// add comment group to the comments list
comments = &ast.CommentGroup{List: list}
p.comments = append(p.comments, comments)
return
}
// Advance to the next non-comment token. In the process, collect
// any comment groups encountered, and remember the last lead and
// line comments.
//
// A lead comment is a comment group that starts and ends in a
// line without any other tokens and that is followed by a non-comment
// token on the line immediately after the comment group.
//
// A line comment is a comment group that follows a non-comment
// token on the same line, and that has no tokens after it on the line
// where it ends.
//
// Lead and line comments may be considered documentation that is
// stored in the AST.
func (p *parser) next() {
p.leadComment = nil
p.lineComment = nil
prev := p.pos
p.next0()
if p.tok == token.COMMENT {
var comment *ast.CommentGroup
var endline int
if p.lineFor(p.pos) == p.lineFor(prev) {
// The comment is on same line as the previous token; it
// cannot be a lead comment but may be a line comment.
comment, endline = p.consumeCommentGroup(0)
if p.lineFor(p.pos) != endline || p.tok == token.SEMICOLON || p.tok == token.EOF {
// The next token is on a different line, thus
// the last comment group is a line comment.
p.lineComment = comment
}
}
// consume successor comments, if any
endline = -1
for p.tok == token.COMMENT {
comment, endline = p.consumeCommentGroup(1)
}
if endline+1 == p.lineFor(p.pos) {
// The next token is following on the line immediately after the
// comment group, thus the last comment group is a lead comment.
p.leadComment = comment
}
}
}
// A bailout panic is raised to indicate early termination. pos and msg are
// only populated when bailing out of object resolution.
type bailout struct {
pos token.Pos
msg string
}
func (p *parser) error(pos token.Pos, msg string) {
if p.trace {
defer un(trace(p, "error: "+msg))
}
epos := p.file.Position(pos)
// If AllErrors is not set, discard errors reported on the same line
// as the last recorded error and stop parsing if there are more than
// 10 errors.
if p.mode&AllErrors == 0 {
n := len(p.errors)
if n > 0 && p.errors[n-1].Pos.Line == epos.Line {
return // discard - likely a spurious error
}
if n > 10 {
panic(bailout{})
}
}
p.errors.Add(epos, msg)
}
func (p *parser) errorExpected(pos token.Pos, msg string) {
msg = "expected " + msg
if pos == p.pos {
// the error happened at the current position;
// make the error message more specific
switch {
case p.tok == token.SEMICOLON && p.lit == "\n":
msg += ", found newline"
case p.tok.IsLiteral():
// print 123 rather than 'INT', etc.
msg += ", found " + p.lit
default:
msg += ", found '" + p.tok.String() + "'"
}
}
p.error(pos, msg)
}
func (p *parser) expect(tok token.Token) token.Pos {
pos := p.pos
if p.tok != tok {
p.errorExpected(pos, "'"+tok.String()+"'")
}
p.next() // make progress
return pos
}
// expect2 is like expect, but it returns an invalid position
// if the expected token is not found.
func (p *parser) expect2(tok token.Token) (pos token.Pos) {
if p.tok == tok {
pos = p.pos
} else {
p.errorExpected(p.pos, "'"+tok.String()+"'")
}
p.next() // make progress
return
}
// expectClosing is like expect but provides a better error message
// for the common case of a missing comma before a newline.
func (p *parser) expectClosing(tok token.Token, context string) token.Pos {
if p.tok != tok && p.tok == token.SEMICOLON && p.lit == "\n" {
p.error(p.pos, "missing ',' before newline in "+context)
p.next()
}
return p.expect(tok)
}
// expectSemi consumes a semicolon and returns the applicable line comment.
func (p *parser) expectSemi() (comment *ast.CommentGroup) {
// semicolon is optional before a closing ')' or '}'
if p.tok != token.RPAREN && p.tok != token.RBRACE {
switch p.tok {
case token.COMMA:
// permit a ',' instead of a ';' but complain
p.errorExpected(p.pos, "';'")
fallthrough
case token.SEMICOLON:
if p.lit == ";" {
// explicit semicolon
p.next()
comment = p.lineComment // use following comments
} else {
// artificial semicolon
comment = p.lineComment // use preceding comments
p.next()
}
return comment
default:
p.errorExpected(p.pos, "';'")
p.advance(stmtStart)
}
}
return nil
}
func (p *parser) atComma(context string, follow token.Token) bool {
if p.tok == token.COMMA {
return true
}
if p.tok != follow {
msg := "missing ','"
if p.tok == token.SEMICOLON && p.lit == "\n" {
msg += " before newline"
}
p.error(p.pos, msg+" in "+context)
return true // "insert" comma and continue
}
return false
}
func assert(cond bool, msg string) {
if !cond {
panic("go/parser internal error: " + msg)
}
}
// advance consumes tokens until the current token p.tok
// is in the 'to' set, or token.EOF. For error recovery.
func (p *parser) advance(to map[token.Token]bool) {
for ; p.tok != token.EOF; p.next() {
if to[p.tok] {
// Return only if parser made some progress since last
// sync or if it has not reached 10 advance calls without
// progress. Otherwise consume at least one token to
// avoid an endless parser loop (it is possible that
// both parseOperand and parseStmt call advance and
// correctly do not advance, thus the need for the
// invocation limit p.syncCnt).
if p.pos == p.syncPos && p.syncCnt < 10 {
p.syncCnt++
return
}
if p.pos > p.syncPos {
p.syncPos = p.pos
p.syncCnt = 0
return
}
// Reaching here indicates a parser bug, likely an
// incorrect token list in this function, but it only
// leads to skipping of possibly correct code if a
// previous error is present, and thus is preferred
// over a non-terminating parse.
}
}
}
var stmtStart = map[token.Token]bool{
token.BREAK: true,
token.CONST: true,
token.CONTINUE: true,
token.DEFER: true,
token.FALLTHROUGH: true,
token.FOR: true,
token.GO: true,
token.GOTO: true,
token.IF: true,
token.RETURN: true,
token.SELECT: true,
token.SWITCH: true,
token.TYPE: true,
token.VAR: true,
}
var declStart = map[token.Token]bool{
token.IMPORT: true,
token.CONST: true,
token.TYPE: true,
token.VAR: true,
}
var exprEnd = map[token.Token]bool{
token.COMMA: true,
token.COLON: true,
token.SEMICOLON: true,
token.RPAREN: true,
token.RBRACK: true,
token.RBRACE: true,
}
// ----------------------------------------------------------------------------
// Identifiers
func (p *parser) parseIdent() *ast.Ident {
pos := p.pos
name := "_"
if p.tok == token.IDENT {
name = p.lit
p.next()
} else {
p.expect(token.IDENT) // use expect() error handling
}
return &ast.Ident{NamePos: pos, Name: name}
}
func (p *parser) parseIdentList() (list []*ast.Ident) {
if p.trace {
defer un(trace(p, "IdentList"))
}
list = append(list, p.parseIdent())
for p.tok == token.COMMA {
p.next()
list = append(list, p.parseIdent())
}
return
}
// ----------------------------------------------------------------------------
// Common productions
// If lhs is set, result list elements which are identifiers are not resolved.
func (p *parser) parseExprList() (list []ast.Expr) {
if p.trace {
defer un(trace(p, "ExpressionList"))
}
list = append(list, p.parseExpr())
for p.tok == token.COMMA {
p.next()
list = append(list, p.parseExpr())
}
return
}
func (p *parser) parseList(inRhs bool) []ast.Expr {
old := p.inRhs
p.inRhs = inRhs
list := p.parseExprList()
p.inRhs = old
return list
}
// ----------------------------------------------------------------------------
// Types
func (p *parser) parseType() ast.Expr {
if p.trace {
defer un(trace(p, "Type"))
}
typ := p.tryIdentOrType()
if typ == nil {
pos := p.pos
p.errorExpected(pos, "type")
p.advance(exprEnd)
return &ast.BadExpr{From: pos, To: p.pos}
}
return typ
}
func (p *parser) parseQualifiedIdent(ident *ast.Ident) ast.Expr {
if p.trace {
defer un(trace(p, "QualifiedIdent"))
}
typ := p.parseTypeName(ident)
if p.tok == token.LBRACK {
typ = p.parseTypeInstance(typ)
}
return typ
}
// If the result is an identifier, it is not resolved.
func (p *parser) parseTypeName(ident *ast.Ident) ast.Expr {
if p.trace {
defer un(trace(p, "TypeName"))
}
if ident == nil {
ident = p.parseIdent()
}
if p.tok == token.PERIOD {
// ident is a package name
p.next()
sel := p.parseIdent()
return &ast.SelectorExpr{X: ident, Sel: sel}
}
return ident
}
// "[" has already been consumed, and lbrack is its position.
// If len != nil it is the already consumed array length.
func (p *parser) parseArrayType(lbrack token.Pos, len ast.Expr) *ast.ArrayType {
if p.trace {
defer un(trace(p, "ArrayType"))
}
if len == nil {
p.exprLev++
// always permit ellipsis for more fault-tolerant parsing
if p.tok == token.ELLIPSIS {
len = &ast.Ellipsis{Ellipsis: p.pos}
p.next()
} else if p.tok != token.RBRACK {
len = p.parseRhs()
}
p.exprLev--
}
if p.tok == token.COMMA {
// Trailing commas are accepted in type parameter
// lists but not in array type declarations.
// Accept for better error handling but complain.
p.error(p.pos, "unexpected comma; expecting ]")
p.next()
}
p.expect(token.RBRACK)
elt := p.parseType()
return &ast.ArrayType{Lbrack: lbrack, Len: len, Elt: elt}
}
func (p *parser) parseArrayFieldOrTypeInstance(x *ast.Ident) (*ast.Ident, ast.Expr) {
if p.trace {
defer un(trace(p, "ArrayFieldOrTypeInstance"))
}
lbrack := p.expect(token.LBRACK)
trailingComma := token.NoPos // if valid, the position of a trailing comma preceding the ']'
var args []ast.Expr
if p.tok != token.RBRACK {
p.exprLev++
args = append(args, p.parseRhs())
for p.tok == token.COMMA {
comma := p.pos
p.next()
if p.tok == token.RBRACK {
trailingComma = comma
break
}
args = append(args, p.parseRhs())
}
p.exprLev--
}
rbrack := p.expect(token.RBRACK)
if len(args) == 0 {
// x []E
elt := p.parseType()
return x, &ast.ArrayType{Lbrack: lbrack, Elt: elt}
}
// x [P]E or x[P]
if len(args) == 1 {
elt := p.tryIdentOrType()
if elt != nil {
// x [P]E
if trailingComma.IsValid() {
// Trailing commas are invalid in array type fields.
p.error(trailingComma, "unexpected comma; expecting ]")
}
return x, &ast.ArrayType{Lbrack: lbrack, Len: args[0], Elt: elt}
}
}
// x[P], x[P1, P2], ...
return nil, packIndexExpr(x, lbrack, args, rbrack)
}
func (p *parser) parseFieldDecl() *ast.Field {
if p.trace {
defer un(trace(p, "FieldDecl"))
}
doc := p.leadComment
var names []*ast.Ident
var typ ast.Expr
switch p.tok {
case token.IDENT:
name := p.parseIdent()
if p.tok == token.PERIOD || p.tok == token.STRING || p.tok == token.SEMICOLON || p.tok == token.RBRACE {
// embedded type
typ = name
if p.tok == token.PERIOD {
typ = p.parseQualifiedIdent(name)
}
} else {
// name1, name2, ... T
names = []*ast.Ident{name}
for p.tok == token.COMMA {
p.next()
names = append(names, p.parseIdent())
}
// Careful dance: We don't know if we have an embedded instantiated
// type T[P1, P2, ...] or a field T of array type []E or [P]E.
if len(names) == 1 && p.tok == token.LBRACK {
name, typ = p.parseArrayFieldOrTypeInstance(name)
if name == nil {
names = nil
}
} else {
// T P
typ = p.parseType()
}
}
case token.MUL:
star := p.pos
p.next()
if p.tok == token.LPAREN {
// *(T)
p.error(p.pos, "cannot parenthesize embedded type")
p.next()
typ = p.parseQualifiedIdent(nil)
// expect closing ')' but no need to complain if missing
if p.tok == token.RPAREN {
p.next()
}
} else {
// *T
typ = p.parseQualifiedIdent(nil)
}
typ = &ast.StarExpr{Star: star, X: typ}
case token.LPAREN:
p.error(p.pos, "cannot parenthesize embedded type")
p.next()
if p.tok == token.MUL {
// (*T)
star := p.pos
p.next()
typ = &ast.StarExpr{Star: star, X: p.parseQualifiedIdent(nil)}
} else {
// (T)
typ = p.parseQualifiedIdent(nil)
}
// expect closing ')' but no need to complain if missing
if p.tok == token.RPAREN {
p.next()
}
default:
pos := p.pos
p.errorExpected(pos, "field name or embedded type")
p.advance(exprEnd)
typ = &ast.BadExpr{From: pos, To: p.pos}
}
var tag *ast.BasicLit
if p.tok == token.STRING {
tag = &ast.BasicLit{ValuePos: p.pos, Kind: p.tok, Value: p.lit}
p.next()
}
comment := p.expectSemi()
field := &ast.Field{Doc: doc, Names: names, Type: typ, Tag: tag, Comment: comment}
return field
}
func (p *parser) parseStructType() *ast.StructType {
if p.trace {
defer un(trace(p, "StructType"))
}
pos := p.expect(token.STRUCT)
lbrace := p.expect(token.LBRACE)
var list []*ast.Field
for p.tok == token.IDENT || p.tok == token.MUL || p.tok == token.LPAREN {
// a field declaration cannot start with a '(' but we accept
// it here for more robust parsing and better error messages
// (parseFieldDecl will check and complain if necessary)
list = append(list, p.parseFieldDecl())
}
rbrace := p.expect(token.RBRACE)
return &ast.StructType{
Struct: pos,
Fields: &ast.FieldList{
Opening: lbrace,
List: list,
Closing: rbrace,
},
}
}
func (p *parser) parsePointerType() *ast.StarExpr {
if p.trace {
defer un(trace(p, "PointerType"))
}
star := p.expect(token.MUL)
base := p.parseType()
return &ast.StarExpr{Star: star, X: base}
}
func (p *parser) parseDotsType() *ast.Ellipsis {
if p.trace {
defer un(trace(p, "DotsType"))
}
pos := p.expect(token.ELLIPSIS)
elt := p.parseType()
return &ast.Ellipsis{Ellipsis: pos, Elt: elt}
}
type field struct {
name *ast.Ident
typ ast.Expr
}
func (p *parser) parseParamDecl(name *ast.Ident, typeSetsOK bool) (f field) {
// TODO(rFindley) refactor to be more similar to paramDeclOrNil in the syntax
// package
if p.trace {
defer un(trace(p, "ParamDecl"))
}
ptok := p.tok
if name != nil {
p.tok = token.IDENT // force token.IDENT case in switch below
} else if typeSetsOK && p.tok == token.TILDE {
// "~" ...
return field{nil, p.embeddedElem(nil)}
}
switch p.tok {
case token.IDENT:
// name
if name != nil {
f.name = name
p.tok = ptok
} else {
f.name = p.parseIdent()
}
switch p.tok {
case token.IDENT, token.MUL, token.ARROW, token.FUNC, token.CHAN, token.MAP, token.STRUCT, token.INTERFACE, token.LPAREN:
// name type
f.typ = p.parseType()
case token.LBRACK:
// name "[" type1, ..., typeN "]" or name "[" n "]" type
f.name, f.typ = p.parseArrayFieldOrTypeInstance(f.name)
case token.ELLIPSIS:
// name "..." type
f.typ = p.parseDotsType()
return // don't allow ...type "|" ...
case token.PERIOD:
// name "." ...
f.typ = p.parseQualifiedIdent(f.name)
f.name = nil
case token.TILDE:
if typeSetsOK {
f.typ = p.embeddedElem(nil)
return
}
case token.OR:
if typeSetsOK {
// name "|" typeset
f.typ = p.embeddedElem(f.name)
f.name = nil
return
}
}
case token.MUL, token.ARROW, token.FUNC, token.LBRACK, token.CHAN, token.MAP, token.STRUCT, token.INTERFACE, token.LPAREN:
// type
f.typ = p.parseType()
case token.ELLIPSIS:
// "..." type
// (always accepted)
f.typ = p.parseDotsType()
return // don't allow ...type "|" ...
default:
// TODO(rfindley): this is incorrect in the case of type parameter lists
// (should be "']'" in that case)
p.errorExpected(p.pos, "')'")
p.advance(exprEnd)
}
// [name] type "|"
if typeSetsOK && p.tok == token.OR && f.typ != nil {
f.typ = p.embeddedElem(f.typ)
}
return
}
func (p *parser) parseParameterList(name0 *ast.Ident, typ0 ast.Expr, closing token.Token, dddok bool) (params []*ast.Field) {
if p.trace {
defer un(trace(p, "ParameterList"))
}
// Type parameters are the only parameter list closed by ']'.
tparams := closing == token.RBRACK
pos0 := p.pos
if name0 != nil {
pos0 = name0.Pos()
} else if typ0 != nil {
pos0 = typ0.Pos()
}
// Note: The code below matches the corresponding code in the syntax
// parser closely. Changes must be reflected in either parser.
// For the code to match, we use the local []field list that
// corresponds to []syntax.Field. At the end, the list must be
// converted into an []*ast.Field.
var list []field
var named int // number of parameters that have an explicit name and type
var typed int // number of parameters that have an explicit type
for name0 != nil || p.tok != closing && p.tok != token.EOF {
var par field
if typ0 != nil {
if tparams {
typ0 = p.embeddedElem(typ0)
}
par = field{name0, typ0}
} else {
par = p.parseParamDecl(name0, tparams)
}
name0 = nil // 1st name was consumed if present
typ0 = nil // 1st typ was consumed if present
if par.name != nil || par.typ != nil {
list = append(list, par)
if par.name != nil && par.typ != nil {
named++
}
if par.typ != nil {
typed++
}
}
if !p.atComma("parameter list", closing) {
break
}
p.next()
}
if len(list) == 0 {
return // not uncommon
}
// distribute parameter types (len(list) > 0)
if named == 0 {
// all unnamed => found names are type names
for i := range list {
par := &list[i]
if typ := par.name; typ != nil {
par.typ = typ
par.name = nil
}
}
if tparams {
// This is the same error handling as below, adjusted for type parameters only.
// See comment below for details. (go.dev/issue/64534)
var errPos token.Pos
var msg string
if named == typed /* same as typed == 0 */ {
errPos = p.pos // position error at closing ]
msg = "missing type constraint"
} else {
errPos = pos0 // position at opening [ or first name
msg = "missing type parameter name"
if len(list) == 1 {
msg += " or invalid array length"
}
}
p.error(errPos, msg)
}
} else if named != len(list) {
// some named or we're in a type parameter list => all must be named
var errPos token.Pos // left-most error position (or invalid)
var typ ast.Expr // current type (from right to left)
for i := range list {
if par := &list[len(list)-i-1]; par.typ != nil {
typ = par.typ
if par.name == nil {
errPos = typ.Pos()
n := ast.NewIdent("_")
n.NamePos = errPos // correct position
par.name = n
}
} else if typ != nil {
par.typ = typ
} else {
// par.typ == nil && typ == nil => we only have a par.name
errPos = par.name.Pos()
par.typ = &ast.BadExpr{From: errPos, To: p.pos}
}
}
if errPos.IsValid() {
// Not all parameters are named because named != len(list).
// If named == typed, there must be parameters that have no types.
// They must be at the end of the parameter list, otherwise types
// would have been filled in by the right-to-left sweep above and
// there would be no error.
// If tparams is set, the parameter list is a type parameter list.
var msg string
if named == typed {
errPos = p.pos // position error at closing token ) or ]
if tparams {
msg = "missing type constraint"
} else {
msg = "missing parameter type"
}
} else {
if tparams {
msg = "missing type parameter name"
// go.dev/issue/60812
if len(list) == 1 {
msg += " or invalid array length"
}
} else {
msg = "missing parameter name"
}
}
p.error(errPos, msg)
}
}
// check use of ...
first := true // only report first occurrence
for i, _ := range list {
f := &list[i]
if t, _ := f.typ.(*ast.Ellipsis); t != nil && (!dddok || i+1 < len(list)) {
if first {
first = false
if dddok {
p.error(t.Ellipsis, "can only use ... with final parameter")
} else {
p.error(t.Ellipsis, "invalid use of ...")
}
}
// use T instead of invalid ...T
// TODO(gri) would like to use `f.typ = t.Elt` but that causes problems
// with the resolver in cases of reuse of the same identifier
f.typ = &ast.BadExpr{From: t.Pos(), To: t.End()}
}
}
// Convert list to []*ast.Field.
// If list contains types only, each type gets its own ast.Field.
if named == 0 {
// parameter list consists of types only
for _, par := range list {
assert(par.typ != nil, "nil type in unnamed parameter list")
params = append(params, &ast.Field{Type: par.typ})
}
return
}
// If the parameter list consists of named parameters with types,
// collect all names with the same types into a single ast.Field.
var names []*ast.Ident
var typ ast.Expr
addParams := func() {
assert(typ != nil, "nil type in named parameter list")
field := &ast.Field{Names: names, Type: typ}
params = append(params, field)
names = nil
}
for _, par := range list {
if par.typ != typ {
if len(names) > 0 {
addParams()
}
typ = par.typ
}
names = append(names, par.name)
}
if len(names) > 0 {
addParams()
}
return
}
func (p *parser) parseTypeParameters() *ast.FieldList {
if p.trace {
defer un(trace(p, "TypeParameters"))
}
lbrack := p.expect(token.LBRACK)
var list []*ast.Field
if p.tok != token.RBRACK {
list = p.parseParameterList(nil, nil, token.RBRACK, false)
}
rbrack := p.expect(token.RBRACK)
if len(list) == 0 {
p.error(rbrack, "empty type parameter list")
return nil // avoid follow-on errors
}
return &ast.FieldList{Opening: lbrack, List: list, Closing: rbrack}
}
func (p *parser) parseParameters(result bool) *ast.FieldList {
if p.trace {
defer un(trace(p, "Parameters"))
}
if !result || p.tok == token.LPAREN {
lparen := p.expect(token.LPAREN)
var list []*ast.Field
if p.tok != token.RPAREN {
list = p.parseParameterList(nil, nil, token.RPAREN, !result)
}
rparen := p.expect(token.RPAREN)
return &ast.FieldList{Opening: lparen, List: list, Closing: rparen}
}
if typ := p.tryIdentOrType(); typ != nil {
list := make([]*ast.Field, 1)
list[0] = &ast.Field{Type: typ}
return &ast.FieldList{List: list}
}
return nil
}
func (p *parser) parseFuncType() *ast.FuncType {
if p.trace {
defer un(trace(p, "FuncType"))
}
pos := p.expect(token.FUNC)
// accept type parameters for more tolerant parsing but complain
if p.tok == token.LBRACK {
tparams := p.parseTypeParameters()
if tparams != nil {
p.error(tparams.Opening, "function type must have no type parameters")
}
}
params := p.parseParameters(false)
results := p.parseParameters(true)
return &ast.FuncType{Func: pos, Params: params, Results: results}
}
func (p *parser) parseMethodSpec() *ast.Field {
if p.trace {
defer un(trace(p, "MethodSpec"))
}
doc := p.leadComment
var idents []*ast.Ident
var typ ast.Expr
x := p.parseTypeName(nil)
if ident, _ := x.(*ast.Ident); ident != nil {
switch {
case p.tok == token.LBRACK:
// generic method or embedded instantiated type
lbrack := p.pos
p.next()
p.exprLev++
x := p.parseExpr()
p.exprLev--
if name0, _ := x.(*ast.Ident); name0 != nil && p.tok != token.COMMA && p.tok != token.RBRACK {
// generic method m[T any]
//
// Interface methods do not have type parameters. We parse them for a
// better error message and improved error recovery.
_ = p.parseParameterList(name0, nil, token.RBRACK, false)
_ = p.expect(token.RBRACK)
p.error(lbrack, "interface method must have no type parameters")
// TODO(rfindley) refactor to share code with parseFuncType.
params := p.parseParameters(false)
results := p.parseParameters(true)
idents = []*ast.Ident{ident}
typ = &ast.FuncType{
Func: token.NoPos,
Params: params,
Results: results,
}
} else {
// embedded instantiated type
// TODO(rfindley) should resolve all identifiers in x.
list := []ast.Expr{x}
if p.atComma("type argument list", token.RBRACK) {
p.exprLev++
p.next()
for p.tok != token.RBRACK && p.tok != token.EOF {
list = append(list, p.parseType())
if !p.atComma("type argument list", token.RBRACK) {
break
}
p.next()
}
p.exprLev--
}
rbrack := p.expectClosing(token.RBRACK, "type argument list")
typ = packIndexExpr(ident, lbrack, list, rbrack)
}
case p.tok == token.LPAREN:
// ordinary method
// TODO(rfindley) refactor to share code with parseFuncType.
params := p.parseParameters(false)
results := p.parseParameters(true)
idents = []*ast.Ident{ident}
typ = &ast.FuncType{Func: token.NoPos, Params: params, Results: results}
default:
// embedded type
typ = x
}
} else {
// embedded, possibly instantiated type
typ = x
if p.tok == token.LBRACK {
// embedded instantiated interface
typ = p.parseTypeInstance(typ)
}
}
// Comment is added at the callsite: the field below may joined with
// additional type specs using '|'.
// TODO(rfindley) this should be refactored.
// TODO(rfindley) add more tests for comment handling.
return &ast.Field{Doc: doc, Names: idents, Type: typ}
}
func (p *parser) embeddedElem(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "EmbeddedElem"))
}
if x == nil {
x = p.embeddedTerm()
}
for p.tok == token.OR {
t := new(ast.BinaryExpr)
t.OpPos = p.pos
t.Op = token.OR
p.next()
t.X = x
t.Y = p.embeddedTerm()
x = t
}
return x
}
func (p *parser) embeddedTerm() ast.Expr {
if p.trace {
defer un(trace(p, "EmbeddedTerm"))
}
if p.tok == token.TILDE {
t := new(ast.UnaryExpr)
t.OpPos = p.pos
t.Op = token.TILDE
p.next()
t.X = p.parseType()
return t
}
t := p.tryIdentOrType()
if t == nil {
pos := p.pos
p.errorExpected(pos, "~ term or type")
p.advance(exprEnd)
return &ast.BadExpr{From: pos, To: p.pos}
}
return t
}
func (p *parser) parseInterfaceType() *ast.InterfaceType {
if p.trace {
defer un(trace(p, "InterfaceType"))
}
pos := p.expect(token.INTERFACE)
lbrace := p.expect(token.LBRACE)
var list []*ast.Field
parseElements:
for {
switch {
case p.tok == token.IDENT:
f := p.parseMethodSpec()
if f.Names == nil {
f.Type = p.embeddedElem(f.Type)
}
f.Comment = p.expectSemi()
list = append(list, f)
case p.tok == token.TILDE:
typ := p.embeddedElem(nil)
comment := p.expectSemi()
list = append(list, &ast.Field{Type: typ, Comment: comment})
default:
if t := p.tryIdentOrType(); t != nil {
typ := p.embeddedElem(t)
comment := p.expectSemi()
list = append(list, &ast.Field{Type: typ, Comment: comment})
} else {
break parseElements
}
}
}
// TODO(rfindley): the error produced here could be improved, since we could
// accept an identifier, 'type', or a '}' at this point.
rbrace := p.expect(token.RBRACE)
return &ast.InterfaceType{
Interface: pos,
Methods: &ast.FieldList{
Opening: lbrace,
List: list,
Closing: rbrace,
},
}
}
func (p *parser) parseMapType() *ast.MapType {
if p.trace {
defer un(trace(p, "MapType"))
}
pos := p.expect(token.MAP)
p.expect(token.LBRACK)
key := p.parseType()
p.expect(token.RBRACK)
value := p.parseType()
return &ast.MapType{Map: pos, Key: key, Value: value}
}
func (p *parser) parseChanType() *ast.ChanType {
if p.trace {
defer un(trace(p, "ChanType"))
}
pos := p.pos
dir := ast.SEND | ast.RECV
var arrow token.Pos
if p.tok == token.CHAN {
p.next()
if p.tok == token.ARROW {
arrow = p.pos
p.next()
dir = ast.SEND
}
} else {
arrow = p.expect(token.ARROW)
p.expect(token.CHAN)
dir = ast.RECV
}
value := p.parseType()
return &ast.ChanType{Begin: pos, Arrow: arrow, Dir: dir, Value: value}
}
func (p *parser) parseTypeInstance(typ ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "TypeInstance"))
}
opening := p.expect(token.LBRACK)
p.exprLev++
var list []ast.Expr
for p.tok != token.RBRACK && p.tok != token.EOF {
list = append(list, p.parseType())
if !p.atComma("type argument list", token.RBRACK) {
break
}
p.next()
}
p.exprLev--
closing := p.expectClosing(token.RBRACK, "type argument list")
if len(list) == 0 {
p.errorExpected(closing, "type argument list")
return &ast.IndexExpr{
X: typ,
Lbrack: opening,
Index: &ast.BadExpr{From: opening + 1, To: closing},
Rbrack: closing,
}
}
return packIndexExpr(typ, opening, list, closing)
}
func (p *parser) tryIdentOrType() ast.Expr {
defer decNestLev(incNestLev(p))
switch p.tok {
case token.IDENT:
typ := p.parseTypeName(nil)
if p.tok == token.LBRACK {
typ = p.parseTypeInstance(typ)
}
return typ
case token.LBRACK:
lbrack := p.expect(token.LBRACK)
return p.parseArrayType(lbrack, nil)
case token.STRUCT:
return p.parseStructType()
case token.MUL:
return p.parsePointerType()
case token.FUNC:
return p.parseFuncType()
case token.INTERFACE:
return p.parseInterfaceType()
case token.MAP:
return p.parseMapType()
case token.CHAN, token.ARROW:
return p.parseChanType()
case token.LPAREN:
lparen := p.pos
p.next()
typ := p.parseType()
rparen := p.expect(token.RPAREN)
return &ast.ParenExpr{Lparen: lparen, X: typ, Rparen: rparen}
}
// no type found
return nil
}
// ----------------------------------------------------------------------------
// Blocks
func (p *parser) parseStmtList() (list []ast.Stmt) {
if p.trace {
defer un(trace(p, "StatementList"))
}
for p.tok != token.CASE && p.tok != token.DEFAULT && p.tok != token.RBRACE && p.tok != token.EOF {
list = append(list, p.parseStmt())
}
return
}
func (p *parser) parseBody() *ast.BlockStmt {
if p.trace {
defer un(trace(p, "Body"))
}
lbrace := p.expect(token.LBRACE)
list := p.parseStmtList()
rbrace := p.expect2(token.RBRACE)
return &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
}
func (p *parser) parseBlockStmt() *ast.BlockStmt {
if p.trace {
defer un(trace(p, "BlockStmt"))
}
lbrace := p.expect(token.LBRACE)
list := p.parseStmtList()
rbrace := p.expect2(token.RBRACE)
return &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
}
// ----------------------------------------------------------------------------
// Expressions
func (p *parser) parseFuncTypeOrLit() ast.Expr {
if p.trace {
defer un(trace(p, "FuncTypeOrLit"))
}
typ := p.parseFuncType()
if p.tok != token.LBRACE {
// function type only
return typ
}
p.exprLev++
body := p.parseBody()
p.exprLev--
return &ast.FuncLit{Type: typ, Body: body}
}
// parseOperand may return an expression or a raw type (incl. array
// types of the form [...]T). Callers must verify the result.
func (p *parser) parseOperand() ast.Expr {
if p.trace {
defer un(trace(p, "Operand"))
}
switch p.tok {
case token.IDENT:
x := p.parseIdent()
return x
case token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING:
x := &ast.BasicLit{ValuePos: p.pos, Kind: p.tok, Value: p.lit}
p.next()
return x
case token.LPAREN:
lparen := p.pos
p.next()
p.exprLev++
x := p.parseRhs() // types may be parenthesized: (some type)
p.exprLev--
rparen := p.expect(token.RPAREN)
return &ast.ParenExpr{Lparen: lparen, X: x, Rparen: rparen}
case token.FUNC:
return p.parseFuncTypeOrLit()
}
if typ := p.tryIdentOrType(); typ != nil { // do not consume trailing type parameters
// could be type for composite literal or conversion
_, isIdent := typ.(*ast.Ident)
assert(!isIdent, "type cannot be identifier")
return typ
}
// we have an error
pos := p.pos
p.errorExpected(pos, "operand")
p.advance(stmtStart)
return &ast.BadExpr{From: pos, To: p.pos}
}
func (p *parser) parseSelector(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "Selector"))
}
sel := p.parseIdent()
return &ast.SelectorExpr{X: x, Sel: sel}
}
func (p *parser) parseTypeAssertion(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "TypeAssertion"))
}
lparen := p.expect(token.LPAREN)
var typ ast.Expr
if p.tok == token.TYPE {
// type switch: typ == nil
p.next()
} else {
typ = p.parseType()
}
rparen := p.expect(token.RPAREN)
return &ast.TypeAssertExpr{X: x, Type: typ, Lparen: lparen, Rparen: rparen}
}
func (p *parser) parseIndexOrSliceOrInstance(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "parseIndexOrSliceOrInstance"))
}
lbrack := p.expect(token.LBRACK)
if p.tok == token.RBRACK {
// empty index, slice or index expressions are not permitted;
// accept them for parsing tolerance, but complain
p.errorExpected(p.pos, "operand")
rbrack := p.pos
p.next()
return &ast.IndexExpr{
X: x,
Lbrack: lbrack,
Index: &ast.BadExpr{From: rbrack, To: rbrack},
Rbrack: rbrack,
}
}
p.exprLev++
const N = 3 // change the 3 to 2 to disable 3-index slices
var args []ast.Expr
var index [N]ast.Expr
var colons [N - 1]token.Pos
if p.tok != token.COLON {
// We can't know if we have an index expression or a type instantiation;
// so even if we see a (named) type we are not going to be in type context.
index[0] = p.parseRhs()
}
ncolons := 0
switch p.tok {
case token.COLON:
// slice expression
for p.tok == token.COLON && ncolons < len(colons) {
colons[ncolons] = p.pos
ncolons++
p.next()
if p.tok != token.COLON && p.tok != token.RBRACK && p.tok != token.EOF {
index[ncolons] = p.parseRhs()
}
}
case token.COMMA:
// instance expression
args = append(args, index[0])
for p.tok == token.COMMA {
p.next()
if p.tok != token.RBRACK && p.tok != token.EOF {
args = append(args, p.parseType())
}
}
}
p.exprLev--
rbrack := p.expect(token.RBRACK)
if ncolons > 0 {
// slice expression
slice3 := false
if ncolons == 2 {
slice3 = true
// Check presence of middle and final index here rather than during type-checking
// to prevent erroneous programs from passing through gofmt (was go.dev/issue/7305).
if index[1] == nil {
p.error(colons[0], "middle index required in 3-index slice")
index[1] = &ast.BadExpr{From: colons[0] + 1, To: colons[1]}
}
if index[2] == nil {
p.error(colons[1], "final index required in 3-index slice")
index[2] = &ast.BadExpr{From: colons[1] + 1, To: rbrack}
}
}
return &ast.SliceExpr{X: x, Lbrack: lbrack, Low: index[0], High: index[1], Max: index[2], Slice3: slice3, Rbrack: rbrack}
}
if len(args) == 0 {
// index expression
return &ast.IndexExpr{X: x, Lbrack: lbrack, Index: index[0], Rbrack: rbrack}
}
// instance expression
return packIndexExpr(x, lbrack, args, rbrack)
}
func (p *parser) parseCallOrConversion(fun ast.Expr) *ast.CallExpr {
if p.trace {
defer un(trace(p, "CallOrConversion"))
}
lparen := p.expect(token.LPAREN)
p.exprLev++
var list []ast.Expr
var ellipsis token.Pos
for p.tok != token.RPAREN && p.tok != token.EOF && !ellipsis.IsValid() {
list = append(list, p.parseRhs()) // builtins may expect a type: make(some type, ...)
if p.tok == token.ELLIPSIS {
ellipsis = p.pos
p.next()
}
if !p.atComma("argument list", token.RPAREN) {
break
}
p.next()
}
p.exprLev--
rparen := p.expectClosing(token.RPAREN, "argument list")
return &ast.CallExpr{Fun: fun, Lparen: lparen, Args: list, Ellipsis: ellipsis, Rparen: rparen}
}
func (p *parser) parseValue() ast.Expr {
if p.trace {
defer un(trace(p, "Element"))
}
if p.tok == token.LBRACE {
return p.parseLiteralValue(nil)
}
x := p.parseExpr()
return x
}
func (p *parser) parseElement() ast.Expr {
if p.trace {
defer un(trace(p, "Element"))
}
x := p.parseValue()
if p.tok == token.COLON {
colon := p.pos
p.next()
x = &ast.KeyValueExpr{Key: x, Colon: colon, Value: p.parseValue()}
}
return x
}
func (p *parser) parseElementList() (list []ast.Expr) {
if p.trace {
defer un(trace(p, "ElementList"))
}
for p.tok != token.RBRACE && p.tok != token.EOF {
list = append(list, p.parseElement())
if !p.atComma("composite literal", token.RBRACE) {
break
}
p.next()
}
return
}
func (p *parser) parseLiteralValue(typ ast.Expr) ast.Expr {
defer decNestLev(incNestLev(p))
if p.trace {
defer un(trace(p, "LiteralValue"))
}
lbrace := p.expect(token.LBRACE)
var elts []ast.Expr
p.exprLev++
if p.tok != token.RBRACE {
elts = p.parseElementList()
}
p.exprLev--
rbrace := p.expectClosing(token.RBRACE, "composite literal")
return &ast.CompositeLit{Type: typ, Lbrace: lbrace, Elts: elts, Rbrace: rbrace}
}
func (p *parser) parsePrimaryExpr(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "PrimaryExpr"))
}
if x == nil {
x = p.parseOperand()
}
// We track the nesting here rather than at the entry for the function,
// since it can iteratively produce a nested output, and we want to
// limit how deep a structure we generate.
var n int
defer func() { p.nestLev -= n }()
for n = 1; ; n++ {
incNestLev(p)
switch p.tok {
case token.PERIOD:
p.next()
switch p.tok {
case token.IDENT:
x = p.parseSelector(x)
case token.LPAREN:
x = p.parseTypeAssertion(x)
default:
pos := p.pos
p.errorExpected(pos, "selector or type assertion")
// TODO(rFindley) The check for token.RBRACE below is a targeted fix
// to error recovery sufficient to make the x/tools tests to
// pass with the new parsing logic introduced for type
// parameters. Remove this once error recovery has been
// more generally reconsidered.
if p.tok != token.RBRACE {
p.next() // make progress
}
sel := &ast.Ident{NamePos: pos, Name: "_"}
x = &ast.SelectorExpr{X: x, Sel: sel}
}
case token.LBRACK:
x = p.parseIndexOrSliceOrInstance(x)
case token.LPAREN:
x = p.parseCallOrConversion(x)
case token.LBRACE:
// operand may have returned a parenthesized complit
// type; accept it but complain if we have a complit
t := ast.Unparen(x)
// determine if '{' belongs to a composite literal or a block statement
switch t.(type) {
case *ast.BadExpr, *ast.Ident, *ast.SelectorExpr:
if p.exprLev < 0 {
return x
}
// x is possibly a composite literal type
case *ast.IndexExpr, *ast.IndexListExpr:
if p.exprLev < 0 {
return x
}
// x is possibly a composite literal type
case *ast.ArrayType, *ast.StructType, *ast.MapType:
// x is a composite literal type
default:
return x
}
if t != x {
p.error(t.Pos(), "cannot parenthesize type in composite literal")
// already progressed, no need to advance
}
x = p.parseLiteralValue(x)
default:
return x
}
}
}
func (p *parser) parseUnaryExpr() ast.Expr {
defer decNestLev(incNestLev(p))
if p.trace {
defer un(trace(p, "UnaryExpr"))
}
switch p.tok {
case token.ADD, token.SUB, token.NOT, token.XOR, token.AND, token.TILDE:
pos, op := p.pos, p.tok
p.next()
x := p.parseUnaryExpr()
return &ast.UnaryExpr{OpPos: pos, Op: op, X: x}
case token.ARROW:
// channel type or receive expression
arrow := p.pos
p.next()
// If the next token is token.CHAN we still don't know if it
// is a channel type or a receive operation - we only know
// once we have found the end of the unary expression. There
// are two cases:
//
// <- type => (<-type) must be channel type
// <- expr => <-(expr) is a receive from an expression
//
// In the first case, the arrow must be re-associated with
// the channel type parsed already:
//
// <- (chan type) => (<-chan type)
// <- (chan<- type) => (<-chan (<-type))
x := p.parseUnaryExpr()
// determine which case we have
if typ, ok := x.(*ast.ChanType); ok {
// (<-type)
// re-associate position info and <-
dir := ast.SEND
for ok && dir == ast.SEND {
if typ.Dir == ast.RECV {
// error: (<-type) is (<-(<-chan T))
p.errorExpected(typ.Arrow, "'chan'")
}
arrow, typ.Begin, typ.Arrow = typ.Arrow, arrow, arrow
dir, typ.Dir = typ.Dir, ast.RECV
typ, ok = typ.Value.(*ast.ChanType)
}
if dir == ast.SEND {
p.errorExpected(arrow, "channel type")
}
return x
}
// <-(expr)
return &ast.UnaryExpr{OpPos: arrow, Op: token.ARROW, X: x}
case token.MUL:
// pointer type or unary "*" expression
pos := p.pos
p.next()
x := p.parseUnaryExpr()
return &ast.StarExpr{Star: pos, X: x}
}
return p.parsePrimaryExpr(nil)
}
func (p *parser) tokPrec() (token.Token, int) {
tok := p.tok
if p.inRhs && tok == token.ASSIGN {
tok = token.EQL
}
return tok, tok.Precedence()
}
// parseBinaryExpr parses a (possibly) binary expression.
// If x is non-nil, it is used as the left operand.
//
// TODO(rfindley): parseBinaryExpr has become overloaded. Consider refactoring.
func (p *parser) parseBinaryExpr(x ast.Expr, prec1 int) ast.Expr {
if p.trace {
defer un(trace(p, "BinaryExpr"))
}
if x == nil {
x = p.parseUnaryExpr()
}
// We track the nesting here rather than at the entry for the function,
// since it can iteratively produce a nested output, and we want to
// limit how deep a structure we generate.
var n int
defer func() { p.nestLev -= n }()
for n = 1; ; n++ {
incNestLev(p)
op, oprec := p.tokPrec()
if oprec < prec1 {
return x
}
pos := p.expect(op)
y := p.parseBinaryExpr(nil, oprec+1)
x = &ast.BinaryExpr{X: x, OpPos: pos, Op: op, Y: y}
}
}
// The result may be a type or even a raw type ([...]int).
func (p *parser) parseExpr() ast.Expr {
if p.trace {
defer un(trace(p, "Expression"))
}
return p.parseBinaryExpr(nil, token.LowestPrec+1)
}
func (p *parser) parseRhs() ast.Expr {
old := p.inRhs
p.inRhs = true
x := p.parseExpr()
p.inRhs = old
return x
}
// ----------------------------------------------------------------------------
// Statements
// Parsing modes for parseSimpleStmt.
const (
basic = iota
labelOk
rangeOk
)
// parseSimpleStmt returns true as 2nd result if it parsed the assignment
// of a range clause (with mode == rangeOk). The returned statement is an
// assignment with a right-hand side that is a single unary expression of
// the form "range x". No guarantees are given for the left-hand side.
func (p *parser) parseSimpleStmt(mode int) (ast.Stmt, bool) {
if p.trace {
defer un(trace(p, "SimpleStmt"))
}
x := p.parseList(false)
switch p.tok {
case
token.DEFINE, token.ASSIGN, token.ADD_ASSIGN,
token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN,
token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN,
token.XOR_ASSIGN, token.SHL_ASSIGN, token.SHR_ASSIGN, token.AND_NOT_ASSIGN:
// assignment statement, possibly part of a range clause
pos, tok := p.pos, p.tok
p.next()
var y []ast.Expr
isRange := false
if mode == rangeOk && p.tok == token.RANGE && (tok == token.DEFINE || tok == token.ASSIGN) {
pos := p.pos
p.next()
y = []ast.Expr{&ast.UnaryExpr{OpPos: pos, Op: token.RANGE, X: p.parseRhs()}}
isRange = true
} else {
y = p.parseList(true)
}
return &ast.AssignStmt{Lhs: x, TokPos: pos, Tok: tok, Rhs: y}, isRange
}
if len(x) > 1 {
p.errorExpected(x[0].Pos(), "1 expression")
// continue with first expression
}
switch p.tok {
case token.COLON:
// labeled statement
colon := p.pos
p.next()
if label, isIdent := x[0].(*ast.Ident); mode == labelOk && isIdent {
// Go spec: The scope of a label is the body of the function
// in which it is declared and excludes the body of any nested
// function.
stmt := &ast.LabeledStmt{Label: label, Colon: colon, Stmt: p.parseStmt()}
return stmt, false
}
// The label declaration typically starts at x[0].Pos(), but the label
// declaration may be erroneous due to a token after that position (and
// before the ':'). If SpuriousErrors is not set, the (only) error
// reported for the line is the illegal label error instead of the token
// before the ':' that caused the problem. Thus, use the (latest) colon
// position for error reporting.
p.error(colon, "illegal label declaration")
return &ast.BadStmt{From: x[0].Pos(), To: colon + 1}, false
case token.ARROW:
// send statement
arrow := p.pos
p.next()
y := p.parseRhs()
return &ast.SendStmt{Chan: x[0], Arrow: arrow, Value: y}, false
case token.INC, token.DEC:
// increment or decrement
s := &ast.IncDecStmt{X: x[0], TokPos: p.pos, Tok: p.tok}
p.next()
return s, false
}
// expression
return &ast.ExprStmt{X: x[0]}, false
}
func (p *parser) parseCallExpr(callType string) *ast.CallExpr {
x := p.parseRhs() // could be a conversion: (some type)(x)
if t := ast.Unparen(x); t != x {
p.error(x.Pos(), fmt.Sprintf("expression in %s must not be parenthesized", callType))
x = t
}
if call, isCall := x.(*ast.CallExpr); isCall {
return call
}
if _, isBad := x.(*ast.BadExpr); !isBad {
// only report error if it's a new one
p.error(x.End(), fmt.Sprintf("expression in %s must be function call", callType))
}
return nil
}
func (p *parser) parseGoStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "GoStmt"))
}
pos := p.expect(token.GO)
call := p.parseCallExpr("go")
p.expectSemi()
if call == nil {
return &ast.BadStmt{From: pos, To: pos + 2} // len("go")
}
return &ast.GoStmt{Go: pos, Call: call}
}
func (p *parser) parseDeferStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "DeferStmt"))
}
pos := p.expect(token.DEFER)
call := p.parseCallExpr("defer")
p.expectSemi()
if call == nil {
return &ast.BadStmt{From: pos, To: pos + 5} // len("defer")
}
return &ast.DeferStmt{Defer: pos, Call: call}
}
func (p *parser) parseReturnStmt() *ast.ReturnStmt {
if p.trace {
defer un(trace(p, "ReturnStmt"))
}
pos := p.pos
p.expect(token.RETURN)
var x []ast.Expr
if p.tok != token.SEMICOLON && p.tok != token.RBRACE {
x = p.parseList(true)
}
p.expectSemi()
return &ast.ReturnStmt{Return: pos, Results: x}
}
func (p *parser) parseBranchStmt(tok token.Token) *ast.BranchStmt {
if p.trace {
defer un(trace(p, "BranchStmt"))
}
pos := p.expect(tok)
var label *ast.Ident
if tok == token.GOTO || ((tok == token.CONTINUE || tok == token.BREAK) && p.tok == token.IDENT) {
label = p.parseIdent()
}
p.expectSemi()
return &ast.BranchStmt{TokPos: pos, Tok: tok, Label: label}
}
func (p *parser) makeExpr(s ast.Stmt, want string) ast.Expr {
if s == nil {
return nil
}
if es, isExpr := s.(*ast.ExprStmt); isExpr {
return es.X
}
found := "simple statement"
if _, isAss := s.(*ast.AssignStmt); isAss {
found = "assignment"
}
p.error(s.Pos(), fmt.Sprintf("expected %s, found %s (missing parentheses around composite literal?)", want, found))
return &ast.BadExpr{From: s.Pos(), To: s.End()}
}
// parseIfHeader is an adjusted version of parser.header
// in cmd/compile/internal/syntax/parser.go, which has
// been tuned for better error handling.
func (p *parser) parseIfHeader() (init ast.Stmt, cond ast.Expr) {
if p.tok == token.LBRACE {
p.error(p.pos, "missing condition in if statement")
cond = &ast.BadExpr{From: p.pos, To: p.pos}
return
}
// p.tok != token.LBRACE
prevLev := p.exprLev
p.exprLev = -1
if p.tok != token.SEMICOLON {
// accept potential variable declaration but complain
if p.tok == token.VAR {
p.next()
p.error(p.pos, "var declaration not allowed in if initializer")
}
init, _ = p.parseSimpleStmt(basic)
}
var condStmt ast.Stmt
var semi struct {
pos token.Pos
lit string // ";" or "\n"; valid if pos.IsValid()
}
if p.tok != token.LBRACE {
if p.tok == token.SEMICOLON {
semi.pos = p.pos
semi.lit = p.lit
p.next()
} else {
p.expect(token.SEMICOLON)
}
if p.tok != token.LBRACE {
condStmt, _ = p.parseSimpleStmt(basic)
}
} else {
condStmt = init
init = nil
}
if condStmt != nil {
cond = p.makeExpr(condStmt, "boolean expression")
} else if semi.pos.IsValid() {
if semi.lit == "\n" {
p.error(semi.pos, "unexpected newline, expecting { after if clause")
} else {
p.error(semi.pos, "missing condition in if statement")
}
}
// make sure we have a valid AST
if cond == nil {
cond = &ast.BadExpr{From: p.pos, To: p.pos}
}
p.exprLev = prevLev
return
}
func (p *parser) parseIfStmt() *ast.IfStmt {
defer decNestLev(incNestLev(p))
if p.trace {
defer un(trace(p, "IfStmt"))
}
pos := p.expect(token.IF)
init, cond := p.parseIfHeader()
body := p.parseBlockStmt()
var else_ ast.Stmt
if p.tok == token.ELSE {
p.next()
switch p.tok {
case token.IF:
else_ = p.parseIfStmt()
case token.LBRACE:
else_ = p.parseBlockStmt()
p.expectSemi()
default:
p.errorExpected(p.pos, "if statement or block")
else_ = &ast.BadStmt{From: p.pos, To: p.pos}
}
} else {
p.expectSemi()
}
return &ast.IfStmt{If: pos, Init: init, Cond: cond, Body: body, Else: else_}
}
func (p *parser) parseCaseClause() *ast.CaseClause {
if p.trace {
defer un(trace(p, "CaseClause"))
}
pos := p.pos
var list []ast.Expr
if p.tok == token.CASE {
p.next()
list = p.parseList(true)
} else {
p.expect(token.DEFAULT)
}
colon := p.expect(token.COLON)
body := p.parseStmtList()
return &ast.CaseClause{Case: pos, List: list, Colon: colon, Body: body}
}
func isTypeSwitchAssert(x ast.Expr) bool {
a, ok := x.(*ast.TypeAssertExpr)
return ok && a.Type == nil
}
func (p *parser) isTypeSwitchGuard(s ast.Stmt) bool {
switch t := s.(type) {
case *ast.ExprStmt:
// x.(type)
return isTypeSwitchAssert(t.X)
case *ast.AssignStmt:
// v := x.(type)
if len(t.Lhs) == 1 && len(t.Rhs) == 1 && isTypeSwitchAssert(t.Rhs[0]) {
switch t.Tok {
case token.ASSIGN:
// permit v = x.(type) but complain
p.error(t.TokPos, "expected ':=', found '='")
fallthrough
case token.DEFINE:
return true
}
}
}
return false
}
func (p *parser) parseSwitchStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "SwitchStmt"))
}
pos := p.expect(token.SWITCH)
var s1, s2 ast.Stmt
if p.tok != token.LBRACE {
prevLev := p.exprLev
p.exprLev = -1
if p.tok != token.SEMICOLON {
s2, _ = p.parseSimpleStmt(basic)
}
if p.tok == token.SEMICOLON {
p.next()
s1 = s2
s2 = nil
if p.tok != token.LBRACE {
// A TypeSwitchGuard may declare a variable in addition
// to the variable declared in the initial SimpleStmt.
// Introduce extra scope to avoid redeclaration errors:
//
// switch t := 0; t := x.(T) { ... }
//
// (this code is not valid Go because the first t
// cannot be accessed and thus is never used, the extra
// scope is needed for the correct error message).
//
// If we don't have a type switch, s2 must be an expression.
// Having the extra nested but empty scope won't affect it.
s2, _ = p.parseSimpleStmt(basic)
}
}
p.exprLev = prevLev
}
typeSwitch := p.isTypeSwitchGuard(s2)
lbrace := p.expect(token.LBRACE)
var list []ast.Stmt
for p.tok == token.CASE || p.tok == token.DEFAULT {
list = append(list, p.parseCaseClause())
}
rbrace := p.expect(token.RBRACE)
p.expectSemi()
body := &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
if typeSwitch {
return &ast.TypeSwitchStmt{Switch: pos, Init: s1, Assign: s2, Body: body}
}
return &ast.SwitchStmt{Switch: pos, Init: s1, Tag: p.makeExpr(s2, "switch expression"), Body: body}
}
func (p *parser) parseCommClause() *ast.CommClause {
if p.trace {
defer un(trace(p, "CommClause"))
}
pos := p.pos
var comm ast.Stmt
if p.tok == token.CASE {
p.next()
lhs := p.parseList(false)
if p.tok == token.ARROW {
// SendStmt
if len(lhs) > 1 {
p.errorExpected(lhs[0].Pos(), "1 expression")
// continue with first expression
}
arrow := p.pos
p.next()
rhs := p.parseRhs()
comm = &ast.SendStmt{Chan: lhs[0], Arrow: arrow, Value: rhs}
} else {
// RecvStmt
if tok := p.tok; tok == token.ASSIGN || tok == token.DEFINE {
// RecvStmt with assignment
if len(lhs) > 2 {
p.errorExpected(lhs[0].Pos(), "1 or 2 expressions")
// continue with first two expressions
lhs = lhs[0:2]
}
pos := p.pos
p.next()
rhs := p.parseRhs()
comm = &ast.AssignStmt{Lhs: lhs, TokPos: pos, Tok: tok, Rhs: []ast.Expr{rhs}}
} else {
// lhs must be single receive operation
if len(lhs) > 1 {
p.errorExpected(lhs[0].Pos(), "1 expression")
// continue with first expression
}
comm = &ast.ExprStmt{X: lhs[0]}
}
}
} else {
p.expect(token.DEFAULT)
}
colon := p.expect(token.COLON)
body := p.parseStmtList()
return &ast.CommClause{Case: pos, Comm: comm, Colon: colon, Body: body}
}
func (p *parser) parseSelectStmt() *ast.SelectStmt {
if p.trace {
defer un(trace(p, "SelectStmt"))
}
pos := p.expect(token.SELECT)
lbrace := p.expect(token.LBRACE)
var list []ast.Stmt
for p.tok == token.CASE || p.tok == token.DEFAULT {
list = append(list, p.parseCommClause())
}
rbrace := p.expect(token.RBRACE)
p.expectSemi()
body := &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
return &ast.SelectStmt{Select: pos, Body: body}
}
func (p *parser) parseForStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "ForStmt"))
}
pos := p.expect(token.FOR)
var s1, s2, s3 ast.Stmt
var isRange bool
if p.tok != token.LBRACE {
prevLev := p.exprLev
p.exprLev = -1
if p.tok != token.SEMICOLON {
if p.tok == token.RANGE {
// "for range x" (nil lhs in assignment)
pos := p.pos
p.next()
y := []ast.Expr{&ast.UnaryExpr{OpPos: pos, Op: token.RANGE, X: p.parseRhs()}}
s2 = &ast.AssignStmt{Rhs: y}
isRange = true
} else {
s2, isRange = p.parseSimpleStmt(rangeOk)
}
}
if !isRange && p.tok == token.SEMICOLON {
p.next()
s1 = s2
s2 = nil
if p.tok != token.SEMICOLON {
s2, _ = p.parseSimpleStmt(basic)
}
p.expectSemi()
if p.tok != token.LBRACE {
s3, _ = p.parseSimpleStmt(basic)
}
}
p.exprLev = prevLev
}
body := p.parseBlockStmt()
p.expectSemi()
if isRange {
as := s2.(*ast.AssignStmt)
// check lhs
var key, value ast.Expr
switch len(as.Lhs) {
case 0:
// nothing to do
case 1:
key = as.Lhs[0]
case 2:
key, value = as.Lhs[0], as.Lhs[1]
default:
p.errorExpected(as.Lhs[len(as.Lhs)-1].Pos(), "at most 2 expressions")
return &ast.BadStmt{From: pos, To: body.End()}
}
// parseSimpleStmt returned a right-hand side that
// is a single unary expression of the form "range x"
x := as.Rhs[0].(*ast.UnaryExpr).X
return &ast.RangeStmt{
For: pos,
Key: key,
Value: value,
TokPos: as.TokPos,
Tok: as.Tok,
Range: as.Rhs[0].Pos(),
X: x,
Body: body,
}
}
// regular for statement
return &ast.ForStmt{
For: pos,
Init: s1,
Cond: p.makeExpr(s2, "boolean or range expression"),
Post: s3,
Body: body,
}
}
func (p *parser) parseStmt() (s ast.Stmt) {
defer decNestLev(incNestLev(p))
if p.trace {
defer un(trace(p, "Statement"))
}
switch p.tok {
case token.CONST, token.TYPE, token.VAR:
s = &ast.DeclStmt{Decl: p.parseDecl(stmtStart)}
case
// tokens that may start an expression
token.IDENT, token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING, token.FUNC, token.LPAREN, // operands
token.LBRACK, token.STRUCT, token.MAP, token.CHAN, token.INTERFACE, // composite types
token.ADD, token.SUB, token.MUL, token.AND, token.XOR, token.ARROW, token.NOT: // unary operators
s, _ = p.parseSimpleStmt(labelOk)
// because of the required look-ahead, labeled statements are
// parsed by parseSimpleStmt - don't expect a semicolon after
// them
if _, isLabeledStmt := s.(*ast.LabeledStmt); !isLabeledStmt {
p.expectSemi()
}
case token.GO:
s = p.parseGoStmt()
case token.DEFER:
s = p.parseDeferStmt()
case token.RETURN:
s = p.parseReturnStmt()
case token.BREAK, token.CONTINUE, token.GOTO, token.FALLTHROUGH:
s = p.parseBranchStmt(p.tok)
case token.LBRACE:
s = p.parseBlockStmt()
p.expectSemi()
case token.IF:
s = p.parseIfStmt()
case token.SWITCH:
s = p.parseSwitchStmt()
case token.SELECT:
s = p.parseSelectStmt()
case token.FOR:
s = p.parseForStmt()
case token.SEMICOLON:
// Is it ever possible to have an implicit semicolon
// producing an empty statement in a valid program?
// (handle correctly anyway)
s = &ast.EmptyStmt{Semicolon: p.pos, Implicit: p.lit == "\n"}
p.next()
case token.RBRACE:
// a semicolon may be omitted before a closing "}"
s = &ast.EmptyStmt{Semicolon: p.pos, Implicit: true}
default:
// no statement found
pos := p.pos
p.errorExpected(pos, "statement")
p.advance(stmtStart)
s = &ast.BadStmt{From: pos, To: p.pos}
}
return
}
// ----------------------------------------------------------------------------
// Declarations
type parseSpecFunction func(doc *ast.CommentGroup, keyword token.Token, iota int) ast.Spec
func (p *parser) parseImportSpec(doc *ast.CommentGroup, _ token.Token, _ int) ast.Spec {
if p.trace {
defer un(trace(p, "ImportSpec"))
}
var ident *ast.Ident
switch p.tok {
case token.IDENT:
ident = p.parseIdent()
case token.PERIOD:
ident = &ast.Ident{NamePos: p.pos, Name: "."}
p.next()
}
pos := p.pos
var path string
if p.tok == token.STRING {
path = p.lit
p.next()
} else if p.tok.IsLiteral() {
p.error(pos, "import path must be a string")
p.next()
} else {
p.error(pos, "missing import path")
p.advance(exprEnd)
}
comment := p.expectSemi()
// collect imports
spec := &ast.ImportSpec{
Doc: doc,
Name: ident,
Path: &ast.BasicLit{ValuePos: pos, Kind: token.STRING, Value: path},
Comment: comment,
}
p.imports = append(p.imports, spec)
return spec
}
func (p *parser) parseValueSpec(doc *ast.CommentGroup, keyword token.Token, iota int) ast.Spec {
if p.trace {
defer un(trace(p, keyword.String()+"Spec"))
}
idents := p.parseIdentList()
var typ ast.Expr
var values []ast.Expr
switch keyword {
case token.CONST:
// always permit optional type and initialization for more tolerant parsing
if p.tok != token.EOF && p.tok != token.SEMICOLON && p.tok != token.RPAREN {
typ = p.tryIdentOrType()
if p.tok == token.ASSIGN {
p.next()
values = p.parseList(true)
}
}
case token.VAR:
if p.tok != token.ASSIGN {
typ = p.parseType()
}
if p.tok == token.ASSIGN {
p.next()
values = p.parseList(true)
}
default:
panic("unreachable")
}
comment := p.expectSemi()
spec := &ast.ValueSpec{
Doc: doc,
Names: idents,
Type: typ,
Values: values,
Comment: comment,
}
return spec
}
func (p *parser) parseGenericType(spec *ast.TypeSpec, openPos token.Pos, name0 *ast.Ident, typ0 ast.Expr) {
if p.trace {
defer un(trace(p, "parseGenericType"))
}
list := p.parseParameterList(name0, typ0, token.RBRACK, false)
closePos := p.expect(token.RBRACK)
spec.TypeParams = &ast.FieldList{Opening: openPos, List: list, Closing: closePos}
if p.tok == token.ASSIGN {
// type alias
spec.Assign = p.pos
p.next()
}
spec.Type = p.parseType()
}
func (p *parser) parseTypeSpec(doc *ast.CommentGroup, _ token.Token, _ int) ast.Spec {
if p.trace {
defer un(trace(p, "TypeSpec"))
}
name := p.parseIdent()
spec := &ast.TypeSpec{Doc: doc, Name: name}
if p.tok == token.LBRACK {
// spec.Name "[" ...
// array/slice type or type parameter list
lbrack := p.pos
p.next()
if p.tok == token.IDENT {
// We may have an array type or a type parameter list.
// In either case we expect an expression x (which may
// just be a name, or a more complex expression) which
// we can analyze further.
//
// A type parameter list may have a type bound starting
// with a "[" as in: P []E. In that case, simply parsing
// an expression would lead to an error: P[] is invalid.
// But since index or slice expressions are never constant
// and thus invalid array length expressions, if the name
// is followed by "[" it must be the start of an array or
// slice constraint. Only if we don't see a "[" do we
// need to parse a full expression. Notably, name <- x
// is not a concern because name <- x is a statement and
// not an expression.
var x ast.Expr = p.parseIdent()
if p.tok != token.LBRACK {
// To parse the expression starting with name, expand
// the call sequence we would get by passing in name
// to parser.expr, and pass in name to parsePrimaryExpr.
p.exprLev++
lhs := p.parsePrimaryExpr(x)
x = p.parseBinaryExpr(lhs, token.LowestPrec+1)
p.exprLev--
}
// Analyze expression x. If we can split x into a type parameter
// name, possibly followed by a type parameter type, we consider
// this the start of a type parameter list, with some caveats:
// a single name followed by "]" tilts the decision towards an
// array declaration; a type parameter type that could also be
// an ordinary expression but which is followed by a comma tilts
// the decision towards a type parameter list.
if pname, ptype := extractName(x, p.tok == token.COMMA); pname != nil && (ptype != nil || p.tok != token.RBRACK) {
// spec.Name "[" pname ...
// spec.Name "[" pname ptype ...
// spec.Name "[" pname ptype "," ...
p.parseGenericType(spec, lbrack, pname, ptype) // ptype may be nil
} else {
// spec.Name "[" pname "]" ...
// spec.Name "[" x ...
spec.Type = p.parseArrayType(lbrack, x)
}
} else {
// array type
spec.Type = p.parseArrayType(lbrack, nil)
}
} else {
// no type parameters
if p.tok == token.ASSIGN {
// type alias
spec.Assign = p.pos
p.next()
}
spec.Type = p.parseType()
}
spec.Comment = p.expectSemi()
return spec
}
// extractName splits the expression x into (name, expr) if syntactically
// x can be written as name expr. The split only happens if expr is a type
// element (per the isTypeElem predicate) or if force is set.
// If x is just a name, the result is (name, nil). If the split succeeds,
// the result is (name, expr). Otherwise the result is (nil, x).
// Examples:
//
// x force name expr
// ------------------------------------
// P*[]int T/F P *[]int
// P*E T P *E
// P*E F nil P*E
// P([]int) T/F P ([]int)
// P(E) T P (E)
// P(E) F nil P(E)
// P*E|F|~G T/F P *E|F|~G
// P*E|F|G T P *E|F|G
// P*E|F|G F nil P*E|F|G
func extractName(x ast.Expr, force bool) (*ast.Ident, ast.Expr) {
switch x := x.(type) {
case *ast.Ident:
return x, nil
case *ast.BinaryExpr:
switch x.Op {
case token.MUL:
if name, _ := x.X.(*ast.Ident); name != nil && (force || isTypeElem(x.Y)) {
// x = name *x.Y
return name, &ast.StarExpr{Star: x.OpPos, X: x.Y}
}
case token.OR:
if name, lhs := extractName(x.X, force || isTypeElem(x.Y)); name != nil && lhs != nil {
// x = name lhs|x.Y
op := *x
op.X = lhs
return name, &op
}
}
case *ast.CallExpr:
if name, _ := x.Fun.(*ast.Ident); name != nil {
if len(x.Args) == 1 && x.Ellipsis == token.NoPos && (force || isTypeElem(x.Args[0])) {
// x = name (x.Args[0])
// (Note that the cmd/compile/internal/syntax parser does not care
// about syntax tree fidelity and does not preserve parentheses here.)
return name, &ast.ParenExpr{
Lparen: x.Lparen,
X: x.Args[0],
Rparen: x.Rparen,
}
}
}
}
return nil, x
}
// isTypeElem reports whether x is a (possibly parenthesized) type element expression.
// The result is false if x could be a type element OR an ordinary (value) expression.
func isTypeElem(x ast.Expr) bool {
switch x := x.(type) {
case *ast.ArrayType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.MapType, *ast.ChanType:
return true
case *ast.BinaryExpr:
return isTypeElem(x.X) || isTypeElem(x.Y)
case *ast.UnaryExpr:
return x.Op == token.TILDE
case *ast.ParenExpr:
return isTypeElem(x.X)
}
return false
}
func (p *parser) parseGenDecl(keyword token.Token, f parseSpecFunction) *ast.GenDecl {
if p.trace {
defer un(trace(p, "GenDecl("+keyword.String()+")"))
}
doc := p.leadComment
pos := p.expect(keyword)
var lparen, rparen token.Pos
var list []ast.Spec
if p.tok == token.LPAREN {
lparen = p.pos
p.next()
for iota := 0; p.tok != token.RPAREN && p.tok != token.EOF; iota++ {
list = append(list, f(p.leadComment, keyword, iota))
}
rparen = p.expect(token.RPAREN)
p.expectSemi()
} else {
list = append(list, f(nil, keyword, 0))
}
return &ast.GenDecl{
Doc: doc,
TokPos: pos,
Tok: keyword,
Lparen: lparen,
Specs: list,
Rparen: rparen,
}
}
func (p *parser) parseFuncDecl() *ast.FuncDecl {
if p.trace {
defer un(trace(p, "FunctionDecl"))
}
doc := p.leadComment
pos := p.expect(token.FUNC)
var recv *ast.FieldList
if p.tok == token.LPAREN {
recv = p.parseParameters(false)
}
ident := p.parseIdent()
var tparams *ast.FieldList
if p.tok == token.LBRACK {
tparams = p.parseTypeParameters()
if recv != nil && tparams != nil {
// Method declarations do not have type parameters. We parse them for a
// better error message and improved error recovery.
p.error(tparams.Opening, "method must have no type parameters")
tparams = nil
}
}
params := p.parseParameters(false)
results := p.parseParameters(true)
var body *ast.BlockStmt
switch p.tok {
case token.LBRACE:
body = p.parseBody()
p.expectSemi()
case token.SEMICOLON:
p.next()
if p.tok == token.LBRACE {
// opening { of function declaration on next line
p.error(p.pos, "unexpected semicolon or newline before {")
body = p.parseBody()
p.expectSemi()
}
default:
p.expectSemi()
}
decl := &ast.FuncDecl{
Doc: doc,
Recv: recv,
Name: ident,
Type: &ast.FuncType{
Func: pos,
TypeParams: tparams,
Params: params,
Results: results,
},
Body: body,
}
return decl
}
func (p *parser) parseDecl(sync map[token.Token]bool) ast.Decl {
if p.trace {
defer un(trace(p, "Declaration"))
}
var f parseSpecFunction
switch p.tok {
case token.IMPORT:
f = p.parseImportSpec
case token.CONST, token.VAR:
f = p.parseValueSpec
case token.TYPE:
f = p.parseTypeSpec
case token.FUNC:
return p.parseFuncDecl()
default:
pos := p.pos
p.errorExpected(pos, "declaration")
p.advance(sync)
return &ast.BadDecl{From: pos, To: p.pos}
}
return p.parseGenDecl(p.tok, f)
}
// ----------------------------------------------------------------------------
// Source files
func (p *parser) parseFile() *ast.File {
if p.trace {
defer un(trace(p, "File"))
}
// Don't bother parsing the rest if we had errors scanning the first token.
// Likely not a Go source file at all.
if p.errors.Len() != 0 {
return nil
}
// package clause
doc := p.leadComment
pos := p.expect(token.PACKAGE)
// Go spec: The package clause is not a declaration;
// the package name does not appear in any scope.
ident := p.parseIdent()
if ident.Name == "_" && p.mode&DeclarationErrors != 0 {
p.error(p.pos, "invalid package name _")
}
p.expectSemi()
// Don't bother parsing the rest if we had errors parsing the package clause.
// Likely not a Go source file at all.
if p.errors.Len() != 0 {
return nil
}
var decls []ast.Decl
if p.mode&PackageClauseOnly == 0 {
// import decls
for p.tok == token.IMPORT {
decls = append(decls, p.parseGenDecl(token.IMPORT, p.parseImportSpec))
}
if p.mode&ImportsOnly == 0 {
// rest of package body
prev := token.IMPORT
for p.tok != token.EOF {
// Continue to accept import declarations for error tolerance, but complain.
if p.tok == token.IMPORT && prev != token.IMPORT {
p.error(p.pos, "imports must appear before other declarations")
}
prev = p.tok
decls = append(decls, p.parseDecl(declStart))
}
}
}
f := &ast.File{
Doc: doc,
Package: pos,
Name: ident,
Decls: decls,
// File{Start,End} are set by the defer in the caller.
Imports: p.imports,
Comments: p.comments,
GoVersion: p.goVersion,
}
var declErr func(token.Pos, string)
if p.mode&DeclarationErrors != 0 {
declErr = p.error
}
if p.mode&SkipObjectResolution == 0 {
resolveFile(f, p.file, declErr)
}
return f
}
// packIndexExpr returns an IndexExpr x[expr0] or IndexListExpr x[expr0, ...].
func packIndexExpr(x ast.Expr, lbrack token.Pos, exprs []ast.Expr, rbrack token.Pos) ast.Expr {
switch len(exprs) {
case 0:
panic("internal error: packIndexExpr with empty expr slice")
case 1:
return &ast.IndexExpr{
X: x,
Lbrack: lbrack,
Index: exprs[0],
Rbrack: rbrack,
}
default:
return &ast.IndexListExpr{
X: x,
Lbrack: lbrack,
Indices: exprs,
Rbrack: rbrack,
}
}
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package parser
import (
"fmt"
"go/ast"
"go/token"
"strings"
)
const debugResolve = false
// resolveFile walks the given file to resolve identifiers within the file
// scope, updating ast.Ident.Obj fields with declaration information.
//
// If declErr is non-nil, it is used to report declaration errors during
// resolution. tok is used to format position in error messages.
func resolveFile(file *ast.File, handle *token.File, declErr func(token.Pos, string)) {
pkgScope := ast.NewScope(nil)
r := &resolver{
handle: handle,
declErr: declErr,
topScope: pkgScope,
pkgScope: pkgScope,
depth: 1,
}
for _, decl := range file.Decls {
ast.Walk(r, decl)
}
r.closeScope()
assert(r.topScope == nil, "unbalanced scopes")
assert(r.labelScope == nil, "unbalanced label scopes")
// resolve global identifiers within the same file
i := 0
for _, ident := range r.unresolved {
// i <= index for current ident
assert(ident.Obj == unresolved, "object already resolved")
ident.Obj = r.pkgScope.Lookup(ident.Name) // also removes unresolved sentinel
if ident.Obj == nil {
r.unresolved[i] = ident
i++
} else if debugResolve {
pos := ident.Obj.Decl.(interface{ Pos() token.Pos }).Pos()
r.trace("resolved %s@%v to package object %v", ident.Name, ident.Pos(), pos)
}
}
file.Scope = r.pkgScope
file.Unresolved = r.unresolved[0:i]
}
const maxScopeDepth int = 1e3
type resolver struct {
handle *token.File
declErr func(token.Pos, string)
// Ordinary identifier scopes
pkgScope *ast.Scope // pkgScope.Outer == nil
topScope *ast.Scope // top-most scope; may be pkgScope
unresolved []*ast.Ident // unresolved identifiers
depth int // scope depth
// Label scopes
// (maintained by open/close LabelScope)
labelScope *ast.Scope // label scope for current function
targetStack [][]*ast.Ident // stack of unresolved labels
}
func (r *resolver) trace(format string, args ...any) {
fmt.Println(strings.Repeat(". ", r.depth) + r.sprintf(format, args...))
}
func (r *resolver) sprintf(format string, args ...any) string {
for i, arg := range args {
switch arg := arg.(type) {
case token.Pos:
args[i] = r.handle.Position(arg)
}
}
return fmt.Sprintf(format, args...)
}
func (r *resolver) openScope(pos token.Pos) {
r.depth++
if r.depth > maxScopeDepth {
panic(bailout{pos: pos, msg: "exceeded max scope depth during object resolution"})
}
if debugResolve {
r.trace("opening scope @%v", pos)
}
r.topScope = ast.NewScope(r.topScope)
}
func (r *resolver) closeScope() {
r.depth--
if debugResolve {
r.trace("closing scope")
}
r.topScope = r.topScope.Outer
}
func (r *resolver) openLabelScope() {
r.labelScope = ast.NewScope(r.labelScope)
r.targetStack = append(r.targetStack, nil)
}
func (r *resolver) closeLabelScope() {
// resolve labels
n := len(r.targetStack) - 1
scope := r.labelScope
for _, ident := range r.targetStack[n] {
ident.Obj = scope.Lookup(ident.Name)
if ident.Obj == nil && r.declErr != nil {
r.declErr(ident.Pos(), fmt.Sprintf("label %s undefined", ident.Name))
}
}
// pop label scope
r.targetStack = r.targetStack[0:n]
r.labelScope = r.labelScope.Outer
}
func (r *resolver) declare(decl, data any, scope *ast.Scope, kind ast.ObjKind, idents ...*ast.Ident) {
for _, ident := range idents {
if ident.Obj != nil {
panic(fmt.Sprintf("%v: identifier %s already declared or resolved", ident.Pos(), ident.Name))
}
obj := ast.NewObj(kind, ident.Name)
// remember the corresponding declaration for redeclaration
// errors and global variable resolution/typechecking phase
obj.Decl = decl
obj.Data = data
// Identifiers (for receiver type parameters) are written to the scope, but
// never set as the resolved object. See go.dev/issue/50956.
if _, ok := decl.(*ast.Ident); !ok {
ident.Obj = obj
}
if ident.Name != "_" {
if debugResolve {
r.trace("declaring %s@%v", ident.Name, ident.Pos())
}
if alt := scope.Insert(obj); alt != nil && r.declErr != nil {
prevDecl := ""
if pos := alt.Pos(); pos.IsValid() {
prevDecl = r.sprintf("\n\tprevious declaration at %v", pos)
}
r.declErr(ident.Pos(), fmt.Sprintf("%s redeclared in this block%s", ident.Name, prevDecl))
}
}
}
}
func (r *resolver) shortVarDecl(decl *ast.AssignStmt) {
// Go spec: A short variable declaration may redeclare variables
// provided they were originally declared in the same block with
// the same type, and at least one of the non-blank variables is new.
n := 0 // number of new variables
for _, x := range decl.Lhs {
if ident, isIdent := x.(*ast.Ident); isIdent {
assert(ident.Obj == nil, "identifier already declared or resolved")
obj := ast.NewObj(ast.Var, ident.Name)
// remember corresponding assignment for other tools
obj.Decl = decl
ident.Obj = obj
if ident.Name != "_" {
if debugResolve {
r.trace("declaring %s@%v", ident.Name, ident.Pos())
}
if alt := r.topScope.Insert(obj); alt != nil {
ident.Obj = alt // redeclaration
} else {
n++ // new declaration
}
}
}
}
if n == 0 && r.declErr != nil {
r.declErr(decl.Lhs[0].Pos(), "no new variables on left side of :=")
}
}
// The unresolved object is a sentinel to mark identifiers that have been added
// to the list of unresolved identifiers. The sentinel is only used for verifying
// internal consistency.
var unresolved = new(ast.Object)
// If x is an identifier, resolve attempts to resolve x by looking up
// the object it denotes. If no object is found and collectUnresolved is
// set, x is marked as unresolved and collected in the list of unresolved
// identifiers.
func (r *resolver) resolve(ident *ast.Ident, collectUnresolved bool) {
if ident.Obj != nil {
panic(r.sprintf("%v: identifier %s already declared or resolved", ident.Pos(), ident.Name))
}
// '_' should never refer to existing declarations, because it has special
// handling in the spec.
if ident.Name == "_" {
return
}
for s := r.topScope; s != nil; s = s.Outer {
if obj := s.Lookup(ident.Name); obj != nil {
if debugResolve {
r.trace("resolved %v:%s to %v", ident.Pos(), ident.Name, obj)
}
assert(obj.Name != "", "obj with no name")
// Identifiers (for receiver type parameters) are written to the scope,
// but never set as the resolved object. See go.dev/issue/50956.
if _, ok := obj.Decl.(*ast.Ident); !ok {
ident.Obj = obj
}
return
}
}
// all local scopes are known, so any unresolved identifier
// must be found either in the file scope, package scope
// (perhaps in another file), or universe scope --- collect
// them so that they can be resolved later
if collectUnresolved {
ident.Obj = unresolved
r.unresolved = append(r.unresolved, ident)
}
}
func (r *resolver) walkExprs(list []ast.Expr) {
for _, node := range list {
ast.Walk(r, node)
}
}
func (r *resolver) walkLHS(list []ast.Expr) {
for _, expr := range list {
expr := ast.Unparen(expr)
if _, ok := expr.(*ast.Ident); !ok && expr != nil {
ast.Walk(r, expr)
}
}
}
func (r *resolver) walkStmts(list []ast.Stmt) {
for _, stmt := range list {
ast.Walk(r, stmt)
}
}
func (r *resolver) Visit(node ast.Node) ast.Visitor {
if debugResolve && node != nil {
r.trace("node %T@%v", node, node.Pos())
}
switch n := node.(type) {
// Expressions.
case *ast.Ident:
r.resolve(n, true)
case *ast.FuncLit:
r.openScope(n.Pos())
defer r.closeScope()
r.walkFuncType(n.Type)
r.walkBody(n.Body)
case *ast.SelectorExpr:
ast.Walk(r, n.X)
// Note: don't try to resolve n.Sel, as we don't support qualified
// resolution.
case *ast.StructType:
r.openScope(n.Pos())
defer r.closeScope()
r.walkFieldList(n.Fields, ast.Var)
case *ast.FuncType:
r.openScope(n.Pos())
defer r.closeScope()
r.walkFuncType(n)
case *ast.CompositeLit:
if n.Type != nil {
ast.Walk(r, n.Type)
}
for _, e := range n.Elts {
if kv, _ := e.(*ast.KeyValueExpr); kv != nil {
// See go.dev/issue/45160: try to resolve composite lit keys, but don't
// collect them as unresolved if resolution failed. This replicates
// existing behavior when resolving during parsing.
if ident, _ := kv.Key.(*ast.Ident); ident != nil {
r.resolve(ident, false)
} else {
ast.Walk(r, kv.Key)
}
ast.Walk(r, kv.Value)
} else {
ast.Walk(r, e)
}
}
case *ast.InterfaceType:
r.openScope(n.Pos())
defer r.closeScope()
r.walkFieldList(n.Methods, ast.Fun)
// Statements
case *ast.LabeledStmt:
r.declare(n, nil, r.labelScope, ast.Lbl, n.Label)
ast.Walk(r, n.Stmt)
case *ast.AssignStmt:
r.walkExprs(n.Rhs)
if n.Tok == token.DEFINE {
r.shortVarDecl(n)
} else {
r.walkExprs(n.Lhs)
}
case *ast.BranchStmt:
// add to list of unresolved targets
if n.Tok != token.FALLTHROUGH && n.Label != nil {
depth := len(r.targetStack) - 1
r.targetStack[depth] = append(r.targetStack[depth], n.Label)
}
case *ast.BlockStmt:
r.openScope(n.Pos())
defer r.closeScope()
r.walkStmts(n.List)
case *ast.IfStmt:
r.openScope(n.Pos())
defer r.closeScope()
if n.Init != nil {
ast.Walk(r, n.Init)
}
ast.Walk(r, n.Cond)
ast.Walk(r, n.Body)
if n.Else != nil {
ast.Walk(r, n.Else)
}
case *ast.CaseClause:
r.walkExprs(n.List)
r.openScope(n.Pos())
defer r.closeScope()
r.walkStmts(n.Body)
case *ast.SwitchStmt:
r.openScope(n.Pos())
defer r.closeScope()
if n.Init != nil {
ast.Walk(r, n.Init)
}
if n.Tag != nil {
// The scope below reproduces some unnecessary behavior of the parser,
// opening an extra scope in case this is a type switch. It's not needed
// for expression switches.
// TODO: remove this once we've matched the parser resolution exactly.
if n.Init != nil {
r.openScope(n.Tag.Pos())
defer r.closeScope()
}
ast.Walk(r, n.Tag)
}
if n.Body != nil {
r.walkStmts(n.Body.List)
}
case *ast.TypeSwitchStmt:
if n.Init != nil {
r.openScope(n.Pos())
defer r.closeScope()
ast.Walk(r, n.Init)
}
r.openScope(n.Assign.Pos())
defer r.closeScope()
ast.Walk(r, n.Assign)
// s.Body consists only of case clauses, so does not get its own
// scope.
if n.Body != nil {
r.walkStmts(n.Body.List)
}
case *ast.CommClause:
r.openScope(n.Pos())
defer r.closeScope()
if n.Comm != nil {
ast.Walk(r, n.Comm)
}
r.walkStmts(n.Body)
case *ast.SelectStmt:
// as for switch statements, select statement bodies don't get their own
// scope.
if n.Body != nil {
r.walkStmts(n.Body.List)
}
case *ast.ForStmt:
r.openScope(n.Pos())
defer r.closeScope()
if n.Init != nil {
ast.Walk(r, n.Init)
}
if n.Cond != nil {
ast.Walk(r, n.Cond)
}
if n.Post != nil {
ast.Walk(r, n.Post)
}
ast.Walk(r, n.Body)
case *ast.RangeStmt:
r.openScope(n.Pos())
defer r.closeScope()
ast.Walk(r, n.X)
var lhs []ast.Expr
if n.Key != nil {
lhs = append(lhs, n.Key)
}
if n.Value != nil {
lhs = append(lhs, n.Value)
}
if len(lhs) > 0 {
if n.Tok == token.DEFINE {
// Note: we can't exactly match the behavior of object resolution
// during the parsing pass here, as it uses the position of the RANGE
// token for the RHS OpPos. That information is not contained within
// the AST.
as := &ast.AssignStmt{
Lhs: lhs,
Tok: token.DEFINE,
TokPos: n.TokPos,
Rhs: []ast.Expr{&ast.UnaryExpr{Op: token.RANGE, X: n.X}},
}
// TODO(rFindley): this walkLHS reproduced the parser resolution, but
// is it necessary? By comparison, for a normal AssignStmt we don't
// walk the LHS in case there is an invalid identifier list.
r.walkLHS(lhs)
r.shortVarDecl(as)
} else {
r.walkExprs(lhs)
}
}
ast.Walk(r, n.Body)
// Declarations
case *ast.GenDecl:
switch n.Tok {
case token.CONST, token.VAR:
for i, spec := range n.Specs {
spec := spec.(*ast.ValueSpec)
kind := ast.Con
if n.Tok == token.VAR {
kind = ast.Var
}
r.walkExprs(spec.Values)
if spec.Type != nil {
ast.Walk(r, spec.Type)
}
r.declare(spec, i, r.topScope, kind, spec.Names...)
}
case token.TYPE:
for _, spec := range n.Specs {
spec := spec.(*ast.TypeSpec)
// Go spec: The scope of a type identifier declared inside a function begins
// at the identifier in the TypeSpec and ends at the end of the innermost
// containing block.
r.declare(spec, nil, r.topScope, ast.Typ, spec.Name)
if spec.TypeParams != nil {
r.openScope(spec.Pos())
defer r.closeScope()
r.walkTParams(spec.TypeParams)
}
ast.Walk(r, spec.Type)
}
}
case *ast.FuncDecl:
// Open the function scope.
r.openScope(n.Pos())
defer r.closeScope()
r.walkRecv(n.Recv)
// Type parameters are walked normally: they can reference each other, and
// can be referenced by normal parameters.
if n.Type.TypeParams != nil {
r.walkTParams(n.Type.TypeParams)
// TODO(rFindley): need to address receiver type parameters.
}
// Resolve and declare parameters in a specific order to get duplicate
// declaration errors in the correct location.
r.resolveList(n.Type.Params)
r.resolveList(n.Type.Results)
r.declareList(n.Recv, ast.Var)
r.declareList(n.Type.Params, ast.Var)
r.declareList(n.Type.Results, ast.Var)
r.walkBody(n.Body)
if n.Recv == nil && n.Name.Name != "init" {
r.declare(n, nil, r.pkgScope, ast.Fun, n.Name)
}
default:
return r
}
return nil
}
func (r *resolver) walkFuncType(typ *ast.FuncType) {
// typ.TypeParams must be walked separately for FuncDecls.
r.resolveList(typ.Params)
r.resolveList(typ.Results)
r.declareList(typ.Params, ast.Var)
r.declareList(typ.Results, ast.Var)
}
func (r *resolver) resolveList(list *ast.FieldList) {
if list == nil {
return
}
for _, f := range list.List {
if f.Type != nil {
ast.Walk(r, f.Type)
}
}
}
func (r *resolver) declareList(list *ast.FieldList, kind ast.ObjKind) {
if list == nil {
return
}
for _, f := range list.List {
r.declare(f, nil, r.topScope, kind, f.Names...)
}
}
func (r *resolver) walkRecv(recv *ast.FieldList) {
// If our receiver has receiver type parameters, we must declare them before
// trying to resolve the rest of the receiver, and avoid re-resolving the
// type parameter identifiers.
if recv == nil || len(recv.List) == 0 {
return // nothing to do
}
typ := recv.List[0].Type
if ptr, ok := typ.(*ast.StarExpr); ok {
typ = ptr.X
}
var declareExprs []ast.Expr // exprs to declare
var resolveExprs []ast.Expr // exprs to resolve
switch typ := typ.(type) {
case *ast.IndexExpr:
declareExprs = []ast.Expr{typ.Index}
resolveExprs = append(resolveExprs, typ.X)
case *ast.IndexListExpr:
declareExprs = typ.Indices
resolveExprs = append(resolveExprs, typ.X)
default:
resolveExprs = append(resolveExprs, typ)
}
for _, expr := range declareExprs {
if id, _ := expr.(*ast.Ident); id != nil {
r.declare(expr, nil, r.topScope, ast.Typ, id)
} else {
// The receiver type parameter expression is invalid, but try to resolve
// it anyway for consistency.
resolveExprs = append(resolveExprs, expr)
}
}
for _, expr := range resolveExprs {
if expr != nil {
ast.Walk(r, expr)
}
}
// The receiver is invalid, but try to resolve it anyway for consistency.
for _, f := range recv.List[1:] {
if f.Type != nil {
ast.Walk(r, f.Type)
}
}
}
func (r *resolver) walkFieldList(list *ast.FieldList, kind ast.ObjKind) {
if list == nil {
return
}
r.resolveList(list)
r.declareList(list, kind)
}
// walkTParams is like walkFieldList, but declares type parameters eagerly so
// that they may be resolved in the constraint expressions held in the field
// Type.
func (r *resolver) walkTParams(list *ast.FieldList) {
r.declareList(list, ast.Typ)
r.resolveList(list)
}
func (r *resolver) walkBody(body *ast.BlockStmt) {
if body == nil {
return
}
r.openLabelScope()
defer r.closeLabelScope()
r.walkStmts(body.List)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package token
import (
"cmp"
"fmt"
"slices"
"strconv"
"sync"
"sync/atomic"
)
// If debug is set, invalid offset and position values cause a panic
// (go.dev/issue/57490).
const debug = false
// -----------------------------------------------------------------------------
// Positions
// Position describes an arbitrary source position
// including the file, line, and column location.
// A Position is valid if the line number is > 0.
type Position struct {
Filename string // filename, if any
Offset int // offset, starting at 0
Line int // line number, starting at 1
Column int // column number, starting at 1 (byte count)
}
// IsValid reports whether the position is valid.
func (pos *Position) IsValid() bool { return pos.Line > 0 }
// String returns a string in one of several forms:
//
// file:line:column valid position with file name
// file:line valid position with file name but no column (column == 0)
// line:column valid position without file name
// line valid position without file name and no column (column == 0)
// file invalid position with file name
// - invalid position without file name
func (pos Position) String() string {
s := pos.Filename
if pos.IsValid() {
if s != "" {
s += ":"
}
s += strconv.Itoa(pos.Line)
if pos.Column != 0 {
s += fmt.Sprintf(":%d", pos.Column)
}
}
if s == "" {
s = "-"
}
return s
}
// Pos is a compact encoding of a source position within a file set.
// It can be converted into a [Position] for a more convenient, but much
// larger, representation.
//
// The Pos value for a given file is a number in the range [base, base+size],
// where base and size are specified when a file is added to the file set.
// The difference between a Pos value and the corresponding file base
// corresponds to the byte offset of that position (represented by the Pos value)
// from the beginning of the file. Thus, the file base offset is the Pos value
// representing the first byte in the file.
//
// To create the Pos value for a specific source offset (measured in bytes),
// first add the respective file to the current file set using [FileSet.AddFile]
// and then call [File.Pos](offset) for that file. Given a Pos value p
// for a specific file set fset, the corresponding [Position] value is
// obtained by calling fset.Position(p).
//
// Pos values can be compared directly with the usual comparison operators:
// If two Pos values p and q are in the same file, comparing p and q is
// equivalent to comparing the respective source file offsets. If p and q
// are in different files, p < q is true if the file implied by p was added
// to the respective file set before the file implied by q.
type Pos int
// The zero value for [Pos] is NoPos; there is no file and line information
// associated with it, and NoPos.IsValid() is false. NoPos is always
// smaller than any other [Pos] value. The corresponding [Position] value
// for NoPos is the zero value for [Position].
const NoPos Pos = 0
// IsValid reports whether the position is valid.
func (p Pos) IsValid() bool {
return p != NoPos
}
// -----------------------------------------------------------------------------
// File
// A File is a handle for a file belonging to a [FileSet].
// A File has a name, size, and line offset table.
//
// Use [FileSet.AddFile] to create a File.
// A File may belong to more than one FileSet; see [FileSet.AddExistingFiles].
type File struct {
name string // file name as provided to AddFile
base int // Pos value range for this file is [base...base+size]
size int // file size as provided to AddFile
// lines and infos are protected by mutex
mutex sync.Mutex
lines []int // lines contains the offset of the first character for each line (the first entry is always 0)
infos []lineInfo
}
// Name returns the file name of file f as registered with AddFile.
func (f *File) Name() string {
return f.name
}
// Base returns the base offset of file f as registered with AddFile.
func (f *File) Base() int {
return f.base
}
// Size returns the size of file f as registered with AddFile.
func (f *File) Size() int {
return f.size
}
// LineCount returns the number of lines in file f.
func (f *File) LineCount() int {
f.mutex.Lock()
n := len(f.lines)
f.mutex.Unlock()
return n
}
// AddLine adds the line offset for a new line.
// The line offset must be larger than the offset for the previous line
// and smaller than the file size; otherwise the line offset is ignored.
func (f *File) AddLine(offset int) {
f.mutex.Lock()
if i := len(f.lines); (i == 0 || f.lines[i-1] < offset) && offset < f.size {
f.lines = append(f.lines, offset)
}
f.mutex.Unlock()
}
// MergeLine merges a line with the following line. It is akin to replacing
// the newline character at the end of the line with a space (to not change the
// remaining offsets). To obtain the line number, consult e.g. [Position.Line].
// MergeLine will panic if given an invalid line number.
func (f *File) MergeLine(line int) {
if line < 1 {
panic(fmt.Sprintf("invalid line number %d (should be >= 1)", line))
}
f.mutex.Lock()
defer f.mutex.Unlock()
if line >= len(f.lines) {
panic(fmt.Sprintf("invalid line number %d (should be < %d)", line, len(f.lines)))
}
// To merge the line numbered <line> with the line numbered <line+1>,
// we need to remove the entry in lines corresponding to the line
// numbered <line+1>. The entry in lines corresponding to the line
// numbered <line+1> is located at index <line>, since indices in lines
// are 0-based and line numbers are 1-based.
copy(f.lines[line:], f.lines[line+1:])
f.lines = f.lines[:len(f.lines)-1]
}
// Lines returns the effective line offset table of the form described by [File.SetLines].
// Callers must not mutate the result.
func (f *File) Lines() []int {
f.mutex.Lock()
lines := f.lines
f.mutex.Unlock()
return lines
}
// SetLines sets the line offsets for a file and reports whether it succeeded.
// The line offsets are the offsets of the first character of each line;
// for instance for the content "ab\nc\n" the line offsets are {0, 3}.
// An empty file has an empty line offset table.
// Each line offset must be larger than the offset for the previous line
// and smaller than the file size; otherwise SetLines fails and returns
// false.
// Callers must not mutate the provided slice after SetLines returns.
func (f *File) SetLines(lines []int) bool {
// verify validity of lines table
size := f.size
for i, offset := range lines {
if i > 0 && offset <= lines[i-1] || size <= offset {
return false
}
}
// set lines table
f.mutex.Lock()
f.lines = lines
f.mutex.Unlock()
return true
}
// SetLinesForContent sets the line offsets for the given file content.
// It ignores position-altering //line comments.
func (f *File) SetLinesForContent(content []byte) {
var lines []int
line := 0
for offset, b := range content {
if line >= 0 {
lines = append(lines, line)
}
line = -1
if b == '\n' {
line = offset + 1
}
}
// set lines table
f.mutex.Lock()
f.lines = lines
f.mutex.Unlock()
}
// LineStart returns the [Pos] value of the start of the specified line.
// It ignores any alternative positions set using [File.AddLineColumnInfo].
// LineStart panics if the 1-based line number is invalid.
func (f *File) LineStart(line int) Pos {
if line < 1 {
panic(fmt.Sprintf("invalid line number %d (should be >= 1)", line))
}
f.mutex.Lock()
defer f.mutex.Unlock()
if line > len(f.lines) {
panic(fmt.Sprintf("invalid line number %d (should be < %d)", line, len(f.lines)))
}
return Pos(f.base + f.lines[line-1])
}
// A lineInfo object describes alternative file, line, and column
// number information (such as provided via a //line directive)
// for a given file offset.
type lineInfo struct {
// fields are exported to make them accessible to gob
Offset int
Filename string
Line, Column int
}
// AddLineInfo is like [File.AddLineColumnInfo] with a column = 1 argument.
// It is here for backward-compatibility for code prior to Go 1.11.
func (f *File) AddLineInfo(offset int, filename string, line int) {
f.AddLineColumnInfo(offset, filename, line, 1)
}
// AddLineColumnInfo adds alternative file, line, and column number
// information for a given file offset. The offset must be larger
// than the offset for the previously added alternative line info
// and smaller than the file size; otherwise the information is
// ignored.
//
// AddLineColumnInfo is typically used to register alternative position
// information for line directives such as //line filename:line:column.
func (f *File) AddLineColumnInfo(offset int, filename string, line, column int) {
f.mutex.Lock()
if i := len(f.infos); (i == 0 || f.infos[i-1].Offset < offset) && offset < f.size {
f.infos = append(f.infos, lineInfo{offset, filename, line, column})
}
f.mutex.Unlock()
}
// fixOffset fixes an out-of-bounds offset such that 0 <= offset <= f.size.
func (f *File) fixOffset(offset int) int {
switch {
case offset < 0:
if !debug {
return 0
}
case offset > f.size:
if !debug {
return f.size
}
default:
return offset
}
// only generate this code if needed
if debug {
panic(fmt.Sprintf("offset %d out of bounds [%d, %d] (position %d out of bounds [%d, %d])",
0 /* for symmetry */, offset, f.size,
f.base+offset, f.base, f.base+f.size))
}
return 0
}
// Pos returns the Pos value for the given file offset.
//
// If offset is negative, the result is the file's start
// position; if the offset is too large, the result is
// the file's end position (see also go.dev/issue/57490).
//
// The following invariant, though not true for Pos values
// in general, holds for the result p:
// f.Pos(f.Offset(p)) == p.
func (f *File) Pos(offset int) Pos {
return Pos(f.base + f.fixOffset(offset))
}
// Offset returns the offset for the given file position p.
//
// If p is before the file's start position (or if p is NoPos),
// the result is 0; if p is past the file's end position,
// the result is the file size (see also go.dev/issue/57490).
//
// The following invariant, though not true for offset values
// in general, holds for the result offset:
// f.Offset(f.Pos(offset)) == offset
func (f *File) Offset(p Pos) int {
return f.fixOffset(int(p) - f.base)
}
// Line returns the line number for the given file position p;
// p must be a [Pos] value in that file or [NoPos].
func (f *File) Line(p Pos) int {
return f.Position(p).Line
}
func searchLineInfos(a []lineInfo, x int) int {
i, found := slices.BinarySearchFunc(a, x, func(a lineInfo, x int) int {
return cmp.Compare(a.Offset, x)
})
if !found {
// We want the lineInfo containing x, but if we didn't
// find x then i is the next one.
i--
}
return i
}
// unpack returns the filename and line and column number for a file offset.
// If adjusted is set, unpack will return the filename and line information
// possibly adjusted by //line comments; otherwise those comments are ignored.
func (f *File) unpack(offset int, adjusted bool) (filename string, line, column int) {
f.mutex.Lock()
filename = f.name
if i := searchInts(f.lines, offset); i >= 0 {
line, column = i+1, offset-f.lines[i]+1
}
if adjusted && len(f.infos) > 0 {
// few files have extra line infos
if i := searchLineInfos(f.infos, offset); i >= 0 {
alt := &f.infos[i]
filename = alt.Filename
if i := searchInts(f.lines, alt.Offset); i >= 0 {
// i+1 is the line at which the alternative position was recorded
d := line - (i + 1) // line distance from alternative position base
line = alt.Line + d
if alt.Column == 0 {
// alternative column is unknown => relative column is unknown
// (the current specification for line directives requires
// this to apply until the next PosBase/line directive,
// not just until the new newline)
column = 0
} else if d == 0 {
// the alternative position base is on the current line
// => column is relative to alternative column
column = alt.Column + (offset - alt.Offset)
}
}
}
}
// TODO(mvdan): move Unlock back under Lock with a defer statement once
// https://go.dev/issue/38471 is fixed to remove the performance penalty.
f.mutex.Unlock()
return
}
func (f *File) position(p Pos, adjusted bool) (pos Position) {
offset := f.fixOffset(int(p) - f.base)
pos.Offset = offset
pos.Filename, pos.Line, pos.Column = f.unpack(offset, adjusted)
return
}
// PositionFor returns the Position value for the given file position p.
// If p is out of bounds, it is adjusted to match the File.Offset behavior.
// If adjusted is set, the position may be adjusted by position-altering
// //line comments; otherwise those comments are ignored.
// p must be a Pos value in f or NoPos.
func (f *File) PositionFor(p Pos, adjusted bool) (pos Position) {
if p != NoPos {
pos = f.position(p, adjusted)
}
return
}
// Position returns the Position value for the given file position p.
// If p is out of bounds, it is adjusted to match the File.Offset behavior.
// Calling f.Position(p) is equivalent to calling f.PositionFor(p, true).
func (f *File) Position(p Pos) (pos Position) {
return f.PositionFor(p, true)
}
// -----------------------------------------------------------------------------
// FileSet
// A FileSet represents a set of source files.
// Methods of file sets are synchronized; multiple goroutines
// may invoke them concurrently.
//
// The byte offsets for each file in a file set are mapped into
// distinct (integer) intervals, one interval [base, base+size]
// per file. [FileSet.Base] represents the first byte in the file, and size
// is the corresponding file size. A [Pos] value is a value in such
// an interval. By determining the interval a [Pos] value belongs
// to, the file, its file base, and thus the byte offset (position)
// the [Pos] value is representing can be computed.
//
// When adding a new file, a file base must be provided. That can
// be any integer value that is past the end of any interval of any
// file already in the file set. For convenience, [FileSet.Base] provides
// such a value, which is simply the end of the Pos interval of the most
// recently added file, plus one. Unless there is a need to extend an
// interval later, using the [FileSet.Base] should be used as argument
// for [FileSet.AddFile].
//
// A [File] may be removed from a FileSet when it is no longer needed.
// This may reduce memory usage in a long-running application.
type FileSet struct {
mutex sync.RWMutex // protects the file set
base int // base offset for the next file
tree tree // tree of files in ascending base order
last atomic.Pointer[File] // cache of last file looked up
}
// NewFileSet creates a new file set.
func NewFileSet() *FileSet {
return &FileSet{
base: 1, // 0 == NoPos
}
}
// Base returns the minimum base offset that must be provided to
// [FileSet.AddFile] when adding the next file.
func (s *FileSet) Base() int {
s.mutex.RLock()
b := s.base
s.mutex.RUnlock()
return b
}
// AddFile adds a new file with a given filename, base offset, and file size
// to the file set s and returns the file. Multiple files may have the same
// name. The base offset must not be smaller than the [FileSet.Base], and
// size must not be negative. As a special case, if a negative base is provided,
// the current value of the [FileSet.Base] is used instead.
//
// Adding the file will set the file set's [FileSet.Base] value to base + size + 1
// as the minimum base value for the next file. The following relationship
// exists between a [Pos] value p for a given file offset offs:
//
// int(p) = base + offs
//
// with offs in the range [0, size] and thus p in the range [base, base+size].
// For convenience, [File.Pos] may be used to create file-specific position
// values from a file offset.
func (s *FileSet) AddFile(filename string, base, size int) *File {
// Allocate f outside the critical section.
f := &File{name: filename, size: size, lines: []int{0}}
s.mutex.Lock()
defer s.mutex.Unlock()
if base < 0 {
base = s.base
}
if base < s.base {
panic(fmt.Sprintf("invalid base %d (should be >= %d)", base, s.base))
}
f.base = base
if size < 0 {
panic(fmt.Sprintf("invalid size %d (should be >= 0)", size))
}
// base >= s.base && size >= 0
base += size + 1 // +1 because EOF also has a position
if base < 0 {
panic("token.Pos offset overflow (> 2G of source code in file set)")
}
// add the file to the file set
s.base = base
s.tree.add(f)
s.last.Store(f)
return f
}
// AddExistingFiles adds the specified files to the
// FileSet if they are not already present.
// The caller must ensure that no pair of Files that
// would appear in the resulting FileSet overlap.
func (s *FileSet) AddExistingFiles(files ...*File) {
// This function cannot be implemented as:
//
// for _, file := range files {
// if prev := fset.File(token.Pos(file.Base())); prev != nil {
// if prev != file {
// panic("FileSet contains a different file at the same base")
// }
// continue
// }
// file2 := fset.AddFile(file.Name(), file.Base(), file.Size())
// file2.SetLines(file.Lines())
// }
//
// because all calls to AddFile must be in increasing order.
// AddExistingFilesFiles lets us augment an existing FileSet
// sequentially, so long as all sets of files have disjoint ranges.
// This approach also does not preserve line directives.
s.mutex.Lock()
defer s.mutex.Unlock()
for _, f := range files {
s.tree.add(f)
s.base = max(s.base, f.Base()+f.Size()+1)
}
}
// RemoveFile removes a file from the [FileSet] so that subsequent
// queries for its [Pos] interval yield a negative result.
// This reduces the memory usage of a long-lived [FileSet] that
// encounters an unbounded stream of files.
//
// Removing a file that does not belong to the set has no effect.
func (s *FileSet) RemoveFile(file *File) {
s.last.CompareAndSwap(file, nil) // clear last file cache
s.mutex.Lock()
defer s.mutex.Unlock()
pn, _ := s.tree.locate(file.key())
if *pn != nil && (*pn).file == file {
s.tree.delete(pn)
}
}
// Iterate calls yield for the files in the file set in ascending Base
// order until yield returns false.
func (s *FileSet) Iterate(yield func(*File) bool) {
s.mutex.RLock()
defer s.mutex.RUnlock()
// Unlock around user code.
// The iterator is robust to modification by yield.
// Avoid range here, so we can use defer.
s.tree.all()(func(f *File) bool {
s.mutex.RUnlock()
defer s.mutex.RLock()
return yield(f)
})
}
func (s *FileSet) file(p Pos) *File {
// common case: p is in last file.
if f := s.last.Load(); f != nil && f.base <= int(p) && int(p) <= f.base+f.size {
return f
}
s.mutex.RLock()
defer s.mutex.RUnlock()
pn, _ := s.tree.locate(key{int(p), int(p)})
if n := *pn; n != nil {
// Update cache of last file. A race is ok,
// but an exclusive lock causes heavy contention.
s.last.Store(n.file)
return n.file
}
return nil
}
// File returns the file that contains the position p.
// If no such file is found (for instance for p == [NoPos]),
// the result is nil.
func (s *FileSet) File(p Pos) (f *File) {
if p != NoPos {
f = s.file(p)
}
return
}
// PositionFor converts a [Pos] p in the fileset into a [Position] value.
// If adjusted is set, the position may be adjusted by position-altering
// //line comments; otherwise those comments are ignored.
// p must be a [Pos] value in s or [NoPos].
func (s *FileSet) PositionFor(p Pos, adjusted bool) (pos Position) {
if p != NoPos {
if f := s.file(p); f != nil {
return f.position(p, adjusted)
}
}
return
}
// Position converts a [Pos] p in the fileset into a Position value.
// Calling s.Position(p) is equivalent to calling s.PositionFor(p, true).
func (s *FileSet) Position(p Pos) (pos Position) {
return s.PositionFor(p, true)
}
// -----------------------------------------------------------------------------
// Helper functions
func searchInts(a []int, x int) int {
// This function body is a manually inlined version of:
//
// return sort.Search(len(a), func(i int) bool { return a[i] > x }) - 1
//
// With better compiler optimizations, this may not be needed in the
// future, but at the moment this change improves the go/printer
// benchmark performance by ~30%. This has a direct impact on the
// speed of gofmt and thus seems worthwhile (2011-04-29).
// TODO(gri): Remove this when compilers have caught up.
i, j := 0, len(a)
for i < j {
h := int(uint(i+j) >> 1) // avoid overflow when computing h
// i ≤ h < j
if a[h] <= x {
i = h + 1
} else {
j = h
}
}
return i - 1
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package token
import "slices"
type serializedFile struct {
// fields correspond 1:1 to fields with same (lower-case) name in File
Name string
Base int
Size int
Lines []int
Infos []lineInfo
}
type serializedFileSet struct {
Base int
Files []serializedFile
}
// Read calls decode to deserialize a file set into s; s must not be nil.
func (s *FileSet) Read(decode func(any) error) error {
var ss serializedFileSet
if err := decode(&ss); err != nil {
return err
}
s.mutex.Lock()
s.base = ss.Base
for _, f := range ss.Files {
s.tree.add(&File{
name: f.Name,
base: f.Base,
size: f.Size,
lines: f.Lines,
infos: f.Infos,
})
}
s.last.Store(nil)
s.mutex.Unlock()
return nil
}
// Write calls encode to serialize the file set s.
func (s *FileSet) Write(encode func(any) error) error {
var ss serializedFileSet
s.mutex.Lock()
ss.Base = s.base
var files []serializedFile
for f := range s.tree.all() {
f.mutex.Lock()
files = append(files, serializedFile{
Name: f.name,
Base: f.base,
Size: f.size,
Lines: slices.Clone(f.lines),
Infos: slices.Clone(f.infos),
})
f.mutex.Unlock()
}
ss.Files = files
s.mutex.Unlock()
return encode(ss)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package token defines constants representing the lexical tokens of the Go
// programming language and basic operations on tokens (printing, predicates).
package token
import (
"strconv"
"unicode"
"unicode/utf8"
)
// Token is the set of lexical tokens of the Go programming language.
type Token int
// The list of tokens.
const (
// Special tokens
ILLEGAL Token = iota
EOF
COMMENT
literal_beg
// Identifiers and basic type literals
// (these tokens stand for classes of literals)
IDENT // main
INT // 12345
FLOAT // 123.45
IMAG // 123.45i
CHAR // 'a'
STRING // "abc"
literal_end
operator_beg
// Operators and delimiters
ADD // +
SUB // -
MUL // *
QUO // /
REM // %
AND // &
OR // |
XOR // ^
SHL // <<
SHR // >>
AND_NOT // &^
ADD_ASSIGN // +=
SUB_ASSIGN // -=
MUL_ASSIGN // *=
QUO_ASSIGN // /=
REM_ASSIGN // %=
AND_ASSIGN // &=
OR_ASSIGN // |=
XOR_ASSIGN // ^=
SHL_ASSIGN // <<=
SHR_ASSIGN // >>=
AND_NOT_ASSIGN // &^=
LAND // &&
LOR // ||
ARROW // <-
INC // ++
DEC // --
EQL // ==
LSS // <
GTR // >
ASSIGN // =
NOT // !
NEQ // !=
LEQ // <=
GEQ // >=
DEFINE // :=
ELLIPSIS // ...
LPAREN // (
LBRACK // [
LBRACE // {
COMMA // ,
PERIOD // .
RPAREN // )
RBRACK // ]
RBRACE // }
SEMICOLON // ;
COLON // :
operator_end
keyword_beg
// Keywords
BREAK
CASE
CHAN
CONST
CONTINUE
DEFAULT
DEFER
ELSE
FALLTHROUGH
FOR
FUNC
GO
GOTO
IF
IMPORT
INTERFACE
MAP
PACKAGE
RANGE
RETURN
SELECT
STRUCT
SWITCH
TYPE
VAR
keyword_end
additional_beg
// additional tokens, handled in an ad-hoc manner
TILDE
additional_end
)
var tokens = [...]string{
ILLEGAL: "ILLEGAL",
EOF: "EOF",
COMMENT: "COMMENT",
IDENT: "IDENT",
INT: "INT",
FLOAT: "FLOAT",
IMAG: "IMAG",
CHAR: "CHAR",
STRING: "STRING",
ADD: "+",
SUB: "-",
MUL: "*",
QUO: "/",
REM: "%",
AND: "&",
OR: "|",
XOR: "^",
SHL: "<<",
SHR: ">>",
AND_NOT: "&^",
ADD_ASSIGN: "+=",
SUB_ASSIGN: "-=",
MUL_ASSIGN: "*=",
QUO_ASSIGN: "/=",
REM_ASSIGN: "%=",
AND_ASSIGN: "&=",
OR_ASSIGN: "|=",
XOR_ASSIGN: "^=",
SHL_ASSIGN: "<<=",
SHR_ASSIGN: ">>=",
AND_NOT_ASSIGN: "&^=",
LAND: "&&",
LOR: "||",
ARROW: "<-",
INC: "++",
DEC: "--",
EQL: "==",
LSS: "<",
GTR: ">",
ASSIGN: "=",
NOT: "!",
NEQ: "!=",
LEQ: "<=",
GEQ: ">=",
DEFINE: ":=",
ELLIPSIS: "...",
LPAREN: "(",
LBRACK: "[",
LBRACE: "{",
COMMA: ",",
PERIOD: ".",
RPAREN: ")",
RBRACK: "]",
RBRACE: "}",
SEMICOLON: ";",
COLON: ":",
BREAK: "break",
CASE: "case",
CHAN: "chan",
CONST: "const",
CONTINUE: "continue",
DEFAULT: "default",
DEFER: "defer",
ELSE: "else",
FALLTHROUGH: "fallthrough",
FOR: "for",
FUNC: "func",
GO: "go",
GOTO: "goto",
IF: "if",
IMPORT: "import",
INTERFACE: "interface",
MAP: "map",
PACKAGE: "package",
RANGE: "range",
RETURN: "return",
SELECT: "select",
STRUCT: "struct",
SWITCH: "switch",
TYPE: "type",
VAR: "var",
TILDE: "~",
}
// String returns the string corresponding to the token tok.
// For operators, delimiters, and keywords the string is the actual
// token character sequence (e.g., for the token [ADD], the string is
// "+"). For all other tokens the string corresponds to the token
// constant name (e.g. for the token [IDENT], the string is "IDENT").
func (tok Token) String() string {
s := ""
if 0 <= tok && tok < Token(len(tokens)) {
s = tokens[tok]
}
if s == "" {
s = "token(" + strconv.Itoa(int(tok)) + ")"
}
return s
}
// A set of constants for precedence-based expression parsing.
// Non-operators have lowest precedence, followed by operators
// starting with precedence 1 up to unary operators. The highest
// precedence serves as "catch-all" precedence for selector,
// indexing, and other operator and delimiter tokens.
const (
LowestPrec = 0 // non-operators
UnaryPrec = 6
HighestPrec = 7
)
// Precedence returns the operator precedence of the binary
// operator op. If op is not a binary operator, the result
// is LowestPrecedence.
func (op Token) Precedence() int {
switch op {
case LOR:
return 1
case LAND:
return 2
case EQL, NEQ, LSS, LEQ, GTR, GEQ:
return 3
case ADD, SUB, OR, XOR:
return 4
case MUL, QUO, REM, SHL, SHR, AND, AND_NOT:
return 5
}
return LowestPrec
}
var keywords map[string]Token
func init() {
keywords = make(map[string]Token, keyword_end-(keyword_beg+1))
for i := keyword_beg + 1; i < keyword_end; i++ {
keywords[tokens[i]] = i
}
}
// Lookup maps an identifier to its keyword token or [IDENT] (if not a keyword).
func Lookup(ident string) Token {
if tok, is_keyword := keywords[ident]; is_keyword {
return tok
}
return IDENT
}
// Predicates
// IsLiteral returns true for tokens corresponding to identifiers
// and basic type literals; it returns false otherwise.
func (tok Token) IsLiteral() bool { return literal_beg < tok && tok < literal_end }
// IsOperator returns true for tokens corresponding to operators and
// delimiters; it returns false otherwise.
func (tok Token) IsOperator() bool {
return (operator_beg < tok && tok < operator_end) || tok == TILDE
}
// IsKeyword returns true for tokens corresponding to keywords;
// it returns false otherwise.
func (tok Token) IsKeyword() bool { return keyword_beg < tok && tok < keyword_end }
// IsExported reports whether name starts with an upper-case letter.
func IsExported(name string) bool {
ch, _ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(ch)
}
// IsKeyword reports whether name is a Go keyword, such as "func" or "return".
func IsKeyword(name string) bool {
// TODO: opt: use a perfect hash function instead of a global map.
_, ok := keywords[name]
return ok
}
// IsIdentifier reports whether name is a Go identifier, that is, a non-empty
// string made up of letters, digits, and underscores, where the first character
// is not a digit. Keywords are not identifiers.
func IsIdentifier(name string) bool {
if name == "" || IsKeyword(name) {
return false
}
for i, c := range name {
if !unicode.IsLetter(c) && c != '_' && (i == 0 || !unicode.IsDigit(c)) {
return false
}
}
return true
}
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package token
// tree is a self-balancing AVL tree; see
// Lewis & Denenberg, Data Structures and Their Algorithms.
//
// An AVL tree is a binary tree in which the difference between the
// heights of a node's two subtrees--the node's "balance factor"--is
// at most one. It is more strictly balanced than a red/black tree,
// and thus favors lookups at the expense of updates, which is the
// appropriate trade-off for FileSet.
//
// Insertion at a node may cause its ancestors' balance factors to
// temporarily reach ±2, requiring rebalancing of each such ancestor
// by a rotation.
//
// Each key is the pos-end range of a single File.
// All Files in the tree must have disjoint ranges.
//
// The implementation is simplified from Russ Cox's github.com/rsc/omap.
import (
"fmt"
"iter"
)
// A tree is a tree-based ordered map:
// each value is a *File, keyed by its Pos range.
// All map entries cover disjoint ranges.
//
// The zero value of tree is an empty map ready to use.
type tree struct {
root *node
}
type node struct {
// We use the notation (parent left right) in many comments.
parent *node
left *node
right *node
file *File
key key // = file.key(), but improves locality (25% faster)
balance int32 // at most ±2
height int32
}
// A key represents the Pos range of a File.
type key struct{ start, end int }
func (f *File) key() key {
return key{f.base, f.base + f.size}
}
// compareKey reports whether x is before y (-1),
// after y (+1), or overlapping y (0).
// This is a total order so long as all
// files in the tree have disjoint ranges.
//
// All files are separated by at least one unit.
// This allows us to use strict < comparisons.
// Use key{p, p} to search for a zero-width position
// even at the start or end of a file.
func compareKey(x, y key) int {
switch {
case x.end < y.start:
return -1
case y.end < x.start:
return +1
}
return 0
}
// check asserts that each node's height, subtree, and parent link is
// correct.
func (n *node) check(parent *node) {
const debugging = false
if debugging {
if n == nil {
return
}
if n.parent != parent {
panic("bad parent")
}
n.left.check(n)
n.right.check(n)
n.checkBalance()
}
}
func (n *node) checkBalance() {
lheight, rheight := n.left.safeHeight(), n.right.safeHeight()
balance := rheight - lheight
if balance != n.balance {
panic("bad node.balance")
}
if !(-2 <= balance && balance <= +2) {
panic(fmt.Sprintf("node.balance out of range: %d", balance))
}
h := 1 + max(lheight, rheight)
if h != n.height {
panic("bad node.height")
}
}
// locate returns a pointer to the variable that holds the node
// identified by k, along with its parent, if any. If the key is not
// present, it returns a pointer to the node where the key should be
// inserted by a subsequent call to [tree.set].
func (t *tree) locate(k key) (pos **node, parent *node) {
pos, x := &t.root, t.root
for x != nil {
sign := compareKey(k, x.key)
if sign < 0 {
pos, x, parent = &x.left, x.left, x
} else if sign > 0 {
pos, x, parent = &x.right, x.right, x
} else {
break
}
}
return pos, parent
}
// all returns an iterator over the tree t.
// If t is modified during the iteration,
// some files may not be visited.
// No file will be visited multiple times.
func (t *tree) all() iter.Seq[*File] {
return func(yield func(*File) bool) {
if t == nil {
return
}
x := t.root
if x != nil {
for x.left != nil {
x = x.left
}
}
for x != nil && yield(x.file) {
if x.height >= 0 {
// still in tree
x = x.next()
} else {
// deleted
x = t.nextAfter(t.locate(x.key))
}
}
}
}
// nextAfter returns the node in the key sequence following
// (pos, parent), a result pair from [tree.locate].
func (t *tree) nextAfter(pos **node, parent *node) *node {
switch {
case *pos != nil:
return (*pos).next()
case parent == nil:
return nil
case pos == &parent.left:
return parent
default:
return parent.next()
}
}
func (x *node) next() *node {
if x.right == nil {
for x.parent != nil && x.parent.right == x {
x = x.parent
}
return x.parent
}
x = x.right
for x.left != nil {
x = x.left
}
return x
}
func (t *tree) setRoot(x *node) {
t.root = x
if x != nil {
x.parent = nil
}
}
func (x *node) setLeft(y *node) {
x.left = y
if y != nil {
y.parent = x
}
}
func (x *node) setRight(y *node) {
x.right = y
if y != nil {
y.parent = x
}
}
func (n *node) safeHeight() int32 {
if n == nil {
return -1
}
return n.height
}
func (n *node) update() {
lheight, rheight := n.left.safeHeight(), n.right.safeHeight()
n.height = max(lheight, rheight) + 1
n.balance = rheight - lheight
}
func (t *tree) replaceChild(parent, old, new *node) {
switch {
case parent == nil:
if t.root != old {
panic("corrupt tree")
}
t.setRoot(new)
case parent.left == old:
parent.setLeft(new)
case parent.right == old:
parent.setRight(new)
default:
panic("corrupt tree")
}
}
// rebalanceUp visits each excessively unbalanced ancestor
// of x, restoring balance by rotating it.
//
// x is a node that has just been mutated, and so the height and
// balance of x and its ancestors may be stale, but the children of x
// must be in a valid state.
func (t *tree) rebalanceUp(x *node) {
for x != nil {
h := x.height
x.update()
switch x.balance {
case -2:
if x.left.balance == 1 {
t.rotateLeft(x.left)
}
x = t.rotateRight(x)
case +2:
if x.right.balance == -1 {
t.rotateRight(x.right)
}
x = t.rotateLeft(x)
}
if x.height == h {
// x's height has not changed, so the height
// and balance of its ancestors have not changed;
// no further rebalancing is required.
return
}
x = x.parent
}
}
// rotateRight rotates the subtree rooted at node y.
// turning (y (x a b) c) into (x a (y b c)).
func (t *tree) rotateRight(y *node) *node {
// p -> (y (x a b) c)
p := y.parent
x := y.left
b := x.right
x.checkBalance()
y.checkBalance()
x.setRight(y)
y.setLeft(b)
t.replaceChild(p, y, x)
y.update()
x.update()
return x
}
// rotateLeft rotates the subtree rooted at node x.
// turning (x a (y b c)) into (y (x a b) c).
func (t *tree) rotateLeft(x *node) *node {
// p -> (x a (y b c))
p := x.parent
y := x.right
b := y.left
x.checkBalance()
y.checkBalance()
y.setLeft(x)
x.setRight(b)
t.replaceChild(p, x, y)
x.update()
y.update()
return y
}
// add inserts file into the tree, if not present.
// It panics if file overlaps with another.
func (t *tree) add(file *File) {
pos, parent := t.locate(file.key())
if *pos == nil {
t.set(file, pos, parent) // missing; insert
return
}
if prev := (*pos).file; prev != file {
panic(fmt.Sprintf("file %s (%d-%d) overlaps with file %s (%d-%d)",
prev.Name(), prev.Base(), prev.Base()+prev.Size(),
file.Name(), file.Base(), file.Base()+file.Size()))
}
}
// set updates the existing node at (pos, parent) if present, or
// inserts a new node if not, so that it refers to file.
func (t *tree) set(file *File, pos **node, parent *node) {
if x := *pos; x != nil {
// This code path isn't currently needed
// because FileSet never updates an existing entry.
// Remove this assertion if things change.
if true {
panic("unreachable according to current FileSet requirements")
}
x.file = file
return
}
x := &node{file: file, key: file.key(), parent: parent, height: -1}
*pos = x
t.rebalanceUp(x)
}
// delete deletes the node at pos.
func (t *tree) delete(pos **node) {
t.root.check(nil)
x := *pos
switch {
case x == nil:
// This code path isn't currently needed because FileSet
// only calls delete after a positive locate.
// Remove this assertion if things change.
if true {
panic("unreachable according to current FileSet requirements")
}
return
case x.left == nil:
if *pos = x.right; *pos != nil {
(*pos).parent = x.parent
}
t.rebalanceUp(x.parent)
case x.right == nil:
*pos = x.left
x.left.parent = x.parent
t.rebalanceUp(x.parent)
default:
t.deleteSwap(pos)
}
x.balance = -100
x.parent = nil
x.left = nil
x.right = nil
x.height = -1
t.root.check(nil)
}
// deleteSwap deletes a node that has two children by replacing
// it by its in-order successor, then triggers a rebalance.
func (t *tree) deleteSwap(pos **node) {
x := *pos
z := t.deleteMin(&x.right)
*pos = z
unbalanced := z.parent // lowest potentially unbalanced node
if unbalanced == x {
unbalanced = z // (x a (z nil b)) -> (z a b)
}
z.parent = x.parent
z.height = x.height
z.balance = x.balance
z.setLeft(x.left)
z.setRight(x.right)
t.rebalanceUp(unbalanced)
}
// deleteMin updates the subtree rooted at *zpos to delete its minimum
// (leftmost) element, which may be *zpos itself. It returns the
// deleted node.
func (t *tree) deleteMin(zpos **node) (z *node) {
for (*zpos).left != nil {
zpos = &(*zpos).left
}
z = *zpos
*zpos = z.right
if *zpos != nil {
(*zpos).parent = z.parent
}
return z
}
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package version provides operations on [Go versions]
// in [Go toolchain name syntax]: strings like
// "go1.20", "go1.21.0", "go1.22rc2", and "go1.23.4-bigcorp".
//
// [Go versions]: https://go.dev/doc/toolchain#version
// [Go toolchain name syntax]: https://go.dev/doc/toolchain#name
package version // import "go/version"
import (
"internal/gover"
"strings"
)
// stripGo converts from a "go1.21-bigcorp" version to a "1.21" version.
// If v does not start with "go", stripGo returns the empty string (a known invalid version).
func stripGo(v string) string {
v, _, _ = strings.Cut(v, "-") // strip -bigcorp suffix.
if len(v) < 2 || v[:2] != "go" {
return ""
}
return v[2:]
}
// Lang returns the Go language version for version x.
// If x is not a valid version, Lang returns the empty string.
// For example:
//
// Lang("go1.21rc2") = "go1.21"
// Lang("go1.21.2") = "go1.21"
// Lang("go1.21") = "go1.21"
// Lang("go1") = "go1"
// Lang("bad") = ""
// Lang("1.21") = ""
func Lang(x string) string {
v := gover.Lang(stripGo(x))
if v == "" {
return ""
}
if strings.HasPrefix(x[2:], v) {
return x[:2+len(v)] // "go"+v without allocation
} else {
return "go" + v
}
}
// Compare returns -1, 0, or +1 depending on whether
// x < y, x == y, or x > y, interpreted as Go versions.
// The versions x and y must begin with a "go" prefix: "go1.21" not "1.21".
// Invalid versions, including the empty string, compare less than
// valid versions and equal to each other.
// The language version "go1.21" compares less than the
// release candidate and eventual releases "go1.21rc1" and "go1.21.0".
func Compare(x, y string) int {
return gover.Compare(stripGo(x), stripGo(y))
}
// IsValid reports whether the version x is valid.
func IsValid(x string) bool {
return gover.IsValid(stripGo(x))
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package adler32 implements the Adler-32 checksum.
//
// It is defined in RFC 1950:
//
// Adler-32 is composed of two sums accumulated per byte: s1 is
// the sum of all bytes, s2 is the sum of all s1 values. Both sums
// are done modulo 65521. s1 is initialized to 1, s2 to zero. The
// Adler-32 checksum is stored as s2*65536 + s1 in most-
// significant-byte first (network) order.
package adler32
import (
"errors"
"hash"
"internal/byteorder"
)
const (
// mod is the largest prime that is less than 65536.
mod = 65521
// nmax is the largest n such that
// 255 * n * (n+1) / 2 + (n+1) * (mod-1) <= 2^32-1.
// It is mentioned in RFC 1950 (search for "5552").
nmax = 5552
)
// The size of an Adler-32 checksum in bytes.
const Size = 4
// digest represents the partial evaluation of a checksum.
// The low 16 bits are s1, the high 16 bits are s2.
type digest uint32
func (d *digest) Reset() { *d = 1 }
// New returns a new hash.Hash32 computing the Adler-32 checksum. Its
// Sum method will lay the value out in big-endian byte order. The
// returned Hash32 also implements [encoding.BinaryMarshaler] and
// [encoding.BinaryUnmarshaler] to marshal and unmarshal the internal
// state of the hash.
func New() hash.Hash32 {
d := new(digest)
d.Reset()
return d
}
func (d *digest) Size() int { return Size }
func (d *digest) BlockSize() int { return 4 }
const (
magic = "adl\x01"
marshaledSize = len(magic) + 4
)
func (d *digest) AppendBinary(b []byte) ([]byte, error) {
b = append(b, magic...)
b = byteorder.BEAppendUint32(b, uint32(*d))
return b, nil
}
func (d *digest) MarshalBinary() ([]byte, error) {
return d.AppendBinary(make([]byte, 0, marshaledSize))
}
func (d *digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic) || string(b[:len(magic)]) != magic {
return errors.New("hash/adler32: invalid hash state identifier")
}
if len(b) != marshaledSize {
return errors.New("hash/adler32: invalid hash state size")
}
*d = digest(byteorder.BEUint32(b[len(magic):]))
return nil
}
func (d *digest) Clone() (hash.Cloner, error) {
r := *d
return &r, nil
}
// Add p to the running checksum d.
func update(d digest, p []byte) digest {
s1, s2 := uint32(d&0xffff), uint32(d>>16)
for len(p) > 0 {
var q []byte
if len(p) > nmax {
p, q = p[:nmax], p[nmax:]
}
for len(p) >= 4 {
s1 += uint32(p[0])
s2 += s1
s1 += uint32(p[1])
s2 += s1
s1 += uint32(p[2])
s2 += s1
s1 += uint32(p[3])
s2 += s1
p = p[4:]
}
for _, x := range p {
s1 += uint32(x)
s2 += s1
}
s1 %= mod
s2 %= mod
p = q
}
return digest(s2<<16 | s1)
}
func (d *digest) Write(p []byte) (nn int, err error) {
*d = update(*d, p)
return len(p), nil
}
func (d *digest) Sum32() uint32 { return uint32(*d) }
func (d *digest) Sum(in []byte) []byte {
s := uint32(*d)
return append(in, byte(s>>24), byte(s>>16), byte(s>>8), byte(s))
}
// Checksum returns the Adler-32 checksum of data.
func Checksum(data []byte) uint32 { return uint32(update(1, data)) }
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package crc32 implements the 32-bit cyclic redundancy check, or CRC-32,
// checksum. See https://en.wikipedia.org/wiki/Cyclic_redundancy_check for
// information.
//
// Polynomials are represented in LSB-first form also known as reversed representation.
//
// See https://en.wikipedia.org/wiki/Mathematics_of_cyclic_redundancy_checks#Reversed_representations_and_reciprocal_polynomials
// for information.
package crc32
import (
"errors"
"hash"
"internal/byteorder"
"sync"
"sync/atomic"
)
// The size of a CRC-32 checksum in bytes.
const Size = 4
// Predefined polynomials.
const (
// IEEE is by far and away the most common CRC-32 polynomial.
// Used by ethernet (IEEE 802.3), v.42, fddi, gzip, zip, png, ...
IEEE = 0xedb88320
// Castagnoli's polynomial, used in iSCSI.
// Has better error detection characteristics than IEEE.
// https://dx.doi.org/10.1109/26.231911
Castagnoli = 0x82f63b78
// Koopman's polynomial.
// Also has better error detection characteristics than IEEE.
// https://dx.doi.org/10.1109/DSN.2002.1028931
Koopman = 0xeb31d82e
)
// Table is a 256-word table representing the polynomial for efficient processing.
type Table [256]uint32
// This file makes use of functions implemented in architecture-specific files.
// The interface that they implement is as follows:
//
// // archAvailableIEEE reports whether an architecture-specific CRC32-IEEE
// // algorithm is available.
// archAvailableIEEE() bool
//
// // archInitIEEE initializes the architecture-specific CRC3-IEEE algorithm.
// // It can only be called if archAvailableIEEE() returns true.
// archInitIEEE()
//
// // archUpdateIEEE updates the given CRC32-IEEE. It can only be called if
// // archInitIEEE() was previously called.
// archUpdateIEEE(crc uint32, p []byte) uint32
//
// // archAvailableCastagnoli reports whether an architecture-specific
// // CRC32-C algorithm is available.
// archAvailableCastagnoli() bool
//
// // archInitCastagnoli initializes the architecture-specific CRC32-C
// // algorithm. It can only be called if archAvailableCastagnoli() returns
// // true.
// archInitCastagnoli()
//
// // archUpdateCastagnoli updates the given CRC32-C. It can only be called
// // if archInitCastagnoli() was previously called.
// archUpdateCastagnoli(crc uint32, p []byte) uint32
// castagnoliTable points to a lazily initialized Table for the Castagnoli
// polynomial. MakeTable will always return this value when asked to make a
// Castagnoli table so we can compare against it to find when the caller is
// using this polynomial.
var castagnoliTable *Table
var castagnoliTable8 *slicing8Table
var updateCastagnoli func(crc uint32, p []byte) uint32
var haveCastagnoli atomic.Bool
var castagnoliInitOnce = sync.OnceFunc(func() {
castagnoliTable = simpleMakeTable(Castagnoli)
if archAvailableCastagnoli() {
archInitCastagnoli()
updateCastagnoli = archUpdateCastagnoli
} else {
// Initialize the slicing-by-8 table.
castagnoliTable8 = slicingMakeTable(Castagnoli)
updateCastagnoli = func(crc uint32, p []byte) uint32 {
return slicingUpdate(crc, castagnoliTable8, p)
}
}
haveCastagnoli.Store(true)
})
// IEEETable is the table for the [IEEE] polynomial.
var IEEETable = simpleMakeTable(IEEE)
// ieeeTable8 is the slicing8Table for IEEE
var ieeeTable8 *slicing8Table
var updateIEEE func(crc uint32, p []byte) uint32
var ieeeInitOnce = sync.OnceFunc(func() {
if archAvailableIEEE() {
archInitIEEE()
updateIEEE = archUpdateIEEE
} else {
// Initialize the slicing-by-8 table.
ieeeTable8 = slicingMakeTable(IEEE)
updateIEEE = func(crc uint32, p []byte) uint32 {
return slicingUpdate(crc, ieeeTable8, p)
}
}
})
// MakeTable returns a [Table] constructed from the specified polynomial.
// The contents of this [Table] must not be modified.
func MakeTable(poly uint32) *Table {
switch poly {
case IEEE:
ieeeInitOnce()
return IEEETable
case Castagnoli:
castagnoliInitOnce()
return castagnoliTable
default:
return simpleMakeTable(poly)
}
}
// digest represents the partial evaluation of a checksum.
type digest struct {
crc uint32
tab *Table
}
// New creates a new [hash.Hash32] computing the CRC-32 checksum using the
// polynomial represented by the [Table]. Its Sum method will lay the
// value out in big-endian byte order. The returned Hash32 also
// implements [encoding.BinaryMarshaler] and [encoding.BinaryUnmarshaler] to
// marshal and unmarshal the internal state of the hash.
func New(tab *Table) hash.Hash32 {
if tab == IEEETable {
ieeeInitOnce()
}
return &digest{0, tab}
}
// NewIEEE creates a new [hash.Hash32] computing the CRC-32 checksum using
// the [IEEE] polynomial. Its Sum method will lay the value out in
// big-endian byte order. The returned Hash32 also implements
// [encoding.BinaryMarshaler] and [encoding.BinaryUnmarshaler] to marshal
// and unmarshal the internal state of the hash.
func NewIEEE() hash.Hash32 { return New(IEEETable) }
func (d *digest) Size() int { return Size }
func (d *digest) BlockSize() int { return 1 }
func (d *digest) Reset() { d.crc = 0 }
const (
magic = "crc\x01"
marshaledSize = len(magic) + 4 + 4
)
func (d *digest) AppendBinary(b []byte) ([]byte, error) {
b = append(b, magic...)
b = byteorder.BEAppendUint32(b, tableSum(d.tab))
b = byteorder.BEAppendUint32(b, d.crc)
return b, nil
}
func (d *digest) MarshalBinary() ([]byte, error) {
return d.AppendBinary(make([]byte, 0, marshaledSize))
}
func (d *digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic) || string(b[:len(magic)]) != magic {
return errors.New("hash/crc32: invalid hash state identifier")
}
if len(b) != marshaledSize {
return errors.New("hash/crc32: invalid hash state size")
}
if tableSum(d.tab) != byteorder.BEUint32(b[4:]) {
return errors.New("hash/crc32: tables do not match")
}
d.crc = byteorder.BEUint32(b[8:])
return nil
}
func (d *digest) Clone() (hash.Cloner, error) {
r := *d
return &r, nil
}
func update(crc uint32, tab *Table, p []byte, checkInitIEEE bool) uint32 {
switch {
case haveCastagnoli.Load() && tab == castagnoliTable:
return updateCastagnoli(crc, p)
case tab == IEEETable:
if checkInitIEEE {
ieeeInitOnce()
}
return updateIEEE(crc, p)
default:
return simpleUpdate(crc, tab, p)
}
}
// Update returns the result of adding the bytes in p to the crc.
func Update(crc uint32, tab *Table, p []byte) uint32 {
// Unfortunately, because IEEETable is exported, IEEE may be used without a
// call to MakeTable. We have to make sure it gets initialized in that case.
return update(crc, tab, p, true)
}
func (d *digest) Write(p []byte) (n int, err error) {
// We only create digest objects through New() which takes care of
// initialization in this case.
d.crc = update(d.crc, d.tab, p, false)
return len(p), nil
}
func (d *digest) Sum32() uint32 { return d.crc }
func (d *digest) Sum(in []byte) []byte {
s := d.Sum32()
return append(in, byte(s>>24), byte(s>>16), byte(s>>8), byte(s))
}
// Checksum returns the CRC-32 checksum of data
// using the polynomial represented by the [Table].
func Checksum(data []byte, tab *Table) uint32 { return Update(0, tab, data) }
// ChecksumIEEE returns the CRC-32 checksum of data
// using the [IEEE] polynomial.
func ChecksumIEEE(data []byte) uint32 {
ieeeInitOnce()
return updateIEEE(0, data)
}
// tableSum returns the IEEE checksum of table t.
func tableSum(t *Table) uint32 {
var a [1024]byte
b := a[:0]
if t != nil {
for _, x := range t {
b = byteorder.BEAppendUint32(b, x)
}
}
return ChecksumIEEE(b)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// AMD64-specific hardware-assisted CRC32 algorithms. See crc32.go for a
// description of the interface that each architecture-specific file
// implements.
package crc32
import (
"internal/cpu"
"unsafe"
)
// Offset into internal/cpu records for use in assembly.
const (
offsetX86HasAVX512VPCLMULQDQL = unsafe.Offsetof(cpu.X86.HasAVX512VPCLMULQDQ)
)
// This file contains the code to call the SSE 4.2 version of the Castagnoli
// and IEEE CRC.
// castagnoliSSE42 is defined in crc32_amd64.s and uses the SSE 4.2 CRC32
// instruction.
//
//go:noescape
func castagnoliSSE42(crc uint32, p []byte) uint32
// castagnoliSSE42Triple is defined in crc32_amd64.s and uses the SSE 4.2 CRC32
// instruction.
//
//go:noescape
func castagnoliSSE42Triple(
crcA, crcB, crcC uint32,
a, b, c []byte,
rounds uint32,
) (retA uint32, retB uint32, retC uint32)
// ieeeCLMUL is defined in crc_amd64.s and uses the PCLMULQDQ
// instruction as well as SSE 4.1.
//
//go:noescape
func ieeeCLMUL(crc uint32, p []byte) uint32
const castagnoliK1 = 168
const castagnoliK2 = 1344
type sse42Table [4]Table
var castagnoliSSE42TableK1 *sse42Table
var castagnoliSSE42TableK2 *sse42Table
func archAvailableCastagnoli() bool {
return cpu.X86.HasSSE42
}
func archInitCastagnoli() {
if !cpu.X86.HasSSE42 {
panic("arch-specific Castagnoli not available")
}
castagnoliSSE42TableK1 = new(sse42Table)
castagnoliSSE42TableK2 = new(sse42Table)
// See description in updateCastagnoli.
// t[0][i] = CRC(i000, O)
// t[1][i] = CRC(0i00, O)
// t[2][i] = CRC(00i0, O)
// t[3][i] = CRC(000i, O)
// where O is a sequence of K zeros.
var tmp [castagnoliK2]byte
for b := 0; b < 4; b++ {
for i := 0; i < 256; i++ {
val := uint32(i) << uint32(b*8)
castagnoliSSE42TableK1[b][i] = castagnoliSSE42(val, tmp[:castagnoliK1])
castagnoliSSE42TableK2[b][i] = castagnoliSSE42(val, tmp[:])
}
}
}
// castagnoliShift computes the CRC32-C of K1 or K2 zeroes (depending on the
// table given) with the given initial crc value. This corresponds to
// CRC(crc, O) in the description in updateCastagnoli.
func castagnoliShift(table *sse42Table, crc uint32) uint32 {
return table[3][crc>>24] ^
table[2][(crc>>16)&0xFF] ^
table[1][(crc>>8)&0xFF] ^
table[0][crc&0xFF]
}
func archUpdateCastagnoli(crc uint32, p []byte) uint32 {
if !cpu.X86.HasSSE42 {
panic("not available")
}
// This method is inspired from the algorithm in Intel's white paper:
// "Fast CRC Computation for iSCSI Polynomial Using CRC32 Instruction"
// The same strategy of splitting the buffer in three is used but the
// combining calculation is different; the complete derivation is explained
// below.
//
// -- The basic idea --
//
// The CRC32 instruction (available in SSE4.2) can process 8 bytes at a
// time. In recent Intel architectures the instruction takes 3 cycles;
// however the processor can pipeline up to three instructions if they
// don't depend on each other.
//
// Roughly this means that we can process three buffers in about the same
// time we can process one buffer.
//
// The idea is then to split the buffer in three, CRC the three pieces
// separately and then combine the results.
//
// Combining the results requires precomputed tables, so we must choose a
// fixed buffer length to optimize. The longer the length, the faster; but
// only buffers longer than this length will use the optimization. We choose
// two cutoffs and compute tables for both:
// - one around 512: 168*3=504
// - one around 4KB: 1344*3=4032
//
// -- The nitty gritty --
//
// Let CRC(I, X) be the non-inverted CRC32-C of the sequence X (with
// initial non-inverted CRC I). This function has the following properties:
// (a) CRC(I, AB) = CRC(CRC(I, A), B)
// (b) CRC(I, A xor B) = CRC(I, A) xor CRC(0, B)
//
// Say we want to compute CRC(I, ABC) where A, B, C are three sequences of
// K bytes each, where K is a fixed constant. Let O be the sequence of K zero
// bytes.
//
// CRC(I, ABC) = CRC(I, ABO xor C)
// = CRC(I, ABO) xor CRC(0, C)
// = CRC(CRC(I, AB), O) xor CRC(0, C)
// = CRC(CRC(I, AO xor B), O) xor CRC(0, C)
// = CRC(CRC(I, AO) xor CRC(0, B), O) xor CRC(0, C)
// = CRC(CRC(CRC(I, A), O) xor CRC(0, B), O) xor CRC(0, C)
//
// The castagnoliSSE42Triple function can compute CRC(I, A), CRC(0, B),
// and CRC(0, C) efficiently. We just need to find a way to quickly compute
// CRC(uvwx, O) given a 4-byte initial value uvwx. We can precompute these
// values; since we can't have a 32-bit table, we break it up into four
// 8-bit tables:
//
// CRC(uvwx, O) = CRC(u000, O) xor
// CRC(0v00, O) xor
// CRC(00w0, O) xor
// CRC(000x, O)
//
// We can compute tables corresponding to the four terms for all 8-bit
// values.
crc = ^crc
// If a buffer is long enough to use the optimization, process the first few
// bytes to align the buffer to an 8 byte boundary (if necessary).
if len(p) >= castagnoliK1*3 {
delta := int(uintptr(unsafe.Pointer(&p[0])) & 7)
if delta != 0 {
delta = 8 - delta
crc = castagnoliSSE42(crc, p[:delta])
p = p[delta:]
}
}
// Process 3*K2 at a time.
for len(p) >= castagnoliK2*3 {
// Compute CRC(I, A), CRC(0, B), and CRC(0, C).
crcA, crcB, crcC := castagnoliSSE42Triple(
crc, 0, 0,
p, p[castagnoliK2:], p[castagnoliK2*2:],
castagnoliK2/24)
// CRC(I, AB) = CRC(CRC(I, A), O) xor CRC(0, B)
crcAB := castagnoliShift(castagnoliSSE42TableK2, crcA) ^ crcB
// CRC(I, ABC) = CRC(CRC(I, AB), O) xor CRC(0, C)
crc = castagnoliShift(castagnoliSSE42TableK2, crcAB) ^ crcC
p = p[castagnoliK2*3:]
}
// Process 3*K1 at a time.
for len(p) >= castagnoliK1*3 {
// Compute CRC(I, A), CRC(0, B), and CRC(0, C).
crcA, crcB, crcC := castagnoliSSE42Triple(
crc, 0, 0,
p, p[castagnoliK1:], p[castagnoliK1*2:],
castagnoliK1/24)
// CRC(I, AB) = CRC(CRC(I, A), O) xor CRC(0, B)
crcAB := castagnoliShift(castagnoliSSE42TableK1, crcA) ^ crcB
// CRC(I, ABC) = CRC(CRC(I, AB), O) xor CRC(0, C)
crc = castagnoliShift(castagnoliSSE42TableK1, crcAB) ^ crcC
p = p[castagnoliK1*3:]
}
// Use the simple implementation for what's left.
crc = castagnoliSSE42(crc, p)
return ^crc
}
func archAvailableIEEE() bool {
return cpu.X86.HasPCLMULQDQ && cpu.X86.HasSSE41
}
var archIeeeTable8 *slicing8Table
func archInitIEEE() {
if !cpu.X86.HasPCLMULQDQ || !cpu.X86.HasSSE41 {
panic("not available")
}
// We still use slicing-by-8 for small buffers.
archIeeeTable8 = slicingMakeTable(IEEE)
}
func archUpdateIEEE(crc uint32, p []byte) uint32 {
if !cpu.X86.HasPCLMULQDQ || !cpu.X86.HasSSE41 {
panic("not available")
}
if len(p) >= 64 {
left := len(p) & 15
do := len(p) - left
crc = ^ieeeCLMUL(^crc, p[:do])
p = p[do:]
}
if len(p) == 0 {
return crc
}
return slicingUpdate(crc, archIeeeTable8, p)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file contains CRC32 algorithms that are not specific to any architecture
// and don't use hardware acceleration.
//
// The simple (and slow) CRC32 implementation only uses a 256*4 bytes table.
//
// The slicing-by-8 algorithm is a faster implementation that uses a bigger
// table (8*256*4 bytes).
package crc32
import "internal/byteorder"
// simpleMakeTable allocates and constructs a Table for the specified
// polynomial. The table is suitable for use with the simple algorithm
// (simpleUpdate).
func simpleMakeTable(poly uint32) *Table {
t := new(Table)
simplePopulateTable(poly, t)
return t
}
// simplePopulateTable constructs a Table for the specified polynomial, suitable
// for use with simpleUpdate.
func simplePopulateTable(poly uint32, t *Table) {
for i := 0; i < 256; i++ {
crc := uint32(i)
for j := 0; j < 8; j++ {
if crc&1 == 1 {
crc = (crc >> 1) ^ poly
} else {
crc >>= 1
}
}
t[i] = crc
}
}
// simpleUpdate uses the simple algorithm to update the CRC, given a table that
// was previously computed using simpleMakeTable.
func simpleUpdate(crc uint32, tab *Table, p []byte) uint32 {
crc = ^crc
for _, v := range p {
crc = tab[byte(crc)^v] ^ (crc >> 8)
}
return ^crc
}
// Use slicing-by-8 when payload >= this value.
const slicing8Cutoff = 16
// slicing8Table is array of 8 Tables, used by the slicing-by-8 algorithm.
type slicing8Table [8]Table
// slicingMakeTable constructs a slicing8Table for the specified polynomial. The
// table is suitable for use with the slicing-by-8 algorithm (slicingUpdate).
func slicingMakeTable(poly uint32) *slicing8Table {
t := new(slicing8Table)
simplePopulateTable(poly, &t[0])
for i := 0; i < 256; i++ {
crc := t[0][i]
for j := 1; j < 8; j++ {
crc = t[0][crc&0xFF] ^ (crc >> 8)
t[j][i] = crc
}
}
return t
}
// slicingUpdate uses the slicing-by-8 algorithm to update the CRC, given a
// table that was previously computed using slicingMakeTable.
func slicingUpdate(crc uint32, tab *slicing8Table, p []byte) uint32 {
if len(p) >= slicing8Cutoff {
crc = ^crc
for len(p) > 8 {
crc ^= byteorder.LEUint32(p)
crc = tab[0][p[7]] ^ tab[1][p[6]] ^ tab[2][p[5]] ^ tab[3][p[4]] ^
tab[4][crc>>24] ^ tab[5][(crc>>16)&0xFF] ^
tab[6][(crc>>8)&0xFF] ^ tab[7][crc&0xFF]
p = p[8:]
}
crc = ^crc
}
if len(p) == 0 {
return crc
}
return simpleUpdate(crc, &tab[0], p)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package crc64 implements the 64-bit cyclic redundancy check, or CRC-64,
// checksum. See https://en.wikipedia.org/wiki/Cyclic_redundancy_check for
// information.
package crc64
import (
"errors"
"hash"
"internal/byteorder"
"sync"
)
// The size of a CRC-64 checksum in bytes.
const Size = 8
// Predefined polynomials.
const (
// The ISO polynomial, defined in ISO 3309 and used in HDLC.
ISO = 0xD800000000000000
// The ECMA polynomial, defined in ECMA 182.
ECMA = 0xC96C5795D7870F42
)
// Table is a 256-word table representing the polynomial for efficient processing.
type Table [256]uint64
var (
slicing8TableISO *[8]Table
slicing8TableECMA *[8]Table
)
var buildSlicing8TablesOnce = sync.OnceFunc(buildSlicing8Tables)
func buildSlicing8Tables() {
slicing8TableISO = makeSlicingBy8Table(makeTable(ISO))
slicing8TableECMA = makeSlicingBy8Table(makeTable(ECMA))
}
// MakeTable returns a [Table] constructed from the specified polynomial.
// The contents of this [Table] must not be modified.
func MakeTable(poly uint64) *Table {
buildSlicing8TablesOnce()
switch poly {
case ISO:
return &slicing8TableISO[0]
case ECMA:
return &slicing8TableECMA[0]
default:
return makeTable(poly)
}
}
func makeTable(poly uint64) *Table {
t := new(Table)
for i := 0; i < 256; i++ {
crc := uint64(i)
for j := 0; j < 8; j++ {
if crc&1 == 1 {
crc = (crc >> 1) ^ poly
} else {
crc >>= 1
}
}
t[i] = crc
}
return t
}
func makeSlicingBy8Table(t *Table) *[8]Table {
var helperTable [8]Table
helperTable[0] = *t
for i := 0; i < 256; i++ {
crc := t[i]
for j := 1; j < 8; j++ {
crc = t[crc&0xff] ^ (crc >> 8)
helperTable[j][i] = crc
}
}
return &helperTable
}
// digest represents the partial evaluation of a checksum.
type digest struct {
crc uint64
tab *Table
}
// New creates a new hash.Hash64 computing the CRC-64 checksum using the
// polynomial represented by the [Table]. Its Sum method will lay the
// value out in big-endian byte order. The returned Hash64 also
// implements [encoding.BinaryMarshaler] and [encoding.BinaryUnmarshaler] to
// marshal and unmarshal the internal state of the hash.
func New(tab *Table) hash.Hash64 { return &digest{0, tab} }
func (d *digest) Size() int { return Size }
func (d *digest) BlockSize() int { return 1 }
func (d *digest) Reset() { d.crc = 0 }
const (
magic = "crc\x02"
marshaledSize = len(magic) + 8 + 8
)
func (d *digest) AppendBinary(b []byte) ([]byte, error) {
b = append(b, magic...)
b = byteorder.BEAppendUint64(b, tableSum(d.tab))
b = byteorder.BEAppendUint64(b, d.crc)
return b, nil
}
func (d *digest) MarshalBinary() ([]byte, error) {
return d.AppendBinary(make([]byte, 0, marshaledSize))
}
func (d *digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic) || string(b[:len(magic)]) != magic {
return errors.New("hash/crc64: invalid hash state identifier")
}
if len(b) != marshaledSize {
return errors.New("hash/crc64: invalid hash state size")
}
if tableSum(d.tab) != byteorder.BEUint64(b[4:]) {
return errors.New("hash/crc64: tables do not match")
}
d.crc = byteorder.BEUint64(b[12:])
return nil
}
func (d *digest) Clone() (hash.Cloner, error) {
r := *d
return &r, nil
}
func update(crc uint64, tab *Table, p []byte) uint64 {
buildSlicing8TablesOnce()
crc = ^crc
// Table comparison is somewhat expensive, so avoid it for small sizes
for len(p) >= 64 {
var helperTable *[8]Table
if *tab == slicing8TableECMA[0] {
helperTable = slicing8TableECMA
} else if *tab == slicing8TableISO[0] {
helperTable = slicing8TableISO
// For smaller sizes creating extended table takes too much time
} else if len(p) >= 2048 {
// According to the tests between various x86 and arm CPUs, 2k is a reasonable
// threshold for now. This may change in the future.
helperTable = makeSlicingBy8Table(tab)
} else {
break
}
// Update using slicing-by-8
for len(p) > 8 {
crc ^= byteorder.LEUint64(p)
crc = helperTable[7][crc&0xff] ^
helperTable[6][(crc>>8)&0xff] ^
helperTable[5][(crc>>16)&0xff] ^
helperTable[4][(crc>>24)&0xff] ^
helperTable[3][(crc>>32)&0xff] ^
helperTable[2][(crc>>40)&0xff] ^
helperTable[1][(crc>>48)&0xff] ^
helperTable[0][crc>>56]
p = p[8:]
}
}
// For reminders or small sizes
for _, v := range p {
crc = tab[byte(crc)^v] ^ (crc >> 8)
}
return ^crc
}
// Update returns the result of adding the bytes in p to the crc.
func Update(crc uint64, tab *Table, p []byte) uint64 {
return update(crc, tab, p)
}
func (d *digest) Write(p []byte) (n int, err error) {
d.crc = update(d.crc, d.tab, p)
return len(p), nil
}
func (d *digest) Sum64() uint64 { return d.crc }
func (d *digest) Sum(in []byte) []byte {
s := d.Sum64()
return append(in, byte(s>>56), byte(s>>48), byte(s>>40), byte(s>>32), byte(s>>24), byte(s>>16), byte(s>>8), byte(s))
}
// Checksum returns the CRC-64 checksum of data
// using the polynomial represented by the [Table].
func Checksum(data []byte, tab *Table) uint64 { return update(0, tab, data) }
// tableSum returns the ISO checksum of table t.
func tableSum(t *Table) uint64 {
var a [2048]byte
b := a[:0]
if t != nil {
for _, x := range t {
b = byteorder.BEAppendUint64(b, x)
}
}
return Checksum(b, MakeTable(ISO))
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package html
import "sync"
// All entities that do not end with ';' are 6 or fewer bytes long.
const longestEntityWithoutSemicolon = 6
// entityMaps returns entity and entity2.
//
// entity is a map from HTML entity names to their values. The semicolon matters:
// https://html.spec.whatwg.org/multipage/named-characters.html
// lists both "amp" and "amp;" as two separate entries.
// Note that the HTML5 list is larger than the HTML4 list at
// http://www.w3.org/TR/html4/sgml/entities.html
//
// entity2 is a map of HTML entities to two unicode codepoints.
var entityMaps = sync.OnceValues(func() (entity map[string]rune, entity2 map[string][2]rune) {
entity = map[string]rune{
"AElig;": '\U000000C6',
"AMP;": '\U00000026',
"Aacute;": '\U000000C1',
"Abreve;": '\U00000102',
"Acirc;": '\U000000C2',
"Acy;": '\U00000410',
"Afr;": '\U0001D504',
"Agrave;": '\U000000C0',
"Alpha;": '\U00000391',
"Amacr;": '\U00000100',
"And;": '\U00002A53',
"Aogon;": '\U00000104',
"Aopf;": '\U0001D538',
"ApplyFunction;": '\U00002061',
"Aring;": '\U000000C5',
"Ascr;": '\U0001D49C',
"Assign;": '\U00002254',
"Atilde;": '\U000000C3',
"Auml;": '\U000000C4',
"Backslash;": '\U00002216',
"Barv;": '\U00002AE7',
"Barwed;": '\U00002306',
"Bcy;": '\U00000411',
"Because;": '\U00002235',
"Bernoullis;": '\U0000212C',
"Beta;": '\U00000392',
"Bfr;": '\U0001D505',
"Bopf;": '\U0001D539',
"Breve;": '\U000002D8',
"Bscr;": '\U0000212C',
"Bumpeq;": '\U0000224E',
"CHcy;": '\U00000427',
"COPY;": '\U000000A9',
"Cacute;": '\U00000106',
"Cap;": '\U000022D2',
"CapitalDifferentialD;": '\U00002145',
"Cayleys;": '\U0000212D',
"Ccaron;": '\U0000010C',
"Ccedil;": '\U000000C7',
"Ccirc;": '\U00000108',
"Cconint;": '\U00002230',
"Cdot;": '\U0000010A',
"Cedilla;": '\U000000B8',
"CenterDot;": '\U000000B7',
"Cfr;": '\U0000212D',
"Chi;": '\U000003A7',
"CircleDot;": '\U00002299',
"CircleMinus;": '\U00002296',
"CirclePlus;": '\U00002295',
"CircleTimes;": '\U00002297',
"ClockwiseContourIntegral;": '\U00002232',
"CloseCurlyDoubleQuote;": '\U0000201D',
"CloseCurlyQuote;": '\U00002019',
"Colon;": '\U00002237',
"Colone;": '\U00002A74',
"Congruent;": '\U00002261',
"Conint;": '\U0000222F',
"ContourIntegral;": '\U0000222E',
"Copf;": '\U00002102',
"Coproduct;": '\U00002210',
"CounterClockwiseContourIntegral;": '\U00002233',
"Cross;": '\U00002A2F',
"Cscr;": '\U0001D49E',
"Cup;": '\U000022D3',
"CupCap;": '\U0000224D',
"DD;": '\U00002145',
"DDotrahd;": '\U00002911',
"DJcy;": '\U00000402',
"DScy;": '\U00000405',
"DZcy;": '\U0000040F',
"Dagger;": '\U00002021',
"Darr;": '\U000021A1',
"Dashv;": '\U00002AE4',
"Dcaron;": '\U0000010E',
"Dcy;": '\U00000414',
"Del;": '\U00002207',
"Delta;": '\U00000394',
"Dfr;": '\U0001D507',
"DiacriticalAcute;": '\U000000B4',
"DiacriticalDot;": '\U000002D9',
"DiacriticalDoubleAcute;": '\U000002DD',
"DiacriticalGrave;": '\U00000060',
"DiacriticalTilde;": '\U000002DC',
"Diamond;": '\U000022C4',
"DifferentialD;": '\U00002146',
"Dopf;": '\U0001D53B',
"Dot;": '\U000000A8',
"DotDot;": '\U000020DC',
"DotEqual;": '\U00002250',
"DoubleContourIntegral;": '\U0000222F',
"DoubleDot;": '\U000000A8',
"DoubleDownArrow;": '\U000021D3',
"DoubleLeftArrow;": '\U000021D0',
"DoubleLeftRightArrow;": '\U000021D4',
"DoubleLeftTee;": '\U00002AE4',
"DoubleLongLeftArrow;": '\U000027F8',
"DoubleLongLeftRightArrow;": '\U000027FA',
"DoubleLongRightArrow;": '\U000027F9',
"DoubleRightArrow;": '\U000021D2',
"DoubleRightTee;": '\U000022A8',
"DoubleUpArrow;": '\U000021D1',
"DoubleUpDownArrow;": '\U000021D5',
"DoubleVerticalBar;": '\U00002225',
"DownArrow;": '\U00002193',
"DownArrowBar;": '\U00002913',
"DownArrowUpArrow;": '\U000021F5',
"DownBreve;": '\U00000311',
"DownLeftRightVector;": '\U00002950',
"DownLeftTeeVector;": '\U0000295E',
"DownLeftVector;": '\U000021BD',
"DownLeftVectorBar;": '\U00002956',
"DownRightTeeVector;": '\U0000295F',
"DownRightVector;": '\U000021C1',
"DownRightVectorBar;": '\U00002957',
"DownTee;": '\U000022A4',
"DownTeeArrow;": '\U000021A7',
"Downarrow;": '\U000021D3',
"Dscr;": '\U0001D49F',
"Dstrok;": '\U00000110',
"ENG;": '\U0000014A',
"ETH;": '\U000000D0',
"Eacute;": '\U000000C9',
"Ecaron;": '\U0000011A',
"Ecirc;": '\U000000CA',
"Ecy;": '\U0000042D',
"Edot;": '\U00000116',
"Efr;": '\U0001D508',
"Egrave;": '\U000000C8',
"Element;": '\U00002208',
"Emacr;": '\U00000112',
"EmptySmallSquare;": '\U000025FB',
"EmptyVerySmallSquare;": '\U000025AB',
"Eogon;": '\U00000118',
"Eopf;": '\U0001D53C',
"Epsilon;": '\U00000395',
"Equal;": '\U00002A75',
"EqualTilde;": '\U00002242',
"Equilibrium;": '\U000021CC',
"Escr;": '\U00002130',
"Esim;": '\U00002A73',
"Eta;": '\U00000397',
"Euml;": '\U000000CB',
"Exists;": '\U00002203',
"ExponentialE;": '\U00002147',
"Fcy;": '\U00000424',
"Ffr;": '\U0001D509',
"FilledSmallSquare;": '\U000025FC',
"FilledVerySmallSquare;": '\U000025AA',
"Fopf;": '\U0001D53D',
"ForAll;": '\U00002200',
"Fouriertrf;": '\U00002131',
"Fscr;": '\U00002131',
"GJcy;": '\U00000403',
"GT;": '\U0000003E',
"Gamma;": '\U00000393',
"Gammad;": '\U000003DC',
"Gbreve;": '\U0000011E',
"Gcedil;": '\U00000122',
"Gcirc;": '\U0000011C',
"Gcy;": '\U00000413',
"Gdot;": '\U00000120',
"Gfr;": '\U0001D50A',
"Gg;": '\U000022D9',
"Gopf;": '\U0001D53E',
"GreaterEqual;": '\U00002265',
"GreaterEqualLess;": '\U000022DB',
"GreaterFullEqual;": '\U00002267',
"GreaterGreater;": '\U00002AA2',
"GreaterLess;": '\U00002277',
"GreaterSlantEqual;": '\U00002A7E',
"GreaterTilde;": '\U00002273',
"Gscr;": '\U0001D4A2',
"Gt;": '\U0000226B',
"HARDcy;": '\U0000042A',
"Hacek;": '\U000002C7',
"Hat;": '\U0000005E',
"Hcirc;": '\U00000124',
"Hfr;": '\U0000210C',
"HilbertSpace;": '\U0000210B',
"Hopf;": '\U0000210D',
"HorizontalLine;": '\U00002500',
"Hscr;": '\U0000210B',
"Hstrok;": '\U00000126',
"HumpDownHump;": '\U0000224E',
"HumpEqual;": '\U0000224F',
"IEcy;": '\U00000415',
"IJlig;": '\U00000132',
"IOcy;": '\U00000401',
"Iacute;": '\U000000CD',
"Icirc;": '\U000000CE',
"Icy;": '\U00000418',
"Idot;": '\U00000130',
"Ifr;": '\U00002111',
"Igrave;": '\U000000CC',
"Im;": '\U00002111',
"Imacr;": '\U0000012A',
"ImaginaryI;": '\U00002148',
"Implies;": '\U000021D2',
"Int;": '\U0000222C',
"Integral;": '\U0000222B',
"Intersection;": '\U000022C2',
"InvisibleComma;": '\U00002063',
"InvisibleTimes;": '\U00002062',
"Iogon;": '\U0000012E',
"Iopf;": '\U0001D540',
"Iota;": '\U00000399',
"Iscr;": '\U00002110',
"Itilde;": '\U00000128',
"Iukcy;": '\U00000406',
"Iuml;": '\U000000CF',
"Jcirc;": '\U00000134',
"Jcy;": '\U00000419',
"Jfr;": '\U0001D50D',
"Jopf;": '\U0001D541',
"Jscr;": '\U0001D4A5',
"Jsercy;": '\U00000408',
"Jukcy;": '\U00000404',
"KHcy;": '\U00000425',
"KJcy;": '\U0000040C',
"Kappa;": '\U0000039A',
"Kcedil;": '\U00000136',
"Kcy;": '\U0000041A',
"Kfr;": '\U0001D50E',
"Kopf;": '\U0001D542',
"Kscr;": '\U0001D4A6',
"LJcy;": '\U00000409',
"LT;": '\U0000003C',
"Lacute;": '\U00000139',
"Lambda;": '\U0000039B',
"Lang;": '\U000027EA',
"Laplacetrf;": '\U00002112',
"Larr;": '\U0000219E',
"Lcaron;": '\U0000013D',
"Lcedil;": '\U0000013B',
"Lcy;": '\U0000041B',
"LeftAngleBracket;": '\U000027E8',
"LeftArrow;": '\U00002190',
"LeftArrowBar;": '\U000021E4',
"LeftArrowRightArrow;": '\U000021C6',
"LeftCeiling;": '\U00002308',
"LeftDoubleBracket;": '\U000027E6',
"LeftDownTeeVector;": '\U00002961',
"LeftDownVector;": '\U000021C3',
"LeftDownVectorBar;": '\U00002959',
"LeftFloor;": '\U0000230A',
"LeftRightArrow;": '\U00002194',
"LeftRightVector;": '\U0000294E',
"LeftTee;": '\U000022A3',
"LeftTeeArrow;": '\U000021A4',
"LeftTeeVector;": '\U0000295A',
"LeftTriangle;": '\U000022B2',
"LeftTriangleBar;": '\U000029CF',
"LeftTriangleEqual;": '\U000022B4',
"LeftUpDownVector;": '\U00002951',
"LeftUpTeeVector;": '\U00002960',
"LeftUpVector;": '\U000021BF',
"LeftUpVectorBar;": '\U00002958',
"LeftVector;": '\U000021BC',
"LeftVectorBar;": '\U00002952',
"Leftarrow;": '\U000021D0',
"Leftrightarrow;": '\U000021D4',
"LessEqualGreater;": '\U000022DA',
"LessFullEqual;": '\U00002266',
"LessGreater;": '\U00002276',
"LessLess;": '\U00002AA1',
"LessSlantEqual;": '\U00002A7D',
"LessTilde;": '\U00002272',
"Lfr;": '\U0001D50F',
"Ll;": '\U000022D8',
"Lleftarrow;": '\U000021DA',
"Lmidot;": '\U0000013F',
"LongLeftArrow;": '\U000027F5',
"LongLeftRightArrow;": '\U000027F7',
"LongRightArrow;": '\U000027F6',
"Longleftarrow;": '\U000027F8',
"Longleftrightarrow;": '\U000027FA',
"Longrightarrow;": '\U000027F9',
"Lopf;": '\U0001D543',
"LowerLeftArrow;": '\U00002199',
"LowerRightArrow;": '\U00002198',
"Lscr;": '\U00002112',
"Lsh;": '\U000021B0',
"Lstrok;": '\U00000141',
"Lt;": '\U0000226A',
"Map;": '\U00002905',
"Mcy;": '\U0000041C',
"MediumSpace;": '\U0000205F',
"Mellintrf;": '\U00002133',
"Mfr;": '\U0001D510',
"MinusPlus;": '\U00002213',
"Mopf;": '\U0001D544',
"Mscr;": '\U00002133',
"Mu;": '\U0000039C',
"NJcy;": '\U0000040A',
"Nacute;": '\U00000143',
"Ncaron;": '\U00000147',
"Ncedil;": '\U00000145',
"Ncy;": '\U0000041D',
"NegativeMediumSpace;": '\U0000200B',
"NegativeThickSpace;": '\U0000200B',
"NegativeThinSpace;": '\U0000200B',
"NegativeVeryThinSpace;": '\U0000200B',
"NestedGreaterGreater;": '\U0000226B',
"NestedLessLess;": '\U0000226A',
"NewLine;": '\U0000000A',
"Nfr;": '\U0001D511',
"NoBreak;": '\U00002060',
"NonBreakingSpace;": '\U000000A0',
"Nopf;": '\U00002115',
"Not;": '\U00002AEC',
"NotCongruent;": '\U00002262',
"NotCupCap;": '\U0000226D',
"NotDoubleVerticalBar;": '\U00002226',
"NotElement;": '\U00002209',
"NotEqual;": '\U00002260',
"NotExists;": '\U00002204',
"NotGreater;": '\U0000226F',
"NotGreaterEqual;": '\U00002271',
"NotGreaterLess;": '\U00002279',
"NotGreaterTilde;": '\U00002275',
"NotLeftTriangle;": '\U000022EA',
"NotLeftTriangleEqual;": '\U000022EC',
"NotLess;": '\U0000226E',
"NotLessEqual;": '\U00002270',
"NotLessGreater;": '\U00002278',
"NotLessTilde;": '\U00002274',
"NotPrecedes;": '\U00002280',
"NotPrecedesSlantEqual;": '\U000022E0',
"NotReverseElement;": '\U0000220C',
"NotRightTriangle;": '\U000022EB',
"NotRightTriangleEqual;": '\U000022ED',
"NotSquareSubsetEqual;": '\U000022E2',
"NotSquareSupersetEqual;": '\U000022E3',
"NotSubsetEqual;": '\U00002288',
"NotSucceeds;": '\U00002281',
"NotSucceedsSlantEqual;": '\U000022E1',
"NotSupersetEqual;": '\U00002289',
"NotTilde;": '\U00002241',
"NotTildeEqual;": '\U00002244',
"NotTildeFullEqual;": '\U00002247',
"NotTildeTilde;": '\U00002249',
"NotVerticalBar;": '\U00002224',
"Nscr;": '\U0001D4A9',
"Ntilde;": '\U000000D1',
"Nu;": '\U0000039D',
"OElig;": '\U00000152',
"Oacute;": '\U000000D3',
"Ocirc;": '\U000000D4',
"Ocy;": '\U0000041E',
"Odblac;": '\U00000150',
"Ofr;": '\U0001D512',
"Ograve;": '\U000000D2',
"Omacr;": '\U0000014C',
"Omega;": '\U000003A9',
"Omicron;": '\U0000039F',
"Oopf;": '\U0001D546',
"OpenCurlyDoubleQuote;": '\U0000201C',
"OpenCurlyQuote;": '\U00002018',
"Or;": '\U00002A54',
"Oscr;": '\U0001D4AA',
"Oslash;": '\U000000D8',
"Otilde;": '\U000000D5',
"Otimes;": '\U00002A37',
"Ouml;": '\U000000D6',
"OverBar;": '\U0000203E',
"OverBrace;": '\U000023DE',
"OverBracket;": '\U000023B4',
"OverParenthesis;": '\U000023DC',
"PartialD;": '\U00002202',
"Pcy;": '\U0000041F',
"Pfr;": '\U0001D513',
"Phi;": '\U000003A6',
"Pi;": '\U000003A0',
"PlusMinus;": '\U000000B1',
"Poincareplane;": '\U0000210C',
"Popf;": '\U00002119',
"Pr;": '\U00002ABB',
"Precedes;": '\U0000227A',
"PrecedesEqual;": '\U00002AAF',
"PrecedesSlantEqual;": '\U0000227C',
"PrecedesTilde;": '\U0000227E',
"Prime;": '\U00002033',
"Product;": '\U0000220F',
"Proportion;": '\U00002237',
"Proportional;": '\U0000221D',
"Pscr;": '\U0001D4AB',
"Psi;": '\U000003A8',
"QUOT;": '\U00000022',
"Qfr;": '\U0001D514',
"Qopf;": '\U0000211A',
"Qscr;": '\U0001D4AC',
"RBarr;": '\U00002910',
"REG;": '\U000000AE',
"Racute;": '\U00000154',
"Rang;": '\U000027EB',
"Rarr;": '\U000021A0',
"Rarrtl;": '\U00002916',
"Rcaron;": '\U00000158',
"Rcedil;": '\U00000156',
"Rcy;": '\U00000420',
"Re;": '\U0000211C',
"ReverseElement;": '\U0000220B',
"ReverseEquilibrium;": '\U000021CB',
"ReverseUpEquilibrium;": '\U0000296F',
"Rfr;": '\U0000211C',
"Rho;": '\U000003A1',
"RightAngleBracket;": '\U000027E9',
"RightArrow;": '\U00002192',
"RightArrowBar;": '\U000021E5',
"RightArrowLeftArrow;": '\U000021C4',
"RightCeiling;": '\U00002309',
"RightDoubleBracket;": '\U000027E7',
"RightDownTeeVector;": '\U0000295D',
"RightDownVector;": '\U000021C2',
"RightDownVectorBar;": '\U00002955',
"RightFloor;": '\U0000230B',
"RightTee;": '\U000022A2',
"RightTeeArrow;": '\U000021A6',
"RightTeeVector;": '\U0000295B',
"RightTriangle;": '\U000022B3',
"RightTriangleBar;": '\U000029D0',
"RightTriangleEqual;": '\U000022B5',
"RightUpDownVector;": '\U0000294F',
"RightUpTeeVector;": '\U0000295C',
"RightUpVector;": '\U000021BE',
"RightUpVectorBar;": '\U00002954',
"RightVector;": '\U000021C0',
"RightVectorBar;": '\U00002953',
"Rightarrow;": '\U000021D2',
"Ropf;": '\U0000211D',
"RoundImplies;": '\U00002970',
"Rrightarrow;": '\U000021DB',
"Rscr;": '\U0000211B',
"Rsh;": '\U000021B1',
"RuleDelayed;": '\U000029F4',
"SHCHcy;": '\U00000429',
"SHcy;": '\U00000428',
"SOFTcy;": '\U0000042C',
"Sacute;": '\U0000015A',
"Sc;": '\U00002ABC',
"Scaron;": '\U00000160',
"Scedil;": '\U0000015E',
"Scirc;": '\U0000015C',
"Scy;": '\U00000421',
"Sfr;": '\U0001D516',
"ShortDownArrow;": '\U00002193',
"ShortLeftArrow;": '\U00002190',
"ShortRightArrow;": '\U00002192',
"ShortUpArrow;": '\U00002191',
"Sigma;": '\U000003A3',
"SmallCircle;": '\U00002218',
"Sopf;": '\U0001D54A',
"Sqrt;": '\U0000221A',
"Square;": '\U000025A1',
"SquareIntersection;": '\U00002293',
"SquareSubset;": '\U0000228F',
"SquareSubsetEqual;": '\U00002291',
"SquareSuperset;": '\U00002290',
"SquareSupersetEqual;": '\U00002292',
"SquareUnion;": '\U00002294',
"Sscr;": '\U0001D4AE',
"Star;": '\U000022C6',
"Sub;": '\U000022D0',
"Subset;": '\U000022D0',
"SubsetEqual;": '\U00002286',
"Succeeds;": '\U0000227B',
"SucceedsEqual;": '\U00002AB0',
"SucceedsSlantEqual;": '\U0000227D',
"SucceedsTilde;": '\U0000227F',
"SuchThat;": '\U0000220B',
"Sum;": '\U00002211',
"Sup;": '\U000022D1',
"Superset;": '\U00002283',
"SupersetEqual;": '\U00002287',
"Supset;": '\U000022D1',
"THORN;": '\U000000DE',
"TRADE;": '\U00002122',
"TSHcy;": '\U0000040B',
"TScy;": '\U00000426',
"Tab;": '\U00000009',
"Tau;": '\U000003A4',
"Tcaron;": '\U00000164',
"Tcedil;": '\U00000162',
"Tcy;": '\U00000422',
"Tfr;": '\U0001D517',
"Therefore;": '\U00002234',
"Theta;": '\U00000398',
"ThinSpace;": '\U00002009',
"Tilde;": '\U0000223C',
"TildeEqual;": '\U00002243',
"TildeFullEqual;": '\U00002245',
"TildeTilde;": '\U00002248',
"Topf;": '\U0001D54B',
"TripleDot;": '\U000020DB',
"Tscr;": '\U0001D4AF',
"Tstrok;": '\U00000166',
"Uacute;": '\U000000DA',
"Uarr;": '\U0000219F',
"Uarrocir;": '\U00002949',
"Ubrcy;": '\U0000040E',
"Ubreve;": '\U0000016C',
"Ucirc;": '\U000000DB',
"Ucy;": '\U00000423',
"Udblac;": '\U00000170',
"Ufr;": '\U0001D518',
"Ugrave;": '\U000000D9',
"Umacr;": '\U0000016A',
"UnderBar;": '\U0000005F',
"UnderBrace;": '\U000023DF',
"UnderBracket;": '\U000023B5',
"UnderParenthesis;": '\U000023DD',
"Union;": '\U000022C3',
"UnionPlus;": '\U0000228E',
"Uogon;": '\U00000172',
"Uopf;": '\U0001D54C',
"UpArrow;": '\U00002191',
"UpArrowBar;": '\U00002912',
"UpArrowDownArrow;": '\U000021C5',
"UpDownArrow;": '\U00002195',
"UpEquilibrium;": '\U0000296E',
"UpTee;": '\U000022A5',
"UpTeeArrow;": '\U000021A5',
"Uparrow;": '\U000021D1',
"Updownarrow;": '\U000021D5',
"UpperLeftArrow;": '\U00002196',
"UpperRightArrow;": '\U00002197',
"Upsi;": '\U000003D2',
"Upsilon;": '\U000003A5',
"Uring;": '\U0000016E',
"Uscr;": '\U0001D4B0',
"Utilde;": '\U00000168',
"Uuml;": '\U000000DC',
"VDash;": '\U000022AB',
"Vbar;": '\U00002AEB',
"Vcy;": '\U00000412',
"Vdash;": '\U000022A9',
"Vdashl;": '\U00002AE6',
"Vee;": '\U000022C1',
"Verbar;": '\U00002016',
"Vert;": '\U00002016',
"VerticalBar;": '\U00002223',
"VerticalLine;": '\U0000007C',
"VerticalSeparator;": '\U00002758',
"VerticalTilde;": '\U00002240',
"VeryThinSpace;": '\U0000200A',
"Vfr;": '\U0001D519',
"Vopf;": '\U0001D54D',
"Vscr;": '\U0001D4B1',
"Vvdash;": '\U000022AA',
"Wcirc;": '\U00000174',
"Wedge;": '\U000022C0',
"Wfr;": '\U0001D51A',
"Wopf;": '\U0001D54E',
"Wscr;": '\U0001D4B2',
"Xfr;": '\U0001D51B',
"Xi;": '\U0000039E',
"Xopf;": '\U0001D54F',
"Xscr;": '\U0001D4B3',
"YAcy;": '\U0000042F',
"YIcy;": '\U00000407',
"YUcy;": '\U0000042E',
"Yacute;": '\U000000DD',
"Ycirc;": '\U00000176',
"Ycy;": '\U0000042B',
"Yfr;": '\U0001D51C',
"Yopf;": '\U0001D550',
"Yscr;": '\U0001D4B4',
"Yuml;": '\U00000178',
"ZHcy;": '\U00000416',
"Zacute;": '\U00000179',
"Zcaron;": '\U0000017D',
"Zcy;": '\U00000417',
"Zdot;": '\U0000017B',
"ZeroWidthSpace;": '\U0000200B',
"Zeta;": '\U00000396',
"Zfr;": '\U00002128',
"Zopf;": '\U00002124',
"Zscr;": '\U0001D4B5',
"aacute;": '\U000000E1',
"abreve;": '\U00000103',
"ac;": '\U0000223E',
"acd;": '\U0000223F',
"acirc;": '\U000000E2',
"acute;": '\U000000B4',
"acy;": '\U00000430',
"aelig;": '\U000000E6',
"af;": '\U00002061',
"afr;": '\U0001D51E',
"agrave;": '\U000000E0',
"alefsym;": '\U00002135',
"aleph;": '\U00002135',
"alpha;": '\U000003B1',
"amacr;": '\U00000101',
"amalg;": '\U00002A3F',
"amp;": '\U00000026',
"and;": '\U00002227',
"andand;": '\U00002A55',
"andd;": '\U00002A5C',
"andslope;": '\U00002A58',
"andv;": '\U00002A5A',
"ang;": '\U00002220',
"ange;": '\U000029A4',
"angle;": '\U00002220',
"angmsd;": '\U00002221',
"angmsdaa;": '\U000029A8',
"angmsdab;": '\U000029A9',
"angmsdac;": '\U000029AA',
"angmsdad;": '\U000029AB',
"angmsdae;": '\U000029AC',
"angmsdaf;": '\U000029AD',
"angmsdag;": '\U000029AE',
"angmsdah;": '\U000029AF',
"angrt;": '\U0000221F',
"angrtvb;": '\U000022BE',
"angrtvbd;": '\U0000299D',
"angsph;": '\U00002222',
"angst;": '\U000000C5',
"angzarr;": '\U0000237C',
"aogon;": '\U00000105',
"aopf;": '\U0001D552',
"ap;": '\U00002248',
"apE;": '\U00002A70',
"apacir;": '\U00002A6F',
"ape;": '\U0000224A',
"apid;": '\U0000224B',
"apos;": '\U00000027',
"approx;": '\U00002248',
"approxeq;": '\U0000224A',
"aring;": '\U000000E5',
"ascr;": '\U0001D4B6',
"ast;": '\U0000002A',
"asymp;": '\U00002248',
"asympeq;": '\U0000224D',
"atilde;": '\U000000E3',
"auml;": '\U000000E4',
"awconint;": '\U00002233',
"awint;": '\U00002A11',
"bNot;": '\U00002AED',
"backcong;": '\U0000224C',
"backepsilon;": '\U000003F6',
"backprime;": '\U00002035',
"backsim;": '\U0000223D',
"backsimeq;": '\U000022CD',
"barvee;": '\U000022BD',
"barwed;": '\U00002305',
"barwedge;": '\U00002305',
"bbrk;": '\U000023B5',
"bbrktbrk;": '\U000023B6',
"bcong;": '\U0000224C',
"bcy;": '\U00000431',
"bdquo;": '\U0000201E',
"becaus;": '\U00002235',
"because;": '\U00002235',
"bemptyv;": '\U000029B0',
"bepsi;": '\U000003F6',
"bernou;": '\U0000212C',
"beta;": '\U000003B2',
"beth;": '\U00002136',
"between;": '\U0000226C',
"bfr;": '\U0001D51F',
"bigcap;": '\U000022C2',
"bigcirc;": '\U000025EF',
"bigcup;": '\U000022C3',
"bigodot;": '\U00002A00',
"bigoplus;": '\U00002A01',
"bigotimes;": '\U00002A02',
"bigsqcup;": '\U00002A06',
"bigstar;": '\U00002605',
"bigtriangledown;": '\U000025BD',
"bigtriangleup;": '\U000025B3',
"biguplus;": '\U00002A04',
"bigvee;": '\U000022C1',
"bigwedge;": '\U000022C0',
"bkarow;": '\U0000290D',
"blacklozenge;": '\U000029EB',
"blacksquare;": '\U000025AA',
"blacktriangle;": '\U000025B4',
"blacktriangledown;": '\U000025BE',
"blacktriangleleft;": '\U000025C2',
"blacktriangleright;": '\U000025B8',
"blank;": '\U00002423',
"blk12;": '\U00002592',
"blk14;": '\U00002591',
"blk34;": '\U00002593',
"block;": '\U00002588',
"bnot;": '\U00002310',
"bopf;": '\U0001D553',
"bot;": '\U000022A5',
"bottom;": '\U000022A5',
"bowtie;": '\U000022C8',
"boxDL;": '\U00002557',
"boxDR;": '\U00002554',
"boxDl;": '\U00002556',
"boxDr;": '\U00002553',
"boxH;": '\U00002550',
"boxHD;": '\U00002566',
"boxHU;": '\U00002569',
"boxHd;": '\U00002564',
"boxHu;": '\U00002567',
"boxUL;": '\U0000255D',
"boxUR;": '\U0000255A',
"boxUl;": '\U0000255C',
"boxUr;": '\U00002559',
"boxV;": '\U00002551',
"boxVH;": '\U0000256C',
"boxVL;": '\U00002563',
"boxVR;": '\U00002560',
"boxVh;": '\U0000256B',
"boxVl;": '\U00002562',
"boxVr;": '\U0000255F',
"boxbox;": '\U000029C9',
"boxdL;": '\U00002555',
"boxdR;": '\U00002552',
"boxdl;": '\U00002510',
"boxdr;": '\U0000250C',
"boxh;": '\U00002500',
"boxhD;": '\U00002565',
"boxhU;": '\U00002568',
"boxhd;": '\U0000252C',
"boxhu;": '\U00002534',
"boxminus;": '\U0000229F',
"boxplus;": '\U0000229E',
"boxtimes;": '\U000022A0',
"boxuL;": '\U0000255B',
"boxuR;": '\U00002558',
"boxul;": '\U00002518',
"boxur;": '\U00002514',
"boxv;": '\U00002502',
"boxvH;": '\U0000256A',
"boxvL;": '\U00002561',
"boxvR;": '\U0000255E',
"boxvh;": '\U0000253C',
"boxvl;": '\U00002524',
"boxvr;": '\U0000251C',
"bprime;": '\U00002035',
"breve;": '\U000002D8',
"brvbar;": '\U000000A6',
"bscr;": '\U0001D4B7',
"bsemi;": '\U0000204F',
"bsim;": '\U0000223D',
"bsime;": '\U000022CD',
"bsol;": '\U0000005C',
"bsolb;": '\U000029C5',
"bsolhsub;": '\U000027C8',
"bull;": '\U00002022',
"bullet;": '\U00002022',
"bump;": '\U0000224E',
"bumpE;": '\U00002AAE',
"bumpe;": '\U0000224F',
"bumpeq;": '\U0000224F',
"cacute;": '\U00000107',
"cap;": '\U00002229',
"capand;": '\U00002A44',
"capbrcup;": '\U00002A49',
"capcap;": '\U00002A4B',
"capcup;": '\U00002A47',
"capdot;": '\U00002A40',
"caret;": '\U00002041',
"caron;": '\U000002C7',
"ccaps;": '\U00002A4D',
"ccaron;": '\U0000010D',
"ccedil;": '\U000000E7',
"ccirc;": '\U00000109',
"ccups;": '\U00002A4C',
"ccupssm;": '\U00002A50',
"cdot;": '\U0000010B',
"cedil;": '\U000000B8',
"cemptyv;": '\U000029B2',
"cent;": '\U000000A2',
"centerdot;": '\U000000B7',
"cfr;": '\U0001D520',
"chcy;": '\U00000447',
"check;": '\U00002713',
"checkmark;": '\U00002713',
"chi;": '\U000003C7',
"cir;": '\U000025CB',
"cirE;": '\U000029C3',
"circ;": '\U000002C6',
"circeq;": '\U00002257',
"circlearrowleft;": '\U000021BA',
"circlearrowright;": '\U000021BB',
"circledR;": '\U000000AE',
"circledS;": '\U000024C8',
"circledast;": '\U0000229B',
"circledcirc;": '\U0000229A',
"circleddash;": '\U0000229D',
"cire;": '\U00002257',
"cirfnint;": '\U00002A10',
"cirmid;": '\U00002AEF',
"cirscir;": '\U000029C2',
"clubs;": '\U00002663',
"clubsuit;": '\U00002663',
"colon;": '\U0000003A',
"colone;": '\U00002254',
"coloneq;": '\U00002254',
"comma;": '\U0000002C',
"commat;": '\U00000040',
"comp;": '\U00002201',
"compfn;": '\U00002218',
"complement;": '\U00002201',
"complexes;": '\U00002102',
"cong;": '\U00002245',
"congdot;": '\U00002A6D',
"conint;": '\U0000222E',
"copf;": '\U0001D554',
"coprod;": '\U00002210',
"copy;": '\U000000A9',
"copysr;": '\U00002117',
"crarr;": '\U000021B5',
"cross;": '\U00002717',
"cscr;": '\U0001D4B8',
"csub;": '\U00002ACF',
"csube;": '\U00002AD1',
"csup;": '\U00002AD0',
"csupe;": '\U00002AD2',
"ctdot;": '\U000022EF',
"cudarrl;": '\U00002938',
"cudarrr;": '\U00002935',
"cuepr;": '\U000022DE',
"cuesc;": '\U000022DF',
"cularr;": '\U000021B6',
"cularrp;": '\U0000293D',
"cup;": '\U0000222A',
"cupbrcap;": '\U00002A48',
"cupcap;": '\U00002A46',
"cupcup;": '\U00002A4A',
"cupdot;": '\U0000228D',
"cupor;": '\U00002A45',
"curarr;": '\U000021B7',
"curarrm;": '\U0000293C',
"curlyeqprec;": '\U000022DE',
"curlyeqsucc;": '\U000022DF',
"curlyvee;": '\U000022CE',
"curlywedge;": '\U000022CF',
"curren;": '\U000000A4',
"curvearrowleft;": '\U000021B6',
"curvearrowright;": '\U000021B7',
"cuvee;": '\U000022CE',
"cuwed;": '\U000022CF',
"cwconint;": '\U00002232',
"cwint;": '\U00002231',
"cylcty;": '\U0000232D',
"dArr;": '\U000021D3',
"dHar;": '\U00002965',
"dagger;": '\U00002020',
"daleth;": '\U00002138',
"darr;": '\U00002193',
"dash;": '\U00002010',
"dashv;": '\U000022A3',
"dbkarow;": '\U0000290F',
"dblac;": '\U000002DD',
"dcaron;": '\U0000010F',
"dcy;": '\U00000434',
"dd;": '\U00002146',
"ddagger;": '\U00002021',
"ddarr;": '\U000021CA',
"ddotseq;": '\U00002A77',
"deg;": '\U000000B0',
"delta;": '\U000003B4',
"demptyv;": '\U000029B1',
"dfisht;": '\U0000297F',
"dfr;": '\U0001D521',
"dharl;": '\U000021C3',
"dharr;": '\U000021C2',
"diam;": '\U000022C4',
"diamond;": '\U000022C4',
"diamondsuit;": '\U00002666',
"diams;": '\U00002666',
"die;": '\U000000A8',
"digamma;": '\U000003DD',
"disin;": '\U000022F2',
"div;": '\U000000F7',
"divide;": '\U000000F7',
"divideontimes;": '\U000022C7',
"divonx;": '\U000022C7',
"djcy;": '\U00000452',
"dlcorn;": '\U0000231E',
"dlcrop;": '\U0000230D',
"dollar;": '\U00000024',
"dopf;": '\U0001D555',
"dot;": '\U000002D9',
"doteq;": '\U00002250',
"doteqdot;": '\U00002251',
"dotminus;": '\U00002238',
"dotplus;": '\U00002214',
"dotsquare;": '\U000022A1',
"doublebarwedge;": '\U00002306',
"downarrow;": '\U00002193',
"downdownarrows;": '\U000021CA',
"downharpoonleft;": '\U000021C3',
"downharpoonright;": '\U000021C2',
"drbkarow;": '\U00002910',
"drcorn;": '\U0000231F',
"drcrop;": '\U0000230C',
"dscr;": '\U0001D4B9',
"dscy;": '\U00000455',
"dsol;": '\U000029F6',
"dstrok;": '\U00000111',
"dtdot;": '\U000022F1',
"dtri;": '\U000025BF',
"dtrif;": '\U000025BE',
"duarr;": '\U000021F5',
"duhar;": '\U0000296F',
"dwangle;": '\U000029A6',
"dzcy;": '\U0000045F',
"dzigrarr;": '\U000027FF',
"eDDot;": '\U00002A77',
"eDot;": '\U00002251',
"eacute;": '\U000000E9',
"easter;": '\U00002A6E',
"ecaron;": '\U0000011B',
"ecir;": '\U00002256',
"ecirc;": '\U000000EA',
"ecolon;": '\U00002255',
"ecy;": '\U0000044D',
"edot;": '\U00000117',
"ee;": '\U00002147',
"efDot;": '\U00002252',
"efr;": '\U0001D522',
"eg;": '\U00002A9A',
"egrave;": '\U000000E8',
"egs;": '\U00002A96',
"egsdot;": '\U00002A98',
"el;": '\U00002A99',
"elinters;": '\U000023E7',
"ell;": '\U00002113',
"els;": '\U00002A95',
"elsdot;": '\U00002A97',
"emacr;": '\U00000113',
"empty;": '\U00002205',
"emptyset;": '\U00002205',
"emptyv;": '\U00002205',
"emsp;": '\U00002003',
"emsp13;": '\U00002004',
"emsp14;": '\U00002005',
"eng;": '\U0000014B',
"ensp;": '\U00002002',
"eogon;": '\U00000119',
"eopf;": '\U0001D556',
"epar;": '\U000022D5',
"eparsl;": '\U000029E3',
"eplus;": '\U00002A71',
"epsi;": '\U000003B5',
"epsilon;": '\U000003B5',
"epsiv;": '\U000003F5',
"eqcirc;": '\U00002256',
"eqcolon;": '\U00002255',
"eqsim;": '\U00002242',
"eqslantgtr;": '\U00002A96',
"eqslantless;": '\U00002A95',
"equals;": '\U0000003D',
"equest;": '\U0000225F',
"equiv;": '\U00002261',
"equivDD;": '\U00002A78',
"eqvparsl;": '\U000029E5',
"erDot;": '\U00002253',
"erarr;": '\U00002971',
"escr;": '\U0000212F',
"esdot;": '\U00002250',
"esim;": '\U00002242',
"eta;": '\U000003B7',
"eth;": '\U000000F0',
"euml;": '\U000000EB',
"euro;": '\U000020AC',
"excl;": '\U00000021',
"exist;": '\U00002203',
"expectation;": '\U00002130',
"exponentiale;": '\U00002147',
"fallingdotseq;": '\U00002252',
"fcy;": '\U00000444',
"female;": '\U00002640',
"ffilig;": '\U0000FB03',
"fflig;": '\U0000FB00',
"ffllig;": '\U0000FB04',
"ffr;": '\U0001D523',
"filig;": '\U0000FB01',
"flat;": '\U0000266D',
"fllig;": '\U0000FB02',
"fltns;": '\U000025B1',
"fnof;": '\U00000192',
"fopf;": '\U0001D557',
"forall;": '\U00002200',
"fork;": '\U000022D4',
"forkv;": '\U00002AD9',
"fpartint;": '\U00002A0D',
"frac12;": '\U000000BD',
"frac13;": '\U00002153',
"frac14;": '\U000000BC',
"frac15;": '\U00002155',
"frac16;": '\U00002159',
"frac18;": '\U0000215B',
"frac23;": '\U00002154',
"frac25;": '\U00002156',
"frac34;": '\U000000BE',
"frac35;": '\U00002157',
"frac38;": '\U0000215C',
"frac45;": '\U00002158',
"frac56;": '\U0000215A',
"frac58;": '\U0000215D',
"frac78;": '\U0000215E',
"frasl;": '\U00002044',
"frown;": '\U00002322',
"fscr;": '\U0001D4BB',
"gE;": '\U00002267',
"gEl;": '\U00002A8C',
"gacute;": '\U000001F5',
"gamma;": '\U000003B3',
"gammad;": '\U000003DD',
"gap;": '\U00002A86',
"gbreve;": '\U0000011F',
"gcirc;": '\U0000011D',
"gcy;": '\U00000433',
"gdot;": '\U00000121',
"ge;": '\U00002265',
"gel;": '\U000022DB',
"geq;": '\U00002265',
"geqq;": '\U00002267',
"geqslant;": '\U00002A7E',
"ges;": '\U00002A7E',
"gescc;": '\U00002AA9',
"gesdot;": '\U00002A80',
"gesdoto;": '\U00002A82',
"gesdotol;": '\U00002A84',
"gesles;": '\U00002A94',
"gfr;": '\U0001D524',
"gg;": '\U0000226B',
"ggg;": '\U000022D9',
"gimel;": '\U00002137',
"gjcy;": '\U00000453',
"gl;": '\U00002277',
"glE;": '\U00002A92',
"gla;": '\U00002AA5',
"glj;": '\U00002AA4',
"gnE;": '\U00002269',
"gnap;": '\U00002A8A',
"gnapprox;": '\U00002A8A',
"gne;": '\U00002A88',
"gneq;": '\U00002A88',
"gneqq;": '\U00002269',
"gnsim;": '\U000022E7',
"gopf;": '\U0001D558',
"grave;": '\U00000060',
"gscr;": '\U0000210A',
"gsim;": '\U00002273',
"gsime;": '\U00002A8E',
"gsiml;": '\U00002A90',
"gt;": '\U0000003E',
"gtcc;": '\U00002AA7',
"gtcir;": '\U00002A7A',
"gtdot;": '\U000022D7',
"gtlPar;": '\U00002995',
"gtquest;": '\U00002A7C',
"gtrapprox;": '\U00002A86',
"gtrarr;": '\U00002978',
"gtrdot;": '\U000022D7',
"gtreqless;": '\U000022DB',
"gtreqqless;": '\U00002A8C',
"gtrless;": '\U00002277',
"gtrsim;": '\U00002273',
"hArr;": '\U000021D4',
"hairsp;": '\U0000200A',
"half;": '\U000000BD',
"hamilt;": '\U0000210B',
"hardcy;": '\U0000044A',
"harr;": '\U00002194',
"harrcir;": '\U00002948',
"harrw;": '\U000021AD',
"hbar;": '\U0000210F',
"hcirc;": '\U00000125',
"hearts;": '\U00002665',
"heartsuit;": '\U00002665',
"hellip;": '\U00002026',
"hercon;": '\U000022B9',
"hfr;": '\U0001D525',
"hksearow;": '\U00002925',
"hkswarow;": '\U00002926',
"hoarr;": '\U000021FF',
"homtht;": '\U0000223B',
"hookleftarrow;": '\U000021A9',
"hookrightarrow;": '\U000021AA',
"hopf;": '\U0001D559',
"horbar;": '\U00002015',
"hscr;": '\U0001D4BD',
"hslash;": '\U0000210F',
"hstrok;": '\U00000127',
"hybull;": '\U00002043',
"hyphen;": '\U00002010',
"iacute;": '\U000000ED',
"ic;": '\U00002063',
"icirc;": '\U000000EE',
"icy;": '\U00000438',
"iecy;": '\U00000435',
"iexcl;": '\U000000A1',
"iff;": '\U000021D4',
"ifr;": '\U0001D526',
"igrave;": '\U000000EC',
"ii;": '\U00002148',
"iiiint;": '\U00002A0C',
"iiint;": '\U0000222D',
"iinfin;": '\U000029DC',
"iiota;": '\U00002129',
"ijlig;": '\U00000133',
"imacr;": '\U0000012B',
"image;": '\U00002111',
"imagline;": '\U00002110',
"imagpart;": '\U00002111',
"imath;": '\U00000131',
"imof;": '\U000022B7',
"imped;": '\U000001B5',
"in;": '\U00002208',
"incare;": '\U00002105',
"infin;": '\U0000221E',
"infintie;": '\U000029DD',
"inodot;": '\U00000131',
"int;": '\U0000222B',
"intcal;": '\U000022BA',
"integers;": '\U00002124',
"intercal;": '\U000022BA',
"intlarhk;": '\U00002A17',
"intprod;": '\U00002A3C',
"iocy;": '\U00000451',
"iogon;": '\U0000012F',
"iopf;": '\U0001D55A',
"iota;": '\U000003B9',
"iprod;": '\U00002A3C',
"iquest;": '\U000000BF',
"iscr;": '\U0001D4BE',
"isin;": '\U00002208',
"isinE;": '\U000022F9',
"isindot;": '\U000022F5',
"isins;": '\U000022F4',
"isinsv;": '\U000022F3',
"isinv;": '\U00002208',
"it;": '\U00002062',
"itilde;": '\U00000129',
"iukcy;": '\U00000456',
"iuml;": '\U000000EF',
"jcirc;": '\U00000135',
"jcy;": '\U00000439',
"jfr;": '\U0001D527',
"jmath;": '\U00000237',
"jopf;": '\U0001D55B',
"jscr;": '\U0001D4BF',
"jsercy;": '\U00000458',
"jukcy;": '\U00000454',
"kappa;": '\U000003BA',
"kappav;": '\U000003F0',
"kcedil;": '\U00000137',
"kcy;": '\U0000043A',
"kfr;": '\U0001D528',
"kgreen;": '\U00000138',
"khcy;": '\U00000445',
"kjcy;": '\U0000045C',
"kopf;": '\U0001D55C',
"kscr;": '\U0001D4C0',
"lAarr;": '\U000021DA',
"lArr;": '\U000021D0',
"lAtail;": '\U0000291B',
"lBarr;": '\U0000290E',
"lE;": '\U00002266',
"lEg;": '\U00002A8B',
"lHar;": '\U00002962',
"lacute;": '\U0000013A',
"laemptyv;": '\U000029B4',
"lagran;": '\U00002112',
"lambda;": '\U000003BB',
"lang;": '\U000027E8',
"langd;": '\U00002991',
"langle;": '\U000027E8',
"lap;": '\U00002A85',
"laquo;": '\U000000AB',
"larr;": '\U00002190',
"larrb;": '\U000021E4',
"larrbfs;": '\U0000291F',
"larrfs;": '\U0000291D',
"larrhk;": '\U000021A9',
"larrlp;": '\U000021AB',
"larrpl;": '\U00002939',
"larrsim;": '\U00002973',
"larrtl;": '\U000021A2',
"lat;": '\U00002AAB',
"latail;": '\U00002919',
"late;": '\U00002AAD',
"lbarr;": '\U0000290C',
"lbbrk;": '\U00002772',
"lbrace;": '\U0000007B',
"lbrack;": '\U0000005B',
"lbrke;": '\U0000298B',
"lbrksld;": '\U0000298F',
"lbrkslu;": '\U0000298D',
"lcaron;": '\U0000013E',
"lcedil;": '\U0000013C',
"lceil;": '\U00002308',
"lcub;": '\U0000007B',
"lcy;": '\U0000043B',
"ldca;": '\U00002936',
"ldquo;": '\U0000201C',
"ldquor;": '\U0000201E',
"ldrdhar;": '\U00002967',
"ldrushar;": '\U0000294B',
"ldsh;": '\U000021B2',
"le;": '\U00002264',
"leftarrow;": '\U00002190',
"leftarrowtail;": '\U000021A2',
"leftharpoondown;": '\U000021BD',
"leftharpoonup;": '\U000021BC',
"leftleftarrows;": '\U000021C7',
"leftrightarrow;": '\U00002194',
"leftrightarrows;": '\U000021C6',
"leftrightharpoons;": '\U000021CB',
"leftrightsquigarrow;": '\U000021AD',
"leftthreetimes;": '\U000022CB',
"leg;": '\U000022DA',
"leq;": '\U00002264',
"leqq;": '\U00002266',
"leqslant;": '\U00002A7D',
"les;": '\U00002A7D',
"lescc;": '\U00002AA8',
"lesdot;": '\U00002A7F',
"lesdoto;": '\U00002A81',
"lesdotor;": '\U00002A83',
"lesges;": '\U00002A93',
"lessapprox;": '\U00002A85',
"lessdot;": '\U000022D6',
"lesseqgtr;": '\U000022DA',
"lesseqqgtr;": '\U00002A8B',
"lessgtr;": '\U00002276',
"lesssim;": '\U00002272',
"lfisht;": '\U0000297C',
"lfloor;": '\U0000230A',
"lfr;": '\U0001D529',
"lg;": '\U00002276',
"lgE;": '\U00002A91',
"lhard;": '\U000021BD',
"lharu;": '\U000021BC',
"lharul;": '\U0000296A',
"lhblk;": '\U00002584',
"ljcy;": '\U00000459',
"ll;": '\U0000226A',
"llarr;": '\U000021C7',
"llcorner;": '\U0000231E',
"llhard;": '\U0000296B',
"lltri;": '\U000025FA',
"lmidot;": '\U00000140',
"lmoust;": '\U000023B0',
"lmoustache;": '\U000023B0',
"lnE;": '\U00002268',
"lnap;": '\U00002A89',
"lnapprox;": '\U00002A89',
"lne;": '\U00002A87',
"lneq;": '\U00002A87',
"lneqq;": '\U00002268',
"lnsim;": '\U000022E6',
"loang;": '\U000027EC',
"loarr;": '\U000021FD',
"lobrk;": '\U000027E6',
"longleftarrow;": '\U000027F5',
"longleftrightarrow;": '\U000027F7',
"longmapsto;": '\U000027FC',
"longrightarrow;": '\U000027F6',
"looparrowleft;": '\U000021AB',
"looparrowright;": '\U000021AC',
"lopar;": '\U00002985',
"lopf;": '\U0001D55D',
"loplus;": '\U00002A2D',
"lotimes;": '\U00002A34',
"lowast;": '\U00002217',
"lowbar;": '\U0000005F',
"loz;": '\U000025CA',
"lozenge;": '\U000025CA',
"lozf;": '\U000029EB',
"lpar;": '\U00000028',
"lparlt;": '\U00002993',
"lrarr;": '\U000021C6',
"lrcorner;": '\U0000231F',
"lrhar;": '\U000021CB',
"lrhard;": '\U0000296D',
"lrm;": '\U0000200E',
"lrtri;": '\U000022BF',
"lsaquo;": '\U00002039',
"lscr;": '\U0001D4C1',
"lsh;": '\U000021B0',
"lsim;": '\U00002272',
"lsime;": '\U00002A8D',
"lsimg;": '\U00002A8F',
"lsqb;": '\U0000005B',
"lsquo;": '\U00002018',
"lsquor;": '\U0000201A',
"lstrok;": '\U00000142',
"lt;": '\U0000003C',
"ltcc;": '\U00002AA6',
"ltcir;": '\U00002A79',
"ltdot;": '\U000022D6',
"lthree;": '\U000022CB',
"ltimes;": '\U000022C9',
"ltlarr;": '\U00002976',
"ltquest;": '\U00002A7B',
"ltrPar;": '\U00002996',
"ltri;": '\U000025C3',
"ltrie;": '\U000022B4',
"ltrif;": '\U000025C2',
"lurdshar;": '\U0000294A',
"luruhar;": '\U00002966',
"mDDot;": '\U0000223A',
"macr;": '\U000000AF',
"male;": '\U00002642',
"malt;": '\U00002720',
"maltese;": '\U00002720',
"map;": '\U000021A6',
"mapsto;": '\U000021A6',
"mapstodown;": '\U000021A7',
"mapstoleft;": '\U000021A4',
"mapstoup;": '\U000021A5',
"marker;": '\U000025AE',
"mcomma;": '\U00002A29',
"mcy;": '\U0000043C',
"mdash;": '\U00002014',
"measuredangle;": '\U00002221',
"mfr;": '\U0001D52A',
"mho;": '\U00002127',
"micro;": '\U000000B5',
"mid;": '\U00002223',
"midast;": '\U0000002A',
"midcir;": '\U00002AF0',
"middot;": '\U000000B7',
"minus;": '\U00002212',
"minusb;": '\U0000229F',
"minusd;": '\U00002238',
"minusdu;": '\U00002A2A',
"mlcp;": '\U00002ADB',
"mldr;": '\U00002026',
"mnplus;": '\U00002213',
"models;": '\U000022A7',
"mopf;": '\U0001D55E',
"mp;": '\U00002213',
"mscr;": '\U0001D4C2',
"mstpos;": '\U0000223E',
"mu;": '\U000003BC',
"multimap;": '\U000022B8',
"mumap;": '\U000022B8',
"nLeftarrow;": '\U000021CD',
"nLeftrightarrow;": '\U000021CE',
"nRightarrow;": '\U000021CF',
"nVDash;": '\U000022AF',
"nVdash;": '\U000022AE',
"nabla;": '\U00002207',
"nacute;": '\U00000144',
"nap;": '\U00002249',
"napos;": '\U00000149',
"napprox;": '\U00002249',
"natur;": '\U0000266E',
"natural;": '\U0000266E',
"naturals;": '\U00002115',
"nbsp;": '\U000000A0',
"ncap;": '\U00002A43',
"ncaron;": '\U00000148',
"ncedil;": '\U00000146',
"ncong;": '\U00002247',
"ncup;": '\U00002A42',
"ncy;": '\U0000043D',
"ndash;": '\U00002013',
"ne;": '\U00002260',
"neArr;": '\U000021D7',
"nearhk;": '\U00002924',
"nearr;": '\U00002197',
"nearrow;": '\U00002197',
"nequiv;": '\U00002262',
"nesear;": '\U00002928',
"nexist;": '\U00002204',
"nexists;": '\U00002204',
"nfr;": '\U0001D52B',
"nge;": '\U00002271',
"ngeq;": '\U00002271',
"ngsim;": '\U00002275',
"ngt;": '\U0000226F',
"ngtr;": '\U0000226F',
"nhArr;": '\U000021CE',
"nharr;": '\U000021AE',
"nhpar;": '\U00002AF2',
"ni;": '\U0000220B',
"nis;": '\U000022FC',
"nisd;": '\U000022FA',
"niv;": '\U0000220B',
"njcy;": '\U0000045A',
"nlArr;": '\U000021CD',
"nlarr;": '\U0000219A',
"nldr;": '\U00002025',
"nle;": '\U00002270',
"nleftarrow;": '\U0000219A',
"nleftrightarrow;": '\U000021AE',
"nleq;": '\U00002270',
"nless;": '\U0000226E',
"nlsim;": '\U00002274',
"nlt;": '\U0000226E',
"nltri;": '\U000022EA',
"nltrie;": '\U000022EC',
"nmid;": '\U00002224',
"nopf;": '\U0001D55F',
"not;": '\U000000AC',
"notin;": '\U00002209',
"notinva;": '\U00002209',
"notinvb;": '\U000022F7',
"notinvc;": '\U000022F6',
"notni;": '\U0000220C',
"notniva;": '\U0000220C',
"notnivb;": '\U000022FE',
"notnivc;": '\U000022FD',
"npar;": '\U00002226',
"nparallel;": '\U00002226',
"npolint;": '\U00002A14',
"npr;": '\U00002280',
"nprcue;": '\U000022E0',
"nprec;": '\U00002280',
"nrArr;": '\U000021CF',
"nrarr;": '\U0000219B',
"nrightarrow;": '\U0000219B',
"nrtri;": '\U000022EB',
"nrtrie;": '\U000022ED',
"nsc;": '\U00002281',
"nsccue;": '\U000022E1',
"nscr;": '\U0001D4C3',
"nshortmid;": '\U00002224',
"nshortparallel;": '\U00002226',
"nsim;": '\U00002241',
"nsime;": '\U00002244',
"nsimeq;": '\U00002244',
"nsmid;": '\U00002224',
"nspar;": '\U00002226',
"nsqsube;": '\U000022E2',
"nsqsupe;": '\U000022E3',
"nsub;": '\U00002284',
"nsube;": '\U00002288',
"nsubseteq;": '\U00002288',
"nsucc;": '\U00002281',
"nsup;": '\U00002285',
"nsupe;": '\U00002289',
"nsupseteq;": '\U00002289',
"ntgl;": '\U00002279',
"ntilde;": '\U000000F1',
"ntlg;": '\U00002278',
"ntriangleleft;": '\U000022EA',
"ntrianglelefteq;": '\U000022EC',
"ntriangleright;": '\U000022EB',
"ntrianglerighteq;": '\U000022ED',
"nu;": '\U000003BD',
"num;": '\U00000023',
"numero;": '\U00002116',
"numsp;": '\U00002007',
"nvDash;": '\U000022AD',
"nvHarr;": '\U00002904',
"nvdash;": '\U000022AC',
"nvinfin;": '\U000029DE',
"nvlArr;": '\U00002902',
"nvrArr;": '\U00002903',
"nwArr;": '\U000021D6',
"nwarhk;": '\U00002923',
"nwarr;": '\U00002196',
"nwarrow;": '\U00002196',
"nwnear;": '\U00002927',
"oS;": '\U000024C8',
"oacute;": '\U000000F3',
"oast;": '\U0000229B',
"ocir;": '\U0000229A',
"ocirc;": '\U000000F4',
"ocy;": '\U0000043E',
"odash;": '\U0000229D',
"odblac;": '\U00000151',
"odiv;": '\U00002A38',
"odot;": '\U00002299',
"odsold;": '\U000029BC',
"oelig;": '\U00000153',
"ofcir;": '\U000029BF',
"ofr;": '\U0001D52C',
"ogon;": '\U000002DB',
"ograve;": '\U000000F2',
"ogt;": '\U000029C1',
"ohbar;": '\U000029B5',
"ohm;": '\U000003A9',
"oint;": '\U0000222E',
"olarr;": '\U000021BA',
"olcir;": '\U000029BE',
"olcross;": '\U000029BB',
"oline;": '\U0000203E',
"olt;": '\U000029C0',
"omacr;": '\U0000014D',
"omega;": '\U000003C9',
"omicron;": '\U000003BF',
"omid;": '\U000029B6',
"ominus;": '\U00002296',
"oopf;": '\U0001D560',
"opar;": '\U000029B7',
"operp;": '\U000029B9',
"oplus;": '\U00002295',
"or;": '\U00002228',
"orarr;": '\U000021BB',
"ord;": '\U00002A5D',
"order;": '\U00002134',
"orderof;": '\U00002134',
"ordf;": '\U000000AA',
"ordm;": '\U000000BA',
"origof;": '\U000022B6',
"oror;": '\U00002A56',
"orslope;": '\U00002A57',
"orv;": '\U00002A5B',
"oscr;": '\U00002134',
"oslash;": '\U000000F8',
"osol;": '\U00002298',
"otilde;": '\U000000F5',
"otimes;": '\U00002297',
"otimesas;": '\U00002A36',
"ouml;": '\U000000F6',
"ovbar;": '\U0000233D',
"par;": '\U00002225',
"para;": '\U000000B6',
"parallel;": '\U00002225',
"parsim;": '\U00002AF3',
"parsl;": '\U00002AFD',
"part;": '\U00002202',
"pcy;": '\U0000043F',
"percnt;": '\U00000025',
"period;": '\U0000002E',
"permil;": '\U00002030',
"perp;": '\U000022A5',
"pertenk;": '\U00002031',
"pfr;": '\U0001D52D',
"phi;": '\U000003C6',
"phiv;": '\U000003D5',
"phmmat;": '\U00002133',
"phone;": '\U0000260E',
"pi;": '\U000003C0',
"pitchfork;": '\U000022D4',
"piv;": '\U000003D6',
"planck;": '\U0000210F',
"planckh;": '\U0000210E',
"plankv;": '\U0000210F',
"plus;": '\U0000002B',
"plusacir;": '\U00002A23',
"plusb;": '\U0000229E',
"pluscir;": '\U00002A22',
"plusdo;": '\U00002214',
"plusdu;": '\U00002A25',
"pluse;": '\U00002A72',
"plusmn;": '\U000000B1',
"plussim;": '\U00002A26',
"plustwo;": '\U00002A27',
"pm;": '\U000000B1',
"pointint;": '\U00002A15',
"popf;": '\U0001D561',
"pound;": '\U000000A3',
"pr;": '\U0000227A',
"prE;": '\U00002AB3',
"prap;": '\U00002AB7',
"prcue;": '\U0000227C',
"pre;": '\U00002AAF',
"prec;": '\U0000227A',
"precapprox;": '\U00002AB7',
"preccurlyeq;": '\U0000227C',
"preceq;": '\U00002AAF',
"precnapprox;": '\U00002AB9',
"precneqq;": '\U00002AB5',
"precnsim;": '\U000022E8',
"precsim;": '\U0000227E',
"prime;": '\U00002032',
"primes;": '\U00002119',
"prnE;": '\U00002AB5',
"prnap;": '\U00002AB9',
"prnsim;": '\U000022E8',
"prod;": '\U0000220F',
"profalar;": '\U0000232E',
"profline;": '\U00002312',
"profsurf;": '\U00002313',
"prop;": '\U0000221D',
"propto;": '\U0000221D',
"prsim;": '\U0000227E',
"prurel;": '\U000022B0',
"pscr;": '\U0001D4C5',
"psi;": '\U000003C8',
"puncsp;": '\U00002008',
"qfr;": '\U0001D52E',
"qint;": '\U00002A0C',
"qopf;": '\U0001D562',
"qprime;": '\U00002057',
"qscr;": '\U0001D4C6',
"quaternions;": '\U0000210D',
"quatint;": '\U00002A16',
"quest;": '\U0000003F',
"questeq;": '\U0000225F',
"quot;": '\U00000022',
"rAarr;": '\U000021DB',
"rArr;": '\U000021D2',
"rAtail;": '\U0000291C',
"rBarr;": '\U0000290F',
"rHar;": '\U00002964',
"racute;": '\U00000155',
"radic;": '\U0000221A',
"raemptyv;": '\U000029B3',
"rang;": '\U000027E9',
"rangd;": '\U00002992',
"range;": '\U000029A5',
"rangle;": '\U000027E9',
"raquo;": '\U000000BB',
"rarr;": '\U00002192',
"rarrap;": '\U00002975',
"rarrb;": '\U000021E5',
"rarrbfs;": '\U00002920',
"rarrc;": '\U00002933',
"rarrfs;": '\U0000291E',
"rarrhk;": '\U000021AA',
"rarrlp;": '\U000021AC',
"rarrpl;": '\U00002945',
"rarrsim;": '\U00002974',
"rarrtl;": '\U000021A3',
"rarrw;": '\U0000219D',
"ratail;": '\U0000291A',
"ratio;": '\U00002236',
"rationals;": '\U0000211A',
"rbarr;": '\U0000290D',
"rbbrk;": '\U00002773',
"rbrace;": '\U0000007D',
"rbrack;": '\U0000005D',
"rbrke;": '\U0000298C',
"rbrksld;": '\U0000298E',
"rbrkslu;": '\U00002990',
"rcaron;": '\U00000159',
"rcedil;": '\U00000157',
"rceil;": '\U00002309',
"rcub;": '\U0000007D',
"rcy;": '\U00000440',
"rdca;": '\U00002937',
"rdldhar;": '\U00002969',
"rdquo;": '\U0000201D',
"rdquor;": '\U0000201D',
"rdsh;": '\U000021B3',
"real;": '\U0000211C',
"realine;": '\U0000211B',
"realpart;": '\U0000211C',
"reals;": '\U0000211D',
"rect;": '\U000025AD',
"reg;": '\U000000AE',
"rfisht;": '\U0000297D',
"rfloor;": '\U0000230B',
"rfr;": '\U0001D52F',
"rhard;": '\U000021C1',
"rharu;": '\U000021C0',
"rharul;": '\U0000296C',
"rho;": '\U000003C1',
"rhov;": '\U000003F1',
"rightarrow;": '\U00002192',
"rightarrowtail;": '\U000021A3',
"rightharpoondown;": '\U000021C1',
"rightharpoonup;": '\U000021C0',
"rightleftarrows;": '\U000021C4',
"rightleftharpoons;": '\U000021CC',
"rightrightarrows;": '\U000021C9',
"rightsquigarrow;": '\U0000219D',
"rightthreetimes;": '\U000022CC',
"ring;": '\U000002DA',
"risingdotseq;": '\U00002253',
"rlarr;": '\U000021C4',
"rlhar;": '\U000021CC',
"rlm;": '\U0000200F',
"rmoust;": '\U000023B1',
"rmoustache;": '\U000023B1',
"rnmid;": '\U00002AEE',
"roang;": '\U000027ED',
"roarr;": '\U000021FE',
"robrk;": '\U000027E7',
"ropar;": '\U00002986',
"ropf;": '\U0001D563',
"roplus;": '\U00002A2E',
"rotimes;": '\U00002A35',
"rpar;": '\U00000029',
"rpargt;": '\U00002994',
"rppolint;": '\U00002A12',
"rrarr;": '\U000021C9',
"rsaquo;": '\U0000203A',
"rscr;": '\U0001D4C7',
"rsh;": '\U000021B1',
"rsqb;": '\U0000005D',
"rsquo;": '\U00002019',
"rsquor;": '\U00002019',
"rthree;": '\U000022CC',
"rtimes;": '\U000022CA',
"rtri;": '\U000025B9',
"rtrie;": '\U000022B5',
"rtrif;": '\U000025B8',
"rtriltri;": '\U000029CE',
"ruluhar;": '\U00002968',
"rx;": '\U0000211E',
"sacute;": '\U0000015B',
"sbquo;": '\U0000201A',
"sc;": '\U0000227B',
"scE;": '\U00002AB4',
"scap;": '\U00002AB8',
"scaron;": '\U00000161',
"sccue;": '\U0000227D',
"sce;": '\U00002AB0',
"scedil;": '\U0000015F',
"scirc;": '\U0000015D',
"scnE;": '\U00002AB6',
"scnap;": '\U00002ABA',
"scnsim;": '\U000022E9',
"scpolint;": '\U00002A13',
"scsim;": '\U0000227F',
"scy;": '\U00000441',
"sdot;": '\U000022C5',
"sdotb;": '\U000022A1',
"sdote;": '\U00002A66',
"seArr;": '\U000021D8',
"searhk;": '\U00002925',
"searr;": '\U00002198',
"searrow;": '\U00002198',
"sect;": '\U000000A7',
"semi;": '\U0000003B',
"seswar;": '\U00002929',
"setminus;": '\U00002216',
"setmn;": '\U00002216',
"sext;": '\U00002736',
"sfr;": '\U0001D530',
"sfrown;": '\U00002322',
"sharp;": '\U0000266F',
"shchcy;": '\U00000449',
"shcy;": '\U00000448',
"shortmid;": '\U00002223',
"shortparallel;": '\U00002225',
"shy;": '\U000000AD',
"sigma;": '\U000003C3',
"sigmaf;": '\U000003C2',
"sigmav;": '\U000003C2',
"sim;": '\U0000223C',
"simdot;": '\U00002A6A',
"sime;": '\U00002243',
"simeq;": '\U00002243',
"simg;": '\U00002A9E',
"simgE;": '\U00002AA0',
"siml;": '\U00002A9D',
"simlE;": '\U00002A9F',
"simne;": '\U00002246',
"simplus;": '\U00002A24',
"simrarr;": '\U00002972',
"slarr;": '\U00002190',
"smallsetminus;": '\U00002216',
"smashp;": '\U00002A33',
"smeparsl;": '\U000029E4',
"smid;": '\U00002223',
"smile;": '\U00002323',
"smt;": '\U00002AAA',
"smte;": '\U00002AAC',
"softcy;": '\U0000044C',
"sol;": '\U0000002F',
"solb;": '\U000029C4',
"solbar;": '\U0000233F',
"sopf;": '\U0001D564',
"spades;": '\U00002660',
"spadesuit;": '\U00002660',
"spar;": '\U00002225',
"sqcap;": '\U00002293',
"sqcup;": '\U00002294',
"sqsub;": '\U0000228F',
"sqsube;": '\U00002291',
"sqsubset;": '\U0000228F',
"sqsubseteq;": '\U00002291',
"sqsup;": '\U00002290',
"sqsupe;": '\U00002292',
"sqsupset;": '\U00002290',
"sqsupseteq;": '\U00002292',
"squ;": '\U000025A1',
"square;": '\U000025A1',
"squarf;": '\U000025AA',
"squf;": '\U000025AA',
"srarr;": '\U00002192',
"sscr;": '\U0001D4C8',
"ssetmn;": '\U00002216',
"ssmile;": '\U00002323',
"sstarf;": '\U000022C6',
"star;": '\U00002606',
"starf;": '\U00002605',
"straightepsilon;": '\U000003F5',
"straightphi;": '\U000003D5',
"strns;": '\U000000AF',
"sub;": '\U00002282',
"subE;": '\U00002AC5',
"subdot;": '\U00002ABD',
"sube;": '\U00002286',
"subedot;": '\U00002AC3',
"submult;": '\U00002AC1',
"subnE;": '\U00002ACB',
"subne;": '\U0000228A',
"subplus;": '\U00002ABF',
"subrarr;": '\U00002979',
"subset;": '\U00002282',
"subseteq;": '\U00002286',
"subseteqq;": '\U00002AC5',
"subsetneq;": '\U0000228A',
"subsetneqq;": '\U00002ACB',
"subsim;": '\U00002AC7',
"subsub;": '\U00002AD5',
"subsup;": '\U00002AD3',
"succ;": '\U0000227B',
"succapprox;": '\U00002AB8',
"succcurlyeq;": '\U0000227D',
"succeq;": '\U00002AB0',
"succnapprox;": '\U00002ABA',
"succneqq;": '\U00002AB6',
"succnsim;": '\U000022E9',
"succsim;": '\U0000227F',
"sum;": '\U00002211',
"sung;": '\U0000266A',
"sup;": '\U00002283',
"sup1;": '\U000000B9',
"sup2;": '\U000000B2',
"sup3;": '\U000000B3',
"supE;": '\U00002AC6',
"supdot;": '\U00002ABE',
"supdsub;": '\U00002AD8',
"supe;": '\U00002287',
"supedot;": '\U00002AC4',
"suphsol;": '\U000027C9',
"suphsub;": '\U00002AD7',
"suplarr;": '\U0000297B',
"supmult;": '\U00002AC2',
"supnE;": '\U00002ACC',
"supne;": '\U0000228B',
"supplus;": '\U00002AC0',
"supset;": '\U00002283',
"supseteq;": '\U00002287',
"supseteqq;": '\U00002AC6',
"supsetneq;": '\U0000228B',
"supsetneqq;": '\U00002ACC',
"supsim;": '\U00002AC8',
"supsub;": '\U00002AD4',
"supsup;": '\U00002AD6',
"swArr;": '\U000021D9',
"swarhk;": '\U00002926',
"swarr;": '\U00002199',
"swarrow;": '\U00002199',
"swnwar;": '\U0000292A',
"szlig;": '\U000000DF',
"target;": '\U00002316',
"tau;": '\U000003C4',
"tbrk;": '\U000023B4',
"tcaron;": '\U00000165',
"tcedil;": '\U00000163',
"tcy;": '\U00000442',
"tdot;": '\U000020DB',
"telrec;": '\U00002315',
"tfr;": '\U0001D531',
"there4;": '\U00002234',
"therefore;": '\U00002234',
"theta;": '\U000003B8',
"thetasym;": '\U000003D1',
"thetav;": '\U000003D1',
"thickapprox;": '\U00002248',
"thicksim;": '\U0000223C',
"thinsp;": '\U00002009',
"thkap;": '\U00002248',
"thksim;": '\U0000223C',
"thorn;": '\U000000FE',
"tilde;": '\U000002DC',
"times;": '\U000000D7',
"timesb;": '\U000022A0',
"timesbar;": '\U00002A31',
"timesd;": '\U00002A30',
"tint;": '\U0000222D',
"toea;": '\U00002928',
"top;": '\U000022A4',
"topbot;": '\U00002336',
"topcir;": '\U00002AF1',
"topf;": '\U0001D565',
"topfork;": '\U00002ADA',
"tosa;": '\U00002929',
"tprime;": '\U00002034',
"trade;": '\U00002122',
"triangle;": '\U000025B5',
"triangledown;": '\U000025BF',
"triangleleft;": '\U000025C3',
"trianglelefteq;": '\U000022B4',
"triangleq;": '\U0000225C',
"triangleright;": '\U000025B9',
"trianglerighteq;": '\U000022B5',
"tridot;": '\U000025EC',
"trie;": '\U0000225C',
"triminus;": '\U00002A3A',
"triplus;": '\U00002A39',
"trisb;": '\U000029CD',
"tritime;": '\U00002A3B',
"trpezium;": '\U000023E2',
"tscr;": '\U0001D4C9',
"tscy;": '\U00000446',
"tshcy;": '\U0000045B',
"tstrok;": '\U00000167',
"twixt;": '\U0000226C',
"twoheadleftarrow;": '\U0000219E',
"twoheadrightarrow;": '\U000021A0',
"uArr;": '\U000021D1',
"uHar;": '\U00002963',
"uacute;": '\U000000FA',
"uarr;": '\U00002191',
"ubrcy;": '\U0000045E',
"ubreve;": '\U0000016D',
"ucirc;": '\U000000FB',
"ucy;": '\U00000443',
"udarr;": '\U000021C5',
"udblac;": '\U00000171',
"udhar;": '\U0000296E',
"ufisht;": '\U0000297E',
"ufr;": '\U0001D532',
"ugrave;": '\U000000F9',
"uharl;": '\U000021BF',
"uharr;": '\U000021BE',
"uhblk;": '\U00002580',
"ulcorn;": '\U0000231C',
"ulcorner;": '\U0000231C',
"ulcrop;": '\U0000230F',
"ultri;": '\U000025F8',
"umacr;": '\U0000016B',
"uml;": '\U000000A8',
"uogon;": '\U00000173',
"uopf;": '\U0001D566',
"uparrow;": '\U00002191',
"updownarrow;": '\U00002195',
"upharpoonleft;": '\U000021BF',
"upharpoonright;": '\U000021BE',
"uplus;": '\U0000228E',
"upsi;": '\U000003C5',
"upsih;": '\U000003D2',
"upsilon;": '\U000003C5',
"upuparrows;": '\U000021C8',
"urcorn;": '\U0000231D',
"urcorner;": '\U0000231D',
"urcrop;": '\U0000230E',
"uring;": '\U0000016F',
"urtri;": '\U000025F9',
"uscr;": '\U0001D4CA',
"utdot;": '\U000022F0',
"utilde;": '\U00000169',
"utri;": '\U000025B5',
"utrif;": '\U000025B4',
"uuarr;": '\U000021C8',
"uuml;": '\U000000FC',
"uwangle;": '\U000029A7',
"vArr;": '\U000021D5',
"vBar;": '\U00002AE8',
"vBarv;": '\U00002AE9',
"vDash;": '\U000022A8',
"vangrt;": '\U0000299C',
"varepsilon;": '\U000003F5',
"varkappa;": '\U000003F0',
"varnothing;": '\U00002205',
"varphi;": '\U000003D5',
"varpi;": '\U000003D6',
"varpropto;": '\U0000221D',
"varr;": '\U00002195',
"varrho;": '\U000003F1',
"varsigma;": '\U000003C2',
"vartheta;": '\U000003D1',
"vartriangleleft;": '\U000022B2',
"vartriangleright;": '\U000022B3',
"vcy;": '\U00000432',
"vdash;": '\U000022A2',
"vee;": '\U00002228',
"veebar;": '\U000022BB',
"veeeq;": '\U0000225A',
"vellip;": '\U000022EE',
"verbar;": '\U0000007C',
"vert;": '\U0000007C',
"vfr;": '\U0001D533',
"vltri;": '\U000022B2',
"vopf;": '\U0001D567',
"vprop;": '\U0000221D',
"vrtri;": '\U000022B3',
"vscr;": '\U0001D4CB',
"vzigzag;": '\U0000299A',
"wcirc;": '\U00000175',
"wedbar;": '\U00002A5F',
"wedge;": '\U00002227',
"wedgeq;": '\U00002259',
"weierp;": '\U00002118',
"wfr;": '\U0001D534',
"wopf;": '\U0001D568',
"wp;": '\U00002118',
"wr;": '\U00002240',
"wreath;": '\U00002240',
"wscr;": '\U0001D4CC',
"xcap;": '\U000022C2',
"xcirc;": '\U000025EF',
"xcup;": '\U000022C3',
"xdtri;": '\U000025BD',
"xfr;": '\U0001D535',
"xhArr;": '\U000027FA',
"xharr;": '\U000027F7',
"xi;": '\U000003BE',
"xlArr;": '\U000027F8',
"xlarr;": '\U000027F5',
"xmap;": '\U000027FC',
"xnis;": '\U000022FB',
"xodot;": '\U00002A00',
"xopf;": '\U0001D569',
"xoplus;": '\U00002A01',
"xotime;": '\U00002A02',
"xrArr;": '\U000027F9',
"xrarr;": '\U000027F6',
"xscr;": '\U0001D4CD',
"xsqcup;": '\U00002A06',
"xuplus;": '\U00002A04',
"xutri;": '\U000025B3',
"xvee;": '\U000022C1',
"xwedge;": '\U000022C0',
"yacute;": '\U000000FD',
"yacy;": '\U0000044F',
"ycirc;": '\U00000177',
"ycy;": '\U0000044B',
"yen;": '\U000000A5',
"yfr;": '\U0001D536',
"yicy;": '\U00000457',
"yopf;": '\U0001D56A',
"yscr;": '\U0001D4CE',
"yucy;": '\U0000044E',
"yuml;": '\U000000FF',
"zacute;": '\U0000017A',
"zcaron;": '\U0000017E',
"zcy;": '\U00000437',
"zdot;": '\U0000017C',
"zeetrf;": '\U00002128',
"zeta;": '\U000003B6',
"zfr;": '\U0001D537',
"zhcy;": '\U00000436',
"zigrarr;": '\U000021DD',
"zopf;": '\U0001D56B',
"zscr;": '\U0001D4CF',
"zwj;": '\U0000200D',
"zwnj;": '\U0000200C',
"AElig": '\U000000C6',
"AMP": '\U00000026',
"Aacute": '\U000000C1',
"Acirc": '\U000000C2',
"Agrave": '\U000000C0',
"Aring": '\U000000C5',
"Atilde": '\U000000C3',
"Auml": '\U000000C4',
"COPY": '\U000000A9',
"Ccedil": '\U000000C7',
"ETH": '\U000000D0',
"Eacute": '\U000000C9',
"Ecirc": '\U000000CA',
"Egrave": '\U000000C8',
"Euml": '\U000000CB',
"GT": '\U0000003E',
"Iacute": '\U000000CD',
"Icirc": '\U000000CE',
"Igrave": '\U000000CC',
"Iuml": '\U000000CF',
"LT": '\U0000003C',
"Ntilde": '\U000000D1',
"Oacute": '\U000000D3',
"Ocirc": '\U000000D4',
"Ograve": '\U000000D2',
"Oslash": '\U000000D8',
"Otilde": '\U000000D5',
"Ouml": '\U000000D6',
"QUOT": '\U00000022',
"REG": '\U000000AE',
"THORN": '\U000000DE',
"Uacute": '\U000000DA',
"Ucirc": '\U000000DB',
"Ugrave": '\U000000D9',
"Uuml": '\U000000DC',
"Yacute": '\U000000DD',
"aacute": '\U000000E1',
"acirc": '\U000000E2',
"acute": '\U000000B4',
"aelig": '\U000000E6',
"agrave": '\U000000E0',
"amp": '\U00000026',
"aring": '\U000000E5',
"atilde": '\U000000E3',
"auml": '\U000000E4',
"brvbar": '\U000000A6',
"ccedil": '\U000000E7',
"cedil": '\U000000B8',
"cent": '\U000000A2',
"copy": '\U000000A9',
"curren": '\U000000A4',
"deg": '\U000000B0',
"divide": '\U000000F7',
"eacute": '\U000000E9',
"ecirc": '\U000000EA',
"egrave": '\U000000E8',
"eth": '\U000000F0',
"euml": '\U000000EB',
"frac12": '\U000000BD',
"frac14": '\U000000BC',
"frac34": '\U000000BE',
"gt": '\U0000003E',
"iacute": '\U000000ED',
"icirc": '\U000000EE',
"iexcl": '\U000000A1',
"igrave": '\U000000EC',
"iquest": '\U000000BF',
"iuml": '\U000000EF',
"laquo": '\U000000AB',
"lt": '\U0000003C',
"macr": '\U000000AF',
"micro": '\U000000B5',
"middot": '\U000000B7',
"nbsp": '\U000000A0',
"not": '\U000000AC',
"ntilde": '\U000000F1',
"oacute": '\U000000F3',
"ocirc": '\U000000F4',
"ograve": '\U000000F2',
"ordf": '\U000000AA',
"ordm": '\U000000BA',
"oslash": '\U000000F8',
"otilde": '\U000000F5',
"ouml": '\U000000F6',
"para": '\U000000B6',
"plusmn": '\U000000B1',
"pound": '\U000000A3',
"quot": '\U00000022',
"raquo": '\U000000BB',
"reg": '\U000000AE',
"sect": '\U000000A7',
"shy": '\U000000AD',
"sup1": '\U000000B9',
"sup2": '\U000000B2',
"sup3": '\U000000B3',
"szlig": '\U000000DF',
"thorn": '\U000000FE',
"times": '\U000000D7',
"uacute": '\U000000FA',
"ucirc": '\U000000FB',
"ugrave": '\U000000F9',
"uml": '\U000000A8',
"uuml": '\U000000FC',
"yacute": '\U000000FD',
"yen": '\U000000A5',
"yuml": '\U000000FF',
}
entity2 = map[string][2]rune{
// TODO(nigeltao): Handle replacements that are wider than their names.
// "nLt;": {'\u226A', '\u20D2'},
// "nGt;": {'\u226B', '\u20D2'},
"NotEqualTilde;": {'\u2242', '\u0338'},
"NotGreaterFullEqual;": {'\u2267', '\u0338'},
"NotGreaterGreater;": {'\u226B', '\u0338'},
"NotGreaterSlantEqual;": {'\u2A7E', '\u0338'},
"NotHumpDownHump;": {'\u224E', '\u0338'},
"NotHumpEqual;": {'\u224F', '\u0338'},
"NotLeftTriangleBar;": {'\u29CF', '\u0338'},
"NotLessLess;": {'\u226A', '\u0338'},
"NotLessSlantEqual;": {'\u2A7D', '\u0338'},
"NotNestedGreaterGreater;": {'\u2AA2', '\u0338'},
"NotNestedLessLess;": {'\u2AA1', '\u0338'},
"NotPrecedesEqual;": {'\u2AAF', '\u0338'},
"NotRightTriangleBar;": {'\u29D0', '\u0338'},
"NotSquareSubset;": {'\u228F', '\u0338'},
"NotSquareSuperset;": {'\u2290', '\u0338'},
"NotSubset;": {'\u2282', '\u20D2'},
"NotSucceedsEqual;": {'\u2AB0', '\u0338'},
"NotSucceedsTilde;": {'\u227F', '\u0338'},
"NotSuperset;": {'\u2283', '\u20D2'},
"ThickSpace;": {'\u205F', '\u200A'},
"acE;": {'\u223E', '\u0333'},
"bne;": {'\u003D', '\u20E5'},
"bnequiv;": {'\u2261', '\u20E5'},
"caps;": {'\u2229', '\uFE00'},
"cups;": {'\u222A', '\uFE00'},
"fjlig;": {'\u0066', '\u006A'},
"gesl;": {'\u22DB', '\uFE00'},
"gvertneqq;": {'\u2269', '\uFE00'},
"gvnE;": {'\u2269', '\uFE00'},
"lates;": {'\u2AAD', '\uFE00'},
"lesg;": {'\u22DA', '\uFE00'},
"lvertneqq;": {'\u2268', '\uFE00'},
"lvnE;": {'\u2268', '\uFE00'},
"nGg;": {'\u22D9', '\u0338'},
"nGtv;": {'\u226B', '\u0338'},
"nLl;": {'\u22D8', '\u0338'},
"nLtv;": {'\u226A', '\u0338'},
"nang;": {'\u2220', '\u20D2'},
"napE;": {'\u2A70', '\u0338'},
"napid;": {'\u224B', '\u0338'},
"nbump;": {'\u224E', '\u0338'},
"nbumpe;": {'\u224F', '\u0338'},
"ncongdot;": {'\u2A6D', '\u0338'},
"nedot;": {'\u2250', '\u0338'},
"nesim;": {'\u2242', '\u0338'},
"ngE;": {'\u2267', '\u0338'},
"ngeqq;": {'\u2267', '\u0338'},
"ngeqslant;": {'\u2A7E', '\u0338'},
"nges;": {'\u2A7E', '\u0338'},
"nlE;": {'\u2266', '\u0338'},
"nleqq;": {'\u2266', '\u0338'},
"nleqslant;": {'\u2A7D', '\u0338'},
"nles;": {'\u2A7D', '\u0338'},
"notinE;": {'\u22F9', '\u0338'},
"notindot;": {'\u22F5', '\u0338'},
"nparsl;": {'\u2AFD', '\u20E5'},
"npart;": {'\u2202', '\u0338'},
"npre;": {'\u2AAF', '\u0338'},
"npreceq;": {'\u2AAF', '\u0338'},
"nrarrc;": {'\u2933', '\u0338'},
"nrarrw;": {'\u219D', '\u0338'},
"nsce;": {'\u2AB0', '\u0338'},
"nsubE;": {'\u2AC5', '\u0338'},
"nsubset;": {'\u2282', '\u20D2'},
"nsubseteqq;": {'\u2AC5', '\u0338'},
"nsucceq;": {'\u2AB0', '\u0338'},
"nsupE;": {'\u2AC6', '\u0338'},
"nsupset;": {'\u2283', '\u20D2'},
"nsupseteqq;": {'\u2AC6', '\u0338'},
"nvap;": {'\u224D', '\u20D2'},
"nvge;": {'\u2265', '\u20D2'},
"nvgt;": {'\u003E', '\u20D2'},
"nvle;": {'\u2264', '\u20D2'},
"nvlt;": {'\u003C', '\u20D2'},
"nvltrie;": {'\u22B4', '\u20D2'},
"nvrtrie;": {'\u22B5', '\u20D2'},
"nvsim;": {'\u223C', '\u20D2'},
"race;": {'\u223D', '\u0331'},
"smtes;": {'\u2AAC', '\uFE00'},
"sqcaps;": {'\u2293', '\uFE00'},
"sqcups;": {'\u2294', '\uFE00'},
"varsubsetneq;": {'\u228A', '\uFE00'},
"varsubsetneqq;": {'\u2ACB', '\uFE00'},
"varsupsetneq;": {'\u228B', '\uFE00'},
"varsupsetneqq;": {'\u2ACC', '\uFE00'},
"vnsub;": {'\u2282', '\u20D2'},
"vnsup;": {'\u2283', '\u20D2'},
"vsubnE;": {'\u2ACB', '\uFE00'},
"vsubne;": {'\u228A', '\uFE00'},
"vsupnE;": {'\u2ACC', '\uFE00'},
"vsupne;": {'\u228B', '\uFE00'},
}
return entity, entity2
})
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package html provides functions for escaping and unescaping HTML text.
package html
import (
"strings"
"unicode/utf8"
)
// These replacements permit compatibility with old numeric entities that
// assumed Windows-1252 encoding.
// https://html.spec.whatwg.org/multipage/parsing.html#numeric-character-reference-end-state
var replacementTable = [...]rune{
'\u20AC', // First entry is what 0x80 should be replaced with.
'\u0081',
'\u201A',
'\u0192',
'\u201E',
'\u2026',
'\u2020',
'\u2021',
'\u02C6',
'\u2030',
'\u0160',
'\u2039',
'\u0152',
'\u008D',
'\u017D',
'\u008F',
'\u0090',
'\u2018',
'\u2019',
'\u201C',
'\u201D',
'\u2022',
'\u2013',
'\u2014',
'\u02DC',
'\u2122',
'\u0161',
'\u203A',
'\u0153',
'\u009D',
'\u017E',
'\u0178', // Last entry is 0x9F.
// 0x00->'\uFFFD' is handled programmatically.
// 0x0D->'\u000D' is a no-op.
}
// unescapeEntity reads an entity like "<" from b[src:] and writes the
// corresponding "<" to b[dst:], returning the incremented dst and src cursors.
// Precondition: b[src] == '&' && dst <= src.
func unescapeEntity(b []byte, dst, src int, entity map[string]rune, entity2 map[string][2]rune) (dst1, src1 int) {
const attribute = false
// http://www.whatwg.org/specs/web-apps/current-work/multipage/tokenization.html#consume-a-character-reference
// i starts at 1 because we already know that s[0] == '&'.
i, s := 1, b[src:]
if len(s) <= 1 {
b[dst] = b[src]
return dst + 1, src + 1
}
if s[i] == '#' {
if len(s) <= 3 { // We need to have at least "&#.".
b[dst] = b[src]
return dst + 1, src + 1
}
i++
c := s[i]
hex := false
if c == 'x' || c == 'X' {
hex = true
i++
}
x := '\x00'
for i < len(s) {
c = s[i]
i++
if hex {
if '0' <= c && c <= '9' {
x = 16*x + rune(c) - '0'
continue
} else if 'a' <= c && c <= 'f' {
x = 16*x + rune(c) - 'a' + 10
continue
} else if 'A' <= c && c <= 'F' {
x = 16*x + rune(c) - 'A' + 10
continue
}
} else if '0' <= c && c <= '9' {
x = 10*x + rune(c) - '0'
continue
}
if c != ';' {
i--
}
break
}
if i <= 3 { // No characters matched.
b[dst] = b[src]
return dst + 1, src + 1
}
if 0x80 <= x && x <= 0x9F {
// Replace characters from Windows-1252 with UTF-8 equivalents.
x = replacementTable[x-0x80]
} else if x == 0 || (0xD800 <= x && x <= 0xDFFF) || x > 0x10FFFF {
// Replace invalid characters with the replacement character.
x = '\uFFFD'
}
return dst + utf8.EncodeRune(b[dst:], x), src + i
}
// Consume the maximum number of characters possible, with the
// consumed characters matching one of the named references.
for i < len(s) {
c := s[i]
i++
// Lower-cased characters are more common in entities, so we check for them first.
if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' {
continue
}
if c != ';' {
i--
}
break
}
entityName := s[1:i]
if len(entityName) == 0 {
// No-op.
} else if attribute && entityName[len(entityName)-1] != ';' && len(s) > i && s[i] == '=' {
// No-op.
} else if x := entity[string(entityName)]; x != 0 {
return dst + utf8.EncodeRune(b[dst:], x), src + i
} else if x := entity2[string(entityName)]; x[0] != 0 {
dst1 := dst + utf8.EncodeRune(b[dst:], x[0])
return dst1 + utf8.EncodeRune(b[dst1:], x[1]), src + i
} else if !attribute {
maxLen := len(entityName) - 1
if maxLen > longestEntityWithoutSemicolon {
maxLen = longestEntityWithoutSemicolon
}
for j := maxLen; j > 1; j-- {
if x := entity[string(entityName[:j])]; x != 0 {
return dst + utf8.EncodeRune(b[dst:], x), src + j + 1
}
}
}
dst1, src1 = dst+i, src+i
copy(b[dst:dst1], b[src:src1])
return dst1, src1
}
var htmlEscaper = strings.NewReplacer(
`&`, "&",
`'`, "'", // "'" is shorter than "'" and apos was not in HTML until HTML5.
`<`, "<",
`>`, ">",
`"`, """, // """ is shorter than """.
)
// EscapeString escapes special characters like "<" to become "<". It
// escapes only five such characters: <, >, &, ' and ".
// [UnescapeString](EscapeString(s)) == s always holds, but the converse isn't
// always true.
func EscapeString(s string) string {
return htmlEscaper.Replace(s)
}
// UnescapeString unescapes entities like "<" to become "<". It unescapes a
// larger range of entities than [EscapeString] escapes. For example, "á"
// unescapes to "á", as does "á" and "á".
// UnescapeString([EscapeString](s)) == s always holds, but the converse isn't
// always true.
func UnescapeString(s string) string {
i := strings.IndexByte(s, '&')
if i < 0 {
return s
}
b := []byte(s)
entity, entity2 := entityMaps()
dst, src := unescapeEntity(b, i, i, entity, entity2)
for len(s[src:]) > 0 {
if s[src] == '&' {
i = 0
} else {
i = strings.IndexByte(s[src:], '&')
}
if i < 0 {
dst += copy(b[dst:], s[src:])
break
}
if i > 0 {
copy(b[dst:], s[src:src+i])
}
dst, src = unescapeEntity(b, dst+i, src+i, entity, entity2)
}
return string(b[:dst])
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"strings"
)
// attrTypeMap[n] describes the value of the given attribute.
// If an attribute affects (or can mask) the encoding or interpretation of
// other content, or affects the contents, idempotency, or credentials of a
// network message, then the value in this map is contentTypeUnsafe.
// This map is derived from HTML5, specifically
// https://www.w3.org/TR/html5/Overview.html#attributes-1
// as well as "%URI"-typed attributes from
// https://www.w3.org/TR/html4/index/attributes.html
var attrTypeMap = map[string]contentType{
"accept": contentTypePlain,
"accept-charset": contentTypeUnsafe,
"action": contentTypeURL,
"alt": contentTypePlain,
"archive": contentTypeURL,
"async": contentTypeUnsafe,
"autocomplete": contentTypePlain,
"autofocus": contentTypePlain,
"autoplay": contentTypePlain,
"background": contentTypeURL,
"border": contentTypePlain,
"checked": contentTypePlain,
"cite": contentTypeURL,
"challenge": contentTypeUnsafe,
"charset": contentTypeUnsafe,
"class": contentTypePlain,
"classid": contentTypeURL,
"codebase": contentTypeURL,
"cols": contentTypePlain,
"colspan": contentTypePlain,
"content": contentTypeUnsafe,
"contenteditable": contentTypePlain,
"contextmenu": contentTypePlain,
"controls": contentTypePlain,
"coords": contentTypePlain,
"crossorigin": contentTypeUnsafe,
"data": contentTypeURL,
"datetime": contentTypePlain,
"default": contentTypePlain,
"defer": contentTypeUnsafe,
"dir": contentTypePlain,
"dirname": contentTypePlain,
"disabled": contentTypePlain,
"draggable": contentTypePlain,
"dropzone": contentTypePlain,
"enctype": contentTypeUnsafe,
"for": contentTypePlain,
"form": contentTypeUnsafe,
"formaction": contentTypeURL,
"formenctype": contentTypeUnsafe,
"formmethod": contentTypeUnsafe,
"formnovalidate": contentTypeUnsafe,
"formtarget": contentTypePlain,
"headers": contentTypePlain,
"height": contentTypePlain,
"hidden": contentTypePlain,
"high": contentTypePlain,
"href": contentTypeURL,
"hreflang": contentTypePlain,
"http-equiv": contentTypeUnsafe,
"icon": contentTypeURL,
"id": contentTypePlain,
"ismap": contentTypePlain,
"keytype": contentTypeUnsafe,
"kind": contentTypePlain,
"label": contentTypePlain,
"lang": contentTypePlain,
"language": contentTypeUnsafe,
"list": contentTypePlain,
"longdesc": contentTypeURL,
"loop": contentTypePlain,
"low": contentTypePlain,
"manifest": contentTypeURL,
"max": contentTypePlain,
"maxlength": contentTypePlain,
"media": contentTypePlain,
"mediagroup": contentTypePlain,
"method": contentTypeUnsafe,
"min": contentTypePlain,
"multiple": contentTypePlain,
"name": contentTypePlain,
"novalidate": contentTypeUnsafe,
// Skip handler names from
// https://www.w3.org/TR/html5/webappapis.html#event-handlers-on-elements,-document-objects,-and-window-objects
// since we have special handling in attrType.
"open": contentTypePlain,
"optimum": contentTypePlain,
"pattern": contentTypeUnsafe,
"placeholder": contentTypePlain,
"poster": contentTypeURL,
"profile": contentTypeURL,
"preload": contentTypePlain,
"pubdate": contentTypePlain,
"radiogroup": contentTypePlain,
"readonly": contentTypePlain,
"rel": contentTypeUnsafe,
"required": contentTypePlain,
"reversed": contentTypePlain,
"rows": contentTypePlain,
"rowspan": contentTypePlain,
"sandbox": contentTypeUnsafe,
"spellcheck": contentTypePlain,
"scope": contentTypePlain,
"scoped": contentTypePlain,
"seamless": contentTypePlain,
"selected": contentTypePlain,
"shape": contentTypePlain,
"size": contentTypePlain,
"sizes": contentTypePlain,
"span": contentTypePlain,
"src": contentTypeURL,
"srcdoc": contentTypeHTML,
"srclang": contentTypePlain,
"srcset": contentTypeSrcset,
"start": contentTypePlain,
"step": contentTypePlain,
"style": contentTypeCSS,
"tabindex": contentTypePlain,
"target": contentTypePlain,
"title": contentTypePlain,
"type": contentTypeUnsafe,
"usemap": contentTypeURL,
"value": contentTypeUnsafe,
"width": contentTypePlain,
"wrap": contentTypePlain,
"xmlns": contentTypeURL,
}
// attrType returns a conservative (upper-bound on authority) guess at the
// type of the lowercase named attribute.
func attrType(name string) contentType {
if strings.HasPrefix(name, "data-") {
// Strip data- so that custom attribute heuristics below are
// widely applied.
// Treat data-action as URL below.
name = name[5:]
} else if prefix, short, ok := strings.Cut(name, ":"); ok {
if prefix == "xmlns" {
return contentTypeURL
}
// Treat svg:href and xlink:href as href below.
name = short
}
if t, ok := attrTypeMap[name]; ok {
return t
}
// Treat partial event handler names as script.
if strings.HasPrefix(name, "on") {
return contentTypeJS
}
// Heuristics to prevent "javascript:..." injection in custom
// data attributes and custom attributes like g:tweetUrl.
// https://www.w3.org/TR/html5/dom.html#embedding-custom-non-visible-data-with-the-data-*-attributes
// "Custom data attributes are intended to store custom data
// private to the page or application, for which there are no
// more appropriate attributes or elements."
// Developers seem to store URL content in data URLs that start
// or end with "URI" or "URL".
if strings.Contains(name, "src") ||
strings.Contains(name, "uri") ||
strings.Contains(name, "url") {
return contentTypeURL
}
return contentTypePlain
}
// Code generated by "stringer -type attr"; DO NOT EDIT.
package template
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[attrNone-0]
_ = x[attrScript-1]
_ = x[attrScriptType-2]
_ = x[attrStyle-3]
_ = x[attrURL-4]
_ = x[attrSrcset-5]
}
const _attr_name = "attrNoneattrScriptattrScriptTypeattrStyleattrURLattrSrcset"
var _attr_index = [...]uint8{0, 8, 18, 32, 41, 48, 58}
func (i attr) String() string {
if i >= attr(len(_attr_index)-1) {
return "attr(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _attr_name[_attr_index[i]:_attr_index[i+1]]
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"fmt"
"reflect"
)
// Strings of content from a trusted source.
type (
// CSS encapsulates known safe content that matches any of:
// 1. The CSS3 stylesheet production, such as `p { color: purple }`.
// 2. The CSS3 rule production, such as `a[href=~"https:"].foo#bar`.
// 3. CSS3 declaration productions, such as `color: red; margin: 2px`.
// 4. The CSS3 value production, such as `rgba(0, 0, 255, 127)`.
// See https://www.w3.org/TR/css3-syntax/#parsing and
// https://web.archive.org/web/20090211114933/http://w3.org/TR/css3-syntax#style
//
// Use of this type presents a security risk:
// the encapsulated content should come from a trusted source,
// as it will be included verbatim in the template output.
CSS string
// HTML encapsulates a known safe HTML document fragment.
// It should not be used for HTML from a third-party, or HTML with
// unclosed tags or comments. The outputs of a sound HTML sanitizer
// and a template escaped by this package are fine for use with HTML.
//
// Use of this type presents a security risk:
// the encapsulated content should come from a trusted source,
// as it will be included verbatim in the template output.
HTML string
// HTMLAttr encapsulates an HTML attribute from a trusted source,
// for example, ` dir="ltr"`.
//
// Use of this type presents a security risk:
// the encapsulated content should come from a trusted source,
// as it will be included verbatim in the template output.
HTMLAttr string
// JS encapsulates a known safe EcmaScript5 Expression, for example,
// `(x + y * z())`.
// Template authors are responsible for ensuring that typed expressions
// do not break the intended precedence and that there is no
// statement/expression ambiguity as when passing an expression like
// "{ foo: bar() }\n['foo']()", which is both a valid Expression and a
// valid Program with a very different meaning.
//
// Use of this type presents a security risk:
// the encapsulated content should come from a trusted source,
// as it will be included verbatim in the template output.
//
// Using JS to include valid but untrusted JSON is not safe.
// A safe alternative is to parse the JSON with json.Unmarshal and then
// pass the resultant object into the template, where it will be
// converted to sanitized JSON when presented in a JavaScript context.
JS string
// JSStr encapsulates a sequence of characters meant to be embedded
// between quotes in a JavaScript expression.
// The string must match a series of StringCharacters:
// StringCharacter :: SourceCharacter but not `\` or LineTerminator
// | EscapeSequence
// Note that LineContinuations are not allowed.
// JSStr("foo\\nbar") is fine, but JSStr("foo\\\nbar") is not.
//
// Use of this type presents a security risk:
// the encapsulated content should come from a trusted source,
// as it will be included verbatim in the template output.
JSStr string
// URL encapsulates a known safe URL or URL substring (see RFC 3986).
// A URL like `javascript:checkThatFormNotEditedBeforeLeavingPage()`
// from a trusted source should go in the page, but by default dynamic
// `javascript:` URLs are filtered out since they are a frequently
// exploited injection vector.
//
// Use of this type presents a security risk:
// the encapsulated content should come from a trusted source,
// as it will be included verbatim in the template output.
URL string
// Srcset encapsulates a known safe srcset attribute
// (see https://w3c.github.io/html/semantics-embedded-content.html#element-attrdef-img-srcset).
//
// Use of this type presents a security risk:
// the encapsulated content should come from a trusted source,
// as it will be included verbatim in the template output.
Srcset string
)
type contentType uint8
const (
contentTypePlain contentType = iota
contentTypeCSS
contentTypeHTML
contentTypeHTMLAttr
contentTypeJS
contentTypeJSStr
contentTypeURL
contentTypeSrcset
// contentTypeUnsafe is used in attr.go for values that affect how
// embedded content and network messages are formed, vetted,
// or interpreted; or which credentials network messages carry.
contentTypeUnsafe
)
// indirect returns the value, after dereferencing as many times
// as necessary to reach the base type (or nil).
func indirect(a any) any {
if a == nil {
return nil
}
if t := reflect.TypeOf(a); t.Kind() != reflect.Pointer {
// Avoid creating a reflect.Value if it's not a pointer.
return a
}
v := reflect.ValueOf(a)
for v.Kind() == reflect.Pointer && !v.IsNil() {
v = v.Elem()
}
return v.Interface()
}
var (
errorType = reflect.TypeFor[error]()
fmtStringerType = reflect.TypeFor[fmt.Stringer]()
)
// indirectToStringerOrError returns the value, after dereferencing as many times
// as necessary to reach the base type (or nil) or an implementation of fmt.Stringer
// or error.
func indirectToStringerOrError(a any) any {
if a == nil {
return nil
}
v := reflect.ValueOf(a)
for !v.Type().Implements(fmtStringerType) && !v.Type().Implements(errorType) && v.Kind() == reflect.Pointer && !v.IsNil() {
v = v.Elem()
}
return v.Interface()
}
// stringify converts its arguments to a string and the type of the content.
// All pointers are dereferenced, as in the text/template package.
func stringify(args ...any) (string, contentType) {
if len(args) == 1 {
switch s := indirect(args[0]).(type) {
case string:
return s, contentTypePlain
case CSS:
return string(s), contentTypeCSS
case HTML:
return string(s), contentTypeHTML
case HTMLAttr:
return string(s), contentTypeHTMLAttr
case JS:
return string(s), contentTypeJS
case JSStr:
return string(s), contentTypeJSStr
case URL:
return string(s), contentTypeURL
case Srcset:
return string(s), contentTypeSrcset
}
}
i := 0
for _, arg := range args {
// We skip untyped nil arguments for backward compatibility.
// Without this they would be output as <nil>, escaped.
// See issue 25875.
if arg == nil {
continue
}
args[i] = indirectToStringerOrError(arg)
i++
}
return fmt.Sprint(args[:i]...), contentTypePlain
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"fmt"
"text/template/parse"
)
// context describes the state an HTML parser must be in when it reaches the
// portion of HTML produced by evaluating a particular template node.
//
// The zero value of type context is the start context for a template that
// produces an HTML fragment as defined at
// https://www.w3.org/TR/html5/syntax.html#the-end
// where the context element is null.
type context struct {
state state
delim delim
urlPart urlPart
jsCtx jsCtx
// jsBraceDepth contains the current depth, for each JS template literal
// string interpolation expression, of braces we've seen. This is used to
// determine if the next } will close a JS template literal string
// interpolation expression or not.
jsBraceDepth []int
attr attr
element element
n parse.Node // for range break/continue
err *Error
}
func (c context) String() string {
var err error
if c.err != nil {
err = c.err
}
return fmt.Sprintf("{%v %v %v %v %v %v %v}", c.state, c.delim, c.urlPart, c.jsCtx, c.attr, c.element, err)
}
// eq reports whether two contexts are equal.
func (c context) eq(d context) bool {
return c.state == d.state &&
c.delim == d.delim &&
c.urlPart == d.urlPart &&
c.jsCtx == d.jsCtx &&
c.attr == d.attr &&
c.element == d.element &&
c.err == d.err
}
// mangle produces an identifier that includes a suffix that distinguishes it
// from template names mangled with different contexts.
func (c context) mangle(templateName string) string {
// The mangled name for the default context is the input templateName.
if c.state == stateText {
return templateName
}
s := templateName + "$htmltemplate_" + c.state.String()
if c.delim != delimNone {
s += "_" + c.delim.String()
}
if c.urlPart != urlPartNone {
s += "_" + c.urlPart.String()
}
if c.jsCtx != jsCtxRegexp {
s += "_" + c.jsCtx.String()
}
if c.attr != attrNone {
s += "_" + c.attr.String()
}
if c.element != elementNone {
s += "_" + c.element.String()
}
return s
}
// state describes a high-level HTML parser state.
//
// It bounds the top of the element stack, and by extension the HTML insertion
// mode, but also contains state that does not correspond to anything in the
// HTML5 parsing algorithm because a single token production in the HTML
// grammar may contain embedded actions in a template. For instance, the quoted
// HTML attribute produced by
//
// <div title="Hello {{.World}}">
//
// is a single token in HTML's grammar but in a template spans several nodes.
type state uint8
//go:generate stringer -type state
const (
// stateText is parsed character data. An HTML parser is in
// this state when its parse position is outside an HTML tag,
// directive, comment, and special element body.
stateText state = iota
// stateTag occurs before an HTML attribute or the end of a tag.
stateTag
// stateAttrName occurs inside an attribute name.
// It occurs between the ^'s in ` ^name^ = value`.
stateAttrName
// stateAfterName occurs after an attr name has ended but before any
// equals sign. It occurs between the ^'s in ` name^ ^= value`.
stateAfterName
// stateBeforeValue occurs after the equals sign but before the value.
// It occurs between the ^'s in ` name =^ ^value`.
stateBeforeValue
// stateHTMLCmt occurs inside an <!-- HTML comment -->.
stateHTMLCmt
// stateRCDATA occurs inside an RCDATA element (<textarea> or <title>)
// as described at https://www.w3.org/TR/html5/syntax.html#elements-0
stateRCDATA
// stateAttr occurs inside an HTML attribute whose content is text.
stateAttr
// stateURL occurs inside an HTML attribute whose content is a URL.
stateURL
// stateSrcset occurs inside an HTML srcset attribute.
stateSrcset
// stateJS occurs inside an event handler or script element.
stateJS
// stateJSDqStr occurs inside a JavaScript double quoted string.
stateJSDqStr
// stateJSSqStr occurs inside a JavaScript single quoted string.
stateJSSqStr
// stateJSTmplLit occurs inside a JavaScript back quoted string.
stateJSTmplLit
// stateJSRegexp occurs inside a JavaScript regexp literal.
stateJSRegexp
// stateJSBlockCmt occurs inside a JavaScript /* block comment */.
stateJSBlockCmt
// stateJSLineCmt occurs inside a JavaScript // line comment.
stateJSLineCmt
// stateJSHTMLOpenCmt occurs inside a JavaScript <!-- HTML-like comment.
stateJSHTMLOpenCmt
// stateJSHTMLCloseCmt occurs inside a JavaScript --> HTML-like comment.
stateJSHTMLCloseCmt
// stateCSS occurs inside a <style> element or style attribute.
stateCSS
// stateCSSDqStr occurs inside a CSS double quoted string.
stateCSSDqStr
// stateCSSSqStr occurs inside a CSS single quoted string.
stateCSSSqStr
// stateCSSDqURL occurs inside a CSS double quoted url("...").
stateCSSDqURL
// stateCSSSqURL occurs inside a CSS single quoted url('...').
stateCSSSqURL
// stateCSSURL occurs inside a CSS unquoted url(...).
stateCSSURL
// stateCSSBlockCmt occurs inside a CSS /* block comment */.
stateCSSBlockCmt
// stateCSSLineCmt occurs inside a CSS // line comment.
stateCSSLineCmt
// stateError is an infectious error state outside any valid
// HTML/CSS/JS construct.
stateError
// stateDead marks unreachable code after a {{break}} or {{continue}}.
stateDead
)
// isComment is true for any state that contains content meant for template
// authors & maintainers, not for end-users or machines.
func isComment(s state) bool {
switch s {
case stateHTMLCmt, stateJSBlockCmt, stateJSLineCmt, stateJSHTMLOpenCmt, stateJSHTMLCloseCmt, stateCSSBlockCmt, stateCSSLineCmt:
return true
}
return false
}
// isInTag return whether s occurs solely inside an HTML tag.
func isInTag(s state) bool {
switch s {
case stateTag, stateAttrName, stateAfterName, stateBeforeValue, stateAttr:
return true
}
return false
}
// isInScriptLiteral returns true if s is one of the literal states within a
// <script> tag, and as such occurrences of "<!--", "<script", and "</script"
// need to be treated specially.
func isInScriptLiteral(s state) bool {
// Ignore the comment states (stateJSBlockCmt, stateJSLineCmt,
// stateJSHTMLOpenCmt, stateJSHTMLCloseCmt) because their content is already
// omitted from the output.
switch s {
case stateJSDqStr, stateJSSqStr, stateJSTmplLit, stateJSRegexp:
return true
}
return false
}
// delim is the delimiter that will end the current HTML attribute.
type delim uint8
//go:generate stringer -type delim
const (
// delimNone occurs outside any attribute.
delimNone delim = iota
// delimDoubleQuote occurs when a double quote (") closes the attribute.
delimDoubleQuote
// delimSingleQuote occurs when a single quote (') closes the attribute.
delimSingleQuote
// delimSpaceOrTagEnd occurs when a space or right angle bracket (>)
// closes the attribute.
delimSpaceOrTagEnd
)
// urlPart identifies a part in an RFC 3986 hierarchical URL to allow different
// encoding strategies.
type urlPart uint8
//go:generate stringer -type urlPart
const (
// urlPartNone occurs when not in a URL, or possibly at the start:
// ^ in "^http://auth/path?k=v#frag".
urlPartNone urlPart = iota
// urlPartPreQuery occurs in the scheme, authority, or path; between the
// ^s in "h^ttp://auth/path^?k=v#frag".
urlPartPreQuery
// urlPartQueryOrFrag occurs in the query portion between the ^s in
// "http://auth/path?^k=v#frag^".
urlPartQueryOrFrag
// urlPartUnknown occurs due to joining of contexts both before and
// after the query separator.
urlPartUnknown
)
// jsCtx determines whether a '/' starts a regular expression literal or a
// division operator.
type jsCtx uint8
//go:generate stringer -type jsCtx
const (
// jsCtxRegexp occurs where a '/' would start a regexp literal.
jsCtxRegexp jsCtx = iota
// jsCtxDivOp occurs where a '/' would start a division operator.
jsCtxDivOp
// jsCtxUnknown occurs where a '/' is ambiguous due to context joining.
jsCtxUnknown
)
// element identifies the HTML element when inside a start tag or special body.
// Certain HTML element (for example <script> and <style>) have bodies that are
// treated differently from stateText so the element type is necessary to
// transition into the correct context at the end of a tag and to identify the
// end delimiter for the body.
type element uint8
//go:generate stringer -type element
const (
// elementNone occurs outside a special tag or special element body.
elementNone element = iota
// elementScript corresponds to the raw text <script> element
// with JS MIME type or no type attribute.
elementScript
// elementStyle corresponds to the raw text <style> element.
elementStyle
// elementTextarea corresponds to the RCDATA <textarea> element.
elementTextarea
// elementTitle corresponds to the RCDATA <title> element.
elementTitle
)
//go:generate stringer -type attr
// attr identifies the current HTML attribute when inside the attribute,
// that is, starting from stateAttrName until stateTag/stateText (exclusive).
type attr uint8
const (
// attrNone corresponds to a normal attribute or no attribute.
attrNone attr = iota
// attrScript corresponds to an event handler attribute.
attrScript
// attrScriptType corresponds to the type attribute in script HTML element
attrScriptType
// attrStyle corresponds to the style attribute whose value is CSS.
attrStyle
// attrURL corresponds to an attribute whose value is a URL.
attrURL
// attrSrcset corresponds to a srcset attribute.
attrSrcset
)
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"bytes"
"fmt"
"strings"
"unicode"
"unicode/utf8"
)
// endsWithCSSKeyword reports whether b ends with an ident that
// case-insensitively matches the lower-case kw.
func endsWithCSSKeyword(b []byte, kw string) bool {
i := len(b) - len(kw)
if i < 0 {
// Too short.
return false
}
if i != 0 {
r, _ := utf8.DecodeLastRune(b[:i])
if isCSSNmchar(r) {
// Too long.
return false
}
}
// Many CSS keywords, such as "!important" can have characters encoded,
// but the URI production does not allow that according to
// https://www.w3.org/TR/css3-syntax/#TOK-URI
// This does not attempt to recognize encoded keywords. For example,
// given "\75\72\6c" and "url" this return false.
return string(bytes.ToLower(b[i:])) == kw
}
// isCSSNmchar reports whether rune is allowed anywhere in a CSS identifier.
func isCSSNmchar(r rune) bool {
// Based on the CSS3 nmchar production but ignores multi-rune escape
// sequences.
// https://www.w3.org/TR/css3-syntax/#SUBTOK-nmchar
return 'a' <= r && r <= 'z' ||
'A' <= r && r <= 'Z' ||
'0' <= r && r <= '9' ||
r == '-' ||
r == '_' ||
// Non-ASCII cases below.
0x80 <= r && r <= 0xd7ff ||
0xe000 <= r && r <= 0xfffd ||
0x10000 <= r && r <= 0x10ffff
}
// decodeCSS decodes CSS3 escapes given a sequence of stringchars.
// If there is no change, it returns the input, otherwise it returns a slice
// backed by a new array.
// https://www.w3.org/TR/css3-syntax/#SUBTOK-stringchar defines stringchar.
func decodeCSS(s []byte) []byte {
i := bytes.IndexByte(s, '\\')
if i == -1 {
return s
}
// The UTF-8 sequence for a codepoint is never longer than 1 + the
// number hex digits need to represent that codepoint, so len(s) is an
// upper bound on the output length.
b := make([]byte, 0, len(s))
for len(s) != 0 {
i := bytes.IndexByte(s, '\\')
if i == -1 {
i = len(s)
}
b, s = append(b, s[:i]...), s[i:]
if len(s) < 2 {
break
}
// https://www.w3.org/TR/css3-syntax/#SUBTOK-escape
// escape ::= unicode | '\' [#x20-#x7E#x80-#xD7FF#xE000-#xFFFD#x10000-#x10FFFF]
if isHex(s[1]) {
// https://www.w3.org/TR/css3-syntax/#SUBTOK-unicode
// unicode ::= '\' [0-9a-fA-F]{1,6} wc?
j := 2
for j < len(s) && j < 7 && isHex(s[j]) {
j++
}
r := hexDecode(s[1:j])
if r > unicode.MaxRune {
r, j = r/16, j-1
}
n := utf8.EncodeRune(b[len(b):cap(b)], r)
// The optional space at the end allows a hex
// sequence to be followed by a literal hex.
// string(decodeCSS([]byte(`\A B`))) == "\nB"
b, s = b[:len(b)+n], skipCSSSpace(s[j:])
} else {
// `\\` decodes to `\` and `\"` to `"`.
_, n := utf8.DecodeRune(s[1:])
b, s = append(b, s[1:1+n]...), s[1+n:]
}
}
return b
}
// isHex reports whether the given character is a hex digit.
func isHex(c byte) bool {
return '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F'
}
// hexDecode decodes a short hex digit sequence: "10" -> 16.
func hexDecode(s []byte) rune {
n := '\x00'
for _, c := range s {
n <<= 4
switch {
case '0' <= c && c <= '9':
n |= rune(c - '0')
case 'a' <= c && c <= 'f':
n |= rune(c-'a') + 10
case 'A' <= c && c <= 'F':
n |= rune(c-'A') + 10
default:
panic(fmt.Sprintf("Bad hex digit in %q", s))
}
}
return n
}
// skipCSSSpace returns a suffix of c, skipping over a single space.
func skipCSSSpace(c []byte) []byte {
if len(c) == 0 {
return c
}
// wc ::= #x9 | #xA | #xC | #xD | #x20
switch c[0] {
case '\t', '\n', '\f', ' ':
return c[1:]
case '\r':
// This differs from CSS3's wc production because it contains a
// probable spec error whereby wc contains all the single byte
// sequences in nl (newline) but not CRLF.
if len(c) >= 2 && c[1] == '\n' {
return c[2:]
}
return c[1:]
}
return c
}
// isCSSSpace reports whether b is a CSS space char as defined in wc.
func isCSSSpace(b byte) bool {
switch b {
case '\t', '\n', '\f', '\r', ' ':
return true
}
return false
}
// cssEscaper escapes HTML and CSS special characters using \<hex>+ escapes.
func cssEscaper(args ...any) string {
s, _ := stringify(args...)
var b strings.Builder
r, w, written := rune(0), 0, 0
for i := 0; i < len(s); i += w {
// See comment in htmlEscaper.
r, w = utf8.DecodeRuneInString(s[i:])
var repl string
switch {
case int(r) < len(cssReplacementTable) && cssReplacementTable[r] != "":
repl = cssReplacementTable[r]
default:
continue
}
if written == 0 {
b.Grow(len(s))
}
b.WriteString(s[written:i])
b.WriteString(repl)
written = i + w
if repl != `\\` && (written == len(s) || isHex(s[written]) || isCSSSpace(s[written])) {
b.WriteByte(' ')
}
}
if written == 0 {
return s
}
b.WriteString(s[written:])
return b.String()
}
var cssReplacementTable = []string{
0: `\0`,
'\t': `\9`,
'\n': `\a`,
'\f': `\c`,
'\r': `\d`,
// Encode HTML specials as hex so the output can be embedded
// in HTML attributes without further encoding.
'"': `\22`,
'&': `\26`,
'\'': `\27`,
'(': `\28`,
')': `\29`,
'+': `\2b`,
'/': `\2f`,
':': `\3a`,
';': `\3b`,
'<': `\3c`,
'>': `\3e`,
'\\': `\\`,
'{': `\7b`,
'}': `\7d`,
}
var expressionBytes = []byte("expression")
var mozBindingBytes = []byte("mozbinding")
// cssValueFilter allows innocuous CSS values in the output including CSS
// quantities (10px or 25%), ID or class literals (#foo, .bar), keyword values
// (inherit, blue), and colors (#888).
// It filters out unsafe values, such as those that affect token boundaries,
// and anything that might execute scripts.
func cssValueFilter(args ...any) string {
s, t := stringify(args...)
if t == contentTypeCSS {
return s
}
b, id := decodeCSS([]byte(s)), make([]byte, 0, 64)
// CSS3 error handling is specified as honoring string boundaries per
// https://www.w3.org/TR/css3-syntax/#error-handling :
// Malformed declarations. User agents must handle unexpected
// tokens encountered while parsing a declaration by reading until
// the end of the declaration, while observing the rules for
// matching pairs of (), [], {}, "", and '', and correctly handling
// escapes. For example, a malformed declaration may be missing a
// property, colon (:) or value.
// So we need to make sure that values do not have mismatched bracket
// or quote characters to prevent the browser from restarting parsing
// inside a string that might embed JavaScript source.
for i, c := range b {
switch c {
case 0, '"', '\'', '(', ')', '/', ';', '@', '[', '\\', ']', '`', '{', '}', '<', '>':
return filterFailsafe
case '-':
// Disallow <!-- or -->.
// -- should not appear in valid identifiers.
if i != 0 && b[i-1] == '-' {
return filterFailsafe
}
default:
if c < utf8.RuneSelf && isCSSNmchar(rune(c)) {
id = append(id, c)
}
}
}
id = bytes.ToLower(id)
if bytes.Contains(id, expressionBytes) || bytes.Contains(id, mozBindingBytes) {
return filterFailsafe
}
return string(b)
}
// Code generated by "stringer -type delim"; DO NOT EDIT.
package template
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[delimNone-0]
_ = x[delimDoubleQuote-1]
_ = x[delimSingleQuote-2]
_ = x[delimSpaceOrTagEnd-3]
}
const _delim_name = "delimNonedelimDoubleQuotedelimSingleQuotedelimSpaceOrTagEnd"
var _delim_index = [...]uint8{0, 9, 25, 41, 59}
func (i delim) String() string {
if i >= delim(len(_delim_index)-1) {
return "delim(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _delim_name[_delim_index[i]:_delim_index[i+1]]
}
// Code generated by "stringer -type element"; DO NOT EDIT.
package template
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[elementNone-0]
_ = x[elementScript-1]
_ = x[elementStyle-2]
_ = x[elementTextarea-3]
_ = x[elementTitle-4]
}
const _element_name = "elementNoneelementScriptelementStyleelementTextareaelementTitle"
var _element_index = [...]uint8{0, 11, 24, 36, 51, 63}
func (i element) String() string {
if i >= element(len(_element_index)-1) {
return "element(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _element_name[_element_index[i]:_element_index[i+1]]
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"fmt"
"text/template/parse"
)
// Error describes a problem encountered during template Escaping.
type Error struct {
// ErrorCode describes the kind of error.
ErrorCode ErrorCode
// Node is the node that caused the problem, if known.
// If not nil, it overrides Name and Line.
Node parse.Node
// Name is the name of the template in which the error was encountered.
Name string
// Line is the line number of the error in the template source or 0.
Line int
// Description is a human-readable description of the problem.
Description string
}
// ErrorCode is a code for a kind of error.
type ErrorCode int
// We define codes for each error that manifests while escaping templates, but
// escaped templates may also fail at runtime.
//
// Output: "ZgotmplZ"
// Example:
//
// <img src="{{.X}}">
// where {{.X}} evaluates to `javascript:...`
//
// Discussion:
//
// "ZgotmplZ" is a special value that indicates that unsafe content reached a
// CSS or URL context at runtime. The output of the example will be
// <img src="#ZgotmplZ">
// If the data comes from a trusted source, use content types to exempt it
// from filtering: URL(`javascript:...`).
const (
// OK indicates the lack of an error.
OK ErrorCode = iota
// ErrAmbigContext: "... appears in an ambiguous context within a URL"
// Example:
// <a href="
// {{if .C}}
// /path/
// {{else}}
// /search?q=
// {{end}}
// {{.X}}
// ">
// Discussion:
// {{.X}} is in an ambiguous URL context since, depending on {{.C}},
// it may be either a URL suffix or a query parameter.
// Moving {{.X}} into the condition removes the ambiguity:
// <a href="{{if .C}}/path/{{.X}}{{else}}/search?q={{.X}}">
ErrAmbigContext
// ErrBadHTML: "expected space, attr name, or end of tag, but got ...",
// "... in unquoted attr", "... in attribute name"
// Example:
// <a href = /search?q=foo>
// <href=foo>
// <form na<e=...>
// <option selected<
// Discussion:
// This is often due to a typo in an HTML element, but some runes
// are banned in tag names, attribute names, and unquoted attribute
// values because they can tickle parser ambiguities.
// Quoting all attributes is the best policy.
ErrBadHTML
// ErrBranchEnd: "{{if}} branches end in different contexts"
// Examples:
// {{if .C}}<a href="{{end}}{{.X}}
// <script {{with .T}}type="{{.}}"{{end}}>
// Discussion:
// Package html/template statically examines each path through an
// {{if}}, {{range}}, or {{with}} to escape any following pipelines.
// The first example is ambiguous since {{.X}} might be an HTML text node,
// or a URL prefix in an HTML attribute. The context of {{.X}} is
// used to figure out how to escape it, but that context depends on
// the run-time value of {{.C}} which is not statically known.
// The second example is ambiguous as the script type attribute
// can change the type of escaping needed for the script contents.
//
// The problem is usually something like missing quotes or angle
// brackets, or can be avoided by refactoring to put the two contexts
// into different branches of an if, range or with. If the problem
// is in a {{range}} over a collection that should never be empty,
// adding a dummy {{else}} can help.
ErrBranchEnd
// ErrEndContext: "... ends in a non-text context: ..."
// Examples:
// <div
// <div title="no close quote>
// <script>f()
// Discussion:
// Executed templates should produce a DocumentFragment of HTML.
// Templates that end without closing tags will trigger this error.
// Templates that should not be used in an HTML context or that
// produce incomplete Fragments should not be executed directly.
//
// {{define "main"}} <script>{{template "helper"}}</script> {{end}}
// {{define "helper"}} document.write(' <div title=" ') {{end}}
//
// "helper" does not produce a valid document fragment, so should
// not be Executed directly.
ErrEndContext
// ErrNoSuchTemplate: "no such template ..."
// Examples:
// {{define "main"}}<div {{template "attrs"}}>{{end}}
// {{define "attrs"}}href="{{.URL}}"{{end}}
// Discussion:
// Package html/template looks through template calls to compute the
// context.
// Here the {{.URL}} in "attrs" must be treated as a URL when called
// from "main", but you will get this error if "attrs" is not defined
// when "main" is parsed.
ErrNoSuchTemplate
// ErrOutputContext: "cannot compute output context for template ..."
// Examples:
// {{define "t"}}{{if .T}}{{template "t" .T}}{{end}}{{.H}}",{{end}}
// Discussion:
// A recursive template does not end in the same context in which it
// starts, and a reliable output context cannot be computed.
// Look for typos in the named template.
// If the template should not be called in the named start context,
// look for calls to that template in unexpected contexts.
// Maybe refactor recursive templates to not be recursive.
ErrOutputContext
// ErrPartialCharset: "unfinished JS regexp charset in ..."
// Example:
// <script>var pattern = /foo[{{.Chars}}]/</script>
// Discussion:
// Package html/template does not support interpolation into regular
// expression literal character sets.
ErrPartialCharset
// ErrPartialEscape: "unfinished escape sequence in ..."
// Example:
// <script>alert("\{{.X}}")</script>
// Discussion:
// Package html/template does not support actions following a
// backslash.
// This is usually an error and there are better solutions; for
// example
// <script>alert("{{.X}}")</script>
// should work, and if {{.X}} is a partial escape sequence such as
// "xA0", mark the whole sequence as safe content: JSStr(`\xA0`)
ErrPartialEscape
// ErrRangeLoopReentry: "on range loop re-entry: ..."
// Example:
// <script>var x = [{{range .}}'{{.}},{{end}}]</script>
// Discussion:
// If an iteration through a range would cause it to end in a
// different context than an earlier pass, there is no single context.
// In the example, there is missing a quote, so it is not clear
// whether {{.}} is meant to be inside a JS string or in a JS value
// context. The second iteration would produce something like
//
// <script>var x = ['firstValue,'secondValue]</script>
ErrRangeLoopReentry
// ErrSlashAmbig: '/' could start a division or regexp.
// Example:
// <script>
// {{if .C}}var x = 1{{end}}
// /-{{.N}}/i.test(x) ? doThis : doThat();
// </script>
// Discussion:
// The example above could produce `var x = 1/-2/i.test(s)...`
// in which the first '/' is a mathematical division operator or it
// could produce `/-2/i.test(s)` in which the first '/' starts a
// regexp literal.
// Look for missing semicolons inside branches, and maybe add
// parentheses to make it clear which interpretation you intend.
ErrSlashAmbig
// ErrPredefinedEscaper: "predefined escaper ... disallowed in template"
// Example:
// <div class={{. | html}}>Hello<div>
// Discussion:
// Package html/template already contextually escapes all pipelines to
// produce HTML output safe against code injection. Manually escaping
// pipeline output using the predefined escapers "html" or "urlquery" is
// unnecessary, and may affect the correctness or safety of the escaped
// pipeline output in Go 1.8 and earlier.
//
// In most cases, such as the given example, this error can be resolved by
// simply removing the predefined escaper from the pipeline and letting the
// contextual autoescaper handle the escaping of the pipeline. In other
// instances, where the predefined escaper occurs in the middle of a
// pipeline where subsequent commands expect escaped input, e.g.
// {{.X | html | makeALink}}
// where makeALink does
// return `<a href="`+input+`">link</a>`
// consider refactoring the surrounding template to make use of the
// contextual autoescaper, i.e.
// <a href="{{.X}}">link</a>
//
// To ease migration to Go 1.9 and beyond, "html" and "urlquery" will
// continue to be allowed as the last command in a pipeline. However, if the
// pipeline occurs in an unquoted attribute value context, "html" is
// disallowed. Avoid using "html" and "urlquery" entirely in new templates.
ErrPredefinedEscaper
// ErrJSTemplate: "... appears in a JS template literal"
// Example:
// <script>var tmpl = `{{.Interp}}`</script>
// Discussion:
// Package html/template does not support actions inside of JS template
// literals.
//
// Deprecated: ErrJSTemplate is no longer returned when an action is present
// in a JS template literal. Actions inside of JS template literals are now
// escaped as expected.
ErrJSTemplate
)
func (e *Error) Error() string {
switch {
case e.Node != nil:
loc, _ := (*parse.Tree)(nil).ErrorContext(e.Node)
return fmt.Sprintf("html/template:%s: %s", loc, e.Description)
case e.Line != 0:
return fmt.Sprintf("html/template:%s:%d: %s", e.Name, e.Line, e.Description)
case e.Name != "":
return fmt.Sprintf("html/template:%s: %s", e.Name, e.Description)
}
return "html/template: " + e.Description
}
// errorf creates an error given a format string f and args.
// The template Name still needs to be supplied.
func errorf(k ErrorCode, node parse.Node, line int, f string, args ...any) *Error {
return &Error{k, node, "", line, fmt.Sprintf(f, args...)}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"bytes"
"fmt"
"html"
"internal/godebug"
"io"
"maps"
"regexp"
"text/template"
"text/template/parse"
)
// escapeTemplate rewrites the named template, which must be
// associated with t, to guarantee that the output of any of the named
// templates is properly escaped. If no error is returned, then the named templates have
// been modified. Otherwise the named templates have been rendered
// unusable.
func escapeTemplate(tmpl *Template, node parse.Node, name string) error {
c, _ := tmpl.esc.escapeTree(context{}, node, name, 0)
var err error
if c.err != nil {
err, c.err.Name = c.err, name
} else if c.state != stateText {
err = &Error{ErrEndContext, nil, name, 0, fmt.Sprintf("ends in a non-text context: %v", c)}
}
if err != nil {
// Prevent execution of unsafe templates.
if t := tmpl.set[name]; t != nil {
t.escapeErr = err
t.text.Tree = nil
t.Tree = nil
}
return err
}
tmpl.esc.commit()
if t := tmpl.set[name]; t != nil {
t.escapeErr = escapeOK
t.Tree = t.text.Tree
}
return nil
}
// evalArgs formats the list of arguments into a string. It is equivalent to
// fmt.Sprint(args...), except that it dereferences all pointers.
func evalArgs(args ...any) string {
// Optimization for simple common case of a single string argument.
if len(args) == 1 {
if s, ok := args[0].(string); ok {
return s
}
}
for i, arg := range args {
args[i] = indirectToStringerOrError(arg)
}
return fmt.Sprint(args...)
}
// funcMap maps command names to functions that render their inputs safe.
var funcMap = template.FuncMap{
"_html_template_attrescaper": attrEscaper,
"_html_template_commentescaper": commentEscaper,
"_html_template_cssescaper": cssEscaper,
"_html_template_cssvaluefilter": cssValueFilter,
"_html_template_htmlnamefilter": htmlNameFilter,
"_html_template_htmlescaper": htmlEscaper,
"_html_template_jsregexpescaper": jsRegexpEscaper,
"_html_template_jsstrescaper": jsStrEscaper,
"_html_template_jstmpllitescaper": jsTmplLitEscaper,
"_html_template_jsvalescaper": jsValEscaper,
"_html_template_nospaceescaper": htmlNospaceEscaper,
"_html_template_rcdataescaper": rcdataEscaper,
"_html_template_srcsetescaper": srcsetFilterAndEscaper,
"_html_template_urlescaper": urlEscaper,
"_html_template_urlfilter": urlFilter,
"_html_template_urlnormalizer": urlNormalizer,
"_eval_args_": evalArgs,
}
// escaper collects type inferences about templates and changes needed to make
// templates injection safe.
type escaper struct {
// ns is the nameSpace that this escaper is associated with.
ns *nameSpace
// output[templateName] is the output context for a templateName that
// has been mangled to include its input context.
output map[string]context
// derived[c.mangle(name)] maps to a template derived from the template
// named name templateName for the start context c.
derived map[string]*template.Template
// called[templateName] is a set of called mangled template names.
called map[string]bool
// xxxNodeEdits are the accumulated edits to apply during commit.
// Such edits are not applied immediately in case a template set
// executes a given template in different escaping contexts.
actionNodeEdits map[*parse.ActionNode][]string
templateNodeEdits map[*parse.TemplateNode]string
textNodeEdits map[*parse.TextNode][]byte
// rangeContext holds context about the current range loop.
rangeContext *rangeContext
}
// rangeContext holds information about the current range loop.
type rangeContext struct {
outer *rangeContext // outer loop
breaks []context // context at each break action
continues []context // context at each continue action
}
// makeEscaper creates a blank escaper for the given set.
func makeEscaper(n *nameSpace) escaper {
return escaper{
n,
map[string]context{},
map[string]*template.Template{},
map[string]bool{},
map[*parse.ActionNode][]string{},
map[*parse.TemplateNode]string{},
map[*parse.TextNode][]byte{},
nil,
}
}
// filterFailsafe is an innocuous word that is emitted in place of unsafe values
// by sanitizer functions. It is not a keyword in any programming language,
// contains no special characters, is not empty, and when it appears in output
// it is distinct enough that a developer can find the source of the problem
// via a search engine.
const filterFailsafe = "ZgotmplZ"
// escape escapes a template node.
func (e *escaper) escape(c context, n parse.Node) context {
switch n := n.(type) {
case *parse.ActionNode:
return e.escapeAction(c, n)
case *parse.BreakNode:
c.n = n
e.rangeContext.breaks = append(e.rangeContext.breaks, c)
return context{state: stateDead}
case *parse.CommentNode:
return c
case *parse.ContinueNode:
c.n = n
e.rangeContext.continues = append(e.rangeContext.continues, c)
return context{state: stateDead}
case *parse.IfNode:
return e.escapeBranch(c, &n.BranchNode, "if")
case *parse.ListNode:
return e.escapeList(c, n)
case *parse.RangeNode:
return e.escapeBranch(c, &n.BranchNode, "range")
case *parse.TemplateNode:
return e.escapeTemplate(c, n)
case *parse.TextNode:
return e.escapeText(c, n)
case *parse.WithNode:
return e.escapeBranch(c, &n.BranchNode, "with")
}
panic("escaping " + n.String() + " is unimplemented")
}
var debugAllowActionJSTmpl = godebug.New("jstmpllitinterp")
// escapeAction escapes an action template node.
func (e *escaper) escapeAction(c context, n *parse.ActionNode) context {
if len(n.Pipe.Decl) != 0 {
// A local variable assignment, not an interpolation.
return c
}
c = nudge(c)
// Check for disallowed use of predefined escapers in the pipeline.
for pos, idNode := range n.Pipe.Cmds {
node, ok := idNode.Args[0].(*parse.IdentifierNode)
if !ok {
// A predefined escaper "esc" will never be found as an identifier in a
// Chain or Field node, since:
// - "esc.x ..." is invalid, since predefined escapers return strings, and
// strings do not have methods, keys or fields.
// - "... .esc" is invalid, since predefined escapers are global functions,
// not methods or fields of any types.
// Therefore, it is safe to ignore these two node types.
continue
}
ident := node.Ident
if _, ok := predefinedEscapers[ident]; ok {
if pos < len(n.Pipe.Cmds)-1 ||
c.state == stateAttr && c.delim == delimSpaceOrTagEnd && ident == "html" {
return context{
state: stateError,
err: errorf(ErrPredefinedEscaper, n, n.Line, "predefined escaper %q disallowed in template", ident),
}
}
}
}
s := make([]string, 0, 3)
switch c.state {
case stateError:
return c
case stateURL, stateCSSDqStr, stateCSSSqStr, stateCSSDqURL, stateCSSSqURL, stateCSSURL:
switch c.urlPart {
case urlPartNone:
s = append(s, "_html_template_urlfilter")
fallthrough
case urlPartPreQuery:
switch c.state {
case stateCSSDqStr, stateCSSSqStr:
s = append(s, "_html_template_cssescaper")
default:
s = append(s, "_html_template_urlnormalizer")
}
case urlPartQueryOrFrag:
s = append(s, "_html_template_urlescaper")
case urlPartUnknown:
return context{
state: stateError,
err: errorf(ErrAmbigContext, n, n.Line, "%s appears in an ambiguous context within a URL", n),
}
default:
panic(c.urlPart.String())
}
case stateJS:
s = append(s, "_html_template_jsvalescaper")
// A slash after a value starts a div operator.
c.jsCtx = jsCtxDivOp
case stateJSDqStr, stateJSSqStr:
s = append(s, "_html_template_jsstrescaper")
case stateJSTmplLit:
s = append(s, "_html_template_jstmpllitescaper")
case stateJSRegexp:
s = append(s, "_html_template_jsregexpescaper")
case stateCSS:
s = append(s, "_html_template_cssvaluefilter")
case stateText:
s = append(s, "_html_template_htmlescaper")
case stateRCDATA:
s = append(s, "_html_template_rcdataescaper")
case stateAttr:
// Handled below in delim check.
case stateAttrName, stateTag:
c.state = stateAttrName
s = append(s, "_html_template_htmlnamefilter")
case stateSrcset:
s = append(s, "_html_template_srcsetescaper")
default:
if isComment(c.state) {
s = append(s, "_html_template_commentescaper")
} else {
panic("unexpected state " + c.state.String())
}
}
switch c.delim {
case delimNone:
// No extra-escaping needed for raw text content.
case delimSpaceOrTagEnd:
s = append(s, "_html_template_nospaceescaper")
default:
s = append(s, "_html_template_attrescaper")
}
e.editActionNode(n, s)
return c
}
// ensurePipelineContains ensures that the pipeline ends with the commands with
// the identifiers in s in order. If the pipeline ends with a predefined escaper
// (i.e. "html" or "urlquery"), merge it with the identifiers in s.
func ensurePipelineContains(p *parse.PipeNode, s []string) {
if len(s) == 0 {
// Do not rewrite pipeline if we have no escapers to insert.
return
}
// Precondition: p.Cmds contains at most one predefined escaper and the
// escaper will be present at p.Cmds[len(p.Cmds)-1]. This precondition is
// always true because of the checks in escapeAction.
pipelineLen := len(p.Cmds)
if pipelineLen > 0 {
lastCmd := p.Cmds[pipelineLen-1]
if idNode, ok := lastCmd.Args[0].(*parse.IdentifierNode); ok {
if esc := idNode.Ident; predefinedEscapers[esc] {
// Pipeline ends with a predefined escaper.
if len(p.Cmds) == 1 && len(lastCmd.Args) > 1 {
// Special case: pipeline is of the form {{ esc arg1 arg2 ... argN }},
// where esc is the predefined escaper, and arg1...argN are its arguments.
// Convert this into the equivalent form
// {{ _eval_args_ arg1 arg2 ... argN | esc }}, so that esc can be easily
// merged with the escapers in s.
lastCmd.Args[0] = parse.NewIdentifier("_eval_args_").SetTree(nil).SetPos(lastCmd.Args[0].Position())
p.Cmds = appendCmd(p.Cmds, newIdentCmd(esc, p.Position()))
pipelineLen++
}
// If any of the commands in s that we are about to insert is equivalent
// to the predefined escaper, use the predefined escaper instead.
dup := false
for i, escaper := range s {
if escFnsEq(esc, escaper) {
s[i] = idNode.Ident
dup = true
}
}
if dup {
// The predefined escaper will already be inserted along with the
// escapers in s, so do not copy it to the rewritten pipeline.
pipelineLen--
}
}
}
}
// Rewrite the pipeline, creating the escapers in s at the end of the pipeline.
newCmds := make([]*parse.CommandNode, pipelineLen, pipelineLen+len(s))
insertedIdents := make(map[string]bool)
for i := 0; i < pipelineLen; i++ {
cmd := p.Cmds[i]
newCmds[i] = cmd
if idNode, ok := cmd.Args[0].(*parse.IdentifierNode); ok {
insertedIdents[normalizeEscFn(idNode.Ident)] = true
}
}
for _, name := range s {
if !insertedIdents[normalizeEscFn(name)] {
// When two templates share an underlying parse tree via the use of
// AddParseTree and one template is executed after the other, this check
// ensures that escapers that were already inserted into the pipeline on
// the first escaping pass do not get inserted again.
newCmds = appendCmd(newCmds, newIdentCmd(name, p.Position()))
}
}
p.Cmds = newCmds
}
// predefinedEscapers contains template predefined escapers that are equivalent
// to some contextual escapers. Keep in sync with equivEscapers.
var predefinedEscapers = map[string]bool{
"html": true,
"urlquery": true,
}
// equivEscapers matches contextual escapers to equivalent predefined
// template escapers.
var equivEscapers = map[string]string{
// The following pairs of HTML escapers provide equivalent security
// guarantees, since they all escape '\000', '\'', '"', '&', '<', and '>'.
"_html_template_attrescaper": "html",
"_html_template_htmlescaper": "html",
"_html_template_rcdataescaper": "html",
// These two URL escapers produce URLs safe for embedding in a URL query by
// percent-encoding all the reserved characters specified in RFC 3986 Section
// 2.2
"_html_template_urlescaper": "urlquery",
// These two functions are not actually equivalent; urlquery is stricter as it
// escapes reserved characters (e.g. '#'), while _html_template_urlnormalizer
// does not. It is therefore only safe to replace _html_template_urlnormalizer
// with urlquery (this happens in ensurePipelineContains), but not the otherI've
// way around. We keep this entry around to preserve the behavior of templates
// written before Go 1.9, which might depend on this substitution taking place.
"_html_template_urlnormalizer": "urlquery",
}
// escFnsEq reports whether the two escaping functions are equivalent.
func escFnsEq(a, b string) bool {
return normalizeEscFn(a) == normalizeEscFn(b)
}
// normalizeEscFn(a) is equal to normalizeEscFn(b) for any pair of names of
// escaper functions a and b that are equivalent.
func normalizeEscFn(e string) string {
if norm := equivEscapers[e]; norm != "" {
return norm
}
return e
}
// redundantFuncs[a][b] implies that funcMap[b](funcMap[a](x)) == funcMap[a](x)
// for all x.
var redundantFuncs = map[string]map[string]bool{
"_html_template_commentescaper": {
"_html_template_attrescaper": true,
"_html_template_htmlescaper": true,
},
"_html_template_cssescaper": {
"_html_template_attrescaper": true,
},
"_html_template_jsregexpescaper": {
"_html_template_attrescaper": true,
},
"_html_template_jsstrescaper": {
"_html_template_attrescaper": true,
},
"_html_template_jstmpllitescaper": {
"_html_template_attrescaper": true,
},
"_html_template_urlescaper": {
"_html_template_urlnormalizer": true,
},
}
// appendCmd appends the given command to the end of the command pipeline
// unless it is redundant with the last command.
func appendCmd(cmds []*parse.CommandNode, cmd *parse.CommandNode) []*parse.CommandNode {
if n := len(cmds); n != 0 {
last, okLast := cmds[n-1].Args[0].(*parse.IdentifierNode)
next, okNext := cmd.Args[0].(*parse.IdentifierNode)
if okLast && okNext && redundantFuncs[last.Ident][next.Ident] {
return cmds
}
}
return append(cmds, cmd)
}
// newIdentCmd produces a command containing a single identifier node.
func newIdentCmd(identifier string, pos parse.Pos) *parse.CommandNode {
return &parse.CommandNode{
NodeType: parse.NodeCommand,
Args: []parse.Node{parse.NewIdentifier(identifier).SetTree(nil).SetPos(pos)}, // TODO: SetTree.
}
}
// nudge returns the context that would result from following empty string
// transitions from the input context.
// For example, parsing:
//
// `<a href=`
//
// will end in context{stateBeforeValue, attrURL}, but parsing one extra rune:
//
// `<a href=x`
//
// will end in context{stateURL, delimSpaceOrTagEnd, ...}.
// There are two transitions that happen when the 'x' is seen:
// (1) Transition from a before-value state to a start-of-value state without
//
// consuming any character.
//
// (2) Consume 'x' and transition past the first value character.
// In this case, nudging produces the context after (1) happens.
func nudge(c context) context {
switch c.state {
case stateTag:
// In `<foo {{.}}`, the action should emit an attribute.
c.state = stateAttrName
case stateBeforeValue:
// In `<foo bar={{.}}`, the action is an undelimited value.
c.state, c.delim, c.attr = attrStartStates[c.attr], delimSpaceOrTagEnd, attrNone
case stateAfterName:
// In `<foo bar {{.}}`, the action is an attribute name.
c.state, c.attr = stateAttrName, attrNone
}
return c
}
// join joins the two contexts of a branch template node. The result is an
// error context if either of the input contexts are error contexts, or if the
// input contexts differ.
func join(a, b context, node parse.Node, nodeName string) context {
if a.state == stateError {
return a
}
if b.state == stateError {
return b
}
if a.state == stateDead {
return b
}
if b.state == stateDead {
return a
}
if a.eq(b) {
return a
}
c := a
c.urlPart = b.urlPart
if c.eq(b) {
// The contexts differ only by urlPart.
c.urlPart = urlPartUnknown
return c
}
c = a
c.jsCtx = b.jsCtx
if c.eq(b) {
// The contexts differ only by jsCtx.
c.jsCtx = jsCtxUnknown
return c
}
// Allow a nudged context to join with an unnudged one.
// This means that
// <p title={{if .C}}{{.}}{{end}}
// ends in an unquoted value state even though the else branch
// ends in stateBeforeValue.
if c, d := nudge(a), nudge(b); !(c.eq(a) && d.eq(b)) {
if e := join(c, d, node, nodeName); e.state != stateError {
return e
}
}
return context{
state: stateError,
err: errorf(ErrBranchEnd, node, 0, "{{%s}} branches end in different contexts: %v, %v", nodeName, a, b),
}
}
// escapeBranch escapes a branch template node: "if", "range" and "with".
func (e *escaper) escapeBranch(c context, n *parse.BranchNode, nodeName string) context {
if nodeName == "range" {
e.rangeContext = &rangeContext{outer: e.rangeContext}
}
c0 := e.escapeList(c, n.List)
if nodeName == "range" {
if c0.state != stateError {
c0 = joinRange(c0, e.rangeContext)
}
e.rangeContext = e.rangeContext.outer
if c0.state == stateError {
return c0
}
// The "true" branch of a "range" node can execute multiple times.
// We check that executing n.List once results in the same context
// as executing n.List twice.
e.rangeContext = &rangeContext{outer: e.rangeContext}
c1, _ := e.escapeListConditionally(c0, n.List, nil)
c0 = join(c0, c1, n, nodeName)
if c0.state == stateError {
e.rangeContext = e.rangeContext.outer
// Make clear that this is a problem on loop re-entry
// since developers tend to overlook that branch when
// debugging templates.
c0.err.Line = n.Line
c0.err.Description = "on range loop re-entry: " + c0.err.Description
return c0
}
c0 = joinRange(c0, e.rangeContext)
e.rangeContext = e.rangeContext.outer
if c0.state == stateError {
return c0
}
}
c1 := e.escapeList(c, n.ElseList)
return join(c0, c1, n, nodeName)
}
func joinRange(c0 context, rc *rangeContext) context {
// Merge contexts at break and continue statements into overall body context.
// In theory we could treat breaks differently from continues, but for now it is
// enough to treat them both as going back to the start of the loop (which may then stop).
for _, c := range rc.breaks {
c0 = join(c0, c, c.n, "range")
if c0.state == stateError {
c0.err.Line = c.n.(*parse.BreakNode).Line
c0.err.Description = "at range loop break: " + c0.err.Description
return c0
}
}
for _, c := range rc.continues {
c0 = join(c0, c, c.n, "range")
if c0.state == stateError {
c0.err.Line = c.n.(*parse.ContinueNode).Line
c0.err.Description = "at range loop continue: " + c0.err.Description
return c0
}
}
return c0
}
// escapeList escapes a list template node.
func (e *escaper) escapeList(c context, n *parse.ListNode) context {
if n == nil {
return c
}
for _, m := range n.Nodes {
c = e.escape(c, m)
if c.state == stateDead {
break
}
}
return c
}
// escapeListConditionally escapes a list node but only preserves edits and
// inferences in e if the inferences and output context satisfy filter.
// It returns the best guess at an output context, and the result of the filter
// which is the same as whether e was updated.
func (e *escaper) escapeListConditionally(c context, n *parse.ListNode, filter func(*escaper, context) bool) (context, bool) {
e1 := makeEscaper(e.ns)
e1.rangeContext = e.rangeContext
// Make type inferences available to f.
maps.Copy(e1.output, e.output)
c = e1.escapeList(c, n)
ok := filter != nil && filter(&e1, c)
if ok {
// Copy inferences and edits from e1 back into e.
maps.Copy(e.output, e1.output)
maps.Copy(e.derived, e1.derived)
maps.Copy(e.called, e1.called)
for k, v := range e1.actionNodeEdits {
e.editActionNode(k, v)
}
for k, v := range e1.templateNodeEdits {
e.editTemplateNode(k, v)
}
for k, v := range e1.textNodeEdits {
e.editTextNode(k, v)
}
}
return c, ok
}
// escapeTemplate escapes a {{template}} call node.
func (e *escaper) escapeTemplate(c context, n *parse.TemplateNode) context {
c, name := e.escapeTree(c, n, n.Name, n.Line)
if name != n.Name {
e.editTemplateNode(n, name)
}
return c
}
// escapeTree escapes the named template starting in the given context as
// necessary and returns its output context.
func (e *escaper) escapeTree(c context, node parse.Node, name string, line int) (context, string) {
// Mangle the template name with the input context to produce a reliable
// identifier.
dname := c.mangle(name)
e.called[dname] = true
if out, ok := e.output[dname]; ok {
// Already escaped.
return out, dname
}
t := e.template(name)
if t == nil {
// Two cases: The template exists but is empty, or has never been mentioned at
// all. Distinguish the cases in the error messages.
if e.ns.set[name] != nil {
return context{
state: stateError,
err: errorf(ErrNoSuchTemplate, node, line, "%q is an incomplete or empty template", name),
}, dname
}
return context{
state: stateError,
err: errorf(ErrNoSuchTemplate, node, line, "no such template %q", name),
}, dname
}
if dname != name {
// Use any template derived during an earlier call to escapeTemplate
// with different top level templates, or clone if necessary.
dt := e.template(dname)
if dt == nil {
dt = template.New(dname)
dt.Tree = &parse.Tree{Name: dname, Root: t.Root.CopyList()}
e.derived[dname] = dt
}
t = dt
}
return e.computeOutCtx(c, t), dname
}
// computeOutCtx takes a template and its start context and computes the output
// context while storing any inferences in e.
func (e *escaper) computeOutCtx(c context, t *template.Template) context {
// Propagate context over the body.
c1, ok := e.escapeTemplateBody(c, t)
if !ok {
// Look for a fixed point by assuming c1 as the output context.
if c2, ok2 := e.escapeTemplateBody(c1, t); ok2 {
c1, ok = c2, true
}
// Use c1 as the error context if neither assumption worked.
}
if !ok && c1.state != stateError {
return context{
state: stateError,
err: errorf(ErrOutputContext, t.Tree.Root, 0, "cannot compute output context for template %s", t.Name()),
}
}
return c1
}
// escapeTemplateBody escapes the given template assuming the given output
// context, and returns the best guess at the output context and whether the
// assumption was correct.
func (e *escaper) escapeTemplateBody(c context, t *template.Template) (context, bool) {
filter := func(e1 *escaper, c1 context) bool {
if c1.state == stateError {
// Do not update the input escaper, e.
return false
}
if !e1.called[t.Name()] {
// If t is not recursively called, then c1 is an
// accurate output context.
return true
}
// c1 is accurate if it matches our assumed output context.
return c.eq(c1)
}
// We need to assume an output context so that recursive template calls
// take the fast path out of escapeTree instead of infinitely recurring.
// Naively assuming that the input context is the same as the output
// works >90% of the time.
e.output[t.Name()] = c
return e.escapeListConditionally(c, t.Tree.Root, filter)
}
// delimEnds maps each delim to a string of characters that terminate it.
var delimEnds = [...]string{
delimDoubleQuote: `"`,
delimSingleQuote: "'",
// Determined empirically by running the below in various browsers.
// var div = document.createElement("DIV");
// for (var i = 0; i < 0x10000; ++i) {
// div.innerHTML = "<span title=x" + String.fromCharCode(i) + "-bar>";
// if (div.getElementsByTagName("SPAN")[0].title.indexOf("bar") < 0)
// document.write("<p>U+" + i.toString(16));
// }
delimSpaceOrTagEnd: " \t\n\f\r>",
}
var (
// Per WHATWG HTML specification, section 4.12.1.3, there are extremely
// complicated rules for how to handle the set of opening tags <!--,
// <script, and </script when they appear in JS literals (i.e. strings,
// regexs, and comments). The specification suggests a simple solution,
// rather than implementing the arcane ABNF, which involves simply escaping
// the opening bracket with \x3C. We use the below regex for this, since it
// makes doing the case-insensitive find-replace much simpler.
specialScriptTagRE = regexp.MustCompile("(?i)<(script|/script|!--)")
specialScriptTagReplacement = []byte("\\x3C$1")
)
func containsSpecialScriptTag(s []byte) bool {
return specialScriptTagRE.Match(s)
}
func escapeSpecialScriptTags(s []byte) []byte {
return specialScriptTagRE.ReplaceAll(s, specialScriptTagReplacement)
}
var doctypeBytes = []byte("<!DOCTYPE")
// escapeText escapes a text template node.
func (e *escaper) escapeText(c context, n *parse.TextNode) context {
s, written, i, b := n.Text, 0, 0, new(bytes.Buffer)
for i != len(s) {
c1, nread := contextAfterText(c, s[i:])
i1 := i + nread
if c.state == stateText || c.state == stateRCDATA {
end := i1
if c1.state != c.state {
for j := end - 1; j >= i; j-- {
if s[j] == '<' {
end = j
break
}
}
}
for j := i; j < end; j++ {
if s[j] == '<' && !bytes.HasPrefix(bytes.ToUpper(s[j:]), doctypeBytes) {
b.Write(s[written:j])
b.WriteString("<")
written = j + 1
}
}
} else if isComment(c.state) && c.delim == delimNone {
switch c.state {
case stateJSBlockCmt:
// https://es5.github.io/#x7.4:
// "Comments behave like white space and are
// discarded except that, if a MultiLineComment
// contains a line terminator character, then
// the entire comment is considered to be a
// LineTerminator for purposes of parsing by
// the syntactic grammar."
if bytes.ContainsAny(s[written:i1], "\n\r\u2028\u2029") {
b.WriteByte('\n')
} else {
b.WriteByte(' ')
}
case stateCSSBlockCmt:
b.WriteByte(' ')
}
written = i1
}
if c.state != c1.state && isComment(c1.state) && c1.delim == delimNone {
// Preserve the portion between written and the comment start.
cs := i1 - 2
if c1.state == stateHTMLCmt || c1.state == stateJSHTMLOpenCmt {
// "<!--" instead of "/*" or "//"
cs -= 2
} else if c1.state == stateJSHTMLCloseCmt {
// "-->" instead of "/*" or "//"
cs -= 1
}
b.Write(s[written:cs])
written = i1
}
if isInScriptLiteral(c.state) && containsSpecialScriptTag(s[i:i1]) {
b.Write(s[written:i])
b.Write(escapeSpecialScriptTags(s[i:i1]))
written = i1
}
if i == i1 && c.state == c1.state {
panic(fmt.Sprintf("infinite loop from %v to %v on %q..%q", c, c1, s[:i], s[i:]))
}
c, i = c1, i1
}
if written != 0 && c.state != stateError {
if !isComment(c.state) || c.delim != delimNone {
b.Write(n.Text[written:])
}
e.editTextNode(n, b.Bytes())
}
return c
}
// contextAfterText starts in context c, consumes some tokens from the front of
// s, then returns the context after those tokens and the unprocessed suffix.
func contextAfterText(c context, s []byte) (context, int) {
if c.delim == delimNone {
c1, i := tSpecialTagEnd(c, s)
if i == 0 {
// A special end tag (`</script>`) has been seen and
// all content preceding it has been consumed.
return c1, 0
}
// Consider all content up to any end tag.
return transitionFunc[c.state](c, s[:i])
}
// We are at the beginning of an attribute value.
i := bytes.IndexAny(s, delimEnds[c.delim])
if i == -1 {
i = len(s)
}
if c.delim == delimSpaceOrTagEnd {
// https://www.w3.org/TR/html5/syntax.html#attribute-value-(unquoted)-state
// lists the runes below as error characters.
// Error out because HTML parsers may differ on whether
// "<a id= onclick=f(" ends inside id's or onclick's value,
// "<a class=`foo " ends inside a value,
// "<a style=font:'Arial'" needs open-quote fixup.
// IE treats '`' as a quotation character.
if j := bytes.IndexAny(s[:i], "\"'<=`"); j >= 0 {
return context{
state: stateError,
err: errorf(ErrBadHTML, nil, 0, "%q in unquoted attr: %q", s[j:j+1], s[:i]),
}, len(s)
}
}
if i == len(s) {
// Remain inside the attribute.
// Decode the value so non-HTML rules can easily handle
// <button onclick="alert("Hi!")">
// without having to entity decode token boundaries.
for u := []byte(html.UnescapeString(string(s))); len(u) != 0; {
c1, i1 := transitionFunc[c.state](c, u)
c, u = c1, u[i1:]
}
return c, len(s)
}
element := c.element
// If this is a non-JS "type" attribute inside "script" tag, do not treat the contents as JS.
if c.state == stateAttr && c.element == elementScript && c.attr == attrScriptType && !isJSType(string(s[:i])) {
element = elementNone
}
if c.delim != delimSpaceOrTagEnd {
// Consume any quote.
i++
}
// On exiting an attribute, we discard all state information
// except the state and element.
return context{state: stateTag, element: element}, i
}
// editActionNode records a change to an action pipeline for later commit.
func (e *escaper) editActionNode(n *parse.ActionNode, cmds []string) {
if _, ok := e.actionNodeEdits[n]; ok {
panic(fmt.Sprintf("node %s shared between templates", n))
}
e.actionNodeEdits[n] = cmds
}
// editTemplateNode records a change to a {{template}} callee for later commit.
func (e *escaper) editTemplateNode(n *parse.TemplateNode, callee string) {
if _, ok := e.templateNodeEdits[n]; ok {
panic(fmt.Sprintf("node %s shared between templates", n))
}
e.templateNodeEdits[n] = callee
}
// editTextNode records a change to a text node for later commit.
func (e *escaper) editTextNode(n *parse.TextNode, text []byte) {
if _, ok := e.textNodeEdits[n]; ok {
panic(fmt.Sprintf("node %s shared between templates", n))
}
e.textNodeEdits[n] = text
}
// commit applies changes to actions and template calls needed to contextually
// autoescape content and adds any derived templates to the set.
func (e *escaper) commit() {
for name := range e.output {
e.template(name).Funcs(funcMap)
}
// Any template from the name space associated with this escaper can be used
// to add derived templates to the underlying text/template name space.
tmpl := e.arbitraryTemplate()
for _, t := range e.derived {
if _, err := tmpl.text.AddParseTree(t.Name(), t.Tree); err != nil {
panic("error adding derived template")
}
}
for n, s := range e.actionNodeEdits {
ensurePipelineContains(n.Pipe, s)
}
for n, name := range e.templateNodeEdits {
n.Name = name
}
for n, s := range e.textNodeEdits {
n.Text = s
}
// Reset state that is specific to this commit so that the same changes are
// not re-applied to the template on subsequent calls to commit.
e.called = make(map[string]bool)
e.actionNodeEdits = make(map[*parse.ActionNode][]string)
e.templateNodeEdits = make(map[*parse.TemplateNode]string)
e.textNodeEdits = make(map[*parse.TextNode][]byte)
}
// template returns the named template given a mangled template name.
func (e *escaper) template(name string) *template.Template {
// Any template from the name space associated with this escaper can be used
// to look up templates in the underlying text/template name space.
t := e.arbitraryTemplate().text.Lookup(name)
if t == nil {
t = e.derived[name]
}
return t
}
// arbitraryTemplate returns an arbitrary template from the name space
// associated with e and panics if no templates are found.
func (e *escaper) arbitraryTemplate() *Template {
for _, t := range e.ns.set {
return t
}
panic("no templates in name space")
}
// Forwarding functions so that clients need only import this package
// to reach the general escaping functions of text/template.
// HTMLEscape writes to w the escaped HTML equivalent of the plain text data b.
func HTMLEscape(w io.Writer, b []byte) {
template.HTMLEscape(w, b)
}
// HTMLEscapeString returns the escaped HTML equivalent of the plain text data s.
func HTMLEscapeString(s string) string {
return template.HTMLEscapeString(s)
}
// HTMLEscaper returns the escaped HTML equivalent of the textual
// representation of its arguments.
func HTMLEscaper(args ...any) string {
return template.HTMLEscaper(args...)
}
// JSEscape writes to w the escaped JavaScript equivalent of the plain text data b.
func JSEscape(w io.Writer, b []byte) {
template.JSEscape(w, b)
}
// JSEscapeString returns the escaped JavaScript equivalent of the plain text data s.
func JSEscapeString(s string) string {
return template.JSEscapeString(s)
}
// JSEscaper returns the escaped JavaScript equivalent of the textual
// representation of its arguments.
func JSEscaper(args ...any) string {
return template.JSEscaper(args...)
}
// URLQueryEscaper returns the escaped value of the textual representation of
// its arguments in a form suitable for embedding in a URL query.
func URLQueryEscaper(args ...any) string {
return template.URLQueryEscaper(args...)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"bytes"
"fmt"
"strings"
"unicode/utf8"
)
// htmlNospaceEscaper escapes for inclusion in unquoted attribute values.
func htmlNospaceEscaper(args ...any) string {
s, t := stringify(args...)
if s == "" {
return filterFailsafe
}
if t == contentTypeHTML {
return htmlReplacer(stripTags(s), htmlNospaceNormReplacementTable, false)
}
return htmlReplacer(s, htmlNospaceReplacementTable, false)
}
// attrEscaper escapes for inclusion in quoted attribute values.
func attrEscaper(args ...any) string {
s, t := stringify(args...)
if t == contentTypeHTML {
return htmlReplacer(stripTags(s), htmlNormReplacementTable, true)
}
return htmlReplacer(s, htmlReplacementTable, true)
}
// rcdataEscaper escapes for inclusion in an RCDATA element body.
func rcdataEscaper(args ...any) string {
s, t := stringify(args...)
if t == contentTypeHTML {
return htmlReplacer(s, htmlNormReplacementTable, true)
}
return htmlReplacer(s, htmlReplacementTable, true)
}
// htmlEscaper escapes for inclusion in HTML text.
func htmlEscaper(args ...any) string {
s, t := stringify(args...)
if t == contentTypeHTML {
return s
}
return htmlReplacer(s, htmlReplacementTable, true)
}
// htmlReplacementTable contains the runes that need to be escaped
// inside a quoted attribute value or in a text node.
var htmlReplacementTable = []string{
// https://www.w3.org/TR/html5/syntax.html#attribute-value-(unquoted)-state
// U+0000 NULL Parse error. Append a U+FFFD REPLACEMENT
// CHARACTER character to the current attribute's value.
// "
// and similarly
// https://www.w3.org/TR/html5/syntax.html#before-attribute-value-state
0: "\uFFFD",
'"': """,
'&': "&",
'\'': "'",
'+': "+",
'<': "<",
'>': ">",
}
// htmlNormReplacementTable is like htmlReplacementTable but without '&' to
// avoid over-encoding existing entities.
var htmlNormReplacementTable = []string{
0: "\uFFFD",
'"': """,
'\'': "'",
'+': "+",
'<': "<",
'>': ">",
}
// htmlNospaceReplacementTable contains the runes that need to be escaped
// inside an unquoted attribute value.
// The set of runes escaped is the union of the HTML specials and
// those determined by running the JS below in browsers:
// <div id=d></div>
// <script>(function () {
// var a = [], d = document.getElementById("d"), i, c, s;
// for (i = 0; i < 0x10000; ++i) {
//
// c = String.fromCharCode(i);
// d.innerHTML = "<span title=" + c + "lt" + c + "></span>"
// s = d.getElementsByTagName("SPAN")[0];
// if (!s || s.title !== c + "lt" + c) { a.push(i.toString(16)); }
//
// }
// document.write(a.join(", "));
// })()</script>
var htmlNospaceReplacementTable = []string{
0: "�",
'\t': "	",
'\n': " ",
'\v': "",
'\f': "",
'\r': " ",
' ': " ",
'"': """,
'&': "&",
'\'': "'",
'+': "+",
'<': "<",
'=': "=",
'>': ">",
// A parse error in the attribute value (unquoted) and
// before attribute value states.
// Treated as a quoting character by IE.
'`': "`",
}
// htmlNospaceNormReplacementTable is like htmlNospaceReplacementTable but
// without '&' to avoid over-encoding existing entities.
var htmlNospaceNormReplacementTable = []string{
0: "�",
'\t': "	",
'\n': " ",
'\v': "",
'\f': "",
'\r': " ",
' ': " ",
'"': """,
'\'': "'",
'+': "+",
'<': "<",
'=': "=",
'>': ">",
// A parse error in the attribute value (unquoted) and
// before attribute value states.
// Treated as a quoting character by IE.
'`': "`",
}
// htmlReplacer returns s with runes replaced according to replacementTable
// and when badRunes is true, certain bad runes are allowed through unescaped.
func htmlReplacer(s string, replacementTable []string, badRunes bool) string {
written, b := 0, new(strings.Builder)
r, w := rune(0), 0
for i := 0; i < len(s); i += w {
// Cannot use 'for range s' because we need to preserve the width
// of the runes in the input. If we see a decoding error, the input
// width will not be utf8.Runelen(r) and we will overrun the buffer.
r, w = utf8.DecodeRuneInString(s[i:])
if int(r) < len(replacementTable) {
if repl := replacementTable[r]; len(repl) != 0 {
if written == 0 {
b.Grow(len(s))
}
b.WriteString(s[written:i])
b.WriteString(repl)
written = i + w
}
} else if badRunes {
// No-op.
// IE does not allow these ranges in unquoted attrs.
} else if 0xfdd0 <= r && r <= 0xfdef || 0xfff0 <= r && r <= 0xffff {
if written == 0 {
b.Grow(len(s))
}
fmt.Fprintf(b, "%s&#x%x;", s[written:i], r)
written = i + w
}
}
if written == 0 {
return s
}
b.WriteString(s[written:])
return b.String()
}
// stripTags takes a snippet of HTML and returns only the text content.
// For example, `<b>¡Hi!</b> <script>...</script>` -> `¡Hi! `.
func stripTags(html string) string {
var b strings.Builder
s, c, i, allText := []byte(html), context{}, 0, true
// Using the transition funcs helps us avoid mangling
// `<div title="1>2">` or `I <3 Ponies!`.
for i != len(s) {
if c.delim == delimNone {
st := c.state
// Use RCDATA instead of parsing into JS or CSS styles.
if c.element != elementNone && !isInTag(st) {
st = stateRCDATA
}
d, nread := transitionFunc[st](c, s[i:])
i1 := i + nread
if c.state == stateText || c.state == stateRCDATA {
// Emit text up to the start of the tag or comment.
j := i1
if d.state != c.state {
for j1 := j - 1; j1 >= i; j1-- {
if s[j1] == '<' {
j = j1
break
}
}
}
b.Write(s[i:j])
} else {
allText = false
}
c, i = d, i1
continue
}
i1 := i + bytes.IndexAny(s[i:], delimEnds[c.delim])
if i1 < i {
break
}
if c.delim != delimSpaceOrTagEnd {
// Consume any quote.
i1++
}
c, i = context{state: stateTag, element: c.element}, i1
}
if allText {
return html
} else if c.state == stateText || c.state == stateRCDATA {
b.Write(s[i:])
}
return b.String()
}
// htmlNameFilter accepts valid parts of an HTML attribute or tag name or
// a known-safe HTML attribute.
func htmlNameFilter(args ...any) string {
s, t := stringify(args...)
if t == contentTypeHTMLAttr {
return s
}
if len(s) == 0 {
// Avoid violation of structure preservation.
// <input checked {{.K}}={{.V}}>.
// Without this, if .K is empty then .V is the value of
// checked, but otherwise .V is the value of the attribute
// named .K.
return filterFailsafe
}
s = strings.ToLower(s)
if t := attrType(s); t != contentTypePlain {
// TODO: Split attr and element name part filters so we can recognize known attributes.
return filterFailsafe
}
for _, r := range s {
switch {
case '0' <= r && r <= '9':
case 'a' <= r && r <= 'z':
default:
return filterFailsafe
}
}
return s
}
// commentEscaper returns the empty string regardless of input.
// Comment content does not correspond to any parsed structure or
// human-readable content, so the simplest and most secure policy is to drop
// content interpolated into comments.
// This approach is equally valid whether or not static comment content is
// removed from the template.
func commentEscaper(args ...any) string {
return ""
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"regexp"
"strings"
"unicode/utf8"
)
// jsWhitespace contains all of the JS whitespace characters, as defined
// by the \s character class.
// See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Regular_expressions/Character_classes.
const jsWhitespace = "\f\n\r\t\v\u0020\u00a0\u1680\u2000\u2001\u2002\u2003\u2004\u2005\u2006\u2007\u2008\u2009\u200a\u2028\u2029\u202f\u205f\u3000\ufeff"
// nextJSCtx returns the context that determines whether a slash after the
// given run of tokens starts a regular expression instead of a division
// operator: / or /=.
//
// This assumes that the token run does not include any string tokens, comment
// tokens, regular expression literal tokens, or division operators.
//
// This fails on some valid but nonsensical JavaScript programs like
// "x = ++/foo/i" which is quite different than "x++/foo/i", but is not known to
// fail on any known useful programs. It is based on the draft
// JavaScript 2.0 lexical grammar and requires one token of lookbehind:
// https://www.mozilla.org/js/language/js20-2000-07/rationale/syntax.html
func nextJSCtx(s []byte, preceding jsCtx) jsCtx {
// Trim all JS whitespace characters
s = bytes.TrimRight(s, jsWhitespace)
if len(s) == 0 {
return preceding
}
// All cases below are in the single-byte UTF-8 group.
switch c, n := s[len(s)-1], len(s); c {
case '+', '-':
// ++ and -- are not regexp preceders, but + and - are whether
// they are used as infix or prefix operators.
start := n - 1
// Count the number of adjacent dashes or pluses.
for start > 0 && s[start-1] == c {
start--
}
if (n-start)&1 == 1 {
// Reached for trailing minus signs since "---" is the
// same as "-- -".
return jsCtxRegexp
}
return jsCtxDivOp
case '.':
// Handle "42."
if n != 1 && '0' <= s[n-2] && s[n-2] <= '9' {
return jsCtxDivOp
}
return jsCtxRegexp
// Suffixes for all punctuators from section 7.7 of the language spec
// that only end binary operators not handled above.
case ',', '<', '>', '=', '*', '%', '&', '|', '^', '?':
return jsCtxRegexp
// Suffixes for all punctuators from section 7.7 of the language spec
// that are prefix operators not handled above.
case '!', '~':
return jsCtxRegexp
// Matches all the punctuators from section 7.7 of the language spec
// that are open brackets not handled above.
case '(', '[':
return jsCtxRegexp
// Matches all the punctuators from section 7.7 of the language spec
// that precede expression starts.
case ':', ';', '{':
return jsCtxRegexp
// CAVEAT: the close punctuators ('}', ']', ')') precede div ops and
// are handled in the default except for '}' which can precede a
// division op as in
// ({ valueOf: function () { return 42 } } / 2
// which is valid, but, in practice, developers don't divide object
// literals, so our heuristic works well for code like
// function () { ... } /foo/.test(x) && sideEffect();
// The ')' punctuator can precede a regular expression as in
// if (b) /foo/.test(x) && ...
// but this is much less likely than
// (a + b) / c
case '}':
return jsCtxRegexp
default:
// Look for an IdentifierName and see if it is a keyword that
// can precede a regular expression.
j := n
for j > 0 && isJSIdentPart(rune(s[j-1])) {
j--
}
if regexpPrecederKeywords[string(s[j:])] {
return jsCtxRegexp
}
}
// Otherwise is a punctuator not listed above, or
// a string which precedes a div op, or an identifier
// which precedes a div op.
return jsCtxDivOp
}
// regexpPrecederKeywords is a set of reserved JS keywords that can precede a
// regular expression in JS source.
var regexpPrecederKeywords = map[string]bool{
"break": true,
"case": true,
"continue": true,
"delete": true,
"do": true,
"else": true,
"finally": true,
"in": true,
"instanceof": true,
"return": true,
"throw": true,
"try": true,
"typeof": true,
"void": true,
}
var jsonMarshalType = reflect.TypeFor[json.Marshaler]()
// indirectToJSONMarshaler returns the value, after dereferencing as many times
// as necessary to reach the base type (or nil) or an implementation of json.Marshal.
func indirectToJSONMarshaler(a any) any {
// text/template now supports passing untyped nil as a func call
// argument, so we must support it. Otherwise we'd panic below, as one
// cannot call the Type or Interface methods on an invalid
// reflect.Value. See golang.org/issue/18716.
if a == nil {
return nil
}
v := reflect.ValueOf(a)
for !v.Type().Implements(jsonMarshalType) && v.Kind() == reflect.Pointer && !v.IsNil() {
v = v.Elem()
}
return v.Interface()
}
var scriptTagRe = regexp.MustCompile("(?i)<(/?)script")
// jsValEscaper escapes its inputs to a JS Expression (section 11.14) that has
// neither side-effects nor free variables outside (NaN, Infinity).
func jsValEscaper(args ...any) string {
var a any
if len(args) == 1 {
a = indirectToJSONMarshaler(args[0])
switch t := a.(type) {
case JS:
return string(t)
case JSStr:
// TODO: normalize quotes.
return `"` + string(t) + `"`
case json.Marshaler:
// Do not treat as a Stringer.
case fmt.Stringer:
a = t.String()
}
} else {
for i, arg := range args {
args[i] = indirectToJSONMarshaler(arg)
}
a = fmt.Sprint(args...)
}
// TODO: detect cycles before calling Marshal which loops infinitely on
// cyclic data. This may be an unacceptable DoS risk.
b, err := json.Marshal(a)
if err != nil {
// While the standard JSON marshaler does not include user controlled
// information in the error message, if a type has a MarshalJSON method,
// the content of the error message is not guaranteed. Since we insert
// the error into the template, as part of a comment, we attempt to
// prevent the error from either terminating the comment, or the script
// block itself.
//
// In particular we:
// * replace "*/" comment end tokens with "* /", which does not
// terminate the comment
// * replace "<script" and "</script" with "\x3Cscript" and "\x3C/script"
// (case insensitively), and "<!--" with "\x3C!--", which prevents
// confusing script block termination semantics
//
// We also put a space before the comment so that if it is flush against
// a division operator it is not turned into a line comment:
// x/{{y}}
// turning into
// x//* error marshaling y:
// second line of error message */null
errStr := err.Error()
errStr = string(scriptTagRe.ReplaceAll([]byte(errStr), []byte(`\x3C${1}script`)))
errStr = strings.ReplaceAll(errStr, "*/", "* /")
errStr = strings.ReplaceAll(errStr, "<!--", `\x3C!--`)
return fmt.Sprintf(" /* %s */null ", errStr)
}
// TODO: maybe post-process output to prevent it from containing
// "<!--", "-->", "<![CDATA[", "]]>", or "</script"
// in case custom marshalers produce output containing those.
// Note: Do not use \x escaping to save bytes because it is not JSON compatible and this escaper
// supports ld+json content-type.
if len(b) == 0 {
// In, `x=y/{{.}}*z` a json.Marshaler that produces "" should
// not cause the output `x=y/*z`.
return " null "
}
first, _ := utf8.DecodeRune(b)
last, _ := utf8.DecodeLastRune(b)
var buf strings.Builder
// Prevent IdentifierNames and NumericLiterals from running into
// keywords: in, instanceof, typeof, void
pad := isJSIdentPart(first) || isJSIdentPart(last)
if pad {
buf.WriteByte(' ')
}
written := 0
// Make sure that json.Marshal escapes codepoints U+2028 & U+2029
// so it falls within the subset of JSON which is valid JS.
for i := 0; i < len(b); {
rune, n := utf8.DecodeRune(b[i:])
repl := ""
if rune == 0x2028 {
repl = `\u2028`
} else if rune == 0x2029 {
repl = `\u2029`
}
if repl != "" {
buf.Write(b[written:i])
buf.WriteString(repl)
written = i + n
}
i += n
}
if buf.Len() != 0 {
buf.Write(b[written:])
if pad {
buf.WriteByte(' ')
}
return buf.String()
}
return string(b)
}
// jsStrEscaper produces a string that can be included between quotes in
// JavaScript source, in JavaScript embedded in an HTML5 <script> element,
// or in an HTML5 event handler attribute such as onclick.
func jsStrEscaper(args ...any) string {
s, t := stringify(args...)
if t == contentTypeJSStr {
return replace(s, jsStrNormReplacementTable)
}
return replace(s, jsStrReplacementTable)
}
func jsTmplLitEscaper(args ...any) string {
s, _ := stringify(args...)
return replace(s, jsBqStrReplacementTable)
}
// jsRegexpEscaper behaves like jsStrEscaper but escapes regular expression
// specials so the result is treated literally when included in a regular
// expression literal. /foo{{.X}}bar/ matches the string "foo" followed by
// the literal text of {{.X}} followed by the string "bar".
func jsRegexpEscaper(args ...any) string {
s, _ := stringify(args...)
s = replace(s, jsRegexpReplacementTable)
if s == "" {
// /{{.X}}/ should not produce a line comment when .X == "".
return "(?:)"
}
return s
}
// replace replaces each rune r of s with replacementTable[r], provided that
// r < len(replacementTable). If replacementTable[r] is the empty string then
// no replacement is made.
// It also replaces runes U+2028 and U+2029 with the raw strings `\u2028` and
// `\u2029`.
func replace(s string, replacementTable []string) string {
var b strings.Builder
r, w, written := rune(0), 0, 0
for i := 0; i < len(s); i += w {
// See comment in htmlEscaper.
r, w = utf8.DecodeRuneInString(s[i:])
var repl string
switch {
case int(r) < len(lowUnicodeReplacementTable):
repl = lowUnicodeReplacementTable[r]
case int(r) < len(replacementTable) && replacementTable[r] != "":
repl = replacementTable[r]
case r == '\u2028':
repl = `\u2028`
case r == '\u2029':
repl = `\u2029`
default:
continue
}
if written == 0 {
b.Grow(len(s))
}
b.WriteString(s[written:i])
b.WriteString(repl)
written = i + w
}
if written == 0 {
return s
}
b.WriteString(s[written:])
return b.String()
}
var lowUnicodeReplacementTable = []string{
0: `\u0000`, 1: `\u0001`, 2: `\u0002`, 3: `\u0003`, 4: `\u0004`, 5: `\u0005`, 6: `\u0006`,
'\a': `\u0007`,
'\b': `\u0008`,
'\t': `\t`,
'\n': `\n`,
'\v': `\u000b`, // "\v" == "v" on IE 6.
'\f': `\f`,
'\r': `\r`,
0xe: `\u000e`, 0xf: `\u000f`, 0x10: `\u0010`, 0x11: `\u0011`, 0x12: `\u0012`, 0x13: `\u0013`,
0x14: `\u0014`, 0x15: `\u0015`, 0x16: `\u0016`, 0x17: `\u0017`, 0x18: `\u0018`, 0x19: `\u0019`,
0x1a: `\u001a`, 0x1b: `\u001b`, 0x1c: `\u001c`, 0x1d: `\u001d`, 0x1e: `\u001e`, 0x1f: `\u001f`,
}
var jsStrReplacementTable = []string{
0: `\u0000`,
'\t': `\t`,
'\n': `\n`,
'\v': `\u000b`, // "\v" == "v" on IE 6.
'\f': `\f`,
'\r': `\r`,
// Encode HTML specials as hex so the output can be embedded
// in HTML attributes without further encoding.
'"': `\u0022`,
'`': `\u0060`,
'&': `\u0026`,
'\'': `\u0027`,
'+': `\u002b`,
'/': `\/`,
'<': `\u003c`,
'>': `\u003e`,
'\\': `\\`,
}
// jsBqStrReplacementTable is like jsStrReplacementTable except it also contains
// the special characters for JS template literals: $, {, and }.
var jsBqStrReplacementTable = []string{
0: `\u0000`,
'\t': `\t`,
'\n': `\n`,
'\v': `\u000b`, // "\v" == "v" on IE 6.
'\f': `\f`,
'\r': `\r`,
// Encode HTML specials as hex so the output can be embedded
// in HTML attributes without further encoding.
'"': `\u0022`,
'`': `\u0060`,
'&': `\u0026`,
'\'': `\u0027`,
'+': `\u002b`,
'/': `\/`,
'<': `\u003c`,
'>': `\u003e`,
'\\': `\\`,
'$': `\u0024`,
'{': `\u007b`,
'}': `\u007d`,
}
// jsStrNormReplacementTable is like jsStrReplacementTable but does not
// overencode existing escapes since this table has no entry for `\`.
var jsStrNormReplacementTable = []string{
0: `\u0000`,
'\t': `\t`,
'\n': `\n`,
'\v': `\u000b`, // "\v" == "v" on IE 6.
'\f': `\f`,
'\r': `\r`,
// Encode HTML specials as hex so the output can be embedded
// in HTML attributes without further encoding.
'"': `\u0022`,
'&': `\u0026`,
'\'': `\u0027`,
'`': `\u0060`,
'+': `\u002b`,
'/': `\/`,
'<': `\u003c`,
'>': `\u003e`,
}
var jsRegexpReplacementTable = []string{
0: `\u0000`,
'\t': `\t`,
'\n': `\n`,
'\v': `\u000b`, // "\v" == "v" on IE 6.
'\f': `\f`,
'\r': `\r`,
// Encode HTML specials as hex so the output can be embedded
// in HTML attributes without further encoding.
'"': `\u0022`,
'$': `\$`,
'&': `\u0026`,
'\'': `\u0027`,
'(': `\(`,
')': `\)`,
'*': `\*`,
'+': `\u002b`,
'-': `\-`,
'.': `\.`,
'/': `\/`,
'<': `\u003c`,
'>': `\u003e`,
'?': `\?`,
'[': `\[`,
'\\': `\\`,
']': `\]`,
'^': `\^`,
'{': `\{`,
'|': `\|`,
'}': `\}`,
}
// isJSIdentPart reports whether the given rune is a JS identifier part.
// It does not handle all the non-Latin letters, joiners, and combining marks,
// but it does handle every codepoint that can occur in a numeric literal or
// a keyword.
func isJSIdentPart(r rune) bool {
switch {
case r == '$':
return true
case '0' <= r && r <= '9':
return true
case 'A' <= r && r <= 'Z':
return true
case r == '_':
return true
case 'a' <= r && r <= 'z':
return true
}
return false
}
// isJSType reports whether the given MIME type should be considered JavaScript.
//
// It is used to determine whether a script tag with a type attribute is a javascript container.
func isJSType(mimeType string) bool {
// per
// https://www.w3.org/TR/html5/scripting-1.html#attr-script-type
// https://tools.ietf.org/html/rfc7231#section-3.1.1
// https://tools.ietf.org/html/rfc4329#section-3
// https://www.ietf.org/rfc/rfc4627.txt
// discard parameters
mimeType, _, _ = strings.Cut(mimeType, ";")
mimeType = strings.ToLower(mimeType)
mimeType = strings.TrimSpace(mimeType)
switch mimeType {
case
"application/ecmascript",
"application/javascript",
"application/json",
"application/ld+json",
"application/x-ecmascript",
"application/x-javascript",
"module",
"text/ecmascript",
"text/javascript",
"text/javascript1.0",
"text/javascript1.1",
"text/javascript1.2",
"text/javascript1.3",
"text/javascript1.4",
"text/javascript1.5",
"text/jscript",
"text/livescript",
"text/x-ecmascript",
"text/x-javascript":
return true
default:
return false
}
}
// Code generated by "stringer -type jsCtx"; DO NOT EDIT.
package template
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[jsCtxRegexp-0]
_ = x[jsCtxDivOp-1]
_ = x[jsCtxUnknown-2]
}
const _jsCtx_name = "jsCtxRegexpjsCtxDivOpjsCtxUnknown"
var _jsCtx_index = [...]uint8{0, 11, 21, 33}
func (i jsCtx) String() string {
if i >= jsCtx(len(_jsCtx_index)-1) {
return "jsCtx(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _jsCtx_name[_jsCtx_index[i]:_jsCtx_index[i+1]]
}
// Code generated by "stringer -type state"; DO NOT EDIT.
package template
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[stateText-0]
_ = x[stateTag-1]
_ = x[stateAttrName-2]
_ = x[stateAfterName-3]
_ = x[stateBeforeValue-4]
_ = x[stateHTMLCmt-5]
_ = x[stateRCDATA-6]
_ = x[stateAttr-7]
_ = x[stateURL-8]
_ = x[stateSrcset-9]
_ = x[stateJS-10]
_ = x[stateJSDqStr-11]
_ = x[stateJSSqStr-12]
_ = x[stateJSTmplLit-13]
_ = x[stateJSRegexp-14]
_ = x[stateJSBlockCmt-15]
_ = x[stateJSLineCmt-16]
_ = x[stateJSHTMLOpenCmt-17]
_ = x[stateJSHTMLCloseCmt-18]
_ = x[stateCSS-19]
_ = x[stateCSSDqStr-20]
_ = x[stateCSSSqStr-21]
_ = x[stateCSSDqURL-22]
_ = x[stateCSSSqURL-23]
_ = x[stateCSSURL-24]
_ = x[stateCSSBlockCmt-25]
_ = x[stateCSSLineCmt-26]
_ = x[stateError-27]
_ = x[stateDead-28]
}
const _state_name = "stateTextstateTagstateAttrNamestateAfterNamestateBeforeValuestateHTMLCmtstateRCDATAstateAttrstateURLstateSrcsetstateJSstateJSDqStrstateJSSqStrstateJSTmplLitstateJSRegexpstateJSBlockCmtstateJSLineCmtstateJSHTMLOpenCmtstateJSHTMLCloseCmtstateCSSstateCSSDqStrstateCSSSqStrstateCSSDqURLstateCSSSqURLstateCSSURLstateCSSBlockCmtstateCSSLineCmtstateErrorstateDead"
var _state_index = [...]uint16{0, 9, 17, 30, 44, 60, 72, 83, 92, 100, 111, 118, 130, 142, 156, 169, 184, 198, 216, 235, 243, 256, 269, 282, 295, 306, 322, 337, 347, 356}
func (i state) String() string {
if i >= state(len(_state_index)-1) {
return "state(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _state_name[_state_index[i]:_state_index[i+1]]
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"fmt"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"sync"
"text/template"
"text/template/parse"
)
// Template is a specialized Template from "text/template" that produces a safe
// HTML document fragment.
type Template struct {
// Sticky error if escaping fails, or escapeOK if succeeded.
escapeErr error
// We could embed the text/template field, but it's safer not to because
// we need to keep our version of the name space and the underlying
// template's in sync.
text *template.Template
// The underlying template's parse tree, updated to be HTML-safe.
Tree *parse.Tree
*nameSpace // common to all associated templates
}
// escapeOK is a sentinel value used to indicate valid escaping.
var escapeOK = fmt.Errorf("template escaped correctly")
// nameSpace is the data structure shared by all templates in an association.
type nameSpace struct {
mu sync.Mutex
set map[string]*Template
escaped bool
esc escaper
}
// Templates returns a slice of the templates associated with t, including t
// itself.
func (t *Template) Templates() []*Template {
ns := t.nameSpace
ns.mu.Lock()
defer ns.mu.Unlock()
// Return a slice so we don't expose the map.
m := make([]*Template, 0, len(ns.set))
for _, v := range ns.set {
m = append(m, v)
}
return m
}
// Option sets options for the template. Options are described by
// strings, either a simple string or "key=value". There can be at
// most one equals sign in an option string. If the option string
// is unrecognized or otherwise invalid, Option panics.
//
// Known options:
//
// missingkey: Control the behavior during execution if a map is
// indexed with a key that is not present in the map.
//
// "missingkey=default" or "missingkey=invalid"
// The default behavior: Do nothing and continue execution.
// If printed, the result of the index operation is the string
// "<no value>".
// "missingkey=zero"
// The operation returns the zero value for the map type's element.
// "missingkey=error"
// Execution stops immediately with an error.
func (t *Template) Option(opt ...string) *Template {
t.text.Option(opt...)
return t
}
// checkCanParse checks whether it is OK to parse templates.
// If not, it returns an error.
func (t *Template) checkCanParse() error {
if t == nil {
return nil
}
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
if t.nameSpace.escaped {
return fmt.Errorf("html/template: cannot Parse after Execute")
}
return nil
}
// escape escapes all associated templates.
func (t *Template) escape() error {
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
t.nameSpace.escaped = true
if t.escapeErr == nil {
if t.Tree == nil {
return fmt.Errorf("template: %q is an incomplete or empty template", t.Name())
}
if err := escapeTemplate(t, t.text.Root, t.Name()); err != nil {
return err
}
} else if t.escapeErr != escapeOK {
return t.escapeErr
}
return nil
}
// Execute applies a parsed template to the specified data object,
// writing the output to wr.
// If an error occurs executing the template or writing its output,
// execution stops, but partial results may already have been written to
// the output writer.
// A template may be executed safely in parallel, although if parallel
// executions share a Writer the output may be interleaved.
func (t *Template) Execute(wr io.Writer, data any) error {
if err := t.escape(); err != nil {
return err
}
return t.text.Execute(wr, data)
}
// ExecuteTemplate applies the template associated with t that has the given
// name to the specified data object and writes the output to wr.
// If an error occurs executing the template or writing its output,
// execution stops, but partial results may already have been written to
// the output writer.
// A template may be executed safely in parallel, although if parallel
// executions share a Writer the output may be interleaved.
func (t *Template) ExecuteTemplate(wr io.Writer, name string, data any) error {
tmpl, err := t.lookupAndEscapeTemplate(name)
if err != nil {
return err
}
return tmpl.text.Execute(wr, data)
}
// lookupAndEscapeTemplate guarantees that the template with the given name
// is escaped, or returns an error if it cannot be. It returns the named
// template.
func (t *Template) lookupAndEscapeTemplate(name string) (tmpl *Template, err error) {
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
t.nameSpace.escaped = true
tmpl = t.set[name]
if tmpl == nil {
return nil, fmt.Errorf("html/template: %q is undefined", name)
}
if tmpl.escapeErr != nil && tmpl.escapeErr != escapeOK {
return nil, tmpl.escapeErr
}
if tmpl.text.Tree == nil || tmpl.text.Root == nil {
return nil, fmt.Errorf("html/template: %q is an incomplete template", name)
}
if t.text.Lookup(name) == nil {
panic("html/template internal error: template escaping out of sync")
}
if tmpl.escapeErr == nil {
err = escapeTemplate(tmpl, tmpl.text.Root, name)
}
return tmpl, err
}
// DefinedTemplates returns a string listing the defined templates,
// prefixed by the string "; defined templates are: ". If there are none,
// it returns the empty string. Used to generate an error message.
func (t *Template) DefinedTemplates() string {
return t.text.DefinedTemplates()
}
// Parse parses text as a template body for t.
// Named template definitions ({{define ...}} or {{block ...}} statements) in text
// define additional templates associated with t and are removed from the
// definition of t itself.
//
// Templates can be redefined in successive calls to Parse,
// before the first use of [Template.Execute] on t or any associated template.
// A template definition with a body containing only white space and comments
// is considered empty and will not replace an existing template's body.
// This allows using Parse to add new named template definitions without
// overwriting the main template body.
func (t *Template) Parse(text string) (*Template, error) {
if err := t.checkCanParse(); err != nil {
return nil, err
}
ret, err := t.text.Parse(text)
if err != nil {
return nil, err
}
// In general, all the named templates might have changed underfoot.
// Regardless, some new ones may have been defined.
// The template.Template set has been updated; update ours.
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
for _, v := range ret.Templates() {
name := v.Name()
tmpl := t.set[name]
if tmpl == nil {
tmpl = t.new(name)
}
tmpl.text = v
tmpl.Tree = v.Tree
}
return t, nil
}
// AddParseTree creates a new template with the name and parse tree
// and associates it with t.
//
// It returns an error if t or any associated template has already been executed.
func (t *Template) AddParseTree(name string, tree *parse.Tree) (*Template, error) {
if err := t.checkCanParse(); err != nil {
return nil, err
}
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
text, err := t.text.AddParseTree(name, tree)
if err != nil {
return nil, err
}
ret := &Template{
nil,
text,
text.Tree,
t.nameSpace,
}
t.set[name] = ret
return ret, nil
}
// Clone returns a duplicate of the template, including all associated
// templates. The actual representation is not copied, but the name space of
// associated templates is, so further calls to [Template.Parse] in the copy will add
// templates to the copy but not to the original. [Template.Clone] can be used to prepare
// common templates and use them with variant definitions for other templates
// by adding the variants after the clone is made.
//
// It returns an error if t has already been executed.
func (t *Template) Clone() (*Template, error) {
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
if t.escapeErr != nil {
return nil, fmt.Errorf("html/template: cannot Clone %q after it has executed", t.Name())
}
textClone, err := t.text.Clone()
if err != nil {
return nil, err
}
ns := &nameSpace{set: make(map[string]*Template)}
ns.esc = makeEscaper(ns)
ret := &Template{
nil,
textClone,
textClone.Tree,
ns,
}
ret.set[ret.Name()] = ret
for _, x := range textClone.Templates() {
name := x.Name()
src := t.set[name]
if src == nil || src.escapeErr != nil {
return nil, fmt.Errorf("html/template: cannot Clone %q after it has executed", t.Name())
}
x.Tree = x.Tree.Copy()
ret.set[name] = &Template{
nil,
x,
x.Tree,
ret.nameSpace,
}
}
// Return the template associated with the name of this template.
return ret.set[ret.Name()], nil
}
// New allocates a new HTML template with the given name.
func New(name string) *Template {
ns := &nameSpace{set: make(map[string]*Template)}
ns.esc = makeEscaper(ns)
tmpl := &Template{
nil,
template.New(name),
nil,
ns,
}
tmpl.set[name] = tmpl
return tmpl
}
// New allocates a new HTML template associated with the given one
// and with the same delimiters. The association, which is transitive,
// allows one template to invoke another with a {{template}} action.
//
// If a template with the given name already exists, the new HTML template
// will replace it. The existing template will be reset and disassociated with
// t.
func (t *Template) New(name string) *Template {
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
return t.new(name)
}
// new is the implementation of New, without the lock.
func (t *Template) new(name string) *Template {
tmpl := &Template{
nil,
t.text.New(name),
nil,
t.nameSpace,
}
if existing, ok := tmpl.set[name]; ok {
emptyTmpl := New(existing.Name())
*existing = *emptyTmpl
}
tmpl.set[name] = tmpl
return tmpl
}
// Name returns the name of the template.
func (t *Template) Name() string {
return t.text.Name()
}
type FuncMap = template.FuncMap
// Funcs adds the elements of the argument map to the template's function map.
// It must be called before the template is parsed.
// It panics if a value in the map is not a function with appropriate return
// type. However, it is legal to overwrite elements of the map. The return
// value is the template, so calls can be chained.
func (t *Template) Funcs(funcMap FuncMap) *Template {
t.text.Funcs(template.FuncMap(funcMap))
return t
}
// Delims sets the action delimiters to the specified strings, to be used in
// subsequent calls to [Template.Parse], [ParseFiles], or [ParseGlob]. Nested template
// definitions will inherit the settings. An empty delimiter stands for the
// corresponding default: {{ or }}.
// The return value is the template, so calls can be chained.
func (t *Template) Delims(left, right string) *Template {
t.text.Delims(left, right)
return t
}
// Lookup returns the template with the given name that is associated with t,
// or nil if there is no such template.
func (t *Template) Lookup(name string) *Template {
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
return t.set[name]
}
// Must is a helper that wraps a call to a function returning ([*Template], error)
// and panics if the error is non-nil. It is intended for use in variable initializations
// such as
//
// var t = template.Must(template.New("name").Parse("html"))
func Must(t *Template, err error) *Template {
if err != nil {
panic(err)
}
return t
}
// ParseFiles creates a new [Template] and parses the template definitions from
// the named files. The returned template's name will have the (base) name and
// (parsed) contents of the first file. There must be at least one file.
// If an error occurs, parsing stops and the returned [*Template] is nil.
//
// When parsing multiple files with the same name in different directories,
// the last one mentioned will be the one that results.
// For instance, ParseFiles("a/foo", "b/foo") stores "b/foo" as the template
// named "foo", while "a/foo" is unavailable.
func ParseFiles(filenames ...string) (*Template, error) {
return parseFiles(nil, readFileOS, filenames...)
}
// ParseFiles parses the named files and associates the resulting templates with
// t. If an error occurs, parsing stops and the returned template is nil;
// otherwise it is t. There must be at least one file.
//
// When parsing multiple files with the same name in different directories,
// the last one mentioned will be the one that results.
//
// ParseFiles returns an error if t or any associated template has already been executed.
func (t *Template) ParseFiles(filenames ...string) (*Template, error) {
return parseFiles(t, readFileOS, filenames...)
}
// parseFiles is the helper for the method and function. If the argument
// template is nil, it is created from the first file.
func parseFiles(t *Template, readFile func(string) (string, []byte, error), filenames ...string) (*Template, error) {
if err := t.checkCanParse(); err != nil {
return nil, err
}
if len(filenames) == 0 {
// Not really a problem, but be consistent.
return nil, fmt.Errorf("html/template: no files named in call to ParseFiles")
}
for _, filename := range filenames {
name, b, err := readFile(filename)
if err != nil {
return nil, err
}
s := string(b)
// First template becomes return value if not already defined,
// and we use that one for subsequent New calls to associate
// all the templates together. Also, if this file has the same name
// as t, this file becomes the contents of t, so
// t, err := New(name).Funcs(xxx).ParseFiles(name)
// works. Otherwise we create a new template associated with t.
var tmpl *Template
if t == nil {
t = New(name)
}
if name == t.Name() {
tmpl = t
} else {
tmpl = t.New(name)
}
_, err = tmpl.Parse(s)
if err != nil {
return nil, err
}
}
return t, nil
}
// ParseGlob creates a new [Template] and parses the template definitions from
// the files identified by the pattern. The files are matched according to the
// semantics of filepath.Match, and the pattern must match at least one file.
// The returned template will have the (base) name and (parsed) contents of the
// first file matched by the pattern. ParseGlob is equivalent to calling
// [ParseFiles] with the list of files matched by the pattern.
//
// When parsing multiple files with the same name in different directories,
// the last one mentioned will be the one that results.
func ParseGlob(pattern string) (*Template, error) {
return parseGlob(nil, pattern)
}
// ParseGlob parses the template definitions in the files identified by the
// pattern and associates the resulting templates with t. The files are matched
// according to the semantics of filepath.Match, and the pattern must match at
// least one file. ParseGlob is equivalent to calling t.ParseFiles with the
// list of files matched by the pattern.
//
// When parsing multiple files with the same name in different directories,
// the last one mentioned will be the one that results.
//
// ParseGlob returns an error if t or any associated template has already been executed.
func (t *Template) ParseGlob(pattern string) (*Template, error) {
return parseGlob(t, pattern)
}
// parseGlob is the implementation of the function and method ParseGlob.
func parseGlob(t *Template, pattern string) (*Template, error) {
if err := t.checkCanParse(); err != nil {
return nil, err
}
filenames, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
if len(filenames) == 0 {
return nil, fmt.Errorf("html/template: pattern matches no files: %#q", pattern)
}
return parseFiles(t, readFileOS, filenames...)
}
// IsTrue reports whether the value is 'true', in the sense of not the zero of its type,
// and whether the value has a meaningful truth value. This is the definition of
// truth used by if and other such actions.
func IsTrue(val any) (truth, ok bool) {
return template.IsTrue(val)
}
// ParseFS is like [ParseFiles] or [ParseGlob] but reads from the file system fs
// instead of the host operating system's file system.
// It accepts a list of glob patterns.
// (Note that most file names serve as glob patterns matching only themselves.)
func ParseFS(fs fs.FS, patterns ...string) (*Template, error) {
return parseFS(nil, fs, patterns)
}
// ParseFS is like [Template.ParseFiles] or [Template.ParseGlob] but reads from the file system fs
// instead of the host operating system's file system.
// It accepts a list of glob patterns.
// (Note that most file names serve as glob patterns matching only themselves.)
func (t *Template) ParseFS(fs fs.FS, patterns ...string) (*Template, error) {
return parseFS(t, fs, patterns)
}
func parseFS(t *Template, fsys fs.FS, patterns []string) (*Template, error) {
var filenames []string
for _, pattern := range patterns {
list, err := fs.Glob(fsys, pattern)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, fmt.Errorf("template: pattern matches no files: %#q", pattern)
}
filenames = append(filenames, list...)
}
return parseFiles(t, readFileFS(fsys), filenames...)
}
func readFileOS(file string) (name string, b []byte, err error) {
name = filepath.Base(file)
b, err = os.ReadFile(file)
return
}
func readFileFS(fsys fs.FS) func(string) (string, []byte, error) {
return func(file string) (name string, b []byte, err error) {
name = path.Base(file)
b, err = fs.ReadFile(fsys, file)
return
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"bytes"
"strings"
)
// transitionFunc is the array of context transition functions for text nodes.
// A transition function takes a context and template text input, and returns
// the updated context and the number of bytes consumed from the front of the
// input.
var transitionFunc = [...]func(context, []byte) (context, int){
stateText: tText,
stateTag: tTag,
stateAttrName: tAttrName,
stateAfterName: tAfterName,
stateBeforeValue: tBeforeValue,
stateHTMLCmt: tHTMLCmt,
stateRCDATA: tSpecialTagEnd,
stateAttr: tAttr,
stateURL: tURL,
stateSrcset: tURL,
stateJS: tJS,
stateJSDqStr: tJSDelimited,
stateJSSqStr: tJSDelimited,
stateJSRegexp: tJSDelimited,
stateJSTmplLit: tJSTmpl,
stateJSBlockCmt: tBlockCmt,
stateJSLineCmt: tLineCmt,
stateJSHTMLOpenCmt: tLineCmt,
stateJSHTMLCloseCmt: tLineCmt,
stateCSS: tCSS,
stateCSSDqStr: tCSSStr,
stateCSSSqStr: tCSSStr,
stateCSSDqURL: tCSSStr,
stateCSSSqURL: tCSSStr,
stateCSSURL: tCSSStr,
stateCSSBlockCmt: tBlockCmt,
stateCSSLineCmt: tLineCmt,
stateError: tError,
}
var commentStart = []byte("<!--")
var commentEnd = []byte("-->")
// tText is the context transition function for the text state.
func tText(c context, s []byte) (context, int) {
k := 0
for {
i := k + bytes.IndexByte(s[k:], '<')
if i < k || i+1 == len(s) {
return c, len(s)
} else if i+4 <= len(s) && bytes.Equal(commentStart, s[i:i+4]) {
return context{state: stateHTMLCmt}, i + 4
}
i++
end := false
if s[i] == '/' {
if i+1 == len(s) {
return c, len(s)
}
end, i = true, i+1
}
j, e := eatTagName(s, i)
if j != i {
if end {
e = elementNone
}
// We've found an HTML tag.
return context{state: stateTag, element: e}, j
}
k = j
}
}
var elementContentType = [...]state{
elementNone: stateText,
elementScript: stateJS,
elementStyle: stateCSS,
elementTextarea: stateRCDATA,
elementTitle: stateRCDATA,
}
// tTag is the context transition function for the tag state.
func tTag(c context, s []byte) (context, int) {
// Find the attribute name.
i := eatWhiteSpace(s, 0)
if i == len(s) {
return c, len(s)
}
if s[i] == '>' {
return context{
state: elementContentType[c.element],
element: c.element,
}, i + 1
}
j, err := eatAttrName(s, i)
if err != nil {
return context{state: stateError, err: err}, len(s)
}
state, attr := stateTag, attrNone
if i == j {
return context{
state: stateError,
err: errorf(ErrBadHTML, nil, 0, "expected space, attr name, or end of tag, but got %q", s[i:]),
}, len(s)
}
attrName := strings.ToLower(string(s[i:j]))
if c.element == elementScript && attrName == "type" {
attr = attrScriptType
} else {
switch attrType(attrName) {
case contentTypeURL:
attr = attrURL
case contentTypeCSS:
attr = attrStyle
case contentTypeJS:
attr = attrScript
case contentTypeSrcset:
attr = attrSrcset
}
}
if j == len(s) {
state = stateAttrName
} else {
state = stateAfterName
}
return context{state: state, element: c.element, attr: attr}, j
}
// tAttrName is the context transition function for stateAttrName.
func tAttrName(c context, s []byte) (context, int) {
i, err := eatAttrName(s, 0)
if err != nil {
return context{state: stateError, err: err}, len(s)
} else if i != len(s) {
c.state = stateAfterName
}
return c, i
}
// tAfterName is the context transition function for stateAfterName.
func tAfterName(c context, s []byte) (context, int) {
// Look for the start of the value.
i := eatWhiteSpace(s, 0)
if i == len(s) {
return c, len(s)
} else if s[i] != '=' {
// Occurs due to tag ending '>', and valueless attribute.
c.state = stateTag
return c, i
}
c.state = stateBeforeValue
// Consume the "=".
return c, i + 1
}
var attrStartStates = [...]state{
attrNone: stateAttr,
attrScript: stateJS,
attrScriptType: stateAttr,
attrStyle: stateCSS,
attrURL: stateURL,
attrSrcset: stateSrcset,
}
// tBeforeValue is the context transition function for stateBeforeValue.
func tBeforeValue(c context, s []byte) (context, int) {
i := eatWhiteSpace(s, 0)
if i == len(s) {
return c, len(s)
}
// Find the attribute delimiter.
delim := delimSpaceOrTagEnd
switch s[i] {
case '\'':
delim, i = delimSingleQuote, i+1
case '"':
delim, i = delimDoubleQuote, i+1
}
c.state, c.delim = attrStartStates[c.attr], delim
return c, i
}
// tHTMLCmt is the context transition function for stateHTMLCmt.
func tHTMLCmt(c context, s []byte) (context, int) {
if i := bytes.Index(s, commentEnd); i != -1 {
return context{}, i + 3
}
return c, len(s)
}
// specialTagEndMarkers maps element types to the character sequence that
// case-insensitively signals the end of the special tag body.
var specialTagEndMarkers = [...][]byte{
elementScript: []byte("script"),
elementStyle: []byte("style"),
elementTextarea: []byte("textarea"),
elementTitle: []byte("title"),
}
var (
specialTagEndPrefix = []byte("</")
tagEndSeparators = []byte("> \t\n\f/")
)
// tSpecialTagEnd is the context transition function for raw text and RCDATA
// element states.
func tSpecialTagEnd(c context, s []byte) (context, int) {
if c.element != elementNone {
// script end tags ("</script") within script literals are ignored, so that
// we can properly escape them.
if c.element == elementScript && (isInScriptLiteral(c.state) || isComment(c.state)) {
return c, len(s)
}
if i := indexTagEnd(s, specialTagEndMarkers[c.element]); i != -1 {
return context{}, i
}
}
return c, len(s)
}
// indexTagEnd finds the index of a special tag end in a case insensitive way, or returns -1
func indexTagEnd(s []byte, tag []byte) int {
res := 0
plen := len(specialTagEndPrefix)
for len(s) > 0 {
// Try to find the tag end prefix first
i := bytes.Index(s, specialTagEndPrefix)
if i == -1 {
return i
}
s = s[i+plen:]
// Try to match the actual tag if there is still space for it
if len(tag) <= len(s) && bytes.EqualFold(tag, s[:len(tag)]) {
s = s[len(tag):]
// Check the tag is followed by a proper separator
if len(s) > 0 && bytes.IndexByte(tagEndSeparators, s[0]) != -1 {
return res + i
}
res += len(tag)
}
res += i + plen
}
return -1
}
// tAttr is the context transition function for the attribute state.
func tAttr(c context, s []byte) (context, int) {
return c, len(s)
}
// tURL is the context transition function for the URL state.
func tURL(c context, s []byte) (context, int) {
if bytes.ContainsAny(s, "#?") {
c.urlPart = urlPartQueryOrFrag
} else if len(s) != eatWhiteSpace(s, 0) && c.urlPart == urlPartNone {
// HTML5 uses "Valid URL potentially surrounded by spaces" for
// attrs: https://www.w3.org/TR/html5/index.html#attributes-1
c.urlPart = urlPartPreQuery
}
return c, len(s)
}
// tJS is the context transition function for the JS state.
func tJS(c context, s []byte) (context, int) {
i := bytes.IndexAny(s, "\"`'/{}<-#")
if i == -1 {
// Entire input is non string, comment, regexp tokens.
c.jsCtx = nextJSCtx(s, c.jsCtx)
return c, len(s)
}
c.jsCtx = nextJSCtx(s[:i], c.jsCtx)
switch s[i] {
case '"':
c.state, c.jsCtx = stateJSDqStr, jsCtxRegexp
case '\'':
c.state, c.jsCtx = stateJSSqStr, jsCtxRegexp
case '`':
c.state, c.jsCtx = stateJSTmplLit, jsCtxRegexp
case '/':
switch {
case i+1 < len(s) && s[i+1] == '/':
c.state, i = stateJSLineCmt, i+1
case i+1 < len(s) && s[i+1] == '*':
c.state, i = stateJSBlockCmt, i+1
case c.jsCtx == jsCtxRegexp:
c.state = stateJSRegexp
case c.jsCtx == jsCtxDivOp:
c.jsCtx = jsCtxRegexp
default:
return context{
state: stateError,
err: errorf(ErrSlashAmbig, nil, 0, "'/' could start a division or regexp: %.32q", s[i:]),
}, len(s)
}
// ECMAScript supports HTML style comments for legacy reasons, see Appendix
// B.1.1 "HTML-like Comments". The handling of these comments is somewhat
// confusing. Multi-line comments are not supported, i.e. anything on lines
// between the opening and closing tokens is not considered a comment, but
// anything following the opening or closing token, on the same line, is
// ignored. As such we simply treat any line prefixed with "<!--" or "-->"
// as if it were actually prefixed with "//" and move on.
case '<':
if i+3 < len(s) && bytes.Equal(commentStart, s[i:i+4]) {
c.state, i = stateJSHTMLOpenCmt, i+3
}
case '-':
if i+2 < len(s) && bytes.Equal(commentEnd, s[i:i+3]) {
c.state, i = stateJSHTMLCloseCmt, i+2
}
// ECMAScript also supports "hashbang" comment lines, see Section 12.5.
case '#':
if i+1 < len(s) && s[i+1] == '!' {
c.state, i = stateJSLineCmt, i+1
}
case '{':
// We only care about tracking brace depth if we are inside of a
// template literal.
if len(c.jsBraceDepth) == 0 {
return c, i + 1
}
c.jsBraceDepth[len(c.jsBraceDepth)-1]++
case '}':
if len(c.jsBraceDepth) == 0 {
return c, i + 1
}
// There are no cases where a brace can be escaped in the JS context
// that are not syntax errors, it seems. Because of this we can just
// count "\}" as "}" and move on, the script is already broken as
// fully fledged parsers will just fail anyway.
c.jsBraceDepth[len(c.jsBraceDepth)-1]--
if c.jsBraceDepth[len(c.jsBraceDepth)-1] >= 0 {
return c, i + 1
}
c.jsBraceDepth = c.jsBraceDepth[:len(c.jsBraceDepth)-1]
c.state = stateJSTmplLit
default:
panic("unreachable")
}
return c, i + 1
}
func tJSTmpl(c context, s []byte) (context, int) {
var k int
for {
i := k + bytes.IndexAny(s[k:], "`\\$")
if i < k {
break
}
switch s[i] {
case '\\':
i++
if i == len(s) {
return context{
state: stateError,
err: errorf(ErrPartialEscape, nil, 0, "unfinished escape sequence in JS string: %q", s),
}, len(s)
}
case '$':
if len(s) >= i+2 && s[i+1] == '{' {
c.jsBraceDepth = append(c.jsBraceDepth, 0)
c.state = stateJS
return c, i + 2
}
case '`':
// end
c.state = stateJS
return c, i + 1
}
k = i + 1
}
return c, len(s)
}
// tJSDelimited is the context transition function for the JS string and regexp
// states.
func tJSDelimited(c context, s []byte) (context, int) {
specials := `\"`
switch c.state {
case stateJSSqStr:
specials = `\'`
case stateJSRegexp:
specials = `\/[]`
}
k, inCharset := 0, false
for {
i := k + bytes.IndexAny(s[k:], specials)
if i < k {
break
}
switch s[i] {
case '\\':
i++
if i == len(s) {
return context{
state: stateError,
err: errorf(ErrPartialEscape, nil, 0, "unfinished escape sequence in JS string: %q", s),
}, len(s)
}
case '[':
inCharset = true
case ']':
inCharset = false
case '/':
// If "</script" appears in a regex literal, the '/' should not
// close the regex literal, and it will later be escaped to
// "\x3C/script" in escapeText.
if i > 0 && i+7 <= len(s) && bytes.Equal(bytes.ToLower(s[i-1:i+7]), []byte("</script")) {
i++
} else if !inCharset {
c.state, c.jsCtx = stateJS, jsCtxDivOp
return c, i + 1
}
default:
// end delimiter
if !inCharset {
c.state, c.jsCtx = stateJS, jsCtxDivOp
return c, i + 1
}
}
k = i + 1
}
if inCharset {
// This can be fixed by making context richer if interpolation
// into charsets is desired.
return context{
state: stateError,
err: errorf(ErrPartialCharset, nil, 0, "unfinished JS regexp charset: %q", s),
}, len(s)
}
return c, len(s)
}
var blockCommentEnd = []byte("*/")
// tBlockCmt is the context transition function for /*comment*/ states.
func tBlockCmt(c context, s []byte) (context, int) {
i := bytes.Index(s, blockCommentEnd)
if i == -1 {
return c, len(s)
}
switch c.state {
case stateJSBlockCmt:
c.state = stateJS
case stateCSSBlockCmt:
c.state = stateCSS
default:
panic(c.state.String())
}
return c, i + 2
}
// tLineCmt is the context transition function for //comment states, and the JS HTML-like comment state.
func tLineCmt(c context, s []byte) (context, int) {
var lineTerminators string
var endState state
switch c.state {
case stateJSLineCmt, stateJSHTMLOpenCmt, stateJSHTMLCloseCmt:
lineTerminators, endState = "\n\r\u2028\u2029", stateJS
case stateCSSLineCmt:
lineTerminators, endState = "\n\f\r", stateCSS
// Line comments are not part of any published CSS standard but
// are supported by the 4 major browsers.
// This defines line comments as
// LINECOMMENT ::= "//" [^\n\f\d]*
// since https://www.w3.org/TR/css3-syntax/#SUBTOK-nl defines
// newlines:
// nl ::= #xA | #xD #xA | #xD | #xC
default:
panic(c.state.String())
}
i := bytes.IndexAny(s, lineTerminators)
if i == -1 {
return c, len(s)
}
c.state = endState
// Per section 7.4 of EcmaScript 5 : https://es5.github.io/#x7.4
// "However, the LineTerminator at the end of the line is not
// considered to be part of the single-line comment; it is
// recognized separately by the lexical grammar and becomes part
// of the stream of input elements for the syntactic grammar."
return c, i
}
// tCSS is the context transition function for the CSS state.
func tCSS(c context, s []byte) (context, int) {
// CSS quoted strings are almost never used except for:
// (1) URLs as in background: "/foo.png"
// (2) Multiword font-names as in font-family: "Times New Roman"
// (3) List separators in content values as in inline-lists:
// <style>
// ul.inlineList { list-style: none; padding:0 }
// ul.inlineList > li { display: inline }
// ul.inlineList > li:before { content: ", " }
// ul.inlineList > li:first-child:before { content: "" }
// </style>
// <ul class=inlineList><li>One<li>Two<li>Three</ul>
// (4) Attribute value selectors as in a[href="http://example.com/"]
//
// We conservatively treat all strings as URLs, but make some
// allowances to avoid confusion.
//
// In (1), our conservative assumption is justified.
// In (2), valid font names do not contain ':', '?', or '#', so our
// conservative assumption is fine since we will never transition past
// urlPartPreQuery.
// In (3), our protocol heuristic should not be tripped, and there
// should not be non-space content after a '?' or '#', so as long as
// we only %-encode RFC 3986 reserved characters we are ok.
// In (4), we should URL escape for URL attributes, and for others we
// have the attribute name available if our conservative assumption
// proves problematic for real code.
k := 0
for {
i := k + bytes.IndexAny(s[k:], `("'/`)
if i < k {
return c, len(s)
}
switch s[i] {
case '(':
// Look for url to the left.
p := bytes.TrimRight(s[:i], "\t\n\f\r ")
if endsWithCSSKeyword(p, "url") {
j := len(s) - len(bytes.TrimLeft(s[i+1:], "\t\n\f\r "))
switch {
case j != len(s) && s[j] == '"':
c.state, j = stateCSSDqURL, j+1
case j != len(s) && s[j] == '\'':
c.state, j = stateCSSSqURL, j+1
default:
c.state = stateCSSURL
}
return c, j
}
case '/':
if i+1 < len(s) {
switch s[i+1] {
case '/':
c.state = stateCSSLineCmt
return c, i + 2
case '*':
c.state = stateCSSBlockCmt
return c, i + 2
}
}
case '"':
c.state = stateCSSDqStr
return c, i + 1
case '\'':
c.state = stateCSSSqStr
return c, i + 1
}
k = i + 1
}
}
// tCSSStr is the context transition function for the CSS string and URL states.
func tCSSStr(c context, s []byte) (context, int) {
var endAndEsc string
switch c.state {
case stateCSSDqStr, stateCSSDqURL:
endAndEsc = `\"`
case stateCSSSqStr, stateCSSSqURL:
endAndEsc = `\'`
case stateCSSURL:
// Unquoted URLs end with a newline or close parenthesis.
// The below includes the wc (whitespace character) and nl.
endAndEsc = "\\\t\n\f\r )"
default:
panic(c.state.String())
}
k := 0
for {
i := k + bytes.IndexAny(s[k:], endAndEsc)
if i < k {
c, nread := tURL(c, decodeCSS(s[k:]))
return c, k + nread
}
if s[i] == '\\' {
i++
if i == len(s) {
return context{
state: stateError,
err: errorf(ErrPartialEscape, nil, 0, "unfinished escape sequence in CSS string: %q", s),
}, len(s)
}
} else {
c.state = stateCSS
return c, i + 1
}
c, _ = tURL(c, decodeCSS(s[:i+1]))
k = i + 1
}
}
// tError is the context transition function for the error state.
func tError(c context, s []byte) (context, int) {
return c, len(s)
}
// eatAttrName returns the largest j such that s[i:j] is an attribute name.
// It returns an error if s[i:] does not look like it begins with an
// attribute name, such as encountering a quote mark without a preceding
// equals sign.
func eatAttrName(s []byte, i int) (int, *Error) {
for j := i; j < len(s); j++ {
switch s[j] {
case ' ', '\t', '\n', '\f', '\r', '=', '>':
return j, nil
case '\'', '"', '<':
// These result in a parse warning in HTML5 and are
// indicative of serious problems if seen in an attr
// name in a template.
return -1, errorf(ErrBadHTML, nil, 0, "%q in attribute name: %.32q", s[j:j+1], s)
default:
// No-op.
}
}
return len(s), nil
}
var elementNameMap = map[string]element{
"script": elementScript,
"style": elementStyle,
"textarea": elementTextarea,
"title": elementTitle,
}
// asciiAlpha reports whether c is an ASCII letter.
func asciiAlpha(c byte) bool {
return 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z'
}
// asciiAlphaNum reports whether c is an ASCII letter or digit.
func asciiAlphaNum(c byte) bool {
return asciiAlpha(c) || '0' <= c && c <= '9'
}
// eatTagName returns the largest j such that s[i:j] is a tag name and the tag type.
func eatTagName(s []byte, i int) (int, element) {
if i == len(s) || !asciiAlpha(s[i]) {
return i, elementNone
}
j := i + 1
for j < len(s) {
x := s[j]
if asciiAlphaNum(x) {
j++
continue
}
// Allow "x-y" or "x:y" but not "x-", "-y", or "x--y".
if (x == ':' || x == '-') && j+1 < len(s) && asciiAlphaNum(s[j+1]) {
j += 2
continue
}
break
}
return j, elementNameMap[strings.ToLower(string(s[i:j]))]
}
// eatWhiteSpace returns the largest j such that s[i:j] is white space.
func eatWhiteSpace(s []byte, i int) int {
for j := i; j < len(s); j++ {
switch s[j] {
case ' ', '\t', '\n', '\f', '\r':
// No-op.
default:
return j
}
}
return len(s)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"fmt"
"strings"
)
// urlFilter returns its input unless it contains an unsafe scheme in which
// case it defangs the entire URL.
//
// Schemes that cause unintended side effects that are irreversible without user
// interaction are considered unsafe. For example, clicking on a "javascript:"
// link can immediately trigger JavaScript code execution.
//
// This filter conservatively assumes that all schemes other than the following
// are unsafe:
// - http: Navigates to a new website, and may open a new window or tab.
// These side effects can be reversed by navigating back to the
// previous website, or closing the window or tab. No irreversible
// changes will take place without further user interaction with
// the new website.
// - https: Same as http.
// - mailto: Opens an email program and starts a new draft. This side effect
// is not irreversible until the user explicitly clicks send; it
// can be undone by closing the email program.
//
// To allow URLs containing other schemes to bypass this filter, developers must
// explicitly indicate that such a URL is expected and safe by encapsulating it
// in a template.URL value.
func urlFilter(args ...any) string {
s, t := stringify(args...)
if t == contentTypeURL {
return s
}
if !isSafeURL(s) {
return "#" + filterFailsafe
}
return s
}
// isSafeURL is true if s is a relative URL or if URL has a protocol in
// (http, https, mailto).
func isSafeURL(s string) bool {
if protocol, _, ok := strings.Cut(s, ":"); ok && !strings.Contains(protocol, "/") {
if !strings.EqualFold(protocol, "http") && !strings.EqualFold(protocol, "https") && !strings.EqualFold(protocol, "mailto") {
return false
}
}
return true
}
// urlEscaper produces an output that can be embedded in a URL query.
// The output can be embedded in an HTML attribute without further escaping.
func urlEscaper(args ...any) string {
return urlProcessor(false, args...)
}
// urlNormalizer normalizes URL content so it can be embedded in a quote-delimited
// string or parenthesis delimited url(...).
// The normalizer does not encode all HTML specials. Specifically, it does not
// encode '&' so correct embedding in an HTML attribute requires escaping of
// '&' to '&'.
func urlNormalizer(args ...any) string {
return urlProcessor(true, args...)
}
// urlProcessor normalizes (when norm is true) or escapes its input to produce
// a valid hierarchical or opaque URL part.
func urlProcessor(norm bool, args ...any) string {
s, t := stringify(args...)
if t == contentTypeURL {
norm = true
}
var b strings.Builder
if processURLOnto(s, norm, &b) {
return b.String()
}
return s
}
// processURLOnto appends a normalized URL corresponding to its input to b
// and reports whether the appended content differs from s.
func processURLOnto(s string, norm bool, b *strings.Builder) bool {
b.Grow(len(s) + 16)
written := 0
// The byte loop below assumes that all URLs use UTF-8 as the
// content-encoding. This is similar to the URI to IRI encoding scheme
// defined in section 3.1 of RFC 3987, and behaves the same as the
// EcmaScript builtin encodeURIComponent.
// It should not cause any misencoding of URLs in pages with
// Content-type: text/html;charset=UTF-8.
for i, n := 0, len(s); i < n; i++ {
c := s[i]
switch c {
// Single quote and parens are sub-delims in RFC 3986, but we
// escape them so the output can be embedded in single
// quoted attributes and unquoted CSS url(...) constructs.
// Single quotes are reserved in URLs, but are only used in
// the obsolete "mark" rule in an appendix in RFC 3986
// so can be safely encoded.
case '!', '#', '$', '&', '*', '+', ',', '/', ':', ';', '=', '?', '@', '[', ']':
if norm {
continue
}
// Unreserved according to RFC 3986 sec 2.3
// "For consistency, percent-encoded octets in the ranges of
// ALPHA (%41-%5A and %61-%7A), DIGIT (%30-%39), hyphen (%2D),
// period (%2E), underscore (%5F), or tilde (%7E) should not be
// created by URI producers
case '-', '.', '_', '~':
continue
case '%':
// When normalizing do not re-encode valid escapes.
if norm && i+2 < len(s) && isHex(s[i+1]) && isHex(s[i+2]) {
continue
}
default:
// Unreserved according to RFC 3986 sec 2.3
if 'a' <= c && c <= 'z' {
continue
}
if 'A' <= c && c <= 'Z' {
continue
}
if '0' <= c && c <= '9' {
continue
}
}
b.WriteString(s[written:i])
fmt.Fprintf(b, "%%%02x", c)
written = i + 1
}
b.WriteString(s[written:])
return written != 0
}
// Filters and normalizes srcset values which are comma separated
// URLs followed by metadata.
func srcsetFilterAndEscaper(args ...any) string {
s, t := stringify(args...)
switch t {
case contentTypeSrcset:
return s
case contentTypeURL:
// Normalizing gets rid of all HTML whitespace
// which separate the image URL from its metadata.
var b strings.Builder
if processURLOnto(s, true, &b) {
s = b.String()
}
// Additionally, commas separate one source from another.
return strings.ReplaceAll(s, ",", "%2c")
}
var b strings.Builder
written := 0
for i := 0; i < len(s); i++ {
if s[i] == ',' {
filterSrcsetElement(s, written, i, &b)
b.WriteString(",")
written = i + 1
}
}
filterSrcsetElement(s, written, len(s), &b)
return b.String()
}
// Derived from https://play.golang.org/p/Dhmj7FORT5
const htmlSpaceAndASCIIAlnumBytes = "\x00\x36\x00\x00\x01\x00\xff\x03\xfe\xff\xff\x07\xfe\xff\xff\x07"
// isHTMLSpace is true iff c is a whitespace character per
// https://infra.spec.whatwg.org/#ascii-whitespace
func isHTMLSpace(c byte) bool {
return (c <= 0x20) && 0 != (htmlSpaceAndASCIIAlnumBytes[c>>3]&(1<<uint(c&0x7)))
}
func isHTMLSpaceOrASCIIAlnum(c byte) bool {
return (c < 0x80) && 0 != (htmlSpaceAndASCIIAlnumBytes[c>>3]&(1<<uint(c&0x7)))
}
func filterSrcsetElement(s string, left int, right int, b *strings.Builder) {
start := left
for start < right && isHTMLSpace(s[start]) {
start++
}
end := right
for i := start; i < right; i++ {
if isHTMLSpace(s[i]) {
end = i
break
}
}
if url := s[start:end]; isSafeURL(url) {
// If image metadata is only spaces or alnums then
// we don't need to URL normalize it.
metadataOk := true
for i := end; i < right; i++ {
if !isHTMLSpaceOrASCIIAlnum(s[i]) {
metadataOk = false
break
}
}
if metadataOk {
b.WriteString(s[left:start])
processURLOnto(url, true, b)
b.WriteString(s[end:right])
return
}
}
b.WriteString("#")
b.WriteString(filterFailsafe)
}
// Code generated by "stringer -type urlPart"; DO NOT EDIT.
package template
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[urlPartNone-0]
_ = x[urlPartPreQuery-1]
_ = x[urlPartQueryOrFrag-2]
_ = x[urlPartUnknown-3]
}
const _urlPart_name = "urlPartNoneurlPartPreQueryurlPartQueryOrFragurlPartUnknown"
var _urlPart_index = [...]uint8{0, 11, 26, 44, 58}
func (i urlPart) String() string {
if i >= urlPart(len(_urlPart_index)-1) {
return "urlPart(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _urlPart_name[_urlPart_index[i]:_urlPart_index[i+1]]
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package image
import (
"bufio"
"errors"
"io"
"sync"
"sync/atomic"
)
// ErrFormat indicates that decoding encountered an unknown format.
var ErrFormat = errors.New("image: unknown format")
// A format holds an image format's name, magic header and how to decode it.
type format struct {
name, magic string
decode func(io.Reader) (Image, error)
decodeConfig func(io.Reader) (Config, error)
}
// Formats is the list of registered formats.
var (
formatsMu sync.Mutex
atomicFormats atomic.Value
)
// RegisterFormat registers an image format for use by [Decode].
// Name is the name of the format, like "jpeg" or "png".
// Magic is the magic prefix that identifies the format's encoding. The magic
// string can contain "?" wildcards that each match any one byte.
// [Decode] is the function that decodes the encoded image.
// [DecodeConfig] is the function that decodes just its configuration.
func RegisterFormat(name, magic string, decode func(io.Reader) (Image, error), decodeConfig func(io.Reader) (Config, error)) {
formatsMu.Lock()
formats, _ := atomicFormats.Load().([]format)
atomicFormats.Store(append(formats, format{name, magic, decode, decodeConfig}))
formatsMu.Unlock()
}
// A reader is an io.Reader that can also peek ahead.
type reader interface {
io.Reader
Peek(int) ([]byte, error)
}
// asReader converts an io.Reader to a reader.
func asReader(r io.Reader) reader {
if rr, ok := r.(reader); ok {
return rr
}
return bufio.NewReader(r)
}
// match reports whether magic matches b. Magic may contain "?" wildcards.
func match(magic string, b []byte) bool {
if len(magic) != len(b) {
return false
}
for i, c := range b {
if magic[i] != c && magic[i] != '?' {
return false
}
}
return true
}
// sniff determines the format of r's data.
func sniff(r reader) format {
formats, _ := atomicFormats.Load().([]format)
for _, f := range formats {
b, err := r.Peek(len(f.magic))
if err == nil && match(f.magic, b) {
return f
}
}
return format{}
}
// Decode decodes an image that has been encoded in a registered format.
// The string returned is the format name used during format registration.
// Format registration is typically done by an init function in the codec-
// specific package.
func Decode(r io.Reader) (Image, string, error) {
rr := asReader(r)
f := sniff(rr)
if f.decode == nil {
return nil, "", ErrFormat
}
m, err := f.decode(rr)
return m, f.name, err
}
// DecodeConfig decodes the color model and dimensions of an image that has
// been encoded in a registered format. The string returned is the format name
// used during format registration. Format registration is typically done by
// an init function in the codec-specific package.
func DecodeConfig(r io.Reader) (Config, string, error) {
rr := asReader(r)
f := sniff(rr)
if f.decodeConfig == nil {
return Config{}, "", ErrFormat
}
c, err := f.decodeConfig(rr)
return c, f.name, err
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package image
import (
"image/color"
"math/bits"
"strconv"
)
// A Point is an X, Y coordinate pair. The axes increase right and down.
type Point struct {
X, Y int
}
// String returns a string representation of p like "(3,4)".
func (p Point) String() string {
return "(" + strconv.Itoa(p.X) + "," + strconv.Itoa(p.Y) + ")"
}
// Add returns the vector p+q.
func (p Point) Add(q Point) Point {
return Point{p.X + q.X, p.Y + q.Y}
}
// Sub returns the vector p-q.
func (p Point) Sub(q Point) Point {
return Point{p.X - q.X, p.Y - q.Y}
}
// Mul returns the vector p*k.
func (p Point) Mul(k int) Point {
return Point{p.X * k, p.Y * k}
}
// Div returns the vector p/k.
func (p Point) Div(k int) Point {
return Point{p.X / k, p.Y / k}
}
// In reports whether p is in r.
func (p Point) In(r Rectangle) bool {
return r.Min.X <= p.X && p.X < r.Max.X &&
r.Min.Y <= p.Y && p.Y < r.Max.Y
}
// Mod returns the point q in r such that p.X-q.X is a multiple of r's width
// and p.Y-q.Y is a multiple of r's height.
func (p Point) Mod(r Rectangle) Point {
w, h := r.Dx(), r.Dy()
p = p.Sub(r.Min)
p.X = p.X % w
if p.X < 0 {
p.X += w
}
p.Y = p.Y % h
if p.Y < 0 {
p.Y += h
}
return p.Add(r.Min)
}
// Eq reports whether p and q are equal.
func (p Point) Eq(q Point) bool {
return p == q
}
// ZP is the zero [Point].
//
// Deprecated: Use a literal [image.Point] instead.
var ZP Point
// Pt is shorthand for [Point]{X, Y}.
func Pt(X, Y int) Point {
return Point{X, Y}
}
// A Rectangle contains the points with Min.X <= X < Max.X, Min.Y <= Y < Max.Y.
// It is well-formed if Min.X <= Max.X and likewise for Y. Points are always
// well-formed. A rectangle's methods always return well-formed outputs for
// well-formed inputs.
//
// A Rectangle is also an [Image] whose bounds are the rectangle itself. At
// returns color.Opaque for points in the rectangle and color.Transparent
// otherwise.
type Rectangle struct {
Min, Max Point
}
// String returns a string representation of r like "(3,4)-(6,5)".
func (r Rectangle) String() string {
return r.Min.String() + "-" + r.Max.String()
}
// Dx returns r's width.
func (r Rectangle) Dx() int {
return r.Max.X - r.Min.X
}
// Dy returns r's height.
func (r Rectangle) Dy() int {
return r.Max.Y - r.Min.Y
}
// Size returns r's width and height.
func (r Rectangle) Size() Point {
return Point{
r.Max.X - r.Min.X,
r.Max.Y - r.Min.Y,
}
}
// Add returns the rectangle r translated by p.
func (r Rectangle) Add(p Point) Rectangle {
return Rectangle{
Point{r.Min.X + p.X, r.Min.Y + p.Y},
Point{r.Max.X + p.X, r.Max.Y + p.Y},
}
}
// Sub returns the rectangle r translated by -p.
func (r Rectangle) Sub(p Point) Rectangle {
return Rectangle{
Point{r.Min.X - p.X, r.Min.Y - p.Y},
Point{r.Max.X - p.X, r.Max.Y - p.Y},
}
}
// Inset returns the rectangle r inset by n, which may be negative. If either
// of r's dimensions is less than 2*n then an empty rectangle near the center
// of r will be returned.
func (r Rectangle) Inset(n int) Rectangle {
if r.Dx() < 2*n {
r.Min.X = (r.Min.X + r.Max.X) / 2
r.Max.X = r.Min.X
} else {
r.Min.X += n
r.Max.X -= n
}
if r.Dy() < 2*n {
r.Min.Y = (r.Min.Y + r.Max.Y) / 2
r.Max.Y = r.Min.Y
} else {
r.Min.Y += n
r.Max.Y -= n
}
return r
}
// Intersect returns the largest rectangle contained by both r and s. If the
// two rectangles do not overlap then the zero rectangle will be returned.
func (r Rectangle) Intersect(s Rectangle) Rectangle {
if r.Min.X < s.Min.X {
r.Min.X = s.Min.X
}
if r.Min.Y < s.Min.Y {
r.Min.Y = s.Min.Y
}
if r.Max.X > s.Max.X {
r.Max.X = s.Max.X
}
if r.Max.Y > s.Max.Y {
r.Max.Y = s.Max.Y
}
// Letting r0 and s0 be the values of r and s at the time that the method
// is called, this next line is equivalent to:
//
// if max(r0.Min.X, s0.Min.X) >= min(r0.Max.X, s0.Max.X) || likewiseForY { etc }
if r.Empty() {
return Rectangle{}
}
return r
}
// Union returns the smallest rectangle that contains both r and s.
func (r Rectangle) Union(s Rectangle) Rectangle {
if r.Empty() {
return s
}
if s.Empty() {
return r
}
if r.Min.X > s.Min.X {
r.Min.X = s.Min.X
}
if r.Min.Y > s.Min.Y {
r.Min.Y = s.Min.Y
}
if r.Max.X < s.Max.X {
r.Max.X = s.Max.X
}
if r.Max.Y < s.Max.Y {
r.Max.Y = s.Max.Y
}
return r
}
// Empty reports whether the rectangle contains no points.
func (r Rectangle) Empty() bool {
return r.Min.X >= r.Max.X || r.Min.Y >= r.Max.Y
}
// Eq reports whether r and s contain the same set of points. All empty
// rectangles are considered equal.
func (r Rectangle) Eq(s Rectangle) bool {
return r == s || r.Empty() && s.Empty()
}
// Overlaps reports whether r and s have a non-empty intersection.
func (r Rectangle) Overlaps(s Rectangle) bool {
return !r.Empty() && !s.Empty() &&
r.Min.X < s.Max.X && s.Min.X < r.Max.X &&
r.Min.Y < s.Max.Y && s.Min.Y < r.Max.Y
}
// In reports whether every point in r is in s.
func (r Rectangle) In(s Rectangle) bool {
if r.Empty() {
return true
}
// Note that r.Max is an exclusive bound for r, so that r.In(s)
// does not require that r.Max.In(s).
return s.Min.X <= r.Min.X && r.Max.X <= s.Max.X &&
s.Min.Y <= r.Min.Y && r.Max.Y <= s.Max.Y
}
// Canon returns the canonical version of r. The returned rectangle has minimum
// and maximum coordinates swapped if necessary so that it is well-formed.
func (r Rectangle) Canon() Rectangle {
if r.Max.X < r.Min.X {
r.Min.X, r.Max.X = r.Max.X, r.Min.X
}
if r.Max.Y < r.Min.Y {
r.Min.Y, r.Max.Y = r.Max.Y, r.Min.Y
}
return r
}
// At implements the [Image] interface.
func (r Rectangle) At(x, y int) color.Color {
if (Point{x, y}).In(r) {
return color.Opaque
}
return color.Transparent
}
// RGBA64At implements the [RGBA64Image] interface.
func (r Rectangle) RGBA64At(x, y int) color.RGBA64 {
if (Point{x, y}).In(r) {
return color.RGBA64{0xffff, 0xffff, 0xffff, 0xffff}
}
return color.RGBA64{}
}
// Bounds implements the [Image] interface.
func (r Rectangle) Bounds() Rectangle {
return r
}
// ColorModel implements the [Image] interface.
func (r Rectangle) ColorModel() color.Model {
return color.Alpha16Model
}
// ZR is the zero [Rectangle].
//
// Deprecated: Use a literal [image.Rectangle] instead.
var ZR Rectangle
// Rect is shorthand for [Rectangle]{Pt(x0, y0), [Pt](x1, y1)}. The returned
// rectangle has minimum and maximum coordinates swapped if necessary so that
// it is well-formed.
func Rect(x0, y0, x1, y1 int) Rectangle {
if x0 > x1 {
x0, x1 = x1, x0
}
if y0 > y1 {
y0, y1 = y1, y0
}
return Rectangle{Point{x0, y0}, Point{x1, y1}}
}
// mul3NonNeg returns (x * y * z), unless at least one argument is negative or
// if the computation overflows the int type, in which case it returns -1.
func mul3NonNeg(x int, y int, z int) int {
if (x < 0) || (y < 0) || (z < 0) {
return -1
}
hi, lo := bits.Mul64(uint64(x), uint64(y))
if hi != 0 {
return -1
}
hi, lo = bits.Mul64(lo, uint64(z))
if hi != 0 {
return -1
}
a := int(lo)
if (a < 0) || (uint64(a) != lo) {
return -1
}
return a
}
// add2NonNeg returns (x + y), unless at least one argument is negative or if
// the computation overflows the int type, in which case it returns -1.
func add2NonNeg(x int, y int) int {
if (x < 0) || (y < 0) {
return -1
}
a := x + y
if a < 0 {
return -1
}
return a
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package gif implements a GIF image decoder and encoder.
//
// The GIF specification is at https://www.w3.org/Graphics/GIF/spec-gif89a.txt.
package gif
import (
"bufio"
"compress/lzw"
"errors"
"fmt"
"image"
"image/color"
"io"
)
var (
errNotEnough = errors.New("gif: not enough image data")
errTooMuch = errors.New("gif: too much image data")
errBadPixel = errors.New("gif: invalid pixel value")
)
// If the io.Reader does not also have ReadByte, then decode will introduce its own buffering.
type reader interface {
io.Reader
io.ByteReader
}
// Masks etc.
const (
// Fields.
fColorTable = 1 << 7
fInterlace = 1 << 6
fColorTableBitsMask = 7
// Graphic control flags.
gcTransparentColorSet = 1 << 0
gcDisposalMethodMask = 7 << 2
)
// Disposal Methods.
const (
DisposalNone = 0x01
DisposalBackground = 0x02
DisposalPrevious = 0x03
)
// Section indicators.
const (
sExtension = 0x21
sImageDescriptor = 0x2C
sTrailer = 0x3B
)
// Extensions.
const (
eText = 0x01 // Plain Text
eGraphicControl = 0xF9 // Graphic Control
eComment = 0xFE // Comment
eApplication = 0xFF // Application
)
func readFull(r io.Reader, b []byte) error {
_, err := io.ReadFull(r, b)
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return err
}
func readByte(r io.ByteReader) (byte, error) {
b, err := r.ReadByte()
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return b, err
}
// decoder is the type used to decode a GIF file.
type decoder struct {
r reader
// From header.
vers string
width int
height int
loopCount int
delayTime int
backgroundIndex byte
disposalMethod byte
// From image descriptor.
imageFields byte
// From graphics control.
transparentIndex byte
hasTransparentIndex bool
// Computed.
globalColorTable color.Palette
// Used when decoding.
delay []int
disposal []byte
image []*image.Paletted
tmp [1024]byte // must be at least 768 so we can read color table
}
// blockReader parses the block structure of GIF image data, which comprises
// (n, (n bytes)) blocks, with 1 <= n <= 255. It is the reader given to the
// LZW decoder, which is thus immune to the blocking. After the LZW decoder
// completes, there will be a 0-byte block remaining (0, ()), which is
// consumed when checking that the blockReader is exhausted.
//
// To avoid the allocation of a bufio.Reader for the lzw Reader, blockReader
// implements io.ByteReader and buffers blocks into the decoder's "tmp" buffer.
type blockReader struct {
d *decoder
i, j uint8 // d.tmp[i:j] contains the buffered bytes
err error
}
func (b *blockReader) fill() {
if b.err != nil {
return
}
b.j, b.err = readByte(b.d.r)
if b.j == 0 && b.err == nil {
b.err = io.EOF
}
if b.err != nil {
return
}
b.i = 0
b.err = readFull(b.d.r, b.d.tmp[:b.j])
if b.err != nil {
b.j = 0
}
}
func (b *blockReader) ReadByte() (byte, error) {
if b.i == b.j {
b.fill()
if b.err != nil {
return 0, b.err
}
}
c := b.d.tmp[b.i]
b.i++
return c, nil
}
// blockReader must implement io.Reader, but its Read shouldn't ever actually
// be called in practice. The compress/lzw package will only call [blockReader.ReadByte].
func (b *blockReader) Read(p []byte) (int, error) {
if len(p) == 0 || b.err != nil {
return 0, b.err
}
if b.i == b.j {
b.fill()
if b.err != nil {
return 0, b.err
}
}
n := copy(p, b.d.tmp[b.i:b.j])
b.i += uint8(n)
return n, nil
}
// close primarily detects whether or not a block terminator was encountered
// after reading a sequence of data sub-blocks. It allows at most one trailing
// sub-block worth of data. I.e., if some number of bytes exist in one sub-block
// following the end of LZW data, the very next sub-block must be the block
// terminator. If the very end of LZW data happened to fill one sub-block, at
// most one more sub-block of length 1 may exist before the block-terminator.
// These accommodations allow us to support GIFs created by less strict encoders.
// See https://golang.org/issue/16146.
func (b *blockReader) close() error {
if b.err == io.EOF {
// A clean block-sequence terminator was encountered while reading.
return nil
} else if b.err != nil {
// Some other error was encountered while reading.
return b.err
}
if b.i == b.j {
// We reached the end of a sub block reading LZW data. We'll allow at
// most one more sub block of data with a length of 1 byte.
b.fill()
if b.err == io.EOF {
return nil
} else if b.err != nil {
return b.err
} else if b.j > 1 {
return errTooMuch
}
}
// Part of a sub-block remains buffered. We expect that the next attempt to
// buffer a sub-block will reach the block terminator.
b.fill()
if b.err == io.EOF {
return nil
} else if b.err != nil {
return b.err
}
return errTooMuch
}
// decode reads a GIF image from r and stores the result in d.
func (d *decoder) decode(r io.Reader, configOnly, keepAllFrames bool) error {
// Add buffering if r does not provide ReadByte.
if rr, ok := r.(reader); ok {
d.r = rr
} else {
d.r = bufio.NewReader(r)
}
d.loopCount = -1
err := d.readHeaderAndScreenDescriptor()
if err != nil {
return err
}
if configOnly {
return nil
}
for {
c, err := readByte(d.r)
if err != nil {
return fmt.Errorf("gif: reading frames: %v", err)
}
switch c {
case sExtension:
if err = d.readExtension(); err != nil {
return err
}
case sImageDescriptor:
if err = d.readImageDescriptor(keepAllFrames); err != nil {
return err
}
if !keepAllFrames && len(d.image) == 1 {
return nil
}
case sTrailer:
if len(d.image) == 0 {
return fmt.Errorf("gif: missing image data")
}
return nil
default:
return fmt.Errorf("gif: unknown block type: 0x%.2x", c)
}
}
}
func (d *decoder) readHeaderAndScreenDescriptor() error {
err := readFull(d.r, d.tmp[:13])
if err != nil {
return fmt.Errorf("gif: reading header: %v", err)
}
d.vers = string(d.tmp[:6])
if d.vers != "GIF87a" && d.vers != "GIF89a" {
return fmt.Errorf("gif: can't recognize format %q", d.vers)
}
d.width = int(d.tmp[6]) + int(d.tmp[7])<<8
d.height = int(d.tmp[8]) + int(d.tmp[9])<<8
if fields := d.tmp[10]; fields&fColorTable != 0 {
d.backgroundIndex = d.tmp[11]
// readColorTable overwrites the contents of d.tmp, but that's OK.
if d.globalColorTable, err = d.readColorTable(fields); err != nil {
return err
}
}
// d.tmp[12] is the Pixel Aspect Ratio, which is ignored.
return nil
}
func (d *decoder) readColorTable(fields byte) (color.Palette, error) {
n := 1 << (1 + uint(fields&fColorTableBitsMask))
err := readFull(d.r, d.tmp[:3*n])
if err != nil {
return nil, fmt.Errorf("gif: reading color table: %s", err)
}
j, p := 0, make(color.Palette, n)
for i := range p {
p[i] = color.RGBA{d.tmp[j+0], d.tmp[j+1], d.tmp[j+2], 0xFF}
j += 3
}
return p, nil
}
func (d *decoder) readExtension() error {
extension, err := readByte(d.r)
if err != nil {
return fmt.Errorf("gif: reading extension: %v", err)
}
size := 0
switch extension {
case eText:
size = 13
case eGraphicControl:
return d.readGraphicControl()
case eComment:
// nothing to do but read the data.
case eApplication:
b, err := readByte(d.r)
if err != nil {
return fmt.Errorf("gif: reading extension: %v", err)
}
// The spec requires size be 11, but Adobe sometimes uses 10.
size = int(b)
default:
return fmt.Errorf("gif: unknown extension 0x%.2x", extension)
}
if size > 0 {
if err := readFull(d.r, d.tmp[:size]); err != nil {
return fmt.Errorf("gif: reading extension: %v", err)
}
}
// Application Extension with "NETSCAPE2.0" as string and 1 in data means
// this extension defines a loop count.
if extension == eApplication && string(d.tmp[:size]) == "NETSCAPE2.0" {
n, err := d.readBlock()
if err != nil {
return fmt.Errorf("gif: reading extension: %v", err)
}
if n == 0 {
return nil
}
if n == 3 && d.tmp[0] == 1 {
d.loopCount = int(d.tmp[1]) | int(d.tmp[2])<<8
}
}
for {
n, err := d.readBlock()
if err != nil {
return fmt.Errorf("gif: reading extension: %v", err)
}
if n == 0 {
return nil
}
}
}
func (d *decoder) readGraphicControl() error {
if err := readFull(d.r, d.tmp[:6]); err != nil {
return fmt.Errorf("gif: can't read graphic control: %s", err)
}
if d.tmp[0] != 4 {
return fmt.Errorf("gif: invalid graphic control extension block size: %d", d.tmp[0])
}
flags := d.tmp[1]
d.disposalMethod = (flags & gcDisposalMethodMask) >> 2
d.delayTime = int(d.tmp[2]) | int(d.tmp[3])<<8
if flags&gcTransparentColorSet != 0 {
d.transparentIndex = d.tmp[4]
d.hasTransparentIndex = true
}
if d.tmp[5] != 0 {
return fmt.Errorf("gif: invalid graphic control extension block terminator: %d", d.tmp[5])
}
return nil
}
func (d *decoder) readImageDescriptor(keepAllFrames bool) error {
m, err := d.newImageFromDescriptor()
if err != nil {
return err
}
useLocalColorTable := d.imageFields&fColorTable != 0
if useLocalColorTable {
m.Palette, err = d.readColorTable(d.imageFields)
if err != nil {
return err
}
} else {
if d.globalColorTable == nil {
return errors.New("gif: no color table")
}
m.Palette = d.globalColorTable
}
if d.hasTransparentIndex {
if !useLocalColorTable {
// Clone the global color table.
m.Palette = append(color.Palette(nil), d.globalColorTable...)
}
if ti := int(d.transparentIndex); ti < len(m.Palette) {
m.Palette[ti] = color.RGBA{}
} else {
// The transparentIndex is out of range, which is an error
// according to the spec, but Firefox and Google Chrome
// seem OK with this, so we enlarge the palette with
// transparent colors. See golang.org/issue/15059.
p := make(color.Palette, ti+1)
copy(p, m.Palette)
for i := len(m.Palette); i < len(p); i++ {
p[i] = color.RGBA{}
}
m.Palette = p
}
}
litWidth, err := readByte(d.r)
if err != nil {
return fmt.Errorf("gif: reading image data: %v", err)
}
if litWidth < 2 || litWidth > 8 {
return fmt.Errorf("gif: pixel size in decode out of range: %d", litWidth)
}
// A wonderfully Go-like piece of magic.
br := &blockReader{d: d}
lzwr := lzw.NewReader(br, lzw.LSB, int(litWidth))
defer lzwr.Close()
if err = readFull(lzwr, m.Pix); err != nil {
if err != io.ErrUnexpectedEOF {
return fmt.Errorf("gif: reading image data: %v", err)
}
return errNotEnough
}
// In theory, both lzwr and br should be exhausted. Reading from them
// should yield (0, io.EOF).
//
// The spec (Appendix F - Compression), says that "An End of
// Information code... must be the last code output by the encoder
// for an image". In practice, though, giflib (a widely used C
// library) does not enforce this, so we also accept lzwr returning
// io.ErrUnexpectedEOF (meaning that the encoded stream hit io.EOF
// before the LZW decoder saw an explicit end code), provided that
// the io.ReadFull call above successfully read len(m.Pix) bytes.
// See https://golang.org/issue/9856 for an example GIF.
if n, err := lzwr.Read(d.tmp[256:257]); n != 0 || (err != io.EOF && err != io.ErrUnexpectedEOF) {
if err != nil {
return fmt.Errorf("gif: reading image data: %v", err)
}
return errTooMuch
}
// In practice, some GIFs have an extra byte in the data sub-block
// stream, which we ignore. See https://golang.org/issue/16146.
if err := br.close(); err == errTooMuch {
return errTooMuch
} else if err != nil {
return fmt.Errorf("gif: reading image data: %v", err)
}
// Check that the color indexes are inside the palette.
if len(m.Palette) < 256 {
for _, pixel := range m.Pix {
if int(pixel) >= len(m.Palette) {
return errBadPixel
}
}
}
// Undo the interlacing if necessary.
if d.imageFields&fInterlace != 0 {
uninterlace(m)
}
if keepAllFrames || len(d.image) == 0 {
d.image = append(d.image, m)
d.delay = append(d.delay, d.delayTime)
d.disposal = append(d.disposal, d.disposalMethod)
}
// The GIF89a spec, Section 23 (Graphic Control Extension) says:
// "The scope of this extension is the first graphic rendering block
// to follow." We therefore reset the GCE fields to zero.
d.delayTime = 0
d.hasTransparentIndex = false
return nil
}
func (d *decoder) newImageFromDescriptor() (*image.Paletted, error) {
if err := readFull(d.r, d.tmp[:9]); err != nil {
return nil, fmt.Errorf("gif: can't read image descriptor: %s", err)
}
left := int(d.tmp[0]) + int(d.tmp[1])<<8
top := int(d.tmp[2]) + int(d.tmp[3])<<8
width := int(d.tmp[4]) + int(d.tmp[5])<<8
height := int(d.tmp[6]) + int(d.tmp[7])<<8
d.imageFields = d.tmp[8]
// The GIF89a spec, Section 20 (Image Descriptor) says: "Each image must
// fit within the boundaries of the Logical Screen, as defined in the
// Logical Screen Descriptor."
//
// This is conceptually similar to testing
// frameBounds := image.Rect(left, top, left+width, top+height)
// imageBounds := image.Rect(0, 0, d.width, d.height)
// if !frameBounds.In(imageBounds) { etc }
// but the semantics of the Go image.Rectangle type is that r.In(s) is true
// whenever r is an empty rectangle, even if r.Min.X > s.Max.X. Here, we
// want something stricter.
//
// Note that, by construction, left >= 0 && top >= 0, so we only have to
// explicitly compare frameBounds.Max (left+width, top+height) against
// imageBounds.Max (d.width, d.height) and not frameBounds.Min (left, top)
// against imageBounds.Min (0, 0).
if left+width > d.width || top+height > d.height {
return nil, errors.New("gif: frame bounds larger than image bounds")
}
return image.NewPaletted(image.Rectangle{
Min: image.Point{left, top},
Max: image.Point{left + width, top + height},
}, nil), nil
}
func (d *decoder) readBlock() (int, error) {
n, err := readByte(d.r)
if n == 0 || err != nil {
return 0, err
}
if err := readFull(d.r, d.tmp[:n]); err != nil {
return 0, err
}
return int(n), nil
}
// interlaceScan defines the ordering for a pass of the interlace algorithm.
type interlaceScan struct {
skip, start int
}
// interlacing represents the set of scans in an interlaced GIF image.
var interlacing = []interlaceScan{
{8, 0}, // Group 1 : Every 8th. row, starting with row 0.
{8, 4}, // Group 2 : Every 8th. row, starting with row 4.
{4, 2}, // Group 3 : Every 4th. row, starting with row 2.
{2, 1}, // Group 4 : Every 2nd. row, starting with row 1.
}
// uninterlace rearranges the pixels in m to account for interlaced input.
func uninterlace(m *image.Paletted) {
var nPix []uint8
dx := m.Bounds().Dx()
dy := m.Bounds().Dy()
nPix = make([]uint8, dx*dy)
offset := 0 // steps through the input by sequential scan lines.
for _, pass := range interlacing {
nOffset := pass.start * dx // steps through the output as defined by pass.
for y := pass.start; y < dy; y += pass.skip {
copy(nPix[nOffset:nOffset+dx], m.Pix[offset:offset+dx])
offset += dx
nOffset += dx * pass.skip
}
}
m.Pix = nPix
}
// Decode reads a GIF image from r and returns the first embedded
// image as an [image.Image].
func Decode(r io.Reader) (image.Image, error) {
var d decoder
if err := d.decode(r, false, false); err != nil {
return nil, err
}
return d.image[0], nil
}
// GIF represents the possibly multiple images stored in a GIF file.
type GIF struct {
Image []*image.Paletted // The successive images.
Delay []int // The successive delay times, one per frame, in 100ths of a second.
// LoopCount controls the number of times an animation will be
// restarted during display.
// A LoopCount of 0 means to loop forever.
// A LoopCount of -1 means to show each frame only once.
// Otherwise, the animation is looped LoopCount+1 times.
LoopCount int
// Disposal is the successive disposal methods, one per frame. For
// backwards compatibility, a nil Disposal is valid to pass to EncodeAll,
// and implies that each frame's disposal method is 0 (no disposal
// specified).
Disposal []byte
// Config is the global color table (palette), width and height. A nil or
// empty-color.Palette Config.ColorModel means that each frame has its own
// color table and there is no global color table. Each frame's bounds must
// be within the rectangle defined by the two points (0, 0) and
// (Config.Width, Config.Height).
//
// For backwards compatibility, a zero-valued Config is valid to pass to
// EncodeAll, and implies that the overall GIF's width and height equals
// the first frame's bounds' Rectangle.Max point.
Config image.Config
// BackgroundIndex is the background index in the global color table, for
// use with the DisposalBackground disposal method.
BackgroundIndex byte
}
// DecodeAll reads a GIF image from r and returns the sequential frames
// and timing information.
func DecodeAll(r io.Reader) (*GIF, error) {
var d decoder
if err := d.decode(r, false, true); err != nil {
return nil, err
}
gif := &GIF{
Image: d.image,
LoopCount: d.loopCount,
Delay: d.delay,
Disposal: d.disposal,
Config: image.Config{
ColorModel: d.globalColorTable,
Width: d.width,
Height: d.height,
},
BackgroundIndex: d.backgroundIndex,
}
return gif, nil
}
// DecodeConfig returns the global color model and dimensions of a GIF image
// without decoding the entire image.
func DecodeConfig(r io.Reader) (image.Config, error) {
var d decoder
if err := d.decode(r, true, false); err != nil {
return image.Config{}, err
}
return image.Config{
ColorModel: d.globalColorTable,
Width: d.width,
Height: d.height,
}, nil
}
func init() {
image.RegisterFormat("gif", "GIF8?a", Decode, DecodeConfig)
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gif
import (
"bufio"
"bytes"
"compress/lzw"
"errors"
"image"
"image/color"
"image/color/palette"
"image/draw"
"internal/byteorder"
"io"
)
// Graphic control extension fields.
const (
gcLabel = 0xF9
gcBlockSize = 0x04
)
var log2Lookup = [8]int{2, 4, 8, 16, 32, 64, 128, 256}
func log2(x int) int {
for i, v := range log2Lookup {
if x <= v {
return i
}
}
return -1
}
// writer is a buffered writer.
type writer interface {
Flush() error
io.Writer
io.ByteWriter
}
// encoder encodes an image to the GIF format.
type encoder struct {
// w is the writer to write to. err is the first error encountered during
// writing. All attempted writes after the first error become no-ops.
w writer
err error
// g is a reference to the data that is being encoded.
g GIF
// globalCT is the size in bytes of the global color table.
globalCT int
// buf is a scratch buffer. It must be at least 256 for the blockWriter.
buf [256]byte
globalColorTable [3 * 256]byte
localColorTable [3 * 256]byte
}
// blockWriter writes the block structure of GIF image data, which
// comprises (n, (n bytes)) blocks, with 1 <= n <= 255. It is the
// writer given to the LZW encoder, which is thus immune to the
// blocking.
type blockWriter struct {
e *encoder
}
func (b blockWriter) setup() {
b.e.buf[0] = 0
}
func (b blockWriter) Flush() error {
return b.e.err
}
func (b blockWriter) WriteByte(c byte) error {
if b.e.err != nil {
return b.e.err
}
// Append c to buffered sub-block.
b.e.buf[0]++
b.e.buf[b.e.buf[0]] = c
if b.e.buf[0] < 255 {
return nil
}
// Flush block
b.e.write(b.e.buf[:256])
b.e.buf[0] = 0
return b.e.err
}
// blockWriter must be an io.Writer for lzw.NewWriter, but this is never
// actually called.
func (b blockWriter) Write(data []byte) (int, error) {
for i, c := range data {
if err := b.WriteByte(c); err != nil {
return i, err
}
}
return len(data), nil
}
func (b blockWriter) close() {
// Write the block terminator (0x00), either by itself, or along with a
// pending sub-block.
if b.e.buf[0] == 0 {
b.e.writeByte(0)
} else {
n := uint(b.e.buf[0])
b.e.buf[n+1] = 0
b.e.write(b.e.buf[:n+2])
}
b.e.flush()
}
func (e *encoder) flush() {
if e.err != nil {
return
}
e.err = e.w.Flush()
}
func (e *encoder) write(p []byte) {
if e.err != nil {
return
}
_, e.err = e.w.Write(p)
}
func (e *encoder) writeByte(b byte) {
if e.err != nil {
return
}
e.err = e.w.WriteByte(b)
}
func (e *encoder) writeHeader() {
if e.err != nil {
return
}
_, e.err = io.WriteString(e.w, "GIF89a")
if e.err != nil {
return
}
// Logical screen width and height.
byteorder.LEPutUint16(e.buf[0:2], uint16(e.g.Config.Width))
byteorder.LEPutUint16(e.buf[2:4], uint16(e.g.Config.Height))
e.write(e.buf[:4])
if p, ok := e.g.Config.ColorModel.(color.Palette); ok && len(p) > 0 {
paddedSize := log2(len(p)) // Size of Global Color Table: 2^(1+n).
e.buf[0] = fColorTable | uint8(paddedSize)
e.buf[1] = e.g.BackgroundIndex
e.buf[2] = 0x00 // Pixel Aspect Ratio.
e.write(e.buf[:3])
var err error
e.globalCT, err = encodeColorTable(e.globalColorTable[:], p, paddedSize)
if err != nil && e.err == nil {
e.err = err
return
}
e.write(e.globalColorTable[:e.globalCT])
} else {
// All frames have a local color table, so a global color table
// is not needed.
e.buf[0] = 0x00
e.buf[1] = 0x00 // Background Color Index.
e.buf[2] = 0x00 // Pixel Aspect Ratio.
e.write(e.buf[:3])
}
// Add animation info if necessary.
if len(e.g.Image) > 1 && e.g.LoopCount >= 0 {
e.buf[0] = 0x21 // Extension Introducer.
e.buf[1] = 0xff // Application Label.
e.buf[2] = 0x0b // Block Size.
e.write(e.buf[:3])
_, err := io.WriteString(e.w, "NETSCAPE2.0") // Application Identifier.
if err != nil && e.err == nil {
e.err = err
return
}
e.buf[0] = 0x03 // Block Size.
e.buf[1] = 0x01 // Sub-block Index.
byteorder.LEPutUint16(e.buf[2:4], uint16(e.g.LoopCount))
e.buf[4] = 0x00 // Block Terminator.
e.write(e.buf[:5])
}
}
func encodeColorTable(dst []byte, p color.Palette, size int) (int, error) {
if uint(size) >= uint(len(log2Lookup)) {
return 0, errors.New("gif: cannot encode color table with more than 256 entries")
}
for i, c := range p {
if c == nil {
return 0, errors.New("gif: cannot encode color table with nil entries")
}
var r, g, b uint8
// It is most likely that the palette is full of color.RGBAs, so they
// get a fast path.
if rgba, ok := c.(color.RGBA); ok {
r, g, b = rgba.R, rgba.G, rgba.B
} else {
rr, gg, bb, _ := c.RGBA()
r, g, b = uint8(rr>>8), uint8(gg>>8), uint8(bb>>8)
}
dst[3*i+0] = r
dst[3*i+1] = g
dst[3*i+2] = b
}
n := log2Lookup[size]
if n > len(p) {
// Pad with black.
clear(dst[3*len(p) : 3*n])
}
return 3 * n, nil
}
func (e *encoder) colorTablesMatch(localLen, transparentIndex int) bool {
localSize := 3 * localLen
if transparentIndex >= 0 {
trOff := 3 * transparentIndex
return bytes.Equal(e.globalColorTable[:trOff], e.localColorTable[:trOff]) &&
bytes.Equal(e.globalColorTable[trOff+3:localSize], e.localColorTable[trOff+3:localSize])
}
return bytes.Equal(e.globalColorTable[:localSize], e.localColorTable[:localSize])
}
func (e *encoder) writeImageBlock(pm *image.Paletted, delay int, disposal byte) {
if e.err != nil {
return
}
if len(pm.Palette) == 0 {
e.err = errors.New("gif: cannot encode image block with empty palette")
return
}
b := pm.Bounds()
if b.Min.X < 0 || b.Max.X >= 1<<16 || b.Min.Y < 0 || b.Max.Y >= 1<<16 {
e.err = errors.New("gif: image block is too large to encode")
return
}
if !b.In(image.Rectangle{Max: image.Point{e.g.Config.Width, e.g.Config.Height}}) {
e.err = errors.New("gif: image block is out of bounds")
return
}
transparentIndex := -1
for i, c := range pm.Palette {
if c == nil {
e.err = errors.New("gif: cannot encode color table with nil entries")
return
}
if _, _, _, a := c.RGBA(); a == 0 {
transparentIndex = i
break
}
}
if delay > 0 || disposal != 0 || transparentIndex != -1 {
e.buf[0] = sExtension // Extension Introducer.
e.buf[1] = gcLabel // Graphic Control Label.
e.buf[2] = gcBlockSize // Block Size.
if transparentIndex != -1 {
e.buf[3] = 0x01 | disposal<<2
} else {
e.buf[3] = 0x00 | disposal<<2
}
byteorder.LEPutUint16(e.buf[4:6], uint16(delay)) // Delay Time (1/100ths of a second)
// Transparent color index.
if transparentIndex != -1 {
e.buf[6] = uint8(transparentIndex)
} else {
e.buf[6] = 0x00
}
e.buf[7] = 0x00 // Block Terminator.
e.write(e.buf[:8])
}
e.buf[0] = sImageDescriptor
byteorder.LEPutUint16(e.buf[1:3], uint16(b.Min.X))
byteorder.LEPutUint16(e.buf[3:5], uint16(b.Min.Y))
byteorder.LEPutUint16(e.buf[5:7], uint16(b.Dx()))
byteorder.LEPutUint16(e.buf[7:9], uint16(b.Dy()))
e.write(e.buf[:9])
// To determine whether or not this frame's palette is the same as the
// global palette, we can check a couple things. First, do they actually
// point to the same []color.Color? If so, they are equal so long as the
// frame's palette is not longer than the global palette...
paddedSize := log2(len(pm.Palette)) // Size of Local Color Table: 2^(1+n).
if gp, ok := e.g.Config.ColorModel.(color.Palette); ok && len(pm.Palette) <= len(gp) && &gp[0] == &pm.Palette[0] {
e.writeByte(0) // Use the global color table.
} else {
ct, err := encodeColorTable(e.localColorTable[:], pm.Palette, paddedSize)
if err != nil {
if e.err == nil {
e.err = err
}
return
}
// This frame's palette is not the very same slice as the global
// palette, but it might be a copy, possibly with one value turned into
// transparency by DecodeAll.
if ct <= e.globalCT && e.colorTablesMatch(len(pm.Palette), transparentIndex) {
e.writeByte(0) // Use the global color table.
} else {
// Use a local color table.
e.writeByte(fColorTable | uint8(paddedSize))
e.write(e.localColorTable[:ct])
}
}
litWidth := paddedSize + 1
if litWidth < 2 {
litWidth = 2
}
e.writeByte(uint8(litWidth)) // LZW Minimum Code Size.
bw := blockWriter{e: e}
bw.setup()
lzww := lzw.NewWriter(bw, lzw.LSB, litWidth)
if dx := b.Dx(); dx == pm.Stride {
_, e.err = lzww.Write(pm.Pix[:dx*b.Dy()])
if e.err != nil {
lzww.Close()
return
}
} else {
for i, y := 0, b.Min.Y; y < b.Max.Y; i, y = i+pm.Stride, y+1 {
_, e.err = lzww.Write(pm.Pix[i : i+dx])
if e.err != nil {
lzww.Close()
return
}
}
}
lzww.Close() // flush to bw
bw.close() // flush to e.w
}
// Options are the encoding parameters.
type Options struct {
// NumColors is the maximum number of colors used in the image.
// It ranges from 1 to 256.
NumColors int
// Quantizer is used to produce a palette with size NumColors.
// palette.Plan9 is used in place of a nil Quantizer.
Quantizer draw.Quantizer
// Drawer is used to convert the source image to the desired palette.
// draw.FloydSteinberg is used in place of a nil Drawer.
Drawer draw.Drawer
}
// EncodeAll writes the images in g to w in GIF format with the
// given loop count and delay between frames.
func EncodeAll(w io.Writer, g *GIF) error {
if len(g.Image) == 0 {
return errors.New("gif: must provide at least one image")
}
if len(g.Image) != len(g.Delay) {
return errors.New("gif: mismatched image and delay lengths")
}
e := encoder{g: *g}
// The GIF.Disposal, GIF.Config and GIF.BackgroundIndex fields were added
// in Go 1.5. Valid Go 1.4 code, such as when the Disposal field is omitted
// in a GIF struct literal, should still produce valid GIFs.
if e.g.Disposal != nil && len(e.g.Image) != len(e.g.Disposal) {
return errors.New("gif: mismatched image and disposal lengths")
}
if e.g.Config == (image.Config{}) {
p := g.Image[0].Bounds().Max
e.g.Config.Width = p.X
e.g.Config.Height = p.Y
} else if e.g.Config.ColorModel != nil {
if _, ok := e.g.Config.ColorModel.(color.Palette); !ok {
return errors.New("gif: GIF color model must be a color.Palette")
}
}
if ww, ok := w.(writer); ok {
e.w = ww
} else {
e.w = bufio.NewWriter(w)
}
e.writeHeader()
for i, pm := range g.Image {
disposal := uint8(0)
if g.Disposal != nil {
disposal = g.Disposal[i]
}
e.writeImageBlock(pm, g.Delay[i], disposal)
}
e.writeByte(sTrailer)
e.flush()
return e.err
}
// Encode writes the Image m to w in GIF format.
func Encode(w io.Writer, m image.Image, o *Options) error {
// Check for bounds and size restrictions.
b := m.Bounds()
if b.Dx() >= 1<<16 || b.Dy() >= 1<<16 {
return errors.New("gif: image is too large to encode")
}
opts := Options{}
if o != nil {
opts = *o
}
if opts.NumColors < 1 || 256 < opts.NumColors {
opts.NumColors = 256
}
if opts.Drawer == nil {
opts.Drawer = draw.FloydSteinberg
}
pm, _ := m.(*image.Paletted)
if pm == nil {
if cp, ok := m.ColorModel().(color.Palette); ok {
pm = image.NewPaletted(b, cp)
for y := b.Min.Y; y < b.Max.Y; y++ {
for x := b.Min.X; x < b.Max.X; x++ {
pm.Set(x, y, cp.Convert(m.At(x, y)))
}
}
}
}
if pm == nil || len(pm.Palette) > opts.NumColors {
// Set pm to be a palettedized copy of m, including its bounds, which
// might not start at (0, 0).
//
// TODO: Pick a better sub-sample of the Plan 9 palette.
pm = image.NewPaletted(b, palette.Plan9[:opts.NumColors])
if opts.Quantizer != nil {
pm.Palette = opts.Quantizer.Quantize(make(color.Palette, 0, opts.NumColors), m)
}
opts.Drawer.Draw(pm, b, m, b.Min)
}
// When calling Encode instead of EncodeAll, the single-frame image is
// translated such that its top-left corner is (0, 0), so that the single
// frame completely fills the overall GIF's bounds.
if pm.Rect.Min != (image.Point{}) {
dup := *pm
dup.Rect = dup.Rect.Sub(dup.Rect.Min)
pm = &dup
}
return EncodeAll(w, &GIF{
Image: []*image.Paletted{pm},
Delay: []int{0},
Config: image.Config{
ColorModel: pm.Palette,
Width: b.Dx(),
Height: b.Dy(),
},
})
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package image implements a basic 2-D image library.
//
// The fundamental interface is called [Image]. An [Image] contains colors, which
// are described in the image/color package.
//
// Values of the [Image] interface are created either by calling functions such
// as [NewRGBA] and [NewPaletted], or by calling [Decode] on an [io.Reader] containing
// image data in a format such as GIF, JPEG or PNG. Decoding any particular
// image format requires the prior registration of a decoder function.
// Registration is typically automatic as a side effect of initializing that
// format's package so that, to decode a PNG image, it suffices to have
//
// import _ "image/png"
//
// in a program's main package. The _ means to import a package purely for its
// initialization side effects.
//
// See "The Go image package" for more details:
// https://golang.org/doc/articles/image_package.html
//
// # Security Considerations
//
// The image package can be used to parse arbitrarily large images, which can
// cause resource exhaustion on machines which do not have enough memory to
// store them. When operating on arbitrary images, [DecodeConfig] should be called
// before [Decode], so that the program can decide whether the image, as defined
// in the returned header, can be safely decoded with the available resources. A
// call to [Decode] which produces an extremely large image, as defined in the
// header returned by [DecodeConfig], is not considered a security issue,
// regardless of whether the image is itself malformed or not. A call to
// [DecodeConfig] which returns a header which does not match the image returned
// by [Decode] may be considered a security issue, and should be reported per the
// [Go Security Policy].
//
// [Go Security Policy]: https://go.dev/security/policy
package image
import (
"image/color"
)
// Config holds an image's color model and dimensions.
type Config struct {
ColorModel color.Model
Width, Height int
}
// Image is a finite rectangular grid of [color.Color] values taken from a color
// model.
type Image interface {
// ColorModel returns the Image's color model.
ColorModel() color.Model
// Bounds returns the domain for which At can return non-zero color.
// The bounds do not necessarily contain the point (0, 0).
Bounds() Rectangle
// At returns the color of the pixel at (x, y).
// At(Bounds().Min.X, Bounds().Min.Y) returns the upper-left pixel of the grid.
// At(Bounds().Max.X-1, Bounds().Max.Y-1) returns the lower-right one.
At(x, y int) color.Color
}
// RGBA64Image is an [Image] whose pixels can be converted directly to a
// color.RGBA64.
type RGBA64Image interface {
// RGBA64At returns the RGBA64 color of the pixel at (x, y). It is
// equivalent to calling At(x, y).RGBA() and converting the resulting
// 32-bit return values to a color.RGBA64, but it can avoid allocations
// from converting concrete color types to the color.Color interface type.
RGBA64At(x, y int) color.RGBA64
Image
}
// PalettedImage is an image whose colors may come from a limited palette.
// If m is a PalettedImage and m.ColorModel() returns a [color.Palette] p,
// then m.At(x, y) should be equivalent to p[m.ColorIndexAt(x, y)]. If m's
// color model is not a color.Palette, then ColorIndexAt's behavior is
// undefined.
type PalettedImage interface {
// ColorIndexAt returns the palette index of the pixel at (x, y).
ColorIndexAt(x, y int) uint8
Image
}
// pixelBufferLength returns the length of the []uint8 typed Pix slice field
// for the NewXxx functions. Conceptually, this is just (bpp * width * height),
// but this function panics if at least one of those is negative or if the
// computation would overflow the int type.
//
// This panics instead of returning an error because of backwards
// compatibility. The NewXxx functions do not return an error.
func pixelBufferLength(bytesPerPixel int, r Rectangle, imageTypeName string) int {
totalLength := mul3NonNeg(bytesPerPixel, r.Dx(), r.Dy())
if totalLength < 0 {
panic("image: New" + imageTypeName + " Rectangle has huge or negative dimensions")
}
return totalLength
}
// RGBA is an in-memory image whose At method returns [color.RGBA] values.
type RGBA struct {
// Pix holds the image's pixels, in R, G, B, A order. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*4].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *RGBA) ColorModel() color.Model { return color.RGBAModel }
func (p *RGBA) Bounds() Rectangle { return p.Rect }
func (p *RGBA) At(x, y int) color.Color {
return p.RGBAAt(x, y)
}
func (p *RGBA) RGBA64At(x, y int) color.RGBA64 {
if !(Point{x, y}.In(p.Rect)) {
return color.RGBA64{}
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
r := uint16(s[0])
g := uint16(s[1])
b := uint16(s[2])
a := uint16(s[3])
return color.RGBA64{
(r << 8) | r,
(g << 8) | g,
(b << 8) | b,
(a << 8) | a,
}
}
func (p *RGBA) RGBAAt(x, y int) color.RGBA {
if !(Point{x, y}.In(p.Rect)) {
return color.RGBA{}
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
return color.RGBA{s[0], s[1], s[2], s[3]}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *RGBA) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*4
}
func (p *RGBA) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
c1 := color.RGBAModel.Convert(c).(color.RGBA)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = c1.R
s[1] = c1.G
s[2] = c1.B
s[3] = c1.A
}
func (p *RGBA) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = uint8(c.R >> 8)
s[1] = uint8(c.G >> 8)
s[2] = uint8(c.B >> 8)
s[3] = uint8(c.A >> 8)
}
func (p *RGBA) SetRGBA(x, y int, c color.RGBA) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = c.R
s[1] = c.G
s[2] = c.B
s[3] = c.A
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *RGBA) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &RGBA{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &RGBA{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *RGBA) Opaque() bool {
if p.Rect.Empty() {
return true
}
i0, i1 := 3, p.Rect.Dx()*4
for y := p.Rect.Min.Y; y < p.Rect.Max.Y; y++ {
for i := i0; i < i1; i += 4 {
if p.Pix[i] != 0xff {
return false
}
}
i0 += p.Stride
i1 += p.Stride
}
return true
}
// NewRGBA returns a new [RGBA] image with the given bounds.
func NewRGBA(r Rectangle) *RGBA {
return &RGBA{
Pix: make([]uint8, pixelBufferLength(4, r, "RGBA")),
Stride: 4 * r.Dx(),
Rect: r,
}
}
// RGBA64 is an in-memory image whose At method returns [color.RGBA64] values.
type RGBA64 struct {
// Pix holds the image's pixels, in R, G, B, A order and big-endian format. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*8].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *RGBA64) ColorModel() color.Model { return color.RGBA64Model }
func (p *RGBA64) Bounds() Rectangle { return p.Rect }
func (p *RGBA64) At(x, y int) color.Color {
return p.RGBA64At(x, y)
}
func (p *RGBA64) RGBA64At(x, y int) color.RGBA64 {
if !(Point{x, y}.In(p.Rect)) {
return color.RGBA64{}
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+8 : i+8] // Small cap improves performance, see https://golang.org/issue/27857
return color.RGBA64{
uint16(s[0])<<8 | uint16(s[1]),
uint16(s[2])<<8 | uint16(s[3]),
uint16(s[4])<<8 | uint16(s[5]),
uint16(s[6])<<8 | uint16(s[7]),
}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *RGBA64) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*8
}
func (p *RGBA64) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
c1 := color.RGBA64Model.Convert(c).(color.RGBA64)
s := p.Pix[i : i+8 : i+8] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = uint8(c1.R >> 8)
s[1] = uint8(c1.R)
s[2] = uint8(c1.G >> 8)
s[3] = uint8(c1.G)
s[4] = uint8(c1.B >> 8)
s[5] = uint8(c1.B)
s[6] = uint8(c1.A >> 8)
s[7] = uint8(c1.A)
}
func (p *RGBA64) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+8 : i+8] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = uint8(c.R >> 8)
s[1] = uint8(c.R)
s[2] = uint8(c.G >> 8)
s[3] = uint8(c.G)
s[4] = uint8(c.B >> 8)
s[5] = uint8(c.B)
s[6] = uint8(c.A >> 8)
s[7] = uint8(c.A)
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *RGBA64) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &RGBA64{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &RGBA64{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *RGBA64) Opaque() bool {
if p.Rect.Empty() {
return true
}
i0, i1 := 6, p.Rect.Dx()*8
for y := p.Rect.Min.Y; y < p.Rect.Max.Y; y++ {
for i := i0; i < i1; i += 8 {
if p.Pix[i+0] != 0xff || p.Pix[i+1] != 0xff {
return false
}
}
i0 += p.Stride
i1 += p.Stride
}
return true
}
// NewRGBA64 returns a new [RGBA64] image with the given bounds.
func NewRGBA64(r Rectangle) *RGBA64 {
return &RGBA64{
Pix: make([]uint8, pixelBufferLength(8, r, "RGBA64")),
Stride: 8 * r.Dx(),
Rect: r,
}
}
// NRGBA is an in-memory image whose At method returns [color.NRGBA] values.
type NRGBA struct {
// Pix holds the image's pixels, in R, G, B, A order. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*4].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *NRGBA) ColorModel() color.Model { return color.NRGBAModel }
func (p *NRGBA) Bounds() Rectangle { return p.Rect }
func (p *NRGBA) At(x, y int) color.Color {
return p.NRGBAAt(x, y)
}
func (p *NRGBA) RGBA64At(x, y int) color.RGBA64 {
r, g, b, a := p.NRGBAAt(x, y).RGBA()
return color.RGBA64{uint16(r), uint16(g), uint16(b), uint16(a)}
}
func (p *NRGBA) NRGBAAt(x, y int) color.NRGBA {
if !(Point{x, y}.In(p.Rect)) {
return color.NRGBA{}
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
return color.NRGBA{s[0], s[1], s[2], s[3]}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *NRGBA) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*4
}
func (p *NRGBA) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
c1 := color.NRGBAModel.Convert(c).(color.NRGBA)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = c1.R
s[1] = c1.G
s[2] = c1.B
s[3] = c1.A
}
func (p *NRGBA) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
r, g, b, a := uint32(c.R), uint32(c.G), uint32(c.B), uint32(c.A)
if (a != 0) && (a != 0xffff) {
r = (r * 0xffff) / a
g = (g * 0xffff) / a
b = (b * 0xffff) / a
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = uint8(r >> 8)
s[1] = uint8(g >> 8)
s[2] = uint8(b >> 8)
s[3] = uint8(a >> 8)
}
func (p *NRGBA) SetNRGBA(x, y int, c color.NRGBA) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = c.R
s[1] = c.G
s[2] = c.B
s[3] = c.A
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *NRGBA) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &NRGBA{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &NRGBA{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *NRGBA) Opaque() bool {
if p.Rect.Empty() {
return true
}
i0, i1 := 3, p.Rect.Dx()*4
for y := p.Rect.Min.Y; y < p.Rect.Max.Y; y++ {
for i := i0; i < i1; i += 4 {
if p.Pix[i] != 0xff {
return false
}
}
i0 += p.Stride
i1 += p.Stride
}
return true
}
// NewNRGBA returns a new [NRGBA] image with the given bounds.
func NewNRGBA(r Rectangle) *NRGBA {
return &NRGBA{
Pix: make([]uint8, pixelBufferLength(4, r, "NRGBA")),
Stride: 4 * r.Dx(),
Rect: r,
}
}
// NRGBA64 is an in-memory image whose At method returns [color.NRGBA64] values.
type NRGBA64 struct {
// Pix holds the image's pixels, in R, G, B, A order and big-endian format. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*8].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *NRGBA64) ColorModel() color.Model { return color.NRGBA64Model }
func (p *NRGBA64) Bounds() Rectangle { return p.Rect }
func (p *NRGBA64) At(x, y int) color.Color {
return p.NRGBA64At(x, y)
}
func (p *NRGBA64) RGBA64At(x, y int) color.RGBA64 {
r, g, b, a := p.NRGBA64At(x, y).RGBA()
return color.RGBA64{uint16(r), uint16(g), uint16(b), uint16(a)}
}
func (p *NRGBA64) NRGBA64At(x, y int) color.NRGBA64 {
if !(Point{x, y}.In(p.Rect)) {
return color.NRGBA64{}
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+8 : i+8] // Small cap improves performance, see https://golang.org/issue/27857
return color.NRGBA64{
uint16(s[0])<<8 | uint16(s[1]),
uint16(s[2])<<8 | uint16(s[3]),
uint16(s[4])<<8 | uint16(s[5]),
uint16(s[6])<<8 | uint16(s[7]),
}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *NRGBA64) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*8
}
func (p *NRGBA64) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
c1 := color.NRGBA64Model.Convert(c).(color.NRGBA64)
s := p.Pix[i : i+8 : i+8] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = uint8(c1.R >> 8)
s[1] = uint8(c1.R)
s[2] = uint8(c1.G >> 8)
s[3] = uint8(c1.G)
s[4] = uint8(c1.B >> 8)
s[5] = uint8(c1.B)
s[6] = uint8(c1.A >> 8)
s[7] = uint8(c1.A)
}
func (p *NRGBA64) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
r, g, b, a := uint32(c.R), uint32(c.G), uint32(c.B), uint32(c.A)
if (a != 0) && (a != 0xffff) {
r = (r * 0xffff) / a
g = (g * 0xffff) / a
b = (b * 0xffff) / a
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+8 : i+8] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = uint8(r >> 8)
s[1] = uint8(r)
s[2] = uint8(g >> 8)
s[3] = uint8(g)
s[4] = uint8(b >> 8)
s[5] = uint8(b)
s[6] = uint8(a >> 8)
s[7] = uint8(a)
}
func (p *NRGBA64) SetNRGBA64(x, y int, c color.NRGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+8 : i+8] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = uint8(c.R >> 8)
s[1] = uint8(c.R)
s[2] = uint8(c.G >> 8)
s[3] = uint8(c.G)
s[4] = uint8(c.B >> 8)
s[5] = uint8(c.B)
s[6] = uint8(c.A >> 8)
s[7] = uint8(c.A)
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *NRGBA64) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &NRGBA64{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &NRGBA64{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *NRGBA64) Opaque() bool {
if p.Rect.Empty() {
return true
}
i0, i1 := 6, p.Rect.Dx()*8
for y := p.Rect.Min.Y; y < p.Rect.Max.Y; y++ {
for i := i0; i < i1; i += 8 {
if p.Pix[i+0] != 0xff || p.Pix[i+1] != 0xff {
return false
}
}
i0 += p.Stride
i1 += p.Stride
}
return true
}
// NewNRGBA64 returns a new [NRGBA64] image with the given bounds.
func NewNRGBA64(r Rectangle) *NRGBA64 {
return &NRGBA64{
Pix: make([]uint8, pixelBufferLength(8, r, "NRGBA64")),
Stride: 8 * r.Dx(),
Rect: r,
}
}
// Alpha is an in-memory image whose At method returns [color.Alpha] values.
type Alpha struct {
// Pix holds the image's pixels, as alpha values. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*1].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *Alpha) ColorModel() color.Model { return color.AlphaModel }
func (p *Alpha) Bounds() Rectangle { return p.Rect }
func (p *Alpha) At(x, y int) color.Color {
return p.AlphaAt(x, y)
}
func (p *Alpha) RGBA64At(x, y int) color.RGBA64 {
a := uint16(p.AlphaAt(x, y).A)
a |= a << 8
return color.RGBA64{a, a, a, a}
}
func (p *Alpha) AlphaAt(x, y int) color.Alpha {
if !(Point{x, y}.In(p.Rect)) {
return color.Alpha{}
}
i := p.PixOffset(x, y)
return color.Alpha{p.Pix[i]}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *Alpha) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*1
}
func (p *Alpha) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = color.AlphaModel.Convert(c).(color.Alpha).A
}
func (p *Alpha) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = uint8(c.A >> 8)
}
func (p *Alpha) SetAlpha(x, y int, c color.Alpha) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = c.A
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *Alpha) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &Alpha{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &Alpha{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *Alpha) Opaque() bool {
if p.Rect.Empty() {
return true
}
i0, i1 := 0, p.Rect.Dx()
for y := p.Rect.Min.Y; y < p.Rect.Max.Y; y++ {
for i := i0; i < i1; i++ {
if p.Pix[i] != 0xff {
return false
}
}
i0 += p.Stride
i1 += p.Stride
}
return true
}
// NewAlpha returns a new [Alpha] image with the given bounds.
func NewAlpha(r Rectangle) *Alpha {
return &Alpha{
Pix: make([]uint8, pixelBufferLength(1, r, "Alpha")),
Stride: 1 * r.Dx(),
Rect: r,
}
}
// Alpha16 is an in-memory image whose At method returns [color.Alpha16] values.
type Alpha16 struct {
// Pix holds the image's pixels, as alpha values in big-endian format. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*2].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *Alpha16) ColorModel() color.Model { return color.Alpha16Model }
func (p *Alpha16) Bounds() Rectangle { return p.Rect }
func (p *Alpha16) At(x, y int) color.Color {
return p.Alpha16At(x, y)
}
func (p *Alpha16) RGBA64At(x, y int) color.RGBA64 {
a := p.Alpha16At(x, y).A
return color.RGBA64{a, a, a, a}
}
func (p *Alpha16) Alpha16At(x, y int) color.Alpha16 {
if !(Point{x, y}.In(p.Rect)) {
return color.Alpha16{}
}
i := p.PixOffset(x, y)
return color.Alpha16{uint16(p.Pix[i+0])<<8 | uint16(p.Pix[i+1])}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *Alpha16) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*2
}
func (p *Alpha16) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
c1 := color.Alpha16Model.Convert(c).(color.Alpha16)
p.Pix[i+0] = uint8(c1.A >> 8)
p.Pix[i+1] = uint8(c1.A)
}
func (p *Alpha16) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i+0] = uint8(c.A >> 8)
p.Pix[i+1] = uint8(c.A)
}
func (p *Alpha16) SetAlpha16(x, y int, c color.Alpha16) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i+0] = uint8(c.A >> 8)
p.Pix[i+1] = uint8(c.A)
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *Alpha16) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &Alpha16{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &Alpha16{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *Alpha16) Opaque() bool {
if p.Rect.Empty() {
return true
}
i0, i1 := 0, p.Rect.Dx()*2
for y := p.Rect.Min.Y; y < p.Rect.Max.Y; y++ {
for i := i0; i < i1; i += 2 {
if p.Pix[i+0] != 0xff || p.Pix[i+1] != 0xff {
return false
}
}
i0 += p.Stride
i1 += p.Stride
}
return true
}
// NewAlpha16 returns a new [Alpha16] image with the given bounds.
func NewAlpha16(r Rectangle) *Alpha16 {
return &Alpha16{
Pix: make([]uint8, pixelBufferLength(2, r, "Alpha16")),
Stride: 2 * r.Dx(),
Rect: r,
}
}
// Gray is an in-memory image whose At method returns [color.Gray] values.
type Gray struct {
// Pix holds the image's pixels, as gray values. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*1].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *Gray) ColorModel() color.Model { return color.GrayModel }
func (p *Gray) Bounds() Rectangle { return p.Rect }
func (p *Gray) At(x, y int) color.Color {
return p.GrayAt(x, y)
}
func (p *Gray) RGBA64At(x, y int) color.RGBA64 {
gray := uint16(p.GrayAt(x, y).Y)
gray |= gray << 8
return color.RGBA64{gray, gray, gray, 0xffff}
}
func (p *Gray) GrayAt(x, y int) color.Gray {
if !(Point{x, y}.In(p.Rect)) {
return color.Gray{}
}
i := p.PixOffset(x, y)
return color.Gray{p.Pix[i]}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *Gray) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*1
}
func (p *Gray) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = color.GrayModel.Convert(c).(color.Gray).Y
}
func (p *Gray) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
// This formula is the same as in color.grayModel.
gray := (19595*uint32(c.R) + 38470*uint32(c.G) + 7471*uint32(c.B) + 1<<15) >> 24
i := p.PixOffset(x, y)
p.Pix[i] = uint8(gray)
}
func (p *Gray) SetGray(x, y int, c color.Gray) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = c.Y
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *Gray) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &Gray{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &Gray{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *Gray) Opaque() bool {
return true
}
// NewGray returns a new [Gray] image with the given bounds.
func NewGray(r Rectangle) *Gray {
return &Gray{
Pix: make([]uint8, pixelBufferLength(1, r, "Gray")),
Stride: 1 * r.Dx(),
Rect: r,
}
}
// Gray16 is an in-memory image whose At method returns [color.Gray16] values.
type Gray16 struct {
// Pix holds the image's pixels, as gray values in big-endian format. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*2].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *Gray16) ColorModel() color.Model { return color.Gray16Model }
func (p *Gray16) Bounds() Rectangle { return p.Rect }
func (p *Gray16) At(x, y int) color.Color {
return p.Gray16At(x, y)
}
func (p *Gray16) RGBA64At(x, y int) color.RGBA64 {
gray := p.Gray16At(x, y).Y
return color.RGBA64{gray, gray, gray, 0xffff}
}
func (p *Gray16) Gray16At(x, y int) color.Gray16 {
if !(Point{x, y}.In(p.Rect)) {
return color.Gray16{}
}
i := p.PixOffset(x, y)
return color.Gray16{uint16(p.Pix[i+0])<<8 | uint16(p.Pix[i+1])}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *Gray16) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*2
}
func (p *Gray16) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
c1 := color.Gray16Model.Convert(c).(color.Gray16)
p.Pix[i+0] = uint8(c1.Y >> 8)
p.Pix[i+1] = uint8(c1.Y)
}
func (p *Gray16) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
// This formula is the same as in color.gray16Model.
gray := (19595*uint32(c.R) + 38470*uint32(c.G) + 7471*uint32(c.B) + 1<<15) >> 16
i := p.PixOffset(x, y)
p.Pix[i+0] = uint8(gray >> 8)
p.Pix[i+1] = uint8(gray)
}
func (p *Gray16) SetGray16(x, y int, c color.Gray16) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i+0] = uint8(c.Y >> 8)
p.Pix[i+1] = uint8(c.Y)
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *Gray16) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &Gray16{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &Gray16{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *Gray16) Opaque() bool {
return true
}
// NewGray16 returns a new [Gray16] image with the given bounds.
func NewGray16(r Rectangle) *Gray16 {
return &Gray16{
Pix: make([]uint8, pixelBufferLength(2, r, "Gray16")),
Stride: 2 * r.Dx(),
Rect: r,
}
}
// CMYK is an in-memory image whose At method returns [color.CMYK] values.
type CMYK struct {
// Pix holds the image's pixels, in C, M, Y, K order. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*4].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *CMYK) ColorModel() color.Model { return color.CMYKModel }
func (p *CMYK) Bounds() Rectangle { return p.Rect }
func (p *CMYK) At(x, y int) color.Color {
return p.CMYKAt(x, y)
}
func (p *CMYK) RGBA64At(x, y int) color.RGBA64 {
r, g, b, a := p.CMYKAt(x, y).RGBA()
return color.RGBA64{uint16(r), uint16(g), uint16(b), uint16(a)}
}
func (p *CMYK) CMYKAt(x, y int) color.CMYK {
if !(Point{x, y}.In(p.Rect)) {
return color.CMYK{}
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
return color.CMYK{s[0], s[1], s[2], s[3]}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *CMYK) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*4
}
func (p *CMYK) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
c1 := color.CMYKModel.Convert(c).(color.CMYK)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = c1.C
s[1] = c1.M
s[2] = c1.Y
s[3] = c1.K
}
func (p *CMYK) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
cc, mm, yy, kk := color.RGBToCMYK(uint8(c.R>>8), uint8(c.G>>8), uint8(c.B>>8))
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = cc
s[1] = mm
s[2] = yy
s[3] = kk
}
func (p *CMYK) SetCMYK(x, y int, c color.CMYK) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = c.C
s[1] = c.M
s[2] = c.Y
s[3] = c.K
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *CMYK) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &CMYK{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &CMYK{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *CMYK) Opaque() bool {
return true
}
// NewCMYK returns a new CMYK image with the given bounds.
func NewCMYK(r Rectangle) *CMYK {
return &CMYK{
Pix: make([]uint8, pixelBufferLength(4, r, "CMYK")),
Stride: 4 * r.Dx(),
Rect: r,
}
}
// Paletted is an in-memory image of uint8 indices into a given palette.
type Paletted struct {
// Pix holds the image's pixels, as palette indices. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*1].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
// Palette is the image's palette.
Palette color.Palette
}
func (p *Paletted) ColorModel() color.Model { return p.Palette }
func (p *Paletted) Bounds() Rectangle { return p.Rect }
func (p *Paletted) At(x, y int) color.Color {
if len(p.Palette) == 0 {
return nil
}
if !(Point{x, y}.In(p.Rect)) {
return p.Palette[0]
}
i := p.PixOffset(x, y)
return p.Palette[p.Pix[i]]
}
func (p *Paletted) RGBA64At(x, y int) color.RGBA64 {
if len(p.Palette) == 0 {
return color.RGBA64{}
}
c := color.Color(nil)
if !(Point{x, y}.In(p.Rect)) {
c = p.Palette[0]
} else {
i := p.PixOffset(x, y)
c = p.Palette[p.Pix[i]]
}
r, g, b, a := c.RGBA()
return color.RGBA64{
uint16(r),
uint16(g),
uint16(b),
uint16(a),
}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *Paletted) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*1
}
func (p *Paletted) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = uint8(p.Palette.Index(c))
}
func (p *Paletted) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = uint8(p.Palette.Index(c))
}
func (p *Paletted) ColorIndexAt(x, y int) uint8 {
if !(Point{x, y}.In(p.Rect)) {
return 0
}
i := p.PixOffset(x, y)
return p.Pix[i]
}
func (p *Paletted) SetColorIndex(x, y int, index uint8) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = index
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *Paletted) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &Paletted{
Palette: p.Palette,
}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &Paletted{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: p.Rect.Intersect(r),
Palette: p.Palette,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *Paletted) Opaque() bool {
var present [256]bool
i0, i1 := 0, p.Rect.Dx()
for y := p.Rect.Min.Y; y < p.Rect.Max.Y; y++ {
for _, c := range p.Pix[i0:i1] {
present[c] = true
}
i0 += p.Stride
i1 += p.Stride
}
for i, c := range p.Palette {
if !present[i] {
continue
}
_, _, _, a := c.RGBA()
if a != 0xffff {
return false
}
}
return true
}
// NewPaletted returns a new [Paletted] image with the given width, height and
// palette.
func NewPaletted(r Rectangle, p color.Palette) *Paletted {
return &Paletted{
Pix: make([]uint8, pixelBufferLength(1, r, "Paletted")),
Stride: 1 * r.Dx(),
Rect: r,
Palette: p,
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package jpeg
// This file implements a Forward Discrete Cosine Transformation.
/*
It is based on the code in jfdctint.c from the Independent JPEG Group,
found at http://www.ijg.org/files/jpegsrc.v8c.tar.gz.
The "LEGAL ISSUES" section of the README in that archive says:
In plain English:
1. We don't promise that this software works. (But if you find any bugs,
please let us know!)
2. You can use this software for whatever you want. You don't have to pay us.
3. You may not pretend that you wrote this software. If you use it in a
program, you must acknowledge somewhere in your documentation that
you've used the IJG code.
In legalese:
The authors make NO WARRANTY or representation, either express or implied,
with respect to this software, its quality, accuracy, merchantability, or
fitness for a particular purpose. This software is provided "AS IS", and you,
its user, assume the entire risk as to its quality and accuracy.
This software is copyright (C) 1991-2011, Thomas G. Lane, Guido Vollbeding.
All Rights Reserved except as specified below.
Permission is hereby granted to use, copy, modify, and distribute this
software (or portions thereof) for any purpose, without fee, subject to these
conditions:
(1) If any part of the source code for this software is distributed, then this
README file must be included, with this copyright and no-warranty notice
unaltered; and any additions, deletions, or changes to the original files
must be clearly indicated in accompanying documentation.
(2) If only executable code is distributed, then the accompanying
documentation must state that "this software is based in part on the work of
the Independent JPEG Group".
(3) Permission for use of this software is granted only if the user accepts
full responsibility for any undesirable consequences; the authors accept
NO LIABILITY for damages of any kind.
These conditions apply to any software derived from or based on the IJG code,
not just to the unmodified library. If you use our work, you ought to
acknowledge us.
Permission is NOT granted for the use of any IJG author's name or company name
in advertising or publicity relating to this software or products derived from
it. This software may be referred to only as "the Independent JPEG Group's
software".
We specifically permit and encourage the use of this software as the basis of
commercial products, provided that all warranty or liability claims are
assumed by the product vendor.
*/
// Trigonometric constants in 13-bit fixed point format.
const (
fix_0_298631336 = 2446
fix_0_390180644 = 3196
fix_0_541196100 = 4433
fix_0_765366865 = 6270
fix_0_899976223 = 7373
fix_1_175875602 = 9633
fix_1_501321110 = 12299
fix_1_847759065 = 15137
fix_1_961570560 = 16069
fix_2_053119869 = 16819
fix_2_562915447 = 20995
fix_3_072711026 = 25172
)
const (
constBits = 13
pass1Bits = 2
centerJSample = 128
)
// fdct performs a forward DCT on an 8x8 block of coefficients, including a
// level shift.
func fdct(b *block) {
// Pass 1: process rows.
for y := 0; y < 8; y++ {
y8 := y * 8
s := b[y8 : y8+8 : y8+8] // Small cap improves performance, see https://golang.org/issue/27857
x0 := s[0]
x1 := s[1]
x2 := s[2]
x3 := s[3]
x4 := s[4]
x5 := s[5]
x6 := s[6]
x7 := s[7]
tmp0 := x0 + x7
tmp1 := x1 + x6
tmp2 := x2 + x5
tmp3 := x3 + x4
tmp10 := tmp0 + tmp3
tmp12 := tmp0 - tmp3
tmp11 := tmp1 + tmp2
tmp13 := tmp1 - tmp2
tmp0 = x0 - x7
tmp1 = x1 - x6
tmp2 = x2 - x5
tmp3 = x3 - x4
s[0] = (tmp10 + tmp11 - 8*centerJSample) << pass1Bits
s[4] = (tmp10 - tmp11) << pass1Bits
z1 := (tmp12 + tmp13) * fix_0_541196100
z1 += 1 << (constBits - pass1Bits - 1)
s[2] = (z1 + tmp12*fix_0_765366865) >> (constBits - pass1Bits)
s[6] = (z1 - tmp13*fix_1_847759065) >> (constBits - pass1Bits)
tmp10 = tmp0 + tmp3
tmp11 = tmp1 + tmp2
tmp12 = tmp0 + tmp2
tmp13 = tmp1 + tmp3
z1 = (tmp12 + tmp13) * fix_1_175875602
z1 += 1 << (constBits - pass1Bits - 1)
tmp0 *= fix_1_501321110
tmp1 *= fix_3_072711026
tmp2 *= fix_2_053119869
tmp3 *= fix_0_298631336
tmp10 *= -fix_0_899976223
tmp11 *= -fix_2_562915447
tmp12 *= -fix_0_390180644
tmp13 *= -fix_1_961570560
tmp12 += z1
tmp13 += z1
s[1] = (tmp0 + tmp10 + tmp12) >> (constBits - pass1Bits)
s[3] = (tmp1 + tmp11 + tmp13) >> (constBits - pass1Bits)
s[5] = (tmp2 + tmp11 + tmp12) >> (constBits - pass1Bits)
s[7] = (tmp3 + tmp10 + tmp13) >> (constBits - pass1Bits)
}
// Pass 2: process columns.
// We remove pass1Bits scaling, but leave results scaled up by an overall factor of 8.
for x := 0; x < 8; x++ {
tmp0 := b[0*8+x] + b[7*8+x]
tmp1 := b[1*8+x] + b[6*8+x]
tmp2 := b[2*8+x] + b[5*8+x]
tmp3 := b[3*8+x] + b[4*8+x]
tmp10 := tmp0 + tmp3 + 1<<(pass1Bits-1)
tmp12 := tmp0 - tmp3
tmp11 := tmp1 + tmp2
tmp13 := tmp1 - tmp2
tmp0 = b[0*8+x] - b[7*8+x]
tmp1 = b[1*8+x] - b[6*8+x]
tmp2 = b[2*8+x] - b[5*8+x]
tmp3 = b[3*8+x] - b[4*8+x]
b[0*8+x] = (tmp10 + tmp11) >> pass1Bits
b[4*8+x] = (tmp10 - tmp11) >> pass1Bits
z1 := (tmp12 + tmp13) * fix_0_541196100
z1 += 1 << (constBits + pass1Bits - 1)
b[2*8+x] = (z1 + tmp12*fix_0_765366865) >> (constBits + pass1Bits)
b[6*8+x] = (z1 - tmp13*fix_1_847759065) >> (constBits + pass1Bits)
tmp10 = tmp0 + tmp3
tmp11 = tmp1 + tmp2
tmp12 = tmp0 + tmp2
tmp13 = tmp1 + tmp3
z1 = (tmp12 + tmp13) * fix_1_175875602
z1 += 1 << (constBits + pass1Bits - 1)
tmp0 *= fix_1_501321110
tmp1 *= fix_3_072711026
tmp2 *= fix_2_053119869
tmp3 *= fix_0_298631336
tmp10 *= -fix_0_899976223
tmp11 *= -fix_2_562915447
tmp12 *= -fix_0_390180644
tmp13 *= -fix_1_961570560
tmp12 += z1
tmp13 += z1
b[1*8+x] = (tmp0 + tmp10 + tmp12) >> (constBits + pass1Bits)
b[3*8+x] = (tmp1 + tmp11 + tmp13) >> (constBits + pass1Bits)
b[5*8+x] = (tmp2 + tmp11 + tmp12) >> (constBits + pass1Bits)
b[7*8+x] = (tmp3 + tmp10 + tmp13) >> (constBits + pass1Bits)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package jpeg
import (
"io"
)
// maxCodeLength is the maximum (inclusive) number of bits in a Huffman code.
const maxCodeLength = 16
// maxNCodes is the maximum (inclusive) number of codes in a Huffman tree.
const maxNCodes = 256
// lutSize is the log-2 size of the Huffman decoder's look-up table.
const lutSize = 8
// huffman is a Huffman decoder, specified in section C.
type huffman struct {
// length is the number of codes in the tree.
nCodes int32
// lut is the look-up table for the next lutSize bits in the bit-stream.
// The high 8 bits of the uint16 are the encoded value. The low 8 bits
// are 1 plus the code length, or 0 if the value is too large to fit in
// lutSize bits.
lut [1 << lutSize]uint16
// vals are the decoded values, sorted by their encoding.
vals [maxNCodes]uint8
// minCodes[i] is the minimum code of length i, or -1 if there are no
// codes of that length.
minCodes [maxCodeLength]int32
// maxCodes[i] is the maximum code of length i, or -1 if there are no
// codes of that length.
maxCodes [maxCodeLength]int32
// valsIndices[i] is the index into vals of minCodes[i].
valsIndices [maxCodeLength]int32
}
// errShortHuffmanData means that an unexpected EOF occurred while decoding
// Huffman data.
var errShortHuffmanData = FormatError("short Huffman data")
// ensureNBits reads bytes from the byte buffer to ensure that d.bits.n is at
// least n. For best performance (avoiding function calls inside hot loops),
// the caller is the one responsible for first checking that d.bits.n < n.
func (d *decoder) ensureNBits(n int32) error {
for {
c, err := d.readByteStuffedByte()
if err != nil {
if err == io.ErrUnexpectedEOF {
return errShortHuffmanData
}
return err
}
d.bits.a = d.bits.a<<8 | uint32(c)
d.bits.n += 8
if d.bits.m == 0 {
d.bits.m = 1 << 7
} else {
d.bits.m <<= 8
}
if d.bits.n >= n {
break
}
}
return nil
}
// receiveExtend is the composition of RECEIVE and EXTEND, specified in section
// F.2.2.1.
func (d *decoder) receiveExtend(t uint8) (int32, error) {
if d.bits.n < int32(t) {
if err := d.ensureNBits(int32(t)); err != nil {
return 0, err
}
}
d.bits.n -= int32(t)
d.bits.m >>= t
s := int32(1) << t
x := int32(d.bits.a>>uint8(d.bits.n)) & (s - 1)
if x < s>>1 {
x += ((-1) << t) + 1
}
return x, nil
}
// processDHT processes a Define Huffman Table marker, and initializes a huffman
// struct from its contents. Specified in section B.2.4.2.
func (d *decoder) processDHT(n int) error {
for n > 0 {
if n < 17 {
return FormatError("DHT has wrong length")
}
if err := d.readFull(d.tmp[:17]); err != nil {
return err
}
tc := d.tmp[0] >> 4
if tc > maxTc {
return FormatError("bad Tc value")
}
th := d.tmp[0] & 0x0f
// The baseline th <= 1 restriction is specified in table B.5.
if th > maxTh || (d.baseline && th > 1) {
return FormatError("bad Th value")
}
h := &d.huff[tc][th]
// Read nCodes and h.vals (and derive h.nCodes).
// nCodes[i] is the number of codes with code length i.
// h.nCodes is the total number of codes.
h.nCodes = 0
var nCodes [maxCodeLength]int32
for i := range nCodes {
nCodes[i] = int32(d.tmp[i+1])
h.nCodes += nCodes[i]
}
if h.nCodes == 0 {
return FormatError("Huffman table has zero length")
}
if h.nCodes > maxNCodes {
return FormatError("Huffman table has excessive length")
}
n -= int(h.nCodes) + 17
if n < 0 {
return FormatError("DHT has wrong length")
}
if err := d.readFull(h.vals[:h.nCodes]); err != nil {
return err
}
// Derive the look-up table.
clear(h.lut[:])
var x, code uint32
for i := uint32(0); i < lutSize; i++ {
code <<= 1
for j := int32(0); j < nCodes[i]; j++ {
// The codeLength is 1+i, so shift code by 8-(1+i) to
// calculate the high bits for every 8-bit sequence
// whose codeLength's high bits matches code.
// The high 8 bits of lutValue are the encoded value.
// The low 8 bits are 1 plus the codeLength.
base := uint8(code << (7 - i))
lutValue := uint16(h.vals[x])<<8 | uint16(2+i)
for k := uint8(0); k < 1<<(7-i); k++ {
h.lut[base|k] = lutValue
}
code++
x++
}
}
// Derive minCodes, maxCodes, and valsIndices.
var c, index int32
for i, n := range nCodes {
if n == 0 {
h.minCodes[i] = -1
h.maxCodes[i] = -1
h.valsIndices[i] = -1
} else {
h.minCodes[i] = c
h.maxCodes[i] = c + n - 1
h.valsIndices[i] = index
c += n
index += n
}
c <<= 1
}
}
return nil
}
// decodeHuffman returns the next Huffman-coded value from the bit-stream,
// decoded according to h.
func (d *decoder) decodeHuffman(h *huffman) (uint8, error) {
if h.nCodes == 0 {
return 0, FormatError("uninitialized Huffman table")
}
if d.bits.n < 8 {
if err := d.ensureNBits(8); err != nil {
if err != errMissingFF00 && err != errShortHuffmanData {
return 0, err
}
// There are no more bytes of data in this segment, but we may still
// be able to read the next symbol out of the previously read bits.
// First, undo the readByte that the ensureNBits call made.
if d.bytes.nUnreadable != 0 {
d.unreadByteStuffedByte()
}
goto slowPath
}
}
if v := h.lut[(d.bits.a>>uint32(d.bits.n-lutSize))&0xff]; v != 0 {
n := (v & 0xff) - 1
d.bits.n -= int32(n)
d.bits.m >>= n
return uint8(v >> 8), nil
}
slowPath:
for i, code := 0, int32(0); i < maxCodeLength; i++ {
if d.bits.n == 0 {
if err := d.ensureNBits(1); err != nil {
return 0, err
}
}
if d.bits.a&d.bits.m != 0 {
code |= 1
}
d.bits.n--
d.bits.m >>= 1
if code <= h.maxCodes[i] {
return h.vals[h.valsIndices[i]+code-h.minCodes[i]], nil
}
code <<= 1
}
return 0, FormatError("bad Huffman code")
}
func (d *decoder) decodeBit() (bool, error) {
if d.bits.n == 0 {
if err := d.ensureNBits(1); err != nil {
return false, err
}
}
ret := d.bits.a&d.bits.m != 0
d.bits.n--
d.bits.m >>= 1
return ret, nil
}
func (d *decoder) decodeBits(n int32) (uint32, error) {
if d.bits.n < n {
if err := d.ensureNBits(n); err != nil {
return 0, err
}
}
ret := d.bits.a >> uint32(d.bits.n-n)
ret &= (1 << uint32(n)) - 1
d.bits.n -= n
d.bits.m >>= uint32(n)
return ret, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package jpeg
// This is a Go translation of idct.c from
//
// http://standards.iso.org/ittf/PubliclyAvailableStandards/ISO_IEC_13818-4_2004_Conformance_Testing/Video/verifier/mpeg2decode_960109.tar.gz
//
// which carries the following notice:
/* Copyright (C) 1996, MPEG Software Simulation Group. All Rights Reserved. */
/*
* Disclaimer of Warranty
*
* These software programs are available to the user without any license fee or
* royalty on an "as is" basis. The MPEG Software Simulation Group disclaims
* any and all warranties, whether express, implied, or statuary, including any
* implied warranties or merchantability or of fitness for a particular
* purpose. In no event shall the copyright-holder be liable for any
* incidental, punitive, or consequential damages of any kind whatsoever
* arising from the use of these programs.
*
* This disclaimer of warranty extends to the user of these programs and user's
* customers, employees, agents, transferees, successors, and assigns.
*
* The MPEG Software Simulation Group does not represent or warrant that the
* programs furnished hereunder are free of infringement of any third-party
* patents.
*
* Commercial implementations of MPEG-1 and MPEG-2 video, including shareware,
* are subject to royalty fees to patent holders. Many of these patents are
* general enough such that they are unavoidable regardless of implementation
* design.
*
*/
const blockSize = 64 // A DCT block is 8x8.
type block [blockSize]int32
const (
w1 = 2841 // 2048*sqrt(2)*cos(1*pi/16)
w2 = 2676 // 2048*sqrt(2)*cos(2*pi/16)
w3 = 2408 // 2048*sqrt(2)*cos(3*pi/16)
w5 = 1609 // 2048*sqrt(2)*cos(5*pi/16)
w6 = 1108 // 2048*sqrt(2)*cos(6*pi/16)
w7 = 565 // 2048*sqrt(2)*cos(7*pi/16)
w1pw7 = w1 + w7
w1mw7 = w1 - w7
w2pw6 = w2 + w6
w2mw6 = w2 - w6
w3pw5 = w3 + w5
w3mw5 = w3 - w5
r2 = 181 // 256/sqrt(2)
)
// idct performs a 2-D Inverse Discrete Cosine Transformation.
//
// The input coefficients should already have been multiplied by the
// appropriate quantization table. We use fixed-point computation, with the
// number of bits for the fractional component varying over the intermediate
// stages.
//
// For more on the actual algorithm, see Z. Wang, "Fast algorithms for the
// discrete W transform and for the discrete Fourier transform", IEEE Trans. on
// ASSP, Vol. ASSP- 32, pp. 803-816, Aug. 1984.
func idct(src *block) {
// Horizontal 1-D IDCT.
for y := 0; y < 8; y++ {
y8 := y * 8
s := src[y8 : y8+8 : y8+8] // Small cap improves performance, see https://golang.org/issue/27857
// If all the AC components are zero, then the IDCT is trivial.
if s[1] == 0 && s[2] == 0 && s[3] == 0 &&
s[4] == 0 && s[5] == 0 && s[6] == 0 && s[7] == 0 {
dc := s[0] << 3
s[0] = dc
s[1] = dc
s[2] = dc
s[3] = dc
s[4] = dc
s[5] = dc
s[6] = dc
s[7] = dc
continue
}
// Prescale.
x0 := (s[0] << 11) + 128
x1 := s[4] << 11
x2 := s[6]
x3 := s[2]
x4 := s[1]
x5 := s[7]
x6 := s[5]
x7 := s[3]
// Stage 1.
x8 := w7 * (x4 + x5)
x4 = x8 + w1mw7*x4
x5 = x8 - w1pw7*x5
x8 = w3 * (x6 + x7)
x6 = x8 - w3mw5*x6
x7 = x8 - w3pw5*x7
// Stage 2.
x8 = x0 + x1
x0 -= x1
x1 = w6 * (x3 + x2)
x2 = x1 - w2pw6*x2
x3 = x1 + w2mw6*x3
x1 = x4 + x6
x4 -= x6
x6 = x5 + x7
x5 -= x7
// Stage 3.
x7 = x8 + x3
x8 -= x3
x3 = x0 + x2
x0 -= x2
x2 = (r2*(x4+x5) + 128) >> 8
x4 = (r2*(x4-x5) + 128) >> 8
// Stage 4.
s[0] = (x7 + x1) >> 8
s[1] = (x3 + x2) >> 8
s[2] = (x0 + x4) >> 8
s[3] = (x8 + x6) >> 8
s[4] = (x8 - x6) >> 8
s[5] = (x0 - x4) >> 8
s[6] = (x3 - x2) >> 8
s[7] = (x7 - x1) >> 8
}
// Vertical 1-D IDCT.
for x := 0; x < 8; x++ {
// Similar to the horizontal 1-D IDCT case, if all the AC components are zero, then the IDCT is trivial.
// However, after performing the horizontal 1-D IDCT, there are typically non-zero AC components, so
// we do not bother to check for the all-zero case.
s := src[x : x+57 : x+57] // Small cap improves performance, see https://golang.org/issue/27857
// Prescale.
y0 := (s[8*0] << 8) + 8192
y1 := s[8*4] << 8
y2 := s[8*6]
y3 := s[8*2]
y4 := s[8*1]
y5 := s[8*7]
y6 := s[8*5]
y7 := s[8*3]
// Stage 1.
y8 := w7*(y4+y5) + 4
y4 = (y8 + w1mw7*y4) >> 3
y5 = (y8 - w1pw7*y5) >> 3
y8 = w3*(y6+y7) + 4
y6 = (y8 - w3mw5*y6) >> 3
y7 = (y8 - w3pw5*y7) >> 3
// Stage 2.
y8 = y0 + y1
y0 -= y1
y1 = w6*(y3+y2) + 4
y2 = (y1 - w2pw6*y2) >> 3
y3 = (y1 + w2mw6*y3) >> 3
y1 = y4 + y6
y4 -= y6
y6 = y5 + y7
y5 -= y7
// Stage 3.
y7 = y8 + y3
y8 -= y3
y3 = y0 + y2
y0 -= y2
y2 = (r2*(y4+y5) + 128) >> 8
y4 = (r2*(y4-y5) + 128) >> 8
// Stage 4.
s[8*0] = (y7 + y1) >> 14
s[8*1] = (y3 + y2) >> 14
s[8*2] = (y0 + y4) >> 14
s[8*3] = (y8 + y6) >> 14
s[8*4] = (y8 - y6) >> 14
s[8*5] = (y0 - y4) >> 14
s[8*6] = (y3 - y2) >> 14
s[8*7] = (y7 - y1) >> 14
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package jpeg implements a JPEG image decoder and encoder.
//
// JPEG is defined in ITU-T T.81: https://www.w3.org/Graphics/JPEG/itu-t81.pdf.
package jpeg
import (
"image"
"image/color"
"image/internal/imageutil"
"io"
)
// A FormatError reports that the input is not a valid JPEG.
type FormatError string
func (e FormatError) Error() string { return "invalid JPEG format: " + string(e) }
// An UnsupportedError reports that the input uses a valid but unimplemented JPEG feature.
type UnsupportedError string
func (e UnsupportedError) Error() string { return "unsupported JPEG feature: " + string(e) }
var errUnsupportedSubsamplingRatio = UnsupportedError("luma/chroma subsampling ratio")
// Component specification, specified in section B.2.2.
type component struct {
h int // Horizontal sampling factor.
v int // Vertical sampling factor.
c uint8 // Component identifier.
tq uint8 // Quantization table destination selector.
}
const (
dcTable = 0
acTable = 1
maxTc = 1
maxTh = 3
maxTq = 3
maxComponents = 4
)
const (
sof0Marker = 0xc0 // Start Of Frame (Baseline Sequential).
sof1Marker = 0xc1 // Start Of Frame (Extended Sequential).
sof2Marker = 0xc2 // Start Of Frame (Progressive).
dhtMarker = 0xc4 // Define Huffman Table.
rst0Marker = 0xd0 // ReSTart (0).
rst7Marker = 0xd7 // ReSTart (7).
soiMarker = 0xd8 // Start Of Image.
eoiMarker = 0xd9 // End Of Image.
sosMarker = 0xda // Start Of Scan.
dqtMarker = 0xdb // Define Quantization Table.
driMarker = 0xdd // Define Restart Interval.
comMarker = 0xfe // COMment.
// "APPlication specific" markers aren't part of the JPEG spec per se,
// but in practice, their use is described at
// https://www.sno.phy.queensu.ca/~phil/exiftool/TagNames/JPEG.html
app0Marker = 0xe0
app14Marker = 0xee
app15Marker = 0xef
)
// See https://www.sno.phy.queensu.ca/~phil/exiftool/TagNames/JPEG.html#Adobe
const (
adobeTransformUnknown = 0
adobeTransformYCbCr = 1
adobeTransformYCbCrK = 2
)
// unzig maps from the zig-zag ordering to the natural ordering. For example,
// unzig[3] is the column and row of the fourth element in zig-zag order. The
// value is 16, which means first column (16%8 == 0) and third row (16/8 == 2).
var unzig = [blockSize]int{
0, 1, 8, 16, 9, 2, 3, 10,
17, 24, 32, 25, 18, 11, 4, 5,
12, 19, 26, 33, 40, 48, 41, 34,
27, 20, 13, 6, 7, 14, 21, 28,
35, 42, 49, 56, 57, 50, 43, 36,
29, 22, 15, 23, 30, 37, 44, 51,
58, 59, 52, 45, 38, 31, 39, 46,
53, 60, 61, 54, 47, 55, 62, 63,
}
// Deprecated: Reader is not used by the [image/jpeg] package and should
// not be used by others. It is kept for compatibility.
type Reader interface {
io.ByteReader
io.Reader
}
// bits holds the unprocessed bits that have been taken from the byte-stream.
// The n least significant bits of a form the unread bits, to be read in MSB to
// LSB order.
type bits struct {
a uint32 // accumulator.
m uint32 // mask. m==1<<(n-1) when n>0, with m==0 when n==0.
n int32 // the number of unread bits in a.
}
type decoder struct {
r io.Reader
bits bits
// bytes is a byte buffer, similar to a bufio.Reader, except that it
// has to be able to unread more than 1 byte, due to byte stuffing.
// Byte stuffing is specified in section F.1.2.3.
bytes struct {
// buf[i:j] are the buffered bytes read from the underlying
// io.Reader that haven't yet been passed further on.
buf [4096]byte
i, j int
// nUnreadable is the number of bytes to back up i after
// overshooting. It can be 0, 1 or 2.
nUnreadable int
}
width, height int
img1 *image.Gray
img3 *image.YCbCr
blackPix []byte
blackStride int
ri int // Restart Interval.
nComp int
// As per section 4.5, there are four modes of operation (selected by the
// SOF? markers): sequential DCT, progressive DCT, lossless and
// hierarchical, although this implementation does not support the latter
// two non-DCT modes. Sequential DCT is further split into baseline and
// extended, as per section 4.11.
baseline bool
progressive bool
jfif bool
adobeTransformValid bool
adobeTransform uint8
eobRun uint16 // End-of-Band run, specified in section G.1.2.2.
comp [maxComponents]component
progCoeffs [maxComponents][]block // Saved state between progressive-mode scans.
huff [maxTc + 1][maxTh + 1]huffman
quant [maxTq + 1]block // Quantization tables, in zig-zag order.
tmp [2 * blockSize]byte
}
// fill fills up the d.bytes.buf buffer from the underlying io.Reader. It
// should only be called when there are no unread bytes in d.bytes.
func (d *decoder) fill() error {
if d.bytes.i != d.bytes.j {
panic("jpeg: fill called when unread bytes exist")
}
// Move the last 2 bytes to the start of the buffer, in case we need
// to call unreadByteStuffedByte.
if d.bytes.j > 2 {
d.bytes.buf[0] = d.bytes.buf[d.bytes.j-2]
d.bytes.buf[1] = d.bytes.buf[d.bytes.j-1]
d.bytes.i, d.bytes.j = 2, 2
}
// Fill in the rest of the buffer.
n, err := d.r.Read(d.bytes.buf[d.bytes.j:])
d.bytes.j += n
if n > 0 {
return nil
}
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return err
}
// unreadByteStuffedByte undoes the most recent readByteStuffedByte call,
// giving a byte of data back from d.bits to d.bytes. The Huffman look-up table
// requires at least 8 bits for look-up, which means that Huffman decoding can
// sometimes overshoot and read one or two too many bytes. Two-byte overshoot
// can happen when expecting to read a 0xff 0x00 byte-stuffed byte.
func (d *decoder) unreadByteStuffedByte() {
d.bytes.i -= d.bytes.nUnreadable
d.bytes.nUnreadable = 0
if d.bits.n >= 8 {
d.bits.a >>= 8
d.bits.n -= 8
d.bits.m >>= 8
}
}
// readByte returns the next byte, whether buffered or not buffered. It does
// not care about byte stuffing.
func (d *decoder) readByte() (x byte, err error) {
for d.bytes.i == d.bytes.j {
if err = d.fill(); err != nil {
return 0, err
}
}
x = d.bytes.buf[d.bytes.i]
d.bytes.i++
d.bytes.nUnreadable = 0
return x, nil
}
// errMissingFF00 means that readByteStuffedByte encountered an 0xff byte (a
// marker byte) that wasn't the expected byte-stuffed sequence 0xff, 0x00.
var errMissingFF00 = FormatError("missing 0xff00 sequence")
// readByteStuffedByte is like readByte but is for byte-stuffed Huffman data.
func (d *decoder) readByteStuffedByte() (x byte, err error) {
// Take the fast path if d.bytes.buf contains at least two bytes.
if d.bytes.i+2 <= d.bytes.j {
x = d.bytes.buf[d.bytes.i]
d.bytes.i++
d.bytes.nUnreadable = 1
if x != 0xff {
return x, err
}
if d.bytes.buf[d.bytes.i] != 0x00 {
return 0, errMissingFF00
}
d.bytes.i++
d.bytes.nUnreadable = 2
return 0xff, nil
}
d.bytes.nUnreadable = 0
x, err = d.readByte()
if err != nil {
return 0, err
}
d.bytes.nUnreadable = 1
if x != 0xff {
return x, nil
}
x, err = d.readByte()
if err != nil {
return 0, err
}
d.bytes.nUnreadable = 2
if x != 0x00 {
return 0, errMissingFF00
}
return 0xff, nil
}
// readFull reads exactly len(p) bytes into p. It does not care about byte
// stuffing.
func (d *decoder) readFull(p []byte) error {
// Unread the overshot bytes, if any.
if d.bytes.nUnreadable != 0 {
if d.bits.n >= 8 {
d.unreadByteStuffedByte()
}
d.bytes.nUnreadable = 0
}
for {
n := copy(p, d.bytes.buf[d.bytes.i:d.bytes.j])
p = p[n:]
d.bytes.i += n
if len(p) == 0 {
break
}
if err := d.fill(); err != nil {
return err
}
}
return nil
}
// ignore ignores the next n bytes.
func (d *decoder) ignore(n int) error {
// Unread the overshot bytes, if any.
if d.bytes.nUnreadable != 0 {
if d.bits.n >= 8 {
d.unreadByteStuffedByte()
}
d.bytes.nUnreadable = 0
}
for {
m := d.bytes.j - d.bytes.i
if m > n {
m = n
}
d.bytes.i += m
n -= m
if n == 0 {
break
}
if err := d.fill(); err != nil {
return err
}
}
return nil
}
// Specified in section B.2.2.
func (d *decoder) processSOF(n int) error {
if d.nComp != 0 {
return FormatError("multiple SOF markers")
}
switch n {
case 6 + 3*1: // Grayscale image.
d.nComp = 1
case 6 + 3*3: // YCbCr or RGB image.
d.nComp = 3
case 6 + 3*4: // YCbCrK or CMYK image.
d.nComp = 4
default:
return UnsupportedError("number of components")
}
if err := d.readFull(d.tmp[:n]); err != nil {
return err
}
// We only support 8-bit precision.
if d.tmp[0] != 8 {
return UnsupportedError("precision")
}
d.height = int(d.tmp[1])<<8 + int(d.tmp[2])
d.width = int(d.tmp[3])<<8 + int(d.tmp[4])
if int(d.tmp[5]) != d.nComp {
return FormatError("SOF has wrong length")
}
for i := 0; i < d.nComp; i++ {
d.comp[i].c = d.tmp[6+3*i]
// Section B.2.2 states that "the value of C_i shall be different from
// the values of C_1 through C_(i-1)".
for j := 0; j < i; j++ {
if d.comp[i].c == d.comp[j].c {
return FormatError("repeated component identifier")
}
}
d.comp[i].tq = d.tmp[8+3*i]
if d.comp[i].tq > maxTq {
return FormatError("bad Tq value")
}
hv := d.tmp[7+3*i]
h, v := int(hv>>4), int(hv&0x0f)
if h < 1 || 4 < h || v < 1 || 4 < v {
return FormatError("luma/chroma subsampling ratio")
}
if h == 3 || v == 3 {
return errUnsupportedSubsamplingRatio
}
switch d.nComp {
case 1:
// If a JPEG image has only one component, section A.2 says "this data
// is non-interleaved by definition" and section A.2.2 says "[in this
// case...] the order of data units within a scan shall be left-to-right
// and top-to-bottom... regardless of the values of H_1 and V_1". Section
// 4.8.2 also says "[for non-interleaved data], the MCU is defined to be
// one data unit". Similarly, section A.1.1 explains that it is the ratio
// of H_i to max_j(H_j) that matters, and similarly for V. For grayscale
// images, H_1 is the maximum H_j for all components j, so that ratio is
// always 1. The component's (h, v) is effectively always (1, 1): even if
// the nominal (h, v) is (2, 1), a 20x5 image is encoded in three 8x8
// MCUs, not two 16x8 MCUs.
h, v = 1, 1
case 3:
// For YCbCr images, we only support 4:4:4, 4:4:0, 4:2:2, 4:2:0,
// 4:1:1 or 4:1:0 chroma subsampling ratios. This implies that the
// (h, v) values for the Y component are either (1, 1), (1, 2),
// (2, 1), (2, 2), (4, 1) or (4, 2), and the Y component's values
// must be a multiple of the Cb and Cr component's values. We also
// assume that the two chroma components have the same subsampling
// ratio.
switch i {
case 0: // Y.
// We have already verified, above, that h and v are both
// either 1, 2 or 4, so invalid (h, v) combinations are those
// with v == 4.
if v == 4 {
return errUnsupportedSubsamplingRatio
}
case 1: // Cb.
if d.comp[0].h%h != 0 || d.comp[0].v%v != 0 {
return errUnsupportedSubsamplingRatio
}
case 2: // Cr.
if d.comp[1].h != h || d.comp[1].v != v {
return errUnsupportedSubsamplingRatio
}
}
case 4:
// For 4-component images (either CMYK or YCbCrK), we only support two
// hv vectors: [0x11 0x11 0x11 0x11] and [0x22 0x11 0x11 0x22].
// Theoretically, 4-component JPEG images could mix and match hv values
// but in practice, those two combinations are the only ones in use,
// and it simplifies the applyBlack code below if we can assume that:
// - for CMYK, the C and K channels have full samples, and if the M
// and Y channels subsample, they subsample both horizontally and
// vertically.
// - for YCbCrK, the Y and K channels have full samples.
switch i {
case 0:
if hv != 0x11 && hv != 0x22 {
return errUnsupportedSubsamplingRatio
}
case 1, 2:
if hv != 0x11 {
return errUnsupportedSubsamplingRatio
}
case 3:
if d.comp[0].h != h || d.comp[0].v != v {
return errUnsupportedSubsamplingRatio
}
}
}
d.comp[i].h = h
d.comp[i].v = v
}
return nil
}
// Specified in section B.2.4.1.
func (d *decoder) processDQT(n int) error {
loop:
for n > 0 {
n--
x, err := d.readByte()
if err != nil {
return err
}
tq := x & 0x0f
if tq > maxTq {
return FormatError("bad Tq value")
}
switch x >> 4 {
default:
return FormatError("bad Pq value")
case 0:
if n < blockSize {
break loop
}
n -= blockSize
if err := d.readFull(d.tmp[:blockSize]); err != nil {
return err
}
for i := range d.quant[tq] {
d.quant[tq][i] = int32(d.tmp[i])
}
case 1:
if n < 2*blockSize {
break loop
}
n -= 2 * blockSize
if err := d.readFull(d.tmp[:2*blockSize]); err != nil {
return err
}
for i := range d.quant[tq] {
d.quant[tq][i] = int32(d.tmp[2*i])<<8 | int32(d.tmp[2*i+1])
}
}
}
if n != 0 {
return FormatError("DQT has wrong length")
}
return nil
}
// Specified in section B.2.4.4.
func (d *decoder) processDRI(n int) error {
if n != 2 {
return FormatError("DRI has wrong length")
}
if err := d.readFull(d.tmp[:2]); err != nil {
return err
}
d.ri = int(d.tmp[0])<<8 + int(d.tmp[1])
return nil
}
func (d *decoder) processApp0Marker(n int) error {
if n < 5 {
return d.ignore(n)
}
if err := d.readFull(d.tmp[:5]); err != nil {
return err
}
n -= 5
d.jfif = d.tmp[0] == 'J' && d.tmp[1] == 'F' && d.tmp[2] == 'I' && d.tmp[3] == 'F' && d.tmp[4] == '\x00'
if n > 0 {
return d.ignore(n)
}
return nil
}
func (d *decoder) processApp14Marker(n int) error {
if n < 12 {
return d.ignore(n)
}
if err := d.readFull(d.tmp[:12]); err != nil {
return err
}
n -= 12
if d.tmp[0] == 'A' && d.tmp[1] == 'd' && d.tmp[2] == 'o' && d.tmp[3] == 'b' && d.tmp[4] == 'e' {
d.adobeTransformValid = true
d.adobeTransform = d.tmp[11]
}
if n > 0 {
return d.ignore(n)
}
return nil
}
// decode reads a JPEG image from r and returns it as an image.Image.
func (d *decoder) decode(r io.Reader, configOnly bool) (image.Image, error) {
d.r = r
// Check for the Start Of Image marker.
if err := d.readFull(d.tmp[:2]); err != nil {
return nil, err
}
if d.tmp[0] != 0xff || d.tmp[1] != soiMarker {
return nil, FormatError("missing SOI marker")
}
// Process the remaining segments until the End Of Image marker.
for {
err := d.readFull(d.tmp[:2])
if err != nil {
return nil, err
}
for d.tmp[0] != 0xff {
// Strictly speaking, this is a format error. However, libjpeg is
// liberal in what it accepts. As of version 9, next_marker in
// jdmarker.c treats this as a warning (JWRN_EXTRANEOUS_DATA) and
// continues to decode the stream. Even before next_marker sees
// extraneous data, jpeg_fill_bit_buffer in jdhuff.c reads as many
// bytes as it can, possibly past the end of a scan's data. It
// effectively puts back any markers that it overscanned (e.g. an
// "\xff\xd9" EOI marker), but it does not put back non-marker data,
// and thus it can silently ignore a small number of extraneous
// non-marker bytes before next_marker has a chance to see them (and
// print a warning).
//
// We are therefore also liberal in what we accept. Extraneous data
// is silently ignored.
//
// This is similar to, but not exactly the same as, the restart
// mechanism within a scan (the RST[0-7] markers).
//
// Note that extraneous 0xff bytes in e.g. SOS data are escaped as
// "\xff\x00", and so are detected a little further down below.
d.tmp[0] = d.tmp[1]
d.tmp[1], err = d.readByte()
if err != nil {
return nil, err
}
}
marker := d.tmp[1]
if marker == 0 {
// Treat "\xff\x00" as extraneous data.
continue
}
for marker == 0xff {
// Section B.1.1.2 says, "Any marker may optionally be preceded by any
// number of fill bytes, which are bytes assigned code X'FF'".
marker, err = d.readByte()
if err != nil {
return nil, err
}
}
if marker == eoiMarker { // End Of Image.
break
}
if rst0Marker <= marker && marker <= rst7Marker {
// Figures B.2 and B.16 of the specification suggest that restart markers should
// only occur between Entropy Coded Segments and not after the final ECS.
// However, some encoders may generate incorrect JPEGs with a final restart
// marker. That restart marker will be seen here instead of inside the processSOS
// method, and is ignored as a harmless error. Restart markers have no extra data,
// so we check for this before we read the 16-bit length of the segment.
continue
}
// Read the 16-bit length of the segment. The value includes the 2 bytes for the
// length itself, so we subtract 2 to get the number of remaining bytes.
if err = d.readFull(d.tmp[:2]); err != nil {
return nil, err
}
n := int(d.tmp[0])<<8 + int(d.tmp[1]) - 2
if n < 0 {
return nil, FormatError("short segment length")
}
switch marker {
case sof0Marker, sof1Marker, sof2Marker:
d.baseline = marker == sof0Marker
d.progressive = marker == sof2Marker
err = d.processSOF(n)
if configOnly && d.jfif {
return nil, err
}
case dhtMarker:
if configOnly {
err = d.ignore(n)
} else {
err = d.processDHT(n)
}
case dqtMarker:
if configOnly {
err = d.ignore(n)
} else {
err = d.processDQT(n)
}
case sosMarker:
if configOnly {
return nil, nil
}
err = d.processSOS(n)
case driMarker:
if configOnly {
err = d.ignore(n)
} else {
err = d.processDRI(n)
}
case app0Marker:
err = d.processApp0Marker(n)
case app14Marker:
err = d.processApp14Marker(n)
default:
if app0Marker <= marker && marker <= app15Marker || marker == comMarker {
err = d.ignore(n)
} else if marker < 0xc0 { // See Table B.1 "Marker code assignments".
err = FormatError("unknown marker")
} else {
err = UnsupportedError("unknown marker")
}
}
if err != nil {
return nil, err
}
}
if d.progressive {
if err := d.reconstructProgressiveImage(); err != nil {
return nil, err
}
}
if d.img1 != nil {
return d.img1, nil
}
if d.img3 != nil {
if d.blackPix != nil {
return d.applyBlack()
} else if d.isRGB() {
return d.convertToRGB()
}
return d.img3, nil
}
return nil, FormatError("missing SOS marker")
}
// applyBlack combines d.img3 and d.blackPix into a CMYK image. The formula
// used depends on whether the JPEG image is stored as CMYK or YCbCrK,
// indicated by the APP14 (Adobe) metadata.
//
// Adobe CMYK JPEG images are inverted, where 255 means no ink instead of full
// ink, so we apply "v = 255 - v" at various points. Note that a double
// inversion is a no-op, so inversions might be implicit in the code below.
func (d *decoder) applyBlack() (image.Image, error) {
if !d.adobeTransformValid {
return nil, UnsupportedError("unknown color model: 4-component JPEG doesn't have Adobe APP14 metadata")
}
// If the 4-component JPEG image isn't explicitly marked as "Unknown (RGB
// or CMYK)" as per
// https://www.sno.phy.queensu.ca/~phil/exiftool/TagNames/JPEG.html#Adobe
// we assume that it is YCbCrK. This matches libjpeg's jdapimin.c.
if d.adobeTransform != adobeTransformUnknown {
// Convert the YCbCr part of the YCbCrK to RGB, invert the RGB to get
// CMY, and patch in the original K. The RGB to CMY inversion cancels
// out the 'Adobe inversion' described in the applyBlack doc comment
// above, so in practice, only the fourth channel (black) is inverted.
bounds := d.img3.Bounds()
img := image.NewRGBA(bounds)
imageutil.DrawYCbCr(img, bounds, d.img3, bounds.Min)
for iBase, y := 0, bounds.Min.Y; y < bounds.Max.Y; iBase, y = iBase+img.Stride, y+1 {
for i, x := iBase+3, bounds.Min.X; x < bounds.Max.X; i, x = i+4, x+1 {
img.Pix[i] = 255 - d.blackPix[(y-bounds.Min.Y)*d.blackStride+(x-bounds.Min.X)]
}
}
return &image.CMYK{
Pix: img.Pix,
Stride: img.Stride,
Rect: img.Rect,
}, nil
}
// The first three channels (cyan, magenta, yellow) of the CMYK
// were decoded into d.img3, but each channel was decoded into a separate
// []byte slice, and some channels may be subsampled. We interleave the
// separate channels into an image.CMYK's single []byte slice containing 4
// contiguous bytes per pixel.
bounds := d.img3.Bounds()
img := image.NewCMYK(bounds)
translations := [4]struct {
src []byte
stride int
}{
{d.img3.Y, d.img3.YStride},
{d.img3.Cb, d.img3.CStride},
{d.img3.Cr, d.img3.CStride},
{d.blackPix, d.blackStride},
}
for t, translation := range translations {
subsample := d.comp[t].h != d.comp[0].h || d.comp[t].v != d.comp[0].v
for iBase, y := 0, bounds.Min.Y; y < bounds.Max.Y; iBase, y = iBase+img.Stride, y+1 {
sy := y - bounds.Min.Y
if subsample {
sy /= 2
}
for i, x := iBase+t, bounds.Min.X; x < bounds.Max.X; i, x = i+4, x+1 {
sx := x - bounds.Min.X
if subsample {
sx /= 2
}
img.Pix[i] = 255 - translation.src[sy*translation.stride+sx]
}
}
}
return img, nil
}
func (d *decoder) isRGB() bool {
if d.jfif {
return false
}
if d.adobeTransformValid && d.adobeTransform == adobeTransformUnknown {
// https://www.sno.phy.queensu.ca/~phil/exiftool/TagNames/JPEG.html#Adobe
// says that 0 means Unknown (and in practice RGB) and 1 means YCbCr.
return true
}
return d.comp[0].c == 'R' && d.comp[1].c == 'G' && d.comp[2].c == 'B'
}
func (d *decoder) convertToRGB() (image.Image, error) {
cScale := d.comp[0].h / d.comp[1].h
bounds := d.img3.Bounds()
img := image.NewRGBA(bounds)
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
po := img.PixOffset(bounds.Min.X, y)
yo := d.img3.YOffset(bounds.Min.X, y)
co := d.img3.COffset(bounds.Min.X, y)
for i, iMax := 0, bounds.Max.X-bounds.Min.X; i < iMax; i++ {
img.Pix[po+4*i+0] = d.img3.Y[yo+i]
img.Pix[po+4*i+1] = d.img3.Cb[co+i/cScale]
img.Pix[po+4*i+2] = d.img3.Cr[co+i/cScale]
img.Pix[po+4*i+3] = 255
}
}
return img, nil
}
// Decode reads a JPEG image from r and returns it as an [image.Image].
func Decode(r io.Reader) (image.Image, error) {
var d decoder
return d.decode(r, false)
}
// DecodeConfig returns the color model and dimensions of a JPEG image without
// decoding the entire image.
func DecodeConfig(r io.Reader) (image.Config, error) {
var d decoder
if _, err := d.decode(r, true); err != nil {
return image.Config{}, err
}
switch d.nComp {
case 1:
return image.Config{
ColorModel: color.GrayModel,
Width: d.width,
Height: d.height,
}, nil
case 3:
cm := color.YCbCrModel
if d.isRGB() {
cm = color.RGBAModel
}
return image.Config{
ColorModel: cm,
Width: d.width,
Height: d.height,
}, nil
case 4:
return image.Config{
ColorModel: color.CMYKModel,
Width: d.width,
Height: d.height,
}, nil
}
return image.Config{}, FormatError("missing SOF marker")
}
func init() {
image.RegisterFormat("jpeg", "\xff\xd8", Decode, DecodeConfig)
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package jpeg
import (
"image"
)
// makeImg allocates and initializes the destination image.
func (d *decoder) makeImg(mxx, myy int) {
if d.nComp == 1 {
m := image.NewGray(image.Rect(0, 0, 8*mxx, 8*myy))
d.img1 = m.SubImage(image.Rect(0, 0, d.width, d.height)).(*image.Gray)
return
}
h0 := d.comp[0].h
v0 := d.comp[0].v
hRatio := h0 / d.comp[1].h
vRatio := v0 / d.comp[1].v
var subsampleRatio image.YCbCrSubsampleRatio
switch hRatio<<4 | vRatio {
case 0x11:
subsampleRatio = image.YCbCrSubsampleRatio444
case 0x12:
subsampleRatio = image.YCbCrSubsampleRatio440
case 0x21:
subsampleRatio = image.YCbCrSubsampleRatio422
case 0x22:
subsampleRatio = image.YCbCrSubsampleRatio420
case 0x41:
subsampleRatio = image.YCbCrSubsampleRatio411
case 0x42:
subsampleRatio = image.YCbCrSubsampleRatio410
default:
panic("unreachable")
}
m := image.NewYCbCr(image.Rect(0, 0, 8*h0*mxx, 8*v0*myy), subsampleRatio)
d.img3 = m.SubImage(image.Rect(0, 0, d.width, d.height)).(*image.YCbCr)
if d.nComp == 4 {
h3, v3 := d.comp[3].h, d.comp[3].v
d.blackPix = make([]byte, 8*h3*mxx*8*v3*myy)
d.blackStride = 8 * h3 * mxx
}
}
// Specified in section B.2.3.
func (d *decoder) processSOS(n int) error {
if d.nComp == 0 {
return FormatError("missing SOF marker")
}
if n < 6 || 4+2*d.nComp < n || n%2 != 0 {
return FormatError("SOS has wrong length")
}
if err := d.readFull(d.tmp[:n]); err != nil {
return err
}
nComp := int(d.tmp[0])
if n != 4+2*nComp {
return FormatError("SOS length inconsistent with number of components")
}
var scan [maxComponents]struct {
compIndex uint8
td uint8 // DC table selector.
ta uint8 // AC table selector.
}
totalHV := 0
for i := 0; i < nComp; i++ {
cs := d.tmp[1+2*i] // Component selector.
compIndex := -1
for j, comp := range d.comp[:d.nComp] {
if cs == comp.c {
compIndex = j
}
}
if compIndex < 0 {
return FormatError("unknown component selector")
}
scan[i].compIndex = uint8(compIndex)
// Section B.2.3 states that "the value of Cs_j shall be different from
// the values of Cs_1 through Cs_(j-1)". Since we have previously
// verified that a frame's component identifiers (C_i values in section
// B.2.2) are unique, it suffices to check that the implicit indexes
// into d.comp are unique.
for j := 0; j < i; j++ {
if scan[i].compIndex == scan[j].compIndex {
return FormatError("repeated component selector")
}
}
totalHV += d.comp[compIndex].h * d.comp[compIndex].v
// The baseline t <= 1 restriction is specified in table B.3.
scan[i].td = d.tmp[2+2*i] >> 4
if t := scan[i].td; t > maxTh || (d.baseline && t > 1) {
return FormatError("bad Td value")
}
scan[i].ta = d.tmp[2+2*i] & 0x0f
if t := scan[i].ta; t > maxTh || (d.baseline && t > 1) {
return FormatError("bad Ta value")
}
}
// Section B.2.3 states that if there is more than one component then the
// total H*V values in a scan must be <= 10.
if d.nComp > 1 && totalHV > 10 {
return FormatError("total sampling factors too large")
}
// zigStart and zigEnd are the spectral selection bounds.
// ah and al are the successive approximation high and low values.
// The spec calls these values Ss, Se, Ah and Al.
//
// For progressive JPEGs, these are the two more-or-less independent
// aspects of progression. Spectral selection progression is when not
// all of a block's 64 DCT coefficients are transmitted in one pass.
// For example, three passes could transmit coefficient 0 (the DC
// component), coefficients 1-5, and coefficients 6-63, in zig-zag
// order. Successive approximation is when not all of the bits of a
// band of coefficients are transmitted in one pass. For example,
// three passes could transmit the 6 most significant bits, followed
// by the second-least significant bit, followed by the least
// significant bit.
//
// For sequential JPEGs, these parameters are hard-coded to 0/63/0/0, as
// per table B.3.
zigStart, zigEnd, ah, al := int32(0), int32(blockSize-1), uint32(0), uint32(0)
if d.progressive {
zigStart = int32(d.tmp[1+2*nComp])
zigEnd = int32(d.tmp[2+2*nComp])
ah = uint32(d.tmp[3+2*nComp] >> 4)
al = uint32(d.tmp[3+2*nComp] & 0x0f)
if (zigStart == 0 && zigEnd != 0) || zigStart > zigEnd || blockSize <= zigEnd {
return FormatError("bad spectral selection bounds")
}
if zigStart != 0 && nComp != 1 {
return FormatError("progressive AC coefficients for more than one component")
}
if ah != 0 && ah != al+1 {
return FormatError("bad successive approximation values")
}
}
// mxx and myy are the number of MCUs (Minimum Coded Units) in the image.
h0, v0 := d.comp[0].h, d.comp[0].v // The h and v values from the Y components.
mxx := (d.width + 8*h0 - 1) / (8 * h0)
myy := (d.height + 8*v0 - 1) / (8 * v0)
if d.img1 == nil && d.img3 == nil {
d.makeImg(mxx, myy)
}
if d.progressive {
for i := 0; i < nComp; i++ {
compIndex := scan[i].compIndex
if d.progCoeffs[compIndex] == nil {
d.progCoeffs[compIndex] = make([]block, mxx*myy*d.comp[compIndex].h*d.comp[compIndex].v)
}
}
}
d.bits = bits{}
mcu, expectedRST := 0, uint8(rst0Marker)
var (
// b is the decoded coefficients, in natural (not zig-zag) order.
b block
dc [maxComponents]int32
// bx and by are the location of the current block, in units of 8x8
// blocks: the third block in the first row has (bx, by) = (2, 0).
bx, by int
blockCount int
)
for my := 0; my < myy; my++ {
for mx := 0; mx < mxx; mx++ {
for i := 0; i < nComp; i++ {
compIndex := scan[i].compIndex
hi := d.comp[compIndex].h
vi := d.comp[compIndex].v
for j := 0; j < hi*vi; j++ {
// The blocks are traversed one MCU at a time. For 4:2:0 chroma
// subsampling, there are four Y 8x8 blocks in every 16x16 MCU.
//
// For a sequential 32x16 pixel image, the Y blocks visiting order is:
// 0 1 4 5
// 2 3 6 7
//
// For progressive images, the interleaved scans (those with nComp > 1)
// are traversed as above, but non-interleaved scans are traversed left
// to right, top to bottom:
// 0 1 2 3
// 4 5 6 7
// Only DC scans (zigStart == 0) can be interleaved. AC scans must have
// only one component.
//
// To further complicate matters, for non-interleaved scans, there is no
// data for any blocks that are inside the image at the MCU level but
// outside the image at the pixel level. For example, a 24x16 pixel 4:2:0
// progressive image consists of two 16x16 MCUs. The interleaved scans
// will process 8 Y blocks:
// 0 1 4 5
// 2 3 6 7
// The non-interleaved scans will process only 6 Y blocks:
// 0 1 2
// 3 4 5
if nComp != 1 {
bx = hi*mx + j%hi
by = vi*my + j/hi
} else {
q := mxx * hi
bx = blockCount % q
by = blockCount / q
blockCount++
if bx*8 >= d.width || by*8 >= d.height {
continue
}
}
// Load the previous partially decoded coefficients, if applicable.
if d.progressive {
b = d.progCoeffs[compIndex][by*mxx*hi+bx]
} else {
b = block{}
}
if ah != 0 {
if err := d.refine(&b, &d.huff[acTable][scan[i].ta], zigStart, zigEnd, 1<<al); err != nil {
return err
}
} else {
zig := zigStart
if zig == 0 {
zig++
// Decode the DC coefficient, as specified in section F.2.2.1.
value, err := d.decodeHuffman(&d.huff[dcTable][scan[i].td])
if err != nil {
return err
}
if value > 16 {
return UnsupportedError("excessive DC component")
}
dcDelta, err := d.receiveExtend(value)
if err != nil {
return err
}
dc[compIndex] += dcDelta
b[0] = dc[compIndex] << al
}
if zig <= zigEnd && d.eobRun > 0 {
d.eobRun--
} else {
// Decode the AC coefficients, as specified in section F.2.2.2.
huff := &d.huff[acTable][scan[i].ta]
for ; zig <= zigEnd; zig++ {
value, err := d.decodeHuffman(huff)
if err != nil {
return err
}
val0 := value >> 4
val1 := value & 0x0f
if val1 != 0 {
zig += int32(val0)
if zig > zigEnd {
break
}
ac, err := d.receiveExtend(val1)
if err != nil {
return err
}
b[unzig[zig]] = ac << al
} else {
if val0 != 0x0f {
d.eobRun = uint16(1 << val0)
if val0 != 0 {
bits, err := d.decodeBits(int32(val0))
if err != nil {
return err
}
d.eobRun |= uint16(bits)
}
d.eobRun--
break
}
zig += 0x0f
}
}
}
}
if d.progressive {
// Save the coefficients.
d.progCoeffs[compIndex][by*mxx*hi+bx] = b
// At this point, we could call reconstructBlock to dequantize and perform the
// inverse DCT, to save early stages of a progressive image to the *image.YCbCr
// buffers (the whole point of progressive encoding), but in Go, the jpeg.Decode
// function does not return until the entire image is decoded, so we "continue"
// here to avoid wasted computation. Instead, reconstructBlock is called on each
// accumulated block by the reconstructProgressiveImage method after all of the
// SOS markers are processed.
continue
}
if err := d.reconstructBlock(&b, bx, by, int(compIndex)); err != nil {
return err
}
} // for j
} // for i
mcu++
if d.ri > 0 && mcu%d.ri == 0 && mcu < mxx*myy {
// For well-formed input, the RST[0-7] restart marker follows
// immediately. For corrupt input, call findRST to try to
// resynchronize.
if err := d.readFull(d.tmp[:2]); err != nil {
return err
} else if d.tmp[0] != 0xff || d.tmp[1] != expectedRST {
if err := d.findRST(expectedRST); err != nil {
return err
}
}
expectedRST++
if expectedRST == rst7Marker+1 {
expectedRST = rst0Marker
}
// Reset the Huffman decoder.
d.bits = bits{}
// Reset the DC components, as per section F.2.1.3.1.
dc = [maxComponents]int32{}
// Reset the progressive decoder state, as per section G.1.2.2.
d.eobRun = 0
}
} // for mx
} // for my
return nil
}
// refine decodes a successive approximation refinement block, as specified in
// section G.1.2.
func (d *decoder) refine(b *block, h *huffman, zigStart, zigEnd, delta int32) error {
// Refining a DC component is trivial.
if zigStart == 0 {
if zigEnd != 0 {
panic("unreachable")
}
bit, err := d.decodeBit()
if err != nil {
return err
}
if bit {
b[0] |= delta
}
return nil
}
// Refining AC components is more complicated; see sections G.1.2.2 and G.1.2.3.
zig := zigStart
if d.eobRun == 0 {
loop:
for ; zig <= zigEnd; zig++ {
z := int32(0)
value, err := d.decodeHuffman(h)
if err != nil {
return err
}
val0 := value >> 4
val1 := value & 0x0f
switch val1 {
case 0:
if val0 != 0x0f {
d.eobRun = uint16(1 << val0)
if val0 != 0 {
bits, err := d.decodeBits(int32(val0))
if err != nil {
return err
}
d.eobRun |= uint16(bits)
}
break loop
}
case 1:
z = delta
bit, err := d.decodeBit()
if err != nil {
return err
}
if !bit {
z = -z
}
default:
return FormatError("unexpected Huffman code")
}
zig, err = d.refineNonZeroes(b, zig, zigEnd, int32(val0), delta)
if err != nil {
return err
}
if zig > zigEnd {
return FormatError("too many coefficients")
}
if z != 0 {
b[unzig[zig]] = z
}
}
}
if d.eobRun > 0 {
d.eobRun--
if _, err := d.refineNonZeroes(b, zig, zigEnd, -1, delta); err != nil {
return err
}
}
return nil
}
// refineNonZeroes refines non-zero entries of b in zig-zag order. If nz >= 0,
// the first nz zero entries are skipped over.
func (d *decoder) refineNonZeroes(b *block, zig, zigEnd, nz, delta int32) (int32, error) {
for ; zig <= zigEnd; zig++ {
u := unzig[zig]
if b[u] == 0 {
if nz == 0 {
break
}
nz--
continue
}
bit, err := d.decodeBit()
if err != nil {
return 0, err
}
if !bit {
continue
}
if b[u] >= 0 {
b[u] += delta
} else {
b[u] -= delta
}
}
return zig, nil
}
func (d *decoder) reconstructProgressiveImage() error {
// The h0, mxx, by and bx variables have the same meaning as in the
// processSOS method.
h0 := d.comp[0].h
mxx := (d.width + 8*h0 - 1) / (8 * h0)
for i := 0; i < d.nComp; i++ {
if d.progCoeffs[i] == nil {
continue
}
v := 8 * d.comp[0].v / d.comp[i].v
h := 8 * d.comp[0].h / d.comp[i].h
stride := mxx * d.comp[i].h
for by := 0; by*v < d.height; by++ {
for bx := 0; bx*h < d.width; bx++ {
if err := d.reconstructBlock(&d.progCoeffs[i][by*stride+bx], bx, by, i); err != nil {
return err
}
}
}
}
return nil
}
// reconstructBlock dequantizes, performs the inverse DCT and stores the block
// to the image.
func (d *decoder) reconstructBlock(b *block, bx, by, compIndex int) error {
qt := &d.quant[d.comp[compIndex].tq]
for zig := 0; zig < blockSize; zig++ {
b[unzig[zig]] *= qt[zig]
}
idct(b)
dst, stride := []byte(nil), 0
if d.nComp == 1 {
dst, stride = d.img1.Pix[8*(by*d.img1.Stride+bx):], d.img1.Stride
} else {
switch compIndex {
case 0: