KoD 0.6
-Nuova ricerca globale -migliorie prestazionali in generale -fix vari ai server
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
# For backwards compatibility, provide imports that used to be here.
|
||||
from .connection import is_connection_dropped
|
||||
from .request import make_headers
|
||||
@@ -12,43 +13,34 @@ from .ssl_ import (
|
||||
resolve_cert_reqs,
|
||||
resolve_ssl_version,
|
||||
ssl_wrap_socket,
|
||||
PROTOCOL_TLS,
|
||||
)
|
||||
from .timeout import (
|
||||
current_time,
|
||||
Timeout,
|
||||
)
|
||||
from .timeout import current_time, Timeout
|
||||
|
||||
from .retry import Retry
|
||||
from .url import (
|
||||
get_host,
|
||||
parse_url,
|
||||
split_first,
|
||||
Url,
|
||||
)
|
||||
from .wait import (
|
||||
wait_for_read,
|
||||
wait_for_write
|
||||
)
|
||||
from .url import get_host, parse_url, split_first, Url
|
||||
from .wait import wait_for_read, wait_for_write
|
||||
|
||||
__all__ = (
|
||||
'HAS_SNI',
|
||||
'IS_PYOPENSSL',
|
||||
'IS_SECURETRANSPORT',
|
||||
'SSLContext',
|
||||
'Retry',
|
||||
'Timeout',
|
||||
'Url',
|
||||
'assert_fingerprint',
|
||||
'current_time',
|
||||
'is_connection_dropped',
|
||||
'is_fp_closed',
|
||||
'get_host',
|
||||
'parse_url',
|
||||
'make_headers',
|
||||
'resolve_cert_reqs',
|
||||
'resolve_ssl_version',
|
||||
'split_first',
|
||||
'ssl_wrap_socket',
|
||||
'wait_for_read',
|
||||
'wait_for_write'
|
||||
"HAS_SNI",
|
||||
"IS_PYOPENSSL",
|
||||
"IS_SECURETRANSPORT",
|
||||
"SSLContext",
|
||||
"PROTOCOL_TLS",
|
||||
"Retry",
|
||||
"Timeout",
|
||||
"Url",
|
||||
"assert_fingerprint",
|
||||
"current_time",
|
||||
"is_connection_dropped",
|
||||
"is_fp_closed",
|
||||
"get_host",
|
||||
"parse_url",
|
||||
"make_headers",
|
||||
"resolve_cert_reqs",
|
||||
"resolve_ssl_version",
|
||||
"split_first",
|
||||
"ssl_wrap_socket",
|
||||
"wait_for_read",
|
||||
"wait_for_write",
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import absolute_import
|
||||
import socket
|
||||
from .wait import wait_for_read
|
||||
from .selectors import HAS_SELECT, SelectorError
|
||||
from .wait import NoWayToWaitForSocketError, wait_for_read
|
||||
from ..contrib import _appengine_environ
|
||||
|
||||
|
||||
def is_connection_dropped(conn): # Platform-specific
|
||||
@@ -14,27 +14,28 @@ def is_connection_dropped(conn): # Platform-specific
|
||||
Note: For platforms like AppEngine, this will always return ``False`` to
|
||||
let the platform handle connection recycling transparently for us.
|
||||
"""
|
||||
sock = getattr(conn, 'sock', False)
|
||||
sock = getattr(conn, "sock", False)
|
||||
if sock is False: # Platform-specific: AppEngine
|
||||
return False
|
||||
if sock is None: # Connection already closed (such as by httplib).
|
||||
return True
|
||||
|
||||
if not HAS_SELECT:
|
||||
return False
|
||||
|
||||
try:
|
||||
return bool(wait_for_read(sock, timeout=0.0))
|
||||
except SelectorError:
|
||||
return True
|
||||
# Returns True if readable, which here means it's been dropped
|
||||
return wait_for_read(sock, timeout=0.0)
|
||||
except NoWayToWaitForSocketError: # Platform-specific: AppEngine
|
||||
return False
|
||||
|
||||
|
||||
# This function is copied from socket.py in the Python 2.7 standard
|
||||
# library test suite. Added to its signature is only `socket_options`.
|
||||
# One additional modification is that we avoid binding to IPv6 servers
|
||||
# discovered in DNS if the system doesn't have IPv6 functionality.
|
||||
def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
|
||||
source_address=None, socket_options=None):
|
||||
def create_connection(
|
||||
address,
|
||||
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
|
||||
source_address=None,
|
||||
socket_options=None,
|
||||
):
|
||||
"""Connect to *address* and return the socket object.
|
||||
|
||||
Convenience function. Connect to *address* (a 2-tuple ``(host,
|
||||
@@ -48,8 +49,8 @@ def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
|
||||
"""
|
||||
|
||||
host, port = address
|
||||
if host.startswith('['):
|
||||
host = host.strip('[]')
|
||||
if host.startswith("["):
|
||||
host = host.strip("[]")
|
||||
err = None
|
||||
|
||||
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
|
||||
@@ -109,6 +110,13 @@ def _has_ipv6(host):
|
||||
sock = None
|
||||
has_ipv6 = False
|
||||
|
||||
# App Engine doesn't support IPV6 sockets and actually has a quota on the
|
||||
# number of sockets that can be used, so just early out here instead of
|
||||
# creating a socket needlessly.
|
||||
# See https://github.com/urllib3/urllib3/issues/1446
|
||||
if _appengine_environ.is_appengine_sandbox():
|
||||
return False
|
||||
|
||||
if socket.has_ipv6:
|
||||
# has_ipv6 returns true if cPython was compiled with IPv6 support.
|
||||
# It does not tell us if the system has IPv6 support enabled. To
|
||||
@@ -127,4 +135,4 @@ def _has_ipv6(host):
|
||||
return has_ipv6
|
||||
|
||||
|
||||
HAS_IPV6 = _has_ipv6('::1')
|
||||
HAS_IPV6 = _has_ipv6("::1")
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
import collections
|
||||
from ..packages import six
|
||||
from ..packages.six.moves import queue
|
||||
|
||||
if six.PY2:
|
||||
# Queue is imported for side effects on MS Windows. See issue #229.
|
||||
import Queue as _unused_module_Queue # noqa: F401
|
||||
|
||||
|
||||
class LifoQueue(queue.Queue):
|
||||
def _init(self, _):
|
||||
self.queue = collections.deque()
|
||||
|
||||
def _qsize(self, len=len):
|
||||
return len(self.queue)
|
||||
|
||||
def _put(self, item):
|
||||
self.queue.append(item)
|
||||
|
||||
def _get(self):
|
||||
return self.queue.pop()
|
||||
+37
-20
@@ -4,12 +4,25 @@ from base64 import b64encode
|
||||
from ..packages.six import b, integer_types
|
||||
from ..exceptions import UnrewindableBodyError
|
||||
|
||||
ACCEPT_ENCODING = 'gzip,deflate'
|
||||
ACCEPT_ENCODING = "gzip,deflate"
|
||||
try:
|
||||
import brotli as _unused_module_brotli # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
ACCEPT_ENCODING += ",br"
|
||||
|
||||
_FAILEDTELL = object()
|
||||
|
||||
|
||||
def make_headers(keep_alive=None, accept_encoding=None, user_agent=None,
|
||||
basic_auth=None, proxy_basic_auth=None, disable_cache=None):
|
||||
def make_headers(
|
||||
keep_alive=None,
|
||||
accept_encoding=None,
|
||||
user_agent=None,
|
||||
basic_auth=None,
|
||||
proxy_basic_auth=None,
|
||||
disable_cache=None,
|
||||
):
|
||||
"""
|
||||
Shortcuts for generating request headers.
|
||||
|
||||
@@ -49,27 +62,27 @@ def make_headers(keep_alive=None, accept_encoding=None, user_agent=None,
|
||||
if isinstance(accept_encoding, str):
|
||||
pass
|
||||
elif isinstance(accept_encoding, list):
|
||||
accept_encoding = ','.join(accept_encoding)
|
||||
accept_encoding = ",".join(accept_encoding)
|
||||
else:
|
||||
accept_encoding = ACCEPT_ENCODING
|
||||
headers['accept-encoding'] = accept_encoding
|
||||
headers["accept-encoding"] = accept_encoding
|
||||
|
||||
if user_agent:
|
||||
headers['user-agent'] = user_agent
|
||||
headers["user-agent"] = user_agent
|
||||
|
||||
if keep_alive:
|
||||
headers['connection'] = 'keep-alive'
|
||||
headers["connection"] = "keep-alive"
|
||||
|
||||
if basic_auth:
|
||||
headers['authorization'] = 'Basic ' + \
|
||||
b64encode(b(basic_auth)).decode('utf-8')
|
||||
headers["authorization"] = "Basic " + b64encode(b(basic_auth)).decode("utf-8")
|
||||
|
||||
if proxy_basic_auth:
|
||||
headers['proxy-authorization'] = 'Basic ' + \
|
||||
b64encode(b(proxy_basic_auth)).decode('utf-8')
|
||||
headers["proxy-authorization"] = "Basic " + b64encode(
|
||||
b(proxy_basic_auth)
|
||||
).decode("utf-8")
|
||||
|
||||
if disable_cache:
|
||||
headers['cache-control'] = 'no-cache'
|
||||
headers["cache-control"] = "no-cache"
|
||||
|
||||
return headers
|
||||
|
||||
@@ -81,7 +94,7 @@ def set_file_position(body, pos):
|
||||
"""
|
||||
if pos is not None:
|
||||
rewind_body(body, pos)
|
||||
elif getattr(body, 'tell', None) is not None:
|
||||
elif getattr(body, "tell", None) is not None:
|
||||
try:
|
||||
pos = body.tell()
|
||||
except (IOError, OSError):
|
||||
@@ -103,16 +116,20 @@ def rewind_body(body, body_pos):
|
||||
:param int pos:
|
||||
Position to seek to in file.
|
||||
"""
|
||||
body_seek = getattr(body, 'seek', None)
|
||||
body_seek = getattr(body, "seek", None)
|
||||
if body_seek is not None and isinstance(body_pos, integer_types):
|
||||
try:
|
||||
body_seek(body_pos)
|
||||
except (IOError, OSError):
|
||||
raise UnrewindableBodyError("An error occurred when rewinding request "
|
||||
"body for redirect/retry.")
|
||||
raise UnrewindableBodyError(
|
||||
"An error occurred when rewinding request " "body for redirect/retry."
|
||||
)
|
||||
elif body_pos is _FAILEDTELL:
|
||||
raise UnrewindableBodyError("Unable to record file position for rewinding "
|
||||
"request body during a redirect/retry.")
|
||||
raise UnrewindableBodyError(
|
||||
"Unable to record file position for rewinding "
|
||||
"request body during a redirect/retry."
|
||||
)
|
||||
else:
|
||||
raise ValueError("body_pos must be of type integer, "
|
||||
"instead it was %s." % type(body_pos))
|
||||
raise ValueError(
|
||||
"body_pos must be of type integer, " "instead it was %s." % type(body_pos)
|
||||
)
|
||||
|
||||
@@ -52,15 +52,20 @@ def assert_header_parsing(headers):
|
||||
# This will fail silently if we pass in the wrong kind of parameter.
|
||||
# To make debugging easier add an explicit check.
|
||||
if not isinstance(headers, httplib.HTTPMessage):
|
||||
raise TypeError('expected httplib.Message, got {0}.'.format(
|
||||
type(headers)))
|
||||
raise TypeError("expected httplib.Message, got {0}.".format(type(headers)))
|
||||
|
||||
defects = getattr(headers, 'defects', None)
|
||||
get_payload = getattr(headers, 'get_payload', None)
|
||||
defects = getattr(headers, "defects", None)
|
||||
get_payload = getattr(headers, "get_payload", None)
|
||||
|
||||
unparsed_data = None
|
||||
if get_payload: # Platform-specific: Python 3.
|
||||
unparsed_data = get_payload()
|
||||
if get_payload:
|
||||
# get_payload is actually email.message.Message.get_payload;
|
||||
# we're only interested in the result if it's not a multipart message
|
||||
if not headers.is_multipart():
|
||||
payload = get_payload()
|
||||
|
||||
if isinstance(payload, (bytes, str)):
|
||||
unparsed_data = payload
|
||||
|
||||
if defects or unparsed_data:
|
||||
raise HeaderParsingError(defects=defects, unparsed_data=unparsed_data)
|
||||
@@ -78,4 +83,4 @@ def is_response_to_head(response):
|
||||
method = response._method
|
||||
if isinstance(method, int): # Platform-specific: Appengine
|
||||
return method == 3
|
||||
return method.upper() == 'HEAD'
|
||||
return method.upper() == "HEAD"
|
||||
|
||||
+76
-27
@@ -19,9 +19,11 @@ from ..packages import six
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Data structure for representing the metadata of requests that result in a retry.
|
||||
RequestHistory = namedtuple('RequestHistory', ["method", "url", "error",
|
||||
"status", "redirect_location"])
|
||||
RequestHistory = namedtuple(
|
||||
"RequestHistory", ["method", "url", "error", "status", "redirect_location"]
|
||||
)
|
||||
|
||||
|
||||
class Retry(object):
|
||||
@@ -114,7 +116,7 @@ class Retry(object):
|
||||
(most errors are resolved immediately by a second try without a
|
||||
delay). urllib3 will sleep for::
|
||||
|
||||
{backoff factor} * (2 ^ ({number of total retries} - 1))
|
||||
{backoff factor} * (2 ** ({number of total retries} - 1))
|
||||
|
||||
seconds. If the backoff_factor is 0.1, then :func:`.sleep` will sleep
|
||||
for [0.0s, 0.2s, 0.4s, ...] between retries. It will never be longer
|
||||
@@ -139,20 +141,39 @@ class Retry(object):
|
||||
Whether to respect Retry-After header on status codes defined as
|
||||
:attr:`Retry.RETRY_AFTER_STATUS_CODES` or not.
|
||||
|
||||
:param iterable remove_headers_on_redirect:
|
||||
Sequence of headers to remove from the request when a response
|
||||
indicating a redirect is returned before firing off the redirected
|
||||
request.
|
||||
"""
|
||||
|
||||
DEFAULT_METHOD_WHITELIST = frozenset([
|
||||
'HEAD', 'GET', 'PUT', 'DELETE', 'OPTIONS', 'TRACE'])
|
||||
DEFAULT_METHOD_WHITELIST = frozenset(
|
||||
["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"]
|
||||
)
|
||||
|
||||
RETRY_AFTER_STATUS_CODES = frozenset([413, 429, 503])
|
||||
|
||||
DEFAULT_REDIRECT_HEADERS_BLACKLIST = frozenset(["Authorization"])
|
||||
|
||||
#: Maximum backoff time.
|
||||
BACKOFF_MAX = 120
|
||||
|
||||
def __init__(self, total=10, connect=None, read=None, redirect=None, status=None,
|
||||
method_whitelist=DEFAULT_METHOD_WHITELIST, status_forcelist=None,
|
||||
backoff_factor=0, raise_on_redirect=True, raise_on_status=True,
|
||||
history=None, respect_retry_after_header=True):
|
||||
def __init__(
|
||||
self,
|
||||
total=10,
|
||||
connect=None,
|
||||
read=None,
|
||||
redirect=None,
|
||||
status=None,
|
||||
method_whitelist=DEFAULT_METHOD_WHITELIST,
|
||||
status_forcelist=None,
|
||||
backoff_factor=0,
|
||||
raise_on_redirect=True,
|
||||
raise_on_status=True,
|
||||
history=None,
|
||||
respect_retry_after_header=True,
|
||||
remove_headers_on_redirect=DEFAULT_REDIRECT_HEADERS_BLACKLIST,
|
||||
):
|
||||
|
||||
self.total = total
|
||||
self.connect = connect
|
||||
@@ -171,17 +192,25 @@ class Retry(object):
|
||||
self.raise_on_status = raise_on_status
|
||||
self.history = history or tuple()
|
||||
self.respect_retry_after_header = respect_retry_after_header
|
||||
self.remove_headers_on_redirect = frozenset(
|
||||
[h.lower() for h in remove_headers_on_redirect]
|
||||
)
|
||||
|
||||
def new(self, **kw):
|
||||
params = dict(
|
||||
total=self.total,
|
||||
connect=self.connect, read=self.read, redirect=self.redirect, status=self.status,
|
||||
connect=self.connect,
|
||||
read=self.read,
|
||||
redirect=self.redirect,
|
||||
status=self.status,
|
||||
method_whitelist=self.method_whitelist,
|
||||
status_forcelist=self.status_forcelist,
|
||||
backoff_factor=self.backoff_factor,
|
||||
raise_on_redirect=self.raise_on_redirect,
|
||||
raise_on_status=self.raise_on_status,
|
||||
history=self.history,
|
||||
remove_headers_on_redirect=self.remove_headers_on_redirect,
|
||||
respect_retry_after_header=self.respect_retry_after_header,
|
||||
)
|
||||
params.update(kw)
|
||||
return type(self)(**params)
|
||||
@@ -206,8 +235,11 @@ class Retry(object):
|
||||
:rtype: float
|
||||
"""
|
||||
# We want to consider only the last consecutive errors sequence (Ignore redirects).
|
||||
consecutive_errors_len = len(list(takewhile(lambda x: x.redirect_location is None,
|
||||
reversed(self.history))))
|
||||
consecutive_errors_len = len(
|
||||
list(
|
||||
takewhile(lambda x: x.redirect_location is None, reversed(self.history))
|
||||
)
|
||||
)
|
||||
if consecutive_errors_len <= 1:
|
||||
return 0
|
||||
|
||||
@@ -263,7 +295,7 @@ class Retry(object):
|
||||
this method will return immediately.
|
||||
"""
|
||||
|
||||
if response:
|
||||
if self.respect_retry_after_header and response:
|
||||
slept = self.sleep_for_retry(response)
|
||||
if slept:
|
||||
return
|
||||
@@ -304,8 +336,12 @@ class Retry(object):
|
||||
if self.status_forcelist and status_code in self.status_forcelist:
|
||||
return True
|
||||
|
||||
return (self.total and self.respect_retry_after_header and
|
||||
has_retry_after and (status_code in self.RETRY_AFTER_STATUS_CODES))
|
||||
return (
|
||||
self.total
|
||||
and self.respect_retry_after_header
|
||||
and has_retry_after
|
||||
and (status_code in self.RETRY_AFTER_STATUS_CODES)
|
||||
)
|
||||
|
||||
def is_exhausted(self):
|
||||
""" Are we out of retries? """
|
||||
@@ -316,8 +352,15 @@ class Retry(object):
|
||||
|
||||
return min(retry_counts) < 0
|
||||
|
||||
def increment(self, method=None, url=None, response=None, error=None,
|
||||
_pool=None, _stacktrace=None):
|
||||
def increment(
|
||||
self,
|
||||
method=None,
|
||||
url=None,
|
||||
response=None,
|
||||
error=None,
|
||||
_pool=None,
|
||||
_stacktrace=None,
|
||||
):
|
||||
""" Return a new Retry object with incremented retry counters.
|
||||
|
||||
:param response: A response object, or None, if the server did not
|
||||
@@ -340,7 +383,7 @@ class Retry(object):
|
||||
read = self.read
|
||||
redirect = self.redirect
|
||||
status_count = self.status
|
||||
cause = 'unknown'
|
||||
cause = "unknown"
|
||||
status = None
|
||||
redirect_location = None
|
||||
|
||||
@@ -362,7 +405,7 @@ class Retry(object):
|
||||
# Redirect retry?
|
||||
if redirect is not None:
|
||||
redirect -= 1
|
||||
cause = 'too many redirects'
|
||||
cause = "too many redirects"
|
||||
redirect_location = response.get_redirect_location()
|
||||
status = response.status
|
||||
|
||||
@@ -373,16 +416,21 @@ class Retry(object):
|
||||
if response and response.status:
|
||||
if status_count is not None:
|
||||
status_count -= 1
|
||||
cause = ResponseError.SPECIFIC_ERROR.format(
|
||||
status_code=response.status)
|
||||
cause = ResponseError.SPECIFIC_ERROR.format(status_code=response.status)
|
||||
status = response.status
|
||||
|
||||
history = self.history + (RequestHistory(method, url, error, status, redirect_location),)
|
||||
history = self.history + (
|
||||
RequestHistory(method, url, error, status, redirect_location),
|
||||
)
|
||||
|
||||
new_retry = self.new(
|
||||
total=total,
|
||||
connect=connect, read=read, redirect=redirect, status=status_count,
|
||||
history=history)
|
||||
connect=connect,
|
||||
read=read,
|
||||
redirect=redirect,
|
||||
status=status_count,
|
||||
history=history,
|
||||
)
|
||||
|
||||
if new_retry.is_exhausted():
|
||||
raise MaxRetryError(_pool, url, error or ResponseError(cause))
|
||||
@@ -392,9 +440,10 @@ class Retry(object):
|
||||
return new_retry
|
||||
|
||||
def __repr__(self):
|
||||
return ('{cls.__name__}(total={self.total}, connect={self.connect}, '
|
||||
'read={self.read}, redirect={self.redirect}, status={self.status})').format(
|
||||
cls=type(self), self=self)
|
||||
return (
|
||||
"{cls.__name__}(total={self.total}, connect={self.connect}, "
|
||||
"read={self.read}, redirect={self.redirect}, status={self.status})"
|
||||
).format(cls=type(self), self=self)
|
||||
|
||||
|
||||
# For backwards compatibility (equivalent to pre-v1.9):
|
||||
|
||||
@@ -1,581 +0,0 @@
|
||||
# Backport of selectors.py from Python 3.5+ to support Python < 3.4
|
||||
# Also has the behavior specified in PEP 475 which is to retry syscalls
|
||||
# in the case of an EINTR error. This module is required because selectors34
|
||||
# does not follow this behavior and instead returns that no dile descriptor
|
||||
# events have occurred rather than retry the syscall. The decision to drop
|
||||
# support for select.devpoll is made to maintain 100% test coverage.
|
||||
|
||||
import errno
|
||||
import math
|
||||
import select
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from collections import namedtuple, Mapping
|
||||
|
||||
try:
|
||||
monotonic = time.monotonic
|
||||
except (AttributeError, ImportError): # Python 3.3<
|
||||
monotonic = time.time
|
||||
|
||||
EVENT_READ = (1 << 0)
|
||||
EVENT_WRITE = (1 << 1)
|
||||
|
||||
HAS_SELECT = True # Variable that shows whether the platform has a selector.
|
||||
_SYSCALL_SENTINEL = object() # Sentinel in case a system call returns None.
|
||||
_DEFAULT_SELECTOR = None
|
||||
|
||||
|
||||
class SelectorError(Exception):
|
||||
def __init__(self, errcode):
|
||||
super(SelectorError, self).__init__()
|
||||
self.errno = errcode
|
||||
|
||||
def __repr__(self):
|
||||
return "<SelectorError errno={0}>".format(self.errno)
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
def _fileobj_to_fd(fileobj):
|
||||
""" Return a file descriptor from a file object. If
|
||||
given an integer will simply return that integer back. """
|
||||
if isinstance(fileobj, int):
|
||||
fd = fileobj
|
||||
else:
|
||||
try:
|
||||
fd = int(fileobj.fileno())
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
raise ValueError("Invalid file object: {0!r}".format(fileobj))
|
||||
if fd < 0:
|
||||
raise ValueError("Invalid file descriptor: {0}".format(fd))
|
||||
return fd
|
||||
|
||||
|
||||
# Determine which function to use to wrap system calls because Python 3.5+
|
||||
# already handles the case when system calls are interrupted.
|
||||
if sys.version_info >= (3, 5):
|
||||
def _syscall_wrapper(func, _, *args, **kwargs):
|
||||
""" This is the short-circuit version of the below logic
|
||||
because in Python 3.5+ all system calls automatically restart
|
||||
and recalculate their timeouts. """
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except (OSError, IOError, select.error) as e:
|
||||
errcode = None
|
||||
if hasattr(e, "errno"):
|
||||
errcode = e.errno
|
||||
raise SelectorError(errcode)
|
||||
else:
|
||||
def _syscall_wrapper(func, recalc_timeout, *args, **kwargs):
|
||||
""" Wrapper function for syscalls that could fail due to EINTR.
|
||||
All functions should be retried if there is time left in the timeout
|
||||
in accordance with PEP 475. """
|
||||
timeout = kwargs.get("timeout", None)
|
||||
if timeout is None:
|
||||
expires = None
|
||||
recalc_timeout = False
|
||||
else:
|
||||
timeout = float(timeout)
|
||||
if timeout < 0.0: # Timeout less than 0 treated as no timeout.
|
||||
expires = None
|
||||
else:
|
||||
expires = monotonic() + timeout
|
||||
|
||||
args = list(args)
|
||||
if recalc_timeout and "timeout" not in kwargs:
|
||||
raise ValueError(
|
||||
"Timeout must be in args or kwargs to be recalculated")
|
||||
|
||||
result = _SYSCALL_SENTINEL
|
||||
while result is _SYSCALL_SENTINEL:
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
# OSError is thrown by select.select
|
||||
# IOError is thrown by select.epoll.poll
|
||||
# select.error is thrown by select.poll.poll
|
||||
# Aren't we thankful for Python 3.x rework for exceptions?
|
||||
except (OSError, IOError, select.error) as e:
|
||||
# select.error wasn't a subclass of OSError in the past.
|
||||
errcode = None
|
||||
if hasattr(e, "errno"):
|
||||
errcode = e.errno
|
||||
elif hasattr(e, "args"):
|
||||
errcode = e.args[0]
|
||||
|
||||
# Also test for the Windows equivalent of EINTR.
|
||||
is_interrupt = (errcode == errno.EINTR or (hasattr(errno, "WSAEINTR") and
|
||||
errcode == errno.WSAEINTR))
|
||||
|
||||
if is_interrupt:
|
||||
if expires is not None:
|
||||
current_time = monotonic()
|
||||
if current_time > expires:
|
||||
raise OSError(errno=errno.ETIMEDOUT)
|
||||
if recalc_timeout:
|
||||
if "timeout" in kwargs:
|
||||
kwargs["timeout"] = expires - current_time
|
||||
continue
|
||||
if errcode:
|
||||
raise SelectorError(errcode)
|
||||
else:
|
||||
raise
|
||||
return result
|
||||
|
||||
|
||||
SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data'])
|
||||
|
||||
|
||||
class _SelectorMapping(Mapping):
|
||||
""" Mapping of file objects to selector keys """
|
||||
|
||||
def __init__(self, selector):
|
||||
self._selector = selector
|
||||
|
||||
def __len__(self):
|
||||
return len(self._selector._fd_to_key)
|
||||
|
||||
def __getitem__(self, fileobj):
|
||||
try:
|
||||
fd = self._selector._fileobj_lookup(fileobj)
|
||||
return self._selector._fd_to_key[fd]
|
||||
except KeyError:
|
||||
raise KeyError("{0!r} is not registered.".format(fileobj))
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._selector._fd_to_key)
|
||||
|
||||
|
||||
class BaseSelector(object):
|
||||
""" Abstract Selector class
|
||||
|
||||
A selector supports registering file objects to be monitored
|
||||
for specific I/O events.
|
||||
|
||||
A file object is a file descriptor or any object with a
|
||||
`fileno()` method. An arbitrary object can be attached to the
|
||||
file object which can be used for example to store context info,
|
||||
a callback, etc.
|
||||
|
||||
A selector can use various implementations (select(), poll(), epoll(),
|
||||
and kqueue()) depending on the platform. The 'DefaultSelector' class uses
|
||||
the most efficient implementation for the current platform.
|
||||
"""
|
||||
def __init__(self):
|
||||
# Maps file descriptors to keys.
|
||||
self._fd_to_key = {}
|
||||
|
||||
# Read-only mapping returned by get_map()
|
||||
self._map = _SelectorMapping(self)
|
||||
|
||||
def _fileobj_lookup(self, fileobj):
|
||||
""" Return a file descriptor from a file object.
|
||||
This wraps _fileobj_to_fd() to do an exhaustive
|
||||
search in case the object is invalid but we still
|
||||
have it in our map. Used by unregister() so we can
|
||||
unregister an object that was previously registered
|
||||
even if it is closed. It is also used by _SelectorMapping
|
||||
"""
|
||||
try:
|
||||
return _fileobj_to_fd(fileobj)
|
||||
except ValueError:
|
||||
|
||||
# Search through all our mapped keys.
|
||||
for key in self._fd_to_key.values():
|
||||
if key.fileobj is fileobj:
|
||||
return key.fd
|
||||
|
||||
# Raise ValueError after all.
|
||||
raise
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
""" Register a file object for a set of events to monitor. """
|
||||
if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)):
|
||||
raise ValueError("Invalid events: {0!r}".format(events))
|
||||
|
||||
key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data)
|
||||
|
||||
if key.fd in self._fd_to_key:
|
||||
raise KeyError("{0!r} (FD {1}) is already registered"
|
||||
.format(fileobj, key.fd))
|
||||
|
||||
self._fd_to_key[key.fd] = key
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
""" Unregister a file object from being monitored. """
|
||||
try:
|
||||
key = self._fd_to_key.pop(self._fileobj_lookup(fileobj))
|
||||
except KeyError:
|
||||
raise KeyError("{0!r} is not registered".format(fileobj))
|
||||
|
||||
# Getting the fileno of a closed socket on Windows errors with EBADF.
|
||||
except socket.error as e: # Platform-specific: Windows.
|
||||
if e.errno != errno.EBADF:
|
||||
raise
|
||||
else:
|
||||
for key in self._fd_to_key.values():
|
||||
if key.fileobj is fileobj:
|
||||
self._fd_to_key.pop(key.fd)
|
||||
break
|
||||
else:
|
||||
raise KeyError("{0!r} is not registered".format(fileobj))
|
||||
return key
|
||||
|
||||
def modify(self, fileobj, events, data=None):
|
||||
""" Change a registered file object monitored events and data. """
|
||||
# NOTE: Some subclasses optimize this operation even further.
|
||||
try:
|
||||
key = self._fd_to_key[self._fileobj_lookup(fileobj)]
|
||||
except KeyError:
|
||||
raise KeyError("{0!r} is not registered".format(fileobj))
|
||||
|
||||
if events != key.events:
|
||||
self.unregister(fileobj)
|
||||
key = self.register(fileobj, events, data)
|
||||
|
||||
elif data != key.data:
|
||||
# Use a shortcut to update the data.
|
||||
key = key._replace(data=data)
|
||||
self._fd_to_key[key.fd] = key
|
||||
|
||||
return key
|
||||
|
||||
def select(self, timeout=None):
|
||||
""" Perform the actual selection until some monitored file objects
|
||||
are ready or the timeout expires. """
|
||||
raise NotImplementedError()
|
||||
|
||||
def close(self):
|
||||
""" Close the selector. This must be called to ensure that all
|
||||
underlying resources are freed. """
|
||||
self._fd_to_key.clear()
|
||||
self._map = None
|
||||
|
||||
def get_key(self, fileobj):
|
||||
""" Return the key associated with a registered file object. """
|
||||
mapping = self.get_map()
|
||||
if mapping is None:
|
||||
raise RuntimeError("Selector is closed")
|
||||
try:
|
||||
return mapping[fileobj]
|
||||
except KeyError:
|
||||
raise KeyError("{0!r} is not registered".format(fileobj))
|
||||
|
||||
def get_map(self):
|
||||
""" Return a mapping of file objects to selector keys """
|
||||
return self._map
|
||||
|
||||
def _key_from_fd(self, fd):
|
||||
""" Return the key associated to a given file descriptor
|
||||
Return None if it is not found. """
|
||||
try:
|
||||
return self._fd_to_key[fd]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.close()
|
||||
|
||||
|
||||
# Almost all platforms have select.select()
|
||||
if hasattr(select, "select"):
|
||||
class SelectSelector(BaseSelector):
|
||||
""" Select-based selector. """
|
||||
def __init__(self):
|
||||
super(SelectSelector, self).__init__()
|
||||
self._readers = set()
|
||||
self._writers = set()
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
key = super(SelectSelector, self).register(fileobj, events, data)
|
||||
if events & EVENT_READ:
|
||||
self._readers.add(key.fd)
|
||||
if events & EVENT_WRITE:
|
||||
self._writers.add(key.fd)
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
key = super(SelectSelector, self).unregister(fileobj)
|
||||
self._readers.discard(key.fd)
|
||||
self._writers.discard(key.fd)
|
||||
return key
|
||||
|
||||
def _select(self, r, w, timeout=None):
|
||||
""" Wrapper for select.select because timeout is a positional arg """
|
||||
return select.select(r, w, [], timeout)
|
||||
|
||||
def select(self, timeout=None):
|
||||
# Selecting on empty lists on Windows errors out.
|
||||
if not len(self._readers) and not len(self._writers):
|
||||
return []
|
||||
|
||||
timeout = None if timeout is None else max(timeout, 0.0)
|
||||
ready = []
|
||||
r, w, _ = _syscall_wrapper(self._select, True, self._readers,
|
||||
self._writers, timeout)
|
||||
r = set(r)
|
||||
w = set(w)
|
||||
for fd in r | w:
|
||||
events = 0
|
||||
if fd in r:
|
||||
events |= EVENT_READ
|
||||
if fd in w:
|
||||
events |= EVENT_WRITE
|
||||
|
||||
key = self._key_from_fd(fd)
|
||||
if key:
|
||||
ready.append((key, events & key.events))
|
||||
return ready
|
||||
|
||||
|
||||
if hasattr(select, "poll"):
|
||||
class PollSelector(BaseSelector):
|
||||
""" Poll-based selector """
|
||||
def __init__(self):
|
||||
super(PollSelector, self).__init__()
|
||||
self._poll = select.poll()
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
key = super(PollSelector, self).register(fileobj, events, data)
|
||||
event_mask = 0
|
||||
if events & EVENT_READ:
|
||||
event_mask |= select.POLLIN
|
||||
if events & EVENT_WRITE:
|
||||
event_mask |= select.POLLOUT
|
||||
self._poll.register(key.fd, event_mask)
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
key = super(PollSelector, self).unregister(fileobj)
|
||||
self._poll.unregister(key.fd)
|
||||
return key
|
||||
|
||||
def _wrap_poll(self, timeout=None):
|
||||
""" Wrapper function for select.poll.poll() so that
|
||||
_syscall_wrapper can work with only seconds. """
|
||||
if timeout is not None:
|
||||
if timeout <= 0:
|
||||
timeout = 0
|
||||
else:
|
||||
# select.poll.poll() has a resolution of 1 millisecond,
|
||||
# round away from zero to wait *at least* timeout seconds.
|
||||
timeout = math.ceil(timeout * 1e3)
|
||||
|
||||
result = self._poll.poll(timeout)
|
||||
return result
|
||||
|
||||
def select(self, timeout=None):
|
||||
ready = []
|
||||
fd_events = _syscall_wrapper(self._wrap_poll, True, timeout=timeout)
|
||||
for fd, event_mask in fd_events:
|
||||
events = 0
|
||||
if event_mask & ~select.POLLIN:
|
||||
events |= EVENT_WRITE
|
||||
if event_mask & ~select.POLLOUT:
|
||||
events |= EVENT_READ
|
||||
|
||||
key = self._key_from_fd(fd)
|
||||
if key:
|
||||
ready.append((key, events & key.events))
|
||||
|
||||
return ready
|
||||
|
||||
|
||||
if hasattr(select, "epoll"):
|
||||
class EpollSelector(BaseSelector):
|
||||
""" Epoll-based selector """
|
||||
def __init__(self):
|
||||
super(EpollSelector, self).__init__()
|
||||
self._epoll = select.epoll()
|
||||
|
||||
def fileno(self):
|
||||
return self._epoll.fileno()
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
key = super(EpollSelector, self).register(fileobj, events, data)
|
||||
events_mask = 0
|
||||
if events & EVENT_READ:
|
||||
events_mask |= select.EPOLLIN
|
||||
if events & EVENT_WRITE:
|
||||
events_mask |= select.EPOLLOUT
|
||||
_syscall_wrapper(self._epoll.register, False, key.fd, events_mask)
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
key = super(EpollSelector, self).unregister(fileobj)
|
||||
try:
|
||||
_syscall_wrapper(self._epoll.unregister, False, key.fd)
|
||||
except SelectorError:
|
||||
# This can occur when the fd was closed since registry.
|
||||
pass
|
||||
return key
|
||||
|
||||
def select(self, timeout=None):
|
||||
if timeout is not None:
|
||||
if timeout <= 0:
|
||||
timeout = 0.0
|
||||
else:
|
||||
# select.epoll.poll() has a resolution of 1 millisecond
|
||||
# but luckily takes seconds so we don't need a wrapper
|
||||
# like PollSelector. Just for better rounding.
|
||||
timeout = math.ceil(timeout * 1e3) * 1e-3
|
||||
timeout = float(timeout)
|
||||
else:
|
||||
timeout = -1.0 # epoll.poll() must have a float.
|
||||
|
||||
# We always want at least 1 to ensure that select can be called
|
||||
# with no file descriptors registered. Otherwise will fail.
|
||||
max_events = max(len(self._fd_to_key), 1)
|
||||
|
||||
ready = []
|
||||
fd_events = _syscall_wrapper(self._epoll.poll, True,
|
||||
timeout=timeout,
|
||||
maxevents=max_events)
|
||||
for fd, event_mask in fd_events:
|
||||
events = 0
|
||||
if event_mask & ~select.EPOLLIN:
|
||||
events |= EVENT_WRITE
|
||||
if event_mask & ~select.EPOLLOUT:
|
||||
events |= EVENT_READ
|
||||
|
||||
key = self._key_from_fd(fd)
|
||||
if key:
|
||||
ready.append((key, events & key.events))
|
||||
return ready
|
||||
|
||||
def close(self):
|
||||
self._epoll.close()
|
||||
super(EpollSelector, self).close()
|
||||
|
||||
|
||||
if hasattr(select, "kqueue"):
|
||||
class KqueueSelector(BaseSelector):
|
||||
""" Kqueue / Kevent-based selector """
|
||||
def __init__(self):
|
||||
super(KqueueSelector, self).__init__()
|
||||
self._kqueue = select.kqueue()
|
||||
|
||||
def fileno(self):
|
||||
return self._kqueue.fileno()
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
key = super(KqueueSelector, self).register(fileobj, events, data)
|
||||
if events & EVENT_READ:
|
||||
kevent = select.kevent(key.fd,
|
||||
select.KQ_FILTER_READ,
|
||||
select.KQ_EV_ADD)
|
||||
|
||||
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
|
||||
|
||||
if events & EVENT_WRITE:
|
||||
kevent = select.kevent(key.fd,
|
||||
select.KQ_FILTER_WRITE,
|
||||
select.KQ_EV_ADD)
|
||||
|
||||
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
|
||||
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
key = super(KqueueSelector, self).unregister(fileobj)
|
||||
if key.events & EVENT_READ:
|
||||
kevent = select.kevent(key.fd,
|
||||
select.KQ_FILTER_READ,
|
||||
select.KQ_EV_DELETE)
|
||||
try:
|
||||
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
|
||||
except SelectorError:
|
||||
pass
|
||||
if key.events & EVENT_WRITE:
|
||||
kevent = select.kevent(key.fd,
|
||||
select.KQ_FILTER_WRITE,
|
||||
select.KQ_EV_DELETE)
|
||||
try:
|
||||
_syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0)
|
||||
except SelectorError:
|
||||
pass
|
||||
|
||||
return key
|
||||
|
||||
def select(self, timeout=None):
|
||||
if timeout is not None:
|
||||
timeout = max(timeout, 0)
|
||||
|
||||
max_events = len(self._fd_to_key) * 2
|
||||
ready_fds = {}
|
||||
|
||||
kevent_list = _syscall_wrapper(self._kqueue.control, True,
|
||||
None, max_events, timeout)
|
||||
|
||||
for kevent in kevent_list:
|
||||
fd = kevent.ident
|
||||
event_mask = kevent.filter
|
||||
events = 0
|
||||
if event_mask == select.KQ_FILTER_READ:
|
||||
events |= EVENT_READ
|
||||
if event_mask == select.KQ_FILTER_WRITE:
|
||||
events |= EVENT_WRITE
|
||||
|
||||
key = self._key_from_fd(fd)
|
||||
if key:
|
||||
if key.fd not in ready_fds:
|
||||
ready_fds[key.fd] = (key, events & key.events)
|
||||
else:
|
||||
old_events = ready_fds[key.fd][1]
|
||||
ready_fds[key.fd] = (key, (events | old_events) & key.events)
|
||||
|
||||
return list(ready_fds.values())
|
||||
|
||||
def close(self):
|
||||
self._kqueue.close()
|
||||
super(KqueueSelector, self).close()
|
||||
|
||||
|
||||
if not hasattr(select, 'select'): # Platform-specific: AppEngine
|
||||
HAS_SELECT = False
|
||||
|
||||
|
||||
def _can_allocate(struct):
|
||||
""" Checks that select structs can be allocated by the underlying
|
||||
operating system, not just advertised by the select module. We don't
|
||||
check select() because we'll be hopeful that most platforms that
|
||||
don't have it available will not advertise it. (ie: GAE) """
|
||||
try:
|
||||
# select.poll() objects won't fail until used.
|
||||
if struct == 'poll':
|
||||
p = select.poll()
|
||||
p.poll(0)
|
||||
|
||||
# All others will fail on allocation.
|
||||
else:
|
||||
getattr(select, struct)().close()
|
||||
return True
|
||||
except (OSError, AttributeError) as e:
|
||||
return False
|
||||
|
||||
|
||||
# Choose the best implementation, roughly:
|
||||
# kqueue == epoll > poll > select. Devpoll not supported. (See above)
|
||||
# select() also can't accept a FD > FD_SETSIZE (usually around 1024)
|
||||
def DefaultSelector():
|
||||
""" This function serves as a first call for DefaultSelector to
|
||||
detect if the select module is being monkey-patched incorrectly
|
||||
by eventlet, greenlet, and preserve proper behavior. """
|
||||
global _DEFAULT_SELECTOR
|
||||
if _DEFAULT_SELECTOR is None:
|
||||
if _can_allocate('kqueue'):
|
||||
_DEFAULT_SELECTOR = KqueueSelector
|
||||
elif _can_allocate('epoll'):
|
||||
_DEFAULT_SELECTOR = EpollSelector
|
||||
elif _can_allocate('poll'):
|
||||
_DEFAULT_SELECTOR = PollSelector
|
||||
elif hasattr(select, 'select'):
|
||||
_DEFAULT_SELECTOR = SelectSelector
|
||||
else: # Platform-specific: AppEngine
|
||||
raise ValueError('Platform does not have a selector')
|
||||
return _DEFAULT_SELECTOR()
|
||||
+163
-97
@@ -2,11 +2,14 @@ from __future__ import absolute_import
|
||||
import errno
|
||||
import warnings
|
||||
import hmac
|
||||
import sys
|
||||
|
||||
from binascii import hexlify, unhexlify
|
||||
from hashlib import md5, sha1, sha256
|
||||
|
||||
from .url import IPV4_RE, BRACELESS_IPV6_ADDRZ_RE
|
||||
from ..exceptions import SSLError, InsecurePlatformWarning, SNIMissingWarning
|
||||
from ..packages import six
|
||||
|
||||
|
||||
SSLContext = None
|
||||
@@ -15,11 +18,7 @@ IS_PYOPENSSL = False
|
||||
IS_SECURETRANSPORT = False
|
||||
|
||||
# Maps the length of a digest to a possible hash function producing this digest
|
||||
HASHFUNC_MAP = {
|
||||
32: md5,
|
||||
40: sha1,
|
||||
64: sha256,
|
||||
}
|
||||
HASHFUNC_MAP = {32: md5, 40: sha1, 64: sha256}
|
||||
|
||||
|
||||
def _const_compare_digest_backport(a, b):
|
||||
@@ -35,17 +34,27 @@ def _const_compare_digest_backport(a, b):
|
||||
return result == 0
|
||||
|
||||
|
||||
_const_compare_digest = getattr(hmac, 'compare_digest',
|
||||
_const_compare_digest_backport)
|
||||
|
||||
_const_compare_digest = getattr(hmac, "compare_digest", _const_compare_digest_backport)
|
||||
|
||||
try: # Test for SSL features
|
||||
import ssl
|
||||
from ssl import wrap_socket, CERT_NONE, PROTOCOL_SSLv23
|
||||
from ssl import wrap_socket, CERT_REQUIRED
|
||||
from ssl import HAS_SNI # Has SNI?
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try: # Platform-specific: Python 3.6
|
||||
from ssl import PROTOCOL_TLS
|
||||
|
||||
PROTOCOL_SSLv23 = PROTOCOL_TLS
|
||||
except ImportError:
|
||||
try:
|
||||
from ssl import PROTOCOL_SSLv23 as PROTOCOL_TLS
|
||||
|
||||
PROTOCOL_SSLv23 = PROTOCOL_TLS
|
||||
except ImportError:
|
||||
PROTOCOL_SSLv23 = PROTOCOL_TLS = 2
|
||||
|
||||
|
||||
try:
|
||||
from ssl import OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION
|
||||
@@ -53,6 +62,7 @@ except ImportError:
|
||||
OP_NO_SSLv2, OP_NO_SSLv3 = 0x1000000, 0x2000000
|
||||
OP_NO_COMPRESSION = 0x20000
|
||||
|
||||
|
||||
# A secure default.
|
||||
# Sources for more information on TLS ciphers:
|
||||
#
|
||||
@@ -61,41 +71,39 @@ except ImportError:
|
||||
# - https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
|
||||
#
|
||||
# The general intent is:
|
||||
# - Prefer TLS 1.3 cipher suites
|
||||
# - prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE),
|
||||
# - prefer ECDHE over DHE for better performance,
|
||||
# - prefer any AES-GCM and ChaCha20 over any AES-CBC for better performance and
|
||||
# security,
|
||||
# - prefer AES-GCM over ChaCha20 because hardware-accelerated AES is common,
|
||||
# - disable NULL authentication, MD5 MACs and DSS for security reasons.
|
||||
DEFAULT_CIPHERS = ':'.join([
|
||||
'TLS13-AES-256-GCM-SHA384',
|
||||
'TLS13-CHACHA20-POLY1305-SHA256',
|
||||
'TLS13-AES-128-GCM-SHA256',
|
||||
'ECDH+AESGCM',
|
||||
'ECDH+CHACHA20',
|
||||
'DH+AESGCM',
|
||||
'DH+CHACHA20',
|
||||
'ECDH+AES256',
|
||||
'DH+AES256',
|
||||
'ECDH+AES128',
|
||||
'DH+AES',
|
||||
'RSA+AESGCM',
|
||||
'RSA+AES',
|
||||
'!aNULL',
|
||||
'!eNULL',
|
||||
'!MD5',
|
||||
])
|
||||
# - disable NULL authentication, MD5 MACs, DSS, and other
|
||||
# insecure ciphers for security reasons.
|
||||
# - NOTE: TLS 1.3 cipher suites are managed through a different interface
|
||||
# not exposed by CPython (yet!) and are enabled by default if they're available.
|
||||
DEFAULT_CIPHERS = ":".join(
|
||||
[
|
||||
"ECDHE+AESGCM",
|
||||
"ECDHE+CHACHA20",
|
||||
"DHE+AESGCM",
|
||||
"DHE+CHACHA20",
|
||||
"ECDH+AESGCM",
|
||||
"DH+AESGCM",
|
||||
"ECDH+AES",
|
||||
"DH+AES",
|
||||
"RSA+AESGCM",
|
||||
"RSA+AES",
|
||||
"!aNULL",
|
||||
"!eNULL",
|
||||
"!MD5",
|
||||
"!DSS",
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
from ssl import SSLContext # Modern SSL?
|
||||
except ImportError:
|
||||
import sys
|
||||
|
||||
class SSLContext(object): # Platform-specific: Python 2 & 3.1
|
||||
supports_set_ciphers = ((2, 7) <= sys.version_info < (3,) or
|
||||
(3, 2) <= sys.version_info)
|
||||
|
||||
class SSLContext(object): # Platform-specific: Python 2
|
||||
def __init__(self, protocol_version):
|
||||
self.protocol = protocol_version
|
||||
# Use default values from a real SSLContext
|
||||
@@ -118,36 +126,27 @@ except ImportError:
|
||||
raise SSLError("CA directories not supported in older Pythons")
|
||||
|
||||
def set_ciphers(self, cipher_suite):
|
||||
if not self.supports_set_ciphers:
|
||||
raise TypeError(
|
||||
'Your version of Python does not support setting '
|
||||
'a custom cipher suite. Please upgrade to Python '
|
||||
'2.7, 3.2, or later if you need this functionality.'
|
||||
)
|
||||
self.ciphers = cipher_suite
|
||||
|
||||
def wrap_socket(self, socket, server_hostname=None, server_side=False):
|
||||
warnings.warn(
|
||||
'A true SSLContext object is not available. This prevents '
|
||||
'urllib3 from configuring SSL appropriately and may cause '
|
||||
'certain SSL connections to fail. You can upgrade to a newer '
|
||||
'version of Python to solve this. For more information, see '
|
||||
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html'
|
||||
'#ssl-warnings',
|
||||
InsecurePlatformWarning
|
||||
"A true SSLContext object is not available. This prevents "
|
||||
"urllib3 from configuring SSL appropriately and may cause "
|
||||
"certain SSL connections to fail. You can upgrade to a newer "
|
||||
"version of Python to solve this. For more information, see "
|
||||
"https://urllib3.readthedocs.io/en/latest/advanced-usage.html"
|
||||
"#ssl-warnings",
|
||||
InsecurePlatformWarning,
|
||||
)
|
||||
kwargs = {
|
||||
'keyfile': self.keyfile,
|
||||
'certfile': self.certfile,
|
||||
'ca_certs': self.ca_certs,
|
||||
'cert_reqs': self.verify_mode,
|
||||
'ssl_version': self.protocol,
|
||||
'server_side': server_side,
|
||||
"keyfile": self.keyfile,
|
||||
"certfile": self.certfile,
|
||||
"ca_certs": self.ca_certs,
|
||||
"cert_reqs": self.verify_mode,
|
||||
"ssl_version": self.protocol,
|
||||
"server_side": server_side,
|
||||
}
|
||||
if self.supports_set_ciphers: # Platform-specific: Python 2.7+
|
||||
return wrap_socket(socket, ciphers=self.ciphers, **kwargs)
|
||||
else: # Platform-specific: Python 2.6
|
||||
return wrap_socket(socket, **kwargs)
|
||||
return wrap_socket(socket, ciphers=self.ciphers, **kwargs)
|
||||
|
||||
|
||||
def assert_fingerprint(cert, fingerprint):
|
||||
@@ -160,12 +159,11 @@ def assert_fingerprint(cert, fingerprint):
|
||||
Fingerprint as string of hexdigits, can be interspersed by colons.
|
||||
"""
|
||||
|
||||
fingerprint = fingerprint.replace(':', '').lower()
|
||||
fingerprint = fingerprint.replace(":", "").lower()
|
||||
digest_length = len(fingerprint)
|
||||
hashfunc = HASHFUNC_MAP.get(digest_length)
|
||||
if not hashfunc:
|
||||
raise SSLError(
|
||||
'Fingerprint of invalid length: {0}'.format(fingerprint))
|
||||
raise SSLError("Fingerprint of invalid length: {0}".format(fingerprint))
|
||||
|
||||
# We need encode() here for py32; works on py2 and p33.
|
||||
fingerprint_bytes = unhexlify(fingerprint.encode())
|
||||
@@ -173,8 +171,11 @@ def assert_fingerprint(cert, fingerprint):
|
||||
cert_digest = hashfunc(cert).digest()
|
||||
|
||||
if not _const_compare_digest(cert_digest, fingerprint_bytes):
|
||||
raise SSLError('Fingerprints did not match. Expected "{0}", got "{1}".'
|
||||
.format(fingerprint, hexlify(cert_digest)))
|
||||
raise SSLError(
|
||||
'Fingerprints did not match. Expected "{0}", got "{1}".'.format(
|
||||
fingerprint, hexlify(cert_digest)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def resolve_cert_reqs(candidate):
|
||||
@@ -183,18 +184,18 @@ def resolve_cert_reqs(candidate):
|
||||
the wrap_socket function/method from the ssl module.
|
||||
Defaults to :data:`ssl.CERT_NONE`.
|
||||
If given a string it is assumed to be the name of the constant in the
|
||||
:mod:`ssl` module or its abbrevation.
|
||||
:mod:`ssl` module or its abbreviation.
|
||||
(So you can specify `REQUIRED` instead of `CERT_REQUIRED`.
|
||||
If it's neither `None` nor a string we assume it is already the numeric
|
||||
constant which can directly be passed to wrap_socket.
|
||||
"""
|
||||
if candidate is None:
|
||||
return CERT_NONE
|
||||
return CERT_REQUIRED
|
||||
|
||||
if isinstance(candidate, str):
|
||||
res = getattr(ssl, candidate, None)
|
||||
if res is None:
|
||||
res = getattr(ssl, 'CERT_' + candidate)
|
||||
res = getattr(ssl, "CERT_" + candidate)
|
||||
return res
|
||||
|
||||
return candidate
|
||||
@@ -205,19 +206,20 @@ def resolve_ssl_version(candidate):
|
||||
like resolve_cert_reqs
|
||||
"""
|
||||
if candidate is None:
|
||||
return PROTOCOL_SSLv23
|
||||
return PROTOCOL_TLS
|
||||
|
||||
if isinstance(candidate, str):
|
||||
res = getattr(ssl, candidate, None)
|
||||
if res is None:
|
||||
res = getattr(ssl, 'PROTOCOL_' + candidate)
|
||||
res = getattr(ssl, "PROTOCOL_" + candidate)
|
||||
return res
|
||||
|
||||
return candidate
|
||||
|
||||
|
||||
def create_urllib3_context(ssl_version=None, cert_reqs=None,
|
||||
options=None, ciphers=None):
|
||||
def create_urllib3_context(
|
||||
ssl_version=None, cert_reqs=None, options=None, ciphers=None
|
||||
):
|
||||
"""All arguments have the same meaning as ``ssl_wrap_socket``.
|
||||
|
||||
By default, this function does a lot of the same work that
|
||||
@@ -251,7 +253,9 @@ def create_urllib3_context(ssl_version=None, cert_reqs=None,
|
||||
Constructed SSLContext object with specified options
|
||||
:rtype: SSLContext
|
||||
"""
|
||||
context = SSLContext(ssl_version or ssl.PROTOCOL_SSLv23)
|
||||
context = SSLContext(ssl_version or PROTOCOL_TLS)
|
||||
|
||||
context.set_ciphers(ciphers or DEFAULT_CIPHERS)
|
||||
|
||||
# Setting the default here, as we may have no ssl module on import
|
||||
cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs
|
||||
@@ -268,21 +272,40 @@ def create_urllib3_context(ssl_version=None, cert_reqs=None,
|
||||
|
||||
context.options |= options
|
||||
|
||||
if getattr(context, 'supports_set_ciphers', True): # Platform-specific: Python 2.6
|
||||
context.set_ciphers(ciphers or DEFAULT_CIPHERS)
|
||||
# Enable post-handshake authentication for TLS 1.3, see GH #1634. PHA is
|
||||
# necessary for conditional client cert authentication with TLS 1.3.
|
||||
# The attribute is None for OpenSSL <= 1.1.0 or does not exist in older
|
||||
# versions of Python. We only enable on Python 3.7.4+ or if certificate
|
||||
# verification is enabled to work around Python issue #37428
|
||||
# See: https://bugs.python.org/issue37428
|
||||
if (cert_reqs == ssl.CERT_REQUIRED or sys.version_info >= (3, 7, 4)) and getattr(
|
||||
context, "post_handshake_auth", None
|
||||
) is not None:
|
||||
context.post_handshake_auth = True
|
||||
|
||||
context.verify_mode = cert_reqs
|
||||
if getattr(context, 'check_hostname', None) is not None: # Platform-specific: Python 3.2
|
||||
if (
|
||||
getattr(context, "check_hostname", None) is not None
|
||||
): # Platform-specific: Python 3.2
|
||||
# We do our own verification, including fingerprints and alternative
|
||||
# hostnames. So disable it here
|
||||
context.check_hostname = False
|
||||
return context
|
||||
|
||||
|
||||
def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
|
||||
ca_certs=None, server_hostname=None,
|
||||
ssl_version=None, ciphers=None, ssl_context=None,
|
||||
ca_cert_dir=None):
|
||||
def ssl_wrap_socket(
|
||||
sock,
|
||||
keyfile=None,
|
||||
certfile=None,
|
||||
cert_reqs=None,
|
||||
ca_certs=None,
|
||||
server_hostname=None,
|
||||
ssl_version=None,
|
||||
ciphers=None,
|
||||
ssl_context=None,
|
||||
ca_cert_dir=None,
|
||||
key_password=None,
|
||||
):
|
||||
"""
|
||||
All arguments except for server_hostname, ssl_context, and ca_cert_dir have
|
||||
the same meaning as they do when using :func:`ssl.wrap_socket`.
|
||||
@@ -293,25 +316,25 @@ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
|
||||
A pre-made :class:`SSLContext` object. If none is provided, one will
|
||||
be created using :func:`create_urllib3_context`.
|
||||
:param ciphers:
|
||||
A string of ciphers we wish the client to support. This is not
|
||||
supported on Python 2.6 as the ssl module does not support it.
|
||||
A string of ciphers we wish the client to support.
|
||||
:param ca_cert_dir:
|
||||
A directory containing CA certificates in multiple separate files, as
|
||||
supported by OpenSSL's -CApath flag or the capath argument to
|
||||
SSLContext.load_verify_locations().
|
||||
:param key_password:
|
||||
Optional password if the keyfile is encrypted.
|
||||
"""
|
||||
context = ssl_context
|
||||
if context is None:
|
||||
# Note: This branch of code and all the variables in it are no longer
|
||||
# used by urllib3 itself. We should consider deprecating and removing
|
||||
# this code.
|
||||
context = create_urllib3_context(ssl_version, cert_reqs,
|
||||
ciphers=ciphers)
|
||||
context = create_urllib3_context(ssl_version, cert_reqs, ciphers=ciphers)
|
||||
|
||||
if ca_certs or ca_cert_dir:
|
||||
try:
|
||||
context.load_verify_locations(ca_certs, ca_cert_dir)
|
||||
except IOError as e: # Platform-specific: Python 2.6, 2.7, 3.2
|
||||
except IOError as e: # Platform-specific: Python 2.7
|
||||
raise SSLError(e)
|
||||
# Py33 raises FileNotFoundError which subclasses OSError
|
||||
# These are not equivalent unless we check the errno attribute
|
||||
@@ -319,23 +342,66 @@ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
|
||||
if e.errno == errno.ENOENT:
|
||||
raise SSLError(e)
|
||||
raise
|
||||
elif getattr(context, 'load_default_certs', None) is not None:
|
||||
|
||||
elif ssl_context is None and hasattr(context, "load_default_certs"):
|
||||
# try to load OS default certs; works well on Windows (require Python3.4+)
|
||||
context.load_default_certs()
|
||||
|
||||
if certfile:
|
||||
context.load_cert_chain(certfile, keyfile)
|
||||
if HAS_SNI: # Platform-specific: OpenSSL with enabled SNI
|
||||
return context.wrap_socket(sock, server_hostname=server_hostname)
|
||||
# Attempt to detect if we get the goofy behavior of the
|
||||
# keyfile being encrypted and OpenSSL asking for the
|
||||
# passphrase via the terminal and instead error out.
|
||||
if keyfile and key_password is None and _is_key_file_encrypted(keyfile):
|
||||
raise SSLError("Client private key is encrypted, password is required")
|
||||
|
||||
if certfile:
|
||||
if key_password is None:
|
||||
context.load_cert_chain(certfile, keyfile)
|
||||
else:
|
||||
context.load_cert_chain(certfile, keyfile, key_password)
|
||||
|
||||
# If we detect server_hostname is an IP address then the SNI
|
||||
# extension should not be used according to RFC3546 Section 3.1
|
||||
# We shouldn't warn the user if SNI isn't available but we would
|
||||
# not be using SNI anyways due to IP address for server_hostname.
|
||||
if (
|
||||
server_hostname is not None and not is_ipaddress(server_hostname)
|
||||
) or IS_SECURETRANSPORT:
|
||||
if HAS_SNI and server_hostname is not None:
|
||||
return context.wrap_socket(sock, server_hostname=server_hostname)
|
||||
|
||||
warnings.warn(
|
||||
"An HTTPS request has been made, but the SNI (Server Name "
|
||||
"Indication) extension to TLS is not available on this platform. "
|
||||
"This may cause the server to present an incorrect TLS "
|
||||
"certificate, which can cause validation failures. You can upgrade to "
|
||||
"a newer version of Python to solve this. For more information, see "
|
||||
"https://urllib3.readthedocs.io/en/latest/advanced-usage.html"
|
||||
"#ssl-warnings",
|
||||
SNIMissingWarning,
|
||||
)
|
||||
|
||||
warnings.warn(
|
||||
'An HTTPS request has been made, but the SNI (Subject Name '
|
||||
'Indication) extension to TLS is not available on this platform. '
|
||||
'This may cause the server to present an incorrect TLS '
|
||||
'certificate, which can cause validation failures. You can upgrade to '
|
||||
'a newer version of Python to solve this. For more information, see '
|
||||
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html'
|
||||
'#ssl-warnings',
|
||||
SNIMissingWarning
|
||||
)
|
||||
return context.wrap_socket(sock)
|
||||
|
||||
|
||||
def is_ipaddress(hostname):
|
||||
"""Detects whether the hostname given is an IPv4 or IPv6 address.
|
||||
Also detects IPv6 addresses with Zone IDs.
|
||||
|
||||
:param str hostname: Hostname to examine.
|
||||
:return: True if the hostname is an IP address, False otherwise.
|
||||
"""
|
||||
if not six.PY2 and isinstance(hostname, bytes):
|
||||
# IDN A-label bytes are ASCII compatible.
|
||||
hostname = hostname.decode("ascii")
|
||||
return bool(IPV4_RE.match(hostname) or BRACELESS_IPV6_ADDRZ_RE.match(hostname))
|
||||
|
||||
|
||||
def _is_key_file_encrypted(key_file):
|
||||
"""Detects if a key file is encrypted or not."""
|
||||
with open(key_file, "r") as f:
|
||||
for line in f:
|
||||
# Look for Proc-Type: 4,ENCRYPTED
|
||||
if "ENCRYPTED" in line:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
+49
-33
@@ -1,4 +1,5 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
# The default socket timeout, used by httplib to indicate that no timeout was
|
||||
# specified by the user
|
||||
from socket import _GLOBAL_DEFAULT_TIMEOUT
|
||||
@@ -45,19 +46,20 @@ class Timeout(object):
|
||||
:type total: integer, float, or None
|
||||
|
||||
:param connect:
|
||||
The maximum amount of time to wait for a connection attempt to a server
|
||||
to succeed. Omitting the parameter will default the connect timeout to
|
||||
the system default, probably `the global default timeout in socket.py
|
||||
The maximum amount of time (in seconds) to wait for a connection
|
||||
attempt to a server to succeed. Omitting the parameter will default the
|
||||
connect timeout to the system default, probably `the global default
|
||||
timeout in socket.py
|
||||
<http://hg.python.org/cpython/file/603b4d593758/Lib/socket.py#l535>`_.
|
||||
None will set an infinite timeout for connection attempts.
|
||||
|
||||
:type connect: integer, float, or None
|
||||
|
||||
:param read:
|
||||
The maximum amount of time to wait between consecutive
|
||||
read operations for a response from the server. Omitting
|
||||
the parameter will default the read timeout to the system
|
||||
default, probably `the global default timeout in socket.py
|
||||
The maximum amount of time (in seconds) to wait between consecutive
|
||||
read operations for a response from the server. Omitting the parameter
|
||||
will default the read timeout to the system default, probably `the
|
||||
global default timeout in socket.py
|
||||
<http://hg.python.org/cpython/file/603b4d593758/Lib/socket.py#l535>`_.
|
||||
None will set an infinite timeout.
|
||||
|
||||
@@ -91,14 +93,18 @@ class Timeout(object):
|
||||
DEFAULT_TIMEOUT = _GLOBAL_DEFAULT_TIMEOUT
|
||||
|
||||
def __init__(self, total=None, connect=_Default, read=_Default):
|
||||
self._connect = self._validate_timeout(connect, 'connect')
|
||||
self._read = self._validate_timeout(read, 'read')
|
||||
self.total = self._validate_timeout(total, 'total')
|
||||
self._connect = self._validate_timeout(connect, "connect")
|
||||
self._read = self._validate_timeout(read, "read")
|
||||
self.total = self._validate_timeout(total, "total")
|
||||
self._start_connect = None
|
||||
|
||||
def __str__(self):
|
||||
return '%s(connect=%r, read=%r, total=%r)' % (
|
||||
type(self).__name__, self._connect, self._read, self.total)
|
||||
return "%s(connect=%r, read=%r, total=%r)" % (
|
||||
type(self).__name__,
|
||||
self._connect,
|
||||
self._read,
|
||||
self.total,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate_timeout(cls, value, name):
|
||||
@@ -118,22 +124,31 @@ class Timeout(object):
|
||||
return value
|
||||
|
||||
if isinstance(value, bool):
|
||||
raise ValueError("Timeout cannot be a boolean value. It must "
|
||||
"be an int, float or None.")
|
||||
raise ValueError(
|
||||
"Timeout cannot be a boolean value. It must "
|
||||
"be an int, float or None."
|
||||
)
|
||||
try:
|
||||
float(value)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError("Timeout value %s was %s, but it must be an "
|
||||
"int, float or None." % (name, value))
|
||||
raise ValueError(
|
||||
"Timeout value %s was %s, but it must be an "
|
||||
"int, float or None." % (name, value)
|
||||
)
|
||||
|
||||
try:
|
||||
if value <= 0:
|
||||
raise ValueError("Attempted to set %s timeout to %s, but the "
|
||||
"timeout cannot be set to a value less "
|
||||
"than or equal to 0." % (name, value))
|
||||
except TypeError: # Python 3
|
||||
raise ValueError("Timeout value %s was %s, but it must be an "
|
||||
"int, float or None." % (name, value))
|
||||
raise ValueError(
|
||||
"Attempted to set %s timeout to %s, but the "
|
||||
"timeout cannot be set to a value less "
|
||||
"than or equal to 0." % (name, value)
|
||||
)
|
||||
except TypeError:
|
||||
# Python 3
|
||||
raise ValueError(
|
||||
"Timeout value %s was %s, but it must be an "
|
||||
"int, float or None." % (name, value)
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
@@ -165,8 +180,7 @@ class Timeout(object):
|
||||
# We can't use copy.deepcopy because that will also create a new object
|
||||
# for _GLOBAL_DEFAULT_TIMEOUT, which socket.py uses as a sentinel to
|
||||
# detect the user default.
|
||||
return Timeout(connect=self._connect, read=self._read,
|
||||
total=self.total)
|
||||
return Timeout(connect=self._connect, read=self._read, total=self.total)
|
||||
|
||||
def start_connect(self):
|
||||
""" Start the timeout clock, used during a connect() attempt
|
||||
@@ -182,14 +196,15 @@ class Timeout(object):
|
||||
def get_connect_duration(self):
|
||||
""" Gets the time elapsed since the call to :meth:`start_connect`.
|
||||
|
||||
:return: Elapsed time.
|
||||
:return: Elapsed time in seconds.
|
||||
:rtype: float
|
||||
:raises urllib3.exceptions.TimeoutStateError: if you attempt
|
||||
to get duration for a timer that hasn't been started.
|
||||
"""
|
||||
if self._start_connect is None:
|
||||
raise TimeoutStateError("Can't get connect duration for timer "
|
||||
"that has not started.")
|
||||
raise TimeoutStateError(
|
||||
"Can't get connect duration for timer " "that has not started."
|
||||
)
|
||||
return current_time() - self._start_connect
|
||||
|
||||
@property
|
||||
@@ -227,15 +242,16 @@ class Timeout(object):
|
||||
:raises urllib3.exceptions.TimeoutStateError: If :meth:`start_connect`
|
||||
has not yet been called on this object.
|
||||
"""
|
||||
if (self.total is not None and
|
||||
self.total is not self.DEFAULT_TIMEOUT and
|
||||
self._read is not None and
|
||||
self._read is not self.DEFAULT_TIMEOUT):
|
||||
if (
|
||||
self.total is not None
|
||||
and self.total is not self.DEFAULT_TIMEOUT
|
||||
and self._read is not None
|
||||
and self._read is not self.DEFAULT_TIMEOUT
|
||||
):
|
||||
# In case the connect timeout has not yet been established.
|
||||
if self._start_connect is None:
|
||||
return self._read
|
||||
return max(0, min(self.total - self.get_connect_duration(),
|
||||
self._read))
|
||||
return max(0, min(self.total - self.get_connect_duration(), self._read))
|
||||
elif self.total is not None and self.total is not self.DEFAULT_TIMEOUT:
|
||||
return max(0, self.total - self.get_connect_duration())
|
||||
else:
|
||||
|
||||
+295
-86
@@ -1,34 +1,110 @@
|
||||
from __future__ import absolute_import
|
||||
import re
|
||||
from collections import namedtuple
|
||||
|
||||
from ..exceptions import LocationParseError
|
||||
from ..packages import six
|
||||
|
||||
|
||||
url_attrs = ['scheme', 'auth', 'host', 'port', 'path', 'query', 'fragment']
|
||||
url_attrs = ["scheme", "auth", "host", "port", "path", "query", "fragment"]
|
||||
|
||||
# We only want to normalize urls with an HTTP(S) scheme.
|
||||
# urllib3 infers URLs without a scheme (None) to be http.
|
||||
NORMALIZABLE_SCHEMES = ('http', 'https', None)
|
||||
NORMALIZABLE_SCHEMES = ("http", "https", None)
|
||||
|
||||
# Almost all of these patterns were derived from the
|
||||
# 'rfc3986' module: https://github.com/python-hyper/rfc3986
|
||||
PERCENT_RE = re.compile(r"%[a-fA-F0-9]{2}")
|
||||
SCHEME_RE = re.compile(r"^(?:[a-zA-Z][a-zA-Z0-9+-]*:|/)")
|
||||
URI_RE = re.compile(
|
||||
r"^(?:([a-zA-Z][a-zA-Z0-9+.-]*):)?"
|
||||
r"(?://([^/?#]*))?"
|
||||
r"([^?#]*)"
|
||||
r"(?:\?([^#]*))?"
|
||||
r"(?:#(.*))?$",
|
||||
re.UNICODE | re.DOTALL,
|
||||
)
|
||||
|
||||
IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}"
|
||||
HEX_PAT = "[0-9A-Fa-f]{1,4}"
|
||||
LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=HEX_PAT, ipv4=IPV4_PAT)
|
||||
_subs = {"hex": HEX_PAT, "ls32": LS32_PAT}
|
||||
_variations = [
|
||||
# 6( h16 ":" ) ls32
|
||||
"(?:%(hex)s:){6}%(ls32)s",
|
||||
# "::" 5( h16 ":" ) ls32
|
||||
"::(?:%(hex)s:){5}%(ls32)s",
|
||||
# [ h16 ] "::" 4( h16 ":" ) ls32
|
||||
"(?:%(hex)s)?::(?:%(hex)s:){4}%(ls32)s",
|
||||
# [ *1( h16 ":" ) h16 ] "::" 3( h16 ":" ) ls32
|
||||
"(?:(?:%(hex)s:)?%(hex)s)?::(?:%(hex)s:){3}%(ls32)s",
|
||||
# [ *2( h16 ":" ) h16 ] "::" 2( h16 ":" ) ls32
|
||||
"(?:(?:%(hex)s:){0,2}%(hex)s)?::(?:%(hex)s:){2}%(ls32)s",
|
||||
# [ *3( h16 ":" ) h16 ] "::" h16 ":" ls32
|
||||
"(?:(?:%(hex)s:){0,3}%(hex)s)?::%(hex)s:%(ls32)s",
|
||||
# [ *4( h16 ":" ) h16 ] "::" ls32
|
||||
"(?:(?:%(hex)s:){0,4}%(hex)s)?::%(ls32)s",
|
||||
# [ *5( h16 ":" ) h16 ] "::" h16
|
||||
"(?:(?:%(hex)s:){0,5}%(hex)s)?::%(hex)s",
|
||||
# [ *6( h16 ":" ) h16 ] "::"
|
||||
"(?:(?:%(hex)s:){0,6}%(hex)s)?::",
|
||||
]
|
||||
|
||||
UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._!\-~"
|
||||
IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")"
|
||||
ZONE_ID_PAT = "(?:%25|%)(?:[" + UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+"
|
||||
IPV6_ADDRZ_PAT = r"\[" + IPV6_PAT + r"(?:" + ZONE_ID_PAT + r")?\]"
|
||||
REG_NAME_PAT = r"(?:[^\[\]%:/?#]|%[a-fA-F0-9]{2})*"
|
||||
TARGET_RE = re.compile(r"^(/[^?]*)(?:\?([^#]+))?(?:#(.*))?$")
|
||||
|
||||
IPV4_RE = re.compile("^" + IPV4_PAT + "$")
|
||||
IPV6_RE = re.compile("^" + IPV6_PAT + "$")
|
||||
IPV6_ADDRZ_RE = re.compile("^" + IPV6_ADDRZ_PAT + "$")
|
||||
BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + IPV6_ADDRZ_PAT[2:-2] + "$")
|
||||
ZONE_ID_RE = re.compile("(" + ZONE_ID_PAT + r")\]$")
|
||||
|
||||
SUBAUTHORITY_PAT = (u"^(?:(.*)@)?(%s|%s|%s)(?::([0-9]{0,5}))?$") % (
|
||||
REG_NAME_PAT,
|
||||
IPV4_PAT,
|
||||
IPV6_ADDRZ_PAT,
|
||||
)
|
||||
SUBAUTHORITY_RE = re.compile(SUBAUTHORITY_PAT, re.UNICODE | re.DOTALL)
|
||||
|
||||
UNRESERVED_CHARS = set(
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._-~"
|
||||
)
|
||||
SUB_DELIM_CHARS = set("!$&'()*+,;=")
|
||||
USERINFO_CHARS = UNRESERVED_CHARS | SUB_DELIM_CHARS | {":"}
|
||||
PATH_CHARS = USERINFO_CHARS | {"@", "/"}
|
||||
QUERY_CHARS = FRAGMENT_CHARS = PATH_CHARS | {"?"}
|
||||
|
||||
|
||||
class Url(namedtuple('Url', url_attrs)):
|
||||
class Url(namedtuple("Url", url_attrs)):
|
||||
"""
|
||||
Datastructure for representing an HTTP URL. Used as a return value for
|
||||
Data structure for representing an HTTP URL. Used as a return value for
|
||||
:func:`parse_url`. Both the scheme and host are normalized as they are
|
||||
both case-insensitive according to RFC 3986.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, scheme=None, auth=None, host=None, port=None, path=None,
|
||||
query=None, fragment=None):
|
||||
if path and not path.startswith('/'):
|
||||
path = '/' + path
|
||||
if scheme:
|
||||
def __new__(
|
||||
cls,
|
||||
scheme=None,
|
||||
auth=None,
|
||||
host=None,
|
||||
port=None,
|
||||
path=None,
|
||||
query=None,
|
||||
fragment=None,
|
||||
):
|
||||
if path and not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if scheme is not None:
|
||||
scheme = scheme.lower()
|
||||
if host and scheme in NORMALIZABLE_SCHEMES:
|
||||
host = host.lower()
|
||||
return super(Url, cls).__new__(cls, scheme, auth, host, port, path,
|
||||
query, fragment)
|
||||
return super(Url, cls).__new__(
|
||||
cls, scheme, auth, host, port, path, query, fragment
|
||||
)
|
||||
|
||||
@property
|
||||
def hostname(self):
|
||||
@@ -38,10 +114,10 @@ class Url(namedtuple('Url', url_attrs)):
|
||||
@property
|
||||
def request_uri(self):
|
||||
"""Absolute path including the query string."""
|
||||
uri = self.path or '/'
|
||||
uri = self.path or "/"
|
||||
|
||||
if self.query is not None:
|
||||
uri += '?' + self.query
|
||||
uri += "?" + self.query
|
||||
|
||||
return uri
|
||||
|
||||
@@ -49,7 +125,7 @@ class Url(namedtuple('Url', url_attrs)):
|
||||
def netloc(self):
|
||||
"""Network location including host and port"""
|
||||
if self.port:
|
||||
return '%s:%d' % (self.host, self.port)
|
||||
return "%s:%d" % (self.host, self.port)
|
||||
return self.host
|
||||
|
||||
@property
|
||||
@@ -72,23 +148,23 @@ class Url(namedtuple('Url', url_attrs)):
|
||||
'http://username:password@host.com:80/path?query#fragment'
|
||||
"""
|
||||
scheme, auth, host, port, path, query, fragment = self
|
||||
url = ''
|
||||
url = u""
|
||||
|
||||
# We use "is not None" we want things to happen with empty strings (or 0 port)
|
||||
if scheme is not None:
|
||||
url += scheme + '://'
|
||||
url += scheme + u"://"
|
||||
if auth is not None:
|
||||
url += auth + '@'
|
||||
url += auth + u"@"
|
||||
if host is not None:
|
||||
url += host
|
||||
if port is not None:
|
||||
url += ':' + str(port)
|
||||
url += u":" + str(port)
|
||||
if path is not None:
|
||||
url += path
|
||||
if query is not None:
|
||||
url += '?' + query
|
||||
url += u"?" + query
|
||||
if fragment is not None:
|
||||
url += '#' + fragment
|
||||
url += u"#" + fragment
|
||||
|
||||
return url
|
||||
|
||||
@@ -98,6 +174,8 @@ class Url(namedtuple('Url', url_attrs)):
|
||||
|
||||
def split_first(s, delims):
|
||||
"""
|
||||
.. deprecated:: 1.25
|
||||
|
||||
Given a string and an iterable of delimiters, split on the first found
|
||||
delimiter. Return two split parts and the matched delimiter.
|
||||
|
||||
@@ -124,15 +202,150 @@ def split_first(s, delims):
|
||||
min_delim = d
|
||||
|
||||
if min_idx is None or min_idx < 0:
|
||||
return s, '', None
|
||||
return s, "", None
|
||||
|
||||
return s[:min_idx], s[min_idx + 1:], min_delim
|
||||
return s[:min_idx], s[min_idx + 1 :], min_delim
|
||||
|
||||
|
||||
def _encode_invalid_chars(component, allowed_chars, encoding="utf-8"):
|
||||
"""Percent-encodes a URI component without reapplying
|
||||
onto an already percent-encoded component.
|
||||
"""
|
||||
if component is None:
|
||||
return component
|
||||
|
||||
component = six.ensure_text(component)
|
||||
|
||||
# Try to see if the component we're encoding is already percent-encoded
|
||||
# so we can skip all '%' characters but still encode all others.
|
||||
percent_encodings = PERCENT_RE.findall(component)
|
||||
|
||||
# Normalize existing percent-encoded bytes.
|
||||
for enc in percent_encodings:
|
||||
if not enc.isupper():
|
||||
component = component.replace(enc, enc.upper())
|
||||
|
||||
uri_bytes = component.encode("utf-8", "surrogatepass")
|
||||
is_percent_encoded = len(percent_encodings) == uri_bytes.count(b"%")
|
||||
|
||||
encoded_component = bytearray()
|
||||
|
||||
for i in range(0, len(uri_bytes)):
|
||||
# Will return a single character bytestring on both Python 2 & 3
|
||||
byte = uri_bytes[i : i + 1]
|
||||
byte_ord = ord(byte)
|
||||
if (is_percent_encoded and byte == b"%") or (
|
||||
byte_ord < 128 and byte.decode() in allowed_chars
|
||||
):
|
||||
encoded_component.extend(byte)
|
||||
continue
|
||||
encoded_component.extend(b"%" + (hex(byte_ord)[2:].encode().zfill(2).upper()))
|
||||
|
||||
return encoded_component.decode(encoding)
|
||||
|
||||
|
||||
def _remove_path_dot_segments(path):
|
||||
# See http://tools.ietf.org/html/rfc3986#section-5.2.4 for pseudo-code
|
||||
segments = path.split("/") # Turn the path into a list of segments
|
||||
output = [] # Initialize the variable to use to store output
|
||||
|
||||
for segment in segments:
|
||||
# '.' is the current directory, so ignore it, it is superfluous
|
||||
if segment == ".":
|
||||
continue
|
||||
# Anything other than '..', should be appended to the output
|
||||
elif segment != "..":
|
||||
output.append(segment)
|
||||
# In this case segment == '..', if we can, we should pop the last
|
||||
# element
|
||||
elif output:
|
||||
output.pop()
|
||||
|
||||
# If the path starts with '/' and the output is empty or the first string
|
||||
# is non-empty
|
||||
if path.startswith("/") and (not output or output[0]):
|
||||
output.insert(0, "")
|
||||
|
||||
# If the path starts with '/.' or '/..' ensure we add one more empty
|
||||
# string to add a trailing '/'
|
||||
if path.endswith(("/.", "/..")):
|
||||
output.append("")
|
||||
|
||||
return "/".join(output)
|
||||
|
||||
|
||||
def _normalize_host(host, scheme):
|
||||
if host:
|
||||
if isinstance(host, six.binary_type):
|
||||
host = six.ensure_str(host)
|
||||
|
||||
if scheme in NORMALIZABLE_SCHEMES:
|
||||
is_ipv6 = IPV6_ADDRZ_RE.match(host)
|
||||
if is_ipv6:
|
||||
match = ZONE_ID_RE.search(host)
|
||||
if match:
|
||||
start, end = match.span(1)
|
||||
zone_id = host[start:end]
|
||||
|
||||
if zone_id.startswith("%25") and zone_id != "%25":
|
||||
zone_id = zone_id[3:]
|
||||
else:
|
||||
zone_id = zone_id[1:]
|
||||
zone_id = "%" + _encode_invalid_chars(zone_id, UNRESERVED_CHARS)
|
||||
return host[:start].lower() + zone_id + host[end:]
|
||||
else:
|
||||
return host.lower()
|
||||
elif not IPV4_RE.match(host):
|
||||
return six.ensure_str(
|
||||
b".".join([_idna_encode(label) for label in host.split(".")])
|
||||
)
|
||||
return host
|
||||
|
||||
|
||||
def _idna_encode(name):
|
||||
if name and any([ord(x) > 128 for x in name]):
|
||||
try:
|
||||
import idna
|
||||
except ImportError:
|
||||
six.raise_from(
|
||||
LocationParseError("Unable to parse URL without the 'idna' module"),
|
||||
None,
|
||||
)
|
||||
try:
|
||||
return idna.encode(name.lower(), strict=True, std3_rules=True)
|
||||
except idna.IDNAError:
|
||||
six.raise_from(
|
||||
LocationParseError(u"Name '%s' is not a valid IDNA label" % name), None
|
||||
)
|
||||
return name.lower().encode("ascii")
|
||||
|
||||
|
||||
def _encode_target(target):
|
||||
"""Percent-encodes a request target so that there are no invalid characters"""
|
||||
if not target.startswith("/"):
|
||||
return target
|
||||
|
||||
path, query, fragment = TARGET_RE.match(target).groups()
|
||||
target = _encode_invalid_chars(path, PATH_CHARS)
|
||||
query = _encode_invalid_chars(query, QUERY_CHARS)
|
||||
fragment = _encode_invalid_chars(fragment, FRAGMENT_CHARS)
|
||||
if query is not None:
|
||||
target += "?" + query
|
||||
if fragment is not None:
|
||||
target += "#" + target
|
||||
return target
|
||||
|
||||
|
||||
def parse_url(url):
|
||||
"""
|
||||
Given a url, return a parsed :class:`.Url` namedtuple. Best-effort is
|
||||
performed to parse incomplete urls. Fields not provided will be None.
|
||||
This parser is RFC 3986 compliant.
|
||||
|
||||
The parser logic and helper functions are based heavily on
|
||||
work done in the ``rfc3986`` module.
|
||||
|
||||
:param str url: URL to parse into a :class:`.Url` namedtuple.
|
||||
|
||||
Partly backwards-compatible with :mod:`urlparse`.
|
||||
|
||||
@@ -145,81 +358,77 @@ def parse_url(url):
|
||||
>>> parse_url('/foo?bar')
|
||||
Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...)
|
||||
"""
|
||||
|
||||
# While this code has overlap with stdlib's urlparse, it is much
|
||||
# simplified for our needs and less annoying.
|
||||
# Additionally, this implementations does silly things to be optimal
|
||||
# on CPython.
|
||||
|
||||
if not url:
|
||||
# Empty
|
||||
return Url()
|
||||
|
||||
scheme = None
|
||||
auth = None
|
||||
host = None
|
||||
port = None
|
||||
path = None
|
||||
fragment = None
|
||||
query = None
|
||||
source_url = url
|
||||
if not SCHEME_RE.search(url):
|
||||
url = "//" + url
|
||||
|
||||
# Scheme
|
||||
if '://' in url:
|
||||
scheme, url = url.split('://', 1)
|
||||
try:
|
||||
scheme, authority, path, query, fragment = URI_RE.match(url).groups()
|
||||
normalize_uri = scheme is None or scheme.lower() in NORMALIZABLE_SCHEMES
|
||||
|
||||
# Find the earliest Authority Terminator
|
||||
# (http://tools.ietf.org/html/rfc3986#section-3.2)
|
||||
url, path_, delim = split_first(url, ['/', '?', '#'])
|
||||
if scheme:
|
||||
scheme = scheme.lower()
|
||||
|
||||
if delim:
|
||||
# Reassemble the path
|
||||
path = delim + path_
|
||||
|
||||
# Auth
|
||||
if '@' in url:
|
||||
# Last '@' denotes end of auth part
|
||||
auth, url = url.rsplit('@', 1)
|
||||
|
||||
# IPv6
|
||||
if url and url[0] == '[':
|
||||
host, url = url.split(']', 1)
|
||||
host += ']'
|
||||
|
||||
# Port
|
||||
if ':' in url:
|
||||
_host, port = url.split(':', 1)
|
||||
|
||||
if not host:
|
||||
host = _host
|
||||
|
||||
if port:
|
||||
# If given, ports must be integers. No whitespace, no plus or
|
||||
# minus prefixes, no non-integer digits such as ^2 (superscript).
|
||||
if not port.isdigit():
|
||||
raise LocationParseError(url)
|
||||
try:
|
||||
port = int(port)
|
||||
except ValueError:
|
||||
raise LocationParseError(url)
|
||||
if authority:
|
||||
auth, host, port = SUBAUTHORITY_RE.match(authority).groups()
|
||||
if auth and normalize_uri:
|
||||
auth = _encode_invalid_chars(auth, USERINFO_CHARS)
|
||||
if port == "":
|
||||
port = None
|
||||
else:
|
||||
# Blank ports are cool, too. (rfc3986#section-3.2.3)
|
||||
port = None
|
||||
auth, host, port = None, None, None
|
||||
|
||||
elif not host and url:
|
||||
host = url
|
||||
if port is not None:
|
||||
port = int(port)
|
||||
if not (0 <= port <= 65535):
|
||||
raise LocationParseError(url)
|
||||
|
||||
host = _normalize_host(host, scheme)
|
||||
|
||||
if normalize_uri and path:
|
||||
path = _remove_path_dot_segments(path)
|
||||
path = _encode_invalid_chars(path, PATH_CHARS)
|
||||
if normalize_uri and query:
|
||||
query = _encode_invalid_chars(query, QUERY_CHARS)
|
||||
if normalize_uri and fragment:
|
||||
fragment = _encode_invalid_chars(fragment, FRAGMENT_CHARS)
|
||||
|
||||
except (ValueError, AttributeError):
|
||||
return six.raise_from(LocationParseError(source_url), None)
|
||||
|
||||
# For the sake of backwards compatibility we put empty
|
||||
# string values for path if there are any defined values
|
||||
# beyond the path in the URL.
|
||||
# TODO: Remove this when we break backwards compatibility.
|
||||
if not path:
|
||||
return Url(scheme, auth, host, port, path, query, fragment)
|
||||
if query is not None or fragment is not None:
|
||||
path = ""
|
||||
else:
|
||||
path = None
|
||||
|
||||
# Fragment
|
||||
if '#' in path:
|
||||
path, fragment = path.split('#', 1)
|
||||
# Ensure that each part of the URL is a `str` for
|
||||
# backwards compatibility.
|
||||
if isinstance(url, six.text_type):
|
||||
ensure_func = six.ensure_text
|
||||
else:
|
||||
ensure_func = six.ensure_str
|
||||
|
||||
# Query
|
||||
if '?' in path:
|
||||
path, query = path.split('?', 1)
|
||||
def ensure_type(x):
|
||||
return x if x is None else ensure_func(x)
|
||||
|
||||
return Url(scheme, auth, host, port, path, query, fragment)
|
||||
return Url(
|
||||
scheme=ensure_type(scheme),
|
||||
auth=ensure_type(auth),
|
||||
host=ensure_type(host),
|
||||
port=port,
|
||||
path=ensure_type(path),
|
||||
query=ensure_type(query),
|
||||
fragment=ensure_type(fragment),
|
||||
)
|
||||
|
||||
|
||||
def get_host(url):
|
||||
@@ -227,4 +436,4 @@ def get_host(url):
|
||||
Deprecated. Use :func:`parse_url` instead.
|
||||
"""
|
||||
p = parse_url(url)
|
||||
return p.scheme or 'http', p.hostname, p.port
|
||||
return p.scheme or "http", p.hostname, p.port
|
||||
|
||||
+146
-33
@@ -1,40 +1,153 @@
|
||||
from .selectors import (
|
||||
HAS_SELECT,
|
||||
DefaultSelector,
|
||||
EVENT_READ,
|
||||
EVENT_WRITE
|
||||
)
|
||||
import errno
|
||||
from functools import partial
|
||||
import select
|
||||
import sys
|
||||
|
||||
try:
|
||||
from time import monotonic
|
||||
except ImportError:
|
||||
from time import time as monotonic
|
||||
|
||||
__all__ = ["NoWayToWaitForSocketError", "wait_for_read", "wait_for_write"]
|
||||
|
||||
|
||||
def _wait_for_io_events(socks, events, timeout=None):
|
||||
""" Waits for IO events to be available from a list of sockets
|
||||
or optionally a single socket if passed in. Returns a list of
|
||||
sockets that can be interacted with immediately. """
|
||||
if not HAS_SELECT:
|
||||
raise ValueError('Platform does not have a selector')
|
||||
if not isinstance(socks, list):
|
||||
# Probably just a single socket.
|
||||
if hasattr(socks, "fileno"):
|
||||
socks = [socks]
|
||||
# Otherwise it might be a non-list iterable.
|
||||
class NoWayToWaitForSocketError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# How should we wait on sockets?
|
||||
#
|
||||
# There are two types of APIs you can use for waiting on sockets: the fancy
|
||||
# modern stateful APIs like epoll/kqueue, and the older stateless APIs like
|
||||
# select/poll. The stateful APIs are more efficient when you have a lots of
|
||||
# sockets to keep track of, because you can set them up once and then use them
|
||||
# lots of times. But we only ever want to wait on a single socket at a time
|
||||
# and don't want to keep track of state, so the stateless APIs are actually
|
||||
# more efficient. So we want to use select() or poll().
|
||||
#
|
||||
# Now, how do we choose between select() and poll()? On traditional Unixes,
|
||||
# select() has a strange calling convention that makes it slow, or fail
|
||||
# altogether, for high-numbered file descriptors. The point of poll() is to fix
|
||||
# that, so on Unixes, we prefer poll().
|
||||
#
|
||||
# On Windows, there is no poll() (or at least Python doesn't provide a wrapper
|
||||
# for it), but that's OK, because on Windows, select() doesn't have this
|
||||
# strange calling convention; plain select() works fine.
|
||||
#
|
||||
# So: on Windows we use select(), and everywhere else we use poll(). We also
|
||||
# fall back to select() in case poll() is somehow broken or missing.
|
||||
|
||||
if sys.version_info >= (3, 5):
|
||||
# Modern Python, that retries syscalls by default
|
||||
def _retry_on_intr(fn, timeout):
|
||||
return fn(timeout)
|
||||
|
||||
|
||||
else:
|
||||
# Old and broken Pythons.
|
||||
def _retry_on_intr(fn, timeout):
|
||||
if timeout is None:
|
||||
deadline = float("inf")
|
||||
else:
|
||||
socks = list(socks)
|
||||
with DefaultSelector() as selector:
|
||||
for sock in socks:
|
||||
selector.register(sock, events)
|
||||
return [key[0].fileobj for key in
|
||||
selector.select(timeout) if key[1] & events]
|
||||
deadline = monotonic() + timeout
|
||||
|
||||
while True:
|
||||
try:
|
||||
return fn(timeout)
|
||||
# OSError for 3 <= pyver < 3.5, select.error for pyver <= 2.7
|
||||
except (OSError, select.error) as e:
|
||||
# 'e.args[0]' incantation works for both OSError and select.error
|
||||
if e.args[0] != errno.EINTR:
|
||||
raise
|
||||
else:
|
||||
timeout = deadline - monotonic()
|
||||
if timeout < 0:
|
||||
timeout = 0
|
||||
if timeout == float("inf"):
|
||||
timeout = None
|
||||
continue
|
||||
|
||||
|
||||
def wait_for_read(socks, timeout=None):
|
||||
""" Waits for reading to be available from a list of sockets
|
||||
or optionally a single socket if passed in. Returns a list of
|
||||
sockets that can be read from immediately. """
|
||||
return _wait_for_io_events(socks, EVENT_READ, timeout)
|
||||
def select_wait_for_socket(sock, read=False, write=False, timeout=None):
|
||||
if not read and not write:
|
||||
raise RuntimeError("must specify at least one of read=True, write=True")
|
||||
rcheck = []
|
||||
wcheck = []
|
||||
if read:
|
||||
rcheck.append(sock)
|
||||
if write:
|
||||
wcheck.append(sock)
|
||||
# When doing a non-blocking connect, most systems signal success by
|
||||
# marking the socket writable. Windows, though, signals success by marked
|
||||
# it as "exceptional". We paper over the difference by checking the write
|
||||
# sockets for both conditions. (The stdlib selectors module does the same
|
||||
# thing.)
|
||||
fn = partial(select.select, rcheck, wcheck, wcheck)
|
||||
rready, wready, xready = _retry_on_intr(fn, timeout)
|
||||
return bool(rready or wready or xready)
|
||||
|
||||
|
||||
def wait_for_write(socks, timeout=None):
|
||||
""" Waits for writing to be available from a list of sockets
|
||||
or optionally a single socket if passed in. Returns a list of
|
||||
sockets that can be written to immediately. """
|
||||
return _wait_for_io_events(socks, EVENT_WRITE, timeout)
|
||||
def poll_wait_for_socket(sock, read=False, write=False, timeout=None):
|
||||
if not read and not write:
|
||||
raise RuntimeError("must specify at least one of read=True, write=True")
|
||||
mask = 0
|
||||
if read:
|
||||
mask |= select.POLLIN
|
||||
if write:
|
||||
mask |= select.POLLOUT
|
||||
poll_obj = select.poll()
|
||||
poll_obj.register(sock, mask)
|
||||
|
||||
# For some reason, poll() takes timeout in milliseconds
|
||||
def do_poll(t):
|
||||
if t is not None:
|
||||
t *= 1000
|
||||
return poll_obj.poll(t)
|
||||
|
||||
return bool(_retry_on_intr(do_poll, timeout))
|
||||
|
||||
|
||||
def null_wait_for_socket(*args, **kwargs):
|
||||
raise NoWayToWaitForSocketError("no select-equivalent available")
|
||||
|
||||
|
||||
def _have_working_poll():
|
||||
# Apparently some systems have a select.poll that fails as soon as you try
|
||||
# to use it, either due to strange configuration or broken monkeypatching
|
||||
# from libraries like eventlet/greenlet.
|
||||
try:
|
||||
poll_obj = select.poll()
|
||||
_retry_on_intr(poll_obj.poll, 0)
|
||||
except (AttributeError, OSError):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def wait_for_socket(*args, **kwargs):
|
||||
# We delay choosing which implementation to use until the first time we're
|
||||
# called. We could do it at import time, but then we might make the wrong
|
||||
# decision if someone goes wild with monkeypatching select.poll after
|
||||
# we're imported.
|
||||
global wait_for_socket
|
||||
if _have_working_poll():
|
||||
wait_for_socket = poll_wait_for_socket
|
||||
elif hasattr(select, "select"):
|
||||
wait_for_socket = select_wait_for_socket
|
||||
else: # Platform-specific: Appengine.
|
||||
wait_for_socket = null_wait_for_socket
|
||||
return wait_for_socket(*args, **kwargs)
|
||||
|
||||
|
||||
def wait_for_read(sock, timeout=None):
|
||||
""" Waits for reading to be available on a given socket.
|
||||
Returns True if the socket is readable, or False if the timeout expired.
|
||||
"""
|
||||
return wait_for_socket(sock, read=True, timeout=timeout)
|
||||
|
||||
|
||||
def wait_for_write(sock, timeout=None):
|
||||
""" Waits for writing to be available on a given socket.
|
||||
Returns True if the socket is readable, or False if the timeout expired.
|
||||
"""
|
||||
return wait_for_socket(sock, write=True, timeout=timeout)
|
||||
|
||||
Reference in New Issue
Block a user