Merge remote-tracking branch 'origin/master' into IE_cleanup
[youtube-dl] / youtube_dl / utils.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import gzip
5 import htmlentitydefs
6 import HTMLParser
7 import locale
8 import os
9 import re
10 import sys
11 import zlib
12 import urllib2
13 import email.utils
14 import json
15
16 try:
17         import cStringIO as StringIO
18 except ImportError:
19         import StringIO
20
21 std_headers = {
22         'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:10.0) Gecko/20100101 Firefox/10.0',
23         'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7',
24         'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
25         'Accept-Encoding': 'gzip, deflate',
26         'Accept-Language': 'en-us,en;q=0.5',
27 }
28
29 try:
30         u = unicode # Python 2
31 except NameError:
32         u = str
33
34 def preferredencoding():
35         """Get preferred encoding.
36
37         Returns the best encoding scheme for the system, based on
38         locale.getpreferredencoding() and some further tweaks.
39         """
40         try:
41                 pref = locale.getpreferredencoding()
42                 u'TEST'.encode(pref)
43         except:
44                 pref = 'UTF-8'
45
46         return pref
47
48
49 def htmlentity_transform(matchobj):
50         """Transforms an HTML entity to a character.
51
52         This function receives a match object and is intended to be used with
53         the re.sub() function.
54         """
55         entity = matchobj.group(1)
56
57         # Known non-numeric HTML entity
58         if entity in htmlentitydefs.name2codepoint:
59                 return unichr(htmlentitydefs.name2codepoint[entity])
60
61         mobj = re.match(ur'(?u)#(x?\d+)', entity)
62         if mobj is not None:
63                 numstr = mobj.group(1)
64                 if numstr.startswith(u'x'):
65                         base = 16
66                         numstr = u'0%s' % numstr
67                 else:
68                         base = 10
69                 return unichr(int(numstr, base))
70
71         # Unknown entity in name, return its literal representation
72         return (u'&%s;' % entity)
73
74 HTMLParser.locatestarttagend = re.compile(r"""<[a-zA-Z][-.a-zA-Z0-9:_]*(?:\s+(?:(?<=['"\s])[^\s/>][^\s/=>]*(?:\s*=+\s*(?:'[^']*'|"[^"]*"|(?!['"])[^>\s]*))?\s*)*)?\s*""", re.VERBOSE) # backport bugfix
75 class IDParser(HTMLParser.HTMLParser):
76         """Modified HTMLParser that isolates a tag with the specified id"""
77         def __init__(self, id):
78                 self.id = id
79                 self.result = None
80                 self.started = False
81                 self.depth = {}
82                 self.html = None
83                 self.watch_startpos = False
84                 self.error_count = 0
85                 HTMLParser.HTMLParser.__init__(self)
86
87         def error(self, message):
88                 if self.error_count > 10 or self.started:
89                         raise HTMLParser.HTMLParseError(message, self.getpos())
90                 self.rawdata = '\n'.join(self.html.split('\n')[self.getpos()[0]:]) # skip one line
91                 self.error_count += 1
92                 self.goahead(1)
93
94         def loads(self, html):
95                 self.html = html
96                 self.feed(html)
97                 self.close()
98
99         def handle_starttag(self, tag, attrs):
100                 attrs = dict(attrs)
101                 if self.started:
102                         self.find_startpos(None)
103                 if 'id' in attrs and attrs['id'] == self.id:
104                         self.result = [tag]
105                         self.started = True
106                         self.watch_startpos = True
107                 if self.started:
108                         if not tag in self.depth: self.depth[tag] = 0
109                         self.depth[tag] += 1
110
111         def handle_endtag(self, tag):
112                 if self.started:
113                         if tag in self.depth: self.depth[tag] -= 1
114                         if self.depth[self.result[0]] == 0:
115                                 self.started = False
116                                 self.result.append(self.getpos())
117
118         def find_startpos(self, x):
119                 """Needed to put the start position of the result (self.result[1])
120                 after the opening tag with the requested id"""
121                 if self.watch_startpos:
122                         self.watch_startpos = False
123                         self.result.append(self.getpos())
124         handle_entityref = handle_charref = handle_data = handle_comment = \
125         handle_decl = handle_pi = unknown_decl = find_startpos
126
127         def get_result(self):
128                 if self.result is None:
129                         return None
130                 if len(self.result) != 3:
131                         return None
132                 lines = self.html.split('\n')
133                 lines = lines[self.result[1][0]-1:self.result[2][0]]
134                 lines[0] = lines[0][self.result[1][1]:]
135                 if len(lines) == 1:
136                         lines[-1] = lines[-1][:self.result[2][1]-self.result[1][1]]
137                 lines[-1] = lines[-1][:self.result[2][1]]
138                 return '\n'.join(lines).strip()
139
140 def get_element_by_id(id, html):
141         """Return the content of the tag with the specified id in the passed HTML document"""
142         parser = IDParser(id)
143         try:
144                 parser.loads(html)
145         except HTMLParser.HTMLParseError:
146                 pass
147         return parser.get_result()
148
149
150 def clean_html(html):
151         """Clean an HTML snippet into a readable string"""
152         # Newline vs <br />
153         html = html.replace('\n', ' ')
154         html = re.sub('\s*<\s*br\s*/?\s*>\s*', '\n', html)
155         # Strip html tags
156         html = re.sub('<.*?>', '', html)
157         # Replace html entities
158         html = unescapeHTML(html)
159         return html
160
161
162 def sanitize_open(filename, open_mode):
163         """Try to open the given filename, and slightly tweak it if this fails.
164
165         Attempts to open the given filename. If this fails, it tries to change
166         the filename slightly, step by step, until it's either able to open it
167         or it fails and raises a final exception, like the standard open()
168         function.
169
170         It returns the tuple (stream, definitive_file_name).
171         """
172         try:
173                 if filename == u'-':
174                         if sys.platform == 'win32':
175                                 import msvcrt
176                                 msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
177                         return (sys.stdout, filename)
178                 stream = open(encodeFilename(filename), open_mode)
179                 return (stream, filename)
180         except (IOError, OSError), err:
181                 # In case of error, try to remove win32 forbidden chars
182                 filename = re.sub(ur'[/<>:"\|\?\*]', u'#', filename)
183
184                 # An exception here should be caught in the caller
185                 stream = open(encodeFilename(filename), open_mode)
186                 return (stream, filename)
187
188
189 def timeconvert(timestr):
190         """Convert RFC 2822 defined time string into system timestamp"""
191         timestamp = None
192         timetuple = email.utils.parsedate_tz(timestr)
193         if timetuple is not None:
194                 timestamp = email.utils.mktime_tz(timetuple)
195         return timestamp
196
197 def sanitize_filename(s, restricted=False):
198         """Sanitizes a string so it could be used as part of a filename.
199         If restricted is set, use a stricter subset of allowed characters.
200         """
201         def replace_insane(char):
202                 if char == '?' or ord(char) < 32 or ord(char) == 127:
203                         return ''
204                 elif char == '"':
205                         return '' if restricted else '\''
206                 elif char == ':':
207                         return '_-' if restricted else ' -'
208                 elif char in '\\/|*<>':
209                         return '_'
210                 if restricted and (char in '!&\'' or char.isspace()):
211                         return '_'
212                 if restricted and ord(char) > 127:
213                         return '_'
214                 return char
215
216         result = u''.join(map(replace_insane, s))
217         while '__' in result:
218                 result = result.replace('__', '_')
219         result = result.strip('_')
220         # Common case of "Foreign band name - English song title"
221         if restricted and result.startswith('-_'):
222                 result = result[2:]
223         if not result:
224                 result = '_'
225         return result
226
227 def orderedSet(iterable):
228         """ Remove all duplicates from the input iterable """
229         res = []
230         for el in iterable:
231                 if el not in res:
232                         res.append(el)
233         return res
234
235 def unescapeHTML(s):
236         """
237         @param s a string
238         """
239         assert type(s) == type(u'')
240
241         result = re.sub(ur'(?u)&(.+?);', htmlentity_transform, s)
242         return result
243
244 def encodeFilename(s):
245         """
246         @param s The name of the file
247         """
248
249         assert type(s) == type(u'')
250
251         if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
252                 # Pass u'' directly to use Unicode APIs on Windows 2000 and up
253                 # (Detecting Windows NT 4 is tricky because 'major >= 4' would
254                 # match Windows 9x series as well. Besides, NT 4 is obsolete.)
255                 return s
256         else:
257                 return s.encode(sys.getfilesystemencoding(), 'ignore')
258
259 class DownloadError(Exception):
260         """Download Error exception.
261
262         This exception may be thrown by FileDownloader objects if they are not
263         configured to continue on errors. They will contain the appropriate
264         error message.
265         """
266         pass
267
268
269 class SameFileError(Exception):
270         """Same File exception.
271
272         This exception will be thrown by FileDownloader objects if they detect
273         multiple files would have to be downloaded to the same file on disk.
274         """
275         pass
276
277
278 class PostProcessingError(Exception):
279         """Post Processing exception.
280
281         This exception may be raised by PostProcessor's .run() method to
282         indicate an error in the postprocessing task.
283         """
284         pass
285
286 class MaxDownloadsReached(Exception):
287         """ --max-downloads limit has been reached. """
288         pass
289
290
291 class UnavailableVideoError(Exception):
292         """Unavailable Format exception.
293
294         This exception will be thrown when a video is requested
295         in a format that is not available for that video.
296         """
297         pass
298
299
300 class ContentTooShortError(Exception):
301         """Content Too Short exception.
302
303         This exception may be raised by FileDownloader objects when a file they
304         download is too small for what the server announced first, indicating
305         the connection was probably interrupted.
306         """
307         # Both in bytes
308         downloaded = None
309         expected = None
310
311         def __init__(self, downloaded, expected):
312                 self.downloaded = downloaded
313                 self.expected = expected
314
315
316 class Trouble(Exception):
317         """Trouble helper exception
318
319         This is an exception to be handled with
320         FileDownloader.trouble
321         """
322
323 class YoutubeDLHandler(urllib2.HTTPHandler):
324         """Handler for HTTP requests and responses.
325
326         This class, when installed with an OpenerDirector, automatically adds
327         the standard headers to every HTTP request and handles gzipped and
328         deflated responses from web servers. If compression is to be avoided in
329         a particular request, the original request in the program code only has
330         to include the HTTP header "Youtubedl-No-Compression", which will be
331         removed before making the real request.
332
333         Part of this code was copied from:
334
335         http://techknack.net/python-urllib2-handlers/
336
337         Andrew Rowls, the author of that code, agreed to release it to the
338         public domain.
339         """
340
341         @staticmethod
342         def deflate(data):
343                 try:
344                         return zlib.decompress(data, -zlib.MAX_WBITS)
345                 except zlib.error:
346                         return zlib.decompress(data)
347
348         @staticmethod
349         def addinfourl_wrapper(stream, headers, url, code):
350                 if hasattr(urllib2.addinfourl, 'getcode'):
351                         return urllib2.addinfourl(stream, headers, url, code)
352                 ret = urllib2.addinfourl(stream, headers, url)
353                 ret.code = code
354                 return ret
355
356         def http_request(self, req):
357                 for h in std_headers:
358                         if h in req.headers:
359                                 del req.headers[h]
360                         req.add_header(h, std_headers[h])
361                 if 'Youtubedl-no-compression' in req.headers:
362                         if 'Accept-encoding' in req.headers:
363                                 del req.headers['Accept-encoding']
364                         del req.headers['Youtubedl-no-compression']
365                 return req
366
367         def http_response(self, req, resp):
368                 old_resp = resp
369                 # gzip
370                 if resp.headers.get('Content-encoding', '') == 'gzip':
371                         gz = gzip.GzipFile(fileobj=StringIO.StringIO(resp.read()), mode='r')
372                         resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
373                         resp.msg = old_resp.msg
374                 # deflate
375                 if resp.headers.get('Content-encoding', '') == 'deflate':
376                         gz = StringIO.StringIO(self.deflate(resp.read()))
377                         resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
378                         resp.msg = old_resp.msg
379                 return resp