hack for apparently broken parse_qs in python2
[youtube-dl] / youtube_dl / utils.py
index 922e17eccfac611a1d90bf83e913383c9afce30d..cf78e9dc843d7bfba008c92a1954add3dca633b4 100644 (file)
 # -*- coding: utf-8 -*-
 
 import gzip
-import htmlentitydefs
-import HTMLParser
+import io
 import locale
 import os
 import re
 import sys
 import zlib
-import urllib2
 import email.utils
 import json
 
 try:
-       import cStringIO as StringIO
-except ImportError:
-       import StringIO
+       import urllib.request as compat_urllib_request
+except ImportError: # Python 2
+       import urllib2 as compat_urllib_request
+
+try:
+       import urllib.error as compat_urllib_error
+except ImportError: # Python 2
+       import urllib2 as compat_urllib_error
+
+try:
+       import urllib.parse as compat_urllib_parse
+except ImportError: # Python 2
+       import urllib as compat_urllib_parse
+
+try:
+       import http.cookiejar as compat_cookiejar
+except ImportError: # Python 2
+       import cookielib as compat_cookiejar
+
+try:
+       import html.entities as compat_html_entities
+except ImportError: # Python 2
+       import htmlentitydefs as compat_html_entities
+
+try:
+       import html.parser as compat_html_parser
+except ImportError: # Python 2
+       import HTMLParser as compat_html_parser
+
+try:
+       import http.client as compat_http_client
+except ImportError: # Python 2
+       import httplib as compat_http_client
+
+try:
+       from urllib.parse import parse_qs as compat_parse_qs
+except ImportError: # Python 2
+       # HACK: The following is the correct parse_qs implementation from cpython 3's stdlib.
+       # Python 2's version is apparently totally broken
+       def _unquote(string, encoding='utf-8', errors='replace'):
+               if string == '':
+                       return string
+               res = string.split('%')
+               if len(res) == 1:
+                       return string
+               if encoding is None:
+                       encoding = 'utf-8'
+               if errors is None:
+                       errors = 'replace'
+               # pct_sequence: contiguous sequence of percent-encoded bytes, decoded
+               pct_sequence = b''
+               string = res[0]
+               for item in res[1:]:
+                       try:
+                               if not item:
+                                       raise ValueError
+                               pct_sequence += item[:2].decode('hex')
+                               rest = item[2:]
+                               if not rest:
+                                       # This segment was just a single percent-encoded character.
+                                       # May be part of a sequence of code units, so delay decoding.
+                                       # (Stored in pct_sequence).
+                                       continue
+                       except ValueError:
+                               rest = '%' + item
+                       # Encountered non-percent-encoded characters. Flush the current
+                       # pct_sequence.
+                       string += pct_sequence.decode(encoding, errors) + rest
+                       pct_sequence = b''
+               if pct_sequence:
+                       # Flush the final pct_sequence
+                       string += pct_sequence.decode(encoding, errors)
+               return string
+
+       def _parse_qsl(qs, keep_blank_values=False, strict_parsing=False,
+                               encoding='utf-8', errors='replace'):
+               qs, _coerce_result = qs, unicode
+               pairs = [s2 for s1 in qs.split('&') for s2 in s1.split(';')]
+               r = []
+               for name_value in pairs:
+                       if not name_value and not strict_parsing:
+                               continue
+                       nv = name_value.split('=', 1)
+                       if len(nv) != 2:
+                               if strict_parsing:
+                                       raise ValueError("bad query field: %r" % (name_value,))
+                               # Handle case of a control-name with no equal sign
+                               if keep_blank_values:
+                                       nv.append('')
+                               else:
+                                       continue
+                       if len(nv[1]) or keep_blank_values:
+                               name = nv[0].replace('+', ' ')
+                               name = _unquote(name, encoding=encoding, errors=errors)
+                               name = _coerce_result(name)
+                               value = nv[1].replace('+', ' ')
+                               value = _unquote(value, encoding=encoding, errors=errors)
+                               value = _coerce_result(value)
+                               r.append((name, value))
+               return r
+
+       def compat_parse_qs(qs, keep_blank_values=False, strict_parsing=False,
+                               encoding='utf-8', errors='replace'):
+               parsed_result = {}
+               pairs = _parse_qsl(qs, keep_blank_values, strict_parsing,
+                                               encoding=encoding, errors=errors)
+               for name, value in pairs:
+                       if name in parsed_result:
+                               parsed_result[name].append(value)
+                       else:
+                               parsed_result[name] = [value]
+               return parsed_result
+
+try:
+       compat_str = unicode # Python 2
+except NameError:
+       compat_str = str
+
+try:
+       compat_chr = unichr # Python 2
+except NameError:
+       compat_chr = chr
 
 std_headers = {
-       'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:5.0.1) Gecko/20100101 Firefox/5.0.1',
+       'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:10.0) Gecko/20100101 Firefox/10.0',
        'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7',
        'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
        'Accept-Encoding': 'gzip, deflate',
        'Accept-Language': 'en-us,en;q=0.5',
 }
-
 def preferredencoding():
        """Get preferred encoding.
 
        Returns the best encoding scheme for the system, based on
        locale.getpreferredencoding() and some further tweaks.
        """
-       def yield_preferredencoding():
-               try:
-                       pref = locale.getpreferredencoding()
-                       u'TEST'.encode(pref)
-               except:
-                       pref = 'UTF-8'
-               while True:
-                       yield pref
-       return yield_preferredencoding().next()
+       try:
+               pref = locale.getpreferredencoding()
+               u'TEST'.encode(pref)
+       except:
+               pref = 'UTF-8'
+
+       return pref
 
+if sys.version_info < (3,0):
+       def compat_print(s):
+               print(s.encode(preferredencoding(), 'xmlcharrefreplace'))
+else:
+       def compat_print(s):
+               assert type(s) == type(u'')
+               print(s)
 
 def htmlentity_transform(matchobj):
-       """Transforms an HTML entity to a Unicode character.
+       """Transforms an HTML entity to a character.
 
        This function receives a match object and is intended to be used with
        the re.sub() function.
@@ -52,11 +173,10 @@ def htmlentity_transform(matchobj):
        entity = matchobj.group(1)
 
        # Known non-numeric HTML entity
-       if entity in htmlentitydefs.name2codepoint:
-               return unichr(htmlentitydefs.name2codepoint[entity])
+       if entity in compat_html_entities.name2codepoint:
+               return compat_chr(compat_html_entities.name2codepoint[entity])
 
-       # Unicode character
-       mobj = re.match(ur'(?u)#(x?\d+)', entity)
+       mobj = re.match(u'(?u)#(x?\\d+)', entity)
        if mobj is not None:
                numstr = mobj.group(1)
                if numstr.startswith(u'x'):
@@ -64,13 +184,13 @@ def htmlentity_transform(matchobj):
                        numstr = u'0%s' % numstr
                else:
                        base = 10
-               return unichr(long(numstr, base))
+               return compat_chr(int(numstr, base))
 
        # Unknown entity in name, return its literal representation
        return (u'&%s;' % entity)
 
-HTMLParser.locatestarttagend = re.compile(r"""<[a-zA-Z][-.a-zA-Z0-9:_]*(?:\s+(?:(?<=['"\s])[^\s/>][^\s/=>]*(?:\s*=+\s*(?:'[^']*'|"[^"]*"|(?!['"])[^>\s]*))?\s*)*)?\s*""", re.VERBOSE) # backport bugfix
-class IDParser(HTMLParser.HTMLParser):
+compat_html_parser.locatestarttagend = re.compile(r"""<[a-zA-Z][-.a-zA-Z0-9:_]*(?:\s+(?:(?<=['"\s])[^\s/>][^\s/=>]*(?:\s*=+\s*(?:'[^']*'|"[^"]*"|(?!['"])[^>\s]*))?\s*)*)?\s*""", re.VERBOSE) # backport bugfix
+class IDParser(compat_html_parser.HTMLParser):
        """Modified HTMLParser that isolates a tag with the specified id"""
        def __init__(self, id):
                self.id = id
@@ -80,12 +200,11 @@ class IDParser(HTMLParser.HTMLParser):
                self.html = None
                self.watch_startpos = False
                self.error_count = 0
-               HTMLParser.HTMLParser.__init__(self)
+               compat_html_parser.HTMLParser.__init__(self)
 
        def error(self, message):
-               print >> sys.stderr, self.getpos()
                if self.error_count > 10 or self.started:
-                       raise HTMLParser.HTMLParseError(message, self.getpos())
+                       raise compat_html_parser.HTMLParseError(message, self.getpos())
                self.rawdata = '\n'.join(self.html.split('\n')[self.getpos()[0]:]) # skip one line
                self.error_count += 1
                self.goahead(1)
@@ -124,8 +243,10 @@ class IDParser(HTMLParser.HTMLParser):
        handle_decl = handle_pi = unknown_decl = find_startpos
 
        def get_result(self):
-               if self.result == None: return None
-               if len(self.result) != 3: return None
+               if self.result is None:
+                       return None
+               if len(self.result) != 3:
+                       return None
                lines = self.html.split('\n')
                lines = lines[self.result[1][0]-1:self.result[2][0]]
                lines[0] = lines[0][self.result[1][1]:]
@@ -139,7 +260,7 @@ def get_element_by_id(id, html):
        parser = IDParser(id)
        try:
                parser.loads(html)
-       except HTMLParser.HTMLParseError:
+       except compat_html_parser.HTMLParseError:
                pass
        return parser.get_result()
 
@@ -174,9 +295,9 @@ def sanitize_open(filename, open_mode):
                        return (sys.stdout, filename)
                stream = open(encodeFilename(filename), open_mode)
                return (stream, filename)
-       except (IOError, OSError), err:
+       except (IOError, OSError) as err:
                # In case of error, try to remove win32 forbidden chars
-               filename = re.sub(ur'[/<>:"\|\?\*]', u'#', filename)
+               filename = re.sub(u'[/<>:"\\|\\\\?\\*]', u'#', filename)
 
                # An exception here should be caught in the caller
                stream = open(encodeFilename(filename), open_mode)
@@ -190,14 +311,36 @@ def timeconvert(timestr):
        if timetuple is not None:
                timestamp = email.utils.mktime_tz(timetuple)
        return timestamp
-       
-def sanitize_filename(s):
-       """Sanitizes a string so it could be used as part of a filename."""
+
+def sanitize_filename(s, restricted=False):
+       """Sanitizes a string so it could be used as part of a filename.
+       If restricted is set, use a stricter subset of allowed characters.
+       """
        def replace_insane(char):
-               if char in u' .\\/|?*<>:"' or ord(char) < 32:
+               if char == '?' or ord(char) < 32 or ord(char) == 127:
+                       return ''
+               elif char == '"':
+                       return '' if restricted else '\''
+               elif char == ':':
+                       return '_-' if restricted else ' -'
+               elif char in '\\/|*<>':
+                       return '_'
+               if restricted and (char in '!&\'' or char.isspace()):
+                       return '_'
+               if restricted and ord(char) > 127:
                        return '_'
                return char
-       return u''.join(map(replace_insane, s)).strip('_')
+
+       result = u''.join(map(replace_insane, s))
+       while '__' in result:
+               result = result.replace('__', '_')
+       result = result.strip('_')
+       # Common case of "Foreign band name - English song title"
+       if restricted and result.startswith('-_'):
+               result = result[2:]
+       if not result:
+               result = '_'
+       return result
 
 def orderedSet(iterable):
        """ Remove all duplicates from the input iterable """
@@ -209,20 +352,24 @@ def orderedSet(iterable):
 
 def unescapeHTML(s):
        """
-       @param s a string (of type unicode)
+       @param s a string
        """
        assert type(s) == type(u'')
 
-       result = re.sub(ur'(?u)&(.+?);', htmlentity_transform, s)
+       result = re.sub(u'(?u)&(.+?);', htmlentity_transform, s)
        return result
 
 def encodeFilename(s):
        """
-       @param s The name of the file (of type unicode)
+       @param s The name of the file
        """
 
        assert type(s) == type(u'')
 
+       # Python 3 has a Unicode API
+       if sys.version_info >= (3, 0):
+               return s
+
        if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
                # Pass u'' directly to use Unicode APIs on Windows 2000 and up
                # (Detecting Windows NT 4 is tricky because 'major >= 4' would
@@ -290,12 +437,12 @@ class ContentTooShortError(Exception):
 
 class Trouble(Exception):
        """Trouble helper exception
-       
+
        This is an exception to be handled with
        FileDownloader.trouble
        """
 
-class YoutubeDLHandler(urllib2.HTTPHandler):
+class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
        """Handler for HTTP requests and responses.
 
        This class, when installed with an OpenerDirector, automatically adds
@@ -322,9 +469,9 @@ class YoutubeDLHandler(urllib2.HTTPHandler):
 
        @staticmethod
        def addinfourl_wrapper(stream, headers, url, code):
-               if hasattr(urllib2.addinfourl, 'getcode'):
-                       return urllib2.addinfourl(stream, headers, url, code)
-               ret = urllib2.addinfourl(stream, headers, url)
+               if hasattr(compat_urllib_request.addinfourl, 'getcode'):
+                       return compat_urllib_request.addinfourl(stream, headers, url, code)
+               ret = compat_urllib_request.addinfourl(stream, headers, url)
                ret.code = code
                return ret
 
@@ -343,12 +490,12 @@ class YoutubeDLHandler(urllib2.HTTPHandler):
                old_resp = resp
                # gzip
                if resp.headers.get('Content-encoding', '') == 'gzip':
-                       gz = gzip.GzipFile(fileobj=StringIO.StringIO(resp.read()), mode='r')
+                       gz = gzip.GzipFile(fileobj=io.BytesIO(resp.read()), mode='r')
                        resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
                        resp.msg = old_resp.msg
                # deflate
                if resp.headers.get('Content-encoding', '') == 'deflate':
-                       gz = StringIO.StringIO(self.deflate(resp.read()))
+                       gz = io.BytesIO(self.deflate(resp.read()))
                        resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
                        resp.msg = old_resp.msg
                return resp