Merge branch 'pr-crashfix_compat_urllib_unquote' of https://github.com/atomicdryad...
[youtube-dl] / youtube_dl / compat.py
index f9529210dd955932eca837aa7022696470c557ed..9e506352fe6238e124d17c4bfeb235b15396cb28 100644 (file)
@@ -9,6 +9,7 @@ import shutil
 import socket
 import subprocess
 import sys
+import itertools
 
 
 try:
@@ -76,40 +77,73 @@ except ImportError:
 try:
     from urllib.parse import unquote as compat_urllib_parse_unquote
 except ImportError:
-    def compat_urllib_parse_unquote(string, encoding='utf-8', errors='replace'):
-        if string == '':
+    def compat_urllib_parse_unquote_to_bytes(string):
+        """unquote_to_bytes('abc%20def') -> b'abc def'."""
+        # Note: strings are encoded as UTF-8. This is only an issue if it contains
+        # unescaped non-ASCII characters, which URIs should not.
+        if not string:
+            # Is it a string-like object?
+            string.split
+            return b''
+        if isinstance(string, str):
+            string = string.encode('utf-8')
+            # string = encode('utf-8')
+
+        # python3 -> 2: must implicitly convert to bits
+        bits = bytes(string).split(b'%')
+
+        if len(bits) == 1:
             return string
-        res = string.split('%')
-        if len(res) == 1:
+        res = [bits[0]]
+        append = res.append
+
+        for item in bits[1:]:
+            if item == '':
+                append(b'%')
+                continue
+            try:
+                append(item[:2].decode('hex'))
+                append(item[2:])
+            except:
+                append(b'%')
+                append(item)
+        return b''.join(res)
+
+    compat_urllib_parse_asciire = re.compile('([\x00-\x7f]+)')
+
+    def compat_urllib_parse_unquote(string, encoding='utf-8', errors='replace'):
+        """Replace %xx escapes by their single-character equivalent. The optional
+        encoding and errors parameters specify how to decode percent-encoded
+        sequences into Unicode characters, as accepted by the bytes.decode()
+        method.
+        By default, percent-encoded sequences are decoded with UTF-8, and invalid
+        sequences are replaced by a placeholder character.
+
+        unquote('abc%20def') -> 'abc def'.
+        """
+
+        if '%' not in string:
+            string.split
             return string
         if encoding is None:
             encoding = 'utf-8'
         if errors is None:
             errors = 'replace'
-        # pct_sequence: contiguous sequence of percent-encoded bytes, decoded
-        pct_sequence = b''
-        string = res[0]
-        for item in res[1:]:
-            try:
-                if not item:
-                    raise ValueError
-                pct_sequence += item[:2].decode('hex')
-                rest = item[2:]
-                if not rest:
-                    # This segment was just a single percent-encoded character.
-                    # May be part of a sequence of code units, so delay decoding.
-                    # (Stored in pct_sequence).
-                    continue
-            except ValueError:
-                rest = '%' + item
-            # Encountered non-percent-encoded characters. Flush the current
-            # pct_sequence.
-            string += pct_sequence.decode(encoding, errors) + rest
-            pct_sequence = b''
-        if pct_sequence:
-            # Flush the final pct_sequence
-            string += pct_sequence.decode(encoding, errors)
-        return string
+
+        bits = compat_urllib_parse_asciire.split(string)
+        res = [bits[0]]
+        append = res.append
+        for i in range(1, len(bits), 2):
+            foo = compat_urllib_parse_unquote_to_bytes(bits[i])
+            foo = foo.decode(encoding, errors)
+            append(foo)
+
+            if bits[i + 1]:
+                bar = bits[i + 1]
+                if not isinstance(bar, unicode):
+                    bar = bar.decode('utf-8')
+                append(bar)
+        return ''.join(res)
 
 try:
     compat_str = unicode  # Python 2
@@ -388,6 +422,15 @@ else:
             pass
         return _terminal_size(columns, lines)
 
+try:
+    itertools.count(start=0, step=1)
+    compat_itertools_count = itertools.count
+except TypeError:  # Python 2.6
+    def compat_itertools_count(start=0, step=1):
+        n = start
+        while True:
+            yield n
+            n += step
 
 __all__ = [
     'compat_HTTPError',
@@ -401,6 +444,7 @@ __all__ = [
     'compat_html_entities',
     'compat_http_client',
     'compat_http_server',
+    'compat_itertools_count',
     'compat_kwargs',
     'compat_ord',
     'compat_parse_qs',
@@ -410,7 +454,9 @@ __all__ = [
     'compat_subprocess_get_DEVNULL',
     'compat_urllib_error',
     'compat_urllib_parse',
+    'compat_urllib_parse_asciire',
     'compat_urllib_parse_unquote',
+    'compat_urllib_parse_unquote_to_bytes',
     'compat_urllib_parse_urlparse',
     'compat_urllib_request',
     'compat_urlparse',