package fuzz import "github.com/andybalholm/cascadia" // Fuzz is the entrypoint used by the go-fuzz framework func Fuzz(data []byte) int { sel, err := cascadia.Compile(string(data)) if err != nil { if sel != nil { panic("sel != nil on error") } return 0 } return 1 }
// Package cascadia is an implementation of CSS selectors. package cascadia import ( "errors" "fmt" "regexp" "strconv" "strings" ) // a parser for CSS selectors type parser struct { s string // the source text i int // the current position // if `false`, parsing a pseudo-element // returns an error. acceptPseudoElements bool } // parseEscape parses a backslash escape. func (p *parser) parseEscape() (result string, err error) { if len(p.s) < p.i+2 || p.s[p.i] != '\\' { return "", errors.New("invalid escape sequence") } start := p.i + 1 c := p.s[start] switch { case c == '\r' || c == '\n' || c == '\f': return "", errors.New("escaped line ending outside string") case hexDigit(c): // unicode escape (hex) var i int for i = start; i < start+6 && i < len(p.s) && hexDigit(p.s[i]); i++ { // empty } v, _ := strconv.ParseUint(p.s[start:i], 16, 64) if len(p.s) > i { switch p.s[i] { case '\r': i++ if len(p.s) > i && p.s[i] == '\n' { i++ } case ' ', '\t', '\n', '\f': i++ } } p.i = i return string(rune(v)), nil } // Return the literal character after the backslash. result = p.s[start : start+1] p.i += 2 return result, nil } // toLowerASCII returns s with all ASCII capital letters lowercased. func toLowerASCII(s string) string { var b []byte for i := 0; i < len(s); i++ { if c := s[i]; 'A' <= c && c <= 'Z' { if b == nil { b = make([]byte, len(s)) copy(b, s) } b[i] = s[i] + ('a' - 'A') } } if b == nil { return s } return string(b) } func hexDigit(c byte) bool { return '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' } // nameStart returns whether c can be the first character of an identifier // (not counting an initial hyphen, or an escape sequence). func nameStart(c byte) bool { return 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_' || c > 127 } // nameChar returns whether c can be a character within an identifier // (not counting an escape sequence). func nameChar(c byte) bool { return 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_' || c > 127 || c == '-' || '0' <= c && c <= '9' } // parseIdentifier parses an identifier. func (p *parser) parseIdentifier() (result string, err error) { const prefix = '-' var numPrefix int for len(p.s) > p.i && p.s[p.i] == prefix { p.i++ numPrefix++ } if len(p.s) <= p.i { return "", errors.New("expected identifier, found EOF instead") } if c := p.s[p.i]; !(nameStart(c) || c == '\\') { return "", fmt.Errorf("expected identifier, found %c instead", c) } result, err = p.parseName() if numPrefix > 0 && err == nil { result = strings.Repeat(string(prefix), numPrefix) + result } return } // parseName parses a name (which is like an identifier, but doesn't have // extra restrictions on the first character). func (p *parser) parseName() (result string, err error) { i := p.i loop: for i < len(p.s) { c := p.s[i] switch { case nameChar(c): start := i for i < len(p.s) && nameChar(p.s[i]) { i++ } result += p.s[start:i] case c == '\\': p.i = i val, err := p.parseEscape() if err != nil { return "", err } i = p.i result += val default: break loop } } if result == "" { return "", errors.New("expected name, found EOF instead") } p.i = i return result, nil } // parseString parses a single- or double-quoted string. func (p *parser) parseString() (result string, err error) { i := p.i if len(p.s) < i+2 { return "", errors.New("expected string, found EOF instead") } quote := p.s[i] i++ loop: for i < len(p.s) { switch p.s[i] { case '\\': if len(p.s) > i+1 { switch c := p.s[i+1]; c { case '\r': if len(p.s) > i+2 && p.s[i+2] == '\n' { i += 3 continue loop } fallthrough case '\n', '\f': i += 2 continue loop } } p.i = i val, err := p.parseEscape() if err != nil { return "", err } i = p.i result += val case quote: break loop case '\r', '\n', '\f': return "", errors.New("unexpected end of line in string") default: start := i for i < len(p.s) { if c := p.s[i]; c == quote || c == '\\' || c == '\r' || c == '\n' || c == '\f' { break } i++ } result += p.s[start:i] } } if i >= len(p.s) { return "", errors.New("EOF in string") } // Consume the final quote. i++ p.i = i return result, nil } // parseRegex parses a regular expression; the end is defined by encountering an // unmatched closing ')' or ']' which is not consumed func (p *parser) parseRegex() (rx *regexp.Regexp, err error) { i := p.i if len(p.s) < i+2 { return nil, errors.New("expected regular expression, found EOF instead") } // number of open parens or brackets; // when it becomes negative, finished parsing regex open := 0 loop: for i < len(p.s) { switch p.s[i] { case '(', '[': open++ case ')', ']': open-- if open < 0 { break loop } } i++ } if i >= len(p.s) { return nil, errors.New("EOF in regular expression") } rx, err = regexp.Compile(p.s[p.i:i]) p.i = i return rx, err } // skipWhitespace consumes whitespace characters and comments. // It returns true if there was actually anything to skip. func (p *parser) skipWhitespace() bool { i := p.i for i < len(p.s) { switch p.s[i] { case ' ', '\t', '\r', '\n', '\f': i++ continue case '/': if strings.HasPrefix(p.s[i:], "/*") { end := strings.Index(p.s[i+len("/*"):], "*/") if end != -1 { i += end + len("/**/") continue } } } break } if i > p.i { p.i = i return true } return false } // consumeParenthesis consumes an opening parenthesis and any following // whitespace. It returns true if there was actually a parenthesis to skip. func (p *parser) consumeParenthesis() bool { if p.i < len(p.s) && p.s[p.i] == '(' { p.i++ p.skipWhitespace() return true } return false } // consumeClosingParenthesis consumes a closing parenthesis and any preceding // whitespace. It returns true if there was actually a parenthesis to skip. func (p *parser) consumeClosingParenthesis() bool { i := p.i p.skipWhitespace() if p.i < len(p.s) && p.s[p.i] == ')' { p.i++ return true } p.i = i return false } // parseTypeSelector parses a type selector (one that matches by tag name). func (p *parser) parseTypeSelector() (result tagSelector, err error) { tag, err := p.parseIdentifier() if err != nil { return } return tagSelector{tag: toLowerASCII(tag)}, nil } // parseIDSelector parses a selector that matches by id attribute. func (p *parser) parseIDSelector() (idSelector, error) { if p.i >= len(p.s) { return idSelector{}, fmt.Errorf("expected id selector (#id), found EOF instead") } if p.s[p.i] != '#' { return idSelector{}, fmt.Errorf("expected id selector (#id), found '%c' instead", p.s[p.i]) } p.i++ id, err := p.parseName() if err != nil { return idSelector{}, err } return idSelector{id: id}, nil } // parseClassSelector parses a selector that matches by class attribute. func (p *parser) parseClassSelector() (classSelector, error) { if p.i >= len(p.s) { return classSelector{}, fmt.Errorf("expected class selector (.class), found EOF instead") } if p.s[p.i] != '.' { return classSelector{}, fmt.Errorf("expected class selector (.class), found '%c' instead", p.s[p.i]) } p.i++ class, err := p.parseIdentifier() if err != nil { return classSelector{}, err } return classSelector{class: class}, nil } // parseAttributeSelector parses a selector that matches by attribute value. func (p *parser) parseAttributeSelector() (attrSelector, error) { if p.i >= len(p.s) { return attrSelector{}, fmt.Errorf("expected attribute selector ([attribute]), found EOF instead") } if p.s[p.i] != '[' { return attrSelector{}, fmt.Errorf("expected attribute selector ([attribute]), found '%c' instead", p.s[p.i]) } p.i++ p.skipWhitespace() key, err := p.parseIdentifier() if err != nil { return attrSelector{}, err } key = toLowerASCII(key) p.skipWhitespace() if p.i >= len(p.s) { return attrSelector{}, errors.New("unexpected EOF in attribute selector") } if p.s[p.i] == ']' { p.i++ return attrSelector{key: key, operation: ""}, nil } if p.i+2 >= len(p.s) { return attrSelector{}, errors.New("unexpected EOF in attribute selector") } op := p.s[p.i : p.i+2] if op[0] == '=' { op = "=" } else if op[1] != '=' { return attrSelector{}, fmt.Errorf(`expected equality operator, found "%s" instead`, op) } p.i += len(op) p.skipWhitespace() if p.i >= len(p.s) { return attrSelector{}, errors.New("unexpected EOF in attribute selector") } var val string var rx *regexp.Regexp if op == "#=" { rx, err = p.parseRegex() } else { switch p.s[p.i] { case '\'', '"': val, err = p.parseString() default: val, err = p.parseIdentifier() } } if err != nil { return attrSelector{}, err } p.skipWhitespace() if p.i >= len(p.s) { return attrSelector{}, errors.New("unexpected EOF in attribute selector") } // check if the attribute contains an ignore case flag ignoreCase := false if p.s[p.i] == 'i' || p.s[p.i] == 'I' { ignoreCase = true p.i++ } p.skipWhitespace() if p.i >= len(p.s) { return attrSelector{}, errors.New("unexpected EOF in attribute selector") } if p.s[p.i] != ']' { return attrSelector{}, fmt.Errorf("expected ']', found '%c' instead", p.s[p.i]) } p.i++ switch op { case "=", "!=", "~=", "|=", "^=", "$=", "*=", "#=": return attrSelector{key: key, val: val, operation: op, regexp: rx, insensitive: ignoreCase}, nil default: return attrSelector{}, fmt.Errorf("attribute operator %q is not supported", op) } } var ( errExpectedParenthesis = errors.New("expected '(' but didn't find it") errExpectedClosingParenthesis = errors.New("expected ')' but didn't find it") errUnmatchedParenthesis = errors.New("unmatched '('") ) // parsePseudoclassSelector parses a pseudoclass selector like :not(p) or a pseudo-element // For backwards compatibility, both ':' and '::' prefix are allowed for pseudo-elements. // https://drafts.csswg.org/selectors-3/#pseudo-elements // Returning a nil `Sel` (and a nil `error`) means we found a pseudo-element. func (p *parser) parsePseudoclassSelector() (out Sel, pseudoElement string, err error) { if p.i >= len(p.s) { return nil, "", fmt.Errorf("expected pseudoclass selector (:pseudoclass), found EOF instead") } if p.s[p.i] != ':' { return nil, "", fmt.Errorf("expected attribute selector (:pseudoclass), found '%c' instead", p.s[p.i]) } p.i++ var mustBePseudoElement bool if p.i >= len(p.s) { return nil, "", fmt.Errorf("got empty pseudoclass (or pseudoelement)") } if p.s[p.i] == ':' { // we found a pseudo-element mustBePseudoElement = true p.i++ } name, err := p.parseIdentifier() if err != nil { return } name = toLowerASCII(name) if mustBePseudoElement && (name != "after" && name != "backdrop" && name != "before" && name != "cue" && name != "first-letter" && name != "first-line" && name != "grammar-error" && name != "marker" && name != "placeholder" && name != "selection" && name != "spelling-error") { return out, "", fmt.Errorf("unknown pseudoelement :%s", name) } switch name { case "not", "has", "haschild": if !p.consumeParenthesis() { return out, "", errExpectedParenthesis } sel, parseErr := p.parseSelectorGroup() if parseErr != nil { return out, "", parseErr } if !p.consumeClosingParenthesis() { return out, "", errExpectedClosingParenthesis } out = relativePseudoClassSelector{name: name, match: sel} case "contains", "containsown": if !p.consumeParenthesis() { return out, "", errExpectedParenthesis } if p.i == len(p.s) { return out, "", errUnmatchedParenthesis } var val string switch p.s[p.i] { case '\'', '"': val, err = p.parseString() default: val, err = p.parseIdentifier() } if err != nil { return out, "", err } val = strings.ToLower(val) p.skipWhitespace() if p.i >= len(p.s) { return out, "", errors.New("unexpected EOF in pseudo selector") } if !p.consumeClosingParenthesis() { return out, "", errExpectedClosingParenthesis } out = containsPseudoClassSelector{own: name == "containsown", value: val} case "matches", "matchesown": if !p.consumeParenthesis() { return out, "", errExpectedParenthesis } rx, err := p.parseRegex() if err != nil { return out, "", err } if p.i >= len(p.s) { return out, "", errors.New("unexpected EOF in pseudo selector") } if !p.consumeClosingParenthesis() { return out, "", errExpectedClosingParenthesis } out = regexpPseudoClassSelector{own: name == "matchesown", regexp: rx} case "nth-child", "nth-last-child", "nth-of-type", "nth-last-of-type": if !p.consumeParenthesis() { return out, "", errExpectedParenthesis } a, b, err := p.parseNth() if err != nil { return out, "", err } if !p.consumeClosingParenthesis() { return out, "", errExpectedClosingParenthesis } last := name == "nth-last-child" || name == "nth-last-of-type" ofType := name == "nth-of-type" || name == "nth-last-of-type" out = nthPseudoClassSelector{a: a, b: b, last: last, ofType: ofType} case "first-child": out = nthPseudoClassSelector{a: 0, b: 1, ofType: false, last: false} case "last-child": out = nthPseudoClassSelector{a: 0, b: 1, ofType: false, last: true} case "first-of-type": out = nthPseudoClassSelector{a: 0, b: 1, ofType: true, last: false} case "last-of-type": out = nthPseudoClassSelector{a: 0, b: 1, ofType: true, last: true} case "only-child": out = onlyChildPseudoClassSelector{ofType: false} case "only-of-type": out = onlyChildPseudoClassSelector{ofType: true} case "input": out = inputPseudoClassSelector{} case "empty": out = emptyElementPseudoClassSelector{} case "root": out = rootPseudoClassSelector{} case "link": out = linkPseudoClassSelector{} case "lang": if !p.consumeParenthesis() { return out, "", errExpectedParenthesis } if p.i == len(p.s) { return out, "", errUnmatchedParenthesis } val, err := p.parseIdentifier() if err != nil { return out, "", err } val = strings.ToLower(val) p.skipWhitespace() if p.i >= len(p.s) { return out, "", errors.New("unexpected EOF in pseudo selector") } if !p.consumeClosingParenthesis() { return out, "", errExpectedClosingParenthesis } out = langPseudoClassSelector{lang: val} case "enabled": out = enabledPseudoClassSelector{} case "disabled": out = disabledPseudoClassSelector{} case "checked": out = checkedPseudoClassSelector{} case "visited", "hover", "active", "focus", "target": // Not applicable in a static context: never match. out = neverMatchSelector{value: ":" + name} case "after", "backdrop", "before", "cue", "first-letter", "first-line", "grammar-error", "marker", "placeholder", "selection", "spelling-error": return nil, name, nil default: return out, "", fmt.Errorf("unknown pseudoclass or pseudoelement :%s", name) } return } // parseInteger parses a decimal integer. func (p *parser) parseInteger() (int, error) { i := p.i start := i for i < len(p.s) && '0' <= p.s[i] && p.s[i] <= '9' { i++ } if i == start { return 0, errors.New("expected integer, but didn't find it") } p.i = i val, err := strconv.Atoi(p.s[start:i]) if err != nil { return 0, err } return val, nil } // parseNth parses the argument for :nth-child (normally of the form an+b). func (p *parser) parseNth() (a, b int, err error) { // initial state if p.i >= len(p.s) { goto eof } switch p.s[p.i] { case '-': p.i++ goto negativeA case '+': p.i++ goto positiveA case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': goto positiveA case 'n', 'N': a = 1 p.i++ goto readN case 'o', 'O', 'e', 'E': id, nameErr := p.parseName() if nameErr != nil { return 0, 0, nameErr } id = toLowerASCII(id) if id == "odd" { return 2, 1, nil } if id == "even" { return 2, 0, nil } return 0, 0, fmt.Errorf("expected 'odd' or 'even', but found '%s' instead", id) default: goto invalid } positiveA: if p.i >= len(p.s) { goto eof } switch p.s[p.i] { case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': a, err = p.parseInteger() if err != nil { return 0, 0, err } goto readA case 'n', 'N': a = 1 p.i++ goto readN default: goto invalid } negativeA: if p.i >= len(p.s) { goto eof } switch p.s[p.i] { case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': a, err = p.parseInteger() if err != nil { return 0, 0, err } a = -a goto readA case 'n', 'N': a = -1 p.i++ goto readN default: goto invalid } readA: if p.i >= len(p.s) { goto eof } switch p.s[p.i] { case 'n', 'N': p.i++ goto readN default: // The number we read as a is actually b. return 0, a, nil } readN: p.skipWhitespace() if p.i >= len(p.s) { goto eof } switch p.s[p.i] { case '+': p.i++ p.skipWhitespace() b, err = p.parseInteger() if err != nil { return 0, 0, err } return a, b, nil case '-': p.i++ p.skipWhitespace() b, err = p.parseInteger() if err != nil { return 0, 0, err } return a, -b, nil default: return a, 0, nil } eof: return 0, 0, errors.New("unexpected EOF while attempting to parse expression of form an+b") invalid: return 0, 0, errors.New("unexpected character while attempting to parse expression of form an+b") } // parseSimpleSelectorSequence parses a selector sequence that applies to // a single element. func (p *parser) parseSimpleSelectorSequence() (Sel, error) { var selectors []Sel if p.i >= len(p.s) { return nil, errors.New("expected selector, found EOF instead") } switch p.s[p.i] { case '*': // It's the universal selector. Just skip over it, since it doesn't affect the meaning. p.i++ if p.i+2 < len(p.s) && p.s[p.i:p.i+2] == "|*" { // other version of universal selector p.i += 2 } case '#', '.', '[', ':': // There's no type selector. Wait to process the other till the main loop. default: r, err := p.parseTypeSelector() if err != nil { return nil, err } selectors = append(selectors, r) } var pseudoElement string loop: for p.i < len(p.s) { var ( ns Sel newPseudoElement string err error ) switch p.s[p.i] { case '#': ns, err = p.parseIDSelector() case '.': ns, err = p.parseClassSelector() case '[': ns, err = p.parseAttributeSelector() case ':': ns, newPseudoElement, err = p.parsePseudoclassSelector() default: break loop } if err != nil { return nil, err } // From https://drafts.csswg.org/selectors-3/#pseudo-elements : // "Only one pseudo-element may appear per selector, and if present // it must appear after the sequence of simple selectors that // represents the subjects of the selector."" if ns == nil { // we found a pseudo-element if pseudoElement != "" { return nil, fmt.Errorf("only one pseudo-element is accepted per selector, got %s and %s", pseudoElement, newPseudoElement) } if !p.acceptPseudoElements { return nil, fmt.Errorf("pseudo-element %s found, but pseudo-elements support is disabled", newPseudoElement) } pseudoElement = newPseudoElement } else { if pseudoElement != "" { return nil, fmt.Errorf("pseudo-element %s must be at the end of selector", pseudoElement) } selectors = append(selectors, ns) } } if len(selectors) == 1 && pseudoElement == "" { // no need wrap the selectors in compoundSelector return selectors[0], nil } return compoundSelector{selectors: selectors, pseudoElement: pseudoElement}, nil } // parseSelector parses a selector that may include combinators. func (p *parser) parseSelector() (Sel, error) { p.skipWhitespace() result, err := p.parseSimpleSelectorSequence() if err != nil { return nil, err } for { var ( combinator byte c Sel ) if p.skipWhitespace() { combinator = ' ' } if p.i >= len(p.s) { return result, nil } switch p.s[p.i] { case '+', '>', '~': combinator = p.s[p.i] p.i++ p.skipWhitespace() case ',', ')': // These characters can't begin a selector, but they can legally occur after one. return result, nil } if combinator == 0 { return result, nil } c, err = p.parseSimpleSelectorSequence() if err != nil { return nil, err } result = combinedSelector{first: result, combinator: combinator, second: c} } } // parseSelectorGroup parses a group of selectors, separated by commas. func (p *parser) parseSelectorGroup() (SelectorGroup, error) { current, err := p.parseSelector() if err != nil { return nil, err } result := SelectorGroup{current} for p.i < len(p.s) { if p.s[p.i] != ',' { break } p.i++ c, err := p.parseSelector() if err != nil { return nil, err } result = append(result, c) } return result, nil }
package cascadia import ( "bytes" "fmt" "regexp" "strings" "golang.org/x/net/html" "golang.org/x/net/html/atom" ) // This file implements the pseudo classes selectors, // which share the implementation of PseudoElement() and Specificity() type abstractPseudoClass struct{} func (s abstractPseudoClass) Specificity() Specificity { return Specificity{0, 1, 0} } func (c abstractPseudoClass) PseudoElement() string { return "" } type relativePseudoClassSelector struct { name string // one of "not", "has", "haschild" match SelectorGroup } func (s relativePseudoClassSelector) Match(n *html.Node) bool { if n.Type != html.ElementNode { return false } switch s.name { case "not": // matches elements that do not match a. return !s.match.Match(n) case "has": // matches elements with any descendant that matches a. return hasDescendantMatch(n, s.match) case "haschild": // matches elements with a child that matches a. return hasChildMatch(n, s.match) default: panic(fmt.Sprintf("unsupported relative pseudo class selector : %s", s.name)) } } // hasChildMatch returns whether n has any child that matches a. func hasChildMatch(n *html.Node, a Matcher) bool { for c := n.FirstChild; c != nil; c = c.NextSibling { if a.Match(c) { return true } } return false } // hasDescendantMatch performs a depth-first search of n's descendants, // testing whether any of them match a. It returns true as soon as a match is // found, or false if no match is found. func hasDescendantMatch(n *html.Node, a Matcher) bool { for c := n.FirstChild; c != nil; c = c.NextSibling { if a.Match(c) || (c.Type == html.ElementNode && hasDescendantMatch(c, a)) { return true } } return false } // Specificity returns the specificity of the most specific selectors // in the pseudo-class arguments. // See https://www.w3.org/TR/selectors/#specificity-rules func (s relativePseudoClassSelector) Specificity() Specificity { var max Specificity for _, sel := range s.match { newSpe := sel.Specificity() if max.Less(newSpe) { max = newSpe } } return max } func (c relativePseudoClassSelector) PseudoElement() string { return "" } type containsPseudoClassSelector struct { abstractPseudoClass value string own bool } func (s containsPseudoClassSelector) Match(n *html.Node) bool { var text string if s.own { // matches nodes that directly contain the given text text = strings.ToLower(nodeOwnText(n)) } else { // matches nodes that contain the given text. text = strings.ToLower(nodeText(n)) } return strings.Contains(text, s.value) } type regexpPseudoClassSelector struct { abstractPseudoClass regexp *regexp.Regexp own bool } func (s regexpPseudoClassSelector) Match(n *html.Node) bool { var text string if s.own { // matches nodes whose text directly matches the specified regular expression text = nodeOwnText(n) } else { // matches nodes whose text matches the specified regular expression text = nodeText(n) } return s.regexp.MatchString(text) } // writeNodeText writes the text contained in n and its descendants to b. func writeNodeText(n *html.Node, b *bytes.Buffer) { switch n.Type { case html.TextNode: b.WriteString(n.Data) case html.ElementNode: for c := n.FirstChild; c != nil; c = c.NextSibling { writeNodeText(c, b) } } } // nodeText returns the text contained in n and its descendants. func nodeText(n *html.Node) string { var b bytes.Buffer writeNodeText(n, &b) return b.String() } // nodeOwnText returns the contents of the text nodes that are direct // children of n. func nodeOwnText(n *html.Node) string { var b bytes.Buffer for c := n.FirstChild; c != nil; c = c.NextSibling { if c.Type == html.TextNode { b.WriteString(c.Data) } } return b.String() } type nthPseudoClassSelector struct { abstractPseudoClass a, b int last, ofType bool } func (s nthPseudoClassSelector) Match(n *html.Node) bool { if s.a == 0 { if s.last { return simpleNthLastChildMatch(s.b, s.ofType, n) } else { return simpleNthChildMatch(s.b, s.ofType, n) } } return nthChildMatch(s.a, s.b, s.last, s.ofType, n) } // nthChildMatch implements :nth-child(an+b). // If last is true, implements :nth-last-child instead. // If ofType is true, implements :nth-of-type instead. func nthChildMatch(a, b int, last, ofType bool, n *html.Node) bool { if n.Type != html.ElementNode { return false } parent := n.Parent if parent == nil { return false } i := -1 count := 0 for c := parent.FirstChild; c != nil; c = c.NextSibling { if (c.Type != html.ElementNode) || (ofType && c.Data != n.Data) { continue } count++ if c == n { i = count if !last { break } } } if i == -1 { // This shouldn't happen, since n should always be one of its parent's children. return false } if last { i = count - i + 1 } i -= b if a == 0 { return i == 0 } return i%a == 0 && i/a >= 0 } // simpleNthChildMatch implements :nth-child(b). // If ofType is true, implements :nth-of-type instead. func simpleNthChildMatch(b int, ofType bool, n *html.Node) bool { if n.Type != html.ElementNode { return false } parent := n.Parent if parent == nil { return false } count := 0 for c := parent.FirstChild; c != nil; c = c.NextSibling { if c.Type != html.ElementNode || (ofType && c.Data != n.Data) { continue } count++ if c == n { return count == b } if count >= b { return false } } return false } // simpleNthLastChildMatch implements :nth-last-child(b). // If ofType is true, implements :nth-last-of-type instead. func simpleNthLastChildMatch(b int, ofType bool, n *html.Node) bool { if n.Type != html.ElementNode { return false } parent := n.Parent if parent == nil { return false } count := 0 for c := parent.LastChild; c != nil; c = c.PrevSibling { if c.Type != html.ElementNode || (ofType && c.Data != n.Data) { continue } count++ if c == n { return count == b } if count >= b { return false } } return false } type onlyChildPseudoClassSelector struct { abstractPseudoClass ofType bool } // Match implements :only-child. // If `ofType` is true, it implements :only-of-type instead. func (s onlyChildPseudoClassSelector) Match(n *html.Node) bool { if n.Type != html.ElementNode { return false } parent := n.Parent if parent == nil { return false } count := 0 for c := parent.FirstChild; c != nil; c = c.NextSibling { if (c.Type != html.ElementNode) || (s.ofType && c.Data != n.Data) { continue } count++ if count > 1 { return false } } return count == 1 } type inputPseudoClassSelector struct { abstractPseudoClass } // Matches input, select, textarea and button elements. func (s inputPseudoClassSelector) Match(n *html.Node) bool { return n.Type == html.ElementNode && (n.Data == "input" || n.Data == "select" || n.Data == "textarea" || n.Data == "button") } type emptyElementPseudoClassSelector struct { abstractPseudoClass } // Matches empty elements. func (s emptyElementPseudoClassSelector) Match(n *html.Node) bool { if n.Type != html.ElementNode { return false } for c := n.FirstChild; c != nil; c = c.NextSibling { switch c.Type { case html.ElementNode: return false case html.TextNode: if strings.TrimSpace(nodeText(c)) == "" { continue } else { return false } } } return true } type rootPseudoClassSelector struct { abstractPseudoClass } // Match implements :root func (s rootPseudoClassSelector) Match(n *html.Node) bool { if n.Type != html.ElementNode { return false } if n.Parent == nil { return false } return n.Parent.Type == html.DocumentNode } func hasAttr(n *html.Node, attr string) bool { return matchAttribute(n, attr, func(string) bool { return true }) } type linkPseudoClassSelector struct { abstractPseudoClass } // Match implements :link func (s linkPseudoClassSelector) Match(n *html.Node) bool { return (n.DataAtom == atom.A || n.DataAtom == atom.Area || n.DataAtom == atom.Link) && hasAttr(n, "href") } type langPseudoClassSelector struct { abstractPseudoClass lang string } func (s langPseudoClassSelector) Match(n *html.Node) bool { own := matchAttribute(n, "lang", func(val string) bool { return val == s.lang || strings.HasPrefix(val, s.lang+"-") }) if n.Parent == nil { return own } return own || s.Match(n.Parent) } type enabledPseudoClassSelector struct { abstractPseudoClass } func (s enabledPseudoClassSelector) Match(n *html.Node) bool { if n.Type != html.ElementNode { return false } switch n.DataAtom { case atom.A, atom.Area, atom.Link: return hasAttr(n, "href") case atom.Optgroup, atom.Menuitem, atom.Fieldset: return !hasAttr(n, "disabled") case atom.Button, atom.Input, atom.Select, atom.Textarea, atom.Option: return !hasAttr(n, "disabled") && !inDisabledFieldset(n) } return false } type disabledPseudoClassSelector struct { abstractPseudoClass } func (s disabledPseudoClassSelector) Match(n *html.Node) bool { if n.Type != html.ElementNode { return false } switch n.DataAtom { case atom.Optgroup, atom.Menuitem, atom.Fieldset: return hasAttr(n, "disabled") case atom.Button, atom.Input, atom.Select, atom.Textarea, atom.Option: return hasAttr(n, "disabled") || inDisabledFieldset(n) } return false } func hasLegendInPreviousSiblings(n *html.Node) bool { for s := n.PrevSibling; s != nil; s = s.PrevSibling { if s.DataAtom == atom.Legend { return true } } return false } func inDisabledFieldset(n *html.Node) bool { if n.Parent == nil { return false } if n.Parent.DataAtom == atom.Fieldset && hasAttr(n.Parent, "disabled") && (n.DataAtom != atom.Legend || hasLegendInPreviousSiblings(n)) { return true } return inDisabledFieldset(n.Parent) } type checkedPseudoClassSelector struct { abstractPseudoClass } func (s checkedPseudoClassSelector) Match(n *html.Node) bool { if n.Type != html.ElementNode { return false } switch n.DataAtom { case atom.Input, atom.Menuitem: return hasAttr(n, "checked") && matchAttribute(n, "type", func(val string) bool { t := toLowerASCII(val) return t == "checkbox" || t == "radio" }) case atom.Option: return hasAttr(n, "selected") } return false }
package cascadia import ( "fmt" "regexp" "strings" "golang.org/x/net/html" ) // Matcher is the interface for basic selector functionality. // Match returns whether a selector matches n. type Matcher interface { Match(n *html.Node) bool } // Sel is the interface for all the functionality provided by selectors. type Sel interface { Matcher Specificity() Specificity // Returns a CSS input compiling to this selector. String() string // Returns a pseudo-element, or an empty string. PseudoElement() string } // Parse parses a selector. Use `ParseWithPseudoElement` // if you need support for pseudo-elements. func Parse(sel string) (Sel, error) { p := &parser{s: sel} compiled, err := p.parseSelector() if err != nil { return nil, err } if p.i < len(sel) { return nil, fmt.Errorf("parsing %q: %d bytes left over", sel, len(sel)-p.i) } return compiled, nil } // ParseWithPseudoElement parses a single selector, // with support for pseudo-element. func ParseWithPseudoElement(sel string) (Sel, error) { p := &parser{s: sel, acceptPseudoElements: true} compiled, err := p.parseSelector() if err != nil { return nil, err } if p.i < len(sel) { return nil, fmt.Errorf("parsing %q: %d bytes left over", sel, len(sel)-p.i) } return compiled, nil } // ParseGroup parses a selector, or a group of selectors separated by commas. // Use `ParseGroupWithPseudoElements` // if you need support for pseudo-elements. func ParseGroup(sel string) (SelectorGroup, error) { p := &parser{s: sel} compiled, err := p.parseSelectorGroup() if err != nil { return nil, err } if p.i < len(sel) { return nil, fmt.Errorf("parsing %q: %d bytes left over", sel, len(sel)-p.i) } return compiled, nil } // ParseGroupWithPseudoElements parses a selector, or a group of selectors separated by commas. // It supports pseudo-elements. func ParseGroupWithPseudoElements(sel string) (SelectorGroup, error) { p := &parser{s: sel, acceptPseudoElements: true} compiled, err := p.parseSelectorGroup() if err != nil { return nil, err } if p.i < len(sel) { return nil, fmt.Errorf("parsing %q: %d bytes left over", sel, len(sel)-p.i) } return compiled, nil } // A Selector is a function which tells whether a node matches or not. // // This type is maintained for compatibility; I recommend using the newer and // more idiomatic interfaces Sel and Matcher. type Selector func(*html.Node) bool // Compile parses a selector and returns, if successful, a Selector object // that can be used to match against html.Node objects. func Compile(sel string) (Selector, error) { compiled, err := ParseGroup(sel) if err != nil { return nil, err } return Selector(compiled.Match), nil } // MustCompile is like Compile, but panics instead of returning an error. func MustCompile(sel string) Selector { compiled, err := Compile(sel) if err != nil { panic(err) } return compiled } // MatchAll returns a slice of the nodes that match the selector, // from n and its children. func (s Selector) MatchAll(n *html.Node) []*html.Node { return s.matchAllInto(n, nil) } func (s Selector) matchAllInto(n *html.Node, storage []*html.Node) []*html.Node { if s(n) { storage = append(storage, n) } for child := n.FirstChild; child != nil; child = child.NextSibling { storage = s.matchAllInto(child, storage) } return storage } func queryInto(n *html.Node, m Matcher, storage []*html.Node) []*html.Node { for child := n.FirstChild; child != nil; child = child.NextSibling { if m.Match(child) { storage = append(storage, child) } storage = queryInto(child, m, storage) } return storage } // QueryAll returns a slice of all the nodes that match m, from the descendants // of n. func QueryAll(n *html.Node, m Matcher) []*html.Node { return queryInto(n, m, nil) } // Match returns true if the node matches the selector. func (s Selector) Match(n *html.Node) bool { return s(n) } // MatchFirst returns the first node that matches s, from n and its children. func (s Selector) MatchFirst(n *html.Node) *html.Node { if s.Match(n) { return n } for c := n.FirstChild; c != nil; c = c.NextSibling { m := s.MatchFirst(c) if m != nil { return m } } return nil } // Query returns the first node that matches m, from the descendants of n. // If none matches, it returns nil. func Query(n *html.Node, m Matcher) *html.Node { for c := n.FirstChild; c != nil; c = c.NextSibling { if m.Match(c) { return c } if matched := Query(c, m); matched != nil { return matched } } return nil } // Filter returns the nodes in nodes that match the selector. func (s Selector) Filter(nodes []*html.Node) (result []*html.Node) { for _, n := range nodes { if s(n) { result = append(result, n) } } return result } // Filter returns the nodes that match m. func Filter(nodes []*html.Node, m Matcher) (result []*html.Node) { for _, n := range nodes { if m.Match(n) { result = append(result, n) } } return result } type tagSelector struct { tag string } // Matches elements with a given tag name. func (t tagSelector) Match(n *html.Node) bool { return n.Type == html.ElementNode && n.Data == t.tag } func (c tagSelector) Specificity() Specificity { return Specificity{0, 0, 1} } func (c tagSelector) PseudoElement() string { return "" } type classSelector struct { class string } // Matches elements by class attribute. func (t classSelector) Match(n *html.Node) bool { return matchAttribute(n, "class", func(s string) bool { return matchInclude(t.class, s, false) }) } func (c classSelector) Specificity() Specificity { return Specificity{0, 1, 0} } func (c classSelector) PseudoElement() string { return "" } type idSelector struct { id string } // Matches elements by id attribute. func (t idSelector) Match(n *html.Node) bool { return matchAttribute(n, "id", func(s string) bool { return s == t.id }) } func (c idSelector) Specificity() Specificity { return Specificity{1, 0, 0} } func (c idSelector) PseudoElement() string { return "" } type attrSelector struct { key, val, operation string regexp *regexp.Regexp insensitive bool } // Matches elements by attribute value. func (t attrSelector) Match(n *html.Node) bool { switch t.operation { case "": return matchAttribute(n, t.key, func(string) bool { return true }) case "=": return matchAttribute(n, t.key, func(s string) bool { return matchInsensitiveValue(s, t.val, t.insensitive) }) case "!=": return attributeNotEqualMatch(t.key, t.val, n, t.insensitive) case "~=": // matches elements where the attribute named key is a whitespace-separated list that includes val. return matchAttribute(n, t.key, func(s string) bool { return matchInclude(t.val, s, t.insensitive) }) case "|=": return attributeDashMatch(t.key, t.val, n, t.insensitive) case "^=": return attributePrefixMatch(t.key, t.val, n, t.insensitive) case "$=": return attributeSuffixMatch(t.key, t.val, n, t.insensitive) case "*=": return attributeSubstringMatch(t.key, t.val, n, t.insensitive) case "#=": return attributeRegexMatch(t.key, t.regexp, n) default: panic(fmt.Sprintf("unsuported operation : %s", t.operation)) } } // matches elements where we ignore (or not) the case of the attribute value // the user attribute is the value set by the user to match elements // the real attribute is the attribute value found in the code parsed func matchInsensitiveValue(userAttr string, realAttr string, ignoreCase bool) bool { if ignoreCase { return strings.EqualFold(userAttr, realAttr) } return userAttr == realAttr } // matches elements where the attribute named key satisifes the function f. func matchAttribute(n *html.Node, key string, f func(string) bool) bool { if n.Type != html.ElementNode { return false } for _, a := range n.Attr { if a.Key == key && f(a.Val) { return true } } return false } // attributeNotEqualMatch matches elements where // the attribute named key does not have the value val. func attributeNotEqualMatch(key, val string, n *html.Node, ignoreCase bool) bool { if n.Type != html.ElementNode { return false } for _, a := range n.Attr { if a.Key == key && matchInsensitiveValue(a.Val, val, ignoreCase) { return false } } return true } // returns true if s is a whitespace-separated list that includes val. func matchInclude(val string, s string, ignoreCase bool) bool { for s != "" { i := strings.IndexAny(s, " \t\r\n\f") if i == -1 { return matchInsensitiveValue(s, val, ignoreCase) } if matchInsensitiveValue(s[:i], val, ignoreCase) { return true } s = s[i+1:] } return false } // matches elements where the attribute named key equals val or starts with val plus a hyphen. func attributeDashMatch(key, val string, n *html.Node, ignoreCase bool) bool { return matchAttribute(n, key, func(s string) bool { if matchInsensitiveValue(s, val, ignoreCase) { return true } if len(s) <= len(val) { return false } if matchInsensitiveValue(s[:len(val)], val, ignoreCase) && s[len(val)] == '-' { return true } return false }) } // attributePrefixMatch returns a Selector that matches elements where // the attribute named key starts with val. func attributePrefixMatch(key, val string, n *html.Node, ignoreCase bool) bool { return matchAttribute(n, key, func(s string) bool { if strings.TrimSpace(s) == "" { return false } if ignoreCase { return strings.HasPrefix(strings.ToLower(s), strings.ToLower(val)) } return strings.HasPrefix(s, val) }) } // attributeSuffixMatch matches elements where // the attribute named key ends with val. func attributeSuffixMatch(key, val string, n *html.Node, ignoreCase bool) bool { return matchAttribute(n, key, func(s string) bool { if strings.TrimSpace(s) == "" { return false } if ignoreCase { return strings.HasSuffix(strings.ToLower(s), strings.ToLower(val)) } return strings.HasSuffix(s, val) }) } // attributeSubstringMatch matches nodes where // the attribute named key contains val. func attributeSubstringMatch(key, val string, n *html.Node, ignoreCase bool) bool { return matchAttribute(n, key, func(s string) bool { if strings.TrimSpace(s) == "" { return false } if ignoreCase { return strings.Contains(strings.ToLower(s), strings.ToLower(val)) } return strings.Contains(s, val) }) } // attributeRegexMatch matches nodes where // the attribute named key matches the regular expression rx func attributeRegexMatch(key string, rx *regexp.Regexp, n *html.Node) bool { return matchAttribute(n, key, func(s string) bool { return rx.MatchString(s) }) } func (c attrSelector) Specificity() Specificity { return Specificity{0, 1, 0} } func (c attrSelector) PseudoElement() string { return "" } // see pseudo_classes.go for pseudo classes selectors // on a static context, some selectors can't match anything type neverMatchSelector struct { value string } func (s neverMatchSelector) Match(n *html.Node) bool { return false } func (s neverMatchSelector) Specificity() Specificity { return Specificity{0, 0, 0} } func (c neverMatchSelector) PseudoElement() string { return "" } type compoundSelector struct { selectors []Sel pseudoElement string } // Matches elements if each sub-selectors matches. func (t compoundSelector) Match(n *html.Node) bool { if len(t.selectors) == 0 { return n.Type == html.ElementNode } for _, sel := range t.selectors { if !sel.Match(n) { return false } } return true } func (s compoundSelector) Specificity() Specificity { var out Specificity for _, sel := range s.selectors { out = out.Add(sel.Specificity()) } if s.pseudoElement != "" { // https://drafts.csswg.org/selectors-3/#specificity out = out.Add(Specificity{0, 0, 1}) } return out } func (c compoundSelector) PseudoElement() string { return c.pseudoElement } type combinedSelector struct { first Sel combinator byte second Sel } func (t combinedSelector) Match(n *html.Node) bool { if t.first == nil { return false // maybe we should panic } switch t.combinator { case 0: return t.first.Match(n) case ' ': return descendantMatch(t.first, t.second, n) case '>': return childMatch(t.first, t.second, n) case '+': return siblingMatch(t.first, t.second, true, n) case '~': return siblingMatch(t.first, t.second, false, n) default: panic("unknown combinator") } } // matches an element if it matches d and has an ancestor that matches a. func descendantMatch(a, d Matcher, n *html.Node) bool { if !d.Match(n) { return false } for p := n.Parent; p != nil; p = p.Parent { if a.Match(p) { return true } } return false } // matches an element if it matches d and its parent matches a. func childMatch(a, d Matcher, n *html.Node) bool { return d.Match(n) && n.Parent != nil && a.Match(n.Parent) } // matches an element if it matches s2 and is preceded by an element that matches s1. // If adjacent is true, the sibling must be immediately before the element. func siblingMatch(s1, s2 Matcher, adjacent bool, n *html.Node) bool { if !s2.Match(n) { return false } if adjacent { for n = n.PrevSibling; n != nil; n = n.PrevSibling { if n.Type == html.TextNode || n.Type == html.CommentNode { continue } return s1.Match(n) } return false } // Walk backwards looking for element that matches s1 for c := n.PrevSibling; c != nil; c = c.PrevSibling { if s1.Match(c) { return true } } return false } func (s combinedSelector) Specificity() Specificity { spec := s.first.Specificity() if s.second != nil { spec = spec.Add(s.second.Specificity()) } return spec } // on combinedSelector, a pseudo-element only makes sens on the last // selector, although others increase specificity. func (c combinedSelector) PseudoElement() string { if c.second == nil { return "" } return c.second.PseudoElement() } // A SelectorGroup is a list of selectors, which matches if any of the // individual selectors matches. type SelectorGroup []Sel // Match returns true if the node matches one of the single selectors. func (s SelectorGroup) Match(n *html.Node) bool { for _, sel := range s { if sel.Match(n) { return true } } return false }
package cascadia import ( "fmt" "strconv" "strings" ) // implements the reverse operation Sel -> string var specialCharReplacer *strings.Replacer func init() { var pairs []string for _, s := range ",!\"#$%&'()*+ -./:;<=>?@[\\]^`{|}~" { pairs = append(pairs, string(s), "\\"+string(s)) } specialCharReplacer = strings.NewReplacer(pairs...) } // espace special CSS char func escape(s string) string { return specialCharReplacer.Replace(s) } func (c tagSelector) String() string { return c.tag } func (c idSelector) String() string { return "#" + escape(c.id) } func (c classSelector) String() string { return "." + escape(c.class) } func (c attrSelector) String() string { val := c.val if c.operation == "#=" { val = c.regexp.String() } else if c.operation != "" { val = fmt.Sprintf(`"%s"`, val) } ignoreCase := "" if c.insensitive { ignoreCase = " i" } return fmt.Sprintf(`[%s%s%s%s]`, c.key, c.operation, val, ignoreCase) } func (c relativePseudoClassSelector) String() string { return fmt.Sprintf(":%s(%s)", c.name, c.match.String()) } func (c containsPseudoClassSelector) String() string { s := "contains" if c.own { s += "Own" } return fmt.Sprintf(`:%s("%s")`, s, c.value) } func (c regexpPseudoClassSelector) String() string { s := "matches" if c.own { s += "Own" } return fmt.Sprintf(":%s(%s)", s, c.regexp.String()) } func (c nthPseudoClassSelector) String() string { if c.a == 0 && c.b == 1 { // special cases s := ":first-" if c.last { s = ":last-" } if c.ofType { s += "of-type" } else { s += "child" } return s } var name string switch [2]bool{c.last, c.ofType} { case [2]bool{true, true}: name = "nth-last-of-type" case [2]bool{true, false}: name = "nth-last-child" case [2]bool{false, true}: name = "nth-of-type" case [2]bool{false, false}: name = "nth-child" } s := fmt.Sprintf("+%d", c.b) if c.b < 0 { // avoid +-8 invalid syntax s = strconv.Itoa(c.b) } return fmt.Sprintf(":%s(%dn%s)", name, c.a, s) } func (c onlyChildPseudoClassSelector) String() string { if c.ofType { return ":only-of-type" } return ":only-child" } func (c inputPseudoClassSelector) String() string { return ":input" } func (c emptyElementPseudoClassSelector) String() string { return ":empty" } func (c rootPseudoClassSelector) String() string { return ":root" } func (c linkPseudoClassSelector) String() string { return ":link" } func (c langPseudoClassSelector) String() string { return fmt.Sprintf(":lang(%s)", c.lang) } func (c neverMatchSelector) String() string { return c.value } func (c enabledPseudoClassSelector) String() string { return ":enabled" } func (c disabledPseudoClassSelector) String() string { return ":disabled" } func (c checkedPseudoClassSelector) String() string { return ":checked" } func (c compoundSelector) String() string { if len(c.selectors) == 0 && c.pseudoElement == "" { return "*" } chunks := make([]string, len(c.selectors)) for i, sel := range c.selectors { chunks[i] = sel.String() } s := strings.Join(chunks, "") if c.pseudoElement != "" { s += "::" + c.pseudoElement } return s } func (c combinedSelector) String() string { start := c.first.String() if c.second != nil { start += fmt.Sprintf(" %s %s", string(c.combinator), c.second.String()) } return start } func (c SelectorGroup) String() string { ck := make([]string, len(c)) for i, s := range c { ck[i] = s.String() } return strings.Join(ck, ", ") }
package cascadia // Specificity is the CSS specificity as defined in // https://www.w3.org/TR/selectors/#specificity-rules // with the convention Specificity = [A,B,C]. type Specificity [3]int // returns `true` if s < other (strictly), false otherwise func (s Specificity) Less(other Specificity) bool { for i := range s { if s[i] < other[i] { return true } if s[i] > other[i] { return false } } return false } func (s Specificity) Add(other Specificity) Specificity { for i, sp := range other { s[i] += sp } return s }