Refactor IDParser to search for elements by any attribute not just ID
[youtube-dl] / youtube_dl / utils.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import gzip
5 import io
6 import json
7 import locale
8 import os
9 import re
10 import sys
11 import zlib
12 import email.utils
13 import json
14
15 try:
16     import urllib.request as compat_urllib_request
17 except ImportError: # Python 2
18     import urllib2 as compat_urllib_request
19
20 try:
21     import urllib.error as compat_urllib_error
22 except ImportError: # Python 2
23     import urllib2 as compat_urllib_error
24
25 try:
26     import urllib.parse as compat_urllib_parse
27 except ImportError: # Python 2
28     import urllib as compat_urllib_parse
29
30 try:
31     from urllib.parse import urlparse as compat_urllib_parse_urlparse
32 except ImportError: # Python 2
33     from urlparse import urlparse as compat_urllib_parse_urlparse
34
35 try:
36     import http.cookiejar as compat_cookiejar
37 except ImportError: # Python 2
38     import cookielib as compat_cookiejar
39
40 try:
41     import html.entities as compat_html_entities
42 except ImportError: # Python 2
43     import htmlentitydefs as compat_html_entities
44
45 try:
46     import html.parser as compat_html_parser
47 except ImportError: # Python 2
48     import HTMLParser as compat_html_parser
49
50 try:
51     import http.client as compat_http_client
52 except ImportError: # Python 2
53     import httplib as compat_http_client
54
55 try:
56     from subprocess import DEVNULL
57     compat_subprocess_get_DEVNULL = lambda: DEVNULL
58 except ImportError:
59     compat_subprocess_get_DEVNULL = lambda: open(os.path.devnull, 'w')
60
61 try:
62     from urllib.parse import parse_qs as compat_parse_qs
63 except ImportError: # Python 2
64     # HACK: The following is the correct parse_qs implementation from cpython 3's stdlib.
65     # Python 2's version is apparently totally broken
66     def _unquote(string, encoding='utf-8', errors='replace'):
67         if string == '':
68             return string
69         res = string.split('%')
70         if len(res) == 1:
71             return string
72         if encoding is None:
73             encoding = 'utf-8'
74         if errors is None:
75             errors = 'replace'
76         # pct_sequence: contiguous sequence of percent-encoded bytes, decoded
77         pct_sequence = b''
78         string = res[0]
79         for item in res[1:]:
80             try:
81                 if not item:
82                     raise ValueError
83                 pct_sequence += item[:2].decode('hex')
84                 rest = item[2:]
85                 if not rest:
86                     # This segment was just a single percent-encoded character.
87                     # May be part of a sequence of code units, so delay decoding.
88                     # (Stored in pct_sequence).
89                     continue
90             except ValueError:
91                 rest = '%' + item
92             # Encountered non-percent-encoded characters. Flush the current
93             # pct_sequence.
94             string += pct_sequence.decode(encoding, errors) + rest
95             pct_sequence = b''
96         if pct_sequence:
97             # Flush the final pct_sequence
98             string += pct_sequence.decode(encoding, errors)
99         return string
100
101     def _parse_qsl(qs, keep_blank_values=False, strict_parsing=False,
102                 encoding='utf-8', errors='replace'):
103         qs, _coerce_result = qs, unicode
104         pairs = [s2 for s1 in qs.split('&') for s2 in s1.split(';')]
105         r = []
106         for name_value in pairs:
107             if not name_value and not strict_parsing:
108                 continue
109             nv = name_value.split('=', 1)
110             if len(nv) != 2:
111                 if strict_parsing:
112                     raise ValueError("bad query field: %r" % (name_value,))
113                 # Handle case of a control-name with no equal sign
114                 if keep_blank_values:
115                     nv.append('')
116                 else:
117                     continue
118             if len(nv[1]) or keep_blank_values:
119                 name = nv[0].replace('+', ' ')
120                 name = _unquote(name, encoding=encoding, errors=errors)
121                 name = _coerce_result(name)
122                 value = nv[1].replace('+', ' ')
123                 value = _unquote(value, encoding=encoding, errors=errors)
124                 value = _coerce_result(value)
125                 r.append((name, value))
126         return r
127
128     def compat_parse_qs(qs, keep_blank_values=False, strict_parsing=False,
129                 encoding='utf-8', errors='replace'):
130         parsed_result = {}
131         pairs = _parse_qsl(qs, keep_blank_values, strict_parsing,
132                         encoding=encoding, errors=errors)
133         for name, value in pairs:
134             if name in parsed_result:
135                 parsed_result[name].append(value)
136             else:
137                 parsed_result[name] = [value]
138         return parsed_result
139
140 try:
141     compat_str = unicode # Python 2
142 except NameError:
143     compat_str = str
144
145 try:
146     compat_chr = unichr # Python 2
147 except NameError:
148     compat_chr = chr
149
150 std_headers = {
151     'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:10.0) Gecko/20100101 Firefox/10.0',
152     'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7',
153     'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
154     'Accept-Encoding': 'gzip, deflate',
155     'Accept-Language': 'en-us,en;q=0.5',
156 }
157 def preferredencoding():
158     """Get preferred encoding.
159
160     Returns the best encoding scheme for the system, based on
161     locale.getpreferredencoding() and some further tweaks.
162     """
163     try:
164         pref = locale.getpreferredencoding()
165         u'TEST'.encode(pref)
166     except:
167         pref = 'UTF-8'
168
169     return pref
170
171 if sys.version_info < (3,0):
172     def compat_print(s):
173         print(s.encode(preferredencoding(), 'xmlcharrefreplace'))
174 else:
175     def compat_print(s):
176         assert type(s) == type(u'')
177         print(s)
178
179 # In Python 2.x, json.dump expects a bytestream.
180 # In Python 3.x, it writes to a character stream
181 if sys.version_info < (3,0):
182     def write_json_file(obj, fn):
183         with open(fn, 'wb') as f:
184             json.dump(obj, f)
185 else:
186     def write_json_file(obj, fn):
187         with open(fn, 'w', encoding='utf-8') as f:
188             json.dump(obj, f)
189
190
191 def htmlentity_transform(matchobj):
192     """Transforms an HTML entity to a character.
193
194     This function receives a match object and is intended to be used with
195     the re.sub() function.
196     """
197     entity = matchobj.group(1)
198
199     # Known non-numeric HTML entity
200     if entity in compat_html_entities.name2codepoint:
201         return compat_chr(compat_html_entities.name2codepoint[entity])
202
203     mobj = re.match(u'(?u)#(x?\\d+)', entity)
204     if mobj is not None:
205         numstr = mobj.group(1)
206         if numstr.startswith(u'x'):
207             base = 16
208             numstr = u'0%s' % numstr
209         else:
210             base = 10
211         return compat_chr(int(numstr, base))
212
213     # Unknown entity in name, return its literal representation
214     return (u'&%s;' % entity)
215
216 compat_html_parser.locatestarttagend = re.compile(r"""<[a-zA-Z][-.a-zA-Z0-9:_]*(?:\s+(?:(?<=['"\s])[^\s/>][^\s/=>]*(?:\s*=+\s*(?:'[^']*'|"[^"]*"|(?!['"])[^>\s]*))?\s*)*)?\s*""", re.VERBOSE) # backport bugfix
217 class AttrParser(compat_html_parser.HTMLParser):
218     """Modified HTMLParser that isolates a tag with the specified attribute"""
219     def __init__(self, attribute, value):
220         self.attribute = attribute
221         self.value = value
222         self.result = None
223         self.started = False
224         self.depth = {}
225         self.html = None
226         self.watch_startpos = False
227         self.error_count = 0
228         compat_html_parser.HTMLParser.__init__(self)
229
230     def error(self, message):
231         if self.error_count > 10 or self.started:
232             raise compat_html_parser.HTMLParseError(message, self.getpos())
233         self.rawdata = '\n'.join(self.html.split('\n')[self.getpos()[0]:]) # skip one line
234         self.error_count += 1
235         self.goahead(1)
236
237     def loads(self, html):
238         self.html = html
239         self.feed(html)
240         self.close()
241
242     def handle_starttag(self, tag, attrs):
243         attrs = dict(attrs)
244         if self.started:
245             self.find_startpos(None)
246         if self.attribute in attrs and attrs[self.attribute] == self.value:
247             self.result = [tag]
248             self.started = True
249             self.watch_startpos = True
250         if self.started:
251             if not tag in self.depth: self.depth[tag] = 0
252             self.depth[tag] += 1
253
254     def handle_endtag(self, tag):
255         if self.started:
256             if tag in self.depth: self.depth[tag] -= 1
257             if self.depth[self.result[0]] == 0:
258                 self.started = False
259                 self.result.append(self.getpos())
260
261     def find_startpos(self, x):
262         """Needed to put the start position of the result (self.result[1])
263         after the opening tag with the requested id"""
264         if self.watch_startpos:
265             self.watch_startpos = False
266             self.result.append(self.getpos())
267     handle_entityref = handle_charref = handle_data = handle_comment = \
268     handle_decl = handle_pi = unknown_decl = find_startpos
269
270     def get_result(self):
271         if self.result is None:
272             return None
273         if len(self.result) != 3:
274             return None
275         lines = self.html.split('\n')
276         lines = lines[self.result[1][0]-1:self.result[2][0]]
277         lines[0] = lines[0][self.result[1][1]:]
278         if len(lines) == 1:
279             lines[-1] = lines[-1][:self.result[2][1]-self.result[1][1]]
280         lines[-1] = lines[-1][:self.result[2][1]]
281         return '\n'.join(lines).strip()
282
283 def get_element_by_id(id, html):
284     """Return the content of the tag with the specified ID in the passed HTML document"""
285     return get_element_by_attribute("id", id, html)
286
287 def get_element_by_attribute(attribute, value, html):
288     """Return the content of the tag with the specified attribute in the passed HTML document"""
289     parser = AttrParser(attribute, value)
290     try:
291         parser.loads(html)
292     except compat_html_parser.HTMLParseError:
293         pass
294     return parser.get_result()
295
296
297 def clean_html(html):
298     """Clean an HTML snippet into a readable string"""
299     # Newline vs <br />
300     html = html.replace('\n', ' ')
301     html = re.sub('\s*<\s*br\s*/?\s*>\s*', '\n', html)
302     # Strip html tags
303     html = re.sub('<.*?>', '', html)
304     # Replace html entities
305     html = unescapeHTML(html)
306     return html
307
308
309 def sanitize_open(filename, open_mode):
310     """Try to open the given filename, and slightly tweak it if this fails.
311
312     Attempts to open the given filename. If this fails, it tries to change
313     the filename slightly, step by step, until it's either able to open it
314     or it fails and raises a final exception, like the standard open()
315     function.
316
317     It returns the tuple (stream, definitive_file_name).
318     """
319     try:
320         if filename == u'-':
321             if sys.platform == 'win32':
322                 import msvcrt
323                 msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
324             return (sys.stdout, filename)
325         stream = open(encodeFilename(filename), open_mode)
326         return (stream, filename)
327     except (IOError, OSError) as err:
328         # In case of error, try to remove win32 forbidden chars
329         filename = re.sub(u'[/<>:"\\|\\\\?\\*]', u'#', filename)
330
331         # An exception here should be caught in the caller
332         stream = open(encodeFilename(filename), open_mode)
333         return (stream, filename)
334
335
336 def timeconvert(timestr):
337     """Convert RFC 2822 defined time string into system timestamp"""
338     timestamp = None
339     timetuple = email.utils.parsedate_tz(timestr)
340     if timetuple is not None:
341         timestamp = email.utils.mktime_tz(timetuple)
342     return timestamp
343
344 def sanitize_filename(s, restricted=False, is_id=False):
345     """Sanitizes a string so it could be used as part of a filename.
346     If restricted is set, use a stricter subset of allowed characters.
347     Set is_id if this is not an arbitrary string, but an ID that should be kept if possible
348     """
349     def replace_insane(char):
350         if char == '?' or ord(char) < 32 or ord(char) == 127:
351             return ''
352         elif char == '"':
353             return '' if restricted else '\''
354         elif char == ':':
355             return '_-' if restricted else ' -'
356         elif char in '\\/|*<>':
357             return '_'
358         if restricted and (char in '!&\'()[]{}$;`^,#' or char.isspace()):
359             return '_'
360         if restricted and ord(char) > 127:
361             return '_'
362         return char
363
364     result = u''.join(map(replace_insane, s))
365     if not is_id:
366         while '__' in result:
367             result = result.replace('__', '_')
368         result = result.strip('_')
369         # Common case of "Foreign band name - English song title"
370         if restricted and result.startswith('-_'):
371             result = result[2:]
372         if not result:
373             result = '_'
374     return result
375
376 def orderedSet(iterable):
377     """ Remove all duplicates from the input iterable """
378     res = []
379     for el in iterable:
380         if el not in res:
381             res.append(el)
382     return res
383
384 def unescapeHTML(s):
385     """
386     @param s a string
387     """
388     assert type(s) == type(u'')
389
390     result = re.sub(u'(?u)&(.+?);', htmlentity_transform, s)
391     return result
392
393 def encodeFilename(s):
394     """
395     @param s The name of the file
396     """
397
398     assert type(s) == type(u'')
399
400     # Python 3 has a Unicode API
401     if sys.version_info >= (3, 0):
402         return s
403
404     if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
405         # Pass u'' directly to use Unicode APIs on Windows 2000 and up
406         # (Detecting Windows NT 4 is tricky because 'major >= 4' would
407         # match Windows 9x series as well. Besides, NT 4 is obsolete.)
408         return s
409     else:
410         return s.encode(sys.getfilesystemencoding(), 'ignore')
411
412 class DownloadError(Exception):
413     """Download Error exception.
414
415     This exception may be thrown by FileDownloader objects if they are not
416     configured to continue on errors. They will contain the appropriate
417     error message.
418     """
419     pass
420
421
422 class SameFileError(Exception):
423     """Same File exception.
424
425     This exception will be thrown by FileDownloader objects if they detect
426     multiple files would have to be downloaded to the same file on disk.
427     """
428     pass
429
430
431 class PostProcessingError(Exception):
432     """Post Processing exception.
433
434     This exception may be raised by PostProcessor's .run() method to
435     indicate an error in the postprocessing task.
436     """
437     pass
438
439 class MaxDownloadsReached(Exception):
440     """ --max-downloads limit has been reached. """
441     pass
442
443
444 class UnavailableVideoError(Exception):
445     """Unavailable Format exception.
446
447     This exception will be thrown when a video is requested
448     in a format that is not available for that video.
449     """
450     pass
451
452
453 class ContentTooShortError(Exception):
454     """Content Too Short exception.
455
456     This exception may be raised by FileDownloader objects when a file they
457     download is too small for what the server announced first, indicating
458     the connection was probably interrupted.
459     """
460     # Both in bytes
461     downloaded = None
462     expected = None
463
464     def __init__(self, downloaded, expected):
465         self.downloaded = downloaded
466         self.expected = expected
467
468
469 class Trouble(Exception):
470     """Trouble helper exception
471
472     This is an exception to be handled with
473     FileDownloader.trouble
474     """
475
476 class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
477     """Handler for HTTP requests and responses.
478
479     This class, when installed with an OpenerDirector, automatically adds
480     the standard headers to every HTTP request and handles gzipped and
481     deflated responses from web servers. If compression is to be avoided in
482     a particular request, the original request in the program code only has
483     to include the HTTP header "Youtubedl-No-Compression", which will be
484     removed before making the real request.
485
486     Part of this code was copied from:
487
488     http://techknack.net/python-urllib2-handlers/
489
490     Andrew Rowls, the author of that code, agreed to release it to the
491     public domain.
492     """
493
494     @staticmethod
495     def deflate(data):
496         try:
497             return zlib.decompress(data, -zlib.MAX_WBITS)
498         except zlib.error:
499             return zlib.decompress(data)
500
501     @staticmethod
502     def addinfourl_wrapper(stream, headers, url, code):
503         if hasattr(compat_urllib_request.addinfourl, 'getcode'):
504             return compat_urllib_request.addinfourl(stream, headers, url, code)
505         ret = compat_urllib_request.addinfourl(stream, headers, url)
506         ret.code = code
507         return ret
508
509     def http_request(self, req):
510         for h in std_headers:
511             if h in req.headers:
512                 del req.headers[h]
513             req.add_header(h, std_headers[h])
514         if 'Youtubedl-no-compression' in req.headers:
515             if 'Accept-encoding' in req.headers:
516                 del req.headers['Accept-encoding']
517             del req.headers['Youtubedl-no-compression']
518         return req
519
520     def http_response(self, req, resp):
521         old_resp = resp
522         # gzip
523         if resp.headers.get('Content-encoding', '') == 'gzip':
524             gz = gzip.GzipFile(fileobj=io.BytesIO(resp.read()), mode='r')
525             resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
526             resp.msg = old_resp.msg
527         # deflate
528         if resp.headers.get('Content-encoding', '') == 'deflate':
529             gz = io.BytesIO(self.deflate(resp.read()))
530             resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
531             resp.msg = old_resp.msg
532         return resp
533
534     https_request = http_request
535     https_response = http_response