[compat] Add compat_input
[youtube-dl] / youtube_dl / compat.py
index 0b6c5ca7a8ba5eb6cb064916d56a5ca8eae32003..fabac9fd2400240d4da0da3b355a6926408cc283 100644 (file)
@@ -11,6 +11,7 @@ import re
 import shlex
 import shutil
 import socket
+import struct
 import subprocess
 import sys
 import itertools
@@ -244,13 +245,20 @@ try:
 except ImportError:  # Python 2.6
     from xml.parsers.expat import ExpatError as compat_xml_parse_error
 
+
+etree = xml.etree.ElementTree
+
+
+class _TreeBuilder(etree.TreeBuilder):
+    def doctype(self, name, pubid, system):
+        pass
+
 if sys.version_info[0] >= 3:
-    compat_etree_fromstring = xml.etree.ElementTree.fromstring
+    def compat_etree_fromstring(text):
+        return etree.XML(text, parser=etree.XMLParser(target=_TreeBuilder()))
 else:
     # python 2.x tries to encode unicode strings with ascii (see the
     # XMLParser._fixtext method)
-    etree = xml.etree.ElementTree
-
     try:
         _etree_iter = etree.Element.iter
     except AttributeError:  # Python <=2.6
@@ -264,7 +272,7 @@ else:
     # 2.7 source
     def _XML(text, parser=None):
         if not parser:
-            parser = etree.XMLParser(target=etree.TreeBuilder())
+            parser = etree.XMLParser(target=_TreeBuilder())
         parser.feed(text)
         return parser.close()
 
@@ -276,7 +284,7 @@ else:
         return el
 
     def compat_etree_fromstring(text):
-        doc = _XML(text, parser=etree.XMLParser(target=etree.TreeBuilder(element_factory=_element_factory)))
+        doc = _XML(text, parser=etree.XMLParser(target=_TreeBuilder(element_factory=_element_factory)))
         for el in _etree_iter(doc):
             if el.text is not None and isinstance(el.text, bytes):
                 el.text = el.text.decode('utf-8')
@@ -340,9 +348,9 @@ except ImportError:  # Python 2
         return parsed_result
 
 try:
-    from shlex import quote as shlex_quote
+    from shlex import quote as compat_shlex_quote
 except ImportError:  # Python < 3.3
-    def shlex_quote(s):
+    def compat_shlex_quote(s):
         if re.match(r'^[-_\w./]+$', s):
             return s
         else:
@@ -373,6 +381,9 @@ compat_os_name = os._name if os.name == 'java' else os.name
 if sys.version_info >= (3, 0):
     compat_getenv = os.getenv
     compat_expanduser = os.path.expanduser
+
+    def compat_setenv(key, value, env=os.environ):
+        env[key] = value
 else:
     # Environment variables should be decoded with filesystem encoding.
     # Otherwise it will fail if any non-ASCII characters present (see #3854 #3217 #2918)
@@ -384,6 +395,12 @@ else:
             env = env.decode(get_filesystem_encoding())
         return env
 
+    def compat_setenv(key, value, env=os.environ):
+        def encode(v):
+            from .utils import get_filesystem_encoding
+            return v.encode(get_filesystem_encoding()) if isinstance(v, compat_str) else v
+        env[encode(key)] = encode(value)
+
     # HACK: The default implementations of os.path.expanduser from cpython do not decode
     # environment variables with filesystem encoding. We will work around this by
     # providing adjusted implementations.
@@ -456,18 +473,6 @@ else:
         print(s)
 
 
-try:
-    subprocess_check_output = subprocess.check_output
-except AttributeError:
-    def subprocess_check_output(*args, **kwargs):
-        assert 'input' not in kwargs
-        p = subprocess.Popen(*args, stdout=subprocess.PIPE, **kwargs)
-        output, _ = p.communicate()
-        ret = p.poll()
-        if ret:
-            raise subprocess.CalledProcessError(ret, p.args, output=output)
-        return output
-
 if sys.version_info < (3, 0) and sys.platform == 'win32':
     def compat_getpass(prompt, *args, **kwargs):
         if isinstance(prompt, compat_str):
@@ -477,6 +482,11 @@ if sys.version_info < (3, 0) and sys.platform == 'win32':
 else:
     compat_getpass = getpass.getpass
 
+try:
+    compat_input = raw_input
+except NameError:  # Python 3
+    compat_input = input
+
 # Python < 2.6.5 require kwargs to be bytes
 try:
     def _testfunc(x):
@@ -583,6 +593,26 @@ if sys.version_info >= (3, 0):
 else:
     from tokenize import generate_tokens as compat_tokenize_tokenize
 
+
+try:
+    struct.pack('!I', 0)
+except TypeError:
+    # In Python 2.6 and 2.7.x < 2.7.7, struct requires a bytes argument
+    # See https://bugs.python.org/issue19099
+    def compat_struct_pack(spec, *args):
+        if isinstance(spec, compat_str):
+            spec = spec.encode('ascii')
+        return struct.pack(spec, *args)
+
+    def compat_struct_unpack(spec, *args):
+        if isinstance(spec, compat_str):
+            spec = spec.encode('ascii')
+        return struct.unpack(spec, *args)
+else:
+    compat_struct_pack = struct.pack
+    compat_struct_unpack = struct.unpack
+
+
 __all__ = [
     'compat_HTMLParser',
     'compat_HTTPError',
@@ -604,9 +634,13 @@ __all__ = [
     'compat_os_name',
     'compat_parse_qs',
     'compat_print',
+    'compat_setenv',
+    'compat_shlex_quote',
     'compat_shlex_split',
     'compat_socket_create_connection',
     'compat_str',
+    'compat_struct_pack',
+    'compat_struct_unpack',
     'compat_subprocess_get_DEVNULL',
     'compat_tokenize_tokenize',
     'compat_urllib_error',
@@ -623,7 +657,5 @@ __all__ = [
     'compat_urlretrieve',
     'compat_xml_parse_error',
     'compat_xpath',
-    'shlex_quote',
-    'subprocess_check_output',
     'workaround_optparse_bug9161',
 ]