Merge branch 'brightcove_in_page_embed' of https://github.com/remitamine/youtube...
[youtube-dl] / youtube_dl / utils.py
index ae813099dded05c15ed57d27e51be7704e778520..65556d056a7edfe12b94e82d0f55baa93360aa70 100644 (file)
@@ -3,6 +3,7 @@
 
 from __future__ import unicode_literals
 
+import base64
 import calendar
 import codecs
 import contextlib
@@ -35,6 +36,7 @@ import zlib
 from .compat import (
     compat_basestring,
     compat_chr,
+    compat_etree_fromstring,
     compat_html_entities,
     compat_http_client,
     compat_kwargs,
@@ -139,21 +141,24 @@ def write_json_file(obj, fn):
 
 
 if sys.version_info >= (2, 7):
-    def find_xpath_attr(node, xpath, key, val):
+    def find_xpath_attr(node, xpath, key, val=None):
         """ Find the xpath xpath[@key=val] """
-        assert re.match(r'^[a-zA-Z-]+$', key)
-        assert re.match(r'^[a-zA-Z0-9@\s:._-]*$', val)
-        expr = xpath + "[@%s='%s']" % (key, val)
+        assert re.match(r'^[a-zA-Z_-]+$', key)
+        if val:
+            assert re.match(r'^[a-zA-Z0-9@\s:._-]*$', val)
+        expr = xpath + ('[@%s]' % key if val is None else "[@%s='%s']" % (key, val))
         return node.find(expr)
 else:
-    def find_xpath_attr(node, xpath, key, val):
+    def find_xpath_attr(node, xpath, key, val=None):
         # Here comes the crazy part: In 2.6, if the xpath is a unicode,
         # .//node does not match if a node is a direct child of . !
         if isinstance(xpath, compat_str):
             xpath = xpath.encode('ascii')
 
         for f in node.findall(xpath):
-            if f.attrib.get(key) == val:
+            if key not in f.attrib:
+                continue
+            if val is None or f.attrib.get(key) == val:
                 return f
         return None
 
@@ -173,12 +178,21 @@ def xpath_with_ns(path, ns_map):
     return '/'.join(replaced)
 
 
-def xpath_text(node, xpath, name=None, fatal=False, default=NO_DEFAULT):
-    if sys.version_info < (2, 7):  # Crazy 2.6
-        xpath = xpath.encode('ascii')
+def xpath_element(node, xpath, name=None, fatal=False, default=NO_DEFAULT):
+    def _find_xpath(xpath):
+        if sys.version_info < (2, 7):  # Crazy 2.6
+            xpath = xpath.encode('ascii')
+        return node.find(xpath)
 
-    n = node.find(xpath)
-    if n is None or n.text is None:
+    if isinstance(xpath, (str, compat_str)):
+        n = _find_xpath(xpath)
+    else:
+        for xp in xpath:
+            n = _find_xpath(xp)
+            if n is not None:
+                break
+
+    if n is None:
         if default is not NO_DEFAULT:
             return default
         elif fatal:
@@ -186,9 +200,37 @@ def xpath_text(node, xpath, name=None, fatal=False, default=NO_DEFAULT):
             raise ExtractorError('Could not find XML element %s' % name)
         else:
             return None
+    return n
+
+
+def xpath_text(node, xpath, name=None, fatal=False, default=NO_DEFAULT):
+    n = xpath_element(node, xpath, name, fatal=fatal, default=default)
+    if n is None or n == default:
+        return n
+    if n.text is None:
+        if default is not NO_DEFAULT:
+            return default
+        elif fatal:
+            name = xpath if name is None else name
+            raise ExtractorError('Could not find XML element\'s text %s' % name)
+        else:
+            return None
     return n.text
 
 
+def xpath_attr(node, xpath, key, name=None, fatal=False, default=NO_DEFAULT):
+    n = find_xpath_attr(node, xpath, key)
+    if n is None:
+        if default is not NO_DEFAULT:
+            return default
+        elif fatal:
+            name = '%s[@%s]' % (xpath, key) if name is None else name
+            raise ExtractorError('Could not find XML attribute %s' % name)
+        else:
+            return None
+    return n.attrib[key]
+
+
 def get_element_by_id(id, html):
     """Return the content of the tag with the specified ID in the passed HTML document"""
     return get_element_by_attribute("id", id, html)
@@ -217,6 +259,15 @@ def get_element_by_attribute(attribute, value, html):
     return unescapeHTML(res)
 
 
+def extract_attributes(attributes_str, attributes_regex=r'(?s)\s*([^\s=]+)\s*=\s*["\']([^"\']+)["\']'):
+    attributes = re.findall(attributes_regex, attributes_str)
+    attributes_dict = {}
+    if attributes:
+        for (attribute_name, attribute_value) in attributes:
+            attributes_dict[attribute_name] = attribute_value
+    return attributes_dict
+
+
 def clean_html(html):
     """Clean an HTML snippet into a readable string"""
 
@@ -324,7 +375,7 @@ def sanitize_path(s):
     if drive_or_unc:
         norm_path.pop(0)
     sanitized_path = [
-        path_part if path_part in ['.', '..'] else re.sub('(?:[/<>:"\\|\\\\?\\*]|\.$)', '#', path_part)
+        path_part if path_part in ['.', '..'] else re.sub('(?:[/<>:"\\|\\\\?\\*]|[\s.]$)', '#', path_part)
         for path_part in norm_path]
     if drive_or_unc:
         sanitized_path.insert(0, drive_or_unc + os.path.sep)
@@ -576,16 +627,19 @@ class ContentTooShortError(Exception):
     download is too small for what the server announced first, indicating
     the connection was probably interrupted.
     """
-    # Both in bytes
-    downloaded = None
-    expected = None
 
     def __init__(self, downloaded, expected):
+        # Both in bytes
         self.downloaded = downloaded
         self.expected = expected
 
 
 def _create_http_connection(ydl_handler, http_class, is_https, *args, **kwargs):
+    # Working around python 2 bug (see http://bugs.python.org/issue17849) by limiting
+    # expected HTTP responses to meet HTTP/1.0 or later (see also
+    # https://github.com/rg3/youtube-dl/issues/6727)
+    if sys.version_info < (3, 0):
+        kwargs[b'strict'] = True
     hc = http_class(*args, **kwargs)
     source_address = ydl_handler._params.get('source_address')
     if source_address is not None:
@@ -650,6 +704,26 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
         return ret
 
     def http_request(self, req):
+        # According to RFC 3986, URLs can not contain non-ASCII characters, however this is not
+        # always respected by websites, some tend to give out URLs with non percent-encoded
+        # non-ASCII characters (see telemb.py, ard.py [#3412])
+        # urllib chokes on URLs with non-ASCII characters (see http://bugs.python.org/issue3991)
+        # To work around aforementioned issue we will replace request's original URL with
+        # percent-encoded one
+        # Since redirects are also affected (e.g. http://www.southpark.de/alle-episoden/s18e09)
+        # the code of this workaround has been moved here from YoutubeDL.urlopen()
+        url = req.get_full_url()
+        url_escaped = escape_url(url)
+
+        # Substitute URL if any change after escaping
+        if url != url_escaped:
+            req_type = HEADRequest if req.get_method() == 'HEAD' else compat_urllib_request.Request
+            new_req = req_type(
+                url_escaped, data=req.data, headers=req.headers,
+                origin_req_host=req.origin_req_host, unverifiable=req.unverifiable)
+            new_req.timeout = req.timeout
+            req = new_req
+
         for h, v in std_headers.items():
             # Capitalize is needed because of Python bug 2275: http://bugs.python.org/issue2275
             # The dict keys are capitalized because of this bug by urllib
@@ -694,6 +768,18 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
             gz = io.BytesIO(self.deflate(resp.read()))
             resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
             resp.msg = old_resp.msg
+        # Percent-encode redirect URL of Location HTTP header to satisfy RFC 3986 (see
+        # https://github.com/rg3/youtube-dl/issues/6457).
+        if 300 <= resp.code < 400:
+            location = resp.headers.get('Location')
+            if location:
+                # As of RFC 2616 default charset is iso-8859-1 that is respected by python 3
+                if sys.version_info >= (3, 0):
+                    location = location.encode('iso-8859-1').decode('utf-8')
+                location_escaped = escape_url(location)
+                if location != location_escaped:
+                    del resp.headers['Location']
+                    resp.headers['Location'] = location_escaped
         return resp
 
     https_request = http_request
@@ -717,15 +803,41 @@ class YoutubeDLHTTPSHandler(compat_urllib_request.HTTPSHandler):
             req, **kwargs)
 
 
+class YoutubeDLCookieProcessor(compat_urllib_request.HTTPCookieProcessor):
+    def __init__(self, cookiejar=None):
+        compat_urllib_request.HTTPCookieProcessor.__init__(self, cookiejar)
+
+    def http_response(self, request, response):
+        # Python 2 will choke on next HTTP request in row if there are non-ASCII
+        # characters in Set-Cookie HTTP header of last response (see
+        # https://github.com/rg3/youtube-dl/issues/6769).
+        # In order to at least prevent crashing we will percent encode Set-Cookie
+        # header before HTTPCookieProcessor starts processing it.
+        # if sys.version_info < (3, 0) and response.headers:
+        #     for set_cookie_header in ('Set-Cookie', 'Set-Cookie2'):
+        #         set_cookie = response.headers.get(set_cookie_header)
+        #         if set_cookie:
+        #             set_cookie_escaped = compat_urllib_parse.quote(set_cookie, b"%/;:@&=+$,!~*'()?#[] ")
+        #             if set_cookie != set_cookie_escaped:
+        #                 del response.headers[set_cookie_header]
+        #                 response.headers[set_cookie_header] = set_cookie_escaped
+        return compat_urllib_request.HTTPCookieProcessor.http_response(self, request, response)
+
+    https_request = compat_urllib_request.HTTPCookieProcessor.http_request
+    https_response = http_response
+
+
 def parse_iso8601(date_str, delimiter='T', timezone=None):
     """ Return a UNIX timestamp from the given date """
 
     if date_str is None:
         return None
 
+    date_str = re.sub(r'\.[0-9]+', '', date_str)
+
     if timezone is None:
         m = re.search(
-            r'(\.[0-9]+)?(?:Z$| ?(?P<sign>\+|-)(?P<hours>[0-9]{2}):?(?P<minutes>[0-9]{2})$)',
+            r'(?:Z$| ?(?P<sign>\+|-)(?P<hours>[0-9]{2}):?(?P<minutes>[0-9]{2})$)',
             date_str)
         if not m:
             timezone = datetime.timedelta()
@@ -738,9 +850,12 @@ def parse_iso8601(date_str, delimiter='T', timezone=None):
                 timezone = datetime.timedelta(
                     hours=sign * int(m.group('hours')),
                     minutes=sign * int(m.group('minutes')))
-    date_format = '%Y-%m-%d{0}%H:%M:%S'.format(delimiter)
-    dt = datetime.datetime.strptime(date_str, date_format) - timezone
-    return calendar.timegm(dt.timetuple())
+    try:
+        date_format = '%Y-%m-%d{0}%H:%M:%S'.format(delimiter)
+        dt = datetime.datetime.strptime(date_str, date_format) - timezone
+        return calendar.timegm(dt.timetuple())
+    except ValueError:
+        pass
 
 
 def unified_strdate(date_str, day_first=True):
@@ -805,7 +920,8 @@ def unified_strdate(date_str, day_first=True):
         timetuple = email.utils.parsedate_tz(date_str)
         if timetuple:
             upload_date = datetime.datetime(*timetuple[:6]).strftime('%Y%m%d')
-    return upload_date
+    if upload_date is not None:
+        return compat_str(upload_date)
 
 
 def determine_ext(url, default_ext='unknown_video'):
@@ -1281,7 +1397,12 @@ def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1):
             v = getattr(v, get_attr, None)
     if v == '':
         v = None
-    return default if v is None else (int(v) * invscale // scale)
+    if v is None:
+        return default
+    try:
+        return int(v) * invscale // scale
+    except ValueError:
+        return default
 
 
 def str_or_none(v, default=None):
@@ -1297,7 +1418,12 @@ def str_to_int(int_str):
 
 
 def float_or_none(v, scale=1, invscale=1, default=None):
-    return default if v is None else (float(v) * invscale / scale)
+    if v is None:
+        return default
+    try:
+        return float(v) * invscale / scale
+    except ValueError:
+        return default
 
 
 def parse_duration(s):
@@ -1546,27 +1672,8 @@ def urlencode_postdata(*args, **kargs):
     return compat_urllib_parse.urlencode(*args, **kargs).encode('ascii')
 
 
-try:
-    etree_iter = xml.etree.ElementTree.Element.iter
-except AttributeError:  # Python <=2.6
-    etree_iter = lambda n: n.findall('.//*')
-
-
-def parse_xml(s):
-    class TreeBuilder(xml.etree.ElementTree.TreeBuilder):
-        def doctype(self, name, pubid, system):
-            pass  # Ignore doctypes
-
-    parser = xml.etree.ElementTree.XMLParser(target=TreeBuilder())
-    kwargs = {'parser': parser} if sys.version_info >= (2, 7) else {}
-    tree = xml.etree.ElementTree.XML(s.encode('utf-8'), **kwargs)
-    # Fix up XML parser in Python 2.x
-    if sys.version_info < (3, 0):
-        for n in etree_iter(tree):
-            if n.text is not None:
-                if not isinstance(n.text, compat_str):
-                    n.text = n.text.decode('utf-8')
-    return tree
+def encode_dict(d, encoding='utf-8'):
+    return dict((k.encode(encoding), v.encode(encoding)) for k, v in d.items())
 
 
 US_RATINGS = {
@@ -1596,8 +1703,8 @@ def js_to_json(code):
         if v in ('true', 'false', 'null'):
             return v
         if v.startswith('"'):
-            return v
-        if v.startswith("'"):
+            v = re.sub(r"\\'", "'", v[1:-1])
+        elif v.startswith("'"):
             v = v[1:-1]
             v = re.sub(r"\\\\|\\'|\"", lambda m: {
                 '\\\\': '\\\\',
@@ -1691,6 +1798,10 @@ def urlhandle_detect_ext(url_handle):
     return mimetype2ext(getheader('Content-Type'))
 
 
+def encode_data_uri(data, mime_type):
+    return 'data:%s;base64,%s' % (mime_type, base64.b64encode(data).decode('ascii'))
+
+
 def age_restricted(content_limit, age_limit):
     """ Returns True iff the content should be blocked """
 
@@ -1865,7 +1976,7 @@ def dfxp2srt(dfxp_data):
 
         return out
 
-    dfxp = xml.etree.ElementTree.fromstring(dfxp_data.encode('utf-8'))
+    dfxp = compat_etree_fromstring(dfxp_data.encode('utf-8'))
     out = []
     paras = dfxp.findall(_x('.//ttml:p')) or dfxp.findall(_x('.//ttaf1:p')) or dfxp.findall('.//p')
 
@@ -1886,6 +1997,32 @@ def dfxp2srt(dfxp_data):
     return ''.join(out)
 
 
+def cli_option(params, command_option, param):
+    param = params.get(param)
+    return [command_option, param] if param is not None else []
+
+
+def cli_bool_option(params, command_option, param, true_value='true', false_value='false', separator=None):
+    param = params.get(param)
+    assert isinstance(param, bool)
+    if separator:
+        return [command_option + separator + (true_value if param else false_value)]
+    return [command_option, true_value if param else false_value]
+
+
+def cli_valueless_option(params, command_option, param, expected_value=True):
+    param = params.get(param)
+    return [command_option] if param == expected_value else []
+
+
+def cli_configuration_args(params, param, default=[]):
+    ex_args = params.get(param)
+    if ex_args is None:
+        return default
+    assert isinstance(ex_args, list)
+    return ex_args
+
+
 class ISO639Utils(object):
     # See http://www.loc.gov/standards/iso639-2/ISO-639-2_utf-8.txt
     _lang_map = {