Merge pull request #8739 from remitamine/update_url_params
[youtube-dl] / youtube_dl / utils.py
index a82a262a03ae268a942fa851eedc4de268fdae8c..d431aa6b726c59b40a9c48b3e3144f1d7a2c8db0 100644 (file)
@@ -4,6 +4,7 @@
 from __future__ import unicode_literals
 
 import base64
+import binascii
 import calendar
 import codecs
 import contextlib
@@ -159,8 +160,6 @@ if sys.version_info >= (2, 7):
     def find_xpath_attr(node, xpath, key, val=None):
         """ Find the xpath xpath[@key=val] """
         assert re.match(r'^[a-zA-Z_-]+$', key)
-        if val:
-            assert re.match(r'^[a-zA-Z0-9@\s:._-]*$', val)
         expr = xpath + ('[@%s]' % key if val is None else "[@%s='%s']" % (key, val))
         return node.find(expr)
 else:
@@ -248,7 +247,7 @@ def xpath_attr(node, xpath, key, name=None, fatal=False, default=NO_DEFAULT):
 
 def get_element_by_id(id, html):
     """Return the content of the tag with the specified ID in the passed HTML document"""
-    return get_element_by_attribute("id", id, html)
+    return get_element_by_attribute('id', id, html)
 
 
 def get_element_by_attribute(attribute, value, html):
@@ -466,6 +465,10 @@ def encodeFilename(s, for_subprocess=False):
     if not for_subprocess and sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
         return s
 
+    # Jython assumes filenames are Unicode strings though reported as Python 2.x compatible
+    if sys.platform.startswith('java'):
+        return s
+
     return s.encode(get_subprocess_encoding(), 'ignore')
 
 
@@ -904,9 +907,9 @@ def unified_strdate(date_str, day_first=True):
         '%d %b %Y',
         '%B %d %Y',
         '%b %d %Y',
-        '%b %dst %Y %I:%M%p',
-        '%b %dnd %Y %I:%M%p',
-        '%b %dth %Y %I:%M%p',
+        '%b %dst %Y %I:%M',
+        '%b %dnd %Y %I:%M',
+        '%b %dth %Y %I:%M',
         '%Y %m %d',
         '%Y-%m-%d',
         '%Y/%m/%d',
@@ -994,7 +997,7 @@ def date_from_str(date_str):
         unit += 's'
         delta = datetime.timedelta(**{unit: time})
         return today + delta
-    return datetime.datetime.strptime(date_str, "%Y%m%d").date()
+    return datetime.datetime.strptime(date_str, '%Y%m%d').date()
 
 
 def hyphenate_date(date_str):
@@ -1074,22 +1077,22 @@ def _windows_write_string(s, out):
 
     GetStdHandle = ctypes.WINFUNCTYPE(
         ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD)(
-        (b"GetStdHandle", ctypes.windll.kernel32))
+        (b'GetStdHandle', ctypes.windll.kernel32))
     h = GetStdHandle(WIN_OUTPUT_IDS[fileno])
 
     WriteConsoleW = ctypes.WINFUNCTYPE(
         ctypes.wintypes.BOOL, ctypes.wintypes.HANDLE, ctypes.wintypes.LPWSTR,
         ctypes.wintypes.DWORD, ctypes.POINTER(ctypes.wintypes.DWORD),
-        ctypes.wintypes.LPVOID)((b"WriteConsoleW", ctypes.windll.kernel32))
+        ctypes.wintypes.LPVOID)((b'WriteConsoleW', ctypes.windll.kernel32))
     written = ctypes.wintypes.DWORD(0)
 
-    GetFileType = ctypes.WINFUNCTYPE(ctypes.wintypes.DWORD, ctypes.wintypes.DWORD)((b"GetFileType", ctypes.windll.kernel32))
+    GetFileType = ctypes.WINFUNCTYPE(ctypes.wintypes.DWORD, ctypes.wintypes.DWORD)((b'GetFileType', ctypes.windll.kernel32))
     FILE_TYPE_CHAR = 0x0002
     FILE_TYPE_REMOTE = 0x8000
     GetConsoleMode = ctypes.WINFUNCTYPE(
         ctypes.wintypes.BOOL, ctypes.wintypes.HANDLE,
         ctypes.POINTER(ctypes.wintypes.DWORD))(
-        (b"GetConsoleMode", ctypes.windll.kernel32))
+        (b'GetConsoleMode', ctypes.windll.kernel32))
     INVALID_HANDLE_VALUE = ctypes.wintypes.DWORD(-1).value
 
     def not_a_console(handle):
@@ -1216,13 +1219,23 @@ if sys.platform == 'win32':
             raise OSError('Unlocking file failed: %r' % ctypes.FormatError())
 
 else:
-    import fcntl
+    # Some platforms, such as Jython, is missing fcntl
+    try:
+        import fcntl
 
-    def _lock_file(f, exclusive):
-        fcntl.flock(f, fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH)
+        def _lock_file(f, exclusive):
+            fcntl.flock(f, fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH)
 
-    def _unlock_file(f):
-        fcntl.flock(f, fcntl.LOCK_UN)
+        def _unlock_file(f):
+            fcntl.flock(f, fcntl.LOCK_UN)
+    except ImportError:
+        UNSUPPORTED_MSG = 'file locking is not supported on this platform'
+
+        def _lock_file(f, exclusive):
+            raise IOError(UNSUPPORTED_MSG)
+
+        def _unlock_file(f):
+            raise IOError(UNSUPPORTED_MSG)
 
 
 class locked_file(object):
@@ -1386,8 +1399,14 @@ def fix_xml_ampersands(xml_str):
 
 def setproctitle(title):
     assert isinstance(title, compat_str)
+
+    # ctypes in Jython is not complete
+    # http://bugs.jython.org/issue2148
+    if sys.platform.startswith('java'):
+        return
+
     try:
-        libc = ctypes.cdll.LoadLibrary("libc.so.6")
+        libc = ctypes.cdll.LoadLibrary('libc.so.6')
     except OSError:
         return
     title_bytes = title.encode('utf-8')
@@ -1427,7 +1446,7 @@ def url_basename(url):
 
 class HEADRequest(compat_urllib_request.Request):
     def get_method(self):
-        return "HEAD"
+        return 'HEAD'
 
 
 def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1):
@@ -1569,9 +1588,12 @@ class PagedList(object):
 
 
 class OnDemandPagedList(PagedList):
-    def __init__(self, pagefunc, pagesize):
+    def __init__(self, pagefunc, pagesize, use_cache=False):
         self._pagefunc = pagefunc
         self._pagesize = pagesize
+        self._use_cache = use_cache
+        if use_cache:
+            self._cache = {}
 
     def getslice(self, start=0, end=None):
         res = []
@@ -1581,7 +1603,13 @@ class OnDemandPagedList(PagedList):
             if start >= nextfirstid:
                 continue
 
-            page_results = list(self._pagefunc(pagenum))
+            page_results = None
+            if self._use_cache:
+                page_results = self._cache.get(pagenum)
+            if page_results is None:
+                page_results = list(self._pagefunc(pagenum))
+            if self._use_cache:
+                self._cache[pagenum] = page_results
 
             startv = (
                 start % self._pagesize
@@ -1711,6 +1739,14 @@ def urlencode_postdata(*args, **kargs):
     return compat_urllib_parse.urlencode(*args, **kargs).encode('ascii')
 
 
+def update_url_query(url, query):
+    parsed_url = compat_urlparse.urlparse(url)
+    qs = compat_parse_qs(parsed_url.query)
+    qs.update(query)
+    return compat_urlparse.urlunparse(parsed_url._replace(
+        query=compat_urllib_parse.urlencode(qs, True)))
+
+
 def encode_dict(d, encoding='utf-8'):
     def encode(v):
         return v.encode(encoding) if isinstance(v, compat_basestring) else v
@@ -1744,7 +1780,7 @@ def parse_age_limit(s):
     if s is None:
         return None
     m = re.match(r'^(?P<age>\d{1,2})\+?$', s)
-    return int(m.group('age')) if m else US_RATINGS.get(s, None)
+    return int(m.group('age')) if m else US_RATINGS.get(s)
 
 
 def strip_jsonp(code):
@@ -1835,11 +1871,21 @@ def error_to_compat_str(err):
 
 
 def mimetype2ext(mt):
+    ext = {
+        'audio/mp4': 'm4a',
+    }.get(mt)
+    if ext is not None:
+        return ext
+
     _, _, res = mt.rpartition('/')
 
     return {
         '3gpp': '3gp',
+        'smptett+xml': 'tt',
+        'srt': 'srt',
+        'ttaf+xml': 'dfxp',
         'ttml+xml': 'ttml',
+        'vtt': 'vtt',
         'x-flv': 'flv',
         'x-mp4-fragmented': 'mp4',
         'x-ms-wmv': 'wmv',
@@ -2582,3 +2628,58 @@ class PerRequestProxyHandler(compat_urllib_request.ProxyHandler):
             return None  # No Proxy
         return compat_urllib_request.ProxyHandler.proxy_open(
             self, req, proxy, type)
+
+
+def ohdave_rsa_encrypt(data, exponent, modulus):
+    '''
+    Implement OHDave's RSA algorithm. See http://www.ohdave.com/rsa/
+
+    Input:
+        data: data to encrypt, bytes-like object
+        exponent, modulus: parameter e and N of RSA algorithm, both integer
+    Output: hex string of encrypted data
+
+    Limitation: supports one block encryption only
+    '''
+
+    payload = int(binascii.hexlify(data[::-1]), 16)
+    encrypted = pow(payload, exponent, modulus)
+    return '%x' % encrypted
+
+
+def encode_base_n(num, n, table=None):
+    FULL_TABLE = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
+    if not table:
+        table = FULL_TABLE[:n]
+
+    if n > len(table):
+        raise ValueError('base %d exceeds table length %d' % (n, len(table)))
+
+    if num == 0:
+        return table[0]
+
+    ret = ''
+    while num:
+        ret = table[num % n] + ret
+        num = num // n
+    return ret
+
+
+def decode_packed_codes(code):
+    mobj = re.search(
+        r"}\('(.+)',(\d+),(\d+),'([^']+)'\.split\('\|'\)",
+        code)
+    obfucasted_code, base, count, symbols = mobj.groups()
+    base = int(base)
+    count = int(count)
+    symbols = symbols.split('|')
+    symbol_table = {}
+
+    while count:
+        count -= 1
+        base_n_count = encode_base_n(count, base)
+        symbol_table[base_n_count] = symbols[count] or base_n_count
+
+    return re.sub(
+        r'\b(\w+)\b', lambda mobj: symbol_table[mobj.group(0)],
+        obfucasted_code)