Convert all tabs to 4 spaces (PEP8)
[youtube-dl] / youtube_dl / utils.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import gzip
5 import io
6 import locale
7 import os
8 import re
9 import sys
10 import zlib
11 import email.utils
12 import json
13
14 try:
15     import urllib.request as compat_urllib_request
16 except ImportError: # Python 2
17     import urllib2 as compat_urllib_request
18
19 try:
20     import urllib.error as compat_urllib_error
21 except ImportError: # Python 2
22     import urllib2 as compat_urllib_error
23
24 try:
25     import urllib.parse as compat_urllib_parse
26 except ImportError: # Python 2
27     import urllib as compat_urllib_parse
28
29 try:
30     import http.cookiejar as compat_cookiejar
31 except ImportError: # Python 2
32     import cookielib as compat_cookiejar
33
34 try:
35     import html.entities as compat_html_entities
36 except ImportError: # Python 2
37     import htmlentitydefs as compat_html_entities
38
39 try:
40     import html.parser as compat_html_parser
41 except ImportError: # Python 2
42     import HTMLParser as compat_html_parser
43
44 try:
45     import http.client as compat_http_client
46 except ImportError: # Python 2
47     import httplib as compat_http_client
48
49 try:
50     from urllib.parse import parse_qs as compat_parse_qs
51 except ImportError: # Python 2
52     # HACK: The following is the correct parse_qs implementation from cpython 3's stdlib.
53     # Python 2's version is apparently totally broken
54     def _unquote(string, encoding='utf-8', errors='replace'):
55         if string == '':
56             return string
57         res = string.split('%')
58         if len(res) == 1:
59             return string
60         if encoding is None:
61             encoding = 'utf-8'
62         if errors is None:
63             errors = 'replace'
64         # pct_sequence: contiguous sequence of percent-encoded bytes, decoded
65         pct_sequence = b''
66         string = res[0]
67         for item in res[1:]:
68             try:
69                 if not item:
70                     raise ValueError
71                 pct_sequence += item[:2].decode('hex')
72                 rest = item[2:]
73                 if not rest:
74                     # This segment was just a single percent-encoded character.
75                     # May be part of a sequence of code units, so delay decoding.
76                     # (Stored in pct_sequence).
77                     continue
78             except ValueError:
79                 rest = '%' + item
80             # Encountered non-percent-encoded characters. Flush the current
81             # pct_sequence.
82             string += pct_sequence.decode(encoding, errors) + rest
83             pct_sequence = b''
84         if pct_sequence:
85             # Flush the final pct_sequence
86             string += pct_sequence.decode(encoding, errors)
87         return string
88
89     def _parse_qsl(qs, keep_blank_values=False, strict_parsing=False,
90                 encoding='utf-8', errors='replace'):
91         qs, _coerce_result = qs, unicode
92         pairs = [s2 for s1 in qs.split('&') for s2 in s1.split(';')]
93         r = []
94         for name_value in pairs:
95             if not name_value and not strict_parsing:
96                 continue
97             nv = name_value.split('=', 1)
98             if len(nv) != 2:
99                 if strict_parsing:
100                     raise ValueError("bad query field: %r" % (name_value,))
101                 # Handle case of a control-name with no equal sign
102                 if keep_blank_values:
103                     nv.append('')
104                 else:
105                     continue
106             if len(nv[1]) or keep_blank_values:
107                 name = nv[0].replace('+', ' ')
108                 name = _unquote(name, encoding=encoding, errors=errors)
109                 name = _coerce_result(name)
110                 value = nv[1].replace('+', ' ')
111                 value = _unquote(value, encoding=encoding, errors=errors)
112                 value = _coerce_result(value)
113                 r.append((name, value))
114         return r
115
116     def compat_parse_qs(qs, keep_blank_values=False, strict_parsing=False,
117                 encoding='utf-8', errors='replace'):
118         parsed_result = {}
119         pairs = _parse_qsl(qs, keep_blank_values, strict_parsing,
120                         encoding=encoding, errors=errors)
121         for name, value in pairs:
122             if name in parsed_result:
123                 parsed_result[name].append(value)
124             else:
125                 parsed_result[name] = [value]
126         return parsed_result
127
128 try:
129     compat_str = unicode # Python 2
130 except NameError:
131     compat_str = str
132
133 try:
134     compat_chr = unichr # Python 2
135 except NameError:
136     compat_chr = chr
137
138 std_headers = {
139     'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:10.0) Gecko/20100101 Firefox/10.0',
140     'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7',
141     'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
142     'Accept-Encoding': 'gzip, deflate',
143     'Accept-Language': 'en-us,en;q=0.5',
144 }
145 def preferredencoding():
146     """Get preferred encoding.
147
148     Returns the best encoding scheme for the system, based on
149     locale.getpreferredencoding() and some further tweaks.
150     """
151     try:
152         pref = locale.getpreferredencoding()
153         u'TEST'.encode(pref)
154     except:
155         pref = 'UTF-8'
156
157     return pref
158
159 if sys.version_info < (3,0):
160     def compat_print(s):
161         print(s.encode(preferredencoding(), 'xmlcharrefreplace'))
162 else:
163     def compat_print(s):
164         assert type(s) == type(u'')
165         print(s)
166
167 def htmlentity_transform(matchobj):
168     """Transforms an HTML entity to a character.
169
170     This function receives a match object and is intended to be used with
171     the re.sub() function.
172     """
173     entity = matchobj.group(1)
174
175     # Known non-numeric HTML entity
176     if entity in compat_html_entities.name2codepoint:
177         return compat_chr(compat_html_entities.name2codepoint[entity])
178
179     mobj = re.match(u'(?u)#(x?\\d+)', entity)
180     if mobj is not None:
181         numstr = mobj.group(1)
182         if numstr.startswith(u'x'):
183             base = 16
184             numstr = u'0%s' % numstr
185         else:
186             base = 10
187         return compat_chr(int(numstr, base))
188
189     # Unknown entity in name, return its literal representation
190     return (u'&%s;' % entity)
191
192 compat_html_parser.locatestarttagend = re.compile(r"""<[a-zA-Z][-.a-zA-Z0-9:_]*(?:\s+(?:(?<=['"\s])[^\s/>][^\s/=>]*(?:\s*=+\s*(?:'[^']*'|"[^"]*"|(?!['"])[^>\s]*))?\s*)*)?\s*""", re.VERBOSE) # backport bugfix
193 class IDParser(compat_html_parser.HTMLParser):
194     """Modified HTMLParser that isolates a tag with the specified id"""
195     def __init__(self, id):
196         self.id = id
197         self.result = None
198         self.started = False
199         self.depth = {}
200         self.html = None
201         self.watch_startpos = False
202         self.error_count = 0
203         compat_html_parser.HTMLParser.__init__(self)
204
205     def error(self, message):
206         if self.error_count > 10 or self.started:
207             raise compat_html_parser.HTMLParseError(message, self.getpos())
208         self.rawdata = '\n'.join(self.html.split('\n')[self.getpos()[0]:]) # skip one line
209         self.error_count += 1
210         self.goahead(1)
211
212     def loads(self, html):
213         self.html = html
214         self.feed(html)
215         self.close()
216
217     def handle_starttag(self, tag, attrs):
218         attrs = dict(attrs)
219         if self.started:
220             self.find_startpos(None)
221         if 'id' in attrs and attrs['id'] == self.id:
222             self.result = [tag]
223             self.started = True
224             self.watch_startpos = True
225         if self.started:
226             if not tag in self.depth: self.depth[tag] = 0
227             self.depth[tag] += 1
228
229     def handle_endtag(self, tag):
230         if self.started:
231             if tag in self.depth: self.depth[tag] -= 1
232             if self.depth[self.result[0]] == 0:
233                 self.started = False
234                 self.result.append(self.getpos())
235
236     def find_startpos(self, x):
237         """Needed to put the start position of the result (self.result[1])
238         after the opening tag with the requested id"""
239         if self.watch_startpos:
240             self.watch_startpos = False
241             self.result.append(self.getpos())
242     handle_entityref = handle_charref = handle_data = handle_comment = \
243     handle_decl = handle_pi = unknown_decl = find_startpos
244
245     def get_result(self):
246         if self.result is None:
247             return None
248         if len(self.result) != 3:
249             return None
250         lines = self.html.split('\n')
251         lines = lines[self.result[1][0]-1:self.result[2][0]]
252         lines[0] = lines[0][self.result[1][1]:]
253         if len(lines) == 1:
254             lines[-1] = lines[-1][:self.result[2][1]-self.result[1][1]]
255         lines[-1] = lines[-1][:self.result[2][1]]
256         return '\n'.join(lines).strip()
257
258 def get_element_by_id(id, html):
259     """Return the content of the tag with the specified id in the passed HTML document"""
260     parser = IDParser(id)
261     try:
262         parser.loads(html)
263     except compat_html_parser.HTMLParseError:
264         pass
265     return parser.get_result()
266
267
268 def clean_html(html):
269     """Clean an HTML snippet into a readable string"""
270     # Newline vs <br />
271     html = html.replace('\n', ' ')
272     html = re.sub('\s*<\s*br\s*/?\s*>\s*', '\n', html)
273     # Strip html tags
274     html = re.sub('<.*?>', '', html)
275     # Replace html entities
276     html = unescapeHTML(html)
277     return html
278
279
280 def sanitize_open(filename, open_mode):
281     """Try to open the given filename, and slightly tweak it if this fails.
282
283     Attempts to open the given filename. If this fails, it tries to change
284     the filename slightly, step by step, until it's either able to open it
285     or it fails and raises a final exception, like the standard open()
286     function.
287
288     It returns the tuple (stream, definitive_file_name).
289     """
290     try:
291         if filename == u'-':
292             if sys.platform == 'win32':
293                 import msvcrt
294                 msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
295             return (sys.stdout, filename)
296         stream = open(encodeFilename(filename), open_mode)
297         return (stream, filename)
298     except (IOError, OSError) as err:
299         # In case of error, try to remove win32 forbidden chars
300         filename = re.sub(u'[/<>:"\\|\\\\?\\*]', u'#', filename)
301
302         # An exception here should be caught in the caller
303         stream = open(encodeFilename(filename), open_mode)
304         return (stream, filename)
305
306
307 def timeconvert(timestr):
308     """Convert RFC 2822 defined time string into system timestamp"""
309     timestamp = None
310     timetuple = email.utils.parsedate_tz(timestr)
311     if timetuple is not None:
312         timestamp = email.utils.mktime_tz(timetuple)
313     return timestamp
314
315 def sanitize_filename(s, restricted=False):
316     """Sanitizes a string so it could be used as part of a filename.
317     If restricted is set, use a stricter subset of allowed characters.
318     """
319     def replace_insane(char):
320         if char == '?' or ord(char) < 32 or ord(char) == 127:
321             return ''
322         elif char == '"':
323             return '' if restricted else '\''
324         elif char == ':':
325             return '_-' if restricted else ' -'
326         elif char in '\\/|*<>':
327             return '_'
328         if restricted and (char in '!&\'' or char.isspace()):
329             return '_'
330         if restricted and ord(char) > 127:
331             return '_'
332         return char
333
334     result = u''.join(map(replace_insane, s))
335     while '__' in result:
336         result = result.replace('__', '_')
337     result = result.strip('_')
338     # Common case of "Foreign band name - English song title"
339     if restricted and result.startswith('-_'):
340         result = result[2:]
341     if not result:
342         result = '_'
343     return result
344
345 def orderedSet(iterable):
346     """ Remove all duplicates from the input iterable """
347     res = []
348     for el in iterable:
349         if el not in res:
350             res.append(el)
351     return res
352
353 def unescapeHTML(s):
354     """
355     @param s a string
356     """
357     assert type(s) == type(u'')
358
359     result = re.sub(u'(?u)&(.+?);', htmlentity_transform, s)
360     return result
361
362 def encodeFilename(s):
363     """
364     @param s The name of the file
365     """
366
367     assert type(s) == type(u'')
368
369     # Python 3 has a Unicode API
370     if sys.version_info >= (3, 0):
371         return s
372
373     if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
374         # Pass u'' directly to use Unicode APIs on Windows 2000 and up
375         # (Detecting Windows NT 4 is tricky because 'major >= 4' would
376         # match Windows 9x series as well. Besides, NT 4 is obsolete.)
377         return s
378     else:
379         return s.encode(sys.getfilesystemencoding(), 'ignore')
380
381 class DownloadError(Exception):
382     """Download Error exception.
383
384     This exception may be thrown by FileDownloader objects if they are not
385     configured to continue on errors. They will contain the appropriate
386     error message.
387     """
388     pass
389
390
391 class SameFileError(Exception):
392     """Same File exception.
393
394     This exception will be thrown by FileDownloader objects if they detect
395     multiple files would have to be downloaded to the same file on disk.
396     """
397     pass
398
399
400 class PostProcessingError(Exception):
401     """Post Processing exception.
402
403     This exception may be raised by PostProcessor's .run() method to
404     indicate an error in the postprocessing task.
405     """
406     pass
407
408 class MaxDownloadsReached(Exception):
409     """ --max-downloads limit has been reached. """
410     pass
411
412
413 class UnavailableVideoError(Exception):
414     """Unavailable Format exception.
415
416     This exception will be thrown when a video is requested
417     in a format that is not available for that video.
418     """
419     pass
420
421
422 class ContentTooShortError(Exception):
423     """Content Too Short exception.
424
425     This exception may be raised by FileDownloader objects when a file they
426     download is too small for what the server announced first, indicating
427     the connection was probably interrupted.
428     """
429     # Both in bytes
430     downloaded = None
431     expected = None
432
433     def __init__(self, downloaded, expected):
434         self.downloaded = downloaded
435         self.expected = expected
436
437
438 class Trouble(Exception):
439     """Trouble helper exception
440
441     This is an exception to be handled with
442     FileDownloader.trouble
443     """
444
445 class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
446     """Handler for HTTP requests and responses.
447
448     This class, when installed with an OpenerDirector, automatically adds
449     the standard headers to every HTTP request and handles gzipped and
450     deflated responses from web servers. If compression is to be avoided in
451     a particular request, the original request in the program code only has
452     to include the HTTP header "Youtubedl-No-Compression", which will be
453     removed before making the real request.
454
455     Part of this code was copied from:
456
457     http://techknack.net/python-urllib2-handlers/
458
459     Andrew Rowls, the author of that code, agreed to release it to the
460     public domain.
461     """
462
463     @staticmethod
464     def deflate(data):
465         try:
466             return zlib.decompress(data, -zlib.MAX_WBITS)
467         except zlib.error:
468             return zlib.decompress(data)
469
470     @staticmethod
471     def addinfourl_wrapper(stream, headers, url, code):
472         if hasattr(compat_urllib_request.addinfourl, 'getcode'):
473             return compat_urllib_request.addinfourl(stream, headers, url, code)
474         ret = compat_urllib_request.addinfourl(stream, headers, url)
475         ret.code = code
476         return ret
477
478     def http_request(self, req):
479         for h in std_headers:
480             if h in req.headers:
481                 del req.headers[h]
482             req.add_header(h, std_headers[h])
483         if 'Youtubedl-no-compression' in req.headers:
484             if 'Accept-encoding' in req.headers:
485                 del req.headers['Accept-encoding']
486             del req.headers['Youtubedl-no-compression']
487         return req
488
489     def http_response(self, req, resp):
490         old_resp = resp
491         # gzip
492         if resp.headers.get('Content-encoding', '') == 'gzip':
493             gz = gzip.GzipFile(fileobj=io.BytesIO(resp.read()), mode='r')
494             resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
495             resp.msg = old_resp.msg
496         # deflate
497         if resp.headers.get('Content-encoding', '') == 'deflate':
498             gz = io.BytesIO(self.deflate(resp.read()))
499             resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
500             resp.msg = old_resp.msg
501         return resp