[cache] Fix writing to paths with unicode characters
[youtube-dl] / youtube_dl / utils.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 from __future__ import unicode_literals
5
6 import calendar
7 import codecs
8 import contextlib
9 import ctypes
10 import datetime
11 import email.utils
12 import errno
13 import gzip
14 import itertools
15 import io
16 import json
17 import locale
18 import math
19 import os
20 import pipes
21 import platform
22 import re
23 import ssl
24 import socket
25 import struct
26 import subprocess
27 import sys
28 import tempfile
29 import traceback
30 import xml.etree.ElementTree
31 import zlib
32
33 from .compat import (
34     compat_chr,
35     compat_getenv,
36     compat_html_entities,
37     compat_parse_qs,
38     compat_str,
39     compat_urllib_error,
40     compat_urllib_parse,
41     compat_urllib_parse_urlparse,
42     compat_urllib_request,
43     compat_urlparse,
44 )
45
46
47 # This is not clearly defined otherwise
48 compiled_regex_type = type(re.compile(''))
49
50 std_headers = {
51     'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:10.0) Gecko/20100101 Firefox/10.0 (Chrome)',
52     'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7',
53     'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
54     'Accept-Encoding': 'gzip, deflate',
55     'Accept-Language': 'en-us,en;q=0.5',
56 }
57
58 def preferredencoding():
59     """Get preferred encoding.
60
61     Returns the best encoding scheme for the system, based on
62     locale.getpreferredencoding() and some further tweaks.
63     """
64     try:
65         pref = locale.getpreferredencoding()
66         'TEST'.encode(pref)
67     except:
68         pref = 'UTF-8'
69
70     return pref
71
72
73 def write_json_file(obj, fn):
74     """ Encode obj as JSON and write it to fn, atomically """
75
76     fn = encodeFilename(fn)
77     if sys.version_info < (3, 0):
78         encoding = get_filesystem_encoding()
79         # os.path.basename returns a bytes object, but NamedTemporaryFile
80         # will fail if the filename contains non ascii characters unless we
81         # use a unicode object
82         path_basename = lambda f: os.path.basename(fn).decode(encoding)
83         # the same for os.path.dirname
84         path_dirname = lambda f: os.path.dirname(fn).decode(encoding)
85     else:
86         path_basename = os.path.basename
87         path_dirname = os.path.dirname
88
89     args = {
90         'suffix': '.tmp',
91         'prefix': path_basename(fn) + '.',
92         'dir': path_dirname(fn),
93         'delete': False,
94     }
95
96     # In Python 2.x, json.dump expects a bytestream.
97     # In Python 3.x, it writes to a character stream
98     if sys.version_info < (3, 0):
99         args['mode'] = 'wb'
100     else:
101         args.update({
102             'mode': 'w',
103             'encoding': 'utf-8',
104         })
105
106     tf = tempfile.NamedTemporaryFile(**args)
107
108     try:
109         with tf:
110             json.dump(obj, tf)
111         os.rename(tf.name, fn)
112     except:
113         try:
114             os.remove(tf.name)
115         except OSError:
116             pass
117         raise
118
119
120 if sys.version_info >= (2, 7):
121     def find_xpath_attr(node, xpath, key, val):
122         """ Find the xpath xpath[@key=val] """
123         assert re.match(r'^[a-zA-Z-]+$', key)
124         assert re.match(r'^[a-zA-Z0-9@\s:._-]*$', val)
125         expr = xpath + u"[@%s='%s']" % (key, val)
126         return node.find(expr)
127 else:
128     def find_xpath_attr(node, xpath, key, val):
129         # Here comes the crazy part: In 2.6, if the xpath is a unicode,
130         # .//node does not match if a node is a direct child of . !
131         if isinstance(xpath, unicode):
132             xpath = xpath.encode('ascii')
133
134         for f in node.findall(xpath):
135             if f.attrib.get(key) == val:
136                 return f
137         return None
138
139 # On python2.6 the xml.etree.ElementTree.Element methods don't support
140 # the namespace parameter
141 def xpath_with_ns(path, ns_map):
142     components = [c.split(':') for c in path.split('/')]
143     replaced = []
144     for c in components:
145         if len(c) == 1:
146             replaced.append(c[0])
147         else:
148             ns, tag = c
149             replaced.append('{%s}%s' % (ns_map[ns], tag))
150     return '/'.join(replaced)
151
152
153 def xpath_text(node, xpath, name=None, fatal=False):
154     if sys.version_info < (2, 7):  # Crazy 2.6
155         xpath = xpath.encode('ascii')
156
157     n = node.find(xpath)
158     if n is None:
159         if fatal:
160             name = xpath if name is None else name
161             raise ExtractorError('Could not find XML element %s' % name)
162         else:
163             return None
164     return n.text
165
166
167 def get_element_by_id(id, html):
168     """Return the content of the tag with the specified ID in the passed HTML document"""
169     return get_element_by_attribute("id", id, html)
170
171
172 def get_element_by_attribute(attribute, value, html):
173     """Return the content of the tag with the specified attribute in the passed HTML document"""
174
175     m = re.search(r'''(?xs)
176         <([a-zA-Z0-9:._-]+)
177          (?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]+|="[^"]+"|='[^']+'))*?
178          \s+%s=['"]?%s['"]?
179          (?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]+|="[^"]+"|='[^']+'))*?
180         \s*>
181         (?P<content>.*?)
182         </\1>
183     ''' % (re.escape(attribute), re.escape(value)), html)
184
185     if not m:
186         return None
187     res = m.group('content')
188
189     if res.startswith('"') or res.startswith("'"):
190         res = res[1:-1]
191
192     return unescapeHTML(res)
193
194
195 def clean_html(html):
196     """Clean an HTML snippet into a readable string"""
197     # Newline vs <br />
198     html = html.replace('\n', ' ')
199     html = re.sub(r'\s*<\s*br\s*/?\s*>\s*', '\n', html)
200     html = re.sub(r'<\s*/\s*p\s*>\s*<\s*p[^>]*>', '\n', html)
201     # Strip html tags
202     html = re.sub('<.*?>', '', html)
203     # Replace html entities
204     html = unescapeHTML(html)
205     return html.strip()
206
207
208 def sanitize_open(filename, open_mode):
209     """Try to open the given filename, and slightly tweak it if this fails.
210
211     Attempts to open the given filename. If this fails, it tries to change
212     the filename slightly, step by step, until it's either able to open it
213     or it fails and raises a final exception, like the standard open()
214     function.
215
216     It returns the tuple (stream, definitive_file_name).
217     """
218     try:
219         if filename == '-':
220             if sys.platform == 'win32':
221                 import msvcrt
222                 msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
223             return (sys.stdout.buffer if hasattr(sys.stdout, 'buffer') else sys.stdout, filename)
224         stream = open(encodeFilename(filename), open_mode)
225         return (stream, filename)
226     except (IOError, OSError) as err:
227         if err.errno in (errno.EACCES,):
228             raise
229
230         # In case of error, try to remove win32 forbidden chars
231         alt_filename = os.path.join(
232                         re.sub('[/<>:"\\|\\\\?\\*]', '#', path_part)
233                         for path_part in os.path.split(filename)
234                        )
235         if alt_filename == filename:
236             raise
237         else:
238             # An exception here should be caught in the caller
239             stream = open(encodeFilename(filename), open_mode)
240             return (stream, alt_filename)
241
242
243 def timeconvert(timestr):
244     """Convert RFC 2822 defined time string into system timestamp"""
245     timestamp = None
246     timetuple = email.utils.parsedate_tz(timestr)
247     if timetuple is not None:
248         timestamp = email.utils.mktime_tz(timetuple)
249     return timestamp
250
251 def sanitize_filename(s, restricted=False, is_id=False):
252     """Sanitizes a string so it could be used as part of a filename.
253     If restricted is set, use a stricter subset of allowed characters.
254     Set is_id if this is not an arbitrary string, but an ID that should be kept if possible
255     """
256     def replace_insane(char):
257         if char == '?' or ord(char) < 32 or ord(char) == 127:
258             return ''
259         elif char == '"':
260             return '' if restricted else '\''
261         elif char == ':':
262             return '_-' if restricted else ' -'
263         elif char in '\\/|*<>':
264             return '_'
265         if restricted and (char in '!&\'()[]{}$;`^,#' or char.isspace()):
266             return '_'
267         if restricted and ord(char) > 127:
268             return '_'
269         return char
270
271     result = ''.join(map(replace_insane, s))
272     if not is_id:
273         while '__' in result:
274             result = result.replace('__', '_')
275         result = result.strip('_')
276         # Common case of "Foreign band name - English song title"
277         if restricted and result.startswith('-_'):
278             result = result[2:]
279         if not result:
280             result = '_'
281     return result
282
283 def orderedSet(iterable):
284     """ Remove all duplicates from the input iterable """
285     res = []
286     for el in iterable:
287         if el not in res:
288             res.append(el)
289     return res
290
291
292 def _htmlentity_transform(entity):
293     """Transforms an HTML entity to a character."""
294     # Known non-numeric HTML entity
295     if entity in compat_html_entities.name2codepoint:
296         return compat_chr(compat_html_entities.name2codepoint[entity])
297
298     mobj = re.match(r'#(x?[0-9]+)', entity)
299     if mobj is not None:
300         numstr = mobj.group(1)
301         if numstr.startswith('x'):
302             base = 16
303             numstr = '0%s' % numstr
304         else:
305             base = 10
306         return compat_chr(int(numstr, base))
307
308     # Unknown entity in name, return its literal representation
309     return ('&%s;' % entity)
310
311
312 def unescapeHTML(s):
313     if s is None:
314         return None
315     assert type(s) == compat_str
316
317     return re.sub(
318         r'&([^;]+);', lambda m: _htmlentity_transform(m.group(1)), s)
319
320
321 def encodeFilename(s, for_subprocess=False):
322     """
323     @param s The name of the file
324     """
325
326     assert type(s) == compat_str
327
328     # Python 3 has a Unicode API
329     if sys.version_info >= (3, 0):
330         return s
331
332     if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
333         # Pass '' directly to use Unicode APIs on Windows 2000 and up
334         # (Detecting Windows NT 4 is tricky because 'major >= 4' would
335         # match Windows 9x series as well. Besides, NT 4 is obsolete.)
336         if not for_subprocess:
337             return s
338         else:
339             # For subprocess calls, encode with locale encoding
340             # Refer to http://stackoverflow.com/a/9951851/35070
341             encoding = preferredencoding()
342     else:
343         encoding = sys.getfilesystemencoding()
344     if encoding is None:
345         encoding = 'utf-8'
346     return s.encode(encoding, 'ignore')
347
348
349 def encodeArgument(s):
350     if not isinstance(s, compat_str):
351         # Legacy code that uses byte strings
352         # Uncomment the following line after fixing all post processors
353         #assert False, 'Internal error: %r should be of type %r, is %r' % (s, compat_str, type(s))
354         s = s.decode('ascii')
355     return encodeFilename(s, True)
356
357
358 def decodeOption(optval):
359     if optval is None:
360         return optval
361     if isinstance(optval, bytes):
362         optval = optval.decode(preferredencoding())
363
364     assert isinstance(optval, compat_str)
365     return optval
366
367 def formatSeconds(secs):
368     if secs > 3600:
369         return '%d:%02d:%02d' % (secs // 3600, (secs % 3600) // 60, secs % 60)
370     elif secs > 60:
371         return '%d:%02d' % (secs // 60, secs % 60)
372     else:
373         return '%d' % secs
374
375
376 def make_HTTPS_handler(opts_no_check_certificate, **kwargs):
377     if sys.version_info < (3, 2):
378         import httplib
379
380         class HTTPSConnectionV3(httplib.HTTPSConnection):
381             def __init__(self, *args, **kwargs):
382                 httplib.HTTPSConnection.__init__(self, *args, **kwargs)
383
384             def connect(self):
385                 sock = socket.create_connection((self.host, self.port), self.timeout)
386                 if getattr(self, '_tunnel_host', False):
387                     self.sock = sock
388                     self._tunnel()
389                 try:
390                     self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file, ssl_version=ssl.PROTOCOL_TLSv1)
391                 except ssl.SSLError:
392                     self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file, ssl_version=ssl.PROTOCOL_SSLv23)
393
394         class HTTPSHandlerV3(compat_urllib_request.HTTPSHandler):
395             def https_open(self, req):
396                 return self.do_open(HTTPSConnectionV3, req)
397         return HTTPSHandlerV3(**kwargs)
398     elif hasattr(ssl, 'create_default_context'):  # Python >= 3.4
399         context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
400         context.options &= ~ssl.OP_NO_SSLv3  # Allow older, not-as-secure SSLv3
401         if opts_no_check_certificate:
402             context.verify_mode = ssl.CERT_NONE
403         return compat_urllib_request.HTTPSHandler(context=context, **kwargs)
404     else:  # Python < 3.4
405         context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
406         context.verify_mode = (ssl.CERT_NONE
407                                if opts_no_check_certificate
408                                else ssl.CERT_REQUIRED)
409         context.set_default_verify_paths()
410         try:
411             context.load_default_certs()
412         except AttributeError:
413             pass  # Python < 3.4
414         return compat_urllib_request.HTTPSHandler(context=context, **kwargs)
415
416 class ExtractorError(Exception):
417     """Error during info extraction."""
418     def __init__(self, msg, tb=None, expected=False, cause=None, video_id=None):
419         """ tb, if given, is the original traceback (so that it can be printed out).
420         If expected is set, this is a normal error message and most likely not a bug in youtube-dl.
421         """
422
423         if sys.exc_info()[0] in (compat_urllib_error.URLError, socket.timeout, UnavailableVideoError):
424             expected = True
425         if video_id is not None:
426             msg = video_id + ': ' + msg
427         if cause:
428             msg += ' (caused by %r)' % cause
429         if not expected:
430             msg = msg + '; please report this issue on https://yt-dl.org/bug . Be sure to call youtube-dl with the --verbose flag and include its complete output. Make sure you are using the latest version; type  youtube-dl -U  to update.'
431         super(ExtractorError, self).__init__(msg)
432
433         self.traceback = tb
434         self.exc_info = sys.exc_info()  # preserve original exception
435         self.cause = cause
436         self.video_id = video_id
437
438     def format_traceback(self):
439         if self.traceback is None:
440             return None
441         return ''.join(traceback.format_tb(self.traceback))
442
443
444 class RegexNotFoundError(ExtractorError):
445     """Error when a regex didn't match"""
446     pass
447
448
449 class DownloadError(Exception):
450     """Download Error exception.
451
452     This exception may be thrown by FileDownloader objects if they are not
453     configured to continue on errors. They will contain the appropriate
454     error message.
455     """
456     def __init__(self, msg, exc_info=None):
457         """ exc_info, if given, is the original exception that caused the trouble (as returned by sys.exc_info()). """
458         super(DownloadError, self).__init__(msg)
459         self.exc_info = exc_info
460
461
462 class SameFileError(Exception):
463     """Same File exception.
464
465     This exception will be thrown by FileDownloader objects if they detect
466     multiple files would have to be downloaded to the same file on disk.
467     """
468     pass
469
470
471 class PostProcessingError(Exception):
472     """Post Processing exception.
473
474     This exception may be raised by PostProcessor's .run() method to
475     indicate an error in the postprocessing task.
476     """
477     def __init__(self, msg):
478         self.msg = msg
479
480 class MaxDownloadsReached(Exception):
481     """ --max-downloads limit has been reached. """
482     pass
483
484
485 class UnavailableVideoError(Exception):
486     """Unavailable Format exception.
487
488     This exception will be thrown when a video is requested
489     in a format that is not available for that video.
490     """
491     pass
492
493
494 class ContentTooShortError(Exception):
495     """Content Too Short exception.
496
497     This exception may be raised by FileDownloader objects when a file they
498     download is too small for what the server announced first, indicating
499     the connection was probably interrupted.
500     """
501     # Both in bytes
502     downloaded = None
503     expected = None
504
505     def __init__(self, downloaded, expected):
506         self.downloaded = downloaded
507         self.expected = expected
508
509 class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
510     """Handler for HTTP requests and responses.
511
512     This class, when installed with an OpenerDirector, automatically adds
513     the standard headers to every HTTP request and handles gzipped and
514     deflated responses from web servers. If compression is to be avoided in
515     a particular request, the original request in the program code only has
516     to include the HTTP header "Youtubedl-No-Compression", which will be
517     removed before making the real request.
518
519     Part of this code was copied from:
520
521     http://techknack.net/python-urllib2-handlers/
522
523     Andrew Rowls, the author of that code, agreed to release it to the
524     public domain.
525     """
526
527     @staticmethod
528     def deflate(data):
529         try:
530             return zlib.decompress(data, -zlib.MAX_WBITS)
531         except zlib.error:
532             return zlib.decompress(data)
533
534     @staticmethod
535     def addinfourl_wrapper(stream, headers, url, code):
536         if hasattr(compat_urllib_request.addinfourl, 'getcode'):
537             return compat_urllib_request.addinfourl(stream, headers, url, code)
538         ret = compat_urllib_request.addinfourl(stream, headers, url)
539         ret.code = code
540         return ret
541
542     def http_request(self, req):
543         for h, v in std_headers.items():
544             if h not in req.headers:
545                 req.add_header(h, v)
546         if 'Youtubedl-no-compression' in req.headers:
547             if 'Accept-encoding' in req.headers:
548                 del req.headers['Accept-encoding']
549             del req.headers['Youtubedl-no-compression']
550         if 'Youtubedl-user-agent' in req.headers:
551             if 'User-agent' in req.headers:
552                 del req.headers['User-agent']
553             req.headers['User-agent'] = req.headers['Youtubedl-user-agent']
554             del req.headers['Youtubedl-user-agent']
555
556         if sys.version_info < (2, 7) and '#' in req.get_full_url():
557             # Python 2.6 is brain-dead when it comes to fragments
558             req._Request__original = req._Request__original.partition('#')[0]
559             req._Request__r_type = req._Request__r_type.partition('#')[0]
560
561         return req
562
563     def http_response(self, req, resp):
564         old_resp = resp
565         # gzip
566         if resp.headers.get('Content-encoding', '') == 'gzip':
567             content = resp.read()
568             gz = gzip.GzipFile(fileobj=io.BytesIO(content), mode='rb')
569             try:
570                 uncompressed = io.BytesIO(gz.read())
571             except IOError as original_ioerror:
572                 # There may be junk add the end of the file
573                 # See http://stackoverflow.com/q/4928560/35070 for details
574                 for i in range(1, 1024):
575                     try:
576                         gz = gzip.GzipFile(fileobj=io.BytesIO(content[:-i]), mode='rb')
577                         uncompressed = io.BytesIO(gz.read())
578                     except IOError:
579                         continue
580                     break
581                 else:
582                     raise original_ioerror
583             resp = self.addinfourl_wrapper(uncompressed, old_resp.headers, old_resp.url, old_resp.code)
584             resp.msg = old_resp.msg
585         # deflate
586         if resp.headers.get('Content-encoding', '') == 'deflate':
587             gz = io.BytesIO(self.deflate(resp.read()))
588             resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
589             resp.msg = old_resp.msg
590         return resp
591
592     https_request = http_request
593     https_response = http_response
594
595
596 def parse_iso8601(date_str, delimiter='T'):
597     """ Return a UNIX timestamp from the given date """
598
599     if date_str is None:
600         return None
601
602     m = re.search(
603         r'(\.[0-9]+)?(?:Z$| ?(?P<sign>\+|-)(?P<hours>[0-9]{2}):?(?P<minutes>[0-9]{2})$)',
604         date_str)
605     if not m:
606         timezone = datetime.timedelta()
607     else:
608         date_str = date_str[:-len(m.group(0))]
609         if not m.group('sign'):
610             timezone = datetime.timedelta()
611         else:
612             sign = 1 if m.group('sign') == '+' else -1
613             timezone = datetime.timedelta(
614                 hours=sign * int(m.group('hours')),
615                 minutes=sign * int(m.group('minutes')))
616     date_format = '%Y-%m-%d{0}%H:%M:%S'.format(delimiter)
617     dt = datetime.datetime.strptime(date_str, date_format) - timezone
618     return calendar.timegm(dt.timetuple())
619
620
621 def unified_strdate(date_str):
622     """Return a string with the date in the format YYYYMMDD"""
623
624     if date_str is None:
625         return None
626
627     upload_date = None
628     #Replace commas
629     date_str = date_str.replace(',', ' ')
630     # %z (UTC offset) is only supported in python>=3.2
631     date_str = re.sub(r' ?(\+|-)[0-9]{2}:?[0-9]{2}$', '', date_str)
632     format_expressions = [
633         '%d %B %Y',
634         '%d %b %Y',
635         '%B %d %Y',
636         '%b %d %Y',
637         '%b %dst %Y %I:%M%p',
638         '%b %dnd %Y %I:%M%p',
639         '%b %dth %Y %I:%M%p',
640         '%Y-%m-%d',
641         '%Y/%m/%d',
642         '%d.%m.%Y',
643         '%d/%m/%Y',
644         '%d/%m/%y',
645         '%Y/%m/%d %H:%M:%S',
646         '%d/%m/%Y %H:%M:%S',
647         '%Y-%m-%d %H:%M:%S',
648         '%Y-%m-%d %H:%M:%S.%f',
649         '%d.%m.%Y %H:%M',
650         '%d.%m.%Y %H.%M',
651         '%Y-%m-%dT%H:%M:%SZ',
652         '%Y-%m-%dT%H:%M:%S.%fZ',
653         '%Y-%m-%dT%H:%M:%S.%f0Z',
654         '%Y-%m-%dT%H:%M:%S',
655         '%Y-%m-%dT%H:%M:%S.%f',
656         '%Y-%m-%dT%H:%M',
657     ]
658     for expression in format_expressions:
659         try:
660             upload_date = datetime.datetime.strptime(date_str, expression).strftime('%Y%m%d')
661         except ValueError:
662             pass
663     if upload_date is None:
664         timetuple = email.utils.parsedate_tz(date_str)
665         if timetuple:
666             upload_date = datetime.datetime(*timetuple[:6]).strftime('%Y%m%d')
667     return upload_date
668
669 def determine_ext(url, default_ext='unknown_video'):
670     if url is None:
671         return default_ext
672     guess = url.partition('?')[0].rpartition('.')[2]
673     if re.match(r'^[A-Za-z0-9]+$', guess):
674         return guess
675     else:
676         return default_ext
677
678 def subtitles_filename(filename, sub_lang, sub_format):
679     return filename.rsplit('.', 1)[0] + '.' + sub_lang + '.' + sub_format
680
681 def date_from_str(date_str):
682     """
683     Return a datetime object from a string in the format YYYYMMDD or
684     (now|today)[+-][0-9](day|week|month|year)(s)?"""
685     today = datetime.date.today()
686     if date_str == 'now'or date_str == 'today':
687         return today
688     match = re.match('(now|today)(?P<sign>[+-])(?P<time>\d+)(?P<unit>day|week|month|year)(s)?', date_str)
689     if match is not None:
690         sign = match.group('sign')
691         time = int(match.group('time'))
692         if sign == '-':
693             time = -time
694         unit = match.group('unit')
695         #A bad aproximation?
696         if unit == 'month':
697             unit = 'day'
698             time *= 30
699         elif unit == 'year':
700             unit = 'day'
701             time *= 365
702         unit += 's'
703         delta = datetime.timedelta(**{unit: time})
704         return today + delta
705     return datetime.datetime.strptime(date_str, "%Y%m%d").date()
706     
707 def hyphenate_date(date_str):
708     """
709     Convert a date in 'YYYYMMDD' format to 'YYYY-MM-DD' format"""
710     match = re.match(r'^(\d\d\d\d)(\d\d)(\d\d)$', date_str)
711     if match is not None:
712         return '-'.join(match.groups())
713     else:
714         return date_str
715
716 class DateRange(object):
717     """Represents a time interval between two dates"""
718     def __init__(self, start=None, end=None):
719         """start and end must be strings in the format accepted by date"""
720         if start is not None:
721             self.start = date_from_str(start)
722         else:
723             self.start = datetime.datetime.min.date()
724         if end is not None:
725             self.end = date_from_str(end)
726         else:
727             self.end = datetime.datetime.max.date()
728         if self.start > self.end:
729             raise ValueError('Date range: "%s" , the start date must be before the end date' % self)
730     @classmethod
731     def day(cls, day):
732         """Returns a range that only contains the given day"""
733         return cls(day,day)
734     def __contains__(self, date):
735         """Check if the date is in the range"""
736         if not isinstance(date, datetime.date):
737             date = date_from_str(date)
738         return self.start <= date <= self.end
739     def __str__(self):
740         return '%s - %s' % ( self.start.isoformat(), self.end.isoformat())
741
742
743 def platform_name():
744     """ Returns the platform name as a compat_str """
745     res = platform.platform()
746     if isinstance(res, bytes):
747         res = res.decode(preferredencoding())
748
749     assert isinstance(res, compat_str)
750     return res
751
752
753 def _windows_write_string(s, out):
754     """ Returns True if the string was written using special methods,
755     False if it has yet to be written out."""
756     # Adapted from http://stackoverflow.com/a/3259271/35070
757
758     import ctypes
759     import ctypes.wintypes
760
761     WIN_OUTPUT_IDS = {
762         1: -11,
763         2: -12,
764     }
765
766     try:
767         fileno = out.fileno()
768     except AttributeError:
769         # If the output stream doesn't have a fileno, it's virtual
770         return False
771     if fileno not in WIN_OUTPUT_IDS:
772         return False
773
774     GetStdHandle = ctypes.WINFUNCTYPE(
775         ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD)(
776         ("GetStdHandle", ctypes.windll.kernel32))
777     h = GetStdHandle(WIN_OUTPUT_IDS[fileno])
778
779     WriteConsoleW = ctypes.WINFUNCTYPE(
780         ctypes.wintypes.BOOL, ctypes.wintypes.HANDLE, ctypes.wintypes.LPWSTR,
781         ctypes.wintypes.DWORD, ctypes.POINTER(ctypes.wintypes.DWORD),
782         ctypes.wintypes.LPVOID)(("WriteConsoleW", ctypes.windll.kernel32))
783     written = ctypes.wintypes.DWORD(0)
784
785     GetFileType = ctypes.WINFUNCTYPE(ctypes.wintypes.DWORD, ctypes.wintypes.DWORD)(("GetFileType", ctypes.windll.kernel32))
786     FILE_TYPE_CHAR = 0x0002
787     FILE_TYPE_REMOTE = 0x8000
788     GetConsoleMode = ctypes.WINFUNCTYPE(
789         ctypes.wintypes.BOOL, ctypes.wintypes.HANDLE,
790         ctypes.POINTER(ctypes.wintypes.DWORD))(
791         ("GetConsoleMode", ctypes.windll.kernel32))
792     INVALID_HANDLE_VALUE = ctypes.wintypes.DWORD(-1).value
793
794     def not_a_console(handle):
795         if handle == INVALID_HANDLE_VALUE or handle is None:
796             return True
797         return ((GetFileType(handle) & ~FILE_TYPE_REMOTE) != FILE_TYPE_CHAR
798                 or GetConsoleMode(handle, ctypes.byref(ctypes.wintypes.DWORD())) == 0)
799
800     if not_a_console(h):
801         return False
802
803     def next_nonbmp_pos(s):
804         try:
805             return next(i for i, c in enumerate(s) if ord(c) > 0xffff)
806         except StopIteration:
807             return len(s)
808
809     while s:
810         count = min(next_nonbmp_pos(s), 1024)
811
812         ret = WriteConsoleW(
813             h, s, count if count else 2, ctypes.byref(written), None)
814         if ret == 0:
815             raise OSError('Failed to write string')
816         if not count:  # We just wrote a non-BMP character
817             assert written.value == 2
818             s = s[1:]
819         else:
820             assert written.value > 0
821             s = s[written.value:]
822     return True
823
824
825 def write_string(s, out=None, encoding=None):
826     if out is None:
827         out = sys.stderr
828     assert type(s) == compat_str
829
830     if sys.platform == 'win32' and encoding is None and hasattr(out, 'fileno'):
831         if _windows_write_string(s, out):
832             return
833
834     if ('b' in getattr(out, 'mode', '') or
835             sys.version_info[0] < 3):  # Python 2 lies about mode of sys.stderr
836         byt = s.encode(encoding or preferredencoding(), 'ignore')
837         out.write(byt)
838     elif hasattr(out, 'buffer'):
839         enc = encoding or getattr(out, 'encoding', None) or preferredencoding()
840         byt = s.encode(enc, 'ignore')
841         out.buffer.write(byt)
842     else:
843         out.write(s)
844     out.flush()
845
846
847 def bytes_to_intlist(bs):
848     if not bs:
849         return []
850     if isinstance(bs[0], int):  # Python 3
851         return list(bs)
852     else:
853         return [ord(c) for c in bs]
854
855
856 def intlist_to_bytes(xs):
857     if not xs:
858         return b''
859     return struct_pack('%dB' % len(xs), *xs)
860
861
862 # Cross-platform file locking
863 if sys.platform == 'win32':
864     import ctypes.wintypes
865     import msvcrt
866
867     class OVERLAPPED(ctypes.Structure):
868         _fields_ = [
869             ('Internal', ctypes.wintypes.LPVOID),
870             ('InternalHigh', ctypes.wintypes.LPVOID),
871             ('Offset', ctypes.wintypes.DWORD),
872             ('OffsetHigh', ctypes.wintypes.DWORD),
873             ('hEvent', ctypes.wintypes.HANDLE),
874         ]
875
876     kernel32 = ctypes.windll.kernel32
877     LockFileEx = kernel32.LockFileEx
878     LockFileEx.argtypes = [
879         ctypes.wintypes.HANDLE,     # hFile
880         ctypes.wintypes.DWORD,      # dwFlags
881         ctypes.wintypes.DWORD,      # dwReserved
882         ctypes.wintypes.DWORD,      # nNumberOfBytesToLockLow
883         ctypes.wintypes.DWORD,      # nNumberOfBytesToLockHigh
884         ctypes.POINTER(OVERLAPPED)  # Overlapped
885     ]
886     LockFileEx.restype = ctypes.wintypes.BOOL
887     UnlockFileEx = kernel32.UnlockFileEx
888     UnlockFileEx.argtypes = [
889         ctypes.wintypes.HANDLE,     # hFile
890         ctypes.wintypes.DWORD,      # dwReserved
891         ctypes.wintypes.DWORD,      # nNumberOfBytesToLockLow
892         ctypes.wintypes.DWORD,      # nNumberOfBytesToLockHigh
893         ctypes.POINTER(OVERLAPPED)  # Overlapped
894     ]
895     UnlockFileEx.restype = ctypes.wintypes.BOOL
896     whole_low = 0xffffffff
897     whole_high = 0x7fffffff
898
899     def _lock_file(f, exclusive):
900         overlapped = OVERLAPPED()
901         overlapped.Offset = 0
902         overlapped.OffsetHigh = 0
903         overlapped.hEvent = 0
904         f._lock_file_overlapped_p = ctypes.pointer(overlapped)
905         handle = msvcrt.get_osfhandle(f.fileno())
906         if not LockFileEx(handle, 0x2 if exclusive else 0x0, 0,
907                           whole_low, whole_high, f._lock_file_overlapped_p):
908             raise OSError('Locking file failed: %r' % ctypes.FormatError())
909
910     def _unlock_file(f):
911         assert f._lock_file_overlapped_p
912         handle = msvcrt.get_osfhandle(f.fileno())
913         if not UnlockFileEx(handle, 0,
914                             whole_low, whole_high, f._lock_file_overlapped_p):
915             raise OSError('Unlocking file failed: %r' % ctypes.FormatError())
916
917 else:
918     import fcntl
919
920     def _lock_file(f, exclusive):
921         fcntl.flock(f, fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH)
922
923     def _unlock_file(f):
924         fcntl.flock(f, fcntl.LOCK_UN)
925
926
927 class locked_file(object):
928     def __init__(self, filename, mode, encoding=None):
929         assert mode in ['r', 'a', 'w']
930         self.f = io.open(filename, mode, encoding=encoding)
931         self.mode = mode
932
933     def __enter__(self):
934         exclusive = self.mode != 'r'
935         try:
936             _lock_file(self.f, exclusive)
937         except IOError:
938             self.f.close()
939             raise
940         return self
941
942     def __exit__(self, etype, value, traceback):
943         try:
944             _unlock_file(self.f)
945         finally:
946             self.f.close()
947
948     def __iter__(self):
949         return iter(self.f)
950
951     def write(self, *args):
952         return self.f.write(*args)
953
954     def read(self, *args):
955         return self.f.read(*args)
956
957
958 def get_filesystem_encoding():
959     encoding = sys.getfilesystemencoding()
960     return encoding if encoding is not None else 'utf-8'
961
962
963 def shell_quote(args):
964     quoted_args = []
965     encoding = get_filesystem_encoding()
966     for a in args:
967         if isinstance(a, bytes):
968             # We may get a filename encoded with 'encodeFilename'
969             a = a.decode(encoding)
970         quoted_args.append(pipes.quote(a))
971     return ' '.join(quoted_args)
972
973
974 def takewhile_inclusive(pred, seq):
975     """ Like itertools.takewhile, but include the latest evaluated element
976         (the first element so that Not pred(e)) """
977     for e in seq:
978         yield e
979         if not pred(e):
980             return
981
982
983 def smuggle_url(url, data):
984     """ Pass additional data in a URL for internal use. """
985
986     sdata = compat_urllib_parse.urlencode(
987         {'__youtubedl_smuggle': json.dumps(data)})
988     return url + '#' + sdata
989
990
991 def unsmuggle_url(smug_url, default=None):
992     if not '#__youtubedl_smuggle' in smug_url:
993         return smug_url, default
994     url, _, sdata = smug_url.rpartition('#')
995     jsond = compat_parse_qs(sdata)['__youtubedl_smuggle'][0]
996     data = json.loads(jsond)
997     return url, data
998
999
1000 def format_bytes(bytes):
1001     if bytes is None:
1002         return 'N/A'
1003     if type(bytes) is str:
1004         bytes = float(bytes)
1005     if bytes == 0.0:
1006         exponent = 0
1007     else:
1008         exponent = int(math.log(bytes, 1024.0))
1009     suffix = ['B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB', 'YiB'][exponent]
1010     converted = float(bytes) / float(1024 ** exponent)
1011     return '%.2f%s' % (converted, suffix)
1012
1013
1014 def get_term_width():
1015     columns = compat_getenv('COLUMNS', None)
1016     if columns:
1017         return int(columns)
1018
1019     try:
1020         sp = subprocess.Popen(
1021             ['stty', 'size'],
1022             stdout=subprocess.PIPE, stderr=subprocess.PIPE)
1023         out, err = sp.communicate()
1024         return int(out.split()[1])
1025     except:
1026         pass
1027     return None
1028
1029
1030 def month_by_name(name):
1031     """ Return the number of a month by (locale-independently) English name """
1032
1033     ENGLISH_NAMES = [
1034         'January', 'February', 'March', 'April', 'May', 'June',
1035         'July', 'August', 'September', 'October', 'November', 'December']
1036     try:
1037         return ENGLISH_NAMES.index(name) + 1
1038     except ValueError:
1039         return None
1040
1041
1042 def fix_xml_ampersands(xml_str):
1043     """Replace all the '&' by '&amp;' in XML"""
1044     return re.sub(
1045         r'&(?!amp;|lt;|gt;|apos;|quot;|#x[0-9a-fA-F]{,4};|#[0-9]{,4};)',
1046         '&amp;',
1047         xml_str)
1048
1049
1050 def setproctitle(title):
1051     assert isinstance(title, compat_str)
1052     try:
1053         libc = ctypes.cdll.LoadLibrary("libc.so.6")
1054     except OSError:
1055         return
1056     title_bytes = title.encode('utf-8')
1057     buf = ctypes.create_string_buffer(len(title_bytes))
1058     buf.value = title_bytes
1059     try:
1060         libc.prctl(15, buf, 0, 0, 0)
1061     except AttributeError:
1062         return  # Strange libc, just skip this
1063
1064
1065 def remove_start(s, start):
1066     if s.startswith(start):
1067         return s[len(start):]
1068     return s
1069
1070
1071 def remove_end(s, end):
1072     if s.endswith(end):
1073         return s[:-len(end)]
1074     return s
1075
1076
1077 def url_basename(url):
1078     path = compat_urlparse.urlparse(url).path
1079     return path.strip('/').split('/')[-1]
1080
1081
1082 class HEADRequest(compat_urllib_request.Request):
1083     def get_method(self):
1084         return "HEAD"
1085
1086
1087 def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1):
1088     if get_attr:
1089         if v is not None:
1090             v = getattr(v, get_attr, None)
1091     if v == '':
1092         v = None
1093     return default if v is None else (int(v) * invscale // scale)
1094
1095
1096 def str_or_none(v, default=None):
1097     return default if v is None else compat_str(v)
1098
1099
1100 def str_to_int(int_str):
1101     """ A more relaxed version of int_or_none """
1102     if int_str is None:
1103         return None
1104     int_str = re.sub(r'[,\.\+]', '', int_str)
1105     return int(int_str)
1106
1107
1108 def float_or_none(v, scale=1, invscale=1, default=None):
1109     return default if v is None else (float(v) * invscale / scale)
1110
1111
1112 def parse_duration(s):
1113     if s is None:
1114         return None
1115
1116     s = s.strip()
1117
1118     m = re.match(
1119         r'''(?ix)T?
1120             (?:
1121                 (?:(?P<hours>[0-9]+)\s*(?:[:h]|hours?)\s*)?
1122                 (?P<mins>[0-9]+)\s*(?:[:m]|mins?|minutes?)\s*
1123             )?
1124             (?P<secs>[0-9]+)(?P<ms>\.[0-9]+)?\s*(?:s|secs?|seconds?)?$''', s)
1125     if not m:
1126         return None
1127     res = int(m.group('secs'))
1128     if m.group('mins'):
1129         res += int(m.group('mins')) * 60
1130         if m.group('hours'):
1131             res += int(m.group('hours')) * 60 * 60
1132     if m.group('ms'):
1133         res += float(m.group('ms'))
1134     return res
1135
1136
1137 def prepend_extension(filename, ext):
1138     name, real_ext = os.path.splitext(filename) 
1139     return '{0}.{1}{2}'.format(name, ext, real_ext)
1140
1141
1142 def check_executable(exe, args=[]):
1143     """ Checks if the given binary is installed somewhere in PATH, and returns its name.
1144     args can be a list of arguments for a short output (like -version) """
1145     try:
1146         subprocess.Popen([exe] + args, stdout=subprocess.PIPE, stderr=subprocess.PIPE).communicate()
1147     except OSError:
1148         return False
1149     return exe
1150
1151
1152 def get_exe_version(exe, args=['--version'],
1153                     version_re=r'version\s+([0-9._-a-zA-Z]+)',
1154                     unrecognized='present'):
1155     """ Returns the version of the specified executable,
1156     or False if the executable is not present """
1157     try:
1158         out, err = subprocess.Popen(
1159             [exe] + args,
1160             stdout=subprocess.PIPE, stderr=subprocess.STDOUT).communicate()
1161     except OSError:
1162         return False
1163     firstline = out.partition(b'\n')[0].decode('ascii', 'ignore')
1164     m = re.search(version_re, firstline)
1165     if m:
1166         return m.group(1)
1167     else:
1168         return unrecognized
1169
1170
1171 class PagedList(object):
1172     def __len__(self):
1173         # This is only useful for tests
1174         return len(self.getslice())
1175
1176
1177 class OnDemandPagedList(PagedList):
1178     def __init__(self, pagefunc, pagesize):
1179         self._pagefunc = pagefunc
1180         self._pagesize = pagesize
1181
1182     def getslice(self, start=0, end=None):
1183         res = []
1184         for pagenum in itertools.count(start // self._pagesize):
1185             firstid = pagenum * self._pagesize
1186             nextfirstid = pagenum * self._pagesize + self._pagesize
1187             if start >= nextfirstid:
1188                 continue
1189
1190             page_results = list(self._pagefunc(pagenum))
1191
1192             startv = (
1193                 start % self._pagesize
1194                 if firstid <= start < nextfirstid
1195                 else 0)
1196
1197             endv = (
1198                 ((end - 1) % self._pagesize) + 1
1199                 if (end is not None and firstid <= end <= nextfirstid)
1200                 else None)
1201
1202             if startv != 0 or endv is not None:
1203                 page_results = page_results[startv:endv]
1204             res.extend(page_results)
1205
1206             # A little optimization - if current page is not "full", ie. does
1207             # not contain page_size videos then we can assume that this page
1208             # is the last one - there are no more ids on further pages -
1209             # i.e. no need to query again.
1210             if len(page_results) + startv < self._pagesize:
1211                 break
1212
1213             # If we got the whole page, but the next page is not interesting,
1214             # break out early as well
1215             if end == nextfirstid:
1216                 break
1217         return res
1218
1219
1220 class InAdvancePagedList(PagedList):
1221     def __init__(self, pagefunc, pagecount, pagesize):
1222         self._pagefunc = pagefunc
1223         self._pagecount = pagecount
1224         self._pagesize = pagesize
1225
1226     def getslice(self, start=0, end=None):
1227         res = []
1228         start_page = start // self._pagesize
1229         end_page = (
1230             self._pagecount if end is None else (end // self._pagesize + 1))
1231         skip_elems = start - start_page * self._pagesize
1232         only_more = None if end is None else end - start
1233         for pagenum in range(start_page, end_page):
1234             page = list(self._pagefunc(pagenum))
1235             if skip_elems:
1236                 page = page[skip_elems:]
1237                 skip_elems = None
1238             if only_more is not None:
1239                 if len(page) < only_more:
1240                     only_more -= len(page)
1241                 else:
1242                     page = page[:only_more]
1243                     res.extend(page)
1244                     break
1245             res.extend(page)
1246         return res
1247
1248
1249 def uppercase_escape(s):
1250     unicode_escape = codecs.getdecoder('unicode_escape')
1251     return re.sub(
1252         r'\\U[0-9a-fA-F]{8}',
1253         lambda m: unicode_escape(m.group(0))[0],
1254         s)
1255
1256
1257 def escape_rfc3986(s):
1258     """Escape non-ASCII characters as suggested by RFC 3986"""
1259     if sys.version_info < (3, 0) and isinstance(s, unicode):
1260         s = s.encode('utf-8')
1261     return compat_urllib_parse.quote(s, b"%/;:@&=+$,!~*'()?#[]")
1262
1263
1264 def escape_url(url):
1265     """Escape URL as suggested by RFC 3986"""
1266     url_parsed = compat_urllib_parse_urlparse(url)
1267     return url_parsed._replace(
1268         path=escape_rfc3986(url_parsed.path),
1269         params=escape_rfc3986(url_parsed.params),
1270         query=escape_rfc3986(url_parsed.query),
1271         fragment=escape_rfc3986(url_parsed.fragment)
1272     ).geturl()
1273
1274 try:
1275     struct.pack('!I', 0)
1276 except TypeError:
1277     # In Python 2.6 (and some 2.7 versions), struct requires a bytes argument
1278     def struct_pack(spec, *args):
1279         if isinstance(spec, compat_str):
1280             spec = spec.encode('ascii')
1281         return struct.pack(spec, *args)
1282
1283     def struct_unpack(spec, *args):
1284         if isinstance(spec, compat_str):
1285             spec = spec.encode('ascii')
1286         return struct.unpack(spec, *args)
1287 else:
1288     struct_pack = struct.pack
1289     struct_unpack = struct.unpack
1290
1291
1292 def read_batch_urls(batch_fd):
1293     def fixup(url):
1294         if not isinstance(url, compat_str):
1295             url = url.decode('utf-8', 'replace')
1296         BOM_UTF8 = '\xef\xbb\xbf'
1297         if url.startswith(BOM_UTF8):
1298             url = url[len(BOM_UTF8):]
1299         url = url.strip()
1300         if url.startswith(('#', ';', ']')):
1301             return False
1302         return url
1303
1304     with contextlib.closing(batch_fd) as fd:
1305         return [url for url in map(fixup, fd) if url]
1306
1307
1308 def urlencode_postdata(*args, **kargs):
1309     return compat_urllib_parse.urlencode(*args, **kargs).encode('ascii')
1310
1311
1312 try:
1313     etree_iter = xml.etree.ElementTree.Element.iter
1314 except AttributeError:  # Python <=2.6
1315     etree_iter = lambda n: n.findall('.//*')
1316
1317
1318 def parse_xml(s):
1319     class TreeBuilder(xml.etree.ElementTree.TreeBuilder):
1320         def doctype(self, name, pubid, system):
1321             pass  # Ignore doctypes
1322
1323     parser = xml.etree.ElementTree.XMLParser(target=TreeBuilder())
1324     kwargs = {'parser': parser} if sys.version_info >= (2, 7) else {}
1325     tree = xml.etree.ElementTree.XML(s.encode('utf-8'), **kwargs)
1326     # Fix up XML parser in Python 2.x
1327     if sys.version_info < (3, 0):
1328         for n in etree_iter(tree):
1329             if n.text is not None:
1330                 if not isinstance(n.text, compat_str):
1331                     n.text = n.text.decode('utf-8')
1332     return tree
1333
1334
1335 US_RATINGS = {
1336     'G': 0,
1337     'PG': 10,
1338     'PG-13': 13,
1339     'R': 16,
1340     'NC': 18,
1341 }
1342
1343
1344 def parse_age_limit(s):
1345     if s is None:
1346         return None
1347     m = re.match(r'^(?P<age>\d{1,2})\+?$', s)
1348     return int(m.group('age')) if m else US_RATINGS.get(s, None)
1349
1350
1351 def strip_jsonp(code):
1352     return re.sub(
1353         r'(?s)^[a-zA-Z0-9_]+\s*\(\s*(.*)\);?\s*?(?://[^\n]*)*$', r'\1', code)
1354
1355
1356 def js_to_json(code):
1357     def fix_kv(m):
1358         v = m.group(0)
1359         if v in ('true', 'false', 'null'):
1360             return v
1361         if v.startswith('"'):
1362             return v
1363         if v.startswith("'"):
1364             v = v[1:-1]
1365             v = re.sub(r"\\\\|\\'|\"", lambda m: {
1366                 '\\\\': '\\\\',
1367                 "\\'": "'",
1368                 '"': '\\"',
1369             }[m.group(0)], v)
1370         return '"%s"' % v
1371
1372     res = re.sub(r'''(?x)
1373         "(?:[^"\\]*(?:\\\\|\\")?)*"|
1374         '(?:[^'\\]*(?:\\\\|\\')?)*'|
1375         [a-zA-Z_][a-zA-Z_0-9]*
1376         ''', fix_kv, code)
1377     res = re.sub(r',(\s*\])', lambda m: m.group(1), res)
1378     return res
1379
1380
1381 def qualities(quality_ids):
1382     """ Get a numeric quality value out of a list of possible values """
1383     def q(qid):
1384         try:
1385             return quality_ids.index(qid)
1386         except ValueError:
1387             return -1
1388     return q
1389
1390
1391 DEFAULT_OUTTMPL = '%(title)s-%(id)s.%(ext)s'
1392
1393
1394 def limit_length(s, length):
1395     """ Add ellipses to overly long strings """
1396     if s is None:
1397         return None
1398     ELLIPSES = '...'
1399     if len(s) > length:
1400         return s[:length - len(ELLIPSES)] + ELLIPSES
1401     return s
1402
1403
1404 def version_tuple(v):
1405     return [int(e) for e in v.split('.')]
1406
1407
1408 def is_outdated_version(version, limit, assume_new=True):
1409     if not version:
1410         return not assume_new
1411     try:
1412         return version_tuple(version) < version_tuple(limit)
1413     except ValueError:
1414         return not assume_new