X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=youtube_dl%2Futils.py;h=1013f7c1879e0afac035327e1e60c13ee295ee21;hb=406224be5231e602b543579706ad6056b75fbe68;hp=7de7742e3c2357b8ed15e68656c0c7ea5ca389d5;hpb=aa49acd15a92faa5cfc1d2876821743f86440c13;p=youtube-dl diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 7de7742e3..1013f7c18 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -37,6 +37,7 @@ from .compat import ( compat_chr, compat_html_entities, compat_http_client, + compat_kwargs, compat_parse_qs, compat_socket_create_connection, compat_str, @@ -114,7 +115,7 @@ def write_json_file(obj, fn): 'encoding': 'utf-8', }) - tf = tempfile.NamedTemporaryFile(**args) + tf = tempfile.NamedTemporaryFile(**compat_kwargs(args)) try: with tf: @@ -1128,15 +1129,6 @@ def shell_quote(args): return ' '.join(quoted_args) -def takewhile_inclusive(pred, seq): - """ Like itertools.takewhile, but include the latest evaluated element - (the first element so that Not pred(e)) """ - for e in seq: - yield e - if not pred(e): - return - - def smuggle_url(url, data): """ Pass additional data in a URL for internal use. """ @@ -1357,9 +1349,19 @@ def parse_duration(s): return res -def prepend_extension(filename, ext): +def prepend_extension(filename, ext, expected_real_ext=None): + name, real_ext = os.path.splitext(filename) + return ( + '{0}.{1}{2}'.format(name, ext, real_ext) + if not expected_real_ext or real_ext[1:] == expected_real_ext + else '{0}.{1}'.format(filename, ext)) + + +def replace_extension(filename, ext, expected_real_ext=None): name, real_ext = os.path.splitext(filename) - return '{0}.{1}{2}'.format(name, ext, real_ext) + return '{0}.{1}'.format( + name if not expected_real_ext or real_ext[1:] == expected_real_ext else filename, + ext) def check_executable(exe, args=[]): @@ -1484,6 +1486,14 @@ def uppercase_escape(s): s) +def lowercase_escape(s): + unicode_escape = codecs.getdecoder('unicode_escape') + return re.sub( + r'\\u[0-9a-fA-F]{4}', + lambda m: unicode_escape(m.group(0))[0], + s) + + def escape_rfc3986(s): """Escape non-ASCII characters as suggested by RFC 3986""" if sys.version_info < (3, 0) and isinstance(s, compat_str):