Merge pull request #8092 from bpfoley/twitter-thumbnail
[youtube-dl] / youtube_dl / utils.py
index 7ce661b09ea428ea05954b355e09ea298e04fb73..ec186918cd8672ada2da2d5521e0ba8b22eb273d 100644 (file)
@@ -35,6 +35,7 @@ import xml.etree.ElementTree
 import zlib
 
 from .compat import (
+    compat_HTMLParser,
     compat_basestring,
     compat_chr,
     compat_etree_fromstring,
@@ -160,8 +161,6 @@ if sys.version_info >= (2, 7):
     def find_xpath_attr(node, xpath, key, val=None):
         """ Find the xpath xpath[@key=val] """
         assert re.match(r'^[a-zA-Z_-]+$', key)
-        if val:
-            assert re.match(r'^[a-zA-Z0-9@\s:._-]*$', val)
         expr = xpath + ('[@%s]' % key if val is None else "[@%s='%s']" % (key, val))
         return node.find(expr)
 else:
@@ -274,6 +273,35 @@ def get_element_by_attribute(attribute, value, html):
 
     return unescapeHTML(res)
 
+class HTMLAttributeParser(compat_HTMLParser):
+    """Trivial HTML parser to gather the attributes for a single element"""
+    def __init__(self):
+        self.attrs = { }
+        compat_HTMLParser.__init__(self)
+
+    def handle_starttag(self, tag, attrs):
+        self.attrs = dict(attrs)
+
+def extract_attributes(html_element):
+    """Given a string for an HTML element such as
+    <el
+         a="foo" B="bar" c="&98;az" d=boz
+         empty= noval entity="&amp;"
+         sq='"' dq="'"
+    >
+    Decode and return a dictionary of attributes.
+    {
+        'a': 'foo', 'b': 'bar', c: 'baz', d: 'boz',
+        'empty': '', 'noval': None, 'entity': '&',
+        'sq': '"', 'dq': '\''
+    }.
+    NB HTMLParser is stricter in Python 2.6 & 3.2 than in later versions,
+    but the cases in the unit test will work for all of 2.6, 2.7, 3.2-3.5.
+    """
+    parser = HTMLAttributeParser()
+    parser.feed(html_element)
+    parser.close()
+    return parser.attrs
 
 def clean_html(html):
     """Clean an HTML snippet into a readable string"""
@@ -467,6 +495,10 @@ def encodeFilename(s, for_subprocess=False):
     if not for_subprocess and sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
         return s
 
+    # Jython assumes filenames are Unicode strings though reported as Python 2.x compatible
+    if sys.platform.startswith('java'):
+        return s
+
     return s.encode(get_subprocess_encoding(), 'ignore')
 
 
@@ -905,9 +937,9 @@ def unified_strdate(date_str, day_first=True):
         '%d %b %Y',
         '%B %d %Y',
         '%b %d %Y',
-        '%b %dst %Y %I:%M%p',
-        '%b %dnd %Y %I:%M%p',
-        '%b %dth %Y %I:%M%p',
+        '%b %dst %Y %I:%M',
+        '%b %dnd %Y %I:%M',
+        '%b %dth %Y %I:%M',
         '%Y %m %d',
         '%Y-%m-%d',
         '%Y/%m/%d',
@@ -1217,13 +1249,23 @@ if sys.platform == 'win32':
             raise OSError('Unlocking file failed: %r' % ctypes.FormatError())
 
 else:
-    import fcntl
+    # Some platforms, such as Jython, is missing fcntl
+    try:
+        import fcntl
 
-    def _lock_file(f, exclusive):
-        fcntl.flock(f, fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH)
+        def _lock_file(f, exclusive):
+            fcntl.flock(f, fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH)
 
-    def _unlock_file(f):
-        fcntl.flock(f, fcntl.LOCK_UN)
+        def _unlock_file(f):
+            fcntl.flock(f, fcntl.LOCK_UN)
+    except ImportError:
+        UNSUPPORTED_MSG = 'file locking is not supported on this platform'
+
+        def _lock_file(f, exclusive):
+            raise IOError(UNSUPPORTED_MSG)
+
+        def _unlock_file(f):
+            raise IOError(UNSUPPORTED_MSG)
 
 
 class locked_file(object):
@@ -1304,6 +1346,17 @@ def format_bytes(bytes):
     return '%.2f%s' % (converted, suffix)
 
 
+def lookup_unit_table(unit_table, s):
+    units_re = '|'.join(re.escape(u) for u in unit_table)
+    m = re.match(
+        r'(?P<num>[0-9]+(?:[,.][0-9]*)?)\s*(?P<unit>%s)' % units_re, s)
+    if not m:
+        return None
+    num_str = m.group('num').replace(',', '.')
+    mult = unit_table[m.group('unit')]
+    return int(float(num_str) * mult)
+
+
 def parse_filesize(s):
     if s is None:
         return None
@@ -1347,15 +1400,28 @@ def parse_filesize(s):
         'Yb': 1000 ** 8,
     }
 
-    units_re = '|'.join(re.escape(u) for u in _UNIT_TABLE)
-    m = re.match(
-        r'(?P<num>[0-9]+(?:[,.][0-9]*)?)\s*(?P<unit>%s)' % units_re, s)
-    if not m:
+    return lookup_unit_table(_UNIT_TABLE, s)
+
+
+def parse_count(s):
+    if s is None:
         return None
 
-    num_str = m.group('num').replace(',', '.')
-    mult = _UNIT_TABLE[m.group('unit')]
-    return int(float(num_str) * mult)
+    s = s.strip()
+
+    if re.match(r'^[\d,.]+$', s):
+        return str_to_int(s)
+
+    _UNIT_TABLE = {
+        'k': 1000,
+        'K': 1000,
+        'm': 1000 ** 2,
+        'M': 1000 ** 2,
+        'kk': 1000 ** 2,
+        'KK': 1000 ** 2,
+    }
+
+    return lookup_unit_table(_UNIT_TABLE, s)
 
 
 def month_by_name(name):
@@ -1387,6 +1453,12 @@ def fix_xml_ampersands(xml_str):
 
 def setproctitle(title):
     assert isinstance(title, compat_str)
+
+    # ctypes in Jython is not complete
+    # http://bugs.jython.org/issue2148
+    if sys.platform.startswith('java'):
+        return
+
     try:
         libc = ctypes.cdll.LoadLibrary('libc.so.6')
     except OSError:
@@ -1570,9 +1642,12 @@ class PagedList(object):
 
 
 class OnDemandPagedList(PagedList):
-    def __init__(self, pagefunc, pagesize):
+    def __init__(self, pagefunc, pagesize, use_cache=False):
         self._pagefunc = pagefunc
         self._pagesize = pagesize
+        self._use_cache = use_cache
+        if use_cache:
+            self._cache = {}
 
     def getslice(self, start=0, end=None):
         res = []
@@ -1582,7 +1657,13 @@ class OnDemandPagedList(PagedList):
             if start >= nextfirstid:
                 continue
 
-            page_results = list(self._pagefunc(pagenum))
+            page_results = None
+            if self._use_cache:
+                page_results = self._cache.get(pagenum)
+            if page_results is None:
+                page_results = list(self._pagefunc(pagenum))
+            if self._use_cache:
+                self._cache[pagenum] = page_results
 
             startv = (
                 start % self._pagesize
@@ -1712,6 +1793,15 @@ def urlencode_postdata(*args, **kargs):
     return compat_urllib_parse.urlencode(*args, **kargs).encode('ascii')
 
 
+def update_url_query(url, query):
+    parsed_url = compat_urlparse.urlparse(url)
+    qs = compat_parse_qs(parsed_url.query)
+    qs.update(query)
+    qs = encode_dict(qs)
+    return compat_urlparse.urlunparse(parsed_url._replace(
+        query=compat_urllib_parse.urlencode(qs, True)))
+
+
 def encode_dict(d, encoding='utf-8'):
     def encode(v):
         return v.encode(encoding) if isinstance(v, compat_basestring) else v
@@ -1836,11 +1926,21 @@ def error_to_compat_str(err):
 
 
 def mimetype2ext(mt):
+    ext = {
+        'audio/mp4': 'm4a',
+    }.get(mt)
+    if ext is not None:
+        return ext
+
     _, _, res = mt.rpartition('/')
 
     return {
         '3gpp': '3gp',
+        'smptett+xml': 'tt',
+        'srt': 'srt',
+        'ttaf+xml': 'dfxp',
         'ttml+xml': 'ttml',
+        'vtt': 'vtt',
         'x-flv': 'flv',
         'x-mp4-fragmented': 'mp4',
         'x-ms-wmv': 'wmv',
@@ -2600,3 +2700,41 @@ def ohdave_rsa_encrypt(data, exponent, modulus):
     payload = int(binascii.hexlify(data[::-1]), 16)
     encrypted = pow(payload, exponent, modulus)
     return '%x' % encrypted
+
+
+def encode_base_n(num, n, table=None):
+    FULL_TABLE = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
+    if not table:
+        table = FULL_TABLE[:n]
+
+    if n > len(table):
+        raise ValueError('base %d exceeds table length %d' % (n, len(table)))
+
+    if num == 0:
+        return table[0]
+
+    ret = ''
+    while num:
+        ret = table[num % n] + ret
+        num = num // n
+    return ret
+
+
+def decode_packed_codes(code):
+    mobj = re.search(
+        r"}\('(.+)',(\d+),(\d+),'([^']+)'\.split\('\|'\)",
+        code)
+    obfucasted_code, base, count, symbols = mobj.groups()
+    base = int(base)
+    count = int(count)
+    symbols = symbols.split('|')
+    symbol_table = {}
+
+    while count:
+        count -= 1
+        base_n_count = encode_base_n(count, base)
+        symbol_table[base_n_count] = symbols[count] or base_n_count
+
+    return re.sub(
+        r'\b(\w+)\b', lambda mobj: symbol_table[mobj.group(0)],
+        obfucasted_code)