projects
/
youtube-dl
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
[YoutubeDLCookieJar] Add test for keeping session cookies
[youtube-dl]
/
test
/
helper.py
diff --git
a/test/helper.py
b/test/helper.py
index 288ed237d74080be5688f8d9ee1ba4c59aadb62d..aa9a1c9b2aadcd3a9eaeb1170c2e8d90afabb0b8 100644
(file)
--- a/
test/helper.py
+++ b/
test/helper.py
@@
-7,12
+7,16
@@
import json
import os.path
import re
import types
import os.path
import re
import types
+import ssl
import sys
import youtube_dl.extractor
from youtube_dl import YoutubeDL
import sys
import youtube_dl.extractor
from youtube_dl import YoutubeDL
-from youtube_dl.utils import (
+from youtube_dl.compat import (
+ compat_os_name,
compat_str,
compat_str,
+)
+from youtube_dl.utils import (
preferredencoding,
write_string,
)
preferredencoding,
write_string,
)
@@
-21,8
+25,13
@@
from youtube_dl.utils import (
def get_params(override=None):
PARAMETERS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"parameters.json")
def get_params(override=None):
PARAMETERS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"parameters.json")
+ LOCAL_PARAMETERS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
+ "local_parameters.json")
with io.open(PARAMETERS_FILE, encoding='utf-8') as pf:
parameters = json.load(pf)
with io.open(PARAMETERS_FILE, encoding='utf-8') as pf:
parameters = json.load(pf)
+ if os.path.exists(LOCAL_PARAMETERS_FILE):
+ with io.open(LOCAL_PARAMETERS_FILE, encoding='utf-8') as pf:
+ parameters.update(json.load(pf))
if override:
parameters.update(override)
return parameters
if override:
parameters.update(override)
return parameters
@@
-42,7
+51,7
@@
def report_warning(message):
Print the message to stderr, it will be prefixed with 'WARNING:'
If stderr is a tty file the 'WARNING:' will be colored
'''
Print the message to stderr, it will be prefixed with 'WARNING:'
If stderr is a tty file the 'WARNING:' will be colored
'''
- if sys.stderr.isatty() and
os.
name != 'nt':
+ if sys.stderr.isatty() and
compat_os_
name != 'nt':
_msg_header = '\033[0;33mWARNING:\033[0m'
else:
_msg_header = 'WARNING:'
_msg_header = '\033[0;33mWARNING:\033[0m'
else:
_msg_header = 'WARNING:'
@@
-128,17
+137,21
@@
def expect_value(self, got, expected, field):
elif isinstance(expected, list) and isinstance(got, list):
self.assertEqual(
len(expected), len(got),
elif isinstance(expected, list) and isinstance(got, list):
self.assertEqual(
len(expected), len(got),
- 'Expect a list of length %d, but got a list of length %d' % (len(expected), len(got)))
+ 'Expect a list of length %d, but got a list of length %d for field %s' % (
+ len(expected), len(got), field))
for index, (item_got, item_expected) in enumerate(zip(got, expected)):
type_got = type(item_got)
type_expected = type(item_expected)
self.assertEqual(
type_expected, type_got,
for index, (item_got, item_expected) in enumerate(zip(got, expected)):
type_got = type(item_got)
type_expected = type(item_expected)
self.assertEqual(
type_expected, type_got,
- 'Type
doesn\'t match at element %d of the list in field %s, expect %s, got %s
' % (
- index, field, type_expected, type_got))
+ 'Type
mismatch for list item at index %d for field %s, expected %r, got %r
' % (
+
index, field, type_expected, type_got))
expect_value(self, item_got, item_expected, field)
else:
if isinstance(expected, compat_str) and expected.startswith('md5:'):
expect_value(self, item_got, item_expected, field)
else:
if isinstance(expected, compat_str) and expected.startswith('md5:'):
+ self.assertTrue(
+ isinstance(got, compat_str),
+ 'Expected field %s to be a unicode object, but got value %r of type %r' % (field, got, type(got)))
got = 'md5:' + md5(got)
elif isinstance(expected, compat_str) and expected.startswith('mincount:'):
self.assertTrue(
got = 'md5:' + md5(got)
elif isinstance(expected, compat_str) and expected.startswith('mincount:'):
self.assertTrue(
@@
-152,7
+165,7
@@
def expect_value(self, got, expected, field):
return
self.assertEqual(
expected, got,
return
self.assertEqual(
expected, got,
- '
i
nvalid value for field %s, expected %r, got %r' % (field, expected, got))
+ '
I
nvalid value for field %s, expected %r, got %r' % (field, expected, got))
def expect_dict(self, got_dict, expected_dict):
def expect_dict(self, got_dict, expected_dict):
@@
-232,3
+245,12
@@
def expect_warnings(ydl, warnings_re):
real_warning(w)
ydl.report_warning = _report_warning
real_warning(w)
ydl.report_warning = _report_warning
+
+
+def http_server_port(httpd):
+ if os.name == 'java' and isinstance(httpd.socket, ssl.SSLSocket):
+ # In Jython SSLSocket is not a subclass of socket.socket
+ sock = httpd.socket.sock
+ else:
+ sock = httpd.socket
+ return sock.getsockname()[1]