Merge pull request #14225 from Tithen-Firion/openload-phantomjs-method
[youtube-dl] / youtube_dl / utils.py
index 94e1b07a627983639553938977c5bd7005649445..9e4492d402c225d53071d1424005ff5be0577681 100644 (file)
@@ -11,6 +11,7 @@ import contextlib
 import ctypes
 import datetime
 import email.utils
+import email.header
 import errno
 import functools
 import gzip
@@ -21,7 +22,6 @@ import locale
 import math
 import operator
 import os
-import pipes
 import platform
 import random
 import re
@@ -35,6 +35,7 @@ import xml.etree.ElementTree
 import zlib
 
 from .compat import (
+    compat_HTMLParseError,
     compat_HTMLParser,
     compat_basestring,
     compat_chr,
@@ -364,9 +365,9 @@ def get_elements_by_attribute(attribute, value, html, escape_value=True):
     retlist = []
     for m in re.finditer(r'''(?xs)
         <([a-zA-Z0-9:._-]+)
-         (?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]*|="[^"]*"|='[^']*'))*?
+         (?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]*|="[^"]*"|='[^']*'|))*?
          \s+%s=['"]?%s['"]?
-         (?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]*|="[^"]*"|='[^']*'))*?
+         (?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]*|="[^"]*"|='[^']*'|))*?
         \s*>
         (?P<content>.*?)
         </\1>
@@ -408,8 +409,12 @@ def extract_attributes(html_element):
     but the cases in the unit test will work for all of 2.6, 2.7, 3.2-3.5.
     """
     parser = HTMLAttributeParser()
-    parser.feed(html_element)
-    parser.close()
+    try:
+        parser.feed(html_element)
+        parser.close()
+    # Older Python may throw HTMLParseError in case of malformed HTML
+    except compat_HTMLParseError:
+        pass
     return parser.attrs
 
 
@@ -421,8 +426,8 @@ def clean_html(html):
 
     # Newline vs <br />
     html = html.replace('\n', ' ')
-    html = re.sub(r'\s*<\s*br\s*/?\s*>\s*', '\n', html)
-    html = re.sub(r'<\s*/\s*p\s*>\s*<\s*p[^>]*>', '\n', html)
+    html = re.sub(r'(?u)\s*<\s*br\s*/?\s*>\s*', '\n', html)
+    html = re.sub(r'(?u)<\s*/\s*p\s*>\s*<\s*p[^>]*>', '\n', html)
     # Strip html tags
     html = re.sub('<.*?>', '', html)
     # Replace html entities
@@ -591,7 +596,7 @@ def unescapeHTML(s):
     assert type(s) == compat_str
 
     return re.sub(
-        r'&([^;]+;)', lambda m: _htmlentity_transform(m.group(1)), s)
+        r'&([^&;]+;)', lambda m: _htmlentity_transform(m.group(1)), s)
 
 
 def get_subprocess_encoding():
@@ -931,14 +936,6 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
         except zlib.error:
             return zlib.decompress(data)
 
-    @staticmethod
-    def addinfourl_wrapper(stream, headers, url, code):
-        if hasattr(compat_urllib_request.addinfourl, 'getcode'):
-            return compat_urllib_request.addinfourl(stream, headers, url, code)
-        ret = compat_urllib_request.addinfourl(stream, headers, url)
-        ret.code = code
-        return ret
-
     def http_request(self, req):
         # According to RFC 3986, URLs can not contain non-ASCII characters, however this is not
         # always respected by websites, some tend to give out URLs with non percent-encoded
@@ -990,13 +987,13 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
                     break
                 else:
                     raise original_ioerror
-            resp = self.addinfourl_wrapper(uncompressed, old_resp.headers, old_resp.url, old_resp.code)
+            resp = compat_urllib_request.addinfourl(uncompressed, old_resp.headers, old_resp.url, old_resp.code)
             resp.msg = old_resp.msg
             del resp.headers['Content-encoding']
         # deflate
         if resp.headers.get('Content-encoding', '') == 'deflate':
             gz = io.BytesIO(self.deflate(resp.read()))
-            resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
+            resp = compat_urllib_request.addinfourl(gz, old_resp.headers, old_resp.url, old_resp.code)
             resp.msg = old_resp.msg
             del resp.headers['Content-encoding']
         # Percent-encode redirect URL of Location HTTP header to satisfy RFC 3986 (see
@@ -1186,7 +1183,7 @@ def unified_timestamp(date_str, day_first=True):
     if date_str is None:
         return None
 
-    date_str = date_str.replace(',', ' ')
+    date_str = re.sub(r'[,|]', '', date_str)
 
     pm_delta = 12 if re.search(r'(?i)PM', date_str) else 0
     timezone, date_str = extract_timezone(date_str)
@@ -1194,6 +1191,11 @@ def unified_timestamp(date_str, day_first=True):
     # Remove AM/PM + timezone
     date_str = re.sub(r'(?i)\s*(?:AM|PM)(?:\s+[A-Z]+)?', '', date_str)
 
+    # Remove unrecognized timezones from ISO 8601 alike timestamps
+    m = re.search(r'\d{1,2}:\d{1,2}(?:\.\d+)?(?P<tz>\s*[A-Z]+)$', date_str)
+    if m:
+        date_str = date_str[:-len(m.group('tz'))]
+
     for expression in date_formats(day_first):
         try:
             dt = datetime.datetime.strptime(date_str, expression) - timezone + datetime.timedelta(hours=pm_delta)
@@ -1532,7 +1534,7 @@ def shell_quote(args):
         if isinstance(a, bytes):
             # We may get a filename encoded with 'encodeFilename'
             a = a.decode(encoding)
-        quoted_args.append(pipes.quote(a))
+        quoted_args.append(compat_shlex_quote(a))
     return ' '.join(quoted_args)
 
 
@@ -1813,6 +1815,10 @@ def float_or_none(v, scale=1, invscale=1, default=None):
         return default
 
 
+def bool_or_none(v, default=None):
+    return v if isinstance(v, bool) else default
+
+
 def strip_or_none(v):
     return None if v is None else v.strip()
 
@@ -2092,6 +2098,58 @@ def update_Request(req, url=None, data=None, headers={}, query={}):
     return new_req
 
 
+def _multipart_encode_impl(data, boundary):
+    content_type = 'multipart/form-data; boundary=%s' % boundary
+
+    out = b''
+    for k, v in data.items():
+        out += b'--' + boundary.encode('ascii') + b'\r\n'
+        if isinstance(k, compat_str):
+            k = k.encode('utf-8')
+        if isinstance(v, compat_str):
+            v = v.encode('utf-8')
+        # RFC 2047 requires non-ASCII field names to be encoded, while RFC 7578
+        # suggests sending UTF-8 directly. Firefox sends UTF-8, too
+        content = b'Content-Disposition: form-data; name="' + k + b'"\r\n\r\n' + v + b'\r\n'
+        if boundary.encode('ascii') in content:
+            raise ValueError('Boundary overlaps with data')
+        out += content
+
+    out += b'--' + boundary.encode('ascii') + b'--\r\n'
+
+    return out, content_type
+
+
+def multipart_encode(data, boundary=None):
+    '''
+    Encode a dict to RFC 7578-compliant form-data
+
+    data:
+        A dict where keys and values can be either Unicode or bytes-like
+        objects.
+    boundary:
+        If specified a Unicode object, it's used as the boundary. Otherwise
+        a random boundary is generated.
+
+    Reference: https://tools.ietf.org/html/rfc7578
+    '''
+    has_specified_boundary = boundary is not None
+
+    while True:
+        if boundary is None:
+            boundary = '---------------' + str(random.randrange(0x0fffffff, 0xffffffff))
+
+        try:
+            out, content_type = _multipart_encode_impl(data, boundary)
+            break
+        except ValueError:
+            if has_specified_boundary:
+                raise
+            boundary = None
+
+    return out, content_type
+
+
 def dict_get(d, key_or_keys, default=None, skip_false_values=True):
     if isinstance(key_or_keys, (list, tuple)):
         for key in key_or_keys:
@@ -2103,13 +2161,16 @@ def dict_get(d, key_or_keys, default=None, skip_false_values=True):
 
 
 def try_get(src, getter, expected_type=None):
-    try:
-        v = getter(src)
-    except (AttributeError, KeyError, TypeError, IndexError):
-        pass
-    else:
-        if expected_type is None or isinstance(v, expected_type):
-            return v
+    if not isinstance(getter, (list, tuple)):
+        getter = [getter]
+    for get in getter:
+        try:
+            v = get(src)
+        except (AttributeError, KeyError, TypeError, IndexError):
+            pass
+        else:
+            if expected_type is None or isinstance(v, expected_type):
+                return v
 
 
 def encode_compat_str(string, encoding=preferredencoding(), errors='strict'):
@@ -2150,7 +2211,12 @@ def parse_age_limit(s):
 
 def strip_jsonp(code):
     return re.sub(
-        r'(?s)^[a-zA-Z0-9_.$]+\s*\(\s*(.*)\);?\s*?(?://[^\n]*)*$', r'\1', code)
+        r'''(?sx)^
+            (?:window\.)?(?P<func_name>[a-zA-Z0-9_.$]+)
+            (?:\s*&&\s*(?P=func_name))?
+            \s*\(\s*(?P<callback_data>.*)\);?
+            \s*?(?://[^\n]*)*$''',
+        r'\g<callback_data>', code)
 
 
 def js_to_json(code):
@@ -2270,10 +2336,8 @@ def mimetype2ext(mt):
     return {
         '3gpp': '3gp',
         'smptett+xml': 'tt',
-        'srt': 'srt',
         'ttaf+xml': 'dfxp',
         'ttml+xml': 'ttml',
-        'vtt': 'vtt',
         'x-flv': 'flv',
         'x-mp4-fragmented': 'mp4',
         'x-ms-wmv': 'wmv',
@@ -2281,11 +2345,11 @@ def mimetype2ext(mt):
         'x-mpegurl': 'm3u8',
         'vnd.apple.mpegurl': 'm3u8',
         'dash+xml': 'mpd',
-        'f4m': 'f4m',
         'f4m+xml': 'f4m',
         'hds+xml': 'f4m',
         'vnd.ms-sstr+xml': 'ism',
         'quicktime': 'mov',
+        'mp2t': 'ts',
     }.get(res, res)
 
 
@@ -2301,11 +2365,11 @@ def parse_codecs(codecs_str):
         if codec in ('avc1', 'avc2', 'avc3', 'avc4', 'vp9', 'vp8', 'hev1', 'hev2', 'h263', 'h264', 'mp4v'):
             if not vcodec:
                 vcodec = full_codec
-        elif codec in ('mp4a', 'opus', 'vorbis', 'mp3', 'aac', 'ac-3'):
+        elif codec in ('mp4a', 'opus', 'vorbis', 'mp3', 'aac', 'ac-3', 'ec-3', 'eac3', 'dtsc', 'dtse', 'dtsh', 'dtsl'):
             if not acodec:
                 acodec = full_codec
         else:
-            write_string('WARNING: Unknown codec %s' % full_codec, sys.stderr)
+            write_string('WARNING: Unknown codec %s\n' % full_codec, sys.stderr)
     if not vcodec and not acodec:
         if len(splited_codecs) == 2:
             return {
@@ -2508,27 +2572,97 @@ def srt_subtitles_timecode(seconds):
 
 
 def dfxp2srt(dfxp_data):
+    LEGACY_NAMESPACES = (
+        ('http://www.w3.org/ns/ttml', [
+            'http://www.w3.org/2004/11/ttaf1',
+            'http://www.w3.org/2006/04/ttaf1',
+            'http://www.w3.org/2006/10/ttaf1',
+        ]),
+        ('http://www.w3.org/ns/ttml#styling', [
+            'http://www.w3.org/ns/ttml#style',
+        ]),
+    )
+
+    SUPPORTED_STYLING = [
+        'color',
+        'fontFamily',
+        'fontSize',
+        'fontStyle',
+        'fontWeight',
+        'textDecoration'
+    ]
+
     _x = functools.partial(xpath_with_ns, ns_map={
         'ttml': 'http://www.w3.org/ns/ttml',
-        'ttaf1': 'http://www.w3.org/2006/10/ttaf1',
-        'ttaf1_0604': 'http://www.w3.org/2006/04/ttaf1',
+        'tts': 'http://www.w3.org/ns/ttml#styling',
     })
 
+    styles = {}
+    default_style = {}
+
     class TTMLPElementParser(object):
-        out = ''
+        _out = ''
+        _unclosed_elements = []
+        _applied_styles = []
 
         def start(self, tag, attrib):
-            if tag in (_x('ttml:br'), _x('ttaf1:br'), 'br'):
-                self.out += '\n'
+            if tag in (_x('ttml:br'), 'br'):
+                self._out += '\n'
+            else:
+                unclosed_elements = []
+                style = {}
+                element_style_id = attrib.get('style')
+                if default_style:
+                    style.update(default_style)
+                if element_style_id:
+                    style.update(styles.get(element_style_id, {}))
+                for prop in SUPPORTED_STYLING:
+                    prop_val = attrib.get(_x('tts:' + prop))
+                    if prop_val:
+                        style[prop] = prop_val
+                if style:
+                    font = ''
+                    for k, v in sorted(style.items()):
+                        if self._applied_styles and self._applied_styles[-1].get(k) == v:
+                            continue
+                        if k == 'color':
+                            font += ' color="%s"' % v
+                        elif k == 'fontSize':
+                            font += ' size="%s"' % v
+                        elif k == 'fontFamily':
+                            font += ' face="%s"' % v
+                        elif k == 'fontWeight' and v == 'bold':
+                            self._out += '<b>'
+                            unclosed_elements.append('b')
+                        elif k == 'fontStyle' and v == 'italic':
+                            self._out += '<i>'
+                            unclosed_elements.append('i')
+                        elif k == 'textDecoration' and v == 'underline':
+                            self._out += '<u>'
+                            unclosed_elements.append('u')
+                    if font:
+                        self._out += '<font' + font + '>'
+                        unclosed_elements.append('font')
+                    applied_style = {}
+                    if self._applied_styles:
+                        applied_style.update(self._applied_styles[-1])
+                    applied_style.update(style)
+                    self._applied_styles.append(applied_style)
+                self._unclosed_elements.append(unclosed_elements)
 
         def end(self, tag):
-            pass
+            if tag not in (_x('ttml:br'), 'br'):
+                unclosed_elements = self._unclosed_elements.pop()
+                for element in reversed(unclosed_elements):
+                    self._out += '</%s>' % element
+                if unclosed_elements and self._applied_styles:
+                    self._applied_styles.pop()
 
         def data(self, data):
-            self.out += data
+            self._out += data
 
         def close(self):
-            return self.out.strip()
+            return self._out.strip()
 
     def parse_node(node):
         target = TTMLPElementParser()
@@ -2536,13 +2670,45 @@ def dfxp2srt(dfxp_data):
         parser.feed(xml.etree.ElementTree.tostring(node))
         return parser.close()
 
+    for k, v in LEGACY_NAMESPACES:
+        for ns in v:
+            dfxp_data = dfxp_data.replace(ns, k)
+
     dfxp = compat_etree_fromstring(dfxp_data.encode('utf-8'))
     out = []
-    paras = dfxp.findall(_x('.//ttml:p')) or dfxp.findall(_x('.//ttaf1:p')) or dfxp.findall(_x('.//ttaf1_0604:p')) or dfxp.findall('.//p')
+    paras = dfxp.findall(_x('.//ttml:p')) or dfxp.findall('.//p')
 
     if not paras:
         raise ValueError('Invalid dfxp/TTML subtitle')
 
+    repeat = False
+    while True:
+        for style in dfxp.findall(_x('.//ttml:style')):
+            style_id = style.get('id')
+            parent_style_id = style.get('style')
+            if parent_style_id:
+                if parent_style_id not in styles:
+                    repeat = True
+                    continue
+                styles[style_id] = styles[parent_style_id].copy()
+            for prop in SUPPORTED_STYLING:
+                prop_val = style.get(_x('tts:' + prop))
+                if prop_val:
+                    styles.setdefault(style_id, {})[prop] = prop_val
+        if repeat:
+            repeat = False
+        else:
+            break
+
+    for p in ('body', 'div'):
+        ele = xpath_element(dfxp, [_x('.//ttml:' + p), './/' + p])
+        if ele is None:
+            continue
+        style = styles.get(ele.get('style'))
+        if not style:
+            continue
+        default_style.update(style)
+
     for para, index in zip(paras, itertools.count(1)):
         begin_time = parse_dfxp_time_expr(para.attrib.get('begin'))
         end_time = parse_dfxp_time_expr(para.attrib.get('end'))
@@ -2571,6 +2737,8 @@ def cli_option(params, command_option, param):
 
 def cli_bool_option(params, command_option, param, true_value='true', false_value='false', separator=None):
     param = params.get(param)
+    if param is None:
+        return []
     assert isinstance(param, bool)
     if separator:
         return [command_option + separator + (true_value if param else false_value)]
@@ -3654,6 +3822,37 @@ def write_xattr(path, key, value):
                         "or the 'xattr' binary.")
 
 
+def cookie_to_dict(cookie):
+    cookie_dict = {
+        'name': cookie.name,
+        'value': cookie.value,
+    };
+    if cookie.port_specified:
+        cookie_dict['port'] = cookie.port
+    if cookie.domain_specified:
+        cookie_dict['domain'] = cookie.domain
+    if cookie.path_specified:
+        cookie_dict['path'] = cookie.path
+    if not cookie.expires is None:
+        cookie_dict['expires'] = cookie.expires
+    if not cookie.secure is None:
+        cookie_dict['secure'] = cookie.secure
+    if not cookie.discard is None:
+        cookie_dict['discard'] = cookie.discard
+    try:
+        if (cookie.has_nonstandard_attr('httpOnly') or
+            cookie.has_nonstandard_attr('httponly') or
+            cookie.has_nonstandard_attr('HttpOnly')):
+            cookie_dict['httponly'] = True
+    except TypeError:
+        pass
+    return cookie_dict
+
+
+def cookie_jar_to_list(cookie_jar):
+    return [cookie_to_dict(cookie) for cookie in cookie_jar]
+
+
 class PhantomJSwrapper(object):
     """PhantomJS wrapper class"""
 
@@ -3674,6 +3873,9 @@ class PhantomJSwrapper(object):
         var fs = require('fs');
         var read = {{ mode: 'r', charset: 'utf-8' }};
         var write = {{ mode: 'w', charset: 'utf-8' }};
+        JSON.parse(fs.read("{cookies}", read)).forEach(function(x) {{
+          phantom.addCookie(x);
+        }});
         page.settings.resourceTimeout = {timeout};
         page.settings.userAgent = "{ua}";
         page.onLoadStarted = function() {{
@@ -3684,6 +3886,7 @@ class PhantomJSwrapper(object):
         }};
         var saveAndExit = function() {{
           fs.write("{html}", page.content, write);
+          fs.write("{cookies}", JSON.stringify(phantom.cookies), write);
           phantom.exit();
         }};
         page.onLoadFinished = function(status) {{
@@ -3697,15 +3900,28 @@ class PhantomJSwrapper(object):
         page.open("");
     '''
 
-    _TMP_FILE_NAMES = ['script', 'html']
+    _TMP_FILE_NAMES = ['script', 'html', 'cookies']
 
-    def __init__(self, extractor, timeout=10000):
+    @staticmethod
+    def _version():
+        return get_exe_version('phantomjs', version_re=r'([0-9.]+)')
+
+    def __init__(self, extractor, required_version=None, timeout=10000):
         self.exe = check_executable('phantomjs', ['-v'])
         if not self.exe:
             raise ExtractorError('PhantomJS executable not found in PATH, '
                                  'download it from http://phantomjs.org',
                                  expected=True)
+
         self.extractor = extractor
+
+        if required_version:
+            version = self._version()
+            if is_outdated_version(version, required_version):
+                self.extractor._downloader.report_warning(
+                    'Your copy of PhantomJS is outdated, update it to version '
+                    '%s or newer if you encounter any errors.' % required_version)
+
         self.options = {
             'timeout': timeout,
         }
@@ -3722,6 +3938,26 @@ class PhantomJSwrapper(object):
             except:
                 pass
 
+    def _save_cookies(self, url):
+        cookies = cookie_jar_to_list(self.extractor._downloader.cookiejar)
+        for cookie in cookies:
+            if 'path' not in cookie:
+                cookie['path'] = '/'
+            if 'domain' not in cookie:
+                cookie['domain'] = compat_urlparse.urlparse(url).netloc
+        with open(self._TMP_FILES['cookies'].name, 'wb') as f:
+            f.write(json.dumps(cookies).encode('utf-8'))
+
+    def _load_cookies(self):
+        with open(self._TMP_FILES['cookies'].name, 'rb') as f:
+            cookies = json.loads(f.read().decode('utf-8'))
+        for cookie in cookies:
+            if cookie['httponly'] is True:
+                cookie['rest'] = { 'httpOnly': None }
+            if 'expiry' in cookie:
+                cookie['expire_time'] = cookie['expiry']
+            self.extractor._set_cookie(**cookie)
+
     def get(self, url, html=None, video_id=None, note=None, note2='Executing JS on webpage', headers={}, jscode='saveAndExit();'):
         """
         Downloads webpage (if needed) and executes JS
@@ -3765,6 +4001,8 @@ class PhantomJSwrapper(object):
         with open(self._TMP_FILES['html'].name, 'wb') as f:
             f.write(html.encode('utf-8'))
 
+        self._save_cookies(url)
+
         replaces = self.options
         replaces['url'] = url
         user_agent = headers.get('User-Agent') or std_headers['User-Agent']
@@ -3791,5 +4029,15 @@ class PhantomJSwrapper(object):
                                  + encodeArgument(err))
         with open(self._TMP_FILES['html'].name, 'rb') as f:
             html = f.read().decode('utf-8')
+
+        self._load_cookies()
+
         return (html, encodeArgument(out))
 
+
+def random_birthday(year_field, month_field, day_field):
+    return {
+        year_field: str(random.randint(1950, 1995)),
+        month_field: str(random.randint(1, 12)),
+        day_field: str(random.randint(1, 31)),
+    }