1# defusedxml
2#
3# Copyright (c) 2013-2020 by Christian Heimes <christian@python.org>
4# Licensed to PSF under a Contributor Agreement.
5# See https://www.python.org/psf/license for licensing details.
6"""Common constants, exceptions and helper functions
7"""
8import sys
9import xml.parsers.expat
10
11PY3 = True
12
13# Fail early when pyexpat is not installed correctly
14if not hasattr(xml.parsers.expat, "ParserCreate"):
15 raise ImportError("pyexpat") # pragma: no cover
16
17
18class DefusedXmlException(ValueError):
19 """Base exception"""
20
21 def __repr__(self):
22 return str(self)
23
24
25class DTDForbidden(DefusedXmlException):
26 """Document type definition is forbidden"""
27
28 def __init__(self, name, sysid, pubid):
29 super().__init__()
30 self.name = name
31 self.sysid = sysid
32 self.pubid = pubid
33
34 def __str__(self):
35 tpl = "DTDForbidden(name='{}', system_id={!r}, public_id={!r})"
36 return tpl.format(self.name, self.sysid, self.pubid)
37
38
39class EntitiesForbidden(DefusedXmlException):
40 """Entity definition is forbidden"""
41
42 def __init__(self, name, value, base, sysid, pubid, notation_name):
43 super().__init__()
44 self.name = name
45 self.value = value
46 self.base = base
47 self.sysid = sysid
48 self.pubid = pubid
49 self.notation_name = notation_name
50
51 def __str__(self):
52 tpl = "EntitiesForbidden(name='{}', system_id={!r}, public_id={!r})"
53 return tpl.format(self.name, self.sysid, self.pubid)
54
55
56class ExternalReferenceForbidden(DefusedXmlException):
57 """Resolving an external reference is forbidden"""
58
59 def __init__(self, context, base, sysid, pubid):
60 super().__init__()
61 self.context = context
62 self.base = base
63 self.sysid = sysid
64 self.pubid = pubid
65
66 def __str__(self):
67 tpl = "ExternalReferenceForbidden(system_id='{}', public_id={})"
68 return tpl.format(self.sysid, self.pubid)
69
70
71class NotSupportedError(DefusedXmlException):
72 """The operation is not supported"""
73
74
75def _apply_defusing(defused_mod):
76 assert defused_mod is sys.modules[defused_mod.__name__]
77 stdlib_name = defused_mod.__origin__
78 __import__(stdlib_name, {}, {}, ["*"])
79 stdlib_mod = sys.modules[stdlib_name]
80 stdlib_names = set(dir(stdlib_mod))
81 for name, obj in vars(defused_mod).items():
82 if name.startswith("_") or name not in stdlib_names:
83 continue
84 setattr(stdlib_mod, name, obj)
85 return stdlib_mod