[extractor/generic] Fix following incomplete redirects (#5640)
[youtube-dl] / youtube_dl / utils.py
index 7de7742e3c2357b8ed15e68656c0c7ea5ca389d5..1013f7c1879e0afac035327e1e60c13ee295ee21 100644 (file)
@@ -37,6 +37,7 @@ from .compat import (
     compat_chr,
     compat_html_entities,
     compat_http_client,
+    compat_kwargs,
     compat_parse_qs,
     compat_socket_create_connection,
     compat_str,
@@ -114,7 +115,7 @@ def write_json_file(obj, fn):
             'encoding': 'utf-8',
         })
 
-    tf = tempfile.NamedTemporaryFile(**args)
+    tf = tempfile.NamedTemporaryFile(**compat_kwargs(args))
 
     try:
         with tf:
@@ -1128,15 +1129,6 @@ def shell_quote(args):
     return ' '.join(quoted_args)
 
 
-def takewhile_inclusive(pred, seq):
-    """ Like itertools.takewhile, but include the latest evaluated element
-        (the first element so that Not pred(e)) """
-    for e in seq:
-        yield e
-        if not pred(e):
-            return
-
-
 def smuggle_url(url, data):
     """ Pass additional data in a URL for internal use. """
 
@@ -1357,9 +1349,19 @@ def parse_duration(s):
     return res
 
 
-def prepend_extension(filename, ext):
+def prepend_extension(filename, ext, expected_real_ext=None):
+    name, real_ext = os.path.splitext(filename)
+    return (
+        '{0}.{1}{2}'.format(name, ext, real_ext)
+        if not expected_real_ext or real_ext[1:] == expected_real_ext
+        else '{0}.{1}'.format(filename, ext))
+
+
+def replace_extension(filename, ext, expected_real_ext=None):
     name, real_ext = os.path.splitext(filename)
-    return '{0}.{1}{2}'.format(name, ext, real_ext)
+    return '{0}.{1}'.format(
+        name if not expected_real_ext or real_ext[1:] == expected_real_ext else filename,
+        ext)
 
 
 def check_executable(exe, args=[]):
@@ -1484,6 +1486,14 @@ def uppercase_escape(s):
         s)
 
 
+def lowercase_escape(s):
+    unicode_escape = codecs.getdecoder('unicode_escape')
+    return re.sub(
+        r'\\u[0-9a-fA-F]{4}',
+        lambda m: unicode_escape(m.group(0))[0],
+        s)
+
+
 def escape_rfc3986(s):
     """Escape non-ASCII characters as suggested by RFC 3986"""
     if sys.version_info < (3, 0) and isinstance(s, compat_str):