[utils] Skip remote IP addresses non matching to source address' IP version (closes...
authorAndrew Udvare <audvare@gmail.com>
Sat, 17 Mar 2018 00:11:47 +0000 (20:11 -0400)
committerSergey M․ <dstftw@gmail.com>
Tue, 28 Aug 2018 18:17:53 +0000 (01:17 +0700)
youtube_dl/utils.py

index 0c830ba71fbde4f2e519e1f10857cfe5e9da6c56..2be8c95cd4acc8698bed3aada9aef97e3152e34b 100644 (file)
@@ -882,7 +882,40 @@ def _create_http_connection(ydl_handler, http_class, is_https, *args, **kwargs):
         kwargs['strict'] = True
     hc = http_class(*args, **compat_kwargs(kwargs))
     source_address = ydl_handler._params.get('source_address')
+
     if source_address is not None:
+        filter_for = socket.AF_INET if '.' in source_address else socket.AF_INET6
+        # This is to workaround _create_connection() from socket where it will try all
+        # address data from getaddrinfo() including IPv6. This filters the result from
+        # getaddrinfo() based on the source_address value.
+        # This is based on the cpython socket.create_connection() function.
+        # https://github.com/python/cpython/blob/master/Lib/socket.py#L691
+        def _create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None):
+            host, port = address
+            err = None
+            addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
+            ip_addrs = [addr for addr in addrs if addr[0] == filter_for]
+            for res in ip_addrs:
+                af, socktype, proto, canonname, sa = res
+                sock = None
+                try:
+                    sock = socket.socket(af, socktype, proto)
+                    if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
+                        sock.settimeout(timeout)
+                    sock.bind(source_address)
+                    sock.connect(sa)
+                    err = None  # Explicitly break reference cycle
+                    return sock
+                except socket.error as _:
+                    err = _
+                    if sock is not None:
+                        sock.close()
+            if err is not None:
+                raise err
+            else:
+                raise socket.error('Unknown error occurred')
+        hc._create_connection = _create_connection
+
         sa = (source_address, 0)
         if hasattr(hc, 'source_address'):  # Python 2.7+
             hc.source_address = sa