// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package credentials implements our struct stored in keychain.
// Store struct is kind of like a database client.
// Credentials struct is kind of like one record from the database.
package credentials
import (
"encoding/base64"
"errors"
"fmt"
"strings"
"github.com/sirupsen/logrus"
)
const (
sep = "\x00"
itemLengthBridge = 9
itemLengthImportExport = 6 // Old format for Import-Export.
)
var (
log = logrus.WithField("pkg", "credentials") //nolint:gochecknoglobals
ErrWrongFormat = errors.New("malformed credentials")
)
type Credentials struct {
UserID, // Do not marshal; used as a key.
Name,
Emails,
APIToken string
MailboxPassword []byte
BridgePassword,
Version string
Timestamp int64
IsHidden, // Deprecated.
IsCombinedAddressMode bool
}
func (s *Credentials) Marshal() string {
items := []string{
s.Name, // 0
s.Emails, // 1
s.APIToken, // 2
string(s.MailboxPassword), // 3
s.BridgePassword, // 4
s.Version, // 5
"", // 6
"", // 7
"", // 8
}
items[6] = fmt.Sprint(s.Timestamp)
if s.IsHidden {
items[7] = "1"
}
if s.IsCombinedAddressMode {
items[8] = "1"
}
str := strings.Join(items, sep)
return base64.StdEncoding.EncodeToString([]byte(str))
}
func (s *Credentials) Unmarshal(secret string) error {
b, err := base64.StdEncoding.DecodeString(secret)
if err != nil {
return err
}
items := strings.Split(string(b), sep)
if len(items) != itemLengthBridge && len(items) != itemLengthImportExport {
return ErrWrongFormat
}
s.Name = items[0]
s.Emails = items[1]
s.APIToken = items[2]
s.MailboxPassword = []byte(items[3])
switch len(items) {
case itemLengthBridge:
s.BridgePassword = items[4]
s.Version = items[5]
if _, err = fmt.Sscan(items[6], &s.Timestamp); err != nil {
s.Timestamp = 0
}
if s.IsHidden = false; items[7] == "1" {
s.IsHidden = true
}
if s.IsCombinedAddressMode = false; items[8] == "1" {
s.IsCombinedAddressMode = true
}
case itemLengthImportExport:
s.Version = items[4]
if _, err = fmt.Sscan(items[5], &s.Timestamp); err != nil {
s.Timestamp = 0
}
}
return nil
}
func (s *Credentials) EmailList() []string {
return strings.Split(s.Emails, ";")
}
func (s *Credentials) SplitAPIToken() (string, string, error) {
split := strings.Split(s.APIToken, ":")
if len(split) != 2 {
return "", "", errors.New("malformed API token")
}
return split[0], split[1], nil
}
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package credentials
import (
"errors"
"fmt"
"sort"
"sync"
)
var storeLocker = sync.RWMutex{} //nolint:gochecknoglobals
// Store is an encrypted credentials store.
type Store struct {
secrets Keychain
}
type Keychain interface {
List() ([]string, error)
Get(string) (string, string, error)
Put(string, string) error
Delete(string) error
}
// NewStore creates a new encrypted credentials store.
func NewStore(keychain Keychain) *Store {
return &Store{secrets: keychain}
}
// List returns a list of usernames that have credentials stored.
func (s *Store) List() (userIDs []string, err error) {
storeLocker.RLock()
defer storeLocker.RUnlock()
log.Trace("Listing credentials in credentials store")
var allUserIDs []string
if allUserIDs, err = s.secrets.List(); err != nil {
log.WithError(err).Error("Could not list credentials")
return nil, err
}
credentialList := []*Credentials{}
for _, userID := range allUserIDs {
creds, getErr := s.get(userID)
if getErr != nil {
log.WithField("userID", userID).WithError(getErr).Warn("Failed to get credentials")
continue
}
// Disabled credentials
if creds.Timestamp == 0 {
continue
}
credentialList = append(credentialList, creds)
}
sort.Slice(credentialList, func(i, j int) bool {
return credentialList[i].Timestamp < credentialList[j].Timestamp
})
for _, credentials := range credentialList {
userIDs = append(userIDs, credentials.UserID)
}
return userIDs, err
}
func (s *Store) Get(userID string) (creds *Credentials, err error) {
storeLocker.RLock()
defer storeLocker.RUnlock()
return s.get(userID)
}
func (s *Store) get(userID string) (*Credentials, error) {
log := log.WithField("user", userID)
_, secret, err := s.secrets.Get(userID)
if err != nil {
return nil, err
}
if secret == "" {
return nil, errors.New("secret is empty")
}
credentials := &Credentials{UserID: userID}
if err := credentials.Unmarshal(secret); err != nil {
log.WithError(fmt.Errorf("malformed secret: %w", err)).Error("Could not unmarshal secret")
if err := s.secrets.Delete(userID); err != nil {
log.WithError(err).Error("Failed to remove malformed secret")
}
return nil, err
}
return credentials, nil
}
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmmime
import (
"fmt"
"io"
"mime"
"regexp"
"strings"
"unicode/utf8"
"github.com/ProtonMail/gluon/rfc5322"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"golang.org/x/net/html/charset"
"golang.org/x/text/encoding"
"golang.org/x/text/encoding/htmlindex"
)
var EmptyContentTypeErr = errors.New("empty content type")
func init() {
rfc822.ParseMediaType = ParseMediaType
proton.CharsetReader = CharsetReader
rfc5322.CharsetReader = CharsetReader
}
func CharsetReader(charset string, input io.Reader) (io.Reader, error) {
dec, err := SelectDecoder(charset)
if err != nil {
return nil, err
}
if dec == nil { // utf-8
return input, nil
}
return dec.Reader(input), nil
}
var WordDec = &mime.WordDecoder{
CharsetReader: CharsetReader,
}
// Expects trimmed lowercase.
func getEncoding(charset string) (enc encoding.Encoding, err error) {
preparsed := strings.Trim(strings.ToLower(charset), " \t\r\n")
// koi
re := regexp.MustCompile("(cs)?koi[-_ ]?8?[-_ ]?(r|ru|u|uk)?$")
matches := re.FindAllStringSubmatch(preparsed, -1)
if len(matches) == 1 && len(matches[0]) == 3 {
preparsed = "koi8-"
switch matches[0][2] {
case "u", "uk":
preparsed += "u"
default:
preparsed += "r"
}
}
// windows-XXXX
re = regexp.MustCompile("(cp|(cs)?win(dows)?)[-_ ]?([0-9]{3,4})$")
matches = re.FindAllStringSubmatch(preparsed, -1)
if len(matches) == 1 && len(matches[0]) == 5 {
switch matches[0][4] {
case "874", "1250", "1251", "1252", "1253", "1254", "1255", "1256", "1257", "1258":
preparsed = "windows-" + matches[0][4]
}
}
// iso
re = regexp.MustCompile("iso[-_ ]?([0-9]{4})[-_ ]?([0-9]+|jp)?[-_ ]?(i|e)?")
matches = re.FindAllStringSubmatch(preparsed, -1)
if len(matches) == 1 && len(matches[0]) == 4 {
if matches[0][1] == "2022" && matches[0][2] == "jp" {
preparsed = "iso-2022-jp"
}
if matches[0][1] == "8859" {
switch matches[0][2] {
case "1", "2", "3", "4", "5", "7", "8", "9", "10", "11", "13", "14", "15", "16":
preparsed = "iso-8859-" + matches[0][2]
if matches[0][3] == "i" {
preparsed += "-" + matches[0][3]
}
case "":
preparsed = "iso-8859-1"
}
}
}
// Latin is tricky.
re = regexp.MustCompile("^(cs|csiso)?l(atin)?[-_ ]?([0-9]{1,2})$")
matches = re.FindAllStringSubmatch(preparsed, -1)
if len(matches) == 1 && len(matches[0]) == 4 {
switch matches[0][3] {
case "1":
preparsed = "windows-1252"
case "2", "3", "4", "5":
preparsed = "iso-8859-" + matches[0][3]
case "6":
preparsed = "iso-8859-10"
case "8":
preparsed = "iso-8859-14"
case "9":
preparsed = "iso-8859-15"
case "10":
preparsed = "iso-8859-16"
}
}
// Missing substitutions.
switch preparsed {
case "csutf8", "iso-utf-8", "utf8mb4":
preparsed = "utf-8"
case "cp932", "windows-932", "windows-31J", "ibm-943", "cp943":
preparsed = "shift_jis"
case "eucjp", "ibm-eucjp":
preparsed = "euc-jp"
case "euckr", "ibm-euckr", "cp949":
preparsed = "euc-kr"
case "euccn", "ibm-euccn":
preparsed = "gbk"
case "zht16mswin950", "cp950":
preparsed = "big5"
case "csascii",
"ansi_x3.4-1968",
"ansi_x3.4-1986",
"ansi_x3.110-1983",
"cp850",
"cp858",
"us",
"iso646",
"iso-646",
"iso646-us",
"iso_646.irv:1991",
"cp367",
"ibm367",
"ibm-367",
"iso-ir-6":
preparsed = "ascii"
case "ibm852":
preparsed = "iso-8859-2"
case "iso-ir-199", "iso-celtic":
preparsed = "iso-8859-14"
case "iso-ir-226":
preparsed = "iso-8859-16"
case "macroman":
preparsed = "macintosh"
}
enc, _ = htmlindex.Get(preparsed)
if enc == nil {
err = fmt.Errorf("can not get encoding for '%s' (or '%s')", charset, preparsed)
}
return
}
func SelectDecoder(charset string) (decoder *encoding.Decoder, err error) {
var enc encoding.Encoding
lcharset := strings.Trim(strings.ToLower(charset), " \t\r\n")
switch lcharset {
case "utf7", "utf-7", "unicode-1-1-utf-7":
return NewUtf7Decoder(), nil
default:
enc, err = getEncoding(lcharset)
}
if err == nil {
decoder = enc.NewDecoder()
}
return
}
// DecodeHeader if needed. Returns error if raw contains non-utf8 characters.
func DecodeHeader(raw string) (decoded string, err error) {
if decoded, err = WordDec.DecodeHeader(raw); err != nil {
decoded = raw
}
if !utf8.ValidString(decoded) {
err = fmt.Errorf("header contains non utf8 chars: %v", err)
}
return
}
// EncodeHeader using quoted printable and utf8.
func EncodeHeader(s string) string {
return mime.QEncoding.Encode("utf-8", s)
}
// DecodeCharset decodes the original using content type parameters.
// If the charset parameter is missing it checks that the content is valid utf8.
// If it isn't, it checks if it's embedded in the html/xml.
// If it isn't, it falls back to windows-1252.
// It then reencodes it as utf-8.
func DecodeCharset(original []byte, contentType string) ([]byte, error) {
// If the contentType itself is specified, use that.
if contentType != "" {
_, params, err := ParseMediaType(contentType)
if err != nil {
return nil, err
}
if charset, ok := params["charset"]; ok {
decoder, err := SelectDecoder(charset)
if err != nil {
return original, errors.Wrap(err, "unknown charset was specified")
}
return decoder.Bytes(original)
}
}
// The charset was not specified. First try utf8.
if utf8.Valid(original) {
return original, nil
}
// encoding will be windows-1252 if it can't be determined properly.
encoding, name, certain := charset.DetermineEncoding(original, contentType)
if !certain {
logrus.WithField("encoding", name).Warn("Determined encoding but was not certain")
}
// Re-encode as UTF-8.
decoded, err := encoding.NewDecoder().Bytes(original)
if err != nil {
return original, errors.Wrap(err, "failed to decode as windows-1252")
}
// If the decoded string is not valid utf8, it wasn't windows-1252, so give up.
if !utf8.Valid(decoded) {
return original, errors.Wrap(err, "failed to decode as windows-1252")
}
return decoded, nil
}
// ParseMediaType from MIME doesn't support RFC2231 for non asci / utf8 encodings so we have to pre-parse it.
func ParseMediaType(v string) (string, map[string]string, error) {
if v == "" {
return "", nil, EmptyContentTypeErr
}
decoded, err := DecodeHeader(v)
if err != nil {
logrus.WithField("value", v).WithField("pkg", "mime").WithError(err).Error("Cannot decode Headers.")
return "", nil, err
}
v, _ = changeEncodingAndKeepLastParamDefinition(decoded)
mediatype, params, err := mime.ParseMediaType(v)
if err != nil {
logrus.WithField("value", v).WithField("pkg", "mime").WithError(err).Error("Media Type parsing error.")
return "", nil, err
}
return mediatype, params, err
}
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmmime
import (
"errors"
"fmt"
"strings"
"unicode"
"github.com/sirupsen/logrus"
)
// changeEncodingAndKeepLastParamDefinition is necessary to modify behaviour
// provided by the golang standard libraries.
func changeEncodingAndKeepLastParamDefinition(v string) (out string, err error) {
log := logrus.WithField("pkg", "pm-mime")
out = v // By default don't do anything with that.
keepOrig := true
i := strings.Index(v, ";")
if i == -1 {
i = len(v)
}
mediatype := strings.TrimSpace(strings.ToLower(v[0:i]))
params := map[string]string{}
var continuation map[string]map[string]string
v = v[i:]
for len(v) > 0 {
v = strings.TrimLeftFunc(v, unicode.IsSpace)
if len(v) == 0 {
break
}
key, value, rest := consumeMediaParam(v)
if key == "" {
break
}
pmap := params
if idx := strings.Index(key, "*"); idx != -1 {
baseName := key[:idx]
if continuation == nil {
continuation = make(map[string]map[string]string)
}
var ok bool
if pmap, ok = continuation[baseName]; !ok {
continuation[baseName] = make(map[string]string)
pmap = continuation[baseName]
}
if isFirstContinuation(key) {
charset, _, err := get2231Charset(value)
if err != nil {
log.Errorln("Filter params:", err)
v = rest
continue
}
if charset != "utf-8" && charset != "us-ascii" {
keepOrig = false
}
}
}
if _, exists := pmap[key]; exists {
keepOrig = false
}
pmap[key] = value
v = rest
}
if keepOrig {
return
}
for paramKey, contMap := range continuation {
value, err := mergeContinuations(paramKey, contMap)
if err == nil {
params[paramKey+"*"] = value
continue
}
// Fallback.
log.Errorln("Merge param", paramKey, ":", err)
for ck, cv := range contMap {
params[ck] = cv
}
}
// Merge ;
out = mediatype
for k, v := range params {
out += ";"
out += k
out += "="
out += v
}
return
}
func isFirstContinuation(key string) bool {
if idx := strings.Index(key, "*"); idx != -1 {
return key[idx:] == "*" || key[idx:] == "*0*"
}
return false
}
// get2231Charset partially from mime/mediatype.go:211 function `decode2231Enc`.
func get2231Charset(v string) (charset, value string, err error) {
sv := strings.SplitN(v, "'", 3)
if len(sv) != 3 {
err = errors.New("incorrect RFC2231 charset format")
return
}
charset = strings.ToLower(sv[0])
value = sv[2]
return
}
func mergeContinuations(paramKey string, contMap map[string]string) (string, error) {
var err error
var charset, value string
// Single value.
if contValue, ok := contMap[paramKey+"*"]; ok {
if charset, value, err = get2231Charset(contValue); err != nil {
return "", err
}
} else {
for n := 0; ; n++ {
contKey := fmt.Sprintf("%s*%d", paramKey, n)
contValue, isLast := contMap[contKey]
if !isLast {
var ok bool
contValue, ok = contMap[contKey+"*"]
if !ok {
return "", errors.New("not valid RFC2231 continuation")
}
}
if n == 0 {
if charset, value, err = get2231Charset(contValue); err != nil || charset == "" {
return "", err
}
} else {
value += contValue
}
if isLast {
break
}
}
}
return convertHexToUTF(charset, value)
}
// convertHexToUTF converts hex values string with charset to UTF8 in RFC2231 format.
func convertHexToUTF(charset, value string) (string, error) {
raw, err := percentHexUnescape(value)
if err != nil {
return "", err
}
utf8, err := DecodeCharset(raw, "text/plain; charset="+charset)
return "utf-8''" + percentHexEscape(utf8), err
}
// consumeMediaParam copy paste mime/mediatype.go:297.
func consumeMediaParam(v string) (param, value, rest string) {
rest = strings.TrimLeftFunc(v, unicode.IsSpace)
if !strings.HasPrefix(rest, ";") {
return "", "", v
}
rest = rest[1:] // Consume semicolon.
rest = strings.TrimLeftFunc(rest, unicode.IsSpace)
param, rest = consumeToken(rest)
param = strings.ToLower(param)
if param == "" {
return "", "", v
}
rest = strings.TrimLeftFunc(rest, unicode.IsSpace)
if !strings.HasPrefix(rest, "=") {
return "", "", v
}
rest = rest[1:] // Consume equals sign.
rest = strings.TrimLeftFunc(rest, unicode.IsSpace)
value, rest2 := consumeValue(rest)
if value == "" && rest2 == rest {
return "", "", v
}
rest = rest2
return param, value, rest
}
// consumeToken copy paste mime/mediatype.go:238.
// consumeToken consumes a token from the beginning of the provided string,
// per RFC 2045 section 5.1 (referenced from 2183), and returns
// the token consumed and the rest of the string.
// Returns ("", v) on failure to consume at least one character.
func consumeToken(v string) (token, rest string) {
notPos := strings.IndexFunc(v, isNotTokenChar)
if notPos == -1 {
return v, ""
}
if notPos == 0 {
return "", v
}
return v[0:notPos], v[notPos:]
}
// consumeValue copy paste mime/mediatype.go:253
// consumeValue consumes a "value" per RFC 2045, where a value is
// either a 'token' or a 'quoted-string'. On success, consumeValue
// returns the value consumed (and de-quoted/escaped, if a
// quoted-string) and the rest of the string.
// On failure, returns ("", v).
func consumeValue(v string) (value, rest string) {
if v == "" {
return
}
if v[0] != '"' {
return consumeToken(v)
}
// parse a quoted-string
buffer := new(strings.Builder)
for i := 1; i < len(v); i++ {
r := v[i]
if r == '"' {
return buffer.String(), v[i+1:]
}
// When MSIE sends a full file path (in "intranet mode"), it does not
// escape backslashes: "C:\dev\go\foo.txt", not "C:\\dev\\go\\foo.txt".
//
// No known MIME generators emit unnecessary backslash escapes
// for simple token characters like numbers and letters.
//
// If we see an unnecessary backslash escape, assume it is from MSIE
// and intended as a literal backslash. This makes Go servers deal better
// with MSIE without affecting the way they handle conforming MIME
// generators.
if r == '\\' && i+1 < len(v) && !isTokenChar(rune(v[i+1])) {
buffer.WriteByte(v[i+1])
i++
continue
}
if r == '\r' || r == '\n' {
return "", v
}
buffer.WriteByte(v[i])
}
// Did not find end quote.
return "", v
}
// isNotTokenChar copy paste from mime/mediatype.go:234.
func isNotTokenChar(r rune) bool {
return !isTokenChar(r)
}
// isTokenChar copy paste from mime/grammar.go:19.
// isTokenChar reports whether rune is in 'token' as defined by RFC 1521 and RFC 2045.
func isTokenChar(r rune) bool {
// token := 1*<any (US-ASCII) CHAR except SPACE, CTLs,
// or tspecials>
return r > 0x20 && r < 0x7f && !isTSpecial(r)
}
// isTSpecial copy paste from mime/grammar.go:13
// isTSpecial reports whether rune is in 'tspecials' as defined by RFC
// 1521 and RFC 2045.
func isTSpecial(r rune) bool {
return strings.ContainsRune(`()<>@,;:\"/[]?=`, r)
}
func percentHexEscape(raw []byte) (out string) {
for _, v := range raw {
out += fmt.Sprintf("%%%x", v)
}
return
}
// percentHexUnescape copy paste from mime/mediatype.go:325.
func percentHexUnescape(s string) ([]byte, error) {
// Count %, check that they're well-formed.
percents := 0
for i := 0; i < len(s); {
if s[i] != '%' {
i++
continue
}
percents++
if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
s = s[i:]
if len(s) > 3 {
s = s[0:3]
}
return []byte{}, fmt.Errorf("mime: bogus characters after %%: %q", s)
}
i += 3
}
if percents == 0 {
return []byte(s), nil
}
t := make([]byte, len(s)-2*percents)
j := 0
for i := 0; i < len(s); {
switch s[i] {
case '%':
t[j] = unhex(s[i+1])<<4 | unhex(s[i+2])
j++
i += 3
default:
t[j] = s[i]
j++
i++
}
}
return t, nil
}
// ishex copy paste from mime/mediatype.go:364.
func ishex(c byte) bool {
switch {
case '0' <= c && c <= '9':
return true
case 'a' <= c && c <= 'f':
return true
case 'A' <= c && c <= 'F':
return true
}
return false
}
// unhex copy paste from mime/mediatype.go:376.
func unhex(c byte) byte {
switch {
case '0' <= c && c <= '9':
return c - '0'
case 'a' <= c && c <= 'f':
return c - 'a' + 10
case 'A' <= c && c <= 'F':
return c - 'A' + 10
}
return 0
}
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmmime
import (
"encoding/base64"
"errors"
"unicode/utf16"
"unicode/utf8"
"golang.org/x/text/encoding"
"golang.org/x/text/transform"
)
// utf7Decoder copied from: https://github.com/cention-sany/utf7/blob/master/utf7.go
// We need `encoding.Decoder` instead of function `UTF7DecodeBytes`.
type utf7Decoder struct {
transform.NopResetter
}
// NewUtf7Decoder returns a new decoder for utf7.
func NewUtf7Decoder() *encoding.Decoder {
return &encoding.Decoder{Transformer: utf7Decoder{}}
}
const (
uRepl = '\uFFFD' // Unicode replacement code point
u7min = 0x20 // Minimum self-representing UTF-7 value
u7max = 0x7E // Maximum self-representing UTF-7 value
)
// ErrBadUTF7 is returned to indicate the invalid modified UTF-7 encoding.
var ErrBadUTF7 = errors.New("utf7: bad utf-7 encoding")
const modifiedbase64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
var u7enc = base64.NewEncoding(modifiedbase64)
func isModifiedBase64(r byte) bool {
if r >= 'A' && r <= 'Z' {
return true
} else if r >= 'a' && r <= 'z' {
return true
} else if r >= '0' && r <= '9' {
return true
} else if r == '+' || r == '/' {
return true
}
return false
}
func (d utf7Decoder) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
var implicit bool
var tmp int
nd, n := len(dst), len(src)
if n == 0 && !atEOF {
return 0, 0, transform.ErrShortSrc
}
for ; nSrc < n; nSrc++ {
if nDst >= nd {
return nDst, nSrc, transform.ErrShortDst
}
if c := src[nSrc]; ((c < u7min || c > u7max) &&
c != '\t' && c != '\r' && c != '\n') ||
c == '~' || c == '\\' {
return nDst, nSrc, ErrBadUTF7 // Illegal code point in ASCII mode.
} else if c != '+' {
dst[nDst] = c // Character is self-representing.
nDst++
continue
}
// Found '+'.
start := nSrc + 1
tmp = nSrc // nSrc still points to '+', tmp points to the end of BASE64.
// Find the end of the Base64 or "+-" segment.
implicit = false
for tmp++; tmp < n && src[tmp] != '-'; tmp++ {
if !isModifiedBase64(src[tmp]) {
if tmp == start {
return nDst, tmp, ErrBadUTF7 // '+' next char must modified base64.
}
// Implicit shift back to ASCII, no need for '-' character.
implicit = true
break
}
}
if tmp == start {
if tmp == n {
// Did not find '-' sign and '+' is the last character.
// Total nSrc does not include '+'.
if atEOF {
return nDst, nSrc, ErrBadUTF7 // '+' can not be at the end.
}
// '+' can not be at the end, the source is too short.
return nDst, nSrc, transform.ErrShortSrc
}
dst[nDst] = '+' // Escape sequence "+-".
nDst++
} else if tmp == n && !atEOF {
// No EOF found, the source is too short.
return nDst, nSrc, transform.ErrShortSrc
} else if b := utf7dec(src[start:tmp]); len(b) > 0 {
if len(b)+nDst > nd {
// Need more space in dst for the decoded modified BASE64 unicode.
// Total nSrc does not include '+'.
return nDst, nSrc, transform.ErrShortDst
}
copy(dst[nDst:], b) // Control or non-ASCII code points in Base64.
nDst += len(b)
if implicit {
if nDst >= nd {
return nDst, tmp, transform.ErrShortDst
}
dst[nDst] = src[tmp] // Implicit shift.
nDst++
}
if tmp == n {
return nDst, tmp, nil
}
} else {
return nDst, nSrc, ErrBadUTF7 // Bad encoding.
}
nSrc = tmp
}
return
}
// utf7dec extracts UTF-16-BE bytes from Base64 data and converts them to UTF-8.
// A nil slice is returned if the encoding is invalid.
func utf7dec(b64 []byte) []byte {
var b []byte
// Allocate a single block of memory large enough to store the Base64 data
// (if padding is required), UTF-16-BE bytes, and decoded UTF-8 bytes.
// Since a 2-byte UTF-16 sequence may expand into a 3-byte UTF-8 sequence,
// double the space allocation for UTF-8.
if n := len(b64); b64[n-1] == '=' {
return nil
} else if n&3 == 0 {
b = make([]byte, u7enc.DecodedLen(n)*3)
} else {
n += 4 - n&3
b = make([]byte, n+u7enc.DecodedLen(n)*3)
copy(b[copy(b, b64):n], []byte("=="))
b64, b = b[:n], b[n:]
}
// Decode Base64 into the first 1/3rd of b.
n, err := u7enc.Decode(b, b64)
if err != nil || n&1 == 1 {
return nil
}
// Decode UTF-16-BE into the remaining 2/3rds of b.
b, s := b[:n], b[n:]
j := 0
for i := 0; i < n; i += 2 {
r := rune(b[i])<<8 | rune(b[i+1])
if utf16.IsSurrogate(r) {
if i += 2; i == n {
return nil
}
r2 := rune(b[i])<<8 | rune(b[i+1])
if r = utf16.DecodeRune(r, r2); r == uRepl {
return nil
}
}
j += utf8.EncodeRune(s[j:], r)
}
return s[:j]
}