[empflix] Adapt to malformed config XML
[youtube-dl] / youtube_dl / utils.py
index 64a9618ca62493f893af16b31b3fbd331bbdc1e7..4ebdf6a784eaee0f02dd81268991c00a015b881a 100644 (file)
@@ -24,6 +24,7 @@ import socket
 import struct
 import subprocess
 import sys
+import tempfile
 import traceback
 import xml.etree.ElementTree
 import zlib
@@ -91,11 +92,9 @@ except ImportError:
     compat_subprocess_get_DEVNULL = lambda: open(os.path.devnull, 'w')
 
 try:
-    from urllib.parse import parse_qs as compat_parse_qs
-except ImportError: # Python 2
-    # HACK: The following is the correct parse_qs implementation from cpython 3's stdlib.
-    # Python 2's version is apparently totally broken
-    def _unquote(string, encoding='utf-8', errors='replace'):
+    from urllib.parse import unquote as compat_urllib_parse_unquote
+except ImportError:
+    def compat_urllib_parse_unquote(string, encoding='utf-8', errors='replace'):
         if string == '':
             return string
         res = string.split('%')
@@ -130,6 +129,13 @@ except ImportError: # Python 2
             string += pct_sequence.decode(encoding, errors)
         return string
 
+
+try:
+    from urllib.parse import parse_qs as compat_parse_qs
+except ImportError: # Python 2
+    # HACK: The following is the correct parse_qs implementation from cpython 3's stdlib.
+    # Python 2's version is apparently totally broken
+
     def _parse_qsl(qs, keep_blank_values=False, strict_parsing=False,
                 encoding='utf-8', errors='replace'):
         qs, _coerce_result = qs, unicode
@@ -149,10 +155,12 @@ except ImportError: # Python 2
                     continue
             if len(nv[1]) or keep_blank_values:
                 name = nv[0].replace('+', ' ')
-                name = _unquote(name, encoding=encoding, errors=errors)
+                name = compat_urllib_parse_unquote(
+                    name, encoding=encoding, errors=errors)
                 name = _coerce_result(name)
                 value = nv[1].replace('+', ' ')
-                value = _unquote(value, encoding=encoding, errors=errors)
+                value = compat_urllib_parse_unquote(
+                    value, encoding=encoding, errors=errors)
                 value = _coerce_result(value)
                 r.append((name, value))
         return r
@@ -184,6 +192,13 @@ try:
 except ImportError:  # Python 2.6
     from xml.parsers.expat import ExpatError as compat_xml_parse_error
 
+try:
+    from shlex import quote as shlex_quote
+except ImportError:  # Python < 3.3
+    def shlex_quote(s):
+        return "'" + s.replace("'", "'\"'\"'") + "'"
+
+
 def compat_ord(c):
     if type(c) is int: return c
     else: return ord(c)
@@ -221,22 +236,46 @@ else:
         assert type(s) == type(u'')
         print(s)
 
-# In Python 2.x, json.dump expects a bytestream.
-# In Python 3.x, it writes to a character stream
-if sys.version_info < (3,0):
-    def write_json_file(obj, fn):
-        with open(fn, 'wb') as f:
-            json.dump(obj, f)
-else:
-    def write_json_file(obj, fn):
-        with open(fn, 'w', encoding='utf-8') as f:
-            json.dump(obj, f)
 
-if sys.version_info >= (2,7):
+def write_json_file(obj, fn):
+    """ Encode obj as JSON and write it to fn, atomically """
+
+    args = {
+        'suffix': '.tmp',
+        'prefix': os.path.basename(fn) + '.',
+        'dir': os.path.dirname(fn),
+        'delete': False,
+    }
+
+    # In Python 2.x, json.dump expects a bytestream.
+    # In Python 3.x, it writes to a character stream
+    if sys.version_info < (3, 0):
+        args['mode'] = 'wb'
+    else:
+        args.update({
+            'mode': 'w',
+            'encoding': 'utf-8',
+        })
+
+    tf = tempfile.NamedTemporaryFile(**args)
+
+    try:
+        with tf:
+            json.dump(obj, tf)
+        os.rename(tf.name, fn)
+    except:
+        try:
+            os.remove(tf.name)
+        except OSError:
+            pass
+        raise
+
+
+if sys.version_info >= (2, 7):
     def find_xpath_attr(node, xpath, key, val):
         """ Find the xpath xpath[@key=val] """
-        assert re.match(r'^[a-zA-Z]+$', key)
-        assert re.match(r'^[a-zA-Z0-9@\s:._]*$', val)
+        assert re.match(r'^[a-zA-Z-]+$', key)
+        assert re.match(r'^[a-zA-Z0-9@\s:._-]*$', val)
         expr = xpath + u"[@%s='%s']" % (key, val)
         return node.find(expr)
 else:
@@ -727,10 +766,9 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
         return ret
 
     def http_request(self, req):
-        for h,v in std_headers.items():
-            if h in req.headers:
-                del req.headers[h]
-            req.add_header(h, v)
+        for h, v in std_headers.items():
+            if h not in req.headers:
+                req.add_header(h, v)
         if 'Youtubedl-no-compression' in req.headers:
             if 'Accept-encoding' in req.headers:
                 del req.headers['Accept-encoding']
@@ -820,8 +858,10 @@ def unified_strdate(date_str):
         '%b %dnd %Y %I:%M%p',
         '%b %dth %Y %I:%M%p',
         '%Y-%m-%d',
+        '%Y/%m/%d',
         '%d.%m.%Y',
         '%d/%m/%Y',
+        '%d/%m/%y',
         '%Y/%m/%d %H:%M:%S',
         '%Y-%m-%d %H:%M:%S',
         '%d.%m.%Y %H:%M',
@@ -845,6 +885,8 @@ def unified_strdate(date_str):
     return upload_date
 
 def determine_ext(url, default_ext=u'unknown_video'):
+    if url is None:
+        return default_ext
     guess = url.partition(u'?')[0].rpartition(u'.')[2]
     if re.match(r'^[A-Za-z0-9]+$', guess):
         return guess
@@ -1193,11 +1235,6 @@ def format_bytes(bytes):
     return u'%.2f%s' % (converted, suffix)
 
 
-def str_to_int(int_str):
-    int_str = re.sub(r'[,\.]', u'', int_str)
-    return int(int_str)
-
-
 def get_term_width():
     columns = os.environ.get('COLUMNS', None)
     if columns:
@@ -1255,6 +1292,12 @@ def remove_start(s, start):
     return s
 
 
+def remove_end(s, end):
+    if s.endswith(end):
+        return s[:-len(end)]
+    return s
+
+
 def url_basename(url):
     path = compat_urlparse.urlparse(url).path
     return path.strip(u'/').split(u'/')[-1]
@@ -1265,15 +1308,28 @@ class HEADRequest(compat_urllib_request.Request):
         return "HEAD"
 
 
-def int_or_none(v, scale=1, default=None, get_attr=None):
+def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1):
     if get_attr:
         if v is not None:
             v = getattr(v, get_attr, None)
-    return default if v is None else (int(v) // scale)
+    if v == '':
+        v = None
+    return default if v is None else (int(v) * invscale // scale)
+
 
+def str_or_none(v, default=None):
+    return default if v is None else compat_str(v)
 
-def float_or_none(v, scale=1, default=None):
-    return default if v is None else (float(v) / scale)
+
+def str_to_int(int_str):
+    if int_str is None:
+        return None
+    int_str = re.sub(r'[,\.]', u'', int_str)
+    return int(int_str)
+
+
+def float_or_none(v, scale=1, invscale=1, default=None):
+    return default if v is None else (float(v) * invscale / scale)
 
 
 def parse_duration(s):
@@ -1281,7 +1337,7 @@ def parse_duration(s):
         return None
 
     m = re.match(
-        r'(?:(?:(?P<hours>[0-9]+)[:h])?(?P<mins>[0-9]+)[:m])?(?P<secs>[0-9]+)s?(?::[0-9]+)?$', s)
+        r'(?:(?:(?P<hours>[0-9]+)[:h])?(?P<mins>[0-9]+)[:m])?(?P<secs>[0-9]+)s?(?::[0-9]+)?(?P<ms>\.[0-9]+)?$', s)
     if not m:
         return None
     res = int(m.group('secs'))
@@ -1289,6 +1345,8 @@ def parse_duration(s):
         res += int(m.group('mins')) * 60
         if m.group('hours'):
             res += int(m.group('hours')) * 60 * 60
+    if m.group('ms'):
+        res += float(m.group('ms'))
     return res
 
 
@@ -1399,6 +1457,12 @@ def urlencode_postdata(*args, **kargs):
     return compat_urllib_parse.urlencode(*args, **kargs).encode('ascii')
 
 
+try:
+    etree_iter = xml.etree.ElementTree.Element.iter
+except AttributeError:  # Python <=2.6
+    etree_iter = lambda n: n.findall('.//*')
+
+
 def parse_xml(s):
     class TreeBuilder(xml.etree.ElementTree.TreeBuilder):
         def doctype(self, name, pubid, system):
@@ -1406,7 +1470,14 @@ def parse_xml(s):
 
     parser = xml.etree.ElementTree.XMLParser(target=TreeBuilder())
     kwargs = {'parser': parser} if sys.version_info >= (2, 7) else {}
-    return xml.etree.ElementTree.XML(s.encode('utf-8'), **kwargs)
+    tree = xml.etree.ElementTree.XML(s.encode('utf-8'), **kwargs)
+    # Fix up XML parser in Python 2.x
+    if sys.version_info < (3, 0):
+        for n in etree_iter(tree):
+            if n.text is not None:
+                if not isinstance(n.text, compat_str):
+                    n.text = n.text.decode('utf-8')
+    return tree
 
 
 if sys.version_info < (3, 0) and sys.platform == 'win32':
@@ -1431,6 +1502,34 @@ def strip_jsonp(code):
     return re.sub(r'(?s)^[a-zA-Z0-9_]+\s*\(\s*(.*)\);?\s*?\s*$', r'\1', code)
 
 
+def js_to_json(code):
+    def fix_kv(m):
+        key = m.group(2)
+        if key.startswith("'"):
+            assert key.endswith("'")
+            assert '"' not in key
+            key = '"%s"' % key[1:-1]
+        elif not key.startswith('"'):
+            key = '"%s"' % key
+
+        value = m.group(4)
+        if value.startswith("'"):
+            assert value.endswith("'")
+            assert '"' not in value
+            value = '"%s"' % value[1:-1]
+
+        return m.group(1) + key + m.group(3) + value
+
+    res = re.sub(r'''(?x)
+            ([{,]\s*)
+            ("[^"]*"|\'[^\']*\'|[a-z0-9A-Z]+)
+            (:\s*)
+            ([0-9.]+|true|false|"[^"]*"|\'[^\']*\'|\[|\{)
+        ''', fix_kv, code)
+    res = re.sub(r',(\s*\])', lambda m: m.group(1), res)
+    return res
+
+
 def qualities(quality_ids):
     """ Get a numeric quality value out of a list of possible values """
     def q(qid):