some fixes, pulled the codename from the code
[youtube-dl] / youtube_dl / utils.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import gzip
5 import io
6 import json
7 import locale
8 import os
9 import re
10 import sys
11 import zlib
12 import email.utils
13 import json
14
15 try:
16     import urllib.request as compat_urllib_request
17 except ImportError: # Python 2
18     import urllib2 as compat_urllib_request
19
20 try:
21     import urllib.error as compat_urllib_error
22 except ImportError: # Python 2
23     import urllib2 as compat_urllib_error
24
25 try:
26     import urllib.parse as compat_urllib_parse
27 except ImportError: # Python 2
28     import urllib as compat_urllib_parse
29
30 try:
31     from urllib.parse import urlparse as compat_urllib_parse_urlparse
32 except ImportError: # Python 2
33     from urlparse import urlparse as compat_urllib_parse_urlparse
34
35 try:
36     import http.cookiejar as compat_cookiejar
37 except ImportError: # Python 2
38     import cookielib as compat_cookiejar
39
40 try:
41     import html.entities as compat_html_entities
42 except ImportError: # Python 2
43     import htmlentitydefs as compat_html_entities
44
45 try:
46     import html.parser as compat_html_parser
47 except ImportError: # Python 2
48     import HTMLParser as compat_html_parser
49
50 try:
51     import http.client as compat_http_client
52 except ImportError: # Python 2
53     import httplib as compat_http_client
54
55 try:
56     from subprocess import DEVNULL
57     compat_subprocess_get_DEVNULL = lambda: DEVNULL
58 except ImportError:
59     compat_subprocess_get_DEVNULL = lambda: open(os.path.devnull, 'w')
60
61 try:
62     from urllib.parse import parse_qs as compat_parse_qs
63 except ImportError: # Python 2
64     # HACK: The following is the correct parse_qs implementation from cpython 3's stdlib.
65     # Python 2's version is apparently totally broken
66     def _unquote(string, encoding='utf-8', errors='replace'):
67         if string == '':
68             return string
69         res = string.split('%')
70         if len(res) == 1:
71             return string
72         if encoding is None:
73             encoding = 'utf-8'
74         if errors is None:
75             errors = 'replace'
76         # pct_sequence: contiguous sequence of percent-encoded bytes, decoded
77         pct_sequence = b''
78         string = res[0]
79         for item in res[1:]:
80             try:
81                 if not item:
82                     raise ValueError
83                 pct_sequence += item[:2].decode('hex')
84                 rest = item[2:]
85                 if not rest:
86                     # This segment was just a single percent-encoded character.
87                     # May be part of a sequence of code units, so delay decoding.
88                     # (Stored in pct_sequence).
89                     continue
90             except ValueError:
91                 rest = '%' + item
92             # Encountered non-percent-encoded characters. Flush the current
93             # pct_sequence.
94             string += pct_sequence.decode(encoding, errors) + rest
95             pct_sequence = b''
96         if pct_sequence:
97             # Flush the final pct_sequence
98             string += pct_sequence.decode(encoding, errors)
99         return string
100
101     def _parse_qsl(qs, keep_blank_values=False, strict_parsing=False,
102                 encoding='utf-8', errors='replace'):
103         qs, _coerce_result = qs, unicode
104         pairs = [s2 for s1 in qs.split('&') for s2 in s1.split(';')]
105         r = []
106         for name_value in pairs:
107             if not name_value and not strict_parsing:
108                 continue
109             nv = name_value.split('=', 1)
110             if len(nv) != 2:
111                 if strict_parsing:
112                     raise ValueError("bad query field: %r" % (name_value,))
113                 # Handle case of a control-name with no equal sign
114                 if keep_blank_values:
115                     nv.append('')
116                 else:
117                     continue
118             if len(nv[1]) or keep_blank_values:
119                 name = nv[0].replace('+', ' ')
120                 name = _unquote(name, encoding=encoding, errors=errors)
121                 name = _coerce_result(name)
122                 value = nv[1].replace('+', ' ')
123                 value = _unquote(value, encoding=encoding, errors=errors)
124                 value = _coerce_result(value)
125                 r.append((name, value))
126         return r
127
128     def compat_parse_qs(qs, keep_blank_values=False, strict_parsing=False,
129                 encoding='utf-8', errors='replace'):
130         parsed_result = {}
131         pairs = _parse_qsl(qs, keep_blank_values, strict_parsing,
132                         encoding=encoding, errors=errors)
133         for name, value in pairs:
134             if name in parsed_result:
135                 parsed_result[name].append(value)
136             else:
137                 parsed_result[name] = [value]
138         return parsed_result
139
140 try:
141     compat_str = unicode # Python 2
142 except NameError:
143     compat_str = str
144
145 try:
146     compat_chr = unichr # Python 2
147 except NameError:
148     compat_chr = chr
149
150 std_headers = {
151     'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:10.0) Gecko/20100101 Firefox/10.0',
152     'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7',
153     'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
154     'Accept-Encoding': 'gzip, deflate',
155     'Accept-Language': 'en-us,en;q=0.5',
156 }
157
158 def preferredencoding():
159     """Get preferred encoding.
160
161     Returns the best encoding scheme for the system, based on
162     locale.getpreferredencoding() and some further tweaks.
163     """
164     try:
165         pref = locale.getpreferredencoding()
166         u'TEST'.encode(pref)
167     except:
168         pref = 'UTF-8'
169
170     return pref
171
172 if sys.version_info < (3,0):
173     def compat_print(s):
174         print(s.encode(preferredencoding(), 'xmlcharrefreplace'))
175 else:
176     def compat_print(s):
177         assert type(s) == type(u'')
178         print(s)
179
180 # In Python 2.x, json.dump expects a bytestream.
181 # In Python 3.x, it writes to a character stream
182 if sys.version_info < (3,0):
183     def write_json_file(obj, fn):
184         with open(fn, 'wb') as f:
185             json.dump(obj, f)
186 else:
187     def write_json_file(obj, fn):
188         with open(fn, 'w', encoding='utf-8') as f:
189             json.dump(obj, f)
190
191 # Some library functions return bytestring on 2.X and unicode on 3.X
192 def enforce_unicode(s, encoding='utf-8'):
193     if type(s) != type(u''):
194         return s.decode(encoding)
195     return s
196
197 def htmlentity_transform(matchobj):
198     """Transforms an HTML entity to a character.
199
200     This function receives a match object and is intended to be used with
201     the re.sub() function.
202     """
203     entity = matchobj.group(1)
204
205     # Known non-numeric HTML entity
206     if entity in compat_html_entities.name2codepoint:
207         return compat_chr(compat_html_entities.name2codepoint[entity])
208
209     mobj = re.match(u'(?u)#(x?\\d+)', entity)
210     if mobj is not None:
211         numstr = mobj.group(1)
212         if numstr.startswith(u'x'):
213             base = 16
214             numstr = u'0%s' % numstr
215         else:
216             base = 10
217         return compat_chr(int(numstr, base))
218
219     # Unknown entity in name, return its literal representation
220     return (u'&%s;' % entity)
221
222 compat_html_parser.locatestarttagend = re.compile(r"""<[a-zA-Z][-.a-zA-Z0-9:_]*(?:\s+(?:(?<=['"\s])[^\s/>][^\s/=>]*(?:\s*=+\s*(?:'[^']*'|"[^"]*"|(?!['"])[^>\s]*))?\s*)*)?\s*""", re.VERBOSE) # backport bugfix
223 class AttrParser(compat_html_parser.HTMLParser):
224     """Modified HTMLParser that isolates a tag with the specified attribute"""
225     def __init__(self, attribute, value):
226         self.attribute = attribute
227         self.value = value
228         self.result = None
229         self.started = False
230         self.depth = {}
231         self.html = None
232         self.watch_startpos = False
233         self.error_count = 0
234         compat_html_parser.HTMLParser.__init__(self)
235
236     def error(self, message):
237         if self.error_count > 10 or self.started:
238             raise compat_html_parser.HTMLParseError(message, self.getpos())
239         self.rawdata = '\n'.join(self.html.split('\n')[self.getpos()[0]:]) # skip one line
240         self.error_count += 1
241         self.goahead(1)
242
243     def loads(self, html):
244         self.html = html
245         self.feed(html)
246         self.close()
247
248     def handle_starttag(self, tag, attrs):
249         attrs = dict(attrs)
250         if self.started:
251             self.find_startpos(None)
252         if self.attribute in attrs and attrs[self.attribute] == self.value:
253             self.result = [tag]
254             self.started = True
255             self.watch_startpos = True
256         if self.started:
257             if not tag in self.depth: self.depth[tag] = 0
258             self.depth[tag] += 1
259
260     def handle_endtag(self, tag):
261         if self.started:
262             if tag in self.depth: self.depth[tag] -= 1
263             if self.depth[self.result[0]] == 0:
264                 self.started = False
265                 self.result.append(self.getpos())
266
267     def find_startpos(self, x):
268         """Needed to put the start position of the result (self.result[1])
269         after the opening tag with the requested id"""
270         if self.watch_startpos:
271             self.watch_startpos = False
272             self.result.append(self.getpos())
273     handle_entityref = handle_charref = handle_data = handle_comment = \
274     handle_decl = handle_pi = unknown_decl = find_startpos
275
276     def get_result(self):
277         if self.result is None:
278             return None
279         if len(self.result) != 3:
280             return None
281         lines = self.html.split('\n')
282         lines = lines[self.result[1][0]-1:self.result[2][0]]
283         lines[0] = lines[0][self.result[1][1]:]
284         if len(lines) == 1:
285             lines[-1] = lines[-1][:self.result[2][1]-self.result[1][1]]
286         lines[-1] = lines[-1][:self.result[2][1]]
287         return '\n'.join(lines).strip()
288
289 def get_element_by_id(id, html):
290     """Return the content of the tag with the specified ID in the passed HTML document"""
291     return get_element_by_attribute("id", id, html)
292
293 def get_element_by_attribute(attribute, value, html):
294     """Return the content of the tag with the specified attribute in the passed HTML document"""
295     parser = AttrParser(attribute, value)
296     try:
297         parser.loads(html)
298     except compat_html_parser.HTMLParseError:
299         pass
300     return parser.get_result()
301
302
303 def clean_html(html):
304     """Clean an HTML snippet into a readable string"""
305     # Newline vs <br />
306     html = html.replace('\n', ' ')
307     html = re.sub(r'\s*<\s*br\s*/?\s*>\s*', '\n', html)
308     html = re.sub(r'<\s*/\s*p\s*>\s*<\s*p[^>]*>', '\n', html)
309     # Strip html tags
310     html = re.sub('<.*?>', '', html)
311     # Replace html entities
312     html = unescapeHTML(html)
313     return html
314
315
316 def sanitize_open(filename, open_mode):
317     """Try to open the given filename, and slightly tweak it if this fails.
318
319     Attempts to open the given filename. If this fails, it tries to change
320     the filename slightly, step by step, until it's either able to open it
321     or it fails and raises a final exception, like the standard open()
322     function.
323
324     It returns the tuple (stream, definitive_file_name).
325     """
326     try:
327         if filename == u'-':
328             if sys.platform == 'win32':
329                 import msvcrt
330                 msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
331             return (sys.stdout, filename)
332         stream = open(encodeFilename(filename), open_mode)
333         return (stream, filename)
334     except (IOError, OSError) as err:
335         # In case of error, try to remove win32 forbidden chars
336         filename = re.sub(u'[/<>:"\\|\\\\?\\*]', u'#', filename)
337
338         # An exception here should be caught in the caller
339         stream = open(encodeFilename(filename), open_mode)
340         return (stream, filename)
341
342
343 def timeconvert(timestr):
344     """Convert RFC 2822 defined time string into system timestamp"""
345     timestamp = None
346     timetuple = email.utils.parsedate_tz(timestr)
347     if timetuple is not None:
348         timestamp = email.utils.mktime_tz(timetuple)
349     return timestamp
350
351 def sanitize_filename(s, restricted=False, is_id=False):
352     """Sanitizes a string so it could be used as part of a filename.
353     If restricted is set, use a stricter subset of allowed characters.
354     Set is_id if this is not an arbitrary string, but an ID that should be kept if possible
355     """
356     def replace_insane(char):
357         if char == '?' or ord(char) < 32 or ord(char) == 127:
358             return ''
359         elif char == '"':
360             return '' if restricted else '\''
361         elif char == ':':
362             return '_-' if restricted else ' -'
363         elif char in '\\/|*<>':
364             return '_'
365         if restricted and (char in '!&\'()[]{}$;`^,#' or char.isspace()):
366             return '_'
367         if restricted and ord(char) > 127:
368             return '_'
369         return char
370
371     result = u''.join(map(replace_insane, s))
372     if not is_id:
373         while '__' in result:
374             result = result.replace('__', '_')
375         result = result.strip('_')
376         # Common case of "Foreign band name - English song title"
377         if restricted and result.startswith('-_'):
378             result = result[2:]
379         if not result:
380             result = '_'
381     return result
382
383 def orderedSet(iterable):
384     """ Remove all duplicates from the input iterable """
385     res = []
386     for el in iterable:
387         if el not in res:
388             res.append(el)
389     return res
390
391 def unescapeHTML(s):
392     """
393     @param s a string
394     """
395     assert type(s) == type(u'')
396
397     result = re.sub(u'(?u)&(.+?);', htmlentity_transform, s)
398     return result
399
400 def encodeFilename(s):
401     """
402     @param s The name of the file
403     """
404
405     assert type(s) == type(u'')
406
407     # Python 3 has a Unicode API
408     if sys.version_info >= (3, 0):
409         return s
410
411     if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
412         # Pass u'' directly to use Unicode APIs on Windows 2000 and up
413         # (Detecting Windows NT 4 is tricky because 'major >= 4' would
414         # match Windows 9x series as well. Besides, NT 4 is obsolete.)
415         return s
416     else:
417         return s.encode(sys.getfilesystemencoding(), 'ignore')
418
419 def rsa_verify(message, signature, key):
420     from struct import pack
421     from hashlib import sha256
422     from sys import version_info
423     def b(x):
424         if version_info[0] == 2: return x
425         else: return x.encode('latin1')
426     assert(type(message) == type(b('')))
427     block_size = 0
428     n = key[0]
429     while n:
430         block_size += 1
431         n >>= 8
432     signature = pow(int(signature, 16), key[1], key[0])
433     raw_bytes = []
434     while signature:
435         raw_bytes.insert(0, pack("B", signature & 0xFF))
436         signature >>= 8
437     signature = (block_size - len(raw_bytes)) * b('\x00') + b('').join(raw_bytes)
438     if signature[0:2] != b('\x00\x01'): return False
439     signature = signature[2:]
440     if not b('\x00') in signature: return False
441     signature = signature[signature.index(b('\x00'))+1:]
442     if not signature.startswith(b('\x30\x31\x30\x0D\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20')): return False
443     signature = signature[19:]
444     if signature != sha256(message).digest(): return False
445     return True
446
447 class DownloadError(Exception):
448     """Download Error exception.
449
450     This exception may be thrown by FileDownloader objects if they are not
451     configured to continue on errors. They will contain the appropriate
452     error message.
453     """
454     pass
455
456
457 class SameFileError(Exception):
458     """Same File exception.
459
460     This exception will be thrown by FileDownloader objects if they detect
461     multiple files would have to be downloaded to the same file on disk.
462     """
463     pass
464
465
466 class PostProcessingError(Exception):
467     """Post Processing exception.
468
469     This exception may be raised by PostProcessor's .run() method to
470     indicate an error in the postprocessing task.
471     """
472     pass
473
474 class MaxDownloadsReached(Exception):
475     """ --max-downloads limit has been reached. """
476     pass
477
478
479 class UnavailableVideoError(Exception):
480     """Unavailable Format exception.
481
482     This exception will be thrown when a video is requested
483     in a format that is not available for that video.
484     """
485     pass
486
487
488 class ContentTooShortError(Exception):
489     """Content Too Short exception.
490
491     This exception may be raised by FileDownloader objects when a file they
492     download is too small for what the server announced first, indicating
493     the connection was probably interrupted.
494     """
495     # Both in bytes
496     downloaded = None
497     expected = None
498
499     def __init__(self, downloaded, expected):
500         self.downloaded = downloaded
501         self.expected = expected
502
503 class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
504     """Handler for HTTP requests and responses.
505
506     This class, when installed with an OpenerDirector, automatically adds
507     the standard headers to every HTTP request and handles gzipped and
508     deflated responses from web servers. If compression is to be avoided in
509     a particular request, the original request in the program code only has
510     to include the HTTP header "Youtubedl-No-Compression", which will be
511     removed before making the real request.
512
513     Part of this code was copied from:
514
515     http://techknack.net/python-urllib2-handlers/
516
517     Andrew Rowls, the author of that code, agreed to release it to the
518     public domain.
519     """
520
521     @staticmethod
522     def deflate(data):
523         try:
524             return zlib.decompress(data, -zlib.MAX_WBITS)
525         except zlib.error:
526             return zlib.decompress(data)
527
528     @staticmethod
529     def addinfourl_wrapper(stream, headers, url, code):
530         if hasattr(compat_urllib_request.addinfourl, 'getcode'):
531             return compat_urllib_request.addinfourl(stream, headers, url, code)
532         ret = compat_urllib_request.addinfourl(stream, headers, url)
533         ret.code = code
534         return ret
535
536     def http_request(self, req):
537         for h in std_headers:
538             if h in req.headers:
539                 del req.headers[h]
540             req.add_header(h, std_headers[h])
541         if 'Youtubedl-no-compression' in req.headers:
542             if 'Accept-encoding' in req.headers:
543                 del req.headers['Accept-encoding']
544             del req.headers['Youtubedl-no-compression']
545         return req
546
547     def http_response(self, req, resp):
548         old_resp = resp
549         # gzip
550         if resp.headers.get('Content-encoding', '') == 'gzip':
551             gz = gzip.GzipFile(fileobj=io.BytesIO(resp.read()), mode='r')
552             resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
553             resp.msg = old_resp.msg
554         # deflate
555         if resp.headers.get('Content-encoding', '') == 'deflate':
556             gz = io.BytesIO(self.deflate(resp.read()))
557             resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
558             resp.msg = old_resp.msg
559         return resp
560
561     https_request = http_request
562     https_response = http_response