This commit is contained in:
marco
2019-12-04 20:45:07 +01:00
parent c313b27b27
commit 5307db2a3c
117 changed files with 59629 additions and 0 deletions

View File

@@ -0,0 +1,544 @@
import logging
import re
# import OpenSSL
import sys
import ssl
import requests
try:
import copyreg
except ImportError:
import copy_reg as copyreg
from copy import deepcopy
from time import sleep
from collections import OrderedDict
from requests.sessions import Session
from requests.adapters import HTTPAdapter
from .interpreters import JavaScriptInterpreter
from .reCaptcha import reCaptcha
from .user_agent import User_Agent
try:
from requests_toolbelt.utils import dump
except ImportError:
pass
try:
import brotli
except ImportError:
pass
try:
from urlparse import urlparse
except ImportError:
from urllib.parse import urlparse
# ------------------------------------------------------------------------------- #
__version__ = '1.2.15'
# ------------------------------------------------------------------------------- #
class CipherSuiteAdapter(HTTPAdapter):
__attrs__ = [
'ssl_context',
'max_retries',
'config',
'_pool_connections',
'_pool_maxsize',
'_pool_block'
]
def __init__(self, *args, **kwargs):
self.ssl_context = kwargs.pop('ssl_context', None)
self.cipherSuite = kwargs.pop('cipherSuite', None)
if not self.ssl_context:
self.ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
self.ssl_context.set_ciphers(self.cipherSuite)
self.ssl_context.options |= (ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1)
super(CipherSuiteAdapter, self).__init__(**kwargs)
# ------------------------------------------------------------------------------- #
def init_poolmanager(self, *args, **kwargs):
kwargs['ssl_context'] = self.ssl_context
return super(CipherSuiteAdapter, self).init_poolmanager(*args, **kwargs)
# ------------------------------------------------------------------------------- #
def proxy_manager_for(self, *args, **kwargs):
kwargs['ssl_context'] = self.ssl_context
return super(CipherSuiteAdapter, self).proxy_manager_for(*args, **kwargs)
# ------------------------------------------------------------------------------- #
class CloudScraper(Session):
def __init__(self, *args, **kwargs):
self.debug = kwargs.pop('debug', False)
self.delay = kwargs.pop('delay', None)
self.cipherSuite = kwargs.pop('cipherSuite', None)
self.interpreter = kwargs.pop('interpreter', 'native')
self.recaptcha = kwargs.pop('recaptcha', {})
self.allow_brotli = kwargs.pop(
'allow_brotli',
True if 'brotli' in sys.modules.keys() else False
)
self.user_agent = User_Agent(
allow_brotli=self.allow_brotli,
browser=kwargs.pop('browser', None)
)
self._solveDepthCnt = 0
self.solveDepth = kwargs.pop('solveDepth', 3)
super(CloudScraper, self).__init__(*args, **kwargs)
# pylint: disable=E0203
if 'requests' in self.headers['User-Agent']:
# ------------------------------------------------------------------------------- #
# Set a random User-Agent if no custom User-Agent has been set
# ------------------------------------------------------------------------------- #
self.headers = self.user_agent.headers
self.mount(
'https://',
CipherSuiteAdapter(
cipherSuite=self.loadCipherSuite() if not self.cipherSuite else self.cipherSuite
)
)
# purely to allow us to pickle dump
copyreg.pickle(ssl.SSLContext, lambda obj: (obj.__class__, (obj.protocol,)))
# ------------------------------------------------------------------------------- #
# Allow us to pickle our session back with all variables
# ------------------------------------------------------------------------------- #
def __getstate__(self):
return self.__dict__
# ------------------------------------------------------------------------------- #
# debug the request via the response
# ------------------------------------------------------------------------------- #
@staticmethod
def debugRequest(req):
try:
print(dump.dump_all(req).decode('utf-8'))
except ValueError as e:
print("Debug Error: {}".format(getattr(e, 'message', e)))
# ------------------------------------------------------------------------------- #
# Decode Brotli on older versions of urllib3 manually
# ------------------------------------------------------------------------------- #
def decodeBrotli(self, resp):
if requests.packages.urllib3.__version__ < '1.25.1' and resp.headers.get('Content-Encoding') == 'br':
if self.allow_brotli and resp._content:
resp._content = brotli.decompress(resp.content)
else:
logging.warning(
'You\'re running urllib3 {}, Brotli content detected, '
'Which requires manual decompression, '
'But option allow_brotli is set to False, '
'We will not continue to decompress.'.format(requests.packages.urllib3.__version__)
)
return resp
# ------------------------------------------------------------------------------- #
# construct a cipher suite of ciphers the system actually supports
# ------------------------------------------------------------------------------- #
def loadCipherSuite(self):
if self.cipherSuite:
return self.cipherSuite
if hasattr(ssl, 'Purpose') and hasattr(ssl.Purpose, 'SERVER_AUTH'):
for cipher in self.user_agent.cipherSuite[:]:
try:
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
context.set_ciphers(cipher)
except (ssl.SSLError):
self.user_agent.cipherSuite.remove(cipher)
if self.user_agent.cipherSuite:
self.cipherSuite = ':'.join(self.user_agent.cipherSuite)
return self.cipherSuite
sys.tracebacklimit = 0
raise RuntimeError("The OpenSSL on this system does not meet the minimum cipher requirements.")
# ------------------------------------------------------------------------------- #
# Our hijacker request function
# ------------------------------------------------------------------------------- #
def request(self, method, url, *args, **kwargs):
# pylint: disable=E0203
if kwargs.get('proxies') and kwargs.get('proxies') != self.proxies:
self.proxies = kwargs.get('proxies')
resp = self.decodeBrotli(
super(CloudScraper, self).request(method, url, *args, **kwargs)
)
# ------------------------------------------------------------------------------- #
# Debug request
# ------------------------------------------------------------------------------- #
if self.debug:
self.debugRequest(resp)
# Check if Cloudflare anti-bot is on
if self.is_Challenge_Request(resp):
# ------------------------------------------------------------------------------- #
# Try to solve the challenge and send it back
# ------------------------------------------------------------------------------- #
if self._solveDepthCnt >= self.solveDepth:
sys.tracebacklimit = 0
_ = self._solveDepthCnt
self._solveDepthCnt = 0
raise RuntimeError("!!Loop Protection!! We have tried to solve {} time(s) in a row.".format(_))
self._solveDepthCnt += 1
resp = self.Challenge_Response(resp, **kwargs)
else:
if resp.status_code not in [302, 429, 503]:
self._solveDepthCnt = 0
return resp
# ------------------------------------------------------------------------------- #
# check if the response contains a valid Cloudflare challenge
# ------------------------------------------------------------------------------- #
@staticmethod
def is_IUAM_Challenge(resp):
try:
return (
resp.headers.get('Server', '').startswith('cloudflare')
and resp.status_code in [429, 503]
and re.search(
r'action="/.*?__cf_chl_jschl_tk__=\S+".*?name="jschl_vc"\svalue=.*?',
resp.text,
re.M | re.DOTALL
)
)
except AttributeError:
pass
return False
# ------------------------------------------------------------------------------- #
# check if the response contains a valid Cloudflare reCaptcha challenge
# ------------------------------------------------------------------------------- #
@staticmethod
def is_reCaptcha_Challenge(resp):
try:
return (
resp.headers.get('Server', '').startswith('cloudflare')
and resp.status_code == 403
and re.search(
r'action="/.*?__cf_chl_captcha_tk__=\S+".*?data\-sitekey=.*?',
resp.text,
re.M | re.DOTALL
)
)
except AttributeError:
pass
return False
# ------------------------------------------------------------------------------- #
# Wrapper for is_reCaptcha_Challenge and is_IUAM_Challenge
# ------------------------------------------------------------------------------- #
def is_Challenge_Request(self, resp):
if self.is_reCaptcha_Challenge(resp) or self.is_IUAM_Challenge(resp):
return True
return False
# ------------------------------------------------------------------------------- #
# Try to solve cloudflare javascript challenge.
# ------------------------------------------------------------------------------- #
@staticmethod
def IUAM_Challenge_Response(body, url, interpreter):
try:
challengeUUID = re.search(
r'id="challenge-form" action="(?P<challengeUUID>\S+)"',
body, re.M | re.DOTALL
).groupdict().get('challengeUUID', '')
payload = OrderedDict(re.findall(r'name="(r|jschl_vc|pass)"\svalue="(.*?)"', body))
except AttributeError:
sys.tracebacklimit = 0
raise RuntimeError(
"Cloudflare IUAM detected, unfortunately we can't extract the parameters correctly."
)
hostParsed = urlparse(url)
try:
payload['jschl_answer'] = JavaScriptInterpreter.dynamicImport(
interpreter
).solveChallenge(body, hostParsed.netloc)
except Exception as e:
raise RuntimeError(
'Unable to parse Cloudflare anti-bots page: {}'.format(
getattr(e, 'message', e)
)
)
return {
'url': '{}://{}{}'.format(
hostParsed.scheme,
hostParsed.netloc,
challengeUUID
),
'data': payload
}
# ------------------------------------------------------------------------------- #
# Try to solve the reCaptcha challenge via 3rd party.
# ------------------------------------------------------------------------------- #
@staticmethod
def reCaptcha_Challenge_Response(provider, provider_params, body, url):
try:
payload = re.search(
r'(name="r"\svalue="(?P<r>\S+)"|).*?challenge-form" action="(?P<challengeUUID>\S+)".*?'
r'data-ray="(?P<data_ray>\S+)".*?data-sitekey="(?P<site_key>\S+)"',
body, re.M | re.DOTALL
).groupdict()
except (AttributeError):
sys.tracebacklimit = 0
raise RuntimeError(
"Cloudflare reCaptcha detected, unfortunately we can't extract the parameters correctly."
)
hostParsed = urlparse(url)
return {
'url': '{}://{}{}'.format(
hostParsed.scheme,
hostParsed.netloc,
payload.get('challengeUUID', '')
),
'data': OrderedDict([
('r', payload.get('r', '')),
('id', payload.get('data_ray')),
(
'g-recaptcha-response',
reCaptcha.dynamicImport(
provider.lower()
).solveCaptcha(url, payload.get('site_key'), provider_params)
)
])
}
# ------------------------------------------------------------------------------- #
# Attempt to handle and send the challenge response back to cloudflare
# ------------------------------------------------------------------------------- #
def Challenge_Response(self, resp, **kwargs):
if self.is_reCaptcha_Challenge(resp):
# ------------------------------------------------------------------------------- #
# double down on the request as some websites are only checking
# if cfuid is populated before issuing reCaptcha.
# ------------------------------------------------------------------------------- #
resp = self.decodeBrotli(
super(CloudScraper, self).request(resp.request.method, resp.url, **kwargs)
)
if not self.is_reCaptcha_Challenge(resp):
return resp
# ------------------------------------------------------------------------------- #
# if no reCaptcha provider raise a runtime error.
# ------------------------------------------------------------------------------- #
if not self.recaptcha or not isinstance(self.recaptcha, dict) or not self.recaptcha.get('provider'):
sys.tracebacklimit = 0
raise RuntimeError(
"Cloudflare reCaptcha detected, unfortunately you haven't loaded an anti reCaptcha provider "
"correctly via the 'recaptcha' parameter."
)
# ------------------------------------------------------------------------------- #
# if provider is return_response, return the response without doing anything.
# ------------------------------------------------------------------------------- #
if self.recaptcha.get('provider') == 'return_response':
return resp
self.recaptcha['proxies'] = self.proxies
submit_url = self.reCaptcha_Challenge_Response(
self.recaptcha.get('provider'),
self.recaptcha,
resp.text,
resp.url
)
else:
# ------------------------------------------------------------------------------- #
# Cloudflare requires a delay before solving the challenge
# ------------------------------------------------------------------------------- #
if not self.delay:
try:
delay = float(
re.search(
r'submit\(\);\r?\n\s*},\s*([0-9]+)',
resp.text
).group(1)
) / float(1000)
if isinstance(delay, (int, float)):
self.delay = delay
except (AttributeError, ValueError):
sys.tracebacklimit = 0
raise RuntimeError("Cloudflare IUAM possibility malformed, issue extracing delay value.")
sleep(self.delay)
# ------------------------------------------------------------------------------- #
submit_url = self.IUAM_Challenge_Response(
resp.text,
resp.url,
self.interpreter
)
# ------------------------------------------------------------------------------- #
# Send the Challenge Response back to Cloudflare
# ------------------------------------------------------------------------------- #
if submit_url:
def updateAttr(obj, name, newValue):
try:
obj[name].update(newValue)
return obj[name]
except (AttributeError, KeyError):
obj[name] = {}
obj[name].update(newValue)
return obj[name]
cloudflare_kwargs = deepcopy(kwargs)
cloudflare_kwargs['allow_redirects'] = False
cloudflare_kwargs['data'] = updateAttr(
cloudflare_kwargs,
'data',
submit_url['data']
)
cloudflare_kwargs['headers'] = updateAttr(
cloudflare_kwargs,
'headers',
{
'Referer': resp.url
}
)
return self.request(
'POST',
submit_url['url'],
**cloudflare_kwargs
)
# ------------------------------------------------------------------------------- #
# We shouldn't be here.... Re-request the original query and process again....
# ------------------------------------------------------------------------------- #
return self.request(resp.request.method, resp.url, **kwargs)
# ------------------------------------------------------------------------------- #
@classmethod
def create_scraper(cls, sess=None, **kwargs):
"""
Convenience function for creating a ready-to-go CloudScraper object.
"""
scraper = cls(**kwargs)
if sess:
for attr in ['auth', 'cert', 'cookies', 'headers', 'hooks', 'params', 'proxies', 'data']:
val = getattr(sess, attr, None)
if val:
setattr(scraper, attr, val)
return scraper
# ------------------------------------------------------------------------------- #
# Functions for integrating cloudscraper with other applications and scripts
# ------------------------------------------------------------------------------- #
@classmethod
def get_tokens(cls, url, **kwargs):
scraper = cls.create_scraper(
**{
field: kwargs.pop(field, None) for field in [
'allow_brotli',
'browser',
'debug',
'delay',
'interpreter',
'recaptcha'
] if field in kwargs
}
)
try:
resp = scraper.get(url, **kwargs)
resp.raise_for_status()
except Exception:
logging.error('"{}" returned an error. Could not collect tokens.'.format(url))
raise
domain = urlparse(resp.url).netloc
# noinspection PyUnusedLocal
cookie_domain = None
for d in scraper.cookies.list_domains():
if d.startswith('.') and d in ('.{}'.format(domain)):
cookie_domain = d
break
else:
sys.tracebacklimit = 0
raise RuntimeError(
"Unable to find Cloudflare cookies. Does the site actually "
"have Cloudflare IUAM (I'm Under Attack Mode) enabled?"
)
return (
{
'__cfduid': scraper.cookies.get('__cfduid', '', domain=cookie_domain),
'cf_clearance': scraper.cookies.get('cf_clearance', '', domain=cookie_domain)
},
scraper.headers['User-Agent']
)
# ------------------------------------------------------------------------------- #
@classmethod
def get_cookie_string(cls, url, **kwargs):
"""
Convenience function for building a Cookie HTTP header value.
"""
tokens, user_agent = cls.get_tokens(url, **kwargs)
return '; '.join('='.join(pair) for pair in tokens.items()), user_agent
# ------------------------------------------------------------------------------- #
create_scraper = CloudScraper.create_scraper
get_tokens = CloudScraper.get_tokens
get_cookie_string = CloudScraper.get_cookie_string

View File

@@ -0,0 +1,54 @@
import sys
import logging
import abc
if sys.version_info >= (3, 4):
ABC = abc.ABC # noqa
else:
ABC = abc.ABCMeta('ABC', (), {})
# ------------------------------------------------------------------------------- #
interpreters = {}
BUG_REPORT = 'Cloudflare may have changed their technique, or there may be a bug in the script.'
# ------------------------------------------------------------------------------- #
class JavaScriptInterpreter(ABC):
# ------------------------------------------------------------------------------- #
@abc.abstractmethod
def __init__(self, name):
interpreters[name] = self
# ------------------------------------------------------------------------------- #
@classmethod
def dynamicImport(cls, name):
if name not in interpreters:
try:
__import__('{}.{}'.format(cls.__module__, name))
if not isinstance(interpreters.get(name), JavaScriptInterpreter):
raise ImportError('The interpreter was not initialized.')
except ImportError:
logging.error('Unable to load {} interpreter'.format(name))
raise
return interpreters[name]
# ------------------------------------------------------------------------------- #
@abc.abstractmethod
def eval(self, jsEnv, js):
pass
# ------------------------------------------------------------------------------- #
def solveChallenge(self, body, domain):
try:
return float(self.eval(body, domain))
except Exception:
logging.error('Error executing Cloudflare IUAM Javascript. {}'.format(BUG_REPORT))
raise

View File

@@ -0,0 +1,103 @@
from __future__ import absolute_import
import os
import sys
import ctypes.util
from ctypes import c_void_p, c_size_t, byref, create_string_buffer, CDLL
from . import JavaScriptInterpreter
from .encapsulated import template
# ------------------------------------------------------------------------------- #
class ChallengeInterpreter(JavaScriptInterpreter):
# ------------------------------------------------------------------------------- #
def __init__(self):
super(ChallengeInterpreter, self).__init__('chakracore')
# ------------------------------------------------------------------------------- #
def eval(self, body, domain):
chakraCoreLibrary = None
# check current working directory.
for _libraryFile in ['libChakraCore.so', 'libChakraCore.dylib', 'ChakraCore.dll']:
if os.path.isfile(os.path.join(os.getcwd(), _libraryFile)):
chakraCoreLibrary = os.path.join(os.getcwd(), _libraryFile)
continue
if not chakraCoreLibrary:
chakraCoreLibrary = ctypes.util.find_library('ChakraCore')
if not chakraCoreLibrary:
sys.tracebacklimit = 0
raise RuntimeError(
'ChakraCore library not found in current path or any of your system library paths, '
'please download from https://www.github.com/VeNoMouS/cloudscraper/tree/ChakraCore/, '
'or https://github.com/Microsoft/ChakraCore/'
)
try:
chakraCore = CDLL(chakraCoreLibrary)
except OSError:
sys.tracebacklimit = 0
raise RuntimeError('There was an error loading the ChakraCore library {}'.format(chakraCoreLibrary))
if sys.platform != 'win32':
chakraCore.DllMain(0, 1, 0)
chakraCore.DllMain(0, 2, 0)
script = create_string_buffer(template(body, domain).encode('utf-16'))
runtime = c_void_p()
chakraCore.JsCreateRuntime(0, 0, byref(runtime))
context = c_void_p()
chakraCore.JsCreateContext(runtime, byref(context))
chakraCore.JsSetCurrentContext(context)
fname = c_void_p()
chakraCore.JsCreateString(
'iuam-challenge.js',
len('iuam-challenge.js'),
byref(fname)
)
scriptSource = c_void_p()
chakraCore.JsCreateExternalArrayBuffer(
script,
len(script),
0,
0,
byref(scriptSource)
)
jsResult = c_void_p()
chakraCore.JsRun(scriptSource, 0, fname, 0x02, byref(jsResult))
resultJSString = c_void_p()
chakraCore.JsConvertValueToString(jsResult, byref(resultJSString))
stringLength = c_size_t()
chakraCore.JsCopyString(resultJSString, 0, 0, byref(stringLength))
resultSTR = create_string_buffer(stringLength.value + 1)
chakraCore.JsCopyString(
resultJSString,
byref(resultSTR),
stringLength.value + 1,
0
)
chakraCore.JsDisposeRuntime(runtime)
return resultSTR.value
# ------------------------------------------------------------------------------- #
ChallengeInterpreter()

View File

@@ -0,0 +1,58 @@
import logging
import re
# ------------------------------------------------------------------------------- #
def template(body, domain):
BUG_REPORT = 'Cloudflare may have changed their technique, or there may be a bug in the script.'
try:
js = re.search(
r'setTimeout\(function\(\){\s+(var s,t,o,p,b,r,e,a,k,i,n,g,f.+?\r?\n[\s\S]+?a\.value =.+?)\r?\n',
body
).group(1)
except Exception:
raise ValueError('Unable to identify Cloudflare IUAM Javascript on website. {}'.format(BUG_REPORT))
js = re.sub(r'\s{2,}', ' ', js, flags=re.MULTILINE | re.DOTALL).replace('\'; 121\'', '')
js += '\na.value;'
jsEnv = '''
String.prototype.italics=function(str) {{return "<i>" + this + "</i>";}};
var document = {{
createElement: function () {{
return {{ firstChild: {{ href: "https://{domain}/" }} }}
}},
getElementById: function () {{
return {{"innerHTML": "{innerHTML}"}};
}}
}};
'''
try:
innerHTML = re.search(
r'<div(?: [^<>]*)? id="([^<>]*?)">([^<>]*?)</div>',
body,
re.MULTILINE | re.DOTALL
)
innerHTML = innerHTML.group(2) if innerHTML else ''
except: # noqa
logging.error('Error extracting Cloudflare IUAM Javascript. {}'.format(BUG_REPORT))
raise
return '{}{}'.format(
re.sub(
r'\s{2,}',
' ',
jsEnv.format(
domain=domain,
innerHTML=innerHTML
),
re.MULTILINE | re.DOTALL
),
js
)
# ------------------------------------------------------------------------------- #

View File

@@ -0,0 +1,44 @@
from __future__ import absolute_import
import js2py
import logging
import base64
from . import JavaScriptInterpreter
from .encapsulated import template
from .jsunfuck import jsunfuck
# ------------------------------------------------------------------------------- #
class ChallengeInterpreter(JavaScriptInterpreter):
# ------------------------------------------------------------------------------- #
def __init__(self):
super(ChallengeInterpreter, self).__init__('js2py')
# ------------------------------------------------------------------------------- #
def eval(self, body, domain):
jsPayload = template(body, domain)
if js2py.eval_js('(+(+!+[]+[+!+[]]+(!![]+[])[!+[]+!+[]+!+[]]+[!+[]+!+[]]+[+[]])+[])[+!+[]]') == '1':
logging.warning('WARNING - Please upgrade your js2py https://github.com/PiotrDabkowski/Js2Py, applying work around for the meantime.')
jsPayload = jsunfuck(jsPayload)
def atob(s):
return base64.b64decode('{}'.format(s)).decode('utf-8')
js2py.disable_pyimport()
context = js2py.EvalJs({'atob': atob})
result = context.eval(jsPayload)
return result
# ------------------------------------------------------------------------------- #
ChallengeInterpreter()

View File

@@ -0,0 +1,97 @@
MAPPING = {
'a': '(false+"")[1]',
'b': '([]["entries"]()+"")[2]',
'c': '([]["fill"]+"")[3]',
'd': '(undefined+"")[2]',
'e': '(true+"")[3]',
'f': '(false+"")[0]',
'g': '(false+[0]+String)[20]',
'h': '(+(101))["to"+String["name"]](21)[1]',
'i': '([false]+undefined)[10]',
'j': '([]["entries"]()+"")[3]',
'k': '(+(20))["to"+String["name"]](21)',
'l': '(false+"")[2]',
'm': '(Number+"")[11]',
'n': '(undefined+"")[1]',
'o': '(true+[]["fill"])[10]',
'p': '(+(211))["to"+String["name"]](31)[1]',
'q': '(+(212))["to"+String["name"]](31)[1]',
'r': '(true+"")[1]',
's': '(false+"")[3]',
't': '(true+"")[0]',
'u': '(undefined+"")[0]',
'v': '(+(31))["to"+String["name"]](32)',
'w': '(+(32))["to"+String["name"]](33)',
'x': '(+(101))["to"+String["name"]](34)[1]',
'y': '(NaN+[Infinity])[10]',
'z': '(+(35))["to"+String["name"]](36)',
'A': '(+[]+Array)[10]',
'B': '(+[]+Boolean)[10]',
'C': 'Function("return escape")()(("")["italics"]())[2]',
'D': 'Function("return escape")()([]["fill"])["slice"]("-1")',
'E': '(RegExp+"")[12]',
'F': '(+[]+Function)[10]',
'G': '(false+Function("return Date")()())[30]',
'I': '(Infinity+"")[0]',
'M': '(true+Function("return Date")()())[30]',
'N': '(NaN+"")[0]',
'O': '(NaN+Function("return{}")())[11]',
'R': '(+[]+RegExp)[10]',
'S': '(+[]+String)[10]',
'T': '(NaN+Function("return Date")()())[30]',
'U': '(NaN+Function("return{}")()["to"+String["name"]]["call"]())[11]',
' ': '(NaN+[]["fill"])[11]',
'"': '("")["fontcolor"]()[12]',
'%': 'Function("return escape")()([]["fill"])[21]',
'&': '("")["link"](0+")[10]',
'(': '(undefined+[]["fill"])[22]',
')': '([0]+false+[]["fill"])[20]',
'+': '(+(+!+[]+(!+[]+[])[!+[]+!+[]+!+[]]+[+!+[]]+[+[]]+[+[]])+[])[2]',
',': '([]["slice"]["call"](false+"")+"")[1]',
'-': '(+(.+[0000000001])+"")[2]',
'.': '(+(+!+[]+[+!+[]]+(!![]+[])[!+[]+!+[]+!+[]]+[!+[]+!+[]]+[+[]])+[])[+!+[]]',
'/': '(false+[0])["italics"]()[10]',
':': '(RegExp()+"")[3]',
';': '("")["link"](")[14]',
'<': '("")["italics"]()[0]',
'=': '("")["fontcolor"]()[11]',
'>': '("")["italics"]()[2]',
'?': '(RegExp()+"")[2]',
'[': '([]["entries"]()+"")[0]',
']': '([]["entries"]()+"")[22]',
'{': '(true+[]["fill"])[20]',
'}': '([]["fill"]+"")["slice"]("-1")'
}
SIMPLE = {
'false': '![]',
'true': '!![]',
'undefined': '[][[]]',
'NaN': '+[![]]',
'Infinity': '+(+!+[]+(!+[]+[])[!+[]+!+[]+!+[]]+[+!+[]]+[+[]]+[+[]]+[+[]])' # +"1e1000"
}
CONSTRUCTORS = {
'Array': '[]',
'Number': '(+[])',
'String': '([]+[])',
'Boolean': '(![])',
'Function': '[]["fill"]',
'RegExp': 'Function("return/"+false+"/")()'
}
def jsunfuck(jsfuckString):
for key in sorted(MAPPING, key=lambda k: len(MAPPING[k]), reverse=True):
if MAPPING.get(key) in jsfuckString:
jsfuckString = jsfuckString.replace(MAPPING.get(key), '"{}"'.format(key))
for key in sorted(SIMPLE, key=lambda k: len(SIMPLE[k]), reverse=True):
if SIMPLE.get(key) in jsfuckString:
jsfuckString = jsfuckString.replace(SIMPLE.get(key), '{}'.format(key))
# for key in sorted(CONSTRUCTORS, key=lambda k: len(CONSTRUCTORS[k]), reverse=True):
# if CONSTRUCTORS.get(key) in jsfuckString:
# jsfuckString = jsfuckString.replace(CONSTRUCTORS.get(key), '{}'.format(key))
return jsfuckString

View File

@@ -0,0 +1,120 @@
from __future__ import absolute_import
import re
import operator as op
from . import JavaScriptInterpreter
# ------------------------------------------------------------------------------- #
class ChallengeInterpreter(JavaScriptInterpreter):
def __init__(self):
super(ChallengeInterpreter, self).__init__('native')
def eval(self, body, domain):
# ------------------------------------------------------------------------------- #
operators = {
'+': op.add,
'-': op.sub,
'*': op.mul,
'/': op.truediv
}
# ------------------------------------------------------------------------------- #
def jsfuckToNumber(jsFuck):
t = ''
split_numbers = re.compile(r'-?\d+').findall
for i in re.findall(
r'\((?:\d|\+|\-)*\)',
jsFuck.replace('!+[]', '1').replace('!![]', '1').replace('[]', '0').lstrip('+').replace('(+', '(')
):
t = '{}{}'.format(t, sum(int(x) for x in split_numbers(i)))
return int(t)
# ------------------------------------------------------------------------------- #
def divisorMath(payload, needle, domain):
jsfuckMath = payload.split('/')
if needle in jsfuckMath[1]:
expression = re.findall(r"^(.*?)(.)\(function", jsfuckMath[1])[0]
expression_value = operators[expression[1]](
float(jsfuckToNumber(expression[0])),
float(ord(domain[jsfuckToNumber(jsfuckMath[1][
jsfuckMath[1].find('"("+p+")")}') + len('"("+p+")")}'):-2
])]))
)
else:
expression_value = jsfuckToNumber(jsfuckMath[1])
expression_value = jsfuckToNumber(jsfuckMath[0]) / float(expression_value)
return expression_value
# ------------------------------------------------------------------------------- #
def challengeSolve(body, domain):
jschl_answer = 0
jsfuckChallenge = re.search(
r"setTimeout\(function\(\){\s+var.*?f,\s*(?P<variable>\w+).*?:(?P<init>\S+)};"
r".*?\('challenge-form'\);\s+;(?P<challenge>.*?a\.value)"
r"(?:.*id=\"cf-dn-.*?>(?P<k>\S+)<)?",
body,
re.DOTALL | re.MULTILINE
).groupdict()
jsfuckChallenge['challenge'] = re.finditer(
r'{}.*?([+\-*/])=(.*?);(?=a\.value|{})'.format(
jsfuckChallenge['variable'],
jsfuckChallenge['variable']
),
jsfuckChallenge['challenge']
)
# ------------------------------------------------------------------------------- #
if '/' in jsfuckChallenge['init']:
val = jsfuckChallenge['init'].split('/')
jschl_answer = jsfuckToNumber(val[0]) / float(jsfuckToNumber(val[1]))
else:
jschl_answer = jsfuckToNumber(jsfuckChallenge['init'])
# ------------------------------------------------------------------------------- #
for expressionMatch in jsfuckChallenge['challenge']:
oper, expression = expressionMatch.groups()
if '/' in expression:
expression_value = divisorMath(expression, 'function(p)', domain)
else:
if 'Element' in expression:
expression_value = divisorMath(jsfuckChallenge['k'], '"("+p+")")}', domain)
else:
expression_value = jsfuckToNumber(expression)
jschl_answer = operators[oper](jschl_answer, expression_value)
# ------------------------------------------------------------------------------- #
if not jsfuckChallenge['k'] and '+ t.length' in body:
jschl_answer += len(domain)
# ------------------------------------------------------------------------------- #
return '{0:.10f}'.format(jschl_answer)
# ------------------------------------------------------------------------------- #
return challengeSolve(body, domain)
# ------------------------------------------------------------------------------- #
ChallengeInterpreter()

View File

@@ -0,0 +1,47 @@
import base64
import subprocess
import sys
from . import JavaScriptInterpreter
from .encapsulated import template
# ------------------------------------------------------------------------------- #
class ChallengeInterpreter(JavaScriptInterpreter):
# ------------------------------------------------------------------------------- #
def __init__(self):
super(ChallengeInterpreter, self).__init__('nodejs')
# ------------------------------------------------------------------------------- #
def eval(self, body, domain):
try:
js = 'var atob = function(str) {return Buffer.from(str, "base64").toString("binary");};' \
'var challenge = atob("%s");' \
'var context = {atob: atob};' \
'var options = {filename: "iuam-challenge.js", timeout: 4000};' \
'var answer = require("vm").runInNewContext(challenge, context, options);' \
'process.stdout.write(String(answer));' \
% base64.b64encode(template(body, domain).encode('UTF-8')).decode('ascii')
return subprocess.check_output(['node', '-e', js])
except OSError as e:
if e.errno == 2:
raise EnvironmentError(
'Missing Node.js runtime. Node is required and must be in the PATH (check with `node -v`). Your Node binary may be called `nodejs` rather than `node`, '
'in which case you may need to run `apt-get install nodejs-legacy` on some Debian-based systems. (Please read the cloudscraper'
' README\'s Dependencies section: https://github.com/VeNoMouS/cloudscraper#dependencies.'
)
raise
except Exception:
sys.tracebacklimit = 0
raise RuntimeError('Error executing Cloudflare IUAM Javascript in nodejs')
# ------------------------------------------------------------------------------- #
ChallengeInterpreter()

View File

@@ -0,0 +1,33 @@
from __future__ import absolute_import
import sys
try:
import v8eval
except ImportError:
sys.tracebacklimit = 0
raise RuntimeError('Please install the python module v8eval either via pip or download it from https://github.com/sony/v8eval')
from . import JavaScriptInterpreter
from .encapsulated import template
# ------------------------------------------------------------------------------- #
class ChallengeInterpreter(JavaScriptInterpreter):
def __init__(self):
super(ChallengeInterpreter, self).__init__('v8')
# ------------------------------------------------------------------------------- #
def eval(self, body, domain):
try:
return v8eval.V8().eval(template(body, domain))
except (TypeError, v8eval.V8Error):
RuntimeError('We encountered an error running the V8 Engine.')
# ------------------------------------------------------------------------------- #
ChallengeInterpreter()

View File

@@ -0,0 +1,206 @@
from __future__ import absolute_import
import requests
try:
import polling
except ImportError:
import sys
sys.tracebacklimit = 0
raise RuntimeError("Please install the python module 'polling' via pip or download it from https://github.com/justiniso/polling/")
from . import reCaptcha
class captchaSolver(reCaptcha):
def __init__(self):
super(captchaSolver, self).__init__('2captcha')
self.host = 'https://2captcha.com'
self.session = requests.Session()
# ------------------------------------------------------------------------------- #
@staticmethod
def checkErrorStatus(response, request_type):
if response.status_code in [500, 502]:
raise RuntimeError('2Captcha: Server Side Error {}'.format(response.status_code))
errors = {
'in.php': {
"ERROR_WRONG_USER_KEY": "You've provided api_key parameter value is in incorrect format, it should contain 32 symbols.",
"ERROR_KEY_DOES_NOT_EXIST": "The api_key you've provided does not exists.",
"ERROR_ZERO_BALANCE": "You don't have sufficient funds on your account.",
"ERROR_PAGEURL": "pageurl parameter is missing in your request.",
"ERROR_NO_SLOT_AVAILABLE":
"No Slots Available.\nYou can receive this error in two cases:\n"
"1. If you solve ReCaptcha: the queue of your captchas that are not distributed to workers is too long. "
"Queue limit changes dynamically and depends on total amount of captchas awaiting solution and usually it's between 50 and 100 captchas.\n"
"2. If you solve Normal Captcha: your maximum rate for normal captchas is lower than current rate on the server."
"You can change your maximum rate in your account's settings.",
"ERROR_IP_NOT_ALLOWED": "The request is sent from the IP that is not on the list of your allowed IPs.",
"IP_BANNED": "Your IP address is banned due to many frequent attempts to access the server using wrong authorization keys.",
"ERROR_BAD_TOKEN_OR_PAGEURL":
"You can get this error code when sending ReCaptcha V2. "
"That happens if your request contains invalid pair of googlekey and pageurl. "
"The common reason for that is that ReCaptcha is loaded inside an iframe hosted on another domain/subdomain.",
"ERROR_GOOGLEKEY":
"You can get this error code when sending ReCaptcha V2. "
"That means that sitekey value provided in your request is incorrect: it's blank or malformed.",
"MAX_USER_TURN": "You made more than 60 requests within 3 seconds.Your account is banned for 10 seconds. Ban will be lifted automatically."
},
'res.php': {
"ERROR_CAPTCHA_UNSOLVABLE":
"We are unable to solve your captcha - three of our workers were unable solve it "
"or we didn't get an answer within 90 seconds (300 seconds for ReCaptcha V2). "
"We will not charge you for that request.",
"ERROR_WRONG_USER_KEY": "You've provided api_key parameter value in incorrect format, it should contain 32 symbols.",
"ERROR_KEY_DOES_NOT_EXIST": "The api_key you've provided does not exists.",
"ERROR_WRONG_ID_FORMAT": "You've provided captcha ID in wrong format. The ID can contain numbers only.",
"ERROR_WRONG_CAPTCHA_ID": "You've provided incorrect captcha ID.",
"ERROR_BAD_DUPLICATES":
"Error is returned when 100% accuracy feature is enabled. "
"The error means that max numbers of tries is reached but min number of matches not found.",
"REPORT_NOT_RECORDED": "Error is returned to your complain request if you already complained lots of correctly solved captchas.",
"ERROR_IP_ADDRES":
"You can receive this error code when registering a pingback (callback) IP or domain."
"That happes if your request is coming from an IP address that doesn't match the IP address of your pingback IP or domain.",
"ERROR_TOKEN_EXPIRED": "You can receive this error code when sending GeeTest. That error means that challenge value you provided is expired.",
"ERROR_EMPTY_ACTION": "Action parameter is missing or no value is provided for action parameter."
}
}
if response.json().get('status') is False and response.json().get('request') in errors.get(request_type):
raise RuntimeError('{} {}'.format(response.json().get('request'), errors.get(request_type).get(response.json().get('request'))))
# ------------------------------------------------------------------------------- #
def reportJob(self, jobID):
if not jobID:
raise RuntimeError("2Captcha: Error bad job id to request reCaptcha.")
def _checkRequest(response):
if response.status_code in [200, 303] and response.json().get('status') == 1:
return response
self.checkErrorStatus(response, 'res.php')
return None
response = polling.poll(
lambda: self.session.get(
'{}/res.php'.format(self.host),
params={
'key': self.api_key,
'action': 'reportbad',
'id': jobID,
'json': '1'
}
),
check_success=_checkRequest,
step=5,
timeout=180
)
if response:
return True
else:
raise RuntimeError("2Captcha: Error - Failed to report bad reCaptcha solve.")
# ------------------------------------------------------------------------------- #
def requestJob(self, jobID):
if not jobID:
raise RuntimeError("2Captcha: Error bad job id to request reCaptcha.")
def _checkRequest(response):
if response.status_code in [200, 303] and response.json().get('status') == 1:
return response
self.checkErrorStatus(response, 'res.php')
return None
response = polling.poll(
lambda: self.session.get(
'{}/res.php'.format(self.host),
params={
'key': self.api_key,
'action': 'get',
'id': jobID,
'json': '1'
}
),
check_success=_checkRequest,
step=5,
timeout=180
)
if response:
return response.json().get('request')
else:
raise RuntimeError("2Captcha: Error failed to solve reCaptcha.")
# ------------------------------------------------------------------------------- #
def requestSolve(self, site_url, site_key):
def _checkRequest(response):
if response.status_code in [200, 303] and response.json().get("status") == 1 and response.json().get('request'):
return response
self.checkErrorStatus(response, 'in.php')
return None
response = polling.poll(
lambda: self.session.post(
'{}/in.php'.format(self.host),
data={
'key': self.api_key,
'method': 'userrecaptcha',
'googlekey': site_key,
'pageurl': site_url,
'json': '1',
'soft_id': '5507698'
},
allow_redirects=False
),
check_success=_checkRequest,
step=5,
timeout=180
)
if response:
return response.json().get('request')
else:
raise RuntimeError('2Captcha: Error no job id was returned.')
# ------------------------------------------------------------------------------- #
def getCaptchaAnswer(self, site_url, site_key, reCaptchaParams):
jobID = None
if not reCaptchaParams.get('api_key'):
raise ValueError("2Captcha: Missing api_key parameter.")
self.api_key = reCaptchaParams.get('api_key')
if reCaptchaParams.get('proxy'):
self.session.proxies = reCaptchaParams.get('proxies')
try:
jobID = self.requestSolve(site_url, site_key)
return self.requestJob(jobID)
except polling.TimeoutException:
try:
if jobID:
self.reportJob(jobID)
except polling.TimeoutException:
raise RuntimeError("2Captcha: reCaptcha solve took to long and also failed reporting the job.")
raise RuntimeError("2Captcha: reCaptcha solve took to long to execute, aborting.")
# ------------------------------------------------------------------------------- #
captchaSolver()

View File

@@ -0,0 +1,46 @@
import sys
import logging
import abc
if sys.version_info >= (3, 4):
ABC = abc.ABC # noqa
else:
ABC = abc.ABCMeta('ABC', (), {})
# ------------------------------------------------------------------------------- #
captchaSolvers = {}
# ------------------------------------------------------------------------------- #
class reCaptcha(ABC):
@abc.abstractmethod
def __init__(self, name):
captchaSolvers[name] = self
# ------------------------------------------------------------------------------- #
@classmethod
def dynamicImport(cls, name):
if name not in captchaSolvers:
try:
__import__('{}.{}'.format(cls.__module__, name))
if not isinstance(captchaSolvers.get(name), reCaptcha):
raise ImportError('The anti reCaptcha provider was not initialized.')
except ImportError:
logging.error("Unable to load {} anti reCaptcha provider".format(name))
raise
return captchaSolvers[name]
# ------------------------------------------------------------------------------- #
@abc.abstractmethod
def getCaptchaAnswer(self, site_url, site_key, reCaptchaParams):
pass
# ------------------------------------------------------------------------------- #
def solveCaptcha(self, site_url, site_key, reCaptchaParams):
return self.getCaptchaAnswer(site_url, site_key, reCaptchaParams)

View File

@@ -0,0 +1,38 @@
from __future__ import absolute_import
import sys
try:
from python_anticaptcha import AnticaptchaClient, NoCaptchaTaskProxylessTask
except ImportError:
sys.tracebacklimit = 0
raise RuntimeError("Please install the python module 'python_anticaptcha' via pip or download it from https://github.com/ad-m/python-anticaptcha")
from . import reCaptcha
class captchaSolver(reCaptcha):
def __init__(self):
super(captchaSolver, self).__init__('anticaptcha')
def getCaptchaAnswer(self, site_url, site_key, reCaptchaParams):
if not reCaptchaParams.get('api_key'):
raise ValueError("reCaptcha provider 'anticaptcha' was not provided an 'api_key' parameter.")
client = AnticaptchaClient(reCaptchaParams.get('api_key'))
if reCaptchaParams.get('proxy'):
client.session.proxies = reCaptchaParams.get('proxies')
task = NoCaptchaTaskProxylessTask(site_url, site_key)
if not hasattr(client, 'createTaskSmee'):
sys.tracebacklimit = 0
raise RuntimeError("Please upgrade 'python_anticaptcha' via pip or download it from https://github.com/ad-m/python-anticaptcha")
job = client.createTaskSmee(task)
return job.get_solution_response()
captchaSolver()

View File

@@ -0,0 +1,201 @@
from __future__ import absolute_import
import json
import requests
try:
import polling
except ImportError:
import sys
sys.tracebacklimit = 0
raise RuntimeError("Please install the python module 'polling' via pip or download it from https://github.com/justiniso/polling/")
from . import reCaptcha
class captchaSolver(reCaptcha):
def __init__(self):
super(captchaSolver, self).__init__('deathbycaptcha')
self.host = 'http://api.dbcapi.me/api'
self.session = requests.Session()
# ------------------------------------------------------------------------------- #
@staticmethod
def checkErrorStatus(response):
errors = dict(
[
(400, "DeathByCaptcha: 400 Bad Request"),
(403, "DeathByCaptcha: 403 Forbidden - Invalid credentails or insufficient credits."),
# (500, "DeathByCaptcha: 500 Internal Server Error."),
(503, "DeathByCaptcha: 503 Service Temporarily Unavailable.")
]
)
if response.status_code in errors:
raise RuntimeError(errors.get(response.status_code))
# ------------------------------------------------------------------------------- #
def login(self, username, password):
self.username = username
self.password = password
def _checkRequest(response):
if response.status_code == 200:
if response.json().get('is_banned'):
raise RuntimeError('DeathByCaptcha: Your account is banned.')
if response.json().get('balanace') == 0:
raise RuntimeError('DeathByCaptcha: insufficient credits.')
return response
self.checkErrorStatus(response)
return None
response = polling.poll(
lambda: self.session.post(
'{}/user'.format(self.host),
headers={'Accept': 'application/json'},
data={
'username': self.username,
'password': self.password
}
),
check_success=_checkRequest,
step=10,
timeout=120
)
self.debugRequest(response)
# ------------------------------------------------------------------------------- #
def reportJob(self, jobID):
if not jobID:
raise RuntimeError("DeathByCaptcha: Error bad job id to report failed reCaptcha.")
def _checkRequest(response):
if response.status_code == 200:
return response
self.checkErrorStatus(response)
return None
response = polling.poll(
lambda: self.session.post(
'{}/captcha/{}/report'.format(self.host, jobID),
headers={'Accept': 'application/json'},
data={
'username': self.username,
'password': self.password
}
),
check_success=_checkRequest,
step=10,
timeout=180
)
if response:
return True
else:
raise RuntimeError("DeathByCaptcha: Error report failed reCaptcha.")
# ------------------------------------------------------------------------------- #
def requestJob(self, jobID):
if not jobID:
raise RuntimeError("DeathByCaptcha: Error bad job id to request reCaptcha.")
def _checkRequest(response):
if response.status_code in [200, 303] and response.json().get('text'):
return response
self.checkErrorStatus(response)
return None
response = polling.poll(
lambda: self.session.get(
'{}/captcha/{}'.format(self.host, jobID),
headers={'Accept': 'application/json'}
),
check_success=_checkRequest,
step=10,
timeout=180
)
if response:
return response.json().get('text')
else:
raise RuntimeError("DeathByCaptcha: Error failed to solve reCaptcha.")
# ------------------------------------------------------------------------------- #
def requestSolve(self, site_url, site_key):
def _checkRequest(response):
if response.status_code in [200, 303] and response.json().get("is_correct") and response.json().get('captcha'):
return response
self.checkErrorStatus(response)
return None
response = polling.poll(
lambda: self.session.post(
'{}/captcha'.format(self.host),
headers={'Accept': 'application/json'},
data={
'username': self.username,
'password': self.password,
'type': '4',
'token_params': json.dumps({
'googlekey': site_key,
'pageurl': site_url
})
},
allow_redirects=False
),
check_success=_checkRequest,
step=10,
timeout=180
)
if response:
return response.json().get('captcha')
else:
raise RuntimeError('DeathByCaptcha: Error no job id was returned.')
# ------------------------------------------------------------------------------- #
def getCaptchaAnswer(self, site_url, site_key, reCaptchaParams):
jobID = None
for param in ['username', 'password']:
if not reCaptchaParams.get(param):
raise ValueError("DeathByCaptcha: Missing '{}' parameter.".format(param))
setattr(self, param, reCaptchaParams.get(param))
if reCaptchaParams.get('proxy'):
self.session.proxies = reCaptchaParams.get('proxies')
try:
jobID = self.requestSolve(site_url, site_key)
return self.requestJob(jobID)
except polling.TimeoutException:
try:
if jobID:
self.reportJob(jobID)
except polling.TimeoutException:
raise RuntimeError("DeathByCaptcha: reCaptcha solve took to long and also failed reporting the job.")
raise RuntimeError("DeathByCaptcha: reCaptcha solve took to long to execute, aborting.")
# ------------------------------------------------------------------------------- #
captchaSolver()

View File

@@ -0,0 +1,111 @@
import json
import os
import random
import sys
import ssl
import re
from collections import OrderedDict
# ------------------------------------------------------------------------------- #
class User_Agent():
# ------------------------------------------------------------------------------- #
def __init__(self, *args, **kwargs):
self.headers = None
self.cipherSuite = []
self.loadUserAgent(*args, **kwargs)
# ------------------------------------------------------------------------------- #
def loadHeaders(self, user_agents, user_agent_version):
if user_agents.get(self.browser).get('releases').get(user_agent_version).get('headers'):
self.headers = user_agents.get(self.browser).get('releases').get(user_agent_version).get('headers')
else:
self.headers = user_agents.get(self.browser).get('default_headers')
# ------------------------------------------------------------------------------- #
def filterAgents(self, releases):
filtered = {}
for release in releases:
if self.mobile and releases[release]['User-Agent']['mobile']:
filtered[release] = filtered.get(release, []) + releases[release]['User-Agent']['mobile']
if self.desktop and releases[release]['User-Agent']['desktop']:
filtered[release] = filtered.get(release, []) + releases[release]['User-Agent']['desktop']
return filtered
# ------------------------------------------------------------------------------- #
def tryMatchCustom(self, user_agents):
for browser in user_agents:
for release in user_agents[browser]['releases']:
for platform in ['mobile', 'desktop']:
if re.search(self.custom, ' '.join(user_agents[browser]['releases'][release]['User-Agent'][platform])):
self.browser = browser
self.loadHeaders(user_agents, release)
self.headers['User-Agent'] = self.custom
self.cipherSuite = user_agents[self.browser].get('cipherSuite', [])
return True
return None
# ------------------------------------------------------------------------------- #
def loadUserAgent(self, *args, **kwargs):
self.browser = kwargs.pop('browser', None)
if isinstance(self.browser, dict):
self.custom = self.browser.get('custom', None)
self.desktop = self.browser.get('desktop', True)
self.mobile = self.browser.get('mobile', True)
self.browser = self.browser.get('browser', None)
else:
self.custom = kwargs.pop('custom', None)
self.desktop = kwargs.pop('desktop', True)
self.mobile = kwargs.pop('mobile', True)
if not self.desktop and not self.mobile:
sys.tracebacklimit = 0
raise RuntimeError("Sorry you can't have mobile and desktop disabled at the same time.")
user_agents = json.load(
open(os.path.join(os.path.dirname(__file__), 'browsers.json'), 'r'),
object_pairs_hook=OrderedDict
)
if self.custom:
if not self.tryMatchCustom(user_agents):
self.cipherSuite = '{}:!ECDHE+SHA:!AES128-SHA'.format(ssl._DEFAULT_CIPHERS).split(':')
self.headers = OrderedDict([
('User-Agent', self.custom),
('Accept', 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8'),
('Accept-Language', 'en-US,en;q=0.9'),
('Accept-Encoding', 'gzip, deflate, br')
])
else:
if self.browser and not user_agents.get(self.browser):
sys.tracebacklimit = 0
raise RuntimeError('Sorry "{}" browser User-Agent was not found.'.format(self.browser))
if not self.browser:
self.browser = random.SystemRandom().choice(list(user_agents))
self.cipherSuite = user_agents.get(self.browser).get('cipherSuite', [])
filteredAgents = self.filterAgents(user_agents.get(self.browser).get('releases'))
user_agent_version = random.SystemRandom().choice(list(filteredAgents))
self.loadHeaders(user_agents, user_agent_version)
self.headers['User-Agent'] = random.SystemRandom().choice(filteredAgents[user_agent_version])
if not kwargs.get('allow_brotli', False):
if 'br' in self.headers['Accept-Encoding']:
self.headers['Accept-Encoding'] = ','.join([encoding for encoding in self.headers['Accept-Encoding'].split(',') if encoding.strip() != 'br']).strip()

File diff suppressed because it is too large Load Diff

28
lib/tornado/__init__.py Executable file
View File

@@ -0,0 +1,28 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""The Tornado web server and tools."""
from __future__ import absolute_import, division, print_function
# version is a human-readable version number.
# version_info is a four-tuple for programmatic comparison. The first
# three numbers are the components of the version number. The fourth
# is zero for an official release, positive for a development branch,
# or negative for a release candidate or beta (after the base version
# number has been incremented)
version = "5.1.1"
version_info = (5, 1, 1, 0)

84
lib/tornado/_locale_data.py Executable file
View File

@@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Data used by the tornado.locale module."""
from __future__ import absolute_import, division, print_function
LOCALE_NAMES = {
"af_ZA": {"name_en": u"Afrikaans", "name": u"Afrikaans"},
"am_ET": {"name_en": u"Amharic", "name": u"አማርኛ"},
"ar_AR": {"name_en": u"Arabic", "name": u"العربية"},
"bg_BG": {"name_en": u"Bulgarian", "name": u"Български"},
"bn_IN": {"name_en": u"Bengali", "name": u"বাংলা"},
"bs_BA": {"name_en": u"Bosnian", "name": u"Bosanski"},
"ca_ES": {"name_en": u"Catalan", "name": u"Català"},
"cs_CZ": {"name_en": u"Czech", "name": u"Čeština"},
"cy_GB": {"name_en": u"Welsh", "name": u"Cymraeg"},
"da_DK": {"name_en": u"Danish", "name": u"Dansk"},
"de_DE": {"name_en": u"German", "name": u"Deutsch"},
"el_GR": {"name_en": u"Greek", "name": u"Ελληνικά"},
"en_GB": {"name_en": u"English (UK)", "name": u"English (UK)"},
"en_US": {"name_en": u"English (US)", "name": u"English (US)"},
"es_ES": {"name_en": u"Spanish (Spain)", "name": u"Español (España)"},
"es_LA": {"name_en": u"Spanish", "name": u"Español"},
"et_EE": {"name_en": u"Estonian", "name": u"Eesti"},
"eu_ES": {"name_en": u"Basque", "name": u"Euskara"},
"fa_IR": {"name_en": u"Persian", "name": u"فارسی"},
"fi_FI": {"name_en": u"Finnish", "name": u"Suomi"},
"fr_CA": {"name_en": u"French (Canada)", "name": u"Français (Canada)"},
"fr_FR": {"name_en": u"French", "name": u"Français"},
"ga_IE": {"name_en": u"Irish", "name": u"Gaeilge"},
"gl_ES": {"name_en": u"Galician", "name": u"Galego"},
"he_IL": {"name_en": u"Hebrew", "name": u"עברית"},
"hi_IN": {"name_en": u"Hindi", "name": u"हिन्दी"},
"hr_HR": {"name_en": u"Croatian", "name": u"Hrvatski"},
"hu_HU": {"name_en": u"Hungarian", "name": u"Magyar"},
"id_ID": {"name_en": u"Indonesian", "name": u"Bahasa Indonesia"},
"is_IS": {"name_en": u"Icelandic", "name": u"Íslenska"},
"it_IT": {"name_en": u"Italian", "name": u"Italiano"},
"ja_JP": {"name_en": u"Japanese", "name": u"日本語"},
"ko_KR": {"name_en": u"Korean", "name": u"한국어"},
"lt_LT": {"name_en": u"Lithuanian", "name": u"Lietuvių"},
"lv_LV": {"name_en": u"Latvian", "name": u"Latviešu"},
"mk_MK": {"name_en": u"Macedonian", "name": u"Македонски"},
"ml_IN": {"name_en": u"Malayalam", "name": u"മലയാളം"},
"ms_MY": {"name_en": u"Malay", "name": u"Bahasa Melayu"},
"nb_NO": {"name_en": u"Norwegian (bokmal)", "name": u"Norsk (bokmål)"},
"nl_NL": {"name_en": u"Dutch", "name": u"Nederlands"},
"nn_NO": {"name_en": u"Norwegian (nynorsk)", "name": u"Norsk (nynorsk)"},
"pa_IN": {"name_en": u"Punjabi", "name": u"ਪੰਜਾਬੀ"},
"pl_PL": {"name_en": u"Polish", "name": u"Polski"},
"pt_BR": {"name_en": u"Portuguese (Brazil)", "name": u"Português (Brasil)"},
"pt_PT": {"name_en": u"Portuguese (Portugal)", "name": u"Português (Portugal)"},
"ro_RO": {"name_en": u"Romanian", "name": u"Română"},
"ru_RU": {"name_en": u"Russian", "name": u"Русский"},
"sk_SK": {"name_en": u"Slovak", "name": u"Slovenčina"},
"sl_SI": {"name_en": u"Slovenian", "name": u"Slovenščina"},
"sq_AL": {"name_en": u"Albanian", "name": u"Shqip"},
"sr_RS": {"name_en": u"Serbian", "name": u"Српски"},
"sv_SE": {"name_en": u"Swedish", "name": u"Svenska"},
"sw_KE": {"name_en": u"Swahili", "name": u"Kiswahili"},
"ta_IN": {"name_en": u"Tamil", "name": u"தமிழ்"},
"te_IN": {"name_en": u"Telugu", "name": u"తెలుగు"},
"th_TH": {"name_en": u"Thai", "name": u"ภาษาไทย"},
"tl_PH": {"name_en": u"Filipino", "name": u"Filipino"},
"tr_TR": {"name_en": u"Turkish", "name": u"Türkçe"},
"uk_UA": {"name_en": u"Ukraini ", "name": u"Українська"},
"vi_VN": {"name_en": u"Vietnamese", "name": u"Tiếng Việt"},
"zh_CN": {"name_en": u"Chinese (Simplified)", "name": u"中文(简体)"},
"zh_TW": {"name_en": u"Chinese (Traditional)", "name": u"中文(繁體)"},
}

1236
lib/tornado/auth.py Executable file

File diff suppressed because it is too large Load Diff

356
lib/tornado/autoreload.py Executable file
View File

@@ -0,0 +1,356 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Automatically restart the server when a source file is modified.
Most applications should not access this module directly. Instead,
pass the keyword argument ``autoreload=True`` to the
`tornado.web.Application` constructor (or ``debug=True``, which
enables this setting and several others). This will enable autoreload
mode as well as checking for changes to templates and static
resources. Note that restarting is a destructive operation and any
requests in progress will be aborted when the process restarts. (If
you want to disable autoreload while using other debug-mode features,
pass both ``debug=True`` and ``autoreload=False``).
This module can also be used as a command-line wrapper around scripts
such as unit test runners. See the `main` method for details.
The command-line wrapper and Application debug modes can be used together.
This combination is encouraged as the wrapper catches syntax errors and
other import-time failures, while debug mode catches changes once
the server has started.
This module depends on `.IOLoop`, so it will not work in WSGI applications
and Google App Engine. It also will not work correctly when `.HTTPServer`'s
multi-process mode is used.
Reloading loses any Python interpreter command-line arguments (e.g. ``-u``)
because it re-executes Python using ``sys.executable`` and ``sys.argv``.
Additionally, modifying these variables will cause reloading to behave
incorrectly.
"""
from __future__ import absolute_import, division, print_function
import os
import sys
# sys.path handling
# -----------------
#
# If a module is run with "python -m", the current directory (i.e. "")
# is automatically prepended to sys.path, but not if it is run as
# "path/to/file.py". The processing for "-m" rewrites the former to
# the latter, so subsequent executions won't have the same path as the
# original.
#
# Conversely, when run as path/to/file.py, the directory containing
# file.py gets added to the path, which can cause confusion as imports
# may become relative in spite of the future import.
#
# We address the former problem by reconstructing the original command
# line (Python >= 3.4) or by setting the $PYTHONPATH environment
# variable (Python < 3.4) before re-execution so the new process will
# see the correct path. We attempt to address the latter problem when
# tornado.autoreload is run as __main__.
if __name__ == "__main__":
# This sys.path manipulation must come before our imports (as much
# as possible - if we introduced a tornado.sys or tornado.os
# module we'd be in trouble), or else our imports would become
# relative again despite the future import.
#
# There is a separate __main__ block at the end of the file to call main().
if sys.path[0] == os.path.dirname(__file__):
del sys.path[0]
import functools
import logging
import os
import pkgutil # type: ignore
import sys
import traceback
import types
import subprocess
import weakref
from tornado import ioloop
from tornado.log import gen_log
from tornado import process
from tornado.util import exec_in
try:
import signal
except ImportError:
signal = None
# os.execv is broken on Windows and can't properly parse command line
# arguments and executable name if they contain whitespaces. subprocess
# fixes that behavior.
_has_execv = sys.platform != 'win32'
_watched_files = set()
_reload_hooks = []
_reload_attempted = False
_io_loops = weakref.WeakKeyDictionary() # type: ignore
_autoreload_is_main = False
_original_argv = None
_original_spec = None
def start(check_time=500):
"""Begins watching source files for changes.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
io_loop = ioloop.IOLoop.current()
if io_loop in _io_loops:
return
_io_loops[io_loop] = True
if len(_io_loops) > 1:
gen_log.warning("tornado.autoreload started more than once in the same process")
modify_times = {}
callback = functools.partial(_reload_on_update, modify_times)
scheduler = ioloop.PeriodicCallback(callback, check_time)
scheduler.start()
def wait():
"""Wait for a watched file to change, then restart the process.
Intended to be used at the end of scripts like unit test runners,
to run the tests again after any source file changes (but see also
the command-line interface in `main`)
"""
io_loop = ioloop.IOLoop()
io_loop.add_callback(start)
io_loop.start()
def watch(filename):
"""Add a file to the watch list.
All imported modules are watched by default.
"""
_watched_files.add(filename)
def add_reload_hook(fn):
"""Add a function to be called before reloading the process.
Note that for open file and socket handles it is generally
preferable to set the ``FD_CLOEXEC`` flag (using `fcntl` or
``tornado.platform.auto.set_close_exec``) instead
of using a reload hook to close them.
"""
_reload_hooks.append(fn)
def _reload_on_update(modify_times):
if _reload_attempted:
# We already tried to reload and it didn't work, so don't try again.
return
if process.task_id() is not None:
# We're in a child process created by fork_processes. If child
# processes restarted themselves, they'd all restart and then
# all call fork_processes again.
return
for module in list(sys.modules.values()):
# Some modules play games with sys.modules (e.g. email/__init__.py
# in the standard library), and occasionally this can cause strange
# failures in getattr. Just ignore anything that's not an ordinary
# module.
if not isinstance(module, types.ModuleType):
continue
path = getattr(module, "__file__", None)
if not path:
continue
if path.endswith(".pyc") or path.endswith(".pyo"):
path = path[:-1]
_check_file(modify_times, path)
for path in _watched_files:
_check_file(modify_times, path)
def _check_file(modify_times, path):
try:
modified = os.stat(path).st_mtime
except Exception:
return
if path not in modify_times:
modify_times[path] = modified
return
if modify_times[path] != modified:
gen_log.info("%s modified; restarting server", path)
_reload()
def _reload():
global _reload_attempted
_reload_attempted = True
for fn in _reload_hooks:
fn()
if hasattr(signal, "setitimer"):
# Clear the alarm signal set by
# ioloop.set_blocking_log_threshold so it doesn't fire
# after the exec.
signal.setitimer(signal.ITIMER_REAL, 0, 0)
# sys.path fixes: see comments at top of file. If __main__.__spec__
# exists, we were invoked with -m and the effective path is about to
# change on re-exec. Reconstruct the original command line to
# ensure that the new process sees the same path we did. If
# __spec__ is not available (Python < 3.4), check instead if
# sys.path[0] is an empty string and add the current directory to
# $PYTHONPATH.
if _autoreload_is_main:
spec = _original_spec
argv = _original_argv
else:
spec = getattr(sys.modules['__main__'], '__spec__', None)
argv = sys.argv
if spec:
argv = ['-m', spec.name] + argv[1:]
else:
path_prefix = '.' + os.pathsep
if (sys.path[0] == '' and
not os.environ.get("PYTHONPATH", "").startswith(path_prefix)):
os.environ["PYTHONPATH"] = (path_prefix +
os.environ.get("PYTHONPATH", ""))
if not _has_execv:
subprocess.Popen([sys.executable] + argv)
os._exit(0)
else:
try:
os.execv(sys.executable, [sys.executable] + argv)
except OSError:
# Mac OS X versions prior to 10.6 do not support execv in
# a process that contains multiple threads. Instead of
# re-executing in the current process, start a new one
# and cause the current process to exit. This isn't
# ideal since the new process is detached from the parent
# terminal and thus cannot easily be killed with ctrl-C,
# but it's better than not being able to autoreload at
# all.
# Unfortunately the errno returned in this case does not
# appear to be consistent, so we can't easily check for
# this error specifically.
os.spawnv(os.P_NOWAIT, sys.executable, [sys.executable] + argv)
# At this point the IOLoop has been closed and finally
# blocks will experience errors if we allow the stack to
# unwind, so just exit uncleanly.
os._exit(0)
_USAGE = """\
Usage:
python -m tornado.autoreload -m module.to.run [args...]
python -m tornado.autoreload path/to/script.py [args...]
"""
def main():
"""Command-line wrapper to re-run a script whenever its source changes.
Scripts may be specified by filename or module name::
python -m tornado.autoreload -m tornado.test.runtests
python -m tornado.autoreload tornado/test/runtests.py
Running a script with this wrapper is similar to calling
`tornado.autoreload.wait` at the end of the script, but this wrapper
can catch import-time problems like syntax errors that would otherwise
prevent the script from reaching its call to `wait`.
"""
# Remember that we were launched with autoreload as main.
# The main module can be tricky; set the variables both in our globals
# (which may be __main__) and the real importable version.
import tornado.autoreload
global _autoreload_is_main
global _original_argv, _original_spec
tornado.autoreload._autoreload_is_main = _autoreload_is_main = True
original_argv = sys.argv
tornado.autoreload._original_argv = _original_argv = original_argv
original_spec = getattr(sys.modules['__main__'], '__spec__', None)
tornado.autoreload._original_spec = _original_spec = original_spec
sys.argv = sys.argv[:]
if len(sys.argv) >= 3 and sys.argv[1] == "-m":
mode = "module"
module = sys.argv[2]
del sys.argv[1:3]
elif len(sys.argv) >= 2:
mode = "script"
script = sys.argv[1]
sys.argv = sys.argv[1:]
else:
print(_USAGE, file=sys.stderr)
sys.exit(1)
try:
if mode == "module":
import runpy
runpy.run_module(module, run_name="__main__", alter_sys=True)
elif mode == "script":
with open(script) as f:
# Execute the script in our namespace instead of creating
# a new one so that something that tries to import __main__
# (e.g. the unittest module) will see names defined in the
# script instead of just those defined in this module.
global __file__
__file__ = script
# If __package__ is defined, imports may be incorrectly
# interpreted as relative to this module.
global __package__
del __package__
exec_in(f.read(), globals(), globals())
except SystemExit as e:
logging.basicConfig()
gen_log.info("Script exited with status %s", e.code)
except Exception as e:
logging.basicConfig()
gen_log.warning("Script exited with uncaught exception", exc_info=True)
# If an exception occurred at import time, the file with the error
# never made it into sys.modules and so we won't know to watch it.
# Just to make sure we've covered everything, walk the stack trace
# from the exception and watch every file.
for (filename, lineno, name, line) in traceback.extract_tb(sys.exc_info()[2]):
watch(filename)
if isinstance(e, SyntaxError):
# SyntaxErrors are special: their innermost stack frame is fake
# so extract_tb won't see it and we have to get the filename
# from the exception object.
watch(e.filename)
else:
logging.basicConfig()
gen_log.info("Script exited normally")
# restore sys.argv so subsequent executions will include autoreload
sys.argv = original_argv
if mode == 'module':
# runpy did a fake import of the module as __main__, but now it's
# no longer in sys.modules. Figure out where it is and watch it.
loader = pkgutil.get_loader(module)
if loader is not None:
watch(loader.get_filename())
wait()
if __name__ == "__main__":
# See also the other __main__ block at the top of the file, which modifies
# sys.path before our imports
main()

660
lib/tornado/concurrent.py Executable file
View File

@@ -0,0 +1,660 @@
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Utilities for working with ``Future`` objects.
``Futures`` are a pattern for concurrent programming introduced in
Python 3.2 in the `concurrent.futures` package, and also adopted (in a
slightly different form) in Python 3.4's `asyncio` package. This
package defines a ``Future`` class that is an alias for `asyncio.Future`
when available, and a compatible implementation for older versions of
Python. It also includes some utility functions for interacting with
``Future`` objects.
While this package is an important part of Tornado's internal
implementation, applications rarely need to interact with it
directly.
"""
from __future__ import absolute_import, division, print_function
import functools
import platform
import textwrap
import traceback
import sys
import warnings
from tornado.log import app_log
from tornado.stack_context import ExceptionStackContext, wrap
from tornado.util import raise_exc_info, ArgReplacer, is_finalizing
try:
from concurrent import futures
except ImportError:
futures = None
try:
import asyncio
except ImportError:
asyncio = None
try:
import typing
except ImportError:
typing = None
# Can the garbage collector handle cycles that include __del__ methods?
# This is true in cpython beginning with version 3.4 (PEP 442).
_GC_CYCLE_FINALIZERS = (platform.python_implementation() == 'CPython' and
sys.version_info >= (3, 4))
class ReturnValueIgnoredError(Exception):
pass
# This class and associated code in the future object is derived
# from the Trollius project, a backport of asyncio to Python 2.x - 3.x
class _TracebackLogger(object):
"""Helper to log a traceback upon destruction if not cleared.
This solves a nasty problem with Futures and Tasks that have an
exception set: if nobody asks for the exception, the exception is
never logged. This violates the Zen of Python: 'Errors should
never pass silently. Unless explicitly silenced.'
However, we don't want to log the exception as soon as
set_exception() is called: if the calling code is written
properly, it will get the exception and handle it properly. But
we *do* want to log it if result() or exception() was never called
-- otherwise developers waste a lot of time wondering why their
buggy code fails silently.
An earlier attempt added a __del__() method to the Future class
itself, but this backfired because the presence of __del__()
prevents garbage collection from breaking cycles. A way out of
this catch-22 is to avoid having a __del__() method on the Future
class itself, but instead to have a reference to a helper object
with a __del__() method that logs the traceback, where we ensure
that the helper object doesn't participate in cycles, and only the
Future has a reference to it.
The helper object is added when set_exception() is called. When
the Future is collected, and the helper is present, the helper
object is also collected, and its __del__() method will log the
traceback. When the Future's result() or exception() method is
called (and a helper object is present), it removes the the helper
object, after calling its clear() method to prevent it from
logging.
One downside is that we do a fair amount of work to extract the
traceback from the exception, even when it is never logged. It
would seem cheaper to just store the exception object, but that
references the traceback, which references stack frames, which may
reference the Future, which references the _TracebackLogger, and
then the _TracebackLogger would be included in a cycle, which is
what we're trying to avoid! As an optimization, we don't
immediately format the exception; we only do the work when
activate() is called, which call is delayed until after all the
Future's callbacks have run. Since usually a Future has at least
one callback (typically set by 'yield From') and usually that
callback extracts the callback, thereby removing the need to
format the exception.
PS. I don't claim credit for this solution. I first heard of it
in a discussion about closing files when they are collected.
"""
__slots__ = ('exc_info', 'formatted_tb')
def __init__(self, exc_info):
self.exc_info = exc_info
self.formatted_tb = None
def activate(self):
exc_info = self.exc_info
if exc_info is not None:
self.exc_info = None
self.formatted_tb = traceback.format_exception(*exc_info)
def clear(self):
self.exc_info = None
self.formatted_tb = None
def __del__(self, is_finalizing=is_finalizing):
if not is_finalizing() and self.formatted_tb:
app_log.error('Future exception was never retrieved: %s',
''.join(self.formatted_tb).rstrip())
class Future(object):
"""Placeholder for an asynchronous result.
A ``Future`` encapsulates the result of an asynchronous
operation. In synchronous applications ``Futures`` are used
to wait for the result from a thread or process pool; in
Tornado they are normally used with `.IOLoop.add_future` or by
yielding them in a `.gen.coroutine`.
`tornado.concurrent.Future` is an alias for `asyncio.Future` when
that package is available (Python 3.4+). Unlike
`concurrent.futures.Future`, the ``Futures`` used by Tornado and
`asyncio` are not thread-safe (and therefore faster for use with
single-threaded event loops).
In addition to ``exception`` and ``set_exception``, Tornado's
``Future`` implementation supports storing an ``exc_info`` triple
to support better tracebacks on Python 2. To set an ``exc_info``
triple, use `future_set_exc_info`, and to retrieve one, call
`result()` (which will raise it).
.. versionchanged:: 4.0
`tornado.concurrent.Future` is always a thread-unsafe ``Future``
with support for the ``exc_info`` methods. Previously it would
be an alias for the thread-safe `concurrent.futures.Future`
if that package was available and fall back to the thread-unsafe
implementation if it was not.
.. versionchanged:: 4.1
If a `.Future` contains an error but that error is never observed
(by calling ``result()``, ``exception()``, or ``exc_info()``),
a stack trace will be logged when the `.Future` is garbage collected.
This normally indicates an error in the application, but in cases
where it results in undesired logging it may be necessary to
suppress the logging by ensuring that the exception is observed:
``f.add_done_callback(lambda f: f.exception())``.
.. versionchanged:: 5.0
This class was previoiusly available under the name
``TracebackFuture``. This name, which was deprecated since
version 4.0, has been removed. When `asyncio` is available
``tornado.concurrent.Future`` is now an alias for
`asyncio.Future`. Like `asyncio.Future`, callbacks are now
always scheduled on the `.IOLoop` and are never run
synchronously.
"""
def __init__(self):
self._done = False
self._result = None
self._exc_info = None
self._log_traceback = False # Used for Python >= 3.4
self._tb_logger = None # Used for Python <= 3.3
self._callbacks = []
# Implement the Python 3.5 Awaitable protocol if possible
# (we can't use return and yield together until py33).
if sys.version_info >= (3, 3):
exec(textwrap.dedent("""
def __await__(self):
return (yield self)
"""))
else:
# Py2-compatible version for use with cython.
def __await__(self):
result = yield self
# StopIteration doesn't take args before py33,
# but Cython recognizes the args tuple.
e = StopIteration()
e.args = (result,)
raise e
def cancel(self):
"""Cancel the operation, if possible.
Tornado ``Futures`` do not support cancellation, so this method always
returns False.
"""
return False
def cancelled(self):
"""Returns True if the operation has been cancelled.
Tornado ``Futures`` do not support cancellation, so this method
always returns False.
"""
return False
def running(self):
"""Returns True if this operation is currently running."""
return not self._done
def done(self):
"""Returns True if the future has finished running."""
return self._done
def _clear_tb_log(self):
self._log_traceback = False
if self._tb_logger is not None:
self._tb_logger.clear()
self._tb_logger = None
def result(self, timeout=None):
"""If the operation succeeded, return its result. If it failed,
re-raise its exception.
This method takes a ``timeout`` argument for compatibility with
`concurrent.futures.Future` but it is an error to call it
before the `Future` is done, so the ``timeout`` is never used.
"""
self._clear_tb_log()
if self._result is not None:
return self._result
if self._exc_info is not None:
try:
raise_exc_info(self._exc_info)
finally:
self = None
self._check_done()
return self._result
def exception(self, timeout=None):
"""If the operation raised an exception, return the `Exception`
object. Otherwise returns None.
This method takes a ``timeout`` argument for compatibility with
`concurrent.futures.Future` but it is an error to call it
before the `Future` is done, so the ``timeout`` is never used.
"""
self._clear_tb_log()
if self._exc_info is not None:
return self._exc_info[1]
else:
self._check_done()
return None
def add_done_callback(self, fn):
"""Attaches the given callback to the `Future`.
It will be invoked with the `Future` as its argument when the Future
has finished running and its result is available. In Tornado
consider using `.IOLoop.add_future` instead of calling
`add_done_callback` directly.
"""
if self._done:
from tornado.ioloop import IOLoop
IOLoop.current().add_callback(fn, self)
else:
self._callbacks.append(fn)
def set_result(self, result):
"""Sets the result of a ``Future``.
It is undefined to call any of the ``set`` methods more than once
on the same object.
"""
self._result = result
self._set_done()
def set_exception(self, exception):
"""Sets the exception of a ``Future.``"""
self.set_exc_info(
(exception.__class__,
exception,
getattr(exception, '__traceback__', None)))
def exc_info(self):
"""Returns a tuple in the same format as `sys.exc_info` or None.
.. versionadded:: 4.0
"""
self._clear_tb_log()
return self._exc_info
def set_exc_info(self, exc_info):
"""Sets the exception information of a ``Future.``
Preserves tracebacks on Python 2.
.. versionadded:: 4.0
"""
self._exc_info = exc_info
self._log_traceback = True
if not _GC_CYCLE_FINALIZERS:
self._tb_logger = _TracebackLogger(exc_info)
try:
self._set_done()
finally:
# Activate the logger after all callbacks have had a
# chance to call result() or exception().
if self._log_traceback and self._tb_logger is not None:
self._tb_logger.activate()
self._exc_info = exc_info
def _check_done(self):
if not self._done:
raise Exception("DummyFuture does not support blocking for results")
def _set_done(self):
self._done = True
if self._callbacks:
from tornado.ioloop import IOLoop
loop = IOLoop.current()
for cb in self._callbacks:
loop.add_callback(cb, self)
self._callbacks = None
# On Python 3.3 or older, objects with a destructor part of a reference
# cycle are never destroyed. It's no longer the case on Python 3.4 thanks to
# the PEP 442.
if _GC_CYCLE_FINALIZERS:
def __del__(self, is_finalizing=is_finalizing):
if is_finalizing() or not self._log_traceback:
# set_exception() was not called, or result() or exception()
# has consumed the exception
return
tb = traceback.format_exception(*self._exc_info)
app_log.error('Future %r exception was never retrieved: %s',
self, ''.join(tb).rstrip())
if asyncio is not None:
Future = asyncio.Future # noqa
if futures is None:
FUTURES = Future # type: typing.Union[type, typing.Tuple[type, ...]]
else:
FUTURES = (futures.Future, Future)
def is_future(x):
return isinstance(x, FUTURES)
class DummyExecutor(object):
def submit(self, fn, *args, **kwargs):
future = Future()
try:
future_set_result_unless_cancelled(future, fn(*args, **kwargs))
except Exception:
future_set_exc_info(future, sys.exc_info())
return future
def shutdown(self, wait=True):
pass
dummy_executor = DummyExecutor()
def run_on_executor(*args, **kwargs):
"""Decorator to run a synchronous method asynchronously on an executor.
The decorated method may be called with a ``callback`` keyword
argument and returns a future.
The executor to be used is determined by the ``executor``
attributes of ``self``. To use a different attribute name, pass a
keyword argument to the decorator::
@run_on_executor(executor='_thread_pool')
def foo(self):
pass
This decorator should not be confused with the similarly-named
`.IOLoop.run_in_executor`. In general, using ``run_in_executor``
when *calling* a blocking method is recommended instead of using
this decorator when *defining* a method. If compatibility with older
versions of Tornado is required, consider defining an executor
and using ``executor.submit()`` at the call site.
.. versionchanged:: 4.2
Added keyword arguments to use alternative attributes.
.. versionchanged:: 5.0
Always uses the current IOLoop instead of ``self.io_loop``.
.. versionchanged:: 5.1
Returns a `.Future` compatible with ``await`` instead of a
`concurrent.futures.Future`.
.. deprecated:: 5.1
The ``callback`` argument is deprecated and will be removed in
6.0. The decorator itself is discouraged in new code but will
not be removed in 6.0.
"""
def run_on_executor_decorator(fn):
executor = kwargs.get("executor", "executor")
@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
callback = kwargs.pop("callback", None)
async_future = Future()
conc_future = getattr(self, executor).submit(fn, self, *args, **kwargs)
chain_future(conc_future, async_future)
if callback:
warnings.warn("callback arguments are deprecated, use the returned Future instead",
DeprecationWarning)
from tornado.ioloop import IOLoop
IOLoop.current().add_future(
async_future, lambda future: callback(future.result()))
return async_future
return wrapper
if args and kwargs:
raise ValueError("cannot combine positional and keyword args")
if len(args) == 1:
return run_on_executor_decorator(args[0])
elif len(args) != 0:
raise ValueError("expected 1 argument, got %d", len(args))
return run_on_executor_decorator
_NO_RESULT = object()
def return_future(f):
"""Decorator to make a function that returns via callback return a
`Future`.
This decorator was provided to ease the transition from
callback-oriented code to coroutines. It is not recommended for
new code.
The wrapped function should take a ``callback`` keyword argument
and invoke it with one argument when it has finished. To signal failure,
the function can simply raise an exception (which will be
captured by the `.StackContext` and passed along to the ``Future``).
From the caller's perspective, the callback argument is optional.
If one is given, it will be invoked when the function is complete
with ``Future.result()`` as an argument. If the function fails, the
callback will not be run and an exception will be raised into the
surrounding `.StackContext`.
If no callback is given, the caller should use the ``Future`` to
wait for the function to complete (perhaps by yielding it in a
coroutine, or passing it to `.IOLoop.add_future`).
Usage:
.. testcode::
@return_future
def future_func(arg1, arg2, callback):
# Do stuff (possibly asynchronous)
callback(result)
async def caller():
await future_func(arg1, arg2)
..
Note that ``@return_future`` and ``@gen.engine`` can be applied to the
same function, provided ``@return_future`` appears first. However,
consider using ``@gen.coroutine`` instead of this combination.
.. versionchanged:: 5.1
Now raises a `.DeprecationWarning` if a callback argument is passed to
the decorated function and deprecation warnings are enabled.
.. deprecated:: 5.1
This decorator will be removed in Tornado 6.0. New code should
use coroutines directly instead of wrapping callback-based code
with this decorator. Interactions with non-Tornado
callback-based code should be managed explicitly to avoid
relying on the `.ExceptionStackContext` built into this
decorator.
"""
warnings.warn("@return_future is deprecated, use coroutines instead",
DeprecationWarning)
return _non_deprecated_return_future(f, warn=True)
def _non_deprecated_return_future(f, warn=False):
# Allow auth.py to use this decorator without triggering
# deprecation warnings. This will go away once auth.py has removed
# its legacy interfaces in 6.0.
replacer = ArgReplacer(f, 'callback')
@functools.wraps(f)
def wrapper(*args, **kwargs):
future = Future()
callback, args, kwargs = replacer.replace(
lambda value=_NO_RESULT: future_set_result_unless_cancelled(future, value),
args, kwargs)
def handle_error(typ, value, tb):
future_set_exc_info(future, (typ, value, tb))
return True
exc_info = None
esc = ExceptionStackContext(handle_error, delay_warning=True)
with esc:
if not warn:
# HACK: In non-deprecated mode (only used in auth.py),
# suppress the warning entirely. Since this is added
# in a 5.1 patch release and already removed in 6.0
# I'm prioritizing a minimial change instead of a
# clean solution.
esc.delay_warning = False
try:
result = f(*args, **kwargs)
if result is not None:
raise ReturnValueIgnoredError(
"@return_future should not be used with functions "
"that return values")
except:
exc_info = sys.exc_info()
raise
if exc_info is not None:
# If the initial synchronous part of f() raised an exception,
# go ahead and raise it to the caller directly without waiting
# for them to inspect the Future.
future.result()
# If the caller passed in a callback, schedule it to be called
# when the future resolves. It is important that this happens
# just before we return the future, or else we risk confusing
# stack contexts with multiple exceptions (one here with the
# immediate exception, and again when the future resolves and
# the callback triggers its exception by calling future.result()).
if callback is not None:
warnings.warn("callback arguments are deprecated, use the returned Future instead",
DeprecationWarning)
def run_callback(future):
result = future.result()
if result is _NO_RESULT:
callback()
else:
callback(future.result())
future_add_done_callback(future, wrap(run_callback))
return future
return wrapper
def chain_future(a, b):
"""Chain two futures together so that when one completes, so does the other.
The result (success or failure) of ``a`` will be copied to ``b``, unless
``b`` has already been completed or cancelled by the time ``a`` finishes.
.. versionchanged:: 5.0
Now accepts both Tornado/asyncio `Future` objects and
`concurrent.futures.Future`.
"""
def copy(future):
assert future is a
if b.done():
return
if (hasattr(a, 'exc_info') and
a.exc_info() is not None):
future_set_exc_info(b, a.exc_info())
elif a.exception() is not None:
b.set_exception(a.exception())
else:
b.set_result(a.result())
if isinstance(a, Future):
future_add_done_callback(a, copy)
else:
# concurrent.futures.Future
from tornado.ioloop import IOLoop
IOLoop.current().add_future(a, copy)
def future_set_result_unless_cancelled(future, value):
"""Set the given ``value`` as the `Future`'s result, if not cancelled.
Avoids asyncio.InvalidStateError when calling set_result() on
a cancelled `asyncio.Future`.
.. versionadded:: 5.0
"""
if not future.cancelled():
future.set_result(value)
def future_set_exc_info(future, exc_info):
"""Set the given ``exc_info`` as the `Future`'s exception.
Understands both `asyncio.Future` and Tornado's extensions to
enable better tracebacks on Python 2.
.. versionadded:: 5.0
"""
if hasattr(future, 'set_exc_info'):
# Tornado's Future
future.set_exc_info(exc_info)
else:
# asyncio.Future
future.set_exception(exc_info[1])
def future_add_done_callback(future, callback):
"""Arrange to call ``callback`` when ``future`` is complete.
``callback`` is invoked with one argument, the ``future``.
If ``future`` is already done, ``callback`` is invoked immediately.
This may differ from the behavior of ``Future.add_done_callback``,
which makes no such guarantee.
.. versionadded:: 5.0
"""
if future.done():
callback(future)
else:
future.add_done_callback(callback)

514
lib/tornado/curl_httpclient.py Executable file
View File

@@ -0,0 +1,514 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Non-blocking HTTP client implementation using pycurl."""
from __future__ import absolute_import, division, print_function
import collections
import functools
import logging
import pycurl # type: ignore
import threading
import time
from io import BytesIO
from tornado import httputil
from tornado import ioloop
from tornado import stack_context
from tornado.escape import utf8, native_str
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main
curl_log = logging.getLogger('tornado.curl_httpclient')
class CurlAsyncHTTPClient(AsyncHTTPClient):
def initialize(self, max_clients=10, defaults=None):
super(CurlAsyncHTTPClient, self).initialize(defaults=defaults)
self._multi = pycurl.CurlMulti()
self._multi.setopt(pycurl.M_TIMERFUNCTION, self._set_timeout)
self._multi.setopt(pycurl.M_SOCKETFUNCTION, self._handle_socket)
self._curls = [self._curl_create() for i in range(max_clients)]
self._free_list = self._curls[:]
self._requests = collections.deque()
self._fds = {}
self._timeout = None
# libcurl has bugs that sometimes cause it to not report all
# relevant file descriptors and timeouts to TIMERFUNCTION/
# SOCKETFUNCTION. Mitigate the effects of such bugs by
# forcing a periodic scan of all active requests.
self._force_timeout_callback = ioloop.PeriodicCallback(
self._handle_force_timeout, 1000)
self._force_timeout_callback.start()
# Work around a bug in libcurl 7.29.0: Some fields in the curl
# multi object are initialized lazily, and its destructor will
# segfault if it is destroyed without having been used. Add
# and remove a dummy handle to make sure everything is
# initialized.
dummy_curl_handle = pycurl.Curl()
self._multi.add_handle(dummy_curl_handle)
self._multi.remove_handle(dummy_curl_handle)
def close(self):
self._force_timeout_callback.stop()
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
for curl in self._curls:
curl.close()
self._multi.close()
super(CurlAsyncHTTPClient, self).close()
# Set below properties to None to reduce the reference count of current
# instance, because those properties hold some methods of current
# instance that will case circular reference.
self._force_timeout_callback = None
self._multi = None
def fetch_impl(self, request, callback):
self._requests.append((request, callback, self.io_loop.time()))
self._process_queue()
self._set_timeout(0)
def _handle_socket(self, event, fd, multi, data):
"""Called by libcurl when it wants to change the file descriptors
it cares about.
"""
event_map = {
pycurl.POLL_NONE: ioloop.IOLoop.NONE,
pycurl.POLL_IN: ioloop.IOLoop.READ,
pycurl.POLL_OUT: ioloop.IOLoop.WRITE,
pycurl.POLL_INOUT: ioloop.IOLoop.READ | ioloop.IOLoop.WRITE
}
if event == pycurl.POLL_REMOVE:
if fd in self._fds:
self.io_loop.remove_handler(fd)
del self._fds[fd]
else:
ioloop_event = event_map[event]
# libcurl sometimes closes a socket and then opens a new
# one using the same FD without giving us a POLL_NONE in
# between. This is a problem with the epoll IOLoop,
# because the kernel can tell when a socket is closed and
# removes it from the epoll automatically, causing future
# update_handler calls to fail. Since we can't tell when
# this has happened, always use remove and re-add
# instead of update.
if fd in self._fds:
self.io_loop.remove_handler(fd)
self.io_loop.add_handler(fd, self._handle_events,
ioloop_event)
self._fds[fd] = ioloop_event
def _set_timeout(self, msecs):
"""Called by libcurl to schedule a timeout."""
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = self.io_loop.add_timeout(
self.io_loop.time() + msecs / 1000.0, self._handle_timeout)
def _handle_events(self, fd, events):
"""Called by IOLoop when there is activity on one of our
file descriptors.
"""
action = 0
if events & ioloop.IOLoop.READ:
action |= pycurl.CSELECT_IN
if events & ioloop.IOLoop.WRITE:
action |= pycurl.CSELECT_OUT
while True:
try:
ret, num_handles = self._multi.socket_action(fd, action)
except pycurl.error as e:
ret = e.args[0]
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
self._finish_pending_requests()
def _handle_timeout(self):
"""Called by IOLoop when the requested timeout has passed."""
with stack_context.NullContext():
self._timeout = None
while True:
try:
ret, num_handles = self._multi.socket_action(
pycurl.SOCKET_TIMEOUT, 0)
except pycurl.error as e:
ret = e.args[0]
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
self._finish_pending_requests()
# In theory, we shouldn't have to do this because curl will
# call _set_timeout whenever the timeout changes. However,
# sometimes after _handle_timeout we will need to reschedule
# immediately even though nothing has changed from curl's
# perspective. This is because when socket_action is
# called with SOCKET_TIMEOUT, libcurl decides internally which
# timeouts need to be processed by using a monotonic clock
# (where available) while tornado uses python's time.time()
# to decide when timeouts have occurred. When those clocks
# disagree on elapsed time (as they will whenever there is an
# NTP adjustment), tornado might call _handle_timeout before
# libcurl is ready. After each timeout, resync the scheduled
# timeout with libcurl's current state.
new_timeout = self._multi.timeout()
if new_timeout >= 0:
self._set_timeout(new_timeout)
def _handle_force_timeout(self):
"""Called by IOLoop periodically to ask libcurl to process any
events it may have forgotten about.
"""
with stack_context.NullContext():
while True:
try:
ret, num_handles = self._multi.socket_all()
except pycurl.error as e:
ret = e.args[0]
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
self._finish_pending_requests()
def _finish_pending_requests(self):
"""Process any requests that were completed by the last
call to multi.socket_action.
"""
while True:
num_q, ok_list, err_list = self._multi.info_read()
for curl in ok_list:
self._finish(curl)
for curl, errnum, errmsg in err_list:
self._finish(curl, errnum, errmsg)
if num_q == 0:
break
self._process_queue()
def _process_queue(self):
with stack_context.NullContext():
while True:
started = 0
while self._free_list and self._requests:
started += 1
curl = self._free_list.pop()
(request, callback, queue_start_time) = self._requests.popleft()
curl.info = {
"headers": httputil.HTTPHeaders(),
"buffer": BytesIO(),
"request": request,
"callback": callback,
"queue_start_time": queue_start_time,
"curl_start_time": time.time(),
"curl_start_ioloop_time": self.io_loop.current().time(),
}
try:
self._curl_setup_request(
curl, request, curl.info["buffer"],
curl.info["headers"])
except Exception as e:
# If there was an error in setup, pass it on
# to the callback. Note that allowing the
# error to escape here will appear to work
# most of the time since we are still in the
# caller's original stack frame, but when
# _process_queue() is called from
# _finish_pending_requests the exceptions have
# nowhere to go.
self._free_list.append(curl)
callback(HTTPResponse(
request=request,
code=599,
error=e))
else:
self._multi.add_handle(curl)
if not started:
break
def _finish(self, curl, curl_error=None, curl_message=None):
info = curl.info
curl.info = None
self._multi.remove_handle(curl)
self._free_list.append(curl)
buffer = info["buffer"]
if curl_error:
error = CurlError(curl_error, curl_message)
code = error.code
effective_url = None
buffer.close()
buffer = None
else:
error = None
code = curl.getinfo(pycurl.HTTP_CODE)
effective_url = curl.getinfo(pycurl.EFFECTIVE_URL)
buffer.seek(0)
# the various curl timings are documented at
# http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html
time_info = dict(
queue=info["curl_start_ioloop_time"] - info["queue_start_time"],
namelookup=curl.getinfo(pycurl.NAMELOOKUP_TIME),
connect=curl.getinfo(pycurl.CONNECT_TIME),
appconnect=curl.getinfo(pycurl.APPCONNECT_TIME),
pretransfer=curl.getinfo(pycurl.PRETRANSFER_TIME),
starttransfer=curl.getinfo(pycurl.STARTTRANSFER_TIME),
total=curl.getinfo(pycurl.TOTAL_TIME),
redirect=curl.getinfo(pycurl.REDIRECT_TIME),
)
try:
info["callback"](HTTPResponse(
request=info["request"], code=code, headers=info["headers"],
buffer=buffer, effective_url=effective_url, error=error,
reason=info['headers'].get("X-Http-Reason", None),
request_time=self.io_loop.time() - info["curl_start_ioloop_time"],
start_time=info["curl_start_time"],
time_info=time_info))
except Exception:
self.handle_callback_exception(info["callback"])
def handle_callback_exception(self, callback):
self.io_loop.handle_callback_exception(callback)
def _curl_create(self):
curl = pycurl.Curl()
if curl_log.isEnabledFor(logging.DEBUG):
curl.setopt(pycurl.VERBOSE, 1)
curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug)
if hasattr(pycurl, 'PROTOCOLS'): # PROTOCOLS first appeared in pycurl 7.19.5 (2014-07-12)
curl.setopt(pycurl.PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
curl.setopt(pycurl.REDIR_PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
return curl
def _curl_setup_request(self, curl, request, buffer, headers):
curl.setopt(pycurl.URL, native_str(request.url))
# libcurl's magic "Expect: 100-continue" behavior causes delays
# with servers that don't support it (which include, among others,
# Google's OpenID endpoint). Additionally, this behavior has
# a bug in conjunction with the curl_multi_socket_action API
# (https://sourceforge.net/tracker/?func=detail&atid=100976&aid=3039744&group_id=976),
# which increases the delays. It's more trouble than it's worth,
# so just turn off the feature (yes, setting Expect: to an empty
# value is the official way to disable this)
if "Expect" not in request.headers:
request.headers["Expect"] = ""
# libcurl adds Pragma: no-cache by default; disable that too
if "Pragma" not in request.headers:
request.headers["Pragma"] = ""
curl.setopt(pycurl.HTTPHEADER,
["%s: %s" % (native_str(k), native_str(v))
for k, v in request.headers.get_all()])
curl.setopt(pycurl.HEADERFUNCTION,
functools.partial(self._curl_header_callback,
headers, request.header_callback))
if request.streaming_callback:
def write_function(chunk):
self.io_loop.add_callback(request.streaming_callback, chunk)
else:
write_function = buffer.write
curl.setopt(pycurl.WRITEFUNCTION, write_function)
curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
if request.user_agent:
curl.setopt(pycurl.USERAGENT, native_str(request.user_agent))
else:
curl.setopt(pycurl.USERAGENT, "Mozilla/5.0 (compatible; pycurl)")
if request.network_interface:
curl.setopt(pycurl.INTERFACE, request.network_interface)
if request.decompress_response:
curl.setopt(pycurl.ENCODING, "gzip,deflate")
else:
curl.setopt(pycurl.ENCODING, "none")
if request.proxy_host and request.proxy_port:
curl.setopt(pycurl.PROXY, request.proxy_host)
curl.setopt(pycurl.PROXYPORT, request.proxy_port)
if request.proxy_username:
credentials = httputil.encode_username_password(request.proxy_username,
request.proxy_password)
curl.setopt(pycurl.PROXYUSERPWD, credentials)
if (request.proxy_auth_mode is None or
request.proxy_auth_mode == "basic"):
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_BASIC)
elif request.proxy_auth_mode == "digest":
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_DIGEST)
else:
raise ValueError(
"Unsupported proxy_auth_mode %s" % request.proxy_auth_mode)
else:
curl.setopt(pycurl.PROXY, '')
curl.unsetopt(pycurl.PROXYUSERPWD)
if request.validate_cert:
curl.setopt(pycurl.SSL_VERIFYPEER, 1)
curl.setopt(pycurl.SSL_VERIFYHOST, 2)
else:
curl.setopt(pycurl.SSL_VERIFYPEER, 0)
curl.setopt(pycurl.SSL_VERIFYHOST, 0)
if request.ca_certs is not None:
curl.setopt(pycurl.CAINFO, request.ca_certs)
else:
# There is no way to restore pycurl.CAINFO to its default value
# (Using unsetopt makes it reject all certificates).
# I don't see any way to read the default value from python so it
# can be restored later. We'll have to just leave CAINFO untouched
# if no ca_certs file was specified, and require that if any
# request uses a custom ca_certs file, they all must.
pass
if request.allow_ipv6 is False:
# Curl behaves reasonably when DNS resolution gives an ipv6 address
# that we can't reach, so allow ipv6 unless the user asks to disable.
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
else:
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER)
# Set the request method through curl's irritating interface which makes
# up names for almost every single method
curl_options = {
"GET": pycurl.HTTPGET,
"POST": pycurl.POST,
"PUT": pycurl.UPLOAD,
"HEAD": pycurl.NOBODY,
}
custom_methods = set(["DELETE", "OPTIONS", "PATCH"])
for o in curl_options.values():
curl.setopt(o, False)
if request.method in curl_options:
curl.unsetopt(pycurl.CUSTOMREQUEST)
curl.setopt(curl_options[request.method], True)
elif request.allow_nonstandard_methods or request.method in custom_methods:
curl.setopt(pycurl.CUSTOMREQUEST, request.method)
else:
raise KeyError('unknown method ' + request.method)
body_expected = request.method in ("POST", "PATCH", "PUT")
body_present = request.body is not None
if not request.allow_nonstandard_methods:
# Some HTTP methods nearly always have bodies while others
# almost never do. Fail in this case unless the user has
# opted out of sanity checks with allow_nonstandard_methods.
if ((body_expected and not body_present) or
(body_present and not body_expected)):
raise ValueError(
'Body must %sbe None for method %s (unless '
'allow_nonstandard_methods is true)' %
('not ' if body_expected else '', request.method))
if body_expected or body_present:
if request.method == "GET":
# Even with `allow_nonstandard_methods` we disallow
# GET with a body (because libcurl doesn't allow it
# unless we use CUSTOMREQUEST). While the spec doesn't
# forbid clients from sending a body, it arguably
# disallows the server from doing anything with them.
raise ValueError('Body must be None for GET request')
request_buffer = BytesIO(utf8(request.body or ''))
def ioctl(cmd):
if cmd == curl.IOCMD_RESTARTREAD:
request_buffer.seek(0)
curl.setopt(pycurl.READFUNCTION, request_buffer.read)
curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
if request.method == "POST":
curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or ''))
else:
curl.setopt(pycurl.UPLOAD, True)
curl.setopt(pycurl.INFILESIZE, len(request.body or ''))
if request.auth_username is not None:
if request.auth_mode is None or request.auth_mode == "basic":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
elif request.auth_mode == "digest":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_DIGEST)
else:
raise ValueError("Unsupported auth_mode %s" % request.auth_mode)
userpwd = httputil.encode_username_password(request.auth_username,
request.auth_password)
curl.setopt(pycurl.USERPWD, userpwd)
curl_log.debug("%s %s (username: %r)", request.method, request.url,
request.auth_username)
else:
curl.unsetopt(pycurl.USERPWD)
curl_log.debug("%s %s", request.method, request.url)
if request.client_cert is not None:
curl.setopt(pycurl.SSLCERT, request.client_cert)
if request.client_key is not None:
curl.setopt(pycurl.SSLKEY, request.client_key)
if request.ssl_options is not None:
raise ValueError("ssl_options not supported in curl_httpclient")
if threading.activeCount() > 1:
# libcurl/pycurl is not thread-safe by default. When multiple threads
# are used, signals should be disabled. This has the side effect
# of disabling DNS timeouts in some environments (when libcurl is
# not linked against ares), so we don't do it when there is only one
# thread. Applications that use many short-lived threads may need
# to set NOSIGNAL manually in a prepare_curl_callback since
# there may not be any other threads running at the time we call
# threading.activeCount.
curl.setopt(pycurl.NOSIGNAL, 1)
if request.prepare_curl_callback is not None:
request.prepare_curl_callback(curl)
def _curl_header_callback(self, headers, header_callback, header_line):
header_line = native_str(header_line.decode('latin1'))
if header_callback is not None:
self.io_loop.add_callback(header_callback, header_line)
# header_line as returned by curl includes the end-of-line characters.
# whitespace at the start should be preserved to allow multi-line headers
header_line = header_line.rstrip()
if header_line.startswith("HTTP/"):
headers.clear()
try:
(__, __, reason) = httputil.parse_response_start_line(header_line)
header_line = "X-Http-Reason: %s" % reason
except httputil.HTTPInputError:
return
if not header_line:
return
headers.parse_line(header_line)
def _curl_debug(self, debug_type, debug_msg):
debug_types = ('I', '<', '>', '<', '>')
if debug_type == 0:
debug_msg = native_str(debug_msg)
curl_log.debug('%s', debug_msg.strip())
elif debug_type in (1, 2):
debug_msg = native_str(debug_msg)
for line in debug_msg.splitlines():
curl_log.debug('%s %s', debug_types[debug_type], line)
elif debug_type == 4:
curl_log.debug('%s %r', debug_types[debug_type], debug_msg)
class CurlError(HTTPError):
def __init__(self, errno, message):
HTTPError.__init__(self, 599, message)
self.errno = errno
if __name__ == "__main__":
AsyncHTTPClient.configure(CurlAsyncHTTPClient)
main()

399
lib/tornado/escape.py Executable file
View File

@@ -0,0 +1,399 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Escaping/unescaping methods for HTML, JSON, URLs, and others.
Also includes a few other miscellaneous string manipulation functions that
have crept in over time.
"""
from __future__ import absolute_import, division, print_function
import json
import re
from tornado.util import PY3, unicode_type, basestring_type
if PY3:
from urllib.parse import parse_qs as _parse_qs
import html.entities as htmlentitydefs
import urllib.parse as urllib_parse
unichr = chr
else:
from urlparse import parse_qs as _parse_qs
import htmlentitydefs
import urllib as urllib_parse
try:
import typing # noqa
except ImportError:
pass
_XHTML_ESCAPE_RE = re.compile('[&<>"\']')
_XHTML_ESCAPE_DICT = {'&': '&amp;', '<': '&lt;', '>': '&gt;', '"': '&quot;',
'\'': '&#39;'}
def xhtml_escape(value):
"""Escapes a string so it is valid within HTML or XML.
Escapes the characters ``<``, ``>``, ``"``, ``'``, and ``&``.
When used in attribute values the escaped strings must be enclosed
in quotes.
.. versionchanged:: 3.2
Added the single quote to the list of escaped characters.
"""
return _XHTML_ESCAPE_RE.sub(lambda match: _XHTML_ESCAPE_DICT[match.group(0)],
to_basestring(value))
def xhtml_unescape(value):
"""Un-escapes an XML-escaped string."""
return re.sub(r"&(#?)(\w+?);", _convert_entity, _unicode(value))
# The fact that json_encode wraps json.dumps is an implementation detail.
# Please see https://github.com/tornadoweb/tornado/pull/706
# before sending a pull request that adds **kwargs to this function.
def json_encode(value):
"""JSON-encodes the given Python object."""
# JSON permits but does not require forward slashes to be escaped.
# This is useful when json data is emitted in a <script> tag
# in HTML, as it prevents </script> tags from prematurely terminating
# the javascript. Some json libraries do this escaping by default,
# although python's standard library does not, so we do it here.
# http://stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped
return json.dumps(value).replace("</", "<\\/")
def json_decode(value):
"""Returns Python objects for the given JSON string."""
return json.loads(to_basestring(value))
def squeeze(value):
"""Replace all sequences of whitespace chars with a single space."""
return re.sub(r"[\x00-\x20]+", " ", value).strip()
def url_escape(value, plus=True):
"""Returns a URL-encoded version of the given value.
If ``plus`` is true (the default), spaces will be represented
as "+" instead of "%20". This is appropriate for query strings
but not for the path component of a URL. Note that this default
is the reverse of Python's urllib module.
.. versionadded:: 3.1
The ``plus`` argument
"""
quote = urllib_parse.quote_plus if plus else urllib_parse.quote
return quote(utf8(value))
# python 3 changed things around enough that we need two separate
# implementations of url_unescape. We also need our own implementation
# of parse_qs since python 3's version insists on decoding everything.
if not PY3:
def url_unescape(value, encoding='utf-8', plus=True):
"""Decodes the given value from a URL.
The argument may be either a byte or unicode string.
If encoding is None, the result will be a byte string. Otherwise,
the result is a unicode string in the specified encoding.
If ``plus`` is true (the default), plus signs will be interpreted
as spaces (literal plus signs must be represented as "%2B"). This
is appropriate for query strings and form-encoded values but not
for the path component of a URL. Note that this default is the
reverse of Python's urllib module.
.. versionadded:: 3.1
The ``plus`` argument
"""
unquote = (urllib_parse.unquote_plus if plus else urllib_parse.unquote)
if encoding is None:
return unquote(utf8(value))
else:
return unicode_type(unquote(utf8(value)), encoding)
parse_qs_bytes = _parse_qs
else:
def url_unescape(value, encoding='utf-8', plus=True):
"""Decodes the given value from a URL.
The argument may be either a byte or unicode string.
If encoding is None, the result will be a byte string. Otherwise,
the result is a unicode string in the specified encoding.
If ``plus`` is true (the default), plus signs will be interpreted
as spaces (literal plus signs must be represented as "%2B"). This
is appropriate for query strings and form-encoded values but not
for the path component of a URL. Note that this default is the
reverse of Python's urllib module.
.. versionadded:: 3.1
The ``plus`` argument
"""
if encoding is None:
if plus:
# unquote_to_bytes doesn't have a _plus variant
value = to_basestring(value).replace('+', ' ')
return urllib_parse.unquote_to_bytes(value)
else:
unquote = (urllib_parse.unquote_plus if plus
else urllib_parse.unquote)
return unquote(to_basestring(value), encoding=encoding)
def parse_qs_bytes(qs, keep_blank_values=False, strict_parsing=False):
"""Parses a query string like urlparse.parse_qs, but returns the
values as byte strings.
Keys still become type str (interpreted as latin1 in python3!)
because it's too painful to keep them as byte strings in
python3 and in practice they're nearly always ascii anyway.
"""
# This is gross, but python3 doesn't give us another way.
# Latin1 is the universal donor of character encodings.
result = _parse_qs(qs, keep_blank_values, strict_parsing,
encoding='latin1', errors='strict')
encoded = {}
for k, v in result.items():
encoded[k] = [i.encode('latin1') for i in v]
return encoded
_UTF8_TYPES = (bytes, type(None))
def utf8(value):
# type: (typing.Union[bytes,unicode_type,None])->typing.Union[bytes,None]
"""Converts a string argument to a byte string.
If the argument is already a byte string or None, it is returned unchanged.
Otherwise it must be a unicode string and is encoded as utf8.
"""
if isinstance(value, _UTF8_TYPES):
return value
if not isinstance(value, unicode_type):
raise TypeError(
"Expected bytes, unicode, or None; got %r" % type(value)
)
return value.encode("utf-8")
_TO_UNICODE_TYPES = (unicode_type, type(None))
def to_unicode(value):
"""Converts a string argument to a unicode string.
If the argument is already a unicode string or None, it is returned
unchanged. Otherwise it must be a byte string and is decoded as utf8.
"""
if isinstance(value, _TO_UNICODE_TYPES):
return value
if not isinstance(value, bytes):
raise TypeError(
"Expected bytes, unicode, or None; got %r" % type(value)
)
return value.decode("utf-8")
# to_unicode was previously named _unicode not because it was private,
# but to avoid conflicts with the built-in unicode() function/type
_unicode = to_unicode
# When dealing with the standard library across python 2 and 3 it is
# sometimes useful to have a direct conversion to the native string type
if str is unicode_type:
native_str = to_unicode
else:
native_str = utf8
_BASESTRING_TYPES = (basestring_type, type(None))
def to_basestring(value):
"""Converts a string argument to a subclass of basestring.
In python2, byte and unicode strings are mostly interchangeable,
so functions that deal with a user-supplied argument in combination
with ascii string constants can use either and should return the type
the user supplied. In python3, the two types are not interchangeable,
so this method is needed to convert byte strings to unicode.
"""
if isinstance(value, _BASESTRING_TYPES):
return value
if not isinstance(value, bytes):
raise TypeError(
"Expected bytes, unicode, or None; got %r" % type(value)
)
return value.decode("utf-8")
def recursive_unicode(obj):
"""Walks a simple data structure, converting byte strings to unicode.
Supports lists, tuples, and dictionaries.
"""
if isinstance(obj, dict):
return dict((recursive_unicode(k), recursive_unicode(v)) for (k, v) in obj.items())
elif isinstance(obj, list):
return list(recursive_unicode(i) for i in obj)
elif isinstance(obj, tuple):
return tuple(recursive_unicode(i) for i in obj)
elif isinstance(obj, bytes):
return to_unicode(obj)
else:
return obj
# I originally used the regex from
# http://daringfireball.net/2010/07/improved_regex_for_matching_urls
# but it gets all exponential on certain patterns (such as too many trailing
# dots), causing the regex matcher to never return.
# This regex should avoid those problems.
# Use to_unicode instead of tornado.util.u - we don't want backslashes getting
# processed as escapes.
_URL_RE = re.compile(to_unicode(
r"""\b((?:([\w-]+):(/{1,3})|www[.])(?:(?:(?:[^\s&()]|&amp;|&quot;)*(?:[^!"#$%&'()*+,.:;<=>?@\[\]^`{|}~\s]))|(?:\((?:[^\s&()]|&amp;|&quot;)*\)))+)""" # noqa: E501
))
def linkify(text, shorten=False, extra_params="",
require_protocol=False, permitted_protocols=["http", "https"]):
"""Converts plain text into HTML with links.
For example: ``linkify("Hello http://tornadoweb.org!")`` would return
``Hello <a href="http://tornadoweb.org">http://tornadoweb.org</a>!``
Parameters:
* ``shorten``: Long urls will be shortened for display.
* ``extra_params``: Extra text to include in the link tag, or a callable
taking the link as an argument and returning the extra text
e.g. ``linkify(text, extra_params='rel="nofollow" class="external"')``,
or::
def extra_params_cb(url):
if url.startswith("http://example.com"):
return 'class="internal"'
else:
return 'class="external" rel="nofollow"'
linkify(text, extra_params=extra_params_cb)
* ``require_protocol``: Only linkify urls which include a protocol. If
this is False, urls such as www.facebook.com will also be linkified.
* ``permitted_protocols``: List (or set) of protocols which should be
linkified, e.g. ``linkify(text, permitted_protocols=["http", "ftp",
"mailto"])``. It is very unsafe to include protocols such as
``javascript``.
"""
if extra_params and not callable(extra_params):
extra_params = " " + extra_params.strip()
def make_link(m):
url = m.group(1)
proto = m.group(2)
if require_protocol and not proto:
return url # not protocol, no linkify
if proto and proto not in permitted_protocols:
return url # bad protocol, no linkify
href = m.group(1)
if not proto:
href = "http://" + href # no proto specified, use http
if callable(extra_params):
params = " " + extra_params(href).strip()
else:
params = extra_params
# clip long urls. max_len is just an approximation
max_len = 30
if shorten and len(url) > max_len:
before_clip = url
if proto:
proto_len = len(proto) + 1 + len(m.group(3) or "") # +1 for :
else:
proto_len = 0
parts = url[proto_len:].split("/")
if len(parts) > 1:
# Grab the whole host part plus the first bit of the path
# The path is usually not that interesting once shortened
# (no more slug, etc), so it really just provides a little
# extra indication of shortening.
url = url[:proto_len] + parts[0] + "/" + \
parts[1][:8].split('?')[0].split('.')[0]
if len(url) > max_len * 1.5: # still too long
url = url[:max_len]
if url != before_clip:
amp = url.rfind('&')
# avoid splitting html char entities
if amp > max_len - 5:
url = url[:amp]
url += "..."
if len(url) >= len(before_clip):
url = before_clip
else:
# full url is visible on mouse-over (for those who don't
# have a status bar, such as Safari by default)
params += ' title="%s"' % href
return u'<a href="%s"%s>%s</a>' % (href, params, url)
# First HTML-escape so that our strings are all safe.
# The regex is modified to avoid character entites other than &amp; so
# that we won't pick up &quot;, etc.
text = _unicode(xhtml_escape(text))
return _URL_RE.sub(make_link, text)
def _convert_entity(m):
if m.group(1) == "#":
try:
if m.group(2)[:1].lower() == 'x':
return unichr(int(m.group(2)[1:], 16))
else:
return unichr(int(m.group(2)))
except ValueError:
return "&#%s;" % m.group(2)
try:
return _HTML_UNICODE_MAP[m.group(2)]
except KeyError:
return "&%s;" % m.group(2)
def _build_unicode_map():
unicode_map = {}
for name, value in htmlentitydefs.name2codepoint.items():
unicode_map[name] = unichr(value)
return unicode_map
_HTML_UNICODE_MAP = _build_unicode_map()

1367
lib/tornado/gen.py Executable file

File diff suppressed because it is too large Load Diff

751
lib/tornado/http1connection.py Executable file
View File

@@ -0,0 +1,751 @@
#
# Copyright 2014 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Client and server implementations of HTTP/1.x.
.. versionadded:: 4.0
"""
from __future__ import absolute_import, division, print_function
import re
import warnings
from tornado.concurrent import (Future, future_add_done_callback,
future_set_result_unless_cancelled)
from tornado.escape import native_str, utf8
from tornado import gen
from tornado import httputil
from tornado import iostream
from tornado.log import gen_log, app_log
from tornado import stack_context
from tornado.util import GzipDecompressor, PY3
class _QuietException(Exception):
def __init__(self):
pass
class _ExceptionLoggingContext(object):
"""Used with the ``with`` statement when calling delegate methods to
log any exceptions with the given logger. Any exceptions caught are
converted to _QuietException
"""
def __init__(self, logger):
self.logger = logger
def __enter__(self):
pass
def __exit__(self, typ, value, tb):
if value is not None:
self.logger.error("Uncaught exception", exc_info=(typ, value, tb))
raise _QuietException
class HTTP1ConnectionParameters(object):
"""Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`.
"""
def __init__(self, no_keep_alive=False, chunk_size=None,
max_header_size=None, header_timeout=None, max_body_size=None,
body_timeout=None, decompress=False):
"""
:arg bool no_keep_alive: If true, always close the connection after
one request.
:arg int chunk_size: how much data to read into memory at once
:arg int max_header_size: maximum amount of data for HTTP headers
:arg float header_timeout: how long to wait for all headers (seconds)
:arg int max_body_size: maximum amount of data for body
:arg float body_timeout: how long to wait while reading body (seconds)
:arg bool decompress: if true, decode incoming
``Content-Encoding: gzip``
"""
self.no_keep_alive = no_keep_alive
self.chunk_size = chunk_size or 65536
self.max_header_size = max_header_size or 65536
self.header_timeout = header_timeout
self.max_body_size = max_body_size
self.body_timeout = body_timeout
self.decompress = decompress
class HTTP1Connection(httputil.HTTPConnection):
"""Implements the HTTP/1.x protocol.
This class can be on its own for clients, or via `HTTP1ServerConnection`
for servers.
"""
def __init__(self, stream, is_client, params=None, context=None):
"""
:arg stream: an `.IOStream`
:arg bool is_client: client or server
:arg params: a `.HTTP1ConnectionParameters` instance or ``None``
:arg context: an opaque application-defined object that can be accessed
as ``connection.context``.
"""
self.is_client = is_client
self.stream = stream
if params is None:
params = HTTP1ConnectionParameters()
self.params = params
self.context = context
self.no_keep_alive = params.no_keep_alive
# The body limits can be altered by the delegate, so save them
# here instead of just referencing self.params later.
self._max_body_size = (self.params.max_body_size or
self.stream.max_buffer_size)
self._body_timeout = self.params.body_timeout
# _write_finished is set to True when finish() has been called,
# i.e. there will be no more data sent. Data may still be in the
# stream's write buffer.
self._write_finished = False
# True when we have read the entire incoming body.
self._read_finished = False
# _finish_future resolves when all data has been written and flushed
# to the IOStream.
self._finish_future = Future()
# If true, the connection should be closed after this request
# (after the response has been written in the server side,
# and after it has been read in the client)
self._disconnect_on_finish = False
self._clear_callbacks()
# Save the start lines after we read or write them; they
# affect later processing (e.g. 304 responses and HEAD methods
# have content-length but no bodies)
self._request_start_line = None
self._response_start_line = None
self._request_headers = None
# True if we are writing output with chunked encoding.
self._chunking_output = None
# While reading a body with a content-length, this is the
# amount left to read.
self._expected_content_remaining = None
# A Future for our outgoing writes, returned by IOStream.write.
self._pending_write = None
def read_response(self, delegate):
"""Read a single HTTP response.
Typical client-mode usage is to write a request using `write_headers`,
`write`, and `finish`, and then call ``read_response``.
:arg delegate: a `.HTTPMessageDelegate`
Returns a `.Future` that resolves to None after the full response has
been read.
"""
if self.params.decompress:
delegate = _GzipMessageDelegate(delegate, self.params.chunk_size)
return self._read_message(delegate)
@gen.coroutine
def _read_message(self, delegate):
need_delegate_close = False
try:
header_future = self.stream.read_until_regex(
b"\r?\n\r?\n",
max_bytes=self.params.max_header_size)
if self.params.header_timeout is None:
header_data = yield header_future
else:
try:
header_data = yield gen.with_timeout(
self.stream.io_loop.time() + self.params.header_timeout,
header_future,
quiet_exceptions=iostream.StreamClosedError)
except gen.TimeoutError:
self.close()
raise gen.Return(False)
start_line, headers = self._parse_headers(header_data)
if self.is_client:
start_line = httputil.parse_response_start_line(start_line)
self._response_start_line = start_line
else:
start_line = httputil.parse_request_start_line(start_line)
self._request_start_line = start_line
self._request_headers = headers
self._disconnect_on_finish = not self._can_keep_alive(
start_line, headers)
need_delegate_close = True
with _ExceptionLoggingContext(app_log):
header_future = delegate.headers_received(start_line, headers)
if header_future is not None:
yield header_future
if self.stream is None:
# We've been detached.
need_delegate_close = False
raise gen.Return(False)
skip_body = False
if self.is_client:
if (self._request_start_line is not None and
self._request_start_line.method == 'HEAD'):
skip_body = True
code = start_line.code
if code == 304:
# 304 responses may include the content-length header
# but do not actually have a body.
# http://tools.ietf.org/html/rfc7230#section-3.3
skip_body = True
if code >= 100 and code < 200:
# 1xx responses should never indicate the presence of
# a body.
if ('Content-Length' in headers or
'Transfer-Encoding' in headers):
raise httputil.HTTPInputError(
"Response code %d cannot have body" % code)
# TODO: client delegates will get headers_received twice
# in the case of a 100-continue. Document or change?
yield self._read_message(delegate)
else:
if (headers.get("Expect") == "100-continue" and
not self._write_finished):
self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
if not skip_body:
body_future = self._read_body(
start_line.code if self.is_client else 0, headers, delegate)
if body_future is not None:
if self._body_timeout is None:
yield body_future
else:
try:
yield gen.with_timeout(
self.stream.io_loop.time() + self._body_timeout,
body_future,
quiet_exceptions=iostream.StreamClosedError)
except gen.TimeoutError:
gen_log.info("Timeout reading body from %s",
self.context)
self.stream.close()
raise gen.Return(False)
self._read_finished = True
if not self._write_finished or self.is_client:
need_delegate_close = False
with _ExceptionLoggingContext(app_log):
delegate.finish()
# If we're waiting for the application to produce an asynchronous
# response, and we're not detached, register a close callback
# on the stream (we didn't need one while we were reading)
if (not self._finish_future.done() and
self.stream is not None and
not self.stream.closed()):
self.stream.set_close_callback(self._on_connection_close)
yield self._finish_future
if self.is_client and self._disconnect_on_finish:
self.close()
if self.stream is None:
raise gen.Return(False)
except httputil.HTTPInputError as e:
gen_log.info("Malformed HTTP message from %s: %s",
self.context, e)
if not self.is_client:
yield self.stream.write(b'HTTP/1.1 400 Bad Request\r\n\r\n')
self.close()
raise gen.Return(False)
finally:
if need_delegate_close:
with _ExceptionLoggingContext(app_log):
delegate.on_connection_close()
header_future = None
self._clear_callbacks()
raise gen.Return(True)
def _clear_callbacks(self):
"""Clears the callback attributes.
This allows the request handler to be garbage collected more
quickly in CPython by breaking up reference cycles.
"""
self._write_callback = None
self._write_future = None
self._close_callback = None
if self.stream is not None:
self.stream.set_close_callback(None)
def set_close_callback(self, callback):
"""Sets a callback that will be run when the connection is closed.
Note that this callback is slightly different from
`.HTTPMessageDelegate.on_connection_close`: The
`.HTTPMessageDelegate` method is called when the connection is
closed while recieving a message. This callback is used when
there is not an active delegate (for example, on the server
side this callback is used if the client closes the connection
after sending its request but before receiving all the
response.
"""
self._close_callback = stack_context.wrap(callback)
def _on_connection_close(self):
# Note that this callback is only registered on the IOStream
# when we have finished reading the request and are waiting for
# the application to produce its response.
if self._close_callback is not None:
callback = self._close_callback
self._close_callback = None
callback()
if not self._finish_future.done():
future_set_result_unless_cancelled(self._finish_future, None)
self._clear_callbacks()
def close(self):
if self.stream is not None:
self.stream.close()
self._clear_callbacks()
if not self._finish_future.done():
future_set_result_unless_cancelled(self._finish_future, None)
def detach(self):
"""Take control of the underlying stream.
Returns the underlying `.IOStream` object and stops all further
HTTP processing. May only be called during
`.HTTPMessageDelegate.headers_received`. Intended for implementing
protocols like websockets that tunnel over an HTTP handshake.
"""
self._clear_callbacks()
stream = self.stream
self.stream = None
if not self._finish_future.done():
future_set_result_unless_cancelled(self._finish_future, None)
return stream
def set_body_timeout(self, timeout):
"""Sets the body timeout for a single request.
Overrides the value from `.HTTP1ConnectionParameters`.
"""
self._body_timeout = timeout
def set_max_body_size(self, max_body_size):
"""Sets the body size limit for a single request.
Overrides the value from `.HTTP1ConnectionParameters`.
"""
self._max_body_size = max_body_size
def write_headers(self, start_line, headers, chunk=None, callback=None):
"""Implements `.HTTPConnection.write_headers`."""
lines = []
if self.is_client:
self._request_start_line = start_line
lines.append(utf8('%s %s HTTP/1.1' % (start_line[0], start_line[1])))
# Client requests with a non-empty body must have either a
# Content-Length or a Transfer-Encoding.
self._chunking_output = (
start_line.method in ('POST', 'PUT', 'PATCH') and
'Content-Length' not in headers and
'Transfer-Encoding' not in headers)
else:
self._response_start_line = start_line
lines.append(utf8('HTTP/1.1 %d %s' % (start_line[1], start_line[2])))
self._chunking_output = (
# TODO: should this use
# self._request_start_line.version or
# start_line.version?
self._request_start_line.version == 'HTTP/1.1' and
# 1xx, 204 and 304 responses have no body (not even a zero-length
# body), and so should not have either Content-Length or
# Transfer-Encoding headers.
start_line.code not in (204, 304) and
(start_line.code < 100 or start_line.code >= 200) and
# No need to chunk the output if a Content-Length is specified.
'Content-Length' not in headers and
# Applications are discouraged from touching Transfer-Encoding,
# but if they do, leave it alone.
'Transfer-Encoding' not in headers)
# If connection to a 1.1 client will be closed, inform client
if (self._request_start_line.version == 'HTTP/1.1' and self._disconnect_on_finish):
headers['Connection'] = 'close'
# If a 1.0 client asked for keep-alive, add the header.
if (self._request_start_line.version == 'HTTP/1.0' and
self._request_headers.get('Connection', '').lower() == 'keep-alive'):
headers['Connection'] = 'Keep-Alive'
if self._chunking_output:
headers['Transfer-Encoding'] = 'chunked'
if (not self.is_client and
(self._request_start_line.method == 'HEAD' or
start_line.code == 304)):
self._expected_content_remaining = 0
elif 'Content-Length' in headers:
self._expected_content_remaining = int(headers['Content-Length'])
else:
self._expected_content_remaining = None
# TODO: headers are supposed to be of type str, but we still have some
# cases that let bytes slip through. Remove these native_str calls when those
# are fixed.
header_lines = (native_str(n) + ": " + native_str(v) for n, v in headers.get_all())
if PY3:
lines.extend(l.encode('latin1') for l in header_lines)
else:
lines.extend(header_lines)
for line in lines:
if b'\n' in line:
raise ValueError('Newline in header: ' + repr(line))
future = None
if self.stream.closed():
future = self._write_future = Future()
future.set_exception(iostream.StreamClosedError())
future.exception()
else:
if callback is not None:
warnings.warn("callback argument is deprecated, use returned Future instead",
DeprecationWarning)
self._write_callback = stack_context.wrap(callback)
else:
future = self._write_future = Future()
data = b"\r\n".join(lines) + b"\r\n\r\n"
if chunk:
data += self._format_chunk(chunk)
self._pending_write = self.stream.write(data)
future_add_done_callback(self._pending_write, self._on_write_complete)
return future
def _format_chunk(self, chunk):
if self._expected_content_remaining is not None:
self._expected_content_remaining -= len(chunk)
if self._expected_content_remaining < 0:
# Close the stream now to stop further framing errors.
self.stream.close()
raise httputil.HTTPOutputError(
"Tried to write more data than Content-Length")
if self._chunking_output and chunk:
# Don't write out empty chunks because that means END-OF-STREAM
# with chunked encoding
return utf8("%x" % len(chunk)) + b"\r\n" + chunk + b"\r\n"
else:
return chunk
def write(self, chunk, callback=None):
"""Implements `.HTTPConnection.write`.
For backwards compatibility it is allowed but deprecated to
skip `write_headers` and instead call `write()` with a
pre-encoded header block.
"""
future = None
if self.stream.closed():
future = self._write_future = Future()
self._write_future.set_exception(iostream.StreamClosedError())
self._write_future.exception()
else:
if callback is not None:
warnings.warn("callback argument is deprecated, use returned Future instead",
DeprecationWarning)
self._write_callback = stack_context.wrap(callback)
else:
future = self._write_future = Future()
self._pending_write = self.stream.write(self._format_chunk(chunk))
self._pending_write.add_done_callback(self._on_write_complete)
return future
def finish(self):
"""Implements `.HTTPConnection.finish`."""
if (self._expected_content_remaining is not None and
self._expected_content_remaining != 0 and
not self.stream.closed()):
self.stream.close()
raise httputil.HTTPOutputError(
"Tried to write %d bytes less than Content-Length" %
self._expected_content_remaining)
if self._chunking_output:
if not self.stream.closed():
self._pending_write = self.stream.write(b"0\r\n\r\n")
self._pending_write.add_done_callback(self._on_write_complete)
self._write_finished = True
# If the app finished the request while we're still reading,
# divert any remaining data away from the delegate and
# close the connection when we're done sending our response.
# Closing the connection is the only way to avoid reading the
# whole input body.
if not self._read_finished:
self._disconnect_on_finish = True
# No more data is coming, so instruct TCP to send any remaining
# data immediately instead of waiting for a full packet or ack.
self.stream.set_nodelay(True)
if self._pending_write is None:
self._finish_request(None)
else:
future_add_done_callback(self._pending_write, self._finish_request)
def _on_write_complete(self, future):
exc = future.exception()
if exc is not None and not isinstance(exc, iostream.StreamClosedError):
future.result()
if self._write_callback is not None:
callback = self._write_callback
self._write_callback = None
self.stream.io_loop.add_callback(callback)
if self._write_future is not None:
future = self._write_future
self._write_future = None
future_set_result_unless_cancelled(future, None)
def _can_keep_alive(self, start_line, headers):
if self.params.no_keep_alive:
return False
connection_header = headers.get("Connection")
if connection_header is not None:
connection_header = connection_header.lower()
if start_line.version == "HTTP/1.1":
return connection_header != "close"
elif ("Content-Length" in headers or
headers.get("Transfer-Encoding", "").lower() == "chunked" or
getattr(start_line, 'method', None) in ("HEAD", "GET")):
# start_line may be a request or response start line; only
# the former has a method attribute.
return connection_header == "keep-alive"
return False
def _finish_request(self, future):
self._clear_callbacks()
if not self.is_client and self._disconnect_on_finish:
self.close()
return
# Turn Nagle's algorithm back on, leaving the stream in its
# default state for the next request.
self.stream.set_nodelay(False)
if not self._finish_future.done():
future_set_result_unless_cancelled(self._finish_future, None)
def _parse_headers(self, data):
# The lstrip removes newlines that some implementations sometimes
# insert between messages of a reused connection. Per RFC 7230,
# we SHOULD ignore at least one empty line before the request.
# http://tools.ietf.org/html/rfc7230#section-3.5
data = native_str(data.decode('latin1')).lstrip("\r\n")
# RFC 7230 section allows for both CRLF and bare LF.
eol = data.find("\n")
start_line = data[:eol].rstrip("\r")
headers = httputil.HTTPHeaders.parse(data[eol:])
return start_line, headers
def _read_body(self, code, headers, delegate):
if "Content-Length" in headers:
if "Transfer-Encoding" in headers:
# Response cannot contain both Content-Length and
# Transfer-Encoding headers.
# http://tools.ietf.org/html/rfc7230#section-3.3.3
raise httputil.HTTPInputError(
"Response with both Transfer-Encoding and Content-Length")
if "," in headers["Content-Length"]:
# Proxies sometimes cause Content-Length headers to get
# duplicated. If all the values are identical then we can
# use them but if they differ it's an error.
pieces = re.split(r',\s*', headers["Content-Length"])
if any(i != pieces[0] for i in pieces):
raise httputil.HTTPInputError(
"Multiple unequal Content-Lengths: %r" %
headers["Content-Length"])
headers["Content-Length"] = pieces[0]
try:
content_length = int(headers["Content-Length"])
except ValueError:
# Handles non-integer Content-Length value.
raise httputil.HTTPInputError(
"Only integer Content-Length is allowed: %s" % headers["Content-Length"])
if content_length > self._max_body_size:
raise httputil.HTTPInputError("Content-Length too long")
else:
content_length = None
if code == 204:
# This response code is not allowed to have a non-empty body,
# and has an implicit length of zero instead of read-until-close.
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
if ("Transfer-Encoding" in headers or
content_length not in (None, 0)):
raise httputil.HTTPInputError(
"Response with code %d should not have body" % code)
content_length = 0
if content_length is not None:
return self._read_fixed_body(content_length, delegate)
if headers.get("Transfer-Encoding", "").lower() == "chunked":
return self._read_chunked_body(delegate)
if self.is_client:
return self._read_body_until_close(delegate)
return None
@gen.coroutine
def _read_fixed_body(self, content_length, delegate):
while content_length > 0:
body = yield self.stream.read_bytes(
min(self.params.chunk_size, content_length), partial=True)
content_length -= len(body)
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
ret = delegate.data_received(body)
if ret is not None:
yield ret
@gen.coroutine
def _read_chunked_body(self, delegate):
# TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
total_size = 0
while True:
chunk_len = yield self.stream.read_until(b"\r\n", max_bytes=64)
chunk_len = int(chunk_len.strip(), 16)
if chunk_len == 0:
crlf = yield self.stream.read_bytes(2)
if crlf != b'\r\n':
raise httputil.HTTPInputError("improperly terminated chunked request")
return
total_size += chunk_len
if total_size > self._max_body_size:
raise httputil.HTTPInputError("chunked body too large")
bytes_to_read = chunk_len
while bytes_to_read:
chunk = yield self.stream.read_bytes(
min(bytes_to_read, self.params.chunk_size), partial=True)
bytes_to_read -= len(chunk)
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
ret = delegate.data_received(chunk)
if ret is not None:
yield ret
# chunk ends with \r\n
crlf = yield self.stream.read_bytes(2)
assert crlf == b"\r\n"
@gen.coroutine
def _read_body_until_close(self, delegate):
body = yield self.stream.read_until_close()
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
delegate.data_received(body)
class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
"""Wraps an `HTTPMessageDelegate` to decode ``Content-Encoding: gzip``.
"""
def __init__(self, delegate, chunk_size):
self._delegate = delegate
self._chunk_size = chunk_size
self._decompressor = None
def headers_received(self, start_line, headers):
if headers.get("Content-Encoding") == "gzip":
self._decompressor = GzipDecompressor()
# Downstream delegates will only see uncompressed data,
# so rename the content-encoding header.
# (but note that curl_httpclient doesn't do this).
headers.add("X-Consumed-Content-Encoding",
headers["Content-Encoding"])
del headers["Content-Encoding"]
return self._delegate.headers_received(start_line, headers)
@gen.coroutine
def data_received(self, chunk):
if self._decompressor:
compressed_data = chunk
while compressed_data:
decompressed = self._decompressor.decompress(
compressed_data, self._chunk_size)
if decompressed:
ret = self._delegate.data_received(decompressed)
if ret is not None:
yield ret
compressed_data = self._decompressor.unconsumed_tail
else:
ret = self._delegate.data_received(chunk)
if ret is not None:
yield ret
def finish(self):
if self._decompressor is not None:
tail = self._decompressor.flush()
if tail:
# I believe the tail will always be empty (i.e.
# decompress will return all it can). The purpose
# of the flush call is to detect errors such
# as truncated input. But in case it ever returns
# anything, treat it as an extra chunk
self._delegate.data_received(tail)
return self._delegate.finish()
def on_connection_close(self):
return self._delegate.on_connection_close()
class HTTP1ServerConnection(object):
"""An HTTP/1.x server."""
def __init__(self, stream, params=None, context=None):
"""
:arg stream: an `.IOStream`
:arg params: a `.HTTP1ConnectionParameters` or None
:arg context: an opaque application-defined object that is accessible
as ``connection.context``
"""
self.stream = stream
if params is None:
params = HTTP1ConnectionParameters()
self.params = params
self.context = context
self._serving_future = None
@gen.coroutine
def close(self):
"""Closes the connection.
Returns a `.Future` that resolves after the serving loop has exited.
"""
self.stream.close()
# Block until the serving loop is done, but ignore any exceptions
# (start_serving is already responsible for logging them).
try:
yield self._serving_future
except Exception:
pass
def start_serving(self, delegate):
"""Starts serving requests on this connection.
:arg delegate: a `.HTTPServerConnectionDelegate`
"""
assert isinstance(delegate, httputil.HTTPServerConnectionDelegate)
self._serving_future = self._server_request_loop(delegate)
# Register the future on the IOLoop so its errors get logged.
self.stream.io_loop.add_future(self._serving_future,
lambda f: f.result())
@gen.coroutine
def _server_request_loop(self, delegate):
try:
while True:
conn = HTTP1Connection(self.stream, False,
self.params, self.context)
request_delegate = delegate.start_request(self, conn)
try:
ret = yield conn.read_response(request_delegate)
except (iostream.StreamClosedError,
iostream.UnsatisfiableReadError):
return
except _QuietException:
# This exception was already logged.
conn.close()
return
except Exception:
gen_log.error("Uncaught exception", exc_info=True)
conn.close()
return
if not ret:
return
yield gen.moment
finally:
delegate.on_close(self)

748
lib/tornado/httpclient.py Executable file
View File

@@ -0,0 +1,748 @@
"""Blocking and non-blocking HTTP client interfaces.
This module defines a common interface shared by two implementations,
``simple_httpclient`` and ``curl_httpclient``. Applications may either
instantiate their chosen implementation class directly or use the
`AsyncHTTPClient` class from this module, which selects an implementation
that can be overridden with the `AsyncHTTPClient.configure` method.
The default implementation is ``simple_httpclient``, and this is expected
to be suitable for most users' needs. However, some applications may wish
to switch to ``curl_httpclient`` for reasons such as the following:
* ``curl_httpclient`` has some features not found in ``simple_httpclient``,
including support for HTTP proxies and the ability to use a specified
network interface.
* ``curl_httpclient`` is more likely to be compatible with sites that are
not-quite-compliant with the HTTP spec, or sites that use little-exercised
features of HTTP.
* ``curl_httpclient`` is faster.
* ``curl_httpclient`` was the default prior to Tornado 2.0.
Note that if you are using ``curl_httpclient``, it is highly
recommended that you use a recent version of ``libcurl`` and
``pycurl``. Currently the minimum supported version of libcurl is
7.22.0, and the minimum version of pycurl is 7.18.2. It is highly
recommended that your ``libcurl`` installation is built with
asynchronous DNS resolver (threaded or c-ares), otherwise you may
encounter various problems with request timeouts (for more
information, see
http://curl.haxx.se/libcurl/c/curl_easy_setopt.html#CURLOPTCONNECTTIMEOUTMS
and comments in curl_httpclient.py).
To select ``curl_httpclient``, call `AsyncHTTPClient.configure` at startup::
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
"""
from __future__ import absolute_import, division, print_function
import functools
import time
import warnings
import weakref
from tornado.concurrent import Future, future_set_result_unless_cancelled
from tornado.escape import utf8, native_str
from tornado import gen, httputil, stack_context
from tornado.ioloop import IOLoop
from tornado.util import Configurable
class HTTPClient(object):
"""A blocking HTTP client.
This interface is provided to make it easier to share code between
synchronous and asynchronous applications. Applications that are
running an `.IOLoop` must use `AsyncHTTPClient` instead.
Typical usage looks like this::
http_client = httpclient.HTTPClient()
try:
response = http_client.fetch("http://www.google.com/")
print(response.body)
except httpclient.HTTPError as e:
# HTTPError is raised for non-200 responses; the response
# can be found in e.response.
print("Error: " + str(e))
except Exception as e:
# Other errors are possible, such as IOError.
print("Error: " + str(e))
http_client.close()
.. versionchanged:: 5.0
Due to limitations in `asyncio`, it is no longer possible to
use the synchronous ``HTTPClient`` while an `.IOLoop` is running.
Use `AsyncHTTPClient` instead.
"""
def __init__(self, async_client_class=None, **kwargs):
# Initialize self._closed at the beginning of the constructor
# so that an exception raised here doesn't lead to confusing
# failures in __del__.
self._closed = True
self._io_loop = IOLoop(make_current=False)
if async_client_class is None:
async_client_class = AsyncHTTPClient
# Create the client while our IOLoop is "current", without
# clobbering the thread's real current IOLoop (if any).
self._async_client = self._io_loop.run_sync(
gen.coroutine(lambda: async_client_class(**kwargs)))
self._closed = False
def __del__(self):
self.close()
def close(self):
"""Closes the HTTPClient, freeing any resources used."""
if not self._closed:
self._async_client.close()
self._io_loop.close()
self._closed = True
def fetch(self, request, **kwargs):
"""Executes a request, returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
If it is a string, we construct an `HTTPRequest` using any additional
kwargs: ``HTTPRequest(request, **kwargs)``
If an error occurs during the fetch, we raise an `HTTPError` unless
the ``raise_error`` keyword argument is set to False.
"""
response = self._io_loop.run_sync(functools.partial(
self._async_client.fetch, request, **kwargs))
return response
class AsyncHTTPClient(Configurable):
"""An non-blocking HTTP client.
Example usage::
async def f():
http_client = AsyncHTTPClient()
try:
response = await http_client.fetch("http://www.google.com")
except Exception as e:
print("Error: %s" % e)
else:
print(response.body)
The constructor for this class is magic in several respects: It
actually creates an instance of an implementation-specific
subclass, and instances are reused as a kind of pseudo-singleton
(one per `.IOLoop`). The keyword argument ``force_instance=True``
can be used to suppress this singleton behavior. Unless
``force_instance=True`` is used, no arguments should be passed to
the `AsyncHTTPClient` constructor. The implementation subclass as
well as arguments to its constructor can be set with the static
method `configure()`
All `AsyncHTTPClient` implementations support a ``defaults``
keyword argument, which can be used to set default values for
`HTTPRequest` attributes. For example::
AsyncHTTPClient.configure(
None, defaults=dict(user_agent="MyUserAgent"))
# or with force_instance:
client = AsyncHTTPClient(force_instance=True,
defaults=dict(user_agent="MyUserAgent"))
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
@classmethod
def configurable_base(cls):
return AsyncHTTPClient
@classmethod
def configurable_default(cls):
from tornado.simple_httpclient import SimpleAsyncHTTPClient
return SimpleAsyncHTTPClient
@classmethod
def _async_clients(cls):
attr_name = '_async_client_dict_' + cls.__name__
if not hasattr(cls, attr_name):
setattr(cls, attr_name, weakref.WeakKeyDictionary())
return getattr(cls, attr_name)
def __new__(cls, force_instance=False, **kwargs):
io_loop = IOLoop.current()
if force_instance:
instance_cache = None
else:
instance_cache = cls._async_clients()
if instance_cache is not None and io_loop in instance_cache:
return instance_cache[io_loop]
instance = super(AsyncHTTPClient, cls).__new__(cls, **kwargs)
# Make sure the instance knows which cache to remove itself from.
# It can't simply call _async_clients() because we may be in
# __new__(AsyncHTTPClient) but instance.__class__ may be
# SimpleAsyncHTTPClient.
instance._instance_cache = instance_cache
if instance_cache is not None:
instance_cache[instance.io_loop] = instance
return instance
def initialize(self, defaults=None):
self.io_loop = IOLoop.current()
self.defaults = dict(HTTPRequest._DEFAULTS)
if defaults is not None:
self.defaults.update(defaults)
self._closed = False
def close(self):
"""Destroys this HTTP client, freeing any file descriptors used.
This method is **not needed in normal use** due to the way
that `AsyncHTTPClient` objects are transparently reused.
``close()`` is generally only necessary when either the
`.IOLoop` is also being closed, or the ``force_instance=True``
argument was used when creating the `AsyncHTTPClient`.
No other methods may be called on the `AsyncHTTPClient` after
``close()``.
"""
if self._closed:
return
self._closed = True
if self._instance_cache is not None:
if self._instance_cache.get(self.io_loop) is not self:
raise RuntimeError("inconsistent AsyncHTTPClient cache")
del self._instance_cache[self.io_loop]
def fetch(self, request, callback=None, raise_error=True, **kwargs):
"""Executes a request, asynchronously returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
If it is a string, we construct an `HTTPRequest` using any additional
kwargs: ``HTTPRequest(request, **kwargs)``
This method returns a `.Future` whose result is an
`HTTPResponse`. By default, the ``Future`` will raise an
`HTTPError` if the request returned a non-200 response code
(other errors may also be raised if the server could not be
contacted). Instead, if ``raise_error`` is set to False, the
response will always be returned regardless of the response
code.
If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
In the callback interface, `HTTPError` is not automatically raised.
Instead, you must check the response's ``error`` attribute or
call its `~HTTPResponse.rethrow` method.
.. deprecated:: 5.1
The ``callback`` argument is deprecated and will be removed
in 6.0. Use the returned `.Future` instead.
The ``raise_error=False`` argument currently suppresses
*all* errors, encapsulating them in `HTTPResponse` objects
with a 599 response code. This will change in Tornado 6.0:
``raise_error=False`` will only affect the `HTTPError`
raised when a non-200 response code is used.
"""
if self._closed:
raise RuntimeError("fetch() called on closed AsyncHTTPClient")
if not isinstance(request, HTTPRequest):
request = HTTPRequest(url=request, **kwargs)
else:
if kwargs:
raise ValueError("kwargs can't be used if request is an HTTPRequest object")
# We may modify this (to add Host, Accept-Encoding, etc),
# so make sure we don't modify the caller's object. This is also
# where normal dicts get converted to HTTPHeaders objects.
request.headers = httputil.HTTPHeaders(request.headers)
request = _RequestProxy(request, self.defaults)
future = Future()
if callback is not None:
warnings.warn("callback arguments are deprecated, use the returned Future instead",
DeprecationWarning)
callback = stack_context.wrap(callback)
def handle_future(future):
exc = future.exception()
if isinstance(exc, HTTPError) and exc.response is not None:
response = exc.response
elif exc is not None:
response = HTTPResponse(
request, 599, error=exc,
request_time=time.time() - request.start_time)
else:
response = future.result()
self.io_loop.add_callback(callback, response)
future.add_done_callback(handle_future)
def handle_response(response):
if raise_error and response.error:
if isinstance(response.error, HTTPError):
response.error.response = response
future.set_exception(response.error)
else:
if response.error and not response._error_is_response_code:
warnings.warn("raise_error=False will allow '%s' to be raised in the future" %
response.error, DeprecationWarning)
future_set_result_unless_cancelled(future, response)
self.fetch_impl(request, handle_response)
return future
def fetch_impl(self, request, callback):
raise NotImplementedError()
@classmethod
def configure(cls, impl, **kwargs):
"""Configures the `AsyncHTTPClient` subclass to use.
``AsyncHTTPClient()`` actually creates an instance of a subclass.
This method may be called with either a class object or the
fully-qualified name of such a class (or ``None`` to use the default,
``SimpleAsyncHTTPClient``)
If additional keyword arguments are given, they will be passed
to the constructor of each subclass instance created. The
keyword argument ``max_clients`` determines the maximum number
of simultaneous `~AsyncHTTPClient.fetch()` operations that can
execute in parallel on each `.IOLoop`. Additional arguments
may be supported depending on the implementation class in use.
Example::
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
"""
super(AsyncHTTPClient, cls).configure(impl, **kwargs)
class HTTPRequest(object):
"""HTTP client request object."""
# Default values for HTTPRequest parameters.
# Merged with the values on the request object by AsyncHTTPClient
# implementations.
_DEFAULTS = dict(
connect_timeout=20.0,
request_timeout=20.0,
follow_redirects=True,
max_redirects=5,
decompress_response=True,
proxy_password='',
allow_nonstandard_methods=False,
validate_cert=True)
def __init__(self, url, method="GET", headers=None, body=None,
auth_username=None, auth_password=None, auth_mode=None,
connect_timeout=None, request_timeout=None,
if_modified_since=None, follow_redirects=None,
max_redirects=None, user_agent=None, use_gzip=None,
network_interface=None, streaming_callback=None,
header_callback=None, prepare_curl_callback=None,
proxy_host=None, proxy_port=None, proxy_username=None,
proxy_password=None, proxy_auth_mode=None,
allow_nonstandard_methods=None, validate_cert=None,
ca_certs=None, allow_ipv6=None, client_key=None,
client_cert=None, body_producer=None,
expect_100_continue=False, decompress_response=None,
ssl_options=None):
r"""All parameters except ``url`` are optional.
:arg str url: URL to fetch
:arg str method: HTTP method, e.g. "GET" or "POST"
:arg headers: Additional HTTP headers to pass on the request
:type headers: `~tornado.httputil.HTTPHeaders` or `dict`
:arg body: HTTP request body as a string (byte or unicode; if unicode
the utf-8 encoding will be used)
:arg body_producer: Callable used for lazy/asynchronous request bodies.
It is called with one argument, a ``write`` function, and should
return a `.Future`. It should call the write function with new
data as it becomes available. The write function returns a
`.Future` which can be used for flow control.
Only one of ``body`` and ``body_producer`` may
be specified. ``body_producer`` is not supported on
``curl_httpclient``. When using ``body_producer`` it is recommended
to pass a ``Content-Length`` in the headers as otherwise chunked
encoding will be used, and many servers do not support chunked
encoding on requests. New in Tornado 4.0
:arg str auth_username: Username for HTTP authentication
:arg str auth_password: Password for HTTP authentication
:arg str auth_mode: Authentication mode; default is "basic".
Allowed values are implementation-defined; ``curl_httpclient``
supports "basic" and "digest"; ``simple_httpclient`` only supports
"basic"
:arg float connect_timeout: Timeout for initial connection in seconds,
default 20 seconds
:arg float request_timeout: Timeout for entire request in seconds,
default 20 seconds
:arg if_modified_since: Timestamp for ``If-Modified-Since`` header
:type if_modified_since: `datetime` or `float`
:arg bool follow_redirects: Should redirects be followed automatically
or return the 3xx response? Default True.
:arg int max_redirects: Limit for ``follow_redirects``, default 5.
:arg str user_agent: String to send as ``User-Agent`` header
:arg bool decompress_response: Request a compressed response from
the server and decompress it after downloading. Default is True.
New in Tornado 4.0.
:arg bool use_gzip: Deprecated alias for ``decompress_response``
since Tornado 4.0.
:arg str network_interface: Network interface to use for request.
``curl_httpclient`` only; see note below.
:arg collections.abc.Callable streaming_callback: If set, ``streaming_callback`` will
be run with each chunk of data as it is received, and
``HTTPResponse.body`` and ``HTTPResponse.buffer`` will be empty in
the final response.
:arg collections.abc.Callable header_callback: If set, ``header_callback`` will
be run with each header line as it is received (including the
first line, e.g. ``HTTP/1.0 200 OK\r\n``, and a final line
containing only ``\r\n``. All lines include the trailing newline
characters). ``HTTPResponse.headers`` will be empty in the final
response. This is most useful in conjunction with
``streaming_callback``, because it's the only way to get access to
header data while the request is in progress.
:arg collections.abc.Callable prepare_curl_callback: If set, will be called with
a ``pycurl.Curl`` object to allow the application to make additional
``setopt`` calls.
:arg str proxy_host: HTTP proxy hostname. To use proxies,
``proxy_host`` and ``proxy_port`` must be set; ``proxy_username``,
``proxy_pass`` and ``proxy_auth_mode`` are optional. Proxies are
currently only supported with ``curl_httpclient``.
:arg int proxy_port: HTTP proxy port
:arg str proxy_username: HTTP proxy username
:arg str proxy_password: HTTP proxy password
:arg str proxy_auth_mode: HTTP proxy Authentication mode;
default is "basic". supports "basic" and "digest"
:arg bool allow_nonstandard_methods: Allow unknown values for ``method``
argument? Default is False.
:arg bool validate_cert: For HTTPS requests, validate the server's
certificate? Default is True.
:arg str ca_certs: filename of CA certificates in PEM format,
or None to use defaults. See note below when used with
``curl_httpclient``.
:arg str client_key: Filename for client SSL key, if any. See
note below when used with ``curl_httpclient``.
:arg str client_cert: Filename for client SSL certificate, if any.
See note below when used with ``curl_httpclient``.
:arg ssl.SSLContext ssl_options: `ssl.SSLContext` object for use in
``simple_httpclient`` (unsupported by ``curl_httpclient``).
Overrides ``validate_cert``, ``ca_certs``, ``client_key``,
and ``client_cert``.
:arg bool allow_ipv6: Use IPv6 when available? Default is true.
:arg bool expect_100_continue: If true, send the
``Expect: 100-continue`` header and wait for a continue response
before sending the request body. Only supported with
simple_httpclient.
.. note::
When using ``curl_httpclient`` certain options may be
inherited by subsequent fetches because ``pycurl`` does
not allow them to be cleanly reset. This applies to the
``ca_certs``, ``client_key``, ``client_cert``, and
``network_interface`` arguments. If you use these
options, you should pass them on every request (you don't
have to always use the same values, but it's not possible
to mix requests that specify these options with ones that
use the defaults).
.. versionadded:: 3.1
The ``auth_mode`` argument.
.. versionadded:: 4.0
The ``body_producer`` and ``expect_100_continue`` arguments.
.. versionadded:: 4.2
The ``ssl_options`` argument.
.. versionadded:: 4.5
The ``proxy_auth_mode`` argument.
"""
# Note that some of these attributes go through property setters
# defined below.
self.headers = headers
if if_modified_since:
self.headers["If-Modified-Since"] = httputil.format_timestamp(
if_modified_since)
self.proxy_host = proxy_host
self.proxy_port = proxy_port
self.proxy_username = proxy_username
self.proxy_password = proxy_password
self.proxy_auth_mode = proxy_auth_mode
self.url = url
self.method = method
self.body = body
self.body_producer = body_producer
self.auth_username = auth_username
self.auth_password = auth_password
self.auth_mode = auth_mode
self.connect_timeout = connect_timeout
self.request_timeout = request_timeout
self.follow_redirects = follow_redirects
self.max_redirects = max_redirects
self.user_agent = user_agent
if decompress_response is not None:
self.decompress_response = decompress_response
else:
self.decompress_response = use_gzip
self.network_interface = network_interface
self.streaming_callback = streaming_callback
self.header_callback = header_callback
self.prepare_curl_callback = prepare_curl_callback
self.allow_nonstandard_methods = allow_nonstandard_methods
self.validate_cert = validate_cert
self.ca_certs = ca_certs
self.allow_ipv6 = allow_ipv6
self.client_key = client_key
self.client_cert = client_cert
self.ssl_options = ssl_options
self.expect_100_continue = expect_100_continue
self.start_time = time.time()
@property
def headers(self):
return self._headers
@headers.setter
def headers(self, value):
if value is None:
self._headers = httputil.HTTPHeaders()
else:
self._headers = value
@property
def body(self):
return self._body
@body.setter
def body(self, value):
self._body = utf8(value)
@property
def body_producer(self):
return self._body_producer
@body_producer.setter
def body_producer(self, value):
self._body_producer = stack_context.wrap(value)
@property
def streaming_callback(self):
return self._streaming_callback
@streaming_callback.setter
def streaming_callback(self, value):
self._streaming_callback = stack_context.wrap(value)
@property
def header_callback(self):
return self._header_callback
@header_callback.setter
def header_callback(self, value):
self._header_callback = stack_context.wrap(value)
@property
def prepare_curl_callback(self):
return self._prepare_curl_callback
@prepare_curl_callback.setter
def prepare_curl_callback(self, value):
self._prepare_curl_callback = stack_context.wrap(value)
class HTTPResponse(object):
"""HTTP Response object.
Attributes:
* request: HTTPRequest object
* code: numeric HTTP status code, e.g. 200 or 404
* reason: human-readable reason phrase describing the status code
* headers: `tornado.httputil.HTTPHeaders` object
* effective_url: final location of the resource after following any
redirects
* buffer: ``cStringIO`` object for response body
* body: response body as bytes (created on demand from ``self.buffer``)
* error: Exception object, if any
* request_time: seconds from request start to finish. Includes all network
operations from DNS resolution to receiving the last byte of data.
Does not include time spent in the queue (due to the ``max_clients`` option).
If redirects were followed, only includes the final request.
* start_time: Time at which the HTTP operation started, based on `time.time`
(not the monotonic clock used by `.IOLoop.time`). May be ``None`` if the request
timed out while in the queue.
* time_info: dictionary of diagnostic timing information from the request.
Available data are subject to change, but currently uses timings
available from http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html,
plus ``queue``, which is the delay (if any) introduced by waiting for
a slot under `AsyncHTTPClient`'s ``max_clients`` setting.
.. versionadded:: 5.1
Added the ``start_time`` attribute.
.. versionchanged:: 5.1
The ``request_time`` attribute previously included time spent in the queue
for ``simple_httpclient``, but not in ``curl_httpclient``. Now queueing time
is excluded in both implementations. ``request_time`` is now more accurate for
``curl_httpclient`` because it uses a monotonic clock when available.
"""
def __init__(self, request, code, headers=None, buffer=None,
effective_url=None, error=None, request_time=None,
time_info=None, reason=None, start_time=None):
if isinstance(request, _RequestProxy):
self.request = request.request
else:
self.request = request
self.code = code
self.reason = reason or httputil.responses.get(code, "Unknown")
if headers is not None:
self.headers = headers
else:
self.headers = httputil.HTTPHeaders()
self.buffer = buffer
self._body = None
if effective_url is None:
self.effective_url = request.url
else:
self.effective_url = effective_url
self._error_is_response_code = False
if error is None:
if self.code < 200 or self.code >= 300:
self._error_is_response_code = True
self.error = HTTPError(self.code, message=self.reason,
response=self)
else:
self.error = None
else:
self.error = error
self.start_time = start_time
self.request_time = request_time
self.time_info = time_info or {}
@property
def body(self):
if self.buffer is None:
return None
elif self._body is None:
self._body = self.buffer.getvalue()
return self._body
def rethrow(self):
"""If there was an error on the request, raise an `HTTPError`."""
if self.error:
raise self.error
def __repr__(self):
args = ",".join("%s=%r" % i for i in sorted(self.__dict__.items()))
return "%s(%s)" % (self.__class__.__name__, args)
class HTTPClientError(Exception):
"""Exception thrown for an unsuccessful HTTP request.
Attributes:
* ``code`` - HTTP error integer error code, e.g. 404. Error code 599 is
used when no HTTP response was received, e.g. for a timeout.
* ``response`` - `HTTPResponse` object, if any.
Note that if ``follow_redirects`` is False, redirects become HTTPErrors,
and you can look at ``error.response.headers['Location']`` to see the
destination of the redirect.
.. versionchanged:: 5.1
Renamed from ``HTTPError`` to ``HTTPClientError`` to avoid collisions with
`tornado.web.HTTPError`. The name ``tornado.httpclient.HTTPError`` remains
as an alias.
"""
def __init__(self, code, message=None, response=None):
self.code = code
self.message = message or httputil.responses.get(code, "Unknown")
self.response = response
super(HTTPClientError, self).__init__(code, message, response)
def __str__(self):
return "HTTP %d: %s" % (self.code, self.message)
# There is a cyclic reference between self and self.response,
# which breaks the default __repr__ implementation.
# (especially on pypy, which doesn't have the same recursion
# detection as cpython).
__repr__ = __str__
HTTPError = HTTPClientError
class _RequestProxy(object):
"""Combines an object with a dictionary of defaults.
Used internally by AsyncHTTPClient implementations.
"""
def __init__(self, request, defaults):
self.request = request
self.defaults = defaults
def __getattr__(self, name):
request_attr = getattr(self.request, name)
if request_attr is not None:
return request_attr
elif self.defaults is not None:
return self.defaults.get(name, None)
else:
return None
def main():
from tornado.options import define, options, parse_command_line
define("print_headers", type=bool, default=False)
define("print_body", type=bool, default=True)
define("follow_redirects", type=bool, default=True)
define("validate_cert", type=bool, default=True)
define("proxy_host", type=str)
define("proxy_port", type=int)
args = parse_command_line()
client = HTTPClient()
for arg in args:
try:
response = client.fetch(arg,
follow_redirects=options.follow_redirects,
validate_cert=options.validate_cert,
proxy_host=options.proxy_host,
proxy_port=options.proxy_port,
)
except HTTPError as e:
if e.response is not None:
response = e.response
else:
raise
if options.print_headers:
print(response.headers)
if options.print_body:
print(native_str(response.body))
client.close()
if __name__ == "__main__":
main()

330
lib/tornado/httpserver.py Executable file
View File

@@ -0,0 +1,330 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A non-blocking, single-threaded HTTP server.
Typical applications have little direct interaction with the `HTTPServer`
class except to start a server at the beginning of the process
(and even that is often done indirectly via `tornado.web.Application.listen`).
.. versionchanged:: 4.0
The ``HTTPRequest`` class that used to live in this module has been moved
to `tornado.httputil.HTTPServerRequest`. The old name remains as an alias.
"""
from __future__ import absolute_import, division, print_function
import socket
from tornado.escape import native_str
from tornado.http1connection import HTTP1ServerConnection, HTTP1ConnectionParameters
from tornado import gen
from tornado import httputil
from tornado import iostream
from tornado import netutil
from tornado.tcpserver import TCPServer
from tornado.util import Configurable
class HTTPServer(TCPServer, Configurable,
httputil.HTTPServerConnectionDelegate):
r"""A non-blocking, single-threaded HTTP server.
A server is defined by a subclass of `.HTTPServerConnectionDelegate`,
or, for backwards compatibility, a callback that takes an
`.HTTPServerRequest` as an argument. The delegate is usually a
`tornado.web.Application`.
`HTTPServer` supports keep-alive connections by default
(automatically for HTTP/1.1, or for HTTP/1.0 when the client
requests ``Connection: keep-alive``).
If ``xheaders`` is ``True``, we support the
``X-Real-Ip``/``X-Forwarded-For`` and
``X-Scheme``/``X-Forwarded-Proto`` headers, which override the
remote IP and URI scheme/protocol for all requests. These headers
are useful when running Tornado behind a reverse proxy or load
balancer. The ``protocol`` argument can also be set to ``https``
if Tornado is run behind an SSL-decoding proxy that does not set one of
the supported ``xheaders``.
By default, when parsing the ``X-Forwarded-For`` header, Tornado will
select the last (i.e., the closest) address on the list of hosts as the
remote host IP address. To select the next server in the chain, a list of
trusted downstream hosts may be passed as the ``trusted_downstream``
argument. These hosts will be skipped when parsing the ``X-Forwarded-For``
header.
To make this server serve SSL traffic, send the ``ssl_options`` keyword
argument with an `ssl.SSLContext` object. For compatibility with older
versions of Python ``ssl_options`` may also be a dictionary of keyword
arguments for the `ssl.wrap_socket` method.::
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"),
os.path.join(data_dir, "mydomain.key"))
HTTPServer(application, ssl_options=ssl_ctx)
`HTTPServer` initialization follows one of three patterns (the
initialization methods are defined on `tornado.tcpserver.TCPServer`):
1. `~tornado.tcpserver.TCPServer.listen`: simple single-process::
server = HTTPServer(app)
server.listen(8888)
IOLoop.current().start()
In many cases, `tornado.web.Application.listen` can be used to avoid
the need to explicitly create the `HTTPServer`.
2. `~tornado.tcpserver.TCPServer.bind`/`~tornado.tcpserver.TCPServer.start`:
simple multi-process::
server = HTTPServer(app)
server.bind(8888)
server.start(0) # Forks multiple sub-processes
IOLoop.current().start()
When using this interface, an `.IOLoop` must *not* be passed
to the `HTTPServer` constructor. `~.TCPServer.start` will always start
the server on the default singleton `.IOLoop`.
3. `~tornado.tcpserver.TCPServer.add_sockets`: advanced multi-process::
sockets = tornado.netutil.bind_sockets(8888)
tornado.process.fork_processes(0)
server = HTTPServer(app)
server.add_sockets(sockets)
IOLoop.current().start()
The `~.TCPServer.add_sockets` interface is more complicated,
but it can be used with `tornado.process.fork_processes` to
give you more flexibility in when the fork happens.
`~.TCPServer.add_sockets` can also be used in single-process
servers if you want to create your listening sockets in some
way other than `tornado.netutil.bind_sockets`.
.. versionchanged:: 4.0
Added ``decompress_request``, ``chunk_size``, ``max_header_size``,
``idle_connection_timeout``, ``body_timeout``, ``max_body_size``
arguments. Added support for `.HTTPServerConnectionDelegate`
instances as ``request_callback``.
.. versionchanged:: 4.1
`.HTTPServerConnectionDelegate.start_request` is now called with
two arguments ``(server_conn, request_conn)`` (in accordance with the
documentation) instead of one ``(request_conn)``.
.. versionchanged:: 4.2
`HTTPServer` is now a subclass of `tornado.util.Configurable`.
.. versionchanged:: 4.5
Added the ``trusted_downstream`` argument.
.. versionchanged:: 5.0
The ``io_loop`` argument has been removed.
"""
def __init__(self, *args, **kwargs):
# Ignore args to __init__; real initialization belongs in
# initialize since we're Configurable. (there's something
# weird in initialization order between this class,
# Configurable, and TCPServer so we can't leave __init__ out
# completely)
pass
def initialize(self, request_callback, no_keep_alive=False,
xheaders=False, ssl_options=None, protocol=None,
decompress_request=False,
chunk_size=None, max_header_size=None,
idle_connection_timeout=None, body_timeout=None,
max_body_size=None, max_buffer_size=None,
trusted_downstream=None):
self.request_callback = request_callback
self.xheaders = xheaders
self.protocol = protocol
self.conn_params = HTTP1ConnectionParameters(
decompress=decompress_request,
chunk_size=chunk_size,
max_header_size=max_header_size,
header_timeout=idle_connection_timeout or 3600,
max_body_size=max_body_size,
body_timeout=body_timeout,
no_keep_alive=no_keep_alive)
TCPServer.__init__(self, ssl_options=ssl_options,
max_buffer_size=max_buffer_size,
read_chunk_size=chunk_size)
self._connections = set()
self.trusted_downstream = trusted_downstream
@classmethod
def configurable_base(cls):
return HTTPServer
@classmethod
def configurable_default(cls):
return HTTPServer
@gen.coroutine
def close_all_connections(self):
while self._connections:
# Peek at an arbitrary element of the set
conn = next(iter(self._connections))
yield conn.close()
def handle_stream(self, stream, address):
context = _HTTPRequestContext(stream, address,
self.protocol,
self.trusted_downstream)
conn = HTTP1ServerConnection(
stream, self.conn_params, context)
self._connections.add(conn)
conn.start_serving(self)
def start_request(self, server_conn, request_conn):
if isinstance(self.request_callback, httputil.HTTPServerConnectionDelegate):
delegate = self.request_callback.start_request(server_conn, request_conn)
else:
delegate = _CallableAdapter(self.request_callback, request_conn)
if self.xheaders:
delegate = _ProxyAdapter(delegate, request_conn)
return delegate
def on_close(self, server_conn):
self._connections.remove(server_conn)
class _CallableAdapter(httputil.HTTPMessageDelegate):
def __init__(self, request_callback, request_conn):
self.connection = request_conn
self.request_callback = request_callback
self.request = None
self.delegate = None
self._chunks = []
def headers_received(self, start_line, headers):
self.request = httputil.HTTPServerRequest(
connection=self.connection, start_line=start_line,
headers=headers)
def data_received(self, chunk):
self._chunks.append(chunk)
def finish(self):
self.request.body = b''.join(self._chunks)
self.request._parse_body()
self.request_callback(self.request)
def on_connection_close(self):
self._chunks = None
class _HTTPRequestContext(object):
def __init__(self, stream, address, protocol, trusted_downstream=None):
self.address = address
# Save the socket's address family now so we know how to
# interpret self.address even after the stream is closed
# and its socket attribute replaced with None.
if stream.socket is not None:
self.address_family = stream.socket.family
else:
self.address_family = None
# In HTTPServerRequest we want an IP, not a full socket address.
if (self.address_family in (socket.AF_INET, socket.AF_INET6) and
address is not None):
self.remote_ip = address[0]
else:
# Unix (or other) socket; fake the remote address.
self.remote_ip = '0.0.0.0'
if protocol:
self.protocol = protocol
elif isinstance(stream, iostream.SSLIOStream):
self.protocol = "https"
else:
self.protocol = "http"
self._orig_remote_ip = self.remote_ip
self._orig_protocol = self.protocol
self.trusted_downstream = set(trusted_downstream or [])
def __str__(self):
if self.address_family in (socket.AF_INET, socket.AF_INET6):
return self.remote_ip
elif isinstance(self.address, bytes):
# Python 3 with the -bb option warns about str(bytes),
# so convert it explicitly.
# Unix socket addresses are str on mac but bytes on linux.
return native_str(self.address)
else:
return str(self.address)
def _apply_xheaders(self, headers):
"""Rewrite the ``remote_ip`` and ``protocol`` fields."""
# Squid uses X-Forwarded-For, others use X-Real-Ip
ip = headers.get("X-Forwarded-For", self.remote_ip)
# Skip trusted downstream hosts in X-Forwarded-For list
for ip in (cand.strip() for cand in reversed(ip.split(','))):
if ip not in self.trusted_downstream:
break
ip = headers.get("X-Real-Ip", ip)
if netutil.is_valid_ip(ip):
self.remote_ip = ip
# AWS uses X-Forwarded-Proto
proto_header = headers.get(
"X-Scheme", headers.get("X-Forwarded-Proto",
self.protocol))
if proto_header:
# use only the last proto entry if there is more than one
# TODO: support trusting mutiple layers of proxied protocol
proto_header = proto_header.split(',')[-1].strip()
if proto_header in ("http", "https"):
self.protocol = proto_header
def _unapply_xheaders(self):
"""Undo changes from `_apply_xheaders`.
Xheaders are per-request so they should not leak to the next
request on the same connection.
"""
self.remote_ip = self._orig_remote_ip
self.protocol = self._orig_protocol
class _ProxyAdapter(httputil.HTTPMessageDelegate):
def __init__(self, delegate, request_conn):
self.connection = request_conn
self.delegate = delegate
def headers_received(self, start_line, headers):
self.connection.context._apply_xheaders(headers)
return self.delegate.headers_received(start_line, headers)
def data_received(self, chunk):
return self.delegate.data_received(chunk)
def finish(self):
self.delegate.finish()
self._cleanup()
def on_connection_close(self):
self.delegate.on_connection_close()
self._cleanup()
def _cleanup(self):
self.connection.context._unapply_xheaders()
HTTPRequest = httputil.HTTPServerRequest

1095
lib/tornado/httputil.py Executable file

File diff suppressed because it is too large Load Diff

1267
lib/tornado/ioloop.py Executable file

File diff suppressed because it is too large Load Diff

1757
lib/tornado/iostream.py Executable file

File diff suppressed because it is too large Load Diff

521
lib/tornado/locale.py Executable file
View File

@@ -0,0 +1,521 @@
# -*- coding: utf-8 -*-
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Translation methods for generating localized strings.
To load a locale and generate a translated string::
user_locale = tornado.locale.get("es_LA")
print(user_locale.translate("Sign out"))
`tornado.locale.get()` returns the closest matching locale, not necessarily the
specific locale you requested. You can support pluralization with
additional arguments to `~Locale.translate()`, e.g.::
people = [...]
message = user_locale.translate(
"%(list)s is online", "%(list)s are online", len(people))
print(message % {"list": user_locale.list(people)})
The first string is chosen if ``len(people) == 1``, otherwise the second
string is chosen.
Applications should call one of `load_translations` (which uses a simple
CSV format) or `load_gettext_translations` (which uses the ``.mo`` format
supported by `gettext` and related tools). If neither method is called,
the `Locale.translate` method will simply return the original string.
"""
from __future__ import absolute_import, division, print_function
import codecs
import csv
import datetime
from io import BytesIO
import numbers
import os
import re
from tornado import escape
from tornado.log import gen_log
from tornado.util import PY3
from tornado._locale_data import LOCALE_NAMES
_default_locale = "en_US"
_translations = {} # type: dict
_supported_locales = frozenset([_default_locale])
_use_gettext = False
CONTEXT_SEPARATOR = "\x04"
def get(*locale_codes):
"""Returns the closest match for the given locale codes.
We iterate over all given locale codes in order. If we have a tight
or a loose match for the code (e.g., "en" for "en_US"), we return
the locale. Otherwise we move to the next code in the list.
By default we return ``en_US`` if no translations are found for any of
the specified locales. You can change the default locale with
`set_default_locale()`.
"""
return Locale.get_closest(*locale_codes)
def set_default_locale(code):
"""Sets the default locale.
The default locale is assumed to be the language used for all strings
in the system. The translations loaded from disk are mappings from
the default locale to the destination locale. Consequently, you don't
need to create a translation file for the default locale.
"""
global _default_locale
global _supported_locales
_default_locale = code
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
def load_translations(directory, encoding=None):
"""Loads translations from CSV files in a directory.
Translations are strings with optional Python-style named placeholders
(e.g., ``My name is %(name)s``) and their associated translations.
The directory should have translation files of the form ``LOCALE.csv``,
e.g. ``es_GT.csv``. The CSV files should have two or three columns: string,
translation, and an optional plural indicator. Plural indicators should
be one of "plural" or "singular". A given string can have both singular
and plural forms. For example ``%(name)s liked this`` may have a
different verb conjugation depending on whether %(name)s is one
name or a list of names. There should be two rows in the CSV file for
that string, one with plural indicator "singular", and one "plural".
For strings with no verbs that would change on translation, simply
use "unknown" or the empty string (or don't include the column at all).
The file is read using the `csv` module in the default "excel" dialect.
In this format there should not be spaces after the commas.
If no ``encoding`` parameter is given, the encoding will be
detected automatically (among UTF-8 and UTF-16) if the file
contains a byte-order marker (BOM), defaulting to UTF-8 if no BOM
is present.
Example translation ``es_LA.csv``::
"I love you","Te amo"
"%(name)s liked this","A %(name)s les gustó esto","plural"
"%(name)s liked this","A %(name)s le gustó esto","singular"
.. versionchanged:: 4.3
Added ``encoding`` parameter. Added support for BOM-based encoding
detection, UTF-16, and UTF-8-with-BOM.
"""
global _translations
global _supported_locales
_translations = {}
for path in os.listdir(directory):
if not path.endswith(".csv"):
continue
locale, extension = path.split(".")
if not re.match("[a-z]+(_[A-Z]+)?$", locale):
gen_log.error("Unrecognized locale %r (path: %s)", locale,
os.path.join(directory, path))
continue
full_path = os.path.join(directory, path)
if encoding is None:
# Try to autodetect encoding based on the BOM.
with open(full_path, 'rb') as f:
data = f.read(len(codecs.BOM_UTF16_LE))
if data in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE):
encoding = 'utf-16'
else:
# utf-8-sig is "utf-8 with optional BOM". It's discouraged
# in most cases but is common with CSV files because Excel
# cannot read utf-8 files without a BOM.
encoding = 'utf-8-sig'
if PY3:
# python 3: csv.reader requires a file open in text mode.
# Force utf8 to avoid dependence on $LANG environment variable.
f = open(full_path, "r", encoding=encoding)
else:
# python 2: csv can only handle byte strings (in ascii-compatible
# encodings), which we decode below. Transcode everything into
# utf8 before passing it to csv.reader.
f = BytesIO()
with codecs.open(full_path, "r", encoding=encoding) as infile:
f.write(escape.utf8(infile.read()))
f.seek(0)
_translations[locale] = {}
for i, row in enumerate(csv.reader(f)):
if not row or len(row) < 2:
continue
row = [escape.to_unicode(c).strip() for c in row]
english, translation = row[:2]
if len(row) > 2:
plural = row[2] or "unknown"
else:
plural = "unknown"
if plural not in ("plural", "singular", "unknown"):
gen_log.error("Unrecognized plural indicator %r in %s line %d",
plural, path, i + 1)
continue
_translations[locale].setdefault(plural, {})[english] = translation
f.close()
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
gen_log.debug("Supported locales: %s", sorted(_supported_locales))
def load_gettext_translations(directory, domain):
"""Loads translations from `gettext`'s locale tree
Locale tree is similar to system's ``/usr/share/locale``, like::
{directory}/{lang}/LC_MESSAGES/{domain}.mo
Three steps are required to have your app translated:
1. Generate POT translation file::
xgettext --language=Python --keyword=_:1,2 -d mydomain file1.py file2.html etc
2. Merge against existing POT file::
msgmerge old.po mydomain.po > new.po
3. Compile::
msgfmt mydomain.po -o {directory}/pt_BR/LC_MESSAGES/mydomain.mo
"""
import gettext
global _translations
global _supported_locales
global _use_gettext
_translations = {}
for lang in os.listdir(directory):
if lang.startswith('.'):
continue # skip .svn, etc
if os.path.isfile(os.path.join(directory, lang)):
continue
try:
os.stat(os.path.join(directory, lang, "LC_MESSAGES", domain + ".mo"))
_translations[lang] = gettext.translation(domain, directory,
languages=[lang])
except Exception as e:
gen_log.error("Cannot load translation for '%s': %s", lang, str(e))
continue
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
_use_gettext = True
gen_log.debug("Supported locales: %s", sorted(_supported_locales))
def get_supported_locales():
"""Returns a list of all the supported locale codes."""
return _supported_locales
class Locale(object):
"""Object representing a locale.
After calling one of `load_translations` or `load_gettext_translations`,
call `get` or `get_closest` to get a Locale object.
"""
@classmethod
def get_closest(cls, *locale_codes):
"""Returns the closest match for the given locale code."""
for code in locale_codes:
if not code:
continue
code = code.replace("-", "_")
parts = code.split("_")
if len(parts) > 2:
continue
elif len(parts) == 2:
code = parts[0].lower() + "_" + parts[1].upper()
if code in _supported_locales:
return cls.get(code)
if parts[0].lower() in _supported_locales:
return cls.get(parts[0].lower())
return cls.get(_default_locale)
@classmethod
def get(cls, code):
"""Returns the Locale for the given locale code.
If it is not supported, we raise an exception.
"""
if not hasattr(cls, "_cache"):
cls._cache = {}
if code not in cls._cache:
assert code in _supported_locales
translations = _translations.get(code, None)
if translations is None:
locale = CSVLocale(code, {})
elif _use_gettext:
locale = GettextLocale(code, translations)
else:
locale = CSVLocale(code, translations)
cls._cache[code] = locale
return cls._cache[code]
def __init__(self, code, translations):
self.code = code
self.name = LOCALE_NAMES.get(code, {}).get("name", u"Unknown")
self.rtl = False
for prefix in ["fa", "ar", "he"]:
if self.code.startswith(prefix):
self.rtl = True
break
self.translations = translations
# Initialize strings for date formatting
_ = self.translate
self._months = [
_("January"), _("February"), _("March"), _("April"),
_("May"), _("June"), _("July"), _("August"),
_("September"), _("October"), _("November"), _("December")]
self._weekdays = [
_("Monday"), _("Tuesday"), _("Wednesday"), _("Thursday"),
_("Friday"), _("Saturday"), _("Sunday")]
def translate(self, message, plural_message=None, count=None):
"""Returns the translation for the given message for this locale.
If ``plural_message`` is given, you must also provide
``count``. We return ``plural_message`` when ``count != 1``,
and we return the singular form for the given message when
``count == 1``.
"""
raise NotImplementedError()
def pgettext(self, context, message, plural_message=None, count=None):
raise NotImplementedError()
def format_date(self, date, gmt_offset=0, relative=True, shorter=False,
full_format=False):
"""Formats the given date (which should be GMT).
By default, we return a relative time (e.g., "2 minutes ago"). You
can return an absolute date string with ``relative=False``.
You can force a full format date ("July 10, 1980") with
``full_format=True``.
This method is primarily intended for dates in the past.
For dates in the future, we fall back to full format.
"""
if isinstance(date, numbers.Real):
date = datetime.datetime.utcfromtimestamp(date)
now = datetime.datetime.utcnow()
if date > now:
if relative and (date - now).seconds < 60:
# Due to click skew, things are some things slightly
# in the future. Round timestamps in the immediate
# future down to now in relative mode.
date = now
else:
# Otherwise, future dates always use the full format.
full_format = True
local_date = date - datetime.timedelta(minutes=gmt_offset)
local_now = now - datetime.timedelta(minutes=gmt_offset)
local_yesterday = local_now - datetime.timedelta(hours=24)
difference = now - date
seconds = difference.seconds
days = difference.days
_ = self.translate
format = None
if not full_format:
if relative and days == 0:
if seconds < 50:
return _("1 second ago", "%(seconds)d seconds ago",
seconds) % {"seconds": seconds}
if seconds < 50 * 60:
minutes = round(seconds / 60.0)
return _("1 minute ago", "%(minutes)d minutes ago",
minutes) % {"minutes": minutes}
hours = round(seconds / (60.0 * 60))
return _("1 hour ago", "%(hours)d hours ago",
hours) % {"hours": hours}
if days == 0:
format = _("%(time)s")
elif days == 1 and local_date.day == local_yesterday.day and \
relative:
format = _("yesterday") if shorter else \
_("yesterday at %(time)s")
elif days < 5:
format = _("%(weekday)s") if shorter else \
_("%(weekday)s at %(time)s")
elif days < 334: # 11mo, since confusing for same month last year
format = _("%(month_name)s %(day)s") if shorter else \
_("%(month_name)s %(day)s at %(time)s")
if format is None:
format = _("%(month_name)s %(day)s, %(year)s") if shorter else \
_("%(month_name)s %(day)s, %(year)s at %(time)s")
tfhour_clock = self.code not in ("en", "en_US", "zh_CN")
if tfhour_clock:
str_time = "%d:%02d" % (local_date.hour, local_date.minute)
elif self.code == "zh_CN":
str_time = "%s%d:%02d" % (
(u'\u4e0a\u5348', u'\u4e0b\u5348')[local_date.hour >= 12],
local_date.hour % 12 or 12, local_date.minute)
else:
str_time = "%d:%02d %s" % (
local_date.hour % 12 or 12, local_date.minute,
("am", "pm")[local_date.hour >= 12])
return format % {
"month_name": self._months[local_date.month - 1],
"weekday": self._weekdays[local_date.weekday()],
"day": str(local_date.day),
"year": str(local_date.year),
"time": str_time
}
def format_day(self, date, gmt_offset=0, dow=True):
"""Formats the given date as a day of week.
Example: "Monday, January 22". You can remove the day of week with
``dow=False``.
"""
local_date = date - datetime.timedelta(minutes=gmt_offset)
_ = self.translate
if dow:
return _("%(weekday)s, %(month_name)s %(day)s") % {
"month_name": self._months[local_date.month - 1],
"weekday": self._weekdays[local_date.weekday()],
"day": str(local_date.day),
}
else:
return _("%(month_name)s %(day)s") % {
"month_name": self._months[local_date.month - 1],
"day": str(local_date.day),
}
def list(self, parts):
"""Returns a comma-separated list for the given list of parts.
The format is, e.g., "A, B and C", "A and B" or just "A" for lists
of size 1.
"""
_ = self.translate
if len(parts) == 0:
return ""
if len(parts) == 1:
return parts[0]
comma = u' \u0648 ' if self.code.startswith("fa") else u", "
return _("%(commas)s and %(last)s") % {
"commas": comma.join(parts[:-1]),
"last": parts[len(parts) - 1],
}
def friendly_number(self, value):
"""Returns a comma-separated number for the given integer."""
if self.code not in ("en", "en_US"):
return str(value)
value = str(value)
parts = []
while value:
parts.append(value[-3:])
value = value[:-3]
return ",".join(reversed(parts))
class CSVLocale(Locale):
"""Locale implementation using tornado's CSV translation format."""
def translate(self, message, plural_message=None, count=None):
if plural_message is not None:
assert count is not None
if count != 1:
message = plural_message
message_dict = self.translations.get("plural", {})
else:
message_dict = self.translations.get("singular", {})
else:
message_dict = self.translations.get("unknown", {})
return message_dict.get(message, message)
def pgettext(self, context, message, plural_message=None, count=None):
if self.translations:
gen_log.warning('pgettext is not supported by CSVLocale')
return self.translate(message, plural_message, count)
class GettextLocale(Locale):
"""Locale implementation using the `gettext` module."""
def __init__(self, code, translations):
try:
# python 2
self.ngettext = translations.ungettext
self.gettext = translations.ugettext
except AttributeError:
# python 3
self.ngettext = translations.ngettext
self.gettext = translations.gettext
# self.gettext must exist before __init__ is called, since it
# calls into self.translate
super(GettextLocale, self).__init__(code, translations)
def translate(self, message, plural_message=None, count=None):
if plural_message is not None:
assert count is not None
return self.ngettext(message, plural_message, count)
else:
return self.gettext(message)
def pgettext(self, context, message, plural_message=None, count=None):
"""Allows to set context for translation, accepts plural forms.
Usage example::
pgettext("law", "right")
pgettext("good", "right")
Plural message example::
pgettext("organization", "club", "clubs", len(clubs))
pgettext("stick", "club", "clubs", len(clubs))
To generate POT file with context, add following options to step 1
of `load_gettext_translations` sequence::
xgettext [basic options] --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3
.. versionadded:: 4.2
"""
if plural_message is not None:
assert count is not None
msgs_with_ctxt = ("%s%s%s" % (context, CONTEXT_SEPARATOR, message),
"%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message),
count)
result = self.ngettext(*msgs_with_ctxt)
if CONTEXT_SEPARATOR in result:
# Translation not found
result = self.ngettext(message, plural_message, count)
return result
else:
msg_with_ctxt = "%s%s%s" % (context, CONTEXT_SEPARATOR, message)
result = self.gettext(msg_with_ctxt)
if CONTEXT_SEPARATOR in result:
# Translation not found
result = message
return result

526
lib/tornado/locks.py Executable file
View File

@@ -0,0 +1,526 @@
# Copyright 2015 The Tornado Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function
import collections
from concurrent.futures import CancelledError
from tornado import gen, ioloop
from tornado.concurrent import Future, future_set_result_unless_cancelled
__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
class _TimeoutGarbageCollector(object):
"""Base class for objects that periodically clean up timed-out waiters.
Avoids memory leak in a common pattern like:
while True:
yield condition.wait(short_timeout)
print('looping....')
"""
def __init__(self):
self._waiters = collections.deque() # Futures.
self._timeouts = 0
def _garbage_collect(self):
# Occasionally clear timed-out waiters.
self._timeouts += 1
if self._timeouts > 100:
self._timeouts = 0
self._waiters = collections.deque(
w for w in self._waiters if not w.done())
class Condition(_TimeoutGarbageCollector):
"""A condition allows one or more coroutines to wait until notified.
Like a standard `threading.Condition`, but does not need an underlying lock
that is acquired and released.
With a `Condition`, coroutines can wait to be notified by other coroutines:
.. testcode::
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.locks import Condition
condition = Condition()
async def waiter():
print("I'll wait right here")
await condition.wait()
print("I'm done waiting")
async def notifier():
print("About to notify")
condition.notify()
print("Done notifying")
async def runner():
# Wait for waiter() and notifier() in parallel
await gen.multi([waiter(), notifier()])
IOLoop.current().run_sync(runner)
.. testoutput::
I'll wait right here
About to notify
Done notifying
I'm done waiting
`wait` takes an optional ``timeout`` argument, which is either an absolute
timestamp::
io_loop = IOLoop.current()
# Wait up to 1 second for a notification.
await condition.wait(timeout=io_loop.time() + 1)
...or a `datetime.timedelta` for a timeout relative to the current time::
# Wait up to 1 second.
await condition.wait(timeout=datetime.timedelta(seconds=1))
The method returns False if there's no notification before the deadline.
.. versionchanged:: 5.0
Previously, waiters could be notified synchronously from within
`notify`. Now, the notification will always be received on the
next iteration of the `.IOLoop`.
"""
def __init__(self):
super(Condition, self).__init__()
self.io_loop = ioloop.IOLoop.current()
def __repr__(self):
result = '<%s' % (self.__class__.__name__, )
if self._waiters:
result += ' waiters[%s]' % len(self._waiters)
return result + '>'
def wait(self, timeout=None):
"""Wait for `.notify`.
Returns a `.Future` that resolves ``True`` if the condition is notified,
or ``False`` after a timeout.
"""
waiter = Future()
self._waiters.append(waiter)
if timeout:
def on_timeout():
if not waiter.done():
future_set_result_unless_cancelled(waiter, False)
self._garbage_collect()
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
waiter.add_done_callback(
lambda _: io_loop.remove_timeout(timeout_handle))
return waiter
def notify(self, n=1):
"""Wake ``n`` waiters."""
waiters = [] # Waiters we plan to run right now.
while n and self._waiters:
waiter = self._waiters.popleft()
if not waiter.done(): # Might have timed out.
n -= 1
waiters.append(waiter)
for waiter in waiters:
future_set_result_unless_cancelled(waiter, True)
def notify_all(self):
"""Wake all waiters."""
self.notify(len(self._waiters))
class Event(object):
"""An event blocks coroutines until its internal flag is set to True.
Similar to `threading.Event`.
A coroutine can wait for an event to be set. Once it is set, calls to
``yield event.wait()`` will not block unless the event has been cleared:
.. testcode::
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.locks import Event
event = Event()
async def waiter():
print("Waiting for event")
await event.wait()
print("Not waiting this time")
await event.wait()
print("Done")
async def setter():
print("About to set the event")
event.set()
async def runner():
await gen.multi([waiter(), setter()])
IOLoop.current().run_sync(runner)
.. testoutput::
Waiting for event
About to set the event
Not waiting this time
Done
"""
def __init__(self):
self._value = False
self._waiters = set()
def __repr__(self):
return '<%s %s>' % (
self.__class__.__name__, 'set' if self.is_set() else 'clear')
def is_set(self):
"""Return ``True`` if the internal flag is true."""
return self._value
def set(self):
"""Set the internal flag to ``True``. All waiters are awakened.
Calling `.wait` once the flag is set will not block.
"""
if not self._value:
self._value = True
for fut in self._waiters:
if not fut.done():
fut.set_result(None)
def clear(self):
"""Reset the internal flag to ``False``.
Calls to `.wait` will block until `.set` is called.
"""
self._value = False
def wait(self, timeout=None):
"""Block until the internal flag is true.
Returns a Future, which raises `tornado.util.TimeoutError` after a
timeout.
"""
fut = Future()
if self._value:
fut.set_result(None)
return fut
self._waiters.add(fut)
fut.add_done_callback(lambda fut: self._waiters.remove(fut))
if timeout is None:
return fut
else:
timeout_fut = gen.with_timeout(timeout, fut, quiet_exceptions=(CancelledError,))
# This is a slightly clumsy workaround for the fact that
# gen.with_timeout doesn't cancel its futures. Cancelling
# fut will remove it from the waiters list.
timeout_fut.add_done_callback(lambda tf: fut.cancel() if not fut.done() else None)
return timeout_fut
class _ReleasingContextManager(object):
"""Releases a Lock or Semaphore at the end of a "with" statement.
with (yield semaphore.acquire()):
pass
# Now semaphore.release() has been called.
"""
def __init__(self, obj):
self._obj = obj
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
self._obj.release()
class Semaphore(_TimeoutGarbageCollector):
"""A lock that can be acquired a fixed number of times before blocking.
A Semaphore manages a counter representing the number of `.release` calls
minus the number of `.acquire` calls, plus an initial value. The `.acquire`
method blocks if necessary until it can return without making the counter
negative.
Semaphores limit access to a shared resource. To allow access for two
workers at a time:
.. testsetup:: semaphore
from collections import deque
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.concurrent import Future
# Ensure reliable doctest output: resolve Futures one at a time.
futures_q = deque([Future() for _ in range(3)])
async def simulator(futures):
for f in futures:
# simulate the asynchronous passage of time
await gen.sleep(0)
await gen.sleep(0)
f.set_result(None)
IOLoop.current().add_callback(simulator, list(futures_q))
def use_some_resource():
return futures_q.popleft()
.. testcode:: semaphore
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.locks import Semaphore
sem = Semaphore(2)
async def worker(worker_id):
await sem.acquire()
try:
print("Worker %d is working" % worker_id)
await use_some_resource()
finally:
print("Worker %d is done" % worker_id)
sem.release()
async def runner():
# Join all workers.
await gen.multi([worker(i) for i in range(3)])
IOLoop.current().run_sync(runner)
.. testoutput:: semaphore
Worker 0 is working
Worker 1 is working
Worker 0 is done
Worker 2 is working
Worker 1 is done
Worker 2 is done
Workers 0 and 1 are allowed to run concurrently, but worker 2 waits until
the semaphore has been released once, by worker 0.
The semaphore can be used as an async context manager::
async def worker(worker_id):
async with sem:
print("Worker %d is working" % worker_id)
await use_some_resource()
# Now the semaphore has been released.
print("Worker %d is done" % worker_id)
For compatibility with older versions of Python, `.acquire` is a
context manager, so ``worker`` could also be written as::
@gen.coroutine
def worker(worker_id):
with (yield sem.acquire()):
print("Worker %d is working" % worker_id)
yield use_some_resource()
# Now the semaphore has been released.
print("Worker %d is done" % worker_id)
.. versionchanged:: 4.3
Added ``async with`` support in Python 3.5.
"""
def __init__(self, value=1):
super(Semaphore, self).__init__()
if value < 0:
raise ValueError('semaphore initial value must be >= 0')
self._value = value
def __repr__(self):
res = super(Semaphore, self).__repr__()
extra = 'locked' if self._value == 0 else 'unlocked,value:{0}'.format(
self._value)
if self._waiters:
extra = '{0},waiters:{1}'.format(extra, len(self._waiters))
return '<{0} [{1}]>'.format(res[1:-1], extra)
def release(self):
"""Increment the counter and wake one waiter."""
self._value += 1
while self._waiters:
waiter = self._waiters.popleft()
if not waiter.done():
self._value -= 1
# If the waiter is a coroutine paused at
#
# with (yield semaphore.acquire()):
#
# then the context manager's __exit__ calls release() at the end
# of the "with" block.
waiter.set_result(_ReleasingContextManager(self))
break
def acquire(self, timeout=None):
"""Decrement the counter. Returns a Future.
Block if the counter is zero and wait for a `.release`. The Future
raises `.TimeoutError` after the deadline.
"""
waiter = Future()
if self._value > 0:
self._value -= 1
waiter.set_result(_ReleasingContextManager(self))
else:
self._waiters.append(waiter)
if timeout:
def on_timeout():
if not waiter.done():
waiter.set_exception(gen.TimeoutError())
self._garbage_collect()
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
waiter.add_done_callback(
lambda _: io_loop.remove_timeout(timeout_handle))
return waiter
def __enter__(self):
raise RuntimeError(
"Use Semaphore like 'with (yield semaphore.acquire())', not like"
" 'with semaphore'")
__exit__ = __enter__
@gen.coroutine
def __aenter__(self):
yield self.acquire()
@gen.coroutine
def __aexit__(self, typ, value, tb):
self.release()
class BoundedSemaphore(Semaphore):
"""A semaphore that prevents release() being called too many times.
If `.release` would increment the semaphore's value past the initial
value, it raises `ValueError`. Semaphores are mostly used to guard
resources with limited capacity, so a semaphore released too many times
is a sign of a bug.
"""
def __init__(self, value=1):
super(BoundedSemaphore, self).__init__(value=value)
self._initial_value = value
def release(self):
"""Increment the counter and wake one waiter."""
if self._value >= self._initial_value:
raise ValueError("Semaphore released too many times")
super(BoundedSemaphore, self).release()
class Lock(object):
"""A lock for coroutines.
A Lock begins unlocked, and `acquire` locks it immediately. While it is
locked, a coroutine that yields `acquire` waits until another coroutine
calls `release`.
Releasing an unlocked lock raises `RuntimeError`.
A Lock can be used as an async context manager with the ``async
with`` statement:
>>> from tornado import locks
>>> lock = locks.Lock()
>>>
>>> async def f():
... async with lock:
... # Do something holding the lock.
... pass
...
... # Now the lock is released.
For compatibility with older versions of Python, the `.acquire`
method asynchronously returns a regular context manager:
>>> async def f2():
... with (yield lock.acquire()):
... # Do something holding the lock.
... pass
...
... # Now the lock is released.
.. versionchanged:: 4.3
Added ``async with`` support in Python 3.5.
"""
def __init__(self):
self._block = BoundedSemaphore(value=1)
def __repr__(self):
return "<%s _block=%s>" % (
self.__class__.__name__,
self._block)
def acquire(self, timeout=None):
"""Attempt to lock. Returns a Future.
Returns a Future, which raises `tornado.util.TimeoutError` after a
timeout.
"""
return self._block.acquire(timeout)
def release(self):
"""Unlock.
The first coroutine in line waiting for `acquire` gets the lock.
If not locked, raise a `RuntimeError`.
"""
try:
self._block.release()
except ValueError:
raise RuntimeError('release unlocked lock')
def __enter__(self):
raise RuntimeError(
"Use Lock like 'with (yield lock)', not like 'with lock'")
__exit__ = __enter__
@gen.coroutine
def __aenter__(self):
yield self.acquire()
@gen.coroutine
def __aexit__(self, typ, value, tb):
self.release()

290
lib/tornado/log.py Executable file
View File

@@ -0,0 +1,290 @@
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Logging support for Tornado.
Tornado uses three logger streams:
* ``tornado.access``: Per-request logging for Tornado's HTTP servers (and
potentially other servers in the future)
* ``tornado.application``: Logging of errors from application code (i.e.
uncaught exceptions from callbacks)
* ``tornado.general``: General-purpose logging, including any errors
or warnings from Tornado itself.
These streams may be configured independently using the standard library's
`logging` module. For example, you may wish to send ``tornado.access`` logs
to a separate file for analysis.
"""
from __future__ import absolute_import, division, print_function
import logging
import logging.handlers
import sys
from tornado.escape import _unicode
from tornado.util import unicode_type, basestring_type
try:
import colorama
except ImportError:
colorama = None
try:
import curses # type: ignore
except ImportError:
curses = None
# Logger objects for internal tornado use
access_log = logging.getLogger("tornado.access")
app_log = logging.getLogger("tornado.application")
gen_log = logging.getLogger("tornado.general")
def _stderr_supports_color():
try:
if hasattr(sys.stderr, 'isatty') and sys.stderr.isatty():
if curses:
curses.setupterm()
if curses.tigetnum("colors") > 0:
return True
elif colorama:
if sys.stderr is getattr(colorama.initialise, 'wrapped_stderr',
object()):
return True
except Exception:
# Very broad exception handling because it's always better to
# fall back to non-colored logs than to break at startup.
pass
return False
def _safe_unicode(s):
try:
return _unicode(s)
except UnicodeDecodeError:
return repr(s)
class LogFormatter(logging.Formatter):
"""Log formatter used in Tornado.
Key features of this formatter are:
* Color support when logging to a terminal that supports it.
* Timestamps on every log line.
* Robust against str/bytes encoding problems.
This formatter is enabled automatically by
`tornado.options.parse_command_line` or `tornado.options.parse_config_file`
(unless ``--logging=none`` is used).
Color support on Windows versions that do not support ANSI color codes is
enabled by use of the colorama__ library. Applications that wish to use
this must first initialize colorama with a call to ``colorama.init``.
See the colorama documentation for details.
__ https://pypi.python.org/pypi/colorama
.. versionchanged:: 4.5
Added support for ``colorama``. Changed the constructor
signature to be compatible with `logging.config.dictConfig`.
"""
DEFAULT_FORMAT = \
'%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s'
DEFAULT_DATE_FORMAT = '%y%m%d %H:%M:%S'
DEFAULT_COLORS = {
logging.DEBUG: 4, # Blue
logging.INFO: 2, # Green
logging.WARNING: 3, # Yellow
logging.ERROR: 1, # Red
}
def __init__(self, fmt=DEFAULT_FORMAT, datefmt=DEFAULT_DATE_FORMAT,
style='%', color=True, colors=DEFAULT_COLORS):
r"""
:arg bool color: Enables color support.
:arg str fmt: Log message format.
It will be applied to the attributes dict of log records. The
text between ``%(color)s`` and ``%(end_color)s`` will be colored
depending on the level if color support is on.
:arg dict colors: color mappings from logging level to terminal color
code
:arg str datefmt: Datetime format.
Used for formatting ``(asctime)`` placeholder in ``prefix_fmt``.
.. versionchanged:: 3.2
Added ``fmt`` and ``datefmt`` arguments.
"""
logging.Formatter.__init__(self, datefmt=datefmt)
self._fmt = fmt
self._colors = {}
if color and _stderr_supports_color():
if curses is not None:
# The curses module has some str/bytes confusion in
# python3. Until version 3.2.3, most methods return
# bytes, but only accept strings. In addition, we want to
# output these strings with the logging module, which
# works with unicode strings. The explicit calls to
# unicode() below are harmless in python2 but will do the
# right conversion in python 3.
fg_color = (curses.tigetstr("setaf") or
curses.tigetstr("setf") or "")
if (3, 0) < sys.version_info < (3, 2, 3):
fg_color = unicode_type(fg_color, "ascii")
for levelno, code in colors.items():
self._colors[levelno] = unicode_type(curses.tparm(fg_color, code), "ascii")
self._normal = unicode_type(curses.tigetstr("sgr0"), "ascii")
else:
# If curses is not present (currently we'll only get here for
# colorama on windows), assume hard-coded ANSI color codes.
for levelno, code in colors.items():
self._colors[levelno] = '\033[2;3%dm' % code
self._normal = '\033[0m'
else:
self._normal = ''
def format(self, record):
try:
message = record.getMessage()
assert isinstance(message, basestring_type) # guaranteed by logging
# Encoding notes: The logging module prefers to work with character
# strings, but only enforces that log messages are instances of
# basestring. In python 2, non-ascii bytestrings will make
# their way through the logging framework until they blow up with
# an unhelpful decoding error (with this formatter it happens
# when we attach the prefix, but there are other opportunities for
# exceptions further along in the framework).
#
# If a byte string makes it this far, convert it to unicode to
# ensure it will make it out to the logs. Use repr() as a fallback
# to ensure that all byte strings can be converted successfully,
# but don't do it by default so we don't add extra quotes to ascii
# bytestrings. This is a bit of a hacky place to do this, but
# it's worth it since the encoding errors that would otherwise
# result are so useless (and tornado is fond of using utf8-encoded
# byte strings wherever possible).
record.message = _safe_unicode(message)
except Exception as e:
record.message = "Bad message (%r): %r" % (e, record.__dict__)
record.asctime = self.formatTime(record, self.datefmt)
if record.levelno in self._colors:
record.color = self._colors[record.levelno]
record.end_color = self._normal
else:
record.color = record.end_color = ''
formatted = self._fmt % record.__dict__
if record.exc_info:
if not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if record.exc_text:
# exc_text contains multiple lines. We need to _safe_unicode
# each line separately so that non-utf8 bytes don't cause
# all the newlines to turn into '\n'.
lines = [formatted.rstrip()]
lines.extend(_safe_unicode(ln) for ln in record.exc_text.split('\n'))
formatted = '\n'.join(lines)
return formatted.replace("\n", "\n ")
def enable_pretty_logging(options=None, logger=None):
"""Turns on formatted logging output as configured.
This is called automatically by `tornado.options.parse_command_line`
and `tornado.options.parse_config_file`.
"""
if options is None:
import tornado.options
options = tornado.options.options
if options.logging is None or options.logging.lower() == 'none':
return
if logger is None:
logger = logging.getLogger()
logger.setLevel(getattr(logging, options.logging.upper()))
if options.log_file_prefix:
rotate_mode = options.log_rotate_mode
if rotate_mode == 'size':
channel = logging.handlers.RotatingFileHandler(
filename=options.log_file_prefix,
maxBytes=options.log_file_max_size,
backupCount=options.log_file_num_backups)
elif rotate_mode == 'time':
channel = logging.handlers.TimedRotatingFileHandler(
filename=options.log_file_prefix,
when=options.log_rotate_when,
interval=options.log_rotate_interval,
backupCount=options.log_file_num_backups)
else:
error_message = 'The value of log_rotate_mode option should be ' +\
'"size" or "time", not "%s".' % rotate_mode
raise ValueError(error_message)
channel.setFormatter(LogFormatter(color=False))
logger.addHandler(channel)
if (options.log_to_stderr or
(options.log_to_stderr is None and not logger.handlers)):
# Set up color if we are in a tty and curses is installed
channel = logging.StreamHandler()
channel.setFormatter(LogFormatter())
logger.addHandler(channel)
def define_logging_options(options=None):
"""Add logging-related flags to ``options``.
These options are present automatically on the default options instance;
this method is only necessary if you have created your own `.OptionParser`.
.. versionadded:: 4.2
This function existed in prior versions but was broken and undocumented until 4.2.
"""
if options is None:
# late import to prevent cycle
import tornado.options
options = tornado.options.options
options.define("logging", default="info",
help=("Set the Python log level. If 'none', tornado won't touch the "
"logging configuration."),
metavar="debug|info|warning|error|none")
options.define("log_to_stderr", type=bool, default=None,
help=("Send log output to stderr (colorized if possible). "
"By default use stderr if --log_file_prefix is not set and "
"no other logging is configured."))
options.define("log_file_prefix", type=str, default=None, metavar="PATH",
help=("Path prefix for log files. "
"Note that if you are running multiple tornado processes, "
"log_file_prefix must be different for each of them (e.g. "
"include the port number)"))
options.define("log_file_max_size", type=int, default=100 * 1000 * 1000,
help="max size of log files before rollover")
options.define("log_file_num_backups", type=int, default=10,
help="number of log files to keep")
options.define("log_rotate_when", type=str, default='midnight',
help=("specify the type of TimedRotatingFileHandler interval "
"other options:('S', 'M', 'H', 'D', 'W0'-'W6')"))
options.define("log_rotate_interval", type=int, default=1,
help="The interval value of timed rotating")
options.define("log_rotate_mode", type=str, default='size',
help="The mode of rotating files(time or size)")
options.add_parse_callback(lambda: enable_pretty_logging(options))

575
lib/tornado/netutil.py Executable file
View File

@@ -0,0 +1,575 @@
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Miscellaneous network utility code."""
from __future__ import absolute_import, division, print_function
import errno
import os
import sys
import socket
import stat
from tornado.concurrent import dummy_executor, run_on_executor
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.platform.auto import set_close_exec
from tornado.util import PY3, Configurable, errno_from_exception
try:
import ssl
except ImportError:
# ssl is not available on Google App Engine
ssl = None
if PY3:
xrange = range
if ssl is not None:
# Note that the naming of ssl.Purpose is confusing; the purpose
# of a context is to authentiate the opposite side of the connection.
_client_ssl_defaults = ssl.create_default_context(
ssl.Purpose.SERVER_AUTH)
_server_ssl_defaults = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH)
if hasattr(ssl, 'OP_NO_COMPRESSION'):
# See netutil.ssl_options_to_context
_client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
_server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
else:
# Google App Engine
_client_ssl_defaults = dict(cert_reqs=None,
ca_certs=None)
_server_ssl_defaults = {}
# ThreadedResolver runs getaddrinfo on a thread. If the hostname is unicode,
# getaddrinfo attempts to import encodings.idna. If this is done at
# module-import time, the import lock is already held by the main thread,
# leading to deadlock. Avoid it by caching the idna encoder on the main
# thread now.
u'foo'.encode('idna')
# For undiagnosed reasons, 'latin1' codec may also need to be preloaded.
u'foo'.encode('latin1')
# These errnos indicate that a non-blocking operation must be retried
# at a later time. On most platforms they're the same value, but on
# some they differ.
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
if hasattr(errno, "WSAEWOULDBLOCK"):
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) # type: ignore
# Default backlog used when calling sock.listen()
_DEFAULT_BACKLOG = 128
def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
backlog=_DEFAULT_BACKLOG, flags=None, reuse_port=False):
"""Creates listening sockets bound to the given port and address.
Returns a list of socket objects (multiple sockets are returned if
the given address maps to multiple IP addresses, which is most common
for mixed IPv4 and IPv6 use).
Address may be either an IP address or hostname. If it's a hostname,
the server will listen on all IP addresses associated with the
name. Address may be an empty string or None to listen on all
available interfaces. Family may be set to either `socket.AF_INET`
or `socket.AF_INET6` to restrict to IPv4 or IPv6 addresses, otherwise
both will be used if available.
The ``backlog`` argument has the same meaning as for
`socket.listen() <socket.socket.listen>`.
``flags`` is a bitmask of AI_* flags to `~socket.getaddrinfo`, like
``socket.AI_PASSIVE | socket.AI_NUMERICHOST``.
``reuse_port`` option sets ``SO_REUSEPORT`` option for every socket
in the list. If your platform doesn't support this option ValueError will
be raised.
"""
if reuse_port and not hasattr(socket, "SO_REUSEPORT"):
raise ValueError("the platform doesn't support SO_REUSEPORT")
sockets = []
if address == "":
address = None
if not socket.has_ipv6 and family == socket.AF_UNSPEC:
# Python can be compiled with --disable-ipv6, which causes
# operations on AF_INET6 sockets to fail, but does not
# automatically exclude those results from getaddrinfo
# results.
# http://bugs.python.org/issue16208
family = socket.AF_INET
if flags is None:
flags = socket.AI_PASSIVE
bound_port = None
for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM,
0, flags)):
af, socktype, proto, canonname, sockaddr = res
if (sys.platform == 'darwin' and address == 'localhost' and
af == socket.AF_INET6 and sockaddr[3] != 0):
# Mac OS X includes a link-local address fe80::1%lo0 in the
# getaddrinfo results for 'localhost'. However, the firewall
# doesn't understand that this is a local address and will
# prompt for access (often repeatedly, due to an apparent
# bug in its ability to remember granting access to an
# application). Skip these addresses.
continue
try:
sock = socket.socket(af, socktype, proto)
except socket.error as e:
if errno_from_exception(e) == errno.EAFNOSUPPORT:
continue
raise
set_close_exec(sock.fileno())
if os.name != 'nt':
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except socket.error as e:
if errno_from_exception(e) != errno.ENOPROTOOPT:
# Hurd doesn't support SO_REUSEADDR.
raise
if reuse_port:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if af == socket.AF_INET6:
# On linux, ipv6 sockets accept ipv4 too by default,
# but this makes it impossible to bind to both
# 0.0.0.0 in ipv4 and :: in ipv6. On other systems,
# separate sockets *must* be used to listen for both ipv4
# and ipv6. For consistency, always disable ipv4 on our
# ipv6 sockets and use a separate ipv4 socket when needed.
#
# Python 2.x on windows doesn't have IPPROTO_IPV6.
if hasattr(socket, "IPPROTO_IPV6"):
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
# automatic port allocation with port=None
# should bind on the same port on IPv4 and IPv6
host, requested_port = sockaddr[:2]
if requested_port == 0 and bound_port is not None:
sockaddr = tuple([host, bound_port] + list(sockaddr[2:]))
sock.setblocking(0)
sock.bind(sockaddr)
bound_port = sock.getsockname()[1]
sock.listen(backlog)
sockets.append(sock)
return sockets
if hasattr(socket, 'AF_UNIX'):
def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG):
"""Creates a listening unix socket.
If a socket with the given name already exists, it will be deleted.
If any other file with that name exists, an exception will be
raised.
Returns a socket object (not a list of socket objects like
`bind_sockets`)
"""
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
set_close_exec(sock.fileno())
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except socket.error as e:
if errno_from_exception(e) != errno.ENOPROTOOPT:
# Hurd doesn't support SO_REUSEADDR
raise
sock.setblocking(0)
try:
st = os.stat(file)
except OSError as err:
if errno_from_exception(err) != errno.ENOENT:
raise
else:
if stat.S_ISSOCK(st.st_mode):
os.remove(file)
else:
raise ValueError("File %s exists and is not a socket", file)
sock.bind(file)
os.chmod(file, mode)
sock.listen(backlog)
return sock
def add_accept_handler(sock, callback):
"""Adds an `.IOLoop` event handler to accept new connections on ``sock``.
When a connection is accepted, ``callback(connection, address)`` will
be run (``connection`` is a socket object, and ``address`` is the
address of the other end of the connection). Note that this signature
is different from the ``callback(fd, events)`` signature used for
`.IOLoop` handlers.
A callable is returned which, when called, will remove the `.IOLoop`
event handler and stop processing further incoming connections.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
.. versionchanged:: 5.0
A callable is returned (``None`` was returned before).
"""
io_loop = IOLoop.current()
removed = [False]
def accept_handler(fd, events):
# More connections may come in while we're handling callbacks;
# to prevent starvation of other tasks we must limit the number
# of connections we accept at a time. Ideally we would accept
# up to the number of connections that were waiting when we
# entered this method, but this information is not available
# (and rearranging this method to call accept() as many times
# as possible before running any callbacks would have adverse
# effects on load balancing in multiprocess configurations).
# Instead, we use the (default) listen backlog as a rough
# heuristic for the number of connections we can reasonably
# accept at once.
for i in xrange(_DEFAULT_BACKLOG):
if removed[0]:
# The socket was probably closed
return
try:
connection, address = sock.accept()
except socket.error as e:
# _ERRNO_WOULDBLOCK indicate we have accepted every
# connection that is available.
if errno_from_exception(e) in _ERRNO_WOULDBLOCK:
return
# ECONNABORTED indicates that there was a connection
# but it was closed while still in the accept queue.
# (observed on FreeBSD).
if errno_from_exception(e) == errno.ECONNABORTED:
continue
raise
set_close_exec(connection.fileno())
callback(connection, address)
def remove_handler():
io_loop.remove_handler(sock)
removed[0] = True
io_loop.add_handler(sock, accept_handler, IOLoop.READ)
return remove_handler
def is_valid_ip(ip):
"""Returns true if the given string is a well-formed IP address.
Supports IPv4 and IPv6.
"""
if not ip or '\x00' in ip:
# getaddrinfo resolves empty strings to localhost, and truncates
# on zero bytes.
return False
try:
res = socket.getaddrinfo(ip, 0, socket.AF_UNSPEC,
socket.SOCK_STREAM,
0, socket.AI_NUMERICHOST)
return bool(res)
except socket.gaierror as e:
if e.args[0] == socket.EAI_NONAME:
return False
raise
return True
class Resolver(Configurable):
"""Configurable asynchronous DNS resolver interface.
By default, a blocking implementation is used (which simply calls
`socket.getaddrinfo`). An alternative implementation can be
chosen with the `Resolver.configure <.Configurable.configure>`
class method::
Resolver.configure('tornado.netutil.ThreadedResolver')
The implementations of this interface included with Tornado are
* `tornado.netutil.DefaultExecutorResolver`
* `tornado.netutil.BlockingResolver` (deprecated)
* `tornado.netutil.ThreadedResolver` (deprecated)
* `tornado.netutil.OverrideResolver`
* `tornado.platform.twisted.TwistedResolver`
* `tornado.platform.caresresolver.CaresResolver`
.. versionchanged:: 5.0
The default implementation has changed from `BlockingResolver` to
`DefaultExecutorResolver`.
"""
@classmethod
def configurable_base(cls):
return Resolver
@classmethod
def configurable_default(cls):
return DefaultExecutorResolver
def resolve(self, host, port, family=socket.AF_UNSPEC, callback=None):
"""Resolves an address.
The ``host`` argument is a string which may be a hostname or a
literal IP address.
Returns a `.Future` whose result is a list of (family,
address) pairs, where address is a tuple suitable to pass to
`socket.connect <socket.socket.connect>` (i.e. a ``(host,
port)`` pair for IPv4; additional fields may be present for
IPv6). If a ``callback`` is passed, it will be run with the
result as an argument when it is complete.
:raises IOError: if the address cannot be resolved.
.. versionchanged:: 4.4
Standardized all implementations to raise `IOError`.
.. deprecated:: 5.1
The ``callback`` argument is deprecated and will be removed in 6.0.
Use the returned awaitable object instead.
"""
raise NotImplementedError()
def close(self):
"""Closes the `Resolver`, freeing any resources used.
.. versionadded:: 3.1
"""
pass
def _resolve_addr(host, port, family=socket.AF_UNSPEC):
# On Solaris, getaddrinfo fails if the given port is not found
# in /etc/services and no socket type is given, so we must pass
# one here. The socket type used here doesn't seem to actually
# matter (we discard the one we get back in the results),
# so the addresses we return should still be usable with SOCK_DGRAM.
addrinfo = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM)
results = []
for family, socktype, proto, canonname, address in addrinfo:
results.append((family, address))
return results
class DefaultExecutorResolver(Resolver):
"""Resolver implementation using `.IOLoop.run_in_executor`.
.. versionadded:: 5.0
"""
@gen.coroutine
def resolve(self, host, port, family=socket.AF_UNSPEC):
result = yield IOLoop.current().run_in_executor(
None, _resolve_addr, host, port, family)
raise gen.Return(result)
class ExecutorResolver(Resolver):
"""Resolver implementation using a `concurrent.futures.Executor`.
Use this instead of `ThreadedResolver` when you require additional
control over the executor being used.
The executor will be shut down when the resolver is closed unless
``close_resolver=False``; use this if you want to reuse the same
executor elsewhere.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
.. deprecated:: 5.0
The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
of this class.
"""
def initialize(self, executor=None, close_executor=True):
self.io_loop = IOLoop.current()
if executor is not None:
self.executor = executor
self.close_executor = close_executor
else:
self.executor = dummy_executor
self.close_executor = False
def close(self):
if self.close_executor:
self.executor.shutdown()
self.executor = None
@run_on_executor
def resolve(self, host, port, family=socket.AF_UNSPEC):
return _resolve_addr(host, port, family)
class BlockingResolver(ExecutorResolver):
"""Default `Resolver` implementation, using `socket.getaddrinfo`.
The `.IOLoop` will be blocked during the resolution, although the
callback will not be run until the next `.IOLoop` iteration.
.. deprecated:: 5.0
The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
of this class.
"""
def initialize(self):
super(BlockingResolver, self).initialize()
class ThreadedResolver(ExecutorResolver):
"""Multithreaded non-blocking `Resolver` implementation.
Requires the `concurrent.futures` package to be installed
(available in the standard library since Python 3.2,
installable with ``pip install futures`` in older versions).
The thread pool size can be configured with::
Resolver.configure('tornado.netutil.ThreadedResolver',
num_threads=10)
.. versionchanged:: 3.1
All ``ThreadedResolvers`` share a single thread pool, whose
size is set by the first one to be created.
.. deprecated:: 5.0
The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
of this class.
"""
_threadpool = None # type: ignore
_threadpool_pid = None # type: int
def initialize(self, num_threads=10):
threadpool = ThreadedResolver._create_threadpool(num_threads)
super(ThreadedResolver, self).initialize(
executor=threadpool, close_executor=False)
@classmethod
def _create_threadpool(cls, num_threads):
pid = os.getpid()
if cls._threadpool_pid != pid:
# Threads cannot survive after a fork, so if our pid isn't what it
# was when we created the pool then delete it.
cls._threadpool = None
if cls._threadpool is None:
from concurrent.futures import ThreadPoolExecutor
cls._threadpool = ThreadPoolExecutor(num_threads)
cls._threadpool_pid = pid
return cls._threadpool
class OverrideResolver(Resolver):
"""Wraps a resolver with a mapping of overrides.
This can be used to make local DNS changes (e.g. for testing)
without modifying system-wide settings.
The mapping can be in three formats::
{
# Hostname to host or ip
"example.com": "127.0.1.1",
# Host+port to host+port
("login.example.com", 443): ("localhost", 1443),
# Host+port+address family to host+port
("login.example.com", 443, socket.AF_INET6): ("::1", 1443),
}
.. versionchanged:: 5.0
Added support for host-port-family triplets.
"""
def initialize(self, resolver, mapping):
self.resolver = resolver
self.mapping = mapping
def close(self):
self.resolver.close()
def resolve(self, host, port, family=socket.AF_UNSPEC, *args, **kwargs):
if (host, port, family) in self.mapping:
host, port = self.mapping[(host, port, family)]
elif (host, port) in self.mapping:
host, port = self.mapping[(host, port)]
elif host in self.mapping:
host = self.mapping[host]
return self.resolver.resolve(host, port, family, *args, **kwargs)
# These are the keyword arguments to ssl.wrap_socket that must be translated
# to their SSLContext equivalents (the other arguments are still passed
# to SSLContext.wrap_socket).
_SSL_CONTEXT_KEYWORDS = frozenset(['ssl_version', 'certfile', 'keyfile',
'cert_reqs', 'ca_certs', 'ciphers'])
def ssl_options_to_context(ssl_options):
"""Try to convert an ``ssl_options`` dictionary to an
`~ssl.SSLContext` object.
The ``ssl_options`` dictionary contains keywords to be passed to
`ssl.wrap_socket`. In Python 2.7.9+, `ssl.SSLContext` objects can
be used instead. This function converts the dict form to its
`~ssl.SSLContext` equivalent, and may be used when a component which
accepts both forms needs to upgrade to the `~ssl.SSLContext` version
to use features like SNI or NPN.
"""
if isinstance(ssl_options, ssl.SSLContext):
return ssl_options
assert isinstance(ssl_options, dict)
assert all(k in _SSL_CONTEXT_KEYWORDS for k in ssl_options), ssl_options
# Can't use create_default_context since this interface doesn't
# tell us client vs server.
context = ssl.SSLContext(
ssl_options.get('ssl_version', ssl.PROTOCOL_SSLv23))
if 'certfile' in ssl_options:
context.load_cert_chain(ssl_options['certfile'], ssl_options.get('keyfile', None))
if 'cert_reqs' in ssl_options:
context.verify_mode = ssl_options['cert_reqs']
if 'ca_certs' in ssl_options:
context.load_verify_locations(ssl_options['ca_certs'])
if 'ciphers' in ssl_options:
context.set_ciphers(ssl_options['ciphers'])
if hasattr(ssl, 'OP_NO_COMPRESSION'):
# Disable TLS compression to avoid CRIME and related attacks.
# This constant depends on openssl version 1.0.
# TODO: Do we need to do this ourselves or can we trust
# the defaults?
context.options |= ssl.OP_NO_COMPRESSION
return context
def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs):
"""Returns an ``ssl.SSLSocket`` wrapping the given socket.
``ssl_options`` may be either an `ssl.SSLContext` object or a
dictionary (as accepted by `ssl_options_to_context`). Additional
keyword arguments are passed to ``wrap_socket`` (either the
`~ssl.SSLContext` method or the `ssl` module function as
appropriate).
"""
context = ssl_options_to_context(ssl_options)
if ssl.HAS_SNI:
# In python 3.4, wrap_socket only accepts the server_hostname
# argument if HAS_SNI is true.
# TODO: add a unittest (python added server-side SNI support in 3.4)
# In the meantime it can be manually tested with
# python3 -m tornado.httpclient https://sni.velox.ch
return context.wrap_socket(socket, server_hostname=server_hostname,
**kwargs)
else:
return context.wrap_socket(socket, **kwargs)

654
lib/tornado/options.py Executable file
View File

@@ -0,0 +1,654 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A command line parsing module that lets modules define their own options.
This module is inspired by Google's `gflags
<https://github.com/google/python-gflags>`_. The primary difference
with libraries such as `argparse` is that a global registry is used so
that options may be defined in any module (it also enables
`tornado.log` by default). The rest of Tornado does not depend on this
module, so feel free to use `argparse` or other configuration
libraries if you prefer them.
Options must be defined with `tornado.options.define` before use,
generally at the top level of a module. The options are then
accessible as attributes of `tornado.options.options`::
# myapp/db.py
from tornado.options import define, options
define("mysql_host", default="127.0.0.1:3306", help="Main user DB")
define("memcache_hosts", default="127.0.0.1:11011", multiple=True,
help="Main user memcache servers")
def connect():
db = database.Connection(options.mysql_host)
...
# myapp/server.py
from tornado.options import define, options
define("port", default=8080, help="port to listen on")
def start_server():
app = make_app()
app.listen(options.port)
The ``main()`` method of your application does not need to be aware of all of
the options used throughout your program; they are all automatically loaded
when the modules are loaded. However, all modules that define options
must have been imported before the command line is parsed.
Your ``main()`` method can parse the command line or parse a config file with
either `parse_command_line` or `parse_config_file`::
import myapp.db, myapp.server
import tornado.options
if __name__ == '__main__':
tornado.options.parse_command_line()
# or
tornado.options.parse_config_file("/etc/server.conf")
.. note::
When using multiple ``parse_*`` functions, pass ``final=False`` to all
but the last one, or side effects may occur twice (in particular,
this can result in log messages being doubled).
`tornado.options.options` is a singleton instance of `OptionParser`, and
the top-level functions in this module (`define`, `parse_command_line`, etc)
simply call methods on it. You may create additional `OptionParser`
instances to define isolated sets of options, such as for subcommands.
.. note::
By default, several options are defined that will configure the
standard `logging` module when `parse_command_line` or `parse_config_file`
are called. If you want Tornado to leave the logging configuration
alone so you can manage it yourself, either pass ``--logging=none``
on the command line or do the following to disable it in code::
from tornado.options import options, parse_command_line
options.logging = None
parse_command_line()
.. versionchanged:: 4.3
Dashes and underscores are fully interchangeable in option names;
options can be defined, set, and read with any mix of the two.
Dashes are typical for command-line usage while config files require
underscores.
"""
from __future__ import absolute_import, division, print_function
import datetime
import numbers
import re
import sys
import os
import textwrap
from tornado.escape import _unicode, native_str
from tornado.log import define_logging_options
from tornado import stack_context
from tornado.util import basestring_type, exec_in
class Error(Exception):
"""Exception raised by errors in the options module."""
pass
class OptionParser(object):
"""A collection of options, a dictionary with object-like access.
Normally accessed via static functions in the `tornado.options` module,
which reference a global instance.
"""
def __init__(self):
# we have to use self.__dict__ because we override setattr.
self.__dict__['_options'] = {}
self.__dict__['_parse_callbacks'] = []
self.define("help", type=bool, help="show this help information",
callback=self._help_callback)
def _normalize_name(self, name):
return name.replace('_', '-')
def __getattr__(self, name):
name = self._normalize_name(name)
if isinstance(self._options.get(name), _Option):
return self._options[name].value()
raise AttributeError("Unrecognized option %r" % name)
def __setattr__(self, name, value):
name = self._normalize_name(name)
if isinstance(self._options.get(name), _Option):
return self._options[name].set(value)
raise AttributeError("Unrecognized option %r" % name)
def __iter__(self):
return (opt.name for opt in self._options.values())
def __contains__(self, name):
name = self._normalize_name(name)
return name in self._options
def __getitem__(self, name):
return self.__getattr__(name)
def __setitem__(self, name, value):
return self.__setattr__(name, value)
def items(self):
"""A sequence of (name, value) pairs.
.. versionadded:: 3.1
"""
return [(opt.name, opt.value()) for name, opt in self._options.items()]
def groups(self):
"""The set of option-groups created by ``define``.
.. versionadded:: 3.1
"""
return set(opt.group_name for opt in self._options.values())
def group_dict(self, group):
"""The names and values of options in a group.
Useful for copying options into Application settings::
from tornado.options import define, parse_command_line, options
define('template_path', group='application')
define('static_path', group='application')
parse_command_line()
application = Application(
handlers, **options.group_dict('application'))
.. versionadded:: 3.1
"""
return dict(
(opt.name, opt.value()) for name, opt in self._options.items()
if not group or group == opt.group_name)
def as_dict(self):
"""The names and values of all options.
.. versionadded:: 3.1
"""
return dict(
(opt.name, opt.value()) for name, opt in self._options.items())
def define(self, name, default=None, type=None, help=None, metavar=None,
multiple=False, group=None, callback=None):
"""Defines a new command line option.
``type`` can be any of `str`, `int`, `float`, `bool`,
`~datetime.datetime`, or `~datetime.timedelta`. If no ``type``
is given but a ``default`` is, ``type`` is the type of
``default``. Otherwise, ``type`` defaults to `str`.
If ``multiple`` is True, the option value is a list of ``type``
instead of an instance of ``type``.
``help`` and ``metavar`` are used to construct the
automatically generated command line help string. The help
message is formatted like::
--name=METAVAR help string
``group`` is used to group the defined options in logical
groups. By default, command line options are grouped by the
file in which they are defined.
Command line option names must be unique globally.
If a ``callback`` is given, it will be run with the new value whenever
the option is changed. This can be used to combine command-line
and file-based options::
define("config", type=str, help="path to config file",
callback=lambda path: parse_config_file(path, final=False))
With this definition, options in the file specified by ``--config`` will
override options set earlier on the command line, but can be overridden
by later flags.
"""
normalized = self._normalize_name(name)
if normalized in self._options:
raise Error("Option %r already defined in %s" %
(normalized, self._options[normalized].file_name))
frame = sys._getframe(0)
options_file = frame.f_code.co_filename
# Can be called directly, or through top level define() fn, in which
# case, step up above that frame to look for real caller.
if (frame.f_back.f_code.co_filename == options_file and
frame.f_back.f_code.co_name == 'define'):
frame = frame.f_back
file_name = frame.f_back.f_code.co_filename
if file_name == options_file:
file_name = ""
if type is None:
if not multiple and default is not None:
type = default.__class__
else:
type = str
if group:
group_name = group
else:
group_name = file_name
option = _Option(name, file_name=file_name,
default=default, type=type, help=help,
metavar=metavar, multiple=multiple,
group_name=group_name,
callback=callback)
self._options[normalized] = option
def parse_command_line(self, args=None, final=True):
"""Parses all options given on the command line (defaults to
`sys.argv`).
Options look like ``--option=value`` and are parsed according
to their ``type``. For boolean options, ``--option`` is
equivalent to ``--option=true``
If the option has ``multiple=True``, comma-separated values
are accepted. For multi-value integer options, the syntax
``x:y`` is also accepted and equivalent to ``range(x, y)``.
Note that ``args[0]`` is ignored since it is the program name
in `sys.argv`.
We return a list of all arguments that are not parsed as options.
If ``final`` is ``False``, parse callbacks will not be run.
This is useful for applications that wish to combine configurations
from multiple sources.
"""
if args is None:
args = sys.argv
remaining = []
for i in range(1, len(args)):
# All things after the last option are command line arguments
if not args[i].startswith("-"):
remaining = args[i:]
break
if args[i] == "--":
remaining = args[i + 1:]
break
arg = args[i].lstrip("-")
name, equals, value = arg.partition("=")
name = self._normalize_name(name)
if name not in self._options:
self.print_help()
raise Error('Unrecognized command line option: %r' % name)
option = self._options[name]
if not equals:
if option.type == bool:
value = "true"
else:
raise Error('Option %r requires a value' % name)
option.parse(value)
if final:
self.run_parse_callbacks()
return remaining
def parse_config_file(self, path, final=True):
"""Parses and loads the config file at the given path.
The config file contains Python code that will be executed (so
it is **not safe** to use untrusted config files). Anything in
the global namespace that matches a defined option will be
used to set that option's value.
Options may either be the specified type for the option or
strings (in which case they will be parsed the same way as in
`.parse_command_line`)
Example (using the options defined in the top-level docs of
this module)::
port = 80
mysql_host = 'mydb.example.com:3306'
# Both lists and comma-separated strings are allowed for
# multiple=True.
memcache_hosts = ['cache1.example.com:11011',
'cache2.example.com:11011']
memcache_hosts = 'cache1.example.com:11011,cache2.example.com:11011'
If ``final`` is ``False``, parse callbacks will not be run.
This is useful for applications that wish to combine configurations
from multiple sources.
.. note::
`tornado.options` is primarily a command-line library.
Config file support is provided for applications that wish
to use it, but applications that prefer config files may
wish to look at other libraries instead.
.. versionchanged:: 4.1
Config files are now always interpreted as utf-8 instead of
the system default encoding.
.. versionchanged:: 4.4
The special variable ``__file__`` is available inside config
files, specifying the absolute path to the config file itself.
.. versionchanged:: 5.1
Added the ability to set options via strings in config files.
"""
config = {'__file__': os.path.abspath(path)}
with open(path, 'rb') as f:
exec_in(native_str(f.read()), config, config)
for name in config:
normalized = self._normalize_name(name)
if normalized in self._options:
option = self._options[normalized]
if option.multiple:
if not isinstance(config[name], (list, str)):
raise Error("Option %r is required to be a list of %s "
"or a comma-separated string" %
(option.name, option.type.__name__))
if type(config[name]) == str and option.type != str:
option.parse(config[name])
else:
option.set(config[name])
if final:
self.run_parse_callbacks()
def print_help(self, file=None):
"""Prints all the command line options to stderr (or another file)."""
if file is None:
file = sys.stderr
print("Usage: %s [OPTIONS]" % sys.argv[0], file=file)
print("\nOptions:\n", file=file)
by_group = {}
for option in self._options.values():
by_group.setdefault(option.group_name, []).append(option)
for filename, o in sorted(by_group.items()):
if filename:
print("\n%s options:\n" % os.path.normpath(filename), file=file)
o.sort(key=lambda option: option.name)
for option in o:
# Always print names with dashes in a CLI context.
prefix = self._normalize_name(option.name)
if option.metavar:
prefix += "=" + option.metavar
description = option.help or ""
if option.default is not None and option.default != '':
description += " (default %s)" % option.default
lines = textwrap.wrap(description, 79 - 35)
if len(prefix) > 30 or len(lines) == 0:
lines.insert(0, '')
print(" --%-30s %s" % (prefix, lines[0]), file=file)
for line in lines[1:]:
print("%-34s %s" % (' ', line), file=file)
print(file=file)
def _help_callback(self, value):
if value:
self.print_help()
sys.exit(0)
def add_parse_callback(self, callback):
"""Adds a parse callback, to be invoked when option parsing is done."""
self._parse_callbacks.append(stack_context.wrap(callback))
def run_parse_callbacks(self):
for callback in self._parse_callbacks:
callback()
def mockable(self):
"""Returns a wrapper around self that is compatible with
`mock.patch <unittest.mock.patch>`.
The `mock.patch <unittest.mock.patch>` function (included in
the standard library `unittest.mock` package since Python 3.3,
or in the third-party ``mock`` package for older versions of
Python) is incompatible with objects like ``options`` that
override ``__getattr__`` and ``__setattr__``. This function
returns an object that can be used with `mock.patch.object
<unittest.mock.patch.object>` to modify option values::
with mock.patch.object(options.mockable(), 'name', value):
assert options.name == value
"""
return _Mockable(self)
class _Mockable(object):
"""`mock.patch` compatible wrapper for `OptionParser`.
As of ``mock`` version 1.0.1, when an object uses ``__getattr__``
hooks instead of ``__dict__``, ``patch.__exit__`` tries to delete
the attribute it set instead of setting a new one (assuming that
the object does not catpure ``__setattr__``, so the patch
created a new attribute in ``__dict__``).
_Mockable's getattr and setattr pass through to the underlying
OptionParser, and delattr undoes the effect of a previous setattr.
"""
def __init__(self, options):
# Modify __dict__ directly to bypass __setattr__
self.__dict__['_options'] = options
self.__dict__['_originals'] = {}
def __getattr__(self, name):
return getattr(self._options, name)
def __setattr__(self, name, value):
assert name not in self._originals, "don't reuse mockable objects"
self._originals[name] = getattr(self._options, name)
setattr(self._options, name, value)
def __delattr__(self, name):
setattr(self._options, name, self._originals.pop(name))
class _Option(object):
UNSET = object()
def __init__(self, name, default=None, type=basestring_type, help=None,
metavar=None, multiple=False, file_name=None, group_name=None,
callback=None):
if default is None and multiple:
default = []
self.name = name
self.type = type
self.help = help
self.metavar = metavar
self.multiple = multiple
self.file_name = file_name
self.group_name = group_name
self.callback = callback
self.default = default
self._value = _Option.UNSET
def value(self):
return self.default if self._value is _Option.UNSET else self._value
def parse(self, value):
_parse = {
datetime.datetime: self._parse_datetime,
datetime.timedelta: self._parse_timedelta,
bool: self._parse_bool,
basestring_type: self._parse_string,
}.get(self.type, self.type)
if self.multiple:
self._value = []
for part in value.split(","):
if issubclass(self.type, numbers.Integral):
# allow ranges of the form X:Y (inclusive at both ends)
lo, _, hi = part.partition(":")
lo = _parse(lo)
hi = _parse(hi) if hi else lo
self._value.extend(range(lo, hi + 1))
else:
self._value.append(_parse(part))
else:
self._value = _parse(value)
if self.callback is not None:
self.callback(self._value)
return self.value()
def set(self, value):
if self.multiple:
if not isinstance(value, list):
raise Error("Option %r is required to be a list of %s" %
(self.name, self.type.__name__))
for item in value:
if item is not None and not isinstance(item, self.type):
raise Error("Option %r is required to be a list of %s" %
(self.name, self.type.__name__))
else:
if value is not None and not isinstance(value, self.type):
raise Error("Option %r is required to be a %s (%s given)" %
(self.name, self.type.__name__, type(value)))
self._value = value
if self.callback is not None:
self.callback(self._value)
# Supported date/time formats in our options
_DATETIME_FORMATS = [
"%a %b %d %H:%M:%S %Y",
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y-%m-%dT%H:%M",
"%Y%m%d %H:%M:%S",
"%Y%m%d %H:%M",
"%Y-%m-%d",
"%Y%m%d",
"%H:%M:%S",
"%H:%M",
]
def _parse_datetime(self, value):
for format in self._DATETIME_FORMATS:
try:
return datetime.datetime.strptime(value, format)
except ValueError:
pass
raise Error('Unrecognized date/time format: %r' % value)
_TIMEDELTA_ABBREV_DICT = {
'h': 'hours',
'm': 'minutes',
'min': 'minutes',
's': 'seconds',
'sec': 'seconds',
'ms': 'milliseconds',
'us': 'microseconds',
'd': 'days',
'w': 'weeks',
}
_FLOAT_PATTERN = r'[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?'
_TIMEDELTA_PATTERN = re.compile(
r'\s*(%s)\s*(\w*)\s*' % _FLOAT_PATTERN, re.IGNORECASE)
def _parse_timedelta(self, value):
try:
sum = datetime.timedelta()
start = 0
while start < len(value):
m = self._TIMEDELTA_PATTERN.match(value, start)
if not m:
raise Exception()
num = float(m.group(1))
units = m.group(2) or 'seconds'
units = self._TIMEDELTA_ABBREV_DICT.get(units, units)
sum += datetime.timedelta(**{units: num})
start = m.end()
return sum
except Exception:
raise
def _parse_bool(self, value):
return value.lower() not in ("false", "0", "f")
def _parse_string(self, value):
return _unicode(value)
options = OptionParser()
"""Global options object.
All defined options are available as attributes on this object.
"""
def define(name, default=None, type=None, help=None, metavar=None,
multiple=False, group=None, callback=None):
"""Defines an option in the global namespace.
See `OptionParser.define`.
"""
return options.define(name, default=default, type=type, help=help,
metavar=metavar, multiple=multiple, group=group,
callback=callback)
def parse_command_line(args=None, final=True):
"""Parses global options from the command line.
See `OptionParser.parse_command_line`.
"""
return options.parse_command_line(args, final=final)
def parse_config_file(path, final=True):
"""Parses global options from a config file.
See `OptionParser.parse_config_file`.
"""
return options.parse_config_file(path, final=final)
def print_help(file=None):
"""Prints all the command line options to stderr (or another file).
See `OptionParser.print_help`.
"""
return options.print_help(file)
def add_parse_callback(callback):
"""Adds a parse callback, to be invoked when option parsing is done.
See `OptionParser.add_parse_callback`
"""
options.add_parse_callback(callback)
# Default options
define_logging_options(options)

View File

299
lib/tornado/platform/asyncio.py Executable file
View File

@@ -0,0 +1,299 @@
"""Bridges between the `asyncio` module and Tornado IOLoop.
.. versionadded:: 3.2
This module integrates Tornado with the ``asyncio`` module introduced
in Python 3.4. This makes it possible to combine the two libraries on
the same event loop.
.. deprecated:: 5.0
While the code in this module is still used, it is now enabled
automatically when `asyncio` is available, so applications should
no longer need to refer to this module directly.
.. note::
Tornado requires the `~asyncio.AbstractEventLoop.add_reader` family of
methods, so it is not compatible with the `~asyncio.ProactorEventLoop` on
Windows. Use the `~asyncio.SelectorEventLoop` instead.
"""
from __future__ import absolute_import, division, print_function
import functools
from tornado.gen import convert_yielded
from tornado.ioloop import IOLoop
from tornado import stack_context
import asyncio
class BaseAsyncIOLoop(IOLoop):
def initialize(self, asyncio_loop, **kwargs):
self.asyncio_loop = asyncio_loop
# Maps fd to (fileobj, handler function) pair (as in IOLoop.add_handler)
self.handlers = {}
# Set of fds listening for reads/writes
self.readers = set()
self.writers = set()
self.closing = False
# If an asyncio loop was closed through an asyncio interface
# instead of IOLoop.close(), we'd never hear about it and may
# have left a dangling reference in our map. In case an
# application (or, more likely, a test suite) creates and
# destroys a lot of event loops in this way, check here to
# ensure that we don't have a lot of dead loops building up in
# the map.
#
# TODO(bdarnell): consider making self.asyncio_loop a weakref
# for AsyncIOMainLoop and make _ioloop_for_asyncio a
# WeakKeyDictionary.
for loop in list(IOLoop._ioloop_for_asyncio):
if loop.is_closed():
del IOLoop._ioloop_for_asyncio[loop]
IOLoop._ioloop_for_asyncio[asyncio_loop] = self
super(BaseAsyncIOLoop, self).initialize(**kwargs)
def close(self, all_fds=False):
self.closing = True
for fd in list(self.handlers):
fileobj, handler_func = self.handlers[fd]
self.remove_handler(fd)
if all_fds:
self.close_fd(fileobj)
# Remove the mapping before closing the asyncio loop. If this
# happened in the other order, we could race against another
# initialize() call which would see the closed asyncio loop,
# assume it was closed from the asyncio side, and do this
# cleanup for us, leading to a KeyError.
del IOLoop._ioloop_for_asyncio[self.asyncio_loop]
self.asyncio_loop.close()
def add_handler(self, fd, handler, events):
fd, fileobj = self.split_fd(fd)
if fd in self.handlers:
raise ValueError("fd %s added twice" % fd)
self.handlers[fd] = (fileobj, stack_context.wrap(handler))
if events & IOLoop.READ:
self.asyncio_loop.add_reader(
fd, self._handle_events, fd, IOLoop.READ)
self.readers.add(fd)
if events & IOLoop.WRITE:
self.asyncio_loop.add_writer(
fd, self._handle_events, fd, IOLoop.WRITE)
self.writers.add(fd)
def update_handler(self, fd, events):
fd, fileobj = self.split_fd(fd)
if events & IOLoop.READ:
if fd not in self.readers:
self.asyncio_loop.add_reader(
fd, self._handle_events, fd, IOLoop.READ)
self.readers.add(fd)
else:
if fd in self.readers:
self.asyncio_loop.remove_reader(fd)
self.readers.remove(fd)
if events & IOLoop.WRITE:
if fd not in self.writers:
self.asyncio_loop.add_writer(
fd, self._handle_events, fd, IOLoop.WRITE)
self.writers.add(fd)
else:
if fd in self.writers:
self.asyncio_loop.remove_writer(fd)
self.writers.remove(fd)
def remove_handler(self, fd):
fd, fileobj = self.split_fd(fd)
if fd not in self.handlers:
return
if fd in self.readers:
self.asyncio_loop.remove_reader(fd)
self.readers.remove(fd)
if fd in self.writers:
self.asyncio_loop.remove_writer(fd)
self.writers.remove(fd)
del self.handlers[fd]
def _handle_events(self, fd, events):
fileobj, handler_func = self.handlers[fd]
handler_func(fileobj, events)
def start(self):
try:
old_loop = asyncio.get_event_loop()
except (RuntimeError, AssertionError):
old_loop = None
try:
self._setup_logging()
asyncio.set_event_loop(self.asyncio_loop)
self.asyncio_loop.run_forever()
finally:
asyncio.set_event_loop(old_loop)
def stop(self):
self.asyncio_loop.stop()
def call_at(self, when, callback, *args, **kwargs):
# asyncio.call_at supports *args but not **kwargs, so bind them here.
# We do not synchronize self.time and asyncio_loop.time, so
# convert from absolute to relative.
return self.asyncio_loop.call_later(
max(0, when - self.time()), self._run_callback,
functools.partial(stack_context.wrap(callback), *args, **kwargs))
def remove_timeout(self, timeout):
timeout.cancel()
def add_callback(self, callback, *args, **kwargs):
try:
self.asyncio_loop.call_soon_threadsafe(
self._run_callback,
functools.partial(stack_context.wrap(callback), *args, **kwargs))
except RuntimeError:
# "Event loop is closed". Swallow the exception for
# consistency with PollIOLoop (and logical consistency
# with the fact that we can't guarantee that an
# add_callback that completes without error will
# eventually execute).
pass
add_callback_from_signal = add_callback
def run_in_executor(self, executor, func, *args):
return self.asyncio_loop.run_in_executor(executor, func, *args)
def set_default_executor(self, executor):
return self.asyncio_loop.set_default_executor(executor)
class AsyncIOMainLoop(BaseAsyncIOLoop):
"""``AsyncIOMainLoop`` creates an `.IOLoop` that corresponds to the
current ``asyncio`` event loop (i.e. the one returned by
``asyncio.get_event_loop()``).
.. deprecated:: 5.0
Now used automatically when appropriate; it is no longer necessary
to refer to this class directly.
.. versionchanged:: 5.0
Closing an `AsyncIOMainLoop` now closes the underlying asyncio loop.
"""
def initialize(self, **kwargs):
super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(), **kwargs)
def make_current(self):
# AsyncIOMainLoop already refers to the current asyncio loop so
# nothing to do here.
pass
class AsyncIOLoop(BaseAsyncIOLoop):
"""``AsyncIOLoop`` is an `.IOLoop` that runs on an ``asyncio`` event loop.
This class follows the usual Tornado semantics for creating new
``IOLoops``; these loops are not necessarily related to the
``asyncio`` default event loop.
Each ``AsyncIOLoop`` creates a new ``asyncio.EventLoop``; this object
can be accessed with the ``asyncio_loop`` attribute.
.. versionchanged:: 5.0
When an ``AsyncIOLoop`` becomes the current `.IOLoop`, it also sets
the current `asyncio` event loop.
.. deprecated:: 5.0
Now used automatically when appropriate; it is no longer necessary
to refer to this class directly.
"""
def initialize(self, **kwargs):
self.is_current = False
loop = asyncio.new_event_loop()
try:
super(AsyncIOLoop, self).initialize(loop, **kwargs)
except Exception:
# If initialize() does not succeed (taking ownership of the loop),
# we have to close it.
loop.close()
raise
def close(self, all_fds=False):
if self.is_current:
self.clear_current()
super(AsyncIOLoop, self).close(all_fds=all_fds)
def make_current(self):
if not self.is_current:
try:
self.old_asyncio = asyncio.get_event_loop()
except (RuntimeError, AssertionError):
self.old_asyncio = None
self.is_current = True
asyncio.set_event_loop(self.asyncio_loop)
def _clear_current_hook(self):
if self.is_current:
asyncio.set_event_loop(self.old_asyncio)
self.is_current = False
def to_tornado_future(asyncio_future):
"""Convert an `asyncio.Future` to a `tornado.concurrent.Future`.
.. versionadded:: 4.1
.. deprecated:: 5.0
Tornado ``Futures`` have been merged with `asyncio.Future`,
so this method is now a no-op.
"""
return asyncio_future
def to_asyncio_future(tornado_future):
"""Convert a Tornado yieldable object to an `asyncio.Future`.
.. versionadded:: 4.1
.. versionchanged:: 4.3
Now accepts any yieldable object, not just
`tornado.concurrent.Future`.
.. deprecated:: 5.0
Tornado ``Futures`` have been merged with `asyncio.Future`,
so this method is now equivalent to `tornado.gen.convert_yielded`.
"""
return convert_yielded(tornado_future)
class AnyThreadEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
"""Event loop policy that allows loop creation on any thread.
The default `asyncio` event loop policy only automatically creates
event loops in the main threads. Other threads must create event
loops explicitly or `asyncio.get_event_loop` (and therefore
`.IOLoop.current`) will fail. Installing this policy allows event
loops to be created automatically on any thread, matching the
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
Usage::
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
.. versionadded:: 5.0
"""
def get_event_loop(self):
try:
return super().get_event_loop()
except (RuntimeError, AssertionError):
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
# and changed to a RuntimeError in 3.4.3.
# "There is no current event loop in thread %r"
loop = self.new_event_loop()
self.set_event_loop(loop)
return loop

58
lib/tornado/platform/auto.py Executable file
View File

@@ -0,0 +1,58 @@
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Implementation of platform-specific functionality.
For each function or class described in `tornado.platform.interface`,
the appropriate platform-specific implementation exists in this module.
Most code that needs access to this functionality should do e.g.::
from tornado.platform.auto import set_close_exec
"""
from __future__ import absolute_import, division, print_function
import os
if 'APPENGINE_RUNTIME' in os.environ:
from tornado.platform.common import Waker
def set_close_exec(fd):
pass
elif os.name == 'nt':
from tornado.platform.common import Waker
from tornado.platform.windows import set_close_exec
else:
from tornado.platform.posix import set_close_exec, Waker
try:
# monotime monkey-patches the time module to have a monotonic function
# in versions of python before 3.3.
import monotime
# Silence pyflakes warning about this unused import
monotime
except ImportError:
pass
try:
# monotonic can provide a monotonic function in versions of python before
# 3.3, too.
from monotonic import monotonic as monotonic_time
except ImportError:
try:
from time import monotonic as monotonic_time
except ImportError:
monotonic_time = None
__all__ = ['Waker', 'set_close_exec', 'monotonic_time']

4
lib/tornado/platform/auto.pyi Executable file
View File

@@ -0,0 +1,4 @@
# auto.py is full of patterns mypy doesn't like, so for type checking
# purposes we replace it with interface.py.
from .interface import *

View File

@@ -0,0 +1,79 @@
from __future__ import absolute_import, division, print_function
import pycares # type: ignore
import socket
from tornado.concurrent import Future
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.netutil import Resolver, is_valid_ip
class CaresResolver(Resolver):
"""Name resolver based on the c-ares library.
This is a non-blocking and non-threaded resolver. It may not produce
the same results as the system resolver, but can be used for non-blocking
resolution when threads cannot be used.
c-ares fails to resolve some names when ``family`` is ``AF_UNSPEC``,
so it is only recommended for use in ``AF_INET`` (i.e. IPv4). This is
the default for ``tornado.simple_httpclient``, but other libraries
may default to ``AF_UNSPEC``.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
def initialize(self):
self.io_loop = IOLoop.current()
self.channel = pycares.Channel(sock_state_cb=self._sock_state_cb)
self.fds = {}
def _sock_state_cb(self, fd, readable, writable):
state = ((IOLoop.READ if readable else 0) |
(IOLoop.WRITE if writable else 0))
if not state:
self.io_loop.remove_handler(fd)
del self.fds[fd]
elif fd in self.fds:
self.io_loop.update_handler(fd, state)
self.fds[fd] = state
else:
self.io_loop.add_handler(fd, self._handle_events, state)
self.fds[fd] = state
def _handle_events(self, fd, events):
read_fd = pycares.ARES_SOCKET_BAD
write_fd = pycares.ARES_SOCKET_BAD
if events & IOLoop.READ:
read_fd = fd
if events & IOLoop.WRITE:
write_fd = fd
self.channel.process_fd(read_fd, write_fd)
@gen.coroutine
def resolve(self, host, port, family=0):
if is_valid_ip(host):
addresses = [host]
else:
# gethostbyname doesn't take callback as a kwarg
fut = Future()
self.channel.gethostbyname(host, family,
lambda result, error: fut.set_result((result, error)))
result, error = yield fut
if error:
raise IOError('C-Ares returned error %s: %s while resolving %s' %
(error, pycares.errno.strerror(error), host))
addresses = result.addresses
addrinfo = []
for address in addresses:
if '.' in address:
address_family = socket.AF_INET
elif ':' in address:
address_family = socket.AF_INET6
else:
address_family = socket.AF_UNSPEC
if family != socket.AF_UNSPEC and family != address_family:
raise IOError('Requested socket family %d but got %d' %
(family, address_family))
addrinfo.append((address_family, (address, port)))
raise gen.Return(addrinfo)

113
lib/tornado/platform/common.py Executable file
View File

@@ -0,0 +1,113 @@
"""Lowest-common-denominator implementations of platform functionality."""
from __future__ import absolute_import, division, print_function
import errno
import socket
import time
from tornado.platform import interface
from tornado.util import errno_from_exception
def try_close(f):
# Avoid issue #875 (race condition when using the file in another
# thread).
for i in range(10):
try:
f.close()
except IOError:
# Yield to another thread
time.sleep(1e-3)
else:
break
# Try a last time and let raise
f.close()
class Waker(interface.Waker):
"""Create an OS independent asynchronous pipe.
For use on platforms that don't have os.pipe() (or where pipes cannot
be passed to select()), but do have sockets. This includes Windows
and Jython.
"""
def __init__(self):
from .auto import set_close_exec
# Based on Zope select_trigger.py:
# https://github.com/zopefoundation/Zope/blob/master/src/ZServer/medusa/thread/select_trigger.py
self.writer = socket.socket()
set_close_exec(self.writer.fileno())
# Disable buffering -- pulling the trigger sends 1 byte,
# and we want that sent immediately, to wake up ASAP.
self.writer.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
count = 0
while 1:
count += 1
# Bind to a local port; for efficiency, let the OS pick
# a free port for us.
# Unfortunately, stress tests showed that we may not
# be able to connect to that port ("Address already in
# use") despite that the OS picked it. This appears
# to be a race bug in the Windows socket implementation.
# So we loop until a connect() succeeds (almost always
# on the first try). See the long thread at
# http://mail.zope.org/pipermail/zope/2005-July/160433.html
# for hideous details.
a = socket.socket()
set_close_exec(a.fileno())
a.bind(("127.0.0.1", 0))
a.listen(1)
connect_address = a.getsockname() # assigned (host, port) pair
try:
self.writer.connect(connect_address)
break # success
except socket.error as detail:
if (not hasattr(errno, 'WSAEADDRINUSE') or
errno_from_exception(detail) != errno.WSAEADDRINUSE):
# "Address already in use" is the only error
# I've seen on two WinXP Pro SP2 boxes, under
# Pythons 2.3.5 and 2.4.1.
raise
# (10048, 'Address already in use')
# assert count <= 2 # never triggered in Tim's tests
if count >= 10: # I've never seen it go above 2
a.close()
self.writer.close()
raise socket.error("Cannot bind trigger!")
# Close `a` and try again. Note: I originally put a short
# sleep() here, but it didn't appear to help or hurt.
a.close()
self.reader, addr = a.accept()
set_close_exec(self.reader.fileno())
self.reader.setblocking(0)
self.writer.setblocking(0)
a.close()
self.reader_fd = self.reader.fileno()
def fileno(self):
return self.reader.fileno()
def write_fileno(self):
return self.writer.fileno()
def wake(self):
try:
self.writer.send(b"x")
except (IOError, socket.error, ValueError):
pass
def consume(self):
try:
while True:
result = self.reader.recv(1024)
if not result:
break
except (IOError, socket.error):
pass
def close(self):
self.reader.close()
try_close(self.writer)

25
lib/tornado/platform/epoll.py Executable file
View File

@@ -0,0 +1,25 @@
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""EPoll-based IOLoop implementation for Linux systems."""
from __future__ import absolute_import, division, print_function
import select
from tornado.ioloop import PollIOLoop
class EPollIOLoop(PollIOLoop):
def initialize(self, **kwargs):
super(EPollIOLoop, self).initialize(impl=select.epoll(), **kwargs)

View File

@@ -0,0 +1,66 @@
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Interfaces for platform-specific functionality.
This module exists primarily for documentation purposes and as base classes
for other tornado.platform modules. Most code should import the appropriate
implementation from `tornado.platform.auto`.
"""
from __future__ import absolute_import, division, print_function
def set_close_exec(fd):
"""Sets the close-on-exec bit (``FD_CLOEXEC``)for a file descriptor."""
raise NotImplementedError()
class Waker(object):
"""A socket-like object that can wake another thread from ``select()``.
The `~tornado.ioloop.IOLoop` will add the Waker's `fileno()` to
its ``select`` (or ``epoll`` or ``kqueue``) calls. When another
thread wants to wake up the loop, it calls `wake`. Once it has woken
up, it will call `consume` to do any necessary per-wake cleanup. When
the ``IOLoop`` is closed, it closes its waker too.
"""
def fileno(self):
"""Returns the read file descriptor for this waker.
Must be suitable for use with ``select()`` or equivalent on the
local platform.
"""
raise NotImplementedError()
def write_fileno(self):
"""Returns the write file descriptor for this waker."""
raise NotImplementedError()
def wake(self):
"""Triggers activity on the waker's file descriptor."""
raise NotImplementedError()
def consume(self):
"""Called after the listen has woken up to do any necessary cleanup."""
raise NotImplementedError()
def close(self):
"""Closes the waker's file descriptor(s)."""
raise NotImplementedError()
def monotonic_time():
raise NotImplementedError()

90
lib/tornado/platform/kqueue.py Executable file
View File

@@ -0,0 +1,90 @@
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""KQueue-based IOLoop implementation for BSD/Mac systems."""
from __future__ import absolute_import, division, print_function
import select
from tornado.ioloop import IOLoop, PollIOLoop
assert hasattr(select, 'kqueue'), 'kqueue not supported'
class _KQueue(object):
"""A kqueue-based event loop for BSD/Mac systems."""
def __init__(self):
self._kqueue = select.kqueue()
self._active = {}
def fileno(self):
return self._kqueue.fileno()
def close(self):
self._kqueue.close()
def register(self, fd, events):
if fd in self._active:
raise IOError("fd %s already registered" % fd)
self._control(fd, events, select.KQ_EV_ADD)
self._active[fd] = events
def modify(self, fd, events):
self.unregister(fd)
self.register(fd, events)
def unregister(self, fd):
events = self._active.pop(fd)
self._control(fd, events, select.KQ_EV_DELETE)
def _control(self, fd, events, flags):
kevents = []
if events & IOLoop.WRITE:
kevents.append(select.kevent(
fd, filter=select.KQ_FILTER_WRITE, flags=flags))
if events & IOLoop.READ:
kevents.append(select.kevent(
fd, filter=select.KQ_FILTER_READ, flags=flags))
# Even though control() takes a list, it seems to return EINVAL
# on Mac OS X (10.6) when there is more than one event in the list.
for kevent in kevents:
self._kqueue.control([kevent], 0)
def poll(self, timeout):
kevents = self._kqueue.control(None, 1000, timeout)
events = {}
for kevent in kevents:
fd = kevent.ident
if kevent.filter == select.KQ_FILTER_READ:
events[fd] = events.get(fd, 0) | IOLoop.READ
if kevent.filter == select.KQ_FILTER_WRITE:
if kevent.flags & select.KQ_EV_EOF:
# If an asynchronous connection is refused, kqueue
# returns a write event with the EOF flag set.
# Turn this into an error for consistency with the
# other IOLoop implementations.
# Note that for read events, EOF may be returned before
# all data has been consumed from the socket buffer,
# so we only check for EOF on write events.
events[fd] = IOLoop.ERROR
else:
events[fd] = events.get(fd, 0) | IOLoop.WRITE
if kevent.flags & select.KQ_EV_ERROR:
events[fd] = events.get(fd, 0) | IOLoop.ERROR
return events.items()
class KQueueIOLoop(PollIOLoop):
def initialize(self, **kwargs):
super(KQueueIOLoop, self).initialize(impl=_KQueue(), **kwargs)

69
lib/tornado/platform/posix.py Executable file
View File

@@ -0,0 +1,69 @@
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Posix implementations of platform-specific functionality."""
from __future__ import absolute_import, division, print_function
import fcntl
import os
from tornado.platform import common, interface
def set_close_exec(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFD)
fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
def _set_nonblocking(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
class Waker(interface.Waker):
def __init__(self):
r, w = os.pipe()
_set_nonblocking(r)
_set_nonblocking(w)
set_close_exec(r)
set_close_exec(w)
self.reader = os.fdopen(r, "rb", 0)
self.writer = os.fdopen(w, "wb", 0)
def fileno(self):
return self.reader.fileno()
def write_fileno(self):
return self.writer.fileno()
def wake(self):
try:
self.writer.write(b"x")
except (IOError, ValueError):
pass
def consume(self):
try:
while True:
result = self.reader.read()
if not result:
break
except IOError:
pass
def close(self):
self.reader.close()
common.try_close(self.writer)

75
lib/tornado/platform/select.py Executable file
View File

@@ -0,0 +1,75 @@
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Select-based IOLoop implementation.
Used as a fallback for systems that don't support epoll or kqueue.
"""
from __future__ import absolute_import, division, print_function
import select
from tornado.ioloop import IOLoop, PollIOLoop
class _Select(object):
"""A simple, select()-based IOLoop implementation for non-Linux systems"""
def __init__(self):
self.read_fds = set()
self.write_fds = set()
self.error_fds = set()
self.fd_sets = (self.read_fds, self.write_fds, self.error_fds)
def close(self):
pass
def register(self, fd, events):
if fd in self.read_fds or fd in self.write_fds or fd in self.error_fds:
raise IOError("fd %s already registered" % fd)
if events & IOLoop.READ:
self.read_fds.add(fd)
if events & IOLoop.WRITE:
self.write_fds.add(fd)
if events & IOLoop.ERROR:
self.error_fds.add(fd)
# Closed connections are reported as errors by epoll and kqueue,
# but as zero-byte reads by select, so when errors are requested
# we need to listen for both read and error.
# self.read_fds.add(fd)
def modify(self, fd, events):
self.unregister(fd)
self.register(fd, events)
def unregister(self, fd):
self.read_fds.discard(fd)
self.write_fds.discard(fd)
self.error_fds.discard(fd)
def poll(self, timeout):
readable, writeable, errors = select.select(
self.read_fds, self.write_fds, self.error_fds, timeout)
events = {}
for fd in readable:
events[fd] = events.get(fd, 0) | IOLoop.READ
for fd in writeable:
events[fd] = events.get(fd, 0) | IOLoop.WRITE
for fd in errors:
events[fd] = events.get(fd, 0) | IOLoop.ERROR
return events.items()
class SelectIOLoop(PollIOLoop):
def initialize(self, **kwargs):
super(SelectIOLoop, self).initialize(impl=_Select(), **kwargs)

609
lib/tornado/platform/twisted.py Executable file
View File

@@ -0,0 +1,609 @@
# Author: Ovidiu Predescu
# Date: July 2011
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Bridges between the Twisted reactor and Tornado IOLoop.
This module lets you run applications and libraries written for
Twisted in a Tornado application. It can be used in two modes,
depending on which library's underlying event loop you want to use.
This module has been tested with Twisted versions 11.0.0 and newer.
"""
from __future__ import absolute_import, division, print_function
import datetime
import functools
import numbers
import socket
import sys
import twisted.internet.abstract # type: ignore
from twisted.internet.defer import Deferred # type: ignore
from twisted.internet.posixbase import PosixReactorBase # type: ignore
from twisted.internet.interfaces import IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor # type: ignore # noqa: E501
from twisted.python import failure, log # type: ignore
from twisted.internet import error # type: ignore
import twisted.names.cache # type: ignore
import twisted.names.client # type: ignore
import twisted.names.hosts # type: ignore
import twisted.names.resolve # type: ignore
from zope.interface import implementer # type: ignore
from tornado.concurrent import Future, future_set_exc_info
from tornado.escape import utf8
from tornado import gen
import tornado.ioloop
from tornado.log import app_log
from tornado.netutil import Resolver
from tornado.stack_context import NullContext, wrap
from tornado.ioloop import IOLoop
from tornado.util import timedelta_to_seconds
@implementer(IDelayedCall)
class TornadoDelayedCall(object):
"""DelayedCall object for Tornado."""
def __init__(self, reactor, seconds, f, *args, **kw):
self._reactor = reactor
self._func = functools.partial(f, *args, **kw)
self._time = self._reactor.seconds() + seconds
self._timeout = self._reactor._io_loop.add_timeout(self._time,
self._called)
self._active = True
def _called(self):
self._active = False
self._reactor._removeDelayedCall(self)
try:
self._func()
except:
app_log.error("_called caught exception", exc_info=True)
def getTime(self):
return self._time
def cancel(self):
self._active = False
self._reactor._io_loop.remove_timeout(self._timeout)
self._reactor._removeDelayedCall(self)
def delay(self, seconds):
self._reactor._io_loop.remove_timeout(self._timeout)
self._time += seconds
self._timeout = self._reactor._io_loop.add_timeout(self._time,
self._called)
def reset(self, seconds):
self._reactor._io_loop.remove_timeout(self._timeout)
self._time = self._reactor.seconds() + seconds
self._timeout = self._reactor._io_loop.add_timeout(self._time,
self._called)
def active(self):
return self._active
@implementer(IReactorTime, IReactorFDSet)
class TornadoReactor(PosixReactorBase):
"""Twisted reactor built on the Tornado IOLoop.
`TornadoReactor` implements the Twisted reactor interface on top of
the Tornado IOLoop. To use it, simply call `install` at the beginning
of the application::
import tornado.platform.twisted
tornado.platform.twisted.install()
from twisted.internet import reactor
When the app is ready to start, call ``IOLoop.current().start()``
instead of ``reactor.run()``.
It is also possible to create a non-global reactor by calling
``tornado.platform.twisted.TornadoReactor()``. However, if
the `.IOLoop` and reactor are to be short-lived (such as those used in
unit tests), additional cleanup may be required. Specifically, it is
recommended to call::
reactor.fireSystemEvent('shutdown')
reactor.disconnectAll()
before closing the `.IOLoop`.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
.. deprecated:: 5.1
This class will be removed in Tornado 6.0. Use
``twisted.internet.asyncioreactor.AsyncioSelectorReactor``
instead.
"""
def __init__(self):
self._io_loop = tornado.ioloop.IOLoop.current()
self._readers = {} # map of reader objects to fd
self._writers = {} # map of writer objects to fd
self._fds = {} # a map of fd to a (reader, writer) tuple
self._delayedCalls = {}
PosixReactorBase.__init__(self)
self.addSystemEventTrigger('during', 'shutdown', self.crash)
# IOLoop.start() bypasses some of the reactor initialization.
# Fire off the necessary events if they weren't already triggered
# by reactor.run().
def start_if_necessary():
if not self._started:
self.fireSystemEvent('startup')
self._io_loop.add_callback(start_if_necessary)
# IReactorTime
def seconds(self):
return self._io_loop.time()
def callLater(self, seconds, f, *args, **kw):
dc = TornadoDelayedCall(self, seconds, f, *args, **kw)
self._delayedCalls[dc] = True
return dc
def getDelayedCalls(self):
return [x for x in self._delayedCalls if x._active]
def _removeDelayedCall(self, dc):
if dc in self._delayedCalls:
del self._delayedCalls[dc]
# IReactorThreads
def callFromThread(self, f, *args, **kw):
assert callable(f), "%s is not callable" % f
with NullContext():
# This NullContext is mainly for an edge case when running
# TwistedIOLoop on top of a TornadoReactor.
# TwistedIOLoop.add_callback uses reactor.callFromThread and
# should not pick up additional StackContexts along the way.
self._io_loop.add_callback(f, *args, **kw)
# We don't need the waker code from the super class, Tornado uses
# its own waker.
def installWaker(self):
pass
def wakeUp(self):
pass
# IReactorFDSet
def _invoke_callback(self, fd, events):
if fd not in self._fds:
return
(reader, writer) = self._fds[fd]
if reader:
err = None
if reader.fileno() == -1:
err = error.ConnectionLost()
elif events & IOLoop.READ:
err = log.callWithLogger(reader, reader.doRead)
if err is None and events & IOLoop.ERROR:
err = error.ConnectionLost()
if err is not None:
self.removeReader(reader)
reader.readConnectionLost(failure.Failure(err))
if writer:
err = None
if writer.fileno() == -1:
err = error.ConnectionLost()
elif events & IOLoop.WRITE:
err = log.callWithLogger(writer, writer.doWrite)
if err is None and events & IOLoop.ERROR:
err = error.ConnectionLost()
if err is not None:
self.removeWriter(writer)
writer.writeConnectionLost(failure.Failure(err))
def addReader(self, reader):
if reader in self._readers:
# Don't add the reader if it's already there
return
fd = reader.fileno()
self._readers[reader] = fd
if fd in self._fds:
(_, writer) = self._fds[fd]
self._fds[fd] = (reader, writer)
if writer:
# We already registered this fd for write events,
# update it for read events as well.
self._io_loop.update_handler(fd, IOLoop.READ | IOLoop.WRITE)
else:
with NullContext():
self._fds[fd] = (reader, None)
self._io_loop.add_handler(fd, self._invoke_callback,
IOLoop.READ)
def addWriter(self, writer):
if writer in self._writers:
return
fd = writer.fileno()
self._writers[writer] = fd
if fd in self._fds:
(reader, _) = self._fds[fd]
self._fds[fd] = (reader, writer)
if reader:
# We already registered this fd for read events,
# update it for write events as well.
self._io_loop.update_handler(fd, IOLoop.READ | IOLoop.WRITE)
else:
with NullContext():
self._fds[fd] = (None, writer)
self._io_loop.add_handler(fd, self._invoke_callback,
IOLoop.WRITE)
def removeReader(self, reader):
if reader in self._readers:
fd = self._readers.pop(reader)
(_, writer) = self._fds[fd]
if writer:
# We have a writer so we need to update the IOLoop for
# write events only.
self._fds[fd] = (None, writer)
self._io_loop.update_handler(fd, IOLoop.WRITE)
else:
# Since we have no writer registered, we remove the
# entry from _fds and unregister the handler from the
# IOLoop
del self._fds[fd]
self._io_loop.remove_handler(fd)
def removeWriter(self, writer):
if writer in self._writers:
fd = self._writers.pop(writer)
(reader, _) = self._fds[fd]
if reader:
# We have a reader so we need to update the IOLoop for
# read events only.
self._fds[fd] = (reader, None)
self._io_loop.update_handler(fd, IOLoop.READ)
else:
# Since we have no reader registered, we remove the
# entry from the _fds and unregister the handler from
# the IOLoop.
del self._fds[fd]
self._io_loop.remove_handler(fd)
def removeAll(self):
return self._removeAll(self._readers, self._writers)
def getReaders(self):
return self._readers.keys()
def getWriters(self):
return self._writers.keys()
# The following functions are mainly used in twisted-style test cases;
# it is expected that most users of the TornadoReactor will call
# IOLoop.start() instead of Reactor.run().
def stop(self):
PosixReactorBase.stop(self)
fire_shutdown = functools.partial(self.fireSystemEvent, "shutdown")
self._io_loop.add_callback(fire_shutdown)
def crash(self):
PosixReactorBase.crash(self)
self._io_loop.stop()
def doIteration(self, delay):
raise NotImplementedError("doIteration")
def mainLoop(self):
# Since this class is intended to be used in applications
# where the top-level event loop is ``io_loop.start()`` rather
# than ``reactor.run()``, it is implemented a little
# differently than other Twisted reactors. We override
# ``mainLoop`` instead of ``doIteration`` and must implement
# timed call functionality on top of `.IOLoop.add_timeout`
# rather than using the implementation in
# ``PosixReactorBase``.
self._io_loop.start()
class _TestReactor(TornadoReactor):
"""Subclass of TornadoReactor for use in unittests.
This can't go in the test.py file because of import-order dependencies
with the Twisted reactor test builder.
"""
def __init__(self):
# always use a new ioloop
IOLoop.clear_current()
IOLoop(make_current=True)
super(_TestReactor, self).__init__()
IOLoop.clear_current()
def listenTCP(self, port, factory, backlog=50, interface=''):
# default to localhost to avoid firewall prompts on the mac
if not interface:
interface = '127.0.0.1'
return super(_TestReactor, self).listenTCP(
port, factory, backlog=backlog, interface=interface)
def listenUDP(self, port, protocol, interface='', maxPacketSize=8192):
if not interface:
interface = '127.0.0.1'
return super(_TestReactor, self).listenUDP(
port, protocol, interface=interface, maxPacketSize=maxPacketSize)
def install():
"""Install this package as the default Twisted reactor.
``install()`` must be called very early in the startup process,
before most other twisted-related imports. Conversely, because it
initializes the `.IOLoop`, it cannot be called before
`.fork_processes` or multi-process `~.TCPServer.start`. These
conflicting requirements make it difficult to use `.TornadoReactor`
in multi-process mode, and an external process manager such as
``supervisord`` is recommended instead.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
.. deprecated:: 5.1
This functio will be removed in Tornado 6.0. Use
``twisted.internet.asyncioreactor.install`` instead.
"""
reactor = TornadoReactor()
from twisted.internet.main import installReactor # type: ignore
installReactor(reactor)
return reactor
@implementer(IReadDescriptor, IWriteDescriptor)
class _FD(object):
def __init__(self, fd, fileobj, handler):
self.fd = fd
self.fileobj = fileobj
self.handler = handler
self.reading = False
self.writing = False
self.lost = False
def fileno(self):
return self.fd
def doRead(self):
if not self.lost:
self.handler(self.fileobj, tornado.ioloop.IOLoop.READ)
def doWrite(self):
if not self.lost:
self.handler(self.fileobj, tornado.ioloop.IOLoop.WRITE)
def connectionLost(self, reason):
if not self.lost:
self.handler(self.fileobj, tornado.ioloop.IOLoop.ERROR)
self.lost = True
writeConnectionLost = readConnectionLost = connectionLost
def logPrefix(self):
return ''
class TwistedIOLoop(tornado.ioloop.IOLoop):
"""IOLoop implementation that runs on Twisted.
`TwistedIOLoop` implements the Tornado IOLoop interface on top of
the Twisted reactor. Recommended usage::
from tornado.platform.twisted import TwistedIOLoop
from twisted.internet import reactor
TwistedIOLoop().install()
# Set up your tornado application as usual using `IOLoop.instance`
reactor.run()
Uses the global Twisted reactor by default. To create multiple
``TwistedIOLoops`` in the same process, you must pass a unique reactor
when constructing each one.
Not compatible with `tornado.process.Subprocess.set_exit_callback`
because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict
with each other.
See also :meth:`tornado.ioloop.IOLoop.install` for general notes on
installing alternative IOLoops.
.. deprecated:: 5.1
The `asyncio` event loop will be the only available implementation in
Tornado 6.0.
"""
def initialize(self, reactor=None, **kwargs):
super(TwistedIOLoop, self).initialize(**kwargs)
if reactor is None:
import twisted.internet.reactor # type: ignore
reactor = twisted.internet.reactor
self.reactor = reactor
self.fds = {}
def close(self, all_fds=False):
fds = self.fds
self.reactor.removeAll()
for c in self.reactor.getDelayedCalls():
c.cancel()
if all_fds:
for fd in fds.values():
self.close_fd(fd.fileobj)
def add_handler(self, fd, handler, events):
if fd in self.fds:
raise ValueError('fd %s added twice' % fd)
fd, fileobj = self.split_fd(fd)
self.fds[fd] = _FD(fd, fileobj, wrap(handler))
if events & tornado.ioloop.IOLoop.READ:
self.fds[fd].reading = True
self.reactor.addReader(self.fds[fd])
if events & tornado.ioloop.IOLoop.WRITE:
self.fds[fd].writing = True
self.reactor.addWriter(self.fds[fd])
def update_handler(self, fd, events):
fd, fileobj = self.split_fd(fd)
if events & tornado.ioloop.IOLoop.READ:
if not self.fds[fd].reading:
self.fds[fd].reading = True
self.reactor.addReader(self.fds[fd])
else:
if self.fds[fd].reading:
self.fds[fd].reading = False
self.reactor.removeReader(self.fds[fd])
if events & tornado.ioloop.IOLoop.WRITE:
if not self.fds[fd].writing:
self.fds[fd].writing = True
self.reactor.addWriter(self.fds[fd])
else:
if self.fds[fd].writing:
self.fds[fd].writing = False
self.reactor.removeWriter(self.fds[fd])
def remove_handler(self, fd):
fd, fileobj = self.split_fd(fd)
if fd not in self.fds:
return
self.fds[fd].lost = True
if self.fds[fd].reading:
self.reactor.removeReader(self.fds[fd])
if self.fds[fd].writing:
self.reactor.removeWriter(self.fds[fd])
del self.fds[fd]
def start(self):
old_current = IOLoop.current(instance=False)
try:
self._setup_logging()
self.make_current()
self.reactor.run()
finally:
if old_current is None:
IOLoop.clear_current()
else:
old_current.make_current()
def stop(self):
self.reactor.crash()
def add_timeout(self, deadline, callback, *args, **kwargs):
# This method could be simplified (since tornado 4.0) by
# overriding call_at instead of add_timeout, but we leave it
# for now as a test of backwards-compatibility.
if isinstance(deadline, numbers.Real):
delay = max(deadline - self.time(), 0)
elif isinstance(deadline, datetime.timedelta):
delay = timedelta_to_seconds(deadline)
else:
raise TypeError("Unsupported deadline %r")
return self.reactor.callLater(
delay, self._run_callback,
functools.partial(wrap(callback), *args, **kwargs))
def remove_timeout(self, timeout):
if timeout.active():
timeout.cancel()
def add_callback(self, callback, *args, **kwargs):
self.reactor.callFromThread(
self._run_callback,
functools.partial(wrap(callback), *args, **kwargs))
def add_callback_from_signal(self, callback, *args, **kwargs):
self.add_callback(callback, *args, **kwargs)
class TwistedResolver(Resolver):
"""Twisted-based asynchronous resolver.
This is a non-blocking and non-threaded resolver. It is
recommended only when threads cannot be used, since it has
limitations compared to the standard ``getaddrinfo``-based
`~tornado.netutil.Resolver` and
`~tornado.netutil.DefaultExecutorResolver`. Specifically, it returns at
most one result, and arguments other than ``host`` and ``family``
are ignored. It may fail to resolve when ``family`` is not
``socket.AF_UNSPEC``.
Requires Twisted 12.1 or newer.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
def initialize(self):
# partial copy of twisted.names.client.createResolver, which doesn't
# allow for a reactor to be passed in.
self.reactor = tornado.platform.twisted.TornadoReactor()
host_resolver = twisted.names.hosts.Resolver('/etc/hosts')
cache_resolver = twisted.names.cache.CacheResolver(reactor=self.reactor)
real_resolver = twisted.names.client.Resolver('/etc/resolv.conf',
reactor=self.reactor)
self.resolver = twisted.names.resolve.ResolverChain(
[host_resolver, cache_resolver, real_resolver])
@gen.coroutine
def resolve(self, host, port, family=0):
# getHostByName doesn't accept IP addresses, so if the input
# looks like an IP address just return it immediately.
if twisted.internet.abstract.isIPAddress(host):
resolved = host
resolved_family = socket.AF_INET
elif twisted.internet.abstract.isIPv6Address(host):
resolved = host
resolved_family = socket.AF_INET6
else:
deferred = self.resolver.getHostByName(utf8(host))
fut = Future()
deferred.addBoth(fut.set_result)
resolved = yield fut
if isinstance(resolved, failure.Failure):
try:
resolved.raiseException()
except twisted.names.error.DomainError as e:
raise IOError(e)
elif twisted.internet.abstract.isIPAddress(resolved):
resolved_family = socket.AF_INET
elif twisted.internet.abstract.isIPv6Address(resolved):
resolved_family = socket.AF_INET6
else:
resolved_family = socket.AF_UNSPEC
if family != socket.AF_UNSPEC and family != resolved_family:
raise Exception('Requested socket family %d but got %d' %
(family, resolved_family))
result = [
(resolved_family, (resolved, port)),
]
raise gen.Return(result)
if hasattr(gen.convert_yielded, 'register'):
@gen.convert_yielded.register(Deferred) # type: ignore
def _(d):
f = Future()
def errback(failure):
try:
failure.raiseException()
# Should never happen, but just in case
raise Exception("errback called without error")
except:
future_set_exc_info(f, sys.exc_info())
d.addCallbacks(f.set_result, errback)
return f

20
lib/tornado/platform/windows.py Executable file
View File

@@ -0,0 +1,20 @@
# NOTE: win32 support is currently experimental, and not recommended
# for production use.
from __future__ import absolute_import, division, print_function
import ctypes # type: ignore
import ctypes.wintypes # type: ignore
# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx
SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
SetHandleInformation.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD) # noqa: E501
SetHandleInformation.restype = ctypes.wintypes.BOOL
HANDLE_FLAG_INHERIT = 0x00000001
def set_close_exec(fd):
success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0)
if not success:
raise ctypes.WinError()

361
lib/tornado/process.py Executable file
View File

@@ -0,0 +1,361 @@
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Utilities for working with multiple processes, including both forking
the server into multiple processes and managing subprocesses.
"""
from __future__ import absolute_import, division, print_function
import errno
import os
import signal
import subprocess
import sys
import time
from binascii import hexlify
from tornado.concurrent import Future, future_set_result_unless_cancelled
from tornado import ioloop
from tornado.iostream import PipeIOStream
from tornado.log import gen_log
from tornado.platform.auto import set_close_exec
from tornado import stack_context
from tornado.util import errno_from_exception, PY3
try:
import multiprocessing
except ImportError:
# Multiprocessing is not available on Google App Engine.
multiprocessing = None
if PY3:
long = int
# Re-export this exception for convenience.
try:
CalledProcessError = subprocess.CalledProcessError
except AttributeError:
# The subprocess module exists in Google App Engine, but is empty.
# This module isn't very useful in that case, but it should
# at least be importable.
if 'APPENGINE_RUNTIME' not in os.environ:
raise
def cpu_count():
"""Returns the number of processors on this machine."""
if multiprocessing is None:
return 1
try:
return multiprocessing.cpu_count()
except NotImplementedError:
pass
try:
return os.sysconf("SC_NPROCESSORS_CONF")
except (AttributeError, ValueError):
pass
gen_log.error("Could not detect number of processors; assuming 1")
return 1
def _reseed_random():
if 'random' not in sys.modules:
return
import random
# If os.urandom is available, this method does the same thing as
# random.seed (at least as of python 2.6). If os.urandom is not
# available, we mix in the pid in addition to a timestamp.
try:
seed = long(hexlify(os.urandom(16)), 16)
except NotImplementedError:
seed = int(time.time() * 1000) ^ os.getpid()
random.seed(seed)
def _pipe_cloexec():
r, w = os.pipe()
set_close_exec(r)
set_close_exec(w)
return r, w
_task_id = None
def fork_processes(num_processes, max_restarts=100):
"""Starts multiple worker processes.
If ``num_processes`` is None or <= 0, we detect the number of cores
available on this machine and fork that number of child
processes. If ``num_processes`` is given and > 0, we fork that
specific number of sub-processes.
Since we use processes and not threads, there is no shared memory
between any server code.
Note that multiple processes are not compatible with the autoreload
module (or the ``autoreload=True`` option to `tornado.web.Application`
which defaults to True when ``debug=True``).
When using multiple processes, no IOLoops can be created or
referenced until after the call to ``fork_processes``.
In each child process, ``fork_processes`` returns its *task id*, a
number between 0 and ``num_processes``. Processes that exit
abnormally (due to a signal or non-zero exit status) are restarted
with the same id (up to ``max_restarts`` times). In the parent
process, ``fork_processes`` returns None if all child processes
have exited normally, but will otherwise only exit by throwing an
exception.
"""
global _task_id
assert _task_id is None
if num_processes is None or num_processes <= 0:
num_processes = cpu_count()
gen_log.info("Starting %d processes", num_processes)
children = {}
def start_child(i):
pid = os.fork()
if pid == 0:
# child process
_reseed_random()
global _task_id
_task_id = i
return i
else:
children[pid] = i
return None
for i in range(num_processes):
id = start_child(i)
if id is not None:
return id
num_restarts = 0
while children:
try:
pid, status = os.wait()
except OSError as e:
if errno_from_exception(e) == errno.EINTR:
continue
raise
if pid not in children:
continue
id = children.pop(pid)
if os.WIFSIGNALED(status):
gen_log.warning("child %d (pid %d) killed by signal %d, restarting",
id, pid, os.WTERMSIG(status))
elif os.WEXITSTATUS(status) != 0:
gen_log.warning("child %d (pid %d) exited with status %d, restarting",
id, pid, os.WEXITSTATUS(status))
else:
gen_log.info("child %d (pid %d) exited normally", id, pid)
continue
num_restarts += 1
if num_restarts > max_restarts:
raise RuntimeError("Too many child restarts, giving up")
new_id = start_child(id)
if new_id is not None:
return new_id
# All child processes exited cleanly, so exit the master process
# instead of just returning to right after the call to
# fork_processes (which will probably just start up another IOLoop
# unless the caller checks the return value).
sys.exit(0)
def task_id():
"""Returns the current task id, if any.
Returns None if this process was not created by `fork_processes`.
"""
global _task_id
return _task_id
class Subprocess(object):
"""Wraps ``subprocess.Popen`` with IOStream support.
The constructor is the same as ``subprocess.Popen`` with the following
additions:
* ``stdin``, ``stdout``, and ``stderr`` may have the value
``tornado.process.Subprocess.STREAM``, which will make the corresponding
attribute of the resulting Subprocess a `.PipeIOStream`. If this option
is used, the caller is responsible for closing the streams when done
with them.
The ``Subprocess.STREAM`` option and the ``set_exit_callback`` and
``wait_for_exit`` methods do not work on Windows. There is
therefore no reason to use this class instead of
``subprocess.Popen`` on that platform.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
STREAM = object()
_initialized = False
_waiting = {} # type: ignore
def __init__(self, *args, **kwargs):
self.io_loop = ioloop.IOLoop.current()
# All FDs we create should be closed on error; those in to_close
# should be closed in the parent process on success.
pipe_fds = []
to_close = []
if kwargs.get('stdin') is Subprocess.STREAM:
in_r, in_w = _pipe_cloexec()
kwargs['stdin'] = in_r
pipe_fds.extend((in_r, in_w))
to_close.append(in_r)
self.stdin = PipeIOStream(in_w)
if kwargs.get('stdout') is Subprocess.STREAM:
out_r, out_w = _pipe_cloexec()
kwargs['stdout'] = out_w
pipe_fds.extend((out_r, out_w))
to_close.append(out_w)
self.stdout = PipeIOStream(out_r)
if kwargs.get('stderr') is Subprocess.STREAM:
err_r, err_w = _pipe_cloexec()
kwargs['stderr'] = err_w
pipe_fds.extend((err_r, err_w))
to_close.append(err_w)
self.stderr = PipeIOStream(err_r)
try:
self.proc = subprocess.Popen(*args, **kwargs)
except:
for fd in pipe_fds:
os.close(fd)
raise
for fd in to_close:
os.close(fd)
for attr in ['stdin', 'stdout', 'stderr', 'pid']:
if not hasattr(self, attr): # don't clobber streams set above
setattr(self, attr, getattr(self.proc, attr))
self._exit_callback = None
self.returncode = None
def set_exit_callback(self, callback):
"""Runs ``callback`` when this process exits.
The callback takes one argument, the return code of the process.
This method uses a ``SIGCHLD`` handler, which is a global setting
and may conflict if you have other libraries trying to handle the
same signal. If you are using more than one ``IOLoop`` it may
be necessary to call `Subprocess.initialize` first to designate
one ``IOLoop`` to run the signal handlers.
In many cases a close callback on the stdout or stderr streams
can be used as an alternative to an exit callback if the
signal handler is causing a problem.
"""
self._exit_callback = stack_context.wrap(callback)
Subprocess.initialize()
Subprocess._waiting[self.pid] = self
Subprocess._try_cleanup_process(self.pid)
def wait_for_exit(self, raise_error=True):
"""Returns a `.Future` which resolves when the process exits.
Usage::
ret = yield proc.wait_for_exit()
This is a coroutine-friendly alternative to `set_exit_callback`
(and a replacement for the blocking `subprocess.Popen.wait`).
By default, raises `subprocess.CalledProcessError` if the process
has a non-zero exit status. Use ``wait_for_exit(raise_error=False)``
to suppress this behavior and return the exit status without raising.
.. versionadded:: 4.2
"""
future = Future()
def callback(ret):
if ret != 0 and raise_error:
# Unfortunately we don't have the original args any more.
future.set_exception(CalledProcessError(ret, None))
else:
future_set_result_unless_cancelled(future, ret)
self.set_exit_callback(callback)
return future
@classmethod
def initialize(cls):
"""Initializes the ``SIGCHLD`` handler.
The signal handler is run on an `.IOLoop` to avoid locking issues.
Note that the `.IOLoop` used for signal handling need not be the
same one used by individual Subprocess objects (as long as the
``IOLoops`` are each running in separate threads).
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been
removed.
"""
if cls._initialized:
return
io_loop = ioloop.IOLoop.current()
cls._old_sigchld = signal.signal(
signal.SIGCHLD,
lambda sig, frame: io_loop.add_callback_from_signal(cls._cleanup))
cls._initialized = True
@classmethod
def uninitialize(cls):
"""Removes the ``SIGCHLD`` handler."""
if not cls._initialized:
return
signal.signal(signal.SIGCHLD, cls._old_sigchld)
cls._initialized = False
@classmethod
def _cleanup(cls):
for pid in list(cls._waiting.keys()): # make a copy
cls._try_cleanup_process(pid)
@classmethod
def _try_cleanup_process(cls, pid):
try:
ret_pid, status = os.waitpid(pid, os.WNOHANG)
except OSError as e:
if errno_from_exception(e) == errno.ECHILD:
return
if ret_pid == 0:
return
assert ret_pid == pid
subproc = cls._waiting.pop(pid)
subproc.io_loop.add_callback_from_signal(
subproc._set_returncode, status)
def _set_returncode(self, status):
if os.WIFSIGNALED(status):
self.returncode = -os.WTERMSIG(status)
else:
assert os.WIFEXITED(status)
self.returncode = os.WEXITSTATUS(status)
# We've taken over wait() duty from the subprocess.Popen
# object. If we don't inform it of the process's return code,
# it will log a warning at destruction in python 3.6+.
self.proc.returncode = self.returncode
if self._exit_callback:
callback = self._exit_callback
self._exit_callback = None
callback(self.returncode)

379
lib/tornado/queues.py Executable file
View File

@@ -0,0 +1,379 @@
# Copyright 2015 The Tornado Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Asynchronous queues for coroutines. These classes are very similar
to those provided in the standard library's `asyncio package
<https://docs.python.org/3/library/asyncio-queue.html>`_.
.. warning::
Unlike the standard library's `queue` module, the classes defined here
are *not* thread-safe. To use these queues from another thread,
use `.IOLoop.add_callback` to transfer control to the `.IOLoop` thread
before calling any queue methods.
"""
from __future__ import absolute_import, division, print_function
import collections
import heapq
from tornado import gen, ioloop
from tornado.concurrent import Future, future_set_result_unless_cancelled
from tornado.locks import Event
__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
class QueueEmpty(Exception):
"""Raised by `.Queue.get_nowait` when the queue has no items."""
pass
class QueueFull(Exception):
"""Raised by `.Queue.put_nowait` when a queue is at its maximum size."""
pass
def _set_timeout(future, timeout):
if timeout:
def on_timeout():
if not future.done():
future.set_exception(gen.TimeoutError())
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
future.add_done_callback(
lambda _: io_loop.remove_timeout(timeout_handle))
class _QueueIterator(object):
def __init__(self, q):
self.q = q
def __anext__(self):
return self.q.get()
class Queue(object):
"""Coordinate producer and consumer coroutines.
If maxsize is 0 (the default) the queue size is unbounded.
.. testcode::
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.queues import Queue
q = Queue(maxsize=2)
async def consumer():
async for item in q:
try:
print('Doing work on %s' % item)
await gen.sleep(0.01)
finally:
q.task_done()
async def producer():
for item in range(5):
await q.put(item)
print('Put %s' % item)
async def main():
# Start consumer without waiting (since it never finishes).
IOLoop.current().spawn_callback(consumer)
await producer() # Wait for producer to put all tasks.
await q.join() # Wait for consumer to finish all tasks.
print('Done')
IOLoop.current().run_sync(main)
.. testoutput::
Put 0
Put 1
Doing work on 0
Put 2
Doing work on 1
Put 3
Doing work on 2
Put 4
Doing work on 3
Doing work on 4
Done
In versions of Python without native coroutines (before 3.5),
``consumer()`` could be written as::
@gen.coroutine
def consumer():
while True:
item = yield q.get()
try:
print('Doing work on %s' % item)
yield gen.sleep(0.01)
finally:
q.task_done()
.. versionchanged:: 4.3
Added ``async for`` support in Python 3.5.
"""
def __init__(self, maxsize=0):
if maxsize is None:
raise TypeError("maxsize can't be None")
if maxsize < 0:
raise ValueError("maxsize can't be negative")
self._maxsize = maxsize
self._init()
self._getters = collections.deque([]) # Futures.
self._putters = collections.deque([]) # Pairs of (item, Future).
self._unfinished_tasks = 0
self._finished = Event()
self._finished.set()
@property
def maxsize(self):
"""Number of items allowed in the queue."""
return self._maxsize
def qsize(self):
"""Number of items in the queue."""
return len(self._queue)
def empty(self):
return not self._queue
def full(self):
if self.maxsize == 0:
return False
else:
return self.qsize() >= self.maxsize
def put(self, item, timeout=None):
"""Put an item into the queue, perhaps waiting until there is room.
Returns a Future, which raises `tornado.util.TimeoutError` after a
timeout.
``timeout`` may be a number denoting a time (on the same
scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a
`datetime.timedelta` object for a deadline relative to the
current time.
"""
future = Future()
try:
self.put_nowait(item)
except QueueFull:
self._putters.append((item, future))
_set_timeout(future, timeout)
else:
future.set_result(None)
return future
def put_nowait(self, item):
"""Put an item into the queue without blocking.
If no free slot is immediately available, raise `QueueFull`.
"""
self._consume_expired()
if self._getters:
assert self.empty(), "queue non-empty, why are getters waiting?"
getter = self._getters.popleft()
self.__put_internal(item)
future_set_result_unless_cancelled(getter, self._get())
elif self.full():
raise QueueFull
else:
self.__put_internal(item)
def get(self, timeout=None):
"""Remove and return an item from the queue.
Returns a Future which resolves once an item is available, or raises
`tornado.util.TimeoutError` after a timeout.
``timeout`` may be a number denoting a time (on the same
scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a
`datetime.timedelta` object for a deadline relative to the
current time.
"""
future = Future()
try:
future.set_result(self.get_nowait())
except QueueEmpty:
self._getters.append(future)
_set_timeout(future, timeout)
return future
def get_nowait(self):
"""Remove and return an item from the queue without blocking.
Return an item if one is immediately available, else raise
`QueueEmpty`.
"""
self._consume_expired()
if self._putters:
assert self.full(), "queue not full, why are putters waiting?"
item, putter = self._putters.popleft()
self.__put_internal(item)
future_set_result_unless_cancelled(putter, None)
return self._get()
elif self.qsize():
return self._get()
else:
raise QueueEmpty
def task_done(self):
"""Indicate that a formerly enqueued task is complete.
Used by queue consumers. For each `.get` used to fetch a task, a
subsequent call to `.task_done` tells the queue that the processing
on the task is complete.
If a `.join` is blocking, it resumes when all items have been
processed; that is, when every `.put` is matched by a `.task_done`.
Raises `ValueError` if called more times than `.put`.
"""
if self._unfinished_tasks <= 0:
raise ValueError('task_done() called too many times')
self._unfinished_tasks -= 1
if self._unfinished_tasks == 0:
self._finished.set()
def join(self, timeout=None):
"""Block until all items in the queue are processed.
Returns a Future, which raises `tornado.util.TimeoutError` after a
timeout.
"""
return self._finished.wait(timeout)
def __aiter__(self):
return _QueueIterator(self)
# These three are overridable in subclasses.
def _init(self):
self._queue = collections.deque()
def _get(self):
return self._queue.popleft()
def _put(self, item):
self._queue.append(item)
# End of the overridable methods.
def __put_internal(self, item):
self._unfinished_tasks += 1
self._finished.clear()
self._put(item)
def _consume_expired(self):
# Remove timed-out waiters.
while self._putters and self._putters[0][1].done():
self._putters.popleft()
while self._getters and self._getters[0].done():
self._getters.popleft()
def __repr__(self):
return '<%s at %s %s>' % (
type(self).__name__, hex(id(self)), self._format())
def __str__(self):
return '<%s %s>' % (type(self).__name__, self._format())
def _format(self):
result = 'maxsize=%r' % (self.maxsize, )
if getattr(self, '_queue', None):
result += ' queue=%r' % self._queue
if self._getters:
result += ' getters[%s]' % len(self._getters)
if self._putters:
result += ' putters[%s]' % len(self._putters)
if self._unfinished_tasks:
result += ' tasks=%s' % self._unfinished_tasks
return result
class PriorityQueue(Queue):
"""A `.Queue` that retrieves entries in priority order, lowest first.
Entries are typically tuples like ``(priority number, data)``.
.. testcode::
from tornado.queues import PriorityQueue
q = PriorityQueue()
q.put((1, 'medium-priority item'))
q.put((0, 'high-priority item'))
q.put((10, 'low-priority item'))
print(q.get_nowait())
print(q.get_nowait())
print(q.get_nowait())
.. testoutput::
(0, 'high-priority item')
(1, 'medium-priority item')
(10, 'low-priority item')
"""
def _init(self):
self._queue = []
def _put(self, item):
heapq.heappush(self._queue, item)
def _get(self):
return heapq.heappop(self._queue)
class LifoQueue(Queue):
"""A `.Queue` that retrieves the most recently put items first.
.. testcode::
from tornado.queues import LifoQueue
q = LifoQueue()
q.put(3)
q.put(2)
q.put(1)
print(q.get_nowait())
print(q.get_nowait())
print(q.get_nowait())
.. testoutput::
1
2
3
"""
def _init(self):
self._queue = []
def _put(self, item):
self._queue.append(item)
def _get(self):
return self._queue.pop()

641
lib/tornado/routing.py Executable file
View File

@@ -0,0 +1,641 @@
# Copyright 2015 The Tornado Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Flexible routing implementation.
Tornado routes HTTP requests to appropriate handlers using `Router`
class implementations. The `tornado.web.Application` class is a
`Router` implementation and may be used directly, or the classes in
this module may be used for additional flexibility. The `RuleRouter`
class can match on more criteria than `.Application`, or the `Router`
interface can be subclassed for maximum customization.
`Router` interface extends `~.httputil.HTTPServerConnectionDelegate`
to provide additional routing capabilities. This also means that any
`Router` implementation can be used directly as a ``request_callback``
for `~.httpserver.HTTPServer` constructor.
`Router` subclass must implement a ``find_handler`` method to provide
a suitable `~.httputil.HTTPMessageDelegate` instance to handle the
request:
.. code-block:: python
class CustomRouter(Router):
def find_handler(self, request, **kwargs):
# some routing logic providing a suitable HTTPMessageDelegate instance
return MessageDelegate(request.connection)
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
def finish(self):
self.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": "2"}),
b"OK")
self.connection.finish()
router = CustomRouter()
server = HTTPServer(router)
The main responsibility of `Router` implementation is to provide a
mapping from a request to `~.httputil.HTTPMessageDelegate` instance
that will handle this request. In the example above we can see that
routing is possible even without instantiating an `~.web.Application`.
For routing to `~.web.RequestHandler` implementations we need an
`~.web.Application` instance. `~.web.Application.get_handler_delegate`
provides a convenient way to create `~.httputil.HTTPMessageDelegate`
for a given request and `~.web.RequestHandler`.
Here is a simple example of how we can we route to
`~.web.RequestHandler` subclasses by HTTP method:
.. code-block:: python
resources = {}
class GetResource(RequestHandler):
def get(self, path):
if path not in resources:
raise HTTPError(404)
self.finish(resources[path])
class PostResource(RequestHandler):
def post(self, path):
resources[path] = self.request.body
class HTTPMethodRouter(Router):
def __init__(self, app):
self.app = app
def find_handler(self, request, **kwargs):
handler = GetResource if request.method == "GET" else PostResource
return self.app.get_handler_delegate(request, handler, path_args=[request.path])
router = HTTPMethodRouter(Application())
server = HTTPServer(router)
`ReversibleRouter` interface adds the ability to distinguish between
the routes and reverse them to the original urls using route's name
and additional arguments. `~.web.Application` is itself an
implementation of `ReversibleRouter` class.
`RuleRouter` and `ReversibleRuleRouter` are implementations of
`Router` and `ReversibleRouter` interfaces and can be used for
creating rule-based routing configurations.
Rules are instances of `Rule` class. They contain a `Matcher`, which
provides the logic for determining whether the rule is a match for a
particular request and a target, which can be one of the following.
1) An instance of `~.httputil.HTTPServerConnectionDelegate`:
.. code-block:: python
router = RuleRouter([
Rule(PathMatches("/handler"), ConnectionDelegate()),
# ... more rules
])
class ConnectionDelegate(HTTPServerConnectionDelegate):
def start_request(self, server_conn, request_conn):
return MessageDelegate(request_conn)
2) A callable accepting a single argument of `~.httputil.HTTPServerRequest` type:
.. code-block:: python
router = RuleRouter([
Rule(PathMatches("/callable"), request_callable)
])
def request_callable(request):
request.write(b"HTTP/1.1 200 OK\\r\\nContent-Length: 2\\r\\n\\r\\nOK")
request.finish()
3) Another `Router` instance:
.. code-block:: python
router = RuleRouter([
Rule(PathMatches("/router.*"), CustomRouter())
])
Of course a nested `RuleRouter` or a `~.web.Application` is allowed:
.. code-block:: python
router = RuleRouter([
Rule(HostMatches("example.com"), RuleRouter([
Rule(PathMatches("/app1/.*"), Application([(r"/app1/handler", Handler)]))),
]))
])
server = HTTPServer(router)
In the example below `RuleRouter` is used to route between applications:
.. code-block:: python
app1 = Application([
(r"/app1/handler", Handler1),
# other handlers ...
])
app2 = Application([
(r"/app2/handler", Handler2),
# other handlers ...
])
router = RuleRouter([
Rule(PathMatches("/app1.*"), app1),
Rule(PathMatches("/app2.*"), app2)
])
server = HTTPServer(router)
For more information on application-level routing see docs for `~.web.Application`.
.. versionadded:: 4.5
"""
from __future__ import absolute_import, division, print_function
import re
from functools import partial
from tornado import httputil
from tornado.httpserver import _CallableAdapter
from tornado.escape import url_escape, url_unescape, utf8
from tornado.log import app_log
from tornado.util import basestring_type, import_object, re_unescape, unicode_type
try:
import typing # noqa
except ImportError:
pass
class Router(httputil.HTTPServerConnectionDelegate):
"""Abstract router interface."""
def find_handler(self, request, **kwargs):
# type: (httputil.HTTPServerRequest, typing.Any)->httputil.HTTPMessageDelegate
"""Must be implemented to return an appropriate instance of `~.httputil.HTTPMessageDelegate`
that can serve the request.
Routing implementations may pass additional kwargs to extend the routing logic.
:arg httputil.HTTPServerRequest request: current HTTP request.
:arg kwargs: additional keyword arguments passed by routing implementation.
:returns: an instance of `~.httputil.HTTPMessageDelegate` that will be used to
process the request.
"""
raise NotImplementedError()
def start_request(self, server_conn, request_conn):
return _RoutingDelegate(self, server_conn, request_conn)
class ReversibleRouter(Router):
"""Abstract router interface for routers that can handle named routes
and support reversing them to original urls.
"""
def reverse_url(self, name, *args):
"""Returns url string for a given route name and arguments
or ``None`` if no match is found.
:arg str name: route name.
:arg args: url parameters.
:returns: parametrized url string for a given route name (or ``None``).
"""
raise NotImplementedError()
class _RoutingDelegate(httputil.HTTPMessageDelegate):
def __init__(self, router, server_conn, request_conn):
self.server_conn = server_conn
self.request_conn = request_conn
self.delegate = None
self.router = router # type: Router
def headers_received(self, start_line, headers):
request = httputil.HTTPServerRequest(
connection=self.request_conn,
server_connection=self.server_conn,
start_line=start_line, headers=headers)
self.delegate = self.router.find_handler(request)
if self.delegate is None:
app_log.debug("Delegate for %s %s request not found",
start_line.method, start_line.path)
self.delegate = _DefaultMessageDelegate(self.request_conn)
return self.delegate.headers_received(start_line, headers)
def data_received(self, chunk):
return self.delegate.data_received(chunk)
def finish(self):
self.delegate.finish()
def on_connection_close(self):
self.delegate.on_connection_close()
class _DefaultMessageDelegate(httputil.HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
def finish(self):
self.connection.write_headers(
httputil.ResponseStartLine("HTTP/1.1", 404, "Not Found"), httputil.HTTPHeaders())
self.connection.finish()
class RuleRouter(Router):
"""Rule-based router implementation."""
def __init__(self, rules=None):
"""Constructs a router from an ordered list of rules::
RuleRouter([
Rule(PathMatches("/handler"), Target),
# ... more rules
])
You can also omit explicit `Rule` constructor and use tuples of arguments::
RuleRouter([
(PathMatches("/handler"), Target),
])
`PathMatches` is a default matcher, so the example above can be simplified::
RuleRouter([
("/handler", Target),
])
In the examples above, ``Target`` can be a nested `Router` instance, an instance of
`~.httputil.HTTPServerConnectionDelegate` or an old-style callable,
accepting a request argument.
:arg rules: a list of `Rule` instances or tuples of `Rule`
constructor arguments.
"""
self.rules = [] # type: typing.List[Rule]
if rules:
self.add_rules(rules)
def add_rules(self, rules):
"""Appends new rules to the router.
:arg rules: a list of Rule instances (or tuples of arguments, which are
passed to Rule constructor).
"""
for rule in rules:
if isinstance(rule, (tuple, list)):
assert len(rule) in (2, 3, 4)
if isinstance(rule[0], basestring_type):
rule = Rule(PathMatches(rule[0]), *rule[1:])
else:
rule = Rule(*rule)
self.rules.append(self.process_rule(rule))
def process_rule(self, rule):
"""Override this method for additional preprocessing of each rule.
:arg Rule rule: a rule to be processed.
:returns: the same or modified Rule instance.
"""
return rule
def find_handler(self, request, **kwargs):
for rule in self.rules:
target_params = rule.matcher.match(request)
if target_params is not None:
if rule.target_kwargs:
target_params['target_kwargs'] = rule.target_kwargs
delegate = self.get_target_delegate(
rule.target, request, **target_params)
if delegate is not None:
return delegate
return None
def get_target_delegate(self, target, request, **target_params):
"""Returns an instance of `~.httputil.HTTPMessageDelegate` for a
Rule's target. This method is called by `~.find_handler` and can be
extended to provide additional target types.
:arg target: a Rule's target.
:arg httputil.HTTPServerRequest request: current request.
:arg target_params: additional parameters that can be useful
for `~.httputil.HTTPMessageDelegate` creation.
"""
if isinstance(target, Router):
return target.find_handler(request, **target_params)
elif isinstance(target, httputil.HTTPServerConnectionDelegate):
return target.start_request(request.server_connection, request.connection)
elif callable(target):
return _CallableAdapter(
partial(target, **target_params), request.connection
)
return None
class ReversibleRuleRouter(ReversibleRouter, RuleRouter):
"""A rule-based router that implements ``reverse_url`` method.
Each rule added to this router may have a ``name`` attribute that can be
used to reconstruct an original uri. The actual reconstruction takes place
in a rule's matcher (see `Matcher.reverse`).
"""
def __init__(self, rules=None):
self.named_rules = {} # type: typing.Dict[str]
super(ReversibleRuleRouter, self).__init__(rules)
def process_rule(self, rule):
rule = super(ReversibleRuleRouter, self).process_rule(rule)
if rule.name:
if rule.name in self.named_rules:
app_log.warning(
"Multiple handlers named %s; replacing previous value",
rule.name)
self.named_rules[rule.name] = rule
return rule
def reverse_url(self, name, *args):
if name in self.named_rules:
return self.named_rules[name].matcher.reverse(*args)
for rule in self.rules:
if isinstance(rule.target, ReversibleRouter):
reversed_url = rule.target.reverse_url(name, *args)
if reversed_url is not None:
return reversed_url
return None
class Rule(object):
"""A routing rule."""
def __init__(self, matcher, target, target_kwargs=None, name=None):
"""Constructs a Rule instance.
:arg Matcher matcher: a `Matcher` instance used for determining
whether the rule should be considered a match for a specific
request.
:arg target: a Rule's target (typically a ``RequestHandler`` or
`~.httputil.HTTPServerConnectionDelegate` subclass or even a nested `Router`,
depending on routing implementation).
:arg dict target_kwargs: a dict of parameters that can be useful
at the moment of target instantiation (for example, ``status_code``
for a ``RequestHandler`` subclass). They end up in
``target_params['target_kwargs']`` of `RuleRouter.get_target_delegate`
method.
:arg str name: the name of the rule that can be used to find it
in `ReversibleRouter.reverse_url` implementation.
"""
if isinstance(target, str):
# import the Module and instantiate the class
# Must be a fully qualified name (module.ClassName)
target = import_object(target)
self.matcher = matcher # type: Matcher
self.target = target
self.target_kwargs = target_kwargs if target_kwargs else {}
self.name = name
def reverse(self, *args):
return self.matcher.reverse(*args)
def __repr__(self):
return '%s(%r, %s, kwargs=%r, name=%r)' % \
(self.__class__.__name__, self.matcher,
self.target, self.target_kwargs, self.name)
class Matcher(object):
"""Represents a matcher for request features."""
def match(self, request):
"""Matches current instance against the request.
:arg httputil.HTTPServerRequest request: current HTTP request
:returns: a dict of parameters to be passed to the target handler
(for example, ``handler_kwargs``, ``path_args``, ``path_kwargs``
can be passed for proper `~.web.RequestHandler` instantiation).
An empty dict is a valid (and common) return value to indicate a match
when the argument-passing features are not used.
``None`` must be returned to indicate that there is no match."""
raise NotImplementedError()
def reverse(self, *args):
"""Reconstructs full url from matcher instance and additional arguments."""
return None
class AnyMatches(Matcher):
"""Matches any request."""
def match(self, request):
return {}
class HostMatches(Matcher):
"""Matches requests from hosts specified by ``host_pattern`` regex."""
def __init__(self, host_pattern):
if isinstance(host_pattern, basestring_type):
if not host_pattern.endswith("$"):
host_pattern += "$"
self.host_pattern = re.compile(host_pattern)
else:
self.host_pattern = host_pattern
def match(self, request):
if self.host_pattern.match(request.host_name):
return {}
return None
class DefaultHostMatches(Matcher):
"""Matches requests from host that is equal to application's default_host.
Always returns no match if ``X-Real-Ip`` header is present.
"""
def __init__(self, application, host_pattern):
self.application = application
self.host_pattern = host_pattern
def match(self, request):
# Look for default host if not behind load balancer (for debugging)
if "X-Real-Ip" not in request.headers:
if self.host_pattern.match(self.application.default_host):
return {}
return None
class PathMatches(Matcher):
"""Matches requests with paths specified by ``path_pattern`` regex."""
def __init__(self, path_pattern):
if isinstance(path_pattern, basestring_type):
if not path_pattern.endswith('$'):
path_pattern += '$'
self.regex = re.compile(path_pattern)
else:
self.regex = path_pattern
assert len(self.regex.groupindex) in (0, self.regex.groups), \
("groups in url regexes must either be all named or all "
"positional: %r" % self.regex.pattern)
self._path, self._group_count = self._find_groups()
def match(self, request):
match = self.regex.match(request.path)
if match is None:
return None
if not self.regex.groups:
return {}
path_args, path_kwargs = [], {}
# Pass matched groups to the handler. Since
# match.groups() includes both named and
# unnamed groups, we want to use either groups
# or groupdict but not both.
if self.regex.groupindex:
path_kwargs = dict(
(str(k), _unquote_or_none(v))
for (k, v) in match.groupdict().items())
else:
path_args = [_unquote_or_none(s) for s in match.groups()]
return dict(path_args=path_args, path_kwargs=path_kwargs)
def reverse(self, *args):
if self._path is None:
raise ValueError("Cannot reverse url regex " + self.regex.pattern)
assert len(args) == self._group_count, "required number of arguments " \
"not found"
if not len(args):
return self._path
converted_args = []
for a in args:
if not isinstance(a, (unicode_type, bytes)):
a = str(a)
converted_args.append(url_escape(utf8(a), plus=False))
return self._path % tuple(converted_args)
def _find_groups(self):
"""Returns a tuple (reverse string, group count) for a url.
For example: Given the url pattern /([0-9]{4})/([a-z-]+)/, this method
would return ('/%s/%s/', 2).
"""
pattern = self.regex.pattern
if pattern.startswith('^'):
pattern = pattern[1:]
if pattern.endswith('$'):
pattern = pattern[:-1]
if self.regex.groups != pattern.count('('):
# The pattern is too complicated for our simplistic matching,
# so we can't support reversing it.
return None, None
pieces = []
for fragment in pattern.split('('):
if ')' in fragment:
paren_loc = fragment.index(')')
if paren_loc >= 0:
pieces.append('%s' + fragment[paren_loc + 1:])
else:
try:
unescaped_fragment = re_unescape(fragment)
except ValueError:
# If we can't unescape part of it, we can't
# reverse this url.
return (None, None)
pieces.append(unescaped_fragment)
return ''.join(pieces), self.regex.groups
class URLSpec(Rule):
"""Specifies mappings between URLs and handlers.
.. versionchanged: 4.5
`URLSpec` is now a subclass of a `Rule` with `PathMatches` matcher and is preserved for
backwards compatibility.
"""
def __init__(self, pattern, handler, kwargs=None, name=None):
"""Parameters:
* ``pattern``: Regular expression to be matched. Any capturing
groups in the regex will be passed in to the handler's
get/post/etc methods as arguments (by keyword if named, by
position if unnamed. Named and unnamed capturing groups
may not be mixed in the same rule).
* ``handler``: `~.web.RequestHandler` subclass to be invoked.
* ``kwargs`` (optional): A dictionary of additional arguments
to be passed to the handler's constructor.
* ``name`` (optional): A name for this handler. Used by
`~.web.Application.reverse_url`.
"""
super(URLSpec, self).__init__(PathMatches(pattern), handler, kwargs, name)
self.regex = self.matcher.regex
self.handler_class = self.target
self.kwargs = kwargs
def __repr__(self):
return '%s(%r, %s, kwargs=%r, name=%r)' % \
(self.__class__.__name__, self.regex.pattern,
self.handler_class, self.kwargs, self.name)
def _unquote_or_none(s):
"""None-safe wrapper around url_unescape to handle unmatched optional
groups correctly.
Note that args are passed as bytes so the handler can decide what
encoding to use.
"""
if s is None:
return s
return url_unescape(s, encoding=None, plus=False)

566
lib/tornado/simple_httpclient.py Executable file
View File

@@ -0,0 +1,566 @@
from __future__ import absolute_import, division, print_function
from tornado.escape import _unicode
from tornado import gen
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
from tornado import httputil
from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters
from tornado.ioloop import IOLoop
from tornado.iostream import StreamClosedError
from tornado.netutil import Resolver, OverrideResolver, _client_ssl_defaults
from tornado.log import gen_log
from tornado import stack_context
from tornado.tcpclient import TCPClient
from tornado.util import PY3
import base64
import collections
import copy
import functools
import re
import socket
import sys
import time
from io import BytesIO
if PY3:
import urllib.parse as urlparse
else:
import urlparse
try:
import ssl
except ImportError:
# ssl is not available on Google App Engine.
ssl = None
class HTTPTimeoutError(HTTPError):
"""Error raised by SimpleAsyncHTTPClient on timeout.
For historical reasons, this is a subclass of `.HTTPClientError`
which simulates a response code of 599.
.. versionadded:: 5.1
"""
def __init__(self, message):
super(HTTPTimeoutError, self).__init__(599, message=message)
def __str__(self):
return self.message
class HTTPStreamClosedError(HTTPError):
"""Error raised by SimpleAsyncHTTPClient when the underlying stream is closed.
When a more specific exception is available (such as `ConnectionResetError`),
it may be raised instead of this one.
For historical reasons, this is a subclass of `.HTTPClientError`
which simulates a response code of 599.
.. versionadded:: 5.1
"""
def __init__(self, message):
super(HTTPStreamClosedError, self).__init__(599, message=message)
def __str__(self):
return self.message
class SimpleAsyncHTTPClient(AsyncHTTPClient):
"""Non-blocking HTTP client with no external dependencies.
This class implements an HTTP 1.1 client on top of Tornado's IOStreams.
Some features found in the curl-based AsyncHTTPClient are not yet
supported. In particular, proxies are not supported, connections
are not reused, and callers cannot select the network interface to be
used.
"""
def initialize(self, max_clients=10,
hostname_mapping=None, max_buffer_size=104857600,
resolver=None, defaults=None, max_header_size=None,
max_body_size=None):
"""Creates a AsyncHTTPClient.
Only a single AsyncHTTPClient instance exists per IOLoop
in order to provide limitations on the number of pending connections.
``force_instance=True`` may be used to suppress this behavior.
Note that because of this implicit reuse, unless ``force_instance``
is used, only the first call to the constructor actually uses
its arguments. It is recommended to use the ``configure`` method
instead of the constructor to ensure that arguments take effect.
``max_clients`` is the number of concurrent requests that can be
in progress; when this limit is reached additional requests will be
queued. Note that time spent waiting in this queue still counts
against the ``request_timeout``.
``hostname_mapping`` is a dictionary mapping hostnames to IP addresses.
It can be used to make local DNS changes when modifying system-wide
settings like ``/etc/hosts`` is not possible or desirable (e.g. in
unittests).
``max_buffer_size`` (default 100MB) is the number of bytes
that can be read into memory at once. ``max_body_size``
(defaults to ``max_buffer_size``) is the largest response body
that the client will accept. Without a
``streaming_callback``, the smaller of these two limits
applies; with a ``streaming_callback`` only ``max_body_size``
does.
.. versionchanged:: 4.2
Added the ``max_body_size`` argument.
"""
super(SimpleAsyncHTTPClient, self).initialize(defaults=defaults)
self.max_clients = max_clients
self.queue = collections.deque()
self.active = {}
self.waiting = {}
self.max_buffer_size = max_buffer_size
self.max_header_size = max_header_size
self.max_body_size = max_body_size
# TCPClient could create a Resolver for us, but we have to do it
# ourselves to support hostname_mapping.
if resolver:
self.resolver = resolver
self.own_resolver = False
else:
self.resolver = Resolver()
self.own_resolver = True
if hostname_mapping is not None:
self.resolver = OverrideResolver(resolver=self.resolver,
mapping=hostname_mapping)
self.tcp_client = TCPClient(resolver=self.resolver)
def close(self):
super(SimpleAsyncHTTPClient, self).close()
if self.own_resolver:
self.resolver.close()
self.tcp_client.close()
def fetch_impl(self, request, callback):
key = object()
self.queue.append((key, request, callback))
if not len(self.active) < self.max_clients:
timeout_handle = self.io_loop.add_timeout(
self.io_loop.time() + min(request.connect_timeout,
request.request_timeout),
functools.partial(self._on_timeout, key, "in request queue"))
else:
timeout_handle = None
self.waiting[key] = (request, callback, timeout_handle)
self._process_queue()
if self.queue:
gen_log.debug("max_clients limit reached, request queued. "
"%d active, %d queued requests." % (
len(self.active), len(self.queue)))
def _process_queue(self):
with stack_context.NullContext():
while self.queue and len(self.active) < self.max_clients:
key, request, callback = self.queue.popleft()
if key not in self.waiting:
continue
self._remove_timeout(key)
self.active[key] = (request, callback)
release_callback = functools.partial(self._release_fetch, key)
self._handle_request(request, release_callback, callback)
def _connection_class(self):
return _HTTPConnection
def _handle_request(self, request, release_callback, final_callback):
self._connection_class()(
self, request, release_callback,
final_callback, self.max_buffer_size, self.tcp_client,
self.max_header_size, self.max_body_size)
def _release_fetch(self, key):
del self.active[key]
self._process_queue()
def _remove_timeout(self, key):
if key in self.waiting:
request, callback, timeout_handle = self.waiting[key]
if timeout_handle is not None:
self.io_loop.remove_timeout(timeout_handle)
del self.waiting[key]
def _on_timeout(self, key, info=None):
"""Timeout callback of request.
Construct a timeout HTTPResponse when a timeout occurs.
:arg object key: A simple object to mark the request.
:info string key: More detailed timeout information.
"""
request, callback, timeout_handle = self.waiting[key]
self.queue.remove((key, request, callback))
error_message = "Timeout {0}".format(info) if info else "Timeout"
timeout_response = HTTPResponse(
request, 599, error=HTTPTimeoutError(error_message),
request_time=self.io_loop.time() - request.start_time)
self.io_loop.add_callback(callback, timeout_response)
del self.waiting[key]
class _HTTPConnection(httputil.HTTPMessageDelegate):
_SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
def __init__(self, client, request, release_callback,
final_callback, max_buffer_size, tcp_client,
max_header_size, max_body_size):
self.io_loop = IOLoop.current()
self.start_time = self.io_loop.time()
self.start_wall_time = time.time()
self.client = client
self.request = request
self.release_callback = release_callback
self.final_callback = final_callback
self.max_buffer_size = max_buffer_size
self.tcp_client = tcp_client
self.max_header_size = max_header_size
self.max_body_size = max_body_size
self.code = None
self.headers = None
self.chunks = []
self._decompressor = None
# Timeout handle returned by IOLoop.add_timeout
self._timeout = None
self._sockaddr = None
IOLoop.current().add_callback(self.run)
@gen.coroutine
def run(self):
try:
self.parsed = urlparse.urlsplit(_unicode(self.request.url))
if self.parsed.scheme not in ("http", "https"):
raise ValueError("Unsupported url scheme: %s" %
self.request.url)
# urlsplit results have hostname and port results, but they
# didn't support ipv6 literals until python 2.7.
netloc = self.parsed.netloc
if "@" in netloc:
userpass, _, netloc = netloc.rpartition("@")
host, port = httputil.split_host_and_port(netloc)
if port is None:
port = 443 if self.parsed.scheme == "https" else 80
if re.match(r'^\[.*\]$', host):
# raw ipv6 addresses in urls are enclosed in brackets
host = host[1:-1]
self.parsed_hostname = host # save final host for _on_connect
if self.request.allow_ipv6 is False:
af = socket.AF_INET
else:
af = socket.AF_UNSPEC
ssl_options = self._get_ssl_options(self.parsed.scheme)
timeout = min(self.request.connect_timeout, self.request.request_timeout)
if timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + timeout,
stack_context.wrap(functools.partial(self._on_timeout, "while connecting")))
stream = yield self.tcp_client.connect(
host, port, af=af,
ssl_options=ssl_options,
max_buffer_size=self.max_buffer_size)
if self.final_callback is None:
# final_callback is cleared if we've hit our timeout.
stream.close()
return
self.stream = stream
self.stream.set_close_callback(self.on_connection_close)
self._remove_timeout()
if self.final_callback is None:
return
if self.request.request_timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + self.request.request_timeout,
stack_context.wrap(functools.partial(self._on_timeout, "during request")))
if (self.request.method not in self._SUPPORTED_METHODS and
not self.request.allow_nonstandard_methods):
raise KeyError("unknown method %s" % self.request.method)
for key in ('network_interface',
'proxy_host', 'proxy_port',
'proxy_username', 'proxy_password',
'proxy_auth_mode'):
if getattr(self.request, key, None):
raise NotImplementedError('%s not supported' % key)
if "Connection" not in self.request.headers:
self.request.headers["Connection"] = "close"
if "Host" not in self.request.headers:
if '@' in self.parsed.netloc:
self.request.headers["Host"] = self.parsed.netloc.rpartition('@')[-1]
else:
self.request.headers["Host"] = self.parsed.netloc
username, password = None, None
if self.parsed.username is not None:
username, password = self.parsed.username, self.parsed.password
elif self.request.auth_username is not None:
username = self.request.auth_username
password = self.request.auth_password or ''
if username is not None:
if self.request.auth_mode not in (None, "basic"):
raise ValueError("unsupported auth_mode %s",
self.request.auth_mode)
self.request.headers["Authorization"] = (
b"Basic " + base64.b64encode(
httputil.encode_username_password(username, password)))
if self.request.user_agent:
self.request.headers["User-Agent"] = self.request.user_agent
if not self.request.allow_nonstandard_methods:
# Some HTTP methods nearly always have bodies while others
# almost never do. Fail in this case unless the user has
# opted out of sanity checks with allow_nonstandard_methods.
body_expected = self.request.method in ("POST", "PATCH", "PUT")
body_present = (self.request.body is not None or
self.request.body_producer is not None)
if ((body_expected and not body_present) or
(body_present and not body_expected)):
raise ValueError(
'Body must %sbe None for method %s (unless '
'allow_nonstandard_methods is true)' %
('not ' if body_expected else '', self.request.method))
if self.request.expect_100_continue:
self.request.headers["Expect"] = "100-continue"
if self.request.body is not None:
# When body_producer is used the caller is responsible for
# setting Content-Length (or else chunked encoding will be used).
self.request.headers["Content-Length"] = str(len(
self.request.body))
if (self.request.method == "POST" and
"Content-Type" not in self.request.headers):
self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
if self.request.decompress_response:
self.request.headers["Accept-Encoding"] = "gzip"
req_path = ((self.parsed.path or '/') +
(('?' + self.parsed.query) if self.parsed.query else ''))
self.connection = self._create_connection(stream)
start_line = httputil.RequestStartLine(self.request.method,
req_path, '')
self.connection.write_headers(start_line, self.request.headers)
if self.request.expect_100_continue:
yield self.connection.read_response(self)
else:
yield self._write_body(True)
except Exception:
if not self._handle_exception(*sys.exc_info()):
raise
def _get_ssl_options(self, scheme):
if scheme == "https":
if self.request.ssl_options is not None:
return self.request.ssl_options
# If we are using the defaults, don't construct a
# new SSLContext.
if (self.request.validate_cert and
self.request.ca_certs is None and
self.request.client_cert is None and
self.request.client_key is None):
return _client_ssl_defaults
ssl_ctx = ssl.create_default_context(
ssl.Purpose.SERVER_AUTH,
cafile=self.request.ca_certs)
if not self.request.validate_cert:
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
if self.request.client_cert is not None:
ssl_ctx.load_cert_chain(self.request.client_cert,
self.request.client_key)
if hasattr(ssl, 'OP_NO_COMPRESSION'):
# See netutil.ssl_options_to_context
ssl_ctx.options |= ssl.OP_NO_COMPRESSION
return ssl_ctx
return None
def _on_timeout(self, info=None):
"""Timeout callback of _HTTPConnection instance.
Raise a `HTTPTimeoutError` when a timeout occurs.
:info string key: More detailed timeout information.
"""
self._timeout = None
error_message = "Timeout {0}".format(info) if info else "Timeout"
if self.final_callback is not None:
self._handle_exception(HTTPTimeoutError, HTTPTimeoutError(error_message),
None)
def _remove_timeout(self):
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
def _create_connection(self, stream):
stream.set_nodelay(True)
connection = HTTP1Connection(
stream, True,
HTTP1ConnectionParameters(
no_keep_alive=True,
max_header_size=self.max_header_size,
max_body_size=self.max_body_size,
decompress=self.request.decompress_response),
self._sockaddr)
return connection
@gen.coroutine
def _write_body(self, start_read):
if self.request.body is not None:
self.connection.write(self.request.body)
elif self.request.body_producer is not None:
fut = self.request.body_producer(self.connection.write)
if fut is not None:
yield fut
self.connection.finish()
if start_read:
try:
yield self.connection.read_response(self)
except StreamClosedError:
if not self._handle_exception(*sys.exc_info()):
raise
def _release(self):
if self.release_callback is not None:
release_callback = self.release_callback
self.release_callback = None
release_callback()
def _run_callback(self, response):
self._release()
if self.final_callback is not None:
final_callback = self.final_callback
self.final_callback = None
self.io_loop.add_callback(final_callback, response)
def _handle_exception(self, typ, value, tb):
if self.final_callback:
self._remove_timeout()
if isinstance(value, StreamClosedError):
if value.real_error is None:
value = HTTPStreamClosedError("Stream closed")
else:
value = value.real_error
self._run_callback(HTTPResponse(self.request, 599, error=value,
request_time=self.io_loop.time() - self.start_time,
start_time=self.start_wall_time,
))
if hasattr(self, "stream"):
# TODO: this may cause a StreamClosedError to be raised
# by the connection's Future. Should we cancel the
# connection more gracefully?
self.stream.close()
return True
else:
# If our callback has already been called, we are probably
# catching an exception that is not caused by us but rather
# some child of our callback. Rather than drop it on the floor,
# pass it along, unless it's just the stream being closed.
return isinstance(value, StreamClosedError)
def on_connection_close(self):
if self.final_callback is not None:
message = "Connection closed"
if self.stream.error:
raise self.stream.error
try:
raise HTTPStreamClosedError(message)
except HTTPStreamClosedError:
self._handle_exception(*sys.exc_info())
def headers_received(self, first_line, headers):
if self.request.expect_100_continue and first_line.code == 100:
self._write_body(False)
return
self.code = first_line.code
self.reason = first_line.reason
self.headers = headers
if self._should_follow_redirect():
return
if self.request.header_callback is not None:
# Reassemble the start line.
self.request.header_callback('%s %s %s\r\n' % first_line)
for k, v in self.headers.get_all():
self.request.header_callback("%s: %s\r\n" % (k, v))
self.request.header_callback('\r\n')
def _should_follow_redirect(self):
return (self.request.follow_redirects and
self.request.max_redirects > 0 and
self.code in (301, 302, 303, 307, 308))
def finish(self):
data = b''.join(self.chunks)
self._remove_timeout()
original_request = getattr(self.request, "original_request",
self.request)
if self._should_follow_redirect():
assert isinstance(self.request, _RequestProxy)
new_request = copy.copy(self.request.request)
new_request.url = urlparse.urljoin(self.request.url,
self.headers["Location"])
new_request.max_redirects = self.request.max_redirects - 1
del new_request.headers["Host"]
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
# Client SHOULD make a GET request after a 303.
# According to the spec, 302 should be followed by the same
# method as the original request, but in practice browsers
# treat 302 the same as 303, and many servers use 302 for
# compatibility with pre-HTTP/1.1 user agents which don't
# understand the 303 status.
if self.code in (302, 303):
new_request.method = "GET"
new_request.body = None
for h in ["Content-Length", "Content-Type",
"Content-Encoding", "Transfer-Encoding"]:
try:
del self.request.headers[h]
except KeyError:
pass
new_request.original_request = original_request
final_callback = self.final_callback
self.final_callback = None
self._release()
fut = self.client.fetch(new_request, raise_error=False)
fut.add_done_callback(lambda f: final_callback(f.result()))
self._on_end_request()
return
if self.request.streaming_callback:
buffer = BytesIO()
else:
buffer = BytesIO(data) # TODO: don't require one big string?
response = HTTPResponse(original_request,
self.code, reason=getattr(self, 'reason', None),
headers=self.headers,
request_time=self.io_loop.time() - self.start_time,
start_time=self.start_wall_time,
buffer=buffer,
effective_url=self.request.url)
self._run_callback(response)
self._on_end_request()
def _on_end_request(self):
self.stream.close()
def data_received(self, chunk):
if self._should_follow_redirect():
# We're going to follow a redirect so just discard the body.
return
if self.request.streaming_callback is not None:
self.request.streaming_callback(chunk)
else:
self.chunks.append(chunk)
if __name__ == "__main__":
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
main()

77
lib/tornado/speedups.c Executable file
View File

@@ -0,0 +1,77 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <stdint.h>
static PyObject* websocket_mask(PyObject* self, PyObject* args) {
const char* mask;
Py_ssize_t mask_len;
uint32_t uint32_mask;
uint64_t uint64_mask;
const char* data;
Py_ssize_t data_len;
Py_ssize_t i;
PyObject* result;
char* buf;
if (!PyArg_ParseTuple(args, "s#s#", &mask, &mask_len, &data, &data_len)) {
return NULL;
}
uint32_mask = ((uint32_t*)mask)[0];
result = PyBytes_FromStringAndSize(NULL, data_len);
if (!result) {
return NULL;
}
buf = PyBytes_AsString(result);
if (sizeof(size_t) >= 8) {
uint64_mask = uint32_mask;
uint64_mask = (uint64_mask << 32) | uint32_mask;
while (data_len >= 8) {
((uint64_t*)buf)[0] = ((uint64_t*)data)[0] ^ uint64_mask;
data += 8;
buf += 8;
data_len -= 8;
}
}
while (data_len >= 4) {
((uint32_t*)buf)[0] = ((uint32_t*)data)[0] ^ uint32_mask;
data += 4;
buf += 4;
data_len -= 4;
}
for (i = 0; i < data_len; i++) {
buf[i] = data[i] ^ mask[i];
}
return result;
}
static PyMethodDef methods[] = {
{"websocket_mask", websocket_mask, METH_VARARGS, ""},
{NULL, NULL, 0, NULL}
};
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef speedupsmodule = {
PyModuleDef_HEAD_INIT,
"speedups",
NULL,
-1,
methods
};
PyMODINIT_FUNC
PyInit_speedups(void) {
return PyModule_Create(&speedupsmodule);
}
#else // Python 2.x
PyMODINIT_FUNC
initspeedups(void) {
Py_InitModule("tornado.speedups", methods);
}
#endif

1
lib/tornado/speedups.pyi Executable file
View File

@@ -0,0 +1 @@
def websocket_mask(mask: bytes, data: bytes) -> bytes: ...

413
lib/tornado/stack_context.py Executable file
View File

@@ -0,0 +1,413 @@
#
# Copyright 2010 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""`StackContext` allows applications to maintain threadlocal-like state
that follows execution as it moves to other execution contexts.
The motivating examples are to eliminate the need for explicit
``async_callback`` wrappers (as in `tornado.web.RequestHandler`), and to
allow some additional context to be kept for logging.
This is slightly magic, but it's an extension of the idea that an
exception handler is a kind of stack-local state and when that stack
is suspended and resumed in a new context that state needs to be
preserved. `StackContext` shifts the burden of restoring that state
from each call site (e.g. wrapping each `.AsyncHTTPClient` callback
in ``async_callback``) to the mechanisms that transfer control from
one context to another (e.g. `.AsyncHTTPClient` itself, `.IOLoop`,
thread pools, etc).
Example usage::
@contextlib.contextmanager
def die_on_error():
try:
yield
except Exception:
logging.error("exception in asynchronous operation",exc_info=True)
sys.exit(1)
with StackContext(die_on_error):
# Any exception thrown here *or in callback and its descendants*
# will cause the process to exit instead of spinning endlessly
# in the ioloop.
http_client.fetch(url, callback)
ioloop.start()
Most applications shouldn't have to work with `StackContext` directly.
Here are a few rules of thumb for when it's necessary:
* If you're writing an asynchronous library that doesn't rely on a
stack_context-aware library like `tornado.ioloop` or `tornado.iostream`
(for example, if you're writing a thread pool), use
`.stack_context.wrap()` before any asynchronous operations to capture the
stack context from where the operation was started.
* If you're writing an asynchronous library that has some shared
resources (such as a connection pool), create those shared resources
within a ``with stack_context.NullContext():`` block. This will prevent
``StackContexts`` from leaking from one request to another.
* If you want to write something like an exception handler that will
persist across asynchronous calls, create a new `StackContext` (or
`ExceptionStackContext`), and make your asynchronous calls in a ``with``
block that references your `StackContext`.
.. deprecated:: 5.1
The ``stack_context`` package is deprecated and will be removed in
Tornado 6.0.
"""
from __future__ import absolute_import, division, print_function
import sys
import threading
import warnings
from tornado.util import raise_exc_info
class StackContextInconsistentError(Exception):
pass
class _State(threading.local):
def __init__(self):
self.contexts = (tuple(), None)
_state = _State()
class StackContext(object):
"""Establishes the given context as a StackContext that will be transferred.
Note that the parameter is a callable that returns a context
manager, not the context itself. That is, where for a
non-transferable context manager you would say::
with my_context():
StackContext takes the function itself rather than its result::
with StackContext(my_context):
The result of ``with StackContext() as cb:`` is a deactivation
callback. Run this callback when the StackContext is no longer
needed to ensure that it is not propagated any further (note that
deactivating a context does not affect any instances of that
context that are currently pending). This is an advanced feature
and not necessary in most applications.
"""
def __init__(self, context_factory):
warnings.warn("StackContext is deprecated and will be removed in Tornado 6.0",
DeprecationWarning)
self.context_factory = context_factory
self.contexts = []
self.active = True
def _deactivate(self):
self.active = False
# StackContext protocol
def enter(self):
context = self.context_factory()
self.contexts.append(context)
context.__enter__()
def exit(self, type, value, traceback):
context = self.contexts.pop()
context.__exit__(type, value, traceback)
# Note that some of this code is duplicated in ExceptionStackContext
# below. ExceptionStackContext is more common and doesn't need
# the full generality of this class.
def __enter__(self):
self.old_contexts = _state.contexts
self.new_contexts = (self.old_contexts[0] + (self,), self)
_state.contexts = self.new_contexts
try:
self.enter()
except:
_state.contexts = self.old_contexts
raise
return self._deactivate
def __exit__(self, type, value, traceback):
try:
self.exit(type, value, traceback)
finally:
final_contexts = _state.contexts
_state.contexts = self.old_contexts
# Generator coroutines and with-statements with non-local
# effects interact badly. Check here for signs of
# the stack getting out of sync.
# Note that this check comes after restoring _state.context
# so that if it fails things are left in a (relatively)
# consistent state.
if final_contexts is not self.new_contexts:
raise StackContextInconsistentError(
'stack_context inconsistency (may be caused by yield '
'within a "with StackContext" block)')
# Break up a reference to itself to allow for faster GC on CPython.
self.new_contexts = None
class ExceptionStackContext(object):
"""Specialization of StackContext for exception handling.
The supplied ``exception_handler`` function will be called in the
event of an uncaught exception in this context. The semantics are
similar to a try/finally clause, and intended use cases are to log
an error, close a socket, or similar cleanup actions. The
``exc_info`` triple ``(type, value, traceback)`` will be passed to the
exception_handler function.
If the exception handler returns true, the exception will be
consumed and will not be propagated to other exception handlers.
.. versionadded:: 5.1
The ``delay_warning`` argument can be used to delay the emission
of DeprecationWarnings until an exception is caught by the
``ExceptionStackContext``, which facilitates certain transitional
use cases.
"""
def __init__(self, exception_handler, delay_warning=False):
self.delay_warning = delay_warning
if not self.delay_warning:
warnings.warn(
"StackContext is deprecated and will be removed in Tornado 6.0",
DeprecationWarning)
self.exception_handler = exception_handler
self.active = True
def _deactivate(self):
self.active = False
def exit(self, type, value, traceback):
if type is not None:
if self.delay_warning:
warnings.warn(
"StackContext is deprecated and will be removed in Tornado 6.0",
DeprecationWarning)
return self.exception_handler(type, value, traceback)
def __enter__(self):
self.old_contexts = _state.contexts
self.new_contexts = (self.old_contexts[0], self)
_state.contexts = self.new_contexts
return self._deactivate
def __exit__(self, type, value, traceback):
try:
if type is not None:
return self.exception_handler(type, value, traceback)
finally:
final_contexts = _state.contexts
_state.contexts = self.old_contexts
if final_contexts is not self.new_contexts:
raise StackContextInconsistentError(
'stack_context inconsistency (may be caused by yield '
'within a "with StackContext" block)')
# Break up a reference to itself to allow for faster GC on CPython.
self.new_contexts = None
class NullContext(object):
"""Resets the `StackContext`.
Useful when creating a shared resource on demand (e.g. an
`.AsyncHTTPClient`) where the stack that caused the creating is
not relevant to future operations.
"""
def __enter__(self):
self.old_contexts = _state.contexts
_state.contexts = (tuple(), None)
def __exit__(self, type, value, traceback):
_state.contexts = self.old_contexts
def _remove_deactivated(contexts):
"""Remove deactivated handlers from the chain"""
# Clean ctx handlers
stack_contexts = tuple([h for h in contexts[0] if h.active])
# Find new head
head = contexts[1]
while head is not None and not head.active:
head = head.old_contexts[1]
# Process chain
ctx = head
while ctx is not None:
parent = ctx.old_contexts[1]
while parent is not None:
if parent.active:
break
ctx.old_contexts = parent.old_contexts
parent = parent.old_contexts[1]
ctx = parent
return (stack_contexts, head)
def wrap(fn):
"""Returns a callable object that will restore the current `StackContext`
when executed.
Use this whenever saving a callback to be executed later in a
different execution context (either in a different thread or
asynchronously in the same thread).
"""
# Check if function is already wrapped
if fn is None or hasattr(fn, '_wrapped'):
return fn
# Capture current stack head
# TODO: Any other better way to store contexts and update them in wrapped function?
cap_contexts = [_state.contexts]
if not cap_contexts[0][0] and not cap_contexts[0][1]:
# Fast path when there are no active contexts.
def null_wrapper(*args, **kwargs):
try:
current_state = _state.contexts
_state.contexts = cap_contexts[0]
return fn(*args, **kwargs)
finally:
_state.contexts = current_state
null_wrapper._wrapped = True
return null_wrapper
def wrapped(*args, **kwargs):
ret = None
try:
# Capture old state
current_state = _state.contexts
# Remove deactivated items
cap_contexts[0] = contexts = _remove_deactivated(cap_contexts[0])
# Force new state
_state.contexts = contexts
# Current exception
exc = (None, None, None)
top = None
# Apply stack contexts
last_ctx = 0
stack = contexts[0]
# Apply state
for n in stack:
try:
n.enter()
last_ctx += 1
except:
# Exception happened. Record exception info and store top-most handler
exc = sys.exc_info()
top = n.old_contexts[1]
# Execute callback if no exception happened while restoring state
if top is None:
try:
ret = fn(*args, **kwargs)
except:
exc = sys.exc_info()
top = contexts[1]
# If there was exception, try to handle it by going through the exception chain
if top is not None:
exc = _handle_exception(top, exc)
else:
# Otherwise take shorter path and run stack contexts in reverse order
while last_ctx > 0:
last_ctx -= 1
c = stack[last_ctx]
try:
c.exit(*exc)
except:
exc = sys.exc_info()
top = c.old_contexts[1]
break
else:
top = None
# If if exception happened while unrolling, take longer exception handler path
if top is not None:
exc = _handle_exception(top, exc)
# If exception was not handled, raise it
if exc != (None, None, None):
raise_exc_info(exc)
finally:
_state.contexts = current_state
return ret
wrapped._wrapped = True
return wrapped
def _handle_exception(tail, exc):
while tail is not None:
try:
if tail.exit(*exc):
exc = (None, None, None)
except:
exc = sys.exc_info()
tail = tail.old_contexts[1]
return exc
def run_with_stack_context(context, func):
"""Run a coroutine ``func`` in the given `StackContext`.
It is not safe to have a ``yield`` statement within a ``with StackContext``
block, so it is difficult to use stack context with `.gen.coroutine`.
This helper function runs the function in the correct context while
keeping the ``yield`` and ``with`` statements syntactically separate.
Example::
@gen.coroutine
def incorrect():
with StackContext(ctx):
# ERROR: this will raise StackContextInconsistentError
yield other_coroutine()
@gen.coroutine
def correct():
yield run_with_stack_context(StackContext(ctx), other_coroutine)
.. versionadded:: 3.1
"""
with context:
return func()

276
lib/tornado/tcpclient.py Executable file
View File

@@ -0,0 +1,276 @@
#
# Copyright 2014 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A non-blocking TCP connection factory.
"""
from __future__ import absolute_import, division, print_function
import functools
import socket
import numbers
import datetime
from tornado.concurrent import Future, future_add_done_callback
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado import gen
from tornado.netutil import Resolver
from tornado.platform.auto import set_close_exec
from tornado.gen import TimeoutError
from tornado.util import timedelta_to_seconds
_INITIAL_CONNECT_TIMEOUT = 0.3
class _Connector(object):
"""A stateless implementation of the "Happy Eyeballs" algorithm.
"Happy Eyeballs" is documented in RFC6555 as the recommended practice
for when both IPv4 and IPv6 addresses are available.
In this implementation, we partition the addresses by family, and
make the first connection attempt to whichever address was
returned first by ``getaddrinfo``. If that connection fails or
times out, we begin a connection in parallel to the first address
of the other family. If there are additional failures we retry
with other addresses, keeping one connection attempt per family
in flight at a time.
http://tools.ietf.org/html/rfc6555
"""
def __init__(self, addrinfo, connect):
self.io_loop = IOLoop.current()
self.connect = connect
self.future = Future()
self.timeout = None
self.connect_timeout = None
self.last_error = None
self.remaining = len(addrinfo)
self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
self.streams = set()
@staticmethod
def split(addrinfo):
"""Partition the ``addrinfo`` list by address family.
Returns two lists. The first list contains the first entry from
``addrinfo`` and all others with the same family, and the
second list contains all other addresses (normally one list will
be AF_INET and the other AF_INET6, although non-standard resolvers
may return additional families).
"""
primary = []
secondary = []
primary_af = addrinfo[0][0]
for af, addr in addrinfo:
if af == primary_af:
primary.append((af, addr))
else:
secondary.append((af, addr))
return primary, secondary
def start(self, timeout=_INITIAL_CONNECT_TIMEOUT, connect_timeout=None):
self.try_connect(iter(self.primary_addrs))
self.set_timeout(timeout)
if connect_timeout is not None:
self.set_connect_timeout(connect_timeout)
return self.future
def try_connect(self, addrs):
try:
af, addr = next(addrs)
except StopIteration:
# We've reached the end of our queue, but the other queue
# might still be working. Send a final error on the future
# only when both queues are finished.
if self.remaining == 0 and not self.future.done():
self.future.set_exception(self.last_error or
IOError("connection failed"))
return
stream, future = self.connect(af, addr)
self.streams.add(stream)
future_add_done_callback(
future, functools.partial(self.on_connect_done, addrs, af, addr))
def on_connect_done(self, addrs, af, addr, future):
self.remaining -= 1
try:
stream = future.result()
except Exception as e:
if self.future.done():
return
# Error: try again (but remember what happened so we have an
# error to raise in the end)
self.last_error = e
self.try_connect(addrs)
if self.timeout is not None:
# If the first attempt failed, don't wait for the
# timeout to try an address from the secondary queue.
self.io_loop.remove_timeout(self.timeout)
self.on_timeout()
return
self.clear_timeouts()
if self.future.done():
# This is a late arrival; just drop it.
stream.close()
else:
self.streams.discard(stream)
self.future.set_result((af, addr, stream))
self.close_streams()
def set_timeout(self, timeout):
self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout,
self.on_timeout)
def on_timeout(self):
self.timeout = None
if not self.future.done():
self.try_connect(iter(self.secondary_addrs))
def clear_timeout(self):
if self.timeout is not None:
self.io_loop.remove_timeout(self.timeout)
def set_connect_timeout(self, connect_timeout):
self.connect_timeout = self.io_loop.add_timeout(
connect_timeout, self.on_connect_timeout)
def on_connect_timeout(self):
if not self.future.done():
self.future.set_exception(TimeoutError())
self.close_streams()
def clear_timeouts(self):
if self.timeout is not None:
self.io_loop.remove_timeout(self.timeout)
if self.connect_timeout is not None:
self.io_loop.remove_timeout(self.connect_timeout)
def close_streams(self):
for stream in self.streams:
stream.close()
class TCPClient(object):
"""A non-blocking TCP connection factory.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
def __init__(self, resolver=None):
if resolver is not None:
self.resolver = resolver
self._own_resolver = False
else:
self.resolver = Resolver()
self._own_resolver = True
def close(self):
if self._own_resolver:
self.resolver.close()
@gen.coroutine
def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
max_buffer_size=None, source_ip=None, source_port=None,
timeout=None):
"""Connect to the given host and port.
Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
``ssl_options`` is not None).
Using the ``source_ip`` kwarg, one can specify the source
IP address to use when establishing the connection.
In case the user needs to resolve and
use a specific interface, it has to be handled outside
of Tornado as this depends very much on the platform.
Raises `TimeoutError` if the input future does not complete before
``timeout``, which may be specified in any form allowed by
`.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time
relative to `.IOLoop.time`)
Similarly, when the user requires a certain source port, it can
be specified using the ``source_port`` arg.
.. versionchanged:: 4.5
Added the ``source_ip`` and ``source_port`` arguments.
.. versionchanged:: 5.0
Added the ``timeout`` argument.
"""
if timeout is not None:
if isinstance(timeout, numbers.Real):
timeout = IOLoop.current().time() + timeout
elif isinstance(timeout, datetime.timedelta):
timeout = IOLoop.current().time() + timedelta_to_seconds(timeout)
else:
raise TypeError("Unsupported timeout %r" % timeout)
if timeout is not None:
addrinfo = yield gen.with_timeout(
timeout, self.resolver.resolve(host, port, af))
else:
addrinfo = yield self.resolver.resolve(host, port, af)
connector = _Connector(
addrinfo,
functools.partial(self._create_stream, max_buffer_size,
source_ip=source_ip, source_port=source_port)
)
af, addr, stream = yield connector.start(connect_timeout=timeout)
# TODO: For better performance we could cache the (af, addr)
# information here and re-use it on subsequent connections to
# the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
if ssl_options is not None:
if timeout is not None:
stream = yield gen.with_timeout(timeout, stream.start_tls(
False, ssl_options=ssl_options, server_hostname=host))
else:
stream = yield stream.start_tls(False, ssl_options=ssl_options,
server_hostname=host)
raise gen.Return(stream)
def _create_stream(self, max_buffer_size, af, addr, source_ip=None,
source_port=None):
# Always connect in plaintext; we'll convert to ssl if necessary
# after one connection has completed.
source_port_bind = source_port if isinstance(source_port, int) else 0
source_ip_bind = source_ip
if source_port_bind and not source_ip:
# User required a specific port, but did not specify
# a certain source IP, will bind to the default loopback.
source_ip_bind = '::1' if af == socket.AF_INET6 else '127.0.0.1'
# Trying to use the same address family as the requested af socket:
# - 127.0.0.1 for IPv4
# - ::1 for IPv6
socket_obj = socket.socket(af)
set_close_exec(socket_obj.fileno())
if source_port_bind or source_ip_bind:
# If the user requires binding also to a specific IP/port.
try:
socket_obj.bind((source_ip_bind, source_port_bind))
except socket.error:
socket_obj.close()
# Fail loudly if unable to use the IP/port.
raise
try:
stream = IOStream(socket_obj,
max_buffer_size=max_buffer_size)
except socket.error as e:
fu = Future()
fu.set_exception(e)
return fu
else:
return stream, stream.connect(addr)

299
lib/tornado/tcpserver.py Executable file
View File

@@ -0,0 +1,299 @@
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A non-blocking, single-threaded TCP server."""
from __future__ import absolute_import, division, print_function
import errno
import os
import socket
from tornado import gen
from tornado.log import app_log
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream, SSLIOStream
from tornado.netutil import bind_sockets, add_accept_handler, ssl_wrap_socket
from tornado import process
from tornado.util import errno_from_exception
try:
import ssl
except ImportError:
# ssl is not available on Google App Engine.
ssl = None
class TCPServer(object):
r"""A non-blocking, single-threaded TCP server.
To use `TCPServer`, define a subclass which overrides the `handle_stream`
method. For example, a simple echo server could be defined like this::
from tornado.tcpserver import TCPServer
from tornado.iostream import StreamClosedError
from tornado import gen
class EchoServer(TCPServer):
async def handle_stream(self, stream, address):
while True:
try:
data = await stream.read_until(b"\n")
await stream.write(data)
except StreamClosedError:
break
To make this server serve SSL traffic, send the ``ssl_options`` keyword
argument with an `ssl.SSLContext` object. For compatibility with older
versions of Python ``ssl_options`` may also be a dictionary of keyword
arguments for the `ssl.wrap_socket` method.::
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"),
os.path.join(data_dir, "mydomain.key"))
TCPServer(ssl_options=ssl_ctx)
`TCPServer` initialization follows one of three patterns:
1. `listen`: simple single-process::
server = TCPServer()
server.listen(8888)
IOLoop.current().start()
2. `bind`/`start`: simple multi-process::
server = TCPServer()
server.bind(8888)
server.start(0) # Forks multiple sub-processes
IOLoop.current().start()
When using this interface, an `.IOLoop` must *not* be passed
to the `TCPServer` constructor. `start` will always start
the server on the default singleton `.IOLoop`.
3. `add_sockets`: advanced multi-process::
sockets = bind_sockets(8888)
tornado.process.fork_processes(0)
server = TCPServer()
server.add_sockets(sockets)
IOLoop.current().start()
The `add_sockets` interface is more complicated, but it can be
used with `tornado.process.fork_processes` to give you more
flexibility in when the fork happens. `add_sockets` can
also be used in single-process servers if you want to create
your listening sockets in some way other than
`~tornado.netutil.bind_sockets`.
.. versionadded:: 3.1
The ``max_buffer_size`` argument.
.. versionchanged:: 5.0
The ``io_loop`` argument has been removed.
"""
def __init__(self, ssl_options=None, max_buffer_size=None,
read_chunk_size=None):
self.ssl_options = ssl_options
self._sockets = {} # fd -> socket object
self._handlers = {} # fd -> remove_handler callable
self._pending_sockets = []
self._started = False
self._stopped = False
self.max_buffer_size = max_buffer_size
self.read_chunk_size = read_chunk_size
# Verify the SSL options. Otherwise we don't get errors until clients
# connect. This doesn't verify that the keys are legitimate, but
# the SSL module doesn't do that until there is a connected socket
# which seems like too much work
if self.ssl_options is not None and isinstance(self.ssl_options, dict):
# Only certfile is required: it can contain both keys
if 'certfile' not in self.ssl_options:
raise KeyError('missing key "certfile" in ssl_options')
if not os.path.exists(self.ssl_options['certfile']):
raise ValueError('certfile "%s" does not exist' %
self.ssl_options['certfile'])
if ('keyfile' in self.ssl_options and
not os.path.exists(self.ssl_options['keyfile'])):
raise ValueError('keyfile "%s" does not exist' %
self.ssl_options['keyfile'])
def listen(self, port, address=""):
"""Starts accepting connections on the given port.
This method may be called more than once to listen on multiple ports.
`listen` takes effect immediately; it is not necessary to call
`TCPServer.start` afterwards. It is, however, necessary to start
the `.IOLoop`.
"""
sockets = bind_sockets(port, address=address)
self.add_sockets(sockets)
def add_sockets(self, sockets):
"""Makes this server start accepting connections on the given sockets.
The ``sockets`` parameter is a list of socket objects such as
those returned by `~tornado.netutil.bind_sockets`.
`add_sockets` is typically used in combination with that
method and `tornado.process.fork_processes` to provide greater
control over the initialization of a multi-process server.
"""
for sock in sockets:
self._sockets[sock.fileno()] = sock
self._handlers[sock.fileno()] = add_accept_handler(
sock, self._handle_connection)
def add_socket(self, socket):
"""Singular version of `add_sockets`. Takes a single socket object."""
self.add_sockets([socket])
def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128,
reuse_port=False):
"""Binds this server to the given port on the given address.
To start the server, call `start`. If you want to run this server
in a single process, you can call `listen` as a shortcut to the
sequence of `bind` and `start` calls.
Address may be either an IP address or hostname. If it's a hostname,
the server will listen on all IP addresses associated with the
name. Address may be an empty string or None to listen on all
available interfaces. Family may be set to either `socket.AF_INET`
or `socket.AF_INET6` to restrict to IPv4 or IPv6 addresses, otherwise
both will be used if available.
The ``backlog`` argument has the same meaning as for
`socket.listen <socket.socket.listen>`. The ``reuse_port`` argument
has the same meaning as for `.bind_sockets`.
This method may be called multiple times prior to `start` to listen
on multiple ports or interfaces.
.. versionchanged:: 4.4
Added the ``reuse_port`` argument.
"""
sockets = bind_sockets(port, address=address, family=family,
backlog=backlog, reuse_port=reuse_port)
if self._started:
self.add_sockets(sockets)
else:
self._pending_sockets.extend(sockets)
def start(self, num_processes=1):
"""Starts this server in the `.IOLoop`.
By default, we run the server in this process and do not fork any
additional child process.
If num_processes is ``None`` or <= 0, we detect the number of cores
available on this machine and fork that number of child
processes. If num_processes is given and > 1, we fork that
specific number of sub-processes.
Since we use processes and not threads, there is no shared memory
between any server code.
Note that multiple processes are not compatible with the autoreload
module (or the ``autoreload=True`` option to `tornado.web.Application`
which defaults to True when ``debug=True``).
When using multiple processes, no IOLoops can be created or
referenced until after the call to ``TCPServer.start(n)``.
"""
assert not self._started
self._started = True
if num_processes != 1:
process.fork_processes(num_processes)
sockets = self._pending_sockets
self._pending_sockets = []
self.add_sockets(sockets)
def stop(self):
"""Stops listening for new connections.
Requests currently in progress may still continue after the
server is stopped.
"""
if self._stopped:
return
self._stopped = True
for fd, sock in self._sockets.items():
assert sock.fileno() == fd
# Unregister socket from IOLoop
self._handlers.pop(fd)()
sock.close()
def handle_stream(self, stream, address):
"""Override to handle a new `.IOStream` from an incoming connection.
This method may be a coroutine; if so any exceptions it raises
asynchronously will be logged. Accepting of incoming connections
will not be blocked by this coroutine.
If this `TCPServer` is configured for SSL, ``handle_stream``
may be called before the SSL handshake has completed. Use
`.SSLIOStream.wait_for_handshake` if you need to verify the client's
certificate or use NPN/ALPN.
.. versionchanged:: 4.2
Added the option for this method to be a coroutine.
"""
raise NotImplementedError()
def _handle_connection(self, connection, address):
if self.ssl_options is not None:
assert ssl, "Python 2.6+ and OpenSSL required for SSL"
try:
connection = ssl_wrap_socket(connection,
self.ssl_options,
server_side=True,
do_handshake_on_connect=False)
except ssl.SSLError as err:
if err.args[0] == ssl.SSL_ERROR_EOF:
return connection.close()
else:
raise
except socket.error as err:
# If the connection is closed immediately after it is created
# (as in a port scan), we can get one of several errors.
# wrap_socket makes an internal call to getpeername,
# which may return either EINVAL (Mac OS X) or ENOTCONN
# (Linux). If it returns ENOTCONN, this error is
# silently swallowed by the ssl module, so we need to
# catch another error later on (AttributeError in
# SSLIOStream._do_ssl_handshake).
# To test this behavior, try nmap with the -sT flag.
# https://github.com/tornadoweb/tornado/pull/750
if errno_from_exception(err) in (errno.ECONNABORTED, errno.EINVAL):
return connection.close()
else:
raise
try:
if self.ssl_options is not None:
stream = SSLIOStream(connection,
max_buffer_size=self.max_buffer_size,
read_chunk_size=self.read_chunk_size)
else:
stream = IOStream(connection,
max_buffer_size=self.max_buffer_size,
read_chunk_size=self.read_chunk_size)
future = self.handle_stream(stream, address)
if future is not None:
IOLoop.current().add_future(gen.convert_yielded(future),
lambda f: f.result())
except Exception:
app_log.error("Error in connection callback", exc_info=True)

976
lib/tornado/template.py Executable file
View File

@@ -0,0 +1,976 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A simple template system that compiles templates to Python code.
Basic usage looks like::
t = template.Template("<html>{{ myvalue }}</html>")
print(t.generate(myvalue="XXX"))
`Loader` is a class that loads templates from a root directory and caches
the compiled templates::
loader = template.Loader("/home/btaylor")
print(loader.load("test.html").generate(myvalue="XXX"))
We compile all templates to raw Python. Error-reporting is currently... uh,
interesting. Syntax for the templates::
### base.html
<html>
<head>
<title>{% block title %}Default title{% end %}</title>
</head>
<body>
<ul>
{% for student in students %}
{% block student %}
<li>{{ escape(student.name) }}</li>
{% end %}
{% end %}
</ul>
</body>
</html>
### bold.html
{% extends "base.html" %}
{% block title %}A bolder title{% end %}
{% block student %}
<li><span style="bold">{{ escape(student.name) }}</span></li>
{% end %}
Unlike most other template systems, we do not put any restrictions on the
expressions you can include in your statements. ``if`` and ``for`` blocks get
translated exactly into Python, so you can do complex expressions like::
{% for student in [p for p in people if p.student and p.age > 23] %}
<li>{{ escape(student.name) }}</li>
{% end %}
Translating directly to Python means you can apply functions to expressions
easily, like the ``escape()`` function in the examples above. You can pass
functions in to your template just like any other variable
(In a `.RequestHandler`, override `.RequestHandler.get_template_namespace`)::
### Python code
def add(x, y):
return x + y
template.execute(add=add)
### The template
{{ add(1, 2) }}
We provide the functions `escape() <.xhtml_escape>`, `.url_escape()`,
`.json_encode()`, and `.squeeze()` to all templates by default.
Typical applications do not create `Template` or `Loader` instances by
hand, but instead use the `~.RequestHandler.render` and
`~.RequestHandler.render_string` methods of
`tornado.web.RequestHandler`, which load templates automatically based
on the ``template_path`` `.Application` setting.
Variable names beginning with ``_tt_`` are reserved by the template
system and should not be used by application code.
Syntax Reference
----------------
Template expressions are surrounded by double curly braces: ``{{ ... }}``.
The contents may be any python expression, which will be escaped according
to the current autoescape setting and inserted into the output. Other
template directives use ``{% %}``.
To comment out a section so that it is omitted from the output, surround it
with ``{# ... #}``.
These tags may be escaped as ``{{!``, ``{%!``, and ``{#!``
if you need to include a literal ``{{``, ``{%``, or ``{#`` in the output.
``{% apply *function* %}...{% end %}``
Applies a function to the output of all template code between ``apply``
and ``end``::
{% apply linkify %}{{name}} said: {{message}}{% end %}
Note that as an implementation detail apply blocks are implemented
as nested functions and thus may interact strangely with variables
set via ``{% set %}``, or the use of ``{% break %}`` or ``{% continue %}``
within loops.
``{% autoescape *function* %}``
Sets the autoescape mode for the current file. This does not affect
other files, even those referenced by ``{% include %}``. Note that
autoescaping can also be configured globally, at the `.Application`
or `Loader`.::
{% autoescape xhtml_escape %}
{% autoescape None %}
``{% block *name* %}...{% end %}``
Indicates a named, replaceable block for use with ``{% extends %}``.
Blocks in the parent template will be replaced with the contents of
the same-named block in a child template.::
<!-- base.html -->
<title>{% block title %}Default title{% end %}</title>
<!-- mypage.html -->
{% extends "base.html" %}
{% block title %}My page title{% end %}
``{% comment ... %}``
A comment which will be removed from the template output. Note that
there is no ``{% end %}`` tag; the comment goes from the word ``comment``
to the closing ``%}`` tag.
``{% extends *filename* %}``
Inherit from another template. Templates that use ``extends`` should
contain one or more ``block`` tags to replace content from the parent
template. Anything in the child template not contained in a ``block``
tag will be ignored. For an example, see the ``{% block %}`` tag.
``{% for *var* in *expr* %}...{% end %}``
Same as the python ``for`` statement. ``{% break %}`` and
``{% continue %}`` may be used inside the loop.
``{% from *x* import *y* %}``
Same as the python ``import`` statement.
``{% if *condition* %}...{% elif *condition* %}...{% else %}...{% end %}``
Conditional statement - outputs the first section whose condition is
true. (The ``elif`` and ``else`` sections are optional)
``{% import *module* %}``
Same as the python ``import`` statement.
``{% include *filename* %}``
Includes another template file. The included file can see all the local
variables as if it were copied directly to the point of the ``include``
directive (the ``{% autoescape %}`` directive is an exception).
Alternately, ``{% module Template(filename, **kwargs) %}`` may be used
to include another template with an isolated namespace.
``{% module *expr* %}``
Renders a `~tornado.web.UIModule`. The output of the ``UIModule`` is
not escaped::
{% module Template("foo.html", arg=42) %}
``UIModules`` are a feature of the `tornado.web.RequestHandler`
class (and specifically its ``render`` method) and will not work
when the template system is used on its own in other contexts.
``{% raw *expr* %}``
Outputs the result of the given expression without autoescaping.
``{% set *x* = *y* %}``
Sets a local variable.
``{% try %}...{% except %}...{% else %}...{% finally %}...{% end %}``
Same as the python ``try`` statement.
``{% while *condition* %}... {% end %}``
Same as the python ``while`` statement. ``{% break %}`` and
``{% continue %}`` may be used inside the loop.
``{% whitespace *mode* %}``
Sets the whitespace mode for the remainder of the current file
(or until the next ``{% whitespace %}`` directive). See
`filter_whitespace` for available options. New in Tornado 4.3.
"""
from __future__ import absolute_import, division, print_function
import datetime
import linecache
import os.path
import posixpath
import re
import threading
from tornado import escape
from tornado.log import app_log
from tornado.util import ObjectDict, exec_in, unicode_type, PY3
if PY3:
from io import StringIO
else:
from cStringIO import StringIO
_DEFAULT_AUTOESCAPE = "xhtml_escape"
_UNSET = object()
def filter_whitespace(mode, text):
"""Transform whitespace in ``text`` according to ``mode``.
Available modes are:
* ``all``: Return all whitespace unmodified.
* ``single``: Collapse consecutive whitespace with a single whitespace
character, preserving newlines.
* ``oneline``: Collapse all runs of whitespace into a single space
character, removing all newlines in the process.
.. versionadded:: 4.3
"""
if mode == 'all':
return text
elif mode == 'single':
text = re.sub(r"([\t ]+)", " ", text)
text = re.sub(r"(\s*\n\s*)", "\n", text)
return text
elif mode == 'oneline':
return re.sub(r"(\s+)", " ", text)
else:
raise Exception("invalid whitespace mode %s" % mode)
class Template(object):
"""A compiled template.
We compile into Python from the given template_string. You can generate
the template from variables with generate().
"""
# note that the constructor's signature is not extracted with
# autodoc because _UNSET looks like garbage. When changing
# this signature update website/sphinx/template.rst too.
def __init__(self, template_string, name="<string>", loader=None,
compress_whitespace=_UNSET, autoescape=_UNSET,
whitespace=None):
"""Construct a Template.
:arg str template_string: the contents of the template file.
:arg str name: the filename from which the template was loaded
(used for error message).
:arg tornado.template.BaseLoader loader: the `~tornado.template.BaseLoader` responsible
for this template, used to resolve ``{% include %}`` and ``{% extend %}`` directives.
:arg bool compress_whitespace: Deprecated since Tornado 4.3.
Equivalent to ``whitespace="single"`` if true and
``whitespace="all"`` if false.
:arg str autoescape: The name of a function in the template
namespace, or ``None`` to disable escaping by default.
:arg str whitespace: A string specifying treatment of whitespace;
see `filter_whitespace` for options.
.. versionchanged:: 4.3
Added ``whitespace`` parameter; deprecated ``compress_whitespace``.
"""
self.name = escape.native_str(name)
if compress_whitespace is not _UNSET:
# Convert deprecated compress_whitespace (bool) to whitespace (str).
if whitespace is not None:
raise Exception("cannot set both whitespace and compress_whitespace")
whitespace = "single" if compress_whitespace else "all"
if whitespace is None:
if loader and loader.whitespace:
whitespace = loader.whitespace
else:
# Whitespace defaults by filename.
if name.endswith(".html") or name.endswith(".js"):
whitespace = "single"
else:
whitespace = "all"
# Validate the whitespace setting.
filter_whitespace(whitespace, '')
if autoescape is not _UNSET:
self.autoescape = autoescape
elif loader:
self.autoescape = loader.autoescape
else:
self.autoescape = _DEFAULT_AUTOESCAPE
self.namespace = loader.namespace if loader else {}
reader = _TemplateReader(name, escape.native_str(template_string),
whitespace)
self.file = _File(self, _parse(reader, self))
self.code = self._generate_python(loader)
self.loader = loader
try:
# Under python2.5, the fake filename used here must match
# the module name used in __name__ below.
# The dont_inherit flag prevents template.py's future imports
# from being applied to the generated code.
self.compiled = compile(
escape.to_unicode(self.code),
"%s.generated.py" % self.name.replace('.', '_'),
"exec", dont_inherit=True)
except Exception:
formatted_code = _format_code(self.code).rstrip()
app_log.error("%s code:\n%s", self.name, formatted_code)
raise
def generate(self, **kwargs):
"""Generate this template with the given arguments."""
namespace = {
"escape": escape.xhtml_escape,
"xhtml_escape": escape.xhtml_escape,
"url_escape": escape.url_escape,
"json_encode": escape.json_encode,
"squeeze": escape.squeeze,
"linkify": escape.linkify,
"datetime": datetime,
"_tt_utf8": escape.utf8, # for internal use
"_tt_string_types": (unicode_type, bytes),
# __name__ and __loader__ allow the traceback mechanism to find
# the generated source code.
"__name__": self.name.replace('.', '_'),
"__loader__": ObjectDict(get_source=lambda name: self.code),
}
namespace.update(self.namespace)
namespace.update(kwargs)
exec_in(self.compiled, namespace)
execute = namespace["_tt_execute"]
# Clear the traceback module's cache of source data now that
# we've generated a new template (mainly for this module's
# unittests, where different tests reuse the same name).
linecache.clearcache()
return execute()
def _generate_python(self, loader):
buffer = StringIO()
try:
# named_blocks maps from names to _NamedBlock objects
named_blocks = {}
ancestors = self._get_ancestors(loader)
ancestors.reverse()
for ancestor in ancestors:
ancestor.find_named_blocks(loader, named_blocks)
writer = _CodeWriter(buffer, named_blocks, loader,
ancestors[0].template)
ancestors[0].generate(writer)
return buffer.getvalue()
finally:
buffer.close()
def _get_ancestors(self, loader):
ancestors = [self.file]
for chunk in self.file.body.chunks:
if isinstance(chunk, _ExtendsBlock):
if not loader:
raise ParseError("{% extends %} block found, but no "
"template loader")
template = loader.load(chunk.name, self.name)
ancestors.extend(template._get_ancestors(loader))
return ancestors
class BaseLoader(object):
"""Base class for template loaders.
You must use a template loader to use template constructs like
``{% extends %}`` and ``{% include %}``. The loader caches all
templates after they are loaded the first time.
"""
def __init__(self, autoescape=_DEFAULT_AUTOESCAPE, namespace=None,
whitespace=None):
"""Construct a template loader.
:arg str autoescape: The name of a function in the template
namespace, such as "xhtml_escape", or ``None`` to disable
autoescaping by default.
:arg dict namespace: A dictionary to be added to the default template
namespace, or ``None``.
:arg str whitespace: A string specifying default behavior for
whitespace in templates; see `filter_whitespace` for options.
Default is "single" for files ending in ".html" and ".js" and
"all" for other files.
.. versionchanged:: 4.3
Added ``whitespace`` parameter.
"""
self.autoescape = autoescape
self.namespace = namespace or {}
self.whitespace = whitespace
self.templates = {}
# self.lock protects self.templates. It's a reentrant lock
# because templates may load other templates via `include` or
# `extends`. Note that thanks to the GIL this code would be safe
# even without the lock, but could lead to wasted work as multiple
# threads tried to compile the same template simultaneously.
self.lock = threading.RLock()
def reset(self):
"""Resets the cache of compiled templates."""
with self.lock:
self.templates = {}
def resolve_path(self, name, parent_path=None):
"""Converts a possibly-relative path to absolute (used internally)."""
raise NotImplementedError()
def load(self, name, parent_path=None):
"""Loads a template."""
name = self.resolve_path(name, parent_path=parent_path)
with self.lock:
if name not in self.templates:
self.templates[name] = self._create_template(name)
return self.templates[name]
def _create_template(self, name):
raise NotImplementedError()
class Loader(BaseLoader):
"""A template loader that loads from a single root directory.
"""
def __init__(self, root_directory, **kwargs):
super(Loader, self).__init__(**kwargs)
self.root = os.path.abspath(root_directory)
def resolve_path(self, name, parent_path=None):
if parent_path and not parent_path.startswith("<") and \
not parent_path.startswith("/") and \
not name.startswith("/"):
current_path = os.path.join(self.root, parent_path)
file_dir = os.path.dirname(os.path.abspath(current_path))
relative_path = os.path.abspath(os.path.join(file_dir, name))
if relative_path.startswith(self.root):
name = relative_path[len(self.root) + 1:]
return name
def _create_template(self, name):
path = os.path.join(self.root, name)
with open(path, "rb") as f:
template = Template(f.read(), name=name, loader=self)
return template
class DictLoader(BaseLoader):
"""A template loader that loads from a dictionary."""
def __init__(self, dict, **kwargs):
super(DictLoader, self).__init__(**kwargs)
self.dict = dict
def resolve_path(self, name, parent_path=None):
if parent_path and not parent_path.startswith("<") and \
not parent_path.startswith("/") and \
not name.startswith("/"):
file_dir = posixpath.dirname(parent_path)
name = posixpath.normpath(posixpath.join(file_dir, name))
return name
def _create_template(self, name):
return Template(self.dict[name], name=name, loader=self)
class _Node(object):
def each_child(self):
return ()
def generate(self, writer):
raise NotImplementedError()
def find_named_blocks(self, loader, named_blocks):
for child in self.each_child():
child.find_named_blocks(loader, named_blocks)
class _File(_Node):
def __init__(self, template, body):
self.template = template
self.body = body
self.line = 0
def generate(self, writer):
writer.write_line("def _tt_execute():", self.line)
with writer.indent():
writer.write_line("_tt_buffer = []", self.line)
writer.write_line("_tt_append = _tt_buffer.append", self.line)
self.body.generate(writer)
writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line)
def each_child(self):
return (self.body,)
class _ChunkList(_Node):
def __init__(self, chunks):
self.chunks = chunks
def generate(self, writer):
for chunk in self.chunks:
chunk.generate(writer)
def each_child(self):
return self.chunks
class _NamedBlock(_Node):
def __init__(self, name, body, template, line):
self.name = name
self.body = body
self.template = template
self.line = line
def each_child(self):
return (self.body,)
def generate(self, writer):
block = writer.named_blocks[self.name]
with writer.include(block.template, self.line):
block.body.generate(writer)
def find_named_blocks(self, loader, named_blocks):
named_blocks[self.name] = self
_Node.find_named_blocks(self, loader, named_blocks)
class _ExtendsBlock(_Node):
def __init__(self, name):
self.name = name
class _IncludeBlock(_Node):
def __init__(self, name, reader, line):
self.name = name
self.template_name = reader.name
self.line = line
def find_named_blocks(self, loader, named_blocks):
included = loader.load(self.name, self.template_name)
included.file.find_named_blocks(loader, named_blocks)
def generate(self, writer):
included = writer.loader.load(self.name, self.template_name)
with writer.include(included, self.line):
included.file.body.generate(writer)
class _ApplyBlock(_Node):
def __init__(self, method, line, body=None):
self.method = method
self.line = line
self.body = body
def each_child(self):
return (self.body,)
def generate(self, writer):
method_name = "_tt_apply%d" % writer.apply_counter
writer.apply_counter += 1
writer.write_line("def %s():" % method_name, self.line)
with writer.indent():
writer.write_line("_tt_buffer = []", self.line)
writer.write_line("_tt_append = _tt_buffer.append", self.line)
self.body.generate(writer)
writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line)
writer.write_line("_tt_append(_tt_utf8(%s(%s())))" % (
self.method, method_name), self.line)
class _ControlBlock(_Node):
def __init__(self, statement, line, body=None):
self.statement = statement
self.line = line
self.body = body
def each_child(self):
return (self.body,)
def generate(self, writer):
writer.write_line("%s:" % self.statement, self.line)
with writer.indent():
self.body.generate(writer)
# Just in case the body was empty
writer.write_line("pass", self.line)
class _IntermediateControlBlock(_Node):
def __init__(self, statement, line):
self.statement = statement
self.line = line
def generate(self, writer):
# In case the previous block was empty
writer.write_line("pass", self.line)
writer.write_line("%s:" % self.statement, self.line, writer.indent_size() - 1)
class _Statement(_Node):
def __init__(self, statement, line):
self.statement = statement
self.line = line
def generate(self, writer):
writer.write_line(self.statement, self.line)
class _Expression(_Node):
def __init__(self, expression, line, raw=False):
self.expression = expression
self.line = line
self.raw = raw
def generate(self, writer):
writer.write_line("_tt_tmp = %s" % self.expression, self.line)
writer.write_line("if isinstance(_tt_tmp, _tt_string_types):"
" _tt_tmp = _tt_utf8(_tt_tmp)", self.line)
writer.write_line("else: _tt_tmp = _tt_utf8(str(_tt_tmp))", self.line)
if not self.raw and writer.current_template.autoescape is not None:
# In python3 functions like xhtml_escape return unicode,
# so we have to convert to utf8 again.
writer.write_line("_tt_tmp = _tt_utf8(%s(_tt_tmp))" %
writer.current_template.autoescape, self.line)
writer.write_line("_tt_append(_tt_tmp)", self.line)
class _Module(_Expression):
def __init__(self, expression, line):
super(_Module, self).__init__("_tt_modules." + expression, line,
raw=True)
class _Text(_Node):
def __init__(self, value, line, whitespace):
self.value = value
self.line = line
self.whitespace = whitespace
def generate(self, writer):
value = self.value
# Compress whitespace if requested, with a crude heuristic to avoid
# altering preformatted whitespace.
if "<pre>" not in value:
value = filter_whitespace(self.whitespace, value)
if value:
writer.write_line('_tt_append(%r)' % escape.utf8(value), self.line)
class ParseError(Exception):
"""Raised for template syntax errors.
``ParseError`` instances have ``filename`` and ``lineno`` attributes
indicating the position of the error.
.. versionchanged:: 4.3
Added ``filename`` and ``lineno`` attributes.
"""
def __init__(self, message, filename=None, lineno=0):
self.message = message
# The names "filename" and "lineno" are chosen for consistency
# with python SyntaxError.
self.filename = filename
self.lineno = lineno
def __str__(self):
return '%s at %s:%d' % (self.message, self.filename, self.lineno)
class _CodeWriter(object):
def __init__(self, file, named_blocks, loader, current_template):
self.file = file
self.named_blocks = named_blocks
self.loader = loader
self.current_template = current_template
self.apply_counter = 0
self.include_stack = []
self._indent = 0
def indent_size(self):
return self._indent
def indent(self):
class Indenter(object):
def __enter__(_):
self._indent += 1
return self
def __exit__(_, *args):
assert self._indent > 0
self._indent -= 1
return Indenter()
def include(self, template, line):
self.include_stack.append((self.current_template, line))
self.current_template = template
class IncludeTemplate(object):
def __enter__(_):
return self
def __exit__(_, *args):
self.current_template = self.include_stack.pop()[0]
return IncludeTemplate()
def write_line(self, line, line_number, indent=None):
if indent is None:
indent = self._indent
line_comment = ' # %s:%d' % (self.current_template.name, line_number)
if self.include_stack:
ancestors = ["%s:%d" % (tmpl.name, lineno)
for (tmpl, lineno) in self.include_stack]
line_comment += ' (via %s)' % ', '.join(reversed(ancestors))
print(" " * indent + line + line_comment, file=self.file)
class _TemplateReader(object):
def __init__(self, name, text, whitespace):
self.name = name
self.text = text
self.whitespace = whitespace
self.line = 1
self.pos = 0
def find(self, needle, start=0, end=None):
assert start >= 0, start
pos = self.pos
start += pos
if end is None:
index = self.text.find(needle, start)
else:
end += pos
assert end >= start
index = self.text.find(needle, start, end)
if index != -1:
index -= pos
return index
def consume(self, count=None):
if count is None:
count = len(self.text) - self.pos
newpos = self.pos + count
self.line += self.text.count("\n", self.pos, newpos)
s = self.text[self.pos:newpos]
self.pos = newpos
return s
def remaining(self):
return len(self.text) - self.pos
def __len__(self):
return self.remaining()
def __getitem__(self, key):
if type(key) is slice:
size = len(self)
start, stop, step = key.indices(size)
if start is None:
start = self.pos
else:
start += self.pos
if stop is not None:
stop += self.pos
return self.text[slice(start, stop, step)]
elif key < 0:
return self.text[key]
else:
return self.text[self.pos + key]
def __str__(self):
return self.text[self.pos:]
def raise_parse_error(self, msg):
raise ParseError(msg, self.name, self.line)
def _format_code(code):
lines = code.splitlines()
format = "%%%dd %%s\n" % len(repr(len(lines) + 1))
return "".join([format % (i + 1, line) for (i, line) in enumerate(lines)])
def _parse(reader, template, in_block=None, in_loop=None):
body = _ChunkList([])
while True:
# Find next template directive
curly = 0
while True:
curly = reader.find("{", curly)
if curly == -1 or curly + 1 == reader.remaining():
# EOF
if in_block:
reader.raise_parse_error(
"Missing {%% end %%} block for %s" % in_block)
body.chunks.append(_Text(reader.consume(), reader.line,
reader.whitespace))
return body
# If the first curly brace is not the start of a special token,
# start searching from the character after it
if reader[curly + 1] not in ("{", "%", "#"):
curly += 1
continue
# When there are more than 2 curlies in a row, use the
# innermost ones. This is useful when generating languages
# like latex where curlies are also meaningful
if (curly + 2 < reader.remaining() and
reader[curly + 1] == '{' and reader[curly + 2] == '{'):
curly += 1
continue
break
# Append any text before the special token
if curly > 0:
cons = reader.consume(curly)
body.chunks.append(_Text(cons, reader.line,
reader.whitespace))
start_brace = reader.consume(2)
line = reader.line
# Template directives may be escaped as "{{!" or "{%!".
# In this case output the braces and consume the "!".
# This is especially useful in conjunction with jquery templates,
# which also use double braces.
if reader.remaining() and reader[0] == "!":
reader.consume(1)
body.chunks.append(_Text(start_brace, line,
reader.whitespace))
continue
# Comment
if start_brace == "{#":
end = reader.find("#}")
if end == -1:
reader.raise_parse_error("Missing end comment #}")
contents = reader.consume(end).strip()
reader.consume(2)
continue
# Expression
if start_brace == "{{":
end = reader.find("}}")
if end == -1:
reader.raise_parse_error("Missing end expression }}")
contents = reader.consume(end).strip()
reader.consume(2)
if not contents:
reader.raise_parse_error("Empty expression")
body.chunks.append(_Expression(contents, line))
continue
# Block
assert start_brace == "{%", start_brace
end = reader.find("%}")
if end == -1:
reader.raise_parse_error("Missing end block %}")
contents = reader.consume(end).strip()
reader.consume(2)
if not contents:
reader.raise_parse_error("Empty block tag ({% %})")
operator, space, suffix = contents.partition(" ")
suffix = suffix.strip()
# Intermediate ("else", "elif", etc) blocks
intermediate_blocks = {
"else": set(["if", "for", "while", "try"]),
"elif": set(["if"]),
"except": set(["try"]),
"finally": set(["try"]),
}
allowed_parents = intermediate_blocks.get(operator)
if allowed_parents is not None:
if not in_block:
reader.raise_parse_error("%s outside %s block" %
(operator, allowed_parents))
if in_block not in allowed_parents:
reader.raise_parse_error(
"%s block cannot be attached to %s block" %
(operator, in_block))
body.chunks.append(_IntermediateControlBlock(contents, line))
continue
# End tag
elif operator == "end":
if not in_block:
reader.raise_parse_error("Extra {% end %} block")
return body
elif operator in ("extends", "include", "set", "import", "from",
"comment", "autoescape", "whitespace", "raw",
"module"):
if operator == "comment":
continue
if operator == "extends":
suffix = suffix.strip('"').strip("'")
if not suffix:
reader.raise_parse_error("extends missing file path")
block = _ExtendsBlock(suffix)
elif operator in ("import", "from"):
if not suffix:
reader.raise_parse_error("import missing statement")
block = _Statement(contents, line)
elif operator == "include":
suffix = suffix.strip('"').strip("'")
if not suffix:
reader.raise_parse_error("include missing file path")
block = _IncludeBlock(suffix, reader, line)
elif operator == "set":
if not suffix:
reader.raise_parse_error("set missing statement")
block = _Statement(suffix, line)
elif operator == "autoescape":
fn = suffix.strip()
if fn == "None":
fn = None
template.autoescape = fn
continue
elif operator == "whitespace":
mode = suffix.strip()
# Validate the selected mode
filter_whitespace(mode, '')
reader.whitespace = mode
continue
elif operator == "raw":
block = _Expression(suffix, line, raw=True)
elif operator == "module":
block = _Module(suffix, line)
body.chunks.append(block)
continue
elif operator in ("apply", "block", "try", "if", "for", "while"):
# parse inner body recursively
if operator in ("for", "while"):
block_body = _parse(reader, template, operator, operator)
elif operator == "apply":
# apply creates a nested function so syntactically it's not
# in the loop.
block_body = _parse(reader, template, operator, None)
else:
block_body = _parse(reader, template, operator, in_loop)
if operator == "apply":
if not suffix:
reader.raise_parse_error("apply missing method name")
block = _ApplyBlock(suffix, line, block_body)
elif operator == "block":
if not suffix:
reader.raise_parse_error("block missing name")
block = _NamedBlock(suffix, block_body, template, line)
else:
block = _ControlBlock(contents, line, block_body)
body.chunks.append(block)
continue
elif operator in ("break", "continue"):
if not in_loop:
reader.raise_parse_error("%s outside %s block" %
(operator, set(["for", "while"])))
body.chunks.append(_Statement(contents, line))
continue
else:
reader.raise_parse_error("unknown operator: %r" % operator)

0
lib/tornado/test/__init__.py Executable file
View File

14
lib/tornado/test/__main__.py Executable file
View File

@@ -0,0 +1,14 @@
"""Shim to allow python -m tornado.test.
This only works in python 2.7+.
"""
from __future__ import absolute_import, division, print_function
from tornado.test.runtests import all, main
# tornado.testing.main autodiscovery relies on 'all' being present in
# the main module, so import it here even though it is not used directly.
# The following line prevents a pyflakes warning.
all = all
main()

206
lib/tornado/test/asyncio_test.py Executable file
View File

@@ -0,0 +1,206 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function
from concurrent.futures import ThreadPoolExecutor
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import unittest, skipBefore33, skipBefore35, exec_test
try:
from tornado.platform.asyncio import asyncio
except ImportError:
asyncio = None
else:
from tornado.platform.asyncio import AsyncIOLoop, to_asyncio_future, AnyThreadEventLoopPolicy
# This is used in dynamically-evaluated code, so silence pyflakes.
to_asyncio_future
@unittest.skipIf(asyncio is None, "asyncio module not present")
class AsyncIOLoopTest(AsyncTestCase):
def get_new_ioloop(self):
io_loop = AsyncIOLoop()
return io_loop
def test_asyncio_callback(self):
# Basic test that the asyncio loop is set up correctly.
asyncio.get_event_loop().call_soon(self.stop)
self.wait()
@gen_test
def test_asyncio_future(self):
# Test that we can yield an asyncio future from a tornado coroutine.
# Without 'yield from', we must wrap coroutines in ensure_future,
# which was introduced during Python 3.4, deprecating the prior "async".
if hasattr(asyncio, 'ensure_future'):
ensure_future = asyncio.ensure_future
else:
# async is a reserved word in Python 3.7
ensure_future = getattr(asyncio, 'async')
x = yield ensure_future(
asyncio.get_event_loop().run_in_executor(None, lambda: 42))
self.assertEqual(x, 42)
@skipBefore33
@gen_test
def test_asyncio_yield_from(self):
# Test that we can use asyncio coroutines with 'yield from'
# instead of asyncio.async(). This requires python 3.3 syntax.
namespace = exec_test(globals(), locals(), """
@gen.coroutine
def f():
event_loop = asyncio.get_event_loop()
x = yield from event_loop.run_in_executor(None, lambda: 42)
return x
""")
result = yield namespace['f']()
self.assertEqual(result, 42)
@skipBefore35
def test_asyncio_adapter(self):
# This test demonstrates that when using the asyncio coroutine
# runner (i.e. run_until_complete), the to_asyncio_future
# adapter is needed. No adapter is needed in the other direction,
# as demonstrated by other tests in the package.
@gen.coroutine
def tornado_coroutine():
yield gen.moment
raise gen.Return(42)
native_coroutine_without_adapter = exec_test(globals(), locals(), """
async def native_coroutine_without_adapter():
return await tornado_coroutine()
""")["native_coroutine_without_adapter"]
native_coroutine_with_adapter = exec_test(globals(), locals(), """
async def native_coroutine_with_adapter():
return await to_asyncio_future(tornado_coroutine())
""")["native_coroutine_with_adapter"]
# Use the adapter, but two degrees from the tornado coroutine.
native_coroutine_with_adapter2 = exec_test(globals(), locals(), """
async def native_coroutine_with_adapter2():
return await to_asyncio_future(native_coroutine_without_adapter())
""")["native_coroutine_with_adapter2"]
# Tornado supports native coroutines both with and without adapters
self.assertEqual(
self.io_loop.run_sync(native_coroutine_without_adapter),
42)
self.assertEqual(
self.io_loop.run_sync(native_coroutine_with_adapter),
42)
self.assertEqual(
self.io_loop.run_sync(native_coroutine_with_adapter2),
42)
# Asyncio only supports coroutines that yield asyncio-compatible
# Futures (which our Future is since 5.0).
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
native_coroutine_without_adapter()),
42)
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
native_coroutine_with_adapter()),
42)
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
native_coroutine_with_adapter2()),
42)
@unittest.skipIf(asyncio is None, "asyncio module not present")
class LeakTest(unittest.TestCase):
def setUp(self):
# Trigger a cleanup of the mapping so we start with a clean slate.
AsyncIOLoop().close()
# If we don't clean up after ourselves other tests may fail on
# py34.
self.orig_policy = asyncio.get_event_loop_policy()
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
def tearDown(self):
asyncio.get_event_loop().close()
asyncio.set_event_loop_policy(self.orig_policy)
def test_ioloop_close_leak(self):
orig_count = len(IOLoop._ioloop_for_asyncio)
for i in range(10):
# Create and close an AsyncIOLoop using Tornado interfaces.
loop = AsyncIOLoop()
loop.close()
new_count = len(IOLoop._ioloop_for_asyncio) - orig_count
self.assertEqual(new_count, 0)
def test_asyncio_close_leak(self):
orig_count = len(IOLoop._ioloop_for_asyncio)
for i in range(10):
# Create and close an AsyncIOMainLoop using asyncio interfaces.
loop = asyncio.new_event_loop()
loop.call_soon(IOLoop.current)
loop.call_soon(loop.stop)
loop.run_forever()
loop.close()
new_count = len(IOLoop._ioloop_for_asyncio) - orig_count
# Because the cleanup is run on new loop creation, we have one
# dangling entry in the map (but only one).
self.assertEqual(new_count, 1)
@unittest.skipIf(asyncio is None, "asyncio module not present")
class AnyThreadEventLoopPolicyTest(unittest.TestCase):
def setUp(self):
self.orig_policy = asyncio.get_event_loop_policy()
self.executor = ThreadPoolExecutor(1)
def tearDown(self):
asyncio.set_event_loop_policy(self.orig_policy)
self.executor.shutdown()
def get_event_loop_on_thread(self):
def get_and_close_event_loop():
"""Get the event loop. Close it if one is returned.
Returns the (closed) event loop. This is a silly thing
to do and leaves the thread in a broken state, but it's
enough for this test. Closing the loop avoids resource
leak warnings.
"""
loop = asyncio.get_event_loop()
loop.close()
return loop
future = self.executor.submit(get_and_close_event_loop)
return future.result()
def run_policy_test(self, accessor, expected_type):
# With the default policy, non-main threads don't get an event
# loop.
self.assertRaises((RuntimeError, AssertionError),
self.executor.submit(accessor).result)
# Set the policy and we can get a loop.
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
self.assertIsInstance(
self.executor.submit(accessor).result(),
expected_type)
# Clean up to silence leak warnings. Always use asyncio since
# IOLoop doesn't (currently) close the underlying loop.
self.executor.submit(lambda: asyncio.get_event_loop().close()).result()
def test_asyncio_accessor(self):
self.run_policy_test(asyncio.get_event_loop, asyncio.AbstractEventLoop)
def test_tornado_accessor(self):
self.run_policy_test(IOLoop.current, IOLoop)

735
lib/tornado/test/auth_test.py Executable file
View File

@@ -0,0 +1,735 @@
# These tests do not currently do much to verify the correct implementation
# of the openid/oauth protocols, they just exercise the major code paths
# and ensure that it doesn't blow up (e.g. with unicode/bytes issues in
# python 3)
from __future__ import absolute_import, division, print_function
import unittest
import warnings
from tornado.auth import (
AuthError, OpenIdMixin, OAuthMixin, OAuth2Mixin,
GoogleOAuth2Mixin, FacebookGraphMixin, TwitterMixin,
)
from tornado.concurrent import Future
from tornado.escape import json_decode
from tornado import gen
from tornado.httputil import url_concat
from tornado.log import gen_log, app_log
from tornado.testing import AsyncHTTPTestCase, ExpectLog
from tornado.test.util import ignore_deprecation
from tornado.web import RequestHandler, Application, asynchronous, HTTPError
try:
from unittest import mock
except ImportError:
mock = None
class OpenIdClientLoginHandlerLegacy(RequestHandler, OpenIdMixin):
def initialize(self, test):
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
with ignore_deprecation():
@asynchronous
def get(self):
if self.get_argument('openid.mode', None):
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
self.get_authenticated_user(
self.on_user, http_client=self.settings['http_client'])
return
res = self.authenticate_redirect()
assert isinstance(res, Future)
assert res.done()
def on_user(self, user):
if user is None:
raise Exception("user is None")
self.finish(user)
class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin):
def initialize(self, test):
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
@gen.coroutine
def get(self):
if self.get_argument('openid.mode', None):
user = yield self.get_authenticated_user(http_client=self.settings['http_client'])
if user is None:
raise Exception("user is None")
self.finish(user)
return
res = self.authenticate_redirect()
assert isinstance(res, Future)
assert res.done()
class OpenIdServerAuthenticateHandler(RequestHandler):
def post(self):
if self.get_argument('openid.mode') != 'check_authentication':
raise Exception("incorrect openid.mode %r")
self.write('is_valid:true')
class OAuth1ClientLoginHandlerLegacy(RequestHandler, OAuthMixin):
def initialize(self, test, version):
self._OAUTH_VERSION = version
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token')
def _oauth_consumer_token(self):
return dict(key='asdf', secret='qwer')
with ignore_deprecation():
@asynchronous
def get(self):
if self.get_argument('oauth_token', None):
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
self.get_authenticated_user(
self.on_user, http_client=self.settings['http_client'])
return
res = self.authorize_redirect(http_client=self.settings['http_client'])
assert isinstance(res, Future)
def on_user(self, user):
if user is None:
raise Exception("user is None")
self.finish(user)
def _oauth_get_user(self, access_token, callback):
if self.get_argument('fail_in_get_user', None):
raise Exception("failing in get_user")
if access_token != dict(key='uiop', secret='5678'):
raise Exception("incorrect access token %r" % access_token)
callback(dict(email='foo@example.com'))
class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
def initialize(self, test, version):
self._OAUTH_VERSION = version
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token')
def _oauth_consumer_token(self):
return dict(key='asdf', secret='qwer')
@gen.coroutine
def get(self):
if self.get_argument('oauth_token', None):
user = yield self.get_authenticated_user(http_client=self.settings['http_client'])
if user is None:
raise Exception("user is None")
self.finish(user)
return
yield self.authorize_redirect(http_client=self.settings['http_client'])
@gen.coroutine
def _oauth_get_user_future(self, access_token):
if self.get_argument('fail_in_get_user', None):
raise Exception("failing in get_user")
if access_token != dict(key='uiop', secret='5678'):
raise Exception("incorrect access token %r" % access_token)
return dict(email='foo@example.com')
class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler):
"""Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine."""
@gen.coroutine
def get(self):
if self.get_argument('oauth_token', None):
# Ensure that any exceptions are set on the returned Future,
# not simply thrown into the surrounding StackContext.
try:
yield self.get_authenticated_user()
except Exception as e:
self.set_status(503)
self.write("got exception: %s" % e)
else:
yield self.authorize_redirect()
class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin):
def initialize(self, version):
self._OAUTH_VERSION = version
def _oauth_consumer_token(self):
return dict(key='asdf', secret='qwer')
def get(self):
params = self._oauth_request_parameters(
'http://www.example.com/api/asdf',
dict(key='uiop', secret='5678'),
parameters=dict(foo='bar'))
self.write(params)
class OAuth1ServerRequestTokenHandler(RequestHandler):
def get(self):
self.write('oauth_token=zxcv&oauth_token_secret=1234')
class OAuth1ServerAccessTokenHandler(RequestHandler):
def get(self):
self.write('oauth_token=uiop&oauth_token_secret=5678')
class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin):
def initialize(self, test):
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth2/server/authorize')
def get(self):
res = self.authorize_redirect()
assert isinstance(res, Future)
assert res.done()
class FacebookClientLoginHandler(RequestHandler, FacebookGraphMixin):
def initialize(self, test):
self._OAUTH_AUTHORIZE_URL = test.get_url('/facebook/server/authorize')
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/facebook/server/access_token')
self._FACEBOOK_BASE_URL = test.get_url('/facebook/server')
@gen.coroutine
def get(self):
if self.get_argument("code", None):
user = yield self.get_authenticated_user(
redirect_uri=self.request.full_url(),
client_id=self.settings["facebook_api_key"],
client_secret=self.settings["facebook_secret"],
code=self.get_argument("code"))
self.write(user)
else:
yield self.authorize_redirect(
redirect_uri=self.request.full_url(),
client_id=self.settings["facebook_api_key"],
extra_params={"scope": "read_stream,offline_access"})
class FacebookServerAccessTokenHandler(RequestHandler):
def get(self):
self.write(dict(access_token="asdf", expires_in=3600))
class FacebookServerMeHandler(RequestHandler):
def get(self):
self.write('{}')
class TwitterClientHandler(RequestHandler, TwitterMixin):
def initialize(self, test):
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/twitter/server/access_token')
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
self._OAUTH_AUTHENTICATE_URL = test.get_url('/twitter/server/authenticate')
self._TWITTER_BASE_URL = test.get_url('/twitter/api')
def get_auth_http_client(self):
return self.settings['http_client']
class TwitterClientLoginHandlerLegacy(TwitterClientHandler):
with ignore_deprecation():
@asynchronous
def get(self):
if self.get_argument("oauth_token", None):
self.get_authenticated_user(self.on_user)
return
self.authorize_redirect()
def on_user(self, user):
if user is None:
raise Exception("user is None")
self.finish(user)
class TwitterClientLoginHandler(TwitterClientHandler):
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
if user is None:
raise Exception("user is None")
self.finish(user)
return
yield self.authorize_redirect()
class TwitterClientAuthenticateHandler(TwitterClientHandler):
# Like TwitterClientLoginHandler, but uses authenticate_redirect
# instead of authorize_redirect.
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
if user is None:
raise Exception("user is None")
self.finish(user)
return
yield self.authenticate_redirect()
class TwitterClientLoginGenEngineHandler(TwitterClientHandler):
with ignore_deprecation():
@asynchronous
@gen.engine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
self.finish(user)
else:
# Old style: with @gen.engine we can ignore the Future from
# authorize_redirect.
self.authorize_redirect()
class TwitterClientLoginGenCoroutineHandler(TwitterClientHandler):
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
self.finish(user)
else:
# New style: with @gen.coroutine the result must be yielded
# or else the request will be auto-finished too soon.
yield self.authorize_redirect()
class TwitterClientShowUserHandlerLegacy(TwitterClientHandler):
with ignore_deprecation():
@asynchronous
@gen.engine
def get(self):
# TODO: would be nice to go through the login flow instead of
# cheating with a hard-coded access token.
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
response = yield gen.Task(self.twitter_request,
'/users/show/%s' % self.get_argument('name'),
access_token=dict(key='hjkl', secret='vbnm'))
if response is None:
self.set_status(500)
self.finish('error from twitter request')
else:
self.finish(response)
class TwitterClientShowUserHandler(TwitterClientHandler):
@gen.coroutine
def get(self):
# TODO: would be nice to go through the login flow instead of
# cheating with a hard-coded access token.
try:
response = yield self.twitter_request(
'/users/show/%s' % self.get_argument('name'),
access_token=dict(key='hjkl', secret='vbnm'))
except AuthError:
self.set_status(500)
self.finish('error from twitter request')
else:
self.finish(response)
class TwitterServerAccessTokenHandler(RequestHandler):
def get(self):
self.write('oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo')
class TwitterServerShowUserHandler(RequestHandler):
def get(self, screen_name):
if screen_name == 'error':
raise HTTPError(500)
assert 'oauth_nonce' in self.request.arguments
assert 'oauth_timestamp' in self.request.arguments
assert 'oauth_signature' in self.request.arguments
assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key'
assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1'
assert self.get_argument('oauth_version') == '1.0'
assert self.get_argument('oauth_token') == 'hjkl'
self.write(dict(screen_name=screen_name, name=screen_name.capitalize()))
class TwitterServerVerifyCredentialsHandler(RequestHandler):
def get(self):
assert 'oauth_nonce' in self.request.arguments
assert 'oauth_timestamp' in self.request.arguments
assert 'oauth_signature' in self.request.arguments
assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key'
assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1'
assert self.get_argument('oauth_version') == '1.0'
assert self.get_argument('oauth_token') == 'hjkl'
self.write(dict(screen_name='foo', name='Foo'))
class AuthTest(AsyncHTTPTestCase):
def get_app(self):
return Application(
[
# test endpoints
('/legacy/openid/client/login', OpenIdClientLoginHandlerLegacy, dict(test=self)),
('/openid/client/login', OpenIdClientLoginHandler, dict(test=self)),
('/legacy/oauth10/client/login', OAuth1ClientLoginHandlerLegacy,
dict(test=self, version='1.0')),
('/oauth10/client/login', OAuth1ClientLoginHandler,
dict(test=self, version='1.0')),
('/oauth10/client/request_params',
OAuth1ClientRequestParametersHandler,
dict(version='1.0')),
('/legacy/oauth10a/client/login', OAuth1ClientLoginHandlerLegacy,
dict(test=self, version='1.0a')),
('/oauth10a/client/login', OAuth1ClientLoginHandler,
dict(test=self, version='1.0a')),
('/oauth10a/client/login_coroutine',
OAuth1ClientLoginCoroutineHandler,
dict(test=self, version='1.0a')),
('/oauth10a/client/request_params',
OAuth1ClientRequestParametersHandler,
dict(version='1.0a')),
('/oauth2/client/login', OAuth2ClientLoginHandler, dict(test=self)),
('/facebook/client/login', FacebookClientLoginHandler, dict(test=self)),
('/legacy/twitter/client/login', TwitterClientLoginHandlerLegacy, dict(test=self)),
('/twitter/client/login', TwitterClientLoginHandler, dict(test=self)),
('/twitter/client/authenticate', TwitterClientAuthenticateHandler, dict(test=self)),
('/twitter/client/login_gen_engine',
TwitterClientLoginGenEngineHandler, dict(test=self)),
('/twitter/client/login_gen_coroutine',
TwitterClientLoginGenCoroutineHandler, dict(test=self)),
('/legacy/twitter/client/show_user',
TwitterClientShowUserHandlerLegacy, dict(test=self)),
('/twitter/client/show_user',
TwitterClientShowUserHandler, dict(test=self)),
# simulated servers
('/openid/server/authenticate', OpenIdServerAuthenticateHandler),
('/oauth1/server/request_token', OAuth1ServerRequestTokenHandler),
('/oauth1/server/access_token', OAuth1ServerAccessTokenHandler),
('/facebook/server/access_token', FacebookServerAccessTokenHandler),
('/facebook/server/me', FacebookServerMeHandler),
('/twitter/server/access_token', TwitterServerAccessTokenHandler),
(r'/twitter/api/users/show/(.*)\.json', TwitterServerShowUserHandler),
(r'/twitter/api/account/verify_credentials\.json',
TwitterServerVerifyCredentialsHandler),
],
http_client=self.http_client,
twitter_consumer_key='test_twitter_consumer_key',
twitter_consumer_secret='test_twitter_consumer_secret',
facebook_api_key='test_facebook_api_key',
facebook_secret='test_facebook_secret')
def test_openid_redirect_legacy(self):
response = self.fetch('/legacy/openid/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
'/openid/server/authenticate?' in response.headers['Location'])
def test_openid_get_user_legacy(self):
response = self.fetch('/legacy/openid/client/login?openid.mode=blah'
'&openid.ns.ax=http://openid.net/srv/ax/1.0'
'&openid.ax.type.email=http://axschema.org/contact/email'
'&openid.ax.value.email=foo@example.com')
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")
def test_openid_redirect(self):
response = self.fetch('/openid/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
'/openid/server/authenticate?' in response.headers['Location'])
def test_openid_get_user(self):
response = self.fetch('/openid/client/login?openid.mode=blah'
'&openid.ns.ax=http://openid.net/srv/ax/1.0'
'&openid.ax.type.email=http://axschema.org/contact/email'
'&openid.ax.value.email=foo@example.com')
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")
def test_oauth10_redirect_legacy(self):
response = self.fetch('/legacy/oauth10/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/oauth1/server/authorize?oauth_token=zxcv'))
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_oauth10_redirect(self):
response = self.fetch('/oauth10/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/oauth1/server/authorize?oauth_token=zxcv'))
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_oauth10_get_user_legacy(self):
with ignore_deprecation():
response = self.fetch(
'/legacy/oauth10/client/login?oauth_token=zxcv',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['email'], 'foo@example.com')
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
def test_oauth10_get_user(self):
response = self.fetch(
'/oauth10/client/login?oauth_token=zxcv',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['email'], 'foo@example.com')
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
def test_oauth10_request_parameters(self):
response = self.fetch('/oauth10/client/request_params')
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
self.assertEqual(parsed['oauth_token'], 'uiop')
self.assertTrue('oauth_nonce' in parsed)
self.assertTrue('oauth_signature' in parsed)
def test_oauth10a_redirect_legacy(self):
response = self.fetch('/legacy/oauth10a/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/oauth1/server/authorize?oauth_token=zxcv'))
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_oauth10a_get_user_legacy(self):
with ignore_deprecation():
response = self.fetch(
'/legacy/oauth10a/client/login?oauth_token=zxcv',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['email'], 'foo@example.com')
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
def test_oauth10a_redirect(self):
response = self.fetch('/oauth10a/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/oauth1/server/authorize?oauth_token=zxcv'))
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
@unittest.skipIf(mock is None, 'mock package not present')
def test_oauth10a_redirect_error(self):
with mock.patch.object(OAuth1ServerRequestTokenHandler, 'get') as get:
get.side_effect = Exception("boom")
with ExpectLog(app_log, "Uncaught exception"):
response = self.fetch('/oauth10a/client/login', follow_redirects=False)
self.assertEqual(response.code, 500)
def test_oauth10a_get_user(self):
response = self.fetch(
'/oauth10a/client/login?oauth_token=zxcv',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['email'], 'foo@example.com')
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
def test_oauth10a_request_parameters(self):
response = self.fetch('/oauth10a/client/request_params')
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
self.assertEqual(parsed['oauth_token'], 'uiop')
self.assertTrue('oauth_nonce' in parsed)
self.assertTrue('oauth_signature' in parsed)
def test_oauth10a_get_user_coroutine_exception(self):
response = self.fetch(
'/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
self.assertEqual(response.code, 503)
def test_oauth2_redirect(self):
response = self.fetch('/oauth2/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue('/oauth2/server/authorize?' in response.headers['Location'])
def test_facebook_login(self):
response = self.fetch('/facebook/client/login', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue('/facebook/server/authorize?' in response.headers['Location'])
response = self.fetch('/facebook/client/login?code=1234', follow_redirects=False)
self.assertEqual(response.code, 200)
user = json_decode(response.body)
self.assertEqual(user['access_token'], 'asdf')
self.assertEqual(user['session_expires'], '3600')
def base_twitter_redirect(self, url):
# Same as test_oauth10a_redirect
response = self.fetch(url, follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/oauth1/server/authorize?oauth_token=zxcv'))
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_twitter_redirect_legacy(self):
self.base_twitter_redirect('/legacy/twitter/client/login')
def test_twitter_redirect(self):
self.base_twitter_redirect('/twitter/client/login')
def test_twitter_redirect_gen_engine(self):
self.base_twitter_redirect('/twitter/client/login_gen_engine')
def test_twitter_redirect_gen_coroutine(self):
self.base_twitter_redirect('/twitter/client/login_gen_coroutine')
def test_twitter_authenticate_redirect(self):
response = self.fetch('/twitter/client/authenticate', follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(response.headers['Location'].endswith(
'/twitter/server/authenticate?oauth_token=zxcv'), response.headers['Location'])
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
response.headers['Set-Cookie'])
def test_twitter_get_user(self):
response = self.fetch(
'/twitter/client/login?oauth_token=zxcv',
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed,
{u'access_token': {u'key': u'hjkl',
u'screen_name': u'foo',
u'secret': u'vbnm'},
u'name': u'Foo',
u'screen_name': u'foo',
u'username': u'foo'})
def test_twitter_show_user_legacy(self):
response = self.fetch('/legacy/twitter/client/show_user?name=somebody')
response.rethrow()
self.assertEqual(json_decode(response.body),
{'name': 'Somebody', 'screen_name': 'somebody'})
def test_twitter_show_user_error_legacy(self):
with ExpectLog(gen_log, 'Error response HTTP 500'):
response = self.fetch('/legacy/twitter/client/show_user?name=error')
self.assertEqual(response.code, 500)
self.assertEqual(response.body, b'error from twitter request')
def test_twitter_show_user(self):
response = self.fetch('/twitter/client/show_user?name=somebody')
response.rethrow()
self.assertEqual(json_decode(response.body),
{'name': 'Somebody', 'screen_name': 'somebody'})
def test_twitter_show_user_error(self):
response = self.fetch('/twitter/client/show_user?name=error')
self.assertEqual(response.code, 500)
self.assertEqual(response.body, b'error from twitter request')
class GoogleLoginHandler(RequestHandler, GoogleOAuth2Mixin):
def initialize(self, test):
self.test = test
self._OAUTH_REDIRECT_URI = test.get_url('/client/login')
self._OAUTH_AUTHORIZE_URL = test.get_url('/google/oauth2/authorize')
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/google/oauth2/token')
@gen.coroutine
def get(self):
code = self.get_argument('code', None)
if code is not None:
# retrieve authenticate google user
access = yield self.get_authenticated_user(self._OAUTH_REDIRECT_URI,
code)
user = yield self.oauth2_request(
self.test.get_url("/google/oauth2/userinfo"),
access_token=access["access_token"])
# return the user and access token as json
user["access_token"] = access["access_token"]
self.write(user)
else:
yield self.authorize_redirect(
redirect_uri=self._OAUTH_REDIRECT_URI,
client_id=self.settings['google_oauth']['key'],
client_secret=self.settings['google_oauth']['secret'],
scope=['profile', 'email'],
response_type='code',
extra_params={'prompt': 'select_account'})
class GoogleOAuth2AuthorizeHandler(RequestHandler):
def get(self):
# issue a fake auth code and redirect to redirect_uri
code = 'fake-authorization-code'
self.redirect(url_concat(self.get_argument('redirect_uri'),
dict(code=code)))
class GoogleOAuth2TokenHandler(RequestHandler):
def post(self):
assert self.get_argument('code') == 'fake-authorization-code'
# issue a fake token
self.finish({
'access_token': 'fake-access-token',
'expires_in': 'never-expires'
})
class GoogleOAuth2UserinfoHandler(RequestHandler):
def get(self):
assert self.get_argument('access_token') == 'fake-access-token'
# return a fake user
self.finish({
'name': 'Foo',
'email': 'foo@example.com'
})
class GoogleOAuth2Test(AsyncHTTPTestCase):
def get_app(self):
return Application(
[
# test endpoints
('/client/login', GoogleLoginHandler, dict(test=self)),
# simulated google authorization server endpoints
('/google/oauth2/authorize', GoogleOAuth2AuthorizeHandler),
('/google/oauth2/token', GoogleOAuth2TokenHandler),
('/google/oauth2/userinfo', GoogleOAuth2UserinfoHandler),
],
google_oauth={
"key": 'fake_google_client_id',
"secret": 'fake_google_client_secret'
})
def test_google_login(self):
response = self.fetch('/client/login')
self.assertDictEqual({
u'name': u'Foo',
u'email': u'foo@example.com',
u'access_token': u'fake-access-token',
}, json_decode(response.body))

View File

@@ -0,0 +1,114 @@
from __future__ import absolute_import, division, print_function
import os
import shutil
import subprocess
from subprocess import Popen
import sys
from tempfile import mkdtemp
import time
from tornado.test.util import unittest
class AutoreloadTest(unittest.TestCase):
def test_reload_module(self):
main = """\
import os
import sys
from tornado import autoreload
# This import will fail if path is not set up correctly
import testapp
print('Starting')
if 'TESTAPP_STARTED' not in os.environ:
os.environ['TESTAPP_STARTED'] = '1'
sys.stdout.flush()
autoreload._reload()
"""
# Create temporary test application
path = mkdtemp()
self.addCleanup(shutil.rmtree, path)
os.mkdir(os.path.join(path, 'testapp'))
open(os.path.join(path, 'testapp/__init__.py'), 'w').close()
with open(os.path.join(path, 'testapp/__main__.py'), 'w') as f:
f.write(main)
# Make sure the tornado module under test is available to the test
# application
pythonpath = os.getcwd()
if 'PYTHONPATH' in os.environ:
pythonpath += os.pathsep + os.environ['PYTHONPATH']
p = Popen(
[sys.executable, '-m', 'testapp'], stdout=subprocess.PIPE,
cwd=path, env=dict(os.environ, PYTHONPATH=pythonpath),
universal_newlines=True)
out = p.communicate()[0]
self.assertEqual(out, 'Starting\nStarting\n')
def test_reload_wrapper_preservation(self):
# This test verifies that when `python -m tornado.autoreload`
# is used on an application that also has an internal
# autoreload, the reload wrapper is preserved on restart.
main = """\
import os
import sys
# This import will fail if path is not set up correctly
import testapp
if 'tornado.autoreload' not in sys.modules:
raise Exception('started without autoreload wrapper')
import tornado.autoreload
print('Starting')
sys.stdout.flush()
if 'TESTAPP_STARTED' not in os.environ:
os.environ['TESTAPP_STARTED'] = '1'
# Simulate an internal autoreload (one not caused
# by the wrapper).
tornado.autoreload._reload()
else:
# Exit directly so autoreload doesn't catch it.
os._exit(0)
"""
# Create temporary test application
path = mkdtemp()
os.mkdir(os.path.join(path, 'testapp'))
self.addCleanup(shutil.rmtree, path)
init_file = os.path.join(path, 'testapp', '__init__.py')
open(init_file, 'w').close()
main_file = os.path.join(path, 'testapp', '__main__.py')
with open(main_file, 'w') as f:
f.write(main)
# Make sure the tornado module under test is available to the test
# application
pythonpath = os.getcwd()
if 'PYTHONPATH' in os.environ:
pythonpath += os.pathsep + os.environ['PYTHONPATH']
autoreload_proc = Popen(
[sys.executable, '-m', 'tornado.autoreload', '-m', 'testapp'],
stdout=subprocess.PIPE, cwd=path,
env=dict(os.environ, PYTHONPATH=pythonpath),
universal_newlines=True)
# This timeout needs to be fairly generous for pypy due to jit
# warmup costs.
for i in range(40):
if autoreload_proc.poll() is not None:
break
time.sleep(0.1)
else:
autoreload_proc.kill()
raise Exception("subprocess failed to terminate")
out = autoreload_proc.communicate()[0]
self.assertEqual(out, 'Starting\n' * 2)

View File

@@ -0,0 +1,496 @@
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function
import gc
import logging
import re
import socket
import sys
import traceback
import warnings
from tornado.concurrent import (Future, return_future, ReturnValueIgnoredError,
run_on_executor, future_set_result_unless_cancelled)
from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado.log import app_log
from tornado import stack_context
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
from tornado.test.util import unittest, skipBefore35, exec_test, ignore_deprecation
try:
from concurrent import futures
except ImportError:
futures = None
class MiscFutureTest(AsyncTestCase):
def test_future_set_result_unless_cancelled(self):
fut = Future()
future_set_result_unless_cancelled(fut, 42)
self.assertEqual(fut.result(), 42)
self.assertFalse(fut.cancelled())
fut = Future()
fut.cancel()
is_cancelled = fut.cancelled()
future_set_result_unless_cancelled(fut, 42)
self.assertEqual(fut.cancelled(), is_cancelled)
if not is_cancelled:
self.assertEqual(fut.result(), 42)
class ReturnFutureTest(AsyncTestCase):
with ignore_deprecation():
@return_future
def sync_future(self, callback):
callback(42)
@return_future
def async_future(self, callback):
self.io_loop.add_callback(callback, 42)
@return_future
def immediate_failure(self, callback):
1 / 0
@return_future
def delayed_failure(self, callback):
self.io_loop.add_callback(lambda: 1 / 0)
@return_future
def return_value(self, callback):
# Note that the result of both running the callback and returning
# a value (or raising an exception) is unspecified; with current
# implementations the last event prior to callback resolution wins.
return 42
@return_future
def no_result_future(self, callback):
callback()
def test_immediate_failure(self):
with self.assertRaises(ZeroDivisionError):
# The caller sees the error just like a normal function.
self.immediate_failure(callback=self.stop)
# The callback is not run because the function failed synchronously.
self.io_loop.add_timeout(self.io_loop.time() + 0.05, self.stop)
result = self.wait()
self.assertIs(result, None)
def test_return_value(self):
with self.assertRaises(ReturnValueIgnoredError):
self.return_value(callback=self.stop)
def test_callback_kw(self):
with ignore_deprecation():
future = self.sync_future(callback=self.stop)
result = self.wait()
self.assertEqual(result, 42)
self.assertEqual(future.result(), 42)
def test_callback_positional(self):
# When the callback is passed in positionally, future_wrap shouldn't
# add another callback in the kwargs.
with ignore_deprecation():
future = self.sync_future(self.stop)
result = self.wait()
self.assertEqual(result, 42)
self.assertEqual(future.result(), 42)
def test_no_callback(self):
future = self.sync_future()
self.assertEqual(future.result(), 42)
def test_none_callback_kw(self):
# explicitly pass None as callback
future = self.sync_future(callback=None)
self.assertEqual(future.result(), 42)
def test_none_callback_pos(self):
future = self.sync_future(None)
self.assertEqual(future.result(), 42)
def test_async_future(self):
future = self.async_future()
self.assertFalse(future.done())
self.io_loop.add_future(future, self.stop)
future2 = self.wait()
self.assertIs(future, future2)
self.assertEqual(future.result(), 42)
@gen_test
def test_async_future_gen(self):
result = yield self.async_future()
self.assertEqual(result, 42)
def test_delayed_failure(self):
future = self.delayed_failure()
with ignore_deprecation():
self.io_loop.add_future(future, self.stop)
future2 = self.wait()
self.assertIs(future, future2)
with self.assertRaises(ZeroDivisionError):
future.result()
def test_kw_only_callback(self):
with ignore_deprecation():
@return_future
def f(**kwargs):
kwargs['callback'](42)
future = f()
self.assertEqual(future.result(), 42)
def test_error_in_callback(self):
with ignore_deprecation():
self.sync_future(callback=lambda future: 1 / 0)
# The exception gets caught by our StackContext and will be re-raised
# when we wait.
self.assertRaises(ZeroDivisionError, self.wait)
def test_no_result_future(self):
with ignore_deprecation():
future = self.no_result_future(self.stop)
result = self.wait()
self.assertIs(result, None)
# result of this future is undefined, but not an error
future.result()
def test_no_result_future_callback(self):
with ignore_deprecation():
future = self.no_result_future(callback=lambda: self.stop())
result = self.wait()
self.assertIs(result, None)
future.result()
@gen_test
def test_future_traceback_legacy(self):
with ignore_deprecation():
@return_future
@gen.engine
def f(callback):
yield gen.Task(self.io_loop.add_callback)
try:
1 / 0
except ZeroDivisionError:
self.expected_frame = traceback.extract_tb(
sys.exc_info()[2], limit=1)[0]
raise
try:
yield f()
self.fail("didn't get expected exception")
except ZeroDivisionError:
tb = traceback.extract_tb(sys.exc_info()[2])
self.assertIn(self.expected_frame, tb)
@gen_test
def test_future_traceback(self):
@gen.coroutine
def f():
yield gen.moment
try:
1 / 0
except ZeroDivisionError:
self.expected_frame = traceback.extract_tb(
sys.exc_info()[2], limit=1)[0]
raise
try:
yield f()
self.fail("didn't get expected exception")
except ZeroDivisionError:
tb = traceback.extract_tb(sys.exc_info()[2])
self.assertIn(self.expected_frame, tb)
@gen_test
def test_uncaught_exception_log(self):
if IOLoop.configured_class().__name__.endswith('AsyncIOLoop'):
# Install an exception handler that mirrors our
# non-asyncio logging behavior.
def exc_handler(loop, context):
app_log.error('%s: %s', context['message'],
type(context.get('exception')))
self.io_loop.asyncio_loop.set_exception_handler(exc_handler)
@gen.coroutine
def f():
yield gen.moment
1 / 0
g = f()
with ExpectLog(app_log,
"(?s)Future.* exception was never retrieved:"
".*ZeroDivisionError"):
yield gen.moment
yield gen.moment
# For some reason, TwistedIOLoop and pypy3 need a third iteration
# in order to drain references to the future
yield gen.moment
del g
gc.collect() # for PyPy
# The following series of classes demonstrate and test various styles
# of use, with and without generators and futures.
class CapServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
data = yield stream.read_until(b"\n")
data = to_unicode(data)
if data == data.upper():
stream.write(b"error\talready capitalized\n")
else:
# data already has \n
stream.write(utf8("ok\t%s" % data.upper()))
stream.close()
class CapError(Exception):
pass
class BaseCapClient(object):
def __init__(self, port):
self.port = port
def process_response(self, data):
status, message = re.match('(.*)\t(.*)\n', to_unicode(data)).groups()
if status == 'ok':
return message
else:
raise CapError(message)
class ManualCapClient(BaseCapClient):
def capitalize(self, request_data, callback=None):
logging.debug("capitalize")
self.request_data = request_data
self.stream = IOStream(socket.socket())
self.stream.connect(('127.0.0.1', self.port),
callback=self.handle_connect)
self.future = Future()
if callback is not None:
self.future.add_done_callback(
stack_context.wrap(lambda future: callback(future.result())))
return self.future
def handle_connect(self):
logging.debug("handle_connect")
self.stream.write(utf8(self.request_data + "\n"))
self.stream.read_until(b'\n', callback=self.handle_read)
def handle_read(self, data):
logging.debug("handle_read")
self.stream.close()
try:
self.future.set_result(self.process_response(data))
except CapError as e:
self.future.set_exception(e)
class DecoratorCapClient(BaseCapClient):
with ignore_deprecation():
@return_future
def capitalize(self, request_data, callback):
logging.debug("capitalize")
self.request_data = request_data
self.stream = IOStream(socket.socket())
self.stream.connect(('127.0.0.1', self.port),
callback=self.handle_connect)
self.callback = callback
def handle_connect(self):
logging.debug("handle_connect")
self.stream.write(utf8(self.request_data + "\n"))
self.stream.read_until(b'\n', callback=self.handle_read)
def handle_read(self, data):
logging.debug("handle_read")
self.stream.close()
self.callback(self.process_response(data))
class GeneratorCapClient(BaseCapClient):
@gen.coroutine
def capitalize(self, request_data):
logging.debug('capitalize')
stream = IOStream(socket.socket())
logging.debug('connecting')
yield stream.connect(('127.0.0.1', self.port))
stream.write(utf8(request_data + '\n'))
logging.debug('reading')
data = yield stream.read_until(b'\n')
logging.debug('returning')
stream.close()
raise gen.Return(self.process_response(data))
class ClientTestMixin(object):
def setUp(self):
super(ClientTestMixin, self).setUp() # type: ignore
self.server = CapServer()
sock, port = bind_unused_port()
self.server.add_sockets([sock])
self.client = self.client_class(port=port)
def tearDown(self):
self.server.stop()
super(ClientTestMixin, self).tearDown() # type: ignore
def test_callback(self):
with ignore_deprecation():
self.client.capitalize("hello", callback=self.stop)
result = self.wait()
self.assertEqual(result, "HELLO")
def test_callback_error(self):
with ignore_deprecation():
self.client.capitalize("HELLO", callback=self.stop)
self.assertRaisesRegexp(CapError, "already capitalized", self.wait)
def test_future(self):
future = self.client.capitalize("hello")
self.io_loop.add_future(future, self.stop)
self.wait()
self.assertEqual(future.result(), "HELLO")
def test_future_error(self):
future = self.client.capitalize("HELLO")
self.io_loop.add_future(future, self.stop)
self.wait()
self.assertRaisesRegexp(CapError, "already capitalized", future.result)
def test_generator(self):
@gen.coroutine
def f():
result = yield self.client.capitalize("hello")
self.assertEqual(result, "HELLO")
self.io_loop.run_sync(f)
def test_generator_error(self):
@gen.coroutine
def f():
with self.assertRaisesRegexp(CapError, "already capitalized"):
yield self.client.capitalize("HELLO")
self.io_loop.run_sync(f)
class ManualClientTest(ClientTestMixin, AsyncTestCase):
client_class = ManualCapClient
def setUp(self):
self.warning_catcher = warnings.catch_warnings()
self.warning_catcher.__enter__()
warnings.simplefilter('ignore', DeprecationWarning)
super(ManualClientTest, self).setUp()
def tearDown(self):
super(ManualClientTest, self).tearDown()
self.warning_catcher.__exit__(None, None, None)
class DecoratorClientTest(ClientTestMixin, AsyncTestCase):
client_class = DecoratorCapClient
def setUp(self):
self.warning_catcher = warnings.catch_warnings()
self.warning_catcher.__enter__()
warnings.simplefilter('ignore', DeprecationWarning)
super(DecoratorClientTest, self).setUp()
def tearDown(self):
super(DecoratorClientTest, self).tearDown()
self.warning_catcher.__exit__(None, None, None)
class GeneratorClientTest(ClientTestMixin, AsyncTestCase):
client_class = GeneratorCapClient
@unittest.skipIf(futures is None, "concurrent.futures module not present")
class RunOnExecutorTest(AsyncTestCase):
@gen_test
def test_no_calling(self):
class Object(object):
def __init__(self):
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor
def f(self):
return 42
o = Object()
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_no_args(self):
class Object(object):
def __init__(self):
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor()
def f(self):
return 42
o = Object()
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_executor(self):
class Object(object):
def __init__(self):
self.__executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor(executor='_Object__executor')
def f(self):
return 42
o = Object()
answer = yield o.f()
self.assertEqual(answer, 42)
@skipBefore35
@gen_test
def test_async_await(self):
class Object(object):
def __init__(self):
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor()
def f(self):
return 42
o = Object()
namespace = exec_test(globals(), locals(), """
async def f():
answer = await o.f()
return answer
""")
result = yield namespace['f']()
self.assertEqual(result, 42)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1 @@
"school","école"
1 school école

View File

@@ -0,0 +1,153 @@
# coding: utf-8
from __future__ import absolute_import, division, print_function
from hashlib import md5
from tornado.escape import utf8
from tornado.httpclient import HTTPRequest, HTTPClientError
from tornado.locks import Event
from tornado.stack_context import ExceptionStackContext
from tornado.testing import AsyncHTTPTestCase, gen_test
from tornado.test import httpclient_test
from tornado.test.util import unittest, ignore_deprecation
from tornado.web import Application, RequestHandler
try:
import pycurl # type: ignore
except ImportError:
pycurl = None
if pycurl is not None:
from tornado.curl_httpclient import CurlAsyncHTTPClient
@unittest.skipIf(pycurl is None, "pycurl module not present")
class CurlHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
def get_http_client(self):
client = CurlAsyncHTTPClient(defaults=dict(allow_ipv6=False))
# make sure AsyncHTTPClient magic doesn't give us the wrong class
self.assertTrue(isinstance(client, CurlAsyncHTTPClient))
return client
class DigestAuthHandler(RequestHandler):
def initialize(self, username, password):
self.username = username
self.password = password
def get(self):
realm = 'test'
opaque = 'asdf'
# Real implementations would use a random nonce.
nonce = "1234"
auth_header = self.request.headers.get('Authorization', None)
if auth_header is not None:
auth_mode, params = auth_header.split(' ', 1)
assert auth_mode == 'Digest'
param_dict = {}
for pair in params.split(','):
k, v = pair.strip().split('=', 1)
if v[0] == '"' and v[-1] == '"':
v = v[1:-1]
param_dict[k] = v
assert param_dict['realm'] == realm
assert param_dict['opaque'] == opaque
assert param_dict['nonce'] == nonce
assert param_dict['username'] == self.username
assert param_dict['uri'] == self.request.path
h1 = md5(utf8('%s:%s:%s' % (self.username, realm, self.password))).hexdigest()
h2 = md5(utf8('%s:%s' % (self.request.method,
self.request.path))).hexdigest()
digest = md5(utf8('%s:%s:%s' % (h1, nonce, h2))).hexdigest()
if digest == param_dict['response']:
self.write('ok')
else:
self.write('fail')
else:
self.set_status(401)
self.set_header('WWW-Authenticate',
'Digest realm="%s", nonce="%s", opaque="%s"' %
(realm, nonce, opaque))
class CustomReasonHandler(RequestHandler):
def get(self):
self.set_status(200, "Custom reason")
class CustomFailReasonHandler(RequestHandler):
def get(self):
self.set_status(400, "Custom reason")
@unittest.skipIf(pycurl is None, "pycurl module not present")
class CurlHTTPClientTestCase(AsyncHTTPTestCase):
def setUp(self):
super(CurlHTTPClientTestCase, self).setUp()
self.http_client = self.create_client()
def get_app(self):
return Application([
('/digest', DigestAuthHandler, {'username': 'foo', 'password': 'bar'}),
('/digest_non_ascii', DigestAuthHandler, {'username': 'foo', 'password': 'barユ£'}),
('/custom_reason', CustomReasonHandler),
('/custom_fail_reason', CustomFailReasonHandler),
])
def create_client(self, **kwargs):
return CurlAsyncHTTPClient(force_instance=True,
defaults=dict(allow_ipv6=False),
**kwargs)
@gen_test
def test_prepare_curl_callback_stack_context(self):
exc_info = []
error_event = Event()
def error_handler(typ, value, tb):
exc_info.append((typ, value, tb))
error_event.set()
return True
with ignore_deprecation():
with ExceptionStackContext(error_handler):
request = HTTPRequest(self.get_url('/custom_reason'),
prepare_curl_callback=lambda curl: 1 / 0)
yield [error_event.wait(), self.http_client.fetch(request)]
self.assertEqual(1, len(exc_info))
self.assertIs(exc_info[0][0], ZeroDivisionError)
def test_digest_auth(self):
response = self.fetch('/digest', auth_mode='digest',
auth_username='foo', auth_password='bar')
self.assertEqual(response.body, b'ok')
def test_custom_reason(self):
response = self.fetch('/custom_reason')
self.assertEqual(response.reason, "Custom reason")
def test_fail_custom_reason(self):
response = self.fetch('/custom_fail_reason')
self.assertEqual(str(response.error), "HTTP 400: Custom reason")
def test_failed_setup(self):
self.http_client = self.create_client(max_clients=1)
for i in range(5):
with ignore_deprecation():
response = self.fetch(u'/ユニコード')
self.assertIsNot(response.error, None)
with self.assertRaises((UnicodeEncodeError, HTTPClientError)):
# This raises UnicodeDecodeError on py3 and
# HTTPClientError(404) on py2. The main motivation of
# this test is to ensure that the UnicodeEncodeError
# during the setup phase doesn't lead the request to
# be dropped on the floor.
response = self.fetch(u'/ユニコード', raise_error=True)
def test_digest_auth_non_ascii(self):
response = self.fetch('/digest_non_ascii', auth_mode='digest',
auth_username='foo', auth_password='barユ£')
self.assertEqual(response.body, b'ok')

250
lib/tornado/test/escape_test.py Executable file
View File

@@ -0,0 +1,250 @@
from __future__ import absolute_import, division, print_function
import tornado.escape
from tornado.escape import (
utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape,
to_unicode, json_decode, json_encode, squeeze, recursive_unicode,
)
from tornado.util import unicode_type
from tornado.test.util import unittest
linkify_tests = [
# (input, linkify_kwargs, expected_output)
("hello http://world.com/!", {},
u'hello <a href="http://world.com/">http://world.com/</a>!'),
("hello http://world.com/with?param=true&stuff=yes", {},
u'hello <a href="http://world.com/with?param=true&amp;stuff=yes">http://world.com/with?param=true&amp;stuff=yes</a>'), # noqa: E501
# an opened paren followed by many chars killed Gruber's regex
("http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", {},
u'<a href="http://url.com/w">http://url.com/w</a>(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'), # noqa: E501
# as did too many dots at the end
("http://url.com/withmany.......................................", {},
u'<a href="http://url.com/withmany">http://url.com/withmany</a>.......................................'), # noqa: E501
("http://url.com/withmany((((((((((((((((((((((((((((((((((a)", {},
u'<a href="http://url.com/withmany">http://url.com/withmany</a>((((((((((((((((((((((((((((((((((a)'), # noqa: E501
# some examples from http://daringfireball.net/2009/11/liberal_regex_for_matching_urls
# plus a fex extras (such as multiple parentheses).
("http://foo.com/blah_blah", {},
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>'),
("http://foo.com/blah_blah/", {},
u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>'),
("(Something like http://foo.com/blah_blah)", {},
u'(Something like <a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>)'),
("http://foo.com/blah_blah_(wikipedia)", {},
u'<a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>'),
("http://foo.com/blah_(blah)_(wikipedia)_blah", {},
u'<a href="http://foo.com/blah_(blah)_(wikipedia)_blah">http://foo.com/blah_(blah)_(wikipedia)_blah</a>'), # noqa: E501
("(Something like http://foo.com/blah_blah_(wikipedia))", {},
u'(Something like <a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>)'), # noqa: E501
("http://foo.com/blah_blah.", {},
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>.'),
("http://foo.com/blah_blah/.", {},
u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>.'),
("<http://foo.com/blah_blah>", {},
u'&lt;<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>&gt;'),
("<http://foo.com/blah_blah/>", {},
u'&lt;<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>&gt;'),
("http://foo.com/blah_blah,", {},
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>,'),
("http://www.example.com/wpstyle/?p=364.", {},
u'<a href="http://www.example.com/wpstyle/?p=364">http://www.example.com/wpstyle/?p=364</a>.'),
("rdar://1234",
{"permitted_protocols": ["http", "rdar"]},
u'<a href="rdar://1234">rdar://1234</a>'),
("rdar:/1234",
{"permitted_protocols": ["rdar"]},
u'<a href="rdar:/1234">rdar:/1234</a>'),
("http://userid:password@example.com:8080", {},
u'<a href="http://userid:password@example.com:8080">http://userid:password@example.com:8080</a>'), # noqa: E501
("http://userid@example.com", {},
u'<a href="http://userid@example.com">http://userid@example.com</a>'),
("http://userid@example.com:8080", {},
u'<a href="http://userid@example.com:8080">http://userid@example.com:8080</a>'),
("http://userid:password@example.com", {},
u'<a href="http://userid:password@example.com">http://userid:password@example.com</a>'),
("message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e",
{"permitted_protocols": ["http", "message"]},
u'<a href="message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e">'
u'message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e</a>'),
(u"http://\u27a1.ws/\u4a39", {},
u'<a href="http://\u27a1.ws/\u4a39">http://\u27a1.ws/\u4a39</a>'),
("<tag>http://example.com</tag>", {},
u'&lt;tag&gt;<a href="http://example.com">http://example.com</a>&lt;/tag&gt;'),
("Just a www.example.com link.", {},
u'Just a <a href="http://www.example.com">www.example.com</a> link.'),
("Just a www.example.com link.",
{"require_protocol": True},
u'Just a www.example.com link.'),
("A http://reallylong.com/link/that/exceedsthelenglimit.html",
{"require_protocol": True, "shorten": True},
u'A <a href="http://reallylong.com/link/that/exceedsthelenglimit.html"'
u' title="http://reallylong.com/link/that/exceedsthelenglimit.html">http://reallylong.com/link...</a>'), # noqa: E501
("A http://reallylongdomainnamethatwillbetoolong.com/hi!",
{"shorten": True},
u'A <a href="http://reallylongdomainnamethatwillbetoolong.com/hi"'
u' title="http://reallylongdomainnamethatwillbetoolong.com/hi">http://reallylongdomainnametha...</a>!'), # noqa: E501
("A file:///passwords.txt and http://web.com link", {},
u'A file:///passwords.txt and <a href="http://web.com">http://web.com</a> link'),
("A file:///passwords.txt and http://web.com link",
{"permitted_protocols": ["file"]},
u'A <a href="file:///passwords.txt">file:///passwords.txt</a> and http://web.com link'),
("www.external-link.com",
{"extra_params": 'rel="nofollow" class="external"'},
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>'), # noqa: E501
("www.external-link.com and www.internal-link.com/blogs extra",
{"extra_params": lambda href: 'class="internal"' if href.startswith("http://www.internal-link.com") else 'rel="nofollow" class="external"'}, # noqa: E501
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>' # noqa: E501
u' and <a href="http://www.internal-link.com/blogs" class="internal">www.internal-link.com/blogs</a> extra'), # noqa: E501
("www.external-link.com",
{"extra_params": lambda href: ' rel="nofollow" class="external" '},
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>'), # noqa: E501
]
class EscapeTestCase(unittest.TestCase):
def test_linkify(self):
for text, kwargs, html in linkify_tests:
linked = tornado.escape.linkify(text, **kwargs)
self.assertEqual(linked, html)
def test_xhtml_escape(self):
tests = [
("<foo>", "&lt;foo&gt;"),
(u"<foo>", u"&lt;foo&gt;"),
(b"<foo>", b"&lt;foo&gt;"),
("<>&\"'", "&lt;&gt;&amp;&quot;&#39;"),
("&amp;", "&amp;amp;"),
(u"<\u00e9>", u"&lt;\u00e9&gt;"),
(b"<\xc3\xa9>", b"&lt;\xc3\xa9&gt;"),
]
for unescaped, escaped in tests:
self.assertEqual(utf8(xhtml_escape(unescaped)), utf8(escaped))
self.assertEqual(utf8(unescaped), utf8(xhtml_unescape(escaped)))
def test_xhtml_unescape_numeric(self):
tests = [
('foo&#32;bar', 'foo bar'),
('foo&#x20;bar', 'foo bar'),
('foo&#X20;bar', 'foo bar'),
('foo&#xabc;bar', u'foo\u0abcbar'),
('foo&#xyz;bar', 'foo&#xyz;bar'), # invalid encoding
('foo&#;bar', 'foo&#;bar'), # invalid encoding
('foo&#x;bar', 'foo&#x;bar'), # invalid encoding
]
for escaped, unescaped in tests:
self.assertEqual(unescaped, xhtml_unescape(escaped))
def test_url_escape_unicode(self):
tests = [
# byte strings are passed through as-is
(u'\u00e9'.encode('utf8'), '%C3%A9'),
(u'\u00e9'.encode('latin1'), '%E9'),
# unicode strings become utf8
(u'\u00e9', '%C3%A9'),
]
for unescaped, escaped in tests:
self.assertEqual(url_escape(unescaped), escaped)
def test_url_unescape_unicode(self):
tests = [
('%C3%A9', u'\u00e9', 'utf8'),
('%C3%A9', u'\u00c3\u00a9', 'latin1'),
('%C3%A9', utf8(u'\u00e9'), None),
]
for escaped, unescaped, encoding in tests:
# input strings to url_unescape should only contain ascii
# characters, but make sure the function accepts both byte
# and unicode strings.
self.assertEqual(url_unescape(to_unicode(escaped), encoding), unescaped)
self.assertEqual(url_unescape(utf8(escaped), encoding), unescaped)
def test_url_escape_quote_plus(self):
unescaped = '+ #%'
plus_escaped = '%2B+%23%25'
escaped = '%2B%20%23%25'
self.assertEqual(url_escape(unescaped), plus_escaped)
self.assertEqual(url_escape(unescaped, plus=False), escaped)
self.assertEqual(url_unescape(plus_escaped), unescaped)
self.assertEqual(url_unescape(escaped, plus=False), unescaped)
self.assertEqual(url_unescape(plus_escaped, encoding=None),
utf8(unescaped))
self.assertEqual(url_unescape(escaped, encoding=None, plus=False),
utf8(unescaped))
def test_escape_return_types(self):
# On python2 the escape methods should generally return the same
# type as their argument
self.assertEqual(type(xhtml_escape("foo")), str)
self.assertEqual(type(xhtml_escape(u"foo")), unicode_type)
def test_json_decode(self):
# json_decode accepts both bytes and unicode, but strings it returns
# are always unicode.
self.assertEqual(json_decode(b'"foo"'), u"foo")
self.assertEqual(json_decode(u'"foo"'), u"foo")
# Non-ascii bytes are interpreted as utf8
self.assertEqual(json_decode(utf8(u'"\u00e9"')), u"\u00e9")
def test_json_encode(self):
# json deals with strings, not bytes. On python 2 byte strings will
# convert automatically if they are utf8; on python 3 byte strings
# are not allowed.
self.assertEqual(json_decode(json_encode(u"\u00e9")), u"\u00e9")
if bytes is str:
self.assertEqual(json_decode(json_encode(utf8(u"\u00e9"))), u"\u00e9")
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")
def test_squeeze(self):
self.assertEqual(squeeze(u'sequences of whitespace chars'),
u'sequences of whitespace chars')
def test_recursive_unicode(self):
tests = {
'dict': {b"foo": b"bar"},
'list': [b"foo", b"bar"],
'tuple': (b"foo", b"bar"),
'bytes': b"foo"
}
self.assertEqual(recursive_unicode(tests['dict']), {u"foo": u"bar"})
self.assertEqual(recursive_unicode(tests['list']), [u"foo", u"bar"])
self.assertEqual(recursive_unicode(tests['tuple']), (u"foo", u"bar"))
self.assertEqual(recursive_unicode(tests['bytes']), u"foo")

1862
lib/tornado/test/gen_test.py Executable file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,16 @@
# flake8: noqa
# Dummy source file to allow creation of the initial .po file in the
# same way as a real project. I'm not entirely sure about the real
# workflow here, but this seems to work.
#
# 1) xgettext --language=Python --keyword=_:1,2 --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3 extract_me.py -o tornado_test.po
# 2) Edit tornado_test.po, setting CHARSET, Plural-Forms and setting msgstr
# 3) msgfmt tornado_test.po -o tornado_test.mo
# 4) Put the file in the proper location: $LANG/LC_MESSAGES
from __future__ import absolute_import, division, print_function
_("school")
pgettext("law", "right")
pgettext("good", "right")
pgettext("organization", "club", "clubs", 1)
pgettext("stick", "club", "clubs", 1)

View File

@@ -0,0 +1,47 @@
# SOME DESCRIPTIVE TITLE.
# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER
# This file is distributed under the same license as the PACKAGE package.
# FIRST AUTHOR <EMAIL@ADDRESS>, YEAR.
#
#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2015-01-27 11:05+0300\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
"Language: \n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Plural-Forms: nplurals=2; plural=(n > 1);\n"
#: extract_me.py:11
msgid "school"
msgstr "école"
#: extract_me.py:12
msgctxt "law"
msgid "right"
msgstr "le droit"
#: extract_me.py:13
msgctxt "good"
msgid "right"
msgstr "le bien"
#: extract_me.py:14
msgctxt "organization"
msgid "club"
msgid_plural "clubs"
msgstr[0] "le club"
msgstr[1] "les clubs"
#: extract_me.py:15
msgctxt "stick"
msgid "club"
msgid_plural "clubs"
msgstr[0] "le bâton"
msgstr[1] "les bâtons"

View File

@@ -0,0 +1,61 @@
from __future__ import absolute_import, division, print_function
import socket
from tornado.http1connection import HTTP1Connection
from tornado.httputil import HTTPMessageDelegate
from tornado.iostream import IOStream
from tornado.locks import Event
from tornado.netutil import add_accept_handler
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
class HTTP1ConnectionTest(AsyncTestCase):
def setUp(self):
super(HTTP1ConnectionTest, self).setUp()
self.asyncSetUp()
@gen_test
def asyncSetUp(self):
listener, port = bind_unused_port()
event = Event()
def accept_callback(conn, addr):
self.server_stream = IOStream(conn)
self.addCleanup(self.server_stream.close)
event.set()
add_accept_handler(listener, accept_callback)
self.client_stream = IOStream(socket.socket())
self.addCleanup(self.client_stream.close)
yield [self.client_stream.connect(('127.0.0.1', port)),
event.wait()]
self.io_loop.remove_handler(listener)
listener.close()
@gen_test
def test_http10_no_content_length(self):
# Regression test for a bug in which can_keep_alive would crash
# for an HTTP/1.0 (not 1.1) response with no content-length.
conn = HTTP1Connection(self.client_stream, True)
self.server_stream.write(b"HTTP/1.0 200 Not Modified\r\n\r\nhello")
self.server_stream.close()
event = Event()
test = self
body = []
class Delegate(HTTPMessageDelegate):
def headers_received(self, start_line, headers):
test.code = start_line.code
def data_received(self, data):
body.append(data)
def finish(self):
event.set()
yield conn.read_response(Delegate())
yield event.wait()
self.assertEqual(self.code, 200)
self.assertEqual(b''.join(body), b'hello')

View File

@@ -0,0 +1,718 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import base64
import binascii
from contextlib import closing
import copy
import sys
import threading
import datetime
from io import BytesIO
import time
import unicodedata
from tornado.escape import utf8, native_str
from tornado import gen
from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado.log import gen_log
from tornado import netutil
from tornado.stack_context import ExceptionStackContext, NullContext
from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog
from tornado.test.util import unittest, skipOnTravis, ignore_deprecation
from tornado.web import Application, RequestHandler, url
from tornado.httputil import format_timestamp, HTTPHeaders
class HelloWorldHandler(RequestHandler):
def get(self):
name = self.get_argument("name", "world")
self.set_header("Content-Type", "text/plain")
self.finish("Hello %s!" % name)
class PostHandler(RequestHandler):
def post(self):
self.finish("Post arg1: %s, arg2: %s" % (
self.get_argument("arg1"), self.get_argument("arg2")))
class PutHandler(RequestHandler):
def put(self):
self.write("Put body: ")
self.write(self.request.body)
class RedirectHandler(RequestHandler):
def prepare(self):
self.write('redirects can have bodies too')
self.redirect(self.get_argument("url"),
status=int(self.get_argument("status", "302")))
class ChunkHandler(RequestHandler):
@gen.coroutine
def get(self):
self.write("asdf")
self.flush()
# Wait a bit to ensure the chunks are sent and received separately.
yield gen.sleep(0.01)
self.write("qwer")
class AuthHandler(RequestHandler):
def get(self):
self.finish(self.request.headers["Authorization"])
class CountdownHandler(RequestHandler):
def get(self, count):
count = int(count)
if count > 0:
self.redirect(self.reverse_url("countdown", count - 1))
else:
self.write("Zero")
class EchoPostHandler(RequestHandler):
def post(self):
self.write(self.request.body)
class UserAgentHandler(RequestHandler):
def get(self):
self.write(self.request.headers.get('User-Agent', 'User agent not set'))
class ContentLength304Handler(RequestHandler):
def get(self):
self.set_status(304)
self.set_header('Content-Length', 42)
def _clear_headers_for_304(self):
# Tornado strips content-length from 304 responses, but here we
# want to simulate servers that include the headers anyway.
pass
class PatchHandler(RequestHandler):
def patch(self):
"Return the request payload - so we can check it is being kept"
self.write(self.request.body)
class AllMethodsHandler(RequestHandler):
SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ('OTHER',)
def method(self):
self.write(self.request.method)
get = post = put = delete = options = patch = other = method
class SetHeaderHandler(RequestHandler):
def get(self):
# Use get_arguments for keys to get strings, but
# request.arguments for values to get bytes.
for k, v in zip(self.get_arguments('k'),
self.request.arguments['v']):
self.set_header(k, v)
# These tests end up getting run redundantly: once here with the default
# HTTPClient implementation, and then again in each implementation's own
# test suite.
class HTTPClientCommonTestCase(AsyncHTTPTestCase):
def get_app(self):
return Application([
url("/hello", HelloWorldHandler),
url("/post", PostHandler),
url("/put", PutHandler),
url("/redirect", RedirectHandler),
url("/chunk", ChunkHandler),
url("/auth", AuthHandler),
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
url("/echopost", EchoPostHandler),
url("/user_agent", UserAgentHandler),
url("/304_with_content_length", ContentLength304Handler),
url("/all_methods", AllMethodsHandler),
url('/patch', PatchHandler),
url('/set_header', SetHeaderHandler),
], gzip=True)
def test_patch_receives_payload(self):
body = b"some patch data"
response = self.fetch("/patch", method='PATCH', body=body)
self.assertEqual(response.code, 200)
self.assertEqual(response.body, body)
@skipOnTravis
def test_hello_world(self):
response = self.fetch("/hello")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["Content-Type"], "text/plain")
self.assertEqual(response.body, b"Hello world!")
self.assertEqual(int(response.request_time), 0)
response = self.fetch("/hello?name=Ben")
self.assertEqual(response.body, b"Hello Ben!")
def test_streaming_callback(self):
# streaming_callback is also tested in test_chunked
chunks = []
response = self.fetch("/hello",
streaming_callback=chunks.append)
# with streaming_callback, data goes to the callback and not response.body
self.assertEqual(chunks, [b"Hello world!"])
self.assertFalse(response.body)
def test_post(self):
response = self.fetch("/post", method="POST",
body="arg1=foo&arg2=bar")
self.assertEqual(response.code, 200)
self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
def test_chunked(self):
response = self.fetch("/chunk")
self.assertEqual(response.body, b"asdfqwer")
chunks = []
response = self.fetch("/chunk",
streaming_callback=chunks.append)
self.assertEqual(chunks, [b"asdf", b"qwer"])
self.assertFalse(response.body)
def test_chunked_close(self):
# test case in which chunks spread read-callback processing
# over several ioloop iterations, but the connection is already closed.
sock, port = bind_unused_port()
with closing(sock):
@gen.coroutine
def accept_callback(conn, address):
# fake an HTTP server using chunked encoding where the final chunks
# and connection close all happen at once
stream = IOStream(conn)
request_data = yield stream.read_until(b"\r\n\r\n")
if b"HTTP/1." not in request_data:
self.skipTest("requires HTTP/1.x")
yield stream.write(b"""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked
1
1
1
2
0
""".replace(b"\n", b"\r\n"))
stream.close()
netutil.add_accept_handler(sock, accept_callback)
resp = self.fetch("http://127.0.0.1:%d/" % port)
resp.rethrow()
self.assertEqual(resp.body, b"12")
self.io_loop.remove_handler(sock.fileno())
def test_streaming_stack_context(self):
chunks = []
exc_info = []
def error_handler(typ, value, tb):
exc_info.append((typ, value, tb))
return True
def streaming_cb(chunk):
chunks.append(chunk)
if chunk == b'qwer':
1 / 0
with ignore_deprecation():
with ExceptionStackContext(error_handler):
self.fetch('/chunk', streaming_callback=streaming_cb)
self.assertEqual(chunks, [b'asdf', b'qwer'])
self.assertEqual(1, len(exc_info))
self.assertIs(exc_info[0][0], ZeroDivisionError)
def test_basic_auth(self):
# This test data appears in section 2 of RFC 7617.
self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame").body,
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
def test_basic_auth_explicit_mode(self):
self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame",
auth_mode="basic").body,
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
def test_basic_auth_unicode(self):
# This test data appears in section 2.1 of RFC 7617.
self.assertEqual(self.fetch("/auth", auth_username="test",
auth_password="123£").body,
b"Basic dGVzdDoxMjPCow==")
# The standard mandates NFC. Give it a decomposed username
# and ensure it is normalized to composed form.
username = unicodedata.normalize("NFD", u"josé")
self.assertEqual(self.fetch("/auth",
auth_username=username,
auth_password="səcrət").body,
b"Basic am9zw6k6c8mZY3LJmXQ=")
def test_unsupported_auth_mode(self):
# curl and simple clients handle errors a bit differently; the
# important thing is that they don't fall back to basic auth
# on an unknown mode.
with ExpectLog(gen_log, "uncaught exception", required=False):
with self.assertRaises((ValueError, HTTPError)):
self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame",
auth_mode="asdf",
raise_error=True)
def test_follow_redirect(self):
response = self.fetch("/countdown/2", follow_redirects=False)
self.assertEqual(302, response.code)
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
response = self.fetch("/countdown/2")
self.assertEqual(200, response.code)
self.assertTrue(response.effective_url.endswith("/countdown/0"))
self.assertEqual(b"Zero", response.body)
def test_credentials_in_url(self):
url = self.get_url("/auth").replace("http://", "http://me:secret@")
response = self.fetch(url)
self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"),
response.body)
def test_body_encoding(self):
unicode_body = u"\xe9"
byte_body = binascii.a2b_hex(b"e9")
# unicode string in body gets converted to utf8
response = self.fetch("/echopost", method="POST", body=unicode_body,
headers={"Content-Type": "application/blah"})
self.assertEqual(response.headers["Content-Length"], "2")
self.assertEqual(response.body, utf8(unicode_body))
# byte strings pass through directly
response = self.fetch("/echopost", method="POST",
body=byte_body,
headers={"Content-Type": "application/blah"})
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
# Mixing unicode in headers and byte string bodies shouldn't
# break anything
response = self.fetch("/echopost", method="POST", body=byte_body,
headers={"Content-Type": "application/blah"},
user_agent=u"foo")
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
def test_types(self):
response = self.fetch("/hello")
self.assertEqual(type(response.body), bytes)
self.assertEqual(type(response.headers["Content-Type"]), str)
self.assertEqual(type(response.code), int)
self.assertEqual(type(response.effective_url), str)
def test_header_callback(self):
first_line = []
headers = {}
chunks = []
def header_callback(header_line):
if header_line.startswith('HTTP/1.1 101'):
# Upgrading to HTTP/2
pass
elif header_line.startswith('HTTP/'):
first_line.append(header_line)
elif header_line != '\r\n':
k, v = header_line.split(':', 1)
headers[k.lower()] = v.strip()
def streaming_callback(chunk):
# All header callbacks are run before any streaming callbacks,
# so the header data is available to process the data as it
# comes in.
self.assertEqual(headers['content-type'], 'text/html; charset=UTF-8')
chunks.append(chunk)
self.fetch('/chunk', header_callback=header_callback,
streaming_callback=streaming_callback)
self.assertEqual(len(first_line), 1, first_line)
self.assertRegexpMatches(first_line[0], 'HTTP/[0-9]\\.[0-9] 200.*\r\n')
self.assertEqual(chunks, [b'asdf', b'qwer'])
def test_header_callback_stack_context(self):
exc_info = []
def error_handler(typ, value, tb):
exc_info.append((typ, value, tb))
return True
def header_callback(header_line):
if header_line.lower().startswith('content-type:'):
1 / 0
with ignore_deprecation():
with ExceptionStackContext(error_handler):
self.fetch('/chunk', header_callback=header_callback)
self.assertEqual(len(exc_info), 1)
self.assertIs(exc_info[0][0], ZeroDivisionError)
@gen_test
def test_configure_defaults(self):
defaults = dict(user_agent='TestDefaultUserAgent', allow_ipv6=False)
# Construct a new instance of the configured client class
client = self.http_client.__class__(force_instance=True,
defaults=defaults)
try:
response = yield client.fetch(self.get_url('/user_agent'))
self.assertEqual(response.body, b'TestDefaultUserAgent')
finally:
client.close()
def test_header_types(self):
# Header values may be passed as character or utf8 byte strings,
# in a plain dictionary or an HTTPHeaders object.
# Keys must always be the native str type.
# All combinations should have the same results on the wire.
for value in [u"MyUserAgent", b"MyUserAgent"]:
for container in [dict, HTTPHeaders]:
headers = container()
headers['User-Agent'] = value
resp = self.fetch('/user_agent', headers=headers)
self.assertEqual(
resp.body, b"MyUserAgent",
"response=%r, value=%r, container=%r" %
(resp.body, value, container))
def test_multi_line_headers(self):
# Multi-line http headers are rare but rfc-allowed
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
sock, port = bind_unused_port()
with closing(sock):
@gen.coroutine
def accept_callback(conn, address):
stream = IOStream(conn)
request_data = yield stream.read_until(b"\r\n\r\n")
if b"HTTP/1." not in request_data:
self.skipTest("requires HTTP/1.x")
yield stream.write(b"""\
HTTP/1.1 200 OK
X-XSS-Protection: 1;
\tmode=block
""".replace(b"\n", b"\r\n"))
stream.close()
netutil.add_accept_handler(sock, accept_callback)
resp = self.fetch("http://127.0.0.1:%d/" % port)
resp.rethrow()
self.assertEqual(resp.headers['X-XSS-Protection'], "1; mode=block")
self.io_loop.remove_handler(sock.fileno())
def test_304_with_content_length(self):
# According to the spec 304 responses SHOULD NOT include
# Content-Length or other entity headers, but some servers do it
# anyway.
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5
response = self.fetch('/304_with_content_length')
self.assertEqual(response.code, 304)
self.assertEqual(response.headers['Content-Length'], '42')
def test_final_callback_stack_context(self):
# The final callback should be run outside of the httpclient's
# stack_context. We want to ensure that there is not stack_context
# between the user's callback and the IOLoop, so monkey-patch
# IOLoop.handle_callback_exception and disable the test harness's
# context with a NullContext.
# Note that this does not apply to secondary callbacks (header
# and streaming_callback), as errors there must be seen as errors
# by the http client so it can clean up the connection.
exc_info = []
def handle_callback_exception(callback):
exc_info.append(sys.exc_info())
self.stop()
self.io_loop.handle_callback_exception = handle_callback_exception
with NullContext():
with ignore_deprecation():
self.http_client.fetch(self.get_url('/hello'),
lambda response: 1 / 0)
self.wait()
self.assertEqual(exc_info[0][0], ZeroDivisionError)
@gen_test
def test_future_interface(self):
response = yield self.http_client.fetch(self.get_url('/hello'))
self.assertEqual(response.body, b'Hello world!')
@gen_test
def test_future_http_error(self):
with self.assertRaises(HTTPError) as context:
yield self.http_client.fetch(self.get_url('/notfound'))
self.assertEqual(context.exception.code, 404)
self.assertEqual(context.exception.response.code, 404)
@gen_test
def test_future_http_error_no_raise(self):
response = yield self.http_client.fetch(self.get_url('/notfound'), raise_error=False)
self.assertEqual(response.code, 404)
@gen_test
def test_reuse_request_from_response(self):
# The response.request attribute should be an HTTPRequest, not
# a _RequestProxy.
# This test uses self.http_client.fetch because self.fetch calls
# self.get_url on the input unconditionally.
url = self.get_url('/hello')
response = yield self.http_client.fetch(url)
self.assertEqual(response.request.url, url)
self.assertTrue(isinstance(response.request, HTTPRequest))
response2 = yield self.http_client.fetch(response.request)
self.assertEqual(response2.body, b'Hello world!')
def test_all_methods(self):
for method in ['GET', 'DELETE', 'OPTIONS']:
response = self.fetch('/all_methods', method=method)
self.assertEqual(response.body, utf8(method))
for method in ['POST', 'PUT', 'PATCH']:
response = self.fetch('/all_methods', method=method, body=b'')
self.assertEqual(response.body, utf8(method))
response = self.fetch('/all_methods', method='HEAD')
self.assertEqual(response.body, b'')
response = self.fetch('/all_methods', method='OTHER',
allow_nonstandard_methods=True)
self.assertEqual(response.body, b'OTHER')
def test_body_sanity_checks(self):
# These methods require a body.
for method in ('POST', 'PUT', 'PATCH'):
with self.assertRaises(ValueError) as context:
self.fetch('/all_methods', method=method, raise_error=True)
self.assertIn('must not be None', str(context.exception))
resp = self.fetch('/all_methods', method=method,
allow_nonstandard_methods=True)
self.assertEqual(resp.code, 200)
# These methods don't allow a body.
for method in ('GET', 'DELETE', 'OPTIONS'):
with self.assertRaises(ValueError) as context:
self.fetch('/all_methods', method=method, body=b'asdf', raise_error=True)
self.assertIn('must be None', str(context.exception))
# In most cases this can be overridden, but curl_httpclient
# does not allow body with a GET at all.
if method != 'GET':
self.fetch('/all_methods', method=method, body=b'asdf',
allow_nonstandard_methods=True, raise_error=True)
self.assertEqual(resp.code, 200)
# This test causes odd failures with the combination of
# curl_httpclient (at least with the version of libcurl available
# on ubuntu 12.04), TwistedIOLoop, and epoll. For POST (but not PUT),
# curl decides the response came back too soon and closes the connection
# to start again. It does this *before* telling the socket callback to
# unregister the FD. Some IOLoop implementations have special kernel
# integration to discover this immediately. Tornado's IOLoops
# ignore errors on remove_handler to accommodate this behavior, but
# Twisted's reactor does not. The removeReader call fails and so
# do all future removeAll calls (which our tests do at cleanup).
#
# def test_post_307(self):
# response = self.fetch("/redirect?status=307&url=/post",
# method="POST", body=b"arg1=foo&arg2=bar")
# self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
def test_put_307(self):
response = self.fetch("/redirect?status=307&url=/put",
method="PUT", body=b"hello")
response.rethrow()
self.assertEqual(response.body, b"Put body: hello")
def test_non_ascii_header(self):
# Non-ascii headers are sent as latin1.
response = self.fetch("/set_header?k=foo&v=%E9")
response.rethrow()
self.assertEqual(response.headers["Foo"], native_str(u"\u00e9"))
def test_response_times(self):
# A few simple sanity checks of the response time fields to
# make sure they're using the right basis (between the
# wall-time and monotonic clocks).
start_time = time.time()
response = self.fetch("/hello")
response.rethrow()
self.assertGreaterEqual(response.request_time, 0)
self.assertLess(response.request_time, 1.0)
# A very crude check to make sure that start_time is based on
# wall time and not the monotonic clock.
self.assertLess(abs(response.start_time - start_time), 1.0)
for k, v in response.time_info.items():
self.assertTrue(0 <= v < 1.0, "time_info[%s] out of bounds: %s" % (k, v))
class RequestProxyTest(unittest.TestCase):
def test_request_set(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/',
user_agent='foo'),
dict())
self.assertEqual(proxy.user_agent, 'foo')
def test_default_set(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/'),
dict(network_interface='foo'))
self.assertEqual(proxy.network_interface, 'foo')
def test_both_set(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/',
proxy_host='foo'),
dict(proxy_host='bar'))
self.assertEqual(proxy.proxy_host, 'foo')
def test_neither_set(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/'),
dict())
self.assertIs(proxy.auth_username, None)
def test_bad_attribute(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/'),
dict())
with self.assertRaises(AttributeError):
proxy.foo
def test_defaults_none(self):
proxy = _RequestProxy(HTTPRequest('http://example.com/'), None)
self.assertIs(proxy.auth_username, None)
class HTTPResponseTestCase(unittest.TestCase):
def test_str(self):
response = HTTPResponse(HTTPRequest('http://example.com'),
200, headers={}, buffer=BytesIO())
s = str(response)
self.assertTrue(s.startswith('HTTPResponse('))
self.assertIn('code=200', s)
class SyncHTTPClientTest(unittest.TestCase):
def setUp(self):
if IOLoop.configured_class().__name__ == 'TwistedIOLoop':
# TwistedIOLoop only supports the global reactor, so we can't have
# separate IOLoops for client and server threads.
raise unittest.SkipTest(
'Sync HTTPClient not compatible with TwistedIOLoop')
self.server_ioloop = IOLoop()
@gen.coroutine
def init_server():
sock, self.port = bind_unused_port()
app = Application([('/', HelloWorldHandler)])
self.server = HTTPServer(app)
self.server.add_socket(sock)
self.server_ioloop.run_sync(init_server)
self.server_thread = threading.Thread(target=self.server_ioloop.start)
self.server_thread.start()
self.http_client = HTTPClient()
def tearDown(self):
def stop_server():
self.server.stop()
# Delay the shutdown of the IOLoop by several iterations because
# the server may still have some cleanup work left when
# the client finishes with the response (this is noticeable
# with http/2, which leaves a Future with an unexamined
# StreamClosedError on the loop).
@gen.coroutine
def slow_stop():
# The number of iterations is difficult to predict. Typically,
# one is sufficient, although sometimes it needs more.
for i in range(5):
yield
self.server_ioloop.stop()
self.server_ioloop.add_callback(slow_stop)
self.server_ioloop.add_callback(stop_server)
self.server_thread.join()
self.http_client.close()
self.server_ioloop.close(all_fds=True)
def get_url(self, path):
return 'http://127.0.0.1:%d%s' % (self.port, path)
def test_sync_client(self):
response = self.http_client.fetch(self.get_url('/'))
self.assertEqual(b'Hello world!', response.body)
def test_sync_client_error(self):
# Synchronous HTTPClient raises errors directly; no need for
# response.rethrow()
with self.assertRaises(HTTPError) as assertion:
self.http_client.fetch(self.get_url('/notfound'))
self.assertEqual(assertion.exception.code, 404)
class HTTPRequestTestCase(unittest.TestCase):
def test_headers(self):
request = HTTPRequest('http://example.com', headers={'foo': 'bar'})
self.assertEqual(request.headers, {'foo': 'bar'})
def test_headers_setter(self):
request = HTTPRequest('http://example.com')
request.headers = {'bar': 'baz'}
self.assertEqual(request.headers, {'bar': 'baz'})
def test_null_headers_setter(self):
request = HTTPRequest('http://example.com')
request.headers = None
self.assertEqual(request.headers, {})
def test_body(self):
request = HTTPRequest('http://example.com', body='foo')
self.assertEqual(request.body, utf8('foo'))
def test_body_setter(self):
request = HTTPRequest('http://example.com')
request.body = 'foo'
self.assertEqual(request.body, utf8('foo'))
def test_if_modified_since(self):
http_date = datetime.datetime.utcnow()
request = HTTPRequest('http://example.com', if_modified_since=http_date)
self.assertEqual(request.headers,
{'If-Modified-Since': format_timestamp(http_date)})
class HTTPErrorTestCase(unittest.TestCase):
def test_copy(self):
e = HTTPError(403)
e2 = copy.copy(e)
self.assertIsNot(e, e2)
self.assertEqual(e.code, e2.code)
def test_plain_error(self):
e = HTTPError(403)
self.assertEqual(str(e), "HTTP 403: Forbidden")
self.assertEqual(repr(e), "HTTP 403: Forbidden")
def test_error_with_response(self):
resp = HTTPResponse(HTTPRequest('http://example.com/'), 403)
with self.assertRaises(HTTPError) as cm:
resp.rethrow()
e = cm.exception
self.assertEqual(str(e), "HTTP 403: Forbidden")
self.assertEqual(repr(e), "HTTP 403: Forbidden")

File diff suppressed because it is too large Load Diff

516
lib/tornado/test/httputil_test.py Executable file
View File

@@ -0,0 +1,516 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
from tornado.httputil import (
url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp,
HTTPServerRequest, parse_request_start_line, parse_cookie, qs_to_qsl,
HTTPInputError,
)
from tornado.escape import utf8, native_str
from tornado.util import PY3
from tornado.log import gen_log
from tornado.testing import ExpectLog
from tornado.test.util import unittest
import copy
import datetime
import logging
import pickle
import time
if PY3:
import urllib.parse as urllib_parse
else:
import urlparse as urllib_parse
class TestUrlConcat(unittest.TestCase):
def test_url_concat_no_query_params(self):
url = url_concat(
"https://localhost/path",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?y=y&z=z")
def test_url_concat_encode_args(self):
url = url_concat(
"https://localhost/path",
[('y', '/y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?y=%2Fy&z=z")
def test_url_concat_trailing_q(self):
url = url_concat(
"https://localhost/path?",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?y=y&z=z")
def test_url_concat_q_with_no_trailing_amp(self):
url = url_concat(
"https://localhost/path?x",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?x=&y=y&z=z")
def test_url_concat_trailing_amp(self):
url = url_concat(
"https://localhost/path?x&",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?x=&y=y&z=z")
def test_url_concat_mult_params(self):
url = url_concat(
"https://localhost/path?a=1&b=2",
[('y', 'y'), ('z', 'z')],
)
self.assertEqual(url, "https://localhost/path?a=1&b=2&y=y&z=z")
def test_url_concat_no_params(self):
url = url_concat(
"https://localhost/path?r=1&t=2",
[],
)
self.assertEqual(url, "https://localhost/path?r=1&t=2")
def test_url_concat_none_params(self):
url = url_concat(
"https://localhost/path?r=1&t=2",
None,
)
self.assertEqual(url, "https://localhost/path?r=1&t=2")
def test_url_concat_with_frag(self):
url = url_concat(
"https://localhost/path#tab",
[('y', 'y')],
)
self.assertEqual(url, "https://localhost/path?y=y#tab")
def test_url_concat_multi_same_params(self):
url = url_concat(
"https://localhost/path",
[('y', 'y1'), ('y', 'y2')],
)
self.assertEqual(url, "https://localhost/path?y=y1&y=y2")
def test_url_concat_multi_same_query_params(self):
url = url_concat(
"https://localhost/path?r=1&r=2",
[('y', 'y')],
)
self.assertEqual(url, "https://localhost/path?r=1&r=2&y=y")
def test_url_concat_dict_params(self):
url = url_concat(
"https://localhost/path",
dict(y='y'),
)
self.assertEqual(url, "https://localhost/path?y=y")
class QsParseTest(unittest.TestCase):
def test_parsing(self):
qsstring = "a=1&b=2&a=3"
qs = urllib_parse.parse_qs(qsstring)
qsl = list(qs_to_qsl(qs))
self.assertIn(('a', '1'), qsl)
self.assertIn(('a', '3'), qsl)
self.assertIn(('b', '2'), qsl)
class MultipartFormDataTest(unittest.TestCase):
def test_file_upload(self):
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
--1234--""".replace(b"\n", b"\r\n")
args = {}
files = {}
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
def test_unquoted_names(self):
# quotes are optional unless special characters are present
data = b"""\
--1234
Content-Disposition: form-data; name=files; filename=ab.txt
Foo
--1234--""".replace(b"\n", b"\r\n")
args = {}
files = {}
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
def test_special_filenames(self):
filenames = ['a;b.txt',
'a"b.txt',
'a";b.txt',
'a;"b.txt',
'a";";.txt',
'a\\"b.txt',
'a\\b.txt',
]
for filename in filenames:
logging.debug("trying filename %r", filename)
data = """\
--1234
Content-Disposition: form-data; name="files"; filename="%s"
Foo
--1234--""" % filename.replace('\\', '\\\\').replace('"', '\\"')
data = utf8(data.replace("\n", "\r\n"))
args = {}
files = {}
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], filename)
self.assertEqual(file["body"], b"Foo")
def test_non_ascii_filename(self):
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"; filename*=UTF-8''%C3%A1b.txt
Foo
--1234--""".replace(b"\n", b"\r\n")
args = {}
files = {}
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], u"áb.txt")
self.assertEqual(file["body"], b"Foo")
def test_boundary_starts_and_ends_with_quotes(self):
data = b'''\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
--1234--'''.replace(b"\n", b"\r\n")
args = {}
files = {}
parse_multipart_form_data(b'"1234"', data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
def test_missing_headers(self):
data = b'''\
--1234
Foo
--1234--'''.replace(b"\n", b"\r\n")
args = {}
files = {}
with ExpectLog(gen_log, "multipart/form-data missing headers"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_invalid_content_disposition(self):
data = b'''\
--1234
Content-Disposition: invalid; name="files"; filename="ab.txt"
Foo
--1234--'''.replace(b"\n", b"\r\n")
args = {}
files = {}
with ExpectLog(gen_log, "Invalid multipart/form-data"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_line_does_not_end_with_correct_line_break(self):
data = b'''\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo--1234--'''.replace(b"\n", b"\r\n")
args = {}
files = {}
with ExpectLog(gen_log, "Invalid multipart/form-data"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_content_disposition_header_without_name_parameter(self):
data = b"""\
--1234
Content-Disposition: form-data; filename="ab.txt"
Foo
--1234--""".replace(b"\n", b"\r\n")
args = {}
files = {}
with ExpectLog(gen_log, "multipart/form-data value missing name"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_data_after_final_boundary(self):
# The spec requires that data after the final boundary be ignored.
# http://www.w3.org/Protocols/rfc1341/7_2_Multipart.html
# In practice, some libraries include an extra CRLF after the boundary.
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
--1234--
""".replace(b"\n", b"\r\n")
args = {}
files = {}
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
class HTTPHeadersTest(unittest.TestCase):
def test_multi_line(self):
# Lines beginning with whitespace are appended to the previous line
# with any leading whitespace replaced by a single space.
# Note that while multi-line headers are a part of the HTTP spec,
# their use is strongly discouraged.
data = """\
Foo: bar
baz
Asdf: qwer
\tzxcv
Foo: even
more
lines
""".replace("\n", "\r\n")
headers = HTTPHeaders.parse(data)
self.assertEqual(headers["asdf"], "qwer zxcv")
self.assertEqual(headers.get_list("asdf"), ["qwer zxcv"])
self.assertEqual(headers["Foo"], "bar baz,even more lines")
self.assertEqual(headers.get_list("foo"), ["bar baz", "even more lines"])
self.assertEqual(sorted(list(headers.get_all())),
[("Asdf", "qwer zxcv"),
("Foo", "bar baz"),
("Foo", "even more lines")])
def test_malformed_continuation(self):
# If the first line starts with whitespace, it's a
# continuation line with nothing to continue, so reject it
# (with a proper error).
data = " Foo: bar"
self.assertRaises(HTTPInputError, HTTPHeaders.parse, data)
def test_unicode_newlines(self):
# Ensure that only \r\n is recognized as a header separator, and not
# the other newline-like unicode characters.
# Characters that are likely to be problematic can be found in
# http://unicode.org/standard/reports/tr13/tr13-5.html
# and cpython's unicodeobject.c (which defines the implementation
# of unicode_type.splitlines(), and uses a different list than TR13).
newlines = [
u'\u001b', # VERTICAL TAB
u'\u001c', # FILE SEPARATOR
u'\u001d', # GROUP SEPARATOR
u'\u001e', # RECORD SEPARATOR
u'\u0085', # NEXT LINE
u'\u2028', # LINE SEPARATOR
u'\u2029', # PARAGRAPH SEPARATOR
]
for newline in newlines:
# Try the utf8 and latin1 representations of each newline
for encoding in ['utf8', 'latin1']:
try:
try:
encoded = newline.encode(encoding)
except UnicodeEncodeError:
# Some chars cannot be represented in latin1
continue
data = b'Cookie: foo=' + encoded + b'bar'
# parse() wants a native_str, so decode through latin1
# in the same way the real parser does.
headers = HTTPHeaders.parse(
native_str(data.decode('latin1')))
expected = [('Cookie', 'foo=' +
native_str(encoded.decode('latin1')) + 'bar')]
self.assertEqual(
expected, list(headers.get_all()))
except Exception:
gen_log.warning("failed while trying %r in %s",
newline, encoding)
raise
def test_optional_cr(self):
# Both CRLF and LF should be accepted as separators. CR should not be
# part of the data when followed by LF, but it is a normal char
# otherwise (or should bare CR be an error?)
headers = HTTPHeaders.parse(
'CRLF: crlf\r\nLF: lf\nCR: cr\rMore: more\r\n')
self.assertEqual(sorted(headers.get_all()),
[('Cr', 'cr\rMore: more'),
('Crlf', 'crlf'),
('Lf', 'lf'),
])
def test_copy(self):
all_pairs = [('A', '1'), ('A', '2'), ('B', 'c')]
h1 = HTTPHeaders()
for k, v in all_pairs:
h1.add(k, v)
h2 = h1.copy()
h3 = copy.copy(h1)
h4 = copy.deepcopy(h1)
for headers in [h1, h2, h3, h4]:
# All the copies are identical, no matter how they were
# constructed.
self.assertEqual(list(sorted(headers.get_all())), all_pairs)
for headers in [h2, h3, h4]:
# Neither the dict or its member lists are reused.
self.assertIsNot(headers, h1)
self.assertIsNot(headers.get_list('A'), h1.get_list('A'))
def test_pickle_roundtrip(self):
headers = HTTPHeaders()
headers.add('Set-Cookie', 'a=b')
headers.add('Set-Cookie', 'c=d')
headers.add('Content-Type', 'text/html')
pickled = pickle.dumps(headers)
unpickled = pickle.loads(pickled)
self.assertEqual(sorted(headers.get_all()), sorted(unpickled.get_all()))
self.assertEqual(sorted(headers.items()), sorted(unpickled.items()))
def test_setdefault(self):
headers = HTTPHeaders()
headers['foo'] = 'bar'
# If a value is present, setdefault returns it without changes.
self.assertEqual(headers.setdefault('foo', 'baz'), 'bar')
self.assertEqual(headers['foo'], 'bar')
# If a value is not present, setdefault sets it for future use.
self.assertEqual(headers.setdefault('quux', 'xyzzy'), 'xyzzy')
self.assertEqual(headers['quux'], 'xyzzy')
self.assertEqual(sorted(headers.get_all()), [('Foo', 'bar'), ('Quux', 'xyzzy')])
def test_string(self):
headers = HTTPHeaders()
headers.add("Foo", "1")
headers.add("Foo", "2")
headers.add("Foo", "3")
headers2 = HTTPHeaders.parse(str(headers))
self.assertEquals(headers, headers2)
class FormatTimestampTest(unittest.TestCase):
# Make sure that all the input types are supported.
TIMESTAMP = 1359312200.503611
EXPECTED = 'Sun, 27 Jan 2013 18:43:20 GMT'
def check(self, value):
self.assertEqual(format_timestamp(value), self.EXPECTED)
def test_unix_time_float(self):
self.check(self.TIMESTAMP)
def test_unix_time_int(self):
self.check(int(self.TIMESTAMP))
def test_struct_time(self):
self.check(time.gmtime(self.TIMESTAMP))
def test_time_tuple(self):
tup = tuple(time.gmtime(self.TIMESTAMP))
self.assertEqual(9, len(tup))
self.check(tup)
def test_datetime(self):
self.check(datetime.datetime.utcfromtimestamp(self.TIMESTAMP))
# HTTPServerRequest is mainly tested incidentally to the server itself,
# but this tests the parts of the class that can be tested in isolation.
class HTTPServerRequestTest(unittest.TestCase):
def test_default_constructor(self):
# All parameters are formally optional, but uri is required
# (and has been for some time). This test ensures that no
# more required parameters slip in.
HTTPServerRequest(uri='/')
def test_body_is_a_byte_string(self):
requets = HTTPServerRequest(uri='/')
self.assertIsInstance(requets.body, bytes)
def test_repr_does_not_contain_headers(self):
request = HTTPServerRequest(uri='/', headers={'Canary': 'Coal Mine'})
self.assertTrue('Canary' not in repr(request))
class ParseRequestStartLineTest(unittest.TestCase):
METHOD = "GET"
PATH = "/foo"
VERSION = "HTTP/1.1"
def test_parse_request_start_line(self):
start_line = " ".join([self.METHOD, self.PATH, self.VERSION])
parsed_start_line = parse_request_start_line(start_line)
self.assertEqual(parsed_start_line.method, self.METHOD)
self.assertEqual(parsed_start_line.path, self.PATH)
self.assertEqual(parsed_start_line.version, self.VERSION)
class ParseCookieTest(unittest.TestCase):
# These tests copied from Django:
# https://github.com/django/django/pull/6277/commits/da810901ada1cae9fc1f018f879f11a7fb467b28
def test_python_cookies(self):
"""
Test cases copied from Python's Lib/test/test_http_cookies.py
"""
self.assertEqual(parse_cookie('chips=ahoy; vienna=finger'),
{'chips': 'ahoy', 'vienna': 'finger'})
# Here parse_cookie() differs from Python's cookie parsing in that it
# treats all semicolons as delimiters, even within quotes.
self.assertEqual(
parse_cookie('keebler="E=mc2; L=\\"Loves\\"; fudge=\\012;"'),
{'keebler': '"E=mc2', 'L': '\\"Loves\\"', 'fudge': '\\012', '': '"'}
)
# Illegal cookies that have an '=' char in an unquoted value.
self.assertEqual(parse_cookie('keebler=E=mc2'), {'keebler': 'E=mc2'})
# Cookies with ':' character in their name.
self.assertEqual(parse_cookie('key:term=value:term'), {'key:term': 'value:term'})
# Cookies with '[' and ']'.
self.assertEqual(parse_cookie('a=b; c=[; d=r; f=h'),
{'a': 'b', 'c': '[', 'd': 'r', 'f': 'h'})
def test_cookie_edgecases(self):
# Cookies that RFC6265 allows.
self.assertEqual(parse_cookie('a=b; Domain=example.com'),
{'a': 'b', 'Domain': 'example.com'})
# parse_cookie() has historically kept only the last cookie with the
# same name.
self.assertEqual(parse_cookie('a=b; h=i; a=c'), {'a': 'c', 'h': 'i'})
def test_invalid_cookies(self):
"""
Cookie strings that go against RFC6265 but browsers will send if set
via document.cookie.
"""
# Chunks without an equals sign appear as unnamed values per
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
self.assertIn('django_language',
parse_cookie('abc=def; unnamed; django_language=en').keys())
# Even a double quote may be an unamed value.
self.assertEqual(parse_cookie('a=b; "; c=d'), {'a': 'b', '': '"', 'c': 'd'})
# Spaces in names and values, and an equals sign in values.
self.assertEqual(parse_cookie('a b c=d e = f; gh=i'), {'a b c': 'd e = f', 'gh': 'i'})
# More characters the spec forbids.
self.assertEqual(parse_cookie('a b,c<>@:/[]?{}=d " =e,f g'),
{'a b,c<>@:/[]?{}': 'd " =e,f g'})
# Unicode characters. The spec only allows ASCII.
self.assertEqual(parse_cookie('saint=André Bessette'),
{'saint': native_str('André Bessette')})
# Browsers don't send extra whitespace or semicolons in Cookie headers,
# but parse_cookie() should parse whitespace the same way
# document.cookie parses whitespace.
self.assertEqual(parse_cookie(' = b ; ; = ; c = ; '), {'': 'b', 'c': ''})

73
lib/tornado/test/import_test.py Executable file
View File

@@ -0,0 +1,73 @@
# flake8: noqa
from __future__ import absolute_import, division, print_function
import subprocess
import sys
from tornado.test.util import unittest
_import_everything = b"""
# The event loop is not fork-safe, and it's easy to initialize an asyncio.Future
# at startup, which in turn creates the default event loop and prevents forking.
# Explicitly disallow the default event loop so that an error will be raised
# if something tries to touch it.
try:
import asyncio
except ImportError:
pass
else:
asyncio.set_event_loop(None)
import tornado.auth
import tornado.autoreload
import tornado.concurrent
import tornado.escape
import tornado.gen
import tornado.http1connection
import tornado.httpclient
import tornado.httpserver
import tornado.httputil
import tornado.ioloop
import tornado.iostream
import tornado.locale
import tornado.log
import tornado.netutil
import tornado.options
import tornado.process
import tornado.simple_httpclient
import tornado.stack_context
import tornado.tcpserver
import tornado.tcpclient
import tornado.template
import tornado.testing
import tornado.util
import tornado.web
import tornado.websocket
import tornado.wsgi
try:
import pycurl
except ImportError:
pass
else:
import tornado.curl_httpclient
"""
class ImportTest(unittest.TestCase):
def test_import_everything(self):
# Test that all Tornado modules can be imported without side effects,
# specifically without initializing the default asyncio event loop.
# Since we can't tell which modules may have already beein imported
# in our process, do it in a subprocess for a clean slate.
proc = subprocess.Popen([sys.executable], stdin=subprocess.PIPE)
proc.communicate(_import_everything)
self.assertEqual(proc.returncode, 0)
def test_import_aliases(self):
# Ensure we don't delete formerly-documented aliases accidentally.
import tornado.ioloop
import tornado.gen
import tornado.util
self.assertIs(tornado.ioloop.TimeoutError, tornado.util.TimeoutError)
self.assertIs(tornado.gen.TimeoutError, tornado.util.TimeoutError)

942
lib/tornado/test/ioloop_test.py Executable file
View File

@@ -0,0 +1,942 @@
from __future__ import absolute_import, division, print_function
from concurrent.futures import ThreadPoolExecutor
import contextlib
import datetime
import functools
import socket
import subprocess
import sys
import threading
import time
import types
try:
from unittest import mock # type: ignore
except ImportError:
try:
import mock # type: ignore
except ImportError:
mock = None
from tornado.escape import native_str
from tornado import gen
from tornado.ioloop import IOLoop, TimeoutError, PollIOLoop, PeriodicCallback
from tornado.log import app_log
from tornado.platform.select import _Select
from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog, gen_test
from tornado.test.util import (unittest, skipIfNonUnix, skipOnTravis,
skipBefore35, exec_test, ignore_deprecation)
try:
from concurrent import futures
except ImportError:
futures = None
try:
import asyncio
except ImportError:
asyncio = None
try:
import twisted
except ImportError:
twisted = None
class FakeTimeSelect(_Select):
def __init__(self):
self._time = 1000
super(FakeTimeSelect, self).__init__()
def time(self):
return self._time
def sleep(self, t):
self._time += t
def poll(self, timeout):
events = super(FakeTimeSelect, self).poll(0)
if events:
return events
self._time += timeout
return []
class FakeTimeIOLoop(PollIOLoop):
"""IOLoop implementation with a fake and deterministic clock.
The clock advances as needed to trigger timeouts immediately.
For use when testing code that involves the passage of time
and no external dependencies.
"""
def initialize(self):
self.fts = FakeTimeSelect()
super(FakeTimeIOLoop, self).initialize(impl=self.fts,
time_func=self.fts.time)
def sleep(self, t):
"""Simulate a blocking sleep by advancing the clock."""
self.fts.sleep(t)
class TestIOLoop(AsyncTestCase):
def test_add_callback_return_sequence(self):
# A callback returning {} or [] shouldn't spin the CPU, see Issue #1803.
self.calls = 0
loop = self.io_loop
test = self
old_add_callback = loop.add_callback
def add_callback(self, callback, *args, **kwargs):
test.calls += 1
old_add_callback(callback, *args, **kwargs)
loop.add_callback = types.MethodType(add_callback, loop)
loop.add_callback(lambda: {})
loop.add_callback(lambda: [])
loop.add_timeout(datetime.timedelta(milliseconds=50), loop.stop)
loop.start()
self.assertLess(self.calls, 10)
@skipOnTravis
def test_add_callback_wakeup(self):
# Make sure that add_callback from inside a running IOLoop
# wakes up the IOLoop immediately instead of waiting for a timeout.
def callback():
self.called = True
self.stop()
def schedule_callback():
self.called = False
self.io_loop.add_callback(callback)
# Store away the time so we can check if we woke up immediately
self.start_time = time.time()
self.io_loop.add_timeout(self.io_loop.time(), schedule_callback)
self.wait()
self.assertAlmostEqual(time.time(), self.start_time, places=2)
self.assertTrue(self.called)
@skipOnTravis
def test_add_callback_wakeup_other_thread(self):
def target():
# sleep a bit to let the ioloop go into its poll loop
time.sleep(0.01)
self.stop_time = time.time()
self.io_loop.add_callback(self.stop)
thread = threading.Thread(target=target)
self.io_loop.add_callback(thread.start)
self.wait()
delta = time.time() - self.stop_time
self.assertLess(delta, 0.1)
thread.join()
def test_add_timeout_timedelta(self):
self.io_loop.add_timeout(datetime.timedelta(microseconds=1), self.stop)
self.wait()
def test_multiple_add(self):
sock, port = bind_unused_port()
try:
self.io_loop.add_handler(sock.fileno(), lambda fd, events: None,
IOLoop.READ)
# Attempting to add the same handler twice fails
# (with a platform-dependent exception)
self.assertRaises(Exception, self.io_loop.add_handler,
sock.fileno(), lambda fd, events: None,
IOLoop.READ)
finally:
self.io_loop.remove_handler(sock.fileno())
sock.close()
def test_remove_without_add(self):
# remove_handler should not throw an exception if called on an fd
# was never added.
sock, port = bind_unused_port()
try:
self.io_loop.remove_handler(sock.fileno())
finally:
sock.close()
def test_add_callback_from_signal(self):
# cheat a little bit and just run this normally, since we can't
# easily simulate the races that happen with real signal handlers
self.io_loop.add_callback_from_signal(self.stop)
self.wait()
def test_add_callback_from_signal_other_thread(self):
# Very crude test, just to make sure that we cover this case.
# This also happens to be the first test where we run an IOLoop in
# a non-main thread.
other_ioloop = IOLoop()
thread = threading.Thread(target=other_ioloop.start)
thread.start()
other_ioloop.add_callback_from_signal(other_ioloop.stop)
thread.join()
other_ioloop.close()
def test_add_callback_while_closing(self):
# add_callback should not fail if it races with another thread
# closing the IOLoop. The callbacks are dropped silently
# without executing.
closing = threading.Event()
def target():
other_ioloop.add_callback(other_ioloop.stop)
other_ioloop.start()
closing.set()
other_ioloop.close(all_fds=True)
other_ioloop = IOLoop()
thread = threading.Thread(target=target)
thread.start()
closing.wait()
for i in range(1000):
other_ioloop.add_callback(lambda: None)
def test_handle_callback_exception(self):
# IOLoop.handle_callback_exception can be overridden to catch
# exceptions in callbacks.
def handle_callback_exception(callback):
self.assertIs(sys.exc_info()[0], ZeroDivisionError)
self.stop()
self.io_loop.handle_callback_exception = handle_callback_exception
with NullContext():
# remove the test StackContext that would see this uncaught
# exception as a test failure.
self.io_loop.add_callback(lambda: 1 / 0)
self.wait()
@skipIfNonUnix # just because socketpair is so convenient
def test_read_while_writeable(self):
# Ensure that write events don't come in while we're waiting for
# a read and haven't asked for writeability. (the reverse is
# difficult to test for)
client, server = socket.socketpair()
try:
def handler(fd, events):
self.assertEqual(events, IOLoop.READ)
self.stop()
self.io_loop.add_handler(client.fileno(), handler, IOLoop.READ)
self.io_loop.add_timeout(self.io_loop.time() + 0.01,
functools.partial(server.send, b'asdf'))
self.wait()
self.io_loop.remove_handler(client.fileno())
finally:
client.close()
server.close()
def test_remove_timeout_after_fire(self):
# It is not an error to call remove_timeout after it has run.
handle = self.io_loop.add_timeout(self.io_loop.time(), self.stop)
self.wait()
self.io_loop.remove_timeout(handle)
def test_remove_timeout_cleanup(self):
# Add and remove enough callbacks to trigger cleanup.
# Not a very thorough test, but it ensures that the cleanup code
# gets executed and doesn't blow up. This test is only really useful
# on PollIOLoop subclasses, but it should run silently on any
# implementation.
for i in range(2000):
timeout = self.io_loop.add_timeout(self.io_loop.time() + 3600,
lambda: None)
self.io_loop.remove_timeout(timeout)
# HACK: wait two IOLoop iterations for the GC to happen.
self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop))
self.wait()
def test_remove_timeout_from_timeout(self):
calls = [False, False]
# Schedule several callbacks and wait for them all to come due at once.
# t2 should be cancelled by t1, even though it is already scheduled to
# be run before the ioloop even looks at it.
now = self.io_loop.time()
def t1():
calls[0] = True
self.io_loop.remove_timeout(t2_handle)
self.io_loop.add_timeout(now + 0.01, t1)
def t2():
calls[1] = True
t2_handle = self.io_loop.add_timeout(now + 0.02, t2)
self.io_loop.add_timeout(now + 0.03, self.stop)
time.sleep(0.03)
self.wait()
self.assertEqual(calls, [True, False])
def test_timeout_with_arguments(self):
# This tests that all the timeout methods pass through *args correctly.
results = []
self.io_loop.add_timeout(self.io_loop.time(), results.append, 1)
self.io_loop.add_timeout(datetime.timedelta(seconds=0),
results.append, 2)
self.io_loop.call_at(self.io_loop.time(), results.append, 3)
self.io_loop.call_later(0, results.append, 4)
self.io_loop.call_later(0, self.stop)
self.wait()
# The asyncio event loop does not guarantee the order of these
# callbacks, but PollIOLoop does.
self.assertEqual(sorted(results), [1, 2, 3, 4])
def test_add_timeout_return(self):
# All the timeout methods return non-None handles that can be
# passed to remove_timeout.
handle = self.io_loop.add_timeout(self.io_loop.time(), lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_call_at_return(self):
handle = self.io_loop.call_at(self.io_loop.time(), lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_call_later_return(self):
handle = self.io_loop.call_later(0, lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_close_file_object(self):
"""When a file object is used instead of a numeric file descriptor,
the object should be closed (by IOLoop.close(all_fds=True),
not just the fd.
"""
# Use a socket since they are supported by IOLoop on all platforms.
# Unfortunately, sockets don't support the .closed attribute for
# inspecting their close status, so we must use a wrapper.
class SocketWrapper(object):
def __init__(self, sockobj):
self.sockobj = sockobj
self.closed = False
def fileno(self):
return self.sockobj.fileno()
def close(self):
self.closed = True
self.sockobj.close()
sockobj, port = bind_unused_port()
socket_wrapper = SocketWrapper(sockobj)
io_loop = IOLoop()
io_loop.add_handler(socket_wrapper, lambda fd, events: None,
IOLoop.READ)
io_loop.close(all_fds=True)
self.assertTrue(socket_wrapper.closed)
def test_handler_callback_file_object(self):
"""The handler callback receives the same fd object it passed in."""
server_sock, port = bind_unused_port()
fds = []
def handle_connection(fd, events):
fds.append(fd)
conn, addr = server_sock.accept()
conn.close()
self.stop()
self.io_loop.add_handler(server_sock, handle_connection, IOLoop.READ)
with contextlib.closing(socket.socket()) as client_sock:
client_sock.connect(('127.0.0.1', port))
self.wait()
self.io_loop.remove_handler(server_sock)
self.io_loop.add_handler(server_sock.fileno(), handle_connection,
IOLoop.READ)
with contextlib.closing(socket.socket()) as client_sock:
client_sock.connect(('127.0.0.1', port))
self.wait()
self.assertIs(fds[0], server_sock)
self.assertEqual(fds[1], server_sock.fileno())
self.io_loop.remove_handler(server_sock.fileno())
server_sock.close()
def test_mixed_fd_fileobj(self):
server_sock, port = bind_unused_port()
def f(fd, events):
pass
self.io_loop.add_handler(server_sock, f, IOLoop.READ)
with self.assertRaises(Exception):
# The exact error is unspecified - some implementations use
# IOError, others use ValueError.
self.io_loop.add_handler(server_sock.fileno(), f, IOLoop.READ)
self.io_loop.remove_handler(server_sock.fileno())
server_sock.close()
def test_reentrant(self):
"""Calling start() twice should raise an error, not deadlock."""
returned_from_start = [False]
got_exception = [False]
def callback():
try:
self.io_loop.start()
returned_from_start[0] = True
except Exception:
got_exception[0] = True
self.stop()
self.io_loop.add_callback(callback)
self.wait()
self.assertTrue(got_exception[0])
self.assertFalse(returned_from_start[0])
def test_exception_logging(self):
"""Uncaught exceptions get logged by the IOLoop."""
# Use a NullContext to keep the exception from being caught by
# AsyncTestCase.
with NullContext():
self.io_loop.add_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_exception_logging_future(self):
"""The IOLoop examines exceptions from Futures and logs them."""
with NullContext():
@gen.coroutine
def callback():
self.io_loop.add_callback(self.stop)
1 / 0
self.io_loop.add_callback(callback)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
@skipBefore35
def test_exception_logging_native_coro(self):
"""The IOLoop examines exceptions from awaitables and logs them."""
namespace = exec_test(globals(), locals(), """
async def callback():
# Stop the IOLoop two iterations after raising an exception
# to give the exception time to be logged.
self.io_loop.add_callback(self.io_loop.add_callback, self.stop)
1 / 0
""")
with NullContext():
self.io_loop.add_callback(namespace["callback"])
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_spawn_callback(self):
with ignore_deprecation():
# An added callback runs in the test's stack_context, so will be
# re-raised in wait().
self.io_loop.add_callback(lambda: 1 / 0)
with self.assertRaises(ZeroDivisionError):
self.wait()
# A spawned callback is run directly on the IOLoop, so it will be
# logged without stopping the test.
self.io_loop.spawn_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
@skipIfNonUnix
def test_remove_handler_from_handler(self):
# Create two sockets with simultaneous read events.
client, server = socket.socketpair()
try:
client.send(b'abc')
server.send(b'abc')
# After reading from one fd, remove the other from the IOLoop.
chunks = []
def handle_read(fd, events):
chunks.append(fd.recv(1024))
if fd is client:
self.io_loop.remove_handler(server)
else:
self.io_loop.remove_handler(client)
self.io_loop.add_handler(client, handle_read, self.io_loop.READ)
self.io_loop.add_handler(server, handle_read, self.io_loop.READ)
self.io_loop.call_later(0.1, self.stop)
self.wait()
# Only one fd was read; the other was cleanly removed.
self.assertEqual(chunks, [b'abc'])
finally:
client.close()
server.close()
@gen_test
def test_init_close_race(self):
# Regression test for #2367
def f():
for i in range(10):
loop = IOLoop()
loop.close()
yield gen.multi([self.io_loop.run_in_executor(None, f) for i in range(2)])
# Deliberately not a subclass of AsyncTestCase so the IOLoop isn't
# automatically set as current.
class TestIOLoopCurrent(unittest.TestCase):
def setUp(self):
self.io_loop = None
IOLoop.clear_current()
def tearDown(self):
if self.io_loop is not None:
self.io_loop.close()
def test_default_current(self):
self.io_loop = IOLoop()
# The first IOLoop with default arguments is made current.
self.assertIs(self.io_loop, IOLoop.current())
# A second IOLoop can be created but is not made current.
io_loop2 = IOLoop()
self.assertIs(self.io_loop, IOLoop.current())
io_loop2.close()
def test_non_current(self):
self.io_loop = IOLoop(make_current=False)
# The new IOLoop is not initially made current.
self.assertIsNone(IOLoop.current(instance=False))
# Starting the IOLoop makes it current, and stopping the loop
# makes it non-current. This process is repeatable.
for i in range(3):
def f():
self.current_io_loop = IOLoop.current()
self.io_loop.stop()
self.io_loop.add_callback(f)
self.io_loop.start()
self.assertIs(self.current_io_loop, self.io_loop)
# Now that the loop is stopped, it is no longer current.
self.assertIsNone(IOLoop.current(instance=False))
def test_force_current(self):
self.io_loop = IOLoop(make_current=True)
self.assertIs(self.io_loop, IOLoop.current())
with self.assertRaises(RuntimeError):
# A second make_current=True construction cannot succeed.
IOLoop(make_current=True)
# current() was not affected by the failed construction.
self.assertIs(self.io_loop, IOLoop.current())
class TestIOLoopCurrentAsync(AsyncTestCase):
@gen_test
def test_clear_without_current(self):
# If there is no current IOLoop, clear_current is a no-op (but
# should not fail). Use a thread so we see the threading.Local
# in a pristine state.
with ThreadPoolExecutor(1) as e:
yield e.submit(IOLoop.clear_current)
class TestIOLoopAddCallback(AsyncTestCase):
def setUp(self):
super(TestIOLoopAddCallback, self).setUp()
self.active_contexts = []
def add_callback(self, callback, *args, **kwargs):
self.io_loop.add_callback(callback, *args, **kwargs)
@contextlib.contextmanager
def context(self, name):
self.active_contexts.append(name)
yield
self.assertEqual(self.active_contexts.pop(), name)
def test_pre_wrap(self):
# A pre-wrapped callback is run in the context in which it was
# wrapped, not when it was added to the IOLoop.
def f1():
self.assertIn('c1', self.active_contexts)
self.assertNotIn('c2', self.active_contexts)
self.stop()
with ignore_deprecation():
with StackContext(functools.partial(self.context, 'c1')):
wrapped = wrap(f1)
with StackContext(functools.partial(self.context, 'c2')):
self.add_callback(wrapped)
self.wait()
def test_pre_wrap_with_args(self):
# Same as test_pre_wrap, but the function takes arguments.
# Implementation note: The function must not be wrapped in a
# functools.partial until after it has been passed through
# stack_context.wrap
def f1(foo, bar):
self.assertIn('c1', self.active_contexts)
self.assertNotIn('c2', self.active_contexts)
self.stop((foo, bar))
with ignore_deprecation():
with StackContext(functools.partial(self.context, 'c1')):
wrapped = wrap(f1)
with StackContext(functools.partial(self.context, 'c2')):
self.add_callback(wrapped, 1, bar=2)
result = self.wait()
self.assertEqual(result, (1, 2))
class TestIOLoopAddCallbackFromSignal(TestIOLoopAddCallback):
# Repeat the add_callback tests using add_callback_from_signal
def add_callback(self, callback, *args, **kwargs):
self.io_loop.add_callback_from_signal(callback, *args, **kwargs)
@unittest.skipIf(futures is None, "futures module not present")
class TestIOLoopFutures(AsyncTestCase):
def test_add_future_threads(self):
with futures.ThreadPoolExecutor(1) as pool:
self.io_loop.add_future(pool.submit(lambda: None),
lambda future: self.stop(future))
future = self.wait()
self.assertTrue(future.done())
self.assertTrue(future.result() is None)
def test_add_future_stack_context(self):
ready = threading.Event()
def task():
# we must wait for the ioloop callback to be scheduled before
# the task completes to ensure that add_future adds the callback
# asynchronously (which is the scenario in which capturing
# the stack_context matters)
ready.wait(1)
assert ready.isSet(), "timed out"
raise Exception("worker")
def callback(future):
self.future = future
raise Exception("callback")
def handle_exception(typ, value, traceback):
self.exception = value
self.stop()
return True
# stack_context propagates to the ioloop callback, but the worker
# task just has its exceptions caught and saved in the Future.
with ignore_deprecation():
with futures.ThreadPoolExecutor(1) as pool:
with ExceptionStackContext(handle_exception):
self.io_loop.add_future(pool.submit(task), callback)
ready.set()
self.wait()
self.assertEqual(self.exception.args[0], "callback")
self.assertEqual(self.future.exception().args[0], "worker")
@gen_test
def test_run_in_executor_gen(self):
event1 = threading.Event()
event2 = threading.Event()
def sync_func(self_event, other_event):
self_event.set()
other_event.wait()
# Note that return value doesn't actually do anything,
# it is just passed through to our final assertion to
# make sure it is passed through properly.
return self_event
# Run two synchronous functions, which would deadlock if not
# run in parallel.
res = yield [
IOLoop.current().run_in_executor(None, sync_func, event1, event2),
IOLoop.current().run_in_executor(None, sync_func, event2, event1)
]
self.assertEqual([event1, event2], res)
@skipBefore35
@gen_test
def test_run_in_executor_native(self):
event1 = threading.Event()
event2 = threading.Event()
def sync_func(self_event, other_event):
self_event.set()
other_event.wait()
return self_event
# Go through an async wrapper to ensure that the result of
# run_in_executor works with await and not just gen.coroutine
# (simply passing the underlying concurrrent future would do that).
namespace = exec_test(globals(), locals(), """
async def async_wrapper(self_event, other_event):
return await IOLoop.current().run_in_executor(
None, sync_func, self_event, other_event)
""")
res = yield [
namespace["async_wrapper"](event1, event2),
namespace["async_wrapper"](event2, event1)
]
self.assertEqual([event1, event2], res)
@gen_test
def test_set_default_executor(self):
count = [0]
class MyExecutor(futures.ThreadPoolExecutor):
def submit(self, func, *args):
count[0] += 1
return super(MyExecutor, self).submit(func, *args)
event = threading.Event()
def sync_func():
event.set()
executor = MyExecutor(1)
loop = IOLoop.current()
loop.set_default_executor(executor)
yield loop.run_in_executor(None, sync_func)
self.assertEqual(1, count[0])
self.assertTrue(event.is_set())
class TestIOLoopRunSync(unittest.TestCase):
def setUp(self):
self.io_loop = IOLoop()
def tearDown(self):
self.io_loop.close()
def test_sync_result(self):
with self.assertRaises(gen.BadYieldError):
self.io_loop.run_sync(lambda: 42)
def test_sync_exception(self):
with self.assertRaises(ZeroDivisionError):
self.io_loop.run_sync(lambda: 1 / 0)
def test_async_result(self):
@gen.coroutine
def f():
yield gen.moment
raise gen.Return(42)
self.assertEqual(self.io_loop.run_sync(f), 42)
def test_async_exception(self):
@gen.coroutine
def f():
yield gen.moment
1 / 0
with self.assertRaises(ZeroDivisionError):
self.io_loop.run_sync(f)
def test_current(self):
def f():
self.assertIs(IOLoop.current(), self.io_loop)
self.io_loop.run_sync(f)
def test_timeout(self):
@gen.coroutine
def f():
yield gen.sleep(1)
self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
@skipBefore35
def test_native_coroutine(self):
@gen.coroutine
def f1():
yield gen.moment
namespace = exec_test(globals(), locals(), """
async def f2():
await f1()
""")
self.io_loop.run_sync(namespace['f2'])
@unittest.skipIf(asyncio is not None,
'IOLoop configuration not available')
class TestPeriodicCallback(unittest.TestCase):
def setUp(self):
self.io_loop = FakeTimeIOLoop()
self.io_loop.make_current()
def tearDown(self):
self.io_loop.close()
def test_basic(self):
calls = []
def cb():
calls.append(self.io_loop.time())
pc = PeriodicCallback(cb, 10000)
pc.start()
self.io_loop.call_later(50, self.io_loop.stop)
self.io_loop.start()
self.assertEqual(calls, [1010, 1020, 1030, 1040, 1050])
def test_overrun(self):
sleep_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0]
expected = [
1010, 1020, 1030, # first 3 calls on schedule
1050, 1070, # next 2 delayed one cycle
1100, 1130, # next 2 delayed 2 cycles
1170, 1210, # next 2 delayed 3 cycles
1220, 1230, # then back on schedule.
]
calls = []
def cb():
calls.append(self.io_loop.time())
if not sleep_durations:
self.io_loop.stop()
return
self.io_loop.sleep(sleep_durations.pop(0))
pc = PeriodicCallback(cb, 10000)
pc.start()
self.io_loop.start()
self.assertEqual(calls, expected)
def test_io_loop_set_at_start(self):
# Check PeriodicCallback uses the current IOLoop at start() time,
# not at instantiation time.
calls = []
io_loop = FakeTimeIOLoop()
def cb():
calls.append(io_loop.time())
pc = PeriodicCallback(cb, 10000)
io_loop.make_current()
pc.start()
io_loop.call_later(50, io_loop.stop)
io_loop.start()
self.assertEqual(calls, [1010, 1020, 1030, 1040, 1050])
io_loop.close()
class TestPeriodicCallbackMath(unittest.TestCase):
def simulate_calls(self, pc, durations):
"""Simulate a series of calls to the PeriodicCallback.
Pass a list of call durations in seconds (negative values
work to simulate clock adjustments during the call, or more or
less equivalently, between calls). This method returns the
times at which each call would be made.
"""
calls = []
now = 1000
pc._next_timeout = now
for d in durations:
pc._update_next(now)
calls.append(pc._next_timeout)
now = pc._next_timeout + d
return calls
def test_basic(self):
pc = PeriodicCallback(None, 10000)
self.assertEqual(self.simulate_calls(pc, [0] * 5),
[1010, 1020, 1030, 1040, 1050])
def test_overrun(self):
# If a call runs for too long, we skip entire cycles to get
# back on schedule.
call_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0, 0]
expected = [
1010, 1020, 1030, # first 3 calls on schedule
1050, 1070, # next 2 delayed one cycle
1100, 1130, # next 2 delayed 2 cycles
1170, 1210, # next 2 delayed 3 cycles
1220, 1230, # then back on schedule.
]
pc = PeriodicCallback(None, 10000)
self.assertEqual(self.simulate_calls(pc, call_durations),
expected)
def test_clock_backwards(self):
pc = PeriodicCallback(None, 10000)
# Backwards jumps are ignored, potentially resulting in a
# slightly slow schedule (although we assume that when
# time.time() and time.monotonic() are different, time.time()
# is getting adjusted by NTP and is therefore more accurate)
self.assertEqual(self.simulate_calls(pc, [-2, -1, -3, -2, 0]),
[1010, 1020, 1030, 1040, 1050])
# For big jumps, we should perhaps alter the schedule, but we
# don't currently. This trace shows that we run callbacks
# every 10s of time.time(), but the first and second calls are
# 110s of real time apart because the backwards jump is
# ignored.
self.assertEqual(self.simulate_calls(pc, [-100, 0, 0]),
[1010, 1020, 1030])
@unittest.skipIf(mock is None, 'mock package not present')
def test_jitter(self):
random_times = [0.5, 1, 0, 0.75]
expected = [1010, 1022.5, 1030, 1041.25]
call_durations = [0] * len(random_times)
pc = PeriodicCallback(None, 10000, jitter=0.5)
def mock_random():
return random_times.pop(0)
with mock.patch('random.random', mock_random):
self.assertEqual(self.simulate_calls(pc, call_durations),
expected)
class TestIOLoopConfiguration(unittest.TestCase):
def run_python(self, *statements):
statements = [
'from tornado.ioloop import IOLoop, PollIOLoop',
'classname = lambda x: x.__class__.__name__',
] + list(statements)
args = [sys.executable, '-c', '; '.join(statements)]
return native_str(subprocess.check_output(args)).strip()
def test_default(self):
if asyncio is not None:
# When asyncio is available, it is used by default.
cls = self.run_python('print(classname(IOLoop.current()))')
self.assertEqual(cls, 'AsyncIOMainLoop')
cls = self.run_python('print(classname(IOLoop()))')
self.assertEqual(cls, 'AsyncIOLoop')
else:
# Otherwise, the default is a subclass of PollIOLoop
is_poll = self.run_python(
'print(isinstance(IOLoop.current(), PollIOLoop))')
self.assertEqual(is_poll, 'True')
@unittest.skipIf(asyncio is not None,
"IOLoop configuration not available")
def test_explicit_select(self):
# SelectIOLoop can always be configured explicitly.
default_class = self.run_python(
'IOLoop.configure("tornado.platform.select.SelectIOLoop")',
'print(classname(IOLoop.current()))')
self.assertEqual(default_class, 'SelectIOLoop')
@unittest.skipIf(asyncio is None, "asyncio module not present")
def test_asyncio(self):
cls = self.run_python(
'IOLoop.configure("tornado.platform.asyncio.AsyncIOLoop")',
'print(classname(IOLoop.current()))')
self.assertEqual(cls, 'AsyncIOMainLoop')
@unittest.skipIf(asyncio is None, "asyncio module not present")
def test_asyncio_main(self):
cls = self.run_python(
'from tornado.platform.asyncio import AsyncIOMainLoop',
'AsyncIOMainLoop().install()',
'print(classname(IOLoop.current()))')
self.assertEqual(cls, 'AsyncIOMainLoop')
@unittest.skipIf(twisted is None, "twisted module not present")
@unittest.skipIf(asyncio is not None,
"IOLoop configuration not available")
def test_twisted(self):
cls = self.run_python(
'from tornado.platform.twisted import TwistedIOLoop',
'TwistedIOLoop().install()',
'print(classname(IOLoop.current()))')
self.assertEqual(cls, 'TwistedIOLoop')
if __name__ == "__main__":
unittest.main()

1454
lib/tornado/test/iostream_test.py Executable file

File diff suppressed because it is too large Load Diff

131
lib/tornado/test/locale_test.py Executable file
View File

@@ -0,0 +1,131 @@
from __future__ import absolute_import, division, print_function
import datetime
import os
import shutil
import tempfile
import tornado.locale
from tornado.escape import utf8, to_unicode
from tornado.test.util import unittest, skipOnAppEngine
from tornado.util import unicode_type
class TranslationLoaderTest(unittest.TestCase):
# TODO: less hacky way to get isolated tests
SAVE_VARS = ['_translations', '_supported_locales', '_use_gettext']
def clear_locale_cache(self):
if hasattr(tornado.locale.Locale, '_cache'):
del tornado.locale.Locale._cache
def setUp(self):
self.saved = {}
for var in TranslationLoaderTest.SAVE_VARS:
self.saved[var] = getattr(tornado.locale, var)
self.clear_locale_cache()
def tearDown(self):
for k, v in self.saved.items():
setattr(tornado.locale, k, v)
self.clear_locale_cache()
def test_csv(self):
tornado.locale.load_translations(
os.path.join(os.path.dirname(__file__), 'csv_translations'))
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.CSVLocale))
self.assertEqual(locale.translate("school"), u"\u00e9cole")
# tempfile.mkdtemp is not available on app engine.
@skipOnAppEngine
def test_csv_bom(self):
with open(os.path.join(os.path.dirname(__file__), 'csv_translations',
'fr_FR.csv'), 'rb') as f:
char_data = to_unicode(f.read())
# Re-encode our input data (which is utf-8 without BOM) in
# encodings that use the BOM and ensure that we can still load
# it. Note that utf-16-le and utf-16-be do not write a BOM,
# so we only test whichver variant is native to our platform.
for encoding in ['utf-8-sig', 'utf-16']:
tmpdir = tempfile.mkdtemp()
try:
with open(os.path.join(tmpdir, 'fr_FR.csv'), 'wb') as f:
f.write(char_data.encode(encoding))
tornado.locale.load_translations(tmpdir)
locale = tornado.locale.get('fr_FR')
self.assertIsInstance(locale, tornado.locale.CSVLocale)
self.assertEqual(locale.translate("school"), u"\u00e9cole")
finally:
shutil.rmtree(tmpdir)
def test_gettext(self):
tornado.locale.load_gettext_translations(
os.path.join(os.path.dirname(__file__), 'gettext_translations'),
"tornado_test")
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.GettextLocale))
self.assertEqual(locale.translate("school"), u"\u00e9cole")
self.assertEqual(locale.pgettext("law", "right"), u"le droit")
self.assertEqual(locale.pgettext("good", "right"), u"le bien")
self.assertEqual(locale.pgettext("organization", "club", "clubs", 1), u"le club")
self.assertEqual(locale.pgettext("organization", "club", "clubs", 2), u"les clubs")
self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), u"le b\xe2ton")
self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), u"les b\xe2tons")
class LocaleDataTest(unittest.TestCase):
def test_non_ascii_name(self):
name = tornado.locale.LOCALE_NAMES['es_LA']['name']
self.assertTrue(isinstance(name, unicode_type))
self.assertEqual(name, u'Espa\u00f1ol')
self.assertEqual(utf8(name), b'Espa\xc3\xb1ol')
class EnglishTest(unittest.TestCase):
def test_format_date(self):
locale = tornado.locale.get('en_US')
date = datetime.datetime(2013, 4, 28, 18, 35)
self.assertEqual(locale.format_date(date, full_format=True),
'April 28, 2013 at 6:35 pm')
now = datetime.datetime.utcnow()
self.assertEqual(locale.format_date(now - datetime.timedelta(seconds=2), full_format=False),
'2 seconds ago')
self.assertEqual(locale.format_date(now - datetime.timedelta(minutes=2), full_format=False),
'2 minutes ago')
self.assertEqual(locale.format_date(now - datetime.timedelta(hours=2), full_format=False),
'2 hours ago')
self.assertEqual(locale.format_date(now - datetime.timedelta(days=1),
full_format=False, shorter=True), 'yesterday')
date = now - datetime.timedelta(days=2)
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
locale._weekdays[date.weekday()])
date = now - datetime.timedelta(days=300)
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
'%s %d' % (locale._months[date.month - 1], date.day))
date = now - datetime.timedelta(days=500)
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
'%s %d, %d' % (locale._months[date.month - 1], date.day, date.year))
def test_friendly_number(self):
locale = tornado.locale.get('en_US')
self.assertEqual(locale.friendly_number(1000000), '1,000,000')
def test_list(self):
locale = tornado.locale.get('en_US')
self.assertEqual(locale.list([]), '')
self.assertEqual(locale.list(['A']), 'A')
self.assertEqual(locale.list(['A', 'B']), 'A and B')
self.assertEqual(locale.list(['A', 'B', 'C']), 'A, B and C')
def test_format_day(self):
locale = tornado.locale.get('en_US')
date = datetime.datetime(2013, 4, 28, 18, 35)
self.assertEqual(locale.format_day(date=date, dow=True), 'Sunday, April 28')
self.assertEqual(locale.format_day(date=date, dow=False), 'April 28')

537
lib/tornado/test/locks_test.py Executable file
View File

@@ -0,0 +1,537 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function
from datetime import timedelta
from tornado import gen, locks
from tornado.gen import TimeoutError
from tornado.testing import gen_test, AsyncTestCase
from tornado.test.util import unittest, skipBefore35, exec_test
class ConditionTest(AsyncTestCase):
def setUp(self):
super(ConditionTest, self).setUp()
self.history = []
def record_done(self, future, key):
"""Record the resolution of a Future returned by Condition.wait."""
def callback(_):
if not future.result():
# wait() resolved to False, meaning it timed out.
self.history.append('timeout')
else:
self.history.append(key)
future.add_done_callback(callback)
def loop_briefly(self):
"""Run all queued callbacks on the IOLoop.
In these tests, this method is used after calling notify() to
preserve the pre-5.0 behavior in which callbacks ran
synchronously.
"""
self.io_loop.add_callback(self.stop)
self.wait()
def test_repr(self):
c = locks.Condition()
self.assertIn('Condition', repr(c))
self.assertNotIn('waiters', repr(c))
c.wait()
self.assertIn('waiters', repr(c))
@gen_test
def test_notify(self):
c = locks.Condition()
self.io_loop.call_later(0.01, c.notify)
yield c.wait()
def test_notify_1(self):
c = locks.Condition()
self.record_done(c.wait(), 'wait1')
self.record_done(c.wait(), 'wait2')
c.notify(1)
self.loop_briefly()
self.history.append('notify1')
c.notify(1)
self.loop_briefly()
self.history.append('notify2')
self.assertEqual(['wait1', 'notify1', 'wait2', 'notify2'],
self.history)
def test_notify_n(self):
c = locks.Condition()
for i in range(6):
self.record_done(c.wait(), i)
c.notify(3)
self.loop_briefly()
# Callbacks execute in the order they were registered.
self.assertEqual(list(range(3)), self.history)
c.notify(1)
self.loop_briefly()
self.assertEqual(list(range(4)), self.history)
c.notify(2)
self.loop_briefly()
self.assertEqual(list(range(6)), self.history)
def test_notify_all(self):
c = locks.Condition()
for i in range(4):
self.record_done(c.wait(), i)
c.notify_all()
self.loop_briefly()
self.history.append('notify_all')
# Callbacks execute in the order they were registered.
self.assertEqual(
list(range(4)) + ['notify_all'],
self.history)
@gen_test
def test_wait_timeout(self):
c = locks.Condition()
wait = c.wait(timedelta(seconds=0.01))
self.io_loop.call_later(0.02, c.notify) # Too late.
yield gen.sleep(0.03)
self.assertFalse((yield wait))
@gen_test
def test_wait_timeout_preempted(self):
c = locks.Condition()
# This fires before the wait times out.
self.io_loop.call_later(0.01, c.notify)
wait = c.wait(timedelta(seconds=0.02))
yield gen.sleep(0.03)
yield wait # No TimeoutError.
@gen_test
def test_notify_n_with_timeout(self):
# Register callbacks 0, 1, 2, and 3. Callback 1 has a timeout.
# Wait for that timeout to expire, then do notify(2) and make
# sure everyone runs. Verifies that a timed-out callback does
# not count against the 'n' argument to notify().
c = locks.Condition()
self.record_done(c.wait(), 0)
self.record_done(c.wait(timedelta(seconds=0.01)), 1)
self.record_done(c.wait(), 2)
self.record_done(c.wait(), 3)
# Wait for callback 1 to time out.
yield gen.sleep(0.02)
self.assertEqual(['timeout'], self.history)
c.notify(2)
yield gen.sleep(0.01)
self.assertEqual(['timeout', 0, 2], self.history)
self.assertEqual(['timeout', 0, 2], self.history)
c.notify()
yield
self.assertEqual(['timeout', 0, 2, 3], self.history)
@gen_test
def test_notify_all_with_timeout(self):
c = locks.Condition()
self.record_done(c.wait(), 0)
self.record_done(c.wait(timedelta(seconds=0.01)), 1)
self.record_done(c.wait(), 2)
# Wait for callback 1 to time out.
yield gen.sleep(0.02)
self.assertEqual(['timeout'], self.history)
c.notify_all()
yield
self.assertEqual(['timeout', 0, 2], self.history)
@gen_test
def test_nested_notify(self):
# Ensure no notifications lost, even if notify() is reentered by a
# waiter calling notify().
c = locks.Condition()
# Three waiters.
futures = [c.wait() for _ in range(3)]
# First and second futures resolved. Second future reenters notify(),
# resolving third future.
futures[1].add_done_callback(lambda _: c.notify())
c.notify(2)
yield
self.assertTrue(all(f.done() for f in futures))
@gen_test
def test_garbage_collection(self):
# Test that timed-out waiters are occasionally cleaned from the queue.
c = locks.Condition()
for _ in range(101):
c.wait(timedelta(seconds=0.01))
future = c.wait()
self.assertEqual(102, len(c._waiters))
# Let first 101 waiters time out, triggering a collection.
yield gen.sleep(0.02)
self.assertEqual(1, len(c._waiters))
# Final waiter is still active.
self.assertFalse(future.done())
c.notify()
self.assertTrue(future.done())
class EventTest(AsyncTestCase):
def test_repr(self):
event = locks.Event()
self.assertTrue('clear' in str(event))
self.assertFalse('set' in str(event))
event.set()
self.assertFalse('clear' in str(event))
self.assertTrue('set' in str(event))
def test_event(self):
e = locks.Event()
future_0 = e.wait()
e.set()
future_1 = e.wait()
e.clear()
future_2 = e.wait()
self.assertTrue(future_0.done())
self.assertTrue(future_1.done())
self.assertFalse(future_2.done())
@gen_test
def test_event_timeout(self):
e = locks.Event()
with self.assertRaises(TimeoutError):
yield e.wait(timedelta(seconds=0.01))
# After a timed-out waiter, normal operation works.
self.io_loop.add_timeout(timedelta(seconds=0.01), e.set)
yield e.wait(timedelta(seconds=1))
def test_event_set_multiple(self):
e = locks.Event()
e.set()
e.set()
self.assertTrue(e.is_set())
def test_event_wait_clear(self):
e = locks.Event()
f0 = e.wait()
e.clear()
f1 = e.wait()
e.set()
self.assertTrue(f0.done())
self.assertTrue(f1.done())
class SemaphoreTest(AsyncTestCase):
def test_negative_value(self):
self.assertRaises(ValueError, locks.Semaphore, value=-1)
def test_repr(self):
sem = locks.Semaphore()
self.assertIn('Semaphore', repr(sem))
self.assertIn('unlocked,value:1', repr(sem))
sem.acquire()
self.assertIn('locked', repr(sem))
self.assertNotIn('waiters', repr(sem))
sem.acquire()
self.assertIn('waiters', repr(sem))
def test_acquire(self):
sem = locks.Semaphore()
f0 = sem.acquire()
self.assertTrue(f0.done())
# Wait for release().
f1 = sem.acquire()
self.assertFalse(f1.done())
f2 = sem.acquire()
sem.release()
self.assertTrue(f1.done())
self.assertFalse(f2.done())
sem.release()
self.assertTrue(f2.done())
sem.release()
# Now acquire() is instant.
self.assertTrue(sem.acquire().done())
self.assertEqual(0, len(sem._waiters))
@gen_test
def test_acquire_timeout(self):
sem = locks.Semaphore(2)
yield sem.acquire()
yield sem.acquire()
acquire = sem.acquire(timedelta(seconds=0.01))
self.io_loop.call_later(0.02, sem.release) # Too late.
yield gen.sleep(0.3)
with self.assertRaises(gen.TimeoutError):
yield acquire
sem.acquire()
f = sem.acquire()
self.assertFalse(f.done())
sem.release()
self.assertTrue(f.done())
@gen_test
def test_acquire_timeout_preempted(self):
sem = locks.Semaphore(1)
yield sem.acquire()
# This fires before the wait times out.
self.io_loop.call_later(0.01, sem.release)
acquire = sem.acquire(timedelta(seconds=0.02))
yield gen.sleep(0.03)
yield acquire # No TimeoutError.
def test_release_unacquired(self):
# Unbounded releases are allowed, and increment the semaphore's value.
sem = locks.Semaphore()
sem.release()
sem.release()
# Now the counter is 3. We can acquire three times before blocking.
self.assertTrue(sem.acquire().done())
self.assertTrue(sem.acquire().done())
self.assertTrue(sem.acquire().done())
self.assertFalse(sem.acquire().done())
@gen_test
def test_garbage_collection(self):
# Test that timed-out waiters are occasionally cleaned from the queue.
sem = locks.Semaphore(value=0)
futures = [sem.acquire(timedelta(seconds=0.01)) for _ in range(101)]
future = sem.acquire()
self.assertEqual(102, len(sem._waiters))
# Let first 101 waiters time out, triggering a collection.
yield gen.sleep(0.02)
self.assertEqual(1, len(sem._waiters))
# Final waiter is still active.
self.assertFalse(future.done())
sem.release()
self.assertTrue(future.done())
# Prevent "Future exception was never retrieved" messages.
for future in futures:
self.assertRaises(TimeoutError, future.result)
class SemaphoreContextManagerTest(AsyncTestCase):
@gen_test
def test_context_manager(self):
sem = locks.Semaphore()
with (yield sem.acquire()) as yielded:
self.assertTrue(yielded is None)
# Semaphore was released and can be acquired again.
self.assertTrue(sem.acquire().done())
@skipBefore35
@gen_test
def test_context_manager_async_await(self):
# Repeat the above test using 'async with'.
sem = locks.Semaphore()
namespace = exec_test(globals(), locals(), """
async def f():
async with sem as yielded:
self.assertTrue(yielded is None)
""")
yield namespace['f']()
# Semaphore was released and can be acquired again.
self.assertTrue(sem.acquire().done())
@gen_test
def test_context_manager_exception(self):
sem = locks.Semaphore()
with self.assertRaises(ZeroDivisionError):
with (yield sem.acquire()):
1 / 0
# Semaphore was released and can be acquired again.
self.assertTrue(sem.acquire().done())
@gen_test
def test_context_manager_timeout(self):
sem = locks.Semaphore()
with (yield sem.acquire(timedelta(seconds=0.01))):
pass
# Semaphore was released and can be acquired again.
self.assertTrue(sem.acquire().done())
@gen_test
def test_context_manager_timeout_error(self):
sem = locks.Semaphore(value=0)
with self.assertRaises(gen.TimeoutError):
with (yield sem.acquire(timedelta(seconds=0.01))):
pass
# Counter is still 0.
self.assertFalse(sem.acquire().done())
@gen_test
def test_context_manager_contended(self):
sem = locks.Semaphore()
history = []
@gen.coroutine
def f(index):
with (yield sem.acquire()):
history.append('acquired %d' % index)
yield gen.sleep(0.01)
history.append('release %d' % index)
yield [f(i) for i in range(2)]
expected_history = []
for i in range(2):
expected_history.extend(['acquired %d' % i, 'release %d' % i])
self.assertEqual(expected_history, history)
@gen_test
def test_yield_sem(self):
# Ensure we catch a "with (yield sem)", which should be
# "with (yield sem.acquire())".
with self.assertRaises(gen.BadYieldError):
with (yield locks.Semaphore()):
pass
def test_context_manager_misuse(self):
# Ensure we catch a "with sem", which should be
# "with (yield sem.acquire())".
with self.assertRaises(RuntimeError):
with locks.Semaphore():
pass
class BoundedSemaphoreTest(AsyncTestCase):
def test_release_unacquired(self):
sem = locks.BoundedSemaphore()
self.assertRaises(ValueError, sem.release)
# Value is 0.
sem.acquire()
# Block on acquire().
future = sem.acquire()
self.assertFalse(future.done())
sem.release()
self.assertTrue(future.done())
# Value is 1.
sem.release()
self.assertRaises(ValueError, sem.release)
class LockTests(AsyncTestCase):
def test_repr(self):
lock = locks.Lock()
# No errors.
repr(lock)
lock.acquire()
repr(lock)
def test_acquire_release(self):
lock = locks.Lock()
self.assertTrue(lock.acquire().done())
future = lock.acquire()
self.assertFalse(future.done())
lock.release()
self.assertTrue(future.done())
@gen_test
def test_acquire_fifo(self):
lock = locks.Lock()
self.assertTrue(lock.acquire().done())
N = 5
history = []
@gen.coroutine
def f(idx):
with (yield lock.acquire()):
history.append(idx)
futures = [f(i) for i in range(N)]
self.assertFalse(any(future.done() for future in futures))
lock.release()
yield futures
self.assertEqual(list(range(N)), history)
@skipBefore35
@gen_test
def test_acquire_fifo_async_with(self):
# Repeat the above test using `async with lock:`
# instead of `with (yield lock.acquire()):`.
lock = locks.Lock()
self.assertTrue(lock.acquire().done())
N = 5
history = []
namespace = exec_test(globals(), locals(), """
async def f(idx):
async with lock:
history.append(idx)
""")
futures = [namespace['f'](i) for i in range(N)]
lock.release()
yield futures
self.assertEqual(list(range(N)), history)
@gen_test
def test_acquire_timeout(self):
lock = locks.Lock()
lock.acquire()
with self.assertRaises(gen.TimeoutError):
yield lock.acquire(timeout=timedelta(seconds=0.01))
# Still locked.
self.assertFalse(lock.acquire().done())
def test_multi_release(self):
lock = locks.Lock()
self.assertRaises(RuntimeError, lock.release)
lock.acquire()
lock.release()
self.assertRaises(RuntimeError, lock.release)
@gen_test
def test_yield_lock(self):
# Ensure we catch a "with (yield lock)", which should be
# "with (yield lock.acquire())".
with self.assertRaises(gen.BadYieldError):
with (yield locks.Lock()):
pass
def test_context_manager_misuse(self):
# Ensure we catch a "with lock", which should be
# "with (yield lock.acquire())".
with self.assertRaises(RuntimeError):
with locks.Lock():
pass
if __name__ == '__main__':
unittest.main()

241
lib/tornado/test/log_test.py Executable file
View File

@@ -0,0 +1,241 @@
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function
import contextlib
import glob
import logging
import os
import re
import subprocess
import sys
import tempfile
import warnings
from tornado.escape import utf8
from tornado.log import LogFormatter, define_logging_options, enable_pretty_logging
from tornado.options import OptionParser
from tornado.test.util import unittest
from tornado.util import basestring_type
@contextlib.contextmanager
def ignore_bytes_warning():
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=BytesWarning)
yield
class LogFormatterTest(unittest.TestCase):
# Matches the output of a single logging call (which may be multiple lines
# if a traceback was included, so we use the DOTALL option)
LINE_RE = re.compile(
b"(?s)\x01\\[E [0-9]{6} [0-9]{2}:[0-9]{2}:[0-9]{2} log_test:[0-9]+\\]\x02 (.*)")
def setUp(self):
self.formatter = LogFormatter(color=False)
# Fake color support. We can't guarantee anything about the $TERM
# variable when the tests are run, so just patch in some values
# for testing. (testing with color off fails to expose some potential
# encoding issues from the control characters)
self.formatter._colors = {
logging.ERROR: u"\u0001",
}
self.formatter._normal = u"\u0002"
# construct a Logger directly to bypass getLogger's caching
self.logger = logging.Logger('LogFormatterTest')
self.logger.propagate = False
self.tempdir = tempfile.mkdtemp()
self.filename = os.path.join(self.tempdir, 'log.out')
self.handler = self.make_handler(self.filename)
self.handler.setFormatter(self.formatter)
self.logger.addHandler(self.handler)
def tearDown(self):
self.handler.close()
os.unlink(self.filename)
os.rmdir(self.tempdir)
def make_handler(self, filename):
# Base case: default setup without explicit encoding.
# In python 2, supports arbitrary byte strings and unicode objects
# that contain only ascii. In python 3, supports ascii-only unicode
# strings (but byte strings will be repr'd automatically).
return logging.FileHandler(filename)
def get_output(self):
with open(self.filename, "rb") as f:
line = f.read().strip()
m = LogFormatterTest.LINE_RE.match(line)
if m:
return m.group(1)
else:
raise Exception("output didn't match regex: %r" % line)
def test_basic_logging(self):
self.logger.error("foo")
self.assertEqual(self.get_output(), b"foo")
def test_bytes_logging(self):
with ignore_bytes_warning():
# This will be "\xe9" on python 2 or "b'\xe9'" on python 3
self.logger.error(b"\xe9")
self.assertEqual(self.get_output(), utf8(repr(b"\xe9")))
def test_utf8_logging(self):
with ignore_bytes_warning():
self.logger.error(u"\u00e9".encode("utf8"))
if issubclass(bytes, basestring_type):
# on python 2, utf8 byte strings (and by extension ascii byte
# strings) are passed through as-is.
self.assertEqual(self.get_output(), utf8(u"\u00e9"))
else:
# on python 3, byte strings always get repr'd even if
# they're ascii-only, so this degenerates into another
# copy of test_bytes_logging.
self.assertEqual(self.get_output(), utf8(repr(utf8(u"\u00e9"))))
def test_bytes_exception_logging(self):
try:
raise Exception(b'\xe9')
except Exception:
self.logger.exception('caught exception')
# This will be "Exception: \xe9" on python 2 or
# "Exception: b'\xe9'" on python 3.
output = self.get_output()
self.assertRegexpMatches(output, br'Exception.*\\xe9')
# The traceback contains newlines, which should not have been escaped.
self.assertNotIn(br'\n', output)
class UnicodeLogFormatterTest(LogFormatterTest):
def make_handler(self, filename):
# Adding an explicit encoding configuration allows non-ascii unicode
# strings in both python 2 and 3, without changing the behavior
# for byte strings.
return logging.FileHandler(filename, encoding="utf8")
def test_unicode_logging(self):
self.logger.error(u"\u00e9")
self.assertEqual(self.get_output(), utf8(u"\u00e9"))
class EnablePrettyLoggingTest(unittest.TestCase):
def setUp(self):
super(EnablePrettyLoggingTest, self).setUp()
self.options = OptionParser()
define_logging_options(self.options)
self.logger = logging.Logger('tornado.test.log_test.EnablePrettyLoggingTest')
self.logger.propagate = False
def test_log_file(self):
tmpdir = tempfile.mkdtemp()
try:
self.options.log_file_prefix = tmpdir + '/test_log'
enable_pretty_logging(options=self.options, logger=self.logger)
self.assertEqual(1, len(self.logger.handlers))
self.logger.error('hello')
self.logger.handlers[0].flush()
filenames = glob.glob(tmpdir + '/test_log*')
self.assertEqual(1, len(filenames))
with open(filenames[0]) as f:
self.assertRegexpMatches(f.read(), r'^\[E [^]]*\] hello$')
finally:
for handler in self.logger.handlers:
handler.flush()
handler.close()
for filename in glob.glob(tmpdir + '/test_log*'):
os.unlink(filename)
os.rmdir(tmpdir)
def test_log_file_with_timed_rotating(self):
tmpdir = tempfile.mkdtemp()
try:
self.options.log_file_prefix = tmpdir + '/test_log'
self.options.log_rotate_mode = 'time'
enable_pretty_logging(options=self.options, logger=self.logger)
self.logger.error('hello')
self.logger.handlers[0].flush()
filenames = glob.glob(tmpdir + '/test_log*')
self.assertEqual(1, len(filenames))
with open(filenames[0]) as f:
self.assertRegexpMatches(
f.read(),
r'^\[E [^]]*\] hello$')
finally:
for handler in self.logger.handlers:
handler.flush()
handler.close()
for filename in glob.glob(tmpdir + '/test_log*'):
os.unlink(filename)
os.rmdir(tmpdir)
def test_wrong_rotate_mode_value(self):
try:
self.options.log_file_prefix = 'some_path'
self.options.log_rotate_mode = 'wrong_mode'
self.assertRaises(ValueError, enable_pretty_logging,
options=self.options, logger=self.logger)
finally:
for handler in self.logger.handlers:
handler.flush()
handler.close()
class LoggingOptionTest(unittest.TestCase):
"""Test the ability to enable and disable Tornado's logging hooks."""
def logs_present(self, statement, args=None):
# Each test may manipulate and/or parse the options and then logs
# a line at the 'info' level. This level is ignored in the
# logging module by default, but Tornado turns it on by default
# so it is the easiest way to tell whether tornado's logging hooks
# ran.
IMPORT = 'from tornado.options import options, parse_command_line'
LOG_INFO = 'import logging; logging.info("hello")'
program = ';'.join([IMPORT, statement, LOG_INFO])
proc = subprocess.Popen(
[sys.executable, '-c', program] + (args or []),
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
stdout, stderr = proc.communicate()
self.assertEqual(proc.returncode, 0, 'process failed: %r' % stdout)
return b'hello' in stdout
def test_default(self):
self.assertFalse(self.logs_present('pass'))
def test_tornado_default(self):
self.assertTrue(self.logs_present('parse_command_line()'))
def test_disable_command_line(self):
self.assertFalse(self.logs_present('parse_command_line()',
['--logging=none']))
def test_disable_command_line_case_insensitive(self):
self.assertFalse(self.logs_present('parse_command_line()',
['--logging=None']))
def test_disable_code_string(self):
self.assertFalse(self.logs_present(
'options.logging = "none"; parse_command_line()'))
def test_disable_code_none(self):
self.assertFalse(self.logs_present(
'options.logging = None; parse_command_line()'))
def test_disable_override(self):
# command line trumps code defaults
self.assertTrue(self.logs_present(
'options.logging = None; parse_command_line()',
['--logging=info']))

242
lib/tornado/test/netutil_test.py Executable file
View File

@@ -0,0 +1,242 @@
from __future__ import absolute_import, division, print_function
import errno
import os
import signal
import socket
from subprocess import Popen
import sys
import time
from tornado.netutil import (
BlockingResolver, OverrideResolver, ThreadedResolver, is_valid_ip, bind_sockets
)
from tornado.stack_context import ExceptionStackContext
from tornado.testing import AsyncTestCase, gen_test, bind_unused_port
from tornado.test.util import unittest, skipIfNoNetwork, ignore_deprecation
try:
from concurrent import futures
except ImportError:
futures = None
try:
import pycares # type: ignore
except ImportError:
pycares = None
else:
from tornado.platform.caresresolver import CaresResolver
try:
import twisted # type: ignore
import twisted.names # type: ignore
except ImportError:
twisted = None
else:
from tornado.platform.twisted import TwistedResolver
class _ResolverTestMixin(object):
def test_localhost(self):
with ignore_deprecation():
self.resolver.resolve('localhost', 80, callback=self.stop)
result = self.wait()
self.assertIn((socket.AF_INET, ('127.0.0.1', 80)), result)
@gen_test
def test_future_interface(self):
addrinfo = yield self.resolver.resolve('localhost', 80,
socket.AF_UNSPEC)
self.assertIn((socket.AF_INET, ('127.0.0.1', 80)),
addrinfo)
# It is impossible to quickly and consistently generate an error in name
# resolution, so test this case separately, using mocks as needed.
class _ResolverErrorTestMixin(object):
def test_bad_host(self):
def handler(exc_typ, exc_val, exc_tb):
self.stop(exc_val)
return True # Halt propagation.
with ignore_deprecation():
with ExceptionStackContext(handler):
self.resolver.resolve('an invalid domain', 80, callback=self.stop)
result = self.wait()
self.assertIsInstance(result, Exception)
@gen_test
def test_future_interface_bad_host(self):
with self.assertRaises(IOError):
yield self.resolver.resolve('an invalid domain', 80,
socket.AF_UNSPEC)
def _failing_getaddrinfo(*args):
"""Dummy implementation of getaddrinfo for use in mocks"""
raise socket.gaierror(errno.EIO, "mock: lookup failed")
@skipIfNoNetwork
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(BlockingResolverTest, self).setUp()
self.resolver = BlockingResolver()
# getaddrinfo-based tests need mocking to reliably generate errors;
# some configurations are slow to produce errors and take longer than
# our default timeout.
class BlockingResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
def setUp(self):
super(BlockingResolverErrorTest, self).setUp()
self.resolver = BlockingResolver()
self.real_getaddrinfo = socket.getaddrinfo
socket.getaddrinfo = _failing_getaddrinfo
def tearDown(self):
socket.getaddrinfo = self.real_getaddrinfo
super(BlockingResolverErrorTest, self).tearDown()
class OverrideResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(OverrideResolverTest, self).setUp()
mapping = {
('google.com', 80): ('1.2.3.4', 80),
('google.com', 80, socket.AF_INET): ('1.2.3.4', 80),
('google.com', 80, socket.AF_INET6): ('2a02:6b8:7c:40c:c51e:495f:e23a:3', 80)
}
self.resolver = OverrideResolver(BlockingResolver(), mapping)
@gen_test
def test_resolve_multiaddr(self):
result = yield self.resolver.resolve('google.com', 80, socket.AF_INET)
self.assertIn((socket.AF_INET, ('1.2.3.4', 80)), result)
result = yield self.resolver.resolve('google.com', 80, socket.AF_INET6)
self.assertIn((socket.AF_INET6, ('2a02:6b8:7c:40c:c51e:495f:e23a:3', 80, 0, 0)), result)
@skipIfNoNetwork
@unittest.skipIf(futures is None, "futures module not present")
class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(ThreadedResolverTest, self).setUp()
self.resolver = ThreadedResolver()
def tearDown(self):
self.resolver.close()
super(ThreadedResolverTest, self).tearDown()
class ThreadedResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
def setUp(self):
super(ThreadedResolverErrorTest, self).setUp()
self.resolver = BlockingResolver()
self.real_getaddrinfo = socket.getaddrinfo
socket.getaddrinfo = _failing_getaddrinfo
def tearDown(self):
socket.getaddrinfo = self.real_getaddrinfo
super(ThreadedResolverErrorTest, self).tearDown()
@skipIfNoNetwork
@unittest.skipIf(futures is None, "futures module not present")
@unittest.skipIf(sys.platform == 'win32', "preexec_fn not available on win32")
class ThreadedResolverImportTest(unittest.TestCase):
def test_import(self):
TIMEOUT = 5
# Test for a deadlock when importing a module that runs the
# ThreadedResolver at import-time. See resolve_test.py for
# full explanation.
command = [
sys.executable,
'-c',
'import tornado.test.resolve_test_helper']
start = time.time()
popen = Popen(command, preexec_fn=lambda: signal.alarm(TIMEOUT))
while time.time() - start < TIMEOUT:
return_code = popen.poll()
if return_code is not None:
self.assertEqual(0, return_code)
return # Success.
time.sleep(0.05)
self.fail("import timed out")
# We do not test errors with CaresResolver:
# Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
# with an NXDOMAIN status code. Most resolvers treat this as an error;
# C-ares returns the results, making the "bad_host" tests unreliable.
# C-ares will try to resolve even malformed names, such as the
# name with spaces used in this test.
@skipIfNoNetwork
@unittest.skipIf(pycares is None, "pycares module not present")
class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(CaresResolverTest, self).setUp()
self.resolver = CaresResolver()
# TwistedResolver produces consistent errors in our test cases so we
# could test the regular and error cases in the same class. However,
# in the error cases it appears that cleanup of socket objects is
# handled asynchronously and occasionally results in "unclosed socket"
# warnings if not given time to shut down (and there is no way to
# explicitly shut it down). This makes the test flaky, so we do not
# test error cases here.
@skipIfNoNetwork
@unittest.skipIf(twisted is None, "twisted module not present")
@unittest.skipIf(getattr(twisted, '__version__', '0.0') < "12.1", "old version of twisted")
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(TwistedResolverTest, self).setUp()
self.resolver = TwistedResolver()
class IsValidIPTest(unittest.TestCase):
def test_is_valid_ip(self):
self.assertTrue(is_valid_ip('127.0.0.1'))
self.assertTrue(is_valid_ip('4.4.4.4'))
self.assertTrue(is_valid_ip('::1'))
self.assertTrue(is_valid_ip('2620:0:1cfe:face:b00c::3'))
self.assertTrue(not is_valid_ip('www.google.com'))
self.assertTrue(not is_valid_ip('localhost'))
self.assertTrue(not is_valid_ip('4.4.4.4<'))
self.assertTrue(not is_valid_ip(' 127.0.0.1'))
self.assertTrue(not is_valid_ip(''))
self.assertTrue(not is_valid_ip(' '))
self.assertTrue(not is_valid_ip('\n'))
self.assertTrue(not is_valid_ip('\x00'))
class TestPortAllocation(unittest.TestCase):
def test_same_port_allocation(self):
if 'TRAVIS' in os.environ:
self.skipTest("dual-stack servers often have port conflicts on travis")
sockets = bind_sockets(None, 'localhost')
try:
port = sockets[0].getsockname()[1]
self.assertTrue(all(s.getsockname()[1] == port
for s in sockets[1:]))
finally:
for sock in sockets:
sock.close()
@unittest.skipIf(not hasattr(socket, "SO_REUSEPORT"), "SO_REUSEPORT is not supported")
def test_reuse_port(self):
sockets = []
socket, port = bind_unused_port(reuse_port=True)
try:
sockets = bind_sockets(port, '127.0.0.1', reuse_port=True)
self.assertTrue(all(s.getsockname()[1] == port for s in sockets))
finally:
socket.close()
for sock in sockets:
sock.close()

View File

@@ -0,0 +1,7 @@
port=443
port=443
username='李康'
foo_bar='a'
my_path = __file__

327
lib/tornado/test/options_test.py Executable file
View File

@@ -0,0 +1,327 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import datetime
import os
import sys
from tornado.options import OptionParser, Error
from tornado.util import basestring_type, PY3
from tornado.test.util import unittest, subTest
if PY3:
from io import StringIO
else:
from cStringIO import StringIO
try:
# py33+
from unittest import mock # type: ignore
except ImportError:
try:
import mock # type: ignore
except ImportError:
mock = None
class Email(object):
def __init__(self, value):
if isinstance(value, str) and '@' in value:
self._value = value
else:
raise ValueError()
@property
def value(self):
return self._value
class OptionsTest(unittest.TestCase):
def test_parse_command_line(self):
options = OptionParser()
options.define("port", default=80)
options.parse_command_line(["main.py", "--port=443"])
self.assertEqual(options.port, 443)
def test_parse_config_file(self):
options = OptionParser()
options.define("port", default=80)
options.define("username", default='foo')
options.define("my_path")
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"options_test.cfg")
options.parse_config_file(config_path)
self.assertEqual(options.port, 443)
self.assertEqual(options.username, "李康")
self.assertEqual(options.my_path, config_path)
def test_parse_callbacks(self):
options = OptionParser()
self.called = False
def callback():
self.called = True
options.add_parse_callback(callback)
# non-final parse doesn't run callbacks
options.parse_command_line(["main.py"], final=False)
self.assertFalse(self.called)
# final parse does
options.parse_command_line(["main.py"])
self.assertTrue(self.called)
# callbacks can be run more than once on the same options
# object if there are multiple final parses
self.called = False
options.parse_command_line(["main.py"])
self.assertTrue(self.called)
def test_help(self):
options = OptionParser()
try:
orig_stderr = sys.stderr
sys.stderr = StringIO()
with self.assertRaises(SystemExit):
options.parse_command_line(["main.py", "--help"])
usage = sys.stderr.getvalue()
finally:
sys.stderr = orig_stderr
self.assertIn("Usage:", usage)
def test_subcommand(self):
base_options = OptionParser()
base_options.define("verbose", default=False)
sub_options = OptionParser()
sub_options.define("foo", type=str)
rest = base_options.parse_command_line(
["main.py", "--verbose", "subcommand", "--foo=bar"])
self.assertEqual(rest, ["subcommand", "--foo=bar"])
self.assertTrue(base_options.verbose)
rest2 = sub_options.parse_command_line(rest)
self.assertEqual(rest2, [])
self.assertEqual(sub_options.foo, "bar")
# the two option sets are distinct
try:
orig_stderr = sys.stderr
sys.stderr = StringIO()
with self.assertRaises(Error):
sub_options.parse_command_line(["subcommand", "--verbose"])
finally:
sys.stderr = orig_stderr
def test_setattr(self):
options = OptionParser()
options.define('foo', default=1, type=int)
options.foo = 2
self.assertEqual(options.foo, 2)
def test_setattr_type_check(self):
# setattr requires that options be the right type and doesn't
# parse from string formats.
options = OptionParser()
options.define('foo', default=1, type=int)
with self.assertRaises(Error):
options.foo = '2'
def test_setattr_with_callback(self):
values = []
options = OptionParser()
options.define('foo', default=1, type=int, callback=values.append)
options.foo = 2
self.assertEqual(values, [2])
def _sample_options(self):
options = OptionParser()
options.define('a', default=1)
options.define('b', default=2)
return options
def test_iter(self):
options = self._sample_options()
# OptionParsers always define 'help'.
self.assertEqual(set(['a', 'b', 'help']), set(iter(options)))
def test_getitem(self):
options = self._sample_options()
self.assertEqual(1, options['a'])
def test_setitem(self):
options = OptionParser()
options.define('foo', default=1, type=int)
options['foo'] = 2
self.assertEqual(options['foo'], 2)
def test_items(self):
options = self._sample_options()
# OptionParsers always define 'help'.
expected = [('a', 1), ('b', 2), ('help', options.help)]
actual = sorted(options.items())
self.assertEqual(expected, actual)
def test_as_dict(self):
options = self._sample_options()
expected = {'a': 1, 'b': 2, 'help': options.help}
self.assertEqual(expected, options.as_dict())
def test_group_dict(self):
options = OptionParser()
options.define('a', default=1)
options.define('b', group='b_group', default=2)
frame = sys._getframe(0)
this_file = frame.f_code.co_filename
self.assertEqual(set(['b_group', '', this_file]), options.groups())
b_group_dict = options.group_dict('b_group')
self.assertEqual({'b': 2}, b_group_dict)
self.assertEqual({}, options.group_dict('nonexistent'))
@unittest.skipIf(mock is None, 'mock package not present')
def test_mock_patch(self):
# ensure that our setattr hooks don't interfere with mock.patch
options = OptionParser()
options.define('foo', default=1)
options.parse_command_line(['main.py', '--foo=2'])
self.assertEqual(options.foo, 2)
with mock.patch.object(options.mockable(), 'foo', 3):
self.assertEqual(options.foo, 3)
self.assertEqual(options.foo, 2)
# Try nested patches mixed with explicit sets
with mock.patch.object(options.mockable(), 'foo', 4):
self.assertEqual(options.foo, 4)
options.foo = 5
self.assertEqual(options.foo, 5)
with mock.patch.object(options.mockable(), 'foo', 6):
self.assertEqual(options.foo, 6)
self.assertEqual(options.foo, 5)
self.assertEqual(options.foo, 2)
def _define_options(self):
options = OptionParser()
options.define('str', type=str)
options.define('basestring', type=basestring_type)
options.define('int', type=int)
options.define('float', type=float)
options.define('datetime', type=datetime.datetime)
options.define('timedelta', type=datetime.timedelta)
options.define('email', type=Email)
options.define('list-of-int', type=int, multiple=True)
return options
def _check_options_values(self, options):
self.assertEqual(options.str, 'asdf')
self.assertEqual(options.basestring, 'qwer')
self.assertEqual(options.int, 42)
self.assertEqual(options.float, 1.5)
self.assertEqual(options.datetime,
datetime.datetime(2013, 4, 28, 5, 16))
self.assertEqual(options.timedelta, datetime.timedelta(seconds=45))
self.assertEqual(options.email.value, 'tornado@web.com')
self.assertTrue(isinstance(options.email, Email))
self.assertEqual(options.list_of_int, [1, 2, 3])
def test_types(self):
options = self._define_options()
options.parse_command_line(['main.py',
'--str=asdf',
'--basestring=qwer',
'--int=42',
'--float=1.5',
'--datetime=2013-04-28 05:16',
'--timedelta=45s',
'--email=tornado@web.com',
'--list-of-int=1,2,3'])
self._check_options_values(options)
def test_types_with_conf_file(self):
for config_file_name in ("options_test_types.cfg",
"options_test_types_str.cfg"):
options = self._define_options()
options.parse_config_file(os.path.join(os.path.dirname(__file__),
config_file_name))
self._check_options_values(options)
def test_multiple_string(self):
options = OptionParser()
options.define('foo', type=str, multiple=True)
options.parse_command_line(['main.py', '--foo=a,b,c'])
self.assertEqual(options.foo, ['a', 'b', 'c'])
def test_multiple_int(self):
options = OptionParser()
options.define('foo', type=int, multiple=True)
options.parse_command_line(['main.py', '--foo=1,3,5:7'])
self.assertEqual(options.foo, [1, 3, 5, 6, 7])
def test_error_redefine(self):
options = OptionParser()
options.define('foo')
with self.assertRaises(Error) as cm:
options.define('foo')
self.assertRegexpMatches(str(cm.exception),
'Option.*foo.*already defined')
def test_error_redefine_underscore(self):
# Ensure that the dash/underscore normalization doesn't
# interfere with the redefinition error.
tests = [
('foo-bar', 'foo-bar'),
('foo_bar', 'foo_bar'),
('foo-bar', 'foo_bar'),
('foo_bar', 'foo-bar'),
]
for a, b in tests:
with subTest(self, a=a, b=b):
options = OptionParser()
options.define(a)
with self.assertRaises(Error) as cm:
options.define(b)
self.assertRegexpMatches(str(cm.exception),
'Option.*foo.bar.*already defined')
def test_dash_underscore_cli(self):
# Dashes and underscores should be interchangeable.
for defined_name in ['foo-bar', 'foo_bar']:
for flag in ['--foo-bar=a', '--foo_bar=a']:
options = OptionParser()
options.define(defined_name)
options.parse_command_line(['main.py', flag])
# Attr-style access always uses underscores.
self.assertEqual(options.foo_bar, 'a')
# Dict-style access allows both.
self.assertEqual(options['foo-bar'], 'a')
self.assertEqual(options['foo_bar'], 'a')
def test_dash_underscore_file(self):
# No matter how an option was defined, it can be set with underscores
# in a config file.
for defined_name in ['foo-bar', 'foo_bar']:
options = OptionParser()
options.define(defined_name)
options.parse_config_file(os.path.join(os.path.dirname(__file__),
"options_test.cfg"))
self.assertEqual(options.foo_bar, 'a')
def test_dash_underscore_introspection(self):
# Original names are preserved in introspection APIs.
options = OptionParser()
options.define('with-dash', group='g')
options.define('with_underscore', group='g')
all_options = ['help', 'with-dash', 'with_underscore']
self.assertEqual(sorted(options), all_options)
self.assertEqual(sorted(k for (k, v) in options.items()), all_options)
self.assertEqual(sorted(options.as_dict().keys()), all_options)
self.assertEqual(sorted(options.group_dict('g')),
['with-dash', 'with_underscore'])
# --help shows CLI-style names with dashes.
buf = StringIO()
options.print_help(buf)
self.assertIn('--with-dash', buf.getvalue())
self.assertIn('--with-underscore', buf.getvalue())

View File

@@ -0,0 +1,11 @@
from datetime import datetime, timedelta
from tornado.test.options_test import Email
str = 'asdf'
basestring = 'qwer'
int = 42
float = 1.5
datetime = datetime(2013, 4, 28, 5, 16)
timedelta = timedelta(0, 45)
email = Email('tornado@web.com')
list_of_int = [1, 2, 3]

View File

@@ -0,0 +1,8 @@
str = 'asdf'
basestring = 'qwer'
int = 42
float = 1.5
datetime = '2013-04-28 05:16'
timedelta = '45s'
email = 'tornado@web.com'
list_of_int = '1,2,3'

266
lib/tornado/test/process_test.py Executable file
View File

@@ -0,0 +1,266 @@
from __future__ import absolute_import, division, print_function
import logging
import os
import signal
import subprocess
import sys
from tornado.httpclient import HTTPClient, HTTPError
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.log import gen_log
from tornado.process import fork_processes, task_id, Subprocess
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase, gen_test
from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application
try:
import asyncio
except ImportError:
asyncio = None
def skip_if_twisted():
if IOLoop.configured_class().__name__.endswith('TwistedIOLoop'):
raise unittest.SkipTest("Process tests not compatible with TwistedIOLoop")
# Not using AsyncHTTPTestCase because we need control over the IOLoop.
@skipIfNonUnix
class ProcessTest(unittest.TestCase):
def get_app(self):
class ProcessHandler(RequestHandler):
def get(self):
if self.get_argument("exit", None):
# must use os._exit instead of sys.exit so unittest's
# exception handler doesn't catch it
os._exit(int(self.get_argument("exit")))
if self.get_argument("signal", None):
os.kill(os.getpid(),
int(self.get_argument("signal")))
self.write(str(os.getpid()))
return Application([("/", ProcessHandler)])
def tearDown(self):
if task_id() is not None:
# We're in a child process, and probably got to this point
# via an uncaught exception. If we return now, both
# processes will continue with the rest of the test suite.
# Exit now so the parent process will restart the child
# (since we don't have a clean way to signal failure to
# the parent that won't restart)
logging.error("aborting child process from tearDown")
logging.shutdown()
os._exit(1)
# In the surviving process, clear the alarm we set earlier
signal.alarm(0)
super(ProcessTest, self).tearDown()
def test_multi_process(self):
# This test doesn't work on twisted because we use the global
# reactor and don't restore it to a sane state after the fork
# (asyncio has the same issue, but we have a special case in
# place for it).
skip_if_twisted()
with ExpectLog(gen_log, "(Starting .* processes|child .* exited|uncaught exception)"):
sock, port = bind_unused_port()
def get_url(path):
return "http://127.0.0.1:%d%s" % (port, path)
# ensure that none of these processes live too long
signal.alarm(5) # master process
try:
id = fork_processes(3, max_restarts=3)
self.assertTrue(id is not None)
signal.alarm(5) # child processes
except SystemExit as e:
# if we exit cleanly from fork_processes, all the child processes
# finished with status 0
self.assertEqual(e.code, 0)
self.assertTrue(task_id() is None)
sock.close()
return
try:
if asyncio is not None:
# Reset the global asyncio event loop, which was put into
# a broken state by the fork.
asyncio.set_event_loop(asyncio.new_event_loop())
if id in (0, 1):
self.assertEqual(id, task_id())
server = HTTPServer(self.get_app())
server.add_sockets([sock])
IOLoop.current().start()
elif id == 2:
self.assertEqual(id, task_id())
sock.close()
# Always use SimpleAsyncHTTPClient here; the curl
# version appears to get confused sometimes if the
# connection gets closed before it's had a chance to
# switch from writing mode to reading mode.
client = HTTPClient(SimpleAsyncHTTPClient)
def fetch(url, fail_ok=False):
try:
return client.fetch(get_url(url))
except HTTPError as e:
if not (fail_ok and e.code == 599):
raise
# Make two processes exit abnormally
fetch("/?exit=2", fail_ok=True)
fetch("/?exit=3", fail_ok=True)
# They've been restarted, so a new fetch will work
int(fetch("/").body)
# Now the same with signals
# Disabled because on the mac a process dying with a signal
# can trigger an "Application exited abnormally; send error
# report to Apple?" prompt.
# fetch("/?signal=%d" % signal.SIGTERM, fail_ok=True)
# fetch("/?signal=%d" % signal.SIGABRT, fail_ok=True)
# int(fetch("/").body)
# Now kill them normally so they won't be restarted
fetch("/?exit=0", fail_ok=True)
# One process left; watch it's pid change
pid = int(fetch("/").body)
fetch("/?exit=4", fail_ok=True)
pid2 = int(fetch("/").body)
self.assertNotEqual(pid, pid2)
# Kill the last one so we shut down cleanly
fetch("/?exit=0", fail_ok=True)
os._exit(0)
except Exception:
logging.error("exception in child process %d", id, exc_info=True)
raise
@skipIfNonUnix
class SubprocessTest(AsyncTestCase):
@gen_test
def test_subprocess(self):
if IOLoop.configured_class().__name__.endswith('LayeredTwistedIOLoop'):
# This test fails non-deterministically with LayeredTwistedIOLoop.
# (the read_until('\n') returns '\n' instead of 'hello\n')
# This probably indicates a problem with either TornadoReactor
# or TwistedIOLoop, but I haven't been able to track it down
# and for now this is just causing spurious travis-ci failures.
raise unittest.SkipTest("Subprocess tests not compatible with "
"LayeredTwistedIOLoop")
subproc = Subprocess([sys.executable, '-u', '-i'],
stdin=Subprocess.STREAM,
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT)
self.addCleanup(lambda: (subproc.proc.terminate(), subproc.proc.wait()))
self.addCleanup(subproc.stdout.close)
self.addCleanup(subproc.stdin.close)
yield subproc.stdout.read_until(b'>>> ')
subproc.stdin.write(b"print('hello')\n")
data = yield subproc.stdout.read_until(b'\n')
self.assertEqual(data, b"hello\n")
yield subproc.stdout.read_until(b">>> ")
subproc.stdin.write(b"raise SystemExit\n")
data = yield subproc.stdout.read_until_close()
self.assertEqual(data, b"")
@gen_test
def test_close_stdin(self):
# Close the parent's stdin handle and see that the child recognizes it.
subproc = Subprocess([sys.executable, '-u', '-i'],
stdin=Subprocess.STREAM,
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT)
self.addCleanup(lambda: (subproc.proc.terminate(), subproc.proc.wait()))
yield subproc.stdout.read_until(b'>>> ')
subproc.stdin.close()
data = yield subproc.stdout.read_until_close()
self.assertEqual(data, b"\n")
@gen_test
def test_stderr(self):
# This test is mysteriously flaky on twisted: it succeeds, but logs
# an error of EBADF on closing a file descriptor.
skip_if_twisted()
subproc = Subprocess([sys.executable, '-u', '-c',
r"import sys; sys.stderr.write('hello\n')"],
stderr=Subprocess.STREAM)
self.addCleanup(lambda: (subproc.proc.terminate(), subproc.proc.wait()))
data = yield subproc.stderr.read_until(b'\n')
self.assertEqual(data, b'hello\n')
# More mysterious EBADF: This fails if done with self.addCleanup instead of here.
subproc.stderr.close()
def test_sigchild(self):
# Twisted's SIGCHLD handler and Subprocess's conflict with each other.
skip_if_twisted()
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c', 'pass'])
subproc.set_exit_callback(self.stop)
ret = self.wait()
self.assertEqual(ret, 0)
self.assertEqual(subproc.returncode, ret)
@gen_test
def test_sigchild_future(self):
skip_if_twisted()
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c', 'pass'])
ret = yield subproc.wait_for_exit()
self.assertEqual(ret, 0)
self.assertEqual(subproc.returncode, ret)
def test_sigchild_signal(self):
skip_if_twisted()
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c',
'import time; time.sleep(30)'],
stdout=Subprocess.STREAM)
self.addCleanup(subproc.stdout.close)
subproc.set_exit_callback(self.stop)
os.kill(subproc.pid, signal.SIGTERM)
try:
ret = self.wait(timeout=1.0)
except AssertionError:
# We failed to get the termination signal. This test is
# occasionally flaky on pypy, so try to get a little more
# information: did the process close its stdout
# (indicating that the problem is in the parent process's
# signal handling) or did the child process somehow fail
# to terminate?
subproc.stdout.read_until_close(callback=self.stop)
try:
self.wait(timeout=1.0)
except AssertionError:
raise AssertionError("subprocess failed to terminate")
else:
raise AssertionError("subprocess closed stdout but failed to "
"get termination signal")
self.assertEqual(subproc.returncode, ret)
self.assertEqual(ret, -signal.SIGTERM)
@gen_test
def test_wait_for_exit_raise(self):
skip_if_twisted()
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)'])
with self.assertRaises(subprocess.CalledProcessError) as cm:
yield subproc.wait_for_exit()
self.assertEqual(cm.exception.returncode, 1)
@gen_test
def test_wait_for_exit_raise_disabled(self):
skip_if_twisted()
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)'])
ret = yield subproc.wait_for_exit(raise_error=False)
self.assertEqual(ret, 1)

423
lib/tornado/test/queues_test.py Executable file
View File

@@ -0,0 +1,423 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function
from datetime import timedelta
from random import random
from tornado import gen, queues
from tornado.gen import TimeoutError
from tornado.testing import gen_test, AsyncTestCase
from tornado.test.util import unittest, skipBefore35, exec_test
class QueueBasicTest(AsyncTestCase):
def test_repr_and_str(self):
q = queues.Queue(maxsize=1)
self.assertIn(hex(id(q)), repr(q))
self.assertNotIn(hex(id(q)), str(q))
q.get()
for q_str in repr(q), str(q):
self.assertTrue(q_str.startswith('<Queue'))
self.assertIn('maxsize=1', q_str)
self.assertIn('getters[1]', q_str)
self.assertNotIn('putters', q_str)
self.assertNotIn('tasks', q_str)
q.put(None)
q.put(None)
# Now the queue is full, this putter blocks.
q.put(None)
for q_str in repr(q), str(q):
self.assertNotIn('getters', q_str)
self.assertIn('putters[1]', q_str)
self.assertIn('tasks=2', q_str)
def test_order(self):
q = queues.Queue()
for i in [1, 3, 2]:
q.put_nowait(i)
items = [q.get_nowait() for _ in range(3)]
self.assertEqual([1, 3, 2], items)
@gen_test
def test_maxsize(self):
self.assertRaises(TypeError, queues.Queue, maxsize=None)
self.assertRaises(ValueError, queues.Queue, maxsize=-1)
q = queues.Queue(maxsize=2)
self.assertTrue(q.empty())
self.assertFalse(q.full())
self.assertEqual(2, q.maxsize)
self.assertTrue(q.put(0).done())
self.assertTrue(q.put(1).done())
self.assertFalse(q.empty())
self.assertTrue(q.full())
put2 = q.put(2)
self.assertFalse(put2.done())
self.assertEqual(0, (yield q.get())) # Make room.
self.assertTrue(put2.done())
self.assertFalse(q.empty())
self.assertTrue(q.full())
class QueueGetTest(AsyncTestCase):
@gen_test
def test_blocking_get(self):
q = queues.Queue()
q.put_nowait(0)
self.assertEqual(0, (yield q.get()))
def test_nonblocking_get(self):
q = queues.Queue()
q.put_nowait(0)
self.assertEqual(0, q.get_nowait())
def test_nonblocking_get_exception(self):
q = queues.Queue()
self.assertRaises(queues.QueueEmpty, q.get_nowait)
@gen_test
def test_get_with_putters(self):
q = queues.Queue(1)
q.put_nowait(0)
put = q.put(1)
self.assertEqual(0, (yield q.get()))
self.assertIsNone((yield put))
@gen_test
def test_blocking_get_wait(self):
q = queues.Queue()
q.put(0)
self.io_loop.call_later(0.01, q.put, 1)
self.io_loop.call_later(0.02, q.put, 2)
self.assertEqual(0, (yield q.get(timeout=timedelta(seconds=1))))
self.assertEqual(1, (yield q.get(timeout=timedelta(seconds=1))))
@gen_test
def test_get_timeout(self):
q = queues.Queue()
get_timeout = q.get(timeout=timedelta(seconds=0.01))
get = q.get()
with self.assertRaises(TimeoutError):
yield get_timeout
q.put_nowait(0)
self.assertEqual(0, (yield get))
@gen_test
def test_get_timeout_preempted(self):
q = queues.Queue()
get = q.get(timeout=timedelta(seconds=0.01))
q.put(0)
yield gen.sleep(0.02)
self.assertEqual(0, (yield get))
@gen_test
def test_get_clears_timed_out_putters(self):
q = queues.Queue(1)
# First putter succeeds, remainder block.
putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
put = q.put(10)
self.assertEqual(10, len(q._putters))
yield gen.sleep(0.02)
self.assertEqual(10, len(q._putters))
self.assertFalse(put.done()) # Final waiter is still active.
q.put(11)
self.assertEqual(0, (yield q.get())) # get() clears the waiters.
self.assertEqual(1, len(q._putters))
for putter in putters[1:]:
self.assertRaises(TimeoutError, putter.result)
@gen_test
def test_get_clears_timed_out_getters(self):
q = queues.Queue()
getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)]
get = q.get()
self.assertEqual(11, len(q._getters))
yield gen.sleep(0.02)
self.assertEqual(11, len(q._getters))
self.assertFalse(get.done()) # Final waiter is still active.
q.get() # get() clears the waiters.
self.assertEqual(2, len(q._getters))
for getter in getters:
self.assertRaises(TimeoutError, getter.result)
@skipBefore35
@gen_test
def test_async_for(self):
q = queues.Queue()
for i in range(5):
q.put(i)
namespace = exec_test(globals(), locals(), """
async def f():
results = []
async for i in q:
results.append(i)
if i == 4:
return results
""")
results = yield namespace['f']()
self.assertEqual(results, list(range(5)))
class QueuePutTest(AsyncTestCase):
@gen_test
def test_blocking_put(self):
q = queues.Queue()
q.put(0)
self.assertEqual(0, q.get_nowait())
def test_nonblocking_put_exception(self):
q = queues.Queue(1)
q.put(0)
self.assertRaises(queues.QueueFull, q.put_nowait, 1)
@gen_test
def test_put_with_getters(self):
q = queues.Queue()
get0 = q.get()
get1 = q.get()
yield q.put(0)
self.assertEqual(0, (yield get0))
yield q.put(1)
self.assertEqual(1, (yield get1))
@gen_test
def test_nonblocking_put_with_getters(self):
q = queues.Queue()
get0 = q.get()
get1 = q.get()
q.put_nowait(0)
# put_nowait does *not* immediately unblock getters.
yield gen.moment
self.assertEqual(0, (yield get0))
q.put_nowait(1)
yield gen.moment
self.assertEqual(1, (yield get1))
@gen_test
def test_blocking_put_wait(self):
q = queues.Queue(1)
q.put_nowait(0)
self.io_loop.call_later(0.01, q.get)
self.io_loop.call_later(0.02, q.get)
futures = [q.put(0), q.put(1)]
self.assertFalse(any(f.done() for f in futures))
yield futures
@gen_test
def test_put_timeout(self):
q = queues.Queue(1)
q.put_nowait(0) # Now it's full.
put_timeout = q.put(1, timeout=timedelta(seconds=0.01))
put = q.put(2)
with self.assertRaises(TimeoutError):
yield put_timeout
self.assertEqual(0, q.get_nowait())
# 1 was never put in the queue.
self.assertEqual(2, (yield q.get()))
# Final get() unblocked this putter.
yield put
@gen_test
def test_put_timeout_preempted(self):
q = queues.Queue(1)
q.put_nowait(0)
put = q.put(1, timeout=timedelta(seconds=0.01))
q.get()
yield gen.sleep(0.02)
yield put # No TimeoutError.
@gen_test
def test_put_clears_timed_out_putters(self):
q = queues.Queue(1)
# First putter succeeds, remainder block.
putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
put = q.put(10)
self.assertEqual(10, len(q._putters))
yield gen.sleep(0.02)
self.assertEqual(10, len(q._putters))
self.assertFalse(put.done()) # Final waiter is still active.
q.put(11) # put() clears the waiters.
self.assertEqual(2, len(q._putters))
for putter in putters[1:]:
self.assertRaises(TimeoutError, putter.result)
@gen_test
def test_put_clears_timed_out_getters(self):
q = queues.Queue()
getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)]
get = q.get()
q.get()
self.assertEqual(12, len(q._getters))
yield gen.sleep(0.02)
self.assertEqual(12, len(q._getters))
self.assertFalse(get.done()) # Final waiters still active.
q.put(0) # put() clears the waiters.
self.assertEqual(1, len(q._getters))
self.assertEqual(0, (yield get))
for getter in getters:
self.assertRaises(TimeoutError, getter.result)
@gen_test
def test_float_maxsize(self):
# Non-int maxsize must round down: http://bugs.python.org/issue21723
q = queues.Queue(maxsize=1.3)
self.assertTrue(q.empty())
self.assertFalse(q.full())
q.put_nowait(0)
q.put_nowait(1)
self.assertFalse(q.empty())
self.assertTrue(q.full())
self.assertRaises(queues.QueueFull, q.put_nowait, 2)
self.assertEqual(0, q.get_nowait())
self.assertFalse(q.empty())
self.assertFalse(q.full())
yield q.put(2)
put = q.put(3)
self.assertFalse(put.done())
self.assertEqual(1, (yield q.get()))
yield put
self.assertTrue(q.full())
class QueueJoinTest(AsyncTestCase):
queue_class = queues.Queue
def test_task_done_underflow(self):
q = self.queue_class()
self.assertRaises(ValueError, q.task_done)
@gen_test
def test_task_done(self):
q = self.queue_class()
for i in range(100):
q.put_nowait(i)
self.accumulator = 0
@gen.coroutine
def worker():
while True:
item = yield q.get()
self.accumulator += item
q.task_done()
yield gen.sleep(random() * 0.01)
# Two coroutines share work.
worker()
worker()
yield q.join()
self.assertEqual(sum(range(100)), self.accumulator)
@gen_test
def test_task_done_delay(self):
# Verify it is task_done(), not get(), that unblocks join().
q = self.queue_class()
q.put_nowait(0)
join = q.join()
self.assertFalse(join.done())
yield q.get()
self.assertFalse(join.done())
yield gen.moment
self.assertFalse(join.done())
q.task_done()
self.assertTrue(join.done())
@gen_test
def test_join_empty_queue(self):
q = self.queue_class()
yield q.join()
yield q.join()
@gen_test
def test_join_timeout(self):
q = self.queue_class()
q.put(0)
with self.assertRaises(TimeoutError):
yield q.join(timeout=timedelta(seconds=0.01))
class PriorityQueueJoinTest(QueueJoinTest):
queue_class = queues.PriorityQueue
@gen_test
def test_order(self):
q = self.queue_class(maxsize=2)
q.put_nowait((1, 'a'))
q.put_nowait((0, 'b'))
self.assertTrue(q.full())
q.put((3, 'c'))
q.put((2, 'd'))
self.assertEqual((0, 'b'), q.get_nowait())
self.assertEqual((1, 'a'), (yield q.get()))
self.assertEqual((2, 'd'), q.get_nowait())
self.assertEqual((3, 'c'), (yield q.get()))
self.assertTrue(q.empty())
class LifoQueueJoinTest(QueueJoinTest):
queue_class = queues.LifoQueue
@gen_test
def test_order(self):
q = self.queue_class(maxsize=2)
q.put_nowait(1)
q.put_nowait(0)
self.assertTrue(q.full())
q.put(3)
q.put(2)
self.assertEqual(3, q.get_nowait())
self.assertEqual(2, (yield q.get()))
self.assertEqual(0, q.get_nowait())
self.assertEqual(1, (yield q.get()))
self.assertTrue(q.empty())
class ProducerConsumerTest(AsyncTestCase):
@gen_test
def test_producer_consumer(self):
q = queues.Queue(maxsize=3)
history = []
# We don't yield between get() and task_done(), so get() must wait for
# the next tick. Otherwise we'd immediately call task_done and unblock
# join() before q.put() resumes, and we'd only process the first four
# items.
@gen.coroutine
def consumer():
while True:
history.append((yield q.get()))
q.task_done()
@gen.coroutine
def producer():
for item in range(10):
yield q.put(item)
consumer()
yield producer()
yield q.join()
self.assertEqual(list(range(10)), history)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,11 @@
from __future__ import absolute_import, division, print_function
from tornado.ioloop import IOLoop
from tornado.netutil import ThreadedResolver
# When this module is imported, it runs getaddrinfo on a thread. Since
# the hostname is unicode, getaddrinfo attempts to import encodings.idna
# but blocks on the import lock. Verify that ThreadedResolver avoids
# this deadlock.
resolver = ThreadedResolver()
IOLoop.current().run_sync(lambda: resolver.resolve(u'localhost', 80))

247
lib/tornado/test/routing_test.py Executable file
View File

@@ -0,0 +1,247 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function
from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine # noqa: E501
from tornado.routing import HostMatches, PathMatches, ReversibleRouter, Router, Rule, RuleRouter
from tornado.testing import AsyncHTTPTestCase
from tornado.web import Application, HTTPError, RequestHandler
from tornado.wsgi import WSGIContainer
class BasicRouter(Router):
def find_handler(self, request, **kwargs):
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
def finish(self):
self.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": "2"}),
b"OK"
)
self.connection.finish()
return MessageDelegate(request.connection)
class BasicRouterTestCase(AsyncHTTPTestCase):
def get_app(self):
return BasicRouter()
def test_basic_router(self):
response = self.fetch("/any_request")
self.assertEqual(response.body, b"OK")
resources = {}
class GetResource(RequestHandler):
def get(self, path):
if path not in resources:
raise HTTPError(404)
self.finish(resources[path])
class PostResource(RequestHandler):
def post(self, path):
resources[path] = self.request.body
class HTTPMethodRouter(Router):
def __init__(self, app):
self.app = app
def find_handler(self, request, **kwargs):
handler = GetResource if request.method == "GET" else PostResource
return self.app.get_handler_delegate(request, handler, path_args=[request.path])
class HTTPMethodRouterTestCase(AsyncHTTPTestCase):
def get_app(self):
return HTTPMethodRouter(Application())
def test_http_method_router(self):
response = self.fetch("/post_resource", method="POST", body="data")
self.assertEqual(response.code, 200)
response = self.fetch("/get_resource")
self.assertEqual(response.code, 404)
response = self.fetch("/post_resource")
self.assertEqual(response.code, 200)
self.assertEqual(response.body, b"data")
def _get_named_handler(handler_name):
class Handler(RequestHandler):
def get(self, *args, **kwargs):
if self.application.settings.get("app_name") is not None:
self.write(self.application.settings["app_name"] + ": ")
self.finish(handler_name + ": " + self.reverse_url(handler_name))
return Handler
FirstHandler = _get_named_handler("first_handler")
SecondHandler = _get_named_handler("second_handler")
class CustomRouter(ReversibleRouter):
def __init__(self):
super(CustomRouter, self).__init__()
self.routes = {}
def add_routes(self, routes):
self.routes.update(routes)
def find_handler(self, request, **kwargs):
if request.path in self.routes:
app, handler = self.routes[request.path]
return app.get_handler_delegate(request, handler)
def reverse_url(self, name, *args):
handler_path = '/' + name
return handler_path if handler_path in self.routes else None
class CustomRouterTestCase(AsyncHTTPTestCase):
def get_app(self):
class CustomApplication(Application):
def reverse_url(self, name, *args):
return router.reverse_url(name, *args)
router = CustomRouter()
app1 = CustomApplication(app_name="app1")
app2 = CustomApplication(app_name="app2")
router.add_routes({
"/first_handler": (app1, FirstHandler),
"/second_handler": (app2, SecondHandler),
"/first_handler_second_app": (app2, FirstHandler),
})
return router
def test_custom_router(self):
response = self.fetch("/first_handler")
self.assertEqual(response.body, b"app1: first_handler: /first_handler")
response = self.fetch("/second_handler")
self.assertEqual(response.body, b"app2: second_handler: /second_handler")
response = self.fetch("/first_handler_second_app")
self.assertEqual(response.body, b"app2: first_handler: /first_handler")
class ConnectionDelegate(HTTPServerConnectionDelegate):
def start_request(self, server_conn, request_conn):
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
def finish(self):
response_body = b"OK"
self.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": str(len(response_body))}))
self.connection.write(response_body)
self.connection.finish()
return MessageDelegate(request_conn)
class RuleRouterTest(AsyncHTTPTestCase):
def get_app(self):
app = Application()
def request_callable(request):
request.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": "2"}))
request.connection.write(b"OK")
request.connection.finish()
router = CustomRouter()
router.add_routes({
"/nested_handler": (app, _get_named_handler("nested_handler"))
})
app.add_handlers(".*", [
(HostMatches("www.example.com"), [
(PathMatches("/first_handler"),
"tornado.test.routing_test.SecondHandler", {}, "second_handler")
]),
Rule(PathMatches("/.*handler"), router),
Rule(PathMatches("/first_handler"), FirstHandler, name="first_handler"),
Rule(PathMatches("/request_callable"), request_callable),
("/connection_delegate", ConnectionDelegate())
])
return app
def test_rule_based_router(self):
response = self.fetch("/first_handler")
self.assertEqual(response.body, b"first_handler: /first_handler")
response = self.fetch("/first_handler", headers={'Host': 'www.example.com'})
self.assertEqual(response.body, b"second_handler: /first_handler")
response = self.fetch("/nested_handler")
self.assertEqual(response.body, b"nested_handler: /nested_handler")
response = self.fetch("/nested_not_found_handler")
self.assertEqual(response.code, 404)
response = self.fetch("/connection_delegate")
self.assertEqual(response.body, b"OK")
response = self.fetch("/request_callable")
self.assertEqual(response.body, b"OK")
response = self.fetch("/404")
self.assertEqual(response.code, 404)
class WSGIContainerTestCase(AsyncHTTPTestCase):
def get_app(self):
wsgi_app = WSGIContainer(self.wsgi_app)
class Handler(RequestHandler):
def get(self, *args, **kwargs):
self.finish(self.reverse_url("tornado"))
return RuleRouter([
(PathMatches("/tornado.*"), Application([(r"/tornado/test", Handler, {}, "tornado")])),
(PathMatches("/wsgi"), wsgi_app),
])
def wsgi_app(self, environ, start_response):
start_response("200 OK", [])
return [b"WSGI"]
def test_wsgi_container(self):
response = self.fetch("/tornado/test")
self.assertEqual(response.body, b"/tornado/test")
response = self.fetch("/wsgi")
self.assertEqual(response.body, b"WSGI")
def test_delegate_not_found(self):
response = self.fetch("/404")
self.assertEqual(response.code, 404)

219
lib/tornado/test/runtests.py Executable file
View File

@@ -0,0 +1,219 @@
from __future__ import absolute_import, division, print_function
import gc
import io
import locale # system locale module, not tornado.locale
import logging
import operator
import textwrap
import sys
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.netutil import Resolver
from tornado.options import define, options, add_parse_callback
from tornado.test.util import unittest
try:
reduce # py2
except NameError:
from functools import reduce # py3
TEST_MODULES = [
'tornado.httputil.doctests',
'tornado.iostream.doctests',
'tornado.util.doctests',
'tornado.test.asyncio_test',
'tornado.test.auth_test',
'tornado.test.autoreload_test',
'tornado.test.concurrent_test',
'tornado.test.curl_httpclient_test',
'tornado.test.escape_test',
'tornado.test.gen_test',
'tornado.test.http1connection_test',
'tornado.test.httpclient_test',
'tornado.test.httpserver_test',
'tornado.test.httputil_test',
'tornado.test.import_test',
'tornado.test.ioloop_test',
'tornado.test.iostream_test',
'tornado.test.locale_test',
'tornado.test.locks_test',
'tornado.test.netutil_test',
'tornado.test.log_test',
'tornado.test.options_test',
'tornado.test.process_test',
'tornado.test.queues_test',
'tornado.test.routing_test',
'tornado.test.simple_httpclient_test',
'tornado.test.stack_context_test',
'tornado.test.tcpclient_test',
'tornado.test.tcpserver_test',
'tornado.test.template_test',
'tornado.test.testing_test',
'tornado.test.twisted_test',
'tornado.test.util_test',
'tornado.test.web_test',
'tornado.test.websocket_test',
'tornado.test.windows_test',
'tornado.test.wsgi_test',
]
def all():
return unittest.defaultTestLoader.loadTestsFromNames(TEST_MODULES)
def test_runner_factory(stderr):
class TornadoTextTestRunner(unittest.TextTestRunner):
def __init__(self, *args, **kwargs):
super(TornadoTextTestRunner, self).__init__(*args, stream=stderr, **kwargs)
def run(self, test):
result = super(TornadoTextTestRunner, self).run(test)
if result.skipped:
skip_reasons = set(reason for (test, reason) in result.skipped)
self.stream.write(textwrap.fill(
"Some tests were skipped because: %s" %
", ".join(sorted(skip_reasons))))
self.stream.write("\n")
return result
return TornadoTextTestRunner
class LogCounter(logging.Filter):
"""Counts the number of WARNING or higher log records."""
def __init__(self, *args, **kwargs):
# Can't use super() because logging.Filter is an old-style class in py26
logging.Filter.__init__(self, *args, **kwargs)
self.info_count = self.warning_count = self.error_count = 0
def filter(self, record):
if record.levelno >= logging.ERROR:
self.error_count += 1
elif record.levelno >= logging.WARNING:
self.warning_count += 1
elif record.levelno >= logging.INFO:
self.info_count += 1
return True
class CountingStderr(io.IOBase):
def __init__(self, real):
self.real = real
self.byte_count = 0
def write(self, data):
self.byte_count += len(data)
return self.real.write(data)
def flush(self):
return self.real.flush()
def main():
# The -W command-line option does not work in a virtualenv with
# python 3 (as of virtualenv 1.7), so configure warnings
# programmatically instead.
import warnings
# Be strict about most warnings. This also turns on warnings that are
# ignored by default, including DeprecationWarnings and
# python 3.2's ResourceWarnings.
warnings.filterwarnings("error")
# setuptools sometimes gives ImportWarnings about things that are on
# sys.path even if they're not being used.
warnings.filterwarnings("ignore", category=ImportWarning)
# Tornado generally shouldn't use anything deprecated, but some of
# our dependencies do (last match wins).
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("error", category=DeprecationWarning,
module=r"tornado\..*")
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
warnings.filterwarnings("error", category=PendingDeprecationWarning,
module=r"tornado\..*")
# The unittest module is aggressive about deprecating redundant methods,
# leaving some without non-deprecated spellings that work on both
# 2.7 and 3.2
warnings.filterwarnings("ignore", category=DeprecationWarning,
message="Please use assert.* instead")
warnings.filterwarnings("ignore", category=PendingDeprecationWarning,
message="Please use assert.* instead")
# Twisted 15.0.0 triggers some warnings on py3 with -bb.
warnings.filterwarnings("ignore", category=BytesWarning,
module=r"twisted\..*")
if (3,) < sys.version_info < (3, 6):
# Prior to 3.6, async ResourceWarnings were rather noisy
# and even
# `python3.4 -W error -c 'import asyncio; asyncio.get_event_loop()'`
# would generate a warning.
warnings.filterwarnings("ignore", category=ResourceWarning, # noqa: F821
module=r"asyncio\..*")
logging.getLogger("tornado.access").setLevel(logging.CRITICAL)
define('httpclient', type=str, default=None,
callback=lambda s: AsyncHTTPClient.configure(
s, defaults=dict(allow_ipv6=False)))
define('httpserver', type=str, default=None,
callback=HTTPServer.configure)
define('ioloop', type=str, default=None)
define('ioloop_time_monotonic', default=False)
define('resolver', type=str, default=None,
callback=Resolver.configure)
define('debug_gc', type=str, multiple=True,
help="A comma-separated list of gc module debug constants, "
"e.g. DEBUG_STATS or DEBUG_COLLECTABLE,DEBUG_OBJECTS",
callback=lambda values: gc.set_debug(
reduce(operator.or_, (getattr(gc, v) for v in values))))
define('locale', type=str, default=None,
callback=lambda x: locale.setlocale(locale.LC_ALL, x))
def configure_ioloop():
kwargs = {}
if options.ioloop_time_monotonic:
from tornado.platform.auto import monotonic_time
if monotonic_time is None:
raise RuntimeError("monotonic clock not found")
kwargs['time_func'] = monotonic_time
if options.ioloop or kwargs:
IOLoop.configure(options.ioloop, **kwargs)
add_parse_callback(configure_ioloop)
log_counter = LogCounter()
add_parse_callback(
lambda: logging.getLogger().handlers[0].addFilter(log_counter))
# Certain errors (especially "unclosed resource" errors raised in
# destructors) go directly to stderr instead of logging. Count
# anything written by anything but the test runner as an error.
orig_stderr = sys.stderr
sys.stderr = CountingStderr(orig_stderr)
import tornado.testing
kwargs = {}
if sys.version_info >= (3, 2):
# HACK: unittest.main will make its own changes to the warning
# configuration, which may conflict with the settings above
# or command-line flags like -bb. Passing warnings=False
# suppresses this behavior, although this looks like an implementation
# detail. http://bugs.python.org/issue15626
kwargs['warnings'] = False
kwargs['testRunner'] = test_runner_factory(orig_stderr)
try:
tornado.testing.main(**kwargs)
finally:
# The tests should run clean; consider it a failure if they
# logged anything at info level or above.
if (log_counter.info_count > 0 or
log_counter.warning_count > 0 or
log_counter.error_count > 0 or
sys.stderr.byte_count > 0):
logging.error("logged %d infos, %d warnings, %d errors, and %d bytes to stderr",
log_counter.info_count, log_counter.warning_count,
log_counter.error_count, sys.stderr.byte_count)
sys.exit(1)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,767 @@
from __future__ import absolute_import, division, print_function
import collections
from contextlib import closing
import errno
import gzip
import logging
import os
import re
import socket
import ssl
import sys
from tornado.escape import to_unicode, utf8
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httputil import HTTPHeaders, ResponseStartLine
from tornado.ioloop import IOLoop
from tornado.iostream import UnsatisfiableReadError
from tornado.locks import Event
from tornado.log import gen_log
from tornado.concurrent import Future
from tornado.netutil import Resolver, bind_sockets
from tornado.simple_httpclient import SimpleAsyncHTTPClient, HTTPStreamClosedError, HTTPTimeoutError
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler, RedirectHandler # noqa: E501
from tornado.test import httpclient_test
from tornado.testing import (AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase,
ExpectLog, gen_test)
from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port, skipBefore35, exec_test
from tornado.web import RequestHandler, Application, url, stream_request_body
class SimpleHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
def get_http_client(self):
client = SimpleAsyncHTTPClient(force_instance=True)
self.assertTrue(isinstance(client, SimpleAsyncHTTPClient))
return client
class TriggerHandler(RequestHandler):
def initialize(self, queue, wake_callback):
self.queue = queue
self.wake_callback = wake_callback
@gen.coroutine
def get(self):
logging.debug("queuing trigger")
self.queue.append(self.finish)
if self.get_argument("wake", "true") == "true":
self.wake_callback()
never_finish = Event()
yield never_finish.wait()
class HangHandler(RequestHandler):
@gen.coroutine
def get(self):
never_finish = Event()
yield never_finish.wait()
class ContentLengthHandler(RequestHandler):
def get(self):
self.stream = self.detach()
IOLoop.current().spawn_callback(self.write_response)
@gen.coroutine
def write_response(self):
yield self.stream.write(utf8("HTTP/1.0 200 OK\r\nContent-Length: %s\r\n\r\nok" %
self.get_argument("value")))
self.stream.close()
class HeadHandler(RequestHandler):
def head(self):
self.set_header("Content-Length", "7")
class OptionsHandler(RequestHandler):
def options(self):
self.set_header("Access-Control-Allow-Origin", "*")
self.write("ok")
class NoContentHandler(RequestHandler):
def get(self):
self.set_status(204)
self.finish()
class SeeOtherPostHandler(RequestHandler):
def post(self):
redirect_code = int(self.request.body)
assert redirect_code in (302, 303), "unexpected body %r" % self.request.body
self.set_header("Location", "/see_other_get")
self.set_status(redirect_code)
class SeeOtherGetHandler(RequestHandler):
def get(self):
if self.request.body:
raise Exception("unexpected body %r" % self.request.body)
self.write("ok")
class HostEchoHandler(RequestHandler):
def get(self):
self.write(self.request.headers["Host"])
class NoContentLengthHandler(RequestHandler):
def get(self):
if self.request.version.startswith('HTTP/1'):
# Emulate the old HTTP/1.0 behavior of returning a body with no
# content-length. Tornado handles content-length at the framework
# level so we have to go around it.
stream = self.detach()
stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
b"hello")
stream.close()
else:
self.finish('HTTP/1 required')
class EchoPostHandler(RequestHandler):
def post(self):
self.write(self.request.body)
@stream_request_body
class RespondInPrepareHandler(RequestHandler):
def prepare(self):
self.set_status(403)
self.finish("forbidden")
class SimpleHTTPClientTestMixin(object):
def get_app(self):
# callable objects to finish pending /trigger requests
self.triggers = collections.deque()
return Application([
url("/trigger", TriggerHandler, dict(queue=self.triggers,
wake_callback=self.stop)),
url("/chunk", ChunkHandler),
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
url("/hang", HangHandler),
url("/hello", HelloWorldHandler),
url("/content_length", ContentLengthHandler),
url("/head", HeadHandler),
url("/options", OptionsHandler),
url("/no_content", NoContentHandler),
url("/see_other_post", SeeOtherPostHandler),
url("/see_other_get", SeeOtherGetHandler),
url("/host_echo", HostEchoHandler),
url("/no_content_length", NoContentLengthHandler),
url("/echo_post", EchoPostHandler),
url("/respond_in_prepare", RespondInPrepareHandler),
url("/redirect", RedirectHandler),
], gzip=True)
def test_singleton(self):
# Class "constructor" reuses objects on the same IOLoop
self.assertTrue(SimpleAsyncHTTPClient() is
SimpleAsyncHTTPClient())
# unless force_instance is used
self.assertTrue(SimpleAsyncHTTPClient() is not
SimpleAsyncHTTPClient(force_instance=True))
# different IOLoops use different objects
with closing(IOLoop()) as io_loop2:
client1 = self.io_loop.run_sync(gen.coroutine(SimpleAsyncHTTPClient))
client2 = io_loop2.run_sync(gen.coroutine(SimpleAsyncHTTPClient))
self.assertTrue(client1 is not client2)
def test_connection_limit(self):
with closing(self.create_client(max_clients=2)) as client:
self.assertEqual(client.max_clients, 2)
seen = []
# Send 4 requests. Two can be sent immediately, while the others
# will be queued
for i in range(4):
client.fetch(self.get_url("/trigger")).add_done_callback(
lambda fut, i=i: (seen.append(i), self.stop()))
self.wait(condition=lambda: len(self.triggers) == 2)
self.assertEqual(len(client.queue), 2)
# Finish the first two requests and let the next two through
self.triggers.popleft()()
self.triggers.popleft()()
self.wait(condition=lambda: (len(self.triggers) == 2 and
len(seen) == 2))
self.assertEqual(set(seen), set([0, 1]))
self.assertEqual(len(client.queue), 0)
# Finish all the pending requests
self.triggers.popleft()()
self.triggers.popleft()()
self.wait(condition=lambda: len(seen) == 4)
self.assertEqual(set(seen), set([0, 1, 2, 3]))
self.assertEqual(len(self.triggers), 0)
@gen_test
def test_redirect_connection_limit(self):
# following redirects should not consume additional connections
with closing(self.create_client(max_clients=1)) as client:
response = yield client.fetch(self.get_url('/countdown/3'),
max_redirects=3)
response.rethrow()
def test_gzip(self):
# All the tests in this file should be using gzip, but this test
# ensures that it is in fact getting compressed.
# Setting Accept-Encoding manually bypasses the client's
# decompression so we can see the raw data.
response = self.fetch("/chunk", use_gzip=False,
headers={"Accept-Encoding": "gzip"})
self.assertEqual(response.headers["Content-Encoding"], "gzip")
self.assertNotEqual(response.body, b"asdfqwer")
# Our test data gets bigger when gzipped. Oops. :)
# Chunked encoding bypasses the MIN_LENGTH check.
self.assertEqual(len(response.body), 34)
f = gzip.GzipFile(mode="r", fileobj=response.buffer)
self.assertEqual(f.read(), b"asdfqwer")
def test_max_redirects(self):
response = self.fetch("/countdown/5", max_redirects=3)
self.assertEqual(302, response.code)
# We requested 5, followed three redirects for 4, 3, 2, then the last
# unfollowed redirect is to 1.
self.assertTrue(response.request.url.endswith("/countdown/5"))
self.assertTrue(response.effective_url.endswith("/countdown/2"))
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
def test_header_reuse(self):
# Apps may reuse a headers object if they are only passing in constant
# headers like user-agent. The header object should not be modified.
headers = HTTPHeaders({'User-Agent': 'Foo'})
self.fetch("/hello", headers=headers)
self.assertEqual(list(headers.get_all()), [('User-Agent', 'Foo')])
def test_see_other_redirect(self):
for code in (302, 303):
response = self.fetch("/see_other_post", method="POST", body="%d" % code)
self.assertEqual(200, response.code)
self.assertTrue(response.request.url.endswith("/see_other_post"))
self.assertTrue(response.effective_url.endswith("/see_other_get"))
# request is the original request, is a POST still
self.assertEqual("POST", response.request.method)
@skipOnTravis
@gen_test
def test_connect_timeout(self):
timeout = 0.1
class TimeoutResolver(Resolver):
def resolve(self, *args, **kwargs):
return Future() # never completes
with closing(self.create_client(resolver=TimeoutResolver())) as client:
with self.assertRaises(HTTPTimeoutError):
yield client.fetch(self.get_url('/hello'),
connect_timeout=timeout,
request_timeout=3600,
raise_error=True)
@skipOnTravis
def test_request_timeout(self):
timeout = 0.1
if os.name == 'nt':
timeout = 0.5
with self.assertRaises(HTTPTimeoutError):
self.fetch('/trigger?wake=false', request_timeout=timeout, raise_error=True)
# trigger the hanging request to let it clean up after itself
self.triggers.popleft()()
@skipIfNoIPv6
def test_ipv6(self):
[sock] = bind_sockets(None, '::1', family=socket.AF_INET6)
port = sock.getsockname()[1]
self.http_server.add_socket(sock)
url = '%s://[::1]:%d/hello' % (self.get_protocol(), port)
# ipv6 is currently enabled by default but can be disabled
with self.assertRaises(Exception):
self.fetch(url, allow_ipv6=False, raise_error=True)
response = self.fetch(url)
self.assertEqual(response.body, b"Hello world!")
def test_multiple_content_length_accepted(self):
response = self.fetch("/content_length?value=2,2")
self.assertEqual(response.body, b"ok")
response = self.fetch("/content_length?value=2,%202,2")
self.assertEqual(response.body, b"ok")
with ExpectLog(gen_log, ".*Multiple unequal Content-Lengths"):
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/content_length?value=2,4", raise_error=True)
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/content_length?value=2,%202,3", raise_error=True)
def test_head_request(self):
response = self.fetch("/head", method="HEAD")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["content-length"], "7")
self.assertFalse(response.body)
def test_options_request(self):
response = self.fetch("/options", method="OPTIONS")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["content-length"], "2")
self.assertEqual(response.headers["access-control-allow-origin"], "*")
self.assertEqual(response.body, b"ok")
def test_no_content(self):
response = self.fetch("/no_content")
self.assertEqual(response.code, 204)
# 204 status shouldn't have a content-length
#
# Tests with a content-length header are included below
# in HTTP204NoContentTestCase.
self.assertNotIn("Content-Length", response.headers)
def test_host_header(self):
host_re = re.compile(b"^127.0.0.1:[0-9]+$")
response = self.fetch("/host_echo")
self.assertTrue(host_re.match(response.body))
url = self.get_url("/host_echo").replace("http://", "http://me:secret@")
response = self.fetch(url)
self.assertTrue(host_re.match(response.body), response.body)
def test_connection_refused(self):
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
with ExpectLog(gen_log, ".*", required=False):
with self.assertRaises(socket.error) as cm:
self.fetch("http://127.0.0.1:%d/" % port, raise_error=True)
if sys.platform != 'cygwin':
# cygwin returns EPERM instead of ECONNREFUSED here
contains_errno = str(errno.ECONNREFUSED) in str(cm.exception)
if not contains_errno and hasattr(errno, "WSAECONNREFUSED"):
contains_errno = str(errno.WSAECONNREFUSED) in str(cm.exception)
self.assertTrue(contains_errno, cm.exception)
# This is usually "Connection refused".
# On windows, strerror is broken and returns "Unknown error".
expected_message = os.strerror(errno.ECONNREFUSED)
self.assertTrue(expected_message in str(cm.exception),
cm.exception)
def test_queue_timeout(self):
with closing(self.create_client(max_clients=1)) as client:
# Wait for the trigger request to block, not complete.
fut1 = client.fetch(self.get_url('/trigger'), request_timeout=10)
self.wait()
with self.assertRaises(HTTPTimeoutError) as cm:
self.io_loop.run_sync(lambda: client.fetch(
self.get_url('/hello'), connect_timeout=0.1, raise_error=True))
self.assertEqual(str(cm.exception), "Timeout in request queue")
self.triggers.popleft()()
self.io_loop.run_sync(lambda: fut1)
def test_no_content_length(self):
response = self.fetch("/no_content_length")
if response.body == b"HTTP/1 required":
self.skipTest("requires HTTP/1.x")
else:
self.assertEquals(b"hello", response.body)
def sync_body_producer(self, write):
write(b'1234')
write(b'5678')
@gen.coroutine
def async_body_producer(self, write):
yield write(b'1234')
yield gen.moment
yield write(b'5678')
def test_sync_body_producer_chunked(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.sync_body_producer)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_sync_body_producer_content_length(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.sync_body_producer,
headers={'Content-Length': '8'})
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_async_body_producer_chunked(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.async_body_producer)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_async_body_producer_content_length(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.async_body_producer,
headers={'Content-Length': '8'})
response.rethrow()
self.assertEqual(response.body, b"12345678")
@skipBefore35
def test_native_body_producer_chunked(self):
namespace = exec_test(globals(), locals(), """
async def body_producer(write):
await write(b'1234')
import asyncio
await asyncio.sleep(0)
await write(b'5678')
""")
response = self.fetch("/echo_post", method="POST",
body_producer=namespace["body_producer"])
response.rethrow()
self.assertEqual(response.body, b"12345678")
@skipBefore35
def test_native_body_producer_content_length(self):
namespace = exec_test(globals(), locals(), """
async def body_producer(write):
await write(b'1234')
import asyncio
await asyncio.sleep(0)
await write(b'5678')
""")
response = self.fetch("/echo_post", method="POST",
body_producer=namespace["body_producer"],
headers={'Content-Length': '8'})
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_100_continue(self):
response = self.fetch("/echo_post", method="POST",
body=b"1234",
expect_100_continue=True)
self.assertEqual(response.body, b"1234")
def test_100_continue_early_response(self):
def body_producer(write):
raise Exception("should not be called")
response = self.fetch("/respond_in_prepare", method="POST",
body_producer=body_producer,
expect_100_continue=True)
self.assertEqual(response.code, 403)
def test_streaming_follow_redirects(self):
# When following redirects, header and streaming callbacks
# should only be called for the final result.
# TODO(bdarnell): this test belongs in httpclient_test instead of
# simple_httpclient_test, but it fails with the version of libcurl
# available on travis-ci. Move it when that has been upgraded
# or we have a better framework to skip tests based on curl version.
headers = []
chunks = []
self.fetch("/redirect?url=/hello",
header_callback=headers.append,
streaming_callback=chunks.append)
chunks = list(map(to_unicode, chunks))
self.assertEqual(chunks, ['Hello world!'])
# Make sure we only got one set of headers.
num_start_lines = len([h for h in headers if h.startswith("HTTP/")])
self.assertEqual(num_start_lines, 1)
class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
def setUp(self):
super(SimpleHTTPClientTestCase, self).setUp()
self.http_client = self.create_client()
def create_client(self, **kwargs):
return SimpleAsyncHTTPClient(force_instance=True, **kwargs)
class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
def setUp(self):
super(SimpleHTTPSClientTestCase, self).setUp()
self.http_client = self.create_client()
def create_client(self, **kwargs):
return SimpleAsyncHTTPClient(force_instance=True,
defaults=dict(validate_cert=False),
**kwargs)
def test_ssl_options(self):
resp = self.fetch("/hello", ssl_options={})
self.assertEqual(resp.body, b"Hello world!")
def test_ssl_context(self):
resp = self.fetch("/hello",
ssl_options=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
self.assertEqual(resp.body, b"Hello world!")
def test_ssl_options_handshake_fail(self):
with ExpectLog(gen_log, "SSL Error|Uncaught exception",
required=False):
with self.assertRaises(ssl.SSLError):
self.fetch(
"/hello", ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED),
raise_error=True)
def test_ssl_context_handshake_fail(self):
with ExpectLog(gen_log, "SSL Error|Uncaught exception"):
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ctx.verify_mode = ssl.CERT_REQUIRED
with self.assertRaises(ssl.SSLError):
self.fetch("/hello", ssl_options=ctx, raise_error=True)
def test_error_logging(self):
# No stack traces are logged for SSL errors (in this case,
# failure to validate the testing self-signed cert).
# The SSLError is exposed through ssl.SSLError.
with ExpectLog(gen_log, '.*') as expect_log:
with self.assertRaises(ssl.SSLError):
self.fetch("/", validate_cert=True, raise_error=True)
self.assertFalse(expect_log.logged_stack)
class CreateAsyncHTTPClientTestCase(AsyncTestCase):
def setUp(self):
super(CreateAsyncHTTPClientTestCase, self).setUp()
self.saved = AsyncHTTPClient._save_configuration()
def tearDown(self):
AsyncHTTPClient._restore_configuration(self.saved)
super(CreateAsyncHTTPClientTestCase, self).tearDown()
def test_max_clients(self):
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
with closing(AsyncHTTPClient(force_instance=True)) as client:
self.assertEqual(client.max_clients, 10)
with closing(AsyncHTTPClient(
max_clients=11, force_instance=True)) as client:
self.assertEqual(client.max_clients, 11)
# Now configure max_clients statically and try overriding it
# with each way max_clients can be passed
AsyncHTTPClient.configure(SimpleAsyncHTTPClient, max_clients=12)
with closing(AsyncHTTPClient(force_instance=True)) as client:
self.assertEqual(client.max_clients, 12)
with closing(AsyncHTTPClient(
max_clients=13, force_instance=True)) as client:
self.assertEqual(client.max_clients, 13)
with closing(AsyncHTTPClient(
max_clients=14, force_instance=True)) as client:
self.assertEqual(client.max_clients, 14)
class HTTP100ContinueTestCase(AsyncHTTPTestCase):
def respond_100(self, request):
self.http1 = request.version.startswith('HTTP/1.')
if not self.http1:
request.connection.write_headers(ResponseStartLine('', 200, 'OK'),
HTTPHeaders())
request.connection.finish()
return
self.request = request
fut = self.request.connection.stream.write(
b"HTTP/1.1 100 CONTINUE\r\n\r\n")
fut.add_done_callback(self.respond_200)
def respond_200(self, fut):
fut.result()
fut = self.request.connection.stream.write(
b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\nA")
fut.add_done_callback(lambda f: self.request.connection.stream.close())
def get_app(self):
# Not a full Application, but works as an HTTPServer callback
return self.respond_100
def test_100_continue(self):
res = self.fetch('/')
if not self.http1:
self.skipTest("requires HTTP/1.x")
self.assertEqual(res.body, b'A')
class HTTP204NoContentTestCase(AsyncHTTPTestCase):
def respond_204(self, request):
self.http1 = request.version.startswith('HTTP/1.')
if not self.http1:
# Close the request cleanly in HTTP/2; it will be skipped anyway.
request.connection.write_headers(ResponseStartLine('', 200, 'OK'),
HTTPHeaders())
request.connection.finish()
return
# A 204 response never has a body, even if doesn't have a content-length
# (which would otherwise mean read-until-close). We simulate here a
# server that sends no content length and does not close the connection.
#
# Tests of a 204 response with no Content-Length header are included
# in SimpleHTTPClientTestMixin.
stream = request.connection.detach()
stream.write(b"HTTP/1.1 204 No content\r\n")
if request.arguments.get("error", [False])[-1]:
stream.write(b"Content-Length: 5\r\n")
else:
stream.write(b"Content-Length: 0\r\n")
stream.write(b"\r\n")
stream.close()
def get_app(self):
return self.respond_204
def test_204_no_content(self):
resp = self.fetch('/')
if not self.http1:
self.skipTest("requires HTTP/1.x")
self.assertEqual(resp.code, 204)
self.assertEqual(resp.body, b'')
def test_204_invalid_content_length(self):
# 204 status with non-zero content length is malformed
with ExpectLog(gen_log, ".*Response with code 204 should not have body"):
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/?error=1", raise_error=True)
if not self.http1:
self.skipTest("requires HTTP/1.x")
if self.http_client.configured_class != SimpleAsyncHTTPClient:
self.skipTest("curl client accepts invalid headers")
class HostnameMappingTestCase(AsyncHTTPTestCase):
def setUp(self):
super(HostnameMappingTestCase, self).setUp()
self.http_client = SimpleAsyncHTTPClient(
hostname_mapping={
'www.example.com': '127.0.0.1',
('foo.example.com', 8000): ('127.0.0.1', self.get_http_port()),
})
def get_app(self):
return Application([url("/hello", HelloWorldHandler), ])
def test_hostname_mapping(self):
response = self.fetch(
'http://www.example.com:%d/hello' % self.get_http_port())
response.rethrow()
self.assertEqual(response.body, b'Hello world!')
def test_port_mapping(self):
response = self.fetch('http://foo.example.com:8000/hello')
response.rethrow()
self.assertEqual(response.body, b'Hello world!')
class ResolveTimeoutTestCase(AsyncHTTPTestCase):
def setUp(self):
# Dummy Resolver subclass that never finishes.
class BadResolver(Resolver):
@gen.coroutine
def resolve(self, *args, **kwargs):
yield Event().wait()
super(ResolveTimeoutTestCase, self).setUp()
self.http_client = SimpleAsyncHTTPClient(
resolver=BadResolver())
def get_app(self):
return Application([url("/hello", HelloWorldHandler), ])
def test_resolve_timeout(self):
with self.assertRaises(HTTPTimeoutError):
self.fetch('/hello', connect_timeout=0.1, raise_error=True)
class MaxHeaderSizeTest(AsyncHTTPTestCase):
def get_app(self):
class SmallHeaders(RequestHandler):
def get(self):
self.set_header("X-Filler", "a" * 100)
self.write("ok")
class LargeHeaders(RequestHandler):
def get(self):
self.set_header("X-Filler", "a" * 1000)
self.write("ok")
return Application([('/small', SmallHeaders),
('/large', LargeHeaders)])
def get_http_client(self):
return SimpleAsyncHTTPClient(max_header_size=1024)
def test_small_headers(self):
response = self.fetch('/small')
response.rethrow()
self.assertEqual(response.body, b'ok')
def test_large_headers(self):
with ExpectLog(gen_log, "Unsatisfiable read"):
with self.assertRaises(UnsatisfiableReadError):
self.fetch('/large', raise_error=True)
class MaxBodySizeTest(AsyncHTTPTestCase):
def get_app(self):
class SmallBody(RequestHandler):
def get(self):
self.write("a" * 1024 * 64)
class LargeBody(RequestHandler):
def get(self):
self.write("a" * 1024 * 100)
return Application([('/small', SmallBody),
('/large', LargeBody)])
def get_http_client(self):
return SimpleAsyncHTTPClient(max_body_size=1024 * 64)
def test_small_body(self):
response = self.fetch('/small')
response.rethrow()
self.assertEqual(response.body, b'a' * 1024 * 64)
def test_large_body(self):
with ExpectLog(gen_log, "Malformed HTTP message from None: Content-Length too long"):
with self.assertRaises(HTTPStreamClosedError):
self.fetch('/large', raise_error=True)
class MaxBufferSizeTest(AsyncHTTPTestCase):
def get_app(self):
class LargeBody(RequestHandler):
def get(self):
self.write("a" * 1024 * 100)
return Application([('/large', LargeBody)])
def get_http_client(self):
# 100KB body with 64KB buffer
return SimpleAsyncHTTPClient(max_body_size=1024 * 100, max_buffer_size=1024 * 64)
def test_large_body(self):
response = self.fetch('/large')
response.rethrow()
self.assertEqual(response.body, b'a' * 1024 * 100)
class ChunkedWithContentLengthTest(AsyncHTTPTestCase):
def get_app(self):
class ChunkedWithContentLength(RequestHandler):
def get(self):
# Add an invalid Transfer-Encoding to the response
self.set_header('Transfer-Encoding', 'chunked')
self.write("Hello world")
return Application([('/chunkwithcl', ChunkedWithContentLength)])
def get_http_client(self):
return SimpleAsyncHTTPClient()
def test_chunked_with_content_length(self):
# Make sure the invalid headers are detected
with ExpectLog(gen_log, ("Malformed HTTP message from None: Response "
"with both Transfer-Encoding and Content-Length")):
with self.assertRaises(HTTPStreamClosedError):
self.fetch('/chunkwithcl', raise_error=True)

View File

@@ -0,0 +1,297 @@
from __future__ import absolute_import, division, print_function
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.log import app_log
from tornado.stack_context import (StackContext, wrap, NullContext, StackContextInconsistentError,
ExceptionStackContext, run_with_stack_context, _state)
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.test.util import unittest, ignore_deprecation
from tornado.web import asynchronous, Application, RequestHandler
import contextlib
import functools
import logging
import warnings
class TestRequestHandler(RequestHandler):
def __init__(self, app, request):
super(TestRequestHandler, self).__init__(app, request)
with ignore_deprecation():
@asynchronous
def get(self):
logging.debug('in get()')
# call self.part2 without a self.async_callback wrapper. Its
# exception should still get thrown
IOLoop.current().add_callback(self.part2)
def part2(self):
logging.debug('in part2()')
# Go through a third layer to make sure that contexts once restored
# are again passed on to future callbacks
IOLoop.current().add_callback(self.part3)
def part3(self):
logging.debug('in part3()')
raise Exception('test exception')
def write_error(self, status_code, **kwargs):
if 'exc_info' in kwargs and str(kwargs['exc_info'][1]) == 'test exception':
self.write('got expected exception')
else:
self.write('unexpected failure')
class HTTPStackContextTest(AsyncHTTPTestCase):
def get_app(self):
return Application([('/', TestRequestHandler)])
def test_stack_context(self):
with ExpectLog(app_log, "Uncaught exception GET /"):
with ignore_deprecation():
self.http_client.fetch(self.get_url('/'), self.handle_response)
self.wait()
self.assertEqual(self.response.code, 500)
self.assertTrue(b'got expected exception' in self.response.body)
def handle_response(self, response):
self.response = response
self.stop()
class StackContextTest(AsyncTestCase):
def setUp(self):
super(StackContextTest, self).setUp()
self.active_contexts = []
self.warning_catcher = warnings.catch_warnings()
self.warning_catcher.__enter__()
warnings.simplefilter('ignore', DeprecationWarning)
def tearDown(self):
self.warning_catcher.__exit__(None, None, None)
super(StackContextTest, self).tearDown()
@contextlib.contextmanager
def context(self, name):
self.active_contexts.append(name)
yield
self.assertEqual(self.active_contexts.pop(), name)
# Simulates the effect of an asynchronous library that uses its own
# StackContext internally and then returns control to the application.
def test_exit_library_context(self):
def library_function(callback):
# capture the caller's context before introducing our own
callback = wrap(callback)
with StackContext(functools.partial(self.context, 'library')):
self.io_loop.add_callback(
functools.partial(library_inner_callback, callback))
def library_inner_callback(callback):
self.assertEqual(self.active_contexts[-2:],
['application', 'library'])
callback()
def final_callback():
# implementation detail: the full context stack at this point
# is ['application', 'library', 'application']. The 'library'
# context was not removed, but is no longer innermost so
# the application context takes precedence.
self.assertEqual(self.active_contexts[-1], 'application')
self.stop()
with StackContext(functools.partial(self.context, 'application')):
library_function(final_callback)
self.wait()
def test_deactivate(self):
deactivate_callbacks = []
def f1():
with StackContext(functools.partial(self.context, 'c1')) as c1:
deactivate_callbacks.append(c1)
self.io_loop.add_callback(f2)
def f2():
with StackContext(functools.partial(self.context, 'c2')) as c2:
deactivate_callbacks.append(c2)
self.io_loop.add_callback(f3)
def f3():
with StackContext(functools.partial(self.context, 'c3')) as c3:
deactivate_callbacks.append(c3)
self.io_loop.add_callback(f4)
def f4():
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
deactivate_callbacks[1]()
# deactivating a context doesn't remove it immediately,
# but it will be missing from the next iteration
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
self.io_loop.add_callback(f5)
def f5():
self.assertEqual(self.active_contexts, ['c1', 'c3'])
self.stop()
self.io_loop.add_callback(f1)
self.wait()
def test_deactivate_order(self):
# Stack context deactivation has separate logic for deactivation at
# the head and tail of the stack, so make sure it works in any order.
def check_contexts():
# Make sure that the full-context array and the exception-context
# linked lists are consistent with each other.
full_contexts, chain = _state.contexts
exception_contexts = []
while chain is not None:
exception_contexts.append(chain)
chain = chain.old_contexts[1]
self.assertEqual(list(reversed(full_contexts)), exception_contexts)
return list(self.active_contexts)
def make_wrapped_function():
"""Wraps a function in three stack contexts, and returns
the function along with the deactivation functions.
"""
# Remove the test's stack context to make sure we can cover
# the case where the last context is deactivated.
with NullContext():
partial = functools.partial
with StackContext(partial(self.context, 'c0')) as c0:
with StackContext(partial(self.context, 'c1')) as c1:
with StackContext(partial(self.context, 'c2')) as c2:
return (wrap(check_contexts), [c0, c1, c2])
# First make sure the test mechanism works without any deactivations
func, deactivate_callbacks = make_wrapped_function()
self.assertEqual(func(), ['c0', 'c1', 'c2'])
# Deactivate the tail
func, deactivate_callbacks = make_wrapped_function()
deactivate_callbacks[0]()
self.assertEqual(func(), ['c1', 'c2'])
# Deactivate the middle
func, deactivate_callbacks = make_wrapped_function()
deactivate_callbacks[1]()
self.assertEqual(func(), ['c0', 'c2'])
# Deactivate the head
func, deactivate_callbacks = make_wrapped_function()
deactivate_callbacks[2]()
self.assertEqual(func(), ['c0', 'c1'])
def test_isolation_nonempty(self):
# f2 and f3 are a chain of operations started in context c1.
# f2 is incidentally run under context c2, but that context should
# not be passed along to f3.
def f1():
with StackContext(functools.partial(self.context, 'c1')):
wrapped = wrap(f2)
with StackContext(functools.partial(self.context, 'c2')):
wrapped()
def f2():
self.assertIn('c1', self.active_contexts)
self.io_loop.add_callback(f3)
def f3():
self.assertIn('c1', self.active_contexts)
self.assertNotIn('c2', self.active_contexts)
self.stop()
self.io_loop.add_callback(f1)
self.wait()
def test_isolation_empty(self):
# Similar to test_isolation_nonempty, but here the f2/f3 chain
# is started without any context. Behavior should be equivalent
# to the nonempty case (although historically it was not)
def f1():
with NullContext():
wrapped = wrap(f2)
with StackContext(functools.partial(self.context, 'c2')):
wrapped()
def f2():
self.io_loop.add_callback(f3)
def f3():
self.assertNotIn('c2', self.active_contexts)
self.stop()
self.io_loop.add_callback(f1)
self.wait()
def test_yield_in_with(self):
@gen.engine
def f():
self.callback = yield gen.Callback('a')
with StackContext(functools.partial(self.context, 'c1')):
# This yield is a problem: the generator will be suspended
# and the StackContext's __exit__ is not called yet, so
# the context will be left on _state.contexts for anything
# that runs before the yield resolves.
yield gen.Wait('a')
with self.assertRaises(StackContextInconsistentError):
f()
self.wait()
# Cleanup: to avoid GC warnings (which for some reason only seem
# to show up on py33-asyncio), invoke the callback (which will do
# nothing since the gen.Runner is already finished) and delete it.
self.callback()
del self.callback
@gen_test
def test_yield_outside_with(self):
# This pattern avoids the problem in the previous test.
cb = yield gen.Callback('k1')
with StackContext(functools.partial(self.context, 'c1')):
self.io_loop.add_callback(cb)
yield gen.Wait('k1')
def test_yield_in_with_exception_stack_context(self):
# As above, but with ExceptionStackContext instead of StackContext.
@gen.engine
def f():
with ExceptionStackContext(lambda t, v, tb: False):
yield gen.Task(self.io_loop.add_callback)
with self.assertRaises(StackContextInconsistentError):
f()
self.wait()
@gen_test
def test_yield_outside_with_exception_stack_context(self):
cb = yield gen.Callback('k1')
with ExceptionStackContext(lambda t, v, tb: False):
self.io_loop.add_callback(cb)
yield gen.Wait('k1')
@gen_test
def test_run_with_stack_context(self):
@gen.coroutine
def f1():
self.assertEqual(self.active_contexts, ['c1'])
yield run_with_stack_context(
StackContext(functools.partial(self.context, 'c2')),
f2)
self.assertEqual(self.active_contexts, ['c1'])
@gen.coroutine
def f2():
self.assertEqual(self.active_contexts, ['c1', 'c2'])
yield gen.Task(self.io_loop.add_callback)
self.assertEqual(self.active_contexts, ['c1', 'c2'])
self.assertEqual(self.active_contexts, [])
yield run_with_stack_context(
StackContext(functools.partial(self.context, 'c1')),
f1)
self.assertEqual(self.active_contexts, [])
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1 @@
this is the index

View File

@@ -0,0 +1,2 @@
User-agent: *
Disallow: /

View File

@@ -0,0 +1,23 @@
<?xml version="1.0"?>
<data>
<country name="Liechtenstein">
<rank>1</rank>
<year>2008</year>
<gdppc>141100</gdppc>
<neighbor name="Austria" direction="E"/>
<neighbor name="Switzerland" direction="W"/>
</country>
<country name="Singapore">
<rank>4</rank>
<year>2011</year>
<gdppc>59900</gdppc>
<neighbor name="Malaysia" direction="N"/>
</country>
<country name="Panama">
<rank>68</rank>
<year>2011</year>
<gdppc>13600</gdppc>
<neighbor name="Costa Rica" direction="W"/>
<neighbor name="Colombia" direction="E"/>
</country>
</data>

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,2 @@
This file should not be served by StaticFileHandler even though
its name starts with "static".

View File

@@ -0,0 +1,430 @@
#
# Copyright 2014 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function
from contextlib import closing
import os
import socket
from tornado.concurrent import Future
from tornado.netutil import bind_sockets, Resolver
from tornado.queues import Queue
from tornado.tcpclient import TCPClient, _Connector
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import skipIfNoIPv6, unittest, refusing_port, skipIfNonUnix
from tornado.gen import TimeoutError
# Fake address families for testing. Used in place of AF_INET
# and AF_INET6 because some installations do not have AF_INET6.
AF1, AF2 = 1, 2
class TestTCPServer(TCPServer):
def __init__(self, family):
super(TestTCPServer, self).__init__()
self.streams = []
self.queue = Queue()
sockets = bind_sockets(None, 'localhost', family)
self.add_sockets(sockets)
self.port = sockets[0].getsockname()[1]
def handle_stream(self, stream, address):
self.streams.append(stream)
self.queue.put(stream)
def stop(self):
super(TestTCPServer, self).stop()
for stream in self.streams:
stream.close()
class TCPClientTest(AsyncTestCase):
def setUp(self):
super(TCPClientTest, self).setUp()
self.server = None
self.client = TCPClient()
def start_server(self, family):
if family == socket.AF_UNSPEC and 'TRAVIS' in os.environ:
self.skipTest("dual-stack servers often have port conflicts on travis")
self.server = TestTCPServer(family)
return self.server.port
def stop_server(self):
if self.server is not None:
self.server.stop()
self.server = None
def tearDown(self):
self.client.close()
self.stop_server()
super(TCPClientTest, self).tearDown()
def skipIfLocalhostV4(self):
# The port used here doesn't matter, but some systems require it
# to be non-zero if we do not also pass AI_PASSIVE.
addrinfo = self.io_loop.run_sync(lambda: Resolver().resolve('localhost', 80))
families = set(addr[0] for addr in addrinfo)
if socket.AF_INET6 not in families:
self.skipTest("localhost does not resolve to ipv6")
@gen_test
def do_test_connect(self, family, host, source_ip=None, source_port=None):
port = self.start_server(family)
stream = yield self.client.connect(host, port,
source_ip=source_ip,
source_port=source_port)
server_stream = yield self.server.queue.get()
with closing(stream):
stream.write(b"hello")
data = yield server_stream.read_bytes(5)
self.assertEqual(data, b"hello")
def test_connect_ipv4_ipv4(self):
self.do_test_connect(socket.AF_INET, '127.0.0.1')
def test_connect_ipv4_dual(self):
self.do_test_connect(socket.AF_INET, 'localhost')
@skipIfNoIPv6
def test_connect_ipv6_ipv6(self):
self.skipIfLocalhostV4()
self.do_test_connect(socket.AF_INET6, '::1')
@skipIfNoIPv6
def test_connect_ipv6_dual(self):
self.skipIfLocalhostV4()
if Resolver.configured_class().__name__.endswith('TwistedResolver'):
self.skipTest('TwistedResolver does not support multiple addresses')
self.do_test_connect(socket.AF_INET6, 'localhost')
def test_connect_unspec_ipv4(self):
self.do_test_connect(socket.AF_UNSPEC, '127.0.0.1')
@skipIfNoIPv6
def test_connect_unspec_ipv6(self):
self.skipIfLocalhostV4()
self.do_test_connect(socket.AF_UNSPEC, '::1')
def test_connect_unspec_dual(self):
self.do_test_connect(socket.AF_UNSPEC, 'localhost')
@gen_test
def test_refused_ipv4(self):
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
with self.assertRaises(IOError):
yield self.client.connect('127.0.0.1', port)
def test_source_ip_fail(self):
'''
Fail when trying to use the source IP Address '8.8.8.8'.
'''
self.assertRaises(socket.error,
self.do_test_connect,
socket.AF_INET,
'127.0.0.1',
source_ip='8.8.8.8')
def test_source_ip_success(self):
'''
Success when trying to use the source IP Address '127.0.0.1'
'''
self.do_test_connect(socket.AF_INET, '127.0.0.1', source_ip='127.0.0.1')
@skipIfNonUnix
def test_source_port_fail(self):
'''
Fail when trying to use source port 1.
'''
self.assertRaises(socket.error,
self.do_test_connect,
socket.AF_INET,
'127.0.0.1',
source_port=1)
@gen_test
def test_connect_timeout(self):
timeout = 0.05
class TimeoutResolver(Resolver):
def resolve(self, *args, **kwargs):
return Future() # never completes
with self.assertRaises(TimeoutError):
yield TCPClient(resolver=TimeoutResolver()).connect(
'1.2.3.4', 12345, timeout=timeout)
class TestConnectorSplit(unittest.TestCase):
def test_one_family(self):
# These addresses aren't in the right format, but split doesn't care.
primary, secondary = _Connector.split(
[(AF1, 'a'),
(AF1, 'b')])
self.assertEqual(primary, [(AF1, 'a'),
(AF1, 'b')])
self.assertEqual(secondary, [])
def test_mixed(self):
primary, secondary = _Connector.split(
[(AF1, 'a'),
(AF2, 'b'),
(AF1, 'c'),
(AF2, 'd')])
self.assertEqual(primary, [(AF1, 'a'), (AF1, 'c')])
self.assertEqual(secondary, [(AF2, 'b'), (AF2, 'd')])
class ConnectorTest(AsyncTestCase):
class FakeStream(object):
def __init__(self):
self.closed = False
def close(self):
self.closed = True
def setUp(self):
super(ConnectorTest, self).setUp()
self.connect_futures = {}
self.streams = {}
self.addrinfo = [(AF1, 'a'), (AF1, 'b'),
(AF2, 'c'), (AF2, 'd')]
def tearDown(self):
# Unless explicitly checked (and popped) in the test, we shouldn't
# be closing any streams
for stream in self.streams.values():
self.assertFalse(stream.closed)
super(ConnectorTest, self).tearDown()
def create_stream(self, af, addr):
stream = ConnectorTest.FakeStream()
self.streams[addr] = stream
future = Future()
self.connect_futures[(af, addr)] = future
return stream, future
def assert_pending(self, *keys):
self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys))
def resolve_connect(self, af, addr, success):
future = self.connect_futures.pop((af, addr))
if success:
future.set_result(self.streams[addr])
else:
self.streams.pop(addr)
future.set_exception(IOError())
# Run the loop to allow callbacks to be run.
self.io_loop.add_callback(self.stop)
self.wait()
def assert_connector_streams_closed(self, conn):
for stream in conn.streams:
self.assertTrue(stream.closed)
def start_connect(self, addrinfo):
conn = _Connector(addrinfo, self.create_stream)
# Give it a huge timeout; we'll trigger timeouts manually.
future = conn.start(3600, connect_timeout=self.io_loop.time() + 3600)
return conn, future
def test_immediate_success(self):
conn, future = self.start_connect(self.addrinfo)
self.assertEqual(list(self.connect_futures.keys()),
[(AF1, 'a')])
self.resolve_connect(AF1, 'a', True)
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
def test_immediate_failure(self):
# Fail with just one address.
conn, future = self.start_connect([(AF1, 'a')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assertRaises(IOError, future.result)
def test_one_family_second_try(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.resolve_connect(AF1, 'b', True)
self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
def test_one_family_second_try_failure(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.resolve_connect(AF1, 'b', False)
self.assertRaises(IOError, future.result)
def test_one_family_second_try_timeout(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
# trigger the timeout while the first lookup is pending;
# nothing happens.
conn.on_timeout()
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.resolve_connect(AF1, 'b', True)
self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
def test_two_families_immediate_failure(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'), (AF2, 'c'))
self.resolve_connect(AF1, 'b', False)
self.resolve_connect(AF2, 'c', True)
self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
def test_two_families_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
self.resolve_connect(AF2, 'c', True)
self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
# resolving 'a' after the connection has completed doesn't start 'b'
self.resolve_connect(AF1, 'a', False)
self.assert_pending()
def test_success_after_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
self.resolve_connect(AF1, 'a', True)
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
# resolving 'c' after completion closes the connection.
self.resolve_connect(AF2, 'c', True)
self.assertTrue(self.streams.pop('c').closed)
def test_all_fail(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
self.resolve_connect(AF2, 'c', False)
self.assert_pending((AF1, 'a'), (AF2, 'd'))
self.resolve_connect(AF2, 'd', False)
# one queue is now empty
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.assertFalse(future.done())
self.resolve_connect(AF1, 'b', False)
self.assertRaises(IOError, future.result)
def test_one_family_timeout_after_connect_timeout(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
conn.on_connect_timeout()
# the connector will close all streams on connect timeout, we
# should explicitly pop the connect_future.
self.connect_futures.pop((AF1, 'a'))
self.assertTrue(self.streams.pop('a').closed)
conn.on_timeout()
# if the future is set with TimeoutError, we will not iterate next
# possible address.
self.assert_pending()
self.assertEqual(len(conn.streams), 1)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)
def test_one_family_success_before_connect_timeout(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', True)
conn.on_connect_timeout()
self.assert_pending()
self.assertEqual(self.streams['a'].closed, False)
# success stream will be pop
self.assertEqual(len(conn.streams), 0)
# streams in connector should be closed after connect timeout
self.assert_connector_streams_closed(conn)
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
def test_one_family_second_try_after_connect_timeout(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
conn.on_connect_timeout()
self.connect_futures.pop((AF1, 'b'))
self.assertTrue(self.streams.pop('b').closed)
self.assert_pending()
self.assertEqual(len(conn.streams), 2)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)
def test_one_family_second_try_failure_before_connect_timeout(self):
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
self.assert_pending((AF1, 'a'))
self.resolve_connect(AF1, 'a', False)
self.assert_pending((AF1, 'b'))
self.resolve_connect(AF1, 'b', False)
conn.on_connect_timeout()
self.assert_pending()
self.assertEqual(len(conn.streams), 2)
self.assert_connector_streams_closed(conn)
self.assertRaises(IOError, future.result)
def test_two_family_timeout_before_connect_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
conn.on_connect_timeout()
self.connect_futures.pop((AF1, 'a'))
self.assertTrue(self.streams.pop('a').closed)
self.connect_futures.pop((AF2, 'c'))
self.assertTrue(self.streams.pop('c').closed)
self.assert_pending()
self.assertEqual(len(conn.streams), 2)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)
def test_two_family_success_after_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_timeout()
self.assert_pending((AF1, 'a'), (AF2, 'c'))
self.resolve_connect(AF1, 'a', True)
# if one of streams succeed, connector will close all other streams
self.connect_futures.pop((AF2, 'c'))
self.assertTrue(self.streams.pop('c').closed)
self.assert_pending()
self.assertEqual(len(conn.streams), 1)
self.assert_connector_streams_closed(conn)
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
def test_two_family_timeout_after_connect_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, 'a'))
conn.on_connect_timeout()
self.connect_futures.pop((AF1, 'a'))
self.assertTrue(self.streams.pop('a').closed)
self.assert_pending()
conn.on_timeout()
# if the future is set with TimeoutError, connector will not
# trigger secondary address.
self.assert_pending()
self.assertEqual(len(conn.streams), 1)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)

View File

@@ -0,0 +1,193 @@
from __future__ import absolute_import, division, print_function
import socket
import subprocess
import sys
import textwrap
from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.iostream import IOStream
from tornado.log import app_log
from tornado.stack_context import NullContext
from tornado.tcpserver import TCPServer
from tornado.test.util import skipBefore35, skipIfNonUnix, exec_test, unittest
from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
class TCPServerTest(AsyncTestCase):
@gen_test
def test_handle_stream_coroutine_logging(self):
# handle_stream may be a coroutine and any exception in its
# Future will be logged.
class TestServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
yield stream.read_bytes(len(b'hello'))
stream.close()
1 / 0
server = client = None
try:
sock, port = bind_unused_port()
with NullContext():
server = TestServer()
server.add_socket(sock)
client = IOStream(socket.socket())
with ExpectLog(app_log, "Exception in callback"):
yield client.connect(('localhost', port))
yield client.write(b'hello')
yield client.read_until_close()
yield gen.moment
finally:
if server is not None:
server.stop()
if client is not None:
client.close()
@skipBefore35
@gen_test
def test_handle_stream_native_coroutine(self):
# handle_stream may be a native coroutine.
namespace = exec_test(globals(), locals(), """
class TestServer(TCPServer):
async def handle_stream(self, stream, address):
stream.write(b'data')
stream.close()
""")
sock, port = bind_unused_port()
server = namespace['TestServer']()
server.add_socket(sock)
client = IOStream(socket.socket())
yield client.connect(('localhost', port))
result = yield client.read_until_close()
self.assertEqual(result, b'data')
server.stop()
client.close()
def test_stop_twice(self):
sock, port = bind_unused_port()
server = TCPServer()
server.add_socket(sock)
server.stop()
server.stop()
@gen_test
def test_stop_in_callback(self):
# Issue #2069: calling server.stop() in a loop callback should not
# raise EBADF when the loop handles other server connection
# requests in the same loop iteration
class TestServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
server.stop()
yield stream.read_until_close()
sock, port = bind_unused_port()
server = TestServer()
server.add_socket(sock)
server_addr = ('localhost', port)
N = 40
clients = [IOStream(socket.socket()) for i in range(N)]
connected_clients = []
@gen.coroutine
def connect(c):
try:
yield c.connect(server_addr)
except EnvironmentError:
pass
else:
connected_clients.append(c)
yield [connect(c) for c in clients]
self.assertGreater(len(connected_clients), 0,
"all clients failed connecting")
try:
if len(connected_clients) == N:
# Ideally we'd make the test deterministic, but we're testing
# for a race condition in combination with the system's TCP stack...
self.skipTest("at least one client should fail connecting "
"for the test to be meaningful")
finally:
for c in connected_clients:
c.close()
# Here tearDown() would re-raise the EBADF encountered in the IO loop
@skipIfNonUnix
class TestMultiprocess(unittest.TestCase):
# These tests verify that the two multiprocess examples from the
# TCPServer docs work. Both tests start a server with three worker
# processes, each of which prints its task id to stdout (a single
# byte, so we don't have to worry about atomicity of the shared
# stdout stream) and then exits.
def run_subproc(self, code):
proc = subprocess.Popen(sys.executable,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE)
proc.stdin.write(utf8(code))
proc.stdin.close()
proc.wait()
stdout = proc.stdout.read()
proc.stdout.close()
if proc.returncode != 0:
raise RuntimeError("Process returned %d. stdout=%r" % (
proc.returncode, stdout))
return to_unicode(stdout)
def test_single(self):
# As a sanity check, run the single-process version through this test
# harness too.
code = textwrap.dedent("""
from __future__ import print_function
from tornado.ioloop import IOLoop
from tornado.tcpserver import TCPServer
server = TCPServer()
server.listen(0, address='127.0.0.1')
IOLoop.current().run_sync(lambda: None)
print('012', end='')
""")
out = self.run_subproc(code)
self.assertEqual(''.join(sorted(out)), "012")
def test_simple(self):
code = textwrap.dedent("""
from __future__ import print_function
from tornado.ioloop import IOLoop
from tornado.process import task_id
from tornado.tcpserver import TCPServer
server = TCPServer()
server.bind(0, address='127.0.0.1')
server.start(3)
IOLoop.current().run_sync(lambda: None)
print(task_id(), end='')
""")
out = self.run_subproc(code)
self.assertEqual(''.join(sorted(out)), "012")
def test_advanced(self):
code = textwrap.dedent("""
from __future__ import print_function
from tornado.ioloop import IOLoop
from tornado.netutil import bind_sockets
from tornado.process import fork_processes, task_id
from tornado.ioloop import IOLoop
from tornado.tcpserver import TCPServer
sockets = bind_sockets(0, address='127.0.0.1')
fork_processes(3)
server = TCPServer()
server.add_sockets(sockets)
IOLoop.current().run_sync(lambda: None)
print(task_id(), end='')
""")
out = self.run_subproc(code)
self.assertEqual(''.join(sorted(out)), "012")

Some files were not shown because too many files have changed in this diff Show More