Fix --extract-audio on Python 3
[youtube-dl] / youtube_dl / utils.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import gzip
5 import io
6 import locale
7 import os
8 import re
9 import sys
10 import zlib
11 import email.utils
12 import json
13
14 try:
15     import urllib.request as compat_urllib_request
16 except ImportError: # Python 2
17     import urllib2 as compat_urllib_request
18
19 try:
20     import urllib.error as compat_urllib_error
21 except ImportError: # Python 2
22     import urllib2 as compat_urllib_error
23
24 try:
25     import urllib.parse as compat_urllib_parse
26 except ImportError: # Python 2
27     import urllib as compat_urllib_parse
28
29 try:
30     from urllib.parse import urlparse as compat_urllib_parse_urlparse
31 except ImportError: # Python 2
32     from urlparse import urlparse as compat_urllib_parse_urlparse
33
34 try:
35     import http.cookiejar as compat_cookiejar
36 except ImportError: # Python 2
37     import cookielib as compat_cookiejar
38
39 try:
40     import html.entities as compat_html_entities
41 except ImportError: # Python 2
42     import htmlentitydefs as compat_html_entities
43
44 try:
45     import html.parser as compat_html_parser
46 except ImportError: # Python 2
47     import HTMLParser as compat_html_parser
48
49 try:
50     import http.client as compat_http_client
51 except ImportError: # Python 2
52     import httplib as compat_http_client
53
54 try:
55     from subprocess import DEVNULL
56     compat_subprocess_get_DEVNULL = lambda: DEVNULL
57 except ImportError:
58     compat_subprocess_get_DEVNULL = lambda: open(os.path.devnull, 'w')
59
60 try:
61     from urllib.parse import parse_qs as compat_parse_qs
62 except ImportError: # Python 2
63     # HACK: The following is the correct parse_qs implementation from cpython 3's stdlib.
64     # Python 2's version is apparently totally broken
65     def _unquote(string, encoding='utf-8', errors='replace'):
66         if string == '':
67             return string
68         res = string.split('%')
69         if len(res) == 1:
70             return string
71         if encoding is None:
72             encoding = 'utf-8'
73         if errors is None:
74             errors = 'replace'
75         # pct_sequence: contiguous sequence of percent-encoded bytes, decoded
76         pct_sequence = b''
77         string = res[0]
78         for item in res[1:]:
79             try:
80                 if not item:
81                     raise ValueError
82                 pct_sequence += item[:2].decode('hex')
83                 rest = item[2:]
84                 if not rest:
85                     # This segment was just a single percent-encoded character.
86                     # May be part of a sequence of code units, so delay decoding.
87                     # (Stored in pct_sequence).
88                     continue
89             except ValueError:
90                 rest = '%' + item
91             # Encountered non-percent-encoded characters. Flush the current
92             # pct_sequence.
93             string += pct_sequence.decode(encoding, errors) + rest
94             pct_sequence = b''
95         if pct_sequence:
96             # Flush the final pct_sequence
97             string += pct_sequence.decode(encoding, errors)
98         return string
99
100     def _parse_qsl(qs, keep_blank_values=False, strict_parsing=False,
101                 encoding='utf-8', errors='replace'):
102         qs, _coerce_result = qs, unicode
103         pairs = [s2 for s1 in qs.split('&') for s2 in s1.split(';')]
104         r = []
105         for name_value in pairs:
106             if not name_value and not strict_parsing:
107                 continue
108             nv = name_value.split('=', 1)
109             if len(nv) != 2:
110                 if strict_parsing:
111                     raise ValueError("bad query field: %r" % (name_value,))
112                 # Handle case of a control-name with no equal sign
113                 if keep_blank_values:
114                     nv.append('')
115                 else:
116                     continue
117             if len(nv[1]) or keep_blank_values:
118                 name = nv[0].replace('+', ' ')
119                 name = _unquote(name, encoding=encoding, errors=errors)
120                 name = _coerce_result(name)
121                 value = nv[1].replace('+', ' ')
122                 value = _unquote(value, encoding=encoding, errors=errors)
123                 value = _coerce_result(value)
124                 r.append((name, value))
125         return r
126
127     def compat_parse_qs(qs, keep_blank_values=False, strict_parsing=False,
128                 encoding='utf-8', errors='replace'):
129         parsed_result = {}
130         pairs = _parse_qsl(qs, keep_blank_values, strict_parsing,
131                         encoding=encoding, errors=errors)
132         for name, value in pairs:
133             if name in parsed_result:
134                 parsed_result[name].append(value)
135             else:
136                 parsed_result[name] = [value]
137         return parsed_result
138
139 try:
140     compat_str = unicode # Python 2
141 except NameError:
142     compat_str = str
143
144 try:
145     compat_chr = unichr # Python 2
146 except NameError:
147     compat_chr = chr
148
149 std_headers = {
150     'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:10.0) Gecko/20100101 Firefox/10.0',
151     'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7',
152     'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
153     'Accept-Encoding': 'gzip, deflate',
154     'Accept-Language': 'en-us,en;q=0.5',
155 }
156 def preferredencoding():
157     """Get preferred encoding.
158
159     Returns the best encoding scheme for the system, based on
160     locale.getpreferredencoding() and some further tweaks.
161     """
162     try:
163         pref = locale.getpreferredencoding()
164         u'TEST'.encode(pref)
165     except:
166         pref = 'UTF-8'
167
168     return pref
169
170 if sys.version_info < (3,0):
171     def compat_print(s):
172         print(s.encode(preferredencoding(), 'xmlcharrefreplace'))
173 else:
174     def compat_print(s):
175         assert type(s) == type(u'')
176         print(s)
177
178 def htmlentity_transform(matchobj):
179     """Transforms an HTML entity to a character.
180
181     This function receives a match object and is intended to be used with
182     the re.sub() function.
183     """
184     entity = matchobj.group(1)
185
186     # Known non-numeric HTML entity
187     if entity in compat_html_entities.name2codepoint:
188         return compat_chr(compat_html_entities.name2codepoint[entity])
189
190     mobj = re.match(u'(?u)#(x?\\d+)', entity)
191     if mobj is not None:
192         numstr = mobj.group(1)
193         if numstr.startswith(u'x'):
194             base = 16
195             numstr = u'0%s' % numstr
196         else:
197             base = 10
198         return compat_chr(int(numstr, base))
199
200     # Unknown entity in name, return its literal representation
201     return (u'&%s;' % entity)
202
203 compat_html_parser.locatestarttagend = re.compile(r"""<[a-zA-Z][-.a-zA-Z0-9:_]*(?:\s+(?:(?<=['"\s])[^\s/>][^\s/=>]*(?:\s*=+\s*(?:'[^']*'|"[^"]*"|(?!['"])[^>\s]*))?\s*)*)?\s*""", re.VERBOSE) # backport bugfix
204 class IDParser(compat_html_parser.HTMLParser):
205     """Modified HTMLParser that isolates a tag with the specified id"""
206     def __init__(self, id):
207         self.id = id
208         self.result = None
209         self.started = False
210         self.depth = {}
211         self.html = None
212         self.watch_startpos = False
213         self.error_count = 0
214         compat_html_parser.HTMLParser.__init__(self)
215
216     def error(self, message):
217         if self.error_count > 10 or self.started:
218             raise compat_html_parser.HTMLParseError(message, self.getpos())
219         self.rawdata = '\n'.join(self.html.split('\n')[self.getpos()[0]:]) # skip one line
220         self.error_count += 1
221         self.goahead(1)
222
223     def loads(self, html):
224         self.html = html
225         self.feed(html)
226         self.close()
227
228     def handle_starttag(self, tag, attrs):
229         attrs = dict(attrs)
230         if self.started:
231             self.find_startpos(None)
232         if 'id' in attrs and attrs['id'] == self.id:
233             self.result = [tag]
234             self.started = True
235             self.watch_startpos = True
236         if self.started:
237             if not tag in self.depth: self.depth[tag] = 0
238             self.depth[tag] += 1
239
240     def handle_endtag(self, tag):
241         if self.started:
242             if tag in self.depth: self.depth[tag] -= 1
243             if self.depth[self.result[0]] == 0:
244                 self.started = False
245                 self.result.append(self.getpos())
246
247     def find_startpos(self, x):
248         """Needed to put the start position of the result (self.result[1])
249         after the opening tag with the requested id"""
250         if self.watch_startpos:
251             self.watch_startpos = False
252             self.result.append(self.getpos())
253     handle_entityref = handle_charref = handle_data = handle_comment = \
254     handle_decl = handle_pi = unknown_decl = find_startpos
255
256     def get_result(self):
257         if self.result is None:
258             return None
259         if len(self.result) != 3:
260             return None
261         lines = self.html.split('\n')
262         lines = lines[self.result[1][0]-1:self.result[2][0]]
263         lines[0] = lines[0][self.result[1][1]:]
264         if len(lines) == 1:
265             lines[-1] = lines[-1][:self.result[2][1]-self.result[1][1]]
266         lines[-1] = lines[-1][:self.result[2][1]]
267         return '\n'.join(lines).strip()
268
269 def get_element_by_id(id, html):
270     """Return the content of the tag with the specified id in the passed HTML document"""
271     parser = IDParser(id)
272     try:
273         parser.loads(html)
274     except compat_html_parser.HTMLParseError:
275         pass
276     return parser.get_result()
277
278
279 def clean_html(html):
280     """Clean an HTML snippet into a readable string"""
281     # Newline vs <br />
282     html = html.replace('\n', ' ')
283     html = re.sub('\s*<\s*br\s*/?\s*>\s*', '\n', html)
284     # Strip html tags
285     html = re.sub('<.*?>', '', html)
286     # Replace html entities
287     html = unescapeHTML(html)
288     return html
289
290
291 def sanitize_open(filename, open_mode):
292     """Try to open the given filename, and slightly tweak it if this fails.
293
294     Attempts to open the given filename. If this fails, it tries to change
295     the filename slightly, step by step, until it's either able to open it
296     or it fails and raises a final exception, like the standard open()
297     function.
298
299     It returns the tuple (stream, definitive_file_name).
300     """
301     try:
302         if filename == u'-':
303             if sys.platform == 'win32':
304                 import msvcrt
305                 msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
306             return (sys.stdout, filename)
307         stream = open(encodeFilename(filename), open_mode)
308         return (stream, filename)
309     except (IOError, OSError) as err:
310         # In case of error, try to remove win32 forbidden chars
311         filename = re.sub(u'[/<>:"\\|\\\\?\\*]', u'#', filename)
312
313         # An exception here should be caught in the caller
314         stream = open(encodeFilename(filename), open_mode)
315         return (stream, filename)
316
317
318 def timeconvert(timestr):
319     """Convert RFC 2822 defined time string into system timestamp"""
320     timestamp = None
321     timetuple = email.utils.parsedate_tz(timestr)
322     if timetuple is not None:
323         timestamp = email.utils.mktime_tz(timetuple)
324     return timestamp
325
326 def sanitize_filename(s, restricted=False, is_id=False):
327     """Sanitizes a string so it could be used as part of a filename.
328     If restricted is set, use a stricter subset of allowed characters.
329     Set is_id if this is not an arbitrary string, but an ID that should be kept if possible
330     """
331     def replace_insane(char):
332         if char == '?' or ord(char) < 32 or ord(char) == 127:
333             return ''
334         elif char == '"':
335             return '' if restricted else '\''
336         elif char == ':':
337             return '_-' if restricted else ' -'
338         elif char in '\\/|*<>':
339             return '_'
340         if restricted and (char in '!&\'()[]{}$;`^,#' or char.isspace()):
341             return '_'
342         if restricted and ord(char) > 127:
343             return '_'
344         return char
345
346     result = u''.join(map(replace_insane, s))
347     if not is_id:
348         while '__' in result:
349             result = result.replace('__', '_')
350         result = result.strip('_')
351         # Common case of "Foreign band name - English song title"
352         if restricted and result.startswith('-_'):
353             result = result[2:]
354         if not result:
355             result = '_'
356     return result
357
358 def orderedSet(iterable):
359     """ Remove all duplicates from the input iterable """
360     res = []
361     for el in iterable:
362         if el not in res:
363             res.append(el)
364     return res
365
366 def unescapeHTML(s):
367     """
368     @param s a string
369     """
370     assert type(s) == type(u'')
371
372     result = re.sub(u'(?u)&(.+?);', htmlentity_transform, s)
373     return result
374
375 def encodeFilename(s):
376     """
377     @param s The name of the file
378     """
379
380     assert type(s) == type(u'')
381
382     # Python 3 has a Unicode API
383     if sys.version_info >= (3, 0):
384         return s
385
386     if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
387         # Pass u'' directly to use Unicode APIs on Windows 2000 and up
388         # (Detecting Windows NT 4 is tricky because 'major >= 4' would
389         # match Windows 9x series as well. Besides, NT 4 is obsolete.)
390         return s
391     else:
392         return s.encode(sys.getfilesystemencoding(), 'ignore')
393
394 class DownloadError(Exception):
395     """Download Error exception.
396
397     This exception may be thrown by FileDownloader objects if they are not
398     configured to continue on errors. They will contain the appropriate
399     error message.
400     """
401     pass
402
403
404 class SameFileError(Exception):
405     """Same File exception.
406
407     This exception will be thrown by FileDownloader objects if they detect
408     multiple files would have to be downloaded to the same file on disk.
409     """
410     pass
411
412
413 class PostProcessingError(Exception):
414     """Post Processing exception.
415
416     This exception may be raised by PostProcessor's .run() method to
417     indicate an error in the postprocessing task.
418     """
419     pass
420
421 class MaxDownloadsReached(Exception):
422     """ --max-downloads limit has been reached. """
423     pass
424
425
426 class UnavailableVideoError(Exception):
427     """Unavailable Format exception.
428
429     This exception will be thrown when a video is requested
430     in a format that is not available for that video.
431     """
432     pass
433
434
435 class ContentTooShortError(Exception):
436     """Content Too Short exception.
437
438     This exception may be raised by FileDownloader objects when a file they
439     download is too small for what the server announced first, indicating
440     the connection was probably interrupted.
441     """
442     # Both in bytes
443     downloaded = None
444     expected = None
445
446     def __init__(self, downloaded, expected):
447         self.downloaded = downloaded
448         self.expected = expected
449
450
451 class Trouble(Exception):
452     """Trouble helper exception
453
454     This is an exception to be handled with
455     FileDownloader.trouble
456     """
457
458 class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
459     """Handler for HTTP requests and responses.
460
461     This class, when installed with an OpenerDirector, automatically adds
462     the standard headers to every HTTP request and handles gzipped and
463     deflated responses from web servers. If compression is to be avoided in
464     a particular request, the original request in the program code only has
465     to include the HTTP header "Youtubedl-No-Compression", which will be
466     removed before making the real request.
467
468     Part of this code was copied from:
469
470     http://techknack.net/python-urllib2-handlers/
471
472     Andrew Rowls, the author of that code, agreed to release it to the
473     public domain.
474     """
475
476     @staticmethod
477     def deflate(data):
478         try:
479             return zlib.decompress(data, -zlib.MAX_WBITS)
480         except zlib.error:
481             return zlib.decompress(data)
482
483     @staticmethod
484     def addinfourl_wrapper(stream, headers, url, code):
485         if hasattr(compat_urllib_request.addinfourl, 'getcode'):
486             return compat_urllib_request.addinfourl(stream, headers, url, code)
487         ret = compat_urllib_request.addinfourl(stream, headers, url)
488         ret.code = code
489         return ret
490
491     def http_request(self, req):
492         for h in std_headers:
493             if h in req.headers:
494                 del req.headers[h]
495             req.add_header(h, std_headers[h])
496         if 'Youtubedl-no-compression' in req.headers:
497             if 'Accept-encoding' in req.headers:
498                 del req.headers['Accept-encoding']
499             del req.headers['Youtubedl-no-compression']
500         return req
501
502     def http_response(self, req, resp):
503         old_resp = resp
504         # gzip
505         if resp.headers.get('Content-encoding', '') == 'gzip':
506             gz = gzip.GzipFile(fileobj=io.BytesIO(resp.read()), mode='r')
507             resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
508             resp.msg = old_resp.msg
509         # deflate
510         if resp.headers.get('Content-encoding', '') == 'deflate':
511             gz = io.BytesIO(self.deflate(resp.read()))
512             resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
513             resp.msg = old_resp.msg
514         return resp
515
516     https_request = http_request
517     https_response = http_response