hack for apparently broken parse_qs in python2
[youtube-dl] / youtube_dl / utils.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import gzip
5 import io
6 import locale
7 import os
8 import re
9 import sys
10 import zlib
11 import email.utils
12 import json
13
14 try:
15         import urllib.request as compat_urllib_request
16 except ImportError: # Python 2
17         import urllib2 as compat_urllib_request
18
19 try:
20         import urllib.error as compat_urllib_error
21 except ImportError: # Python 2
22         import urllib2 as compat_urllib_error
23
24 try:
25         import urllib.parse as compat_urllib_parse
26 except ImportError: # Python 2
27         import urllib as compat_urllib_parse
28
29 try:
30         import http.cookiejar as compat_cookiejar
31 except ImportError: # Python 2
32         import cookielib as compat_cookiejar
33
34 try:
35         import html.entities as compat_html_entities
36 except ImportError: # Python 2
37         import htmlentitydefs as compat_html_entities
38
39 try:
40         import html.parser as compat_html_parser
41 except ImportError: # Python 2
42         import HTMLParser as compat_html_parser
43
44 try:
45         import http.client as compat_http_client
46 except ImportError: # Python 2
47         import httplib as compat_http_client
48
49 try:
50         from urllib.parse import parse_qs as compat_parse_qs
51 except ImportError: # Python 2
52         # HACK: The following is the correct parse_qs implementation from cpython 3's stdlib.
53         # Python 2's version is apparently totally broken
54         def _unquote(string, encoding='utf-8', errors='replace'):
55                 if string == '':
56                         return string
57                 res = string.split('%')
58                 if len(res) == 1:
59                         return string
60                 if encoding is None:
61                         encoding = 'utf-8'
62                 if errors is None:
63                         errors = 'replace'
64                 # pct_sequence: contiguous sequence of percent-encoded bytes, decoded
65                 pct_sequence = b''
66                 string = res[0]
67                 for item in res[1:]:
68                         try:
69                                 if not item:
70                                         raise ValueError
71                                 pct_sequence += item[:2].decode('hex')
72                                 rest = item[2:]
73                                 if not rest:
74                                         # This segment was just a single percent-encoded character.
75                                         # May be part of a sequence of code units, so delay decoding.
76                                         # (Stored in pct_sequence).
77                                         continue
78                         except ValueError:
79                                 rest = '%' + item
80                         # Encountered non-percent-encoded characters. Flush the current
81                         # pct_sequence.
82                         string += pct_sequence.decode(encoding, errors) + rest
83                         pct_sequence = b''
84                 if pct_sequence:
85                         # Flush the final pct_sequence
86                         string += pct_sequence.decode(encoding, errors)
87                 return string
88
89         def _parse_qsl(qs, keep_blank_values=False, strict_parsing=False,
90                                 encoding='utf-8', errors='replace'):
91                 qs, _coerce_result = qs, unicode
92                 pairs = [s2 for s1 in qs.split('&') for s2 in s1.split(';')]
93                 r = []
94                 for name_value in pairs:
95                         if not name_value and not strict_parsing:
96                                 continue
97                         nv = name_value.split('=', 1)
98                         if len(nv) != 2:
99                                 if strict_parsing:
100                                         raise ValueError("bad query field: %r" % (name_value,))
101                                 # Handle case of a control-name with no equal sign
102                                 if keep_blank_values:
103                                         nv.append('')
104                                 else:
105                                         continue
106                         if len(nv[1]) or keep_blank_values:
107                                 name = nv[0].replace('+', ' ')
108                                 name = _unquote(name, encoding=encoding, errors=errors)
109                                 name = _coerce_result(name)
110                                 value = nv[1].replace('+', ' ')
111                                 value = _unquote(value, encoding=encoding, errors=errors)
112                                 value = _coerce_result(value)
113                                 r.append((name, value))
114                 return r
115
116         def compat_parse_qs(qs, keep_blank_values=False, strict_parsing=False,
117                                 encoding='utf-8', errors='replace'):
118                 parsed_result = {}
119                 pairs = _parse_qsl(qs, keep_blank_values, strict_parsing,
120                                                 encoding=encoding, errors=errors)
121                 for name, value in pairs:
122                         if name in parsed_result:
123                                 parsed_result[name].append(value)
124                         else:
125                                 parsed_result[name] = [value]
126                 return parsed_result
127
128 try:
129         compat_str = unicode # Python 2
130 except NameError:
131         compat_str = str
132
133 try:
134         compat_chr = unichr # Python 2
135 except NameError:
136         compat_chr = chr
137
138 std_headers = {
139         'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:10.0) Gecko/20100101 Firefox/10.0',
140         'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7',
141         'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
142         'Accept-Encoding': 'gzip, deflate',
143         'Accept-Language': 'en-us,en;q=0.5',
144 }
145 def preferredencoding():
146         """Get preferred encoding.
147
148         Returns the best encoding scheme for the system, based on
149         locale.getpreferredencoding() and some further tweaks.
150         """
151         try:
152                 pref = locale.getpreferredencoding()
153                 u'TEST'.encode(pref)
154         except:
155                 pref = 'UTF-8'
156
157         return pref
158
159 if sys.version_info < (3,0):
160         def compat_print(s):
161                 print(s.encode(preferredencoding(), 'xmlcharrefreplace'))
162 else:
163         def compat_print(s):
164                 assert type(s) == type(u'')
165                 print(s)
166
167 def htmlentity_transform(matchobj):
168         """Transforms an HTML entity to a character.
169
170         This function receives a match object and is intended to be used with
171         the re.sub() function.
172         """
173         entity = matchobj.group(1)
174
175         # Known non-numeric HTML entity
176         if entity in compat_html_entities.name2codepoint:
177                 return compat_chr(compat_html_entities.name2codepoint[entity])
178
179         mobj = re.match(u'(?u)#(x?\\d+)', entity)
180         if mobj is not None:
181                 numstr = mobj.group(1)
182                 if numstr.startswith(u'x'):
183                         base = 16
184                         numstr = u'0%s' % numstr
185                 else:
186                         base = 10
187                 return compat_chr(int(numstr, base))
188
189         # Unknown entity in name, return its literal representation
190         return (u'&%s;' % entity)
191
192 compat_html_parser.locatestarttagend = re.compile(r"""<[a-zA-Z][-.a-zA-Z0-9:_]*(?:\s+(?:(?<=['"\s])[^\s/>][^\s/=>]*(?:\s*=+\s*(?:'[^']*'|"[^"]*"|(?!['"])[^>\s]*))?\s*)*)?\s*""", re.VERBOSE) # backport bugfix
193 class IDParser(compat_html_parser.HTMLParser):
194         """Modified HTMLParser that isolates a tag with the specified id"""
195         def __init__(self, id):
196                 self.id = id
197                 self.result = None
198                 self.started = False
199                 self.depth = {}
200                 self.html = None
201                 self.watch_startpos = False
202                 self.error_count = 0
203                 compat_html_parser.HTMLParser.__init__(self)
204
205         def error(self, message):
206                 if self.error_count > 10 or self.started:
207                         raise compat_html_parser.HTMLParseError(message, self.getpos())
208                 self.rawdata = '\n'.join(self.html.split('\n')[self.getpos()[0]:]) # skip one line
209                 self.error_count += 1
210                 self.goahead(1)
211
212         def loads(self, html):
213                 self.html = html
214                 self.feed(html)
215                 self.close()
216
217         def handle_starttag(self, tag, attrs):
218                 attrs = dict(attrs)
219                 if self.started:
220                         self.find_startpos(None)
221                 if 'id' in attrs and attrs['id'] == self.id:
222                         self.result = [tag]
223                         self.started = True
224                         self.watch_startpos = True
225                 if self.started:
226                         if not tag in self.depth: self.depth[tag] = 0
227                         self.depth[tag] += 1
228
229         def handle_endtag(self, tag):
230                 if self.started:
231                         if tag in self.depth: self.depth[tag] -= 1
232                         if self.depth[self.result[0]] == 0:
233                                 self.started = False
234                                 self.result.append(self.getpos())
235
236         def find_startpos(self, x):
237                 """Needed to put the start position of the result (self.result[1])
238                 after the opening tag with the requested id"""
239                 if self.watch_startpos:
240                         self.watch_startpos = False
241                         self.result.append(self.getpos())
242         handle_entityref = handle_charref = handle_data = handle_comment = \
243         handle_decl = handle_pi = unknown_decl = find_startpos
244
245         def get_result(self):
246                 if self.result is None:
247                         return None
248                 if len(self.result) != 3:
249                         return None
250                 lines = self.html.split('\n')
251                 lines = lines[self.result[1][0]-1:self.result[2][0]]
252                 lines[0] = lines[0][self.result[1][1]:]
253                 if len(lines) == 1:
254                         lines[-1] = lines[-1][:self.result[2][1]-self.result[1][1]]
255                 lines[-1] = lines[-1][:self.result[2][1]]
256                 return '\n'.join(lines).strip()
257
258 def get_element_by_id(id, html):
259         """Return the content of the tag with the specified id in the passed HTML document"""
260         parser = IDParser(id)
261         try:
262                 parser.loads(html)
263         except compat_html_parser.HTMLParseError:
264                 pass
265         return parser.get_result()
266
267
268 def clean_html(html):
269         """Clean an HTML snippet into a readable string"""
270         # Newline vs <br />
271         html = html.replace('\n', ' ')
272         html = re.sub('\s*<\s*br\s*/?\s*>\s*', '\n', html)
273         # Strip html tags
274         html = re.sub('<.*?>', '', html)
275         # Replace html entities
276         html = unescapeHTML(html)
277         return html
278
279
280 def sanitize_open(filename, open_mode):
281         """Try to open the given filename, and slightly tweak it if this fails.
282
283         Attempts to open the given filename. If this fails, it tries to change
284         the filename slightly, step by step, until it's either able to open it
285         or it fails and raises a final exception, like the standard open()
286         function.
287
288         It returns the tuple (stream, definitive_file_name).
289         """
290         try:
291                 if filename == u'-':
292                         if sys.platform == 'win32':
293                                 import msvcrt
294                                 msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
295                         return (sys.stdout, filename)
296                 stream = open(encodeFilename(filename), open_mode)
297                 return (stream, filename)
298         except (IOError, OSError) as err:
299                 # In case of error, try to remove win32 forbidden chars
300                 filename = re.sub(u'[/<>:"\\|\\\\?\\*]', u'#', filename)
301
302                 # An exception here should be caught in the caller
303                 stream = open(encodeFilename(filename), open_mode)
304                 return (stream, filename)
305
306
307 def timeconvert(timestr):
308         """Convert RFC 2822 defined time string into system timestamp"""
309         timestamp = None
310         timetuple = email.utils.parsedate_tz(timestr)
311         if timetuple is not None:
312                 timestamp = email.utils.mktime_tz(timetuple)
313         return timestamp
314
315 def sanitize_filename(s, restricted=False):
316         """Sanitizes a string so it could be used as part of a filename.
317         If restricted is set, use a stricter subset of allowed characters.
318         """
319         def replace_insane(char):
320                 if char == '?' or ord(char) < 32 or ord(char) == 127:
321                         return ''
322                 elif char == '"':
323                         return '' if restricted else '\''
324                 elif char == ':':
325                         return '_-' if restricted else ' -'
326                 elif char in '\\/|*<>':
327                         return '_'
328                 if restricted and (char in '!&\'' or char.isspace()):
329                         return '_'
330                 if restricted and ord(char) > 127:
331                         return '_'
332                 return char
333
334         result = u''.join(map(replace_insane, s))
335         while '__' in result:
336                 result = result.replace('__', '_')
337         result = result.strip('_')
338         # Common case of "Foreign band name - English song title"
339         if restricted and result.startswith('-_'):
340                 result = result[2:]
341         if not result:
342                 result = '_'
343         return result
344
345 def orderedSet(iterable):
346         """ Remove all duplicates from the input iterable """
347         res = []
348         for el in iterable:
349                 if el not in res:
350                         res.append(el)
351         return res
352
353 def unescapeHTML(s):
354         """
355         @param s a string
356         """
357         assert type(s) == type(u'')
358
359         result = re.sub(u'(?u)&(.+?);', htmlentity_transform, s)
360         return result
361
362 def encodeFilename(s):
363         """
364         @param s The name of the file
365         """
366
367         assert type(s) == type(u'')
368
369         # Python 3 has a Unicode API
370         if sys.version_info >= (3, 0):
371                 return s
372
373         if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
374                 # Pass u'' directly to use Unicode APIs on Windows 2000 and up
375                 # (Detecting Windows NT 4 is tricky because 'major >= 4' would
376                 # match Windows 9x series as well. Besides, NT 4 is obsolete.)
377                 return s
378         else:
379                 return s.encode(sys.getfilesystemencoding(), 'ignore')
380
381 class DownloadError(Exception):
382         """Download Error exception.
383
384         This exception may be thrown by FileDownloader objects if they are not
385         configured to continue on errors. They will contain the appropriate
386         error message.
387         """
388         pass
389
390
391 class SameFileError(Exception):
392         """Same File exception.
393
394         This exception will be thrown by FileDownloader objects if they detect
395         multiple files would have to be downloaded to the same file on disk.
396         """
397         pass
398
399
400 class PostProcessingError(Exception):
401         """Post Processing exception.
402
403         This exception may be raised by PostProcessor's .run() method to
404         indicate an error in the postprocessing task.
405         """
406         pass
407
408 class MaxDownloadsReached(Exception):
409         """ --max-downloads limit has been reached. """
410         pass
411
412
413 class UnavailableVideoError(Exception):
414         """Unavailable Format exception.
415
416         This exception will be thrown when a video is requested
417         in a format that is not available for that video.
418         """
419         pass
420
421
422 class ContentTooShortError(Exception):
423         """Content Too Short exception.
424
425         This exception may be raised by FileDownloader objects when a file they
426         download is too small for what the server announced first, indicating
427         the connection was probably interrupted.
428         """
429         # Both in bytes
430         downloaded = None
431         expected = None
432
433         def __init__(self, downloaded, expected):
434                 self.downloaded = downloaded
435                 self.expected = expected
436
437
438 class Trouble(Exception):
439         """Trouble helper exception
440
441         This is an exception to be handled with
442         FileDownloader.trouble
443         """
444
445 class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
446         """Handler for HTTP requests and responses.
447
448         This class, when installed with an OpenerDirector, automatically adds
449         the standard headers to every HTTP request and handles gzipped and
450         deflated responses from web servers. If compression is to be avoided in
451         a particular request, the original request in the program code only has
452         to include the HTTP header "Youtubedl-No-Compression", which will be
453         removed before making the real request.
454
455         Part of this code was copied from:
456
457         http://techknack.net/python-urllib2-handlers/
458
459         Andrew Rowls, the author of that code, agreed to release it to the
460         public domain.
461         """
462
463         @staticmethod
464         def deflate(data):
465                 try:
466                         return zlib.decompress(data, -zlib.MAX_WBITS)
467                 except zlib.error:
468                         return zlib.decompress(data)
469
470         @staticmethod
471         def addinfourl_wrapper(stream, headers, url, code):
472                 if hasattr(compat_urllib_request.addinfourl, 'getcode'):
473                         return compat_urllib_request.addinfourl(stream, headers, url, code)
474                 ret = compat_urllib_request.addinfourl(stream, headers, url)
475                 ret.code = code
476                 return ret
477
478         def http_request(self, req):
479                 for h in std_headers:
480                         if h in req.headers:
481                                 del req.headers[h]
482                         req.add_header(h, std_headers[h])
483                 if 'Youtubedl-no-compression' in req.headers:
484                         if 'Accept-encoding' in req.headers:
485                                 del req.headers['Accept-encoding']
486                         del req.headers['Youtubedl-no-compression']
487                 return req
488
489         def http_response(self, req, resp):
490                 old_resp = resp
491                 # gzip
492                 if resp.headers.get('Content-encoding', '') == 'gzip':
493                         gz = gzip.GzipFile(fileobj=io.BytesIO(resp.read()), mode='r')
494                         resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
495                         resp.msg = old_resp.msg
496                 # deflate
497                 if resp.headers.get('Content-encoding', '') == 'deflate':
498                         gz = io.BytesIO(self.deflate(resp.read()))
499                         resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
500                         resp.msg = old_resp.msg
501                 return resp