ops
This commit is contained in:
Executable
+28
@@ -0,0 +1,28 @@
|
||||
#
|
||||
# Copyright 2009 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""The Tornado web server and tools."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# version is a human-readable version number.
|
||||
|
||||
# version_info is a four-tuple for programmatic comparison. The first
|
||||
# three numbers are the components of the version number. The fourth
|
||||
# is zero for an official release, positive for a development branch,
|
||||
# or negative for a release candidate or beta (after the base version
|
||||
# number has been incremented)
|
||||
version = "5.1.1"
|
||||
version_info = (5, 1, 1, 0)
|
||||
Executable
+84
@@ -0,0 +1,84 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2012 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Data used by the tornado.locale module."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
LOCALE_NAMES = {
|
||||
"af_ZA": {"name_en": u"Afrikaans", "name": u"Afrikaans"},
|
||||
"am_ET": {"name_en": u"Amharic", "name": u"አማርኛ"},
|
||||
"ar_AR": {"name_en": u"Arabic", "name": u"العربية"},
|
||||
"bg_BG": {"name_en": u"Bulgarian", "name": u"Български"},
|
||||
"bn_IN": {"name_en": u"Bengali", "name": u"বাংলা"},
|
||||
"bs_BA": {"name_en": u"Bosnian", "name": u"Bosanski"},
|
||||
"ca_ES": {"name_en": u"Catalan", "name": u"Català"},
|
||||
"cs_CZ": {"name_en": u"Czech", "name": u"Čeština"},
|
||||
"cy_GB": {"name_en": u"Welsh", "name": u"Cymraeg"},
|
||||
"da_DK": {"name_en": u"Danish", "name": u"Dansk"},
|
||||
"de_DE": {"name_en": u"German", "name": u"Deutsch"},
|
||||
"el_GR": {"name_en": u"Greek", "name": u"Ελληνικά"},
|
||||
"en_GB": {"name_en": u"English (UK)", "name": u"English (UK)"},
|
||||
"en_US": {"name_en": u"English (US)", "name": u"English (US)"},
|
||||
"es_ES": {"name_en": u"Spanish (Spain)", "name": u"Español (España)"},
|
||||
"es_LA": {"name_en": u"Spanish", "name": u"Español"},
|
||||
"et_EE": {"name_en": u"Estonian", "name": u"Eesti"},
|
||||
"eu_ES": {"name_en": u"Basque", "name": u"Euskara"},
|
||||
"fa_IR": {"name_en": u"Persian", "name": u"فارسی"},
|
||||
"fi_FI": {"name_en": u"Finnish", "name": u"Suomi"},
|
||||
"fr_CA": {"name_en": u"French (Canada)", "name": u"Français (Canada)"},
|
||||
"fr_FR": {"name_en": u"French", "name": u"Français"},
|
||||
"ga_IE": {"name_en": u"Irish", "name": u"Gaeilge"},
|
||||
"gl_ES": {"name_en": u"Galician", "name": u"Galego"},
|
||||
"he_IL": {"name_en": u"Hebrew", "name": u"עברית"},
|
||||
"hi_IN": {"name_en": u"Hindi", "name": u"हिन्दी"},
|
||||
"hr_HR": {"name_en": u"Croatian", "name": u"Hrvatski"},
|
||||
"hu_HU": {"name_en": u"Hungarian", "name": u"Magyar"},
|
||||
"id_ID": {"name_en": u"Indonesian", "name": u"Bahasa Indonesia"},
|
||||
"is_IS": {"name_en": u"Icelandic", "name": u"Íslenska"},
|
||||
"it_IT": {"name_en": u"Italian", "name": u"Italiano"},
|
||||
"ja_JP": {"name_en": u"Japanese", "name": u"日本語"},
|
||||
"ko_KR": {"name_en": u"Korean", "name": u"한국어"},
|
||||
"lt_LT": {"name_en": u"Lithuanian", "name": u"Lietuvių"},
|
||||
"lv_LV": {"name_en": u"Latvian", "name": u"Latviešu"},
|
||||
"mk_MK": {"name_en": u"Macedonian", "name": u"Македонски"},
|
||||
"ml_IN": {"name_en": u"Malayalam", "name": u"മലയാളം"},
|
||||
"ms_MY": {"name_en": u"Malay", "name": u"Bahasa Melayu"},
|
||||
"nb_NO": {"name_en": u"Norwegian (bokmal)", "name": u"Norsk (bokmål)"},
|
||||
"nl_NL": {"name_en": u"Dutch", "name": u"Nederlands"},
|
||||
"nn_NO": {"name_en": u"Norwegian (nynorsk)", "name": u"Norsk (nynorsk)"},
|
||||
"pa_IN": {"name_en": u"Punjabi", "name": u"ਪੰਜਾਬੀ"},
|
||||
"pl_PL": {"name_en": u"Polish", "name": u"Polski"},
|
||||
"pt_BR": {"name_en": u"Portuguese (Brazil)", "name": u"Português (Brasil)"},
|
||||
"pt_PT": {"name_en": u"Portuguese (Portugal)", "name": u"Português (Portugal)"},
|
||||
"ro_RO": {"name_en": u"Romanian", "name": u"Română"},
|
||||
"ru_RU": {"name_en": u"Russian", "name": u"Русский"},
|
||||
"sk_SK": {"name_en": u"Slovak", "name": u"Slovenčina"},
|
||||
"sl_SI": {"name_en": u"Slovenian", "name": u"Slovenščina"},
|
||||
"sq_AL": {"name_en": u"Albanian", "name": u"Shqip"},
|
||||
"sr_RS": {"name_en": u"Serbian", "name": u"Српски"},
|
||||
"sv_SE": {"name_en": u"Swedish", "name": u"Svenska"},
|
||||
"sw_KE": {"name_en": u"Swahili", "name": u"Kiswahili"},
|
||||
"ta_IN": {"name_en": u"Tamil", "name": u"தமிழ்"},
|
||||
"te_IN": {"name_en": u"Telugu", "name": u"తెలుగు"},
|
||||
"th_TH": {"name_en": u"Thai", "name": u"ภาษาไทย"},
|
||||
"tl_PH": {"name_en": u"Filipino", "name": u"Filipino"},
|
||||
"tr_TR": {"name_en": u"Turkish", "name": u"Türkçe"},
|
||||
"uk_UA": {"name_en": u"Ukraini ", "name": u"Українська"},
|
||||
"vi_VN": {"name_en": u"Vietnamese", "name": u"Tiếng Việt"},
|
||||
"zh_CN": {"name_en": u"Chinese (Simplified)", "name": u"中文(简体)"},
|
||||
"zh_TW": {"name_en": u"Chinese (Traditional)", "name": u"中文(繁體)"},
|
||||
}
|
||||
Executable
+1236
File diff suppressed because it is too large
Load Diff
Executable
+356
@@ -0,0 +1,356 @@
|
||||
#
|
||||
# Copyright 2009 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Automatically restart the server when a source file is modified.
|
||||
|
||||
Most applications should not access this module directly. Instead,
|
||||
pass the keyword argument ``autoreload=True`` to the
|
||||
`tornado.web.Application` constructor (or ``debug=True``, which
|
||||
enables this setting and several others). This will enable autoreload
|
||||
mode as well as checking for changes to templates and static
|
||||
resources. Note that restarting is a destructive operation and any
|
||||
requests in progress will be aborted when the process restarts. (If
|
||||
you want to disable autoreload while using other debug-mode features,
|
||||
pass both ``debug=True`` and ``autoreload=False``).
|
||||
|
||||
This module can also be used as a command-line wrapper around scripts
|
||||
such as unit test runners. See the `main` method for details.
|
||||
|
||||
The command-line wrapper and Application debug modes can be used together.
|
||||
This combination is encouraged as the wrapper catches syntax errors and
|
||||
other import-time failures, while debug mode catches changes once
|
||||
the server has started.
|
||||
|
||||
This module depends on `.IOLoop`, so it will not work in WSGI applications
|
||||
and Google App Engine. It also will not work correctly when `.HTTPServer`'s
|
||||
multi-process mode is used.
|
||||
|
||||
Reloading loses any Python interpreter command-line arguments (e.g. ``-u``)
|
||||
because it re-executes Python using ``sys.executable`` and ``sys.argv``.
|
||||
Additionally, modifying these variables will cause reloading to behave
|
||||
incorrectly.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# sys.path handling
|
||||
# -----------------
|
||||
#
|
||||
# If a module is run with "python -m", the current directory (i.e. "")
|
||||
# is automatically prepended to sys.path, but not if it is run as
|
||||
# "path/to/file.py". The processing for "-m" rewrites the former to
|
||||
# the latter, so subsequent executions won't have the same path as the
|
||||
# original.
|
||||
#
|
||||
# Conversely, when run as path/to/file.py, the directory containing
|
||||
# file.py gets added to the path, which can cause confusion as imports
|
||||
# may become relative in spite of the future import.
|
||||
#
|
||||
# We address the former problem by reconstructing the original command
|
||||
# line (Python >= 3.4) or by setting the $PYTHONPATH environment
|
||||
# variable (Python < 3.4) before re-execution so the new process will
|
||||
# see the correct path. We attempt to address the latter problem when
|
||||
# tornado.autoreload is run as __main__.
|
||||
|
||||
if __name__ == "__main__":
|
||||
# This sys.path manipulation must come before our imports (as much
|
||||
# as possible - if we introduced a tornado.sys or tornado.os
|
||||
# module we'd be in trouble), or else our imports would become
|
||||
# relative again despite the future import.
|
||||
#
|
||||
# There is a separate __main__ block at the end of the file to call main().
|
||||
if sys.path[0] == os.path.dirname(__file__):
|
||||
del sys.path[0]
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import pkgutil # type: ignore
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
import subprocess
|
||||
import weakref
|
||||
|
||||
from tornado import ioloop
|
||||
from tornado.log import gen_log
|
||||
from tornado import process
|
||||
from tornado.util import exec_in
|
||||
|
||||
try:
|
||||
import signal
|
||||
except ImportError:
|
||||
signal = None
|
||||
|
||||
# os.execv is broken on Windows and can't properly parse command line
|
||||
# arguments and executable name if they contain whitespaces. subprocess
|
||||
# fixes that behavior.
|
||||
_has_execv = sys.platform != 'win32'
|
||||
|
||||
_watched_files = set()
|
||||
_reload_hooks = []
|
||||
_reload_attempted = False
|
||||
_io_loops = weakref.WeakKeyDictionary() # type: ignore
|
||||
_autoreload_is_main = False
|
||||
_original_argv = None
|
||||
_original_spec = None
|
||||
|
||||
|
||||
def start(check_time=500):
|
||||
"""Begins watching source files for changes.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
|
||||
"""
|
||||
io_loop = ioloop.IOLoop.current()
|
||||
if io_loop in _io_loops:
|
||||
return
|
||||
_io_loops[io_loop] = True
|
||||
if len(_io_loops) > 1:
|
||||
gen_log.warning("tornado.autoreload started more than once in the same process")
|
||||
modify_times = {}
|
||||
callback = functools.partial(_reload_on_update, modify_times)
|
||||
scheduler = ioloop.PeriodicCallback(callback, check_time)
|
||||
scheduler.start()
|
||||
|
||||
|
||||
def wait():
|
||||
"""Wait for a watched file to change, then restart the process.
|
||||
|
||||
Intended to be used at the end of scripts like unit test runners,
|
||||
to run the tests again after any source file changes (but see also
|
||||
the command-line interface in `main`)
|
||||
"""
|
||||
io_loop = ioloop.IOLoop()
|
||||
io_loop.add_callback(start)
|
||||
io_loop.start()
|
||||
|
||||
|
||||
def watch(filename):
|
||||
"""Add a file to the watch list.
|
||||
|
||||
All imported modules are watched by default.
|
||||
"""
|
||||
_watched_files.add(filename)
|
||||
|
||||
|
||||
def add_reload_hook(fn):
|
||||
"""Add a function to be called before reloading the process.
|
||||
|
||||
Note that for open file and socket handles it is generally
|
||||
preferable to set the ``FD_CLOEXEC`` flag (using `fcntl` or
|
||||
``tornado.platform.auto.set_close_exec``) instead
|
||||
of using a reload hook to close them.
|
||||
"""
|
||||
_reload_hooks.append(fn)
|
||||
|
||||
|
||||
def _reload_on_update(modify_times):
|
||||
if _reload_attempted:
|
||||
# We already tried to reload and it didn't work, so don't try again.
|
||||
return
|
||||
if process.task_id() is not None:
|
||||
# We're in a child process created by fork_processes. If child
|
||||
# processes restarted themselves, they'd all restart and then
|
||||
# all call fork_processes again.
|
||||
return
|
||||
for module in list(sys.modules.values()):
|
||||
# Some modules play games with sys.modules (e.g. email/__init__.py
|
||||
# in the standard library), and occasionally this can cause strange
|
||||
# failures in getattr. Just ignore anything that's not an ordinary
|
||||
# module.
|
||||
if not isinstance(module, types.ModuleType):
|
||||
continue
|
||||
path = getattr(module, "__file__", None)
|
||||
if not path:
|
||||
continue
|
||||
if path.endswith(".pyc") or path.endswith(".pyo"):
|
||||
path = path[:-1]
|
||||
_check_file(modify_times, path)
|
||||
for path in _watched_files:
|
||||
_check_file(modify_times, path)
|
||||
|
||||
|
||||
def _check_file(modify_times, path):
|
||||
try:
|
||||
modified = os.stat(path).st_mtime
|
||||
except Exception:
|
||||
return
|
||||
if path not in modify_times:
|
||||
modify_times[path] = modified
|
||||
return
|
||||
if modify_times[path] != modified:
|
||||
gen_log.info("%s modified; restarting server", path)
|
||||
_reload()
|
||||
|
||||
|
||||
def _reload():
|
||||
global _reload_attempted
|
||||
_reload_attempted = True
|
||||
for fn in _reload_hooks:
|
||||
fn()
|
||||
if hasattr(signal, "setitimer"):
|
||||
# Clear the alarm signal set by
|
||||
# ioloop.set_blocking_log_threshold so it doesn't fire
|
||||
# after the exec.
|
||||
signal.setitimer(signal.ITIMER_REAL, 0, 0)
|
||||
# sys.path fixes: see comments at top of file. If __main__.__spec__
|
||||
# exists, we were invoked with -m and the effective path is about to
|
||||
# change on re-exec. Reconstruct the original command line to
|
||||
# ensure that the new process sees the same path we did. If
|
||||
# __spec__ is not available (Python < 3.4), check instead if
|
||||
# sys.path[0] is an empty string and add the current directory to
|
||||
# $PYTHONPATH.
|
||||
if _autoreload_is_main:
|
||||
spec = _original_spec
|
||||
argv = _original_argv
|
||||
else:
|
||||
spec = getattr(sys.modules['__main__'], '__spec__', None)
|
||||
argv = sys.argv
|
||||
if spec:
|
||||
argv = ['-m', spec.name] + argv[1:]
|
||||
else:
|
||||
path_prefix = '.' + os.pathsep
|
||||
if (sys.path[0] == '' and
|
||||
not os.environ.get("PYTHONPATH", "").startswith(path_prefix)):
|
||||
os.environ["PYTHONPATH"] = (path_prefix +
|
||||
os.environ.get("PYTHONPATH", ""))
|
||||
if not _has_execv:
|
||||
subprocess.Popen([sys.executable] + argv)
|
||||
os._exit(0)
|
||||
else:
|
||||
try:
|
||||
os.execv(sys.executable, [sys.executable] + argv)
|
||||
except OSError:
|
||||
# Mac OS X versions prior to 10.6 do not support execv in
|
||||
# a process that contains multiple threads. Instead of
|
||||
# re-executing in the current process, start a new one
|
||||
# and cause the current process to exit. This isn't
|
||||
# ideal since the new process is detached from the parent
|
||||
# terminal and thus cannot easily be killed with ctrl-C,
|
||||
# but it's better than not being able to autoreload at
|
||||
# all.
|
||||
# Unfortunately the errno returned in this case does not
|
||||
# appear to be consistent, so we can't easily check for
|
||||
# this error specifically.
|
||||
os.spawnv(os.P_NOWAIT, sys.executable, [sys.executable] + argv)
|
||||
# At this point the IOLoop has been closed and finally
|
||||
# blocks will experience errors if we allow the stack to
|
||||
# unwind, so just exit uncleanly.
|
||||
os._exit(0)
|
||||
|
||||
|
||||
_USAGE = """\
|
||||
Usage:
|
||||
python -m tornado.autoreload -m module.to.run [args...]
|
||||
python -m tornado.autoreload path/to/script.py [args...]
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
"""Command-line wrapper to re-run a script whenever its source changes.
|
||||
|
||||
Scripts may be specified by filename or module name::
|
||||
|
||||
python -m tornado.autoreload -m tornado.test.runtests
|
||||
python -m tornado.autoreload tornado/test/runtests.py
|
||||
|
||||
Running a script with this wrapper is similar to calling
|
||||
`tornado.autoreload.wait` at the end of the script, but this wrapper
|
||||
can catch import-time problems like syntax errors that would otherwise
|
||||
prevent the script from reaching its call to `wait`.
|
||||
"""
|
||||
# Remember that we were launched with autoreload as main.
|
||||
# The main module can be tricky; set the variables both in our globals
|
||||
# (which may be __main__) and the real importable version.
|
||||
import tornado.autoreload
|
||||
global _autoreload_is_main
|
||||
global _original_argv, _original_spec
|
||||
tornado.autoreload._autoreload_is_main = _autoreload_is_main = True
|
||||
original_argv = sys.argv
|
||||
tornado.autoreload._original_argv = _original_argv = original_argv
|
||||
original_spec = getattr(sys.modules['__main__'], '__spec__', None)
|
||||
tornado.autoreload._original_spec = _original_spec = original_spec
|
||||
sys.argv = sys.argv[:]
|
||||
if len(sys.argv) >= 3 and sys.argv[1] == "-m":
|
||||
mode = "module"
|
||||
module = sys.argv[2]
|
||||
del sys.argv[1:3]
|
||||
elif len(sys.argv) >= 2:
|
||||
mode = "script"
|
||||
script = sys.argv[1]
|
||||
sys.argv = sys.argv[1:]
|
||||
else:
|
||||
print(_USAGE, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
if mode == "module":
|
||||
import runpy
|
||||
runpy.run_module(module, run_name="__main__", alter_sys=True)
|
||||
elif mode == "script":
|
||||
with open(script) as f:
|
||||
# Execute the script in our namespace instead of creating
|
||||
# a new one so that something that tries to import __main__
|
||||
# (e.g. the unittest module) will see names defined in the
|
||||
# script instead of just those defined in this module.
|
||||
global __file__
|
||||
__file__ = script
|
||||
# If __package__ is defined, imports may be incorrectly
|
||||
# interpreted as relative to this module.
|
||||
global __package__
|
||||
del __package__
|
||||
exec_in(f.read(), globals(), globals())
|
||||
except SystemExit as e:
|
||||
logging.basicConfig()
|
||||
gen_log.info("Script exited with status %s", e.code)
|
||||
except Exception as e:
|
||||
logging.basicConfig()
|
||||
gen_log.warning("Script exited with uncaught exception", exc_info=True)
|
||||
# If an exception occurred at import time, the file with the error
|
||||
# never made it into sys.modules and so we won't know to watch it.
|
||||
# Just to make sure we've covered everything, walk the stack trace
|
||||
# from the exception and watch every file.
|
||||
for (filename, lineno, name, line) in traceback.extract_tb(sys.exc_info()[2]):
|
||||
watch(filename)
|
||||
if isinstance(e, SyntaxError):
|
||||
# SyntaxErrors are special: their innermost stack frame is fake
|
||||
# so extract_tb won't see it and we have to get the filename
|
||||
# from the exception object.
|
||||
watch(e.filename)
|
||||
else:
|
||||
logging.basicConfig()
|
||||
gen_log.info("Script exited normally")
|
||||
# restore sys.argv so subsequent executions will include autoreload
|
||||
sys.argv = original_argv
|
||||
|
||||
if mode == 'module':
|
||||
# runpy did a fake import of the module as __main__, but now it's
|
||||
# no longer in sys.modules. Figure out where it is and watch it.
|
||||
loader = pkgutil.get_loader(module)
|
||||
if loader is not None:
|
||||
watch(loader.get_filename())
|
||||
|
||||
wait()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# See also the other __main__ block at the top of the file, which modifies
|
||||
# sys.path before our imports
|
||||
main()
|
||||
Executable
+660
@@ -0,0 +1,660 @@
|
||||
#
|
||||
# Copyright 2012 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Utilities for working with ``Future`` objects.
|
||||
|
||||
``Futures`` are a pattern for concurrent programming introduced in
|
||||
Python 3.2 in the `concurrent.futures` package, and also adopted (in a
|
||||
slightly different form) in Python 3.4's `asyncio` package. This
|
||||
package defines a ``Future`` class that is an alias for `asyncio.Future`
|
||||
when available, and a compatible implementation for older versions of
|
||||
Python. It also includes some utility functions for interacting with
|
||||
``Future`` objects.
|
||||
|
||||
While this package is an important part of Tornado's internal
|
||||
implementation, applications rarely need to interact with it
|
||||
directly.
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import functools
|
||||
import platform
|
||||
import textwrap
|
||||
import traceback
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from tornado.log import app_log
|
||||
from tornado.stack_context import ExceptionStackContext, wrap
|
||||
from tornado.util import raise_exc_info, ArgReplacer, is_finalizing
|
||||
|
||||
try:
|
||||
from concurrent import futures
|
||||
except ImportError:
|
||||
futures = None
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError:
|
||||
asyncio = None
|
||||
|
||||
try:
|
||||
import typing
|
||||
except ImportError:
|
||||
typing = None
|
||||
|
||||
|
||||
# Can the garbage collector handle cycles that include __del__ methods?
|
||||
# This is true in cpython beginning with version 3.4 (PEP 442).
|
||||
_GC_CYCLE_FINALIZERS = (platform.python_implementation() == 'CPython' and
|
||||
sys.version_info >= (3, 4))
|
||||
|
||||
|
||||
class ReturnValueIgnoredError(Exception):
|
||||
pass
|
||||
|
||||
# This class and associated code in the future object is derived
|
||||
# from the Trollius project, a backport of asyncio to Python 2.x - 3.x
|
||||
|
||||
|
||||
class _TracebackLogger(object):
|
||||
"""Helper to log a traceback upon destruction if not cleared.
|
||||
|
||||
This solves a nasty problem with Futures and Tasks that have an
|
||||
exception set: if nobody asks for the exception, the exception is
|
||||
never logged. This violates the Zen of Python: 'Errors should
|
||||
never pass silently. Unless explicitly silenced.'
|
||||
|
||||
However, we don't want to log the exception as soon as
|
||||
set_exception() is called: if the calling code is written
|
||||
properly, it will get the exception and handle it properly. But
|
||||
we *do* want to log it if result() or exception() was never called
|
||||
-- otherwise developers waste a lot of time wondering why their
|
||||
buggy code fails silently.
|
||||
|
||||
An earlier attempt added a __del__() method to the Future class
|
||||
itself, but this backfired because the presence of __del__()
|
||||
prevents garbage collection from breaking cycles. A way out of
|
||||
this catch-22 is to avoid having a __del__() method on the Future
|
||||
class itself, but instead to have a reference to a helper object
|
||||
with a __del__() method that logs the traceback, where we ensure
|
||||
that the helper object doesn't participate in cycles, and only the
|
||||
Future has a reference to it.
|
||||
|
||||
The helper object is added when set_exception() is called. When
|
||||
the Future is collected, and the helper is present, the helper
|
||||
object is also collected, and its __del__() method will log the
|
||||
traceback. When the Future's result() or exception() method is
|
||||
called (and a helper object is present), it removes the the helper
|
||||
object, after calling its clear() method to prevent it from
|
||||
logging.
|
||||
|
||||
One downside is that we do a fair amount of work to extract the
|
||||
traceback from the exception, even when it is never logged. It
|
||||
would seem cheaper to just store the exception object, but that
|
||||
references the traceback, which references stack frames, which may
|
||||
reference the Future, which references the _TracebackLogger, and
|
||||
then the _TracebackLogger would be included in a cycle, which is
|
||||
what we're trying to avoid! As an optimization, we don't
|
||||
immediately format the exception; we only do the work when
|
||||
activate() is called, which call is delayed until after all the
|
||||
Future's callbacks have run. Since usually a Future has at least
|
||||
one callback (typically set by 'yield From') and usually that
|
||||
callback extracts the callback, thereby removing the need to
|
||||
format the exception.
|
||||
|
||||
PS. I don't claim credit for this solution. I first heard of it
|
||||
in a discussion about closing files when they are collected.
|
||||
"""
|
||||
|
||||
__slots__ = ('exc_info', 'formatted_tb')
|
||||
|
||||
def __init__(self, exc_info):
|
||||
self.exc_info = exc_info
|
||||
self.formatted_tb = None
|
||||
|
||||
def activate(self):
|
||||
exc_info = self.exc_info
|
||||
if exc_info is not None:
|
||||
self.exc_info = None
|
||||
self.formatted_tb = traceback.format_exception(*exc_info)
|
||||
|
||||
def clear(self):
|
||||
self.exc_info = None
|
||||
self.formatted_tb = None
|
||||
|
||||
def __del__(self, is_finalizing=is_finalizing):
|
||||
if not is_finalizing() and self.formatted_tb:
|
||||
app_log.error('Future exception was never retrieved: %s',
|
||||
''.join(self.formatted_tb).rstrip())
|
||||
|
||||
|
||||
class Future(object):
|
||||
"""Placeholder for an asynchronous result.
|
||||
|
||||
A ``Future`` encapsulates the result of an asynchronous
|
||||
operation. In synchronous applications ``Futures`` are used
|
||||
to wait for the result from a thread or process pool; in
|
||||
Tornado they are normally used with `.IOLoop.add_future` or by
|
||||
yielding them in a `.gen.coroutine`.
|
||||
|
||||
`tornado.concurrent.Future` is an alias for `asyncio.Future` when
|
||||
that package is available (Python 3.4+). Unlike
|
||||
`concurrent.futures.Future`, the ``Futures`` used by Tornado and
|
||||
`asyncio` are not thread-safe (and therefore faster for use with
|
||||
single-threaded event loops).
|
||||
|
||||
In addition to ``exception`` and ``set_exception``, Tornado's
|
||||
``Future`` implementation supports storing an ``exc_info`` triple
|
||||
to support better tracebacks on Python 2. To set an ``exc_info``
|
||||
triple, use `future_set_exc_info`, and to retrieve one, call
|
||||
`result()` (which will raise it).
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
`tornado.concurrent.Future` is always a thread-unsafe ``Future``
|
||||
with support for the ``exc_info`` methods. Previously it would
|
||||
be an alias for the thread-safe `concurrent.futures.Future`
|
||||
if that package was available and fall back to the thread-unsafe
|
||||
implementation if it was not.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
If a `.Future` contains an error but that error is never observed
|
||||
(by calling ``result()``, ``exception()``, or ``exc_info()``),
|
||||
a stack trace will be logged when the `.Future` is garbage collected.
|
||||
This normally indicates an error in the application, but in cases
|
||||
where it results in undesired logging it may be necessary to
|
||||
suppress the logging by ensuring that the exception is observed:
|
||||
``f.add_done_callback(lambda f: f.exception())``.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
|
||||
This class was previoiusly available under the name
|
||||
``TracebackFuture``. This name, which was deprecated since
|
||||
version 4.0, has been removed. When `asyncio` is available
|
||||
``tornado.concurrent.Future`` is now an alias for
|
||||
`asyncio.Future`. Like `asyncio.Future`, callbacks are now
|
||||
always scheduled on the `.IOLoop` and are never run
|
||||
synchronously.
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
self._done = False
|
||||
self._result = None
|
||||
self._exc_info = None
|
||||
|
||||
self._log_traceback = False # Used for Python >= 3.4
|
||||
self._tb_logger = None # Used for Python <= 3.3
|
||||
|
||||
self._callbacks = []
|
||||
|
||||
# Implement the Python 3.5 Awaitable protocol if possible
|
||||
# (we can't use return and yield together until py33).
|
||||
if sys.version_info >= (3, 3):
|
||||
exec(textwrap.dedent("""
|
||||
def __await__(self):
|
||||
return (yield self)
|
||||
"""))
|
||||
else:
|
||||
# Py2-compatible version for use with cython.
|
||||
def __await__(self):
|
||||
result = yield self
|
||||
# StopIteration doesn't take args before py33,
|
||||
# but Cython recognizes the args tuple.
|
||||
e = StopIteration()
|
||||
e.args = (result,)
|
||||
raise e
|
||||
|
||||
def cancel(self):
|
||||
"""Cancel the operation, if possible.
|
||||
|
||||
Tornado ``Futures`` do not support cancellation, so this method always
|
||||
returns False.
|
||||
"""
|
||||
return False
|
||||
|
||||
def cancelled(self):
|
||||
"""Returns True if the operation has been cancelled.
|
||||
|
||||
Tornado ``Futures`` do not support cancellation, so this method
|
||||
always returns False.
|
||||
"""
|
||||
return False
|
||||
|
||||
def running(self):
|
||||
"""Returns True if this operation is currently running."""
|
||||
return not self._done
|
||||
|
||||
def done(self):
|
||||
"""Returns True if the future has finished running."""
|
||||
return self._done
|
||||
|
||||
def _clear_tb_log(self):
|
||||
self._log_traceback = False
|
||||
if self._tb_logger is not None:
|
||||
self._tb_logger.clear()
|
||||
self._tb_logger = None
|
||||
|
||||
def result(self, timeout=None):
|
||||
"""If the operation succeeded, return its result. If it failed,
|
||||
re-raise its exception.
|
||||
|
||||
This method takes a ``timeout`` argument for compatibility with
|
||||
`concurrent.futures.Future` but it is an error to call it
|
||||
before the `Future` is done, so the ``timeout`` is never used.
|
||||
"""
|
||||
self._clear_tb_log()
|
||||
if self._result is not None:
|
||||
return self._result
|
||||
if self._exc_info is not None:
|
||||
try:
|
||||
raise_exc_info(self._exc_info)
|
||||
finally:
|
||||
self = None
|
||||
self._check_done()
|
||||
return self._result
|
||||
|
||||
def exception(self, timeout=None):
|
||||
"""If the operation raised an exception, return the `Exception`
|
||||
object. Otherwise returns None.
|
||||
|
||||
This method takes a ``timeout`` argument for compatibility with
|
||||
`concurrent.futures.Future` but it is an error to call it
|
||||
before the `Future` is done, so the ``timeout`` is never used.
|
||||
"""
|
||||
self._clear_tb_log()
|
||||
if self._exc_info is not None:
|
||||
return self._exc_info[1]
|
||||
else:
|
||||
self._check_done()
|
||||
return None
|
||||
|
||||
def add_done_callback(self, fn):
|
||||
"""Attaches the given callback to the `Future`.
|
||||
|
||||
It will be invoked with the `Future` as its argument when the Future
|
||||
has finished running and its result is available. In Tornado
|
||||
consider using `.IOLoop.add_future` instead of calling
|
||||
`add_done_callback` directly.
|
||||
"""
|
||||
if self._done:
|
||||
from tornado.ioloop import IOLoop
|
||||
IOLoop.current().add_callback(fn, self)
|
||||
else:
|
||||
self._callbacks.append(fn)
|
||||
|
||||
def set_result(self, result):
|
||||
"""Sets the result of a ``Future``.
|
||||
|
||||
It is undefined to call any of the ``set`` methods more than once
|
||||
on the same object.
|
||||
"""
|
||||
self._result = result
|
||||
self._set_done()
|
||||
|
||||
def set_exception(self, exception):
|
||||
"""Sets the exception of a ``Future.``"""
|
||||
self.set_exc_info(
|
||||
(exception.__class__,
|
||||
exception,
|
||||
getattr(exception, '__traceback__', None)))
|
||||
|
||||
def exc_info(self):
|
||||
"""Returns a tuple in the same format as `sys.exc_info` or None.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
self._clear_tb_log()
|
||||
return self._exc_info
|
||||
|
||||
def set_exc_info(self, exc_info):
|
||||
"""Sets the exception information of a ``Future.``
|
||||
|
||||
Preserves tracebacks on Python 2.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
self._exc_info = exc_info
|
||||
self._log_traceback = True
|
||||
if not _GC_CYCLE_FINALIZERS:
|
||||
self._tb_logger = _TracebackLogger(exc_info)
|
||||
|
||||
try:
|
||||
self._set_done()
|
||||
finally:
|
||||
# Activate the logger after all callbacks have had a
|
||||
# chance to call result() or exception().
|
||||
if self._log_traceback and self._tb_logger is not None:
|
||||
self._tb_logger.activate()
|
||||
self._exc_info = exc_info
|
||||
|
||||
def _check_done(self):
|
||||
if not self._done:
|
||||
raise Exception("DummyFuture does not support blocking for results")
|
||||
|
||||
def _set_done(self):
|
||||
self._done = True
|
||||
if self._callbacks:
|
||||
from tornado.ioloop import IOLoop
|
||||
loop = IOLoop.current()
|
||||
for cb in self._callbacks:
|
||||
loop.add_callback(cb, self)
|
||||
self._callbacks = None
|
||||
|
||||
# On Python 3.3 or older, objects with a destructor part of a reference
|
||||
# cycle are never destroyed. It's no longer the case on Python 3.4 thanks to
|
||||
# the PEP 442.
|
||||
if _GC_CYCLE_FINALIZERS:
|
||||
def __del__(self, is_finalizing=is_finalizing):
|
||||
if is_finalizing() or not self._log_traceback:
|
||||
# set_exception() was not called, or result() or exception()
|
||||
# has consumed the exception
|
||||
return
|
||||
|
||||
tb = traceback.format_exception(*self._exc_info)
|
||||
|
||||
app_log.error('Future %r exception was never retrieved: %s',
|
||||
self, ''.join(tb).rstrip())
|
||||
|
||||
|
||||
if asyncio is not None:
|
||||
Future = asyncio.Future # noqa
|
||||
|
||||
if futures is None:
|
||||
FUTURES = Future # type: typing.Union[type, typing.Tuple[type, ...]]
|
||||
else:
|
||||
FUTURES = (futures.Future, Future)
|
||||
|
||||
|
||||
def is_future(x):
|
||||
return isinstance(x, FUTURES)
|
||||
|
||||
|
||||
class DummyExecutor(object):
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
future = Future()
|
||||
try:
|
||||
future_set_result_unless_cancelled(future, fn(*args, **kwargs))
|
||||
except Exception:
|
||||
future_set_exc_info(future, sys.exc_info())
|
||||
return future
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
pass
|
||||
|
||||
|
||||
dummy_executor = DummyExecutor()
|
||||
|
||||
|
||||
def run_on_executor(*args, **kwargs):
|
||||
"""Decorator to run a synchronous method asynchronously on an executor.
|
||||
|
||||
The decorated method may be called with a ``callback`` keyword
|
||||
argument and returns a future.
|
||||
|
||||
The executor to be used is determined by the ``executor``
|
||||
attributes of ``self``. To use a different attribute name, pass a
|
||||
keyword argument to the decorator::
|
||||
|
||||
@run_on_executor(executor='_thread_pool')
|
||||
def foo(self):
|
||||
pass
|
||||
|
||||
This decorator should not be confused with the similarly-named
|
||||
`.IOLoop.run_in_executor`. In general, using ``run_in_executor``
|
||||
when *calling* a blocking method is recommended instead of using
|
||||
this decorator when *defining* a method. If compatibility with older
|
||||
versions of Tornado is required, consider defining an executor
|
||||
and using ``executor.submit()`` at the call site.
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
Added keyword arguments to use alternative attributes.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
Always uses the current IOLoop instead of ``self.io_loop``.
|
||||
|
||||
.. versionchanged:: 5.1
|
||||
Returns a `.Future` compatible with ``await`` instead of a
|
||||
`concurrent.futures.Future`.
|
||||
|
||||
.. deprecated:: 5.1
|
||||
|
||||
The ``callback`` argument is deprecated and will be removed in
|
||||
6.0. The decorator itself is discouraged in new code but will
|
||||
not be removed in 6.0.
|
||||
"""
|
||||
def run_on_executor_decorator(fn):
|
||||
executor = kwargs.get("executor", "executor")
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
callback = kwargs.pop("callback", None)
|
||||
async_future = Future()
|
||||
conc_future = getattr(self, executor).submit(fn, self, *args, **kwargs)
|
||||
chain_future(conc_future, async_future)
|
||||
if callback:
|
||||
warnings.warn("callback arguments are deprecated, use the returned Future instead",
|
||||
DeprecationWarning)
|
||||
from tornado.ioloop import IOLoop
|
||||
IOLoop.current().add_future(
|
||||
async_future, lambda future: callback(future.result()))
|
||||
return async_future
|
||||
return wrapper
|
||||
if args and kwargs:
|
||||
raise ValueError("cannot combine positional and keyword args")
|
||||
if len(args) == 1:
|
||||
return run_on_executor_decorator(args[0])
|
||||
elif len(args) != 0:
|
||||
raise ValueError("expected 1 argument, got %d", len(args))
|
||||
return run_on_executor_decorator
|
||||
|
||||
|
||||
_NO_RESULT = object()
|
||||
|
||||
|
||||
def return_future(f):
|
||||
"""Decorator to make a function that returns via callback return a
|
||||
`Future`.
|
||||
|
||||
This decorator was provided to ease the transition from
|
||||
callback-oriented code to coroutines. It is not recommended for
|
||||
new code.
|
||||
|
||||
The wrapped function should take a ``callback`` keyword argument
|
||||
and invoke it with one argument when it has finished. To signal failure,
|
||||
the function can simply raise an exception (which will be
|
||||
captured by the `.StackContext` and passed along to the ``Future``).
|
||||
|
||||
From the caller's perspective, the callback argument is optional.
|
||||
If one is given, it will be invoked when the function is complete
|
||||
with ``Future.result()`` as an argument. If the function fails, the
|
||||
callback will not be run and an exception will be raised into the
|
||||
surrounding `.StackContext`.
|
||||
|
||||
If no callback is given, the caller should use the ``Future`` to
|
||||
wait for the function to complete (perhaps by yielding it in a
|
||||
coroutine, or passing it to `.IOLoop.add_future`).
|
||||
|
||||
Usage:
|
||||
|
||||
.. testcode::
|
||||
|
||||
@return_future
|
||||
def future_func(arg1, arg2, callback):
|
||||
# Do stuff (possibly asynchronous)
|
||||
callback(result)
|
||||
|
||||
async def caller():
|
||||
await future_func(arg1, arg2)
|
||||
|
||||
..
|
||||
|
||||
Note that ``@return_future`` and ``@gen.engine`` can be applied to the
|
||||
same function, provided ``@return_future`` appears first. However,
|
||||
consider using ``@gen.coroutine`` instead of this combination.
|
||||
|
||||
.. versionchanged:: 5.1
|
||||
|
||||
Now raises a `.DeprecationWarning` if a callback argument is passed to
|
||||
the decorated function and deprecation warnings are enabled.
|
||||
|
||||
.. deprecated:: 5.1
|
||||
|
||||
This decorator will be removed in Tornado 6.0. New code should
|
||||
use coroutines directly instead of wrapping callback-based code
|
||||
with this decorator. Interactions with non-Tornado
|
||||
callback-based code should be managed explicitly to avoid
|
||||
relying on the `.ExceptionStackContext` built into this
|
||||
decorator.
|
||||
"""
|
||||
warnings.warn("@return_future is deprecated, use coroutines instead",
|
||||
DeprecationWarning)
|
||||
return _non_deprecated_return_future(f, warn=True)
|
||||
|
||||
|
||||
def _non_deprecated_return_future(f, warn=False):
|
||||
# Allow auth.py to use this decorator without triggering
|
||||
# deprecation warnings. This will go away once auth.py has removed
|
||||
# its legacy interfaces in 6.0.
|
||||
replacer = ArgReplacer(f, 'callback')
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
future = Future()
|
||||
callback, args, kwargs = replacer.replace(
|
||||
lambda value=_NO_RESULT: future_set_result_unless_cancelled(future, value),
|
||||
args, kwargs)
|
||||
|
||||
def handle_error(typ, value, tb):
|
||||
future_set_exc_info(future, (typ, value, tb))
|
||||
return True
|
||||
exc_info = None
|
||||
esc = ExceptionStackContext(handle_error, delay_warning=True)
|
||||
with esc:
|
||||
if not warn:
|
||||
# HACK: In non-deprecated mode (only used in auth.py),
|
||||
# suppress the warning entirely. Since this is added
|
||||
# in a 5.1 patch release and already removed in 6.0
|
||||
# I'm prioritizing a minimial change instead of a
|
||||
# clean solution.
|
||||
esc.delay_warning = False
|
||||
try:
|
||||
result = f(*args, **kwargs)
|
||||
if result is not None:
|
||||
raise ReturnValueIgnoredError(
|
||||
"@return_future should not be used with functions "
|
||||
"that return values")
|
||||
except:
|
||||
exc_info = sys.exc_info()
|
||||
raise
|
||||
if exc_info is not None:
|
||||
# If the initial synchronous part of f() raised an exception,
|
||||
# go ahead and raise it to the caller directly without waiting
|
||||
# for them to inspect the Future.
|
||||
future.result()
|
||||
|
||||
# If the caller passed in a callback, schedule it to be called
|
||||
# when the future resolves. It is important that this happens
|
||||
# just before we return the future, or else we risk confusing
|
||||
# stack contexts with multiple exceptions (one here with the
|
||||
# immediate exception, and again when the future resolves and
|
||||
# the callback triggers its exception by calling future.result()).
|
||||
if callback is not None:
|
||||
warnings.warn("callback arguments are deprecated, use the returned Future instead",
|
||||
DeprecationWarning)
|
||||
|
||||
def run_callback(future):
|
||||
result = future.result()
|
||||
if result is _NO_RESULT:
|
||||
callback()
|
||||
else:
|
||||
callback(future.result())
|
||||
future_add_done_callback(future, wrap(run_callback))
|
||||
return future
|
||||
return wrapper
|
||||
|
||||
|
||||
def chain_future(a, b):
|
||||
"""Chain two futures together so that when one completes, so does the other.
|
||||
|
||||
The result (success or failure) of ``a`` will be copied to ``b``, unless
|
||||
``b`` has already been completed or cancelled by the time ``a`` finishes.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
|
||||
Now accepts both Tornado/asyncio `Future` objects and
|
||||
`concurrent.futures.Future`.
|
||||
|
||||
"""
|
||||
def copy(future):
|
||||
assert future is a
|
||||
if b.done():
|
||||
return
|
||||
if (hasattr(a, 'exc_info') and
|
||||
a.exc_info() is not None):
|
||||
future_set_exc_info(b, a.exc_info())
|
||||
elif a.exception() is not None:
|
||||
b.set_exception(a.exception())
|
||||
else:
|
||||
b.set_result(a.result())
|
||||
if isinstance(a, Future):
|
||||
future_add_done_callback(a, copy)
|
||||
else:
|
||||
# concurrent.futures.Future
|
||||
from tornado.ioloop import IOLoop
|
||||
IOLoop.current().add_future(a, copy)
|
||||
|
||||
|
||||
def future_set_result_unless_cancelled(future, value):
|
||||
"""Set the given ``value`` as the `Future`'s result, if not cancelled.
|
||||
|
||||
Avoids asyncio.InvalidStateError when calling set_result() on
|
||||
a cancelled `asyncio.Future`.
|
||||
|
||||
.. versionadded:: 5.0
|
||||
"""
|
||||
if not future.cancelled():
|
||||
future.set_result(value)
|
||||
|
||||
|
||||
def future_set_exc_info(future, exc_info):
|
||||
"""Set the given ``exc_info`` as the `Future`'s exception.
|
||||
|
||||
Understands both `asyncio.Future` and Tornado's extensions to
|
||||
enable better tracebacks on Python 2.
|
||||
|
||||
.. versionadded:: 5.0
|
||||
"""
|
||||
if hasattr(future, 'set_exc_info'):
|
||||
# Tornado's Future
|
||||
future.set_exc_info(exc_info)
|
||||
else:
|
||||
# asyncio.Future
|
||||
future.set_exception(exc_info[1])
|
||||
|
||||
|
||||
def future_add_done_callback(future, callback):
|
||||
"""Arrange to call ``callback`` when ``future`` is complete.
|
||||
|
||||
``callback`` is invoked with one argument, the ``future``.
|
||||
|
||||
If ``future`` is already done, ``callback`` is invoked immediately.
|
||||
This may differ from the behavior of ``Future.add_done_callback``,
|
||||
which makes no such guarantee.
|
||||
|
||||
.. versionadded:: 5.0
|
||||
"""
|
||||
if future.done():
|
||||
callback(future)
|
||||
else:
|
||||
future.add_done_callback(callback)
|
||||
Executable
+514
@@ -0,0 +1,514 @@
|
||||
#
|
||||
# Copyright 2009 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Non-blocking HTTP client implementation using pycurl."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import logging
|
||||
import pycurl # type: ignore
|
||||
import threading
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
from tornado import httputil
|
||||
from tornado import ioloop
|
||||
from tornado import stack_context
|
||||
|
||||
from tornado.escape import utf8, native_str
|
||||
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main
|
||||
|
||||
curl_log = logging.getLogger('tornado.curl_httpclient')
|
||||
|
||||
|
||||
class CurlAsyncHTTPClient(AsyncHTTPClient):
|
||||
def initialize(self, max_clients=10, defaults=None):
|
||||
super(CurlAsyncHTTPClient, self).initialize(defaults=defaults)
|
||||
self._multi = pycurl.CurlMulti()
|
||||
self._multi.setopt(pycurl.M_TIMERFUNCTION, self._set_timeout)
|
||||
self._multi.setopt(pycurl.M_SOCKETFUNCTION, self._handle_socket)
|
||||
self._curls = [self._curl_create() for i in range(max_clients)]
|
||||
self._free_list = self._curls[:]
|
||||
self._requests = collections.deque()
|
||||
self._fds = {}
|
||||
self._timeout = None
|
||||
|
||||
# libcurl has bugs that sometimes cause it to not report all
|
||||
# relevant file descriptors and timeouts to TIMERFUNCTION/
|
||||
# SOCKETFUNCTION. Mitigate the effects of such bugs by
|
||||
# forcing a periodic scan of all active requests.
|
||||
self._force_timeout_callback = ioloop.PeriodicCallback(
|
||||
self._handle_force_timeout, 1000)
|
||||
self._force_timeout_callback.start()
|
||||
|
||||
# Work around a bug in libcurl 7.29.0: Some fields in the curl
|
||||
# multi object are initialized lazily, and its destructor will
|
||||
# segfault if it is destroyed without having been used. Add
|
||||
# and remove a dummy handle to make sure everything is
|
||||
# initialized.
|
||||
dummy_curl_handle = pycurl.Curl()
|
||||
self._multi.add_handle(dummy_curl_handle)
|
||||
self._multi.remove_handle(dummy_curl_handle)
|
||||
|
||||
def close(self):
|
||||
self._force_timeout_callback.stop()
|
||||
if self._timeout is not None:
|
||||
self.io_loop.remove_timeout(self._timeout)
|
||||
for curl in self._curls:
|
||||
curl.close()
|
||||
self._multi.close()
|
||||
super(CurlAsyncHTTPClient, self).close()
|
||||
|
||||
# Set below properties to None to reduce the reference count of current
|
||||
# instance, because those properties hold some methods of current
|
||||
# instance that will case circular reference.
|
||||
self._force_timeout_callback = None
|
||||
self._multi = None
|
||||
|
||||
def fetch_impl(self, request, callback):
|
||||
self._requests.append((request, callback, self.io_loop.time()))
|
||||
self._process_queue()
|
||||
self._set_timeout(0)
|
||||
|
||||
def _handle_socket(self, event, fd, multi, data):
|
||||
"""Called by libcurl when it wants to change the file descriptors
|
||||
it cares about.
|
||||
"""
|
||||
event_map = {
|
||||
pycurl.POLL_NONE: ioloop.IOLoop.NONE,
|
||||
pycurl.POLL_IN: ioloop.IOLoop.READ,
|
||||
pycurl.POLL_OUT: ioloop.IOLoop.WRITE,
|
||||
pycurl.POLL_INOUT: ioloop.IOLoop.READ | ioloop.IOLoop.WRITE
|
||||
}
|
||||
if event == pycurl.POLL_REMOVE:
|
||||
if fd in self._fds:
|
||||
self.io_loop.remove_handler(fd)
|
||||
del self._fds[fd]
|
||||
else:
|
||||
ioloop_event = event_map[event]
|
||||
# libcurl sometimes closes a socket and then opens a new
|
||||
# one using the same FD without giving us a POLL_NONE in
|
||||
# between. This is a problem with the epoll IOLoop,
|
||||
# because the kernel can tell when a socket is closed and
|
||||
# removes it from the epoll automatically, causing future
|
||||
# update_handler calls to fail. Since we can't tell when
|
||||
# this has happened, always use remove and re-add
|
||||
# instead of update.
|
||||
if fd in self._fds:
|
||||
self.io_loop.remove_handler(fd)
|
||||
self.io_loop.add_handler(fd, self._handle_events,
|
||||
ioloop_event)
|
||||
self._fds[fd] = ioloop_event
|
||||
|
||||
def _set_timeout(self, msecs):
|
||||
"""Called by libcurl to schedule a timeout."""
|
||||
if self._timeout is not None:
|
||||
self.io_loop.remove_timeout(self._timeout)
|
||||
self._timeout = self.io_loop.add_timeout(
|
||||
self.io_loop.time() + msecs / 1000.0, self._handle_timeout)
|
||||
|
||||
def _handle_events(self, fd, events):
|
||||
"""Called by IOLoop when there is activity on one of our
|
||||
file descriptors.
|
||||
"""
|
||||
action = 0
|
||||
if events & ioloop.IOLoop.READ:
|
||||
action |= pycurl.CSELECT_IN
|
||||
if events & ioloop.IOLoop.WRITE:
|
||||
action |= pycurl.CSELECT_OUT
|
||||
while True:
|
||||
try:
|
||||
ret, num_handles = self._multi.socket_action(fd, action)
|
||||
except pycurl.error as e:
|
||||
ret = e.args[0]
|
||||
if ret != pycurl.E_CALL_MULTI_PERFORM:
|
||||
break
|
||||
self._finish_pending_requests()
|
||||
|
||||
def _handle_timeout(self):
|
||||
"""Called by IOLoop when the requested timeout has passed."""
|
||||
with stack_context.NullContext():
|
||||
self._timeout = None
|
||||
while True:
|
||||
try:
|
||||
ret, num_handles = self._multi.socket_action(
|
||||
pycurl.SOCKET_TIMEOUT, 0)
|
||||
except pycurl.error as e:
|
||||
ret = e.args[0]
|
||||
if ret != pycurl.E_CALL_MULTI_PERFORM:
|
||||
break
|
||||
self._finish_pending_requests()
|
||||
|
||||
# In theory, we shouldn't have to do this because curl will
|
||||
# call _set_timeout whenever the timeout changes. However,
|
||||
# sometimes after _handle_timeout we will need to reschedule
|
||||
# immediately even though nothing has changed from curl's
|
||||
# perspective. This is because when socket_action is
|
||||
# called with SOCKET_TIMEOUT, libcurl decides internally which
|
||||
# timeouts need to be processed by using a monotonic clock
|
||||
# (where available) while tornado uses python's time.time()
|
||||
# to decide when timeouts have occurred. When those clocks
|
||||
# disagree on elapsed time (as they will whenever there is an
|
||||
# NTP adjustment), tornado might call _handle_timeout before
|
||||
# libcurl is ready. After each timeout, resync the scheduled
|
||||
# timeout with libcurl's current state.
|
||||
new_timeout = self._multi.timeout()
|
||||
if new_timeout >= 0:
|
||||
self._set_timeout(new_timeout)
|
||||
|
||||
def _handle_force_timeout(self):
|
||||
"""Called by IOLoop periodically to ask libcurl to process any
|
||||
events it may have forgotten about.
|
||||
"""
|
||||
with stack_context.NullContext():
|
||||
while True:
|
||||
try:
|
||||
ret, num_handles = self._multi.socket_all()
|
||||
except pycurl.error as e:
|
||||
ret = e.args[0]
|
||||
if ret != pycurl.E_CALL_MULTI_PERFORM:
|
||||
break
|
||||
self._finish_pending_requests()
|
||||
|
||||
def _finish_pending_requests(self):
|
||||
"""Process any requests that were completed by the last
|
||||
call to multi.socket_action.
|
||||
"""
|
||||
while True:
|
||||
num_q, ok_list, err_list = self._multi.info_read()
|
||||
for curl in ok_list:
|
||||
self._finish(curl)
|
||||
for curl, errnum, errmsg in err_list:
|
||||
self._finish(curl, errnum, errmsg)
|
||||
if num_q == 0:
|
||||
break
|
||||
self._process_queue()
|
||||
|
||||
def _process_queue(self):
|
||||
with stack_context.NullContext():
|
||||
while True:
|
||||
started = 0
|
||||
while self._free_list and self._requests:
|
||||
started += 1
|
||||
curl = self._free_list.pop()
|
||||
(request, callback, queue_start_time) = self._requests.popleft()
|
||||
curl.info = {
|
||||
"headers": httputil.HTTPHeaders(),
|
||||
"buffer": BytesIO(),
|
||||
"request": request,
|
||||
"callback": callback,
|
||||
"queue_start_time": queue_start_time,
|
||||
"curl_start_time": time.time(),
|
||||
"curl_start_ioloop_time": self.io_loop.current().time(),
|
||||
}
|
||||
try:
|
||||
self._curl_setup_request(
|
||||
curl, request, curl.info["buffer"],
|
||||
curl.info["headers"])
|
||||
except Exception as e:
|
||||
# If there was an error in setup, pass it on
|
||||
# to the callback. Note that allowing the
|
||||
# error to escape here will appear to work
|
||||
# most of the time since we are still in the
|
||||
# caller's original stack frame, but when
|
||||
# _process_queue() is called from
|
||||
# _finish_pending_requests the exceptions have
|
||||
# nowhere to go.
|
||||
self._free_list.append(curl)
|
||||
callback(HTTPResponse(
|
||||
request=request,
|
||||
code=599,
|
||||
error=e))
|
||||
else:
|
||||
self._multi.add_handle(curl)
|
||||
|
||||
if not started:
|
||||
break
|
||||
|
||||
def _finish(self, curl, curl_error=None, curl_message=None):
|
||||
info = curl.info
|
||||
curl.info = None
|
||||
self._multi.remove_handle(curl)
|
||||
self._free_list.append(curl)
|
||||
buffer = info["buffer"]
|
||||
if curl_error:
|
||||
error = CurlError(curl_error, curl_message)
|
||||
code = error.code
|
||||
effective_url = None
|
||||
buffer.close()
|
||||
buffer = None
|
||||
else:
|
||||
error = None
|
||||
code = curl.getinfo(pycurl.HTTP_CODE)
|
||||
effective_url = curl.getinfo(pycurl.EFFECTIVE_URL)
|
||||
buffer.seek(0)
|
||||
# the various curl timings are documented at
|
||||
# http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html
|
||||
time_info = dict(
|
||||
queue=info["curl_start_ioloop_time"] - info["queue_start_time"],
|
||||
namelookup=curl.getinfo(pycurl.NAMELOOKUP_TIME),
|
||||
connect=curl.getinfo(pycurl.CONNECT_TIME),
|
||||
appconnect=curl.getinfo(pycurl.APPCONNECT_TIME),
|
||||
pretransfer=curl.getinfo(pycurl.PRETRANSFER_TIME),
|
||||
starttransfer=curl.getinfo(pycurl.STARTTRANSFER_TIME),
|
||||
total=curl.getinfo(pycurl.TOTAL_TIME),
|
||||
redirect=curl.getinfo(pycurl.REDIRECT_TIME),
|
||||
)
|
||||
try:
|
||||
info["callback"](HTTPResponse(
|
||||
request=info["request"], code=code, headers=info["headers"],
|
||||
buffer=buffer, effective_url=effective_url, error=error,
|
||||
reason=info['headers'].get("X-Http-Reason", None),
|
||||
request_time=self.io_loop.time() - info["curl_start_ioloop_time"],
|
||||
start_time=info["curl_start_time"],
|
||||
time_info=time_info))
|
||||
except Exception:
|
||||
self.handle_callback_exception(info["callback"])
|
||||
|
||||
def handle_callback_exception(self, callback):
|
||||
self.io_loop.handle_callback_exception(callback)
|
||||
|
||||
def _curl_create(self):
|
||||
curl = pycurl.Curl()
|
||||
if curl_log.isEnabledFor(logging.DEBUG):
|
||||
curl.setopt(pycurl.VERBOSE, 1)
|
||||
curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug)
|
||||
if hasattr(pycurl, 'PROTOCOLS'): # PROTOCOLS first appeared in pycurl 7.19.5 (2014-07-12)
|
||||
curl.setopt(pycurl.PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
|
||||
curl.setopt(pycurl.REDIR_PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
|
||||
return curl
|
||||
|
||||
def _curl_setup_request(self, curl, request, buffer, headers):
|
||||
curl.setopt(pycurl.URL, native_str(request.url))
|
||||
|
||||
# libcurl's magic "Expect: 100-continue" behavior causes delays
|
||||
# with servers that don't support it (which include, among others,
|
||||
# Google's OpenID endpoint). Additionally, this behavior has
|
||||
# a bug in conjunction with the curl_multi_socket_action API
|
||||
# (https://sourceforge.net/tracker/?func=detail&atid=100976&aid=3039744&group_id=976),
|
||||
# which increases the delays. It's more trouble than it's worth,
|
||||
# so just turn off the feature (yes, setting Expect: to an empty
|
||||
# value is the official way to disable this)
|
||||
if "Expect" not in request.headers:
|
||||
request.headers["Expect"] = ""
|
||||
|
||||
# libcurl adds Pragma: no-cache by default; disable that too
|
||||
if "Pragma" not in request.headers:
|
||||
request.headers["Pragma"] = ""
|
||||
|
||||
curl.setopt(pycurl.HTTPHEADER,
|
||||
["%s: %s" % (native_str(k), native_str(v))
|
||||
for k, v in request.headers.get_all()])
|
||||
|
||||
curl.setopt(pycurl.HEADERFUNCTION,
|
||||
functools.partial(self._curl_header_callback,
|
||||
headers, request.header_callback))
|
||||
if request.streaming_callback:
|
||||
def write_function(chunk):
|
||||
self.io_loop.add_callback(request.streaming_callback, chunk)
|
||||
else:
|
||||
write_function = buffer.write
|
||||
curl.setopt(pycurl.WRITEFUNCTION, write_function)
|
||||
curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
|
||||
curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
|
||||
curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
|
||||
curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
|
||||
if request.user_agent:
|
||||
curl.setopt(pycurl.USERAGENT, native_str(request.user_agent))
|
||||
else:
|
||||
curl.setopt(pycurl.USERAGENT, "Mozilla/5.0 (compatible; pycurl)")
|
||||
if request.network_interface:
|
||||
curl.setopt(pycurl.INTERFACE, request.network_interface)
|
||||
if request.decompress_response:
|
||||
curl.setopt(pycurl.ENCODING, "gzip,deflate")
|
||||
else:
|
||||
curl.setopt(pycurl.ENCODING, "none")
|
||||
if request.proxy_host and request.proxy_port:
|
||||
curl.setopt(pycurl.PROXY, request.proxy_host)
|
||||
curl.setopt(pycurl.PROXYPORT, request.proxy_port)
|
||||
if request.proxy_username:
|
||||
credentials = httputil.encode_username_password(request.proxy_username,
|
||||
request.proxy_password)
|
||||
curl.setopt(pycurl.PROXYUSERPWD, credentials)
|
||||
|
||||
if (request.proxy_auth_mode is None or
|
||||
request.proxy_auth_mode == "basic"):
|
||||
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_BASIC)
|
||||
elif request.proxy_auth_mode == "digest":
|
||||
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_DIGEST)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported proxy_auth_mode %s" % request.proxy_auth_mode)
|
||||
else:
|
||||
curl.setopt(pycurl.PROXY, '')
|
||||
curl.unsetopt(pycurl.PROXYUSERPWD)
|
||||
if request.validate_cert:
|
||||
curl.setopt(pycurl.SSL_VERIFYPEER, 1)
|
||||
curl.setopt(pycurl.SSL_VERIFYHOST, 2)
|
||||
else:
|
||||
curl.setopt(pycurl.SSL_VERIFYPEER, 0)
|
||||
curl.setopt(pycurl.SSL_VERIFYHOST, 0)
|
||||
if request.ca_certs is not None:
|
||||
curl.setopt(pycurl.CAINFO, request.ca_certs)
|
||||
else:
|
||||
# There is no way to restore pycurl.CAINFO to its default value
|
||||
# (Using unsetopt makes it reject all certificates).
|
||||
# I don't see any way to read the default value from python so it
|
||||
# can be restored later. We'll have to just leave CAINFO untouched
|
||||
# if no ca_certs file was specified, and require that if any
|
||||
# request uses a custom ca_certs file, they all must.
|
||||
pass
|
||||
|
||||
if request.allow_ipv6 is False:
|
||||
# Curl behaves reasonably when DNS resolution gives an ipv6 address
|
||||
# that we can't reach, so allow ipv6 unless the user asks to disable.
|
||||
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
|
||||
else:
|
||||
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER)
|
||||
|
||||
# Set the request method through curl's irritating interface which makes
|
||||
# up names for almost every single method
|
||||
curl_options = {
|
||||
"GET": pycurl.HTTPGET,
|
||||
"POST": pycurl.POST,
|
||||
"PUT": pycurl.UPLOAD,
|
||||
"HEAD": pycurl.NOBODY,
|
||||
}
|
||||
custom_methods = set(["DELETE", "OPTIONS", "PATCH"])
|
||||
for o in curl_options.values():
|
||||
curl.setopt(o, False)
|
||||
if request.method in curl_options:
|
||||
curl.unsetopt(pycurl.CUSTOMREQUEST)
|
||||
curl.setopt(curl_options[request.method], True)
|
||||
elif request.allow_nonstandard_methods or request.method in custom_methods:
|
||||
curl.setopt(pycurl.CUSTOMREQUEST, request.method)
|
||||
else:
|
||||
raise KeyError('unknown method ' + request.method)
|
||||
|
||||
body_expected = request.method in ("POST", "PATCH", "PUT")
|
||||
body_present = request.body is not None
|
||||
if not request.allow_nonstandard_methods:
|
||||
# Some HTTP methods nearly always have bodies while others
|
||||
# almost never do. Fail in this case unless the user has
|
||||
# opted out of sanity checks with allow_nonstandard_methods.
|
||||
if ((body_expected and not body_present) or
|
||||
(body_present and not body_expected)):
|
||||
raise ValueError(
|
||||
'Body must %sbe None for method %s (unless '
|
||||
'allow_nonstandard_methods is true)' %
|
||||
('not ' if body_expected else '', request.method))
|
||||
|
||||
if body_expected or body_present:
|
||||
if request.method == "GET":
|
||||
# Even with `allow_nonstandard_methods` we disallow
|
||||
# GET with a body (because libcurl doesn't allow it
|
||||
# unless we use CUSTOMREQUEST). While the spec doesn't
|
||||
# forbid clients from sending a body, it arguably
|
||||
# disallows the server from doing anything with them.
|
||||
raise ValueError('Body must be None for GET request')
|
||||
request_buffer = BytesIO(utf8(request.body or ''))
|
||||
|
||||
def ioctl(cmd):
|
||||
if cmd == curl.IOCMD_RESTARTREAD:
|
||||
request_buffer.seek(0)
|
||||
curl.setopt(pycurl.READFUNCTION, request_buffer.read)
|
||||
curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
|
||||
if request.method == "POST":
|
||||
curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or ''))
|
||||
else:
|
||||
curl.setopt(pycurl.UPLOAD, True)
|
||||
curl.setopt(pycurl.INFILESIZE, len(request.body or ''))
|
||||
|
||||
if request.auth_username is not None:
|
||||
if request.auth_mode is None or request.auth_mode == "basic":
|
||||
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
|
||||
elif request.auth_mode == "digest":
|
||||
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_DIGEST)
|
||||
else:
|
||||
raise ValueError("Unsupported auth_mode %s" % request.auth_mode)
|
||||
|
||||
userpwd = httputil.encode_username_password(request.auth_username,
|
||||
request.auth_password)
|
||||
curl.setopt(pycurl.USERPWD, userpwd)
|
||||
curl_log.debug("%s %s (username: %r)", request.method, request.url,
|
||||
request.auth_username)
|
||||
else:
|
||||
curl.unsetopt(pycurl.USERPWD)
|
||||
curl_log.debug("%s %s", request.method, request.url)
|
||||
|
||||
if request.client_cert is not None:
|
||||
curl.setopt(pycurl.SSLCERT, request.client_cert)
|
||||
|
||||
if request.client_key is not None:
|
||||
curl.setopt(pycurl.SSLKEY, request.client_key)
|
||||
|
||||
if request.ssl_options is not None:
|
||||
raise ValueError("ssl_options not supported in curl_httpclient")
|
||||
|
||||
if threading.activeCount() > 1:
|
||||
# libcurl/pycurl is not thread-safe by default. When multiple threads
|
||||
# are used, signals should be disabled. This has the side effect
|
||||
# of disabling DNS timeouts in some environments (when libcurl is
|
||||
# not linked against ares), so we don't do it when there is only one
|
||||
# thread. Applications that use many short-lived threads may need
|
||||
# to set NOSIGNAL manually in a prepare_curl_callback since
|
||||
# there may not be any other threads running at the time we call
|
||||
# threading.activeCount.
|
||||
curl.setopt(pycurl.NOSIGNAL, 1)
|
||||
if request.prepare_curl_callback is not None:
|
||||
request.prepare_curl_callback(curl)
|
||||
|
||||
def _curl_header_callback(self, headers, header_callback, header_line):
|
||||
header_line = native_str(header_line.decode('latin1'))
|
||||
if header_callback is not None:
|
||||
self.io_loop.add_callback(header_callback, header_line)
|
||||
# header_line as returned by curl includes the end-of-line characters.
|
||||
# whitespace at the start should be preserved to allow multi-line headers
|
||||
header_line = header_line.rstrip()
|
||||
if header_line.startswith("HTTP/"):
|
||||
headers.clear()
|
||||
try:
|
||||
(__, __, reason) = httputil.parse_response_start_line(header_line)
|
||||
header_line = "X-Http-Reason: %s" % reason
|
||||
except httputil.HTTPInputError:
|
||||
return
|
||||
if not header_line:
|
||||
return
|
||||
headers.parse_line(header_line)
|
||||
|
||||
def _curl_debug(self, debug_type, debug_msg):
|
||||
debug_types = ('I', '<', '>', '<', '>')
|
||||
if debug_type == 0:
|
||||
debug_msg = native_str(debug_msg)
|
||||
curl_log.debug('%s', debug_msg.strip())
|
||||
elif debug_type in (1, 2):
|
||||
debug_msg = native_str(debug_msg)
|
||||
for line in debug_msg.splitlines():
|
||||
curl_log.debug('%s %s', debug_types[debug_type], line)
|
||||
elif debug_type == 4:
|
||||
curl_log.debug('%s %r', debug_types[debug_type], debug_msg)
|
||||
|
||||
|
||||
class CurlError(HTTPError):
|
||||
def __init__(self, errno, message):
|
||||
HTTPError.__init__(self, 599, message)
|
||||
self.errno = errno
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
AsyncHTTPClient.configure(CurlAsyncHTTPClient)
|
||||
main()
|
||||
Executable
+399
@@ -0,0 +1,399 @@
|
||||
#
|
||||
# Copyright 2009 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Escaping/unescaping methods for HTML, JSON, URLs, and others.
|
||||
|
||||
Also includes a few other miscellaneous string manipulation functions that
|
||||
have crept in over time.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from tornado.util import PY3, unicode_type, basestring_type
|
||||
|
||||
if PY3:
|
||||
from urllib.parse import parse_qs as _parse_qs
|
||||
import html.entities as htmlentitydefs
|
||||
import urllib.parse as urllib_parse
|
||||
unichr = chr
|
||||
else:
|
||||
from urlparse import parse_qs as _parse_qs
|
||||
import htmlentitydefs
|
||||
import urllib as urllib_parse
|
||||
|
||||
try:
|
||||
import typing # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
_XHTML_ESCAPE_RE = re.compile('[&<>"\']')
|
||||
_XHTML_ESCAPE_DICT = {'&': '&', '<': '<', '>': '>', '"': '"',
|
||||
'\'': '''}
|
||||
|
||||
|
||||
def xhtml_escape(value):
|
||||
"""Escapes a string so it is valid within HTML or XML.
|
||||
|
||||
Escapes the characters ``<``, ``>``, ``"``, ``'``, and ``&``.
|
||||
When used in attribute values the escaped strings must be enclosed
|
||||
in quotes.
|
||||
|
||||
.. versionchanged:: 3.2
|
||||
|
||||
Added the single quote to the list of escaped characters.
|
||||
"""
|
||||
return _XHTML_ESCAPE_RE.sub(lambda match: _XHTML_ESCAPE_DICT[match.group(0)],
|
||||
to_basestring(value))
|
||||
|
||||
|
||||
def xhtml_unescape(value):
|
||||
"""Un-escapes an XML-escaped string."""
|
||||
return re.sub(r"&(#?)(\w+?);", _convert_entity, _unicode(value))
|
||||
|
||||
|
||||
# The fact that json_encode wraps json.dumps is an implementation detail.
|
||||
# Please see https://github.com/tornadoweb/tornado/pull/706
|
||||
# before sending a pull request that adds **kwargs to this function.
|
||||
def json_encode(value):
|
||||
"""JSON-encodes the given Python object."""
|
||||
# JSON permits but does not require forward slashes to be escaped.
|
||||
# This is useful when json data is emitted in a <script> tag
|
||||
# in HTML, as it prevents </script> tags from prematurely terminating
|
||||
# the javascript. Some json libraries do this escaping by default,
|
||||
# although python's standard library does not, so we do it here.
|
||||
# http://stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped
|
||||
return json.dumps(value).replace("</", "<\\/")
|
||||
|
||||
|
||||
def json_decode(value):
|
||||
"""Returns Python objects for the given JSON string."""
|
||||
return json.loads(to_basestring(value))
|
||||
|
||||
|
||||
def squeeze(value):
|
||||
"""Replace all sequences of whitespace chars with a single space."""
|
||||
return re.sub(r"[\x00-\x20]+", " ", value).strip()
|
||||
|
||||
|
||||
def url_escape(value, plus=True):
|
||||
"""Returns a URL-encoded version of the given value.
|
||||
|
||||
If ``plus`` is true (the default), spaces will be represented
|
||||
as "+" instead of "%20". This is appropriate for query strings
|
||||
but not for the path component of a URL. Note that this default
|
||||
is the reverse of Python's urllib module.
|
||||
|
||||
.. versionadded:: 3.1
|
||||
The ``plus`` argument
|
||||
"""
|
||||
quote = urllib_parse.quote_plus if plus else urllib_parse.quote
|
||||
return quote(utf8(value))
|
||||
|
||||
|
||||
# python 3 changed things around enough that we need two separate
|
||||
# implementations of url_unescape. We also need our own implementation
|
||||
# of parse_qs since python 3's version insists on decoding everything.
|
||||
if not PY3:
|
||||
def url_unescape(value, encoding='utf-8', plus=True):
|
||||
"""Decodes the given value from a URL.
|
||||
|
||||
The argument may be either a byte or unicode string.
|
||||
|
||||
If encoding is None, the result will be a byte string. Otherwise,
|
||||
the result is a unicode string in the specified encoding.
|
||||
|
||||
If ``plus`` is true (the default), plus signs will be interpreted
|
||||
as spaces (literal plus signs must be represented as "%2B"). This
|
||||
is appropriate for query strings and form-encoded values but not
|
||||
for the path component of a URL. Note that this default is the
|
||||
reverse of Python's urllib module.
|
||||
|
||||
.. versionadded:: 3.1
|
||||
The ``plus`` argument
|
||||
"""
|
||||
unquote = (urllib_parse.unquote_plus if plus else urllib_parse.unquote)
|
||||
if encoding is None:
|
||||
return unquote(utf8(value))
|
||||
else:
|
||||
return unicode_type(unquote(utf8(value)), encoding)
|
||||
|
||||
parse_qs_bytes = _parse_qs
|
||||
else:
|
||||
def url_unescape(value, encoding='utf-8', plus=True):
|
||||
"""Decodes the given value from a URL.
|
||||
|
||||
The argument may be either a byte or unicode string.
|
||||
|
||||
If encoding is None, the result will be a byte string. Otherwise,
|
||||
the result is a unicode string in the specified encoding.
|
||||
|
||||
If ``plus`` is true (the default), plus signs will be interpreted
|
||||
as spaces (literal plus signs must be represented as "%2B"). This
|
||||
is appropriate for query strings and form-encoded values but not
|
||||
for the path component of a URL. Note that this default is the
|
||||
reverse of Python's urllib module.
|
||||
|
||||
.. versionadded:: 3.1
|
||||
The ``plus`` argument
|
||||
"""
|
||||
if encoding is None:
|
||||
if plus:
|
||||
# unquote_to_bytes doesn't have a _plus variant
|
||||
value = to_basestring(value).replace('+', ' ')
|
||||
return urllib_parse.unquote_to_bytes(value)
|
||||
else:
|
||||
unquote = (urllib_parse.unquote_plus if plus
|
||||
else urllib_parse.unquote)
|
||||
return unquote(to_basestring(value), encoding=encoding)
|
||||
|
||||
def parse_qs_bytes(qs, keep_blank_values=False, strict_parsing=False):
|
||||
"""Parses a query string like urlparse.parse_qs, but returns the
|
||||
values as byte strings.
|
||||
|
||||
Keys still become type str (interpreted as latin1 in python3!)
|
||||
because it's too painful to keep them as byte strings in
|
||||
python3 and in practice they're nearly always ascii anyway.
|
||||
"""
|
||||
# This is gross, but python3 doesn't give us another way.
|
||||
# Latin1 is the universal donor of character encodings.
|
||||
result = _parse_qs(qs, keep_blank_values, strict_parsing,
|
||||
encoding='latin1', errors='strict')
|
||||
encoded = {}
|
||||
for k, v in result.items():
|
||||
encoded[k] = [i.encode('latin1') for i in v]
|
||||
return encoded
|
||||
|
||||
|
||||
_UTF8_TYPES = (bytes, type(None))
|
||||
|
||||
|
||||
def utf8(value):
|
||||
# type: (typing.Union[bytes,unicode_type,None])->typing.Union[bytes,None]
|
||||
"""Converts a string argument to a byte string.
|
||||
|
||||
If the argument is already a byte string or None, it is returned unchanged.
|
||||
Otherwise it must be a unicode string and is encoded as utf8.
|
||||
"""
|
||||
if isinstance(value, _UTF8_TYPES):
|
||||
return value
|
||||
if not isinstance(value, unicode_type):
|
||||
raise TypeError(
|
||||
"Expected bytes, unicode, or None; got %r" % type(value)
|
||||
)
|
||||
return value.encode("utf-8")
|
||||
|
||||
|
||||
_TO_UNICODE_TYPES = (unicode_type, type(None))
|
||||
|
||||
|
||||
def to_unicode(value):
|
||||
"""Converts a string argument to a unicode string.
|
||||
|
||||
If the argument is already a unicode string or None, it is returned
|
||||
unchanged. Otherwise it must be a byte string and is decoded as utf8.
|
||||
"""
|
||||
if isinstance(value, _TO_UNICODE_TYPES):
|
||||
return value
|
||||
if not isinstance(value, bytes):
|
||||
raise TypeError(
|
||||
"Expected bytes, unicode, or None; got %r" % type(value)
|
||||
)
|
||||
return value.decode("utf-8")
|
||||
|
||||
|
||||
# to_unicode was previously named _unicode not because it was private,
|
||||
# but to avoid conflicts with the built-in unicode() function/type
|
||||
_unicode = to_unicode
|
||||
|
||||
# When dealing with the standard library across python 2 and 3 it is
|
||||
# sometimes useful to have a direct conversion to the native string type
|
||||
if str is unicode_type:
|
||||
native_str = to_unicode
|
||||
else:
|
||||
native_str = utf8
|
||||
|
||||
_BASESTRING_TYPES = (basestring_type, type(None))
|
||||
|
||||
|
||||
def to_basestring(value):
|
||||
"""Converts a string argument to a subclass of basestring.
|
||||
|
||||
In python2, byte and unicode strings are mostly interchangeable,
|
||||
so functions that deal with a user-supplied argument in combination
|
||||
with ascii string constants can use either and should return the type
|
||||
the user supplied. In python3, the two types are not interchangeable,
|
||||
so this method is needed to convert byte strings to unicode.
|
||||
"""
|
||||
if isinstance(value, _BASESTRING_TYPES):
|
||||
return value
|
||||
if not isinstance(value, bytes):
|
||||
raise TypeError(
|
||||
"Expected bytes, unicode, or None; got %r" % type(value)
|
||||
)
|
||||
return value.decode("utf-8")
|
||||
|
||||
|
||||
def recursive_unicode(obj):
|
||||
"""Walks a simple data structure, converting byte strings to unicode.
|
||||
|
||||
Supports lists, tuples, and dictionaries.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return dict((recursive_unicode(k), recursive_unicode(v)) for (k, v) in obj.items())
|
||||
elif isinstance(obj, list):
|
||||
return list(recursive_unicode(i) for i in obj)
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple(recursive_unicode(i) for i in obj)
|
||||
elif isinstance(obj, bytes):
|
||||
return to_unicode(obj)
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
# I originally used the regex from
|
||||
# http://daringfireball.net/2010/07/improved_regex_for_matching_urls
|
||||
# but it gets all exponential on certain patterns (such as too many trailing
|
||||
# dots), causing the regex matcher to never return.
|
||||
# This regex should avoid those problems.
|
||||
# Use to_unicode instead of tornado.util.u - we don't want backslashes getting
|
||||
# processed as escapes.
|
||||
_URL_RE = re.compile(to_unicode(
|
||||
r"""\b((?:([\w-]+):(/{1,3})|www[.])(?:(?:(?:[^\s&()]|&|")*(?:[^!"#$%&'()*+,.:;<=>?@\[\]^`{|}~\s]))|(?:\((?:[^\s&()]|&|")*\)))+)""" # noqa: E501
|
||||
))
|
||||
|
||||
|
||||
def linkify(text, shorten=False, extra_params="",
|
||||
require_protocol=False, permitted_protocols=["http", "https"]):
|
||||
"""Converts plain text into HTML with links.
|
||||
|
||||
For example: ``linkify("Hello http://tornadoweb.org!")`` would return
|
||||
``Hello <a href="http://tornadoweb.org">http://tornadoweb.org</a>!``
|
||||
|
||||
Parameters:
|
||||
|
||||
* ``shorten``: Long urls will be shortened for display.
|
||||
|
||||
* ``extra_params``: Extra text to include in the link tag, or a callable
|
||||
taking the link as an argument and returning the extra text
|
||||
e.g. ``linkify(text, extra_params='rel="nofollow" class="external"')``,
|
||||
or::
|
||||
|
||||
def extra_params_cb(url):
|
||||
if url.startswith("http://example.com"):
|
||||
return 'class="internal"'
|
||||
else:
|
||||
return 'class="external" rel="nofollow"'
|
||||
linkify(text, extra_params=extra_params_cb)
|
||||
|
||||
* ``require_protocol``: Only linkify urls which include a protocol. If
|
||||
this is False, urls such as www.facebook.com will also be linkified.
|
||||
|
||||
* ``permitted_protocols``: List (or set) of protocols which should be
|
||||
linkified, e.g. ``linkify(text, permitted_protocols=["http", "ftp",
|
||||
"mailto"])``. It is very unsafe to include protocols such as
|
||||
``javascript``.
|
||||
"""
|
||||
if extra_params and not callable(extra_params):
|
||||
extra_params = " " + extra_params.strip()
|
||||
|
||||
def make_link(m):
|
||||
url = m.group(1)
|
||||
proto = m.group(2)
|
||||
if require_protocol and not proto:
|
||||
return url # not protocol, no linkify
|
||||
|
||||
if proto and proto not in permitted_protocols:
|
||||
return url # bad protocol, no linkify
|
||||
|
||||
href = m.group(1)
|
||||
if not proto:
|
||||
href = "http://" + href # no proto specified, use http
|
||||
|
||||
if callable(extra_params):
|
||||
params = " " + extra_params(href).strip()
|
||||
else:
|
||||
params = extra_params
|
||||
|
||||
# clip long urls. max_len is just an approximation
|
||||
max_len = 30
|
||||
if shorten and len(url) > max_len:
|
||||
before_clip = url
|
||||
if proto:
|
||||
proto_len = len(proto) + 1 + len(m.group(3) or "") # +1 for :
|
||||
else:
|
||||
proto_len = 0
|
||||
|
||||
parts = url[proto_len:].split("/")
|
||||
if len(parts) > 1:
|
||||
# Grab the whole host part plus the first bit of the path
|
||||
# The path is usually not that interesting once shortened
|
||||
# (no more slug, etc), so it really just provides a little
|
||||
# extra indication of shortening.
|
||||
url = url[:proto_len] + parts[0] + "/" + \
|
||||
parts[1][:8].split('?')[0].split('.')[0]
|
||||
|
||||
if len(url) > max_len * 1.5: # still too long
|
||||
url = url[:max_len]
|
||||
|
||||
if url != before_clip:
|
||||
amp = url.rfind('&')
|
||||
# avoid splitting html char entities
|
||||
if amp > max_len - 5:
|
||||
url = url[:amp]
|
||||
url += "..."
|
||||
|
||||
if len(url) >= len(before_clip):
|
||||
url = before_clip
|
||||
else:
|
||||
# full url is visible on mouse-over (for those who don't
|
||||
# have a status bar, such as Safari by default)
|
||||
params += ' title="%s"' % href
|
||||
|
||||
return u'<a href="%s"%s>%s</a>' % (href, params, url)
|
||||
|
||||
# First HTML-escape so that our strings are all safe.
|
||||
# The regex is modified to avoid character entites other than & so
|
||||
# that we won't pick up ", etc.
|
||||
text = _unicode(xhtml_escape(text))
|
||||
return _URL_RE.sub(make_link, text)
|
||||
|
||||
|
||||
def _convert_entity(m):
|
||||
if m.group(1) == "#":
|
||||
try:
|
||||
if m.group(2)[:1].lower() == 'x':
|
||||
return unichr(int(m.group(2)[1:], 16))
|
||||
else:
|
||||
return unichr(int(m.group(2)))
|
||||
except ValueError:
|
||||
return "&#%s;" % m.group(2)
|
||||
try:
|
||||
return _HTML_UNICODE_MAP[m.group(2)]
|
||||
except KeyError:
|
||||
return "&%s;" % m.group(2)
|
||||
|
||||
|
||||
def _build_unicode_map():
|
||||
unicode_map = {}
|
||||
for name, value in htmlentitydefs.name2codepoint.items():
|
||||
unicode_map[name] = unichr(value)
|
||||
return unicode_map
|
||||
|
||||
|
||||
_HTML_UNICODE_MAP = _build_unicode_map()
|
||||
Executable
+1367
File diff suppressed because it is too large
Load Diff
Executable
+751
@@ -0,0 +1,751 @@
|
||||
#
|
||||
# Copyright 2014 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Client and server implementations of HTTP/1.x.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import re
|
||||
import warnings
|
||||
|
||||
from tornado.concurrent import (Future, future_add_done_callback,
|
||||
future_set_result_unless_cancelled)
|
||||
from tornado.escape import native_str, utf8
|
||||
from tornado import gen
|
||||
from tornado import httputil
|
||||
from tornado import iostream
|
||||
from tornado.log import gen_log, app_log
|
||||
from tornado import stack_context
|
||||
from tornado.util import GzipDecompressor, PY3
|
||||
|
||||
|
||||
class _QuietException(Exception):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class _ExceptionLoggingContext(object):
|
||||
"""Used with the ``with`` statement when calling delegate methods to
|
||||
log any exceptions with the given logger. Any exceptions caught are
|
||||
converted to _QuietException
|
||||
"""
|
||||
def __init__(self, logger):
|
||||
self.logger = logger
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, typ, value, tb):
|
||||
if value is not None:
|
||||
self.logger.error("Uncaught exception", exc_info=(typ, value, tb))
|
||||
raise _QuietException
|
||||
|
||||
|
||||
class HTTP1ConnectionParameters(object):
|
||||
"""Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`.
|
||||
"""
|
||||
def __init__(self, no_keep_alive=False, chunk_size=None,
|
||||
max_header_size=None, header_timeout=None, max_body_size=None,
|
||||
body_timeout=None, decompress=False):
|
||||
"""
|
||||
:arg bool no_keep_alive: If true, always close the connection after
|
||||
one request.
|
||||
:arg int chunk_size: how much data to read into memory at once
|
||||
:arg int max_header_size: maximum amount of data for HTTP headers
|
||||
:arg float header_timeout: how long to wait for all headers (seconds)
|
||||
:arg int max_body_size: maximum amount of data for body
|
||||
:arg float body_timeout: how long to wait while reading body (seconds)
|
||||
:arg bool decompress: if true, decode incoming
|
||||
``Content-Encoding: gzip``
|
||||
"""
|
||||
self.no_keep_alive = no_keep_alive
|
||||
self.chunk_size = chunk_size or 65536
|
||||
self.max_header_size = max_header_size or 65536
|
||||
self.header_timeout = header_timeout
|
||||
self.max_body_size = max_body_size
|
||||
self.body_timeout = body_timeout
|
||||
self.decompress = decompress
|
||||
|
||||
|
||||
class HTTP1Connection(httputil.HTTPConnection):
|
||||
"""Implements the HTTP/1.x protocol.
|
||||
|
||||
This class can be on its own for clients, or via `HTTP1ServerConnection`
|
||||
for servers.
|
||||
"""
|
||||
def __init__(self, stream, is_client, params=None, context=None):
|
||||
"""
|
||||
:arg stream: an `.IOStream`
|
||||
:arg bool is_client: client or server
|
||||
:arg params: a `.HTTP1ConnectionParameters` instance or ``None``
|
||||
:arg context: an opaque application-defined object that can be accessed
|
||||
as ``connection.context``.
|
||||
"""
|
||||
self.is_client = is_client
|
||||
self.stream = stream
|
||||
if params is None:
|
||||
params = HTTP1ConnectionParameters()
|
||||
self.params = params
|
||||
self.context = context
|
||||
self.no_keep_alive = params.no_keep_alive
|
||||
# The body limits can be altered by the delegate, so save them
|
||||
# here instead of just referencing self.params later.
|
||||
self._max_body_size = (self.params.max_body_size or
|
||||
self.stream.max_buffer_size)
|
||||
self._body_timeout = self.params.body_timeout
|
||||
# _write_finished is set to True when finish() has been called,
|
||||
# i.e. there will be no more data sent. Data may still be in the
|
||||
# stream's write buffer.
|
||||
self._write_finished = False
|
||||
# True when we have read the entire incoming body.
|
||||
self._read_finished = False
|
||||
# _finish_future resolves when all data has been written and flushed
|
||||
# to the IOStream.
|
||||
self._finish_future = Future()
|
||||
# If true, the connection should be closed after this request
|
||||
# (after the response has been written in the server side,
|
||||
# and after it has been read in the client)
|
||||
self._disconnect_on_finish = False
|
||||
self._clear_callbacks()
|
||||
# Save the start lines after we read or write them; they
|
||||
# affect later processing (e.g. 304 responses and HEAD methods
|
||||
# have content-length but no bodies)
|
||||
self._request_start_line = None
|
||||
self._response_start_line = None
|
||||
self._request_headers = None
|
||||
# True if we are writing output with chunked encoding.
|
||||
self._chunking_output = None
|
||||
# While reading a body with a content-length, this is the
|
||||
# amount left to read.
|
||||
self._expected_content_remaining = None
|
||||
# A Future for our outgoing writes, returned by IOStream.write.
|
||||
self._pending_write = None
|
||||
|
||||
def read_response(self, delegate):
|
||||
"""Read a single HTTP response.
|
||||
|
||||
Typical client-mode usage is to write a request using `write_headers`,
|
||||
`write`, and `finish`, and then call ``read_response``.
|
||||
|
||||
:arg delegate: a `.HTTPMessageDelegate`
|
||||
|
||||
Returns a `.Future` that resolves to None after the full response has
|
||||
been read.
|
||||
"""
|
||||
if self.params.decompress:
|
||||
delegate = _GzipMessageDelegate(delegate, self.params.chunk_size)
|
||||
return self._read_message(delegate)
|
||||
|
||||
@gen.coroutine
|
||||
def _read_message(self, delegate):
|
||||
need_delegate_close = False
|
||||
try:
|
||||
header_future = self.stream.read_until_regex(
|
||||
b"\r?\n\r?\n",
|
||||
max_bytes=self.params.max_header_size)
|
||||
if self.params.header_timeout is None:
|
||||
header_data = yield header_future
|
||||
else:
|
||||
try:
|
||||
header_data = yield gen.with_timeout(
|
||||
self.stream.io_loop.time() + self.params.header_timeout,
|
||||
header_future,
|
||||
quiet_exceptions=iostream.StreamClosedError)
|
||||
except gen.TimeoutError:
|
||||
self.close()
|
||||
raise gen.Return(False)
|
||||
start_line, headers = self._parse_headers(header_data)
|
||||
if self.is_client:
|
||||
start_line = httputil.parse_response_start_line(start_line)
|
||||
self._response_start_line = start_line
|
||||
else:
|
||||
start_line = httputil.parse_request_start_line(start_line)
|
||||
self._request_start_line = start_line
|
||||
self._request_headers = headers
|
||||
|
||||
self._disconnect_on_finish = not self._can_keep_alive(
|
||||
start_line, headers)
|
||||
need_delegate_close = True
|
||||
with _ExceptionLoggingContext(app_log):
|
||||
header_future = delegate.headers_received(start_line, headers)
|
||||
if header_future is not None:
|
||||
yield header_future
|
||||
if self.stream is None:
|
||||
# We've been detached.
|
||||
need_delegate_close = False
|
||||
raise gen.Return(False)
|
||||
skip_body = False
|
||||
if self.is_client:
|
||||
if (self._request_start_line is not None and
|
||||
self._request_start_line.method == 'HEAD'):
|
||||
skip_body = True
|
||||
code = start_line.code
|
||||
if code == 304:
|
||||
# 304 responses may include the content-length header
|
||||
# but do not actually have a body.
|
||||
# http://tools.ietf.org/html/rfc7230#section-3.3
|
||||
skip_body = True
|
||||
if code >= 100 and code < 200:
|
||||
# 1xx responses should never indicate the presence of
|
||||
# a body.
|
||||
if ('Content-Length' in headers or
|
||||
'Transfer-Encoding' in headers):
|
||||
raise httputil.HTTPInputError(
|
||||
"Response code %d cannot have body" % code)
|
||||
# TODO: client delegates will get headers_received twice
|
||||
# in the case of a 100-continue. Document or change?
|
||||
yield self._read_message(delegate)
|
||||
else:
|
||||
if (headers.get("Expect") == "100-continue" and
|
||||
not self._write_finished):
|
||||
self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
|
||||
if not skip_body:
|
||||
body_future = self._read_body(
|
||||
start_line.code if self.is_client else 0, headers, delegate)
|
||||
if body_future is not None:
|
||||
if self._body_timeout is None:
|
||||
yield body_future
|
||||
else:
|
||||
try:
|
||||
yield gen.with_timeout(
|
||||
self.stream.io_loop.time() + self._body_timeout,
|
||||
body_future,
|
||||
quiet_exceptions=iostream.StreamClosedError)
|
||||
except gen.TimeoutError:
|
||||
gen_log.info("Timeout reading body from %s",
|
||||
self.context)
|
||||
self.stream.close()
|
||||
raise gen.Return(False)
|
||||
self._read_finished = True
|
||||
if not self._write_finished or self.is_client:
|
||||
need_delegate_close = False
|
||||
with _ExceptionLoggingContext(app_log):
|
||||
delegate.finish()
|
||||
# If we're waiting for the application to produce an asynchronous
|
||||
# response, and we're not detached, register a close callback
|
||||
# on the stream (we didn't need one while we were reading)
|
||||
if (not self._finish_future.done() and
|
||||
self.stream is not None and
|
||||
not self.stream.closed()):
|
||||
self.stream.set_close_callback(self._on_connection_close)
|
||||
yield self._finish_future
|
||||
if self.is_client and self._disconnect_on_finish:
|
||||
self.close()
|
||||
if self.stream is None:
|
||||
raise gen.Return(False)
|
||||
except httputil.HTTPInputError as e:
|
||||
gen_log.info("Malformed HTTP message from %s: %s",
|
||||
self.context, e)
|
||||
if not self.is_client:
|
||||
yield self.stream.write(b'HTTP/1.1 400 Bad Request\r\n\r\n')
|
||||
self.close()
|
||||
raise gen.Return(False)
|
||||
finally:
|
||||
if need_delegate_close:
|
||||
with _ExceptionLoggingContext(app_log):
|
||||
delegate.on_connection_close()
|
||||
header_future = None
|
||||
self._clear_callbacks()
|
||||
raise gen.Return(True)
|
||||
|
||||
def _clear_callbacks(self):
|
||||
"""Clears the callback attributes.
|
||||
|
||||
This allows the request handler to be garbage collected more
|
||||
quickly in CPython by breaking up reference cycles.
|
||||
"""
|
||||
self._write_callback = None
|
||||
self._write_future = None
|
||||
self._close_callback = None
|
||||
if self.stream is not None:
|
||||
self.stream.set_close_callback(None)
|
||||
|
||||
def set_close_callback(self, callback):
|
||||
"""Sets a callback that will be run when the connection is closed.
|
||||
|
||||
Note that this callback is slightly different from
|
||||
`.HTTPMessageDelegate.on_connection_close`: The
|
||||
`.HTTPMessageDelegate` method is called when the connection is
|
||||
closed while recieving a message. This callback is used when
|
||||
there is not an active delegate (for example, on the server
|
||||
side this callback is used if the client closes the connection
|
||||
after sending its request but before receiving all the
|
||||
response.
|
||||
"""
|
||||
self._close_callback = stack_context.wrap(callback)
|
||||
|
||||
def _on_connection_close(self):
|
||||
# Note that this callback is only registered on the IOStream
|
||||
# when we have finished reading the request and are waiting for
|
||||
# the application to produce its response.
|
||||
if self._close_callback is not None:
|
||||
callback = self._close_callback
|
||||
self._close_callback = None
|
||||
callback()
|
||||
if not self._finish_future.done():
|
||||
future_set_result_unless_cancelled(self._finish_future, None)
|
||||
self._clear_callbacks()
|
||||
|
||||
def close(self):
|
||||
if self.stream is not None:
|
||||
self.stream.close()
|
||||
self._clear_callbacks()
|
||||
if not self._finish_future.done():
|
||||
future_set_result_unless_cancelled(self._finish_future, None)
|
||||
|
||||
def detach(self):
|
||||
"""Take control of the underlying stream.
|
||||
|
||||
Returns the underlying `.IOStream` object and stops all further
|
||||
HTTP processing. May only be called during
|
||||
`.HTTPMessageDelegate.headers_received`. Intended for implementing
|
||||
protocols like websockets that tunnel over an HTTP handshake.
|
||||
"""
|
||||
self._clear_callbacks()
|
||||
stream = self.stream
|
||||
self.stream = None
|
||||
if not self._finish_future.done():
|
||||
future_set_result_unless_cancelled(self._finish_future, None)
|
||||
return stream
|
||||
|
||||
def set_body_timeout(self, timeout):
|
||||
"""Sets the body timeout for a single request.
|
||||
|
||||
Overrides the value from `.HTTP1ConnectionParameters`.
|
||||
"""
|
||||
self._body_timeout = timeout
|
||||
|
||||
def set_max_body_size(self, max_body_size):
|
||||
"""Sets the body size limit for a single request.
|
||||
|
||||
Overrides the value from `.HTTP1ConnectionParameters`.
|
||||
"""
|
||||
self._max_body_size = max_body_size
|
||||
|
||||
def write_headers(self, start_line, headers, chunk=None, callback=None):
|
||||
"""Implements `.HTTPConnection.write_headers`."""
|
||||
lines = []
|
||||
if self.is_client:
|
||||
self._request_start_line = start_line
|
||||
lines.append(utf8('%s %s HTTP/1.1' % (start_line[0], start_line[1])))
|
||||
# Client requests with a non-empty body must have either a
|
||||
# Content-Length or a Transfer-Encoding.
|
||||
self._chunking_output = (
|
||||
start_line.method in ('POST', 'PUT', 'PATCH') and
|
||||
'Content-Length' not in headers and
|
||||
'Transfer-Encoding' not in headers)
|
||||
else:
|
||||
self._response_start_line = start_line
|
||||
lines.append(utf8('HTTP/1.1 %d %s' % (start_line[1], start_line[2])))
|
||||
self._chunking_output = (
|
||||
# TODO: should this use
|
||||
# self._request_start_line.version or
|
||||
# start_line.version?
|
||||
self._request_start_line.version == 'HTTP/1.1' and
|
||||
# 1xx, 204 and 304 responses have no body (not even a zero-length
|
||||
# body), and so should not have either Content-Length or
|
||||
# Transfer-Encoding headers.
|
||||
start_line.code not in (204, 304) and
|
||||
(start_line.code < 100 or start_line.code >= 200) and
|
||||
# No need to chunk the output if a Content-Length is specified.
|
||||
'Content-Length' not in headers and
|
||||
# Applications are discouraged from touching Transfer-Encoding,
|
||||
# but if they do, leave it alone.
|
||||
'Transfer-Encoding' not in headers)
|
||||
# If connection to a 1.1 client will be closed, inform client
|
||||
if (self._request_start_line.version == 'HTTP/1.1' and self._disconnect_on_finish):
|
||||
headers['Connection'] = 'close'
|
||||
# If a 1.0 client asked for keep-alive, add the header.
|
||||
if (self._request_start_line.version == 'HTTP/1.0' and
|
||||
self._request_headers.get('Connection', '').lower() == 'keep-alive'):
|
||||
headers['Connection'] = 'Keep-Alive'
|
||||
if self._chunking_output:
|
||||
headers['Transfer-Encoding'] = 'chunked'
|
||||
if (not self.is_client and
|
||||
(self._request_start_line.method == 'HEAD' or
|
||||
start_line.code == 304)):
|
||||
self._expected_content_remaining = 0
|
||||
elif 'Content-Length' in headers:
|
||||
self._expected_content_remaining = int(headers['Content-Length'])
|
||||
else:
|
||||
self._expected_content_remaining = None
|
||||
# TODO: headers are supposed to be of type str, but we still have some
|
||||
# cases that let bytes slip through. Remove these native_str calls when those
|
||||
# are fixed.
|
||||
header_lines = (native_str(n) + ": " + native_str(v) for n, v in headers.get_all())
|
||||
if PY3:
|
||||
lines.extend(l.encode('latin1') for l in header_lines)
|
||||
else:
|
||||
lines.extend(header_lines)
|
||||
for line in lines:
|
||||
if b'\n' in line:
|
||||
raise ValueError('Newline in header: ' + repr(line))
|
||||
future = None
|
||||
if self.stream.closed():
|
||||
future = self._write_future = Future()
|
||||
future.set_exception(iostream.StreamClosedError())
|
||||
future.exception()
|
||||
else:
|
||||
if callback is not None:
|
||||
warnings.warn("callback argument is deprecated, use returned Future instead",
|
||||
DeprecationWarning)
|
||||
self._write_callback = stack_context.wrap(callback)
|
||||
else:
|
||||
future = self._write_future = Future()
|
||||
data = b"\r\n".join(lines) + b"\r\n\r\n"
|
||||
if chunk:
|
||||
data += self._format_chunk(chunk)
|
||||
self._pending_write = self.stream.write(data)
|
||||
future_add_done_callback(self._pending_write, self._on_write_complete)
|
||||
return future
|
||||
|
||||
def _format_chunk(self, chunk):
|
||||
if self._expected_content_remaining is not None:
|
||||
self._expected_content_remaining -= len(chunk)
|
||||
if self._expected_content_remaining < 0:
|
||||
# Close the stream now to stop further framing errors.
|
||||
self.stream.close()
|
||||
raise httputil.HTTPOutputError(
|
||||
"Tried to write more data than Content-Length")
|
||||
if self._chunking_output and chunk:
|
||||
# Don't write out empty chunks because that means END-OF-STREAM
|
||||
# with chunked encoding
|
||||
return utf8("%x" % len(chunk)) + b"\r\n" + chunk + b"\r\n"
|
||||
else:
|
||||
return chunk
|
||||
|
||||
def write(self, chunk, callback=None):
|
||||
"""Implements `.HTTPConnection.write`.
|
||||
|
||||
For backwards compatibility it is allowed but deprecated to
|
||||
skip `write_headers` and instead call `write()` with a
|
||||
pre-encoded header block.
|
||||
"""
|
||||
future = None
|
||||
if self.stream.closed():
|
||||
future = self._write_future = Future()
|
||||
self._write_future.set_exception(iostream.StreamClosedError())
|
||||
self._write_future.exception()
|
||||
else:
|
||||
if callback is not None:
|
||||
warnings.warn("callback argument is deprecated, use returned Future instead",
|
||||
DeprecationWarning)
|
||||
self._write_callback = stack_context.wrap(callback)
|
||||
else:
|
||||
future = self._write_future = Future()
|
||||
self._pending_write = self.stream.write(self._format_chunk(chunk))
|
||||
self._pending_write.add_done_callback(self._on_write_complete)
|
||||
return future
|
||||
|
||||
def finish(self):
|
||||
"""Implements `.HTTPConnection.finish`."""
|
||||
if (self._expected_content_remaining is not None and
|
||||
self._expected_content_remaining != 0 and
|
||||
not self.stream.closed()):
|
||||
self.stream.close()
|
||||
raise httputil.HTTPOutputError(
|
||||
"Tried to write %d bytes less than Content-Length" %
|
||||
self._expected_content_remaining)
|
||||
if self._chunking_output:
|
||||
if not self.stream.closed():
|
||||
self._pending_write = self.stream.write(b"0\r\n\r\n")
|
||||
self._pending_write.add_done_callback(self._on_write_complete)
|
||||
self._write_finished = True
|
||||
# If the app finished the request while we're still reading,
|
||||
# divert any remaining data away from the delegate and
|
||||
# close the connection when we're done sending our response.
|
||||
# Closing the connection is the only way to avoid reading the
|
||||
# whole input body.
|
||||
if not self._read_finished:
|
||||
self._disconnect_on_finish = True
|
||||
# No more data is coming, so instruct TCP to send any remaining
|
||||
# data immediately instead of waiting for a full packet or ack.
|
||||
self.stream.set_nodelay(True)
|
||||
if self._pending_write is None:
|
||||
self._finish_request(None)
|
||||
else:
|
||||
future_add_done_callback(self._pending_write, self._finish_request)
|
||||
|
||||
def _on_write_complete(self, future):
|
||||
exc = future.exception()
|
||||
if exc is not None and not isinstance(exc, iostream.StreamClosedError):
|
||||
future.result()
|
||||
if self._write_callback is not None:
|
||||
callback = self._write_callback
|
||||
self._write_callback = None
|
||||
self.stream.io_loop.add_callback(callback)
|
||||
if self._write_future is not None:
|
||||
future = self._write_future
|
||||
self._write_future = None
|
||||
future_set_result_unless_cancelled(future, None)
|
||||
|
||||
def _can_keep_alive(self, start_line, headers):
|
||||
if self.params.no_keep_alive:
|
||||
return False
|
||||
connection_header = headers.get("Connection")
|
||||
if connection_header is not None:
|
||||
connection_header = connection_header.lower()
|
||||
if start_line.version == "HTTP/1.1":
|
||||
return connection_header != "close"
|
||||
elif ("Content-Length" in headers or
|
||||
headers.get("Transfer-Encoding", "").lower() == "chunked" or
|
||||
getattr(start_line, 'method', None) in ("HEAD", "GET")):
|
||||
# start_line may be a request or response start line; only
|
||||
# the former has a method attribute.
|
||||
return connection_header == "keep-alive"
|
||||
return False
|
||||
|
||||
def _finish_request(self, future):
|
||||
self._clear_callbacks()
|
||||
if not self.is_client and self._disconnect_on_finish:
|
||||
self.close()
|
||||
return
|
||||
# Turn Nagle's algorithm back on, leaving the stream in its
|
||||
# default state for the next request.
|
||||
self.stream.set_nodelay(False)
|
||||
if not self._finish_future.done():
|
||||
future_set_result_unless_cancelled(self._finish_future, None)
|
||||
|
||||
def _parse_headers(self, data):
|
||||
# The lstrip removes newlines that some implementations sometimes
|
||||
# insert between messages of a reused connection. Per RFC 7230,
|
||||
# we SHOULD ignore at least one empty line before the request.
|
||||
# http://tools.ietf.org/html/rfc7230#section-3.5
|
||||
data = native_str(data.decode('latin1')).lstrip("\r\n")
|
||||
# RFC 7230 section allows for both CRLF and bare LF.
|
||||
eol = data.find("\n")
|
||||
start_line = data[:eol].rstrip("\r")
|
||||
headers = httputil.HTTPHeaders.parse(data[eol:])
|
||||
return start_line, headers
|
||||
|
||||
def _read_body(self, code, headers, delegate):
|
||||
if "Content-Length" in headers:
|
||||
if "Transfer-Encoding" in headers:
|
||||
# Response cannot contain both Content-Length and
|
||||
# Transfer-Encoding headers.
|
||||
# http://tools.ietf.org/html/rfc7230#section-3.3.3
|
||||
raise httputil.HTTPInputError(
|
||||
"Response with both Transfer-Encoding and Content-Length")
|
||||
if "," in headers["Content-Length"]:
|
||||
# Proxies sometimes cause Content-Length headers to get
|
||||
# duplicated. If all the values are identical then we can
|
||||
# use them but if they differ it's an error.
|
||||
pieces = re.split(r',\s*', headers["Content-Length"])
|
||||
if any(i != pieces[0] for i in pieces):
|
||||
raise httputil.HTTPInputError(
|
||||
"Multiple unequal Content-Lengths: %r" %
|
||||
headers["Content-Length"])
|
||||
headers["Content-Length"] = pieces[0]
|
||||
|
||||
try:
|
||||
content_length = int(headers["Content-Length"])
|
||||
except ValueError:
|
||||
# Handles non-integer Content-Length value.
|
||||
raise httputil.HTTPInputError(
|
||||
"Only integer Content-Length is allowed: %s" % headers["Content-Length"])
|
||||
|
||||
if content_length > self._max_body_size:
|
||||
raise httputil.HTTPInputError("Content-Length too long")
|
||||
else:
|
||||
content_length = None
|
||||
|
||||
if code == 204:
|
||||
# This response code is not allowed to have a non-empty body,
|
||||
# and has an implicit length of zero instead of read-until-close.
|
||||
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
|
||||
if ("Transfer-Encoding" in headers or
|
||||
content_length not in (None, 0)):
|
||||
raise httputil.HTTPInputError(
|
||||
"Response with code %d should not have body" % code)
|
||||
content_length = 0
|
||||
|
||||
if content_length is not None:
|
||||
return self._read_fixed_body(content_length, delegate)
|
||||
if headers.get("Transfer-Encoding", "").lower() == "chunked":
|
||||
return self._read_chunked_body(delegate)
|
||||
if self.is_client:
|
||||
return self._read_body_until_close(delegate)
|
||||
return None
|
||||
|
||||
@gen.coroutine
|
||||
def _read_fixed_body(self, content_length, delegate):
|
||||
while content_length > 0:
|
||||
body = yield self.stream.read_bytes(
|
||||
min(self.params.chunk_size, content_length), partial=True)
|
||||
content_length -= len(body)
|
||||
if not self._write_finished or self.is_client:
|
||||
with _ExceptionLoggingContext(app_log):
|
||||
ret = delegate.data_received(body)
|
||||
if ret is not None:
|
||||
yield ret
|
||||
|
||||
@gen.coroutine
|
||||
def _read_chunked_body(self, delegate):
|
||||
# TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
|
||||
total_size = 0
|
||||
while True:
|
||||
chunk_len = yield self.stream.read_until(b"\r\n", max_bytes=64)
|
||||
chunk_len = int(chunk_len.strip(), 16)
|
||||
if chunk_len == 0:
|
||||
crlf = yield self.stream.read_bytes(2)
|
||||
if crlf != b'\r\n':
|
||||
raise httputil.HTTPInputError("improperly terminated chunked request")
|
||||
return
|
||||
total_size += chunk_len
|
||||
if total_size > self._max_body_size:
|
||||
raise httputil.HTTPInputError("chunked body too large")
|
||||
bytes_to_read = chunk_len
|
||||
while bytes_to_read:
|
||||
chunk = yield self.stream.read_bytes(
|
||||
min(bytes_to_read, self.params.chunk_size), partial=True)
|
||||
bytes_to_read -= len(chunk)
|
||||
if not self._write_finished or self.is_client:
|
||||
with _ExceptionLoggingContext(app_log):
|
||||
ret = delegate.data_received(chunk)
|
||||
if ret is not None:
|
||||
yield ret
|
||||
# chunk ends with \r\n
|
||||
crlf = yield self.stream.read_bytes(2)
|
||||
assert crlf == b"\r\n"
|
||||
|
||||
@gen.coroutine
|
||||
def _read_body_until_close(self, delegate):
|
||||
body = yield self.stream.read_until_close()
|
||||
if not self._write_finished or self.is_client:
|
||||
with _ExceptionLoggingContext(app_log):
|
||||
delegate.data_received(body)
|
||||
|
||||
|
||||
class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
|
||||
"""Wraps an `HTTPMessageDelegate` to decode ``Content-Encoding: gzip``.
|
||||
"""
|
||||
def __init__(self, delegate, chunk_size):
|
||||
self._delegate = delegate
|
||||
self._chunk_size = chunk_size
|
||||
self._decompressor = None
|
||||
|
||||
def headers_received(self, start_line, headers):
|
||||
if headers.get("Content-Encoding") == "gzip":
|
||||
self._decompressor = GzipDecompressor()
|
||||
# Downstream delegates will only see uncompressed data,
|
||||
# so rename the content-encoding header.
|
||||
# (but note that curl_httpclient doesn't do this).
|
||||
headers.add("X-Consumed-Content-Encoding",
|
||||
headers["Content-Encoding"])
|
||||
del headers["Content-Encoding"]
|
||||
return self._delegate.headers_received(start_line, headers)
|
||||
|
||||
@gen.coroutine
|
||||
def data_received(self, chunk):
|
||||
if self._decompressor:
|
||||
compressed_data = chunk
|
||||
while compressed_data:
|
||||
decompressed = self._decompressor.decompress(
|
||||
compressed_data, self._chunk_size)
|
||||
if decompressed:
|
||||
ret = self._delegate.data_received(decompressed)
|
||||
if ret is not None:
|
||||
yield ret
|
||||
compressed_data = self._decompressor.unconsumed_tail
|
||||
else:
|
||||
ret = self._delegate.data_received(chunk)
|
||||
if ret is not None:
|
||||
yield ret
|
||||
|
||||
def finish(self):
|
||||
if self._decompressor is not None:
|
||||
tail = self._decompressor.flush()
|
||||
if tail:
|
||||
# I believe the tail will always be empty (i.e.
|
||||
# decompress will return all it can). The purpose
|
||||
# of the flush call is to detect errors such
|
||||
# as truncated input. But in case it ever returns
|
||||
# anything, treat it as an extra chunk
|
||||
self._delegate.data_received(tail)
|
||||
return self._delegate.finish()
|
||||
|
||||
def on_connection_close(self):
|
||||
return self._delegate.on_connection_close()
|
||||
|
||||
|
||||
class HTTP1ServerConnection(object):
|
||||
"""An HTTP/1.x server."""
|
||||
def __init__(self, stream, params=None, context=None):
|
||||
"""
|
||||
:arg stream: an `.IOStream`
|
||||
:arg params: a `.HTTP1ConnectionParameters` or None
|
||||
:arg context: an opaque application-defined object that is accessible
|
||||
as ``connection.context``
|
||||
"""
|
||||
self.stream = stream
|
||||
if params is None:
|
||||
params = HTTP1ConnectionParameters()
|
||||
self.params = params
|
||||
self.context = context
|
||||
self._serving_future = None
|
||||
|
||||
@gen.coroutine
|
||||
def close(self):
|
||||
"""Closes the connection.
|
||||
|
||||
Returns a `.Future` that resolves after the serving loop has exited.
|
||||
"""
|
||||
self.stream.close()
|
||||
# Block until the serving loop is done, but ignore any exceptions
|
||||
# (start_serving is already responsible for logging them).
|
||||
try:
|
||||
yield self._serving_future
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def start_serving(self, delegate):
|
||||
"""Starts serving requests on this connection.
|
||||
|
||||
:arg delegate: a `.HTTPServerConnectionDelegate`
|
||||
"""
|
||||
assert isinstance(delegate, httputil.HTTPServerConnectionDelegate)
|
||||
self._serving_future = self._server_request_loop(delegate)
|
||||
# Register the future on the IOLoop so its errors get logged.
|
||||
self.stream.io_loop.add_future(self._serving_future,
|
||||
lambda f: f.result())
|
||||
|
||||
@gen.coroutine
|
||||
def _server_request_loop(self, delegate):
|
||||
try:
|
||||
while True:
|
||||
conn = HTTP1Connection(self.stream, False,
|
||||
self.params, self.context)
|
||||
request_delegate = delegate.start_request(self, conn)
|
||||
try:
|
||||
ret = yield conn.read_response(request_delegate)
|
||||
except (iostream.StreamClosedError,
|
||||
iostream.UnsatisfiableReadError):
|
||||
return
|
||||
except _QuietException:
|
||||
# This exception was already logged.
|
||||
conn.close()
|
||||
return
|
||||
except Exception:
|
||||
gen_log.error("Uncaught exception", exc_info=True)
|
||||
conn.close()
|
||||
return
|
||||
if not ret:
|
||||
return
|
||||
yield gen.moment
|
||||
finally:
|
||||
delegate.on_close(self)
|
||||
Executable
+748
@@ -0,0 +1,748 @@
|
||||
"""Blocking and non-blocking HTTP client interfaces.
|
||||
|
||||
This module defines a common interface shared by two implementations,
|
||||
``simple_httpclient`` and ``curl_httpclient``. Applications may either
|
||||
instantiate their chosen implementation class directly or use the
|
||||
`AsyncHTTPClient` class from this module, which selects an implementation
|
||||
that can be overridden with the `AsyncHTTPClient.configure` method.
|
||||
|
||||
The default implementation is ``simple_httpclient``, and this is expected
|
||||
to be suitable for most users' needs. However, some applications may wish
|
||||
to switch to ``curl_httpclient`` for reasons such as the following:
|
||||
|
||||
* ``curl_httpclient`` has some features not found in ``simple_httpclient``,
|
||||
including support for HTTP proxies and the ability to use a specified
|
||||
network interface.
|
||||
|
||||
* ``curl_httpclient`` is more likely to be compatible with sites that are
|
||||
not-quite-compliant with the HTTP spec, or sites that use little-exercised
|
||||
features of HTTP.
|
||||
|
||||
* ``curl_httpclient`` is faster.
|
||||
|
||||
* ``curl_httpclient`` was the default prior to Tornado 2.0.
|
||||
|
||||
Note that if you are using ``curl_httpclient``, it is highly
|
||||
recommended that you use a recent version of ``libcurl`` and
|
||||
``pycurl``. Currently the minimum supported version of libcurl is
|
||||
7.22.0, and the minimum version of pycurl is 7.18.2. It is highly
|
||||
recommended that your ``libcurl`` installation is built with
|
||||
asynchronous DNS resolver (threaded or c-ares), otherwise you may
|
||||
encounter various problems with request timeouts (for more
|
||||
information, see
|
||||
http://curl.haxx.se/libcurl/c/curl_easy_setopt.html#CURLOPTCONNECTTIMEOUTMS
|
||||
and comments in curl_httpclient.py).
|
||||
|
||||
To select ``curl_httpclient``, call `AsyncHTTPClient.configure` at startup::
|
||||
|
||||
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import functools
|
||||
import time
|
||||
import warnings
|
||||
import weakref
|
||||
|
||||
from tornado.concurrent import Future, future_set_result_unless_cancelled
|
||||
from tornado.escape import utf8, native_str
|
||||
from tornado import gen, httputil, stack_context
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.util import Configurable
|
||||
|
||||
|
||||
class HTTPClient(object):
|
||||
"""A blocking HTTP client.
|
||||
|
||||
This interface is provided to make it easier to share code between
|
||||
synchronous and asynchronous applications. Applications that are
|
||||
running an `.IOLoop` must use `AsyncHTTPClient` instead.
|
||||
|
||||
Typical usage looks like this::
|
||||
|
||||
http_client = httpclient.HTTPClient()
|
||||
try:
|
||||
response = http_client.fetch("http://www.google.com/")
|
||||
print(response.body)
|
||||
except httpclient.HTTPError as e:
|
||||
# HTTPError is raised for non-200 responses; the response
|
||||
# can be found in e.response.
|
||||
print("Error: " + str(e))
|
||||
except Exception as e:
|
||||
# Other errors are possible, such as IOError.
|
||||
print("Error: " + str(e))
|
||||
http_client.close()
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
|
||||
Due to limitations in `asyncio`, it is no longer possible to
|
||||
use the synchronous ``HTTPClient`` while an `.IOLoop` is running.
|
||||
Use `AsyncHTTPClient` instead.
|
||||
|
||||
"""
|
||||
def __init__(self, async_client_class=None, **kwargs):
|
||||
# Initialize self._closed at the beginning of the constructor
|
||||
# so that an exception raised here doesn't lead to confusing
|
||||
# failures in __del__.
|
||||
self._closed = True
|
||||
self._io_loop = IOLoop(make_current=False)
|
||||
if async_client_class is None:
|
||||
async_client_class = AsyncHTTPClient
|
||||
# Create the client while our IOLoop is "current", without
|
||||
# clobbering the thread's real current IOLoop (if any).
|
||||
self._async_client = self._io_loop.run_sync(
|
||||
gen.coroutine(lambda: async_client_class(**kwargs)))
|
||||
self._closed = False
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
"""Closes the HTTPClient, freeing any resources used."""
|
||||
if not self._closed:
|
||||
self._async_client.close()
|
||||
self._io_loop.close()
|
||||
self._closed = True
|
||||
|
||||
def fetch(self, request, **kwargs):
|
||||
"""Executes a request, returning an `HTTPResponse`.
|
||||
|
||||
The request may be either a string URL or an `HTTPRequest` object.
|
||||
If it is a string, we construct an `HTTPRequest` using any additional
|
||||
kwargs: ``HTTPRequest(request, **kwargs)``
|
||||
|
||||
If an error occurs during the fetch, we raise an `HTTPError` unless
|
||||
the ``raise_error`` keyword argument is set to False.
|
||||
"""
|
||||
response = self._io_loop.run_sync(functools.partial(
|
||||
self._async_client.fetch, request, **kwargs))
|
||||
return response
|
||||
|
||||
|
||||
class AsyncHTTPClient(Configurable):
|
||||
"""An non-blocking HTTP client.
|
||||
|
||||
Example usage::
|
||||
|
||||
async def f():
|
||||
http_client = AsyncHTTPClient()
|
||||
try:
|
||||
response = await http_client.fetch("http://www.google.com")
|
||||
except Exception as e:
|
||||
print("Error: %s" % e)
|
||||
else:
|
||||
print(response.body)
|
||||
|
||||
The constructor for this class is magic in several respects: It
|
||||
actually creates an instance of an implementation-specific
|
||||
subclass, and instances are reused as a kind of pseudo-singleton
|
||||
(one per `.IOLoop`). The keyword argument ``force_instance=True``
|
||||
can be used to suppress this singleton behavior. Unless
|
||||
``force_instance=True`` is used, no arguments should be passed to
|
||||
the `AsyncHTTPClient` constructor. The implementation subclass as
|
||||
well as arguments to its constructor can be set with the static
|
||||
method `configure()`
|
||||
|
||||
All `AsyncHTTPClient` implementations support a ``defaults``
|
||||
keyword argument, which can be used to set default values for
|
||||
`HTTPRequest` attributes. For example::
|
||||
|
||||
AsyncHTTPClient.configure(
|
||||
None, defaults=dict(user_agent="MyUserAgent"))
|
||||
# or with force_instance:
|
||||
client = AsyncHTTPClient(force_instance=True,
|
||||
defaults=dict(user_agent="MyUserAgent"))
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
|
||||
|
||||
"""
|
||||
@classmethod
|
||||
def configurable_base(cls):
|
||||
return AsyncHTTPClient
|
||||
|
||||
@classmethod
|
||||
def configurable_default(cls):
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient
|
||||
return SimpleAsyncHTTPClient
|
||||
|
||||
@classmethod
|
||||
def _async_clients(cls):
|
||||
attr_name = '_async_client_dict_' + cls.__name__
|
||||
if not hasattr(cls, attr_name):
|
||||
setattr(cls, attr_name, weakref.WeakKeyDictionary())
|
||||
return getattr(cls, attr_name)
|
||||
|
||||
def __new__(cls, force_instance=False, **kwargs):
|
||||
io_loop = IOLoop.current()
|
||||
if force_instance:
|
||||
instance_cache = None
|
||||
else:
|
||||
instance_cache = cls._async_clients()
|
||||
if instance_cache is not None and io_loop in instance_cache:
|
||||
return instance_cache[io_loop]
|
||||
instance = super(AsyncHTTPClient, cls).__new__(cls, **kwargs)
|
||||
# Make sure the instance knows which cache to remove itself from.
|
||||
# It can't simply call _async_clients() because we may be in
|
||||
# __new__(AsyncHTTPClient) but instance.__class__ may be
|
||||
# SimpleAsyncHTTPClient.
|
||||
instance._instance_cache = instance_cache
|
||||
if instance_cache is not None:
|
||||
instance_cache[instance.io_loop] = instance
|
||||
return instance
|
||||
|
||||
def initialize(self, defaults=None):
|
||||
self.io_loop = IOLoop.current()
|
||||
self.defaults = dict(HTTPRequest._DEFAULTS)
|
||||
if defaults is not None:
|
||||
self.defaults.update(defaults)
|
||||
self._closed = False
|
||||
|
||||
def close(self):
|
||||
"""Destroys this HTTP client, freeing any file descriptors used.
|
||||
|
||||
This method is **not needed in normal use** due to the way
|
||||
that `AsyncHTTPClient` objects are transparently reused.
|
||||
``close()`` is generally only necessary when either the
|
||||
`.IOLoop` is also being closed, or the ``force_instance=True``
|
||||
argument was used when creating the `AsyncHTTPClient`.
|
||||
|
||||
No other methods may be called on the `AsyncHTTPClient` after
|
||||
``close()``.
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if self._instance_cache is not None:
|
||||
if self._instance_cache.get(self.io_loop) is not self:
|
||||
raise RuntimeError("inconsistent AsyncHTTPClient cache")
|
||||
del self._instance_cache[self.io_loop]
|
||||
|
||||
def fetch(self, request, callback=None, raise_error=True, **kwargs):
|
||||
"""Executes a request, asynchronously returning an `HTTPResponse`.
|
||||
|
||||
The request may be either a string URL or an `HTTPRequest` object.
|
||||
If it is a string, we construct an `HTTPRequest` using any additional
|
||||
kwargs: ``HTTPRequest(request, **kwargs)``
|
||||
|
||||
This method returns a `.Future` whose result is an
|
||||
`HTTPResponse`. By default, the ``Future`` will raise an
|
||||
`HTTPError` if the request returned a non-200 response code
|
||||
(other errors may also be raised if the server could not be
|
||||
contacted). Instead, if ``raise_error`` is set to False, the
|
||||
response will always be returned regardless of the response
|
||||
code.
|
||||
|
||||
If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
|
||||
In the callback interface, `HTTPError` is not automatically raised.
|
||||
Instead, you must check the response's ``error`` attribute or
|
||||
call its `~HTTPResponse.rethrow` method.
|
||||
|
||||
.. deprecated:: 5.1
|
||||
|
||||
The ``callback`` argument is deprecated and will be removed
|
||||
in 6.0. Use the returned `.Future` instead.
|
||||
|
||||
The ``raise_error=False`` argument currently suppresses
|
||||
*all* errors, encapsulating them in `HTTPResponse` objects
|
||||
with a 599 response code. This will change in Tornado 6.0:
|
||||
``raise_error=False`` will only affect the `HTTPError`
|
||||
raised when a non-200 response code is used.
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise RuntimeError("fetch() called on closed AsyncHTTPClient")
|
||||
if not isinstance(request, HTTPRequest):
|
||||
request = HTTPRequest(url=request, **kwargs)
|
||||
else:
|
||||
if kwargs:
|
||||
raise ValueError("kwargs can't be used if request is an HTTPRequest object")
|
||||
# We may modify this (to add Host, Accept-Encoding, etc),
|
||||
# so make sure we don't modify the caller's object. This is also
|
||||
# where normal dicts get converted to HTTPHeaders objects.
|
||||
request.headers = httputil.HTTPHeaders(request.headers)
|
||||
request = _RequestProxy(request, self.defaults)
|
||||
future = Future()
|
||||
if callback is not None:
|
||||
warnings.warn("callback arguments are deprecated, use the returned Future instead",
|
||||
DeprecationWarning)
|
||||
callback = stack_context.wrap(callback)
|
||||
|
||||
def handle_future(future):
|
||||
exc = future.exception()
|
||||
if isinstance(exc, HTTPError) and exc.response is not None:
|
||||
response = exc.response
|
||||
elif exc is not None:
|
||||
response = HTTPResponse(
|
||||
request, 599, error=exc,
|
||||
request_time=time.time() - request.start_time)
|
||||
else:
|
||||
response = future.result()
|
||||
self.io_loop.add_callback(callback, response)
|
||||
future.add_done_callback(handle_future)
|
||||
|
||||
def handle_response(response):
|
||||
if raise_error and response.error:
|
||||
if isinstance(response.error, HTTPError):
|
||||
response.error.response = response
|
||||
future.set_exception(response.error)
|
||||
else:
|
||||
if response.error and not response._error_is_response_code:
|
||||
warnings.warn("raise_error=False will allow '%s' to be raised in the future" %
|
||||
response.error, DeprecationWarning)
|
||||
future_set_result_unless_cancelled(future, response)
|
||||
self.fetch_impl(request, handle_response)
|
||||
return future
|
||||
|
||||
def fetch_impl(self, request, callback):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def configure(cls, impl, **kwargs):
|
||||
"""Configures the `AsyncHTTPClient` subclass to use.
|
||||
|
||||
``AsyncHTTPClient()`` actually creates an instance of a subclass.
|
||||
This method may be called with either a class object or the
|
||||
fully-qualified name of such a class (or ``None`` to use the default,
|
||||
``SimpleAsyncHTTPClient``)
|
||||
|
||||
If additional keyword arguments are given, they will be passed
|
||||
to the constructor of each subclass instance created. The
|
||||
keyword argument ``max_clients`` determines the maximum number
|
||||
of simultaneous `~AsyncHTTPClient.fetch()` operations that can
|
||||
execute in parallel on each `.IOLoop`. Additional arguments
|
||||
may be supported depending on the implementation class in use.
|
||||
|
||||
Example::
|
||||
|
||||
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
|
||||
"""
|
||||
super(AsyncHTTPClient, cls).configure(impl, **kwargs)
|
||||
|
||||
|
||||
class HTTPRequest(object):
|
||||
"""HTTP client request object."""
|
||||
|
||||
# Default values for HTTPRequest parameters.
|
||||
# Merged with the values on the request object by AsyncHTTPClient
|
||||
# implementations.
|
||||
_DEFAULTS = dict(
|
||||
connect_timeout=20.0,
|
||||
request_timeout=20.0,
|
||||
follow_redirects=True,
|
||||
max_redirects=5,
|
||||
decompress_response=True,
|
||||
proxy_password='',
|
||||
allow_nonstandard_methods=False,
|
||||
validate_cert=True)
|
||||
|
||||
def __init__(self, url, method="GET", headers=None, body=None,
|
||||
auth_username=None, auth_password=None, auth_mode=None,
|
||||
connect_timeout=None, request_timeout=None,
|
||||
if_modified_since=None, follow_redirects=None,
|
||||
max_redirects=None, user_agent=None, use_gzip=None,
|
||||
network_interface=None, streaming_callback=None,
|
||||
header_callback=None, prepare_curl_callback=None,
|
||||
proxy_host=None, proxy_port=None, proxy_username=None,
|
||||
proxy_password=None, proxy_auth_mode=None,
|
||||
allow_nonstandard_methods=None, validate_cert=None,
|
||||
ca_certs=None, allow_ipv6=None, client_key=None,
|
||||
client_cert=None, body_producer=None,
|
||||
expect_100_continue=False, decompress_response=None,
|
||||
ssl_options=None):
|
||||
r"""All parameters except ``url`` are optional.
|
||||
|
||||
:arg str url: URL to fetch
|
||||
:arg str method: HTTP method, e.g. "GET" or "POST"
|
||||
:arg headers: Additional HTTP headers to pass on the request
|
||||
:type headers: `~tornado.httputil.HTTPHeaders` or `dict`
|
||||
:arg body: HTTP request body as a string (byte or unicode; if unicode
|
||||
the utf-8 encoding will be used)
|
||||
:arg body_producer: Callable used for lazy/asynchronous request bodies.
|
||||
It is called with one argument, a ``write`` function, and should
|
||||
return a `.Future`. It should call the write function with new
|
||||
data as it becomes available. The write function returns a
|
||||
`.Future` which can be used for flow control.
|
||||
Only one of ``body`` and ``body_producer`` may
|
||||
be specified. ``body_producer`` is not supported on
|
||||
``curl_httpclient``. When using ``body_producer`` it is recommended
|
||||
to pass a ``Content-Length`` in the headers as otherwise chunked
|
||||
encoding will be used, and many servers do not support chunked
|
||||
encoding on requests. New in Tornado 4.0
|
||||
:arg str auth_username: Username for HTTP authentication
|
||||
:arg str auth_password: Password for HTTP authentication
|
||||
:arg str auth_mode: Authentication mode; default is "basic".
|
||||
Allowed values are implementation-defined; ``curl_httpclient``
|
||||
supports "basic" and "digest"; ``simple_httpclient`` only supports
|
||||
"basic"
|
||||
:arg float connect_timeout: Timeout for initial connection in seconds,
|
||||
default 20 seconds
|
||||
:arg float request_timeout: Timeout for entire request in seconds,
|
||||
default 20 seconds
|
||||
:arg if_modified_since: Timestamp for ``If-Modified-Since`` header
|
||||
:type if_modified_since: `datetime` or `float`
|
||||
:arg bool follow_redirects: Should redirects be followed automatically
|
||||
or return the 3xx response? Default True.
|
||||
:arg int max_redirects: Limit for ``follow_redirects``, default 5.
|
||||
:arg str user_agent: String to send as ``User-Agent`` header
|
||||
:arg bool decompress_response: Request a compressed response from
|
||||
the server and decompress it after downloading. Default is True.
|
||||
New in Tornado 4.0.
|
||||
:arg bool use_gzip: Deprecated alias for ``decompress_response``
|
||||
since Tornado 4.0.
|
||||
:arg str network_interface: Network interface to use for request.
|
||||
``curl_httpclient`` only; see note below.
|
||||
:arg collections.abc.Callable streaming_callback: If set, ``streaming_callback`` will
|
||||
be run with each chunk of data as it is received, and
|
||||
``HTTPResponse.body`` and ``HTTPResponse.buffer`` will be empty in
|
||||
the final response.
|
||||
:arg collections.abc.Callable header_callback: If set, ``header_callback`` will
|
||||
be run with each header line as it is received (including the
|
||||
first line, e.g. ``HTTP/1.0 200 OK\r\n``, and a final line
|
||||
containing only ``\r\n``. All lines include the trailing newline
|
||||
characters). ``HTTPResponse.headers`` will be empty in the final
|
||||
response. This is most useful in conjunction with
|
||||
``streaming_callback``, because it's the only way to get access to
|
||||
header data while the request is in progress.
|
||||
:arg collections.abc.Callable prepare_curl_callback: If set, will be called with
|
||||
a ``pycurl.Curl`` object to allow the application to make additional
|
||||
``setopt`` calls.
|
||||
:arg str proxy_host: HTTP proxy hostname. To use proxies,
|
||||
``proxy_host`` and ``proxy_port`` must be set; ``proxy_username``,
|
||||
``proxy_pass`` and ``proxy_auth_mode`` are optional. Proxies are
|
||||
currently only supported with ``curl_httpclient``.
|
||||
:arg int proxy_port: HTTP proxy port
|
||||
:arg str proxy_username: HTTP proxy username
|
||||
:arg str proxy_password: HTTP proxy password
|
||||
:arg str proxy_auth_mode: HTTP proxy Authentication mode;
|
||||
default is "basic". supports "basic" and "digest"
|
||||
:arg bool allow_nonstandard_methods: Allow unknown values for ``method``
|
||||
argument? Default is False.
|
||||
:arg bool validate_cert: For HTTPS requests, validate the server's
|
||||
certificate? Default is True.
|
||||
:arg str ca_certs: filename of CA certificates in PEM format,
|
||||
or None to use defaults. See note below when used with
|
||||
``curl_httpclient``.
|
||||
:arg str client_key: Filename for client SSL key, if any. See
|
||||
note below when used with ``curl_httpclient``.
|
||||
:arg str client_cert: Filename for client SSL certificate, if any.
|
||||
See note below when used with ``curl_httpclient``.
|
||||
:arg ssl.SSLContext ssl_options: `ssl.SSLContext` object for use in
|
||||
``simple_httpclient`` (unsupported by ``curl_httpclient``).
|
||||
Overrides ``validate_cert``, ``ca_certs``, ``client_key``,
|
||||
and ``client_cert``.
|
||||
:arg bool allow_ipv6: Use IPv6 when available? Default is true.
|
||||
:arg bool expect_100_continue: If true, send the
|
||||
``Expect: 100-continue`` header and wait for a continue response
|
||||
before sending the request body. Only supported with
|
||||
simple_httpclient.
|
||||
|
||||
.. note::
|
||||
|
||||
When using ``curl_httpclient`` certain options may be
|
||||
inherited by subsequent fetches because ``pycurl`` does
|
||||
not allow them to be cleanly reset. This applies to the
|
||||
``ca_certs``, ``client_key``, ``client_cert``, and
|
||||
``network_interface`` arguments. If you use these
|
||||
options, you should pass them on every request (you don't
|
||||
have to always use the same values, but it's not possible
|
||||
to mix requests that specify these options with ones that
|
||||
use the defaults).
|
||||
|
||||
.. versionadded:: 3.1
|
||||
The ``auth_mode`` argument.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
The ``body_producer`` and ``expect_100_continue`` arguments.
|
||||
|
||||
.. versionadded:: 4.2
|
||||
The ``ssl_options`` argument.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
The ``proxy_auth_mode`` argument.
|
||||
"""
|
||||
# Note that some of these attributes go through property setters
|
||||
# defined below.
|
||||
self.headers = headers
|
||||
if if_modified_since:
|
||||
self.headers["If-Modified-Since"] = httputil.format_timestamp(
|
||||
if_modified_since)
|
||||
self.proxy_host = proxy_host
|
||||
self.proxy_port = proxy_port
|
||||
self.proxy_username = proxy_username
|
||||
self.proxy_password = proxy_password
|
||||
self.proxy_auth_mode = proxy_auth_mode
|
||||
self.url = url
|
||||
self.method = method
|
||||
self.body = body
|
||||
self.body_producer = body_producer
|
||||
self.auth_username = auth_username
|
||||
self.auth_password = auth_password
|
||||
self.auth_mode = auth_mode
|
||||
self.connect_timeout = connect_timeout
|
||||
self.request_timeout = request_timeout
|
||||
self.follow_redirects = follow_redirects
|
||||
self.max_redirects = max_redirects
|
||||
self.user_agent = user_agent
|
||||
if decompress_response is not None:
|
||||
self.decompress_response = decompress_response
|
||||
else:
|
||||
self.decompress_response = use_gzip
|
||||
self.network_interface = network_interface
|
||||
self.streaming_callback = streaming_callback
|
||||
self.header_callback = header_callback
|
||||
self.prepare_curl_callback = prepare_curl_callback
|
||||
self.allow_nonstandard_methods = allow_nonstandard_methods
|
||||
self.validate_cert = validate_cert
|
||||
self.ca_certs = ca_certs
|
||||
self.allow_ipv6 = allow_ipv6
|
||||
self.client_key = client_key
|
||||
self.client_cert = client_cert
|
||||
self.ssl_options = ssl_options
|
||||
self.expect_100_continue = expect_100_continue
|
||||
self.start_time = time.time()
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
return self._headers
|
||||
|
||||
@headers.setter
|
||||
def headers(self, value):
|
||||
if value is None:
|
||||
self._headers = httputil.HTTPHeaders()
|
||||
else:
|
||||
self._headers = value
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self._body
|
||||
|
||||
@body.setter
|
||||
def body(self, value):
|
||||
self._body = utf8(value)
|
||||
|
||||
@property
|
||||
def body_producer(self):
|
||||
return self._body_producer
|
||||
|
||||
@body_producer.setter
|
||||
def body_producer(self, value):
|
||||
self._body_producer = stack_context.wrap(value)
|
||||
|
||||
@property
|
||||
def streaming_callback(self):
|
||||
return self._streaming_callback
|
||||
|
||||
@streaming_callback.setter
|
||||
def streaming_callback(self, value):
|
||||
self._streaming_callback = stack_context.wrap(value)
|
||||
|
||||
@property
|
||||
def header_callback(self):
|
||||
return self._header_callback
|
||||
|
||||
@header_callback.setter
|
||||
def header_callback(self, value):
|
||||
self._header_callback = stack_context.wrap(value)
|
||||
|
||||
@property
|
||||
def prepare_curl_callback(self):
|
||||
return self._prepare_curl_callback
|
||||
|
||||
@prepare_curl_callback.setter
|
||||
def prepare_curl_callback(self, value):
|
||||
self._prepare_curl_callback = stack_context.wrap(value)
|
||||
|
||||
|
||||
class HTTPResponse(object):
|
||||
"""HTTP Response object.
|
||||
|
||||
Attributes:
|
||||
|
||||
* request: HTTPRequest object
|
||||
|
||||
* code: numeric HTTP status code, e.g. 200 or 404
|
||||
|
||||
* reason: human-readable reason phrase describing the status code
|
||||
|
||||
* headers: `tornado.httputil.HTTPHeaders` object
|
||||
|
||||
* effective_url: final location of the resource after following any
|
||||
redirects
|
||||
|
||||
* buffer: ``cStringIO`` object for response body
|
||||
|
||||
* body: response body as bytes (created on demand from ``self.buffer``)
|
||||
|
||||
* error: Exception object, if any
|
||||
|
||||
* request_time: seconds from request start to finish. Includes all network
|
||||
operations from DNS resolution to receiving the last byte of data.
|
||||
Does not include time spent in the queue (due to the ``max_clients`` option).
|
||||
If redirects were followed, only includes the final request.
|
||||
|
||||
* start_time: Time at which the HTTP operation started, based on `time.time`
|
||||
(not the monotonic clock used by `.IOLoop.time`). May be ``None`` if the request
|
||||
timed out while in the queue.
|
||||
|
||||
* time_info: dictionary of diagnostic timing information from the request.
|
||||
Available data are subject to change, but currently uses timings
|
||||
available from http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html,
|
||||
plus ``queue``, which is the delay (if any) introduced by waiting for
|
||||
a slot under `AsyncHTTPClient`'s ``max_clients`` setting.
|
||||
|
||||
.. versionadded:: 5.1
|
||||
|
||||
Added the ``start_time`` attribute.
|
||||
|
||||
.. versionchanged:: 5.1
|
||||
|
||||
The ``request_time`` attribute previously included time spent in the queue
|
||||
for ``simple_httpclient``, but not in ``curl_httpclient``. Now queueing time
|
||||
is excluded in both implementations. ``request_time`` is now more accurate for
|
||||
``curl_httpclient`` because it uses a monotonic clock when available.
|
||||
"""
|
||||
def __init__(self, request, code, headers=None, buffer=None,
|
||||
effective_url=None, error=None, request_time=None,
|
||||
time_info=None, reason=None, start_time=None):
|
||||
if isinstance(request, _RequestProxy):
|
||||
self.request = request.request
|
||||
else:
|
||||
self.request = request
|
||||
self.code = code
|
||||
self.reason = reason or httputil.responses.get(code, "Unknown")
|
||||
if headers is not None:
|
||||
self.headers = headers
|
||||
else:
|
||||
self.headers = httputil.HTTPHeaders()
|
||||
self.buffer = buffer
|
||||
self._body = None
|
||||
if effective_url is None:
|
||||
self.effective_url = request.url
|
||||
else:
|
||||
self.effective_url = effective_url
|
||||
self._error_is_response_code = False
|
||||
if error is None:
|
||||
if self.code < 200 or self.code >= 300:
|
||||
self._error_is_response_code = True
|
||||
self.error = HTTPError(self.code, message=self.reason,
|
||||
response=self)
|
||||
else:
|
||||
self.error = None
|
||||
else:
|
||||
self.error = error
|
||||
self.start_time = start_time
|
||||
self.request_time = request_time
|
||||
self.time_info = time_info or {}
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
if self.buffer is None:
|
||||
return None
|
||||
elif self._body is None:
|
||||
self._body = self.buffer.getvalue()
|
||||
|
||||
return self._body
|
||||
|
||||
def rethrow(self):
|
||||
"""If there was an error on the request, raise an `HTTPError`."""
|
||||
if self.error:
|
||||
raise self.error
|
||||
|
||||
def __repr__(self):
|
||||
args = ",".join("%s=%r" % i for i in sorted(self.__dict__.items()))
|
||||
return "%s(%s)" % (self.__class__.__name__, args)
|
||||
|
||||
|
||||
class HTTPClientError(Exception):
|
||||
"""Exception thrown for an unsuccessful HTTP request.
|
||||
|
||||
Attributes:
|
||||
|
||||
* ``code`` - HTTP error integer error code, e.g. 404. Error code 599 is
|
||||
used when no HTTP response was received, e.g. for a timeout.
|
||||
|
||||
* ``response`` - `HTTPResponse` object, if any.
|
||||
|
||||
Note that if ``follow_redirects`` is False, redirects become HTTPErrors,
|
||||
and you can look at ``error.response.headers['Location']`` to see the
|
||||
destination of the redirect.
|
||||
|
||||
.. versionchanged:: 5.1
|
||||
|
||||
Renamed from ``HTTPError`` to ``HTTPClientError`` to avoid collisions with
|
||||
`tornado.web.HTTPError`. The name ``tornado.httpclient.HTTPError`` remains
|
||||
as an alias.
|
||||
"""
|
||||
def __init__(self, code, message=None, response=None):
|
||||
self.code = code
|
||||
self.message = message or httputil.responses.get(code, "Unknown")
|
||||
self.response = response
|
||||
super(HTTPClientError, self).__init__(code, message, response)
|
||||
|
||||
def __str__(self):
|
||||
return "HTTP %d: %s" % (self.code, self.message)
|
||||
|
||||
# There is a cyclic reference between self and self.response,
|
||||
# which breaks the default __repr__ implementation.
|
||||
# (especially on pypy, which doesn't have the same recursion
|
||||
# detection as cpython).
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
HTTPError = HTTPClientError
|
||||
|
||||
|
||||
class _RequestProxy(object):
|
||||
"""Combines an object with a dictionary of defaults.
|
||||
|
||||
Used internally by AsyncHTTPClient implementations.
|
||||
"""
|
||||
def __init__(self, request, defaults):
|
||||
self.request = request
|
||||
self.defaults = defaults
|
||||
|
||||
def __getattr__(self, name):
|
||||
request_attr = getattr(self.request, name)
|
||||
if request_attr is not None:
|
||||
return request_attr
|
||||
elif self.defaults is not None:
|
||||
return self.defaults.get(name, None)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
from tornado.options import define, options, parse_command_line
|
||||
define("print_headers", type=bool, default=False)
|
||||
define("print_body", type=bool, default=True)
|
||||
define("follow_redirects", type=bool, default=True)
|
||||
define("validate_cert", type=bool, default=True)
|
||||
define("proxy_host", type=str)
|
||||
define("proxy_port", type=int)
|
||||
args = parse_command_line()
|
||||
client = HTTPClient()
|
||||
for arg in args:
|
||||
try:
|
||||
response = client.fetch(arg,
|
||||
follow_redirects=options.follow_redirects,
|
||||
validate_cert=options.validate_cert,
|
||||
proxy_host=options.proxy_host,
|
||||
proxy_port=options.proxy_port,
|
||||
)
|
||||
except HTTPError as e:
|
||||
if e.response is not None:
|
||||
response = e.response
|
||||
else:
|
||||
raise
|
||||
if options.print_headers:
|
||||
print(response.headers)
|
||||
if options.print_body:
|
||||
print(native_str(response.body))
|
||||
client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Executable
+330
@@ -0,0 +1,330 @@
|
||||
#
|
||||
# Copyright 2009 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""A non-blocking, single-threaded HTTP server.
|
||||
|
||||
Typical applications have little direct interaction with the `HTTPServer`
|
||||
class except to start a server at the beginning of the process
|
||||
(and even that is often done indirectly via `tornado.web.Application.listen`).
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
|
||||
The ``HTTPRequest`` class that used to live in this module has been moved
|
||||
to `tornado.httputil.HTTPServerRequest`. The old name remains as an alias.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import socket
|
||||
|
||||
from tornado.escape import native_str
|
||||
from tornado.http1connection import HTTP1ServerConnection, HTTP1ConnectionParameters
|
||||
from tornado import gen
|
||||
from tornado import httputil
|
||||
from tornado import iostream
|
||||
from tornado import netutil
|
||||
from tornado.tcpserver import TCPServer
|
||||
from tornado.util import Configurable
|
||||
|
||||
|
||||
class HTTPServer(TCPServer, Configurable,
|
||||
httputil.HTTPServerConnectionDelegate):
|
||||
r"""A non-blocking, single-threaded HTTP server.
|
||||
|
||||
A server is defined by a subclass of `.HTTPServerConnectionDelegate`,
|
||||
or, for backwards compatibility, a callback that takes an
|
||||
`.HTTPServerRequest` as an argument. The delegate is usually a
|
||||
`tornado.web.Application`.
|
||||
|
||||
`HTTPServer` supports keep-alive connections by default
|
||||
(automatically for HTTP/1.1, or for HTTP/1.0 when the client
|
||||
requests ``Connection: keep-alive``).
|
||||
|
||||
If ``xheaders`` is ``True``, we support the
|
||||
``X-Real-Ip``/``X-Forwarded-For`` and
|
||||
``X-Scheme``/``X-Forwarded-Proto`` headers, which override the
|
||||
remote IP and URI scheme/protocol for all requests. These headers
|
||||
are useful when running Tornado behind a reverse proxy or load
|
||||
balancer. The ``protocol`` argument can also be set to ``https``
|
||||
if Tornado is run behind an SSL-decoding proxy that does not set one of
|
||||
the supported ``xheaders``.
|
||||
|
||||
By default, when parsing the ``X-Forwarded-For`` header, Tornado will
|
||||
select the last (i.e., the closest) address on the list of hosts as the
|
||||
remote host IP address. To select the next server in the chain, a list of
|
||||
trusted downstream hosts may be passed as the ``trusted_downstream``
|
||||
argument. These hosts will be skipped when parsing the ``X-Forwarded-For``
|
||||
header.
|
||||
|
||||
To make this server serve SSL traffic, send the ``ssl_options`` keyword
|
||||
argument with an `ssl.SSLContext` object. For compatibility with older
|
||||
versions of Python ``ssl_options`` may also be a dictionary of keyword
|
||||
arguments for the `ssl.wrap_socket` method.::
|
||||
|
||||
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"),
|
||||
os.path.join(data_dir, "mydomain.key"))
|
||||
HTTPServer(application, ssl_options=ssl_ctx)
|
||||
|
||||
`HTTPServer` initialization follows one of three patterns (the
|
||||
initialization methods are defined on `tornado.tcpserver.TCPServer`):
|
||||
|
||||
1. `~tornado.tcpserver.TCPServer.listen`: simple single-process::
|
||||
|
||||
server = HTTPServer(app)
|
||||
server.listen(8888)
|
||||
IOLoop.current().start()
|
||||
|
||||
In many cases, `tornado.web.Application.listen` can be used to avoid
|
||||
the need to explicitly create the `HTTPServer`.
|
||||
|
||||
2. `~tornado.tcpserver.TCPServer.bind`/`~tornado.tcpserver.TCPServer.start`:
|
||||
simple multi-process::
|
||||
|
||||
server = HTTPServer(app)
|
||||
server.bind(8888)
|
||||
server.start(0) # Forks multiple sub-processes
|
||||
IOLoop.current().start()
|
||||
|
||||
When using this interface, an `.IOLoop` must *not* be passed
|
||||
to the `HTTPServer` constructor. `~.TCPServer.start` will always start
|
||||
the server on the default singleton `.IOLoop`.
|
||||
|
||||
3. `~tornado.tcpserver.TCPServer.add_sockets`: advanced multi-process::
|
||||
|
||||
sockets = tornado.netutil.bind_sockets(8888)
|
||||
tornado.process.fork_processes(0)
|
||||
server = HTTPServer(app)
|
||||
server.add_sockets(sockets)
|
||||
IOLoop.current().start()
|
||||
|
||||
The `~.TCPServer.add_sockets` interface is more complicated,
|
||||
but it can be used with `tornado.process.fork_processes` to
|
||||
give you more flexibility in when the fork happens.
|
||||
`~.TCPServer.add_sockets` can also be used in single-process
|
||||
servers if you want to create your listening sockets in some
|
||||
way other than `tornado.netutil.bind_sockets`.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
Added ``decompress_request``, ``chunk_size``, ``max_header_size``,
|
||||
``idle_connection_timeout``, ``body_timeout``, ``max_body_size``
|
||||
arguments. Added support for `.HTTPServerConnectionDelegate`
|
||||
instances as ``request_callback``.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
`.HTTPServerConnectionDelegate.start_request` is now called with
|
||||
two arguments ``(server_conn, request_conn)`` (in accordance with the
|
||||
documentation) instead of one ``(request_conn)``.
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
`HTTPServer` is now a subclass of `tornado.util.Configurable`.
|
||||
|
||||
.. versionchanged:: 4.5
|
||||
Added the ``trusted_downstream`` argument.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
The ``io_loop`` argument has been removed.
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Ignore args to __init__; real initialization belongs in
|
||||
# initialize since we're Configurable. (there's something
|
||||
# weird in initialization order between this class,
|
||||
# Configurable, and TCPServer so we can't leave __init__ out
|
||||
# completely)
|
||||
pass
|
||||
|
||||
def initialize(self, request_callback, no_keep_alive=False,
|
||||
xheaders=False, ssl_options=None, protocol=None,
|
||||
decompress_request=False,
|
||||
chunk_size=None, max_header_size=None,
|
||||
idle_connection_timeout=None, body_timeout=None,
|
||||
max_body_size=None, max_buffer_size=None,
|
||||
trusted_downstream=None):
|
||||
self.request_callback = request_callback
|
||||
self.xheaders = xheaders
|
||||
self.protocol = protocol
|
||||
self.conn_params = HTTP1ConnectionParameters(
|
||||
decompress=decompress_request,
|
||||
chunk_size=chunk_size,
|
||||
max_header_size=max_header_size,
|
||||
header_timeout=idle_connection_timeout or 3600,
|
||||
max_body_size=max_body_size,
|
||||
body_timeout=body_timeout,
|
||||
no_keep_alive=no_keep_alive)
|
||||
TCPServer.__init__(self, ssl_options=ssl_options,
|
||||
max_buffer_size=max_buffer_size,
|
||||
read_chunk_size=chunk_size)
|
||||
self._connections = set()
|
||||
self.trusted_downstream = trusted_downstream
|
||||
|
||||
@classmethod
|
||||
def configurable_base(cls):
|
||||
return HTTPServer
|
||||
|
||||
@classmethod
|
||||
def configurable_default(cls):
|
||||
return HTTPServer
|
||||
|
||||
@gen.coroutine
|
||||
def close_all_connections(self):
|
||||
while self._connections:
|
||||
# Peek at an arbitrary element of the set
|
||||
conn = next(iter(self._connections))
|
||||
yield conn.close()
|
||||
|
||||
def handle_stream(self, stream, address):
|
||||
context = _HTTPRequestContext(stream, address,
|
||||
self.protocol,
|
||||
self.trusted_downstream)
|
||||
conn = HTTP1ServerConnection(
|
||||
stream, self.conn_params, context)
|
||||
self._connections.add(conn)
|
||||
conn.start_serving(self)
|
||||
|
||||
def start_request(self, server_conn, request_conn):
|
||||
if isinstance(self.request_callback, httputil.HTTPServerConnectionDelegate):
|
||||
delegate = self.request_callback.start_request(server_conn, request_conn)
|
||||
else:
|
||||
delegate = _CallableAdapter(self.request_callback, request_conn)
|
||||
|
||||
if self.xheaders:
|
||||
delegate = _ProxyAdapter(delegate, request_conn)
|
||||
|
||||
return delegate
|
||||
|
||||
def on_close(self, server_conn):
|
||||
self._connections.remove(server_conn)
|
||||
|
||||
|
||||
class _CallableAdapter(httputil.HTTPMessageDelegate):
|
||||
def __init__(self, request_callback, request_conn):
|
||||
self.connection = request_conn
|
||||
self.request_callback = request_callback
|
||||
self.request = None
|
||||
self.delegate = None
|
||||
self._chunks = []
|
||||
|
||||
def headers_received(self, start_line, headers):
|
||||
self.request = httputil.HTTPServerRequest(
|
||||
connection=self.connection, start_line=start_line,
|
||||
headers=headers)
|
||||
|
||||
def data_received(self, chunk):
|
||||
self._chunks.append(chunk)
|
||||
|
||||
def finish(self):
|
||||
self.request.body = b''.join(self._chunks)
|
||||
self.request._parse_body()
|
||||
self.request_callback(self.request)
|
||||
|
||||
def on_connection_close(self):
|
||||
self._chunks = None
|
||||
|
||||
|
||||
class _HTTPRequestContext(object):
|
||||
def __init__(self, stream, address, protocol, trusted_downstream=None):
|
||||
self.address = address
|
||||
# Save the socket's address family now so we know how to
|
||||
# interpret self.address even after the stream is closed
|
||||
# and its socket attribute replaced with None.
|
||||
if stream.socket is not None:
|
||||
self.address_family = stream.socket.family
|
||||
else:
|
||||
self.address_family = None
|
||||
# In HTTPServerRequest we want an IP, not a full socket address.
|
||||
if (self.address_family in (socket.AF_INET, socket.AF_INET6) and
|
||||
address is not None):
|
||||
self.remote_ip = address[0]
|
||||
else:
|
||||
# Unix (or other) socket; fake the remote address.
|
||||
self.remote_ip = '0.0.0.0'
|
||||
if protocol:
|
||||
self.protocol = protocol
|
||||
elif isinstance(stream, iostream.SSLIOStream):
|
||||
self.protocol = "https"
|
||||
else:
|
||||
self.protocol = "http"
|
||||
self._orig_remote_ip = self.remote_ip
|
||||
self._orig_protocol = self.protocol
|
||||
self.trusted_downstream = set(trusted_downstream or [])
|
||||
|
||||
def __str__(self):
|
||||
if self.address_family in (socket.AF_INET, socket.AF_INET6):
|
||||
return self.remote_ip
|
||||
elif isinstance(self.address, bytes):
|
||||
# Python 3 with the -bb option warns about str(bytes),
|
||||
# so convert it explicitly.
|
||||
# Unix socket addresses are str on mac but bytes on linux.
|
||||
return native_str(self.address)
|
||||
else:
|
||||
return str(self.address)
|
||||
|
||||
def _apply_xheaders(self, headers):
|
||||
"""Rewrite the ``remote_ip`` and ``protocol`` fields."""
|
||||
# Squid uses X-Forwarded-For, others use X-Real-Ip
|
||||
ip = headers.get("X-Forwarded-For", self.remote_ip)
|
||||
# Skip trusted downstream hosts in X-Forwarded-For list
|
||||
for ip in (cand.strip() for cand in reversed(ip.split(','))):
|
||||
if ip not in self.trusted_downstream:
|
||||
break
|
||||
ip = headers.get("X-Real-Ip", ip)
|
||||
if netutil.is_valid_ip(ip):
|
||||
self.remote_ip = ip
|
||||
# AWS uses X-Forwarded-Proto
|
||||
proto_header = headers.get(
|
||||
"X-Scheme", headers.get("X-Forwarded-Proto",
|
||||
self.protocol))
|
||||
if proto_header:
|
||||
# use only the last proto entry if there is more than one
|
||||
# TODO: support trusting mutiple layers of proxied protocol
|
||||
proto_header = proto_header.split(',')[-1].strip()
|
||||
if proto_header in ("http", "https"):
|
||||
self.protocol = proto_header
|
||||
|
||||
def _unapply_xheaders(self):
|
||||
"""Undo changes from `_apply_xheaders`.
|
||||
|
||||
Xheaders are per-request so they should not leak to the next
|
||||
request on the same connection.
|
||||
"""
|
||||
self.remote_ip = self._orig_remote_ip
|
||||
self.protocol = self._orig_protocol
|
||||
|
||||
|
||||
class _ProxyAdapter(httputil.HTTPMessageDelegate):
|
||||
def __init__(self, delegate, request_conn):
|
||||
self.connection = request_conn
|
||||
self.delegate = delegate
|
||||
|
||||
def headers_received(self, start_line, headers):
|
||||
self.connection.context._apply_xheaders(headers)
|
||||
return self.delegate.headers_received(start_line, headers)
|
||||
|
||||
def data_received(self, chunk):
|
||||
return self.delegate.data_received(chunk)
|
||||
|
||||
def finish(self):
|
||||
self.delegate.finish()
|
||||
self._cleanup()
|
||||
|
||||
def on_connection_close(self):
|
||||
self.delegate.on_connection_close()
|
||||
self._cleanup()
|
||||
|
||||
def _cleanup(self):
|
||||
self.connection.context._unapply_xheaders()
|
||||
|
||||
|
||||
HTTPRequest = httputil.HTTPServerRequest
|
||||
Executable
+1095
File diff suppressed because it is too large
Load Diff
Executable
+1267
File diff suppressed because it is too large
Load Diff
Executable
+1757
File diff suppressed because it is too large
Load Diff
Executable
+521
@@ -0,0 +1,521 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2009 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Translation methods for generating localized strings.
|
||||
|
||||
To load a locale and generate a translated string::
|
||||
|
||||
user_locale = tornado.locale.get("es_LA")
|
||||
print(user_locale.translate("Sign out"))
|
||||
|
||||
`tornado.locale.get()` returns the closest matching locale, not necessarily the
|
||||
specific locale you requested. You can support pluralization with
|
||||
additional arguments to `~Locale.translate()`, e.g.::
|
||||
|
||||
people = [...]
|
||||
message = user_locale.translate(
|
||||
"%(list)s is online", "%(list)s are online", len(people))
|
||||
print(message % {"list": user_locale.list(people)})
|
||||
|
||||
The first string is chosen if ``len(people) == 1``, otherwise the second
|
||||
string is chosen.
|
||||
|
||||
Applications should call one of `load_translations` (which uses a simple
|
||||
CSV format) or `load_gettext_translations` (which uses the ``.mo`` format
|
||||
supported by `gettext` and related tools). If neither method is called,
|
||||
the `Locale.translate` method will simply return the original string.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import codecs
|
||||
import csv
|
||||
import datetime
|
||||
from io import BytesIO
|
||||
import numbers
|
||||
import os
|
||||
import re
|
||||
|
||||
from tornado import escape
|
||||
from tornado.log import gen_log
|
||||
from tornado.util import PY3
|
||||
|
||||
from tornado._locale_data import LOCALE_NAMES
|
||||
|
||||
_default_locale = "en_US"
|
||||
_translations = {} # type: dict
|
||||
_supported_locales = frozenset([_default_locale])
|
||||
_use_gettext = False
|
||||
CONTEXT_SEPARATOR = "\x04"
|
||||
|
||||
|
||||
def get(*locale_codes):
|
||||
"""Returns the closest match for the given locale codes.
|
||||
|
||||
We iterate over all given locale codes in order. If we have a tight
|
||||
or a loose match for the code (e.g., "en" for "en_US"), we return
|
||||
the locale. Otherwise we move to the next code in the list.
|
||||
|
||||
By default we return ``en_US`` if no translations are found for any of
|
||||
the specified locales. You can change the default locale with
|
||||
`set_default_locale()`.
|
||||
"""
|
||||
return Locale.get_closest(*locale_codes)
|
||||
|
||||
|
||||
def set_default_locale(code):
|
||||
"""Sets the default locale.
|
||||
|
||||
The default locale is assumed to be the language used for all strings
|
||||
in the system. The translations loaded from disk are mappings from
|
||||
the default locale to the destination locale. Consequently, you don't
|
||||
need to create a translation file for the default locale.
|
||||
"""
|
||||
global _default_locale
|
||||
global _supported_locales
|
||||
_default_locale = code
|
||||
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
|
||||
|
||||
|
||||
def load_translations(directory, encoding=None):
|
||||
"""Loads translations from CSV files in a directory.
|
||||
|
||||
Translations are strings with optional Python-style named placeholders
|
||||
(e.g., ``My name is %(name)s``) and their associated translations.
|
||||
|
||||
The directory should have translation files of the form ``LOCALE.csv``,
|
||||
e.g. ``es_GT.csv``. The CSV files should have two or three columns: string,
|
||||
translation, and an optional plural indicator. Plural indicators should
|
||||
be one of "plural" or "singular". A given string can have both singular
|
||||
and plural forms. For example ``%(name)s liked this`` may have a
|
||||
different verb conjugation depending on whether %(name)s is one
|
||||
name or a list of names. There should be two rows in the CSV file for
|
||||
that string, one with plural indicator "singular", and one "plural".
|
||||
For strings with no verbs that would change on translation, simply
|
||||
use "unknown" or the empty string (or don't include the column at all).
|
||||
|
||||
The file is read using the `csv` module in the default "excel" dialect.
|
||||
In this format there should not be spaces after the commas.
|
||||
|
||||
If no ``encoding`` parameter is given, the encoding will be
|
||||
detected automatically (among UTF-8 and UTF-16) if the file
|
||||
contains a byte-order marker (BOM), defaulting to UTF-8 if no BOM
|
||||
is present.
|
||||
|
||||
Example translation ``es_LA.csv``::
|
||||
|
||||
"I love you","Te amo"
|
||||
"%(name)s liked this","A %(name)s les gustó esto","plural"
|
||||
"%(name)s liked this","A %(name)s le gustó esto","singular"
|
||||
|
||||
.. versionchanged:: 4.3
|
||||
Added ``encoding`` parameter. Added support for BOM-based encoding
|
||||
detection, UTF-16, and UTF-8-with-BOM.
|
||||
"""
|
||||
global _translations
|
||||
global _supported_locales
|
||||
_translations = {}
|
||||
for path in os.listdir(directory):
|
||||
if not path.endswith(".csv"):
|
||||
continue
|
||||
locale, extension = path.split(".")
|
||||
if not re.match("[a-z]+(_[A-Z]+)?$", locale):
|
||||
gen_log.error("Unrecognized locale %r (path: %s)", locale,
|
||||
os.path.join(directory, path))
|
||||
continue
|
||||
full_path = os.path.join(directory, path)
|
||||
if encoding is None:
|
||||
# Try to autodetect encoding based on the BOM.
|
||||
with open(full_path, 'rb') as f:
|
||||
data = f.read(len(codecs.BOM_UTF16_LE))
|
||||
if data in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE):
|
||||
encoding = 'utf-16'
|
||||
else:
|
||||
# utf-8-sig is "utf-8 with optional BOM". It's discouraged
|
||||
# in most cases but is common with CSV files because Excel
|
||||
# cannot read utf-8 files without a BOM.
|
||||
encoding = 'utf-8-sig'
|
||||
if PY3:
|
||||
# python 3: csv.reader requires a file open in text mode.
|
||||
# Force utf8 to avoid dependence on $LANG environment variable.
|
||||
f = open(full_path, "r", encoding=encoding)
|
||||
else:
|
||||
# python 2: csv can only handle byte strings (in ascii-compatible
|
||||
# encodings), which we decode below. Transcode everything into
|
||||
# utf8 before passing it to csv.reader.
|
||||
f = BytesIO()
|
||||
with codecs.open(full_path, "r", encoding=encoding) as infile:
|
||||
f.write(escape.utf8(infile.read()))
|
||||
f.seek(0)
|
||||
_translations[locale] = {}
|
||||
for i, row in enumerate(csv.reader(f)):
|
||||
if not row or len(row) < 2:
|
||||
continue
|
||||
row = [escape.to_unicode(c).strip() for c in row]
|
||||
english, translation = row[:2]
|
||||
if len(row) > 2:
|
||||
plural = row[2] or "unknown"
|
||||
else:
|
||||
plural = "unknown"
|
||||
if plural not in ("plural", "singular", "unknown"):
|
||||
gen_log.error("Unrecognized plural indicator %r in %s line %d",
|
||||
plural, path, i + 1)
|
||||
continue
|
||||
_translations[locale].setdefault(plural, {})[english] = translation
|
||||
f.close()
|
||||
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
|
||||
gen_log.debug("Supported locales: %s", sorted(_supported_locales))
|
||||
|
||||
|
||||
def load_gettext_translations(directory, domain):
|
||||
"""Loads translations from `gettext`'s locale tree
|
||||
|
||||
Locale tree is similar to system's ``/usr/share/locale``, like::
|
||||
|
||||
{directory}/{lang}/LC_MESSAGES/{domain}.mo
|
||||
|
||||
Three steps are required to have your app translated:
|
||||
|
||||
1. Generate POT translation file::
|
||||
|
||||
xgettext --language=Python --keyword=_:1,2 -d mydomain file1.py file2.html etc
|
||||
|
||||
2. Merge against existing POT file::
|
||||
|
||||
msgmerge old.po mydomain.po > new.po
|
||||
|
||||
3. Compile::
|
||||
|
||||
msgfmt mydomain.po -o {directory}/pt_BR/LC_MESSAGES/mydomain.mo
|
||||
"""
|
||||
import gettext
|
||||
global _translations
|
||||
global _supported_locales
|
||||
global _use_gettext
|
||||
_translations = {}
|
||||
for lang in os.listdir(directory):
|
||||
if lang.startswith('.'):
|
||||
continue # skip .svn, etc
|
||||
if os.path.isfile(os.path.join(directory, lang)):
|
||||
continue
|
||||
try:
|
||||
os.stat(os.path.join(directory, lang, "LC_MESSAGES", domain + ".mo"))
|
||||
_translations[lang] = gettext.translation(domain, directory,
|
||||
languages=[lang])
|
||||
except Exception as e:
|
||||
gen_log.error("Cannot load translation for '%s': %s", lang, str(e))
|
||||
continue
|
||||
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
|
||||
_use_gettext = True
|
||||
gen_log.debug("Supported locales: %s", sorted(_supported_locales))
|
||||
|
||||
|
||||
def get_supported_locales():
|
||||
"""Returns a list of all the supported locale codes."""
|
||||
return _supported_locales
|
||||
|
||||
|
||||
class Locale(object):
|
||||
"""Object representing a locale.
|
||||
|
||||
After calling one of `load_translations` or `load_gettext_translations`,
|
||||
call `get` or `get_closest` to get a Locale object.
|
||||
"""
|
||||
@classmethod
|
||||
def get_closest(cls, *locale_codes):
|
||||
"""Returns the closest match for the given locale code."""
|
||||
for code in locale_codes:
|
||||
if not code:
|
||||
continue
|
||||
code = code.replace("-", "_")
|
||||
parts = code.split("_")
|
||||
if len(parts) > 2:
|
||||
continue
|
||||
elif len(parts) == 2:
|
||||
code = parts[0].lower() + "_" + parts[1].upper()
|
||||
if code in _supported_locales:
|
||||
return cls.get(code)
|
||||
if parts[0].lower() in _supported_locales:
|
||||
return cls.get(parts[0].lower())
|
||||
return cls.get(_default_locale)
|
||||
|
||||
@classmethod
|
||||
def get(cls, code):
|
||||
"""Returns the Locale for the given locale code.
|
||||
|
||||
If it is not supported, we raise an exception.
|
||||
"""
|
||||
if not hasattr(cls, "_cache"):
|
||||
cls._cache = {}
|
||||
if code not in cls._cache:
|
||||
assert code in _supported_locales
|
||||
translations = _translations.get(code, None)
|
||||
if translations is None:
|
||||
locale = CSVLocale(code, {})
|
||||
elif _use_gettext:
|
||||
locale = GettextLocale(code, translations)
|
||||
else:
|
||||
locale = CSVLocale(code, translations)
|
||||
cls._cache[code] = locale
|
||||
return cls._cache[code]
|
||||
|
||||
def __init__(self, code, translations):
|
||||
self.code = code
|
||||
self.name = LOCALE_NAMES.get(code, {}).get("name", u"Unknown")
|
||||
self.rtl = False
|
||||
for prefix in ["fa", "ar", "he"]:
|
||||
if self.code.startswith(prefix):
|
||||
self.rtl = True
|
||||
break
|
||||
self.translations = translations
|
||||
|
||||
# Initialize strings for date formatting
|
||||
_ = self.translate
|
||||
self._months = [
|
||||
_("January"), _("February"), _("March"), _("April"),
|
||||
_("May"), _("June"), _("July"), _("August"),
|
||||
_("September"), _("October"), _("November"), _("December")]
|
||||
self._weekdays = [
|
||||
_("Monday"), _("Tuesday"), _("Wednesday"), _("Thursday"),
|
||||
_("Friday"), _("Saturday"), _("Sunday")]
|
||||
|
||||
def translate(self, message, plural_message=None, count=None):
|
||||
"""Returns the translation for the given message for this locale.
|
||||
|
||||
If ``plural_message`` is given, you must also provide
|
||||
``count``. We return ``plural_message`` when ``count != 1``,
|
||||
and we return the singular form for the given message when
|
||||
``count == 1``.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def pgettext(self, context, message, plural_message=None, count=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def format_date(self, date, gmt_offset=0, relative=True, shorter=False,
|
||||
full_format=False):
|
||||
"""Formats the given date (which should be GMT).
|
||||
|
||||
By default, we return a relative time (e.g., "2 minutes ago"). You
|
||||
can return an absolute date string with ``relative=False``.
|
||||
|
||||
You can force a full format date ("July 10, 1980") with
|
||||
``full_format=True``.
|
||||
|
||||
This method is primarily intended for dates in the past.
|
||||
For dates in the future, we fall back to full format.
|
||||
"""
|
||||
if isinstance(date, numbers.Real):
|
||||
date = datetime.datetime.utcfromtimestamp(date)
|
||||
now = datetime.datetime.utcnow()
|
||||
if date > now:
|
||||
if relative and (date - now).seconds < 60:
|
||||
# Due to click skew, things are some things slightly
|
||||
# in the future. Round timestamps in the immediate
|
||||
# future down to now in relative mode.
|
||||
date = now
|
||||
else:
|
||||
# Otherwise, future dates always use the full format.
|
||||
full_format = True
|
||||
local_date = date - datetime.timedelta(minutes=gmt_offset)
|
||||
local_now = now - datetime.timedelta(minutes=gmt_offset)
|
||||
local_yesterday = local_now - datetime.timedelta(hours=24)
|
||||
difference = now - date
|
||||
seconds = difference.seconds
|
||||
days = difference.days
|
||||
|
||||
_ = self.translate
|
||||
format = None
|
||||
if not full_format:
|
||||
if relative and days == 0:
|
||||
if seconds < 50:
|
||||
return _("1 second ago", "%(seconds)d seconds ago",
|
||||
seconds) % {"seconds": seconds}
|
||||
|
||||
if seconds < 50 * 60:
|
||||
minutes = round(seconds / 60.0)
|
||||
return _("1 minute ago", "%(minutes)d minutes ago",
|
||||
minutes) % {"minutes": minutes}
|
||||
|
||||
hours = round(seconds / (60.0 * 60))
|
||||
return _("1 hour ago", "%(hours)d hours ago",
|
||||
hours) % {"hours": hours}
|
||||
|
||||
if days == 0:
|
||||
format = _("%(time)s")
|
||||
elif days == 1 and local_date.day == local_yesterday.day and \
|
||||
relative:
|
||||
format = _("yesterday") if shorter else \
|
||||
_("yesterday at %(time)s")
|
||||
elif days < 5:
|
||||
format = _("%(weekday)s") if shorter else \
|
||||
_("%(weekday)s at %(time)s")
|
||||
elif days < 334: # 11mo, since confusing for same month last year
|
||||
format = _("%(month_name)s %(day)s") if shorter else \
|
||||
_("%(month_name)s %(day)s at %(time)s")
|
||||
|
||||
if format is None:
|
||||
format = _("%(month_name)s %(day)s, %(year)s") if shorter else \
|
||||
_("%(month_name)s %(day)s, %(year)s at %(time)s")
|
||||
|
||||
tfhour_clock = self.code not in ("en", "en_US", "zh_CN")
|
||||
if tfhour_clock:
|
||||
str_time = "%d:%02d" % (local_date.hour, local_date.minute)
|
||||
elif self.code == "zh_CN":
|
||||
str_time = "%s%d:%02d" % (
|
||||
(u'\u4e0a\u5348', u'\u4e0b\u5348')[local_date.hour >= 12],
|
||||
local_date.hour % 12 or 12, local_date.minute)
|
||||
else:
|
||||
str_time = "%d:%02d %s" % (
|
||||
local_date.hour % 12 or 12, local_date.minute,
|
||||
("am", "pm")[local_date.hour >= 12])
|
||||
|
||||
return format % {
|
||||
"month_name": self._months[local_date.month - 1],
|
||||
"weekday": self._weekdays[local_date.weekday()],
|
||||
"day": str(local_date.day),
|
||||
"year": str(local_date.year),
|
||||
"time": str_time
|
||||
}
|
||||
|
||||
def format_day(self, date, gmt_offset=0, dow=True):
|
||||
"""Formats the given date as a day of week.
|
||||
|
||||
Example: "Monday, January 22". You can remove the day of week with
|
||||
``dow=False``.
|
||||
"""
|
||||
local_date = date - datetime.timedelta(minutes=gmt_offset)
|
||||
_ = self.translate
|
||||
if dow:
|
||||
return _("%(weekday)s, %(month_name)s %(day)s") % {
|
||||
"month_name": self._months[local_date.month - 1],
|
||||
"weekday": self._weekdays[local_date.weekday()],
|
||||
"day": str(local_date.day),
|
||||
}
|
||||
else:
|
||||
return _("%(month_name)s %(day)s") % {
|
||||
"month_name": self._months[local_date.month - 1],
|
||||
"day": str(local_date.day),
|
||||
}
|
||||
|
||||
def list(self, parts):
|
||||
"""Returns a comma-separated list for the given list of parts.
|
||||
|
||||
The format is, e.g., "A, B and C", "A and B" or just "A" for lists
|
||||
of size 1.
|
||||
"""
|
||||
_ = self.translate
|
||||
if len(parts) == 0:
|
||||
return ""
|
||||
if len(parts) == 1:
|
||||
return parts[0]
|
||||
comma = u' \u0648 ' if self.code.startswith("fa") else u", "
|
||||
return _("%(commas)s and %(last)s") % {
|
||||
"commas": comma.join(parts[:-1]),
|
||||
"last": parts[len(parts) - 1],
|
||||
}
|
||||
|
||||
def friendly_number(self, value):
|
||||
"""Returns a comma-separated number for the given integer."""
|
||||
if self.code not in ("en", "en_US"):
|
||||
return str(value)
|
||||
value = str(value)
|
||||
parts = []
|
||||
while value:
|
||||
parts.append(value[-3:])
|
||||
value = value[:-3]
|
||||
return ",".join(reversed(parts))
|
||||
|
||||
|
||||
class CSVLocale(Locale):
|
||||
"""Locale implementation using tornado's CSV translation format."""
|
||||
def translate(self, message, plural_message=None, count=None):
|
||||
if plural_message is not None:
|
||||
assert count is not None
|
||||
if count != 1:
|
||||
message = plural_message
|
||||
message_dict = self.translations.get("plural", {})
|
||||
else:
|
||||
message_dict = self.translations.get("singular", {})
|
||||
else:
|
||||
message_dict = self.translations.get("unknown", {})
|
||||
return message_dict.get(message, message)
|
||||
|
||||
def pgettext(self, context, message, plural_message=None, count=None):
|
||||
if self.translations:
|
||||
gen_log.warning('pgettext is not supported by CSVLocale')
|
||||
return self.translate(message, plural_message, count)
|
||||
|
||||
|
||||
class GettextLocale(Locale):
|
||||
"""Locale implementation using the `gettext` module."""
|
||||
def __init__(self, code, translations):
|
||||
try:
|
||||
# python 2
|
||||
self.ngettext = translations.ungettext
|
||||
self.gettext = translations.ugettext
|
||||
except AttributeError:
|
||||
# python 3
|
||||
self.ngettext = translations.ngettext
|
||||
self.gettext = translations.gettext
|
||||
# self.gettext must exist before __init__ is called, since it
|
||||
# calls into self.translate
|
||||
super(GettextLocale, self).__init__(code, translations)
|
||||
|
||||
def translate(self, message, plural_message=None, count=None):
|
||||
if plural_message is not None:
|
||||
assert count is not None
|
||||
return self.ngettext(message, plural_message, count)
|
||||
else:
|
||||
return self.gettext(message)
|
||||
|
||||
def pgettext(self, context, message, plural_message=None, count=None):
|
||||
"""Allows to set context for translation, accepts plural forms.
|
||||
|
||||
Usage example::
|
||||
|
||||
pgettext("law", "right")
|
||||
pgettext("good", "right")
|
||||
|
||||
Plural message example::
|
||||
|
||||
pgettext("organization", "club", "clubs", len(clubs))
|
||||
pgettext("stick", "club", "clubs", len(clubs))
|
||||
|
||||
To generate POT file with context, add following options to step 1
|
||||
of `load_gettext_translations` sequence::
|
||||
|
||||
xgettext [basic options] --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3
|
||||
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
if plural_message is not None:
|
||||
assert count is not None
|
||||
msgs_with_ctxt = ("%s%s%s" % (context, CONTEXT_SEPARATOR, message),
|
||||
"%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message),
|
||||
count)
|
||||
result = self.ngettext(*msgs_with_ctxt)
|
||||
if CONTEXT_SEPARATOR in result:
|
||||
# Translation not found
|
||||
result = self.ngettext(message, plural_message, count)
|
||||
return result
|
||||
else:
|
||||
msg_with_ctxt = "%s%s%s" % (context, CONTEXT_SEPARATOR, message)
|
||||
result = self.gettext(msg_with_ctxt)
|
||||
if CONTEXT_SEPARATOR in result:
|
||||
# Translation not found
|
||||
result = message
|
||||
return result
|
||||
Executable
+526
@@ -0,0 +1,526 @@
|
||||
# Copyright 2015 The Tornado Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import collections
|
||||
from concurrent.futures import CancelledError
|
||||
|
||||
from tornado import gen, ioloop
|
||||
from tornado.concurrent import Future, future_set_result_unless_cancelled
|
||||
|
||||
__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
|
||||
|
||||
|
||||
class _TimeoutGarbageCollector(object):
|
||||
"""Base class for objects that periodically clean up timed-out waiters.
|
||||
|
||||
Avoids memory leak in a common pattern like:
|
||||
|
||||
while True:
|
||||
yield condition.wait(short_timeout)
|
||||
print('looping....')
|
||||
"""
|
||||
def __init__(self):
|
||||
self._waiters = collections.deque() # Futures.
|
||||
self._timeouts = 0
|
||||
|
||||
def _garbage_collect(self):
|
||||
# Occasionally clear timed-out waiters.
|
||||
self._timeouts += 1
|
||||
if self._timeouts > 100:
|
||||
self._timeouts = 0
|
||||
self._waiters = collections.deque(
|
||||
w for w in self._waiters if not w.done())
|
||||
|
||||
|
||||
class Condition(_TimeoutGarbageCollector):
|
||||
"""A condition allows one or more coroutines to wait until notified.
|
||||
|
||||
Like a standard `threading.Condition`, but does not need an underlying lock
|
||||
that is acquired and released.
|
||||
|
||||
With a `Condition`, coroutines can wait to be notified by other coroutines:
|
||||
|
||||
.. testcode::
|
||||
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.locks import Condition
|
||||
|
||||
condition = Condition()
|
||||
|
||||
async def waiter():
|
||||
print("I'll wait right here")
|
||||
await condition.wait()
|
||||
print("I'm done waiting")
|
||||
|
||||
async def notifier():
|
||||
print("About to notify")
|
||||
condition.notify()
|
||||
print("Done notifying")
|
||||
|
||||
async def runner():
|
||||
# Wait for waiter() and notifier() in parallel
|
||||
await gen.multi([waiter(), notifier()])
|
||||
|
||||
IOLoop.current().run_sync(runner)
|
||||
|
||||
.. testoutput::
|
||||
|
||||
I'll wait right here
|
||||
About to notify
|
||||
Done notifying
|
||||
I'm done waiting
|
||||
|
||||
`wait` takes an optional ``timeout`` argument, which is either an absolute
|
||||
timestamp::
|
||||
|
||||
io_loop = IOLoop.current()
|
||||
|
||||
# Wait up to 1 second for a notification.
|
||||
await condition.wait(timeout=io_loop.time() + 1)
|
||||
|
||||
...or a `datetime.timedelta` for a timeout relative to the current time::
|
||||
|
||||
# Wait up to 1 second.
|
||||
await condition.wait(timeout=datetime.timedelta(seconds=1))
|
||||
|
||||
The method returns False if there's no notification before the deadline.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
Previously, waiters could be notified synchronously from within
|
||||
`notify`. Now, the notification will always be received on the
|
||||
next iteration of the `.IOLoop`.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Condition, self).__init__()
|
||||
self.io_loop = ioloop.IOLoop.current()
|
||||
|
||||
def __repr__(self):
|
||||
result = '<%s' % (self.__class__.__name__, )
|
||||
if self._waiters:
|
||||
result += ' waiters[%s]' % len(self._waiters)
|
||||
return result + '>'
|
||||
|
||||
def wait(self, timeout=None):
|
||||
"""Wait for `.notify`.
|
||||
|
||||
Returns a `.Future` that resolves ``True`` if the condition is notified,
|
||||
or ``False`` after a timeout.
|
||||
"""
|
||||
waiter = Future()
|
||||
self._waiters.append(waiter)
|
||||
if timeout:
|
||||
def on_timeout():
|
||||
if not waiter.done():
|
||||
future_set_result_unless_cancelled(waiter, False)
|
||||
self._garbage_collect()
|
||||
io_loop = ioloop.IOLoop.current()
|
||||
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
|
||||
waiter.add_done_callback(
|
||||
lambda _: io_loop.remove_timeout(timeout_handle))
|
||||
return waiter
|
||||
|
||||
def notify(self, n=1):
|
||||
"""Wake ``n`` waiters."""
|
||||
waiters = [] # Waiters we plan to run right now.
|
||||
while n and self._waiters:
|
||||
waiter = self._waiters.popleft()
|
||||
if not waiter.done(): # Might have timed out.
|
||||
n -= 1
|
||||
waiters.append(waiter)
|
||||
|
||||
for waiter in waiters:
|
||||
future_set_result_unless_cancelled(waiter, True)
|
||||
|
||||
def notify_all(self):
|
||||
"""Wake all waiters."""
|
||||
self.notify(len(self._waiters))
|
||||
|
||||
|
||||
class Event(object):
|
||||
"""An event blocks coroutines until its internal flag is set to True.
|
||||
|
||||
Similar to `threading.Event`.
|
||||
|
||||
A coroutine can wait for an event to be set. Once it is set, calls to
|
||||
``yield event.wait()`` will not block unless the event has been cleared:
|
||||
|
||||
.. testcode::
|
||||
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.locks import Event
|
||||
|
||||
event = Event()
|
||||
|
||||
async def waiter():
|
||||
print("Waiting for event")
|
||||
await event.wait()
|
||||
print("Not waiting this time")
|
||||
await event.wait()
|
||||
print("Done")
|
||||
|
||||
async def setter():
|
||||
print("About to set the event")
|
||||
event.set()
|
||||
|
||||
async def runner():
|
||||
await gen.multi([waiter(), setter()])
|
||||
|
||||
IOLoop.current().run_sync(runner)
|
||||
|
||||
.. testoutput::
|
||||
|
||||
Waiting for event
|
||||
About to set the event
|
||||
Not waiting this time
|
||||
Done
|
||||
"""
|
||||
def __init__(self):
|
||||
self._value = False
|
||||
self._waiters = set()
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s %s>' % (
|
||||
self.__class__.__name__, 'set' if self.is_set() else 'clear')
|
||||
|
||||
def is_set(self):
|
||||
"""Return ``True`` if the internal flag is true."""
|
||||
return self._value
|
||||
|
||||
def set(self):
|
||||
"""Set the internal flag to ``True``. All waiters are awakened.
|
||||
|
||||
Calling `.wait` once the flag is set will not block.
|
||||
"""
|
||||
if not self._value:
|
||||
self._value = True
|
||||
|
||||
for fut in self._waiters:
|
||||
if not fut.done():
|
||||
fut.set_result(None)
|
||||
|
||||
def clear(self):
|
||||
"""Reset the internal flag to ``False``.
|
||||
|
||||
Calls to `.wait` will block until `.set` is called.
|
||||
"""
|
||||
self._value = False
|
||||
|
||||
def wait(self, timeout=None):
|
||||
"""Block until the internal flag is true.
|
||||
|
||||
Returns a Future, which raises `tornado.util.TimeoutError` after a
|
||||
timeout.
|
||||
"""
|
||||
fut = Future()
|
||||
if self._value:
|
||||
fut.set_result(None)
|
||||
return fut
|
||||
self._waiters.add(fut)
|
||||
fut.add_done_callback(lambda fut: self._waiters.remove(fut))
|
||||
if timeout is None:
|
||||
return fut
|
||||
else:
|
||||
timeout_fut = gen.with_timeout(timeout, fut, quiet_exceptions=(CancelledError,))
|
||||
# This is a slightly clumsy workaround for the fact that
|
||||
# gen.with_timeout doesn't cancel its futures. Cancelling
|
||||
# fut will remove it from the waiters list.
|
||||
timeout_fut.add_done_callback(lambda tf: fut.cancel() if not fut.done() else None)
|
||||
return timeout_fut
|
||||
|
||||
|
||||
class _ReleasingContextManager(object):
|
||||
"""Releases a Lock or Semaphore at the end of a "with" statement.
|
||||
|
||||
with (yield semaphore.acquire()):
|
||||
pass
|
||||
|
||||
# Now semaphore.release() has been called.
|
||||
"""
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._obj.release()
|
||||
|
||||
|
||||
class Semaphore(_TimeoutGarbageCollector):
|
||||
"""A lock that can be acquired a fixed number of times before blocking.
|
||||
|
||||
A Semaphore manages a counter representing the number of `.release` calls
|
||||
minus the number of `.acquire` calls, plus an initial value. The `.acquire`
|
||||
method blocks if necessary until it can return without making the counter
|
||||
negative.
|
||||
|
||||
Semaphores limit access to a shared resource. To allow access for two
|
||||
workers at a time:
|
||||
|
||||
.. testsetup:: semaphore
|
||||
|
||||
from collections import deque
|
||||
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.concurrent import Future
|
||||
|
||||
# Ensure reliable doctest output: resolve Futures one at a time.
|
||||
futures_q = deque([Future() for _ in range(3)])
|
||||
|
||||
async def simulator(futures):
|
||||
for f in futures:
|
||||
# simulate the asynchronous passage of time
|
||||
await gen.sleep(0)
|
||||
await gen.sleep(0)
|
||||
f.set_result(None)
|
||||
|
||||
IOLoop.current().add_callback(simulator, list(futures_q))
|
||||
|
||||
def use_some_resource():
|
||||
return futures_q.popleft()
|
||||
|
||||
.. testcode:: semaphore
|
||||
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.locks import Semaphore
|
||||
|
||||
sem = Semaphore(2)
|
||||
|
||||
async def worker(worker_id):
|
||||
await sem.acquire()
|
||||
try:
|
||||
print("Worker %d is working" % worker_id)
|
||||
await use_some_resource()
|
||||
finally:
|
||||
print("Worker %d is done" % worker_id)
|
||||
sem.release()
|
||||
|
||||
async def runner():
|
||||
# Join all workers.
|
||||
await gen.multi([worker(i) for i in range(3)])
|
||||
|
||||
IOLoop.current().run_sync(runner)
|
||||
|
||||
.. testoutput:: semaphore
|
||||
|
||||
Worker 0 is working
|
||||
Worker 1 is working
|
||||
Worker 0 is done
|
||||
Worker 2 is working
|
||||
Worker 1 is done
|
||||
Worker 2 is done
|
||||
|
||||
Workers 0 and 1 are allowed to run concurrently, but worker 2 waits until
|
||||
the semaphore has been released once, by worker 0.
|
||||
|
||||
The semaphore can be used as an async context manager::
|
||||
|
||||
async def worker(worker_id):
|
||||
async with sem:
|
||||
print("Worker %d is working" % worker_id)
|
||||
await use_some_resource()
|
||||
|
||||
# Now the semaphore has been released.
|
||||
print("Worker %d is done" % worker_id)
|
||||
|
||||
For compatibility with older versions of Python, `.acquire` is a
|
||||
context manager, so ``worker`` could also be written as::
|
||||
|
||||
@gen.coroutine
|
||||
def worker(worker_id):
|
||||
with (yield sem.acquire()):
|
||||
print("Worker %d is working" % worker_id)
|
||||
yield use_some_resource()
|
||||
|
||||
# Now the semaphore has been released.
|
||||
print("Worker %d is done" % worker_id)
|
||||
|
||||
.. versionchanged:: 4.3
|
||||
Added ``async with`` support in Python 3.5.
|
||||
|
||||
"""
|
||||
def __init__(self, value=1):
|
||||
super(Semaphore, self).__init__()
|
||||
if value < 0:
|
||||
raise ValueError('semaphore initial value must be >= 0')
|
||||
|
||||
self._value = value
|
||||
|
||||
def __repr__(self):
|
||||
res = super(Semaphore, self).__repr__()
|
||||
extra = 'locked' if self._value == 0 else 'unlocked,value:{0}'.format(
|
||||
self._value)
|
||||
if self._waiters:
|
||||
extra = '{0},waiters:{1}'.format(extra, len(self._waiters))
|
||||
return '<{0} [{1}]>'.format(res[1:-1], extra)
|
||||
|
||||
def release(self):
|
||||
"""Increment the counter and wake one waiter."""
|
||||
self._value += 1
|
||||
while self._waiters:
|
||||
waiter = self._waiters.popleft()
|
||||
if not waiter.done():
|
||||
self._value -= 1
|
||||
|
||||
# If the waiter is a coroutine paused at
|
||||
#
|
||||
# with (yield semaphore.acquire()):
|
||||
#
|
||||
# then the context manager's __exit__ calls release() at the end
|
||||
# of the "with" block.
|
||||
waiter.set_result(_ReleasingContextManager(self))
|
||||
break
|
||||
|
||||
def acquire(self, timeout=None):
|
||||
"""Decrement the counter. Returns a Future.
|
||||
|
||||
Block if the counter is zero and wait for a `.release`. The Future
|
||||
raises `.TimeoutError` after the deadline.
|
||||
"""
|
||||
waiter = Future()
|
||||
if self._value > 0:
|
||||
self._value -= 1
|
||||
waiter.set_result(_ReleasingContextManager(self))
|
||||
else:
|
||||
self._waiters.append(waiter)
|
||||
if timeout:
|
||||
def on_timeout():
|
||||
if not waiter.done():
|
||||
waiter.set_exception(gen.TimeoutError())
|
||||
self._garbage_collect()
|
||||
io_loop = ioloop.IOLoop.current()
|
||||
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
|
||||
waiter.add_done_callback(
|
||||
lambda _: io_loop.remove_timeout(timeout_handle))
|
||||
return waiter
|
||||
|
||||
def __enter__(self):
|
||||
raise RuntimeError(
|
||||
"Use Semaphore like 'with (yield semaphore.acquire())', not like"
|
||||
" 'with semaphore'")
|
||||
|
||||
__exit__ = __enter__
|
||||
|
||||
@gen.coroutine
|
||||
def __aenter__(self):
|
||||
yield self.acquire()
|
||||
|
||||
@gen.coroutine
|
||||
def __aexit__(self, typ, value, tb):
|
||||
self.release()
|
||||
|
||||
|
||||
class BoundedSemaphore(Semaphore):
|
||||
"""A semaphore that prevents release() being called too many times.
|
||||
|
||||
If `.release` would increment the semaphore's value past the initial
|
||||
value, it raises `ValueError`. Semaphores are mostly used to guard
|
||||
resources with limited capacity, so a semaphore released too many times
|
||||
is a sign of a bug.
|
||||
"""
|
||||
def __init__(self, value=1):
|
||||
super(BoundedSemaphore, self).__init__(value=value)
|
||||
self._initial_value = value
|
||||
|
||||
def release(self):
|
||||
"""Increment the counter and wake one waiter."""
|
||||
if self._value >= self._initial_value:
|
||||
raise ValueError("Semaphore released too many times")
|
||||
super(BoundedSemaphore, self).release()
|
||||
|
||||
|
||||
class Lock(object):
|
||||
"""A lock for coroutines.
|
||||
|
||||
A Lock begins unlocked, and `acquire` locks it immediately. While it is
|
||||
locked, a coroutine that yields `acquire` waits until another coroutine
|
||||
calls `release`.
|
||||
|
||||
Releasing an unlocked lock raises `RuntimeError`.
|
||||
|
||||
A Lock can be used as an async context manager with the ``async
|
||||
with`` statement:
|
||||
|
||||
>>> from tornado import locks
|
||||
>>> lock = locks.Lock()
|
||||
>>>
|
||||
>>> async def f():
|
||||
... async with lock:
|
||||
... # Do something holding the lock.
|
||||
... pass
|
||||
...
|
||||
... # Now the lock is released.
|
||||
|
||||
For compatibility with older versions of Python, the `.acquire`
|
||||
method asynchronously returns a regular context manager:
|
||||
|
||||
>>> async def f2():
|
||||
... with (yield lock.acquire()):
|
||||
... # Do something holding the lock.
|
||||
... pass
|
||||
...
|
||||
... # Now the lock is released.
|
||||
|
||||
.. versionchanged:: 4.3
|
||||
Added ``async with`` support in Python 3.5.
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
self._block = BoundedSemaphore(value=1)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s _block=%s>" % (
|
||||
self.__class__.__name__,
|
||||
self._block)
|
||||
|
||||
def acquire(self, timeout=None):
|
||||
"""Attempt to lock. Returns a Future.
|
||||
|
||||
Returns a Future, which raises `tornado.util.TimeoutError` after a
|
||||
timeout.
|
||||
"""
|
||||
return self._block.acquire(timeout)
|
||||
|
||||
def release(self):
|
||||
"""Unlock.
|
||||
|
||||
The first coroutine in line waiting for `acquire` gets the lock.
|
||||
|
||||
If not locked, raise a `RuntimeError`.
|
||||
"""
|
||||
try:
|
||||
self._block.release()
|
||||
except ValueError:
|
||||
raise RuntimeError('release unlocked lock')
|
||||
|
||||
def __enter__(self):
|
||||
raise RuntimeError(
|
||||
"Use Lock like 'with (yield lock)', not like 'with lock'")
|
||||
|
||||
__exit__ = __enter__
|
||||
|
||||
@gen.coroutine
|
||||
def __aenter__(self):
|
||||
yield self.acquire()
|
||||
|
||||
@gen.coroutine
|
||||
def __aexit__(self, typ, value, tb):
|
||||
self.release()
|
||||
Executable
+290
@@ -0,0 +1,290 @@
|
||||
#
|
||||
# Copyright 2012 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Logging support for Tornado.
|
||||
|
||||
Tornado uses three logger streams:
|
||||
|
||||
* ``tornado.access``: Per-request logging for Tornado's HTTP servers (and
|
||||
potentially other servers in the future)
|
||||
* ``tornado.application``: Logging of errors from application code (i.e.
|
||||
uncaught exceptions from callbacks)
|
||||
* ``tornado.general``: General-purpose logging, including any errors
|
||||
or warnings from Tornado itself.
|
||||
|
||||
These streams may be configured independently using the standard library's
|
||||
`logging` module. For example, you may wish to send ``tornado.access`` logs
|
||||
to a separate file for analysis.
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import sys
|
||||
|
||||
from tornado.escape import _unicode
|
||||
from tornado.util import unicode_type, basestring_type
|
||||
|
||||
try:
|
||||
import colorama
|
||||
except ImportError:
|
||||
colorama = None
|
||||
|
||||
try:
|
||||
import curses # type: ignore
|
||||
except ImportError:
|
||||
curses = None
|
||||
|
||||
# Logger objects for internal tornado use
|
||||
access_log = logging.getLogger("tornado.access")
|
||||
app_log = logging.getLogger("tornado.application")
|
||||
gen_log = logging.getLogger("tornado.general")
|
||||
|
||||
|
||||
def _stderr_supports_color():
|
||||
try:
|
||||
if hasattr(sys.stderr, 'isatty') and sys.stderr.isatty():
|
||||
if curses:
|
||||
curses.setupterm()
|
||||
if curses.tigetnum("colors") > 0:
|
||||
return True
|
||||
elif colorama:
|
||||
if sys.stderr is getattr(colorama.initialise, 'wrapped_stderr',
|
||||
object()):
|
||||
return True
|
||||
except Exception:
|
||||
# Very broad exception handling because it's always better to
|
||||
# fall back to non-colored logs than to break at startup.
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def _safe_unicode(s):
|
||||
try:
|
||||
return _unicode(s)
|
||||
except UnicodeDecodeError:
|
||||
return repr(s)
|
||||
|
||||
|
||||
class LogFormatter(logging.Formatter):
|
||||
"""Log formatter used in Tornado.
|
||||
|
||||
Key features of this formatter are:
|
||||
|
||||
* Color support when logging to a terminal that supports it.
|
||||
* Timestamps on every log line.
|
||||
* Robust against str/bytes encoding problems.
|
||||
|
||||
This formatter is enabled automatically by
|
||||
`tornado.options.parse_command_line` or `tornado.options.parse_config_file`
|
||||
(unless ``--logging=none`` is used).
|
||||
|
||||
Color support on Windows versions that do not support ANSI color codes is
|
||||
enabled by use of the colorama__ library. Applications that wish to use
|
||||
this must first initialize colorama with a call to ``colorama.init``.
|
||||
See the colorama documentation for details.
|
||||
|
||||
__ https://pypi.python.org/pypi/colorama
|
||||
|
||||
.. versionchanged:: 4.5
|
||||
Added support for ``colorama``. Changed the constructor
|
||||
signature to be compatible with `logging.config.dictConfig`.
|
||||
"""
|
||||
DEFAULT_FORMAT = \
|
||||
'%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s'
|
||||
DEFAULT_DATE_FORMAT = '%y%m%d %H:%M:%S'
|
||||
DEFAULT_COLORS = {
|
||||
logging.DEBUG: 4, # Blue
|
||||
logging.INFO: 2, # Green
|
||||
logging.WARNING: 3, # Yellow
|
||||
logging.ERROR: 1, # Red
|
||||
}
|
||||
|
||||
def __init__(self, fmt=DEFAULT_FORMAT, datefmt=DEFAULT_DATE_FORMAT,
|
||||
style='%', color=True, colors=DEFAULT_COLORS):
|
||||
r"""
|
||||
:arg bool color: Enables color support.
|
||||
:arg str fmt: Log message format.
|
||||
It will be applied to the attributes dict of log records. The
|
||||
text between ``%(color)s`` and ``%(end_color)s`` will be colored
|
||||
depending on the level if color support is on.
|
||||
:arg dict colors: color mappings from logging level to terminal color
|
||||
code
|
||||
:arg str datefmt: Datetime format.
|
||||
Used for formatting ``(asctime)`` placeholder in ``prefix_fmt``.
|
||||
|
||||
.. versionchanged:: 3.2
|
||||
|
||||
Added ``fmt`` and ``datefmt`` arguments.
|
||||
"""
|
||||
logging.Formatter.__init__(self, datefmt=datefmt)
|
||||
self._fmt = fmt
|
||||
|
||||
self._colors = {}
|
||||
if color and _stderr_supports_color():
|
||||
if curses is not None:
|
||||
# The curses module has some str/bytes confusion in
|
||||
# python3. Until version 3.2.3, most methods return
|
||||
# bytes, but only accept strings. In addition, we want to
|
||||
# output these strings with the logging module, which
|
||||
# works with unicode strings. The explicit calls to
|
||||
# unicode() below are harmless in python2 but will do the
|
||||
# right conversion in python 3.
|
||||
fg_color = (curses.tigetstr("setaf") or
|
||||
curses.tigetstr("setf") or "")
|
||||
if (3, 0) < sys.version_info < (3, 2, 3):
|
||||
fg_color = unicode_type(fg_color, "ascii")
|
||||
|
||||
for levelno, code in colors.items():
|
||||
self._colors[levelno] = unicode_type(curses.tparm(fg_color, code), "ascii")
|
||||
self._normal = unicode_type(curses.tigetstr("sgr0"), "ascii")
|
||||
else:
|
||||
# If curses is not present (currently we'll only get here for
|
||||
# colorama on windows), assume hard-coded ANSI color codes.
|
||||
for levelno, code in colors.items():
|
||||
self._colors[levelno] = '\033[2;3%dm' % code
|
||||
self._normal = '\033[0m'
|
||||
else:
|
||||
self._normal = ''
|
||||
|
||||
def format(self, record):
|
||||
try:
|
||||
message = record.getMessage()
|
||||
assert isinstance(message, basestring_type) # guaranteed by logging
|
||||
# Encoding notes: The logging module prefers to work with character
|
||||
# strings, but only enforces that log messages are instances of
|
||||
# basestring. In python 2, non-ascii bytestrings will make
|
||||
# their way through the logging framework until they blow up with
|
||||
# an unhelpful decoding error (with this formatter it happens
|
||||
# when we attach the prefix, but there are other opportunities for
|
||||
# exceptions further along in the framework).
|
||||
#
|
||||
# If a byte string makes it this far, convert it to unicode to
|
||||
# ensure it will make it out to the logs. Use repr() as a fallback
|
||||
# to ensure that all byte strings can be converted successfully,
|
||||
# but don't do it by default so we don't add extra quotes to ascii
|
||||
# bytestrings. This is a bit of a hacky place to do this, but
|
||||
# it's worth it since the encoding errors that would otherwise
|
||||
# result are so useless (and tornado is fond of using utf8-encoded
|
||||
# byte strings wherever possible).
|
||||
record.message = _safe_unicode(message)
|
||||
except Exception as e:
|
||||
record.message = "Bad message (%r): %r" % (e, record.__dict__)
|
||||
|
||||
record.asctime = self.formatTime(record, self.datefmt)
|
||||
|
||||
if record.levelno in self._colors:
|
||||
record.color = self._colors[record.levelno]
|
||||
record.end_color = self._normal
|
||||
else:
|
||||
record.color = record.end_color = ''
|
||||
|
||||
formatted = self._fmt % record.__dict__
|
||||
|
||||
if record.exc_info:
|
||||
if not record.exc_text:
|
||||
record.exc_text = self.formatException(record.exc_info)
|
||||
if record.exc_text:
|
||||
# exc_text contains multiple lines. We need to _safe_unicode
|
||||
# each line separately so that non-utf8 bytes don't cause
|
||||
# all the newlines to turn into '\n'.
|
||||
lines = [formatted.rstrip()]
|
||||
lines.extend(_safe_unicode(ln) for ln in record.exc_text.split('\n'))
|
||||
formatted = '\n'.join(lines)
|
||||
return formatted.replace("\n", "\n ")
|
||||
|
||||
|
||||
def enable_pretty_logging(options=None, logger=None):
|
||||
"""Turns on formatted logging output as configured.
|
||||
|
||||
This is called automatically by `tornado.options.parse_command_line`
|
||||
and `tornado.options.parse_config_file`.
|
||||
"""
|
||||
if options is None:
|
||||
import tornado.options
|
||||
options = tornado.options.options
|
||||
if options.logging is None or options.logging.lower() == 'none':
|
||||
return
|
||||
if logger is None:
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(getattr(logging, options.logging.upper()))
|
||||
if options.log_file_prefix:
|
||||
rotate_mode = options.log_rotate_mode
|
||||
if rotate_mode == 'size':
|
||||
channel = logging.handlers.RotatingFileHandler(
|
||||
filename=options.log_file_prefix,
|
||||
maxBytes=options.log_file_max_size,
|
||||
backupCount=options.log_file_num_backups)
|
||||
elif rotate_mode == 'time':
|
||||
channel = logging.handlers.TimedRotatingFileHandler(
|
||||
filename=options.log_file_prefix,
|
||||
when=options.log_rotate_when,
|
||||
interval=options.log_rotate_interval,
|
||||
backupCount=options.log_file_num_backups)
|
||||
else:
|
||||
error_message = 'The value of log_rotate_mode option should be ' +\
|
||||
'"size" or "time", not "%s".' % rotate_mode
|
||||
raise ValueError(error_message)
|
||||
channel.setFormatter(LogFormatter(color=False))
|
||||
logger.addHandler(channel)
|
||||
|
||||
if (options.log_to_stderr or
|
||||
(options.log_to_stderr is None and not logger.handlers)):
|
||||
# Set up color if we are in a tty and curses is installed
|
||||
channel = logging.StreamHandler()
|
||||
channel.setFormatter(LogFormatter())
|
||||
logger.addHandler(channel)
|
||||
|
||||
|
||||
def define_logging_options(options=None):
|
||||
"""Add logging-related flags to ``options``.
|
||||
|
||||
These options are present automatically on the default options instance;
|
||||
this method is only necessary if you have created your own `.OptionParser`.
|
||||
|
||||
.. versionadded:: 4.2
|
||||
This function existed in prior versions but was broken and undocumented until 4.2.
|
||||
"""
|
||||
if options is None:
|
||||
# late import to prevent cycle
|
||||
import tornado.options
|
||||
options = tornado.options.options
|
||||
options.define("logging", default="info",
|
||||
help=("Set the Python log level. If 'none', tornado won't touch the "
|
||||
"logging configuration."),
|
||||
metavar="debug|info|warning|error|none")
|
||||
options.define("log_to_stderr", type=bool, default=None,
|
||||
help=("Send log output to stderr (colorized if possible). "
|
||||
"By default use stderr if --log_file_prefix is not set and "
|
||||
"no other logging is configured."))
|
||||
options.define("log_file_prefix", type=str, default=None, metavar="PATH",
|
||||
help=("Path prefix for log files. "
|
||||
"Note that if you are running multiple tornado processes, "
|
||||
"log_file_prefix must be different for each of them (e.g. "
|
||||
"include the port number)"))
|
||||
options.define("log_file_max_size", type=int, default=100 * 1000 * 1000,
|
||||
help="max size of log files before rollover")
|
||||
options.define("log_file_num_backups", type=int, default=10,
|
||||
help="number of log files to keep")
|
||||
|
||||
options.define("log_rotate_when", type=str, default='midnight',
|
||||
help=("specify the type of TimedRotatingFileHandler interval "
|
||||
"other options:('S', 'M', 'H', 'D', 'W0'-'W6')"))
|
||||
options.define("log_rotate_interval", type=int, default=1,
|
||||
help="The interval value of timed rotating")
|
||||
|
||||
options.define("log_rotate_mode", type=str, default='size',
|
||||
help="The mode of rotating files(time or size)")
|
||||
|
||||
options.add_parse_callback(lambda: enable_pretty_logging(options))
|
||||
Executable
+575
@@ -0,0 +1,575 @@
|
||||
#
|
||||
# Copyright 2011 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Miscellaneous network utility code."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import errno
|
||||
import os
|
||||
import sys
|
||||
import socket
|
||||
import stat
|
||||
|
||||
from tornado.concurrent import dummy_executor, run_on_executor
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.platform.auto import set_close_exec
|
||||
from tornado.util import PY3, Configurable, errno_from_exception
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
# ssl is not available on Google App Engine
|
||||
ssl = None
|
||||
|
||||
if PY3:
|
||||
xrange = range
|
||||
|
||||
if ssl is not None:
|
||||
# Note that the naming of ssl.Purpose is confusing; the purpose
|
||||
# of a context is to authentiate the opposite side of the connection.
|
||||
_client_ssl_defaults = ssl.create_default_context(
|
||||
ssl.Purpose.SERVER_AUTH)
|
||||
_server_ssl_defaults = ssl.create_default_context(
|
||||
ssl.Purpose.CLIENT_AUTH)
|
||||
if hasattr(ssl, 'OP_NO_COMPRESSION'):
|
||||
# See netutil.ssl_options_to_context
|
||||
_client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
|
||||
_server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
|
||||
else:
|
||||
# Google App Engine
|
||||
_client_ssl_defaults = dict(cert_reqs=None,
|
||||
ca_certs=None)
|
||||
_server_ssl_defaults = {}
|
||||
|
||||
# ThreadedResolver runs getaddrinfo on a thread. If the hostname is unicode,
|
||||
# getaddrinfo attempts to import encodings.idna. If this is done at
|
||||
# module-import time, the import lock is already held by the main thread,
|
||||
# leading to deadlock. Avoid it by caching the idna encoder on the main
|
||||
# thread now.
|
||||
u'foo'.encode('idna')
|
||||
|
||||
# For undiagnosed reasons, 'latin1' codec may also need to be preloaded.
|
||||
u'foo'.encode('latin1')
|
||||
|
||||
# These errnos indicate that a non-blocking operation must be retried
|
||||
# at a later time. On most platforms they're the same value, but on
|
||||
# some they differ.
|
||||
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
|
||||
|
||||
if hasattr(errno, "WSAEWOULDBLOCK"):
|
||||
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) # type: ignore
|
||||
|
||||
# Default backlog used when calling sock.listen()
|
||||
_DEFAULT_BACKLOG = 128
|
||||
|
||||
|
||||
def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
|
||||
backlog=_DEFAULT_BACKLOG, flags=None, reuse_port=False):
|
||||
"""Creates listening sockets bound to the given port and address.
|
||||
|
||||
Returns a list of socket objects (multiple sockets are returned if
|
||||
the given address maps to multiple IP addresses, which is most common
|
||||
for mixed IPv4 and IPv6 use).
|
||||
|
||||
Address may be either an IP address or hostname. If it's a hostname,
|
||||
the server will listen on all IP addresses associated with the
|
||||
name. Address may be an empty string or None to listen on all
|
||||
available interfaces. Family may be set to either `socket.AF_INET`
|
||||
or `socket.AF_INET6` to restrict to IPv4 or IPv6 addresses, otherwise
|
||||
both will be used if available.
|
||||
|
||||
The ``backlog`` argument has the same meaning as for
|
||||
`socket.listen() <socket.socket.listen>`.
|
||||
|
||||
``flags`` is a bitmask of AI_* flags to `~socket.getaddrinfo`, like
|
||||
``socket.AI_PASSIVE | socket.AI_NUMERICHOST``.
|
||||
|
||||
``reuse_port`` option sets ``SO_REUSEPORT`` option for every socket
|
||||
in the list. If your platform doesn't support this option ValueError will
|
||||
be raised.
|
||||
"""
|
||||
if reuse_port and not hasattr(socket, "SO_REUSEPORT"):
|
||||
raise ValueError("the platform doesn't support SO_REUSEPORT")
|
||||
|
||||
sockets = []
|
||||
if address == "":
|
||||
address = None
|
||||
if not socket.has_ipv6 and family == socket.AF_UNSPEC:
|
||||
# Python can be compiled with --disable-ipv6, which causes
|
||||
# operations on AF_INET6 sockets to fail, but does not
|
||||
# automatically exclude those results from getaddrinfo
|
||||
# results.
|
||||
# http://bugs.python.org/issue16208
|
||||
family = socket.AF_INET
|
||||
if flags is None:
|
||||
flags = socket.AI_PASSIVE
|
||||
bound_port = None
|
||||
for res in set(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM,
|
||||
0, flags)):
|
||||
af, socktype, proto, canonname, sockaddr = res
|
||||
if (sys.platform == 'darwin' and address == 'localhost' and
|
||||
af == socket.AF_INET6 and sockaddr[3] != 0):
|
||||
# Mac OS X includes a link-local address fe80::1%lo0 in the
|
||||
# getaddrinfo results for 'localhost'. However, the firewall
|
||||
# doesn't understand that this is a local address and will
|
||||
# prompt for access (often repeatedly, due to an apparent
|
||||
# bug in its ability to remember granting access to an
|
||||
# application). Skip these addresses.
|
||||
continue
|
||||
try:
|
||||
sock = socket.socket(af, socktype, proto)
|
||||
except socket.error as e:
|
||||
if errno_from_exception(e) == errno.EAFNOSUPPORT:
|
||||
continue
|
||||
raise
|
||||
set_close_exec(sock.fileno())
|
||||
if os.name != 'nt':
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
except socket.error as e:
|
||||
if errno_from_exception(e) != errno.ENOPROTOOPT:
|
||||
# Hurd doesn't support SO_REUSEADDR.
|
||||
raise
|
||||
if reuse_port:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
if af == socket.AF_INET6:
|
||||
# On linux, ipv6 sockets accept ipv4 too by default,
|
||||
# but this makes it impossible to bind to both
|
||||
# 0.0.0.0 in ipv4 and :: in ipv6. On other systems,
|
||||
# separate sockets *must* be used to listen for both ipv4
|
||||
# and ipv6. For consistency, always disable ipv4 on our
|
||||
# ipv6 sockets and use a separate ipv4 socket when needed.
|
||||
#
|
||||
# Python 2.x on windows doesn't have IPPROTO_IPV6.
|
||||
if hasattr(socket, "IPPROTO_IPV6"):
|
||||
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
|
||||
|
||||
# automatic port allocation with port=None
|
||||
# should bind on the same port on IPv4 and IPv6
|
||||
host, requested_port = sockaddr[:2]
|
||||
if requested_port == 0 and bound_port is not None:
|
||||
sockaddr = tuple([host, bound_port] + list(sockaddr[2:]))
|
||||
|
||||
sock.setblocking(0)
|
||||
sock.bind(sockaddr)
|
||||
bound_port = sock.getsockname()[1]
|
||||
sock.listen(backlog)
|
||||
sockets.append(sock)
|
||||
return sockets
|
||||
|
||||
|
||||
if hasattr(socket, 'AF_UNIX'):
|
||||
def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG):
|
||||
"""Creates a listening unix socket.
|
||||
|
||||
If a socket with the given name already exists, it will be deleted.
|
||||
If any other file with that name exists, an exception will be
|
||||
raised.
|
||||
|
||||
Returns a socket object (not a list of socket objects like
|
||||
`bind_sockets`)
|
||||
"""
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
set_close_exec(sock.fileno())
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
except socket.error as e:
|
||||
if errno_from_exception(e) != errno.ENOPROTOOPT:
|
||||
# Hurd doesn't support SO_REUSEADDR
|
||||
raise
|
||||
sock.setblocking(0)
|
||||
try:
|
||||
st = os.stat(file)
|
||||
except OSError as err:
|
||||
if errno_from_exception(err) != errno.ENOENT:
|
||||
raise
|
||||
else:
|
||||
if stat.S_ISSOCK(st.st_mode):
|
||||
os.remove(file)
|
||||
else:
|
||||
raise ValueError("File %s exists and is not a socket", file)
|
||||
sock.bind(file)
|
||||
os.chmod(file, mode)
|
||||
sock.listen(backlog)
|
||||
return sock
|
||||
|
||||
|
||||
def add_accept_handler(sock, callback):
|
||||
"""Adds an `.IOLoop` event handler to accept new connections on ``sock``.
|
||||
|
||||
When a connection is accepted, ``callback(connection, address)`` will
|
||||
be run (``connection`` is a socket object, and ``address`` is the
|
||||
address of the other end of the connection). Note that this signature
|
||||
is different from the ``callback(fd, events)`` signature used for
|
||||
`.IOLoop` handlers.
|
||||
|
||||
A callable is returned which, when called, will remove the `.IOLoop`
|
||||
event handler and stop processing further incoming connections.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
A callable is returned (``None`` was returned before).
|
||||
"""
|
||||
io_loop = IOLoop.current()
|
||||
removed = [False]
|
||||
|
||||
def accept_handler(fd, events):
|
||||
# More connections may come in while we're handling callbacks;
|
||||
# to prevent starvation of other tasks we must limit the number
|
||||
# of connections we accept at a time. Ideally we would accept
|
||||
# up to the number of connections that were waiting when we
|
||||
# entered this method, but this information is not available
|
||||
# (and rearranging this method to call accept() as many times
|
||||
# as possible before running any callbacks would have adverse
|
||||
# effects on load balancing in multiprocess configurations).
|
||||
# Instead, we use the (default) listen backlog as a rough
|
||||
# heuristic for the number of connections we can reasonably
|
||||
# accept at once.
|
||||
for i in xrange(_DEFAULT_BACKLOG):
|
||||
if removed[0]:
|
||||
# The socket was probably closed
|
||||
return
|
||||
try:
|
||||
connection, address = sock.accept()
|
||||
except socket.error as e:
|
||||
# _ERRNO_WOULDBLOCK indicate we have accepted every
|
||||
# connection that is available.
|
||||
if errno_from_exception(e) in _ERRNO_WOULDBLOCK:
|
||||
return
|
||||
# ECONNABORTED indicates that there was a connection
|
||||
# but it was closed while still in the accept queue.
|
||||
# (observed on FreeBSD).
|
||||
if errno_from_exception(e) == errno.ECONNABORTED:
|
||||
continue
|
||||
raise
|
||||
set_close_exec(connection.fileno())
|
||||
callback(connection, address)
|
||||
|
||||
def remove_handler():
|
||||
io_loop.remove_handler(sock)
|
||||
removed[0] = True
|
||||
|
||||
io_loop.add_handler(sock, accept_handler, IOLoop.READ)
|
||||
return remove_handler
|
||||
|
||||
|
||||
def is_valid_ip(ip):
|
||||
"""Returns true if the given string is a well-formed IP address.
|
||||
|
||||
Supports IPv4 and IPv6.
|
||||
"""
|
||||
if not ip or '\x00' in ip:
|
||||
# getaddrinfo resolves empty strings to localhost, and truncates
|
||||
# on zero bytes.
|
||||
return False
|
||||
try:
|
||||
res = socket.getaddrinfo(ip, 0, socket.AF_UNSPEC,
|
||||
socket.SOCK_STREAM,
|
||||
0, socket.AI_NUMERICHOST)
|
||||
return bool(res)
|
||||
except socket.gaierror as e:
|
||||
if e.args[0] == socket.EAI_NONAME:
|
||||
return False
|
||||
raise
|
||||
return True
|
||||
|
||||
|
||||
class Resolver(Configurable):
|
||||
"""Configurable asynchronous DNS resolver interface.
|
||||
|
||||
By default, a blocking implementation is used (which simply calls
|
||||
`socket.getaddrinfo`). An alternative implementation can be
|
||||
chosen with the `Resolver.configure <.Configurable.configure>`
|
||||
class method::
|
||||
|
||||
Resolver.configure('tornado.netutil.ThreadedResolver')
|
||||
|
||||
The implementations of this interface included with Tornado are
|
||||
|
||||
* `tornado.netutil.DefaultExecutorResolver`
|
||||
* `tornado.netutil.BlockingResolver` (deprecated)
|
||||
* `tornado.netutil.ThreadedResolver` (deprecated)
|
||||
* `tornado.netutil.OverrideResolver`
|
||||
* `tornado.platform.twisted.TwistedResolver`
|
||||
* `tornado.platform.caresresolver.CaresResolver`
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
The default implementation has changed from `BlockingResolver` to
|
||||
`DefaultExecutorResolver`.
|
||||
"""
|
||||
@classmethod
|
||||
def configurable_base(cls):
|
||||
return Resolver
|
||||
|
||||
@classmethod
|
||||
def configurable_default(cls):
|
||||
return DefaultExecutorResolver
|
||||
|
||||
def resolve(self, host, port, family=socket.AF_UNSPEC, callback=None):
|
||||
"""Resolves an address.
|
||||
|
||||
The ``host`` argument is a string which may be a hostname or a
|
||||
literal IP address.
|
||||
|
||||
Returns a `.Future` whose result is a list of (family,
|
||||
address) pairs, where address is a tuple suitable to pass to
|
||||
`socket.connect <socket.socket.connect>` (i.e. a ``(host,
|
||||
port)`` pair for IPv4; additional fields may be present for
|
||||
IPv6). If a ``callback`` is passed, it will be run with the
|
||||
result as an argument when it is complete.
|
||||
|
||||
:raises IOError: if the address cannot be resolved.
|
||||
|
||||
.. versionchanged:: 4.4
|
||||
Standardized all implementations to raise `IOError`.
|
||||
|
||||
.. deprecated:: 5.1
|
||||
The ``callback`` argument is deprecated and will be removed in 6.0.
|
||||
Use the returned awaitable object instead.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def close(self):
|
||||
"""Closes the `Resolver`, freeing any resources used.
|
||||
|
||||
.. versionadded:: 3.1
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def _resolve_addr(host, port, family=socket.AF_UNSPEC):
|
||||
# On Solaris, getaddrinfo fails if the given port is not found
|
||||
# in /etc/services and no socket type is given, so we must pass
|
||||
# one here. The socket type used here doesn't seem to actually
|
||||
# matter (we discard the one we get back in the results),
|
||||
# so the addresses we return should still be usable with SOCK_DGRAM.
|
||||
addrinfo = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM)
|
||||
results = []
|
||||
for family, socktype, proto, canonname, address in addrinfo:
|
||||
results.append((family, address))
|
||||
return results
|
||||
|
||||
|
||||
class DefaultExecutorResolver(Resolver):
|
||||
"""Resolver implementation using `.IOLoop.run_in_executor`.
|
||||
|
||||
.. versionadded:: 5.0
|
||||
"""
|
||||
@gen.coroutine
|
||||
def resolve(self, host, port, family=socket.AF_UNSPEC):
|
||||
result = yield IOLoop.current().run_in_executor(
|
||||
None, _resolve_addr, host, port, family)
|
||||
raise gen.Return(result)
|
||||
|
||||
|
||||
class ExecutorResolver(Resolver):
|
||||
"""Resolver implementation using a `concurrent.futures.Executor`.
|
||||
|
||||
Use this instead of `ThreadedResolver` when you require additional
|
||||
control over the executor being used.
|
||||
|
||||
The executor will be shut down when the resolver is closed unless
|
||||
``close_resolver=False``; use this if you want to reuse the same
|
||||
executor elsewhere.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
|
||||
|
||||
.. deprecated:: 5.0
|
||||
The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
|
||||
of this class.
|
||||
"""
|
||||
def initialize(self, executor=None, close_executor=True):
|
||||
self.io_loop = IOLoop.current()
|
||||
if executor is not None:
|
||||
self.executor = executor
|
||||
self.close_executor = close_executor
|
||||
else:
|
||||
self.executor = dummy_executor
|
||||
self.close_executor = False
|
||||
|
||||
def close(self):
|
||||
if self.close_executor:
|
||||
self.executor.shutdown()
|
||||
self.executor = None
|
||||
|
||||
@run_on_executor
|
||||
def resolve(self, host, port, family=socket.AF_UNSPEC):
|
||||
return _resolve_addr(host, port, family)
|
||||
|
||||
|
||||
class BlockingResolver(ExecutorResolver):
|
||||
"""Default `Resolver` implementation, using `socket.getaddrinfo`.
|
||||
|
||||
The `.IOLoop` will be blocked during the resolution, although the
|
||||
callback will not be run until the next `.IOLoop` iteration.
|
||||
|
||||
.. deprecated:: 5.0
|
||||
The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
|
||||
of this class.
|
||||
"""
|
||||
def initialize(self):
|
||||
super(BlockingResolver, self).initialize()
|
||||
|
||||
|
||||
class ThreadedResolver(ExecutorResolver):
|
||||
"""Multithreaded non-blocking `Resolver` implementation.
|
||||
|
||||
Requires the `concurrent.futures` package to be installed
|
||||
(available in the standard library since Python 3.2,
|
||||
installable with ``pip install futures`` in older versions).
|
||||
|
||||
The thread pool size can be configured with::
|
||||
|
||||
Resolver.configure('tornado.netutil.ThreadedResolver',
|
||||
num_threads=10)
|
||||
|
||||
.. versionchanged:: 3.1
|
||||
All ``ThreadedResolvers`` share a single thread pool, whose
|
||||
size is set by the first one to be created.
|
||||
|
||||
.. deprecated:: 5.0
|
||||
The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
|
||||
of this class.
|
||||
"""
|
||||
_threadpool = None # type: ignore
|
||||
_threadpool_pid = None # type: int
|
||||
|
||||
def initialize(self, num_threads=10):
|
||||
threadpool = ThreadedResolver._create_threadpool(num_threads)
|
||||
super(ThreadedResolver, self).initialize(
|
||||
executor=threadpool, close_executor=False)
|
||||
|
||||
@classmethod
|
||||
def _create_threadpool(cls, num_threads):
|
||||
pid = os.getpid()
|
||||
if cls._threadpool_pid != pid:
|
||||
# Threads cannot survive after a fork, so if our pid isn't what it
|
||||
# was when we created the pool then delete it.
|
||||
cls._threadpool = None
|
||||
if cls._threadpool is None:
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
cls._threadpool = ThreadPoolExecutor(num_threads)
|
||||
cls._threadpool_pid = pid
|
||||
return cls._threadpool
|
||||
|
||||
|
||||
class OverrideResolver(Resolver):
|
||||
"""Wraps a resolver with a mapping of overrides.
|
||||
|
||||
This can be used to make local DNS changes (e.g. for testing)
|
||||
without modifying system-wide settings.
|
||||
|
||||
The mapping can be in three formats::
|
||||
|
||||
{
|
||||
# Hostname to host or ip
|
||||
"example.com": "127.0.1.1",
|
||||
|
||||
# Host+port to host+port
|
||||
("login.example.com", 443): ("localhost", 1443),
|
||||
|
||||
# Host+port+address family to host+port
|
||||
("login.example.com", 443, socket.AF_INET6): ("::1", 1443),
|
||||
}
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
Added support for host-port-family triplets.
|
||||
"""
|
||||
def initialize(self, resolver, mapping):
|
||||
self.resolver = resolver
|
||||
self.mapping = mapping
|
||||
|
||||
def close(self):
|
||||
self.resolver.close()
|
||||
|
||||
def resolve(self, host, port, family=socket.AF_UNSPEC, *args, **kwargs):
|
||||
if (host, port, family) in self.mapping:
|
||||
host, port = self.mapping[(host, port, family)]
|
||||
elif (host, port) in self.mapping:
|
||||
host, port = self.mapping[(host, port)]
|
||||
elif host in self.mapping:
|
||||
host = self.mapping[host]
|
||||
return self.resolver.resolve(host, port, family, *args, **kwargs)
|
||||
|
||||
|
||||
# These are the keyword arguments to ssl.wrap_socket that must be translated
|
||||
# to their SSLContext equivalents (the other arguments are still passed
|
||||
# to SSLContext.wrap_socket).
|
||||
_SSL_CONTEXT_KEYWORDS = frozenset(['ssl_version', 'certfile', 'keyfile',
|
||||
'cert_reqs', 'ca_certs', 'ciphers'])
|
||||
|
||||
|
||||
def ssl_options_to_context(ssl_options):
|
||||
"""Try to convert an ``ssl_options`` dictionary to an
|
||||
`~ssl.SSLContext` object.
|
||||
|
||||
The ``ssl_options`` dictionary contains keywords to be passed to
|
||||
`ssl.wrap_socket`. In Python 2.7.9+, `ssl.SSLContext` objects can
|
||||
be used instead. This function converts the dict form to its
|
||||
`~ssl.SSLContext` equivalent, and may be used when a component which
|
||||
accepts both forms needs to upgrade to the `~ssl.SSLContext` version
|
||||
to use features like SNI or NPN.
|
||||
"""
|
||||
if isinstance(ssl_options, ssl.SSLContext):
|
||||
return ssl_options
|
||||
assert isinstance(ssl_options, dict)
|
||||
assert all(k in _SSL_CONTEXT_KEYWORDS for k in ssl_options), ssl_options
|
||||
# Can't use create_default_context since this interface doesn't
|
||||
# tell us client vs server.
|
||||
context = ssl.SSLContext(
|
||||
ssl_options.get('ssl_version', ssl.PROTOCOL_SSLv23))
|
||||
if 'certfile' in ssl_options:
|
||||
context.load_cert_chain(ssl_options['certfile'], ssl_options.get('keyfile', None))
|
||||
if 'cert_reqs' in ssl_options:
|
||||
context.verify_mode = ssl_options['cert_reqs']
|
||||
if 'ca_certs' in ssl_options:
|
||||
context.load_verify_locations(ssl_options['ca_certs'])
|
||||
if 'ciphers' in ssl_options:
|
||||
context.set_ciphers(ssl_options['ciphers'])
|
||||
if hasattr(ssl, 'OP_NO_COMPRESSION'):
|
||||
# Disable TLS compression to avoid CRIME and related attacks.
|
||||
# This constant depends on openssl version 1.0.
|
||||
# TODO: Do we need to do this ourselves or can we trust
|
||||
# the defaults?
|
||||
context.options |= ssl.OP_NO_COMPRESSION
|
||||
return context
|
||||
|
||||
|
||||
def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs):
|
||||
"""Returns an ``ssl.SSLSocket`` wrapping the given socket.
|
||||
|
||||
``ssl_options`` may be either an `ssl.SSLContext` object or a
|
||||
dictionary (as accepted by `ssl_options_to_context`). Additional
|
||||
keyword arguments are passed to ``wrap_socket`` (either the
|
||||
`~ssl.SSLContext` method or the `ssl` module function as
|
||||
appropriate).
|
||||
"""
|
||||
context = ssl_options_to_context(ssl_options)
|
||||
if ssl.HAS_SNI:
|
||||
# In python 3.4, wrap_socket only accepts the server_hostname
|
||||
# argument if HAS_SNI is true.
|
||||
# TODO: add a unittest (python added server-side SNI support in 3.4)
|
||||
# In the meantime it can be manually tested with
|
||||
# python3 -m tornado.httpclient https://sni.velox.ch
|
||||
return context.wrap_socket(socket, server_hostname=server_hostname,
|
||||
**kwargs)
|
||||
else:
|
||||
return context.wrap_socket(socket, **kwargs)
|
||||
Executable
+654
@@ -0,0 +1,654 @@
|
||||
#
|
||||
# Copyright 2009 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""A command line parsing module that lets modules define their own options.
|
||||
|
||||
This module is inspired by Google's `gflags
|
||||
<https://github.com/google/python-gflags>`_. The primary difference
|
||||
with libraries such as `argparse` is that a global registry is used so
|
||||
that options may be defined in any module (it also enables
|
||||
`tornado.log` by default). The rest of Tornado does not depend on this
|
||||
module, so feel free to use `argparse` or other configuration
|
||||
libraries if you prefer them.
|
||||
|
||||
Options must be defined with `tornado.options.define` before use,
|
||||
generally at the top level of a module. The options are then
|
||||
accessible as attributes of `tornado.options.options`::
|
||||
|
||||
# myapp/db.py
|
||||
from tornado.options import define, options
|
||||
|
||||
define("mysql_host", default="127.0.0.1:3306", help="Main user DB")
|
||||
define("memcache_hosts", default="127.0.0.1:11011", multiple=True,
|
||||
help="Main user memcache servers")
|
||||
|
||||
def connect():
|
||||
db = database.Connection(options.mysql_host)
|
||||
...
|
||||
|
||||
# myapp/server.py
|
||||
from tornado.options import define, options
|
||||
|
||||
define("port", default=8080, help="port to listen on")
|
||||
|
||||
def start_server():
|
||||
app = make_app()
|
||||
app.listen(options.port)
|
||||
|
||||
The ``main()`` method of your application does not need to be aware of all of
|
||||
the options used throughout your program; they are all automatically loaded
|
||||
when the modules are loaded. However, all modules that define options
|
||||
must have been imported before the command line is parsed.
|
||||
|
||||
Your ``main()`` method can parse the command line or parse a config file with
|
||||
either `parse_command_line` or `parse_config_file`::
|
||||
|
||||
import myapp.db, myapp.server
|
||||
import tornado.options
|
||||
|
||||
if __name__ == '__main__':
|
||||
tornado.options.parse_command_line()
|
||||
# or
|
||||
tornado.options.parse_config_file("/etc/server.conf")
|
||||
|
||||
.. note::
|
||||
|
||||
When using multiple ``parse_*`` functions, pass ``final=False`` to all
|
||||
but the last one, or side effects may occur twice (in particular,
|
||||
this can result in log messages being doubled).
|
||||
|
||||
`tornado.options.options` is a singleton instance of `OptionParser`, and
|
||||
the top-level functions in this module (`define`, `parse_command_line`, etc)
|
||||
simply call methods on it. You may create additional `OptionParser`
|
||||
instances to define isolated sets of options, such as for subcommands.
|
||||
|
||||
.. note::
|
||||
|
||||
By default, several options are defined that will configure the
|
||||
standard `logging` module when `parse_command_line` or `parse_config_file`
|
||||
are called. If you want Tornado to leave the logging configuration
|
||||
alone so you can manage it yourself, either pass ``--logging=none``
|
||||
on the command line or do the following to disable it in code::
|
||||
|
||||
from tornado.options import options, parse_command_line
|
||||
options.logging = None
|
||||
parse_command_line()
|
||||
|
||||
.. versionchanged:: 4.3
|
||||
Dashes and underscores are fully interchangeable in option names;
|
||||
options can be defined, set, and read with any mix of the two.
|
||||
Dashes are typical for command-line usage while config files require
|
||||
underscores.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import datetime
|
||||
import numbers
|
||||
import re
|
||||
import sys
|
||||
import os
|
||||
import textwrap
|
||||
|
||||
from tornado.escape import _unicode, native_str
|
||||
from tornado.log import define_logging_options
|
||||
from tornado import stack_context
|
||||
from tornado.util import basestring_type, exec_in
|
||||
|
||||
|
||||
class Error(Exception):
|
||||
"""Exception raised by errors in the options module."""
|
||||
pass
|
||||
|
||||
|
||||
class OptionParser(object):
|
||||
"""A collection of options, a dictionary with object-like access.
|
||||
|
||||
Normally accessed via static functions in the `tornado.options` module,
|
||||
which reference a global instance.
|
||||
"""
|
||||
def __init__(self):
|
||||
# we have to use self.__dict__ because we override setattr.
|
||||
self.__dict__['_options'] = {}
|
||||
self.__dict__['_parse_callbacks'] = []
|
||||
self.define("help", type=bool, help="show this help information",
|
||||
callback=self._help_callback)
|
||||
|
||||
def _normalize_name(self, name):
|
||||
return name.replace('_', '-')
|
||||
|
||||
def __getattr__(self, name):
|
||||
name = self._normalize_name(name)
|
||||
if isinstance(self._options.get(name), _Option):
|
||||
return self._options[name].value()
|
||||
raise AttributeError("Unrecognized option %r" % name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
name = self._normalize_name(name)
|
||||
if isinstance(self._options.get(name), _Option):
|
||||
return self._options[name].set(value)
|
||||
raise AttributeError("Unrecognized option %r" % name)
|
||||
|
||||
def __iter__(self):
|
||||
return (opt.name for opt in self._options.values())
|
||||
|
||||
def __contains__(self, name):
|
||||
name = self._normalize_name(name)
|
||||
return name in self._options
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.__getattr__(name)
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
return self.__setattr__(name, value)
|
||||
|
||||
def items(self):
|
||||
"""A sequence of (name, value) pairs.
|
||||
|
||||
.. versionadded:: 3.1
|
||||
"""
|
||||
return [(opt.name, opt.value()) for name, opt in self._options.items()]
|
||||
|
||||
def groups(self):
|
||||
"""The set of option-groups created by ``define``.
|
||||
|
||||
.. versionadded:: 3.1
|
||||
"""
|
||||
return set(opt.group_name for opt in self._options.values())
|
||||
|
||||
def group_dict(self, group):
|
||||
"""The names and values of options in a group.
|
||||
|
||||
Useful for copying options into Application settings::
|
||||
|
||||
from tornado.options import define, parse_command_line, options
|
||||
|
||||
define('template_path', group='application')
|
||||
define('static_path', group='application')
|
||||
|
||||
parse_command_line()
|
||||
|
||||
application = Application(
|
||||
handlers, **options.group_dict('application'))
|
||||
|
||||
.. versionadded:: 3.1
|
||||
"""
|
||||
return dict(
|
||||
(opt.name, opt.value()) for name, opt in self._options.items()
|
||||
if not group or group == opt.group_name)
|
||||
|
||||
def as_dict(self):
|
||||
"""The names and values of all options.
|
||||
|
||||
.. versionadded:: 3.1
|
||||
"""
|
||||
return dict(
|
||||
(opt.name, opt.value()) for name, opt in self._options.items())
|
||||
|
||||
def define(self, name, default=None, type=None, help=None, metavar=None,
|
||||
multiple=False, group=None, callback=None):
|
||||
"""Defines a new command line option.
|
||||
|
||||
``type`` can be any of `str`, `int`, `float`, `bool`,
|
||||
`~datetime.datetime`, or `~datetime.timedelta`. If no ``type``
|
||||
is given but a ``default`` is, ``type`` is the type of
|
||||
``default``. Otherwise, ``type`` defaults to `str`.
|
||||
|
||||
If ``multiple`` is True, the option value is a list of ``type``
|
||||
instead of an instance of ``type``.
|
||||
|
||||
``help`` and ``metavar`` are used to construct the
|
||||
automatically generated command line help string. The help
|
||||
message is formatted like::
|
||||
|
||||
--name=METAVAR help string
|
||||
|
||||
``group`` is used to group the defined options in logical
|
||||
groups. By default, command line options are grouped by the
|
||||
file in which they are defined.
|
||||
|
||||
Command line option names must be unique globally.
|
||||
|
||||
If a ``callback`` is given, it will be run with the new value whenever
|
||||
the option is changed. This can be used to combine command-line
|
||||
and file-based options::
|
||||
|
||||
define("config", type=str, help="path to config file",
|
||||
callback=lambda path: parse_config_file(path, final=False))
|
||||
|
||||
With this definition, options in the file specified by ``--config`` will
|
||||
override options set earlier on the command line, but can be overridden
|
||||
by later flags.
|
||||
|
||||
"""
|
||||
normalized = self._normalize_name(name)
|
||||
if normalized in self._options:
|
||||
raise Error("Option %r already defined in %s" %
|
||||
(normalized, self._options[normalized].file_name))
|
||||
frame = sys._getframe(0)
|
||||
options_file = frame.f_code.co_filename
|
||||
|
||||
# Can be called directly, or through top level define() fn, in which
|
||||
# case, step up above that frame to look for real caller.
|
||||
if (frame.f_back.f_code.co_filename == options_file and
|
||||
frame.f_back.f_code.co_name == 'define'):
|
||||
frame = frame.f_back
|
||||
|
||||
file_name = frame.f_back.f_code.co_filename
|
||||
if file_name == options_file:
|
||||
file_name = ""
|
||||
if type is None:
|
||||
if not multiple and default is not None:
|
||||
type = default.__class__
|
||||
else:
|
||||
type = str
|
||||
if group:
|
||||
group_name = group
|
||||
else:
|
||||
group_name = file_name
|
||||
option = _Option(name, file_name=file_name,
|
||||
default=default, type=type, help=help,
|
||||
metavar=metavar, multiple=multiple,
|
||||
group_name=group_name,
|
||||
callback=callback)
|
||||
self._options[normalized] = option
|
||||
|
||||
def parse_command_line(self, args=None, final=True):
|
||||
"""Parses all options given on the command line (defaults to
|
||||
`sys.argv`).
|
||||
|
||||
Options look like ``--option=value`` and are parsed according
|
||||
to their ``type``. For boolean options, ``--option`` is
|
||||
equivalent to ``--option=true``
|
||||
|
||||
If the option has ``multiple=True``, comma-separated values
|
||||
are accepted. For multi-value integer options, the syntax
|
||||
``x:y`` is also accepted and equivalent to ``range(x, y)``.
|
||||
|
||||
Note that ``args[0]`` is ignored since it is the program name
|
||||
in `sys.argv`.
|
||||
|
||||
We return a list of all arguments that are not parsed as options.
|
||||
|
||||
If ``final`` is ``False``, parse callbacks will not be run.
|
||||
This is useful for applications that wish to combine configurations
|
||||
from multiple sources.
|
||||
|
||||
"""
|
||||
if args is None:
|
||||
args = sys.argv
|
||||
remaining = []
|
||||
for i in range(1, len(args)):
|
||||
# All things after the last option are command line arguments
|
||||
if not args[i].startswith("-"):
|
||||
remaining = args[i:]
|
||||
break
|
||||
if args[i] == "--":
|
||||
remaining = args[i + 1:]
|
||||
break
|
||||
arg = args[i].lstrip("-")
|
||||
name, equals, value = arg.partition("=")
|
||||
name = self._normalize_name(name)
|
||||
if name not in self._options:
|
||||
self.print_help()
|
||||
raise Error('Unrecognized command line option: %r' % name)
|
||||
option = self._options[name]
|
||||
if not equals:
|
||||
if option.type == bool:
|
||||
value = "true"
|
||||
else:
|
||||
raise Error('Option %r requires a value' % name)
|
||||
option.parse(value)
|
||||
|
||||
if final:
|
||||
self.run_parse_callbacks()
|
||||
|
||||
return remaining
|
||||
|
||||
def parse_config_file(self, path, final=True):
|
||||
"""Parses and loads the config file at the given path.
|
||||
|
||||
The config file contains Python code that will be executed (so
|
||||
it is **not safe** to use untrusted config files). Anything in
|
||||
the global namespace that matches a defined option will be
|
||||
used to set that option's value.
|
||||
|
||||
Options may either be the specified type for the option or
|
||||
strings (in which case they will be parsed the same way as in
|
||||
`.parse_command_line`)
|
||||
|
||||
Example (using the options defined in the top-level docs of
|
||||
this module)::
|
||||
|
||||
port = 80
|
||||
mysql_host = 'mydb.example.com:3306'
|
||||
# Both lists and comma-separated strings are allowed for
|
||||
# multiple=True.
|
||||
memcache_hosts = ['cache1.example.com:11011',
|
||||
'cache2.example.com:11011']
|
||||
memcache_hosts = 'cache1.example.com:11011,cache2.example.com:11011'
|
||||
|
||||
If ``final`` is ``False``, parse callbacks will not be run.
|
||||
This is useful for applications that wish to combine configurations
|
||||
from multiple sources.
|
||||
|
||||
.. note::
|
||||
|
||||
`tornado.options` is primarily a command-line library.
|
||||
Config file support is provided for applications that wish
|
||||
to use it, but applications that prefer config files may
|
||||
wish to look at other libraries instead.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
Config files are now always interpreted as utf-8 instead of
|
||||
the system default encoding.
|
||||
|
||||
.. versionchanged:: 4.4
|
||||
The special variable ``__file__`` is available inside config
|
||||
files, specifying the absolute path to the config file itself.
|
||||
|
||||
.. versionchanged:: 5.1
|
||||
Added the ability to set options via strings in config files.
|
||||
|
||||
"""
|
||||
config = {'__file__': os.path.abspath(path)}
|
||||
with open(path, 'rb') as f:
|
||||
exec_in(native_str(f.read()), config, config)
|
||||
for name in config:
|
||||
normalized = self._normalize_name(name)
|
||||
if normalized in self._options:
|
||||
option = self._options[normalized]
|
||||
if option.multiple:
|
||||
if not isinstance(config[name], (list, str)):
|
||||
raise Error("Option %r is required to be a list of %s "
|
||||
"or a comma-separated string" %
|
||||
(option.name, option.type.__name__))
|
||||
|
||||
if type(config[name]) == str and option.type != str:
|
||||
option.parse(config[name])
|
||||
else:
|
||||
option.set(config[name])
|
||||
|
||||
if final:
|
||||
self.run_parse_callbacks()
|
||||
|
||||
def print_help(self, file=None):
|
||||
"""Prints all the command line options to stderr (or another file)."""
|
||||
if file is None:
|
||||
file = sys.stderr
|
||||
print("Usage: %s [OPTIONS]" % sys.argv[0], file=file)
|
||||
print("\nOptions:\n", file=file)
|
||||
by_group = {}
|
||||
for option in self._options.values():
|
||||
by_group.setdefault(option.group_name, []).append(option)
|
||||
|
||||
for filename, o in sorted(by_group.items()):
|
||||
if filename:
|
||||
print("\n%s options:\n" % os.path.normpath(filename), file=file)
|
||||
o.sort(key=lambda option: option.name)
|
||||
for option in o:
|
||||
# Always print names with dashes in a CLI context.
|
||||
prefix = self._normalize_name(option.name)
|
||||
if option.metavar:
|
||||
prefix += "=" + option.metavar
|
||||
description = option.help or ""
|
||||
if option.default is not None and option.default != '':
|
||||
description += " (default %s)" % option.default
|
||||
lines = textwrap.wrap(description, 79 - 35)
|
||||
if len(prefix) > 30 or len(lines) == 0:
|
||||
lines.insert(0, '')
|
||||
print(" --%-30s %s" % (prefix, lines[0]), file=file)
|
||||
for line in lines[1:]:
|
||||
print("%-34s %s" % (' ', line), file=file)
|
||||
print(file=file)
|
||||
|
||||
def _help_callback(self, value):
|
||||
if value:
|
||||
self.print_help()
|
||||
sys.exit(0)
|
||||
|
||||
def add_parse_callback(self, callback):
|
||||
"""Adds a parse callback, to be invoked when option parsing is done."""
|
||||
self._parse_callbacks.append(stack_context.wrap(callback))
|
||||
|
||||
def run_parse_callbacks(self):
|
||||
for callback in self._parse_callbacks:
|
||||
callback()
|
||||
|
||||
def mockable(self):
|
||||
"""Returns a wrapper around self that is compatible with
|
||||
`mock.patch <unittest.mock.patch>`.
|
||||
|
||||
The `mock.patch <unittest.mock.patch>` function (included in
|
||||
the standard library `unittest.mock` package since Python 3.3,
|
||||
or in the third-party ``mock`` package for older versions of
|
||||
Python) is incompatible with objects like ``options`` that
|
||||
override ``__getattr__`` and ``__setattr__``. This function
|
||||
returns an object that can be used with `mock.patch.object
|
||||
<unittest.mock.patch.object>` to modify option values::
|
||||
|
||||
with mock.patch.object(options.mockable(), 'name', value):
|
||||
assert options.name == value
|
||||
"""
|
||||
return _Mockable(self)
|
||||
|
||||
|
||||
class _Mockable(object):
|
||||
"""`mock.patch` compatible wrapper for `OptionParser`.
|
||||
|
||||
As of ``mock`` version 1.0.1, when an object uses ``__getattr__``
|
||||
hooks instead of ``__dict__``, ``patch.__exit__`` tries to delete
|
||||
the attribute it set instead of setting a new one (assuming that
|
||||
the object does not catpure ``__setattr__``, so the patch
|
||||
created a new attribute in ``__dict__``).
|
||||
|
||||
_Mockable's getattr and setattr pass through to the underlying
|
||||
OptionParser, and delattr undoes the effect of a previous setattr.
|
||||
"""
|
||||
def __init__(self, options):
|
||||
# Modify __dict__ directly to bypass __setattr__
|
||||
self.__dict__['_options'] = options
|
||||
self.__dict__['_originals'] = {}
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._options, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
assert name not in self._originals, "don't reuse mockable objects"
|
||||
self._originals[name] = getattr(self._options, name)
|
||||
setattr(self._options, name, value)
|
||||
|
||||
def __delattr__(self, name):
|
||||
setattr(self._options, name, self._originals.pop(name))
|
||||
|
||||
|
||||
class _Option(object):
|
||||
UNSET = object()
|
||||
|
||||
def __init__(self, name, default=None, type=basestring_type, help=None,
|
||||
metavar=None, multiple=False, file_name=None, group_name=None,
|
||||
callback=None):
|
||||
if default is None and multiple:
|
||||
default = []
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.help = help
|
||||
self.metavar = metavar
|
||||
self.multiple = multiple
|
||||
self.file_name = file_name
|
||||
self.group_name = group_name
|
||||
self.callback = callback
|
||||
self.default = default
|
||||
self._value = _Option.UNSET
|
||||
|
||||
def value(self):
|
||||
return self.default if self._value is _Option.UNSET else self._value
|
||||
|
||||
def parse(self, value):
|
||||
_parse = {
|
||||
datetime.datetime: self._parse_datetime,
|
||||
datetime.timedelta: self._parse_timedelta,
|
||||
bool: self._parse_bool,
|
||||
basestring_type: self._parse_string,
|
||||
}.get(self.type, self.type)
|
||||
if self.multiple:
|
||||
self._value = []
|
||||
for part in value.split(","):
|
||||
if issubclass(self.type, numbers.Integral):
|
||||
# allow ranges of the form X:Y (inclusive at both ends)
|
||||
lo, _, hi = part.partition(":")
|
||||
lo = _parse(lo)
|
||||
hi = _parse(hi) if hi else lo
|
||||
self._value.extend(range(lo, hi + 1))
|
||||
else:
|
||||
self._value.append(_parse(part))
|
||||
else:
|
||||
self._value = _parse(value)
|
||||
if self.callback is not None:
|
||||
self.callback(self._value)
|
||||
return self.value()
|
||||
|
||||
def set(self, value):
|
||||
if self.multiple:
|
||||
if not isinstance(value, list):
|
||||
raise Error("Option %r is required to be a list of %s" %
|
||||
(self.name, self.type.__name__))
|
||||
for item in value:
|
||||
if item is not None and not isinstance(item, self.type):
|
||||
raise Error("Option %r is required to be a list of %s" %
|
||||
(self.name, self.type.__name__))
|
||||
else:
|
||||
if value is not None and not isinstance(value, self.type):
|
||||
raise Error("Option %r is required to be a %s (%s given)" %
|
||||
(self.name, self.type.__name__, type(value)))
|
||||
self._value = value
|
||||
if self.callback is not None:
|
||||
self.callback(self._value)
|
||||
|
||||
# Supported date/time formats in our options
|
||||
_DATETIME_FORMATS = [
|
||||
"%a %b %d %H:%M:%S %Y",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y-%m-%dT%H:%M",
|
||||
"%Y%m%d %H:%M:%S",
|
||||
"%Y%m%d %H:%M",
|
||||
"%Y-%m-%d",
|
||||
"%Y%m%d",
|
||||
"%H:%M:%S",
|
||||
"%H:%M",
|
||||
]
|
||||
|
||||
def _parse_datetime(self, value):
|
||||
for format in self._DATETIME_FORMATS:
|
||||
try:
|
||||
return datetime.datetime.strptime(value, format)
|
||||
except ValueError:
|
||||
pass
|
||||
raise Error('Unrecognized date/time format: %r' % value)
|
||||
|
||||
_TIMEDELTA_ABBREV_DICT = {
|
||||
'h': 'hours',
|
||||
'm': 'minutes',
|
||||
'min': 'minutes',
|
||||
's': 'seconds',
|
||||
'sec': 'seconds',
|
||||
'ms': 'milliseconds',
|
||||
'us': 'microseconds',
|
||||
'd': 'days',
|
||||
'w': 'weeks',
|
||||
}
|
||||
|
||||
_FLOAT_PATTERN = r'[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?'
|
||||
|
||||
_TIMEDELTA_PATTERN = re.compile(
|
||||
r'\s*(%s)\s*(\w*)\s*' % _FLOAT_PATTERN, re.IGNORECASE)
|
||||
|
||||
def _parse_timedelta(self, value):
|
||||
try:
|
||||
sum = datetime.timedelta()
|
||||
start = 0
|
||||
while start < len(value):
|
||||
m = self._TIMEDELTA_PATTERN.match(value, start)
|
||||
if not m:
|
||||
raise Exception()
|
||||
num = float(m.group(1))
|
||||
units = m.group(2) or 'seconds'
|
||||
units = self._TIMEDELTA_ABBREV_DICT.get(units, units)
|
||||
sum += datetime.timedelta(**{units: num})
|
||||
start = m.end()
|
||||
return sum
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def _parse_bool(self, value):
|
||||
return value.lower() not in ("false", "0", "f")
|
||||
|
||||
def _parse_string(self, value):
|
||||
return _unicode(value)
|
||||
|
||||
|
||||
options = OptionParser()
|
||||
"""Global options object.
|
||||
|
||||
All defined options are available as attributes on this object.
|
||||
"""
|
||||
|
||||
|
||||
def define(name, default=None, type=None, help=None, metavar=None,
|
||||
multiple=False, group=None, callback=None):
|
||||
"""Defines an option in the global namespace.
|
||||
|
||||
See `OptionParser.define`.
|
||||
"""
|
||||
return options.define(name, default=default, type=type, help=help,
|
||||
metavar=metavar, multiple=multiple, group=group,
|
||||
callback=callback)
|
||||
|
||||
|
||||
def parse_command_line(args=None, final=True):
|
||||
"""Parses global options from the command line.
|
||||
|
||||
See `OptionParser.parse_command_line`.
|
||||
"""
|
||||
return options.parse_command_line(args, final=final)
|
||||
|
||||
|
||||
def parse_config_file(path, final=True):
|
||||
"""Parses global options from a config file.
|
||||
|
||||
See `OptionParser.parse_config_file`.
|
||||
"""
|
||||
return options.parse_config_file(path, final=final)
|
||||
|
||||
|
||||
def print_help(file=None):
|
||||
"""Prints all the command line options to stderr (or another file).
|
||||
|
||||
See `OptionParser.print_help`.
|
||||
"""
|
||||
return options.print_help(file)
|
||||
|
||||
|
||||
def add_parse_callback(callback):
|
||||
"""Adds a parse callback, to be invoked when option parsing is done.
|
||||
|
||||
See `OptionParser.add_parse_callback`
|
||||
"""
|
||||
options.add_parse_callback(callback)
|
||||
|
||||
|
||||
# Default options
|
||||
define_logging_options(options)
|
||||
Executable
Executable
+299
@@ -0,0 +1,299 @@
|
||||
"""Bridges between the `asyncio` module and Tornado IOLoop.
|
||||
|
||||
.. versionadded:: 3.2
|
||||
|
||||
This module integrates Tornado with the ``asyncio`` module introduced
|
||||
in Python 3.4. This makes it possible to combine the two libraries on
|
||||
the same event loop.
|
||||
|
||||
.. deprecated:: 5.0
|
||||
|
||||
While the code in this module is still used, it is now enabled
|
||||
automatically when `asyncio` is available, so applications should
|
||||
no longer need to refer to this module directly.
|
||||
|
||||
.. note::
|
||||
|
||||
Tornado requires the `~asyncio.AbstractEventLoop.add_reader` family of
|
||||
methods, so it is not compatible with the `~asyncio.ProactorEventLoop` on
|
||||
Windows. Use the `~asyncio.SelectorEventLoop` instead.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import functools
|
||||
|
||||
from tornado.gen import convert_yielded
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado import stack_context
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
class BaseAsyncIOLoop(IOLoop):
|
||||
def initialize(self, asyncio_loop, **kwargs):
|
||||
self.asyncio_loop = asyncio_loop
|
||||
# Maps fd to (fileobj, handler function) pair (as in IOLoop.add_handler)
|
||||
self.handlers = {}
|
||||
# Set of fds listening for reads/writes
|
||||
self.readers = set()
|
||||
self.writers = set()
|
||||
self.closing = False
|
||||
# If an asyncio loop was closed through an asyncio interface
|
||||
# instead of IOLoop.close(), we'd never hear about it and may
|
||||
# have left a dangling reference in our map. In case an
|
||||
# application (or, more likely, a test suite) creates and
|
||||
# destroys a lot of event loops in this way, check here to
|
||||
# ensure that we don't have a lot of dead loops building up in
|
||||
# the map.
|
||||
#
|
||||
# TODO(bdarnell): consider making self.asyncio_loop a weakref
|
||||
# for AsyncIOMainLoop and make _ioloop_for_asyncio a
|
||||
# WeakKeyDictionary.
|
||||
for loop in list(IOLoop._ioloop_for_asyncio):
|
||||
if loop.is_closed():
|
||||
del IOLoop._ioloop_for_asyncio[loop]
|
||||
IOLoop._ioloop_for_asyncio[asyncio_loop] = self
|
||||
super(BaseAsyncIOLoop, self).initialize(**kwargs)
|
||||
|
||||
def close(self, all_fds=False):
|
||||
self.closing = True
|
||||
for fd in list(self.handlers):
|
||||
fileobj, handler_func = self.handlers[fd]
|
||||
self.remove_handler(fd)
|
||||
if all_fds:
|
||||
self.close_fd(fileobj)
|
||||
# Remove the mapping before closing the asyncio loop. If this
|
||||
# happened in the other order, we could race against another
|
||||
# initialize() call which would see the closed asyncio loop,
|
||||
# assume it was closed from the asyncio side, and do this
|
||||
# cleanup for us, leading to a KeyError.
|
||||
del IOLoop._ioloop_for_asyncio[self.asyncio_loop]
|
||||
self.asyncio_loop.close()
|
||||
|
||||
def add_handler(self, fd, handler, events):
|
||||
fd, fileobj = self.split_fd(fd)
|
||||
if fd in self.handlers:
|
||||
raise ValueError("fd %s added twice" % fd)
|
||||
self.handlers[fd] = (fileobj, stack_context.wrap(handler))
|
||||
if events & IOLoop.READ:
|
||||
self.asyncio_loop.add_reader(
|
||||
fd, self._handle_events, fd, IOLoop.READ)
|
||||
self.readers.add(fd)
|
||||
if events & IOLoop.WRITE:
|
||||
self.asyncio_loop.add_writer(
|
||||
fd, self._handle_events, fd, IOLoop.WRITE)
|
||||
self.writers.add(fd)
|
||||
|
||||
def update_handler(self, fd, events):
|
||||
fd, fileobj = self.split_fd(fd)
|
||||
if events & IOLoop.READ:
|
||||
if fd not in self.readers:
|
||||
self.asyncio_loop.add_reader(
|
||||
fd, self._handle_events, fd, IOLoop.READ)
|
||||
self.readers.add(fd)
|
||||
else:
|
||||
if fd in self.readers:
|
||||
self.asyncio_loop.remove_reader(fd)
|
||||
self.readers.remove(fd)
|
||||
if events & IOLoop.WRITE:
|
||||
if fd not in self.writers:
|
||||
self.asyncio_loop.add_writer(
|
||||
fd, self._handle_events, fd, IOLoop.WRITE)
|
||||
self.writers.add(fd)
|
||||
else:
|
||||
if fd in self.writers:
|
||||
self.asyncio_loop.remove_writer(fd)
|
||||
self.writers.remove(fd)
|
||||
|
||||
def remove_handler(self, fd):
|
||||
fd, fileobj = self.split_fd(fd)
|
||||
if fd not in self.handlers:
|
||||
return
|
||||
if fd in self.readers:
|
||||
self.asyncio_loop.remove_reader(fd)
|
||||
self.readers.remove(fd)
|
||||
if fd in self.writers:
|
||||
self.asyncio_loop.remove_writer(fd)
|
||||
self.writers.remove(fd)
|
||||
del self.handlers[fd]
|
||||
|
||||
def _handle_events(self, fd, events):
|
||||
fileobj, handler_func = self.handlers[fd]
|
||||
handler_func(fileobj, events)
|
||||
|
||||
def start(self):
|
||||
try:
|
||||
old_loop = asyncio.get_event_loop()
|
||||
except (RuntimeError, AssertionError):
|
||||
old_loop = None
|
||||
try:
|
||||
self._setup_logging()
|
||||
asyncio.set_event_loop(self.asyncio_loop)
|
||||
self.asyncio_loop.run_forever()
|
||||
finally:
|
||||
asyncio.set_event_loop(old_loop)
|
||||
|
||||
def stop(self):
|
||||
self.asyncio_loop.stop()
|
||||
|
||||
def call_at(self, when, callback, *args, **kwargs):
|
||||
# asyncio.call_at supports *args but not **kwargs, so bind them here.
|
||||
# We do not synchronize self.time and asyncio_loop.time, so
|
||||
# convert from absolute to relative.
|
||||
return self.asyncio_loop.call_later(
|
||||
max(0, when - self.time()), self._run_callback,
|
||||
functools.partial(stack_context.wrap(callback), *args, **kwargs))
|
||||
|
||||
def remove_timeout(self, timeout):
|
||||
timeout.cancel()
|
||||
|
||||
def add_callback(self, callback, *args, **kwargs):
|
||||
try:
|
||||
self.asyncio_loop.call_soon_threadsafe(
|
||||
self._run_callback,
|
||||
functools.partial(stack_context.wrap(callback), *args, **kwargs))
|
||||
except RuntimeError:
|
||||
# "Event loop is closed". Swallow the exception for
|
||||
# consistency with PollIOLoop (and logical consistency
|
||||
# with the fact that we can't guarantee that an
|
||||
# add_callback that completes without error will
|
||||
# eventually execute).
|
||||
pass
|
||||
|
||||
add_callback_from_signal = add_callback
|
||||
|
||||
def run_in_executor(self, executor, func, *args):
|
||||
return self.asyncio_loop.run_in_executor(executor, func, *args)
|
||||
|
||||
def set_default_executor(self, executor):
|
||||
return self.asyncio_loop.set_default_executor(executor)
|
||||
|
||||
|
||||
class AsyncIOMainLoop(BaseAsyncIOLoop):
|
||||
"""``AsyncIOMainLoop`` creates an `.IOLoop` that corresponds to the
|
||||
current ``asyncio`` event loop (i.e. the one returned by
|
||||
``asyncio.get_event_loop()``).
|
||||
|
||||
.. deprecated:: 5.0
|
||||
|
||||
Now used automatically when appropriate; it is no longer necessary
|
||||
to refer to this class directly.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
|
||||
Closing an `AsyncIOMainLoop` now closes the underlying asyncio loop.
|
||||
"""
|
||||
def initialize(self, **kwargs):
|
||||
super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(), **kwargs)
|
||||
|
||||
def make_current(self):
|
||||
# AsyncIOMainLoop already refers to the current asyncio loop so
|
||||
# nothing to do here.
|
||||
pass
|
||||
|
||||
|
||||
class AsyncIOLoop(BaseAsyncIOLoop):
|
||||
"""``AsyncIOLoop`` is an `.IOLoop` that runs on an ``asyncio`` event loop.
|
||||
This class follows the usual Tornado semantics for creating new
|
||||
``IOLoops``; these loops are not necessarily related to the
|
||||
``asyncio`` default event loop.
|
||||
|
||||
Each ``AsyncIOLoop`` creates a new ``asyncio.EventLoop``; this object
|
||||
can be accessed with the ``asyncio_loop`` attribute.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
|
||||
When an ``AsyncIOLoop`` becomes the current `.IOLoop`, it also sets
|
||||
the current `asyncio` event loop.
|
||||
|
||||
.. deprecated:: 5.0
|
||||
|
||||
Now used automatically when appropriate; it is no longer necessary
|
||||
to refer to this class directly.
|
||||
"""
|
||||
def initialize(self, **kwargs):
|
||||
self.is_current = False
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
super(AsyncIOLoop, self).initialize(loop, **kwargs)
|
||||
except Exception:
|
||||
# If initialize() does not succeed (taking ownership of the loop),
|
||||
# we have to close it.
|
||||
loop.close()
|
||||
raise
|
||||
|
||||
def close(self, all_fds=False):
|
||||
if self.is_current:
|
||||
self.clear_current()
|
||||
super(AsyncIOLoop, self).close(all_fds=all_fds)
|
||||
|
||||
def make_current(self):
|
||||
if not self.is_current:
|
||||
try:
|
||||
self.old_asyncio = asyncio.get_event_loop()
|
||||
except (RuntimeError, AssertionError):
|
||||
self.old_asyncio = None
|
||||
self.is_current = True
|
||||
asyncio.set_event_loop(self.asyncio_loop)
|
||||
|
||||
def _clear_current_hook(self):
|
||||
if self.is_current:
|
||||
asyncio.set_event_loop(self.old_asyncio)
|
||||
self.is_current = False
|
||||
|
||||
|
||||
def to_tornado_future(asyncio_future):
|
||||
"""Convert an `asyncio.Future` to a `tornado.concurrent.Future`.
|
||||
|
||||
.. versionadded:: 4.1
|
||||
|
||||
.. deprecated:: 5.0
|
||||
Tornado ``Futures`` have been merged with `asyncio.Future`,
|
||||
so this method is now a no-op.
|
||||
"""
|
||||
return asyncio_future
|
||||
|
||||
|
||||
def to_asyncio_future(tornado_future):
|
||||
"""Convert a Tornado yieldable object to an `asyncio.Future`.
|
||||
|
||||
.. versionadded:: 4.1
|
||||
|
||||
.. versionchanged:: 4.3
|
||||
Now accepts any yieldable object, not just
|
||||
`tornado.concurrent.Future`.
|
||||
|
||||
.. deprecated:: 5.0
|
||||
Tornado ``Futures`` have been merged with `asyncio.Future`,
|
||||
so this method is now equivalent to `tornado.gen.convert_yielded`.
|
||||
"""
|
||||
return convert_yielded(tornado_future)
|
||||
|
||||
|
||||
class AnyThreadEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
|
||||
"""Event loop policy that allows loop creation on any thread.
|
||||
|
||||
The default `asyncio` event loop policy only automatically creates
|
||||
event loops in the main threads. Other threads must create event
|
||||
loops explicitly or `asyncio.get_event_loop` (and therefore
|
||||
`.IOLoop.current`) will fail. Installing this policy allows event
|
||||
loops to be created automatically on any thread, matching the
|
||||
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
|
||||
|
||||
Usage::
|
||||
|
||||
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||
|
||||
.. versionadded:: 5.0
|
||||
|
||||
"""
|
||||
def get_event_loop(self):
|
||||
try:
|
||||
return super().get_event_loop()
|
||||
except (RuntimeError, AssertionError):
|
||||
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
|
||||
# and changed to a RuntimeError in 3.4.3.
|
||||
# "There is no current event loop in thread %r"
|
||||
loop = self.new_event_loop()
|
||||
self.set_event_loop(loop)
|
||||
return loop
|
||||
Executable
+58
@@ -0,0 +1,58 @@
|
||||
#
|
||||
# Copyright 2011 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Implementation of platform-specific functionality.
|
||||
|
||||
For each function or class described in `tornado.platform.interface`,
|
||||
the appropriate platform-specific implementation exists in this module.
|
||||
Most code that needs access to this functionality should do e.g.::
|
||||
|
||||
from tornado.platform.auto import set_close_exec
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
|
||||
if 'APPENGINE_RUNTIME' in os.environ:
|
||||
from tornado.platform.common import Waker
|
||||
|
||||
def set_close_exec(fd):
|
||||
pass
|
||||
elif os.name == 'nt':
|
||||
from tornado.platform.common import Waker
|
||||
from tornado.platform.windows import set_close_exec
|
||||
else:
|
||||
from tornado.platform.posix import set_close_exec, Waker
|
||||
|
||||
try:
|
||||
# monotime monkey-patches the time module to have a monotonic function
|
||||
# in versions of python before 3.3.
|
||||
import monotime
|
||||
# Silence pyflakes warning about this unused import
|
||||
monotime
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
# monotonic can provide a monotonic function in versions of python before
|
||||
# 3.3, too.
|
||||
from monotonic import monotonic as monotonic_time
|
||||
except ImportError:
|
||||
try:
|
||||
from time import monotonic as monotonic_time
|
||||
except ImportError:
|
||||
monotonic_time = None
|
||||
|
||||
__all__ = ['Waker', 'set_close_exec', 'monotonic_time']
|
||||
Executable
+4
@@ -0,0 +1,4 @@
|
||||
# auto.py is full of patterns mypy doesn't like, so for type checking
|
||||
# purposes we replace it with interface.py.
|
||||
|
||||
from .interface import *
|
||||
Executable
+79
@@ -0,0 +1,79 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import pycares # type: ignore
|
||||
import socket
|
||||
|
||||
from tornado.concurrent import Future
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.netutil import Resolver, is_valid_ip
|
||||
|
||||
|
||||
class CaresResolver(Resolver):
|
||||
"""Name resolver based on the c-ares library.
|
||||
|
||||
This is a non-blocking and non-threaded resolver. It may not produce
|
||||
the same results as the system resolver, but can be used for non-blocking
|
||||
resolution when threads cannot be used.
|
||||
|
||||
c-ares fails to resolve some names when ``family`` is ``AF_UNSPEC``,
|
||||
so it is only recommended for use in ``AF_INET`` (i.e. IPv4). This is
|
||||
the default for ``tornado.simple_httpclient``, but other libraries
|
||||
may default to ``AF_UNSPEC``.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
|
||||
"""
|
||||
def initialize(self):
|
||||
self.io_loop = IOLoop.current()
|
||||
self.channel = pycares.Channel(sock_state_cb=self._sock_state_cb)
|
||||
self.fds = {}
|
||||
|
||||
def _sock_state_cb(self, fd, readable, writable):
|
||||
state = ((IOLoop.READ if readable else 0) |
|
||||
(IOLoop.WRITE if writable else 0))
|
||||
if not state:
|
||||
self.io_loop.remove_handler(fd)
|
||||
del self.fds[fd]
|
||||
elif fd in self.fds:
|
||||
self.io_loop.update_handler(fd, state)
|
||||
self.fds[fd] = state
|
||||
else:
|
||||
self.io_loop.add_handler(fd, self._handle_events, state)
|
||||
self.fds[fd] = state
|
||||
|
||||
def _handle_events(self, fd, events):
|
||||
read_fd = pycares.ARES_SOCKET_BAD
|
||||
write_fd = pycares.ARES_SOCKET_BAD
|
||||
if events & IOLoop.READ:
|
||||
read_fd = fd
|
||||
if events & IOLoop.WRITE:
|
||||
write_fd = fd
|
||||
self.channel.process_fd(read_fd, write_fd)
|
||||
|
||||
@gen.coroutine
|
||||
def resolve(self, host, port, family=0):
|
||||
if is_valid_ip(host):
|
||||
addresses = [host]
|
||||
else:
|
||||
# gethostbyname doesn't take callback as a kwarg
|
||||
fut = Future()
|
||||
self.channel.gethostbyname(host, family,
|
||||
lambda result, error: fut.set_result((result, error)))
|
||||
result, error = yield fut
|
||||
if error:
|
||||
raise IOError('C-Ares returned error %s: %s while resolving %s' %
|
||||
(error, pycares.errno.strerror(error), host))
|
||||
addresses = result.addresses
|
||||
addrinfo = []
|
||||
for address in addresses:
|
||||
if '.' in address:
|
||||
address_family = socket.AF_INET
|
||||
elif ':' in address:
|
||||
address_family = socket.AF_INET6
|
||||
else:
|
||||
address_family = socket.AF_UNSPEC
|
||||
if family != socket.AF_UNSPEC and family != address_family:
|
||||
raise IOError('Requested socket family %d but got %d' %
|
||||
(family, address_family))
|
||||
addrinfo.append((address_family, (address, port)))
|
||||
raise gen.Return(addrinfo)
|
||||
Executable
+113
@@ -0,0 +1,113 @@
|
||||
"""Lowest-common-denominator implementations of platform functionality."""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import errno
|
||||
import socket
|
||||
import time
|
||||
|
||||
from tornado.platform import interface
|
||||
from tornado.util import errno_from_exception
|
||||
|
||||
|
||||
def try_close(f):
|
||||
# Avoid issue #875 (race condition when using the file in another
|
||||
# thread).
|
||||
for i in range(10):
|
||||
try:
|
||||
f.close()
|
||||
except IOError:
|
||||
# Yield to another thread
|
||||
time.sleep(1e-3)
|
||||
else:
|
||||
break
|
||||
# Try a last time and let raise
|
||||
f.close()
|
||||
|
||||
|
||||
class Waker(interface.Waker):
|
||||
"""Create an OS independent asynchronous pipe.
|
||||
|
||||
For use on platforms that don't have os.pipe() (or where pipes cannot
|
||||
be passed to select()), but do have sockets. This includes Windows
|
||||
and Jython.
|
||||
"""
|
||||
def __init__(self):
|
||||
from .auto import set_close_exec
|
||||
# Based on Zope select_trigger.py:
|
||||
# https://github.com/zopefoundation/Zope/blob/master/src/ZServer/medusa/thread/select_trigger.py
|
||||
|
||||
self.writer = socket.socket()
|
||||
set_close_exec(self.writer.fileno())
|
||||
# Disable buffering -- pulling the trigger sends 1 byte,
|
||||
# and we want that sent immediately, to wake up ASAP.
|
||||
self.writer.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
|
||||
count = 0
|
||||
while 1:
|
||||
count += 1
|
||||
# Bind to a local port; for efficiency, let the OS pick
|
||||
# a free port for us.
|
||||
# Unfortunately, stress tests showed that we may not
|
||||
# be able to connect to that port ("Address already in
|
||||
# use") despite that the OS picked it. This appears
|
||||
# to be a race bug in the Windows socket implementation.
|
||||
# So we loop until a connect() succeeds (almost always
|
||||
# on the first try). See the long thread at
|
||||
# http://mail.zope.org/pipermail/zope/2005-July/160433.html
|
||||
# for hideous details.
|
||||
a = socket.socket()
|
||||
set_close_exec(a.fileno())
|
||||
a.bind(("127.0.0.1", 0))
|
||||
a.listen(1)
|
||||
connect_address = a.getsockname() # assigned (host, port) pair
|
||||
try:
|
||||
self.writer.connect(connect_address)
|
||||
break # success
|
||||
except socket.error as detail:
|
||||
if (not hasattr(errno, 'WSAEADDRINUSE') or
|
||||
errno_from_exception(detail) != errno.WSAEADDRINUSE):
|
||||
# "Address already in use" is the only error
|
||||
# I've seen on two WinXP Pro SP2 boxes, under
|
||||
# Pythons 2.3.5 and 2.4.1.
|
||||
raise
|
||||
# (10048, 'Address already in use')
|
||||
# assert count <= 2 # never triggered in Tim's tests
|
||||
if count >= 10: # I've never seen it go above 2
|
||||
a.close()
|
||||
self.writer.close()
|
||||
raise socket.error("Cannot bind trigger!")
|
||||
# Close `a` and try again. Note: I originally put a short
|
||||
# sleep() here, but it didn't appear to help or hurt.
|
||||
a.close()
|
||||
|
||||
self.reader, addr = a.accept()
|
||||
set_close_exec(self.reader.fileno())
|
||||
self.reader.setblocking(0)
|
||||
self.writer.setblocking(0)
|
||||
a.close()
|
||||
self.reader_fd = self.reader.fileno()
|
||||
|
||||
def fileno(self):
|
||||
return self.reader.fileno()
|
||||
|
||||
def write_fileno(self):
|
||||
return self.writer.fileno()
|
||||
|
||||
def wake(self):
|
||||
try:
|
||||
self.writer.send(b"x")
|
||||
except (IOError, socket.error, ValueError):
|
||||
pass
|
||||
|
||||
def consume(self):
|
||||
try:
|
||||
while True:
|
||||
result = self.reader.recv(1024)
|
||||
if not result:
|
||||
break
|
||||
except (IOError, socket.error):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self.reader.close()
|
||||
try_close(self.writer)
|
||||
Executable
+25
@@ -0,0 +1,25 @@
|
||||
#
|
||||
# Copyright 2012 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""EPoll-based IOLoop implementation for Linux systems."""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import select
|
||||
|
||||
from tornado.ioloop import PollIOLoop
|
||||
|
||||
|
||||
class EPollIOLoop(PollIOLoop):
|
||||
def initialize(self, **kwargs):
|
||||
super(EPollIOLoop, self).initialize(impl=select.epoll(), **kwargs)
|
||||
Executable
+66
@@ -0,0 +1,66 @@
|
||||
#
|
||||
# Copyright 2011 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Interfaces for platform-specific functionality.
|
||||
|
||||
This module exists primarily for documentation purposes and as base classes
|
||||
for other tornado.platform modules. Most code should import the appropriate
|
||||
implementation from `tornado.platform.auto`.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
|
||||
def set_close_exec(fd):
|
||||
"""Sets the close-on-exec bit (``FD_CLOEXEC``)for a file descriptor."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Waker(object):
|
||||
"""A socket-like object that can wake another thread from ``select()``.
|
||||
|
||||
The `~tornado.ioloop.IOLoop` will add the Waker's `fileno()` to
|
||||
its ``select`` (or ``epoll`` or ``kqueue``) calls. When another
|
||||
thread wants to wake up the loop, it calls `wake`. Once it has woken
|
||||
up, it will call `consume` to do any necessary per-wake cleanup. When
|
||||
the ``IOLoop`` is closed, it closes its waker too.
|
||||
"""
|
||||
def fileno(self):
|
||||
"""Returns the read file descriptor for this waker.
|
||||
|
||||
Must be suitable for use with ``select()`` or equivalent on the
|
||||
local platform.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def write_fileno(self):
|
||||
"""Returns the write file descriptor for this waker."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def wake(self):
|
||||
"""Triggers activity on the waker's file descriptor."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def consume(self):
|
||||
"""Called after the listen has woken up to do any necessary cleanup."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def close(self):
|
||||
"""Closes the waker's file descriptor(s)."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def monotonic_time():
|
||||
raise NotImplementedError()
|
||||
Executable
+90
@@ -0,0 +1,90 @@
|
||||
#
|
||||
# Copyright 2012 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""KQueue-based IOLoop implementation for BSD/Mac systems."""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import select
|
||||
|
||||
from tornado.ioloop import IOLoop, PollIOLoop
|
||||
|
||||
assert hasattr(select, 'kqueue'), 'kqueue not supported'
|
||||
|
||||
|
||||
class _KQueue(object):
|
||||
"""A kqueue-based event loop for BSD/Mac systems."""
|
||||
def __init__(self):
|
||||
self._kqueue = select.kqueue()
|
||||
self._active = {}
|
||||
|
||||
def fileno(self):
|
||||
return self._kqueue.fileno()
|
||||
|
||||
def close(self):
|
||||
self._kqueue.close()
|
||||
|
||||
def register(self, fd, events):
|
||||
if fd in self._active:
|
||||
raise IOError("fd %s already registered" % fd)
|
||||
self._control(fd, events, select.KQ_EV_ADD)
|
||||
self._active[fd] = events
|
||||
|
||||
def modify(self, fd, events):
|
||||
self.unregister(fd)
|
||||
self.register(fd, events)
|
||||
|
||||
def unregister(self, fd):
|
||||
events = self._active.pop(fd)
|
||||
self._control(fd, events, select.KQ_EV_DELETE)
|
||||
|
||||
def _control(self, fd, events, flags):
|
||||
kevents = []
|
||||
if events & IOLoop.WRITE:
|
||||
kevents.append(select.kevent(
|
||||
fd, filter=select.KQ_FILTER_WRITE, flags=flags))
|
||||
if events & IOLoop.READ:
|
||||
kevents.append(select.kevent(
|
||||
fd, filter=select.KQ_FILTER_READ, flags=flags))
|
||||
# Even though control() takes a list, it seems to return EINVAL
|
||||
# on Mac OS X (10.6) when there is more than one event in the list.
|
||||
for kevent in kevents:
|
||||
self._kqueue.control([kevent], 0)
|
||||
|
||||
def poll(self, timeout):
|
||||
kevents = self._kqueue.control(None, 1000, timeout)
|
||||
events = {}
|
||||
for kevent in kevents:
|
||||
fd = kevent.ident
|
||||
if kevent.filter == select.KQ_FILTER_READ:
|
||||
events[fd] = events.get(fd, 0) | IOLoop.READ
|
||||
if kevent.filter == select.KQ_FILTER_WRITE:
|
||||
if kevent.flags & select.KQ_EV_EOF:
|
||||
# If an asynchronous connection is refused, kqueue
|
||||
# returns a write event with the EOF flag set.
|
||||
# Turn this into an error for consistency with the
|
||||
# other IOLoop implementations.
|
||||
# Note that for read events, EOF may be returned before
|
||||
# all data has been consumed from the socket buffer,
|
||||
# so we only check for EOF on write events.
|
||||
events[fd] = IOLoop.ERROR
|
||||
else:
|
||||
events[fd] = events.get(fd, 0) | IOLoop.WRITE
|
||||
if kevent.flags & select.KQ_EV_ERROR:
|
||||
events[fd] = events.get(fd, 0) | IOLoop.ERROR
|
||||
return events.items()
|
||||
|
||||
|
||||
class KQueueIOLoop(PollIOLoop):
|
||||
def initialize(self, **kwargs):
|
||||
super(KQueueIOLoop, self).initialize(impl=_KQueue(), **kwargs)
|
||||
Executable
+69
@@ -0,0 +1,69 @@
|
||||
#
|
||||
# Copyright 2011 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Posix implementations of platform-specific functionality."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import fcntl
|
||||
import os
|
||||
|
||||
from tornado.platform import common, interface
|
||||
|
||||
|
||||
def set_close_exec(fd):
|
||||
flags = fcntl.fcntl(fd, fcntl.F_GETFD)
|
||||
fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
|
||||
|
||||
|
||||
def _set_nonblocking(fd):
|
||||
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
|
||||
|
||||
class Waker(interface.Waker):
|
||||
def __init__(self):
|
||||
r, w = os.pipe()
|
||||
_set_nonblocking(r)
|
||||
_set_nonblocking(w)
|
||||
set_close_exec(r)
|
||||
set_close_exec(w)
|
||||
self.reader = os.fdopen(r, "rb", 0)
|
||||
self.writer = os.fdopen(w, "wb", 0)
|
||||
|
||||
def fileno(self):
|
||||
return self.reader.fileno()
|
||||
|
||||
def write_fileno(self):
|
||||
return self.writer.fileno()
|
||||
|
||||
def wake(self):
|
||||
try:
|
||||
self.writer.write(b"x")
|
||||
except (IOError, ValueError):
|
||||
pass
|
||||
|
||||
def consume(self):
|
||||
try:
|
||||
while True:
|
||||
result = self.reader.read()
|
||||
if not result:
|
||||
break
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self.reader.close()
|
||||
common.try_close(self.writer)
|
||||
Executable
+75
@@ -0,0 +1,75 @@
|
||||
#
|
||||
# Copyright 2012 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Select-based IOLoop implementation.
|
||||
|
||||
Used as a fallback for systems that don't support epoll or kqueue.
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import select
|
||||
|
||||
from tornado.ioloop import IOLoop, PollIOLoop
|
||||
|
||||
|
||||
class _Select(object):
|
||||
"""A simple, select()-based IOLoop implementation for non-Linux systems"""
|
||||
def __init__(self):
|
||||
self.read_fds = set()
|
||||
self.write_fds = set()
|
||||
self.error_fds = set()
|
||||
self.fd_sets = (self.read_fds, self.write_fds, self.error_fds)
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def register(self, fd, events):
|
||||
if fd in self.read_fds or fd in self.write_fds or fd in self.error_fds:
|
||||
raise IOError("fd %s already registered" % fd)
|
||||
if events & IOLoop.READ:
|
||||
self.read_fds.add(fd)
|
||||
if events & IOLoop.WRITE:
|
||||
self.write_fds.add(fd)
|
||||
if events & IOLoop.ERROR:
|
||||
self.error_fds.add(fd)
|
||||
# Closed connections are reported as errors by epoll and kqueue,
|
||||
# but as zero-byte reads by select, so when errors are requested
|
||||
# we need to listen for both read and error.
|
||||
# self.read_fds.add(fd)
|
||||
|
||||
def modify(self, fd, events):
|
||||
self.unregister(fd)
|
||||
self.register(fd, events)
|
||||
|
||||
def unregister(self, fd):
|
||||
self.read_fds.discard(fd)
|
||||
self.write_fds.discard(fd)
|
||||
self.error_fds.discard(fd)
|
||||
|
||||
def poll(self, timeout):
|
||||
readable, writeable, errors = select.select(
|
||||
self.read_fds, self.write_fds, self.error_fds, timeout)
|
||||
events = {}
|
||||
for fd in readable:
|
||||
events[fd] = events.get(fd, 0) | IOLoop.READ
|
||||
for fd in writeable:
|
||||
events[fd] = events.get(fd, 0) | IOLoop.WRITE
|
||||
for fd in errors:
|
||||
events[fd] = events.get(fd, 0) | IOLoop.ERROR
|
||||
return events.items()
|
||||
|
||||
|
||||
class SelectIOLoop(PollIOLoop):
|
||||
def initialize(self, **kwargs):
|
||||
super(SelectIOLoop, self).initialize(impl=_Select(), **kwargs)
|
||||
Executable
+609
@@ -0,0 +1,609 @@
|
||||
# Author: Ovidiu Predescu
|
||||
# Date: July 2011
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Bridges between the Twisted reactor and Tornado IOLoop.
|
||||
|
||||
This module lets you run applications and libraries written for
|
||||
Twisted in a Tornado application. It can be used in two modes,
|
||||
depending on which library's underlying event loop you want to use.
|
||||
|
||||
This module has been tested with Twisted versions 11.0.0 and newer.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import numbers
|
||||
import socket
|
||||
import sys
|
||||
|
||||
import twisted.internet.abstract # type: ignore
|
||||
from twisted.internet.defer import Deferred # type: ignore
|
||||
from twisted.internet.posixbase import PosixReactorBase # type: ignore
|
||||
from twisted.internet.interfaces import IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor # type: ignore # noqa: E501
|
||||
from twisted.python import failure, log # type: ignore
|
||||
from twisted.internet import error # type: ignore
|
||||
import twisted.names.cache # type: ignore
|
||||
import twisted.names.client # type: ignore
|
||||
import twisted.names.hosts # type: ignore
|
||||
import twisted.names.resolve # type: ignore
|
||||
|
||||
from zope.interface import implementer # type: ignore
|
||||
|
||||
from tornado.concurrent import Future, future_set_exc_info
|
||||
from tornado.escape import utf8
|
||||
from tornado import gen
|
||||
import tornado.ioloop
|
||||
from tornado.log import app_log
|
||||
from tornado.netutil import Resolver
|
||||
from tornado.stack_context import NullContext, wrap
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.util import timedelta_to_seconds
|
||||
|
||||
|
||||
@implementer(IDelayedCall)
|
||||
class TornadoDelayedCall(object):
|
||||
"""DelayedCall object for Tornado."""
|
||||
def __init__(self, reactor, seconds, f, *args, **kw):
|
||||
self._reactor = reactor
|
||||
self._func = functools.partial(f, *args, **kw)
|
||||
self._time = self._reactor.seconds() + seconds
|
||||
self._timeout = self._reactor._io_loop.add_timeout(self._time,
|
||||
self._called)
|
||||
self._active = True
|
||||
|
||||
def _called(self):
|
||||
self._active = False
|
||||
self._reactor._removeDelayedCall(self)
|
||||
try:
|
||||
self._func()
|
||||
except:
|
||||
app_log.error("_called caught exception", exc_info=True)
|
||||
|
||||
def getTime(self):
|
||||
return self._time
|
||||
|
||||
def cancel(self):
|
||||
self._active = False
|
||||
self._reactor._io_loop.remove_timeout(self._timeout)
|
||||
self._reactor._removeDelayedCall(self)
|
||||
|
||||
def delay(self, seconds):
|
||||
self._reactor._io_loop.remove_timeout(self._timeout)
|
||||
self._time += seconds
|
||||
self._timeout = self._reactor._io_loop.add_timeout(self._time,
|
||||
self._called)
|
||||
|
||||
def reset(self, seconds):
|
||||
self._reactor._io_loop.remove_timeout(self._timeout)
|
||||
self._time = self._reactor.seconds() + seconds
|
||||
self._timeout = self._reactor._io_loop.add_timeout(self._time,
|
||||
self._called)
|
||||
|
||||
def active(self):
|
||||
return self._active
|
||||
|
||||
|
||||
@implementer(IReactorTime, IReactorFDSet)
|
||||
class TornadoReactor(PosixReactorBase):
|
||||
"""Twisted reactor built on the Tornado IOLoop.
|
||||
|
||||
`TornadoReactor` implements the Twisted reactor interface on top of
|
||||
the Tornado IOLoop. To use it, simply call `install` at the beginning
|
||||
of the application::
|
||||
|
||||
import tornado.platform.twisted
|
||||
tornado.platform.twisted.install()
|
||||
from twisted.internet import reactor
|
||||
|
||||
When the app is ready to start, call ``IOLoop.current().start()``
|
||||
instead of ``reactor.run()``.
|
||||
|
||||
It is also possible to create a non-global reactor by calling
|
||||
``tornado.platform.twisted.TornadoReactor()``. However, if
|
||||
the `.IOLoop` and reactor are to be short-lived (such as those used in
|
||||
unit tests), additional cleanup may be required. Specifically, it is
|
||||
recommended to call::
|
||||
|
||||
reactor.fireSystemEvent('shutdown')
|
||||
reactor.disconnectAll()
|
||||
|
||||
before closing the `.IOLoop`.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
|
||||
|
||||
.. deprecated:: 5.1
|
||||
|
||||
This class will be removed in Tornado 6.0. Use
|
||||
``twisted.internet.asyncioreactor.AsyncioSelectorReactor``
|
||||
instead.
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
self._io_loop = tornado.ioloop.IOLoop.current()
|
||||
self._readers = {} # map of reader objects to fd
|
||||
self._writers = {} # map of writer objects to fd
|
||||
self._fds = {} # a map of fd to a (reader, writer) tuple
|
||||
self._delayedCalls = {}
|
||||
PosixReactorBase.__init__(self)
|
||||
self.addSystemEventTrigger('during', 'shutdown', self.crash)
|
||||
|
||||
# IOLoop.start() bypasses some of the reactor initialization.
|
||||
# Fire off the necessary events if they weren't already triggered
|
||||
# by reactor.run().
|
||||
def start_if_necessary():
|
||||
if not self._started:
|
||||
self.fireSystemEvent('startup')
|
||||
self._io_loop.add_callback(start_if_necessary)
|
||||
|
||||
# IReactorTime
|
||||
def seconds(self):
|
||||
return self._io_loop.time()
|
||||
|
||||
def callLater(self, seconds, f, *args, **kw):
|
||||
dc = TornadoDelayedCall(self, seconds, f, *args, **kw)
|
||||
self._delayedCalls[dc] = True
|
||||
return dc
|
||||
|
||||
def getDelayedCalls(self):
|
||||
return [x for x in self._delayedCalls if x._active]
|
||||
|
||||
def _removeDelayedCall(self, dc):
|
||||
if dc in self._delayedCalls:
|
||||
del self._delayedCalls[dc]
|
||||
|
||||
# IReactorThreads
|
||||
def callFromThread(self, f, *args, **kw):
|
||||
assert callable(f), "%s is not callable" % f
|
||||
with NullContext():
|
||||
# This NullContext is mainly for an edge case when running
|
||||
# TwistedIOLoop on top of a TornadoReactor.
|
||||
# TwistedIOLoop.add_callback uses reactor.callFromThread and
|
||||
# should not pick up additional StackContexts along the way.
|
||||
self._io_loop.add_callback(f, *args, **kw)
|
||||
|
||||
# We don't need the waker code from the super class, Tornado uses
|
||||
# its own waker.
|
||||
def installWaker(self):
|
||||
pass
|
||||
|
||||
def wakeUp(self):
|
||||
pass
|
||||
|
||||
# IReactorFDSet
|
||||
def _invoke_callback(self, fd, events):
|
||||
if fd not in self._fds:
|
||||
return
|
||||
(reader, writer) = self._fds[fd]
|
||||
if reader:
|
||||
err = None
|
||||
if reader.fileno() == -1:
|
||||
err = error.ConnectionLost()
|
||||
elif events & IOLoop.READ:
|
||||
err = log.callWithLogger(reader, reader.doRead)
|
||||
if err is None and events & IOLoop.ERROR:
|
||||
err = error.ConnectionLost()
|
||||
if err is not None:
|
||||
self.removeReader(reader)
|
||||
reader.readConnectionLost(failure.Failure(err))
|
||||
if writer:
|
||||
err = None
|
||||
if writer.fileno() == -1:
|
||||
err = error.ConnectionLost()
|
||||
elif events & IOLoop.WRITE:
|
||||
err = log.callWithLogger(writer, writer.doWrite)
|
||||
if err is None and events & IOLoop.ERROR:
|
||||
err = error.ConnectionLost()
|
||||
if err is not None:
|
||||
self.removeWriter(writer)
|
||||
writer.writeConnectionLost(failure.Failure(err))
|
||||
|
||||
def addReader(self, reader):
|
||||
if reader in self._readers:
|
||||
# Don't add the reader if it's already there
|
||||
return
|
||||
fd = reader.fileno()
|
||||
self._readers[reader] = fd
|
||||
if fd in self._fds:
|
||||
(_, writer) = self._fds[fd]
|
||||
self._fds[fd] = (reader, writer)
|
||||
if writer:
|
||||
# We already registered this fd for write events,
|
||||
# update it for read events as well.
|
||||
self._io_loop.update_handler(fd, IOLoop.READ | IOLoop.WRITE)
|
||||
else:
|
||||
with NullContext():
|
||||
self._fds[fd] = (reader, None)
|
||||
self._io_loop.add_handler(fd, self._invoke_callback,
|
||||
IOLoop.READ)
|
||||
|
||||
def addWriter(self, writer):
|
||||
if writer in self._writers:
|
||||
return
|
||||
fd = writer.fileno()
|
||||
self._writers[writer] = fd
|
||||
if fd in self._fds:
|
||||
(reader, _) = self._fds[fd]
|
||||
self._fds[fd] = (reader, writer)
|
||||
if reader:
|
||||
# We already registered this fd for read events,
|
||||
# update it for write events as well.
|
||||
self._io_loop.update_handler(fd, IOLoop.READ | IOLoop.WRITE)
|
||||
else:
|
||||
with NullContext():
|
||||
self._fds[fd] = (None, writer)
|
||||
self._io_loop.add_handler(fd, self._invoke_callback,
|
||||
IOLoop.WRITE)
|
||||
|
||||
def removeReader(self, reader):
|
||||
if reader in self._readers:
|
||||
fd = self._readers.pop(reader)
|
||||
(_, writer) = self._fds[fd]
|
||||
if writer:
|
||||
# We have a writer so we need to update the IOLoop for
|
||||
# write events only.
|
||||
self._fds[fd] = (None, writer)
|
||||
self._io_loop.update_handler(fd, IOLoop.WRITE)
|
||||
else:
|
||||
# Since we have no writer registered, we remove the
|
||||
# entry from _fds and unregister the handler from the
|
||||
# IOLoop
|
||||
del self._fds[fd]
|
||||
self._io_loop.remove_handler(fd)
|
||||
|
||||
def removeWriter(self, writer):
|
||||
if writer in self._writers:
|
||||
fd = self._writers.pop(writer)
|
||||
(reader, _) = self._fds[fd]
|
||||
if reader:
|
||||
# We have a reader so we need to update the IOLoop for
|
||||
# read events only.
|
||||
self._fds[fd] = (reader, None)
|
||||
self._io_loop.update_handler(fd, IOLoop.READ)
|
||||
else:
|
||||
# Since we have no reader registered, we remove the
|
||||
# entry from the _fds and unregister the handler from
|
||||
# the IOLoop.
|
||||
del self._fds[fd]
|
||||
self._io_loop.remove_handler(fd)
|
||||
|
||||
def removeAll(self):
|
||||
return self._removeAll(self._readers, self._writers)
|
||||
|
||||
def getReaders(self):
|
||||
return self._readers.keys()
|
||||
|
||||
def getWriters(self):
|
||||
return self._writers.keys()
|
||||
|
||||
# The following functions are mainly used in twisted-style test cases;
|
||||
# it is expected that most users of the TornadoReactor will call
|
||||
# IOLoop.start() instead of Reactor.run().
|
||||
def stop(self):
|
||||
PosixReactorBase.stop(self)
|
||||
fire_shutdown = functools.partial(self.fireSystemEvent, "shutdown")
|
||||
self._io_loop.add_callback(fire_shutdown)
|
||||
|
||||
def crash(self):
|
||||
PosixReactorBase.crash(self)
|
||||
self._io_loop.stop()
|
||||
|
||||
def doIteration(self, delay):
|
||||
raise NotImplementedError("doIteration")
|
||||
|
||||
def mainLoop(self):
|
||||
# Since this class is intended to be used in applications
|
||||
# where the top-level event loop is ``io_loop.start()`` rather
|
||||
# than ``reactor.run()``, it is implemented a little
|
||||
# differently than other Twisted reactors. We override
|
||||
# ``mainLoop`` instead of ``doIteration`` and must implement
|
||||
# timed call functionality on top of `.IOLoop.add_timeout`
|
||||
# rather than using the implementation in
|
||||
# ``PosixReactorBase``.
|
||||
self._io_loop.start()
|
||||
|
||||
|
||||
class _TestReactor(TornadoReactor):
|
||||
"""Subclass of TornadoReactor for use in unittests.
|
||||
|
||||
This can't go in the test.py file because of import-order dependencies
|
||||
with the Twisted reactor test builder.
|
||||
"""
|
||||
def __init__(self):
|
||||
# always use a new ioloop
|
||||
IOLoop.clear_current()
|
||||
IOLoop(make_current=True)
|
||||
super(_TestReactor, self).__init__()
|
||||
IOLoop.clear_current()
|
||||
|
||||
def listenTCP(self, port, factory, backlog=50, interface=''):
|
||||
# default to localhost to avoid firewall prompts on the mac
|
||||
if not interface:
|
||||
interface = '127.0.0.1'
|
||||
return super(_TestReactor, self).listenTCP(
|
||||
port, factory, backlog=backlog, interface=interface)
|
||||
|
||||
def listenUDP(self, port, protocol, interface='', maxPacketSize=8192):
|
||||
if not interface:
|
||||
interface = '127.0.0.1'
|
||||
return super(_TestReactor, self).listenUDP(
|
||||
port, protocol, interface=interface, maxPacketSize=maxPacketSize)
|
||||
|
||||
|
||||
def install():
|
||||
"""Install this package as the default Twisted reactor.
|
||||
|
||||
``install()`` must be called very early in the startup process,
|
||||
before most other twisted-related imports. Conversely, because it
|
||||
initializes the `.IOLoop`, it cannot be called before
|
||||
`.fork_processes` or multi-process `~.TCPServer.start`. These
|
||||
conflicting requirements make it difficult to use `.TornadoReactor`
|
||||
in multi-process mode, and an external process manager such as
|
||||
``supervisord`` is recommended instead.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
|
||||
|
||||
.. deprecated:: 5.1
|
||||
|
||||
This functio will be removed in Tornado 6.0. Use
|
||||
``twisted.internet.asyncioreactor.install`` instead.
|
||||
"""
|
||||
reactor = TornadoReactor()
|
||||
from twisted.internet.main import installReactor # type: ignore
|
||||
installReactor(reactor)
|
||||
return reactor
|
||||
|
||||
|
||||
@implementer(IReadDescriptor, IWriteDescriptor)
|
||||
class _FD(object):
|
||||
def __init__(self, fd, fileobj, handler):
|
||||
self.fd = fd
|
||||
self.fileobj = fileobj
|
||||
self.handler = handler
|
||||
self.reading = False
|
||||
self.writing = False
|
||||
self.lost = False
|
||||
|
||||
def fileno(self):
|
||||
return self.fd
|
||||
|
||||
def doRead(self):
|
||||
if not self.lost:
|
||||
self.handler(self.fileobj, tornado.ioloop.IOLoop.READ)
|
||||
|
||||
def doWrite(self):
|
||||
if not self.lost:
|
||||
self.handler(self.fileobj, tornado.ioloop.IOLoop.WRITE)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
if not self.lost:
|
||||
self.handler(self.fileobj, tornado.ioloop.IOLoop.ERROR)
|
||||
self.lost = True
|
||||
|
||||
writeConnectionLost = readConnectionLost = connectionLost
|
||||
|
||||
def logPrefix(self):
|
||||
return ''
|
||||
|
||||
|
||||
class TwistedIOLoop(tornado.ioloop.IOLoop):
|
||||
"""IOLoop implementation that runs on Twisted.
|
||||
|
||||
`TwistedIOLoop` implements the Tornado IOLoop interface on top of
|
||||
the Twisted reactor. Recommended usage::
|
||||
|
||||
from tornado.platform.twisted import TwistedIOLoop
|
||||
from twisted.internet import reactor
|
||||
TwistedIOLoop().install()
|
||||
# Set up your tornado application as usual using `IOLoop.instance`
|
||||
reactor.run()
|
||||
|
||||
Uses the global Twisted reactor by default. To create multiple
|
||||
``TwistedIOLoops`` in the same process, you must pass a unique reactor
|
||||
when constructing each one.
|
||||
|
||||
Not compatible with `tornado.process.Subprocess.set_exit_callback`
|
||||
because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict
|
||||
with each other.
|
||||
|
||||
See also :meth:`tornado.ioloop.IOLoop.install` for general notes on
|
||||
installing alternative IOLoops.
|
||||
|
||||
.. deprecated:: 5.1
|
||||
|
||||
The `asyncio` event loop will be the only available implementation in
|
||||
Tornado 6.0.
|
||||
"""
|
||||
def initialize(self, reactor=None, **kwargs):
|
||||
super(TwistedIOLoop, self).initialize(**kwargs)
|
||||
if reactor is None:
|
||||
import twisted.internet.reactor # type: ignore
|
||||
reactor = twisted.internet.reactor
|
||||
self.reactor = reactor
|
||||
self.fds = {}
|
||||
|
||||
def close(self, all_fds=False):
|
||||
fds = self.fds
|
||||
self.reactor.removeAll()
|
||||
for c in self.reactor.getDelayedCalls():
|
||||
c.cancel()
|
||||
if all_fds:
|
||||
for fd in fds.values():
|
||||
self.close_fd(fd.fileobj)
|
||||
|
||||
def add_handler(self, fd, handler, events):
|
||||
if fd in self.fds:
|
||||
raise ValueError('fd %s added twice' % fd)
|
||||
fd, fileobj = self.split_fd(fd)
|
||||
self.fds[fd] = _FD(fd, fileobj, wrap(handler))
|
||||
if events & tornado.ioloop.IOLoop.READ:
|
||||
self.fds[fd].reading = True
|
||||
self.reactor.addReader(self.fds[fd])
|
||||
if events & tornado.ioloop.IOLoop.WRITE:
|
||||
self.fds[fd].writing = True
|
||||
self.reactor.addWriter(self.fds[fd])
|
||||
|
||||
def update_handler(self, fd, events):
|
||||
fd, fileobj = self.split_fd(fd)
|
||||
if events & tornado.ioloop.IOLoop.READ:
|
||||
if not self.fds[fd].reading:
|
||||
self.fds[fd].reading = True
|
||||
self.reactor.addReader(self.fds[fd])
|
||||
else:
|
||||
if self.fds[fd].reading:
|
||||
self.fds[fd].reading = False
|
||||
self.reactor.removeReader(self.fds[fd])
|
||||
if events & tornado.ioloop.IOLoop.WRITE:
|
||||
if not self.fds[fd].writing:
|
||||
self.fds[fd].writing = True
|
||||
self.reactor.addWriter(self.fds[fd])
|
||||
else:
|
||||
if self.fds[fd].writing:
|
||||
self.fds[fd].writing = False
|
||||
self.reactor.removeWriter(self.fds[fd])
|
||||
|
||||
def remove_handler(self, fd):
|
||||
fd, fileobj = self.split_fd(fd)
|
||||
if fd not in self.fds:
|
||||
return
|
||||
self.fds[fd].lost = True
|
||||
if self.fds[fd].reading:
|
||||
self.reactor.removeReader(self.fds[fd])
|
||||
if self.fds[fd].writing:
|
||||
self.reactor.removeWriter(self.fds[fd])
|
||||
del self.fds[fd]
|
||||
|
||||
def start(self):
|
||||
old_current = IOLoop.current(instance=False)
|
||||
try:
|
||||
self._setup_logging()
|
||||
self.make_current()
|
||||
self.reactor.run()
|
||||
finally:
|
||||
if old_current is None:
|
||||
IOLoop.clear_current()
|
||||
else:
|
||||
old_current.make_current()
|
||||
|
||||
def stop(self):
|
||||
self.reactor.crash()
|
||||
|
||||
def add_timeout(self, deadline, callback, *args, **kwargs):
|
||||
# This method could be simplified (since tornado 4.0) by
|
||||
# overriding call_at instead of add_timeout, but we leave it
|
||||
# for now as a test of backwards-compatibility.
|
||||
if isinstance(deadline, numbers.Real):
|
||||
delay = max(deadline - self.time(), 0)
|
||||
elif isinstance(deadline, datetime.timedelta):
|
||||
delay = timedelta_to_seconds(deadline)
|
||||
else:
|
||||
raise TypeError("Unsupported deadline %r")
|
||||
return self.reactor.callLater(
|
||||
delay, self._run_callback,
|
||||
functools.partial(wrap(callback), *args, **kwargs))
|
||||
|
||||
def remove_timeout(self, timeout):
|
||||
if timeout.active():
|
||||
timeout.cancel()
|
||||
|
||||
def add_callback(self, callback, *args, **kwargs):
|
||||
self.reactor.callFromThread(
|
||||
self._run_callback,
|
||||
functools.partial(wrap(callback), *args, **kwargs))
|
||||
|
||||
def add_callback_from_signal(self, callback, *args, **kwargs):
|
||||
self.add_callback(callback, *args, **kwargs)
|
||||
|
||||
|
||||
class TwistedResolver(Resolver):
|
||||
"""Twisted-based asynchronous resolver.
|
||||
|
||||
This is a non-blocking and non-threaded resolver. It is
|
||||
recommended only when threads cannot be used, since it has
|
||||
limitations compared to the standard ``getaddrinfo``-based
|
||||
`~tornado.netutil.Resolver` and
|
||||
`~tornado.netutil.DefaultExecutorResolver`. Specifically, it returns at
|
||||
most one result, and arguments other than ``host`` and ``family``
|
||||
are ignored. It may fail to resolve when ``family`` is not
|
||||
``socket.AF_UNSPEC``.
|
||||
|
||||
Requires Twisted 12.1 or newer.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
|
||||
"""
|
||||
def initialize(self):
|
||||
# partial copy of twisted.names.client.createResolver, which doesn't
|
||||
# allow for a reactor to be passed in.
|
||||
self.reactor = tornado.platform.twisted.TornadoReactor()
|
||||
|
||||
host_resolver = twisted.names.hosts.Resolver('/etc/hosts')
|
||||
cache_resolver = twisted.names.cache.CacheResolver(reactor=self.reactor)
|
||||
real_resolver = twisted.names.client.Resolver('/etc/resolv.conf',
|
||||
reactor=self.reactor)
|
||||
self.resolver = twisted.names.resolve.ResolverChain(
|
||||
[host_resolver, cache_resolver, real_resolver])
|
||||
|
||||
@gen.coroutine
|
||||
def resolve(self, host, port, family=0):
|
||||
# getHostByName doesn't accept IP addresses, so if the input
|
||||
# looks like an IP address just return it immediately.
|
||||
if twisted.internet.abstract.isIPAddress(host):
|
||||
resolved = host
|
||||
resolved_family = socket.AF_INET
|
||||
elif twisted.internet.abstract.isIPv6Address(host):
|
||||
resolved = host
|
||||
resolved_family = socket.AF_INET6
|
||||
else:
|
||||
deferred = self.resolver.getHostByName(utf8(host))
|
||||
fut = Future()
|
||||
deferred.addBoth(fut.set_result)
|
||||
resolved = yield fut
|
||||
if isinstance(resolved, failure.Failure):
|
||||
try:
|
||||
resolved.raiseException()
|
||||
except twisted.names.error.DomainError as e:
|
||||
raise IOError(e)
|
||||
elif twisted.internet.abstract.isIPAddress(resolved):
|
||||
resolved_family = socket.AF_INET
|
||||
elif twisted.internet.abstract.isIPv6Address(resolved):
|
||||
resolved_family = socket.AF_INET6
|
||||
else:
|
||||
resolved_family = socket.AF_UNSPEC
|
||||
if family != socket.AF_UNSPEC and family != resolved_family:
|
||||
raise Exception('Requested socket family %d but got %d' %
|
||||
(family, resolved_family))
|
||||
result = [
|
||||
(resolved_family, (resolved, port)),
|
||||
]
|
||||
raise gen.Return(result)
|
||||
|
||||
|
||||
if hasattr(gen.convert_yielded, 'register'):
|
||||
@gen.convert_yielded.register(Deferred) # type: ignore
|
||||
def _(d):
|
||||
f = Future()
|
||||
|
||||
def errback(failure):
|
||||
try:
|
||||
failure.raiseException()
|
||||
# Should never happen, but just in case
|
||||
raise Exception("errback called without error")
|
||||
except:
|
||||
future_set_exc_info(f, sys.exc_info())
|
||||
d.addCallbacks(f.set_result, errback)
|
||||
return f
|
||||
Executable
+20
@@ -0,0 +1,20 @@
|
||||
# NOTE: win32 support is currently experimental, and not recommended
|
||||
# for production use.
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import ctypes # type: ignore
|
||||
import ctypes.wintypes # type: ignore
|
||||
|
||||
# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx
|
||||
SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
|
||||
SetHandleInformation.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD) # noqa: E501
|
||||
SetHandleInformation.restype = ctypes.wintypes.BOOL
|
||||
|
||||
HANDLE_FLAG_INHERIT = 0x00000001
|
||||
|
||||
|
||||
def set_close_exec(fd):
|
||||
success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0)
|
||||
if not success:
|
||||
raise ctypes.WinError()
|
||||
Executable
+361
@@ -0,0 +1,361 @@
|
||||
#
|
||||
# Copyright 2011 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Utilities for working with multiple processes, including both forking
|
||||
the server into multiple processes and managing subprocesses.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import errno
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
from binascii import hexlify
|
||||
|
||||
from tornado.concurrent import Future, future_set_result_unless_cancelled
|
||||
from tornado import ioloop
|
||||
from tornado.iostream import PipeIOStream
|
||||
from tornado.log import gen_log
|
||||
from tornado.platform.auto import set_close_exec
|
||||
from tornado import stack_context
|
||||
from tornado.util import errno_from_exception, PY3
|
||||
|
||||
try:
|
||||
import multiprocessing
|
||||
except ImportError:
|
||||
# Multiprocessing is not available on Google App Engine.
|
||||
multiprocessing = None
|
||||
|
||||
if PY3:
|
||||
long = int
|
||||
|
||||
# Re-export this exception for convenience.
|
||||
try:
|
||||
CalledProcessError = subprocess.CalledProcessError
|
||||
except AttributeError:
|
||||
# The subprocess module exists in Google App Engine, but is empty.
|
||||
# This module isn't very useful in that case, but it should
|
||||
# at least be importable.
|
||||
if 'APPENGINE_RUNTIME' not in os.environ:
|
||||
raise
|
||||
|
||||
|
||||
def cpu_count():
|
||||
"""Returns the number of processors on this machine."""
|
||||
if multiprocessing is None:
|
||||
return 1
|
||||
try:
|
||||
return multiprocessing.cpu_count()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
try:
|
||||
return os.sysconf("SC_NPROCESSORS_CONF")
|
||||
except (AttributeError, ValueError):
|
||||
pass
|
||||
gen_log.error("Could not detect number of processors; assuming 1")
|
||||
return 1
|
||||
|
||||
|
||||
def _reseed_random():
|
||||
if 'random' not in sys.modules:
|
||||
return
|
||||
import random
|
||||
# If os.urandom is available, this method does the same thing as
|
||||
# random.seed (at least as of python 2.6). If os.urandom is not
|
||||
# available, we mix in the pid in addition to a timestamp.
|
||||
try:
|
||||
seed = long(hexlify(os.urandom(16)), 16)
|
||||
except NotImplementedError:
|
||||
seed = int(time.time() * 1000) ^ os.getpid()
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def _pipe_cloexec():
|
||||
r, w = os.pipe()
|
||||
set_close_exec(r)
|
||||
set_close_exec(w)
|
||||
return r, w
|
||||
|
||||
|
||||
_task_id = None
|
||||
|
||||
|
||||
def fork_processes(num_processes, max_restarts=100):
|
||||
"""Starts multiple worker processes.
|
||||
|
||||
If ``num_processes`` is None or <= 0, we detect the number of cores
|
||||
available on this machine and fork that number of child
|
||||
processes. If ``num_processes`` is given and > 0, we fork that
|
||||
specific number of sub-processes.
|
||||
|
||||
Since we use processes and not threads, there is no shared memory
|
||||
between any server code.
|
||||
|
||||
Note that multiple processes are not compatible with the autoreload
|
||||
module (or the ``autoreload=True`` option to `tornado.web.Application`
|
||||
which defaults to True when ``debug=True``).
|
||||
When using multiple processes, no IOLoops can be created or
|
||||
referenced until after the call to ``fork_processes``.
|
||||
|
||||
In each child process, ``fork_processes`` returns its *task id*, a
|
||||
number between 0 and ``num_processes``. Processes that exit
|
||||
abnormally (due to a signal or non-zero exit status) are restarted
|
||||
with the same id (up to ``max_restarts`` times). In the parent
|
||||
process, ``fork_processes`` returns None if all child processes
|
||||
have exited normally, but will otherwise only exit by throwing an
|
||||
exception.
|
||||
"""
|
||||
global _task_id
|
||||
assert _task_id is None
|
||||
if num_processes is None or num_processes <= 0:
|
||||
num_processes = cpu_count()
|
||||
gen_log.info("Starting %d processes", num_processes)
|
||||
children = {}
|
||||
|
||||
def start_child(i):
|
||||
pid = os.fork()
|
||||
if pid == 0:
|
||||
# child process
|
||||
_reseed_random()
|
||||
global _task_id
|
||||
_task_id = i
|
||||
return i
|
||||
else:
|
||||
children[pid] = i
|
||||
return None
|
||||
|
||||
for i in range(num_processes):
|
||||
id = start_child(i)
|
||||
if id is not None:
|
||||
return id
|
||||
num_restarts = 0
|
||||
while children:
|
||||
try:
|
||||
pid, status = os.wait()
|
||||
except OSError as e:
|
||||
if errno_from_exception(e) == errno.EINTR:
|
||||
continue
|
||||
raise
|
||||
if pid not in children:
|
||||
continue
|
||||
id = children.pop(pid)
|
||||
if os.WIFSIGNALED(status):
|
||||
gen_log.warning("child %d (pid %d) killed by signal %d, restarting",
|
||||
id, pid, os.WTERMSIG(status))
|
||||
elif os.WEXITSTATUS(status) != 0:
|
||||
gen_log.warning("child %d (pid %d) exited with status %d, restarting",
|
||||
id, pid, os.WEXITSTATUS(status))
|
||||
else:
|
||||
gen_log.info("child %d (pid %d) exited normally", id, pid)
|
||||
continue
|
||||
num_restarts += 1
|
||||
if num_restarts > max_restarts:
|
||||
raise RuntimeError("Too many child restarts, giving up")
|
||||
new_id = start_child(id)
|
||||
if new_id is not None:
|
||||
return new_id
|
||||
# All child processes exited cleanly, so exit the master process
|
||||
# instead of just returning to right after the call to
|
||||
# fork_processes (which will probably just start up another IOLoop
|
||||
# unless the caller checks the return value).
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def task_id():
|
||||
"""Returns the current task id, if any.
|
||||
|
||||
Returns None if this process was not created by `fork_processes`.
|
||||
"""
|
||||
global _task_id
|
||||
return _task_id
|
||||
|
||||
|
||||
class Subprocess(object):
|
||||
"""Wraps ``subprocess.Popen`` with IOStream support.
|
||||
|
||||
The constructor is the same as ``subprocess.Popen`` with the following
|
||||
additions:
|
||||
|
||||
* ``stdin``, ``stdout``, and ``stderr`` may have the value
|
||||
``tornado.process.Subprocess.STREAM``, which will make the corresponding
|
||||
attribute of the resulting Subprocess a `.PipeIOStream`. If this option
|
||||
is used, the caller is responsible for closing the streams when done
|
||||
with them.
|
||||
|
||||
The ``Subprocess.STREAM`` option and the ``set_exit_callback`` and
|
||||
``wait_for_exit`` methods do not work on Windows. There is
|
||||
therefore no reason to use this class instead of
|
||||
``subprocess.Popen`` on that platform.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
|
||||
|
||||
"""
|
||||
STREAM = object()
|
||||
|
||||
_initialized = False
|
||||
_waiting = {} # type: ignore
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.io_loop = ioloop.IOLoop.current()
|
||||
# All FDs we create should be closed on error; those in to_close
|
||||
# should be closed in the parent process on success.
|
||||
pipe_fds = []
|
||||
to_close = []
|
||||
if kwargs.get('stdin') is Subprocess.STREAM:
|
||||
in_r, in_w = _pipe_cloexec()
|
||||
kwargs['stdin'] = in_r
|
||||
pipe_fds.extend((in_r, in_w))
|
||||
to_close.append(in_r)
|
||||
self.stdin = PipeIOStream(in_w)
|
||||
if kwargs.get('stdout') is Subprocess.STREAM:
|
||||
out_r, out_w = _pipe_cloexec()
|
||||
kwargs['stdout'] = out_w
|
||||
pipe_fds.extend((out_r, out_w))
|
||||
to_close.append(out_w)
|
||||
self.stdout = PipeIOStream(out_r)
|
||||
if kwargs.get('stderr') is Subprocess.STREAM:
|
||||
err_r, err_w = _pipe_cloexec()
|
||||
kwargs['stderr'] = err_w
|
||||
pipe_fds.extend((err_r, err_w))
|
||||
to_close.append(err_w)
|
||||
self.stderr = PipeIOStream(err_r)
|
||||
try:
|
||||
self.proc = subprocess.Popen(*args, **kwargs)
|
||||
except:
|
||||
for fd in pipe_fds:
|
||||
os.close(fd)
|
||||
raise
|
||||
for fd in to_close:
|
||||
os.close(fd)
|
||||
for attr in ['stdin', 'stdout', 'stderr', 'pid']:
|
||||
if not hasattr(self, attr): # don't clobber streams set above
|
||||
setattr(self, attr, getattr(self.proc, attr))
|
||||
self._exit_callback = None
|
||||
self.returncode = None
|
||||
|
||||
def set_exit_callback(self, callback):
|
||||
"""Runs ``callback`` when this process exits.
|
||||
|
||||
The callback takes one argument, the return code of the process.
|
||||
|
||||
This method uses a ``SIGCHLD`` handler, which is a global setting
|
||||
and may conflict if you have other libraries trying to handle the
|
||||
same signal. If you are using more than one ``IOLoop`` it may
|
||||
be necessary to call `Subprocess.initialize` first to designate
|
||||
one ``IOLoop`` to run the signal handlers.
|
||||
|
||||
In many cases a close callback on the stdout or stderr streams
|
||||
can be used as an alternative to an exit callback if the
|
||||
signal handler is causing a problem.
|
||||
"""
|
||||
self._exit_callback = stack_context.wrap(callback)
|
||||
Subprocess.initialize()
|
||||
Subprocess._waiting[self.pid] = self
|
||||
Subprocess._try_cleanup_process(self.pid)
|
||||
|
||||
def wait_for_exit(self, raise_error=True):
|
||||
"""Returns a `.Future` which resolves when the process exits.
|
||||
|
||||
Usage::
|
||||
|
||||
ret = yield proc.wait_for_exit()
|
||||
|
||||
This is a coroutine-friendly alternative to `set_exit_callback`
|
||||
(and a replacement for the blocking `subprocess.Popen.wait`).
|
||||
|
||||
By default, raises `subprocess.CalledProcessError` if the process
|
||||
has a non-zero exit status. Use ``wait_for_exit(raise_error=False)``
|
||||
to suppress this behavior and return the exit status without raising.
|
||||
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
future = Future()
|
||||
|
||||
def callback(ret):
|
||||
if ret != 0 and raise_error:
|
||||
# Unfortunately we don't have the original args any more.
|
||||
future.set_exception(CalledProcessError(ret, None))
|
||||
else:
|
||||
future_set_result_unless_cancelled(future, ret)
|
||||
self.set_exit_callback(callback)
|
||||
return future
|
||||
|
||||
@classmethod
|
||||
def initialize(cls):
|
||||
"""Initializes the ``SIGCHLD`` handler.
|
||||
|
||||
The signal handler is run on an `.IOLoop` to avoid locking issues.
|
||||
Note that the `.IOLoop` used for signal handling need not be the
|
||||
same one used by individual Subprocess objects (as long as the
|
||||
``IOLoops`` are each running in separate threads).
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
The ``io_loop`` argument (deprecated since version 4.1) has been
|
||||
removed.
|
||||
"""
|
||||
if cls._initialized:
|
||||
return
|
||||
io_loop = ioloop.IOLoop.current()
|
||||
cls._old_sigchld = signal.signal(
|
||||
signal.SIGCHLD,
|
||||
lambda sig, frame: io_loop.add_callback_from_signal(cls._cleanup))
|
||||
cls._initialized = True
|
||||
|
||||
@classmethod
|
||||
def uninitialize(cls):
|
||||
"""Removes the ``SIGCHLD`` handler."""
|
||||
if not cls._initialized:
|
||||
return
|
||||
signal.signal(signal.SIGCHLD, cls._old_sigchld)
|
||||
cls._initialized = False
|
||||
|
||||
@classmethod
|
||||
def _cleanup(cls):
|
||||
for pid in list(cls._waiting.keys()): # make a copy
|
||||
cls._try_cleanup_process(pid)
|
||||
|
||||
@classmethod
|
||||
def _try_cleanup_process(cls, pid):
|
||||
try:
|
||||
ret_pid, status = os.waitpid(pid, os.WNOHANG)
|
||||
except OSError as e:
|
||||
if errno_from_exception(e) == errno.ECHILD:
|
||||
return
|
||||
if ret_pid == 0:
|
||||
return
|
||||
assert ret_pid == pid
|
||||
subproc = cls._waiting.pop(pid)
|
||||
subproc.io_loop.add_callback_from_signal(
|
||||
subproc._set_returncode, status)
|
||||
|
||||
def _set_returncode(self, status):
|
||||
if os.WIFSIGNALED(status):
|
||||
self.returncode = -os.WTERMSIG(status)
|
||||
else:
|
||||
assert os.WIFEXITED(status)
|
||||
self.returncode = os.WEXITSTATUS(status)
|
||||
# We've taken over wait() duty from the subprocess.Popen
|
||||
# object. If we don't inform it of the process's return code,
|
||||
# it will log a warning at destruction in python 3.6+.
|
||||
self.proc.returncode = self.returncode
|
||||
if self._exit_callback:
|
||||
callback = self._exit_callback
|
||||
self._exit_callback = None
|
||||
callback(self.returncode)
|
||||
Executable
+379
@@ -0,0 +1,379 @@
|
||||
# Copyright 2015 The Tornado Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Asynchronous queues for coroutines. These classes are very similar
|
||||
to those provided in the standard library's `asyncio package
|
||||
<https://docs.python.org/3/library/asyncio-queue.html>`_.
|
||||
|
||||
.. warning::
|
||||
|
||||
Unlike the standard library's `queue` module, the classes defined here
|
||||
are *not* thread-safe. To use these queues from another thread,
|
||||
use `.IOLoop.add_callback` to transfer control to the `.IOLoop` thread
|
||||
before calling any queue methods.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import collections
|
||||
import heapq
|
||||
|
||||
from tornado import gen, ioloop
|
||||
from tornado.concurrent import Future, future_set_result_unless_cancelled
|
||||
from tornado.locks import Event
|
||||
|
||||
__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
|
||||
|
||||
|
||||
class QueueEmpty(Exception):
|
||||
"""Raised by `.Queue.get_nowait` when the queue has no items."""
|
||||
pass
|
||||
|
||||
|
||||
class QueueFull(Exception):
|
||||
"""Raised by `.Queue.put_nowait` when a queue is at its maximum size."""
|
||||
pass
|
||||
|
||||
|
||||
def _set_timeout(future, timeout):
|
||||
if timeout:
|
||||
def on_timeout():
|
||||
if not future.done():
|
||||
future.set_exception(gen.TimeoutError())
|
||||
io_loop = ioloop.IOLoop.current()
|
||||
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
|
||||
future.add_done_callback(
|
||||
lambda _: io_loop.remove_timeout(timeout_handle))
|
||||
|
||||
|
||||
class _QueueIterator(object):
|
||||
def __init__(self, q):
|
||||
self.q = q
|
||||
|
||||
def __anext__(self):
|
||||
return self.q.get()
|
||||
|
||||
|
||||
class Queue(object):
|
||||
"""Coordinate producer and consumer coroutines.
|
||||
|
||||
If maxsize is 0 (the default) the queue size is unbounded.
|
||||
|
||||
.. testcode::
|
||||
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.queues import Queue
|
||||
|
||||
q = Queue(maxsize=2)
|
||||
|
||||
async def consumer():
|
||||
async for item in q:
|
||||
try:
|
||||
print('Doing work on %s' % item)
|
||||
await gen.sleep(0.01)
|
||||
finally:
|
||||
q.task_done()
|
||||
|
||||
async def producer():
|
||||
for item in range(5):
|
||||
await q.put(item)
|
||||
print('Put %s' % item)
|
||||
|
||||
async def main():
|
||||
# Start consumer without waiting (since it never finishes).
|
||||
IOLoop.current().spawn_callback(consumer)
|
||||
await producer() # Wait for producer to put all tasks.
|
||||
await q.join() # Wait for consumer to finish all tasks.
|
||||
print('Done')
|
||||
|
||||
IOLoop.current().run_sync(main)
|
||||
|
||||
.. testoutput::
|
||||
|
||||
Put 0
|
||||
Put 1
|
||||
Doing work on 0
|
||||
Put 2
|
||||
Doing work on 1
|
||||
Put 3
|
||||
Doing work on 2
|
||||
Put 4
|
||||
Doing work on 3
|
||||
Doing work on 4
|
||||
Done
|
||||
|
||||
|
||||
In versions of Python without native coroutines (before 3.5),
|
||||
``consumer()`` could be written as::
|
||||
|
||||
@gen.coroutine
|
||||
def consumer():
|
||||
while True:
|
||||
item = yield q.get()
|
||||
try:
|
||||
print('Doing work on %s' % item)
|
||||
yield gen.sleep(0.01)
|
||||
finally:
|
||||
q.task_done()
|
||||
|
||||
.. versionchanged:: 4.3
|
||||
Added ``async for`` support in Python 3.5.
|
||||
|
||||
"""
|
||||
def __init__(self, maxsize=0):
|
||||
if maxsize is None:
|
||||
raise TypeError("maxsize can't be None")
|
||||
|
||||
if maxsize < 0:
|
||||
raise ValueError("maxsize can't be negative")
|
||||
|
||||
self._maxsize = maxsize
|
||||
self._init()
|
||||
self._getters = collections.deque([]) # Futures.
|
||||
self._putters = collections.deque([]) # Pairs of (item, Future).
|
||||
self._unfinished_tasks = 0
|
||||
self._finished = Event()
|
||||
self._finished.set()
|
||||
|
||||
@property
|
||||
def maxsize(self):
|
||||
"""Number of items allowed in the queue."""
|
||||
return self._maxsize
|
||||
|
||||
def qsize(self):
|
||||
"""Number of items in the queue."""
|
||||
return len(self._queue)
|
||||
|
||||
def empty(self):
|
||||
return not self._queue
|
||||
|
||||
def full(self):
|
||||
if self.maxsize == 0:
|
||||
return False
|
||||
else:
|
||||
return self.qsize() >= self.maxsize
|
||||
|
||||
def put(self, item, timeout=None):
|
||||
"""Put an item into the queue, perhaps waiting until there is room.
|
||||
|
||||
Returns a Future, which raises `tornado.util.TimeoutError` after a
|
||||
timeout.
|
||||
|
||||
``timeout`` may be a number denoting a time (on the same
|
||||
scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a
|
||||
`datetime.timedelta` object for a deadline relative to the
|
||||
current time.
|
||||
"""
|
||||
future = Future()
|
||||
try:
|
||||
self.put_nowait(item)
|
||||
except QueueFull:
|
||||
self._putters.append((item, future))
|
||||
_set_timeout(future, timeout)
|
||||
else:
|
||||
future.set_result(None)
|
||||
return future
|
||||
|
||||
def put_nowait(self, item):
|
||||
"""Put an item into the queue without blocking.
|
||||
|
||||
If no free slot is immediately available, raise `QueueFull`.
|
||||
"""
|
||||
self._consume_expired()
|
||||
if self._getters:
|
||||
assert self.empty(), "queue non-empty, why are getters waiting?"
|
||||
getter = self._getters.popleft()
|
||||
self.__put_internal(item)
|
||||
future_set_result_unless_cancelled(getter, self._get())
|
||||
elif self.full():
|
||||
raise QueueFull
|
||||
else:
|
||||
self.__put_internal(item)
|
||||
|
||||
def get(self, timeout=None):
|
||||
"""Remove and return an item from the queue.
|
||||
|
||||
Returns a Future which resolves once an item is available, or raises
|
||||
`tornado.util.TimeoutError` after a timeout.
|
||||
|
||||
``timeout`` may be a number denoting a time (on the same
|
||||
scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a
|
||||
`datetime.timedelta` object for a deadline relative to the
|
||||
current time.
|
||||
"""
|
||||
future = Future()
|
||||
try:
|
||||
future.set_result(self.get_nowait())
|
||||
except QueueEmpty:
|
||||
self._getters.append(future)
|
||||
_set_timeout(future, timeout)
|
||||
return future
|
||||
|
||||
def get_nowait(self):
|
||||
"""Remove and return an item from the queue without blocking.
|
||||
|
||||
Return an item if one is immediately available, else raise
|
||||
`QueueEmpty`.
|
||||
"""
|
||||
self._consume_expired()
|
||||
if self._putters:
|
||||
assert self.full(), "queue not full, why are putters waiting?"
|
||||
item, putter = self._putters.popleft()
|
||||
self.__put_internal(item)
|
||||
future_set_result_unless_cancelled(putter, None)
|
||||
return self._get()
|
||||
elif self.qsize():
|
||||
return self._get()
|
||||
else:
|
||||
raise QueueEmpty
|
||||
|
||||
def task_done(self):
|
||||
"""Indicate that a formerly enqueued task is complete.
|
||||
|
||||
Used by queue consumers. For each `.get` used to fetch a task, a
|
||||
subsequent call to `.task_done` tells the queue that the processing
|
||||
on the task is complete.
|
||||
|
||||
If a `.join` is blocking, it resumes when all items have been
|
||||
processed; that is, when every `.put` is matched by a `.task_done`.
|
||||
|
||||
Raises `ValueError` if called more times than `.put`.
|
||||
"""
|
||||
if self._unfinished_tasks <= 0:
|
||||
raise ValueError('task_done() called too many times')
|
||||
self._unfinished_tasks -= 1
|
||||
if self._unfinished_tasks == 0:
|
||||
self._finished.set()
|
||||
|
||||
def join(self, timeout=None):
|
||||
"""Block until all items in the queue are processed.
|
||||
|
||||
Returns a Future, which raises `tornado.util.TimeoutError` after a
|
||||
timeout.
|
||||
"""
|
||||
return self._finished.wait(timeout)
|
||||
|
||||
def __aiter__(self):
|
||||
return _QueueIterator(self)
|
||||
|
||||
# These three are overridable in subclasses.
|
||||
def _init(self):
|
||||
self._queue = collections.deque()
|
||||
|
||||
def _get(self):
|
||||
return self._queue.popleft()
|
||||
|
||||
def _put(self, item):
|
||||
self._queue.append(item)
|
||||
# End of the overridable methods.
|
||||
|
||||
def __put_internal(self, item):
|
||||
self._unfinished_tasks += 1
|
||||
self._finished.clear()
|
||||
self._put(item)
|
||||
|
||||
def _consume_expired(self):
|
||||
# Remove timed-out waiters.
|
||||
while self._putters and self._putters[0][1].done():
|
||||
self._putters.popleft()
|
||||
|
||||
while self._getters and self._getters[0].done():
|
||||
self._getters.popleft()
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s at %s %s>' % (
|
||||
type(self).__name__, hex(id(self)), self._format())
|
||||
|
||||
def __str__(self):
|
||||
return '<%s %s>' % (type(self).__name__, self._format())
|
||||
|
||||
def _format(self):
|
||||
result = 'maxsize=%r' % (self.maxsize, )
|
||||
if getattr(self, '_queue', None):
|
||||
result += ' queue=%r' % self._queue
|
||||
if self._getters:
|
||||
result += ' getters[%s]' % len(self._getters)
|
||||
if self._putters:
|
||||
result += ' putters[%s]' % len(self._putters)
|
||||
if self._unfinished_tasks:
|
||||
result += ' tasks=%s' % self._unfinished_tasks
|
||||
return result
|
||||
|
||||
|
||||
class PriorityQueue(Queue):
|
||||
"""A `.Queue` that retrieves entries in priority order, lowest first.
|
||||
|
||||
Entries are typically tuples like ``(priority number, data)``.
|
||||
|
||||
.. testcode::
|
||||
|
||||
from tornado.queues import PriorityQueue
|
||||
|
||||
q = PriorityQueue()
|
||||
q.put((1, 'medium-priority item'))
|
||||
q.put((0, 'high-priority item'))
|
||||
q.put((10, 'low-priority item'))
|
||||
|
||||
print(q.get_nowait())
|
||||
print(q.get_nowait())
|
||||
print(q.get_nowait())
|
||||
|
||||
.. testoutput::
|
||||
|
||||
(0, 'high-priority item')
|
||||
(1, 'medium-priority item')
|
||||
(10, 'low-priority item')
|
||||
"""
|
||||
def _init(self):
|
||||
self._queue = []
|
||||
|
||||
def _put(self, item):
|
||||
heapq.heappush(self._queue, item)
|
||||
|
||||
def _get(self):
|
||||
return heapq.heappop(self._queue)
|
||||
|
||||
|
||||
class LifoQueue(Queue):
|
||||
"""A `.Queue` that retrieves the most recently put items first.
|
||||
|
||||
.. testcode::
|
||||
|
||||
from tornado.queues import LifoQueue
|
||||
|
||||
q = LifoQueue()
|
||||
q.put(3)
|
||||
q.put(2)
|
||||
q.put(1)
|
||||
|
||||
print(q.get_nowait())
|
||||
print(q.get_nowait())
|
||||
print(q.get_nowait())
|
||||
|
||||
.. testoutput::
|
||||
|
||||
1
|
||||
2
|
||||
3
|
||||
"""
|
||||
def _init(self):
|
||||
self._queue = []
|
||||
|
||||
def _put(self, item):
|
||||
self._queue.append(item)
|
||||
|
||||
def _get(self):
|
||||
return self._queue.pop()
|
||||
Executable
+641
@@ -0,0 +1,641 @@
|
||||
# Copyright 2015 The Tornado Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Flexible routing implementation.
|
||||
|
||||
Tornado routes HTTP requests to appropriate handlers using `Router`
|
||||
class implementations. The `tornado.web.Application` class is a
|
||||
`Router` implementation and may be used directly, or the classes in
|
||||
this module may be used for additional flexibility. The `RuleRouter`
|
||||
class can match on more criteria than `.Application`, or the `Router`
|
||||
interface can be subclassed for maximum customization.
|
||||
|
||||
`Router` interface extends `~.httputil.HTTPServerConnectionDelegate`
|
||||
to provide additional routing capabilities. This also means that any
|
||||
`Router` implementation can be used directly as a ``request_callback``
|
||||
for `~.httpserver.HTTPServer` constructor.
|
||||
|
||||
`Router` subclass must implement a ``find_handler`` method to provide
|
||||
a suitable `~.httputil.HTTPMessageDelegate` instance to handle the
|
||||
request:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class CustomRouter(Router):
|
||||
def find_handler(self, request, **kwargs):
|
||||
# some routing logic providing a suitable HTTPMessageDelegate instance
|
||||
return MessageDelegate(request.connection)
|
||||
|
||||
class MessageDelegate(HTTPMessageDelegate):
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def finish(self):
|
||||
self.connection.write_headers(
|
||||
ResponseStartLine("HTTP/1.1", 200, "OK"),
|
||||
HTTPHeaders({"Content-Length": "2"}),
|
||||
b"OK")
|
||||
self.connection.finish()
|
||||
|
||||
router = CustomRouter()
|
||||
server = HTTPServer(router)
|
||||
|
||||
The main responsibility of `Router` implementation is to provide a
|
||||
mapping from a request to `~.httputil.HTTPMessageDelegate` instance
|
||||
that will handle this request. In the example above we can see that
|
||||
routing is possible even without instantiating an `~.web.Application`.
|
||||
|
||||
For routing to `~.web.RequestHandler` implementations we need an
|
||||
`~.web.Application` instance. `~.web.Application.get_handler_delegate`
|
||||
provides a convenient way to create `~.httputil.HTTPMessageDelegate`
|
||||
for a given request and `~.web.RequestHandler`.
|
||||
|
||||
Here is a simple example of how we can we route to
|
||||
`~.web.RequestHandler` subclasses by HTTP method:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
resources = {}
|
||||
|
||||
class GetResource(RequestHandler):
|
||||
def get(self, path):
|
||||
if path not in resources:
|
||||
raise HTTPError(404)
|
||||
|
||||
self.finish(resources[path])
|
||||
|
||||
class PostResource(RequestHandler):
|
||||
def post(self, path):
|
||||
resources[path] = self.request.body
|
||||
|
||||
class HTTPMethodRouter(Router):
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
def find_handler(self, request, **kwargs):
|
||||
handler = GetResource if request.method == "GET" else PostResource
|
||||
return self.app.get_handler_delegate(request, handler, path_args=[request.path])
|
||||
|
||||
router = HTTPMethodRouter(Application())
|
||||
server = HTTPServer(router)
|
||||
|
||||
`ReversibleRouter` interface adds the ability to distinguish between
|
||||
the routes and reverse them to the original urls using route's name
|
||||
and additional arguments. `~.web.Application` is itself an
|
||||
implementation of `ReversibleRouter` class.
|
||||
|
||||
`RuleRouter` and `ReversibleRuleRouter` are implementations of
|
||||
`Router` and `ReversibleRouter` interfaces and can be used for
|
||||
creating rule-based routing configurations.
|
||||
|
||||
Rules are instances of `Rule` class. They contain a `Matcher`, which
|
||||
provides the logic for determining whether the rule is a match for a
|
||||
particular request and a target, which can be one of the following.
|
||||
|
||||
1) An instance of `~.httputil.HTTPServerConnectionDelegate`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
router = RuleRouter([
|
||||
Rule(PathMatches("/handler"), ConnectionDelegate()),
|
||||
# ... more rules
|
||||
])
|
||||
|
||||
class ConnectionDelegate(HTTPServerConnectionDelegate):
|
||||
def start_request(self, server_conn, request_conn):
|
||||
return MessageDelegate(request_conn)
|
||||
|
||||
2) A callable accepting a single argument of `~.httputil.HTTPServerRequest` type:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
router = RuleRouter([
|
||||
Rule(PathMatches("/callable"), request_callable)
|
||||
])
|
||||
|
||||
def request_callable(request):
|
||||
request.write(b"HTTP/1.1 200 OK\\r\\nContent-Length: 2\\r\\n\\r\\nOK")
|
||||
request.finish()
|
||||
|
||||
3) Another `Router` instance:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
router = RuleRouter([
|
||||
Rule(PathMatches("/router.*"), CustomRouter())
|
||||
])
|
||||
|
||||
Of course a nested `RuleRouter` or a `~.web.Application` is allowed:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
router = RuleRouter([
|
||||
Rule(HostMatches("example.com"), RuleRouter([
|
||||
Rule(PathMatches("/app1/.*"), Application([(r"/app1/handler", Handler)]))),
|
||||
]))
|
||||
])
|
||||
|
||||
server = HTTPServer(router)
|
||||
|
||||
In the example below `RuleRouter` is used to route between applications:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
app1 = Application([
|
||||
(r"/app1/handler", Handler1),
|
||||
# other handlers ...
|
||||
])
|
||||
|
||||
app2 = Application([
|
||||
(r"/app2/handler", Handler2),
|
||||
# other handlers ...
|
||||
])
|
||||
|
||||
router = RuleRouter([
|
||||
Rule(PathMatches("/app1.*"), app1),
|
||||
Rule(PathMatches("/app2.*"), app2)
|
||||
])
|
||||
|
||||
server = HTTPServer(router)
|
||||
|
||||
For more information on application-level routing see docs for `~.web.Application`.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import re
|
||||
from functools import partial
|
||||
|
||||
from tornado import httputil
|
||||
from tornado.httpserver import _CallableAdapter
|
||||
from tornado.escape import url_escape, url_unescape, utf8
|
||||
from tornado.log import app_log
|
||||
from tornado.util import basestring_type, import_object, re_unescape, unicode_type
|
||||
|
||||
try:
|
||||
import typing # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class Router(httputil.HTTPServerConnectionDelegate):
|
||||
"""Abstract router interface."""
|
||||
|
||||
def find_handler(self, request, **kwargs):
|
||||
# type: (httputil.HTTPServerRequest, typing.Any)->httputil.HTTPMessageDelegate
|
||||
"""Must be implemented to return an appropriate instance of `~.httputil.HTTPMessageDelegate`
|
||||
that can serve the request.
|
||||
Routing implementations may pass additional kwargs to extend the routing logic.
|
||||
|
||||
:arg httputil.HTTPServerRequest request: current HTTP request.
|
||||
:arg kwargs: additional keyword arguments passed by routing implementation.
|
||||
:returns: an instance of `~.httputil.HTTPMessageDelegate` that will be used to
|
||||
process the request.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def start_request(self, server_conn, request_conn):
|
||||
return _RoutingDelegate(self, server_conn, request_conn)
|
||||
|
||||
|
||||
class ReversibleRouter(Router):
|
||||
"""Abstract router interface for routers that can handle named routes
|
||||
and support reversing them to original urls.
|
||||
"""
|
||||
|
||||
def reverse_url(self, name, *args):
|
||||
"""Returns url string for a given route name and arguments
|
||||
or ``None`` if no match is found.
|
||||
|
||||
:arg str name: route name.
|
||||
:arg args: url parameters.
|
||||
:returns: parametrized url string for a given route name (or ``None``).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class _RoutingDelegate(httputil.HTTPMessageDelegate):
|
||||
def __init__(self, router, server_conn, request_conn):
|
||||
self.server_conn = server_conn
|
||||
self.request_conn = request_conn
|
||||
self.delegate = None
|
||||
self.router = router # type: Router
|
||||
|
||||
def headers_received(self, start_line, headers):
|
||||
request = httputil.HTTPServerRequest(
|
||||
connection=self.request_conn,
|
||||
server_connection=self.server_conn,
|
||||
start_line=start_line, headers=headers)
|
||||
|
||||
self.delegate = self.router.find_handler(request)
|
||||
if self.delegate is None:
|
||||
app_log.debug("Delegate for %s %s request not found",
|
||||
start_line.method, start_line.path)
|
||||
self.delegate = _DefaultMessageDelegate(self.request_conn)
|
||||
|
||||
return self.delegate.headers_received(start_line, headers)
|
||||
|
||||
def data_received(self, chunk):
|
||||
return self.delegate.data_received(chunk)
|
||||
|
||||
def finish(self):
|
||||
self.delegate.finish()
|
||||
|
||||
def on_connection_close(self):
|
||||
self.delegate.on_connection_close()
|
||||
|
||||
|
||||
class _DefaultMessageDelegate(httputil.HTTPMessageDelegate):
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def finish(self):
|
||||
self.connection.write_headers(
|
||||
httputil.ResponseStartLine("HTTP/1.1", 404, "Not Found"), httputil.HTTPHeaders())
|
||||
self.connection.finish()
|
||||
|
||||
|
||||
class RuleRouter(Router):
|
||||
"""Rule-based router implementation."""
|
||||
|
||||
def __init__(self, rules=None):
|
||||
"""Constructs a router from an ordered list of rules::
|
||||
|
||||
RuleRouter([
|
||||
Rule(PathMatches("/handler"), Target),
|
||||
# ... more rules
|
||||
])
|
||||
|
||||
You can also omit explicit `Rule` constructor and use tuples of arguments::
|
||||
|
||||
RuleRouter([
|
||||
(PathMatches("/handler"), Target),
|
||||
])
|
||||
|
||||
`PathMatches` is a default matcher, so the example above can be simplified::
|
||||
|
||||
RuleRouter([
|
||||
("/handler", Target),
|
||||
])
|
||||
|
||||
In the examples above, ``Target`` can be a nested `Router` instance, an instance of
|
||||
`~.httputil.HTTPServerConnectionDelegate` or an old-style callable,
|
||||
accepting a request argument.
|
||||
|
||||
:arg rules: a list of `Rule` instances or tuples of `Rule`
|
||||
constructor arguments.
|
||||
"""
|
||||
self.rules = [] # type: typing.List[Rule]
|
||||
if rules:
|
||||
self.add_rules(rules)
|
||||
|
||||
def add_rules(self, rules):
|
||||
"""Appends new rules to the router.
|
||||
|
||||
:arg rules: a list of Rule instances (or tuples of arguments, which are
|
||||
passed to Rule constructor).
|
||||
"""
|
||||
for rule in rules:
|
||||
if isinstance(rule, (tuple, list)):
|
||||
assert len(rule) in (2, 3, 4)
|
||||
if isinstance(rule[0], basestring_type):
|
||||
rule = Rule(PathMatches(rule[0]), *rule[1:])
|
||||
else:
|
||||
rule = Rule(*rule)
|
||||
|
||||
self.rules.append(self.process_rule(rule))
|
||||
|
||||
def process_rule(self, rule):
|
||||
"""Override this method for additional preprocessing of each rule.
|
||||
|
||||
:arg Rule rule: a rule to be processed.
|
||||
:returns: the same or modified Rule instance.
|
||||
"""
|
||||
return rule
|
||||
|
||||
def find_handler(self, request, **kwargs):
|
||||
for rule in self.rules:
|
||||
target_params = rule.matcher.match(request)
|
||||
if target_params is not None:
|
||||
if rule.target_kwargs:
|
||||
target_params['target_kwargs'] = rule.target_kwargs
|
||||
|
||||
delegate = self.get_target_delegate(
|
||||
rule.target, request, **target_params)
|
||||
|
||||
if delegate is not None:
|
||||
return delegate
|
||||
|
||||
return None
|
||||
|
||||
def get_target_delegate(self, target, request, **target_params):
|
||||
"""Returns an instance of `~.httputil.HTTPMessageDelegate` for a
|
||||
Rule's target. This method is called by `~.find_handler` and can be
|
||||
extended to provide additional target types.
|
||||
|
||||
:arg target: a Rule's target.
|
||||
:arg httputil.HTTPServerRequest request: current request.
|
||||
:arg target_params: additional parameters that can be useful
|
||||
for `~.httputil.HTTPMessageDelegate` creation.
|
||||
"""
|
||||
if isinstance(target, Router):
|
||||
return target.find_handler(request, **target_params)
|
||||
|
||||
elif isinstance(target, httputil.HTTPServerConnectionDelegate):
|
||||
return target.start_request(request.server_connection, request.connection)
|
||||
|
||||
elif callable(target):
|
||||
return _CallableAdapter(
|
||||
partial(target, **target_params), request.connection
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ReversibleRuleRouter(ReversibleRouter, RuleRouter):
|
||||
"""A rule-based router that implements ``reverse_url`` method.
|
||||
|
||||
Each rule added to this router may have a ``name`` attribute that can be
|
||||
used to reconstruct an original uri. The actual reconstruction takes place
|
||||
in a rule's matcher (see `Matcher.reverse`).
|
||||
"""
|
||||
|
||||
def __init__(self, rules=None):
|
||||
self.named_rules = {} # type: typing.Dict[str]
|
||||
super(ReversibleRuleRouter, self).__init__(rules)
|
||||
|
||||
def process_rule(self, rule):
|
||||
rule = super(ReversibleRuleRouter, self).process_rule(rule)
|
||||
|
||||
if rule.name:
|
||||
if rule.name in self.named_rules:
|
||||
app_log.warning(
|
||||
"Multiple handlers named %s; replacing previous value",
|
||||
rule.name)
|
||||
self.named_rules[rule.name] = rule
|
||||
|
||||
return rule
|
||||
|
||||
def reverse_url(self, name, *args):
|
||||
if name in self.named_rules:
|
||||
return self.named_rules[name].matcher.reverse(*args)
|
||||
|
||||
for rule in self.rules:
|
||||
if isinstance(rule.target, ReversibleRouter):
|
||||
reversed_url = rule.target.reverse_url(name, *args)
|
||||
if reversed_url is not None:
|
||||
return reversed_url
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class Rule(object):
|
||||
"""A routing rule."""
|
||||
|
||||
def __init__(self, matcher, target, target_kwargs=None, name=None):
|
||||
"""Constructs a Rule instance.
|
||||
|
||||
:arg Matcher matcher: a `Matcher` instance used for determining
|
||||
whether the rule should be considered a match for a specific
|
||||
request.
|
||||
:arg target: a Rule's target (typically a ``RequestHandler`` or
|
||||
`~.httputil.HTTPServerConnectionDelegate` subclass or even a nested `Router`,
|
||||
depending on routing implementation).
|
||||
:arg dict target_kwargs: a dict of parameters that can be useful
|
||||
at the moment of target instantiation (for example, ``status_code``
|
||||
for a ``RequestHandler`` subclass). They end up in
|
||||
``target_params['target_kwargs']`` of `RuleRouter.get_target_delegate`
|
||||
method.
|
||||
:arg str name: the name of the rule that can be used to find it
|
||||
in `ReversibleRouter.reverse_url` implementation.
|
||||
"""
|
||||
if isinstance(target, str):
|
||||
# import the Module and instantiate the class
|
||||
# Must be a fully qualified name (module.ClassName)
|
||||
target = import_object(target)
|
||||
|
||||
self.matcher = matcher # type: Matcher
|
||||
self.target = target
|
||||
self.target_kwargs = target_kwargs if target_kwargs else {}
|
||||
self.name = name
|
||||
|
||||
def reverse(self, *args):
|
||||
return self.matcher.reverse(*args)
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%r, %s, kwargs=%r, name=%r)' % \
|
||||
(self.__class__.__name__, self.matcher,
|
||||
self.target, self.target_kwargs, self.name)
|
||||
|
||||
|
||||
class Matcher(object):
|
||||
"""Represents a matcher for request features."""
|
||||
|
||||
def match(self, request):
|
||||
"""Matches current instance against the request.
|
||||
|
||||
:arg httputil.HTTPServerRequest request: current HTTP request
|
||||
:returns: a dict of parameters to be passed to the target handler
|
||||
(for example, ``handler_kwargs``, ``path_args``, ``path_kwargs``
|
||||
can be passed for proper `~.web.RequestHandler` instantiation).
|
||||
An empty dict is a valid (and common) return value to indicate a match
|
||||
when the argument-passing features are not used.
|
||||
``None`` must be returned to indicate that there is no match."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def reverse(self, *args):
|
||||
"""Reconstructs full url from matcher instance and additional arguments."""
|
||||
return None
|
||||
|
||||
|
||||
class AnyMatches(Matcher):
|
||||
"""Matches any request."""
|
||||
|
||||
def match(self, request):
|
||||
return {}
|
||||
|
||||
|
||||
class HostMatches(Matcher):
|
||||
"""Matches requests from hosts specified by ``host_pattern`` regex."""
|
||||
|
||||
def __init__(self, host_pattern):
|
||||
if isinstance(host_pattern, basestring_type):
|
||||
if not host_pattern.endswith("$"):
|
||||
host_pattern += "$"
|
||||
self.host_pattern = re.compile(host_pattern)
|
||||
else:
|
||||
self.host_pattern = host_pattern
|
||||
|
||||
def match(self, request):
|
||||
if self.host_pattern.match(request.host_name):
|
||||
return {}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class DefaultHostMatches(Matcher):
|
||||
"""Matches requests from host that is equal to application's default_host.
|
||||
Always returns no match if ``X-Real-Ip`` header is present.
|
||||
"""
|
||||
|
||||
def __init__(self, application, host_pattern):
|
||||
self.application = application
|
||||
self.host_pattern = host_pattern
|
||||
|
||||
def match(self, request):
|
||||
# Look for default host if not behind load balancer (for debugging)
|
||||
if "X-Real-Ip" not in request.headers:
|
||||
if self.host_pattern.match(self.application.default_host):
|
||||
return {}
|
||||
return None
|
||||
|
||||
|
||||
class PathMatches(Matcher):
|
||||
"""Matches requests with paths specified by ``path_pattern`` regex."""
|
||||
|
||||
def __init__(self, path_pattern):
|
||||
if isinstance(path_pattern, basestring_type):
|
||||
if not path_pattern.endswith('$'):
|
||||
path_pattern += '$'
|
||||
self.regex = re.compile(path_pattern)
|
||||
else:
|
||||
self.regex = path_pattern
|
||||
|
||||
assert len(self.regex.groupindex) in (0, self.regex.groups), \
|
||||
("groups in url regexes must either be all named or all "
|
||||
"positional: %r" % self.regex.pattern)
|
||||
|
||||
self._path, self._group_count = self._find_groups()
|
||||
|
||||
def match(self, request):
|
||||
match = self.regex.match(request.path)
|
||||
if match is None:
|
||||
return None
|
||||
if not self.regex.groups:
|
||||
return {}
|
||||
|
||||
path_args, path_kwargs = [], {}
|
||||
|
||||
# Pass matched groups to the handler. Since
|
||||
# match.groups() includes both named and
|
||||
# unnamed groups, we want to use either groups
|
||||
# or groupdict but not both.
|
||||
if self.regex.groupindex:
|
||||
path_kwargs = dict(
|
||||
(str(k), _unquote_or_none(v))
|
||||
for (k, v) in match.groupdict().items())
|
||||
else:
|
||||
path_args = [_unquote_or_none(s) for s in match.groups()]
|
||||
|
||||
return dict(path_args=path_args, path_kwargs=path_kwargs)
|
||||
|
||||
def reverse(self, *args):
|
||||
if self._path is None:
|
||||
raise ValueError("Cannot reverse url regex " + self.regex.pattern)
|
||||
assert len(args) == self._group_count, "required number of arguments " \
|
||||
"not found"
|
||||
if not len(args):
|
||||
return self._path
|
||||
converted_args = []
|
||||
for a in args:
|
||||
if not isinstance(a, (unicode_type, bytes)):
|
||||
a = str(a)
|
||||
converted_args.append(url_escape(utf8(a), plus=False))
|
||||
return self._path % tuple(converted_args)
|
||||
|
||||
def _find_groups(self):
|
||||
"""Returns a tuple (reverse string, group count) for a url.
|
||||
|
||||
For example: Given the url pattern /([0-9]{4})/([a-z-]+)/, this method
|
||||
would return ('/%s/%s/', 2).
|
||||
"""
|
||||
pattern = self.regex.pattern
|
||||
if pattern.startswith('^'):
|
||||
pattern = pattern[1:]
|
||||
if pattern.endswith('$'):
|
||||
pattern = pattern[:-1]
|
||||
|
||||
if self.regex.groups != pattern.count('('):
|
||||
# The pattern is too complicated for our simplistic matching,
|
||||
# so we can't support reversing it.
|
||||
return None, None
|
||||
|
||||
pieces = []
|
||||
for fragment in pattern.split('('):
|
||||
if ')' in fragment:
|
||||
paren_loc = fragment.index(')')
|
||||
if paren_loc >= 0:
|
||||
pieces.append('%s' + fragment[paren_loc + 1:])
|
||||
else:
|
||||
try:
|
||||
unescaped_fragment = re_unescape(fragment)
|
||||
except ValueError:
|
||||
# If we can't unescape part of it, we can't
|
||||
# reverse this url.
|
||||
return (None, None)
|
||||
pieces.append(unescaped_fragment)
|
||||
|
||||
return ''.join(pieces), self.regex.groups
|
||||
|
||||
|
||||
class URLSpec(Rule):
|
||||
"""Specifies mappings between URLs and handlers.
|
||||
|
||||
.. versionchanged: 4.5
|
||||
`URLSpec` is now a subclass of a `Rule` with `PathMatches` matcher and is preserved for
|
||||
backwards compatibility.
|
||||
"""
|
||||
def __init__(self, pattern, handler, kwargs=None, name=None):
|
||||
"""Parameters:
|
||||
|
||||
* ``pattern``: Regular expression to be matched. Any capturing
|
||||
groups in the regex will be passed in to the handler's
|
||||
get/post/etc methods as arguments (by keyword if named, by
|
||||
position if unnamed. Named and unnamed capturing groups
|
||||
may not be mixed in the same rule).
|
||||
|
||||
* ``handler``: `~.web.RequestHandler` subclass to be invoked.
|
||||
|
||||
* ``kwargs`` (optional): A dictionary of additional arguments
|
||||
to be passed to the handler's constructor.
|
||||
|
||||
* ``name`` (optional): A name for this handler. Used by
|
||||
`~.web.Application.reverse_url`.
|
||||
|
||||
"""
|
||||
super(URLSpec, self).__init__(PathMatches(pattern), handler, kwargs, name)
|
||||
|
||||
self.regex = self.matcher.regex
|
||||
self.handler_class = self.target
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%r, %s, kwargs=%r, name=%r)' % \
|
||||
(self.__class__.__name__, self.regex.pattern,
|
||||
self.handler_class, self.kwargs, self.name)
|
||||
|
||||
|
||||
def _unquote_or_none(s):
|
||||
"""None-safe wrapper around url_unescape to handle unmatched optional
|
||||
groups correctly.
|
||||
|
||||
Note that args are passed as bytes so the handler can decide what
|
||||
encoding to use.
|
||||
"""
|
||||
if s is None:
|
||||
return s
|
||||
return url_unescape(s, encoding=None, plus=False)
|
||||
Executable
+566
@@ -0,0 +1,566 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from tornado.escape import _unicode
|
||||
from tornado import gen
|
||||
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
|
||||
from tornado import httputil
|
||||
from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import StreamClosedError
|
||||
from tornado.netutil import Resolver, OverrideResolver, _client_ssl_defaults
|
||||
from tornado.log import gen_log
|
||||
from tornado import stack_context
|
||||
from tornado.tcpclient import TCPClient
|
||||
from tornado.util import PY3
|
||||
|
||||
import base64
|
||||
import collections
|
||||
import copy
|
||||
import functools
|
||||
import re
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
if PY3:
|
||||
import urllib.parse as urlparse
|
||||
else:
|
||||
import urlparse
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
# ssl is not available on Google App Engine.
|
||||
ssl = None
|
||||
|
||||
|
||||
class HTTPTimeoutError(HTTPError):
|
||||
"""Error raised by SimpleAsyncHTTPClient on timeout.
|
||||
|
||||
For historical reasons, this is a subclass of `.HTTPClientError`
|
||||
which simulates a response code of 599.
|
||||
|
||||
.. versionadded:: 5.1
|
||||
"""
|
||||
def __init__(self, message):
|
||||
super(HTTPTimeoutError, self).__init__(599, message=message)
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
class HTTPStreamClosedError(HTTPError):
|
||||
"""Error raised by SimpleAsyncHTTPClient when the underlying stream is closed.
|
||||
|
||||
When a more specific exception is available (such as `ConnectionResetError`),
|
||||
it may be raised instead of this one.
|
||||
|
||||
For historical reasons, this is a subclass of `.HTTPClientError`
|
||||
which simulates a response code of 599.
|
||||
|
||||
.. versionadded:: 5.1
|
||||
"""
|
||||
def __init__(self, message):
|
||||
super(HTTPStreamClosedError, self).__init__(599, message=message)
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
class SimpleAsyncHTTPClient(AsyncHTTPClient):
|
||||
"""Non-blocking HTTP client with no external dependencies.
|
||||
|
||||
This class implements an HTTP 1.1 client on top of Tornado's IOStreams.
|
||||
Some features found in the curl-based AsyncHTTPClient are not yet
|
||||
supported. In particular, proxies are not supported, connections
|
||||
are not reused, and callers cannot select the network interface to be
|
||||
used.
|
||||
"""
|
||||
def initialize(self, max_clients=10,
|
||||
hostname_mapping=None, max_buffer_size=104857600,
|
||||
resolver=None, defaults=None, max_header_size=None,
|
||||
max_body_size=None):
|
||||
"""Creates a AsyncHTTPClient.
|
||||
|
||||
Only a single AsyncHTTPClient instance exists per IOLoop
|
||||
in order to provide limitations on the number of pending connections.
|
||||
``force_instance=True`` may be used to suppress this behavior.
|
||||
|
||||
Note that because of this implicit reuse, unless ``force_instance``
|
||||
is used, only the first call to the constructor actually uses
|
||||
its arguments. It is recommended to use the ``configure`` method
|
||||
instead of the constructor to ensure that arguments take effect.
|
||||
|
||||
``max_clients`` is the number of concurrent requests that can be
|
||||
in progress; when this limit is reached additional requests will be
|
||||
queued. Note that time spent waiting in this queue still counts
|
||||
against the ``request_timeout``.
|
||||
|
||||
``hostname_mapping`` is a dictionary mapping hostnames to IP addresses.
|
||||
It can be used to make local DNS changes when modifying system-wide
|
||||
settings like ``/etc/hosts`` is not possible or desirable (e.g. in
|
||||
unittests).
|
||||
|
||||
``max_buffer_size`` (default 100MB) is the number of bytes
|
||||
that can be read into memory at once. ``max_body_size``
|
||||
(defaults to ``max_buffer_size``) is the largest response body
|
||||
that the client will accept. Without a
|
||||
``streaming_callback``, the smaller of these two limits
|
||||
applies; with a ``streaming_callback`` only ``max_body_size``
|
||||
does.
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
Added the ``max_body_size`` argument.
|
||||
"""
|
||||
super(SimpleAsyncHTTPClient, self).initialize(defaults=defaults)
|
||||
self.max_clients = max_clients
|
||||
self.queue = collections.deque()
|
||||
self.active = {}
|
||||
self.waiting = {}
|
||||
self.max_buffer_size = max_buffer_size
|
||||
self.max_header_size = max_header_size
|
||||
self.max_body_size = max_body_size
|
||||
# TCPClient could create a Resolver for us, but we have to do it
|
||||
# ourselves to support hostname_mapping.
|
||||
if resolver:
|
||||
self.resolver = resolver
|
||||
self.own_resolver = False
|
||||
else:
|
||||
self.resolver = Resolver()
|
||||
self.own_resolver = True
|
||||
if hostname_mapping is not None:
|
||||
self.resolver = OverrideResolver(resolver=self.resolver,
|
||||
mapping=hostname_mapping)
|
||||
self.tcp_client = TCPClient(resolver=self.resolver)
|
||||
|
||||
def close(self):
|
||||
super(SimpleAsyncHTTPClient, self).close()
|
||||
if self.own_resolver:
|
||||
self.resolver.close()
|
||||
self.tcp_client.close()
|
||||
|
||||
def fetch_impl(self, request, callback):
|
||||
key = object()
|
||||
self.queue.append((key, request, callback))
|
||||
if not len(self.active) < self.max_clients:
|
||||
timeout_handle = self.io_loop.add_timeout(
|
||||
self.io_loop.time() + min(request.connect_timeout,
|
||||
request.request_timeout),
|
||||
functools.partial(self._on_timeout, key, "in request queue"))
|
||||
else:
|
||||
timeout_handle = None
|
||||
self.waiting[key] = (request, callback, timeout_handle)
|
||||
self._process_queue()
|
||||
if self.queue:
|
||||
gen_log.debug("max_clients limit reached, request queued. "
|
||||
"%d active, %d queued requests." % (
|
||||
len(self.active), len(self.queue)))
|
||||
|
||||
def _process_queue(self):
|
||||
with stack_context.NullContext():
|
||||
while self.queue and len(self.active) < self.max_clients:
|
||||
key, request, callback = self.queue.popleft()
|
||||
if key not in self.waiting:
|
||||
continue
|
||||
self._remove_timeout(key)
|
||||
self.active[key] = (request, callback)
|
||||
release_callback = functools.partial(self._release_fetch, key)
|
||||
self._handle_request(request, release_callback, callback)
|
||||
|
||||
def _connection_class(self):
|
||||
return _HTTPConnection
|
||||
|
||||
def _handle_request(self, request, release_callback, final_callback):
|
||||
self._connection_class()(
|
||||
self, request, release_callback,
|
||||
final_callback, self.max_buffer_size, self.tcp_client,
|
||||
self.max_header_size, self.max_body_size)
|
||||
|
||||
def _release_fetch(self, key):
|
||||
del self.active[key]
|
||||
self._process_queue()
|
||||
|
||||
def _remove_timeout(self, key):
|
||||
if key in self.waiting:
|
||||
request, callback, timeout_handle = self.waiting[key]
|
||||
if timeout_handle is not None:
|
||||
self.io_loop.remove_timeout(timeout_handle)
|
||||
del self.waiting[key]
|
||||
|
||||
def _on_timeout(self, key, info=None):
|
||||
"""Timeout callback of request.
|
||||
|
||||
Construct a timeout HTTPResponse when a timeout occurs.
|
||||
|
||||
:arg object key: A simple object to mark the request.
|
||||
:info string key: More detailed timeout information.
|
||||
"""
|
||||
request, callback, timeout_handle = self.waiting[key]
|
||||
self.queue.remove((key, request, callback))
|
||||
|
||||
error_message = "Timeout {0}".format(info) if info else "Timeout"
|
||||
timeout_response = HTTPResponse(
|
||||
request, 599, error=HTTPTimeoutError(error_message),
|
||||
request_time=self.io_loop.time() - request.start_time)
|
||||
self.io_loop.add_callback(callback, timeout_response)
|
||||
del self.waiting[key]
|
||||
|
||||
|
||||
class _HTTPConnection(httputil.HTTPMessageDelegate):
|
||||
_SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
|
||||
|
||||
def __init__(self, client, request, release_callback,
|
||||
final_callback, max_buffer_size, tcp_client,
|
||||
max_header_size, max_body_size):
|
||||
self.io_loop = IOLoop.current()
|
||||
self.start_time = self.io_loop.time()
|
||||
self.start_wall_time = time.time()
|
||||
self.client = client
|
||||
self.request = request
|
||||
self.release_callback = release_callback
|
||||
self.final_callback = final_callback
|
||||
self.max_buffer_size = max_buffer_size
|
||||
self.tcp_client = tcp_client
|
||||
self.max_header_size = max_header_size
|
||||
self.max_body_size = max_body_size
|
||||
self.code = None
|
||||
self.headers = None
|
||||
self.chunks = []
|
||||
self._decompressor = None
|
||||
# Timeout handle returned by IOLoop.add_timeout
|
||||
self._timeout = None
|
||||
self._sockaddr = None
|
||||
IOLoop.current().add_callback(self.run)
|
||||
|
||||
@gen.coroutine
|
||||
def run(self):
|
||||
try:
|
||||
self.parsed = urlparse.urlsplit(_unicode(self.request.url))
|
||||
if self.parsed.scheme not in ("http", "https"):
|
||||
raise ValueError("Unsupported url scheme: %s" %
|
||||
self.request.url)
|
||||
# urlsplit results have hostname and port results, but they
|
||||
# didn't support ipv6 literals until python 2.7.
|
||||
netloc = self.parsed.netloc
|
||||
if "@" in netloc:
|
||||
userpass, _, netloc = netloc.rpartition("@")
|
||||
host, port = httputil.split_host_and_port(netloc)
|
||||
if port is None:
|
||||
port = 443 if self.parsed.scheme == "https" else 80
|
||||
if re.match(r'^\[.*\]$', host):
|
||||
# raw ipv6 addresses in urls are enclosed in brackets
|
||||
host = host[1:-1]
|
||||
self.parsed_hostname = host # save final host for _on_connect
|
||||
|
||||
if self.request.allow_ipv6 is False:
|
||||
af = socket.AF_INET
|
||||
else:
|
||||
af = socket.AF_UNSPEC
|
||||
|
||||
ssl_options = self._get_ssl_options(self.parsed.scheme)
|
||||
|
||||
timeout = min(self.request.connect_timeout, self.request.request_timeout)
|
||||
if timeout:
|
||||
self._timeout = self.io_loop.add_timeout(
|
||||
self.start_time + timeout,
|
||||
stack_context.wrap(functools.partial(self._on_timeout, "while connecting")))
|
||||
stream = yield self.tcp_client.connect(
|
||||
host, port, af=af,
|
||||
ssl_options=ssl_options,
|
||||
max_buffer_size=self.max_buffer_size)
|
||||
|
||||
if self.final_callback is None:
|
||||
# final_callback is cleared if we've hit our timeout.
|
||||
stream.close()
|
||||
return
|
||||
self.stream = stream
|
||||
self.stream.set_close_callback(self.on_connection_close)
|
||||
self._remove_timeout()
|
||||
if self.final_callback is None:
|
||||
return
|
||||
if self.request.request_timeout:
|
||||
self._timeout = self.io_loop.add_timeout(
|
||||
self.start_time + self.request.request_timeout,
|
||||
stack_context.wrap(functools.partial(self._on_timeout, "during request")))
|
||||
if (self.request.method not in self._SUPPORTED_METHODS and
|
||||
not self.request.allow_nonstandard_methods):
|
||||
raise KeyError("unknown method %s" % self.request.method)
|
||||
for key in ('network_interface',
|
||||
'proxy_host', 'proxy_port',
|
||||
'proxy_username', 'proxy_password',
|
||||
'proxy_auth_mode'):
|
||||
if getattr(self.request, key, None):
|
||||
raise NotImplementedError('%s not supported' % key)
|
||||
if "Connection" not in self.request.headers:
|
||||
self.request.headers["Connection"] = "close"
|
||||
if "Host" not in self.request.headers:
|
||||
if '@' in self.parsed.netloc:
|
||||
self.request.headers["Host"] = self.parsed.netloc.rpartition('@')[-1]
|
||||
else:
|
||||
self.request.headers["Host"] = self.parsed.netloc
|
||||
username, password = None, None
|
||||
if self.parsed.username is not None:
|
||||
username, password = self.parsed.username, self.parsed.password
|
||||
elif self.request.auth_username is not None:
|
||||
username = self.request.auth_username
|
||||
password = self.request.auth_password or ''
|
||||
if username is not None:
|
||||
if self.request.auth_mode not in (None, "basic"):
|
||||
raise ValueError("unsupported auth_mode %s",
|
||||
self.request.auth_mode)
|
||||
self.request.headers["Authorization"] = (
|
||||
b"Basic " + base64.b64encode(
|
||||
httputil.encode_username_password(username, password)))
|
||||
if self.request.user_agent:
|
||||
self.request.headers["User-Agent"] = self.request.user_agent
|
||||
if not self.request.allow_nonstandard_methods:
|
||||
# Some HTTP methods nearly always have bodies while others
|
||||
# almost never do. Fail in this case unless the user has
|
||||
# opted out of sanity checks with allow_nonstandard_methods.
|
||||
body_expected = self.request.method in ("POST", "PATCH", "PUT")
|
||||
body_present = (self.request.body is not None or
|
||||
self.request.body_producer is not None)
|
||||
if ((body_expected and not body_present) or
|
||||
(body_present and not body_expected)):
|
||||
raise ValueError(
|
||||
'Body must %sbe None for method %s (unless '
|
||||
'allow_nonstandard_methods is true)' %
|
||||
('not ' if body_expected else '', self.request.method))
|
||||
if self.request.expect_100_continue:
|
||||
self.request.headers["Expect"] = "100-continue"
|
||||
if self.request.body is not None:
|
||||
# When body_producer is used the caller is responsible for
|
||||
# setting Content-Length (or else chunked encoding will be used).
|
||||
self.request.headers["Content-Length"] = str(len(
|
||||
self.request.body))
|
||||
if (self.request.method == "POST" and
|
||||
"Content-Type" not in self.request.headers):
|
||||
self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
if self.request.decompress_response:
|
||||
self.request.headers["Accept-Encoding"] = "gzip"
|
||||
req_path = ((self.parsed.path or '/') +
|
||||
(('?' + self.parsed.query) if self.parsed.query else ''))
|
||||
self.connection = self._create_connection(stream)
|
||||
start_line = httputil.RequestStartLine(self.request.method,
|
||||
req_path, '')
|
||||
self.connection.write_headers(start_line, self.request.headers)
|
||||
if self.request.expect_100_continue:
|
||||
yield self.connection.read_response(self)
|
||||
else:
|
||||
yield self._write_body(True)
|
||||
except Exception:
|
||||
if not self._handle_exception(*sys.exc_info()):
|
||||
raise
|
||||
|
||||
def _get_ssl_options(self, scheme):
|
||||
if scheme == "https":
|
||||
if self.request.ssl_options is not None:
|
||||
return self.request.ssl_options
|
||||
# If we are using the defaults, don't construct a
|
||||
# new SSLContext.
|
||||
if (self.request.validate_cert and
|
||||
self.request.ca_certs is None and
|
||||
self.request.client_cert is None and
|
||||
self.request.client_key is None):
|
||||
return _client_ssl_defaults
|
||||
ssl_ctx = ssl.create_default_context(
|
||||
ssl.Purpose.SERVER_AUTH,
|
||||
cafile=self.request.ca_certs)
|
||||
if not self.request.validate_cert:
|
||||
ssl_ctx.check_hostname = False
|
||||
ssl_ctx.verify_mode = ssl.CERT_NONE
|
||||
if self.request.client_cert is not None:
|
||||
ssl_ctx.load_cert_chain(self.request.client_cert,
|
||||
self.request.client_key)
|
||||
if hasattr(ssl, 'OP_NO_COMPRESSION'):
|
||||
# See netutil.ssl_options_to_context
|
||||
ssl_ctx.options |= ssl.OP_NO_COMPRESSION
|
||||
return ssl_ctx
|
||||
return None
|
||||
|
||||
def _on_timeout(self, info=None):
|
||||
"""Timeout callback of _HTTPConnection instance.
|
||||
|
||||
Raise a `HTTPTimeoutError` when a timeout occurs.
|
||||
|
||||
:info string key: More detailed timeout information.
|
||||
"""
|
||||
self._timeout = None
|
||||
error_message = "Timeout {0}".format(info) if info else "Timeout"
|
||||
if self.final_callback is not None:
|
||||
self._handle_exception(HTTPTimeoutError, HTTPTimeoutError(error_message),
|
||||
None)
|
||||
|
||||
def _remove_timeout(self):
|
||||
if self._timeout is not None:
|
||||
self.io_loop.remove_timeout(self._timeout)
|
||||
self._timeout = None
|
||||
|
||||
def _create_connection(self, stream):
|
||||
stream.set_nodelay(True)
|
||||
connection = HTTP1Connection(
|
||||
stream, True,
|
||||
HTTP1ConnectionParameters(
|
||||
no_keep_alive=True,
|
||||
max_header_size=self.max_header_size,
|
||||
max_body_size=self.max_body_size,
|
||||
decompress=self.request.decompress_response),
|
||||
self._sockaddr)
|
||||
return connection
|
||||
|
||||
@gen.coroutine
|
||||
def _write_body(self, start_read):
|
||||
if self.request.body is not None:
|
||||
self.connection.write(self.request.body)
|
||||
elif self.request.body_producer is not None:
|
||||
fut = self.request.body_producer(self.connection.write)
|
||||
if fut is not None:
|
||||
yield fut
|
||||
self.connection.finish()
|
||||
if start_read:
|
||||
try:
|
||||
yield self.connection.read_response(self)
|
||||
except StreamClosedError:
|
||||
if not self._handle_exception(*sys.exc_info()):
|
||||
raise
|
||||
|
||||
def _release(self):
|
||||
if self.release_callback is not None:
|
||||
release_callback = self.release_callback
|
||||
self.release_callback = None
|
||||
release_callback()
|
||||
|
||||
def _run_callback(self, response):
|
||||
self._release()
|
||||
if self.final_callback is not None:
|
||||
final_callback = self.final_callback
|
||||
self.final_callback = None
|
||||
self.io_loop.add_callback(final_callback, response)
|
||||
|
||||
def _handle_exception(self, typ, value, tb):
|
||||
if self.final_callback:
|
||||
self._remove_timeout()
|
||||
if isinstance(value, StreamClosedError):
|
||||
if value.real_error is None:
|
||||
value = HTTPStreamClosedError("Stream closed")
|
||||
else:
|
||||
value = value.real_error
|
||||
self._run_callback(HTTPResponse(self.request, 599, error=value,
|
||||
request_time=self.io_loop.time() - self.start_time,
|
||||
start_time=self.start_wall_time,
|
||||
))
|
||||
|
||||
if hasattr(self, "stream"):
|
||||
# TODO: this may cause a StreamClosedError to be raised
|
||||
# by the connection's Future. Should we cancel the
|
||||
# connection more gracefully?
|
||||
self.stream.close()
|
||||
return True
|
||||
else:
|
||||
# If our callback has already been called, we are probably
|
||||
# catching an exception that is not caused by us but rather
|
||||
# some child of our callback. Rather than drop it on the floor,
|
||||
# pass it along, unless it's just the stream being closed.
|
||||
return isinstance(value, StreamClosedError)
|
||||
|
||||
def on_connection_close(self):
|
||||
if self.final_callback is not None:
|
||||
message = "Connection closed"
|
||||
if self.stream.error:
|
||||
raise self.stream.error
|
||||
try:
|
||||
raise HTTPStreamClosedError(message)
|
||||
except HTTPStreamClosedError:
|
||||
self._handle_exception(*sys.exc_info())
|
||||
|
||||
def headers_received(self, first_line, headers):
|
||||
if self.request.expect_100_continue and first_line.code == 100:
|
||||
self._write_body(False)
|
||||
return
|
||||
self.code = first_line.code
|
||||
self.reason = first_line.reason
|
||||
self.headers = headers
|
||||
|
||||
if self._should_follow_redirect():
|
||||
return
|
||||
|
||||
if self.request.header_callback is not None:
|
||||
# Reassemble the start line.
|
||||
self.request.header_callback('%s %s %s\r\n' % first_line)
|
||||
for k, v in self.headers.get_all():
|
||||
self.request.header_callback("%s: %s\r\n" % (k, v))
|
||||
self.request.header_callback('\r\n')
|
||||
|
||||
def _should_follow_redirect(self):
|
||||
return (self.request.follow_redirects and
|
||||
self.request.max_redirects > 0 and
|
||||
self.code in (301, 302, 303, 307, 308))
|
||||
|
||||
def finish(self):
|
||||
data = b''.join(self.chunks)
|
||||
self._remove_timeout()
|
||||
original_request = getattr(self.request, "original_request",
|
||||
self.request)
|
||||
if self._should_follow_redirect():
|
||||
assert isinstance(self.request, _RequestProxy)
|
||||
new_request = copy.copy(self.request.request)
|
||||
new_request.url = urlparse.urljoin(self.request.url,
|
||||
self.headers["Location"])
|
||||
new_request.max_redirects = self.request.max_redirects - 1
|
||||
del new_request.headers["Host"]
|
||||
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
|
||||
# Client SHOULD make a GET request after a 303.
|
||||
# According to the spec, 302 should be followed by the same
|
||||
# method as the original request, but in practice browsers
|
||||
# treat 302 the same as 303, and many servers use 302 for
|
||||
# compatibility with pre-HTTP/1.1 user agents which don't
|
||||
# understand the 303 status.
|
||||
if self.code in (302, 303):
|
||||
new_request.method = "GET"
|
||||
new_request.body = None
|
||||
for h in ["Content-Length", "Content-Type",
|
||||
"Content-Encoding", "Transfer-Encoding"]:
|
||||
try:
|
||||
del self.request.headers[h]
|
||||
except KeyError:
|
||||
pass
|
||||
new_request.original_request = original_request
|
||||
final_callback = self.final_callback
|
||||
self.final_callback = None
|
||||
self._release()
|
||||
fut = self.client.fetch(new_request, raise_error=False)
|
||||
fut.add_done_callback(lambda f: final_callback(f.result()))
|
||||
self._on_end_request()
|
||||
return
|
||||
if self.request.streaming_callback:
|
||||
buffer = BytesIO()
|
||||
else:
|
||||
buffer = BytesIO(data) # TODO: don't require one big string?
|
||||
response = HTTPResponse(original_request,
|
||||
self.code, reason=getattr(self, 'reason', None),
|
||||
headers=self.headers,
|
||||
request_time=self.io_loop.time() - self.start_time,
|
||||
start_time=self.start_wall_time,
|
||||
buffer=buffer,
|
||||
effective_url=self.request.url)
|
||||
self._run_callback(response)
|
||||
self._on_end_request()
|
||||
|
||||
def _on_end_request(self):
|
||||
self.stream.close()
|
||||
|
||||
def data_received(self, chunk):
|
||||
if self._should_follow_redirect():
|
||||
# We're going to follow a redirect so just discard the body.
|
||||
return
|
||||
if self.request.streaming_callback is not None:
|
||||
self.request.streaming_callback(chunk)
|
||||
else:
|
||||
self.chunks.append(chunk)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
|
||||
main()
|
||||
Executable
+77
@@ -0,0 +1,77 @@
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
#include <stdint.h>
|
||||
|
||||
static PyObject* websocket_mask(PyObject* self, PyObject* args) {
|
||||
const char* mask;
|
||||
Py_ssize_t mask_len;
|
||||
uint32_t uint32_mask;
|
||||
uint64_t uint64_mask;
|
||||
const char* data;
|
||||
Py_ssize_t data_len;
|
||||
Py_ssize_t i;
|
||||
PyObject* result;
|
||||
char* buf;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "s#s#", &mask, &mask_len, &data, &data_len)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
uint32_mask = ((uint32_t*)mask)[0];
|
||||
|
||||
result = PyBytes_FromStringAndSize(NULL, data_len);
|
||||
if (!result) {
|
||||
return NULL;
|
||||
}
|
||||
buf = PyBytes_AsString(result);
|
||||
|
||||
if (sizeof(size_t) >= 8) {
|
||||
uint64_mask = uint32_mask;
|
||||
uint64_mask = (uint64_mask << 32) | uint32_mask;
|
||||
|
||||
while (data_len >= 8) {
|
||||
((uint64_t*)buf)[0] = ((uint64_t*)data)[0] ^ uint64_mask;
|
||||
data += 8;
|
||||
buf += 8;
|
||||
data_len -= 8;
|
||||
}
|
||||
}
|
||||
|
||||
while (data_len >= 4) {
|
||||
((uint32_t*)buf)[0] = ((uint32_t*)data)[0] ^ uint32_mask;
|
||||
data += 4;
|
||||
buf += 4;
|
||||
data_len -= 4;
|
||||
}
|
||||
|
||||
for (i = 0; i < data_len; i++) {
|
||||
buf[i] = data[i] ^ mask[i];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyMethodDef methods[] = {
|
||||
{"websocket_mask", websocket_mask, METH_VARARGS, ""},
|
||||
{NULL, NULL, 0, NULL}
|
||||
};
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
static struct PyModuleDef speedupsmodule = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"speedups",
|
||||
NULL,
|
||||
-1,
|
||||
methods
|
||||
};
|
||||
|
||||
PyMODINIT_FUNC
|
||||
PyInit_speedups(void) {
|
||||
return PyModule_Create(&speedupsmodule);
|
||||
}
|
||||
#else // Python 2.x
|
||||
PyMODINIT_FUNC
|
||||
initspeedups(void) {
|
||||
Py_InitModule("tornado.speedups", methods);
|
||||
}
|
||||
#endif
|
||||
Executable
+1
@@ -0,0 +1 @@
|
||||
def websocket_mask(mask: bytes, data: bytes) -> bytes: ...
|
||||
Executable
+413
@@ -0,0 +1,413 @@
|
||||
#
|
||||
# Copyright 2010 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""`StackContext` allows applications to maintain threadlocal-like state
|
||||
that follows execution as it moves to other execution contexts.
|
||||
|
||||
The motivating examples are to eliminate the need for explicit
|
||||
``async_callback`` wrappers (as in `tornado.web.RequestHandler`), and to
|
||||
allow some additional context to be kept for logging.
|
||||
|
||||
This is slightly magic, but it's an extension of the idea that an
|
||||
exception handler is a kind of stack-local state and when that stack
|
||||
is suspended and resumed in a new context that state needs to be
|
||||
preserved. `StackContext` shifts the burden of restoring that state
|
||||
from each call site (e.g. wrapping each `.AsyncHTTPClient` callback
|
||||
in ``async_callback``) to the mechanisms that transfer control from
|
||||
one context to another (e.g. `.AsyncHTTPClient` itself, `.IOLoop`,
|
||||
thread pools, etc).
|
||||
|
||||
Example usage::
|
||||
|
||||
@contextlib.contextmanager
|
||||
def die_on_error():
|
||||
try:
|
||||
yield
|
||||
except Exception:
|
||||
logging.error("exception in asynchronous operation",exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
with StackContext(die_on_error):
|
||||
# Any exception thrown here *or in callback and its descendants*
|
||||
# will cause the process to exit instead of spinning endlessly
|
||||
# in the ioloop.
|
||||
http_client.fetch(url, callback)
|
||||
ioloop.start()
|
||||
|
||||
Most applications shouldn't have to work with `StackContext` directly.
|
||||
Here are a few rules of thumb for when it's necessary:
|
||||
|
||||
* If you're writing an asynchronous library that doesn't rely on a
|
||||
stack_context-aware library like `tornado.ioloop` or `tornado.iostream`
|
||||
(for example, if you're writing a thread pool), use
|
||||
`.stack_context.wrap()` before any asynchronous operations to capture the
|
||||
stack context from where the operation was started.
|
||||
|
||||
* If you're writing an asynchronous library that has some shared
|
||||
resources (such as a connection pool), create those shared resources
|
||||
within a ``with stack_context.NullContext():`` block. This will prevent
|
||||
``StackContexts`` from leaking from one request to another.
|
||||
|
||||
* If you want to write something like an exception handler that will
|
||||
persist across asynchronous calls, create a new `StackContext` (or
|
||||
`ExceptionStackContext`), and make your asynchronous calls in a ``with``
|
||||
block that references your `StackContext`.
|
||||
|
||||
.. deprecated:: 5.1
|
||||
|
||||
The ``stack_context`` package is deprecated and will be removed in
|
||||
Tornado 6.0.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
from tornado.util import raise_exc_info
|
||||
|
||||
|
||||
class StackContextInconsistentError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class _State(threading.local):
|
||||
def __init__(self):
|
||||
self.contexts = (tuple(), None)
|
||||
|
||||
|
||||
_state = _State()
|
||||
|
||||
|
||||
class StackContext(object):
|
||||
"""Establishes the given context as a StackContext that will be transferred.
|
||||
|
||||
Note that the parameter is a callable that returns a context
|
||||
manager, not the context itself. That is, where for a
|
||||
non-transferable context manager you would say::
|
||||
|
||||
with my_context():
|
||||
|
||||
StackContext takes the function itself rather than its result::
|
||||
|
||||
with StackContext(my_context):
|
||||
|
||||
The result of ``with StackContext() as cb:`` is a deactivation
|
||||
callback. Run this callback when the StackContext is no longer
|
||||
needed to ensure that it is not propagated any further (note that
|
||||
deactivating a context does not affect any instances of that
|
||||
context that are currently pending). This is an advanced feature
|
||||
and not necessary in most applications.
|
||||
"""
|
||||
def __init__(self, context_factory):
|
||||
warnings.warn("StackContext is deprecated and will be removed in Tornado 6.0",
|
||||
DeprecationWarning)
|
||||
self.context_factory = context_factory
|
||||
self.contexts = []
|
||||
self.active = True
|
||||
|
||||
def _deactivate(self):
|
||||
self.active = False
|
||||
|
||||
# StackContext protocol
|
||||
def enter(self):
|
||||
context = self.context_factory()
|
||||
self.contexts.append(context)
|
||||
context.__enter__()
|
||||
|
||||
def exit(self, type, value, traceback):
|
||||
context = self.contexts.pop()
|
||||
context.__exit__(type, value, traceback)
|
||||
|
||||
# Note that some of this code is duplicated in ExceptionStackContext
|
||||
# below. ExceptionStackContext is more common and doesn't need
|
||||
# the full generality of this class.
|
||||
def __enter__(self):
|
||||
self.old_contexts = _state.contexts
|
||||
self.new_contexts = (self.old_contexts[0] + (self,), self)
|
||||
_state.contexts = self.new_contexts
|
||||
|
||||
try:
|
||||
self.enter()
|
||||
except:
|
||||
_state.contexts = self.old_contexts
|
||||
raise
|
||||
|
||||
return self._deactivate
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
try:
|
||||
self.exit(type, value, traceback)
|
||||
finally:
|
||||
final_contexts = _state.contexts
|
||||
_state.contexts = self.old_contexts
|
||||
|
||||
# Generator coroutines and with-statements with non-local
|
||||
# effects interact badly. Check here for signs of
|
||||
# the stack getting out of sync.
|
||||
# Note that this check comes after restoring _state.context
|
||||
# so that if it fails things are left in a (relatively)
|
||||
# consistent state.
|
||||
if final_contexts is not self.new_contexts:
|
||||
raise StackContextInconsistentError(
|
||||
'stack_context inconsistency (may be caused by yield '
|
||||
'within a "with StackContext" block)')
|
||||
|
||||
# Break up a reference to itself to allow for faster GC on CPython.
|
||||
self.new_contexts = None
|
||||
|
||||
|
||||
class ExceptionStackContext(object):
|
||||
"""Specialization of StackContext for exception handling.
|
||||
|
||||
The supplied ``exception_handler`` function will be called in the
|
||||
event of an uncaught exception in this context. The semantics are
|
||||
similar to a try/finally clause, and intended use cases are to log
|
||||
an error, close a socket, or similar cleanup actions. The
|
||||
``exc_info`` triple ``(type, value, traceback)`` will be passed to the
|
||||
exception_handler function.
|
||||
|
||||
If the exception handler returns true, the exception will be
|
||||
consumed and will not be propagated to other exception handlers.
|
||||
|
||||
.. versionadded:: 5.1
|
||||
|
||||
The ``delay_warning`` argument can be used to delay the emission
|
||||
of DeprecationWarnings until an exception is caught by the
|
||||
``ExceptionStackContext``, which facilitates certain transitional
|
||||
use cases.
|
||||
"""
|
||||
def __init__(self, exception_handler, delay_warning=False):
|
||||
self.delay_warning = delay_warning
|
||||
if not self.delay_warning:
|
||||
warnings.warn(
|
||||
"StackContext is deprecated and will be removed in Tornado 6.0",
|
||||
DeprecationWarning)
|
||||
self.exception_handler = exception_handler
|
||||
self.active = True
|
||||
|
||||
def _deactivate(self):
|
||||
self.active = False
|
||||
|
||||
def exit(self, type, value, traceback):
|
||||
if type is not None:
|
||||
if self.delay_warning:
|
||||
warnings.warn(
|
||||
"StackContext is deprecated and will be removed in Tornado 6.0",
|
||||
DeprecationWarning)
|
||||
return self.exception_handler(type, value, traceback)
|
||||
|
||||
def __enter__(self):
|
||||
self.old_contexts = _state.contexts
|
||||
self.new_contexts = (self.old_contexts[0], self)
|
||||
_state.contexts = self.new_contexts
|
||||
|
||||
return self._deactivate
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
try:
|
||||
if type is not None:
|
||||
return self.exception_handler(type, value, traceback)
|
||||
finally:
|
||||
final_contexts = _state.contexts
|
||||
_state.contexts = self.old_contexts
|
||||
|
||||
if final_contexts is not self.new_contexts:
|
||||
raise StackContextInconsistentError(
|
||||
'stack_context inconsistency (may be caused by yield '
|
||||
'within a "with StackContext" block)')
|
||||
|
||||
# Break up a reference to itself to allow for faster GC on CPython.
|
||||
self.new_contexts = None
|
||||
|
||||
|
||||
class NullContext(object):
|
||||
"""Resets the `StackContext`.
|
||||
|
||||
Useful when creating a shared resource on demand (e.g. an
|
||||
`.AsyncHTTPClient`) where the stack that caused the creating is
|
||||
not relevant to future operations.
|
||||
"""
|
||||
def __enter__(self):
|
||||
self.old_contexts = _state.contexts
|
||||
_state.contexts = (tuple(), None)
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
_state.contexts = self.old_contexts
|
||||
|
||||
|
||||
def _remove_deactivated(contexts):
|
||||
"""Remove deactivated handlers from the chain"""
|
||||
# Clean ctx handlers
|
||||
stack_contexts = tuple([h for h in contexts[0] if h.active])
|
||||
|
||||
# Find new head
|
||||
head = contexts[1]
|
||||
while head is not None and not head.active:
|
||||
head = head.old_contexts[1]
|
||||
|
||||
# Process chain
|
||||
ctx = head
|
||||
while ctx is not None:
|
||||
parent = ctx.old_contexts[1]
|
||||
|
||||
while parent is not None:
|
||||
if parent.active:
|
||||
break
|
||||
ctx.old_contexts = parent.old_contexts
|
||||
parent = parent.old_contexts[1]
|
||||
|
||||
ctx = parent
|
||||
|
||||
return (stack_contexts, head)
|
||||
|
||||
|
||||
def wrap(fn):
|
||||
"""Returns a callable object that will restore the current `StackContext`
|
||||
when executed.
|
||||
|
||||
Use this whenever saving a callback to be executed later in a
|
||||
different execution context (either in a different thread or
|
||||
asynchronously in the same thread).
|
||||
"""
|
||||
# Check if function is already wrapped
|
||||
if fn is None or hasattr(fn, '_wrapped'):
|
||||
return fn
|
||||
|
||||
# Capture current stack head
|
||||
# TODO: Any other better way to store contexts and update them in wrapped function?
|
||||
cap_contexts = [_state.contexts]
|
||||
|
||||
if not cap_contexts[0][0] and not cap_contexts[0][1]:
|
||||
# Fast path when there are no active contexts.
|
||||
def null_wrapper(*args, **kwargs):
|
||||
try:
|
||||
current_state = _state.contexts
|
||||
_state.contexts = cap_contexts[0]
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
_state.contexts = current_state
|
||||
null_wrapper._wrapped = True
|
||||
return null_wrapper
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
ret = None
|
||||
try:
|
||||
# Capture old state
|
||||
current_state = _state.contexts
|
||||
|
||||
# Remove deactivated items
|
||||
cap_contexts[0] = contexts = _remove_deactivated(cap_contexts[0])
|
||||
|
||||
# Force new state
|
||||
_state.contexts = contexts
|
||||
|
||||
# Current exception
|
||||
exc = (None, None, None)
|
||||
top = None
|
||||
|
||||
# Apply stack contexts
|
||||
last_ctx = 0
|
||||
stack = contexts[0]
|
||||
|
||||
# Apply state
|
||||
for n in stack:
|
||||
try:
|
||||
n.enter()
|
||||
last_ctx += 1
|
||||
except:
|
||||
# Exception happened. Record exception info and store top-most handler
|
||||
exc = sys.exc_info()
|
||||
top = n.old_contexts[1]
|
||||
|
||||
# Execute callback if no exception happened while restoring state
|
||||
if top is None:
|
||||
try:
|
||||
ret = fn(*args, **kwargs)
|
||||
except:
|
||||
exc = sys.exc_info()
|
||||
top = contexts[1]
|
||||
|
||||
# If there was exception, try to handle it by going through the exception chain
|
||||
if top is not None:
|
||||
exc = _handle_exception(top, exc)
|
||||
else:
|
||||
# Otherwise take shorter path and run stack contexts in reverse order
|
||||
while last_ctx > 0:
|
||||
last_ctx -= 1
|
||||
c = stack[last_ctx]
|
||||
|
||||
try:
|
||||
c.exit(*exc)
|
||||
except:
|
||||
exc = sys.exc_info()
|
||||
top = c.old_contexts[1]
|
||||
break
|
||||
else:
|
||||
top = None
|
||||
|
||||
# If if exception happened while unrolling, take longer exception handler path
|
||||
if top is not None:
|
||||
exc = _handle_exception(top, exc)
|
||||
|
||||
# If exception was not handled, raise it
|
||||
if exc != (None, None, None):
|
||||
raise_exc_info(exc)
|
||||
finally:
|
||||
_state.contexts = current_state
|
||||
return ret
|
||||
|
||||
wrapped._wrapped = True
|
||||
return wrapped
|
||||
|
||||
|
||||
def _handle_exception(tail, exc):
|
||||
while tail is not None:
|
||||
try:
|
||||
if tail.exit(*exc):
|
||||
exc = (None, None, None)
|
||||
except:
|
||||
exc = sys.exc_info()
|
||||
|
||||
tail = tail.old_contexts[1]
|
||||
|
||||
return exc
|
||||
|
||||
|
||||
def run_with_stack_context(context, func):
|
||||
"""Run a coroutine ``func`` in the given `StackContext`.
|
||||
|
||||
It is not safe to have a ``yield`` statement within a ``with StackContext``
|
||||
block, so it is difficult to use stack context with `.gen.coroutine`.
|
||||
This helper function runs the function in the correct context while
|
||||
keeping the ``yield`` and ``with`` statements syntactically separate.
|
||||
|
||||
Example::
|
||||
|
||||
@gen.coroutine
|
||||
def incorrect():
|
||||
with StackContext(ctx):
|
||||
# ERROR: this will raise StackContextInconsistentError
|
||||
yield other_coroutine()
|
||||
|
||||
@gen.coroutine
|
||||
def correct():
|
||||
yield run_with_stack_context(StackContext(ctx), other_coroutine)
|
||||
|
||||
.. versionadded:: 3.1
|
||||
"""
|
||||
with context:
|
||||
return func()
|
||||
Executable
+276
@@ -0,0 +1,276 @@
|
||||
#
|
||||
# Copyright 2014 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""A non-blocking TCP connection factory.
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import functools
|
||||
import socket
|
||||
import numbers
|
||||
import datetime
|
||||
|
||||
from tornado.concurrent import Future, future_add_done_callback
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import IOStream
|
||||
from tornado import gen
|
||||
from tornado.netutil import Resolver
|
||||
from tornado.platform.auto import set_close_exec
|
||||
from tornado.gen import TimeoutError
|
||||
from tornado.util import timedelta_to_seconds
|
||||
|
||||
_INITIAL_CONNECT_TIMEOUT = 0.3
|
||||
|
||||
|
||||
class _Connector(object):
|
||||
"""A stateless implementation of the "Happy Eyeballs" algorithm.
|
||||
|
||||
"Happy Eyeballs" is documented in RFC6555 as the recommended practice
|
||||
for when both IPv4 and IPv6 addresses are available.
|
||||
|
||||
In this implementation, we partition the addresses by family, and
|
||||
make the first connection attempt to whichever address was
|
||||
returned first by ``getaddrinfo``. If that connection fails or
|
||||
times out, we begin a connection in parallel to the first address
|
||||
of the other family. If there are additional failures we retry
|
||||
with other addresses, keeping one connection attempt per family
|
||||
in flight at a time.
|
||||
|
||||
http://tools.ietf.org/html/rfc6555
|
||||
|
||||
"""
|
||||
def __init__(self, addrinfo, connect):
|
||||
self.io_loop = IOLoop.current()
|
||||
self.connect = connect
|
||||
|
||||
self.future = Future()
|
||||
self.timeout = None
|
||||
self.connect_timeout = None
|
||||
self.last_error = None
|
||||
self.remaining = len(addrinfo)
|
||||
self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
|
||||
self.streams = set()
|
||||
|
||||
@staticmethod
|
||||
def split(addrinfo):
|
||||
"""Partition the ``addrinfo`` list by address family.
|
||||
|
||||
Returns two lists. The first list contains the first entry from
|
||||
``addrinfo`` and all others with the same family, and the
|
||||
second list contains all other addresses (normally one list will
|
||||
be AF_INET and the other AF_INET6, although non-standard resolvers
|
||||
may return additional families).
|
||||
"""
|
||||
primary = []
|
||||
secondary = []
|
||||
primary_af = addrinfo[0][0]
|
||||
for af, addr in addrinfo:
|
||||
if af == primary_af:
|
||||
primary.append((af, addr))
|
||||
else:
|
||||
secondary.append((af, addr))
|
||||
return primary, secondary
|
||||
|
||||
def start(self, timeout=_INITIAL_CONNECT_TIMEOUT, connect_timeout=None):
|
||||
self.try_connect(iter(self.primary_addrs))
|
||||
self.set_timeout(timeout)
|
||||
if connect_timeout is not None:
|
||||
self.set_connect_timeout(connect_timeout)
|
||||
return self.future
|
||||
|
||||
def try_connect(self, addrs):
|
||||
try:
|
||||
af, addr = next(addrs)
|
||||
except StopIteration:
|
||||
# We've reached the end of our queue, but the other queue
|
||||
# might still be working. Send a final error on the future
|
||||
# only when both queues are finished.
|
||||
if self.remaining == 0 and not self.future.done():
|
||||
self.future.set_exception(self.last_error or
|
||||
IOError("connection failed"))
|
||||
return
|
||||
stream, future = self.connect(af, addr)
|
||||
self.streams.add(stream)
|
||||
future_add_done_callback(
|
||||
future, functools.partial(self.on_connect_done, addrs, af, addr))
|
||||
|
||||
def on_connect_done(self, addrs, af, addr, future):
|
||||
self.remaining -= 1
|
||||
try:
|
||||
stream = future.result()
|
||||
except Exception as e:
|
||||
if self.future.done():
|
||||
return
|
||||
# Error: try again (but remember what happened so we have an
|
||||
# error to raise in the end)
|
||||
self.last_error = e
|
||||
self.try_connect(addrs)
|
||||
if self.timeout is not None:
|
||||
# If the first attempt failed, don't wait for the
|
||||
# timeout to try an address from the secondary queue.
|
||||
self.io_loop.remove_timeout(self.timeout)
|
||||
self.on_timeout()
|
||||
return
|
||||
self.clear_timeouts()
|
||||
if self.future.done():
|
||||
# This is a late arrival; just drop it.
|
||||
stream.close()
|
||||
else:
|
||||
self.streams.discard(stream)
|
||||
self.future.set_result((af, addr, stream))
|
||||
self.close_streams()
|
||||
|
||||
def set_timeout(self, timeout):
|
||||
self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout,
|
||||
self.on_timeout)
|
||||
|
||||
def on_timeout(self):
|
||||
self.timeout = None
|
||||
if not self.future.done():
|
||||
self.try_connect(iter(self.secondary_addrs))
|
||||
|
||||
def clear_timeout(self):
|
||||
if self.timeout is not None:
|
||||
self.io_loop.remove_timeout(self.timeout)
|
||||
|
||||
def set_connect_timeout(self, connect_timeout):
|
||||
self.connect_timeout = self.io_loop.add_timeout(
|
||||
connect_timeout, self.on_connect_timeout)
|
||||
|
||||
def on_connect_timeout(self):
|
||||
if not self.future.done():
|
||||
self.future.set_exception(TimeoutError())
|
||||
self.close_streams()
|
||||
|
||||
def clear_timeouts(self):
|
||||
if self.timeout is not None:
|
||||
self.io_loop.remove_timeout(self.timeout)
|
||||
if self.connect_timeout is not None:
|
||||
self.io_loop.remove_timeout(self.connect_timeout)
|
||||
|
||||
def close_streams(self):
|
||||
for stream in self.streams:
|
||||
stream.close()
|
||||
|
||||
|
||||
class TCPClient(object):
|
||||
"""A non-blocking TCP connection factory.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
|
||||
"""
|
||||
def __init__(self, resolver=None):
|
||||
if resolver is not None:
|
||||
self.resolver = resolver
|
||||
self._own_resolver = False
|
||||
else:
|
||||
self.resolver = Resolver()
|
||||
self._own_resolver = True
|
||||
|
||||
def close(self):
|
||||
if self._own_resolver:
|
||||
self.resolver.close()
|
||||
|
||||
@gen.coroutine
|
||||
def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
|
||||
max_buffer_size=None, source_ip=None, source_port=None,
|
||||
timeout=None):
|
||||
"""Connect to the given host and port.
|
||||
|
||||
Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
|
||||
``ssl_options`` is not None).
|
||||
|
||||
Using the ``source_ip`` kwarg, one can specify the source
|
||||
IP address to use when establishing the connection.
|
||||
In case the user needs to resolve and
|
||||
use a specific interface, it has to be handled outside
|
||||
of Tornado as this depends very much on the platform.
|
||||
|
||||
Raises `TimeoutError` if the input future does not complete before
|
||||
``timeout``, which may be specified in any form allowed by
|
||||
`.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time
|
||||
relative to `.IOLoop.time`)
|
||||
|
||||
Similarly, when the user requires a certain source port, it can
|
||||
be specified using the ``source_port`` arg.
|
||||
|
||||
.. versionchanged:: 4.5
|
||||
Added the ``source_ip`` and ``source_port`` arguments.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
Added the ``timeout`` argument.
|
||||
"""
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, numbers.Real):
|
||||
timeout = IOLoop.current().time() + timeout
|
||||
elif isinstance(timeout, datetime.timedelta):
|
||||
timeout = IOLoop.current().time() + timedelta_to_seconds(timeout)
|
||||
else:
|
||||
raise TypeError("Unsupported timeout %r" % timeout)
|
||||
if timeout is not None:
|
||||
addrinfo = yield gen.with_timeout(
|
||||
timeout, self.resolver.resolve(host, port, af))
|
||||
else:
|
||||
addrinfo = yield self.resolver.resolve(host, port, af)
|
||||
connector = _Connector(
|
||||
addrinfo,
|
||||
functools.partial(self._create_stream, max_buffer_size,
|
||||
source_ip=source_ip, source_port=source_port)
|
||||
)
|
||||
af, addr, stream = yield connector.start(connect_timeout=timeout)
|
||||
# TODO: For better performance we could cache the (af, addr)
|
||||
# information here and re-use it on subsequent connections to
|
||||
# the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
|
||||
if ssl_options is not None:
|
||||
if timeout is not None:
|
||||
stream = yield gen.with_timeout(timeout, stream.start_tls(
|
||||
False, ssl_options=ssl_options, server_hostname=host))
|
||||
else:
|
||||
stream = yield stream.start_tls(False, ssl_options=ssl_options,
|
||||
server_hostname=host)
|
||||
raise gen.Return(stream)
|
||||
|
||||
def _create_stream(self, max_buffer_size, af, addr, source_ip=None,
|
||||
source_port=None):
|
||||
# Always connect in plaintext; we'll convert to ssl if necessary
|
||||
# after one connection has completed.
|
||||
source_port_bind = source_port if isinstance(source_port, int) else 0
|
||||
source_ip_bind = source_ip
|
||||
if source_port_bind and not source_ip:
|
||||
# User required a specific port, but did not specify
|
||||
# a certain source IP, will bind to the default loopback.
|
||||
source_ip_bind = '::1' if af == socket.AF_INET6 else '127.0.0.1'
|
||||
# Trying to use the same address family as the requested af socket:
|
||||
# - 127.0.0.1 for IPv4
|
||||
# - ::1 for IPv6
|
||||
socket_obj = socket.socket(af)
|
||||
set_close_exec(socket_obj.fileno())
|
||||
if source_port_bind or source_ip_bind:
|
||||
# If the user requires binding also to a specific IP/port.
|
||||
try:
|
||||
socket_obj.bind((source_ip_bind, source_port_bind))
|
||||
except socket.error:
|
||||
socket_obj.close()
|
||||
# Fail loudly if unable to use the IP/port.
|
||||
raise
|
||||
try:
|
||||
stream = IOStream(socket_obj,
|
||||
max_buffer_size=max_buffer_size)
|
||||
except socket.error as e:
|
||||
fu = Future()
|
||||
fu.set_exception(e)
|
||||
return fu
|
||||
else:
|
||||
return stream, stream.connect(addr)
|
||||
Executable
+299
@@ -0,0 +1,299 @@
|
||||
#
|
||||
# Copyright 2011 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""A non-blocking, single-threaded TCP server."""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import errno
|
||||
import os
|
||||
import socket
|
||||
|
||||
from tornado import gen
|
||||
from tornado.log import app_log
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import IOStream, SSLIOStream
|
||||
from tornado.netutil import bind_sockets, add_accept_handler, ssl_wrap_socket
|
||||
from tornado import process
|
||||
from tornado.util import errno_from_exception
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
# ssl is not available on Google App Engine.
|
||||
ssl = None
|
||||
|
||||
|
||||
class TCPServer(object):
|
||||
r"""A non-blocking, single-threaded TCP server.
|
||||
|
||||
To use `TCPServer`, define a subclass which overrides the `handle_stream`
|
||||
method. For example, a simple echo server could be defined like this::
|
||||
|
||||
from tornado.tcpserver import TCPServer
|
||||
from tornado.iostream import StreamClosedError
|
||||
from tornado import gen
|
||||
|
||||
class EchoServer(TCPServer):
|
||||
async def handle_stream(self, stream, address):
|
||||
while True:
|
||||
try:
|
||||
data = await stream.read_until(b"\n")
|
||||
await stream.write(data)
|
||||
except StreamClosedError:
|
||||
break
|
||||
|
||||
To make this server serve SSL traffic, send the ``ssl_options`` keyword
|
||||
argument with an `ssl.SSLContext` object. For compatibility with older
|
||||
versions of Python ``ssl_options`` may also be a dictionary of keyword
|
||||
arguments for the `ssl.wrap_socket` method.::
|
||||
|
||||
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"),
|
||||
os.path.join(data_dir, "mydomain.key"))
|
||||
TCPServer(ssl_options=ssl_ctx)
|
||||
|
||||
`TCPServer` initialization follows one of three patterns:
|
||||
|
||||
1. `listen`: simple single-process::
|
||||
|
||||
server = TCPServer()
|
||||
server.listen(8888)
|
||||
IOLoop.current().start()
|
||||
|
||||
2. `bind`/`start`: simple multi-process::
|
||||
|
||||
server = TCPServer()
|
||||
server.bind(8888)
|
||||
server.start(0) # Forks multiple sub-processes
|
||||
IOLoop.current().start()
|
||||
|
||||
When using this interface, an `.IOLoop` must *not* be passed
|
||||
to the `TCPServer` constructor. `start` will always start
|
||||
the server on the default singleton `.IOLoop`.
|
||||
|
||||
3. `add_sockets`: advanced multi-process::
|
||||
|
||||
sockets = bind_sockets(8888)
|
||||
tornado.process.fork_processes(0)
|
||||
server = TCPServer()
|
||||
server.add_sockets(sockets)
|
||||
IOLoop.current().start()
|
||||
|
||||
The `add_sockets` interface is more complicated, but it can be
|
||||
used with `tornado.process.fork_processes` to give you more
|
||||
flexibility in when the fork happens. `add_sockets` can
|
||||
also be used in single-process servers if you want to create
|
||||
your listening sockets in some way other than
|
||||
`~tornado.netutil.bind_sockets`.
|
||||
|
||||
.. versionadded:: 3.1
|
||||
The ``max_buffer_size`` argument.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
The ``io_loop`` argument has been removed.
|
||||
"""
|
||||
def __init__(self, ssl_options=None, max_buffer_size=None,
|
||||
read_chunk_size=None):
|
||||
self.ssl_options = ssl_options
|
||||
self._sockets = {} # fd -> socket object
|
||||
self._handlers = {} # fd -> remove_handler callable
|
||||
self._pending_sockets = []
|
||||
self._started = False
|
||||
self._stopped = False
|
||||
self.max_buffer_size = max_buffer_size
|
||||
self.read_chunk_size = read_chunk_size
|
||||
|
||||
# Verify the SSL options. Otherwise we don't get errors until clients
|
||||
# connect. This doesn't verify that the keys are legitimate, but
|
||||
# the SSL module doesn't do that until there is a connected socket
|
||||
# which seems like too much work
|
||||
if self.ssl_options is not None and isinstance(self.ssl_options, dict):
|
||||
# Only certfile is required: it can contain both keys
|
||||
if 'certfile' not in self.ssl_options:
|
||||
raise KeyError('missing key "certfile" in ssl_options')
|
||||
|
||||
if not os.path.exists(self.ssl_options['certfile']):
|
||||
raise ValueError('certfile "%s" does not exist' %
|
||||
self.ssl_options['certfile'])
|
||||
if ('keyfile' in self.ssl_options and
|
||||
not os.path.exists(self.ssl_options['keyfile'])):
|
||||
raise ValueError('keyfile "%s" does not exist' %
|
||||
self.ssl_options['keyfile'])
|
||||
|
||||
def listen(self, port, address=""):
|
||||
"""Starts accepting connections on the given port.
|
||||
|
||||
This method may be called more than once to listen on multiple ports.
|
||||
`listen` takes effect immediately; it is not necessary to call
|
||||
`TCPServer.start` afterwards. It is, however, necessary to start
|
||||
the `.IOLoop`.
|
||||
"""
|
||||
sockets = bind_sockets(port, address=address)
|
||||
self.add_sockets(sockets)
|
||||
|
||||
def add_sockets(self, sockets):
|
||||
"""Makes this server start accepting connections on the given sockets.
|
||||
|
||||
The ``sockets`` parameter is a list of socket objects such as
|
||||
those returned by `~tornado.netutil.bind_sockets`.
|
||||
`add_sockets` is typically used in combination with that
|
||||
method and `tornado.process.fork_processes` to provide greater
|
||||
control over the initialization of a multi-process server.
|
||||
"""
|
||||
for sock in sockets:
|
||||
self._sockets[sock.fileno()] = sock
|
||||
self._handlers[sock.fileno()] = add_accept_handler(
|
||||
sock, self._handle_connection)
|
||||
|
||||
def add_socket(self, socket):
|
||||
"""Singular version of `add_sockets`. Takes a single socket object."""
|
||||
self.add_sockets([socket])
|
||||
|
||||
def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128,
|
||||
reuse_port=False):
|
||||
"""Binds this server to the given port on the given address.
|
||||
|
||||
To start the server, call `start`. If you want to run this server
|
||||
in a single process, you can call `listen` as a shortcut to the
|
||||
sequence of `bind` and `start` calls.
|
||||
|
||||
Address may be either an IP address or hostname. If it's a hostname,
|
||||
the server will listen on all IP addresses associated with the
|
||||
name. Address may be an empty string or None to listen on all
|
||||
available interfaces. Family may be set to either `socket.AF_INET`
|
||||
or `socket.AF_INET6` to restrict to IPv4 or IPv6 addresses, otherwise
|
||||
both will be used if available.
|
||||
|
||||
The ``backlog`` argument has the same meaning as for
|
||||
`socket.listen <socket.socket.listen>`. The ``reuse_port`` argument
|
||||
has the same meaning as for `.bind_sockets`.
|
||||
|
||||
This method may be called multiple times prior to `start` to listen
|
||||
on multiple ports or interfaces.
|
||||
|
||||
.. versionchanged:: 4.4
|
||||
Added the ``reuse_port`` argument.
|
||||
"""
|
||||
sockets = bind_sockets(port, address=address, family=family,
|
||||
backlog=backlog, reuse_port=reuse_port)
|
||||
if self._started:
|
||||
self.add_sockets(sockets)
|
||||
else:
|
||||
self._pending_sockets.extend(sockets)
|
||||
|
||||
def start(self, num_processes=1):
|
||||
"""Starts this server in the `.IOLoop`.
|
||||
|
||||
By default, we run the server in this process and do not fork any
|
||||
additional child process.
|
||||
|
||||
If num_processes is ``None`` or <= 0, we detect the number of cores
|
||||
available on this machine and fork that number of child
|
||||
processes. If num_processes is given and > 1, we fork that
|
||||
specific number of sub-processes.
|
||||
|
||||
Since we use processes and not threads, there is no shared memory
|
||||
between any server code.
|
||||
|
||||
Note that multiple processes are not compatible with the autoreload
|
||||
module (or the ``autoreload=True`` option to `tornado.web.Application`
|
||||
which defaults to True when ``debug=True``).
|
||||
When using multiple processes, no IOLoops can be created or
|
||||
referenced until after the call to ``TCPServer.start(n)``.
|
||||
"""
|
||||
assert not self._started
|
||||
self._started = True
|
||||
if num_processes != 1:
|
||||
process.fork_processes(num_processes)
|
||||
sockets = self._pending_sockets
|
||||
self._pending_sockets = []
|
||||
self.add_sockets(sockets)
|
||||
|
||||
def stop(self):
|
||||
"""Stops listening for new connections.
|
||||
|
||||
Requests currently in progress may still continue after the
|
||||
server is stopped.
|
||||
"""
|
||||
if self._stopped:
|
||||
return
|
||||
self._stopped = True
|
||||
for fd, sock in self._sockets.items():
|
||||
assert sock.fileno() == fd
|
||||
# Unregister socket from IOLoop
|
||||
self._handlers.pop(fd)()
|
||||
sock.close()
|
||||
|
||||
def handle_stream(self, stream, address):
|
||||
"""Override to handle a new `.IOStream` from an incoming connection.
|
||||
|
||||
This method may be a coroutine; if so any exceptions it raises
|
||||
asynchronously will be logged. Accepting of incoming connections
|
||||
will not be blocked by this coroutine.
|
||||
|
||||
If this `TCPServer` is configured for SSL, ``handle_stream``
|
||||
may be called before the SSL handshake has completed. Use
|
||||
`.SSLIOStream.wait_for_handshake` if you need to verify the client's
|
||||
certificate or use NPN/ALPN.
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
Added the option for this method to be a coroutine.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _handle_connection(self, connection, address):
|
||||
if self.ssl_options is not None:
|
||||
assert ssl, "Python 2.6+ and OpenSSL required for SSL"
|
||||
try:
|
||||
connection = ssl_wrap_socket(connection,
|
||||
self.ssl_options,
|
||||
server_side=True,
|
||||
do_handshake_on_connect=False)
|
||||
except ssl.SSLError as err:
|
||||
if err.args[0] == ssl.SSL_ERROR_EOF:
|
||||
return connection.close()
|
||||
else:
|
||||
raise
|
||||
except socket.error as err:
|
||||
# If the connection is closed immediately after it is created
|
||||
# (as in a port scan), we can get one of several errors.
|
||||
# wrap_socket makes an internal call to getpeername,
|
||||
# which may return either EINVAL (Mac OS X) or ENOTCONN
|
||||
# (Linux). If it returns ENOTCONN, this error is
|
||||
# silently swallowed by the ssl module, so we need to
|
||||
# catch another error later on (AttributeError in
|
||||
# SSLIOStream._do_ssl_handshake).
|
||||
# To test this behavior, try nmap with the -sT flag.
|
||||
# https://github.com/tornadoweb/tornado/pull/750
|
||||
if errno_from_exception(err) in (errno.ECONNABORTED, errno.EINVAL):
|
||||
return connection.close()
|
||||
else:
|
||||
raise
|
||||
try:
|
||||
if self.ssl_options is not None:
|
||||
stream = SSLIOStream(connection,
|
||||
max_buffer_size=self.max_buffer_size,
|
||||
read_chunk_size=self.read_chunk_size)
|
||||
else:
|
||||
stream = IOStream(connection,
|
||||
max_buffer_size=self.max_buffer_size,
|
||||
read_chunk_size=self.read_chunk_size)
|
||||
|
||||
future = self.handle_stream(stream, address)
|
||||
if future is not None:
|
||||
IOLoop.current().add_future(gen.convert_yielded(future),
|
||||
lambda f: f.result())
|
||||
except Exception:
|
||||
app_log.error("Error in connection callback", exc_info=True)
|
||||
Executable
+976
@@ -0,0 +1,976 @@
|
||||
#
|
||||
# Copyright 2009 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""A simple template system that compiles templates to Python code.
|
||||
|
||||
Basic usage looks like::
|
||||
|
||||
t = template.Template("<html>{{ myvalue }}</html>")
|
||||
print(t.generate(myvalue="XXX"))
|
||||
|
||||
`Loader` is a class that loads templates from a root directory and caches
|
||||
the compiled templates::
|
||||
|
||||
loader = template.Loader("/home/btaylor")
|
||||
print(loader.load("test.html").generate(myvalue="XXX"))
|
||||
|
||||
We compile all templates to raw Python. Error-reporting is currently... uh,
|
||||
interesting. Syntax for the templates::
|
||||
|
||||
### base.html
|
||||
<html>
|
||||
<head>
|
||||
<title>{% block title %}Default title{% end %}</title>
|
||||
</head>
|
||||
<body>
|
||||
<ul>
|
||||
{% for student in students %}
|
||||
{% block student %}
|
||||
<li>{{ escape(student.name) }}</li>
|
||||
{% end %}
|
||||
{% end %}
|
||||
</ul>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
### bold.html
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block title %}A bolder title{% end %}
|
||||
|
||||
{% block student %}
|
||||
<li><span style="bold">{{ escape(student.name) }}</span></li>
|
||||
{% end %}
|
||||
|
||||
Unlike most other template systems, we do not put any restrictions on the
|
||||
expressions you can include in your statements. ``if`` and ``for`` blocks get
|
||||
translated exactly into Python, so you can do complex expressions like::
|
||||
|
||||
{% for student in [p for p in people if p.student and p.age > 23] %}
|
||||
<li>{{ escape(student.name) }}</li>
|
||||
{% end %}
|
||||
|
||||
Translating directly to Python means you can apply functions to expressions
|
||||
easily, like the ``escape()`` function in the examples above. You can pass
|
||||
functions in to your template just like any other variable
|
||||
(In a `.RequestHandler`, override `.RequestHandler.get_template_namespace`)::
|
||||
|
||||
### Python code
|
||||
def add(x, y):
|
||||
return x + y
|
||||
template.execute(add=add)
|
||||
|
||||
### The template
|
||||
{{ add(1, 2) }}
|
||||
|
||||
We provide the functions `escape() <.xhtml_escape>`, `.url_escape()`,
|
||||
`.json_encode()`, and `.squeeze()` to all templates by default.
|
||||
|
||||
Typical applications do not create `Template` or `Loader` instances by
|
||||
hand, but instead use the `~.RequestHandler.render` and
|
||||
`~.RequestHandler.render_string` methods of
|
||||
`tornado.web.RequestHandler`, which load templates automatically based
|
||||
on the ``template_path`` `.Application` setting.
|
||||
|
||||
Variable names beginning with ``_tt_`` are reserved by the template
|
||||
system and should not be used by application code.
|
||||
|
||||
Syntax Reference
|
||||
----------------
|
||||
|
||||
Template expressions are surrounded by double curly braces: ``{{ ... }}``.
|
||||
The contents may be any python expression, which will be escaped according
|
||||
to the current autoescape setting and inserted into the output. Other
|
||||
template directives use ``{% %}``.
|
||||
|
||||
To comment out a section so that it is omitted from the output, surround it
|
||||
with ``{# ... #}``.
|
||||
|
||||
These tags may be escaped as ``{{!``, ``{%!``, and ``{#!``
|
||||
if you need to include a literal ``{{``, ``{%``, or ``{#`` in the output.
|
||||
|
||||
|
||||
``{% apply *function* %}...{% end %}``
|
||||
Applies a function to the output of all template code between ``apply``
|
||||
and ``end``::
|
||||
|
||||
{% apply linkify %}{{name}} said: {{message}}{% end %}
|
||||
|
||||
Note that as an implementation detail apply blocks are implemented
|
||||
as nested functions and thus may interact strangely with variables
|
||||
set via ``{% set %}``, or the use of ``{% break %}`` or ``{% continue %}``
|
||||
within loops.
|
||||
|
||||
``{% autoescape *function* %}``
|
||||
Sets the autoescape mode for the current file. This does not affect
|
||||
other files, even those referenced by ``{% include %}``. Note that
|
||||
autoescaping can also be configured globally, at the `.Application`
|
||||
or `Loader`.::
|
||||
|
||||
{% autoescape xhtml_escape %}
|
||||
{% autoescape None %}
|
||||
|
||||
``{% block *name* %}...{% end %}``
|
||||
Indicates a named, replaceable block for use with ``{% extends %}``.
|
||||
Blocks in the parent template will be replaced with the contents of
|
||||
the same-named block in a child template.::
|
||||
|
||||
<!-- base.html -->
|
||||
<title>{% block title %}Default title{% end %}</title>
|
||||
|
||||
<!-- mypage.html -->
|
||||
{% extends "base.html" %}
|
||||
{% block title %}My page title{% end %}
|
||||
|
||||
``{% comment ... %}``
|
||||
A comment which will be removed from the template output. Note that
|
||||
there is no ``{% end %}`` tag; the comment goes from the word ``comment``
|
||||
to the closing ``%}`` tag.
|
||||
|
||||
``{% extends *filename* %}``
|
||||
Inherit from another template. Templates that use ``extends`` should
|
||||
contain one or more ``block`` tags to replace content from the parent
|
||||
template. Anything in the child template not contained in a ``block``
|
||||
tag will be ignored. For an example, see the ``{% block %}`` tag.
|
||||
|
||||
``{% for *var* in *expr* %}...{% end %}``
|
||||
Same as the python ``for`` statement. ``{% break %}`` and
|
||||
``{% continue %}`` may be used inside the loop.
|
||||
|
||||
``{% from *x* import *y* %}``
|
||||
Same as the python ``import`` statement.
|
||||
|
||||
``{% if *condition* %}...{% elif *condition* %}...{% else %}...{% end %}``
|
||||
Conditional statement - outputs the first section whose condition is
|
||||
true. (The ``elif`` and ``else`` sections are optional)
|
||||
|
||||
``{% import *module* %}``
|
||||
Same as the python ``import`` statement.
|
||||
|
||||
``{% include *filename* %}``
|
||||
Includes another template file. The included file can see all the local
|
||||
variables as if it were copied directly to the point of the ``include``
|
||||
directive (the ``{% autoescape %}`` directive is an exception).
|
||||
Alternately, ``{% module Template(filename, **kwargs) %}`` may be used
|
||||
to include another template with an isolated namespace.
|
||||
|
||||
``{% module *expr* %}``
|
||||
Renders a `~tornado.web.UIModule`. The output of the ``UIModule`` is
|
||||
not escaped::
|
||||
|
||||
{% module Template("foo.html", arg=42) %}
|
||||
|
||||
``UIModules`` are a feature of the `tornado.web.RequestHandler`
|
||||
class (and specifically its ``render`` method) and will not work
|
||||
when the template system is used on its own in other contexts.
|
||||
|
||||
``{% raw *expr* %}``
|
||||
Outputs the result of the given expression without autoescaping.
|
||||
|
||||
``{% set *x* = *y* %}``
|
||||
Sets a local variable.
|
||||
|
||||
``{% try %}...{% except %}...{% else %}...{% finally %}...{% end %}``
|
||||
Same as the python ``try`` statement.
|
||||
|
||||
``{% while *condition* %}... {% end %}``
|
||||
Same as the python ``while`` statement. ``{% break %}`` and
|
||||
``{% continue %}`` may be used inside the loop.
|
||||
|
||||
``{% whitespace *mode* %}``
|
||||
Sets the whitespace mode for the remainder of the current file
|
||||
(or until the next ``{% whitespace %}`` directive). See
|
||||
`filter_whitespace` for available options. New in Tornado 4.3.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import datetime
|
||||
import linecache
|
||||
import os.path
|
||||
import posixpath
|
||||
import re
|
||||
import threading
|
||||
|
||||
from tornado import escape
|
||||
from tornado.log import app_log
|
||||
from tornado.util import ObjectDict, exec_in, unicode_type, PY3
|
||||
|
||||
if PY3:
|
||||
from io import StringIO
|
||||
else:
|
||||
from cStringIO import StringIO
|
||||
|
||||
_DEFAULT_AUTOESCAPE = "xhtml_escape"
|
||||
_UNSET = object()
|
||||
|
||||
|
||||
def filter_whitespace(mode, text):
|
||||
"""Transform whitespace in ``text`` according to ``mode``.
|
||||
|
||||
Available modes are:
|
||||
|
||||
* ``all``: Return all whitespace unmodified.
|
||||
* ``single``: Collapse consecutive whitespace with a single whitespace
|
||||
character, preserving newlines.
|
||||
* ``oneline``: Collapse all runs of whitespace into a single space
|
||||
character, removing all newlines in the process.
|
||||
|
||||
.. versionadded:: 4.3
|
||||
"""
|
||||
if mode == 'all':
|
||||
return text
|
||||
elif mode == 'single':
|
||||
text = re.sub(r"([\t ]+)", " ", text)
|
||||
text = re.sub(r"(\s*\n\s*)", "\n", text)
|
||||
return text
|
||||
elif mode == 'oneline':
|
||||
return re.sub(r"(\s+)", " ", text)
|
||||
else:
|
||||
raise Exception("invalid whitespace mode %s" % mode)
|
||||
|
||||
|
||||
class Template(object):
|
||||
"""A compiled template.
|
||||
|
||||
We compile into Python from the given template_string. You can generate
|
||||
the template from variables with generate().
|
||||
"""
|
||||
# note that the constructor's signature is not extracted with
|
||||
# autodoc because _UNSET looks like garbage. When changing
|
||||
# this signature update website/sphinx/template.rst too.
|
||||
def __init__(self, template_string, name="<string>", loader=None,
|
||||
compress_whitespace=_UNSET, autoescape=_UNSET,
|
||||
whitespace=None):
|
||||
"""Construct a Template.
|
||||
|
||||
:arg str template_string: the contents of the template file.
|
||||
:arg str name: the filename from which the template was loaded
|
||||
(used for error message).
|
||||
:arg tornado.template.BaseLoader loader: the `~tornado.template.BaseLoader` responsible
|
||||
for this template, used to resolve ``{% include %}`` and ``{% extend %}`` directives.
|
||||
:arg bool compress_whitespace: Deprecated since Tornado 4.3.
|
||||
Equivalent to ``whitespace="single"`` if true and
|
||||
``whitespace="all"`` if false.
|
||||
:arg str autoescape: The name of a function in the template
|
||||
namespace, or ``None`` to disable escaping by default.
|
||||
:arg str whitespace: A string specifying treatment of whitespace;
|
||||
see `filter_whitespace` for options.
|
||||
|
||||
.. versionchanged:: 4.3
|
||||
Added ``whitespace`` parameter; deprecated ``compress_whitespace``.
|
||||
"""
|
||||
self.name = escape.native_str(name)
|
||||
|
||||
if compress_whitespace is not _UNSET:
|
||||
# Convert deprecated compress_whitespace (bool) to whitespace (str).
|
||||
if whitespace is not None:
|
||||
raise Exception("cannot set both whitespace and compress_whitespace")
|
||||
whitespace = "single" if compress_whitespace else "all"
|
||||
if whitespace is None:
|
||||
if loader and loader.whitespace:
|
||||
whitespace = loader.whitespace
|
||||
else:
|
||||
# Whitespace defaults by filename.
|
||||
if name.endswith(".html") or name.endswith(".js"):
|
||||
whitespace = "single"
|
||||
else:
|
||||
whitespace = "all"
|
||||
# Validate the whitespace setting.
|
||||
filter_whitespace(whitespace, '')
|
||||
|
||||
if autoescape is not _UNSET:
|
||||
self.autoescape = autoescape
|
||||
elif loader:
|
||||
self.autoescape = loader.autoescape
|
||||
else:
|
||||
self.autoescape = _DEFAULT_AUTOESCAPE
|
||||
|
||||
self.namespace = loader.namespace if loader else {}
|
||||
reader = _TemplateReader(name, escape.native_str(template_string),
|
||||
whitespace)
|
||||
self.file = _File(self, _parse(reader, self))
|
||||
self.code = self._generate_python(loader)
|
||||
self.loader = loader
|
||||
try:
|
||||
# Under python2.5, the fake filename used here must match
|
||||
# the module name used in __name__ below.
|
||||
# The dont_inherit flag prevents template.py's future imports
|
||||
# from being applied to the generated code.
|
||||
self.compiled = compile(
|
||||
escape.to_unicode(self.code),
|
||||
"%s.generated.py" % self.name.replace('.', '_'),
|
||||
"exec", dont_inherit=True)
|
||||
except Exception:
|
||||
formatted_code = _format_code(self.code).rstrip()
|
||||
app_log.error("%s code:\n%s", self.name, formatted_code)
|
||||
raise
|
||||
|
||||
def generate(self, **kwargs):
|
||||
"""Generate this template with the given arguments."""
|
||||
namespace = {
|
||||
"escape": escape.xhtml_escape,
|
||||
"xhtml_escape": escape.xhtml_escape,
|
||||
"url_escape": escape.url_escape,
|
||||
"json_encode": escape.json_encode,
|
||||
"squeeze": escape.squeeze,
|
||||
"linkify": escape.linkify,
|
||||
"datetime": datetime,
|
||||
"_tt_utf8": escape.utf8, # for internal use
|
||||
"_tt_string_types": (unicode_type, bytes),
|
||||
# __name__ and __loader__ allow the traceback mechanism to find
|
||||
# the generated source code.
|
||||
"__name__": self.name.replace('.', '_'),
|
||||
"__loader__": ObjectDict(get_source=lambda name: self.code),
|
||||
}
|
||||
namespace.update(self.namespace)
|
||||
namespace.update(kwargs)
|
||||
exec_in(self.compiled, namespace)
|
||||
execute = namespace["_tt_execute"]
|
||||
# Clear the traceback module's cache of source data now that
|
||||
# we've generated a new template (mainly for this module's
|
||||
# unittests, where different tests reuse the same name).
|
||||
linecache.clearcache()
|
||||
return execute()
|
||||
|
||||
def _generate_python(self, loader):
|
||||
buffer = StringIO()
|
||||
try:
|
||||
# named_blocks maps from names to _NamedBlock objects
|
||||
named_blocks = {}
|
||||
ancestors = self._get_ancestors(loader)
|
||||
ancestors.reverse()
|
||||
for ancestor in ancestors:
|
||||
ancestor.find_named_blocks(loader, named_blocks)
|
||||
writer = _CodeWriter(buffer, named_blocks, loader,
|
||||
ancestors[0].template)
|
||||
ancestors[0].generate(writer)
|
||||
return buffer.getvalue()
|
||||
finally:
|
||||
buffer.close()
|
||||
|
||||
def _get_ancestors(self, loader):
|
||||
ancestors = [self.file]
|
||||
for chunk in self.file.body.chunks:
|
||||
if isinstance(chunk, _ExtendsBlock):
|
||||
if not loader:
|
||||
raise ParseError("{% extends %} block found, but no "
|
||||
"template loader")
|
||||
template = loader.load(chunk.name, self.name)
|
||||
ancestors.extend(template._get_ancestors(loader))
|
||||
return ancestors
|
||||
|
||||
|
||||
class BaseLoader(object):
|
||||
"""Base class for template loaders.
|
||||
|
||||
You must use a template loader to use template constructs like
|
||||
``{% extends %}`` and ``{% include %}``. The loader caches all
|
||||
templates after they are loaded the first time.
|
||||
"""
|
||||
def __init__(self, autoescape=_DEFAULT_AUTOESCAPE, namespace=None,
|
||||
whitespace=None):
|
||||
"""Construct a template loader.
|
||||
|
||||
:arg str autoescape: The name of a function in the template
|
||||
namespace, such as "xhtml_escape", or ``None`` to disable
|
||||
autoescaping by default.
|
||||
:arg dict namespace: A dictionary to be added to the default template
|
||||
namespace, or ``None``.
|
||||
:arg str whitespace: A string specifying default behavior for
|
||||
whitespace in templates; see `filter_whitespace` for options.
|
||||
Default is "single" for files ending in ".html" and ".js" and
|
||||
"all" for other files.
|
||||
|
||||
.. versionchanged:: 4.3
|
||||
Added ``whitespace`` parameter.
|
||||
"""
|
||||
self.autoescape = autoescape
|
||||
self.namespace = namespace or {}
|
||||
self.whitespace = whitespace
|
||||
self.templates = {}
|
||||
# self.lock protects self.templates. It's a reentrant lock
|
||||
# because templates may load other templates via `include` or
|
||||
# `extends`. Note that thanks to the GIL this code would be safe
|
||||
# even without the lock, but could lead to wasted work as multiple
|
||||
# threads tried to compile the same template simultaneously.
|
||||
self.lock = threading.RLock()
|
||||
|
||||
def reset(self):
|
||||
"""Resets the cache of compiled templates."""
|
||||
with self.lock:
|
||||
self.templates = {}
|
||||
|
||||
def resolve_path(self, name, parent_path=None):
|
||||
"""Converts a possibly-relative path to absolute (used internally)."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def load(self, name, parent_path=None):
|
||||
"""Loads a template."""
|
||||
name = self.resolve_path(name, parent_path=parent_path)
|
||||
with self.lock:
|
||||
if name not in self.templates:
|
||||
self.templates[name] = self._create_template(name)
|
||||
return self.templates[name]
|
||||
|
||||
def _create_template(self, name):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Loader(BaseLoader):
|
||||
"""A template loader that loads from a single root directory.
|
||||
"""
|
||||
def __init__(self, root_directory, **kwargs):
|
||||
super(Loader, self).__init__(**kwargs)
|
||||
self.root = os.path.abspath(root_directory)
|
||||
|
||||
def resolve_path(self, name, parent_path=None):
|
||||
if parent_path and not parent_path.startswith("<") and \
|
||||
not parent_path.startswith("/") and \
|
||||
not name.startswith("/"):
|
||||
current_path = os.path.join(self.root, parent_path)
|
||||
file_dir = os.path.dirname(os.path.abspath(current_path))
|
||||
relative_path = os.path.abspath(os.path.join(file_dir, name))
|
||||
if relative_path.startswith(self.root):
|
||||
name = relative_path[len(self.root) + 1:]
|
||||
return name
|
||||
|
||||
def _create_template(self, name):
|
||||
path = os.path.join(self.root, name)
|
||||
with open(path, "rb") as f:
|
||||
template = Template(f.read(), name=name, loader=self)
|
||||
return template
|
||||
|
||||
|
||||
class DictLoader(BaseLoader):
|
||||
"""A template loader that loads from a dictionary."""
|
||||
def __init__(self, dict, **kwargs):
|
||||
super(DictLoader, self).__init__(**kwargs)
|
||||
self.dict = dict
|
||||
|
||||
def resolve_path(self, name, parent_path=None):
|
||||
if parent_path and not parent_path.startswith("<") and \
|
||||
not parent_path.startswith("/") and \
|
||||
not name.startswith("/"):
|
||||
file_dir = posixpath.dirname(parent_path)
|
||||
name = posixpath.normpath(posixpath.join(file_dir, name))
|
||||
return name
|
||||
|
||||
def _create_template(self, name):
|
||||
return Template(self.dict[name], name=name, loader=self)
|
||||
|
||||
|
||||
class _Node(object):
|
||||
def each_child(self):
|
||||
return ()
|
||||
|
||||
def generate(self, writer):
|
||||
raise NotImplementedError()
|
||||
|
||||
def find_named_blocks(self, loader, named_blocks):
|
||||
for child in self.each_child():
|
||||
child.find_named_blocks(loader, named_blocks)
|
||||
|
||||
|
||||
class _File(_Node):
|
||||
def __init__(self, template, body):
|
||||
self.template = template
|
||||
self.body = body
|
||||
self.line = 0
|
||||
|
||||
def generate(self, writer):
|
||||
writer.write_line("def _tt_execute():", self.line)
|
||||
with writer.indent():
|
||||
writer.write_line("_tt_buffer = []", self.line)
|
||||
writer.write_line("_tt_append = _tt_buffer.append", self.line)
|
||||
self.body.generate(writer)
|
||||
writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line)
|
||||
|
||||
def each_child(self):
|
||||
return (self.body,)
|
||||
|
||||
|
||||
class _ChunkList(_Node):
|
||||
def __init__(self, chunks):
|
||||
self.chunks = chunks
|
||||
|
||||
def generate(self, writer):
|
||||
for chunk in self.chunks:
|
||||
chunk.generate(writer)
|
||||
|
||||
def each_child(self):
|
||||
return self.chunks
|
||||
|
||||
|
||||
class _NamedBlock(_Node):
|
||||
def __init__(self, name, body, template, line):
|
||||
self.name = name
|
||||
self.body = body
|
||||
self.template = template
|
||||
self.line = line
|
||||
|
||||
def each_child(self):
|
||||
return (self.body,)
|
||||
|
||||
def generate(self, writer):
|
||||
block = writer.named_blocks[self.name]
|
||||
with writer.include(block.template, self.line):
|
||||
block.body.generate(writer)
|
||||
|
||||
def find_named_blocks(self, loader, named_blocks):
|
||||
named_blocks[self.name] = self
|
||||
_Node.find_named_blocks(self, loader, named_blocks)
|
||||
|
||||
|
||||
class _ExtendsBlock(_Node):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class _IncludeBlock(_Node):
|
||||
def __init__(self, name, reader, line):
|
||||
self.name = name
|
||||
self.template_name = reader.name
|
||||
self.line = line
|
||||
|
||||
def find_named_blocks(self, loader, named_blocks):
|
||||
included = loader.load(self.name, self.template_name)
|
||||
included.file.find_named_blocks(loader, named_blocks)
|
||||
|
||||
def generate(self, writer):
|
||||
included = writer.loader.load(self.name, self.template_name)
|
||||
with writer.include(included, self.line):
|
||||
included.file.body.generate(writer)
|
||||
|
||||
|
||||
class _ApplyBlock(_Node):
|
||||
def __init__(self, method, line, body=None):
|
||||
self.method = method
|
||||
self.line = line
|
||||
self.body = body
|
||||
|
||||
def each_child(self):
|
||||
return (self.body,)
|
||||
|
||||
def generate(self, writer):
|
||||
method_name = "_tt_apply%d" % writer.apply_counter
|
||||
writer.apply_counter += 1
|
||||
writer.write_line("def %s():" % method_name, self.line)
|
||||
with writer.indent():
|
||||
writer.write_line("_tt_buffer = []", self.line)
|
||||
writer.write_line("_tt_append = _tt_buffer.append", self.line)
|
||||
self.body.generate(writer)
|
||||
writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line)
|
||||
writer.write_line("_tt_append(_tt_utf8(%s(%s())))" % (
|
||||
self.method, method_name), self.line)
|
||||
|
||||
|
||||
class _ControlBlock(_Node):
|
||||
def __init__(self, statement, line, body=None):
|
||||
self.statement = statement
|
||||
self.line = line
|
||||
self.body = body
|
||||
|
||||
def each_child(self):
|
||||
return (self.body,)
|
||||
|
||||
def generate(self, writer):
|
||||
writer.write_line("%s:" % self.statement, self.line)
|
||||
with writer.indent():
|
||||
self.body.generate(writer)
|
||||
# Just in case the body was empty
|
||||
writer.write_line("pass", self.line)
|
||||
|
||||
|
||||
class _IntermediateControlBlock(_Node):
|
||||
def __init__(self, statement, line):
|
||||
self.statement = statement
|
||||
self.line = line
|
||||
|
||||
def generate(self, writer):
|
||||
# In case the previous block was empty
|
||||
writer.write_line("pass", self.line)
|
||||
writer.write_line("%s:" % self.statement, self.line, writer.indent_size() - 1)
|
||||
|
||||
|
||||
class _Statement(_Node):
|
||||
def __init__(self, statement, line):
|
||||
self.statement = statement
|
||||
self.line = line
|
||||
|
||||
def generate(self, writer):
|
||||
writer.write_line(self.statement, self.line)
|
||||
|
||||
|
||||
class _Expression(_Node):
|
||||
def __init__(self, expression, line, raw=False):
|
||||
self.expression = expression
|
||||
self.line = line
|
||||
self.raw = raw
|
||||
|
||||
def generate(self, writer):
|
||||
writer.write_line("_tt_tmp = %s" % self.expression, self.line)
|
||||
writer.write_line("if isinstance(_tt_tmp, _tt_string_types):"
|
||||
" _tt_tmp = _tt_utf8(_tt_tmp)", self.line)
|
||||
writer.write_line("else: _tt_tmp = _tt_utf8(str(_tt_tmp))", self.line)
|
||||
if not self.raw and writer.current_template.autoescape is not None:
|
||||
# In python3 functions like xhtml_escape return unicode,
|
||||
# so we have to convert to utf8 again.
|
||||
writer.write_line("_tt_tmp = _tt_utf8(%s(_tt_tmp))" %
|
||||
writer.current_template.autoescape, self.line)
|
||||
writer.write_line("_tt_append(_tt_tmp)", self.line)
|
||||
|
||||
|
||||
class _Module(_Expression):
|
||||
def __init__(self, expression, line):
|
||||
super(_Module, self).__init__("_tt_modules." + expression, line,
|
||||
raw=True)
|
||||
|
||||
|
||||
class _Text(_Node):
|
||||
def __init__(self, value, line, whitespace):
|
||||
self.value = value
|
||||
self.line = line
|
||||
self.whitespace = whitespace
|
||||
|
||||
def generate(self, writer):
|
||||
value = self.value
|
||||
|
||||
# Compress whitespace if requested, with a crude heuristic to avoid
|
||||
# altering preformatted whitespace.
|
||||
if "<pre>" not in value:
|
||||
value = filter_whitespace(self.whitespace, value)
|
||||
|
||||
if value:
|
||||
writer.write_line('_tt_append(%r)' % escape.utf8(value), self.line)
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
"""Raised for template syntax errors.
|
||||
|
||||
``ParseError`` instances have ``filename`` and ``lineno`` attributes
|
||||
indicating the position of the error.
|
||||
|
||||
.. versionchanged:: 4.3
|
||||
Added ``filename`` and ``lineno`` attributes.
|
||||
"""
|
||||
def __init__(self, message, filename=None, lineno=0):
|
||||
self.message = message
|
||||
# The names "filename" and "lineno" are chosen for consistency
|
||||
# with python SyntaxError.
|
||||
self.filename = filename
|
||||
self.lineno = lineno
|
||||
|
||||
def __str__(self):
|
||||
return '%s at %s:%d' % (self.message, self.filename, self.lineno)
|
||||
|
||||
|
||||
class _CodeWriter(object):
|
||||
def __init__(self, file, named_blocks, loader, current_template):
|
||||
self.file = file
|
||||
self.named_blocks = named_blocks
|
||||
self.loader = loader
|
||||
self.current_template = current_template
|
||||
self.apply_counter = 0
|
||||
self.include_stack = []
|
||||
self._indent = 0
|
||||
|
||||
def indent_size(self):
|
||||
return self._indent
|
||||
|
||||
def indent(self):
|
||||
class Indenter(object):
|
||||
def __enter__(_):
|
||||
self._indent += 1
|
||||
return self
|
||||
|
||||
def __exit__(_, *args):
|
||||
assert self._indent > 0
|
||||
self._indent -= 1
|
||||
|
||||
return Indenter()
|
||||
|
||||
def include(self, template, line):
|
||||
self.include_stack.append((self.current_template, line))
|
||||
self.current_template = template
|
||||
|
||||
class IncludeTemplate(object):
|
||||
def __enter__(_):
|
||||
return self
|
||||
|
||||
def __exit__(_, *args):
|
||||
self.current_template = self.include_stack.pop()[0]
|
||||
|
||||
return IncludeTemplate()
|
||||
|
||||
def write_line(self, line, line_number, indent=None):
|
||||
if indent is None:
|
||||
indent = self._indent
|
||||
line_comment = ' # %s:%d' % (self.current_template.name, line_number)
|
||||
if self.include_stack:
|
||||
ancestors = ["%s:%d" % (tmpl.name, lineno)
|
||||
for (tmpl, lineno) in self.include_stack]
|
||||
line_comment += ' (via %s)' % ', '.join(reversed(ancestors))
|
||||
print(" " * indent + line + line_comment, file=self.file)
|
||||
|
||||
|
||||
class _TemplateReader(object):
|
||||
def __init__(self, name, text, whitespace):
|
||||
self.name = name
|
||||
self.text = text
|
||||
self.whitespace = whitespace
|
||||
self.line = 1
|
||||
self.pos = 0
|
||||
|
||||
def find(self, needle, start=0, end=None):
|
||||
assert start >= 0, start
|
||||
pos = self.pos
|
||||
start += pos
|
||||
if end is None:
|
||||
index = self.text.find(needle, start)
|
||||
else:
|
||||
end += pos
|
||||
assert end >= start
|
||||
index = self.text.find(needle, start, end)
|
||||
if index != -1:
|
||||
index -= pos
|
||||
return index
|
||||
|
||||
def consume(self, count=None):
|
||||
if count is None:
|
||||
count = len(self.text) - self.pos
|
||||
newpos = self.pos + count
|
||||
self.line += self.text.count("\n", self.pos, newpos)
|
||||
s = self.text[self.pos:newpos]
|
||||
self.pos = newpos
|
||||
return s
|
||||
|
||||
def remaining(self):
|
||||
return len(self.text) - self.pos
|
||||
|
||||
def __len__(self):
|
||||
return self.remaining()
|
||||
|
||||
def __getitem__(self, key):
|
||||
if type(key) is slice:
|
||||
size = len(self)
|
||||
start, stop, step = key.indices(size)
|
||||
if start is None:
|
||||
start = self.pos
|
||||
else:
|
||||
start += self.pos
|
||||
if stop is not None:
|
||||
stop += self.pos
|
||||
return self.text[slice(start, stop, step)]
|
||||
elif key < 0:
|
||||
return self.text[key]
|
||||
else:
|
||||
return self.text[self.pos + key]
|
||||
|
||||
def __str__(self):
|
||||
return self.text[self.pos:]
|
||||
|
||||
def raise_parse_error(self, msg):
|
||||
raise ParseError(msg, self.name, self.line)
|
||||
|
||||
|
||||
def _format_code(code):
|
||||
lines = code.splitlines()
|
||||
format = "%%%dd %%s\n" % len(repr(len(lines) + 1))
|
||||
return "".join([format % (i + 1, line) for (i, line) in enumerate(lines)])
|
||||
|
||||
|
||||
def _parse(reader, template, in_block=None, in_loop=None):
|
||||
body = _ChunkList([])
|
||||
while True:
|
||||
# Find next template directive
|
||||
curly = 0
|
||||
while True:
|
||||
curly = reader.find("{", curly)
|
||||
if curly == -1 or curly + 1 == reader.remaining():
|
||||
# EOF
|
||||
if in_block:
|
||||
reader.raise_parse_error(
|
||||
"Missing {%% end %%} block for %s" % in_block)
|
||||
body.chunks.append(_Text(reader.consume(), reader.line,
|
||||
reader.whitespace))
|
||||
return body
|
||||
# If the first curly brace is not the start of a special token,
|
||||
# start searching from the character after it
|
||||
if reader[curly + 1] not in ("{", "%", "#"):
|
||||
curly += 1
|
||||
continue
|
||||
# When there are more than 2 curlies in a row, use the
|
||||
# innermost ones. This is useful when generating languages
|
||||
# like latex where curlies are also meaningful
|
||||
if (curly + 2 < reader.remaining() and
|
||||
reader[curly + 1] == '{' and reader[curly + 2] == '{'):
|
||||
curly += 1
|
||||
continue
|
||||
break
|
||||
|
||||
# Append any text before the special token
|
||||
if curly > 0:
|
||||
cons = reader.consume(curly)
|
||||
body.chunks.append(_Text(cons, reader.line,
|
||||
reader.whitespace))
|
||||
|
||||
start_brace = reader.consume(2)
|
||||
line = reader.line
|
||||
|
||||
# Template directives may be escaped as "{{!" or "{%!".
|
||||
# In this case output the braces and consume the "!".
|
||||
# This is especially useful in conjunction with jquery templates,
|
||||
# which also use double braces.
|
||||
if reader.remaining() and reader[0] == "!":
|
||||
reader.consume(1)
|
||||
body.chunks.append(_Text(start_brace, line,
|
||||
reader.whitespace))
|
||||
continue
|
||||
|
||||
# Comment
|
||||
if start_brace == "{#":
|
||||
end = reader.find("#}")
|
||||
if end == -1:
|
||||
reader.raise_parse_error("Missing end comment #}")
|
||||
contents = reader.consume(end).strip()
|
||||
reader.consume(2)
|
||||
continue
|
||||
|
||||
# Expression
|
||||
if start_brace == "{{":
|
||||
end = reader.find("}}")
|
||||
if end == -1:
|
||||
reader.raise_parse_error("Missing end expression }}")
|
||||
contents = reader.consume(end).strip()
|
||||
reader.consume(2)
|
||||
if not contents:
|
||||
reader.raise_parse_error("Empty expression")
|
||||
body.chunks.append(_Expression(contents, line))
|
||||
continue
|
||||
|
||||
# Block
|
||||
assert start_brace == "{%", start_brace
|
||||
end = reader.find("%}")
|
||||
if end == -1:
|
||||
reader.raise_parse_error("Missing end block %}")
|
||||
contents = reader.consume(end).strip()
|
||||
reader.consume(2)
|
||||
if not contents:
|
||||
reader.raise_parse_error("Empty block tag ({% %})")
|
||||
|
||||
operator, space, suffix = contents.partition(" ")
|
||||
suffix = suffix.strip()
|
||||
|
||||
# Intermediate ("else", "elif", etc) blocks
|
||||
intermediate_blocks = {
|
||||
"else": set(["if", "for", "while", "try"]),
|
||||
"elif": set(["if"]),
|
||||
"except": set(["try"]),
|
||||
"finally": set(["try"]),
|
||||
}
|
||||
allowed_parents = intermediate_blocks.get(operator)
|
||||
if allowed_parents is not None:
|
||||
if not in_block:
|
||||
reader.raise_parse_error("%s outside %s block" %
|
||||
(operator, allowed_parents))
|
||||
if in_block not in allowed_parents:
|
||||
reader.raise_parse_error(
|
||||
"%s block cannot be attached to %s block" %
|
||||
(operator, in_block))
|
||||
body.chunks.append(_IntermediateControlBlock(contents, line))
|
||||
continue
|
||||
|
||||
# End tag
|
||||
elif operator == "end":
|
||||
if not in_block:
|
||||
reader.raise_parse_error("Extra {% end %} block")
|
||||
return body
|
||||
|
||||
elif operator in ("extends", "include", "set", "import", "from",
|
||||
"comment", "autoescape", "whitespace", "raw",
|
||||
"module"):
|
||||
if operator == "comment":
|
||||
continue
|
||||
if operator == "extends":
|
||||
suffix = suffix.strip('"').strip("'")
|
||||
if not suffix:
|
||||
reader.raise_parse_error("extends missing file path")
|
||||
block = _ExtendsBlock(suffix)
|
||||
elif operator in ("import", "from"):
|
||||
if not suffix:
|
||||
reader.raise_parse_error("import missing statement")
|
||||
block = _Statement(contents, line)
|
||||
elif operator == "include":
|
||||
suffix = suffix.strip('"').strip("'")
|
||||
if not suffix:
|
||||
reader.raise_parse_error("include missing file path")
|
||||
block = _IncludeBlock(suffix, reader, line)
|
||||
elif operator == "set":
|
||||
if not suffix:
|
||||
reader.raise_parse_error("set missing statement")
|
||||
block = _Statement(suffix, line)
|
||||
elif operator == "autoescape":
|
||||
fn = suffix.strip()
|
||||
if fn == "None":
|
||||
fn = None
|
||||
template.autoescape = fn
|
||||
continue
|
||||
elif operator == "whitespace":
|
||||
mode = suffix.strip()
|
||||
# Validate the selected mode
|
||||
filter_whitespace(mode, '')
|
||||
reader.whitespace = mode
|
||||
continue
|
||||
elif operator == "raw":
|
||||
block = _Expression(suffix, line, raw=True)
|
||||
elif operator == "module":
|
||||
block = _Module(suffix, line)
|
||||
body.chunks.append(block)
|
||||
continue
|
||||
|
||||
elif operator in ("apply", "block", "try", "if", "for", "while"):
|
||||
# parse inner body recursively
|
||||
if operator in ("for", "while"):
|
||||
block_body = _parse(reader, template, operator, operator)
|
||||
elif operator == "apply":
|
||||
# apply creates a nested function so syntactically it's not
|
||||
# in the loop.
|
||||
block_body = _parse(reader, template, operator, None)
|
||||
else:
|
||||
block_body = _parse(reader, template, operator, in_loop)
|
||||
|
||||
if operator == "apply":
|
||||
if not suffix:
|
||||
reader.raise_parse_error("apply missing method name")
|
||||
block = _ApplyBlock(suffix, line, block_body)
|
||||
elif operator == "block":
|
||||
if not suffix:
|
||||
reader.raise_parse_error("block missing name")
|
||||
block = _NamedBlock(suffix, block_body, template, line)
|
||||
else:
|
||||
block = _ControlBlock(contents, line, block_body)
|
||||
body.chunks.append(block)
|
||||
continue
|
||||
|
||||
elif operator in ("break", "continue"):
|
||||
if not in_loop:
|
||||
reader.raise_parse_error("%s outside %s block" %
|
||||
(operator, set(["for", "while"])))
|
||||
body.chunks.append(_Statement(contents, line))
|
||||
continue
|
||||
|
||||
else:
|
||||
reader.raise_parse_error("unknown operator: %r" % operator)
|
||||
Executable
Executable
+14
@@ -0,0 +1,14 @@
|
||||
"""Shim to allow python -m tornado.test.
|
||||
|
||||
This only works in python 2.7+.
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from tornado.test.runtests import all, main
|
||||
|
||||
# tornado.testing.main autodiscovery relies on 'all' being present in
|
||||
# the main module, so import it here even though it is not used directly.
|
||||
# The following line prevents a pyflakes warning.
|
||||
all = all
|
||||
|
||||
main()
|
||||
Executable
+206
@@ -0,0 +1,206 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.testing import AsyncTestCase, gen_test
|
||||
from tornado.test.util import unittest, skipBefore33, skipBefore35, exec_test
|
||||
|
||||
try:
|
||||
from tornado.platform.asyncio import asyncio
|
||||
except ImportError:
|
||||
asyncio = None
|
||||
else:
|
||||
from tornado.platform.asyncio import AsyncIOLoop, to_asyncio_future, AnyThreadEventLoopPolicy
|
||||
# This is used in dynamically-evaluated code, so silence pyflakes.
|
||||
to_asyncio_future
|
||||
|
||||
|
||||
@unittest.skipIf(asyncio is None, "asyncio module not present")
|
||||
class AsyncIOLoopTest(AsyncTestCase):
|
||||
def get_new_ioloop(self):
|
||||
io_loop = AsyncIOLoop()
|
||||
return io_loop
|
||||
|
||||
def test_asyncio_callback(self):
|
||||
# Basic test that the asyncio loop is set up correctly.
|
||||
asyncio.get_event_loop().call_soon(self.stop)
|
||||
self.wait()
|
||||
|
||||
@gen_test
|
||||
def test_asyncio_future(self):
|
||||
# Test that we can yield an asyncio future from a tornado coroutine.
|
||||
# Without 'yield from', we must wrap coroutines in ensure_future,
|
||||
# which was introduced during Python 3.4, deprecating the prior "async".
|
||||
if hasattr(asyncio, 'ensure_future'):
|
||||
ensure_future = asyncio.ensure_future
|
||||
else:
|
||||
# async is a reserved word in Python 3.7
|
||||
ensure_future = getattr(asyncio, 'async')
|
||||
|
||||
x = yield ensure_future(
|
||||
asyncio.get_event_loop().run_in_executor(None, lambda: 42))
|
||||
self.assertEqual(x, 42)
|
||||
|
||||
@skipBefore33
|
||||
@gen_test
|
||||
def test_asyncio_yield_from(self):
|
||||
# Test that we can use asyncio coroutines with 'yield from'
|
||||
# instead of asyncio.async(). This requires python 3.3 syntax.
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
@gen.coroutine
|
||||
def f():
|
||||
event_loop = asyncio.get_event_loop()
|
||||
x = yield from event_loop.run_in_executor(None, lambda: 42)
|
||||
return x
|
||||
""")
|
||||
result = yield namespace['f']()
|
||||
self.assertEqual(result, 42)
|
||||
|
||||
@skipBefore35
|
||||
def test_asyncio_adapter(self):
|
||||
# This test demonstrates that when using the asyncio coroutine
|
||||
# runner (i.e. run_until_complete), the to_asyncio_future
|
||||
# adapter is needed. No adapter is needed in the other direction,
|
||||
# as demonstrated by other tests in the package.
|
||||
@gen.coroutine
|
||||
def tornado_coroutine():
|
||||
yield gen.moment
|
||||
raise gen.Return(42)
|
||||
native_coroutine_without_adapter = exec_test(globals(), locals(), """
|
||||
async def native_coroutine_without_adapter():
|
||||
return await tornado_coroutine()
|
||||
""")["native_coroutine_without_adapter"]
|
||||
|
||||
native_coroutine_with_adapter = exec_test(globals(), locals(), """
|
||||
async def native_coroutine_with_adapter():
|
||||
return await to_asyncio_future(tornado_coroutine())
|
||||
""")["native_coroutine_with_adapter"]
|
||||
|
||||
# Use the adapter, but two degrees from the tornado coroutine.
|
||||
native_coroutine_with_adapter2 = exec_test(globals(), locals(), """
|
||||
async def native_coroutine_with_adapter2():
|
||||
return await to_asyncio_future(native_coroutine_without_adapter())
|
||||
""")["native_coroutine_with_adapter2"]
|
||||
|
||||
# Tornado supports native coroutines both with and without adapters
|
||||
self.assertEqual(
|
||||
self.io_loop.run_sync(native_coroutine_without_adapter),
|
||||
42)
|
||||
self.assertEqual(
|
||||
self.io_loop.run_sync(native_coroutine_with_adapter),
|
||||
42)
|
||||
self.assertEqual(
|
||||
self.io_loop.run_sync(native_coroutine_with_adapter2),
|
||||
42)
|
||||
|
||||
# Asyncio only supports coroutines that yield asyncio-compatible
|
||||
# Futures (which our Future is since 5.0).
|
||||
self.assertEqual(
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
native_coroutine_without_adapter()),
|
||||
42)
|
||||
self.assertEqual(
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
native_coroutine_with_adapter()),
|
||||
42)
|
||||
self.assertEqual(
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
native_coroutine_with_adapter2()),
|
||||
42)
|
||||
|
||||
|
||||
@unittest.skipIf(asyncio is None, "asyncio module not present")
|
||||
class LeakTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Trigger a cleanup of the mapping so we start with a clean slate.
|
||||
AsyncIOLoop().close()
|
||||
# If we don't clean up after ourselves other tests may fail on
|
||||
# py34.
|
||||
self.orig_policy = asyncio.get_event_loop_policy()
|
||||
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
|
||||
|
||||
def tearDown(self):
|
||||
asyncio.get_event_loop().close()
|
||||
asyncio.set_event_loop_policy(self.orig_policy)
|
||||
|
||||
def test_ioloop_close_leak(self):
|
||||
orig_count = len(IOLoop._ioloop_for_asyncio)
|
||||
for i in range(10):
|
||||
# Create and close an AsyncIOLoop using Tornado interfaces.
|
||||
loop = AsyncIOLoop()
|
||||
loop.close()
|
||||
new_count = len(IOLoop._ioloop_for_asyncio) - orig_count
|
||||
self.assertEqual(new_count, 0)
|
||||
|
||||
def test_asyncio_close_leak(self):
|
||||
orig_count = len(IOLoop._ioloop_for_asyncio)
|
||||
for i in range(10):
|
||||
# Create and close an AsyncIOMainLoop using asyncio interfaces.
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.call_soon(IOLoop.current)
|
||||
loop.call_soon(loop.stop)
|
||||
loop.run_forever()
|
||||
loop.close()
|
||||
new_count = len(IOLoop._ioloop_for_asyncio) - orig_count
|
||||
# Because the cleanup is run on new loop creation, we have one
|
||||
# dangling entry in the map (but only one).
|
||||
self.assertEqual(new_count, 1)
|
||||
|
||||
|
||||
@unittest.skipIf(asyncio is None, "asyncio module not present")
|
||||
class AnyThreadEventLoopPolicyTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.orig_policy = asyncio.get_event_loop_policy()
|
||||
self.executor = ThreadPoolExecutor(1)
|
||||
|
||||
def tearDown(self):
|
||||
asyncio.set_event_loop_policy(self.orig_policy)
|
||||
self.executor.shutdown()
|
||||
|
||||
def get_event_loop_on_thread(self):
|
||||
def get_and_close_event_loop():
|
||||
"""Get the event loop. Close it if one is returned.
|
||||
|
||||
Returns the (closed) event loop. This is a silly thing
|
||||
to do and leaves the thread in a broken state, but it's
|
||||
enough for this test. Closing the loop avoids resource
|
||||
leak warnings.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.close()
|
||||
return loop
|
||||
future = self.executor.submit(get_and_close_event_loop)
|
||||
return future.result()
|
||||
|
||||
def run_policy_test(self, accessor, expected_type):
|
||||
# With the default policy, non-main threads don't get an event
|
||||
# loop.
|
||||
self.assertRaises((RuntimeError, AssertionError),
|
||||
self.executor.submit(accessor).result)
|
||||
# Set the policy and we can get a loop.
|
||||
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||
self.assertIsInstance(
|
||||
self.executor.submit(accessor).result(),
|
||||
expected_type)
|
||||
# Clean up to silence leak warnings. Always use asyncio since
|
||||
# IOLoop doesn't (currently) close the underlying loop.
|
||||
self.executor.submit(lambda: asyncio.get_event_loop().close()).result()
|
||||
|
||||
def test_asyncio_accessor(self):
|
||||
self.run_policy_test(asyncio.get_event_loop, asyncio.AbstractEventLoop)
|
||||
|
||||
def test_tornado_accessor(self):
|
||||
self.run_policy_test(IOLoop.current, IOLoop)
|
||||
Executable
+735
@@ -0,0 +1,735 @@
|
||||
# These tests do not currently do much to verify the correct implementation
|
||||
# of the openid/oauth protocols, they just exercise the major code paths
|
||||
# and ensure that it doesn't blow up (e.g. with unicode/bytes issues in
|
||||
# python 3)
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
from tornado.auth import (
|
||||
AuthError, OpenIdMixin, OAuthMixin, OAuth2Mixin,
|
||||
GoogleOAuth2Mixin, FacebookGraphMixin, TwitterMixin,
|
||||
)
|
||||
from tornado.concurrent import Future
|
||||
from tornado.escape import json_decode
|
||||
from tornado import gen
|
||||
from tornado.httputil import url_concat
|
||||
from tornado.log import gen_log, app_log
|
||||
from tornado.testing import AsyncHTTPTestCase, ExpectLog
|
||||
from tornado.test.util import ignore_deprecation
|
||||
from tornado.web import RequestHandler, Application, asynchronous, HTTPError
|
||||
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
mock = None
|
||||
|
||||
|
||||
class OpenIdClientLoginHandlerLegacy(RequestHandler, OpenIdMixin):
|
||||
def initialize(self, test):
|
||||
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
|
||||
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
def get(self):
|
||||
if self.get_argument('openid.mode', None):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
self.get_authenticated_user(
|
||||
self.on_user, http_client=self.settings['http_client'])
|
||||
return
|
||||
res = self.authenticate_redirect()
|
||||
assert isinstance(res, Future)
|
||||
assert res.done()
|
||||
|
||||
def on_user(self, user):
|
||||
if user is None:
|
||||
raise Exception("user is None")
|
||||
self.finish(user)
|
||||
|
||||
|
||||
class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin):
|
||||
def initialize(self, test):
|
||||
self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
|
||||
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
if self.get_argument('openid.mode', None):
|
||||
user = yield self.get_authenticated_user(http_client=self.settings['http_client'])
|
||||
if user is None:
|
||||
raise Exception("user is None")
|
||||
self.finish(user)
|
||||
return
|
||||
res = self.authenticate_redirect()
|
||||
assert isinstance(res, Future)
|
||||
assert res.done()
|
||||
|
||||
|
||||
class OpenIdServerAuthenticateHandler(RequestHandler):
|
||||
def post(self):
|
||||
if self.get_argument('openid.mode') != 'check_authentication':
|
||||
raise Exception("incorrect openid.mode %r")
|
||||
self.write('is_valid:true')
|
||||
|
||||
|
||||
class OAuth1ClientLoginHandlerLegacy(RequestHandler, OAuthMixin):
|
||||
def initialize(self, test, version):
|
||||
self._OAUTH_VERSION = version
|
||||
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
|
||||
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
|
||||
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token')
|
||||
|
||||
def _oauth_consumer_token(self):
|
||||
return dict(key='asdf', secret='qwer')
|
||||
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
def get(self):
|
||||
if self.get_argument('oauth_token', None):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
self.get_authenticated_user(
|
||||
self.on_user, http_client=self.settings['http_client'])
|
||||
return
|
||||
res = self.authorize_redirect(http_client=self.settings['http_client'])
|
||||
assert isinstance(res, Future)
|
||||
|
||||
def on_user(self, user):
|
||||
if user is None:
|
||||
raise Exception("user is None")
|
||||
self.finish(user)
|
||||
|
||||
def _oauth_get_user(self, access_token, callback):
|
||||
if self.get_argument('fail_in_get_user', None):
|
||||
raise Exception("failing in get_user")
|
||||
if access_token != dict(key='uiop', secret='5678'):
|
||||
raise Exception("incorrect access token %r" % access_token)
|
||||
callback(dict(email='foo@example.com'))
|
||||
|
||||
|
||||
class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
|
||||
def initialize(self, test, version):
|
||||
self._OAUTH_VERSION = version
|
||||
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
|
||||
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
|
||||
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token')
|
||||
|
||||
def _oauth_consumer_token(self):
|
||||
return dict(key='asdf', secret='qwer')
|
||||
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
if self.get_argument('oauth_token', None):
|
||||
user = yield self.get_authenticated_user(http_client=self.settings['http_client'])
|
||||
if user is None:
|
||||
raise Exception("user is None")
|
||||
self.finish(user)
|
||||
return
|
||||
yield self.authorize_redirect(http_client=self.settings['http_client'])
|
||||
|
||||
@gen.coroutine
|
||||
def _oauth_get_user_future(self, access_token):
|
||||
if self.get_argument('fail_in_get_user', None):
|
||||
raise Exception("failing in get_user")
|
||||
if access_token != dict(key='uiop', secret='5678'):
|
||||
raise Exception("incorrect access token %r" % access_token)
|
||||
return dict(email='foo@example.com')
|
||||
|
||||
|
||||
class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler):
|
||||
"""Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine."""
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
if self.get_argument('oauth_token', None):
|
||||
# Ensure that any exceptions are set on the returned Future,
|
||||
# not simply thrown into the surrounding StackContext.
|
||||
try:
|
||||
yield self.get_authenticated_user()
|
||||
except Exception as e:
|
||||
self.set_status(503)
|
||||
self.write("got exception: %s" % e)
|
||||
else:
|
||||
yield self.authorize_redirect()
|
||||
|
||||
|
||||
class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin):
|
||||
def initialize(self, version):
|
||||
self._OAUTH_VERSION = version
|
||||
|
||||
def _oauth_consumer_token(self):
|
||||
return dict(key='asdf', secret='qwer')
|
||||
|
||||
def get(self):
|
||||
params = self._oauth_request_parameters(
|
||||
'http://www.example.com/api/asdf',
|
||||
dict(key='uiop', secret='5678'),
|
||||
parameters=dict(foo='bar'))
|
||||
self.write(params)
|
||||
|
||||
|
||||
class OAuth1ServerRequestTokenHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.write('oauth_token=zxcv&oauth_token_secret=1234')
|
||||
|
||||
|
||||
class OAuth1ServerAccessTokenHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.write('oauth_token=uiop&oauth_token_secret=5678')
|
||||
|
||||
|
||||
class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin):
|
||||
def initialize(self, test):
|
||||
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth2/server/authorize')
|
||||
|
||||
def get(self):
|
||||
res = self.authorize_redirect()
|
||||
assert isinstance(res, Future)
|
||||
assert res.done()
|
||||
|
||||
|
||||
class FacebookClientLoginHandler(RequestHandler, FacebookGraphMixin):
|
||||
def initialize(self, test):
|
||||
self._OAUTH_AUTHORIZE_URL = test.get_url('/facebook/server/authorize')
|
||||
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/facebook/server/access_token')
|
||||
self._FACEBOOK_BASE_URL = test.get_url('/facebook/server')
|
||||
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
if self.get_argument("code", None):
|
||||
user = yield self.get_authenticated_user(
|
||||
redirect_uri=self.request.full_url(),
|
||||
client_id=self.settings["facebook_api_key"],
|
||||
client_secret=self.settings["facebook_secret"],
|
||||
code=self.get_argument("code"))
|
||||
self.write(user)
|
||||
else:
|
||||
yield self.authorize_redirect(
|
||||
redirect_uri=self.request.full_url(),
|
||||
client_id=self.settings["facebook_api_key"],
|
||||
extra_params={"scope": "read_stream,offline_access"})
|
||||
|
||||
|
||||
class FacebookServerAccessTokenHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.write(dict(access_token="asdf", expires_in=3600))
|
||||
|
||||
|
||||
class FacebookServerMeHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.write('{}')
|
||||
|
||||
|
||||
class TwitterClientHandler(RequestHandler, TwitterMixin):
|
||||
def initialize(self, test):
|
||||
self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
|
||||
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/twitter/server/access_token')
|
||||
self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
|
||||
self._OAUTH_AUTHENTICATE_URL = test.get_url('/twitter/server/authenticate')
|
||||
self._TWITTER_BASE_URL = test.get_url('/twitter/api')
|
||||
|
||||
def get_auth_http_client(self):
|
||||
return self.settings['http_client']
|
||||
|
||||
|
||||
class TwitterClientLoginHandlerLegacy(TwitterClientHandler):
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
def get(self):
|
||||
if self.get_argument("oauth_token", None):
|
||||
self.get_authenticated_user(self.on_user)
|
||||
return
|
||||
self.authorize_redirect()
|
||||
|
||||
def on_user(self, user):
|
||||
if user is None:
|
||||
raise Exception("user is None")
|
||||
self.finish(user)
|
||||
|
||||
|
||||
class TwitterClientLoginHandler(TwitterClientHandler):
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
if self.get_argument("oauth_token", None):
|
||||
user = yield self.get_authenticated_user()
|
||||
if user is None:
|
||||
raise Exception("user is None")
|
||||
self.finish(user)
|
||||
return
|
||||
yield self.authorize_redirect()
|
||||
|
||||
|
||||
class TwitterClientAuthenticateHandler(TwitterClientHandler):
|
||||
# Like TwitterClientLoginHandler, but uses authenticate_redirect
|
||||
# instead of authorize_redirect.
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
if self.get_argument("oauth_token", None):
|
||||
user = yield self.get_authenticated_user()
|
||||
if user is None:
|
||||
raise Exception("user is None")
|
||||
self.finish(user)
|
||||
return
|
||||
yield self.authenticate_redirect()
|
||||
|
||||
|
||||
class TwitterClientLoginGenEngineHandler(TwitterClientHandler):
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
@gen.engine
|
||||
def get(self):
|
||||
if self.get_argument("oauth_token", None):
|
||||
user = yield self.get_authenticated_user()
|
||||
self.finish(user)
|
||||
else:
|
||||
# Old style: with @gen.engine we can ignore the Future from
|
||||
# authorize_redirect.
|
||||
self.authorize_redirect()
|
||||
|
||||
|
||||
class TwitterClientLoginGenCoroutineHandler(TwitterClientHandler):
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
if self.get_argument("oauth_token", None):
|
||||
user = yield self.get_authenticated_user()
|
||||
self.finish(user)
|
||||
else:
|
||||
# New style: with @gen.coroutine the result must be yielded
|
||||
# or else the request will be auto-finished too soon.
|
||||
yield self.authorize_redirect()
|
||||
|
||||
|
||||
class TwitterClientShowUserHandlerLegacy(TwitterClientHandler):
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
@gen.engine
|
||||
def get(self):
|
||||
# TODO: would be nice to go through the login flow instead of
|
||||
# cheating with a hard-coded access token.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
response = yield gen.Task(self.twitter_request,
|
||||
'/users/show/%s' % self.get_argument('name'),
|
||||
access_token=dict(key='hjkl', secret='vbnm'))
|
||||
if response is None:
|
||||
self.set_status(500)
|
||||
self.finish('error from twitter request')
|
||||
else:
|
||||
self.finish(response)
|
||||
|
||||
|
||||
class TwitterClientShowUserHandler(TwitterClientHandler):
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
# TODO: would be nice to go through the login flow instead of
|
||||
# cheating with a hard-coded access token.
|
||||
try:
|
||||
response = yield self.twitter_request(
|
||||
'/users/show/%s' % self.get_argument('name'),
|
||||
access_token=dict(key='hjkl', secret='vbnm'))
|
||||
except AuthError:
|
||||
self.set_status(500)
|
||||
self.finish('error from twitter request')
|
||||
else:
|
||||
self.finish(response)
|
||||
|
||||
|
||||
class TwitterServerAccessTokenHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.write('oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo')
|
||||
|
||||
|
||||
class TwitterServerShowUserHandler(RequestHandler):
|
||||
def get(self, screen_name):
|
||||
if screen_name == 'error':
|
||||
raise HTTPError(500)
|
||||
assert 'oauth_nonce' in self.request.arguments
|
||||
assert 'oauth_timestamp' in self.request.arguments
|
||||
assert 'oauth_signature' in self.request.arguments
|
||||
assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key'
|
||||
assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1'
|
||||
assert self.get_argument('oauth_version') == '1.0'
|
||||
assert self.get_argument('oauth_token') == 'hjkl'
|
||||
self.write(dict(screen_name=screen_name, name=screen_name.capitalize()))
|
||||
|
||||
|
||||
class TwitterServerVerifyCredentialsHandler(RequestHandler):
|
||||
def get(self):
|
||||
assert 'oauth_nonce' in self.request.arguments
|
||||
assert 'oauth_timestamp' in self.request.arguments
|
||||
assert 'oauth_signature' in self.request.arguments
|
||||
assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key'
|
||||
assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1'
|
||||
assert self.get_argument('oauth_version') == '1.0'
|
||||
assert self.get_argument('oauth_token') == 'hjkl'
|
||||
self.write(dict(screen_name='foo', name='Foo'))
|
||||
|
||||
|
||||
class AuthTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
return Application(
|
||||
[
|
||||
# test endpoints
|
||||
('/legacy/openid/client/login', OpenIdClientLoginHandlerLegacy, dict(test=self)),
|
||||
('/openid/client/login', OpenIdClientLoginHandler, dict(test=self)),
|
||||
('/legacy/oauth10/client/login', OAuth1ClientLoginHandlerLegacy,
|
||||
dict(test=self, version='1.0')),
|
||||
('/oauth10/client/login', OAuth1ClientLoginHandler,
|
||||
dict(test=self, version='1.0')),
|
||||
('/oauth10/client/request_params',
|
||||
OAuth1ClientRequestParametersHandler,
|
||||
dict(version='1.0')),
|
||||
('/legacy/oauth10a/client/login', OAuth1ClientLoginHandlerLegacy,
|
||||
dict(test=self, version='1.0a')),
|
||||
('/oauth10a/client/login', OAuth1ClientLoginHandler,
|
||||
dict(test=self, version='1.0a')),
|
||||
('/oauth10a/client/login_coroutine',
|
||||
OAuth1ClientLoginCoroutineHandler,
|
||||
dict(test=self, version='1.0a')),
|
||||
('/oauth10a/client/request_params',
|
||||
OAuth1ClientRequestParametersHandler,
|
||||
dict(version='1.0a')),
|
||||
('/oauth2/client/login', OAuth2ClientLoginHandler, dict(test=self)),
|
||||
|
||||
('/facebook/client/login', FacebookClientLoginHandler, dict(test=self)),
|
||||
|
||||
('/legacy/twitter/client/login', TwitterClientLoginHandlerLegacy, dict(test=self)),
|
||||
('/twitter/client/login', TwitterClientLoginHandler, dict(test=self)),
|
||||
('/twitter/client/authenticate', TwitterClientAuthenticateHandler, dict(test=self)),
|
||||
('/twitter/client/login_gen_engine',
|
||||
TwitterClientLoginGenEngineHandler, dict(test=self)),
|
||||
('/twitter/client/login_gen_coroutine',
|
||||
TwitterClientLoginGenCoroutineHandler, dict(test=self)),
|
||||
('/legacy/twitter/client/show_user',
|
||||
TwitterClientShowUserHandlerLegacy, dict(test=self)),
|
||||
('/twitter/client/show_user',
|
||||
TwitterClientShowUserHandler, dict(test=self)),
|
||||
|
||||
# simulated servers
|
||||
('/openid/server/authenticate', OpenIdServerAuthenticateHandler),
|
||||
('/oauth1/server/request_token', OAuth1ServerRequestTokenHandler),
|
||||
('/oauth1/server/access_token', OAuth1ServerAccessTokenHandler),
|
||||
|
||||
('/facebook/server/access_token', FacebookServerAccessTokenHandler),
|
||||
('/facebook/server/me', FacebookServerMeHandler),
|
||||
('/twitter/server/access_token', TwitterServerAccessTokenHandler),
|
||||
(r'/twitter/api/users/show/(.*)\.json', TwitterServerShowUserHandler),
|
||||
(r'/twitter/api/account/verify_credentials\.json',
|
||||
TwitterServerVerifyCredentialsHandler),
|
||||
],
|
||||
http_client=self.http_client,
|
||||
twitter_consumer_key='test_twitter_consumer_key',
|
||||
twitter_consumer_secret='test_twitter_consumer_secret',
|
||||
facebook_api_key='test_facebook_api_key',
|
||||
facebook_secret='test_facebook_secret')
|
||||
|
||||
def test_openid_redirect_legacy(self):
|
||||
response = self.fetch('/legacy/openid/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue(
|
||||
'/openid/server/authenticate?' in response.headers['Location'])
|
||||
|
||||
def test_openid_get_user_legacy(self):
|
||||
response = self.fetch('/legacy/openid/client/login?openid.mode=blah'
|
||||
'&openid.ns.ax=http://openid.net/srv/ax/1.0'
|
||||
'&openid.ax.type.email=http://axschema.org/contact/email'
|
||||
'&openid.ax.value.email=foo@example.com')
|
||||
response.rethrow()
|
||||
parsed = json_decode(response.body)
|
||||
self.assertEqual(parsed["email"], "foo@example.com")
|
||||
|
||||
def test_openid_redirect(self):
|
||||
response = self.fetch('/openid/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue(
|
||||
'/openid/server/authenticate?' in response.headers['Location'])
|
||||
|
||||
def test_openid_get_user(self):
|
||||
response = self.fetch('/openid/client/login?openid.mode=blah'
|
||||
'&openid.ns.ax=http://openid.net/srv/ax/1.0'
|
||||
'&openid.ax.type.email=http://axschema.org/contact/email'
|
||||
'&openid.ax.value.email=foo@example.com')
|
||||
response.rethrow()
|
||||
parsed = json_decode(response.body)
|
||||
self.assertEqual(parsed["email"], "foo@example.com")
|
||||
|
||||
def test_oauth10_redirect_legacy(self):
|
||||
response = self.fetch('/legacy/oauth10/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue(response.headers['Location'].endswith(
|
||||
'/oauth1/server/authorize?oauth_token=zxcv'))
|
||||
# the cookie is base64('zxcv')|base64('1234')
|
||||
self.assertTrue(
|
||||
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
|
||||
response.headers['Set-Cookie'])
|
||||
|
||||
def test_oauth10_redirect(self):
|
||||
response = self.fetch('/oauth10/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue(response.headers['Location'].endswith(
|
||||
'/oauth1/server/authorize?oauth_token=zxcv'))
|
||||
# the cookie is base64('zxcv')|base64('1234')
|
||||
self.assertTrue(
|
||||
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
|
||||
response.headers['Set-Cookie'])
|
||||
|
||||
def test_oauth10_get_user_legacy(self):
|
||||
with ignore_deprecation():
|
||||
response = self.fetch(
|
||||
'/legacy/oauth10/client/login?oauth_token=zxcv',
|
||||
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
|
||||
response.rethrow()
|
||||
parsed = json_decode(response.body)
|
||||
self.assertEqual(parsed['email'], 'foo@example.com')
|
||||
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
|
||||
|
||||
def test_oauth10_get_user(self):
|
||||
response = self.fetch(
|
||||
'/oauth10/client/login?oauth_token=zxcv',
|
||||
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
|
||||
response.rethrow()
|
||||
parsed = json_decode(response.body)
|
||||
self.assertEqual(parsed['email'], 'foo@example.com')
|
||||
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
|
||||
|
||||
def test_oauth10_request_parameters(self):
|
||||
response = self.fetch('/oauth10/client/request_params')
|
||||
response.rethrow()
|
||||
parsed = json_decode(response.body)
|
||||
self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
|
||||
self.assertEqual(parsed['oauth_token'], 'uiop')
|
||||
self.assertTrue('oauth_nonce' in parsed)
|
||||
self.assertTrue('oauth_signature' in parsed)
|
||||
|
||||
def test_oauth10a_redirect_legacy(self):
|
||||
response = self.fetch('/legacy/oauth10a/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue(response.headers['Location'].endswith(
|
||||
'/oauth1/server/authorize?oauth_token=zxcv'))
|
||||
# the cookie is base64('zxcv')|base64('1234')
|
||||
self.assertTrue(
|
||||
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
|
||||
response.headers['Set-Cookie'])
|
||||
|
||||
def test_oauth10a_get_user_legacy(self):
|
||||
with ignore_deprecation():
|
||||
response = self.fetch(
|
||||
'/legacy/oauth10a/client/login?oauth_token=zxcv',
|
||||
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
|
||||
response.rethrow()
|
||||
parsed = json_decode(response.body)
|
||||
self.assertEqual(parsed['email'], 'foo@example.com')
|
||||
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
|
||||
|
||||
def test_oauth10a_redirect(self):
|
||||
response = self.fetch('/oauth10a/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue(response.headers['Location'].endswith(
|
||||
'/oauth1/server/authorize?oauth_token=zxcv'))
|
||||
# the cookie is base64('zxcv')|base64('1234')
|
||||
self.assertTrue(
|
||||
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
|
||||
response.headers['Set-Cookie'])
|
||||
|
||||
@unittest.skipIf(mock is None, 'mock package not present')
|
||||
def test_oauth10a_redirect_error(self):
|
||||
with mock.patch.object(OAuth1ServerRequestTokenHandler, 'get') as get:
|
||||
get.side_effect = Exception("boom")
|
||||
with ExpectLog(app_log, "Uncaught exception"):
|
||||
response = self.fetch('/oauth10a/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 500)
|
||||
|
||||
def test_oauth10a_get_user(self):
|
||||
response = self.fetch(
|
||||
'/oauth10a/client/login?oauth_token=zxcv',
|
||||
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
|
||||
response.rethrow()
|
||||
parsed = json_decode(response.body)
|
||||
self.assertEqual(parsed['email'], 'foo@example.com')
|
||||
self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
|
||||
|
||||
def test_oauth10a_request_parameters(self):
|
||||
response = self.fetch('/oauth10a/client/request_params')
|
||||
response.rethrow()
|
||||
parsed = json_decode(response.body)
|
||||
self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
|
||||
self.assertEqual(parsed['oauth_token'], 'uiop')
|
||||
self.assertTrue('oauth_nonce' in parsed)
|
||||
self.assertTrue('oauth_signature' in parsed)
|
||||
|
||||
def test_oauth10a_get_user_coroutine_exception(self):
|
||||
response = self.fetch(
|
||||
'/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true',
|
||||
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
|
||||
self.assertEqual(response.code, 503)
|
||||
|
||||
def test_oauth2_redirect(self):
|
||||
response = self.fetch('/oauth2/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue('/oauth2/server/authorize?' in response.headers['Location'])
|
||||
|
||||
def test_facebook_login(self):
|
||||
response = self.fetch('/facebook/client/login', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue('/facebook/server/authorize?' in response.headers['Location'])
|
||||
response = self.fetch('/facebook/client/login?code=1234', follow_redirects=False)
|
||||
self.assertEqual(response.code, 200)
|
||||
user = json_decode(response.body)
|
||||
self.assertEqual(user['access_token'], 'asdf')
|
||||
self.assertEqual(user['session_expires'], '3600')
|
||||
|
||||
def base_twitter_redirect(self, url):
|
||||
# Same as test_oauth10a_redirect
|
||||
response = self.fetch(url, follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue(response.headers['Location'].endswith(
|
||||
'/oauth1/server/authorize?oauth_token=zxcv'))
|
||||
# the cookie is base64('zxcv')|base64('1234')
|
||||
self.assertTrue(
|
||||
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
|
||||
response.headers['Set-Cookie'])
|
||||
|
||||
def test_twitter_redirect_legacy(self):
|
||||
self.base_twitter_redirect('/legacy/twitter/client/login')
|
||||
|
||||
def test_twitter_redirect(self):
|
||||
self.base_twitter_redirect('/twitter/client/login')
|
||||
|
||||
def test_twitter_redirect_gen_engine(self):
|
||||
self.base_twitter_redirect('/twitter/client/login_gen_engine')
|
||||
|
||||
def test_twitter_redirect_gen_coroutine(self):
|
||||
self.base_twitter_redirect('/twitter/client/login_gen_coroutine')
|
||||
|
||||
def test_twitter_authenticate_redirect(self):
|
||||
response = self.fetch('/twitter/client/authenticate', follow_redirects=False)
|
||||
self.assertEqual(response.code, 302)
|
||||
self.assertTrue(response.headers['Location'].endswith(
|
||||
'/twitter/server/authenticate?oauth_token=zxcv'), response.headers['Location'])
|
||||
# the cookie is base64('zxcv')|base64('1234')
|
||||
self.assertTrue(
|
||||
'_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
|
||||
response.headers['Set-Cookie'])
|
||||
|
||||
def test_twitter_get_user(self):
|
||||
response = self.fetch(
|
||||
'/twitter/client/login?oauth_token=zxcv',
|
||||
headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
|
||||
response.rethrow()
|
||||
parsed = json_decode(response.body)
|
||||
self.assertEqual(parsed,
|
||||
{u'access_token': {u'key': u'hjkl',
|
||||
u'screen_name': u'foo',
|
||||
u'secret': u'vbnm'},
|
||||
u'name': u'Foo',
|
||||
u'screen_name': u'foo',
|
||||
u'username': u'foo'})
|
||||
|
||||
def test_twitter_show_user_legacy(self):
|
||||
response = self.fetch('/legacy/twitter/client/show_user?name=somebody')
|
||||
response.rethrow()
|
||||
self.assertEqual(json_decode(response.body),
|
||||
{'name': 'Somebody', 'screen_name': 'somebody'})
|
||||
|
||||
def test_twitter_show_user_error_legacy(self):
|
||||
with ExpectLog(gen_log, 'Error response HTTP 500'):
|
||||
response = self.fetch('/legacy/twitter/client/show_user?name=error')
|
||||
self.assertEqual(response.code, 500)
|
||||
self.assertEqual(response.body, b'error from twitter request')
|
||||
|
||||
def test_twitter_show_user(self):
|
||||
response = self.fetch('/twitter/client/show_user?name=somebody')
|
||||
response.rethrow()
|
||||
self.assertEqual(json_decode(response.body),
|
||||
{'name': 'Somebody', 'screen_name': 'somebody'})
|
||||
|
||||
def test_twitter_show_user_error(self):
|
||||
response = self.fetch('/twitter/client/show_user?name=error')
|
||||
self.assertEqual(response.code, 500)
|
||||
self.assertEqual(response.body, b'error from twitter request')
|
||||
|
||||
|
||||
class GoogleLoginHandler(RequestHandler, GoogleOAuth2Mixin):
|
||||
def initialize(self, test):
|
||||
self.test = test
|
||||
self._OAUTH_REDIRECT_URI = test.get_url('/client/login')
|
||||
self._OAUTH_AUTHORIZE_URL = test.get_url('/google/oauth2/authorize')
|
||||
self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/google/oauth2/token')
|
||||
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
code = self.get_argument('code', None)
|
||||
if code is not None:
|
||||
# retrieve authenticate google user
|
||||
access = yield self.get_authenticated_user(self._OAUTH_REDIRECT_URI,
|
||||
code)
|
||||
user = yield self.oauth2_request(
|
||||
self.test.get_url("/google/oauth2/userinfo"),
|
||||
access_token=access["access_token"])
|
||||
# return the user and access token as json
|
||||
user["access_token"] = access["access_token"]
|
||||
self.write(user)
|
||||
else:
|
||||
yield self.authorize_redirect(
|
||||
redirect_uri=self._OAUTH_REDIRECT_URI,
|
||||
client_id=self.settings['google_oauth']['key'],
|
||||
client_secret=self.settings['google_oauth']['secret'],
|
||||
scope=['profile', 'email'],
|
||||
response_type='code',
|
||||
extra_params={'prompt': 'select_account'})
|
||||
|
||||
|
||||
class GoogleOAuth2AuthorizeHandler(RequestHandler):
|
||||
def get(self):
|
||||
# issue a fake auth code and redirect to redirect_uri
|
||||
code = 'fake-authorization-code'
|
||||
self.redirect(url_concat(self.get_argument('redirect_uri'),
|
||||
dict(code=code)))
|
||||
|
||||
|
||||
class GoogleOAuth2TokenHandler(RequestHandler):
|
||||
def post(self):
|
||||
assert self.get_argument('code') == 'fake-authorization-code'
|
||||
# issue a fake token
|
||||
self.finish({
|
||||
'access_token': 'fake-access-token',
|
||||
'expires_in': 'never-expires'
|
||||
})
|
||||
|
||||
|
||||
class GoogleOAuth2UserinfoHandler(RequestHandler):
|
||||
def get(self):
|
||||
assert self.get_argument('access_token') == 'fake-access-token'
|
||||
# return a fake user
|
||||
self.finish({
|
||||
'name': 'Foo',
|
||||
'email': 'foo@example.com'
|
||||
})
|
||||
|
||||
|
||||
class GoogleOAuth2Test(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
return Application(
|
||||
[
|
||||
# test endpoints
|
||||
('/client/login', GoogleLoginHandler, dict(test=self)),
|
||||
|
||||
# simulated google authorization server endpoints
|
||||
('/google/oauth2/authorize', GoogleOAuth2AuthorizeHandler),
|
||||
('/google/oauth2/token', GoogleOAuth2TokenHandler),
|
||||
('/google/oauth2/userinfo', GoogleOAuth2UserinfoHandler),
|
||||
],
|
||||
google_oauth={
|
||||
"key": 'fake_google_client_id',
|
||||
"secret": 'fake_google_client_secret'
|
||||
})
|
||||
|
||||
def test_google_login(self):
|
||||
response = self.fetch('/client/login')
|
||||
self.assertDictEqual({
|
||||
u'name': u'Foo',
|
||||
u'email': u'foo@example.com',
|
||||
u'access_token': u'fake-access-token',
|
||||
}, json_decode(response.body))
|
||||
Executable
+114
@@ -0,0 +1,114 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from subprocess import Popen
|
||||
import sys
|
||||
from tempfile import mkdtemp
|
||||
import time
|
||||
|
||||
from tornado.test.util import unittest
|
||||
|
||||
|
||||
class AutoreloadTest(unittest.TestCase):
|
||||
|
||||
def test_reload_module(self):
|
||||
main = """\
|
||||
import os
|
||||
import sys
|
||||
|
||||
from tornado import autoreload
|
||||
|
||||
# This import will fail if path is not set up correctly
|
||||
import testapp
|
||||
|
||||
print('Starting')
|
||||
if 'TESTAPP_STARTED' not in os.environ:
|
||||
os.environ['TESTAPP_STARTED'] = '1'
|
||||
sys.stdout.flush()
|
||||
autoreload._reload()
|
||||
"""
|
||||
|
||||
# Create temporary test application
|
||||
path = mkdtemp()
|
||||
self.addCleanup(shutil.rmtree, path)
|
||||
os.mkdir(os.path.join(path, 'testapp'))
|
||||
open(os.path.join(path, 'testapp/__init__.py'), 'w').close()
|
||||
with open(os.path.join(path, 'testapp/__main__.py'), 'w') as f:
|
||||
f.write(main)
|
||||
|
||||
# Make sure the tornado module under test is available to the test
|
||||
# application
|
||||
pythonpath = os.getcwd()
|
||||
if 'PYTHONPATH' in os.environ:
|
||||
pythonpath += os.pathsep + os.environ['PYTHONPATH']
|
||||
|
||||
p = Popen(
|
||||
[sys.executable, '-m', 'testapp'], stdout=subprocess.PIPE,
|
||||
cwd=path, env=dict(os.environ, PYTHONPATH=pythonpath),
|
||||
universal_newlines=True)
|
||||
out = p.communicate()[0]
|
||||
self.assertEqual(out, 'Starting\nStarting\n')
|
||||
|
||||
def test_reload_wrapper_preservation(self):
|
||||
# This test verifies that when `python -m tornado.autoreload`
|
||||
# is used on an application that also has an internal
|
||||
# autoreload, the reload wrapper is preserved on restart.
|
||||
main = """\
|
||||
import os
|
||||
import sys
|
||||
|
||||
# This import will fail if path is not set up correctly
|
||||
import testapp
|
||||
|
||||
if 'tornado.autoreload' not in sys.modules:
|
||||
raise Exception('started without autoreload wrapper')
|
||||
|
||||
import tornado.autoreload
|
||||
|
||||
print('Starting')
|
||||
sys.stdout.flush()
|
||||
if 'TESTAPP_STARTED' not in os.environ:
|
||||
os.environ['TESTAPP_STARTED'] = '1'
|
||||
# Simulate an internal autoreload (one not caused
|
||||
# by the wrapper).
|
||||
tornado.autoreload._reload()
|
||||
else:
|
||||
# Exit directly so autoreload doesn't catch it.
|
||||
os._exit(0)
|
||||
"""
|
||||
|
||||
# Create temporary test application
|
||||
path = mkdtemp()
|
||||
os.mkdir(os.path.join(path, 'testapp'))
|
||||
self.addCleanup(shutil.rmtree, path)
|
||||
init_file = os.path.join(path, 'testapp', '__init__.py')
|
||||
open(init_file, 'w').close()
|
||||
main_file = os.path.join(path, 'testapp', '__main__.py')
|
||||
with open(main_file, 'w') as f:
|
||||
f.write(main)
|
||||
|
||||
# Make sure the tornado module under test is available to the test
|
||||
# application
|
||||
pythonpath = os.getcwd()
|
||||
if 'PYTHONPATH' in os.environ:
|
||||
pythonpath += os.pathsep + os.environ['PYTHONPATH']
|
||||
|
||||
autoreload_proc = Popen(
|
||||
[sys.executable, '-m', 'tornado.autoreload', '-m', 'testapp'],
|
||||
stdout=subprocess.PIPE, cwd=path,
|
||||
env=dict(os.environ, PYTHONPATH=pythonpath),
|
||||
universal_newlines=True)
|
||||
|
||||
# This timeout needs to be fairly generous for pypy due to jit
|
||||
# warmup costs.
|
||||
for i in range(40):
|
||||
if autoreload_proc.poll() is not None:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
autoreload_proc.kill()
|
||||
raise Exception("subprocess failed to terminate")
|
||||
|
||||
out = autoreload_proc.communicate()[0]
|
||||
self.assertEqual(out, 'Starting\n' * 2)
|
||||
Executable
+496
@@ -0,0 +1,496 @@
|
||||
#
|
||||
# Copyright 2012 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import re
|
||||
import socket
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
from tornado.concurrent import (Future, return_future, ReturnValueIgnoredError,
|
||||
run_on_executor, future_set_result_unless_cancelled)
|
||||
from tornado.escape import utf8, to_unicode
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import IOStream
|
||||
from tornado.log import app_log
|
||||
from tornado import stack_context
|
||||
from tornado.tcpserver import TCPServer
|
||||
from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
|
||||
from tornado.test.util import unittest, skipBefore35, exec_test, ignore_deprecation
|
||||
|
||||
|
||||
try:
|
||||
from concurrent import futures
|
||||
except ImportError:
|
||||
futures = None
|
||||
|
||||
|
||||
class MiscFutureTest(AsyncTestCase):
|
||||
|
||||
def test_future_set_result_unless_cancelled(self):
|
||||
fut = Future()
|
||||
future_set_result_unless_cancelled(fut, 42)
|
||||
self.assertEqual(fut.result(), 42)
|
||||
self.assertFalse(fut.cancelled())
|
||||
|
||||
fut = Future()
|
||||
fut.cancel()
|
||||
is_cancelled = fut.cancelled()
|
||||
future_set_result_unless_cancelled(fut, 42)
|
||||
self.assertEqual(fut.cancelled(), is_cancelled)
|
||||
if not is_cancelled:
|
||||
self.assertEqual(fut.result(), 42)
|
||||
|
||||
|
||||
class ReturnFutureTest(AsyncTestCase):
|
||||
with ignore_deprecation():
|
||||
@return_future
|
||||
def sync_future(self, callback):
|
||||
callback(42)
|
||||
|
||||
@return_future
|
||||
def async_future(self, callback):
|
||||
self.io_loop.add_callback(callback, 42)
|
||||
|
||||
@return_future
|
||||
def immediate_failure(self, callback):
|
||||
1 / 0
|
||||
|
||||
@return_future
|
||||
def delayed_failure(self, callback):
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
|
||||
@return_future
|
||||
def return_value(self, callback):
|
||||
# Note that the result of both running the callback and returning
|
||||
# a value (or raising an exception) is unspecified; with current
|
||||
# implementations the last event prior to callback resolution wins.
|
||||
return 42
|
||||
|
||||
@return_future
|
||||
def no_result_future(self, callback):
|
||||
callback()
|
||||
|
||||
def test_immediate_failure(self):
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
# The caller sees the error just like a normal function.
|
||||
self.immediate_failure(callback=self.stop)
|
||||
# The callback is not run because the function failed synchronously.
|
||||
self.io_loop.add_timeout(self.io_loop.time() + 0.05, self.stop)
|
||||
result = self.wait()
|
||||
self.assertIs(result, None)
|
||||
|
||||
def test_return_value(self):
|
||||
with self.assertRaises(ReturnValueIgnoredError):
|
||||
self.return_value(callback=self.stop)
|
||||
|
||||
def test_callback_kw(self):
|
||||
with ignore_deprecation():
|
||||
future = self.sync_future(callback=self.stop)
|
||||
result = self.wait()
|
||||
self.assertEqual(result, 42)
|
||||
self.assertEqual(future.result(), 42)
|
||||
|
||||
def test_callback_positional(self):
|
||||
# When the callback is passed in positionally, future_wrap shouldn't
|
||||
# add another callback in the kwargs.
|
||||
with ignore_deprecation():
|
||||
future = self.sync_future(self.stop)
|
||||
result = self.wait()
|
||||
self.assertEqual(result, 42)
|
||||
self.assertEqual(future.result(), 42)
|
||||
|
||||
def test_no_callback(self):
|
||||
future = self.sync_future()
|
||||
self.assertEqual(future.result(), 42)
|
||||
|
||||
def test_none_callback_kw(self):
|
||||
# explicitly pass None as callback
|
||||
future = self.sync_future(callback=None)
|
||||
self.assertEqual(future.result(), 42)
|
||||
|
||||
def test_none_callback_pos(self):
|
||||
future = self.sync_future(None)
|
||||
self.assertEqual(future.result(), 42)
|
||||
|
||||
def test_async_future(self):
|
||||
future = self.async_future()
|
||||
self.assertFalse(future.done())
|
||||
self.io_loop.add_future(future, self.stop)
|
||||
future2 = self.wait()
|
||||
self.assertIs(future, future2)
|
||||
self.assertEqual(future.result(), 42)
|
||||
|
||||
@gen_test
|
||||
def test_async_future_gen(self):
|
||||
result = yield self.async_future()
|
||||
self.assertEqual(result, 42)
|
||||
|
||||
def test_delayed_failure(self):
|
||||
future = self.delayed_failure()
|
||||
with ignore_deprecation():
|
||||
self.io_loop.add_future(future, self.stop)
|
||||
future2 = self.wait()
|
||||
self.assertIs(future, future2)
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
future.result()
|
||||
|
||||
def test_kw_only_callback(self):
|
||||
with ignore_deprecation():
|
||||
@return_future
|
||||
def f(**kwargs):
|
||||
kwargs['callback'](42)
|
||||
future = f()
|
||||
self.assertEqual(future.result(), 42)
|
||||
|
||||
def test_error_in_callback(self):
|
||||
with ignore_deprecation():
|
||||
self.sync_future(callback=lambda future: 1 / 0)
|
||||
# The exception gets caught by our StackContext and will be re-raised
|
||||
# when we wait.
|
||||
self.assertRaises(ZeroDivisionError, self.wait)
|
||||
|
||||
def test_no_result_future(self):
|
||||
with ignore_deprecation():
|
||||
future = self.no_result_future(self.stop)
|
||||
result = self.wait()
|
||||
self.assertIs(result, None)
|
||||
# result of this future is undefined, but not an error
|
||||
future.result()
|
||||
|
||||
def test_no_result_future_callback(self):
|
||||
with ignore_deprecation():
|
||||
future = self.no_result_future(callback=lambda: self.stop())
|
||||
result = self.wait()
|
||||
self.assertIs(result, None)
|
||||
future.result()
|
||||
|
||||
@gen_test
|
||||
def test_future_traceback_legacy(self):
|
||||
with ignore_deprecation():
|
||||
@return_future
|
||||
@gen.engine
|
||||
def f(callback):
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
try:
|
||||
1 / 0
|
||||
except ZeroDivisionError:
|
||||
self.expected_frame = traceback.extract_tb(
|
||||
sys.exc_info()[2], limit=1)[0]
|
||||
raise
|
||||
try:
|
||||
yield f()
|
||||
self.fail("didn't get expected exception")
|
||||
except ZeroDivisionError:
|
||||
tb = traceback.extract_tb(sys.exc_info()[2])
|
||||
self.assertIn(self.expected_frame, tb)
|
||||
|
||||
@gen_test
|
||||
def test_future_traceback(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.moment
|
||||
try:
|
||||
1 / 0
|
||||
except ZeroDivisionError:
|
||||
self.expected_frame = traceback.extract_tb(
|
||||
sys.exc_info()[2], limit=1)[0]
|
||||
raise
|
||||
try:
|
||||
yield f()
|
||||
self.fail("didn't get expected exception")
|
||||
except ZeroDivisionError:
|
||||
tb = traceback.extract_tb(sys.exc_info()[2])
|
||||
self.assertIn(self.expected_frame, tb)
|
||||
|
||||
@gen_test
|
||||
def test_uncaught_exception_log(self):
|
||||
if IOLoop.configured_class().__name__.endswith('AsyncIOLoop'):
|
||||
# Install an exception handler that mirrors our
|
||||
# non-asyncio logging behavior.
|
||||
def exc_handler(loop, context):
|
||||
app_log.error('%s: %s', context['message'],
|
||||
type(context.get('exception')))
|
||||
self.io_loop.asyncio_loop.set_exception_handler(exc_handler)
|
||||
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.moment
|
||||
1 / 0
|
||||
|
||||
g = f()
|
||||
|
||||
with ExpectLog(app_log,
|
||||
"(?s)Future.* exception was never retrieved:"
|
||||
".*ZeroDivisionError"):
|
||||
yield gen.moment
|
||||
yield gen.moment
|
||||
# For some reason, TwistedIOLoop and pypy3 need a third iteration
|
||||
# in order to drain references to the future
|
||||
yield gen.moment
|
||||
del g
|
||||
gc.collect() # for PyPy
|
||||
|
||||
|
||||
# The following series of classes demonstrate and test various styles
|
||||
# of use, with and without generators and futures.
|
||||
|
||||
|
||||
class CapServer(TCPServer):
|
||||
@gen.coroutine
|
||||
def handle_stream(self, stream, address):
|
||||
data = yield stream.read_until(b"\n")
|
||||
data = to_unicode(data)
|
||||
if data == data.upper():
|
||||
stream.write(b"error\talready capitalized\n")
|
||||
else:
|
||||
# data already has \n
|
||||
stream.write(utf8("ok\t%s" % data.upper()))
|
||||
stream.close()
|
||||
|
||||
|
||||
class CapError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseCapClient(object):
|
||||
def __init__(self, port):
|
||||
self.port = port
|
||||
|
||||
def process_response(self, data):
|
||||
status, message = re.match('(.*)\t(.*)\n', to_unicode(data)).groups()
|
||||
if status == 'ok':
|
||||
return message
|
||||
else:
|
||||
raise CapError(message)
|
||||
|
||||
|
||||
class ManualCapClient(BaseCapClient):
|
||||
def capitalize(self, request_data, callback=None):
|
||||
logging.debug("capitalize")
|
||||
self.request_data = request_data
|
||||
self.stream = IOStream(socket.socket())
|
||||
self.stream.connect(('127.0.0.1', self.port),
|
||||
callback=self.handle_connect)
|
||||
self.future = Future()
|
||||
if callback is not None:
|
||||
self.future.add_done_callback(
|
||||
stack_context.wrap(lambda future: callback(future.result())))
|
||||
return self.future
|
||||
|
||||
def handle_connect(self):
|
||||
logging.debug("handle_connect")
|
||||
self.stream.write(utf8(self.request_data + "\n"))
|
||||
self.stream.read_until(b'\n', callback=self.handle_read)
|
||||
|
||||
def handle_read(self, data):
|
||||
logging.debug("handle_read")
|
||||
self.stream.close()
|
||||
try:
|
||||
self.future.set_result(self.process_response(data))
|
||||
except CapError as e:
|
||||
self.future.set_exception(e)
|
||||
|
||||
|
||||
class DecoratorCapClient(BaseCapClient):
|
||||
with ignore_deprecation():
|
||||
@return_future
|
||||
def capitalize(self, request_data, callback):
|
||||
logging.debug("capitalize")
|
||||
self.request_data = request_data
|
||||
self.stream = IOStream(socket.socket())
|
||||
self.stream.connect(('127.0.0.1', self.port),
|
||||
callback=self.handle_connect)
|
||||
self.callback = callback
|
||||
|
||||
def handle_connect(self):
|
||||
logging.debug("handle_connect")
|
||||
self.stream.write(utf8(self.request_data + "\n"))
|
||||
self.stream.read_until(b'\n', callback=self.handle_read)
|
||||
|
||||
def handle_read(self, data):
|
||||
logging.debug("handle_read")
|
||||
self.stream.close()
|
||||
self.callback(self.process_response(data))
|
||||
|
||||
|
||||
class GeneratorCapClient(BaseCapClient):
|
||||
@gen.coroutine
|
||||
def capitalize(self, request_data):
|
||||
logging.debug('capitalize')
|
||||
stream = IOStream(socket.socket())
|
||||
logging.debug('connecting')
|
||||
yield stream.connect(('127.0.0.1', self.port))
|
||||
stream.write(utf8(request_data + '\n'))
|
||||
logging.debug('reading')
|
||||
data = yield stream.read_until(b'\n')
|
||||
logging.debug('returning')
|
||||
stream.close()
|
||||
raise gen.Return(self.process_response(data))
|
||||
|
||||
|
||||
class ClientTestMixin(object):
|
||||
def setUp(self):
|
||||
super(ClientTestMixin, self).setUp() # type: ignore
|
||||
self.server = CapServer()
|
||||
sock, port = bind_unused_port()
|
||||
self.server.add_sockets([sock])
|
||||
self.client = self.client_class(port=port)
|
||||
|
||||
def tearDown(self):
|
||||
self.server.stop()
|
||||
super(ClientTestMixin, self).tearDown() # type: ignore
|
||||
|
||||
def test_callback(self):
|
||||
with ignore_deprecation():
|
||||
self.client.capitalize("hello", callback=self.stop)
|
||||
result = self.wait()
|
||||
self.assertEqual(result, "HELLO")
|
||||
|
||||
def test_callback_error(self):
|
||||
with ignore_deprecation():
|
||||
self.client.capitalize("HELLO", callback=self.stop)
|
||||
self.assertRaisesRegexp(CapError, "already capitalized", self.wait)
|
||||
|
||||
def test_future(self):
|
||||
future = self.client.capitalize("hello")
|
||||
self.io_loop.add_future(future, self.stop)
|
||||
self.wait()
|
||||
self.assertEqual(future.result(), "HELLO")
|
||||
|
||||
def test_future_error(self):
|
||||
future = self.client.capitalize("HELLO")
|
||||
self.io_loop.add_future(future, self.stop)
|
||||
self.wait()
|
||||
self.assertRaisesRegexp(CapError, "already capitalized", future.result)
|
||||
|
||||
def test_generator(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
result = yield self.client.capitalize("hello")
|
||||
self.assertEqual(result, "HELLO")
|
||||
self.io_loop.run_sync(f)
|
||||
|
||||
def test_generator_error(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
with self.assertRaisesRegexp(CapError, "already capitalized"):
|
||||
yield self.client.capitalize("HELLO")
|
||||
self.io_loop.run_sync(f)
|
||||
|
||||
|
||||
class ManualClientTest(ClientTestMixin, AsyncTestCase):
|
||||
client_class = ManualCapClient
|
||||
|
||||
def setUp(self):
|
||||
self.warning_catcher = warnings.catch_warnings()
|
||||
self.warning_catcher.__enter__()
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
super(ManualClientTest, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super(ManualClientTest, self).tearDown()
|
||||
self.warning_catcher.__exit__(None, None, None)
|
||||
|
||||
|
||||
class DecoratorClientTest(ClientTestMixin, AsyncTestCase):
|
||||
client_class = DecoratorCapClient
|
||||
|
||||
def setUp(self):
|
||||
self.warning_catcher = warnings.catch_warnings()
|
||||
self.warning_catcher.__enter__()
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
super(DecoratorClientTest, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super(DecoratorClientTest, self).tearDown()
|
||||
self.warning_catcher.__exit__(None, None, None)
|
||||
|
||||
|
||||
class GeneratorClientTest(ClientTestMixin, AsyncTestCase):
|
||||
client_class = GeneratorCapClient
|
||||
|
||||
|
||||
@unittest.skipIf(futures is None, "concurrent.futures module not present")
|
||||
class RunOnExecutorTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_no_calling(self):
|
||||
class Object(object):
|
||||
def __init__(self):
|
||||
self.executor = futures.thread.ThreadPoolExecutor(1)
|
||||
|
||||
@run_on_executor
|
||||
def f(self):
|
||||
return 42
|
||||
|
||||
o = Object()
|
||||
answer = yield o.f()
|
||||
self.assertEqual(answer, 42)
|
||||
|
||||
@gen_test
|
||||
def test_call_with_no_args(self):
|
||||
class Object(object):
|
||||
def __init__(self):
|
||||
self.executor = futures.thread.ThreadPoolExecutor(1)
|
||||
|
||||
@run_on_executor()
|
||||
def f(self):
|
||||
return 42
|
||||
|
||||
o = Object()
|
||||
answer = yield o.f()
|
||||
self.assertEqual(answer, 42)
|
||||
|
||||
@gen_test
|
||||
def test_call_with_executor(self):
|
||||
class Object(object):
|
||||
def __init__(self):
|
||||
self.__executor = futures.thread.ThreadPoolExecutor(1)
|
||||
|
||||
@run_on_executor(executor='_Object__executor')
|
||||
def f(self):
|
||||
return 42
|
||||
|
||||
o = Object()
|
||||
answer = yield o.f()
|
||||
self.assertEqual(answer, 42)
|
||||
|
||||
@skipBefore35
|
||||
@gen_test
|
||||
def test_async_await(self):
|
||||
class Object(object):
|
||||
def __init__(self):
|
||||
self.executor = futures.thread.ThreadPoolExecutor(1)
|
||||
|
||||
@run_on_executor()
|
||||
def f(self):
|
||||
return 42
|
||||
|
||||
o = Object()
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def f():
|
||||
answer = await o.f()
|
||||
return answer
|
||||
""")
|
||||
result = yield namespace['f']()
|
||||
self.assertEqual(result, 42)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Executable
+1
@@ -0,0 +1 @@
|
||||
"school","école"
|
||||
|
Executable
+153
@@ -0,0 +1,153 @@
|
||||
# coding: utf-8
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from hashlib import md5
|
||||
|
||||
from tornado.escape import utf8
|
||||
from tornado.httpclient import HTTPRequest, HTTPClientError
|
||||
from tornado.locks import Event
|
||||
from tornado.stack_context import ExceptionStackContext
|
||||
from tornado.testing import AsyncHTTPTestCase, gen_test
|
||||
from tornado.test import httpclient_test
|
||||
from tornado.test.util import unittest, ignore_deprecation
|
||||
from tornado.web import Application, RequestHandler
|
||||
|
||||
|
||||
try:
|
||||
import pycurl # type: ignore
|
||||
except ImportError:
|
||||
pycurl = None
|
||||
|
||||
if pycurl is not None:
|
||||
from tornado.curl_httpclient import CurlAsyncHTTPClient
|
||||
|
||||
|
||||
@unittest.skipIf(pycurl is None, "pycurl module not present")
|
||||
class CurlHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
|
||||
def get_http_client(self):
|
||||
client = CurlAsyncHTTPClient(defaults=dict(allow_ipv6=False))
|
||||
# make sure AsyncHTTPClient magic doesn't give us the wrong class
|
||||
self.assertTrue(isinstance(client, CurlAsyncHTTPClient))
|
||||
return client
|
||||
|
||||
|
||||
class DigestAuthHandler(RequestHandler):
|
||||
def initialize(self, username, password):
|
||||
self.username = username
|
||||
self.password = password
|
||||
|
||||
def get(self):
|
||||
realm = 'test'
|
||||
opaque = 'asdf'
|
||||
# Real implementations would use a random nonce.
|
||||
nonce = "1234"
|
||||
|
||||
auth_header = self.request.headers.get('Authorization', None)
|
||||
if auth_header is not None:
|
||||
auth_mode, params = auth_header.split(' ', 1)
|
||||
assert auth_mode == 'Digest'
|
||||
param_dict = {}
|
||||
for pair in params.split(','):
|
||||
k, v = pair.strip().split('=', 1)
|
||||
if v[0] == '"' and v[-1] == '"':
|
||||
v = v[1:-1]
|
||||
param_dict[k] = v
|
||||
assert param_dict['realm'] == realm
|
||||
assert param_dict['opaque'] == opaque
|
||||
assert param_dict['nonce'] == nonce
|
||||
assert param_dict['username'] == self.username
|
||||
assert param_dict['uri'] == self.request.path
|
||||
h1 = md5(utf8('%s:%s:%s' % (self.username, realm, self.password))).hexdigest()
|
||||
h2 = md5(utf8('%s:%s' % (self.request.method,
|
||||
self.request.path))).hexdigest()
|
||||
digest = md5(utf8('%s:%s:%s' % (h1, nonce, h2))).hexdigest()
|
||||
if digest == param_dict['response']:
|
||||
self.write('ok')
|
||||
else:
|
||||
self.write('fail')
|
||||
else:
|
||||
self.set_status(401)
|
||||
self.set_header('WWW-Authenticate',
|
||||
'Digest realm="%s", nonce="%s", opaque="%s"' %
|
||||
(realm, nonce, opaque))
|
||||
|
||||
|
||||
class CustomReasonHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.set_status(200, "Custom reason")
|
||||
|
||||
|
||||
class CustomFailReasonHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.set_status(400, "Custom reason")
|
||||
|
||||
|
||||
@unittest.skipIf(pycurl is None, "pycurl module not present")
|
||||
class CurlHTTPClientTestCase(AsyncHTTPTestCase):
|
||||
def setUp(self):
|
||||
super(CurlHTTPClientTestCase, self).setUp()
|
||||
self.http_client = self.create_client()
|
||||
|
||||
def get_app(self):
|
||||
return Application([
|
||||
('/digest', DigestAuthHandler, {'username': 'foo', 'password': 'bar'}),
|
||||
('/digest_non_ascii', DigestAuthHandler, {'username': 'foo', 'password': 'barユ£'}),
|
||||
('/custom_reason', CustomReasonHandler),
|
||||
('/custom_fail_reason', CustomFailReasonHandler),
|
||||
])
|
||||
|
||||
def create_client(self, **kwargs):
|
||||
return CurlAsyncHTTPClient(force_instance=True,
|
||||
defaults=dict(allow_ipv6=False),
|
||||
**kwargs)
|
||||
|
||||
@gen_test
|
||||
def test_prepare_curl_callback_stack_context(self):
|
||||
exc_info = []
|
||||
error_event = Event()
|
||||
|
||||
def error_handler(typ, value, tb):
|
||||
exc_info.append((typ, value, tb))
|
||||
error_event.set()
|
||||
return True
|
||||
|
||||
with ignore_deprecation():
|
||||
with ExceptionStackContext(error_handler):
|
||||
request = HTTPRequest(self.get_url('/custom_reason'),
|
||||
prepare_curl_callback=lambda curl: 1 / 0)
|
||||
yield [error_event.wait(), self.http_client.fetch(request)]
|
||||
self.assertEqual(1, len(exc_info))
|
||||
self.assertIs(exc_info[0][0], ZeroDivisionError)
|
||||
|
||||
def test_digest_auth(self):
|
||||
response = self.fetch('/digest', auth_mode='digest',
|
||||
auth_username='foo', auth_password='bar')
|
||||
self.assertEqual(response.body, b'ok')
|
||||
|
||||
def test_custom_reason(self):
|
||||
response = self.fetch('/custom_reason')
|
||||
self.assertEqual(response.reason, "Custom reason")
|
||||
|
||||
def test_fail_custom_reason(self):
|
||||
response = self.fetch('/custom_fail_reason')
|
||||
self.assertEqual(str(response.error), "HTTP 400: Custom reason")
|
||||
|
||||
def test_failed_setup(self):
|
||||
self.http_client = self.create_client(max_clients=1)
|
||||
for i in range(5):
|
||||
with ignore_deprecation():
|
||||
response = self.fetch(u'/ユニコード')
|
||||
self.assertIsNot(response.error, None)
|
||||
|
||||
with self.assertRaises((UnicodeEncodeError, HTTPClientError)):
|
||||
# This raises UnicodeDecodeError on py3 and
|
||||
# HTTPClientError(404) on py2. The main motivation of
|
||||
# this test is to ensure that the UnicodeEncodeError
|
||||
# during the setup phase doesn't lead the request to
|
||||
# be dropped on the floor.
|
||||
response = self.fetch(u'/ユニコード', raise_error=True)
|
||||
|
||||
def test_digest_auth_non_ascii(self):
|
||||
response = self.fetch('/digest_non_ascii', auth_mode='digest',
|
||||
auth_username='foo', auth_password='barユ£')
|
||||
self.assertEqual(response.body, b'ok')
|
||||
Executable
+250
@@ -0,0 +1,250 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import tornado.escape
|
||||
from tornado.escape import (
|
||||
utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape,
|
||||
to_unicode, json_decode, json_encode, squeeze, recursive_unicode,
|
||||
)
|
||||
from tornado.util import unicode_type
|
||||
from tornado.test.util import unittest
|
||||
|
||||
linkify_tests = [
|
||||
# (input, linkify_kwargs, expected_output)
|
||||
|
||||
("hello http://world.com/!", {},
|
||||
u'hello <a href="http://world.com/">http://world.com/</a>!'),
|
||||
|
||||
("hello http://world.com/with?param=true&stuff=yes", {},
|
||||
u'hello <a href="http://world.com/with?param=true&stuff=yes">http://world.com/with?param=true&stuff=yes</a>'), # noqa: E501
|
||||
|
||||
# an opened paren followed by many chars killed Gruber's regex
|
||||
("http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", {},
|
||||
u'<a href="http://url.com/w">http://url.com/w</a>(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'), # noqa: E501
|
||||
|
||||
# as did too many dots at the end
|
||||
("http://url.com/withmany.......................................", {},
|
||||
u'<a href="http://url.com/withmany">http://url.com/withmany</a>.......................................'), # noqa: E501
|
||||
|
||||
("http://url.com/withmany((((((((((((((((((((((((((((((((((a)", {},
|
||||
u'<a href="http://url.com/withmany">http://url.com/withmany</a>((((((((((((((((((((((((((((((((((a)'), # noqa: E501
|
||||
|
||||
# some examples from http://daringfireball.net/2009/11/liberal_regex_for_matching_urls
|
||||
# plus a fex extras (such as multiple parentheses).
|
||||
("http://foo.com/blah_blah", {},
|
||||
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>'),
|
||||
|
||||
("http://foo.com/blah_blah/", {},
|
||||
u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>'),
|
||||
|
||||
("(Something like http://foo.com/blah_blah)", {},
|
||||
u'(Something like <a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>)'),
|
||||
|
||||
("http://foo.com/blah_blah_(wikipedia)", {},
|
||||
u'<a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>'),
|
||||
|
||||
("http://foo.com/blah_(blah)_(wikipedia)_blah", {},
|
||||
u'<a href="http://foo.com/blah_(blah)_(wikipedia)_blah">http://foo.com/blah_(blah)_(wikipedia)_blah</a>'), # noqa: E501
|
||||
|
||||
("(Something like http://foo.com/blah_blah_(wikipedia))", {},
|
||||
u'(Something like <a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>)'), # noqa: E501
|
||||
|
||||
("http://foo.com/blah_blah.", {},
|
||||
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>.'),
|
||||
|
||||
("http://foo.com/blah_blah/.", {},
|
||||
u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>.'),
|
||||
|
||||
("<http://foo.com/blah_blah>", {},
|
||||
u'<<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>>'),
|
||||
|
||||
("<http://foo.com/blah_blah/>", {},
|
||||
u'<<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>>'),
|
||||
|
||||
("http://foo.com/blah_blah,", {},
|
||||
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>,'),
|
||||
|
||||
("http://www.example.com/wpstyle/?p=364.", {},
|
||||
u'<a href="http://www.example.com/wpstyle/?p=364">http://www.example.com/wpstyle/?p=364</a>.'),
|
||||
|
||||
("rdar://1234",
|
||||
{"permitted_protocols": ["http", "rdar"]},
|
||||
u'<a href="rdar://1234">rdar://1234</a>'),
|
||||
|
||||
("rdar:/1234",
|
||||
{"permitted_protocols": ["rdar"]},
|
||||
u'<a href="rdar:/1234">rdar:/1234</a>'),
|
||||
|
||||
("http://userid:password@example.com:8080", {},
|
||||
u'<a href="http://userid:password@example.com:8080">http://userid:password@example.com:8080</a>'), # noqa: E501
|
||||
|
||||
("http://userid@example.com", {},
|
||||
u'<a href="http://userid@example.com">http://userid@example.com</a>'),
|
||||
|
||||
("http://userid@example.com:8080", {},
|
||||
u'<a href="http://userid@example.com:8080">http://userid@example.com:8080</a>'),
|
||||
|
||||
("http://userid:password@example.com", {},
|
||||
u'<a href="http://userid:password@example.com">http://userid:password@example.com</a>'),
|
||||
|
||||
("message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e",
|
||||
{"permitted_protocols": ["http", "message"]},
|
||||
u'<a href="message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e">'
|
||||
u'message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e</a>'),
|
||||
|
||||
(u"http://\u27a1.ws/\u4a39", {},
|
||||
u'<a href="http://\u27a1.ws/\u4a39">http://\u27a1.ws/\u4a39</a>'),
|
||||
|
||||
("<tag>http://example.com</tag>", {},
|
||||
u'<tag><a href="http://example.com">http://example.com</a></tag>'),
|
||||
|
||||
("Just a www.example.com link.", {},
|
||||
u'Just a <a href="http://www.example.com">www.example.com</a> link.'),
|
||||
|
||||
("Just a www.example.com link.",
|
||||
{"require_protocol": True},
|
||||
u'Just a www.example.com link.'),
|
||||
|
||||
("A http://reallylong.com/link/that/exceedsthelenglimit.html",
|
||||
{"require_protocol": True, "shorten": True},
|
||||
u'A <a href="http://reallylong.com/link/that/exceedsthelenglimit.html"'
|
||||
u' title="http://reallylong.com/link/that/exceedsthelenglimit.html">http://reallylong.com/link...</a>'), # noqa: E501
|
||||
|
||||
("A http://reallylongdomainnamethatwillbetoolong.com/hi!",
|
||||
{"shorten": True},
|
||||
u'A <a href="http://reallylongdomainnamethatwillbetoolong.com/hi"'
|
||||
u' title="http://reallylongdomainnamethatwillbetoolong.com/hi">http://reallylongdomainnametha...</a>!'), # noqa: E501
|
||||
|
||||
("A file:///passwords.txt and http://web.com link", {},
|
||||
u'A file:///passwords.txt and <a href="http://web.com">http://web.com</a> link'),
|
||||
|
||||
("A file:///passwords.txt and http://web.com link",
|
||||
{"permitted_protocols": ["file"]},
|
||||
u'A <a href="file:///passwords.txt">file:///passwords.txt</a> and http://web.com link'),
|
||||
|
||||
("www.external-link.com",
|
||||
{"extra_params": 'rel="nofollow" class="external"'},
|
||||
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>'), # noqa: E501
|
||||
|
||||
("www.external-link.com and www.internal-link.com/blogs extra",
|
||||
{"extra_params": lambda href: 'class="internal"' if href.startswith("http://www.internal-link.com") else 'rel="nofollow" class="external"'}, # noqa: E501
|
||||
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>' # noqa: E501
|
||||
u' and <a href="http://www.internal-link.com/blogs" class="internal">www.internal-link.com/blogs</a> extra'), # noqa: E501
|
||||
|
||||
("www.external-link.com",
|
||||
{"extra_params": lambda href: ' rel="nofollow" class="external" '},
|
||||
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>'), # noqa: E501
|
||||
]
|
||||
|
||||
|
||||
class EscapeTestCase(unittest.TestCase):
|
||||
def test_linkify(self):
|
||||
for text, kwargs, html in linkify_tests:
|
||||
linked = tornado.escape.linkify(text, **kwargs)
|
||||
self.assertEqual(linked, html)
|
||||
|
||||
def test_xhtml_escape(self):
|
||||
tests = [
|
||||
("<foo>", "<foo>"),
|
||||
(u"<foo>", u"<foo>"),
|
||||
(b"<foo>", b"<foo>"),
|
||||
|
||||
("<>&\"'", "<>&"'"),
|
||||
("&", "&amp;"),
|
||||
|
||||
(u"<\u00e9>", u"<\u00e9>"),
|
||||
(b"<\xc3\xa9>", b"<\xc3\xa9>"),
|
||||
]
|
||||
for unescaped, escaped in tests:
|
||||
self.assertEqual(utf8(xhtml_escape(unescaped)), utf8(escaped))
|
||||
self.assertEqual(utf8(unescaped), utf8(xhtml_unescape(escaped)))
|
||||
|
||||
def test_xhtml_unescape_numeric(self):
|
||||
tests = [
|
||||
('foo bar', 'foo bar'),
|
||||
('foo bar', 'foo bar'),
|
||||
('foo bar', 'foo bar'),
|
||||
('foo઼bar', u'foo\u0abcbar'),
|
||||
('foo&#xyz;bar', 'foo&#xyz;bar'), # invalid encoding
|
||||
('foo&#;bar', 'foo&#;bar'), # invalid encoding
|
||||
('foo&#x;bar', 'foo&#x;bar'), # invalid encoding
|
||||
]
|
||||
for escaped, unescaped in tests:
|
||||
self.assertEqual(unescaped, xhtml_unescape(escaped))
|
||||
|
||||
def test_url_escape_unicode(self):
|
||||
tests = [
|
||||
# byte strings are passed through as-is
|
||||
(u'\u00e9'.encode('utf8'), '%C3%A9'),
|
||||
(u'\u00e9'.encode('latin1'), '%E9'),
|
||||
|
||||
# unicode strings become utf8
|
||||
(u'\u00e9', '%C3%A9'),
|
||||
]
|
||||
for unescaped, escaped in tests:
|
||||
self.assertEqual(url_escape(unescaped), escaped)
|
||||
|
||||
def test_url_unescape_unicode(self):
|
||||
tests = [
|
||||
('%C3%A9', u'\u00e9', 'utf8'),
|
||||
('%C3%A9', u'\u00c3\u00a9', 'latin1'),
|
||||
('%C3%A9', utf8(u'\u00e9'), None),
|
||||
]
|
||||
for escaped, unescaped, encoding in tests:
|
||||
# input strings to url_unescape should only contain ascii
|
||||
# characters, but make sure the function accepts both byte
|
||||
# and unicode strings.
|
||||
self.assertEqual(url_unescape(to_unicode(escaped), encoding), unescaped)
|
||||
self.assertEqual(url_unescape(utf8(escaped), encoding), unescaped)
|
||||
|
||||
def test_url_escape_quote_plus(self):
|
||||
unescaped = '+ #%'
|
||||
plus_escaped = '%2B+%23%25'
|
||||
escaped = '%2B%20%23%25'
|
||||
self.assertEqual(url_escape(unescaped), plus_escaped)
|
||||
self.assertEqual(url_escape(unescaped, plus=False), escaped)
|
||||
self.assertEqual(url_unescape(plus_escaped), unescaped)
|
||||
self.assertEqual(url_unescape(escaped, plus=False), unescaped)
|
||||
self.assertEqual(url_unescape(plus_escaped, encoding=None),
|
||||
utf8(unescaped))
|
||||
self.assertEqual(url_unescape(escaped, encoding=None, plus=False),
|
||||
utf8(unescaped))
|
||||
|
||||
def test_escape_return_types(self):
|
||||
# On python2 the escape methods should generally return the same
|
||||
# type as their argument
|
||||
self.assertEqual(type(xhtml_escape("foo")), str)
|
||||
self.assertEqual(type(xhtml_escape(u"foo")), unicode_type)
|
||||
|
||||
def test_json_decode(self):
|
||||
# json_decode accepts both bytes and unicode, but strings it returns
|
||||
# are always unicode.
|
||||
self.assertEqual(json_decode(b'"foo"'), u"foo")
|
||||
self.assertEqual(json_decode(u'"foo"'), u"foo")
|
||||
|
||||
# Non-ascii bytes are interpreted as utf8
|
||||
self.assertEqual(json_decode(utf8(u'"\u00e9"')), u"\u00e9")
|
||||
|
||||
def test_json_encode(self):
|
||||
# json deals with strings, not bytes. On python 2 byte strings will
|
||||
# convert automatically if they are utf8; on python 3 byte strings
|
||||
# are not allowed.
|
||||
self.assertEqual(json_decode(json_encode(u"\u00e9")), u"\u00e9")
|
||||
if bytes is str:
|
||||
self.assertEqual(json_decode(json_encode(utf8(u"\u00e9"))), u"\u00e9")
|
||||
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")
|
||||
|
||||
def test_squeeze(self):
|
||||
self.assertEqual(squeeze(u'sequences of whitespace chars'),
|
||||
u'sequences of whitespace chars')
|
||||
|
||||
def test_recursive_unicode(self):
|
||||
tests = {
|
||||
'dict': {b"foo": b"bar"},
|
||||
'list': [b"foo", b"bar"],
|
||||
'tuple': (b"foo", b"bar"),
|
||||
'bytes': b"foo"
|
||||
}
|
||||
self.assertEqual(recursive_unicode(tests['dict']), {u"foo": u"bar"})
|
||||
self.assertEqual(recursive_unicode(tests['list']), [u"foo", u"bar"])
|
||||
self.assertEqual(recursive_unicode(tests['tuple']), (u"foo", u"bar"))
|
||||
self.assertEqual(recursive_unicode(tests['bytes']), u"foo")
|
||||
Executable
+1862
File diff suppressed because it is too large
Load Diff
+16
@@ -0,0 +1,16 @@
|
||||
# flake8: noqa
|
||||
# Dummy source file to allow creation of the initial .po file in the
|
||||
# same way as a real project. I'm not entirely sure about the real
|
||||
# workflow here, but this seems to work.
|
||||
#
|
||||
# 1) xgettext --language=Python --keyword=_:1,2 --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3 extract_me.py -o tornado_test.po
|
||||
# 2) Edit tornado_test.po, setting CHARSET, Plural-Forms and setting msgstr
|
||||
# 3) msgfmt tornado_test.po -o tornado_test.mo
|
||||
# 4) Put the file in the proper location: $LANG/LC_MESSAGES
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
_("school")
|
||||
pgettext("law", "right")
|
||||
pgettext("good", "right")
|
||||
pgettext("organization", "club", "clubs", 1)
|
||||
pgettext("stick", "club", "clubs", 1)
|
||||
Binary file not shown.
@@ -0,0 +1,47 @@
|
||||
# SOME DESCRIPTIVE TITLE.
|
||||
# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER
|
||||
# This file is distributed under the same license as the PACKAGE package.
|
||||
# FIRST AUTHOR <EMAIL@ADDRESS>, YEAR.
|
||||
#
|
||||
#, fuzzy
|
||||
msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: PACKAGE VERSION\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2015-01-27 11:05+0300\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language-Team: LANGUAGE <LL@li.org>\n"
|
||||
"Language: \n"
|
||||
"MIME-Version: 1.0\n"
|
||||
"Content-Type: text/plain; charset=utf-8\n"
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Plural-Forms: nplurals=2; plural=(n > 1);\n"
|
||||
|
||||
#: extract_me.py:11
|
||||
msgid "school"
|
||||
msgstr "école"
|
||||
|
||||
#: extract_me.py:12
|
||||
msgctxt "law"
|
||||
msgid "right"
|
||||
msgstr "le droit"
|
||||
|
||||
#: extract_me.py:13
|
||||
msgctxt "good"
|
||||
msgid "right"
|
||||
msgstr "le bien"
|
||||
|
||||
#: extract_me.py:14
|
||||
msgctxt "organization"
|
||||
msgid "club"
|
||||
msgid_plural "clubs"
|
||||
msgstr[0] "le club"
|
||||
msgstr[1] "les clubs"
|
||||
|
||||
#: extract_me.py:15
|
||||
msgctxt "stick"
|
||||
msgid "club"
|
||||
msgid_plural "clubs"
|
||||
msgstr[0] "le bâton"
|
||||
msgstr[1] "les bâtons"
|
||||
Executable
+61
@@ -0,0 +1,61 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import socket
|
||||
|
||||
from tornado.http1connection import HTTP1Connection
|
||||
from tornado.httputil import HTTPMessageDelegate
|
||||
from tornado.iostream import IOStream
|
||||
from tornado.locks import Event
|
||||
from tornado.netutil import add_accept_handler
|
||||
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
|
||||
|
||||
|
||||
class HTTP1ConnectionTest(AsyncTestCase):
|
||||
def setUp(self):
|
||||
super(HTTP1ConnectionTest, self).setUp()
|
||||
self.asyncSetUp()
|
||||
|
||||
@gen_test
|
||||
def asyncSetUp(self):
|
||||
listener, port = bind_unused_port()
|
||||
event = Event()
|
||||
|
||||
def accept_callback(conn, addr):
|
||||
self.server_stream = IOStream(conn)
|
||||
self.addCleanup(self.server_stream.close)
|
||||
event.set()
|
||||
|
||||
add_accept_handler(listener, accept_callback)
|
||||
self.client_stream = IOStream(socket.socket())
|
||||
self.addCleanup(self.client_stream.close)
|
||||
yield [self.client_stream.connect(('127.0.0.1', port)),
|
||||
event.wait()]
|
||||
self.io_loop.remove_handler(listener)
|
||||
listener.close()
|
||||
|
||||
@gen_test
|
||||
def test_http10_no_content_length(self):
|
||||
# Regression test for a bug in which can_keep_alive would crash
|
||||
# for an HTTP/1.0 (not 1.1) response with no content-length.
|
||||
conn = HTTP1Connection(self.client_stream, True)
|
||||
self.server_stream.write(b"HTTP/1.0 200 Not Modified\r\n\r\nhello")
|
||||
self.server_stream.close()
|
||||
|
||||
event = Event()
|
||||
test = self
|
||||
body = []
|
||||
|
||||
class Delegate(HTTPMessageDelegate):
|
||||
def headers_received(self, start_line, headers):
|
||||
test.code = start_line.code
|
||||
|
||||
def data_received(self, data):
|
||||
body.append(data)
|
||||
|
||||
def finish(self):
|
||||
event.set()
|
||||
|
||||
yield conn.read_response(Delegate())
|
||||
yield event.wait()
|
||||
self.assertEqual(self.code, 200)
|
||||
self.assertEqual(b''.join(body), b'hello')
|
||||
Executable
+718
@@ -0,0 +1,718 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
from contextlib import closing
|
||||
import copy
|
||||
import sys
|
||||
import threading
|
||||
import datetime
|
||||
from io import BytesIO
|
||||
import time
|
||||
import unicodedata
|
||||
|
||||
from tornado.escape import utf8, native_str
|
||||
from tornado import gen
|
||||
from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import IOStream
|
||||
from tornado.log import gen_log
|
||||
from tornado import netutil
|
||||
from tornado.stack_context import ExceptionStackContext, NullContext
|
||||
from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog
|
||||
from tornado.test.util import unittest, skipOnTravis, ignore_deprecation
|
||||
from tornado.web import Application, RequestHandler, url
|
||||
from tornado.httputil import format_timestamp, HTTPHeaders
|
||||
|
||||
|
||||
class HelloWorldHandler(RequestHandler):
|
||||
def get(self):
|
||||
name = self.get_argument("name", "world")
|
||||
self.set_header("Content-Type", "text/plain")
|
||||
self.finish("Hello %s!" % name)
|
||||
|
||||
|
||||
class PostHandler(RequestHandler):
|
||||
def post(self):
|
||||
self.finish("Post arg1: %s, arg2: %s" % (
|
||||
self.get_argument("arg1"), self.get_argument("arg2")))
|
||||
|
||||
|
||||
class PutHandler(RequestHandler):
|
||||
def put(self):
|
||||
self.write("Put body: ")
|
||||
self.write(self.request.body)
|
||||
|
||||
|
||||
class RedirectHandler(RequestHandler):
|
||||
def prepare(self):
|
||||
self.write('redirects can have bodies too')
|
||||
self.redirect(self.get_argument("url"),
|
||||
status=int(self.get_argument("status", "302")))
|
||||
|
||||
|
||||
class ChunkHandler(RequestHandler):
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
self.write("asdf")
|
||||
self.flush()
|
||||
# Wait a bit to ensure the chunks are sent and received separately.
|
||||
yield gen.sleep(0.01)
|
||||
self.write("qwer")
|
||||
|
||||
|
||||
class AuthHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.finish(self.request.headers["Authorization"])
|
||||
|
||||
|
||||
class CountdownHandler(RequestHandler):
|
||||
def get(self, count):
|
||||
count = int(count)
|
||||
if count > 0:
|
||||
self.redirect(self.reverse_url("countdown", count - 1))
|
||||
else:
|
||||
self.write("Zero")
|
||||
|
||||
|
||||
class EchoPostHandler(RequestHandler):
|
||||
def post(self):
|
||||
self.write(self.request.body)
|
||||
|
||||
|
||||
class UserAgentHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.write(self.request.headers.get('User-Agent', 'User agent not set'))
|
||||
|
||||
|
||||
class ContentLength304Handler(RequestHandler):
|
||||
def get(self):
|
||||
self.set_status(304)
|
||||
self.set_header('Content-Length', 42)
|
||||
|
||||
def _clear_headers_for_304(self):
|
||||
# Tornado strips content-length from 304 responses, but here we
|
||||
# want to simulate servers that include the headers anyway.
|
||||
pass
|
||||
|
||||
|
||||
class PatchHandler(RequestHandler):
|
||||
|
||||
def patch(self):
|
||||
"Return the request payload - so we can check it is being kept"
|
||||
self.write(self.request.body)
|
||||
|
||||
|
||||
class AllMethodsHandler(RequestHandler):
|
||||
SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ('OTHER',)
|
||||
|
||||
def method(self):
|
||||
self.write(self.request.method)
|
||||
|
||||
get = post = put = delete = options = patch = other = method
|
||||
|
||||
|
||||
class SetHeaderHandler(RequestHandler):
|
||||
def get(self):
|
||||
# Use get_arguments for keys to get strings, but
|
||||
# request.arguments for values to get bytes.
|
||||
for k, v in zip(self.get_arguments('k'),
|
||||
self.request.arguments['v']):
|
||||
self.set_header(k, v)
|
||||
|
||||
# These tests end up getting run redundantly: once here with the default
|
||||
# HTTPClient implementation, and then again in each implementation's own
|
||||
# test suite.
|
||||
|
||||
|
||||
class HTTPClientCommonTestCase(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
return Application([
|
||||
url("/hello", HelloWorldHandler),
|
||||
url("/post", PostHandler),
|
||||
url("/put", PutHandler),
|
||||
url("/redirect", RedirectHandler),
|
||||
url("/chunk", ChunkHandler),
|
||||
url("/auth", AuthHandler),
|
||||
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
|
||||
url("/echopost", EchoPostHandler),
|
||||
url("/user_agent", UserAgentHandler),
|
||||
url("/304_with_content_length", ContentLength304Handler),
|
||||
url("/all_methods", AllMethodsHandler),
|
||||
url('/patch', PatchHandler),
|
||||
url('/set_header', SetHeaderHandler),
|
||||
], gzip=True)
|
||||
|
||||
def test_patch_receives_payload(self):
|
||||
body = b"some patch data"
|
||||
response = self.fetch("/patch", method='PATCH', body=body)
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertEqual(response.body, body)
|
||||
|
||||
@skipOnTravis
|
||||
def test_hello_world(self):
|
||||
response = self.fetch("/hello")
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertEqual(response.headers["Content-Type"], "text/plain")
|
||||
self.assertEqual(response.body, b"Hello world!")
|
||||
self.assertEqual(int(response.request_time), 0)
|
||||
|
||||
response = self.fetch("/hello?name=Ben")
|
||||
self.assertEqual(response.body, b"Hello Ben!")
|
||||
|
||||
def test_streaming_callback(self):
|
||||
# streaming_callback is also tested in test_chunked
|
||||
chunks = []
|
||||
response = self.fetch("/hello",
|
||||
streaming_callback=chunks.append)
|
||||
# with streaming_callback, data goes to the callback and not response.body
|
||||
self.assertEqual(chunks, [b"Hello world!"])
|
||||
self.assertFalse(response.body)
|
||||
|
||||
def test_post(self):
|
||||
response = self.fetch("/post", method="POST",
|
||||
body="arg1=foo&arg2=bar")
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
|
||||
|
||||
def test_chunked(self):
|
||||
response = self.fetch("/chunk")
|
||||
self.assertEqual(response.body, b"asdfqwer")
|
||||
|
||||
chunks = []
|
||||
response = self.fetch("/chunk",
|
||||
streaming_callback=chunks.append)
|
||||
self.assertEqual(chunks, [b"asdf", b"qwer"])
|
||||
self.assertFalse(response.body)
|
||||
|
||||
def test_chunked_close(self):
|
||||
# test case in which chunks spread read-callback processing
|
||||
# over several ioloop iterations, but the connection is already closed.
|
||||
sock, port = bind_unused_port()
|
||||
with closing(sock):
|
||||
@gen.coroutine
|
||||
def accept_callback(conn, address):
|
||||
# fake an HTTP server using chunked encoding where the final chunks
|
||||
# and connection close all happen at once
|
||||
stream = IOStream(conn)
|
||||
request_data = yield stream.read_until(b"\r\n\r\n")
|
||||
if b"HTTP/1." not in request_data:
|
||||
self.skipTest("requires HTTP/1.x")
|
||||
yield stream.write(b"""\
|
||||
HTTP/1.1 200 OK
|
||||
Transfer-Encoding: chunked
|
||||
|
||||
1
|
||||
1
|
||||
1
|
||||
2
|
||||
0
|
||||
|
||||
""".replace(b"\n", b"\r\n"))
|
||||
stream.close()
|
||||
netutil.add_accept_handler(sock, accept_callback)
|
||||
resp = self.fetch("http://127.0.0.1:%d/" % port)
|
||||
resp.rethrow()
|
||||
self.assertEqual(resp.body, b"12")
|
||||
self.io_loop.remove_handler(sock.fileno())
|
||||
|
||||
def test_streaming_stack_context(self):
|
||||
chunks = []
|
||||
exc_info = []
|
||||
|
||||
def error_handler(typ, value, tb):
|
||||
exc_info.append((typ, value, tb))
|
||||
return True
|
||||
|
||||
def streaming_cb(chunk):
|
||||
chunks.append(chunk)
|
||||
if chunk == b'qwer':
|
||||
1 / 0
|
||||
|
||||
with ignore_deprecation():
|
||||
with ExceptionStackContext(error_handler):
|
||||
self.fetch('/chunk', streaming_callback=streaming_cb)
|
||||
|
||||
self.assertEqual(chunks, [b'asdf', b'qwer'])
|
||||
self.assertEqual(1, len(exc_info))
|
||||
self.assertIs(exc_info[0][0], ZeroDivisionError)
|
||||
|
||||
def test_basic_auth(self):
|
||||
# This test data appears in section 2 of RFC 7617.
|
||||
self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
|
||||
auth_password="open sesame").body,
|
||||
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
|
||||
|
||||
def test_basic_auth_explicit_mode(self):
|
||||
self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
|
||||
auth_password="open sesame",
|
||||
auth_mode="basic").body,
|
||||
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
|
||||
|
||||
def test_basic_auth_unicode(self):
|
||||
# This test data appears in section 2.1 of RFC 7617.
|
||||
self.assertEqual(self.fetch("/auth", auth_username="test",
|
||||
auth_password="123£").body,
|
||||
b"Basic dGVzdDoxMjPCow==")
|
||||
|
||||
# The standard mandates NFC. Give it a decomposed username
|
||||
# and ensure it is normalized to composed form.
|
||||
username = unicodedata.normalize("NFD", u"josé")
|
||||
self.assertEqual(self.fetch("/auth",
|
||||
auth_username=username,
|
||||
auth_password="səcrət").body,
|
||||
b"Basic am9zw6k6c8mZY3LJmXQ=")
|
||||
|
||||
def test_unsupported_auth_mode(self):
|
||||
# curl and simple clients handle errors a bit differently; the
|
||||
# important thing is that they don't fall back to basic auth
|
||||
# on an unknown mode.
|
||||
with ExpectLog(gen_log, "uncaught exception", required=False):
|
||||
with self.assertRaises((ValueError, HTTPError)):
|
||||
self.fetch("/auth", auth_username="Aladdin",
|
||||
auth_password="open sesame",
|
||||
auth_mode="asdf",
|
||||
raise_error=True)
|
||||
|
||||
def test_follow_redirect(self):
|
||||
response = self.fetch("/countdown/2", follow_redirects=False)
|
||||
self.assertEqual(302, response.code)
|
||||
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
|
||||
|
||||
response = self.fetch("/countdown/2")
|
||||
self.assertEqual(200, response.code)
|
||||
self.assertTrue(response.effective_url.endswith("/countdown/0"))
|
||||
self.assertEqual(b"Zero", response.body)
|
||||
|
||||
def test_credentials_in_url(self):
|
||||
url = self.get_url("/auth").replace("http://", "http://me:secret@")
|
||||
response = self.fetch(url)
|
||||
self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"),
|
||||
response.body)
|
||||
|
||||
def test_body_encoding(self):
|
||||
unicode_body = u"\xe9"
|
||||
byte_body = binascii.a2b_hex(b"e9")
|
||||
|
||||
# unicode string in body gets converted to utf8
|
||||
response = self.fetch("/echopost", method="POST", body=unicode_body,
|
||||
headers={"Content-Type": "application/blah"})
|
||||
self.assertEqual(response.headers["Content-Length"], "2")
|
||||
self.assertEqual(response.body, utf8(unicode_body))
|
||||
|
||||
# byte strings pass through directly
|
||||
response = self.fetch("/echopost", method="POST",
|
||||
body=byte_body,
|
||||
headers={"Content-Type": "application/blah"})
|
||||
self.assertEqual(response.headers["Content-Length"], "1")
|
||||
self.assertEqual(response.body, byte_body)
|
||||
|
||||
# Mixing unicode in headers and byte string bodies shouldn't
|
||||
# break anything
|
||||
response = self.fetch("/echopost", method="POST", body=byte_body,
|
||||
headers={"Content-Type": "application/blah"},
|
||||
user_agent=u"foo")
|
||||
self.assertEqual(response.headers["Content-Length"], "1")
|
||||
self.assertEqual(response.body, byte_body)
|
||||
|
||||
def test_types(self):
|
||||
response = self.fetch("/hello")
|
||||
self.assertEqual(type(response.body), bytes)
|
||||
self.assertEqual(type(response.headers["Content-Type"]), str)
|
||||
self.assertEqual(type(response.code), int)
|
||||
self.assertEqual(type(response.effective_url), str)
|
||||
|
||||
def test_header_callback(self):
|
||||
first_line = []
|
||||
headers = {}
|
||||
chunks = []
|
||||
|
||||
def header_callback(header_line):
|
||||
if header_line.startswith('HTTP/1.1 101'):
|
||||
# Upgrading to HTTP/2
|
||||
pass
|
||||
elif header_line.startswith('HTTP/'):
|
||||
first_line.append(header_line)
|
||||
elif header_line != '\r\n':
|
||||
k, v = header_line.split(':', 1)
|
||||
headers[k.lower()] = v.strip()
|
||||
|
||||
def streaming_callback(chunk):
|
||||
# All header callbacks are run before any streaming callbacks,
|
||||
# so the header data is available to process the data as it
|
||||
# comes in.
|
||||
self.assertEqual(headers['content-type'], 'text/html; charset=UTF-8')
|
||||
chunks.append(chunk)
|
||||
|
||||
self.fetch('/chunk', header_callback=header_callback,
|
||||
streaming_callback=streaming_callback)
|
||||
self.assertEqual(len(first_line), 1, first_line)
|
||||
self.assertRegexpMatches(first_line[0], 'HTTP/[0-9]\\.[0-9] 200.*\r\n')
|
||||
self.assertEqual(chunks, [b'asdf', b'qwer'])
|
||||
|
||||
def test_header_callback_stack_context(self):
|
||||
exc_info = []
|
||||
|
||||
def error_handler(typ, value, tb):
|
||||
exc_info.append((typ, value, tb))
|
||||
return True
|
||||
|
||||
def header_callback(header_line):
|
||||
if header_line.lower().startswith('content-type:'):
|
||||
1 / 0
|
||||
|
||||
with ignore_deprecation():
|
||||
with ExceptionStackContext(error_handler):
|
||||
self.fetch('/chunk', header_callback=header_callback)
|
||||
self.assertEqual(len(exc_info), 1)
|
||||
self.assertIs(exc_info[0][0], ZeroDivisionError)
|
||||
|
||||
@gen_test
|
||||
def test_configure_defaults(self):
|
||||
defaults = dict(user_agent='TestDefaultUserAgent', allow_ipv6=False)
|
||||
# Construct a new instance of the configured client class
|
||||
client = self.http_client.__class__(force_instance=True,
|
||||
defaults=defaults)
|
||||
try:
|
||||
response = yield client.fetch(self.get_url('/user_agent'))
|
||||
self.assertEqual(response.body, b'TestDefaultUserAgent')
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
def test_header_types(self):
|
||||
# Header values may be passed as character or utf8 byte strings,
|
||||
# in a plain dictionary or an HTTPHeaders object.
|
||||
# Keys must always be the native str type.
|
||||
# All combinations should have the same results on the wire.
|
||||
for value in [u"MyUserAgent", b"MyUserAgent"]:
|
||||
for container in [dict, HTTPHeaders]:
|
||||
headers = container()
|
||||
headers['User-Agent'] = value
|
||||
resp = self.fetch('/user_agent', headers=headers)
|
||||
self.assertEqual(
|
||||
resp.body, b"MyUserAgent",
|
||||
"response=%r, value=%r, container=%r" %
|
||||
(resp.body, value, container))
|
||||
|
||||
def test_multi_line_headers(self):
|
||||
# Multi-line http headers are rare but rfc-allowed
|
||||
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
|
||||
sock, port = bind_unused_port()
|
||||
with closing(sock):
|
||||
@gen.coroutine
|
||||
def accept_callback(conn, address):
|
||||
stream = IOStream(conn)
|
||||
request_data = yield stream.read_until(b"\r\n\r\n")
|
||||
if b"HTTP/1." not in request_data:
|
||||
self.skipTest("requires HTTP/1.x")
|
||||
yield stream.write(b"""\
|
||||
HTTP/1.1 200 OK
|
||||
X-XSS-Protection: 1;
|
||||
\tmode=block
|
||||
|
||||
""".replace(b"\n", b"\r\n"))
|
||||
stream.close()
|
||||
|
||||
netutil.add_accept_handler(sock, accept_callback)
|
||||
resp = self.fetch("http://127.0.0.1:%d/" % port)
|
||||
resp.rethrow()
|
||||
self.assertEqual(resp.headers['X-XSS-Protection'], "1; mode=block")
|
||||
self.io_loop.remove_handler(sock.fileno())
|
||||
|
||||
def test_304_with_content_length(self):
|
||||
# According to the spec 304 responses SHOULD NOT include
|
||||
# Content-Length or other entity headers, but some servers do it
|
||||
# anyway.
|
||||
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5
|
||||
response = self.fetch('/304_with_content_length')
|
||||
self.assertEqual(response.code, 304)
|
||||
self.assertEqual(response.headers['Content-Length'], '42')
|
||||
|
||||
def test_final_callback_stack_context(self):
|
||||
# The final callback should be run outside of the httpclient's
|
||||
# stack_context. We want to ensure that there is not stack_context
|
||||
# between the user's callback and the IOLoop, so monkey-patch
|
||||
# IOLoop.handle_callback_exception and disable the test harness's
|
||||
# context with a NullContext.
|
||||
# Note that this does not apply to secondary callbacks (header
|
||||
# and streaming_callback), as errors there must be seen as errors
|
||||
# by the http client so it can clean up the connection.
|
||||
exc_info = []
|
||||
|
||||
def handle_callback_exception(callback):
|
||||
exc_info.append(sys.exc_info())
|
||||
self.stop()
|
||||
self.io_loop.handle_callback_exception = handle_callback_exception
|
||||
with NullContext():
|
||||
with ignore_deprecation():
|
||||
self.http_client.fetch(self.get_url('/hello'),
|
||||
lambda response: 1 / 0)
|
||||
self.wait()
|
||||
self.assertEqual(exc_info[0][0], ZeroDivisionError)
|
||||
|
||||
@gen_test
|
||||
def test_future_interface(self):
|
||||
response = yield self.http_client.fetch(self.get_url('/hello'))
|
||||
self.assertEqual(response.body, b'Hello world!')
|
||||
|
||||
@gen_test
|
||||
def test_future_http_error(self):
|
||||
with self.assertRaises(HTTPError) as context:
|
||||
yield self.http_client.fetch(self.get_url('/notfound'))
|
||||
self.assertEqual(context.exception.code, 404)
|
||||
self.assertEqual(context.exception.response.code, 404)
|
||||
|
||||
@gen_test
|
||||
def test_future_http_error_no_raise(self):
|
||||
response = yield self.http_client.fetch(self.get_url('/notfound'), raise_error=False)
|
||||
self.assertEqual(response.code, 404)
|
||||
|
||||
@gen_test
|
||||
def test_reuse_request_from_response(self):
|
||||
# The response.request attribute should be an HTTPRequest, not
|
||||
# a _RequestProxy.
|
||||
# This test uses self.http_client.fetch because self.fetch calls
|
||||
# self.get_url on the input unconditionally.
|
||||
url = self.get_url('/hello')
|
||||
response = yield self.http_client.fetch(url)
|
||||
self.assertEqual(response.request.url, url)
|
||||
self.assertTrue(isinstance(response.request, HTTPRequest))
|
||||
response2 = yield self.http_client.fetch(response.request)
|
||||
self.assertEqual(response2.body, b'Hello world!')
|
||||
|
||||
def test_all_methods(self):
|
||||
for method in ['GET', 'DELETE', 'OPTIONS']:
|
||||
response = self.fetch('/all_methods', method=method)
|
||||
self.assertEqual(response.body, utf8(method))
|
||||
for method in ['POST', 'PUT', 'PATCH']:
|
||||
response = self.fetch('/all_methods', method=method, body=b'')
|
||||
self.assertEqual(response.body, utf8(method))
|
||||
response = self.fetch('/all_methods', method='HEAD')
|
||||
self.assertEqual(response.body, b'')
|
||||
response = self.fetch('/all_methods', method='OTHER',
|
||||
allow_nonstandard_methods=True)
|
||||
self.assertEqual(response.body, b'OTHER')
|
||||
|
||||
def test_body_sanity_checks(self):
|
||||
# These methods require a body.
|
||||
for method in ('POST', 'PUT', 'PATCH'):
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.fetch('/all_methods', method=method, raise_error=True)
|
||||
self.assertIn('must not be None', str(context.exception))
|
||||
|
||||
resp = self.fetch('/all_methods', method=method,
|
||||
allow_nonstandard_methods=True)
|
||||
self.assertEqual(resp.code, 200)
|
||||
|
||||
# These methods don't allow a body.
|
||||
for method in ('GET', 'DELETE', 'OPTIONS'):
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.fetch('/all_methods', method=method, body=b'asdf', raise_error=True)
|
||||
self.assertIn('must be None', str(context.exception))
|
||||
|
||||
# In most cases this can be overridden, but curl_httpclient
|
||||
# does not allow body with a GET at all.
|
||||
if method != 'GET':
|
||||
self.fetch('/all_methods', method=method, body=b'asdf',
|
||||
allow_nonstandard_methods=True, raise_error=True)
|
||||
self.assertEqual(resp.code, 200)
|
||||
|
||||
# This test causes odd failures with the combination of
|
||||
# curl_httpclient (at least with the version of libcurl available
|
||||
# on ubuntu 12.04), TwistedIOLoop, and epoll. For POST (but not PUT),
|
||||
# curl decides the response came back too soon and closes the connection
|
||||
# to start again. It does this *before* telling the socket callback to
|
||||
# unregister the FD. Some IOLoop implementations have special kernel
|
||||
# integration to discover this immediately. Tornado's IOLoops
|
||||
# ignore errors on remove_handler to accommodate this behavior, but
|
||||
# Twisted's reactor does not. The removeReader call fails and so
|
||||
# do all future removeAll calls (which our tests do at cleanup).
|
||||
#
|
||||
# def test_post_307(self):
|
||||
# response = self.fetch("/redirect?status=307&url=/post",
|
||||
# method="POST", body=b"arg1=foo&arg2=bar")
|
||||
# self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
|
||||
|
||||
def test_put_307(self):
|
||||
response = self.fetch("/redirect?status=307&url=/put",
|
||||
method="PUT", body=b"hello")
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b"Put body: hello")
|
||||
|
||||
def test_non_ascii_header(self):
|
||||
# Non-ascii headers are sent as latin1.
|
||||
response = self.fetch("/set_header?k=foo&v=%E9")
|
||||
response.rethrow()
|
||||
self.assertEqual(response.headers["Foo"], native_str(u"\u00e9"))
|
||||
|
||||
def test_response_times(self):
|
||||
# A few simple sanity checks of the response time fields to
|
||||
# make sure they're using the right basis (between the
|
||||
# wall-time and monotonic clocks).
|
||||
start_time = time.time()
|
||||
response = self.fetch("/hello")
|
||||
response.rethrow()
|
||||
self.assertGreaterEqual(response.request_time, 0)
|
||||
self.assertLess(response.request_time, 1.0)
|
||||
# A very crude check to make sure that start_time is based on
|
||||
# wall time and not the monotonic clock.
|
||||
self.assertLess(abs(response.start_time - start_time), 1.0)
|
||||
|
||||
for k, v in response.time_info.items():
|
||||
self.assertTrue(0 <= v < 1.0, "time_info[%s] out of bounds: %s" % (k, v))
|
||||
|
||||
|
||||
class RequestProxyTest(unittest.TestCase):
|
||||
def test_request_set(self):
|
||||
proxy = _RequestProxy(HTTPRequest('http://example.com/',
|
||||
user_agent='foo'),
|
||||
dict())
|
||||
self.assertEqual(proxy.user_agent, 'foo')
|
||||
|
||||
def test_default_set(self):
|
||||
proxy = _RequestProxy(HTTPRequest('http://example.com/'),
|
||||
dict(network_interface='foo'))
|
||||
self.assertEqual(proxy.network_interface, 'foo')
|
||||
|
||||
def test_both_set(self):
|
||||
proxy = _RequestProxy(HTTPRequest('http://example.com/',
|
||||
proxy_host='foo'),
|
||||
dict(proxy_host='bar'))
|
||||
self.assertEqual(proxy.proxy_host, 'foo')
|
||||
|
||||
def test_neither_set(self):
|
||||
proxy = _RequestProxy(HTTPRequest('http://example.com/'),
|
||||
dict())
|
||||
self.assertIs(proxy.auth_username, None)
|
||||
|
||||
def test_bad_attribute(self):
|
||||
proxy = _RequestProxy(HTTPRequest('http://example.com/'),
|
||||
dict())
|
||||
with self.assertRaises(AttributeError):
|
||||
proxy.foo
|
||||
|
||||
def test_defaults_none(self):
|
||||
proxy = _RequestProxy(HTTPRequest('http://example.com/'), None)
|
||||
self.assertIs(proxy.auth_username, None)
|
||||
|
||||
|
||||
class HTTPResponseTestCase(unittest.TestCase):
|
||||
def test_str(self):
|
||||
response = HTTPResponse(HTTPRequest('http://example.com'),
|
||||
200, headers={}, buffer=BytesIO())
|
||||
s = str(response)
|
||||
self.assertTrue(s.startswith('HTTPResponse('))
|
||||
self.assertIn('code=200', s)
|
||||
|
||||
|
||||
class SyncHTTPClientTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if IOLoop.configured_class().__name__ == 'TwistedIOLoop':
|
||||
# TwistedIOLoop only supports the global reactor, so we can't have
|
||||
# separate IOLoops for client and server threads.
|
||||
raise unittest.SkipTest(
|
||||
'Sync HTTPClient not compatible with TwistedIOLoop')
|
||||
self.server_ioloop = IOLoop()
|
||||
|
||||
@gen.coroutine
|
||||
def init_server():
|
||||
sock, self.port = bind_unused_port()
|
||||
app = Application([('/', HelloWorldHandler)])
|
||||
self.server = HTTPServer(app)
|
||||
self.server.add_socket(sock)
|
||||
self.server_ioloop.run_sync(init_server)
|
||||
|
||||
self.server_thread = threading.Thread(target=self.server_ioloop.start)
|
||||
self.server_thread.start()
|
||||
|
||||
self.http_client = HTTPClient()
|
||||
|
||||
def tearDown(self):
|
||||
def stop_server():
|
||||
self.server.stop()
|
||||
# Delay the shutdown of the IOLoop by several iterations because
|
||||
# the server may still have some cleanup work left when
|
||||
# the client finishes with the response (this is noticeable
|
||||
# with http/2, which leaves a Future with an unexamined
|
||||
# StreamClosedError on the loop).
|
||||
|
||||
@gen.coroutine
|
||||
def slow_stop():
|
||||
# The number of iterations is difficult to predict. Typically,
|
||||
# one is sufficient, although sometimes it needs more.
|
||||
for i in range(5):
|
||||
yield
|
||||
self.server_ioloop.stop()
|
||||
self.server_ioloop.add_callback(slow_stop)
|
||||
self.server_ioloop.add_callback(stop_server)
|
||||
self.server_thread.join()
|
||||
self.http_client.close()
|
||||
self.server_ioloop.close(all_fds=True)
|
||||
|
||||
def get_url(self, path):
|
||||
return 'http://127.0.0.1:%d%s' % (self.port, path)
|
||||
|
||||
def test_sync_client(self):
|
||||
response = self.http_client.fetch(self.get_url('/'))
|
||||
self.assertEqual(b'Hello world!', response.body)
|
||||
|
||||
def test_sync_client_error(self):
|
||||
# Synchronous HTTPClient raises errors directly; no need for
|
||||
# response.rethrow()
|
||||
with self.assertRaises(HTTPError) as assertion:
|
||||
self.http_client.fetch(self.get_url('/notfound'))
|
||||
self.assertEqual(assertion.exception.code, 404)
|
||||
|
||||
|
||||
class HTTPRequestTestCase(unittest.TestCase):
|
||||
def test_headers(self):
|
||||
request = HTTPRequest('http://example.com', headers={'foo': 'bar'})
|
||||
self.assertEqual(request.headers, {'foo': 'bar'})
|
||||
|
||||
def test_headers_setter(self):
|
||||
request = HTTPRequest('http://example.com')
|
||||
request.headers = {'bar': 'baz'}
|
||||
self.assertEqual(request.headers, {'bar': 'baz'})
|
||||
|
||||
def test_null_headers_setter(self):
|
||||
request = HTTPRequest('http://example.com')
|
||||
request.headers = None
|
||||
self.assertEqual(request.headers, {})
|
||||
|
||||
def test_body(self):
|
||||
request = HTTPRequest('http://example.com', body='foo')
|
||||
self.assertEqual(request.body, utf8('foo'))
|
||||
|
||||
def test_body_setter(self):
|
||||
request = HTTPRequest('http://example.com')
|
||||
request.body = 'foo'
|
||||
self.assertEqual(request.body, utf8('foo'))
|
||||
|
||||
def test_if_modified_since(self):
|
||||
http_date = datetime.datetime.utcnow()
|
||||
request = HTTPRequest('http://example.com', if_modified_since=http_date)
|
||||
self.assertEqual(request.headers,
|
||||
{'If-Modified-Since': format_timestamp(http_date)})
|
||||
|
||||
|
||||
class HTTPErrorTestCase(unittest.TestCase):
|
||||
def test_copy(self):
|
||||
e = HTTPError(403)
|
||||
e2 = copy.copy(e)
|
||||
self.assertIsNot(e, e2)
|
||||
self.assertEqual(e.code, e2.code)
|
||||
|
||||
def test_plain_error(self):
|
||||
e = HTTPError(403)
|
||||
self.assertEqual(str(e), "HTTP 403: Forbidden")
|
||||
self.assertEqual(repr(e), "HTTP 403: Forbidden")
|
||||
|
||||
def test_error_with_response(self):
|
||||
resp = HTTPResponse(HTTPRequest('http://example.com/'), 403)
|
||||
with self.assertRaises(HTTPError) as cm:
|
||||
resp.rethrow()
|
||||
e = cm.exception
|
||||
self.assertEqual(str(e), "HTTP 403: Forbidden")
|
||||
self.assertEqual(repr(e), "HTTP 403: Forbidden")
|
||||
Executable
+1167
File diff suppressed because it is too large
Load Diff
Executable
+516
@@ -0,0 +1,516 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from tornado.httputil import (
|
||||
url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp,
|
||||
HTTPServerRequest, parse_request_start_line, parse_cookie, qs_to_qsl,
|
||||
HTTPInputError,
|
||||
)
|
||||
from tornado.escape import utf8, native_str
|
||||
from tornado.util import PY3
|
||||
from tornado.log import gen_log
|
||||
from tornado.testing import ExpectLog
|
||||
from tornado.test.util import unittest
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import logging
|
||||
import pickle
|
||||
import time
|
||||
|
||||
if PY3:
|
||||
import urllib.parse as urllib_parse
|
||||
else:
|
||||
import urlparse as urllib_parse
|
||||
|
||||
|
||||
class TestUrlConcat(unittest.TestCase):
|
||||
def test_url_concat_no_query_params(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path",
|
||||
[('y', 'y'), ('z', 'z')],
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?y=y&z=z")
|
||||
|
||||
def test_url_concat_encode_args(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path",
|
||||
[('y', '/y'), ('z', 'z')],
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?y=%2Fy&z=z")
|
||||
|
||||
def test_url_concat_trailing_q(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path?",
|
||||
[('y', 'y'), ('z', 'z')],
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?y=y&z=z")
|
||||
|
||||
def test_url_concat_q_with_no_trailing_amp(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path?x",
|
||||
[('y', 'y'), ('z', 'z')],
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?x=&y=y&z=z")
|
||||
|
||||
def test_url_concat_trailing_amp(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path?x&",
|
||||
[('y', 'y'), ('z', 'z')],
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?x=&y=y&z=z")
|
||||
|
||||
def test_url_concat_mult_params(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path?a=1&b=2",
|
||||
[('y', 'y'), ('z', 'z')],
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?a=1&b=2&y=y&z=z")
|
||||
|
||||
def test_url_concat_no_params(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path?r=1&t=2",
|
||||
[],
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?r=1&t=2")
|
||||
|
||||
def test_url_concat_none_params(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path?r=1&t=2",
|
||||
None,
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?r=1&t=2")
|
||||
|
||||
def test_url_concat_with_frag(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path#tab",
|
||||
[('y', 'y')],
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?y=y#tab")
|
||||
|
||||
def test_url_concat_multi_same_params(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path",
|
||||
[('y', 'y1'), ('y', 'y2')],
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?y=y1&y=y2")
|
||||
|
||||
def test_url_concat_multi_same_query_params(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path?r=1&r=2",
|
||||
[('y', 'y')],
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?r=1&r=2&y=y")
|
||||
|
||||
def test_url_concat_dict_params(self):
|
||||
url = url_concat(
|
||||
"https://localhost/path",
|
||||
dict(y='y'),
|
||||
)
|
||||
self.assertEqual(url, "https://localhost/path?y=y")
|
||||
|
||||
|
||||
class QsParseTest(unittest.TestCase):
|
||||
|
||||
def test_parsing(self):
|
||||
qsstring = "a=1&b=2&a=3"
|
||||
qs = urllib_parse.parse_qs(qsstring)
|
||||
qsl = list(qs_to_qsl(qs))
|
||||
self.assertIn(('a', '1'), qsl)
|
||||
self.assertIn(('a', '3'), qsl)
|
||||
self.assertIn(('b', '2'), qsl)
|
||||
|
||||
|
||||
class MultipartFormDataTest(unittest.TestCase):
|
||||
def test_file_upload(self):
|
||||
data = b"""\
|
||||
--1234
|
||||
Content-Disposition: form-data; name="files"; filename="ab.txt"
|
||||
|
||||
Foo
|
||||
--1234--""".replace(b"\n", b"\r\n")
|
||||
args = {}
|
||||
files = {}
|
||||
parse_multipart_form_data(b"1234", data, args, files)
|
||||
file = files["files"][0]
|
||||
self.assertEqual(file["filename"], "ab.txt")
|
||||
self.assertEqual(file["body"], b"Foo")
|
||||
|
||||
def test_unquoted_names(self):
|
||||
# quotes are optional unless special characters are present
|
||||
data = b"""\
|
||||
--1234
|
||||
Content-Disposition: form-data; name=files; filename=ab.txt
|
||||
|
||||
Foo
|
||||
--1234--""".replace(b"\n", b"\r\n")
|
||||
args = {}
|
||||
files = {}
|
||||
parse_multipart_form_data(b"1234", data, args, files)
|
||||
file = files["files"][0]
|
||||
self.assertEqual(file["filename"], "ab.txt")
|
||||
self.assertEqual(file["body"], b"Foo")
|
||||
|
||||
def test_special_filenames(self):
|
||||
filenames = ['a;b.txt',
|
||||
'a"b.txt',
|
||||
'a";b.txt',
|
||||
'a;"b.txt',
|
||||
'a";";.txt',
|
||||
'a\\"b.txt',
|
||||
'a\\b.txt',
|
||||
]
|
||||
for filename in filenames:
|
||||
logging.debug("trying filename %r", filename)
|
||||
data = """\
|
||||
--1234
|
||||
Content-Disposition: form-data; name="files"; filename="%s"
|
||||
|
||||
Foo
|
||||
--1234--""" % filename.replace('\\', '\\\\').replace('"', '\\"')
|
||||
data = utf8(data.replace("\n", "\r\n"))
|
||||
args = {}
|
||||
files = {}
|
||||
parse_multipart_form_data(b"1234", data, args, files)
|
||||
file = files["files"][0]
|
||||
self.assertEqual(file["filename"], filename)
|
||||
self.assertEqual(file["body"], b"Foo")
|
||||
|
||||
def test_non_ascii_filename(self):
|
||||
data = b"""\
|
||||
--1234
|
||||
Content-Disposition: form-data; name="files"; filename="ab.txt"; filename*=UTF-8''%C3%A1b.txt
|
||||
|
||||
Foo
|
||||
--1234--""".replace(b"\n", b"\r\n")
|
||||
args = {}
|
||||
files = {}
|
||||
parse_multipart_form_data(b"1234", data, args, files)
|
||||
file = files["files"][0]
|
||||
self.assertEqual(file["filename"], u"áb.txt")
|
||||
self.assertEqual(file["body"], b"Foo")
|
||||
|
||||
def test_boundary_starts_and_ends_with_quotes(self):
|
||||
data = b'''\
|
||||
--1234
|
||||
Content-Disposition: form-data; name="files"; filename="ab.txt"
|
||||
|
||||
Foo
|
||||
--1234--'''.replace(b"\n", b"\r\n")
|
||||
args = {}
|
||||
files = {}
|
||||
parse_multipart_form_data(b'"1234"', data, args, files)
|
||||
file = files["files"][0]
|
||||
self.assertEqual(file["filename"], "ab.txt")
|
||||
self.assertEqual(file["body"], b"Foo")
|
||||
|
||||
def test_missing_headers(self):
|
||||
data = b'''\
|
||||
--1234
|
||||
|
||||
Foo
|
||||
--1234--'''.replace(b"\n", b"\r\n")
|
||||
args = {}
|
||||
files = {}
|
||||
with ExpectLog(gen_log, "multipart/form-data missing headers"):
|
||||
parse_multipart_form_data(b"1234", data, args, files)
|
||||
self.assertEqual(files, {})
|
||||
|
||||
def test_invalid_content_disposition(self):
|
||||
data = b'''\
|
||||
--1234
|
||||
Content-Disposition: invalid; name="files"; filename="ab.txt"
|
||||
|
||||
Foo
|
||||
--1234--'''.replace(b"\n", b"\r\n")
|
||||
args = {}
|
||||
files = {}
|
||||
with ExpectLog(gen_log, "Invalid multipart/form-data"):
|
||||
parse_multipart_form_data(b"1234", data, args, files)
|
||||
self.assertEqual(files, {})
|
||||
|
||||
def test_line_does_not_end_with_correct_line_break(self):
|
||||
data = b'''\
|
||||
--1234
|
||||
Content-Disposition: form-data; name="files"; filename="ab.txt"
|
||||
|
||||
Foo--1234--'''.replace(b"\n", b"\r\n")
|
||||
args = {}
|
||||
files = {}
|
||||
with ExpectLog(gen_log, "Invalid multipart/form-data"):
|
||||
parse_multipart_form_data(b"1234", data, args, files)
|
||||
self.assertEqual(files, {})
|
||||
|
||||
def test_content_disposition_header_without_name_parameter(self):
|
||||
data = b"""\
|
||||
--1234
|
||||
Content-Disposition: form-data; filename="ab.txt"
|
||||
|
||||
Foo
|
||||
--1234--""".replace(b"\n", b"\r\n")
|
||||
args = {}
|
||||
files = {}
|
||||
with ExpectLog(gen_log, "multipart/form-data value missing name"):
|
||||
parse_multipart_form_data(b"1234", data, args, files)
|
||||
self.assertEqual(files, {})
|
||||
|
||||
def test_data_after_final_boundary(self):
|
||||
# The spec requires that data after the final boundary be ignored.
|
||||
# http://www.w3.org/Protocols/rfc1341/7_2_Multipart.html
|
||||
# In practice, some libraries include an extra CRLF after the boundary.
|
||||
data = b"""\
|
||||
--1234
|
||||
Content-Disposition: form-data; name="files"; filename="ab.txt"
|
||||
|
||||
Foo
|
||||
--1234--
|
||||
""".replace(b"\n", b"\r\n")
|
||||
args = {}
|
||||
files = {}
|
||||
parse_multipart_form_data(b"1234", data, args, files)
|
||||
file = files["files"][0]
|
||||
self.assertEqual(file["filename"], "ab.txt")
|
||||
self.assertEqual(file["body"], b"Foo")
|
||||
|
||||
|
||||
class HTTPHeadersTest(unittest.TestCase):
|
||||
def test_multi_line(self):
|
||||
# Lines beginning with whitespace are appended to the previous line
|
||||
# with any leading whitespace replaced by a single space.
|
||||
# Note that while multi-line headers are a part of the HTTP spec,
|
||||
# their use is strongly discouraged.
|
||||
data = """\
|
||||
Foo: bar
|
||||
baz
|
||||
Asdf: qwer
|
||||
\tzxcv
|
||||
Foo: even
|
||||
more
|
||||
lines
|
||||
""".replace("\n", "\r\n")
|
||||
headers = HTTPHeaders.parse(data)
|
||||
self.assertEqual(headers["asdf"], "qwer zxcv")
|
||||
self.assertEqual(headers.get_list("asdf"), ["qwer zxcv"])
|
||||
self.assertEqual(headers["Foo"], "bar baz,even more lines")
|
||||
self.assertEqual(headers.get_list("foo"), ["bar baz", "even more lines"])
|
||||
self.assertEqual(sorted(list(headers.get_all())),
|
||||
[("Asdf", "qwer zxcv"),
|
||||
("Foo", "bar baz"),
|
||||
("Foo", "even more lines")])
|
||||
|
||||
def test_malformed_continuation(self):
|
||||
# If the first line starts with whitespace, it's a
|
||||
# continuation line with nothing to continue, so reject it
|
||||
# (with a proper error).
|
||||
data = " Foo: bar"
|
||||
self.assertRaises(HTTPInputError, HTTPHeaders.parse, data)
|
||||
|
||||
def test_unicode_newlines(self):
|
||||
# Ensure that only \r\n is recognized as a header separator, and not
|
||||
# the other newline-like unicode characters.
|
||||
# Characters that are likely to be problematic can be found in
|
||||
# http://unicode.org/standard/reports/tr13/tr13-5.html
|
||||
# and cpython's unicodeobject.c (which defines the implementation
|
||||
# of unicode_type.splitlines(), and uses a different list than TR13).
|
||||
newlines = [
|
||||
u'\u001b', # VERTICAL TAB
|
||||
u'\u001c', # FILE SEPARATOR
|
||||
u'\u001d', # GROUP SEPARATOR
|
||||
u'\u001e', # RECORD SEPARATOR
|
||||
u'\u0085', # NEXT LINE
|
||||
u'\u2028', # LINE SEPARATOR
|
||||
u'\u2029', # PARAGRAPH SEPARATOR
|
||||
]
|
||||
for newline in newlines:
|
||||
# Try the utf8 and latin1 representations of each newline
|
||||
for encoding in ['utf8', 'latin1']:
|
||||
try:
|
||||
try:
|
||||
encoded = newline.encode(encoding)
|
||||
except UnicodeEncodeError:
|
||||
# Some chars cannot be represented in latin1
|
||||
continue
|
||||
data = b'Cookie: foo=' + encoded + b'bar'
|
||||
# parse() wants a native_str, so decode through latin1
|
||||
# in the same way the real parser does.
|
||||
headers = HTTPHeaders.parse(
|
||||
native_str(data.decode('latin1')))
|
||||
expected = [('Cookie', 'foo=' +
|
||||
native_str(encoded.decode('latin1')) + 'bar')]
|
||||
self.assertEqual(
|
||||
expected, list(headers.get_all()))
|
||||
except Exception:
|
||||
gen_log.warning("failed while trying %r in %s",
|
||||
newline, encoding)
|
||||
raise
|
||||
|
||||
def test_optional_cr(self):
|
||||
# Both CRLF and LF should be accepted as separators. CR should not be
|
||||
# part of the data when followed by LF, but it is a normal char
|
||||
# otherwise (or should bare CR be an error?)
|
||||
headers = HTTPHeaders.parse(
|
||||
'CRLF: crlf\r\nLF: lf\nCR: cr\rMore: more\r\n')
|
||||
self.assertEqual(sorted(headers.get_all()),
|
||||
[('Cr', 'cr\rMore: more'),
|
||||
('Crlf', 'crlf'),
|
||||
('Lf', 'lf'),
|
||||
])
|
||||
|
||||
def test_copy(self):
|
||||
all_pairs = [('A', '1'), ('A', '2'), ('B', 'c')]
|
||||
h1 = HTTPHeaders()
|
||||
for k, v in all_pairs:
|
||||
h1.add(k, v)
|
||||
h2 = h1.copy()
|
||||
h3 = copy.copy(h1)
|
||||
h4 = copy.deepcopy(h1)
|
||||
for headers in [h1, h2, h3, h4]:
|
||||
# All the copies are identical, no matter how they were
|
||||
# constructed.
|
||||
self.assertEqual(list(sorted(headers.get_all())), all_pairs)
|
||||
for headers in [h2, h3, h4]:
|
||||
# Neither the dict or its member lists are reused.
|
||||
self.assertIsNot(headers, h1)
|
||||
self.assertIsNot(headers.get_list('A'), h1.get_list('A'))
|
||||
|
||||
def test_pickle_roundtrip(self):
|
||||
headers = HTTPHeaders()
|
||||
headers.add('Set-Cookie', 'a=b')
|
||||
headers.add('Set-Cookie', 'c=d')
|
||||
headers.add('Content-Type', 'text/html')
|
||||
pickled = pickle.dumps(headers)
|
||||
unpickled = pickle.loads(pickled)
|
||||
self.assertEqual(sorted(headers.get_all()), sorted(unpickled.get_all()))
|
||||
self.assertEqual(sorted(headers.items()), sorted(unpickled.items()))
|
||||
|
||||
def test_setdefault(self):
|
||||
headers = HTTPHeaders()
|
||||
headers['foo'] = 'bar'
|
||||
# If a value is present, setdefault returns it without changes.
|
||||
self.assertEqual(headers.setdefault('foo', 'baz'), 'bar')
|
||||
self.assertEqual(headers['foo'], 'bar')
|
||||
# If a value is not present, setdefault sets it for future use.
|
||||
self.assertEqual(headers.setdefault('quux', 'xyzzy'), 'xyzzy')
|
||||
self.assertEqual(headers['quux'], 'xyzzy')
|
||||
self.assertEqual(sorted(headers.get_all()), [('Foo', 'bar'), ('Quux', 'xyzzy')])
|
||||
|
||||
def test_string(self):
|
||||
headers = HTTPHeaders()
|
||||
headers.add("Foo", "1")
|
||||
headers.add("Foo", "2")
|
||||
headers.add("Foo", "3")
|
||||
headers2 = HTTPHeaders.parse(str(headers))
|
||||
self.assertEquals(headers, headers2)
|
||||
|
||||
|
||||
class FormatTimestampTest(unittest.TestCase):
|
||||
# Make sure that all the input types are supported.
|
||||
TIMESTAMP = 1359312200.503611
|
||||
EXPECTED = 'Sun, 27 Jan 2013 18:43:20 GMT'
|
||||
|
||||
def check(self, value):
|
||||
self.assertEqual(format_timestamp(value), self.EXPECTED)
|
||||
|
||||
def test_unix_time_float(self):
|
||||
self.check(self.TIMESTAMP)
|
||||
|
||||
def test_unix_time_int(self):
|
||||
self.check(int(self.TIMESTAMP))
|
||||
|
||||
def test_struct_time(self):
|
||||
self.check(time.gmtime(self.TIMESTAMP))
|
||||
|
||||
def test_time_tuple(self):
|
||||
tup = tuple(time.gmtime(self.TIMESTAMP))
|
||||
self.assertEqual(9, len(tup))
|
||||
self.check(tup)
|
||||
|
||||
def test_datetime(self):
|
||||
self.check(datetime.datetime.utcfromtimestamp(self.TIMESTAMP))
|
||||
|
||||
|
||||
# HTTPServerRequest is mainly tested incidentally to the server itself,
|
||||
# but this tests the parts of the class that can be tested in isolation.
|
||||
class HTTPServerRequestTest(unittest.TestCase):
|
||||
def test_default_constructor(self):
|
||||
# All parameters are formally optional, but uri is required
|
||||
# (and has been for some time). This test ensures that no
|
||||
# more required parameters slip in.
|
||||
HTTPServerRequest(uri='/')
|
||||
|
||||
def test_body_is_a_byte_string(self):
|
||||
requets = HTTPServerRequest(uri='/')
|
||||
self.assertIsInstance(requets.body, bytes)
|
||||
|
||||
def test_repr_does_not_contain_headers(self):
|
||||
request = HTTPServerRequest(uri='/', headers={'Canary': 'Coal Mine'})
|
||||
self.assertTrue('Canary' not in repr(request))
|
||||
|
||||
|
||||
class ParseRequestStartLineTest(unittest.TestCase):
|
||||
METHOD = "GET"
|
||||
PATH = "/foo"
|
||||
VERSION = "HTTP/1.1"
|
||||
|
||||
def test_parse_request_start_line(self):
|
||||
start_line = " ".join([self.METHOD, self.PATH, self.VERSION])
|
||||
parsed_start_line = parse_request_start_line(start_line)
|
||||
self.assertEqual(parsed_start_line.method, self.METHOD)
|
||||
self.assertEqual(parsed_start_line.path, self.PATH)
|
||||
self.assertEqual(parsed_start_line.version, self.VERSION)
|
||||
|
||||
|
||||
class ParseCookieTest(unittest.TestCase):
|
||||
# These tests copied from Django:
|
||||
# https://github.com/django/django/pull/6277/commits/da810901ada1cae9fc1f018f879f11a7fb467b28
|
||||
def test_python_cookies(self):
|
||||
"""
|
||||
Test cases copied from Python's Lib/test/test_http_cookies.py
|
||||
"""
|
||||
self.assertEqual(parse_cookie('chips=ahoy; vienna=finger'),
|
||||
{'chips': 'ahoy', 'vienna': 'finger'})
|
||||
# Here parse_cookie() differs from Python's cookie parsing in that it
|
||||
# treats all semicolons as delimiters, even within quotes.
|
||||
self.assertEqual(
|
||||
parse_cookie('keebler="E=mc2; L=\\"Loves\\"; fudge=\\012;"'),
|
||||
{'keebler': '"E=mc2', 'L': '\\"Loves\\"', 'fudge': '\\012', '': '"'}
|
||||
)
|
||||
# Illegal cookies that have an '=' char in an unquoted value.
|
||||
self.assertEqual(parse_cookie('keebler=E=mc2'), {'keebler': 'E=mc2'})
|
||||
# Cookies with ':' character in their name.
|
||||
self.assertEqual(parse_cookie('key:term=value:term'), {'key:term': 'value:term'})
|
||||
# Cookies with '[' and ']'.
|
||||
self.assertEqual(parse_cookie('a=b; c=[; d=r; f=h'),
|
||||
{'a': 'b', 'c': '[', 'd': 'r', 'f': 'h'})
|
||||
|
||||
def test_cookie_edgecases(self):
|
||||
# Cookies that RFC6265 allows.
|
||||
self.assertEqual(parse_cookie('a=b; Domain=example.com'),
|
||||
{'a': 'b', 'Domain': 'example.com'})
|
||||
# parse_cookie() has historically kept only the last cookie with the
|
||||
# same name.
|
||||
self.assertEqual(parse_cookie('a=b; h=i; a=c'), {'a': 'c', 'h': 'i'})
|
||||
|
||||
def test_invalid_cookies(self):
|
||||
"""
|
||||
Cookie strings that go against RFC6265 but browsers will send if set
|
||||
via document.cookie.
|
||||
"""
|
||||
# Chunks without an equals sign appear as unnamed values per
|
||||
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
|
||||
self.assertIn('django_language',
|
||||
parse_cookie('abc=def; unnamed; django_language=en').keys())
|
||||
# Even a double quote may be an unamed value.
|
||||
self.assertEqual(parse_cookie('a=b; "; c=d'), {'a': 'b', '': '"', 'c': 'd'})
|
||||
# Spaces in names and values, and an equals sign in values.
|
||||
self.assertEqual(parse_cookie('a b c=d e = f; gh=i'), {'a b c': 'd e = f', 'gh': 'i'})
|
||||
# More characters the spec forbids.
|
||||
self.assertEqual(parse_cookie('a b,c<>@:/[]?{}=d " =e,f g'),
|
||||
{'a b,c<>@:/[]?{}': 'd " =e,f g'})
|
||||
# Unicode characters. The spec only allows ASCII.
|
||||
self.assertEqual(parse_cookie('saint=André Bessette'),
|
||||
{'saint': native_str('André Bessette')})
|
||||
# Browsers don't send extra whitespace or semicolons in Cookie headers,
|
||||
# but parse_cookie() should parse whitespace the same way
|
||||
# document.cookie parses whitespace.
|
||||
self.assertEqual(parse_cookie(' = b ; ; = ; c = ; '), {'': 'b', 'c': ''})
|
||||
Executable
+73
@@ -0,0 +1,73 @@
|
||||
# flake8: noqa
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from tornado.test.util import unittest
|
||||
|
||||
_import_everything = b"""
|
||||
# The event loop is not fork-safe, and it's easy to initialize an asyncio.Future
|
||||
# at startup, which in turn creates the default event loop and prevents forking.
|
||||
# Explicitly disallow the default event loop so that an error will be raised
|
||||
# if something tries to touch it.
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
import tornado.auth
|
||||
import tornado.autoreload
|
||||
import tornado.concurrent
|
||||
import tornado.escape
|
||||
import tornado.gen
|
||||
import tornado.http1connection
|
||||
import tornado.httpclient
|
||||
import tornado.httpserver
|
||||
import tornado.httputil
|
||||
import tornado.ioloop
|
||||
import tornado.iostream
|
||||
import tornado.locale
|
||||
import tornado.log
|
||||
import tornado.netutil
|
||||
import tornado.options
|
||||
import tornado.process
|
||||
import tornado.simple_httpclient
|
||||
import tornado.stack_context
|
||||
import tornado.tcpserver
|
||||
import tornado.tcpclient
|
||||
import tornado.template
|
||||
import tornado.testing
|
||||
import tornado.util
|
||||
import tornado.web
|
||||
import tornado.websocket
|
||||
import tornado.wsgi
|
||||
|
||||
try:
|
||||
import pycurl
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
import tornado.curl_httpclient
|
||||
"""
|
||||
|
||||
|
||||
class ImportTest(unittest.TestCase):
|
||||
def test_import_everything(self):
|
||||
# Test that all Tornado modules can be imported without side effects,
|
||||
# specifically without initializing the default asyncio event loop.
|
||||
# Since we can't tell which modules may have already beein imported
|
||||
# in our process, do it in a subprocess for a clean slate.
|
||||
proc = subprocess.Popen([sys.executable], stdin=subprocess.PIPE)
|
||||
proc.communicate(_import_everything)
|
||||
self.assertEqual(proc.returncode, 0)
|
||||
|
||||
def test_import_aliases(self):
|
||||
# Ensure we don't delete formerly-documented aliases accidentally.
|
||||
import tornado.ioloop
|
||||
import tornado.gen
|
||||
import tornado.util
|
||||
self.assertIs(tornado.ioloop.TimeoutError, tornado.util.TimeoutError)
|
||||
self.assertIs(tornado.gen.TimeoutError, tornado.util.TimeoutError)
|
||||
Executable
+942
@@ -0,0 +1,942 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import contextlib
|
||||
import datetime
|
||||
import functools
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
try:
|
||||
from unittest import mock # type: ignore
|
||||
except ImportError:
|
||||
try:
|
||||
import mock # type: ignore
|
||||
except ImportError:
|
||||
mock = None
|
||||
|
||||
from tornado.escape import native_str
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop, TimeoutError, PollIOLoop, PeriodicCallback
|
||||
from tornado.log import app_log
|
||||
from tornado.platform.select import _Select
|
||||
from tornado.stack_context import ExceptionStackContext, StackContext, wrap, NullContext
|
||||
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog, gen_test
|
||||
from tornado.test.util import (unittest, skipIfNonUnix, skipOnTravis,
|
||||
skipBefore35, exec_test, ignore_deprecation)
|
||||
|
||||
try:
|
||||
from concurrent import futures
|
||||
except ImportError:
|
||||
futures = None
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError:
|
||||
asyncio = None
|
||||
|
||||
try:
|
||||
import twisted
|
||||
except ImportError:
|
||||
twisted = None
|
||||
|
||||
|
||||
class FakeTimeSelect(_Select):
|
||||
def __init__(self):
|
||||
self._time = 1000
|
||||
super(FakeTimeSelect, self).__init__()
|
||||
|
||||
def time(self):
|
||||
return self._time
|
||||
|
||||
def sleep(self, t):
|
||||
self._time += t
|
||||
|
||||
def poll(self, timeout):
|
||||
events = super(FakeTimeSelect, self).poll(0)
|
||||
if events:
|
||||
return events
|
||||
self._time += timeout
|
||||
return []
|
||||
|
||||
|
||||
class FakeTimeIOLoop(PollIOLoop):
|
||||
"""IOLoop implementation with a fake and deterministic clock.
|
||||
|
||||
The clock advances as needed to trigger timeouts immediately.
|
||||
For use when testing code that involves the passage of time
|
||||
and no external dependencies.
|
||||
"""
|
||||
def initialize(self):
|
||||
self.fts = FakeTimeSelect()
|
||||
super(FakeTimeIOLoop, self).initialize(impl=self.fts,
|
||||
time_func=self.fts.time)
|
||||
|
||||
def sleep(self, t):
|
||||
"""Simulate a blocking sleep by advancing the clock."""
|
||||
self.fts.sleep(t)
|
||||
|
||||
|
||||
class TestIOLoop(AsyncTestCase):
|
||||
def test_add_callback_return_sequence(self):
|
||||
# A callback returning {} or [] shouldn't spin the CPU, see Issue #1803.
|
||||
self.calls = 0
|
||||
|
||||
loop = self.io_loop
|
||||
test = self
|
||||
old_add_callback = loop.add_callback
|
||||
|
||||
def add_callback(self, callback, *args, **kwargs):
|
||||
test.calls += 1
|
||||
old_add_callback(callback, *args, **kwargs)
|
||||
|
||||
loop.add_callback = types.MethodType(add_callback, loop)
|
||||
loop.add_callback(lambda: {})
|
||||
loop.add_callback(lambda: [])
|
||||
loop.add_timeout(datetime.timedelta(milliseconds=50), loop.stop)
|
||||
loop.start()
|
||||
self.assertLess(self.calls, 10)
|
||||
|
||||
@skipOnTravis
|
||||
def test_add_callback_wakeup(self):
|
||||
# Make sure that add_callback from inside a running IOLoop
|
||||
# wakes up the IOLoop immediately instead of waiting for a timeout.
|
||||
def callback():
|
||||
self.called = True
|
||||
self.stop()
|
||||
|
||||
def schedule_callback():
|
||||
self.called = False
|
||||
self.io_loop.add_callback(callback)
|
||||
# Store away the time so we can check if we woke up immediately
|
||||
self.start_time = time.time()
|
||||
self.io_loop.add_timeout(self.io_loop.time(), schedule_callback)
|
||||
self.wait()
|
||||
self.assertAlmostEqual(time.time(), self.start_time, places=2)
|
||||
self.assertTrue(self.called)
|
||||
|
||||
@skipOnTravis
|
||||
def test_add_callback_wakeup_other_thread(self):
|
||||
def target():
|
||||
# sleep a bit to let the ioloop go into its poll loop
|
||||
time.sleep(0.01)
|
||||
self.stop_time = time.time()
|
||||
self.io_loop.add_callback(self.stop)
|
||||
thread = threading.Thread(target=target)
|
||||
self.io_loop.add_callback(thread.start)
|
||||
self.wait()
|
||||
delta = time.time() - self.stop_time
|
||||
self.assertLess(delta, 0.1)
|
||||
thread.join()
|
||||
|
||||
def test_add_timeout_timedelta(self):
|
||||
self.io_loop.add_timeout(datetime.timedelta(microseconds=1), self.stop)
|
||||
self.wait()
|
||||
|
||||
def test_multiple_add(self):
|
||||
sock, port = bind_unused_port()
|
||||
try:
|
||||
self.io_loop.add_handler(sock.fileno(), lambda fd, events: None,
|
||||
IOLoop.READ)
|
||||
# Attempting to add the same handler twice fails
|
||||
# (with a platform-dependent exception)
|
||||
self.assertRaises(Exception, self.io_loop.add_handler,
|
||||
sock.fileno(), lambda fd, events: None,
|
||||
IOLoop.READ)
|
||||
finally:
|
||||
self.io_loop.remove_handler(sock.fileno())
|
||||
sock.close()
|
||||
|
||||
def test_remove_without_add(self):
|
||||
# remove_handler should not throw an exception if called on an fd
|
||||
# was never added.
|
||||
sock, port = bind_unused_port()
|
||||
try:
|
||||
self.io_loop.remove_handler(sock.fileno())
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
def test_add_callback_from_signal(self):
|
||||
# cheat a little bit and just run this normally, since we can't
|
||||
# easily simulate the races that happen with real signal handlers
|
||||
self.io_loop.add_callback_from_signal(self.stop)
|
||||
self.wait()
|
||||
|
||||
def test_add_callback_from_signal_other_thread(self):
|
||||
# Very crude test, just to make sure that we cover this case.
|
||||
# This also happens to be the first test where we run an IOLoop in
|
||||
# a non-main thread.
|
||||
other_ioloop = IOLoop()
|
||||
thread = threading.Thread(target=other_ioloop.start)
|
||||
thread.start()
|
||||
other_ioloop.add_callback_from_signal(other_ioloop.stop)
|
||||
thread.join()
|
||||
other_ioloop.close()
|
||||
|
||||
def test_add_callback_while_closing(self):
|
||||
# add_callback should not fail if it races with another thread
|
||||
# closing the IOLoop. The callbacks are dropped silently
|
||||
# without executing.
|
||||
closing = threading.Event()
|
||||
|
||||
def target():
|
||||
other_ioloop.add_callback(other_ioloop.stop)
|
||||
other_ioloop.start()
|
||||
closing.set()
|
||||
other_ioloop.close(all_fds=True)
|
||||
other_ioloop = IOLoop()
|
||||
thread = threading.Thread(target=target)
|
||||
thread.start()
|
||||
closing.wait()
|
||||
for i in range(1000):
|
||||
other_ioloop.add_callback(lambda: None)
|
||||
|
||||
def test_handle_callback_exception(self):
|
||||
# IOLoop.handle_callback_exception can be overridden to catch
|
||||
# exceptions in callbacks.
|
||||
def handle_callback_exception(callback):
|
||||
self.assertIs(sys.exc_info()[0], ZeroDivisionError)
|
||||
self.stop()
|
||||
self.io_loop.handle_callback_exception = handle_callback_exception
|
||||
with NullContext():
|
||||
# remove the test StackContext that would see this uncaught
|
||||
# exception as a test failure.
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
self.wait()
|
||||
|
||||
@skipIfNonUnix # just because socketpair is so convenient
|
||||
def test_read_while_writeable(self):
|
||||
# Ensure that write events don't come in while we're waiting for
|
||||
# a read and haven't asked for writeability. (the reverse is
|
||||
# difficult to test for)
|
||||
client, server = socket.socketpair()
|
||||
try:
|
||||
def handler(fd, events):
|
||||
self.assertEqual(events, IOLoop.READ)
|
||||
self.stop()
|
||||
self.io_loop.add_handler(client.fileno(), handler, IOLoop.READ)
|
||||
self.io_loop.add_timeout(self.io_loop.time() + 0.01,
|
||||
functools.partial(server.send, b'asdf'))
|
||||
self.wait()
|
||||
self.io_loop.remove_handler(client.fileno())
|
||||
finally:
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
def test_remove_timeout_after_fire(self):
|
||||
# It is not an error to call remove_timeout after it has run.
|
||||
handle = self.io_loop.add_timeout(self.io_loop.time(), self.stop)
|
||||
self.wait()
|
||||
self.io_loop.remove_timeout(handle)
|
||||
|
||||
def test_remove_timeout_cleanup(self):
|
||||
# Add and remove enough callbacks to trigger cleanup.
|
||||
# Not a very thorough test, but it ensures that the cleanup code
|
||||
# gets executed and doesn't blow up. This test is only really useful
|
||||
# on PollIOLoop subclasses, but it should run silently on any
|
||||
# implementation.
|
||||
for i in range(2000):
|
||||
timeout = self.io_loop.add_timeout(self.io_loop.time() + 3600,
|
||||
lambda: None)
|
||||
self.io_loop.remove_timeout(timeout)
|
||||
# HACK: wait two IOLoop iterations for the GC to happen.
|
||||
self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop))
|
||||
self.wait()
|
||||
|
||||
def test_remove_timeout_from_timeout(self):
|
||||
calls = [False, False]
|
||||
|
||||
# Schedule several callbacks and wait for them all to come due at once.
|
||||
# t2 should be cancelled by t1, even though it is already scheduled to
|
||||
# be run before the ioloop even looks at it.
|
||||
now = self.io_loop.time()
|
||||
|
||||
def t1():
|
||||
calls[0] = True
|
||||
self.io_loop.remove_timeout(t2_handle)
|
||||
self.io_loop.add_timeout(now + 0.01, t1)
|
||||
|
||||
def t2():
|
||||
calls[1] = True
|
||||
t2_handle = self.io_loop.add_timeout(now + 0.02, t2)
|
||||
self.io_loop.add_timeout(now + 0.03, self.stop)
|
||||
time.sleep(0.03)
|
||||
self.wait()
|
||||
self.assertEqual(calls, [True, False])
|
||||
|
||||
def test_timeout_with_arguments(self):
|
||||
# This tests that all the timeout methods pass through *args correctly.
|
||||
results = []
|
||||
self.io_loop.add_timeout(self.io_loop.time(), results.append, 1)
|
||||
self.io_loop.add_timeout(datetime.timedelta(seconds=0),
|
||||
results.append, 2)
|
||||
self.io_loop.call_at(self.io_loop.time(), results.append, 3)
|
||||
self.io_loop.call_later(0, results.append, 4)
|
||||
self.io_loop.call_later(0, self.stop)
|
||||
self.wait()
|
||||
# The asyncio event loop does not guarantee the order of these
|
||||
# callbacks, but PollIOLoop does.
|
||||
self.assertEqual(sorted(results), [1, 2, 3, 4])
|
||||
|
||||
def test_add_timeout_return(self):
|
||||
# All the timeout methods return non-None handles that can be
|
||||
# passed to remove_timeout.
|
||||
handle = self.io_loop.add_timeout(self.io_loop.time(), lambda: None)
|
||||
self.assertFalse(handle is None)
|
||||
self.io_loop.remove_timeout(handle)
|
||||
|
||||
def test_call_at_return(self):
|
||||
handle = self.io_loop.call_at(self.io_loop.time(), lambda: None)
|
||||
self.assertFalse(handle is None)
|
||||
self.io_loop.remove_timeout(handle)
|
||||
|
||||
def test_call_later_return(self):
|
||||
handle = self.io_loop.call_later(0, lambda: None)
|
||||
self.assertFalse(handle is None)
|
||||
self.io_loop.remove_timeout(handle)
|
||||
|
||||
def test_close_file_object(self):
|
||||
"""When a file object is used instead of a numeric file descriptor,
|
||||
the object should be closed (by IOLoop.close(all_fds=True),
|
||||
not just the fd.
|
||||
"""
|
||||
# Use a socket since they are supported by IOLoop on all platforms.
|
||||
# Unfortunately, sockets don't support the .closed attribute for
|
||||
# inspecting their close status, so we must use a wrapper.
|
||||
class SocketWrapper(object):
|
||||
def __init__(self, sockobj):
|
||||
self.sockobj = sockobj
|
||||
self.closed = False
|
||||
|
||||
def fileno(self):
|
||||
return self.sockobj.fileno()
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
self.sockobj.close()
|
||||
sockobj, port = bind_unused_port()
|
||||
socket_wrapper = SocketWrapper(sockobj)
|
||||
io_loop = IOLoop()
|
||||
io_loop.add_handler(socket_wrapper, lambda fd, events: None,
|
||||
IOLoop.READ)
|
||||
io_loop.close(all_fds=True)
|
||||
self.assertTrue(socket_wrapper.closed)
|
||||
|
||||
def test_handler_callback_file_object(self):
|
||||
"""The handler callback receives the same fd object it passed in."""
|
||||
server_sock, port = bind_unused_port()
|
||||
fds = []
|
||||
|
||||
def handle_connection(fd, events):
|
||||
fds.append(fd)
|
||||
conn, addr = server_sock.accept()
|
||||
conn.close()
|
||||
self.stop()
|
||||
self.io_loop.add_handler(server_sock, handle_connection, IOLoop.READ)
|
||||
with contextlib.closing(socket.socket()) as client_sock:
|
||||
client_sock.connect(('127.0.0.1', port))
|
||||
self.wait()
|
||||
self.io_loop.remove_handler(server_sock)
|
||||
self.io_loop.add_handler(server_sock.fileno(), handle_connection,
|
||||
IOLoop.READ)
|
||||
with contextlib.closing(socket.socket()) as client_sock:
|
||||
client_sock.connect(('127.0.0.1', port))
|
||||
self.wait()
|
||||
self.assertIs(fds[0], server_sock)
|
||||
self.assertEqual(fds[1], server_sock.fileno())
|
||||
self.io_loop.remove_handler(server_sock.fileno())
|
||||
server_sock.close()
|
||||
|
||||
def test_mixed_fd_fileobj(self):
|
||||
server_sock, port = bind_unused_port()
|
||||
|
||||
def f(fd, events):
|
||||
pass
|
||||
self.io_loop.add_handler(server_sock, f, IOLoop.READ)
|
||||
with self.assertRaises(Exception):
|
||||
# The exact error is unspecified - some implementations use
|
||||
# IOError, others use ValueError.
|
||||
self.io_loop.add_handler(server_sock.fileno(), f, IOLoop.READ)
|
||||
self.io_loop.remove_handler(server_sock.fileno())
|
||||
server_sock.close()
|
||||
|
||||
def test_reentrant(self):
|
||||
"""Calling start() twice should raise an error, not deadlock."""
|
||||
returned_from_start = [False]
|
||||
got_exception = [False]
|
||||
|
||||
def callback():
|
||||
try:
|
||||
self.io_loop.start()
|
||||
returned_from_start[0] = True
|
||||
except Exception:
|
||||
got_exception[0] = True
|
||||
self.stop()
|
||||
self.io_loop.add_callback(callback)
|
||||
self.wait()
|
||||
self.assertTrue(got_exception[0])
|
||||
self.assertFalse(returned_from_start[0])
|
||||
|
||||
def test_exception_logging(self):
|
||||
"""Uncaught exceptions get logged by the IOLoop."""
|
||||
# Use a NullContext to keep the exception from being caught by
|
||||
# AsyncTestCase.
|
||||
with NullContext():
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
self.io_loop.add_callback(self.stop)
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
self.wait()
|
||||
|
||||
def test_exception_logging_future(self):
|
||||
"""The IOLoop examines exceptions from Futures and logs them."""
|
||||
with NullContext():
|
||||
@gen.coroutine
|
||||
def callback():
|
||||
self.io_loop.add_callback(self.stop)
|
||||
1 / 0
|
||||
self.io_loop.add_callback(callback)
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
self.wait()
|
||||
|
||||
@skipBefore35
|
||||
def test_exception_logging_native_coro(self):
|
||||
"""The IOLoop examines exceptions from awaitables and logs them."""
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def callback():
|
||||
# Stop the IOLoop two iterations after raising an exception
|
||||
# to give the exception time to be logged.
|
||||
self.io_loop.add_callback(self.io_loop.add_callback, self.stop)
|
||||
1 / 0
|
||||
""")
|
||||
with NullContext():
|
||||
self.io_loop.add_callback(namespace["callback"])
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
self.wait()
|
||||
|
||||
def test_spawn_callback(self):
|
||||
with ignore_deprecation():
|
||||
# An added callback runs in the test's stack_context, so will be
|
||||
# re-raised in wait().
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
self.wait()
|
||||
# A spawned callback is run directly on the IOLoop, so it will be
|
||||
# logged without stopping the test.
|
||||
self.io_loop.spawn_callback(lambda: 1 / 0)
|
||||
self.io_loop.add_callback(self.stop)
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
self.wait()
|
||||
|
||||
@skipIfNonUnix
|
||||
def test_remove_handler_from_handler(self):
|
||||
# Create two sockets with simultaneous read events.
|
||||
client, server = socket.socketpair()
|
||||
try:
|
||||
client.send(b'abc')
|
||||
server.send(b'abc')
|
||||
|
||||
# After reading from one fd, remove the other from the IOLoop.
|
||||
chunks = []
|
||||
|
||||
def handle_read(fd, events):
|
||||
chunks.append(fd.recv(1024))
|
||||
if fd is client:
|
||||
self.io_loop.remove_handler(server)
|
||||
else:
|
||||
self.io_loop.remove_handler(client)
|
||||
self.io_loop.add_handler(client, handle_read, self.io_loop.READ)
|
||||
self.io_loop.add_handler(server, handle_read, self.io_loop.READ)
|
||||
self.io_loop.call_later(0.1, self.stop)
|
||||
self.wait()
|
||||
|
||||
# Only one fd was read; the other was cleanly removed.
|
||||
self.assertEqual(chunks, [b'abc'])
|
||||
finally:
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
@gen_test
|
||||
def test_init_close_race(self):
|
||||
# Regression test for #2367
|
||||
def f():
|
||||
for i in range(10):
|
||||
loop = IOLoop()
|
||||
loop.close()
|
||||
|
||||
yield gen.multi([self.io_loop.run_in_executor(None, f) for i in range(2)])
|
||||
|
||||
|
||||
# Deliberately not a subclass of AsyncTestCase so the IOLoop isn't
|
||||
# automatically set as current.
|
||||
class TestIOLoopCurrent(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.io_loop = None
|
||||
IOLoop.clear_current()
|
||||
|
||||
def tearDown(self):
|
||||
if self.io_loop is not None:
|
||||
self.io_loop.close()
|
||||
|
||||
def test_default_current(self):
|
||||
self.io_loop = IOLoop()
|
||||
# The first IOLoop with default arguments is made current.
|
||||
self.assertIs(self.io_loop, IOLoop.current())
|
||||
# A second IOLoop can be created but is not made current.
|
||||
io_loop2 = IOLoop()
|
||||
self.assertIs(self.io_loop, IOLoop.current())
|
||||
io_loop2.close()
|
||||
|
||||
def test_non_current(self):
|
||||
self.io_loop = IOLoop(make_current=False)
|
||||
# The new IOLoop is not initially made current.
|
||||
self.assertIsNone(IOLoop.current(instance=False))
|
||||
# Starting the IOLoop makes it current, and stopping the loop
|
||||
# makes it non-current. This process is repeatable.
|
||||
for i in range(3):
|
||||
def f():
|
||||
self.current_io_loop = IOLoop.current()
|
||||
self.io_loop.stop()
|
||||
self.io_loop.add_callback(f)
|
||||
self.io_loop.start()
|
||||
self.assertIs(self.current_io_loop, self.io_loop)
|
||||
# Now that the loop is stopped, it is no longer current.
|
||||
self.assertIsNone(IOLoop.current(instance=False))
|
||||
|
||||
def test_force_current(self):
|
||||
self.io_loop = IOLoop(make_current=True)
|
||||
self.assertIs(self.io_loop, IOLoop.current())
|
||||
with self.assertRaises(RuntimeError):
|
||||
# A second make_current=True construction cannot succeed.
|
||||
IOLoop(make_current=True)
|
||||
# current() was not affected by the failed construction.
|
||||
self.assertIs(self.io_loop, IOLoop.current())
|
||||
|
||||
|
||||
class TestIOLoopCurrentAsync(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_clear_without_current(self):
|
||||
# If there is no current IOLoop, clear_current is a no-op (but
|
||||
# should not fail). Use a thread so we see the threading.Local
|
||||
# in a pristine state.
|
||||
with ThreadPoolExecutor(1) as e:
|
||||
yield e.submit(IOLoop.clear_current)
|
||||
|
||||
|
||||
class TestIOLoopAddCallback(AsyncTestCase):
|
||||
def setUp(self):
|
||||
super(TestIOLoopAddCallback, self).setUp()
|
||||
self.active_contexts = []
|
||||
|
||||
def add_callback(self, callback, *args, **kwargs):
|
||||
self.io_loop.add_callback(callback, *args, **kwargs)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def context(self, name):
|
||||
self.active_contexts.append(name)
|
||||
yield
|
||||
self.assertEqual(self.active_contexts.pop(), name)
|
||||
|
||||
def test_pre_wrap(self):
|
||||
# A pre-wrapped callback is run in the context in which it was
|
||||
# wrapped, not when it was added to the IOLoop.
|
||||
def f1():
|
||||
self.assertIn('c1', self.active_contexts)
|
||||
self.assertNotIn('c2', self.active_contexts)
|
||||
self.stop()
|
||||
|
||||
with ignore_deprecation():
|
||||
with StackContext(functools.partial(self.context, 'c1')):
|
||||
wrapped = wrap(f1)
|
||||
|
||||
with StackContext(functools.partial(self.context, 'c2')):
|
||||
self.add_callback(wrapped)
|
||||
|
||||
self.wait()
|
||||
|
||||
def test_pre_wrap_with_args(self):
|
||||
# Same as test_pre_wrap, but the function takes arguments.
|
||||
# Implementation note: The function must not be wrapped in a
|
||||
# functools.partial until after it has been passed through
|
||||
# stack_context.wrap
|
||||
def f1(foo, bar):
|
||||
self.assertIn('c1', self.active_contexts)
|
||||
self.assertNotIn('c2', self.active_contexts)
|
||||
self.stop((foo, bar))
|
||||
|
||||
with ignore_deprecation():
|
||||
with StackContext(functools.partial(self.context, 'c1')):
|
||||
wrapped = wrap(f1)
|
||||
|
||||
with StackContext(functools.partial(self.context, 'c2')):
|
||||
self.add_callback(wrapped, 1, bar=2)
|
||||
|
||||
result = self.wait()
|
||||
self.assertEqual(result, (1, 2))
|
||||
|
||||
|
||||
class TestIOLoopAddCallbackFromSignal(TestIOLoopAddCallback):
|
||||
# Repeat the add_callback tests using add_callback_from_signal
|
||||
def add_callback(self, callback, *args, **kwargs):
|
||||
self.io_loop.add_callback_from_signal(callback, *args, **kwargs)
|
||||
|
||||
|
||||
@unittest.skipIf(futures is None, "futures module not present")
|
||||
class TestIOLoopFutures(AsyncTestCase):
|
||||
def test_add_future_threads(self):
|
||||
with futures.ThreadPoolExecutor(1) as pool:
|
||||
self.io_loop.add_future(pool.submit(lambda: None),
|
||||
lambda future: self.stop(future))
|
||||
future = self.wait()
|
||||
self.assertTrue(future.done())
|
||||
self.assertTrue(future.result() is None)
|
||||
|
||||
def test_add_future_stack_context(self):
|
||||
ready = threading.Event()
|
||||
|
||||
def task():
|
||||
# we must wait for the ioloop callback to be scheduled before
|
||||
# the task completes to ensure that add_future adds the callback
|
||||
# asynchronously (which is the scenario in which capturing
|
||||
# the stack_context matters)
|
||||
ready.wait(1)
|
||||
assert ready.isSet(), "timed out"
|
||||
raise Exception("worker")
|
||||
|
||||
def callback(future):
|
||||
self.future = future
|
||||
raise Exception("callback")
|
||||
|
||||
def handle_exception(typ, value, traceback):
|
||||
self.exception = value
|
||||
self.stop()
|
||||
return True
|
||||
|
||||
# stack_context propagates to the ioloop callback, but the worker
|
||||
# task just has its exceptions caught and saved in the Future.
|
||||
with ignore_deprecation():
|
||||
with futures.ThreadPoolExecutor(1) as pool:
|
||||
with ExceptionStackContext(handle_exception):
|
||||
self.io_loop.add_future(pool.submit(task), callback)
|
||||
ready.set()
|
||||
self.wait()
|
||||
|
||||
self.assertEqual(self.exception.args[0], "callback")
|
||||
self.assertEqual(self.future.exception().args[0], "worker")
|
||||
|
||||
@gen_test
|
||||
def test_run_in_executor_gen(self):
|
||||
event1 = threading.Event()
|
||||
event2 = threading.Event()
|
||||
|
||||
def sync_func(self_event, other_event):
|
||||
self_event.set()
|
||||
other_event.wait()
|
||||
# Note that return value doesn't actually do anything,
|
||||
# it is just passed through to our final assertion to
|
||||
# make sure it is passed through properly.
|
||||
return self_event
|
||||
|
||||
# Run two synchronous functions, which would deadlock if not
|
||||
# run in parallel.
|
||||
res = yield [
|
||||
IOLoop.current().run_in_executor(None, sync_func, event1, event2),
|
||||
IOLoop.current().run_in_executor(None, sync_func, event2, event1)
|
||||
]
|
||||
|
||||
self.assertEqual([event1, event2], res)
|
||||
|
||||
@skipBefore35
|
||||
@gen_test
|
||||
def test_run_in_executor_native(self):
|
||||
event1 = threading.Event()
|
||||
event2 = threading.Event()
|
||||
|
||||
def sync_func(self_event, other_event):
|
||||
self_event.set()
|
||||
other_event.wait()
|
||||
return self_event
|
||||
|
||||
# Go through an async wrapper to ensure that the result of
|
||||
# run_in_executor works with await and not just gen.coroutine
|
||||
# (simply passing the underlying concurrrent future would do that).
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def async_wrapper(self_event, other_event):
|
||||
return await IOLoop.current().run_in_executor(
|
||||
None, sync_func, self_event, other_event)
|
||||
""")
|
||||
|
||||
res = yield [
|
||||
namespace["async_wrapper"](event1, event2),
|
||||
namespace["async_wrapper"](event2, event1)
|
||||
]
|
||||
|
||||
self.assertEqual([event1, event2], res)
|
||||
|
||||
@gen_test
|
||||
def test_set_default_executor(self):
|
||||
count = [0]
|
||||
|
||||
class MyExecutor(futures.ThreadPoolExecutor):
|
||||
def submit(self, func, *args):
|
||||
count[0] += 1
|
||||
return super(MyExecutor, self).submit(func, *args)
|
||||
|
||||
event = threading.Event()
|
||||
|
||||
def sync_func():
|
||||
event.set()
|
||||
|
||||
executor = MyExecutor(1)
|
||||
loop = IOLoop.current()
|
||||
loop.set_default_executor(executor)
|
||||
yield loop.run_in_executor(None, sync_func)
|
||||
self.assertEqual(1, count[0])
|
||||
self.assertTrue(event.is_set())
|
||||
|
||||
|
||||
class TestIOLoopRunSync(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.io_loop = IOLoop()
|
||||
|
||||
def tearDown(self):
|
||||
self.io_loop.close()
|
||||
|
||||
def test_sync_result(self):
|
||||
with self.assertRaises(gen.BadYieldError):
|
||||
self.io_loop.run_sync(lambda: 42)
|
||||
|
||||
def test_sync_exception(self):
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
self.io_loop.run_sync(lambda: 1 / 0)
|
||||
|
||||
def test_async_result(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.moment
|
||||
raise gen.Return(42)
|
||||
self.assertEqual(self.io_loop.run_sync(f), 42)
|
||||
|
||||
def test_async_exception(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.moment
|
||||
1 / 0
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
self.io_loop.run_sync(f)
|
||||
|
||||
def test_current(self):
|
||||
def f():
|
||||
self.assertIs(IOLoop.current(), self.io_loop)
|
||||
self.io_loop.run_sync(f)
|
||||
|
||||
def test_timeout(self):
|
||||
@gen.coroutine
|
||||
def f():
|
||||
yield gen.sleep(1)
|
||||
self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
|
||||
|
||||
@skipBefore35
|
||||
def test_native_coroutine(self):
|
||||
@gen.coroutine
|
||||
def f1():
|
||||
yield gen.moment
|
||||
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def f2():
|
||||
await f1()
|
||||
""")
|
||||
self.io_loop.run_sync(namespace['f2'])
|
||||
|
||||
|
||||
@unittest.skipIf(asyncio is not None,
|
||||
'IOLoop configuration not available')
|
||||
class TestPeriodicCallback(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.io_loop = FakeTimeIOLoop()
|
||||
self.io_loop.make_current()
|
||||
|
||||
def tearDown(self):
|
||||
self.io_loop.close()
|
||||
|
||||
def test_basic(self):
|
||||
calls = []
|
||||
|
||||
def cb():
|
||||
calls.append(self.io_loop.time())
|
||||
pc = PeriodicCallback(cb, 10000)
|
||||
pc.start()
|
||||
self.io_loop.call_later(50, self.io_loop.stop)
|
||||
self.io_loop.start()
|
||||
self.assertEqual(calls, [1010, 1020, 1030, 1040, 1050])
|
||||
|
||||
def test_overrun(self):
|
||||
sleep_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0]
|
||||
expected = [
|
||||
1010, 1020, 1030, # first 3 calls on schedule
|
||||
1050, 1070, # next 2 delayed one cycle
|
||||
1100, 1130, # next 2 delayed 2 cycles
|
||||
1170, 1210, # next 2 delayed 3 cycles
|
||||
1220, 1230, # then back on schedule.
|
||||
]
|
||||
calls = []
|
||||
|
||||
def cb():
|
||||
calls.append(self.io_loop.time())
|
||||
if not sleep_durations:
|
||||
self.io_loop.stop()
|
||||
return
|
||||
self.io_loop.sleep(sleep_durations.pop(0))
|
||||
pc = PeriodicCallback(cb, 10000)
|
||||
pc.start()
|
||||
self.io_loop.start()
|
||||
self.assertEqual(calls, expected)
|
||||
|
||||
def test_io_loop_set_at_start(self):
|
||||
# Check PeriodicCallback uses the current IOLoop at start() time,
|
||||
# not at instantiation time.
|
||||
calls = []
|
||||
io_loop = FakeTimeIOLoop()
|
||||
|
||||
def cb():
|
||||
calls.append(io_loop.time())
|
||||
pc = PeriodicCallback(cb, 10000)
|
||||
io_loop.make_current()
|
||||
pc.start()
|
||||
io_loop.call_later(50, io_loop.stop)
|
||||
io_loop.start()
|
||||
self.assertEqual(calls, [1010, 1020, 1030, 1040, 1050])
|
||||
io_loop.close()
|
||||
|
||||
|
||||
class TestPeriodicCallbackMath(unittest.TestCase):
|
||||
def simulate_calls(self, pc, durations):
|
||||
"""Simulate a series of calls to the PeriodicCallback.
|
||||
|
||||
Pass a list of call durations in seconds (negative values
|
||||
work to simulate clock adjustments during the call, or more or
|
||||
less equivalently, between calls). This method returns the
|
||||
times at which each call would be made.
|
||||
"""
|
||||
calls = []
|
||||
now = 1000
|
||||
pc._next_timeout = now
|
||||
for d in durations:
|
||||
pc._update_next(now)
|
||||
calls.append(pc._next_timeout)
|
||||
now = pc._next_timeout + d
|
||||
return calls
|
||||
|
||||
def test_basic(self):
|
||||
pc = PeriodicCallback(None, 10000)
|
||||
self.assertEqual(self.simulate_calls(pc, [0] * 5),
|
||||
[1010, 1020, 1030, 1040, 1050])
|
||||
|
||||
def test_overrun(self):
|
||||
# If a call runs for too long, we skip entire cycles to get
|
||||
# back on schedule.
|
||||
call_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0, 0]
|
||||
expected = [
|
||||
1010, 1020, 1030, # first 3 calls on schedule
|
||||
1050, 1070, # next 2 delayed one cycle
|
||||
1100, 1130, # next 2 delayed 2 cycles
|
||||
1170, 1210, # next 2 delayed 3 cycles
|
||||
1220, 1230, # then back on schedule.
|
||||
]
|
||||
|
||||
pc = PeriodicCallback(None, 10000)
|
||||
self.assertEqual(self.simulate_calls(pc, call_durations),
|
||||
expected)
|
||||
|
||||
def test_clock_backwards(self):
|
||||
pc = PeriodicCallback(None, 10000)
|
||||
# Backwards jumps are ignored, potentially resulting in a
|
||||
# slightly slow schedule (although we assume that when
|
||||
# time.time() and time.monotonic() are different, time.time()
|
||||
# is getting adjusted by NTP and is therefore more accurate)
|
||||
self.assertEqual(self.simulate_calls(pc, [-2, -1, -3, -2, 0]),
|
||||
[1010, 1020, 1030, 1040, 1050])
|
||||
|
||||
# For big jumps, we should perhaps alter the schedule, but we
|
||||
# don't currently. This trace shows that we run callbacks
|
||||
# every 10s of time.time(), but the first and second calls are
|
||||
# 110s of real time apart because the backwards jump is
|
||||
# ignored.
|
||||
self.assertEqual(self.simulate_calls(pc, [-100, 0, 0]),
|
||||
[1010, 1020, 1030])
|
||||
|
||||
@unittest.skipIf(mock is None, 'mock package not present')
|
||||
def test_jitter(self):
|
||||
random_times = [0.5, 1, 0, 0.75]
|
||||
expected = [1010, 1022.5, 1030, 1041.25]
|
||||
call_durations = [0] * len(random_times)
|
||||
pc = PeriodicCallback(None, 10000, jitter=0.5)
|
||||
|
||||
def mock_random():
|
||||
return random_times.pop(0)
|
||||
with mock.patch('random.random', mock_random):
|
||||
self.assertEqual(self.simulate_calls(pc, call_durations),
|
||||
expected)
|
||||
|
||||
|
||||
class TestIOLoopConfiguration(unittest.TestCase):
|
||||
def run_python(self, *statements):
|
||||
statements = [
|
||||
'from tornado.ioloop import IOLoop, PollIOLoop',
|
||||
'classname = lambda x: x.__class__.__name__',
|
||||
] + list(statements)
|
||||
args = [sys.executable, '-c', '; '.join(statements)]
|
||||
return native_str(subprocess.check_output(args)).strip()
|
||||
|
||||
def test_default(self):
|
||||
if asyncio is not None:
|
||||
# When asyncio is available, it is used by default.
|
||||
cls = self.run_python('print(classname(IOLoop.current()))')
|
||||
self.assertEqual(cls, 'AsyncIOMainLoop')
|
||||
cls = self.run_python('print(classname(IOLoop()))')
|
||||
self.assertEqual(cls, 'AsyncIOLoop')
|
||||
else:
|
||||
# Otherwise, the default is a subclass of PollIOLoop
|
||||
is_poll = self.run_python(
|
||||
'print(isinstance(IOLoop.current(), PollIOLoop))')
|
||||
self.assertEqual(is_poll, 'True')
|
||||
|
||||
@unittest.skipIf(asyncio is not None,
|
||||
"IOLoop configuration not available")
|
||||
def test_explicit_select(self):
|
||||
# SelectIOLoop can always be configured explicitly.
|
||||
default_class = self.run_python(
|
||||
'IOLoop.configure("tornado.platform.select.SelectIOLoop")',
|
||||
'print(classname(IOLoop.current()))')
|
||||
self.assertEqual(default_class, 'SelectIOLoop')
|
||||
|
||||
@unittest.skipIf(asyncio is None, "asyncio module not present")
|
||||
def test_asyncio(self):
|
||||
cls = self.run_python(
|
||||
'IOLoop.configure("tornado.platform.asyncio.AsyncIOLoop")',
|
||||
'print(classname(IOLoop.current()))')
|
||||
self.assertEqual(cls, 'AsyncIOMainLoop')
|
||||
|
||||
@unittest.skipIf(asyncio is None, "asyncio module not present")
|
||||
def test_asyncio_main(self):
|
||||
cls = self.run_python(
|
||||
'from tornado.platform.asyncio import AsyncIOMainLoop',
|
||||
'AsyncIOMainLoop().install()',
|
||||
'print(classname(IOLoop.current()))')
|
||||
self.assertEqual(cls, 'AsyncIOMainLoop')
|
||||
|
||||
@unittest.skipIf(twisted is None, "twisted module not present")
|
||||
@unittest.skipIf(asyncio is not None,
|
||||
"IOLoop configuration not available")
|
||||
def test_twisted(self):
|
||||
cls = self.run_python(
|
||||
'from tornado.platform.twisted import TwistedIOLoop',
|
||||
'TwistedIOLoop().install()',
|
||||
'print(classname(IOLoop.current()))')
|
||||
self.assertEqual(cls, 'TwistedIOLoop')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Executable
+1454
File diff suppressed because it is too large
Load Diff
Executable
+131
@@ -0,0 +1,131 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import tornado.locale
|
||||
from tornado.escape import utf8, to_unicode
|
||||
from tornado.test.util import unittest, skipOnAppEngine
|
||||
from tornado.util import unicode_type
|
||||
|
||||
|
||||
class TranslationLoaderTest(unittest.TestCase):
|
||||
# TODO: less hacky way to get isolated tests
|
||||
SAVE_VARS = ['_translations', '_supported_locales', '_use_gettext']
|
||||
|
||||
def clear_locale_cache(self):
|
||||
if hasattr(tornado.locale.Locale, '_cache'):
|
||||
del tornado.locale.Locale._cache
|
||||
|
||||
def setUp(self):
|
||||
self.saved = {}
|
||||
for var in TranslationLoaderTest.SAVE_VARS:
|
||||
self.saved[var] = getattr(tornado.locale, var)
|
||||
self.clear_locale_cache()
|
||||
|
||||
def tearDown(self):
|
||||
for k, v in self.saved.items():
|
||||
setattr(tornado.locale, k, v)
|
||||
self.clear_locale_cache()
|
||||
|
||||
def test_csv(self):
|
||||
tornado.locale.load_translations(
|
||||
os.path.join(os.path.dirname(__file__), 'csv_translations'))
|
||||
locale = tornado.locale.get("fr_FR")
|
||||
self.assertTrue(isinstance(locale, tornado.locale.CSVLocale))
|
||||
self.assertEqual(locale.translate("school"), u"\u00e9cole")
|
||||
|
||||
# tempfile.mkdtemp is not available on app engine.
|
||||
@skipOnAppEngine
|
||||
def test_csv_bom(self):
|
||||
with open(os.path.join(os.path.dirname(__file__), 'csv_translations',
|
||||
'fr_FR.csv'), 'rb') as f:
|
||||
char_data = to_unicode(f.read())
|
||||
# Re-encode our input data (which is utf-8 without BOM) in
|
||||
# encodings that use the BOM and ensure that we can still load
|
||||
# it. Note that utf-16-le and utf-16-be do not write a BOM,
|
||||
# so we only test whichver variant is native to our platform.
|
||||
for encoding in ['utf-8-sig', 'utf-16']:
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
try:
|
||||
with open(os.path.join(tmpdir, 'fr_FR.csv'), 'wb') as f:
|
||||
f.write(char_data.encode(encoding))
|
||||
tornado.locale.load_translations(tmpdir)
|
||||
locale = tornado.locale.get('fr_FR')
|
||||
self.assertIsInstance(locale, tornado.locale.CSVLocale)
|
||||
self.assertEqual(locale.translate("school"), u"\u00e9cole")
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def test_gettext(self):
|
||||
tornado.locale.load_gettext_translations(
|
||||
os.path.join(os.path.dirname(__file__), 'gettext_translations'),
|
||||
"tornado_test")
|
||||
locale = tornado.locale.get("fr_FR")
|
||||
self.assertTrue(isinstance(locale, tornado.locale.GettextLocale))
|
||||
self.assertEqual(locale.translate("school"), u"\u00e9cole")
|
||||
self.assertEqual(locale.pgettext("law", "right"), u"le droit")
|
||||
self.assertEqual(locale.pgettext("good", "right"), u"le bien")
|
||||
self.assertEqual(locale.pgettext("organization", "club", "clubs", 1), u"le club")
|
||||
self.assertEqual(locale.pgettext("organization", "club", "clubs", 2), u"les clubs")
|
||||
self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), u"le b\xe2ton")
|
||||
self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), u"les b\xe2tons")
|
||||
|
||||
|
||||
class LocaleDataTest(unittest.TestCase):
|
||||
def test_non_ascii_name(self):
|
||||
name = tornado.locale.LOCALE_NAMES['es_LA']['name']
|
||||
self.assertTrue(isinstance(name, unicode_type))
|
||||
self.assertEqual(name, u'Espa\u00f1ol')
|
||||
self.assertEqual(utf8(name), b'Espa\xc3\xb1ol')
|
||||
|
||||
|
||||
class EnglishTest(unittest.TestCase):
|
||||
def test_format_date(self):
|
||||
locale = tornado.locale.get('en_US')
|
||||
date = datetime.datetime(2013, 4, 28, 18, 35)
|
||||
self.assertEqual(locale.format_date(date, full_format=True),
|
||||
'April 28, 2013 at 6:35 pm')
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
|
||||
self.assertEqual(locale.format_date(now - datetime.timedelta(seconds=2), full_format=False),
|
||||
'2 seconds ago')
|
||||
self.assertEqual(locale.format_date(now - datetime.timedelta(minutes=2), full_format=False),
|
||||
'2 minutes ago')
|
||||
self.assertEqual(locale.format_date(now - datetime.timedelta(hours=2), full_format=False),
|
||||
'2 hours ago')
|
||||
|
||||
self.assertEqual(locale.format_date(now - datetime.timedelta(days=1),
|
||||
full_format=False, shorter=True), 'yesterday')
|
||||
|
||||
date = now - datetime.timedelta(days=2)
|
||||
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
|
||||
locale._weekdays[date.weekday()])
|
||||
|
||||
date = now - datetime.timedelta(days=300)
|
||||
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
|
||||
'%s %d' % (locale._months[date.month - 1], date.day))
|
||||
|
||||
date = now - datetime.timedelta(days=500)
|
||||
self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
|
||||
'%s %d, %d' % (locale._months[date.month - 1], date.day, date.year))
|
||||
|
||||
def test_friendly_number(self):
|
||||
locale = tornado.locale.get('en_US')
|
||||
self.assertEqual(locale.friendly_number(1000000), '1,000,000')
|
||||
|
||||
def test_list(self):
|
||||
locale = tornado.locale.get('en_US')
|
||||
self.assertEqual(locale.list([]), '')
|
||||
self.assertEqual(locale.list(['A']), 'A')
|
||||
self.assertEqual(locale.list(['A', 'B']), 'A and B')
|
||||
self.assertEqual(locale.list(['A', 'B', 'C']), 'A, B and C')
|
||||
|
||||
def test_format_day(self):
|
||||
locale = tornado.locale.get('en_US')
|
||||
date = datetime.datetime(2013, 4, 28, 18, 35)
|
||||
self.assertEqual(locale.format_day(date=date, dow=True), 'Sunday, April 28')
|
||||
self.assertEqual(locale.format_day(date=date, dow=False), 'April 28')
|
||||
Executable
+537
@@ -0,0 +1,537 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
from datetime import timedelta
|
||||
|
||||
from tornado import gen, locks
|
||||
from tornado.gen import TimeoutError
|
||||
from tornado.testing import gen_test, AsyncTestCase
|
||||
from tornado.test.util import unittest, skipBefore35, exec_test
|
||||
|
||||
|
||||
class ConditionTest(AsyncTestCase):
|
||||
def setUp(self):
|
||||
super(ConditionTest, self).setUp()
|
||||
self.history = []
|
||||
|
||||
def record_done(self, future, key):
|
||||
"""Record the resolution of a Future returned by Condition.wait."""
|
||||
def callback(_):
|
||||
if not future.result():
|
||||
# wait() resolved to False, meaning it timed out.
|
||||
self.history.append('timeout')
|
||||
else:
|
||||
self.history.append(key)
|
||||
future.add_done_callback(callback)
|
||||
|
||||
def loop_briefly(self):
|
||||
"""Run all queued callbacks on the IOLoop.
|
||||
|
||||
In these tests, this method is used after calling notify() to
|
||||
preserve the pre-5.0 behavior in which callbacks ran
|
||||
synchronously.
|
||||
"""
|
||||
self.io_loop.add_callback(self.stop)
|
||||
self.wait()
|
||||
|
||||
def test_repr(self):
|
||||
c = locks.Condition()
|
||||
self.assertIn('Condition', repr(c))
|
||||
self.assertNotIn('waiters', repr(c))
|
||||
c.wait()
|
||||
self.assertIn('waiters', repr(c))
|
||||
|
||||
@gen_test
|
||||
def test_notify(self):
|
||||
c = locks.Condition()
|
||||
self.io_loop.call_later(0.01, c.notify)
|
||||
yield c.wait()
|
||||
|
||||
def test_notify_1(self):
|
||||
c = locks.Condition()
|
||||
self.record_done(c.wait(), 'wait1')
|
||||
self.record_done(c.wait(), 'wait2')
|
||||
c.notify(1)
|
||||
self.loop_briefly()
|
||||
self.history.append('notify1')
|
||||
c.notify(1)
|
||||
self.loop_briefly()
|
||||
self.history.append('notify2')
|
||||
self.assertEqual(['wait1', 'notify1', 'wait2', 'notify2'],
|
||||
self.history)
|
||||
|
||||
def test_notify_n(self):
|
||||
c = locks.Condition()
|
||||
for i in range(6):
|
||||
self.record_done(c.wait(), i)
|
||||
|
||||
c.notify(3)
|
||||
self.loop_briefly()
|
||||
|
||||
# Callbacks execute in the order they were registered.
|
||||
self.assertEqual(list(range(3)), self.history)
|
||||
c.notify(1)
|
||||
self.loop_briefly()
|
||||
self.assertEqual(list(range(4)), self.history)
|
||||
c.notify(2)
|
||||
self.loop_briefly()
|
||||
self.assertEqual(list(range(6)), self.history)
|
||||
|
||||
def test_notify_all(self):
|
||||
c = locks.Condition()
|
||||
for i in range(4):
|
||||
self.record_done(c.wait(), i)
|
||||
|
||||
c.notify_all()
|
||||
self.loop_briefly()
|
||||
self.history.append('notify_all')
|
||||
|
||||
# Callbacks execute in the order they were registered.
|
||||
self.assertEqual(
|
||||
list(range(4)) + ['notify_all'],
|
||||
self.history)
|
||||
|
||||
@gen_test
|
||||
def test_wait_timeout(self):
|
||||
c = locks.Condition()
|
||||
wait = c.wait(timedelta(seconds=0.01))
|
||||
self.io_loop.call_later(0.02, c.notify) # Too late.
|
||||
yield gen.sleep(0.03)
|
||||
self.assertFalse((yield wait))
|
||||
|
||||
@gen_test
|
||||
def test_wait_timeout_preempted(self):
|
||||
c = locks.Condition()
|
||||
|
||||
# This fires before the wait times out.
|
||||
self.io_loop.call_later(0.01, c.notify)
|
||||
wait = c.wait(timedelta(seconds=0.02))
|
||||
yield gen.sleep(0.03)
|
||||
yield wait # No TimeoutError.
|
||||
|
||||
@gen_test
|
||||
def test_notify_n_with_timeout(self):
|
||||
# Register callbacks 0, 1, 2, and 3. Callback 1 has a timeout.
|
||||
# Wait for that timeout to expire, then do notify(2) and make
|
||||
# sure everyone runs. Verifies that a timed-out callback does
|
||||
# not count against the 'n' argument to notify().
|
||||
c = locks.Condition()
|
||||
self.record_done(c.wait(), 0)
|
||||
self.record_done(c.wait(timedelta(seconds=0.01)), 1)
|
||||
self.record_done(c.wait(), 2)
|
||||
self.record_done(c.wait(), 3)
|
||||
|
||||
# Wait for callback 1 to time out.
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(['timeout'], self.history)
|
||||
|
||||
c.notify(2)
|
||||
yield gen.sleep(0.01)
|
||||
self.assertEqual(['timeout', 0, 2], self.history)
|
||||
self.assertEqual(['timeout', 0, 2], self.history)
|
||||
c.notify()
|
||||
yield
|
||||
self.assertEqual(['timeout', 0, 2, 3], self.history)
|
||||
|
||||
@gen_test
|
||||
def test_notify_all_with_timeout(self):
|
||||
c = locks.Condition()
|
||||
self.record_done(c.wait(), 0)
|
||||
self.record_done(c.wait(timedelta(seconds=0.01)), 1)
|
||||
self.record_done(c.wait(), 2)
|
||||
|
||||
# Wait for callback 1 to time out.
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(['timeout'], self.history)
|
||||
|
||||
c.notify_all()
|
||||
yield
|
||||
self.assertEqual(['timeout', 0, 2], self.history)
|
||||
|
||||
@gen_test
|
||||
def test_nested_notify(self):
|
||||
# Ensure no notifications lost, even if notify() is reentered by a
|
||||
# waiter calling notify().
|
||||
c = locks.Condition()
|
||||
|
||||
# Three waiters.
|
||||
futures = [c.wait() for _ in range(3)]
|
||||
|
||||
# First and second futures resolved. Second future reenters notify(),
|
||||
# resolving third future.
|
||||
futures[1].add_done_callback(lambda _: c.notify())
|
||||
c.notify(2)
|
||||
yield
|
||||
self.assertTrue(all(f.done() for f in futures))
|
||||
|
||||
@gen_test
|
||||
def test_garbage_collection(self):
|
||||
# Test that timed-out waiters are occasionally cleaned from the queue.
|
||||
c = locks.Condition()
|
||||
for _ in range(101):
|
||||
c.wait(timedelta(seconds=0.01))
|
||||
|
||||
future = c.wait()
|
||||
self.assertEqual(102, len(c._waiters))
|
||||
|
||||
# Let first 101 waiters time out, triggering a collection.
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(1, len(c._waiters))
|
||||
|
||||
# Final waiter is still active.
|
||||
self.assertFalse(future.done())
|
||||
c.notify()
|
||||
self.assertTrue(future.done())
|
||||
|
||||
|
||||
class EventTest(AsyncTestCase):
|
||||
def test_repr(self):
|
||||
event = locks.Event()
|
||||
self.assertTrue('clear' in str(event))
|
||||
self.assertFalse('set' in str(event))
|
||||
event.set()
|
||||
self.assertFalse('clear' in str(event))
|
||||
self.assertTrue('set' in str(event))
|
||||
|
||||
def test_event(self):
|
||||
e = locks.Event()
|
||||
future_0 = e.wait()
|
||||
e.set()
|
||||
future_1 = e.wait()
|
||||
e.clear()
|
||||
future_2 = e.wait()
|
||||
|
||||
self.assertTrue(future_0.done())
|
||||
self.assertTrue(future_1.done())
|
||||
self.assertFalse(future_2.done())
|
||||
|
||||
@gen_test
|
||||
def test_event_timeout(self):
|
||||
e = locks.Event()
|
||||
with self.assertRaises(TimeoutError):
|
||||
yield e.wait(timedelta(seconds=0.01))
|
||||
|
||||
# After a timed-out waiter, normal operation works.
|
||||
self.io_loop.add_timeout(timedelta(seconds=0.01), e.set)
|
||||
yield e.wait(timedelta(seconds=1))
|
||||
|
||||
def test_event_set_multiple(self):
|
||||
e = locks.Event()
|
||||
e.set()
|
||||
e.set()
|
||||
self.assertTrue(e.is_set())
|
||||
|
||||
def test_event_wait_clear(self):
|
||||
e = locks.Event()
|
||||
f0 = e.wait()
|
||||
e.clear()
|
||||
f1 = e.wait()
|
||||
e.set()
|
||||
self.assertTrue(f0.done())
|
||||
self.assertTrue(f1.done())
|
||||
|
||||
|
||||
class SemaphoreTest(AsyncTestCase):
|
||||
def test_negative_value(self):
|
||||
self.assertRaises(ValueError, locks.Semaphore, value=-1)
|
||||
|
||||
def test_repr(self):
|
||||
sem = locks.Semaphore()
|
||||
self.assertIn('Semaphore', repr(sem))
|
||||
self.assertIn('unlocked,value:1', repr(sem))
|
||||
sem.acquire()
|
||||
self.assertIn('locked', repr(sem))
|
||||
self.assertNotIn('waiters', repr(sem))
|
||||
sem.acquire()
|
||||
self.assertIn('waiters', repr(sem))
|
||||
|
||||
def test_acquire(self):
|
||||
sem = locks.Semaphore()
|
||||
f0 = sem.acquire()
|
||||
self.assertTrue(f0.done())
|
||||
|
||||
# Wait for release().
|
||||
f1 = sem.acquire()
|
||||
self.assertFalse(f1.done())
|
||||
f2 = sem.acquire()
|
||||
sem.release()
|
||||
self.assertTrue(f1.done())
|
||||
self.assertFalse(f2.done())
|
||||
sem.release()
|
||||
self.assertTrue(f2.done())
|
||||
|
||||
sem.release()
|
||||
# Now acquire() is instant.
|
||||
self.assertTrue(sem.acquire().done())
|
||||
self.assertEqual(0, len(sem._waiters))
|
||||
|
||||
@gen_test
|
||||
def test_acquire_timeout(self):
|
||||
sem = locks.Semaphore(2)
|
||||
yield sem.acquire()
|
||||
yield sem.acquire()
|
||||
acquire = sem.acquire(timedelta(seconds=0.01))
|
||||
self.io_loop.call_later(0.02, sem.release) # Too late.
|
||||
yield gen.sleep(0.3)
|
||||
with self.assertRaises(gen.TimeoutError):
|
||||
yield acquire
|
||||
|
||||
sem.acquire()
|
||||
f = sem.acquire()
|
||||
self.assertFalse(f.done())
|
||||
sem.release()
|
||||
self.assertTrue(f.done())
|
||||
|
||||
@gen_test
|
||||
def test_acquire_timeout_preempted(self):
|
||||
sem = locks.Semaphore(1)
|
||||
yield sem.acquire()
|
||||
|
||||
# This fires before the wait times out.
|
||||
self.io_loop.call_later(0.01, sem.release)
|
||||
acquire = sem.acquire(timedelta(seconds=0.02))
|
||||
yield gen.sleep(0.03)
|
||||
yield acquire # No TimeoutError.
|
||||
|
||||
def test_release_unacquired(self):
|
||||
# Unbounded releases are allowed, and increment the semaphore's value.
|
||||
sem = locks.Semaphore()
|
||||
sem.release()
|
||||
sem.release()
|
||||
|
||||
# Now the counter is 3. We can acquire three times before blocking.
|
||||
self.assertTrue(sem.acquire().done())
|
||||
self.assertTrue(sem.acquire().done())
|
||||
self.assertTrue(sem.acquire().done())
|
||||
self.assertFalse(sem.acquire().done())
|
||||
|
||||
@gen_test
|
||||
def test_garbage_collection(self):
|
||||
# Test that timed-out waiters are occasionally cleaned from the queue.
|
||||
sem = locks.Semaphore(value=0)
|
||||
futures = [sem.acquire(timedelta(seconds=0.01)) for _ in range(101)]
|
||||
|
||||
future = sem.acquire()
|
||||
self.assertEqual(102, len(sem._waiters))
|
||||
|
||||
# Let first 101 waiters time out, triggering a collection.
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(1, len(sem._waiters))
|
||||
|
||||
# Final waiter is still active.
|
||||
self.assertFalse(future.done())
|
||||
sem.release()
|
||||
self.assertTrue(future.done())
|
||||
|
||||
# Prevent "Future exception was never retrieved" messages.
|
||||
for future in futures:
|
||||
self.assertRaises(TimeoutError, future.result)
|
||||
|
||||
|
||||
class SemaphoreContextManagerTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_context_manager(self):
|
||||
sem = locks.Semaphore()
|
||||
with (yield sem.acquire()) as yielded:
|
||||
self.assertTrue(yielded is None)
|
||||
|
||||
# Semaphore was released and can be acquired again.
|
||||
self.assertTrue(sem.acquire().done())
|
||||
|
||||
@skipBefore35
|
||||
@gen_test
|
||||
def test_context_manager_async_await(self):
|
||||
# Repeat the above test using 'async with'.
|
||||
sem = locks.Semaphore()
|
||||
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def f():
|
||||
async with sem as yielded:
|
||||
self.assertTrue(yielded is None)
|
||||
""")
|
||||
yield namespace['f']()
|
||||
|
||||
# Semaphore was released and can be acquired again.
|
||||
self.assertTrue(sem.acquire().done())
|
||||
|
||||
@gen_test
|
||||
def test_context_manager_exception(self):
|
||||
sem = locks.Semaphore()
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
with (yield sem.acquire()):
|
||||
1 / 0
|
||||
|
||||
# Semaphore was released and can be acquired again.
|
||||
self.assertTrue(sem.acquire().done())
|
||||
|
||||
@gen_test
|
||||
def test_context_manager_timeout(self):
|
||||
sem = locks.Semaphore()
|
||||
with (yield sem.acquire(timedelta(seconds=0.01))):
|
||||
pass
|
||||
|
||||
# Semaphore was released and can be acquired again.
|
||||
self.assertTrue(sem.acquire().done())
|
||||
|
||||
@gen_test
|
||||
def test_context_manager_timeout_error(self):
|
||||
sem = locks.Semaphore(value=0)
|
||||
with self.assertRaises(gen.TimeoutError):
|
||||
with (yield sem.acquire(timedelta(seconds=0.01))):
|
||||
pass
|
||||
|
||||
# Counter is still 0.
|
||||
self.assertFalse(sem.acquire().done())
|
||||
|
||||
@gen_test
|
||||
def test_context_manager_contended(self):
|
||||
sem = locks.Semaphore()
|
||||
history = []
|
||||
|
||||
@gen.coroutine
|
||||
def f(index):
|
||||
with (yield sem.acquire()):
|
||||
history.append('acquired %d' % index)
|
||||
yield gen.sleep(0.01)
|
||||
history.append('release %d' % index)
|
||||
|
||||
yield [f(i) for i in range(2)]
|
||||
|
||||
expected_history = []
|
||||
for i in range(2):
|
||||
expected_history.extend(['acquired %d' % i, 'release %d' % i])
|
||||
|
||||
self.assertEqual(expected_history, history)
|
||||
|
||||
@gen_test
|
||||
def test_yield_sem(self):
|
||||
# Ensure we catch a "with (yield sem)", which should be
|
||||
# "with (yield sem.acquire())".
|
||||
with self.assertRaises(gen.BadYieldError):
|
||||
with (yield locks.Semaphore()):
|
||||
pass
|
||||
|
||||
def test_context_manager_misuse(self):
|
||||
# Ensure we catch a "with sem", which should be
|
||||
# "with (yield sem.acquire())".
|
||||
with self.assertRaises(RuntimeError):
|
||||
with locks.Semaphore():
|
||||
pass
|
||||
|
||||
|
||||
class BoundedSemaphoreTest(AsyncTestCase):
|
||||
def test_release_unacquired(self):
|
||||
sem = locks.BoundedSemaphore()
|
||||
self.assertRaises(ValueError, sem.release)
|
||||
# Value is 0.
|
||||
sem.acquire()
|
||||
# Block on acquire().
|
||||
future = sem.acquire()
|
||||
self.assertFalse(future.done())
|
||||
sem.release()
|
||||
self.assertTrue(future.done())
|
||||
# Value is 1.
|
||||
sem.release()
|
||||
self.assertRaises(ValueError, sem.release)
|
||||
|
||||
|
||||
class LockTests(AsyncTestCase):
|
||||
def test_repr(self):
|
||||
lock = locks.Lock()
|
||||
# No errors.
|
||||
repr(lock)
|
||||
lock.acquire()
|
||||
repr(lock)
|
||||
|
||||
def test_acquire_release(self):
|
||||
lock = locks.Lock()
|
||||
self.assertTrue(lock.acquire().done())
|
||||
future = lock.acquire()
|
||||
self.assertFalse(future.done())
|
||||
lock.release()
|
||||
self.assertTrue(future.done())
|
||||
|
||||
@gen_test
|
||||
def test_acquire_fifo(self):
|
||||
lock = locks.Lock()
|
||||
self.assertTrue(lock.acquire().done())
|
||||
N = 5
|
||||
history = []
|
||||
|
||||
@gen.coroutine
|
||||
def f(idx):
|
||||
with (yield lock.acquire()):
|
||||
history.append(idx)
|
||||
|
||||
futures = [f(i) for i in range(N)]
|
||||
self.assertFalse(any(future.done() for future in futures))
|
||||
lock.release()
|
||||
yield futures
|
||||
self.assertEqual(list(range(N)), history)
|
||||
|
||||
@skipBefore35
|
||||
@gen_test
|
||||
def test_acquire_fifo_async_with(self):
|
||||
# Repeat the above test using `async with lock:`
|
||||
# instead of `with (yield lock.acquire()):`.
|
||||
lock = locks.Lock()
|
||||
self.assertTrue(lock.acquire().done())
|
||||
N = 5
|
||||
history = []
|
||||
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def f(idx):
|
||||
async with lock:
|
||||
history.append(idx)
|
||||
""")
|
||||
futures = [namespace['f'](i) for i in range(N)]
|
||||
lock.release()
|
||||
yield futures
|
||||
self.assertEqual(list(range(N)), history)
|
||||
|
||||
@gen_test
|
||||
def test_acquire_timeout(self):
|
||||
lock = locks.Lock()
|
||||
lock.acquire()
|
||||
with self.assertRaises(gen.TimeoutError):
|
||||
yield lock.acquire(timeout=timedelta(seconds=0.01))
|
||||
|
||||
# Still locked.
|
||||
self.assertFalse(lock.acquire().done())
|
||||
|
||||
def test_multi_release(self):
|
||||
lock = locks.Lock()
|
||||
self.assertRaises(RuntimeError, lock.release)
|
||||
lock.acquire()
|
||||
lock.release()
|
||||
self.assertRaises(RuntimeError, lock.release)
|
||||
|
||||
@gen_test
|
||||
def test_yield_lock(self):
|
||||
# Ensure we catch a "with (yield lock)", which should be
|
||||
# "with (yield lock.acquire())".
|
||||
with self.assertRaises(gen.BadYieldError):
|
||||
with (yield locks.Lock()):
|
||||
pass
|
||||
|
||||
def test_context_manager_misuse(self):
|
||||
# Ensure we catch a "with lock", which should be
|
||||
# "with (yield lock.acquire())".
|
||||
with self.assertRaises(RuntimeError):
|
||||
with locks.Lock():
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Executable
+241
@@ -0,0 +1,241 @@
|
||||
#
|
||||
# Copyright 2012 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import contextlib
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
from tornado.escape import utf8
|
||||
from tornado.log import LogFormatter, define_logging_options, enable_pretty_logging
|
||||
from tornado.options import OptionParser
|
||||
from tornado.test.util import unittest
|
||||
from tornado.util import basestring_type
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ignore_bytes_warning():
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', category=BytesWarning)
|
||||
yield
|
||||
|
||||
|
||||
class LogFormatterTest(unittest.TestCase):
|
||||
# Matches the output of a single logging call (which may be multiple lines
|
||||
# if a traceback was included, so we use the DOTALL option)
|
||||
LINE_RE = re.compile(
|
||||
b"(?s)\x01\\[E [0-9]{6} [0-9]{2}:[0-9]{2}:[0-9]{2} log_test:[0-9]+\\]\x02 (.*)")
|
||||
|
||||
def setUp(self):
|
||||
self.formatter = LogFormatter(color=False)
|
||||
# Fake color support. We can't guarantee anything about the $TERM
|
||||
# variable when the tests are run, so just patch in some values
|
||||
# for testing. (testing with color off fails to expose some potential
|
||||
# encoding issues from the control characters)
|
||||
self.formatter._colors = {
|
||||
logging.ERROR: u"\u0001",
|
||||
}
|
||||
self.formatter._normal = u"\u0002"
|
||||
# construct a Logger directly to bypass getLogger's caching
|
||||
self.logger = logging.Logger('LogFormatterTest')
|
||||
self.logger.propagate = False
|
||||
self.tempdir = tempfile.mkdtemp()
|
||||
self.filename = os.path.join(self.tempdir, 'log.out')
|
||||
self.handler = self.make_handler(self.filename)
|
||||
self.handler.setFormatter(self.formatter)
|
||||
self.logger.addHandler(self.handler)
|
||||
|
||||
def tearDown(self):
|
||||
self.handler.close()
|
||||
os.unlink(self.filename)
|
||||
os.rmdir(self.tempdir)
|
||||
|
||||
def make_handler(self, filename):
|
||||
# Base case: default setup without explicit encoding.
|
||||
# In python 2, supports arbitrary byte strings and unicode objects
|
||||
# that contain only ascii. In python 3, supports ascii-only unicode
|
||||
# strings (but byte strings will be repr'd automatically).
|
||||
return logging.FileHandler(filename)
|
||||
|
||||
def get_output(self):
|
||||
with open(self.filename, "rb") as f:
|
||||
line = f.read().strip()
|
||||
m = LogFormatterTest.LINE_RE.match(line)
|
||||
if m:
|
||||
return m.group(1)
|
||||
else:
|
||||
raise Exception("output didn't match regex: %r" % line)
|
||||
|
||||
def test_basic_logging(self):
|
||||
self.logger.error("foo")
|
||||
self.assertEqual(self.get_output(), b"foo")
|
||||
|
||||
def test_bytes_logging(self):
|
||||
with ignore_bytes_warning():
|
||||
# This will be "\xe9" on python 2 or "b'\xe9'" on python 3
|
||||
self.logger.error(b"\xe9")
|
||||
self.assertEqual(self.get_output(), utf8(repr(b"\xe9")))
|
||||
|
||||
def test_utf8_logging(self):
|
||||
with ignore_bytes_warning():
|
||||
self.logger.error(u"\u00e9".encode("utf8"))
|
||||
if issubclass(bytes, basestring_type):
|
||||
# on python 2, utf8 byte strings (and by extension ascii byte
|
||||
# strings) are passed through as-is.
|
||||
self.assertEqual(self.get_output(), utf8(u"\u00e9"))
|
||||
else:
|
||||
# on python 3, byte strings always get repr'd even if
|
||||
# they're ascii-only, so this degenerates into another
|
||||
# copy of test_bytes_logging.
|
||||
self.assertEqual(self.get_output(), utf8(repr(utf8(u"\u00e9"))))
|
||||
|
||||
def test_bytes_exception_logging(self):
|
||||
try:
|
||||
raise Exception(b'\xe9')
|
||||
except Exception:
|
||||
self.logger.exception('caught exception')
|
||||
# This will be "Exception: \xe9" on python 2 or
|
||||
# "Exception: b'\xe9'" on python 3.
|
||||
output = self.get_output()
|
||||
self.assertRegexpMatches(output, br'Exception.*\\xe9')
|
||||
# The traceback contains newlines, which should not have been escaped.
|
||||
self.assertNotIn(br'\n', output)
|
||||
|
||||
|
||||
class UnicodeLogFormatterTest(LogFormatterTest):
|
||||
def make_handler(self, filename):
|
||||
# Adding an explicit encoding configuration allows non-ascii unicode
|
||||
# strings in both python 2 and 3, without changing the behavior
|
||||
# for byte strings.
|
||||
return logging.FileHandler(filename, encoding="utf8")
|
||||
|
||||
def test_unicode_logging(self):
|
||||
self.logger.error(u"\u00e9")
|
||||
self.assertEqual(self.get_output(), utf8(u"\u00e9"))
|
||||
|
||||
|
||||
class EnablePrettyLoggingTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super(EnablePrettyLoggingTest, self).setUp()
|
||||
self.options = OptionParser()
|
||||
define_logging_options(self.options)
|
||||
self.logger = logging.Logger('tornado.test.log_test.EnablePrettyLoggingTest')
|
||||
self.logger.propagate = False
|
||||
|
||||
def test_log_file(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
try:
|
||||
self.options.log_file_prefix = tmpdir + '/test_log'
|
||||
enable_pretty_logging(options=self.options, logger=self.logger)
|
||||
self.assertEqual(1, len(self.logger.handlers))
|
||||
self.logger.error('hello')
|
||||
self.logger.handlers[0].flush()
|
||||
filenames = glob.glob(tmpdir + '/test_log*')
|
||||
self.assertEqual(1, len(filenames))
|
||||
with open(filenames[0]) as f:
|
||||
self.assertRegexpMatches(f.read(), r'^\[E [^]]*\] hello$')
|
||||
finally:
|
||||
for handler in self.logger.handlers:
|
||||
handler.flush()
|
||||
handler.close()
|
||||
for filename in glob.glob(tmpdir + '/test_log*'):
|
||||
os.unlink(filename)
|
||||
os.rmdir(tmpdir)
|
||||
|
||||
def test_log_file_with_timed_rotating(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
try:
|
||||
self.options.log_file_prefix = tmpdir + '/test_log'
|
||||
self.options.log_rotate_mode = 'time'
|
||||
enable_pretty_logging(options=self.options, logger=self.logger)
|
||||
self.logger.error('hello')
|
||||
self.logger.handlers[0].flush()
|
||||
filenames = glob.glob(tmpdir + '/test_log*')
|
||||
self.assertEqual(1, len(filenames))
|
||||
with open(filenames[0]) as f:
|
||||
self.assertRegexpMatches(
|
||||
f.read(),
|
||||
r'^\[E [^]]*\] hello$')
|
||||
finally:
|
||||
for handler in self.logger.handlers:
|
||||
handler.flush()
|
||||
handler.close()
|
||||
for filename in glob.glob(tmpdir + '/test_log*'):
|
||||
os.unlink(filename)
|
||||
os.rmdir(tmpdir)
|
||||
|
||||
def test_wrong_rotate_mode_value(self):
|
||||
try:
|
||||
self.options.log_file_prefix = 'some_path'
|
||||
self.options.log_rotate_mode = 'wrong_mode'
|
||||
self.assertRaises(ValueError, enable_pretty_logging,
|
||||
options=self.options, logger=self.logger)
|
||||
finally:
|
||||
for handler in self.logger.handlers:
|
||||
handler.flush()
|
||||
handler.close()
|
||||
|
||||
|
||||
class LoggingOptionTest(unittest.TestCase):
|
||||
"""Test the ability to enable and disable Tornado's logging hooks."""
|
||||
def logs_present(self, statement, args=None):
|
||||
# Each test may manipulate and/or parse the options and then logs
|
||||
# a line at the 'info' level. This level is ignored in the
|
||||
# logging module by default, but Tornado turns it on by default
|
||||
# so it is the easiest way to tell whether tornado's logging hooks
|
||||
# ran.
|
||||
IMPORT = 'from tornado.options import options, parse_command_line'
|
||||
LOG_INFO = 'import logging; logging.info("hello")'
|
||||
program = ';'.join([IMPORT, statement, LOG_INFO])
|
||||
proc = subprocess.Popen(
|
||||
[sys.executable, '-c', program] + (args or []),
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
stdout, stderr = proc.communicate()
|
||||
self.assertEqual(proc.returncode, 0, 'process failed: %r' % stdout)
|
||||
return b'hello' in stdout
|
||||
|
||||
def test_default(self):
|
||||
self.assertFalse(self.logs_present('pass'))
|
||||
|
||||
def test_tornado_default(self):
|
||||
self.assertTrue(self.logs_present('parse_command_line()'))
|
||||
|
||||
def test_disable_command_line(self):
|
||||
self.assertFalse(self.logs_present('parse_command_line()',
|
||||
['--logging=none']))
|
||||
|
||||
def test_disable_command_line_case_insensitive(self):
|
||||
self.assertFalse(self.logs_present('parse_command_line()',
|
||||
['--logging=None']))
|
||||
|
||||
def test_disable_code_string(self):
|
||||
self.assertFalse(self.logs_present(
|
||||
'options.logging = "none"; parse_command_line()'))
|
||||
|
||||
def test_disable_code_none(self):
|
||||
self.assertFalse(self.logs_present(
|
||||
'options.logging = None; parse_command_line()'))
|
||||
|
||||
def test_disable_override(self):
|
||||
# command line trumps code defaults
|
||||
self.assertTrue(self.logs_present(
|
||||
'options.logging = None; parse_command_line()',
|
||||
['--logging=info']))
|
||||
Executable
+242
@@ -0,0 +1,242 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import errno
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
from subprocess import Popen
|
||||
import sys
|
||||
import time
|
||||
|
||||
from tornado.netutil import (
|
||||
BlockingResolver, OverrideResolver, ThreadedResolver, is_valid_ip, bind_sockets
|
||||
)
|
||||
from tornado.stack_context import ExceptionStackContext
|
||||
from tornado.testing import AsyncTestCase, gen_test, bind_unused_port
|
||||
from tornado.test.util import unittest, skipIfNoNetwork, ignore_deprecation
|
||||
|
||||
try:
|
||||
from concurrent import futures
|
||||
except ImportError:
|
||||
futures = None
|
||||
|
||||
try:
|
||||
import pycares # type: ignore
|
||||
except ImportError:
|
||||
pycares = None
|
||||
else:
|
||||
from tornado.platform.caresresolver import CaresResolver
|
||||
|
||||
try:
|
||||
import twisted # type: ignore
|
||||
import twisted.names # type: ignore
|
||||
except ImportError:
|
||||
twisted = None
|
||||
else:
|
||||
from tornado.platform.twisted import TwistedResolver
|
||||
|
||||
|
||||
class _ResolverTestMixin(object):
|
||||
def test_localhost(self):
|
||||
with ignore_deprecation():
|
||||
self.resolver.resolve('localhost', 80, callback=self.stop)
|
||||
result = self.wait()
|
||||
self.assertIn((socket.AF_INET, ('127.0.0.1', 80)), result)
|
||||
|
||||
@gen_test
|
||||
def test_future_interface(self):
|
||||
addrinfo = yield self.resolver.resolve('localhost', 80,
|
||||
socket.AF_UNSPEC)
|
||||
self.assertIn((socket.AF_INET, ('127.0.0.1', 80)),
|
||||
addrinfo)
|
||||
|
||||
|
||||
# It is impossible to quickly and consistently generate an error in name
|
||||
# resolution, so test this case separately, using mocks as needed.
|
||||
class _ResolverErrorTestMixin(object):
|
||||
def test_bad_host(self):
|
||||
def handler(exc_typ, exc_val, exc_tb):
|
||||
self.stop(exc_val)
|
||||
return True # Halt propagation.
|
||||
|
||||
with ignore_deprecation():
|
||||
with ExceptionStackContext(handler):
|
||||
self.resolver.resolve('an invalid domain', 80, callback=self.stop)
|
||||
|
||||
result = self.wait()
|
||||
self.assertIsInstance(result, Exception)
|
||||
|
||||
@gen_test
|
||||
def test_future_interface_bad_host(self):
|
||||
with self.assertRaises(IOError):
|
||||
yield self.resolver.resolve('an invalid domain', 80,
|
||||
socket.AF_UNSPEC)
|
||||
|
||||
|
||||
def _failing_getaddrinfo(*args):
|
||||
"""Dummy implementation of getaddrinfo for use in mocks"""
|
||||
raise socket.gaierror(errno.EIO, "mock: lookup failed")
|
||||
|
||||
|
||||
@skipIfNoNetwork
|
||||
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
|
||||
def setUp(self):
|
||||
super(BlockingResolverTest, self).setUp()
|
||||
self.resolver = BlockingResolver()
|
||||
|
||||
|
||||
# getaddrinfo-based tests need mocking to reliably generate errors;
|
||||
# some configurations are slow to produce errors and take longer than
|
||||
# our default timeout.
|
||||
class BlockingResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
|
||||
def setUp(self):
|
||||
super(BlockingResolverErrorTest, self).setUp()
|
||||
self.resolver = BlockingResolver()
|
||||
self.real_getaddrinfo = socket.getaddrinfo
|
||||
socket.getaddrinfo = _failing_getaddrinfo
|
||||
|
||||
def tearDown(self):
|
||||
socket.getaddrinfo = self.real_getaddrinfo
|
||||
super(BlockingResolverErrorTest, self).tearDown()
|
||||
|
||||
|
||||
class OverrideResolverTest(AsyncTestCase, _ResolverTestMixin):
|
||||
def setUp(self):
|
||||
super(OverrideResolverTest, self).setUp()
|
||||
mapping = {
|
||||
('google.com', 80): ('1.2.3.4', 80),
|
||||
('google.com', 80, socket.AF_INET): ('1.2.3.4', 80),
|
||||
('google.com', 80, socket.AF_INET6): ('2a02:6b8:7c:40c:c51e:495f:e23a:3', 80)
|
||||
}
|
||||
self.resolver = OverrideResolver(BlockingResolver(), mapping)
|
||||
|
||||
@gen_test
|
||||
def test_resolve_multiaddr(self):
|
||||
result = yield self.resolver.resolve('google.com', 80, socket.AF_INET)
|
||||
self.assertIn((socket.AF_INET, ('1.2.3.4', 80)), result)
|
||||
|
||||
result = yield self.resolver.resolve('google.com', 80, socket.AF_INET6)
|
||||
self.assertIn((socket.AF_INET6, ('2a02:6b8:7c:40c:c51e:495f:e23a:3', 80, 0, 0)), result)
|
||||
|
||||
|
||||
@skipIfNoNetwork
|
||||
@unittest.skipIf(futures is None, "futures module not present")
|
||||
class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
|
||||
def setUp(self):
|
||||
super(ThreadedResolverTest, self).setUp()
|
||||
self.resolver = ThreadedResolver()
|
||||
|
||||
def tearDown(self):
|
||||
self.resolver.close()
|
||||
super(ThreadedResolverTest, self).tearDown()
|
||||
|
||||
|
||||
class ThreadedResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
|
||||
def setUp(self):
|
||||
super(ThreadedResolverErrorTest, self).setUp()
|
||||
self.resolver = BlockingResolver()
|
||||
self.real_getaddrinfo = socket.getaddrinfo
|
||||
socket.getaddrinfo = _failing_getaddrinfo
|
||||
|
||||
def tearDown(self):
|
||||
socket.getaddrinfo = self.real_getaddrinfo
|
||||
super(ThreadedResolverErrorTest, self).tearDown()
|
||||
|
||||
|
||||
@skipIfNoNetwork
|
||||
@unittest.skipIf(futures is None, "futures module not present")
|
||||
@unittest.skipIf(sys.platform == 'win32', "preexec_fn not available on win32")
|
||||
class ThreadedResolverImportTest(unittest.TestCase):
|
||||
def test_import(self):
|
||||
TIMEOUT = 5
|
||||
|
||||
# Test for a deadlock when importing a module that runs the
|
||||
# ThreadedResolver at import-time. See resolve_test.py for
|
||||
# full explanation.
|
||||
command = [
|
||||
sys.executable,
|
||||
'-c',
|
||||
'import tornado.test.resolve_test_helper']
|
||||
|
||||
start = time.time()
|
||||
popen = Popen(command, preexec_fn=lambda: signal.alarm(TIMEOUT))
|
||||
while time.time() - start < TIMEOUT:
|
||||
return_code = popen.poll()
|
||||
if return_code is not None:
|
||||
self.assertEqual(0, return_code)
|
||||
return # Success.
|
||||
time.sleep(0.05)
|
||||
|
||||
self.fail("import timed out")
|
||||
|
||||
|
||||
# We do not test errors with CaresResolver:
|
||||
# Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
|
||||
# with an NXDOMAIN status code. Most resolvers treat this as an error;
|
||||
# C-ares returns the results, making the "bad_host" tests unreliable.
|
||||
# C-ares will try to resolve even malformed names, such as the
|
||||
# name with spaces used in this test.
|
||||
@skipIfNoNetwork
|
||||
@unittest.skipIf(pycares is None, "pycares module not present")
|
||||
class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
|
||||
def setUp(self):
|
||||
super(CaresResolverTest, self).setUp()
|
||||
self.resolver = CaresResolver()
|
||||
|
||||
|
||||
# TwistedResolver produces consistent errors in our test cases so we
|
||||
# could test the regular and error cases in the same class. However,
|
||||
# in the error cases it appears that cleanup of socket objects is
|
||||
# handled asynchronously and occasionally results in "unclosed socket"
|
||||
# warnings if not given time to shut down (and there is no way to
|
||||
# explicitly shut it down). This makes the test flaky, so we do not
|
||||
# test error cases here.
|
||||
@skipIfNoNetwork
|
||||
@unittest.skipIf(twisted is None, "twisted module not present")
|
||||
@unittest.skipIf(getattr(twisted, '__version__', '0.0') < "12.1", "old version of twisted")
|
||||
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin):
|
||||
def setUp(self):
|
||||
super(TwistedResolverTest, self).setUp()
|
||||
self.resolver = TwistedResolver()
|
||||
|
||||
|
||||
class IsValidIPTest(unittest.TestCase):
|
||||
def test_is_valid_ip(self):
|
||||
self.assertTrue(is_valid_ip('127.0.0.1'))
|
||||
self.assertTrue(is_valid_ip('4.4.4.4'))
|
||||
self.assertTrue(is_valid_ip('::1'))
|
||||
self.assertTrue(is_valid_ip('2620:0:1cfe:face:b00c::3'))
|
||||
self.assertTrue(not is_valid_ip('www.google.com'))
|
||||
self.assertTrue(not is_valid_ip('localhost'))
|
||||
self.assertTrue(not is_valid_ip('4.4.4.4<'))
|
||||
self.assertTrue(not is_valid_ip(' 127.0.0.1'))
|
||||
self.assertTrue(not is_valid_ip(''))
|
||||
self.assertTrue(not is_valid_ip(' '))
|
||||
self.assertTrue(not is_valid_ip('\n'))
|
||||
self.assertTrue(not is_valid_ip('\x00'))
|
||||
|
||||
|
||||
class TestPortAllocation(unittest.TestCase):
|
||||
def test_same_port_allocation(self):
|
||||
if 'TRAVIS' in os.environ:
|
||||
self.skipTest("dual-stack servers often have port conflicts on travis")
|
||||
sockets = bind_sockets(None, 'localhost')
|
||||
try:
|
||||
port = sockets[0].getsockname()[1]
|
||||
self.assertTrue(all(s.getsockname()[1] == port
|
||||
for s in sockets[1:]))
|
||||
finally:
|
||||
for sock in sockets:
|
||||
sock.close()
|
||||
|
||||
@unittest.skipIf(not hasattr(socket, "SO_REUSEPORT"), "SO_REUSEPORT is not supported")
|
||||
def test_reuse_port(self):
|
||||
sockets = []
|
||||
socket, port = bind_unused_port(reuse_port=True)
|
||||
try:
|
||||
sockets = bind_sockets(port, '127.0.0.1', reuse_port=True)
|
||||
self.assertTrue(all(s.getsockname()[1] == port for s in sockets))
|
||||
finally:
|
||||
socket.close()
|
||||
for sock in sockets:
|
||||
sock.close()
|
||||
Executable
+7
@@ -0,0 +1,7 @@
|
||||
port=443
|
||||
port=443
|
||||
username='李康'
|
||||
|
||||
foo_bar='a'
|
||||
|
||||
my_path = __file__
|
||||
Executable
+327
@@ -0,0 +1,327 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import sys
|
||||
|
||||
from tornado.options import OptionParser, Error
|
||||
from tornado.util import basestring_type, PY3
|
||||
from tornado.test.util import unittest, subTest
|
||||
|
||||
if PY3:
|
||||
from io import StringIO
|
||||
else:
|
||||
from cStringIO import StringIO
|
||||
|
||||
try:
|
||||
# py33+
|
||||
from unittest import mock # type: ignore
|
||||
except ImportError:
|
||||
try:
|
||||
import mock # type: ignore
|
||||
except ImportError:
|
||||
mock = None
|
||||
|
||||
|
||||
class Email(object):
|
||||
def __init__(self, value):
|
||||
if isinstance(value, str) and '@' in value:
|
||||
self._value = value
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
|
||||
class OptionsTest(unittest.TestCase):
|
||||
def test_parse_command_line(self):
|
||||
options = OptionParser()
|
||||
options.define("port", default=80)
|
||||
options.parse_command_line(["main.py", "--port=443"])
|
||||
self.assertEqual(options.port, 443)
|
||||
|
||||
def test_parse_config_file(self):
|
||||
options = OptionParser()
|
||||
options.define("port", default=80)
|
||||
options.define("username", default='foo')
|
||||
options.define("my_path")
|
||||
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
"options_test.cfg")
|
||||
options.parse_config_file(config_path)
|
||||
self.assertEqual(options.port, 443)
|
||||
self.assertEqual(options.username, "李康")
|
||||
self.assertEqual(options.my_path, config_path)
|
||||
|
||||
def test_parse_callbacks(self):
|
||||
options = OptionParser()
|
||||
self.called = False
|
||||
|
||||
def callback():
|
||||
self.called = True
|
||||
options.add_parse_callback(callback)
|
||||
|
||||
# non-final parse doesn't run callbacks
|
||||
options.parse_command_line(["main.py"], final=False)
|
||||
self.assertFalse(self.called)
|
||||
|
||||
# final parse does
|
||||
options.parse_command_line(["main.py"])
|
||||
self.assertTrue(self.called)
|
||||
|
||||
# callbacks can be run more than once on the same options
|
||||
# object if there are multiple final parses
|
||||
self.called = False
|
||||
options.parse_command_line(["main.py"])
|
||||
self.assertTrue(self.called)
|
||||
|
||||
def test_help(self):
|
||||
options = OptionParser()
|
||||
try:
|
||||
orig_stderr = sys.stderr
|
||||
sys.stderr = StringIO()
|
||||
with self.assertRaises(SystemExit):
|
||||
options.parse_command_line(["main.py", "--help"])
|
||||
usage = sys.stderr.getvalue()
|
||||
finally:
|
||||
sys.stderr = orig_stderr
|
||||
self.assertIn("Usage:", usage)
|
||||
|
||||
def test_subcommand(self):
|
||||
base_options = OptionParser()
|
||||
base_options.define("verbose", default=False)
|
||||
sub_options = OptionParser()
|
||||
sub_options.define("foo", type=str)
|
||||
rest = base_options.parse_command_line(
|
||||
["main.py", "--verbose", "subcommand", "--foo=bar"])
|
||||
self.assertEqual(rest, ["subcommand", "--foo=bar"])
|
||||
self.assertTrue(base_options.verbose)
|
||||
rest2 = sub_options.parse_command_line(rest)
|
||||
self.assertEqual(rest2, [])
|
||||
self.assertEqual(sub_options.foo, "bar")
|
||||
|
||||
# the two option sets are distinct
|
||||
try:
|
||||
orig_stderr = sys.stderr
|
||||
sys.stderr = StringIO()
|
||||
with self.assertRaises(Error):
|
||||
sub_options.parse_command_line(["subcommand", "--verbose"])
|
||||
finally:
|
||||
sys.stderr = orig_stderr
|
||||
|
||||
def test_setattr(self):
|
||||
options = OptionParser()
|
||||
options.define('foo', default=1, type=int)
|
||||
options.foo = 2
|
||||
self.assertEqual(options.foo, 2)
|
||||
|
||||
def test_setattr_type_check(self):
|
||||
# setattr requires that options be the right type and doesn't
|
||||
# parse from string formats.
|
||||
options = OptionParser()
|
||||
options.define('foo', default=1, type=int)
|
||||
with self.assertRaises(Error):
|
||||
options.foo = '2'
|
||||
|
||||
def test_setattr_with_callback(self):
|
||||
values = []
|
||||
options = OptionParser()
|
||||
options.define('foo', default=1, type=int, callback=values.append)
|
||||
options.foo = 2
|
||||
self.assertEqual(values, [2])
|
||||
|
||||
def _sample_options(self):
|
||||
options = OptionParser()
|
||||
options.define('a', default=1)
|
||||
options.define('b', default=2)
|
||||
return options
|
||||
|
||||
def test_iter(self):
|
||||
options = self._sample_options()
|
||||
# OptionParsers always define 'help'.
|
||||
self.assertEqual(set(['a', 'b', 'help']), set(iter(options)))
|
||||
|
||||
def test_getitem(self):
|
||||
options = self._sample_options()
|
||||
self.assertEqual(1, options['a'])
|
||||
|
||||
def test_setitem(self):
|
||||
options = OptionParser()
|
||||
options.define('foo', default=1, type=int)
|
||||
options['foo'] = 2
|
||||
self.assertEqual(options['foo'], 2)
|
||||
|
||||
def test_items(self):
|
||||
options = self._sample_options()
|
||||
# OptionParsers always define 'help'.
|
||||
expected = [('a', 1), ('b', 2), ('help', options.help)]
|
||||
actual = sorted(options.items())
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_as_dict(self):
|
||||
options = self._sample_options()
|
||||
expected = {'a': 1, 'b': 2, 'help': options.help}
|
||||
self.assertEqual(expected, options.as_dict())
|
||||
|
||||
def test_group_dict(self):
|
||||
options = OptionParser()
|
||||
options.define('a', default=1)
|
||||
options.define('b', group='b_group', default=2)
|
||||
|
||||
frame = sys._getframe(0)
|
||||
this_file = frame.f_code.co_filename
|
||||
self.assertEqual(set(['b_group', '', this_file]), options.groups())
|
||||
|
||||
b_group_dict = options.group_dict('b_group')
|
||||
self.assertEqual({'b': 2}, b_group_dict)
|
||||
|
||||
self.assertEqual({}, options.group_dict('nonexistent'))
|
||||
|
||||
@unittest.skipIf(mock is None, 'mock package not present')
|
||||
def test_mock_patch(self):
|
||||
# ensure that our setattr hooks don't interfere with mock.patch
|
||||
options = OptionParser()
|
||||
options.define('foo', default=1)
|
||||
options.parse_command_line(['main.py', '--foo=2'])
|
||||
self.assertEqual(options.foo, 2)
|
||||
|
||||
with mock.patch.object(options.mockable(), 'foo', 3):
|
||||
self.assertEqual(options.foo, 3)
|
||||
self.assertEqual(options.foo, 2)
|
||||
|
||||
# Try nested patches mixed with explicit sets
|
||||
with mock.patch.object(options.mockable(), 'foo', 4):
|
||||
self.assertEqual(options.foo, 4)
|
||||
options.foo = 5
|
||||
self.assertEqual(options.foo, 5)
|
||||
with mock.patch.object(options.mockable(), 'foo', 6):
|
||||
self.assertEqual(options.foo, 6)
|
||||
self.assertEqual(options.foo, 5)
|
||||
self.assertEqual(options.foo, 2)
|
||||
|
||||
def _define_options(self):
|
||||
options = OptionParser()
|
||||
options.define('str', type=str)
|
||||
options.define('basestring', type=basestring_type)
|
||||
options.define('int', type=int)
|
||||
options.define('float', type=float)
|
||||
options.define('datetime', type=datetime.datetime)
|
||||
options.define('timedelta', type=datetime.timedelta)
|
||||
options.define('email', type=Email)
|
||||
options.define('list-of-int', type=int, multiple=True)
|
||||
return options
|
||||
|
||||
def _check_options_values(self, options):
|
||||
self.assertEqual(options.str, 'asdf')
|
||||
self.assertEqual(options.basestring, 'qwer')
|
||||
self.assertEqual(options.int, 42)
|
||||
self.assertEqual(options.float, 1.5)
|
||||
self.assertEqual(options.datetime,
|
||||
datetime.datetime(2013, 4, 28, 5, 16))
|
||||
self.assertEqual(options.timedelta, datetime.timedelta(seconds=45))
|
||||
self.assertEqual(options.email.value, 'tornado@web.com')
|
||||
self.assertTrue(isinstance(options.email, Email))
|
||||
self.assertEqual(options.list_of_int, [1, 2, 3])
|
||||
|
||||
def test_types(self):
|
||||
options = self._define_options()
|
||||
options.parse_command_line(['main.py',
|
||||
'--str=asdf',
|
||||
'--basestring=qwer',
|
||||
'--int=42',
|
||||
'--float=1.5',
|
||||
'--datetime=2013-04-28 05:16',
|
||||
'--timedelta=45s',
|
||||
'--email=tornado@web.com',
|
||||
'--list-of-int=1,2,3'])
|
||||
self._check_options_values(options)
|
||||
|
||||
def test_types_with_conf_file(self):
|
||||
for config_file_name in ("options_test_types.cfg",
|
||||
"options_test_types_str.cfg"):
|
||||
options = self._define_options()
|
||||
options.parse_config_file(os.path.join(os.path.dirname(__file__),
|
||||
config_file_name))
|
||||
self._check_options_values(options)
|
||||
|
||||
def test_multiple_string(self):
|
||||
options = OptionParser()
|
||||
options.define('foo', type=str, multiple=True)
|
||||
options.parse_command_line(['main.py', '--foo=a,b,c'])
|
||||
self.assertEqual(options.foo, ['a', 'b', 'c'])
|
||||
|
||||
def test_multiple_int(self):
|
||||
options = OptionParser()
|
||||
options.define('foo', type=int, multiple=True)
|
||||
options.parse_command_line(['main.py', '--foo=1,3,5:7'])
|
||||
self.assertEqual(options.foo, [1, 3, 5, 6, 7])
|
||||
|
||||
def test_error_redefine(self):
|
||||
options = OptionParser()
|
||||
options.define('foo')
|
||||
with self.assertRaises(Error) as cm:
|
||||
options.define('foo')
|
||||
self.assertRegexpMatches(str(cm.exception),
|
||||
'Option.*foo.*already defined')
|
||||
|
||||
def test_error_redefine_underscore(self):
|
||||
# Ensure that the dash/underscore normalization doesn't
|
||||
# interfere with the redefinition error.
|
||||
tests = [
|
||||
('foo-bar', 'foo-bar'),
|
||||
('foo_bar', 'foo_bar'),
|
||||
('foo-bar', 'foo_bar'),
|
||||
('foo_bar', 'foo-bar'),
|
||||
]
|
||||
for a, b in tests:
|
||||
with subTest(self, a=a, b=b):
|
||||
options = OptionParser()
|
||||
options.define(a)
|
||||
with self.assertRaises(Error) as cm:
|
||||
options.define(b)
|
||||
self.assertRegexpMatches(str(cm.exception),
|
||||
'Option.*foo.bar.*already defined')
|
||||
|
||||
def test_dash_underscore_cli(self):
|
||||
# Dashes and underscores should be interchangeable.
|
||||
for defined_name in ['foo-bar', 'foo_bar']:
|
||||
for flag in ['--foo-bar=a', '--foo_bar=a']:
|
||||
options = OptionParser()
|
||||
options.define(defined_name)
|
||||
options.parse_command_line(['main.py', flag])
|
||||
# Attr-style access always uses underscores.
|
||||
self.assertEqual(options.foo_bar, 'a')
|
||||
# Dict-style access allows both.
|
||||
self.assertEqual(options['foo-bar'], 'a')
|
||||
self.assertEqual(options['foo_bar'], 'a')
|
||||
|
||||
def test_dash_underscore_file(self):
|
||||
# No matter how an option was defined, it can be set with underscores
|
||||
# in a config file.
|
||||
for defined_name in ['foo-bar', 'foo_bar']:
|
||||
options = OptionParser()
|
||||
options.define(defined_name)
|
||||
options.parse_config_file(os.path.join(os.path.dirname(__file__),
|
||||
"options_test.cfg"))
|
||||
self.assertEqual(options.foo_bar, 'a')
|
||||
|
||||
def test_dash_underscore_introspection(self):
|
||||
# Original names are preserved in introspection APIs.
|
||||
options = OptionParser()
|
||||
options.define('with-dash', group='g')
|
||||
options.define('with_underscore', group='g')
|
||||
all_options = ['help', 'with-dash', 'with_underscore']
|
||||
self.assertEqual(sorted(options), all_options)
|
||||
self.assertEqual(sorted(k for (k, v) in options.items()), all_options)
|
||||
self.assertEqual(sorted(options.as_dict().keys()), all_options)
|
||||
|
||||
self.assertEqual(sorted(options.group_dict('g')),
|
||||
['with-dash', 'with_underscore'])
|
||||
|
||||
# --help shows CLI-style names with dashes.
|
||||
buf = StringIO()
|
||||
options.print_help(buf)
|
||||
self.assertIn('--with-dash', buf.getvalue())
|
||||
self.assertIn('--with-underscore', buf.getvalue())
|
||||
Executable
+11
@@ -0,0 +1,11 @@
|
||||
from datetime import datetime, timedelta
|
||||
from tornado.test.options_test import Email
|
||||
|
||||
str = 'asdf'
|
||||
basestring = 'qwer'
|
||||
int = 42
|
||||
float = 1.5
|
||||
datetime = datetime(2013, 4, 28, 5, 16)
|
||||
timedelta = timedelta(0, 45)
|
||||
email = Email('tornado@web.com')
|
||||
list_of_int = [1, 2, 3]
|
||||
Executable
+8
@@ -0,0 +1,8 @@
|
||||
str = 'asdf'
|
||||
basestring = 'qwer'
|
||||
int = 42
|
||||
float = 1.5
|
||||
datetime = '2013-04-28 05:16'
|
||||
timedelta = '45s'
|
||||
email = 'tornado@web.com'
|
||||
list_of_int = '1,2,3'
|
||||
Executable
+266
@@ -0,0 +1,266 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from tornado.httpclient import HTTPClient, HTTPError
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.log import gen_log
|
||||
from tornado.process import fork_processes, task_id, Subprocess
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient
|
||||
from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase, gen_test
|
||||
from tornado.test.util import unittest, skipIfNonUnix
|
||||
from tornado.web import RequestHandler, Application
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError:
|
||||
asyncio = None
|
||||
|
||||
|
||||
def skip_if_twisted():
|
||||
if IOLoop.configured_class().__name__.endswith('TwistedIOLoop'):
|
||||
raise unittest.SkipTest("Process tests not compatible with TwistedIOLoop")
|
||||
|
||||
# Not using AsyncHTTPTestCase because we need control over the IOLoop.
|
||||
|
||||
|
||||
@skipIfNonUnix
|
||||
class ProcessTest(unittest.TestCase):
|
||||
def get_app(self):
|
||||
class ProcessHandler(RequestHandler):
|
||||
def get(self):
|
||||
if self.get_argument("exit", None):
|
||||
# must use os._exit instead of sys.exit so unittest's
|
||||
# exception handler doesn't catch it
|
||||
os._exit(int(self.get_argument("exit")))
|
||||
if self.get_argument("signal", None):
|
||||
os.kill(os.getpid(),
|
||||
int(self.get_argument("signal")))
|
||||
self.write(str(os.getpid()))
|
||||
return Application([("/", ProcessHandler)])
|
||||
|
||||
def tearDown(self):
|
||||
if task_id() is not None:
|
||||
# We're in a child process, and probably got to this point
|
||||
# via an uncaught exception. If we return now, both
|
||||
# processes will continue with the rest of the test suite.
|
||||
# Exit now so the parent process will restart the child
|
||||
# (since we don't have a clean way to signal failure to
|
||||
# the parent that won't restart)
|
||||
logging.error("aborting child process from tearDown")
|
||||
logging.shutdown()
|
||||
os._exit(1)
|
||||
# In the surviving process, clear the alarm we set earlier
|
||||
signal.alarm(0)
|
||||
super(ProcessTest, self).tearDown()
|
||||
|
||||
def test_multi_process(self):
|
||||
# This test doesn't work on twisted because we use the global
|
||||
# reactor and don't restore it to a sane state after the fork
|
||||
# (asyncio has the same issue, but we have a special case in
|
||||
# place for it).
|
||||
skip_if_twisted()
|
||||
with ExpectLog(gen_log, "(Starting .* processes|child .* exited|uncaught exception)"):
|
||||
sock, port = bind_unused_port()
|
||||
|
||||
def get_url(path):
|
||||
return "http://127.0.0.1:%d%s" % (port, path)
|
||||
# ensure that none of these processes live too long
|
||||
signal.alarm(5) # master process
|
||||
try:
|
||||
id = fork_processes(3, max_restarts=3)
|
||||
self.assertTrue(id is not None)
|
||||
signal.alarm(5) # child processes
|
||||
except SystemExit as e:
|
||||
# if we exit cleanly from fork_processes, all the child processes
|
||||
# finished with status 0
|
||||
self.assertEqual(e.code, 0)
|
||||
self.assertTrue(task_id() is None)
|
||||
sock.close()
|
||||
return
|
||||
try:
|
||||
if asyncio is not None:
|
||||
# Reset the global asyncio event loop, which was put into
|
||||
# a broken state by the fork.
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
if id in (0, 1):
|
||||
self.assertEqual(id, task_id())
|
||||
server = HTTPServer(self.get_app())
|
||||
server.add_sockets([sock])
|
||||
IOLoop.current().start()
|
||||
elif id == 2:
|
||||
self.assertEqual(id, task_id())
|
||||
sock.close()
|
||||
# Always use SimpleAsyncHTTPClient here; the curl
|
||||
# version appears to get confused sometimes if the
|
||||
# connection gets closed before it's had a chance to
|
||||
# switch from writing mode to reading mode.
|
||||
client = HTTPClient(SimpleAsyncHTTPClient)
|
||||
|
||||
def fetch(url, fail_ok=False):
|
||||
try:
|
||||
return client.fetch(get_url(url))
|
||||
except HTTPError as e:
|
||||
if not (fail_ok and e.code == 599):
|
||||
raise
|
||||
|
||||
# Make two processes exit abnormally
|
||||
fetch("/?exit=2", fail_ok=True)
|
||||
fetch("/?exit=3", fail_ok=True)
|
||||
|
||||
# They've been restarted, so a new fetch will work
|
||||
int(fetch("/").body)
|
||||
|
||||
# Now the same with signals
|
||||
# Disabled because on the mac a process dying with a signal
|
||||
# can trigger an "Application exited abnormally; send error
|
||||
# report to Apple?" prompt.
|
||||
# fetch("/?signal=%d" % signal.SIGTERM, fail_ok=True)
|
||||
# fetch("/?signal=%d" % signal.SIGABRT, fail_ok=True)
|
||||
# int(fetch("/").body)
|
||||
|
||||
# Now kill them normally so they won't be restarted
|
||||
fetch("/?exit=0", fail_ok=True)
|
||||
# One process left; watch it's pid change
|
||||
pid = int(fetch("/").body)
|
||||
fetch("/?exit=4", fail_ok=True)
|
||||
pid2 = int(fetch("/").body)
|
||||
self.assertNotEqual(pid, pid2)
|
||||
|
||||
# Kill the last one so we shut down cleanly
|
||||
fetch("/?exit=0", fail_ok=True)
|
||||
|
||||
os._exit(0)
|
||||
except Exception:
|
||||
logging.error("exception in child process %d", id, exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@skipIfNonUnix
|
||||
class SubprocessTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_subprocess(self):
|
||||
if IOLoop.configured_class().__name__.endswith('LayeredTwistedIOLoop'):
|
||||
# This test fails non-deterministically with LayeredTwistedIOLoop.
|
||||
# (the read_until('\n') returns '\n' instead of 'hello\n')
|
||||
# This probably indicates a problem with either TornadoReactor
|
||||
# or TwistedIOLoop, but I haven't been able to track it down
|
||||
# and for now this is just causing spurious travis-ci failures.
|
||||
raise unittest.SkipTest("Subprocess tests not compatible with "
|
||||
"LayeredTwistedIOLoop")
|
||||
subproc = Subprocess([sys.executable, '-u', '-i'],
|
||||
stdin=Subprocess.STREAM,
|
||||
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT)
|
||||
self.addCleanup(lambda: (subproc.proc.terminate(), subproc.proc.wait()))
|
||||
self.addCleanup(subproc.stdout.close)
|
||||
self.addCleanup(subproc.stdin.close)
|
||||
yield subproc.stdout.read_until(b'>>> ')
|
||||
subproc.stdin.write(b"print('hello')\n")
|
||||
data = yield subproc.stdout.read_until(b'\n')
|
||||
self.assertEqual(data, b"hello\n")
|
||||
|
||||
yield subproc.stdout.read_until(b">>> ")
|
||||
subproc.stdin.write(b"raise SystemExit\n")
|
||||
data = yield subproc.stdout.read_until_close()
|
||||
self.assertEqual(data, b"")
|
||||
|
||||
@gen_test
|
||||
def test_close_stdin(self):
|
||||
# Close the parent's stdin handle and see that the child recognizes it.
|
||||
subproc = Subprocess([sys.executable, '-u', '-i'],
|
||||
stdin=Subprocess.STREAM,
|
||||
stdout=Subprocess.STREAM, stderr=subprocess.STDOUT)
|
||||
self.addCleanup(lambda: (subproc.proc.terminate(), subproc.proc.wait()))
|
||||
yield subproc.stdout.read_until(b'>>> ')
|
||||
subproc.stdin.close()
|
||||
data = yield subproc.stdout.read_until_close()
|
||||
self.assertEqual(data, b"\n")
|
||||
|
||||
@gen_test
|
||||
def test_stderr(self):
|
||||
# This test is mysteriously flaky on twisted: it succeeds, but logs
|
||||
# an error of EBADF on closing a file descriptor.
|
||||
skip_if_twisted()
|
||||
subproc = Subprocess([sys.executable, '-u', '-c',
|
||||
r"import sys; sys.stderr.write('hello\n')"],
|
||||
stderr=Subprocess.STREAM)
|
||||
self.addCleanup(lambda: (subproc.proc.terminate(), subproc.proc.wait()))
|
||||
data = yield subproc.stderr.read_until(b'\n')
|
||||
self.assertEqual(data, b'hello\n')
|
||||
# More mysterious EBADF: This fails if done with self.addCleanup instead of here.
|
||||
subproc.stderr.close()
|
||||
|
||||
def test_sigchild(self):
|
||||
# Twisted's SIGCHLD handler and Subprocess's conflict with each other.
|
||||
skip_if_twisted()
|
||||
Subprocess.initialize()
|
||||
self.addCleanup(Subprocess.uninitialize)
|
||||
subproc = Subprocess([sys.executable, '-c', 'pass'])
|
||||
subproc.set_exit_callback(self.stop)
|
||||
ret = self.wait()
|
||||
self.assertEqual(ret, 0)
|
||||
self.assertEqual(subproc.returncode, ret)
|
||||
|
||||
@gen_test
|
||||
def test_sigchild_future(self):
|
||||
skip_if_twisted()
|
||||
Subprocess.initialize()
|
||||
self.addCleanup(Subprocess.uninitialize)
|
||||
subproc = Subprocess([sys.executable, '-c', 'pass'])
|
||||
ret = yield subproc.wait_for_exit()
|
||||
self.assertEqual(ret, 0)
|
||||
self.assertEqual(subproc.returncode, ret)
|
||||
|
||||
def test_sigchild_signal(self):
|
||||
skip_if_twisted()
|
||||
Subprocess.initialize()
|
||||
self.addCleanup(Subprocess.uninitialize)
|
||||
subproc = Subprocess([sys.executable, '-c',
|
||||
'import time; time.sleep(30)'],
|
||||
stdout=Subprocess.STREAM)
|
||||
self.addCleanup(subproc.stdout.close)
|
||||
subproc.set_exit_callback(self.stop)
|
||||
os.kill(subproc.pid, signal.SIGTERM)
|
||||
try:
|
||||
ret = self.wait(timeout=1.0)
|
||||
except AssertionError:
|
||||
# We failed to get the termination signal. This test is
|
||||
# occasionally flaky on pypy, so try to get a little more
|
||||
# information: did the process close its stdout
|
||||
# (indicating that the problem is in the parent process's
|
||||
# signal handling) or did the child process somehow fail
|
||||
# to terminate?
|
||||
subproc.stdout.read_until_close(callback=self.stop)
|
||||
try:
|
||||
self.wait(timeout=1.0)
|
||||
except AssertionError:
|
||||
raise AssertionError("subprocess failed to terminate")
|
||||
else:
|
||||
raise AssertionError("subprocess closed stdout but failed to "
|
||||
"get termination signal")
|
||||
self.assertEqual(subproc.returncode, ret)
|
||||
self.assertEqual(ret, -signal.SIGTERM)
|
||||
|
||||
@gen_test
|
||||
def test_wait_for_exit_raise(self):
|
||||
skip_if_twisted()
|
||||
Subprocess.initialize()
|
||||
self.addCleanup(Subprocess.uninitialize)
|
||||
subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)'])
|
||||
with self.assertRaises(subprocess.CalledProcessError) as cm:
|
||||
yield subproc.wait_for_exit()
|
||||
self.assertEqual(cm.exception.returncode, 1)
|
||||
|
||||
@gen_test
|
||||
def test_wait_for_exit_raise_disabled(self):
|
||||
skip_if_twisted()
|
||||
Subprocess.initialize()
|
||||
self.addCleanup(Subprocess.uninitialize)
|
||||
subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)'])
|
||||
ret = yield subproc.wait_for_exit(raise_error=False)
|
||||
self.assertEqual(ret, 1)
|
||||
Executable
+423
@@ -0,0 +1,423 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
from datetime import timedelta
|
||||
from random import random
|
||||
|
||||
from tornado import gen, queues
|
||||
from tornado.gen import TimeoutError
|
||||
from tornado.testing import gen_test, AsyncTestCase
|
||||
from tornado.test.util import unittest, skipBefore35, exec_test
|
||||
|
||||
|
||||
class QueueBasicTest(AsyncTestCase):
|
||||
def test_repr_and_str(self):
|
||||
q = queues.Queue(maxsize=1)
|
||||
self.assertIn(hex(id(q)), repr(q))
|
||||
self.assertNotIn(hex(id(q)), str(q))
|
||||
q.get()
|
||||
|
||||
for q_str in repr(q), str(q):
|
||||
self.assertTrue(q_str.startswith('<Queue'))
|
||||
self.assertIn('maxsize=1', q_str)
|
||||
self.assertIn('getters[1]', q_str)
|
||||
self.assertNotIn('putters', q_str)
|
||||
self.assertNotIn('tasks', q_str)
|
||||
|
||||
q.put(None)
|
||||
q.put(None)
|
||||
# Now the queue is full, this putter blocks.
|
||||
q.put(None)
|
||||
|
||||
for q_str in repr(q), str(q):
|
||||
self.assertNotIn('getters', q_str)
|
||||
self.assertIn('putters[1]', q_str)
|
||||
self.assertIn('tasks=2', q_str)
|
||||
|
||||
def test_order(self):
|
||||
q = queues.Queue()
|
||||
for i in [1, 3, 2]:
|
||||
q.put_nowait(i)
|
||||
|
||||
items = [q.get_nowait() for _ in range(3)]
|
||||
self.assertEqual([1, 3, 2], items)
|
||||
|
||||
@gen_test
|
||||
def test_maxsize(self):
|
||||
self.assertRaises(TypeError, queues.Queue, maxsize=None)
|
||||
self.assertRaises(ValueError, queues.Queue, maxsize=-1)
|
||||
|
||||
q = queues.Queue(maxsize=2)
|
||||
self.assertTrue(q.empty())
|
||||
self.assertFalse(q.full())
|
||||
self.assertEqual(2, q.maxsize)
|
||||
self.assertTrue(q.put(0).done())
|
||||
self.assertTrue(q.put(1).done())
|
||||
self.assertFalse(q.empty())
|
||||
self.assertTrue(q.full())
|
||||
put2 = q.put(2)
|
||||
self.assertFalse(put2.done())
|
||||
self.assertEqual(0, (yield q.get())) # Make room.
|
||||
self.assertTrue(put2.done())
|
||||
self.assertFalse(q.empty())
|
||||
self.assertTrue(q.full())
|
||||
|
||||
|
||||
class QueueGetTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_blocking_get(self):
|
||||
q = queues.Queue()
|
||||
q.put_nowait(0)
|
||||
self.assertEqual(0, (yield q.get()))
|
||||
|
||||
def test_nonblocking_get(self):
|
||||
q = queues.Queue()
|
||||
q.put_nowait(0)
|
||||
self.assertEqual(0, q.get_nowait())
|
||||
|
||||
def test_nonblocking_get_exception(self):
|
||||
q = queues.Queue()
|
||||
self.assertRaises(queues.QueueEmpty, q.get_nowait)
|
||||
|
||||
@gen_test
|
||||
def test_get_with_putters(self):
|
||||
q = queues.Queue(1)
|
||||
q.put_nowait(0)
|
||||
put = q.put(1)
|
||||
self.assertEqual(0, (yield q.get()))
|
||||
self.assertIsNone((yield put))
|
||||
|
||||
@gen_test
|
||||
def test_blocking_get_wait(self):
|
||||
q = queues.Queue()
|
||||
q.put(0)
|
||||
self.io_loop.call_later(0.01, q.put, 1)
|
||||
self.io_loop.call_later(0.02, q.put, 2)
|
||||
self.assertEqual(0, (yield q.get(timeout=timedelta(seconds=1))))
|
||||
self.assertEqual(1, (yield q.get(timeout=timedelta(seconds=1))))
|
||||
|
||||
@gen_test
|
||||
def test_get_timeout(self):
|
||||
q = queues.Queue()
|
||||
get_timeout = q.get(timeout=timedelta(seconds=0.01))
|
||||
get = q.get()
|
||||
with self.assertRaises(TimeoutError):
|
||||
yield get_timeout
|
||||
|
||||
q.put_nowait(0)
|
||||
self.assertEqual(0, (yield get))
|
||||
|
||||
@gen_test
|
||||
def test_get_timeout_preempted(self):
|
||||
q = queues.Queue()
|
||||
get = q.get(timeout=timedelta(seconds=0.01))
|
||||
q.put(0)
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(0, (yield get))
|
||||
|
||||
@gen_test
|
||||
def test_get_clears_timed_out_putters(self):
|
||||
q = queues.Queue(1)
|
||||
# First putter succeeds, remainder block.
|
||||
putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
|
||||
put = q.put(10)
|
||||
self.assertEqual(10, len(q._putters))
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(10, len(q._putters))
|
||||
self.assertFalse(put.done()) # Final waiter is still active.
|
||||
q.put(11)
|
||||
self.assertEqual(0, (yield q.get())) # get() clears the waiters.
|
||||
self.assertEqual(1, len(q._putters))
|
||||
for putter in putters[1:]:
|
||||
self.assertRaises(TimeoutError, putter.result)
|
||||
|
||||
@gen_test
|
||||
def test_get_clears_timed_out_getters(self):
|
||||
q = queues.Queue()
|
||||
getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)]
|
||||
get = q.get()
|
||||
self.assertEqual(11, len(q._getters))
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(11, len(q._getters))
|
||||
self.assertFalse(get.done()) # Final waiter is still active.
|
||||
q.get() # get() clears the waiters.
|
||||
self.assertEqual(2, len(q._getters))
|
||||
for getter in getters:
|
||||
self.assertRaises(TimeoutError, getter.result)
|
||||
|
||||
@skipBefore35
|
||||
@gen_test
|
||||
def test_async_for(self):
|
||||
q = queues.Queue()
|
||||
for i in range(5):
|
||||
q.put(i)
|
||||
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def f():
|
||||
results = []
|
||||
async for i in q:
|
||||
results.append(i)
|
||||
if i == 4:
|
||||
return results
|
||||
""")
|
||||
results = yield namespace['f']()
|
||||
self.assertEqual(results, list(range(5)))
|
||||
|
||||
|
||||
class QueuePutTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_blocking_put(self):
|
||||
q = queues.Queue()
|
||||
q.put(0)
|
||||
self.assertEqual(0, q.get_nowait())
|
||||
|
||||
def test_nonblocking_put_exception(self):
|
||||
q = queues.Queue(1)
|
||||
q.put(0)
|
||||
self.assertRaises(queues.QueueFull, q.put_nowait, 1)
|
||||
|
||||
@gen_test
|
||||
def test_put_with_getters(self):
|
||||
q = queues.Queue()
|
||||
get0 = q.get()
|
||||
get1 = q.get()
|
||||
yield q.put(0)
|
||||
self.assertEqual(0, (yield get0))
|
||||
yield q.put(1)
|
||||
self.assertEqual(1, (yield get1))
|
||||
|
||||
@gen_test
|
||||
def test_nonblocking_put_with_getters(self):
|
||||
q = queues.Queue()
|
||||
get0 = q.get()
|
||||
get1 = q.get()
|
||||
q.put_nowait(0)
|
||||
# put_nowait does *not* immediately unblock getters.
|
||||
yield gen.moment
|
||||
self.assertEqual(0, (yield get0))
|
||||
q.put_nowait(1)
|
||||
yield gen.moment
|
||||
self.assertEqual(1, (yield get1))
|
||||
|
||||
@gen_test
|
||||
def test_blocking_put_wait(self):
|
||||
q = queues.Queue(1)
|
||||
q.put_nowait(0)
|
||||
self.io_loop.call_later(0.01, q.get)
|
||||
self.io_loop.call_later(0.02, q.get)
|
||||
futures = [q.put(0), q.put(1)]
|
||||
self.assertFalse(any(f.done() for f in futures))
|
||||
yield futures
|
||||
|
||||
@gen_test
|
||||
def test_put_timeout(self):
|
||||
q = queues.Queue(1)
|
||||
q.put_nowait(0) # Now it's full.
|
||||
put_timeout = q.put(1, timeout=timedelta(seconds=0.01))
|
||||
put = q.put(2)
|
||||
with self.assertRaises(TimeoutError):
|
||||
yield put_timeout
|
||||
|
||||
self.assertEqual(0, q.get_nowait())
|
||||
# 1 was never put in the queue.
|
||||
self.assertEqual(2, (yield q.get()))
|
||||
|
||||
# Final get() unblocked this putter.
|
||||
yield put
|
||||
|
||||
@gen_test
|
||||
def test_put_timeout_preempted(self):
|
||||
q = queues.Queue(1)
|
||||
q.put_nowait(0)
|
||||
put = q.put(1, timeout=timedelta(seconds=0.01))
|
||||
q.get()
|
||||
yield gen.sleep(0.02)
|
||||
yield put # No TimeoutError.
|
||||
|
||||
@gen_test
|
||||
def test_put_clears_timed_out_putters(self):
|
||||
q = queues.Queue(1)
|
||||
# First putter succeeds, remainder block.
|
||||
putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
|
||||
put = q.put(10)
|
||||
self.assertEqual(10, len(q._putters))
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(10, len(q._putters))
|
||||
self.assertFalse(put.done()) # Final waiter is still active.
|
||||
q.put(11) # put() clears the waiters.
|
||||
self.assertEqual(2, len(q._putters))
|
||||
for putter in putters[1:]:
|
||||
self.assertRaises(TimeoutError, putter.result)
|
||||
|
||||
@gen_test
|
||||
def test_put_clears_timed_out_getters(self):
|
||||
q = queues.Queue()
|
||||
getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)]
|
||||
get = q.get()
|
||||
q.get()
|
||||
self.assertEqual(12, len(q._getters))
|
||||
yield gen.sleep(0.02)
|
||||
self.assertEqual(12, len(q._getters))
|
||||
self.assertFalse(get.done()) # Final waiters still active.
|
||||
q.put(0) # put() clears the waiters.
|
||||
self.assertEqual(1, len(q._getters))
|
||||
self.assertEqual(0, (yield get))
|
||||
for getter in getters:
|
||||
self.assertRaises(TimeoutError, getter.result)
|
||||
|
||||
@gen_test
|
||||
def test_float_maxsize(self):
|
||||
# Non-int maxsize must round down: http://bugs.python.org/issue21723
|
||||
q = queues.Queue(maxsize=1.3)
|
||||
self.assertTrue(q.empty())
|
||||
self.assertFalse(q.full())
|
||||
q.put_nowait(0)
|
||||
q.put_nowait(1)
|
||||
self.assertFalse(q.empty())
|
||||
self.assertTrue(q.full())
|
||||
self.assertRaises(queues.QueueFull, q.put_nowait, 2)
|
||||
self.assertEqual(0, q.get_nowait())
|
||||
self.assertFalse(q.empty())
|
||||
self.assertFalse(q.full())
|
||||
|
||||
yield q.put(2)
|
||||
put = q.put(3)
|
||||
self.assertFalse(put.done())
|
||||
self.assertEqual(1, (yield q.get()))
|
||||
yield put
|
||||
self.assertTrue(q.full())
|
||||
|
||||
|
||||
class QueueJoinTest(AsyncTestCase):
|
||||
queue_class = queues.Queue
|
||||
|
||||
def test_task_done_underflow(self):
|
||||
q = self.queue_class()
|
||||
self.assertRaises(ValueError, q.task_done)
|
||||
|
||||
@gen_test
|
||||
def test_task_done(self):
|
||||
q = self.queue_class()
|
||||
for i in range(100):
|
||||
q.put_nowait(i)
|
||||
|
||||
self.accumulator = 0
|
||||
|
||||
@gen.coroutine
|
||||
def worker():
|
||||
while True:
|
||||
item = yield q.get()
|
||||
self.accumulator += item
|
||||
q.task_done()
|
||||
yield gen.sleep(random() * 0.01)
|
||||
|
||||
# Two coroutines share work.
|
||||
worker()
|
||||
worker()
|
||||
yield q.join()
|
||||
self.assertEqual(sum(range(100)), self.accumulator)
|
||||
|
||||
@gen_test
|
||||
def test_task_done_delay(self):
|
||||
# Verify it is task_done(), not get(), that unblocks join().
|
||||
q = self.queue_class()
|
||||
q.put_nowait(0)
|
||||
join = q.join()
|
||||
self.assertFalse(join.done())
|
||||
yield q.get()
|
||||
self.assertFalse(join.done())
|
||||
yield gen.moment
|
||||
self.assertFalse(join.done())
|
||||
q.task_done()
|
||||
self.assertTrue(join.done())
|
||||
|
||||
@gen_test
|
||||
def test_join_empty_queue(self):
|
||||
q = self.queue_class()
|
||||
yield q.join()
|
||||
yield q.join()
|
||||
|
||||
@gen_test
|
||||
def test_join_timeout(self):
|
||||
q = self.queue_class()
|
||||
q.put(0)
|
||||
with self.assertRaises(TimeoutError):
|
||||
yield q.join(timeout=timedelta(seconds=0.01))
|
||||
|
||||
|
||||
class PriorityQueueJoinTest(QueueJoinTest):
|
||||
queue_class = queues.PriorityQueue
|
||||
|
||||
@gen_test
|
||||
def test_order(self):
|
||||
q = self.queue_class(maxsize=2)
|
||||
q.put_nowait((1, 'a'))
|
||||
q.put_nowait((0, 'b'))
|
||||
self.assertTrue(q.full())
|
||||
q.put((3, 'c'))
|
||||
q.put((2, 'd'))
|
||||
self.assertEqual((0, 'b'), q.get_nowait())
|
||||
self.assertEqual((1, 'a'), (yield q.get()))
|
||||
self.assertEqual((2, 'd'), q.get_nowait())
|
||||
self.assertEqual((3, 'c'), (yield q.get()))
|
||||
self.assertTrue(q.empty())
|
||||
|
||||
|
||||
class LifoQueueJoinTest(QueueJoinTest):
|
||||
queue_class = queues.LifoQueue
|
||||
|
||||
@gen_test
|
||||
def test_order(self):
|
||||
q = self.queue_class(maxsize=2)
|
||||
q.put_nowait(1)
|
||||
q.put_nowait(0)
|
||||
self.assertTrue(q.full())
|
||||
q.put(3)
|
||||
q.put(2)
|
||||
self.assertEqual(3, q.get_nowait())
|
||||
self.assertEqual(2, (yield q.get()))
|
||||
self.assertEqual(0, q.get_nowait())
|
||||
self.assertEqual(1, (yield q.get()))
|
||||
self.assertTrue(q.empty())
|
||||
|
||||
|
||||
class ProducerConsumerTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_producer_consumer(self):
|
||||
q = queues.Queue(maxsize=3)
|
||||
history = []
|
||||
|
||||
# We don't yield between get() and task_done(), so get() must wait for
|
||||
# the next tick. Otherwise we'd immediately call task_done and unblock
|
||||
# join() before q.put() resumes, and we'd only process the first four
|
||||
# items.
|
||||
@gen.coroutine
|
||||
def consumer():
|
||||
while True:
|
||||
history.append((yield q.get()))
|
||||
q.task_done()
|
||||
|
||||
@gen.coroutine
|
||||
def producer():
|
||||
for item in range(10):
|
||||
yield q.put(item)
|
||||
|
||||
consumer()
|
||||
yield producer()
|
||||
yield q.join()
|
||||
self.assertEqual(list(range(10)), history)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Executable
+11
@@ -0,0 +1,11 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.netutil import ThreadedResolver
|
||||
|
||||
# When this module is imported, it runs getaddrinfo on a thread. Since
|
||||
# the hostname is unicode, getaddrinfo attempts to import encodings.idna
|
||||
# but blocks on the import lock. Verify that ThreadedResolver avoids
|
||||
# this deadlock.
|
||||
|
||||
resolver = ThreadedResolver()
|
||||
IOLoop.current().run_sync(lambda: resolver.resolve(u'localhost', 80))
|
||||
Executable
+247
@@ -0,0 +1,247 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine # noqa: E501
|
||||
from tornado.routing import HostMatches, PathMatches, ReversibleRouter, Router, Rule, RuleRouter
|
||||
from tornado.testing import AsyncHTTPTestCase
|
||||
from tornado.web import Application, HTTPError, RequestHandler
|
||||
from tornado.wsgi import WSGIContainer
|
||||
|
||||
|
||||
class BasicRouter(Router):
|
||||
def find_handler(self, request, **kwargs):
|
||||
|
||||
class MessageDelegate(HTTPMessageDelegate):
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def finish(self):
|
||||
self.connection.write_headers(
|
||||
ResponseStartLine("HTTP/1.1", 200, "OK"),
|
||||
HTTPHeaders({"Content-Length": "2"}),
|
||||
b"OK"
|
||||
)
|
||||
self.connection.finish()
|
||||
|
||||
return MessageDelegate(request.connection)
|
||||
|
||||
|
||||
class BasicRouterTestCase(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
return BasicRouter()
|
||||
|
||||
def test_basic_router(self):
|
||||
response = self.fetch("/any_request")
|
||||
self.assertEqual(response.body, b"OK")
|
||||
|
||||
|
||||
resources = {}
|
||||
|
||||
|
||||
class GetResource(RequestHandler):
|
||||
def get(self, path):
|
||||
if path not in resources:
|
||||
raise HTTPError(404)
|
||||
|
||||
self.finish(resources[path])
|
||||
|
||||
|
||||
class PostResource(RequestHandler):
|
||||
def post(self, path):
|
||||
resources[path] = self.request.body
|
||||
|
||||
|
||||
class HTTPMethodRouter(Router):
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
def find_handler(self, request, **kwargs):
|
||||
handler = GetResource if request.method == "GET" else PostResource
|
||||
return self.app.get_handler_delegate(request, handler, path_args=[request.path])
|
||||
|
||||
|
||||
class HTTPMethodRouterTestCase(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
return HTTPMethodRouter(Application())
|
||||
|
||||
def test_http_method_router(self):
|
||||
response = self.fetch("/post_resource", method="POST", body="data")
|
||||
self.assertEqual(response.code, 200)
|
||||
|
||||
response = self.fetch("/get_resource")
|
||||
self.assertEqual(response.code, 404)
|
||||
|
||||
response = self.fetch("/post_resource")
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertEqual(response.body, b"data")
|
||||
|
||||
|
||||
def _get_named_handler(handler_name):
|
||||
class Handler(RequestHandler):
|
||||
def get(self, *args, **kwargs):
|
||||
if self.application.settings.get("app_name") is not None:
|
||||
self.write(self.application.settings["app_name"] + ": ")
|
||||
|
||||
self.finish(handler_name + ": " + self.reverse_url(handler_name))
|
||||
|
||||
return Handler
|
||||
|
||||
|
||||
FirstHandler = _get_named_handler("first_handler")
|
||||
SecondHandler = _get_named_handler("second_handler")
|
||||
|
||||
|
||||
class CustomRouter(ReversibleRouter):
|
||||
def __init__(self):
|
||||
super(CustomRouter, self).__init__()
|
||||
self.routes = {}
|
||||
|
||||
def add_routes(self, routes):
|
||||
self.routes.update(routes)
|
||||
|
||||
def find_handler(self, request, **kwargs):
|
||||
if request.path in self.routes:
|
||||
app, handler = self.routes[request.path]
|
||||
return app.get_handler_delegate(request, handler)
|
||||
|
||||
def reverse_url(self, name, *args):
|
||||
handler_path = '/' + name
|
||||
return handler_path if handler_path in self.routes else None
|
||||
|
||||
|
||||
class CustomRouterTestCase(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
class CustomApplication(Application):
|
||||
def reverse_url(self, name, *args):
|
||||
return router.reverse_url(name, *args)
|
||||
|
||||
router = CustomRouter()
|
||||
app1 = CustomApplication(app_name="app1")
|
||||
app2 = CustomApplication(app_name="app2")
|
||||
|
||||
router.add_routes({
|
||||
"/first_handler": (app1, FirstHandler),
|
||||
"/second_handler": (app2, SecondHandler),
|
||||
"/first_handler_second_app": (app2, FirstHandler),
|
||||
})
|
||||
|
||||
return router
|
||||
|
||||
def test_custom_router(self):
|
||||
response = self.fetch("/first_handler")
|
||||
self.assertEqual(response.body, b"app1: first_handler: /first_handler")
|
||||
response = self.fetch("/second_handler")
|
||||
self.assertEqual(response.body, b"app2: second_handler: /second_handler")
|
||||
response = self.fetch("/first_handler_second_app")
|
||||
self.assertEqual(response.body, b"app2: first_handler: /first_handler")
|
||||
|
||||
|
||||
class ConnectionDelegate(HTTPServerConnectionDelegate):
|
||||
def start_request(self, server_conn, request_conn):
|
||||
|
||||
class MessageDelegate(HTTPMessageDelegate):
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def finish(self):
|
||||
response_body = b"OK"
|
||||
self.connection.write_headers(
|
||||
ResponseStartLine("HTTP/1.1", 200, "OK"),
|
||||
HTTPHeaders({"Content-Length": str(len(response_body))}))
|
||||
self.connection.write(response_body)
|
||||
self.connection.finish()
|
||||
|
||||
return MessageDelegate(request_conn)
|
||||
|
||||
|
||||
class RuleRouterTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
app = Application()
|
||||
|
||||
def request_callable(request):
|
||||
request.connection.write_headers(
|
||||
ResponseStartLine("HTTP/1.1", 200, "OK"),
|
||||
HTTPHeaders({"Content-Length": "2"}))
|
||||
request.connection.write(b"OK")
|
||||
request.connection.finish()
|
||||
|
||||
router = CustomRouter()
|
||||
router.add_routes({
|
||||
"/nested_handler": (app, _get_named_handler("nested_handler"))
|
||||
})
|
||||
|
||||
app.add_handlers(".*", [
|
||||
(HostMatches("www.example.com"), [
|
||||
(PathMatches("/first_handler"),
|
||||
"tornado.test.routing_test.SecondHandler", {}, "second_handler")
|
||||
]),
|
||||
Rule(PathMatches("/.*handler"), router),
|
||||
Rule(PathMatches("/first_handler"), FirstHandler, name="first_handler"),
|
||||
Rule(PathMatches("/request_callable"), request_callable),
|
||||
("/connection_delegate", ConnectionDelegate())
|
||||
])
|
||||
|
||||
return app
|
||||
|
||||
def test_rule_based_router(self):
|
||||
response = self.fetch("/first_handler")
|
||||
self.assertEqual(response.body, b"first_handler: /first_handler")
|
||||
|
||||
response = self.fetch("/first_handler", headers={'Host': 'www.example.com'})
|
||||
self.assertEqual(response.body, b"second_handler: /first_handler")
|
||||
|
||||
response = self.fetch("/nested_handler")
|
||||
self.assertEqual(response.body, b"nested_handler: /nested_handler")
|
||||
|
||||
response = self.fetch("/nested_not_found_handler")
|
||||
self.assertEqual(response.code, 404)
|
||||
|
||||
response = self.fetch("/connection_delegate")
|
||||
self.assertEqual(response.body, b"OK")
|
||||
|
||||
response = self.fetch("/request_callable")
|
||||
self.assertEqual(response.body, b"OK")
|
||||
|
||||
response = self.fetch("/404")
|
||||
self.assertEqual(response.code, 404)
|
||||
|
||||
|
||||
class WSGIContainerTestCase(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
wsgi_app = WSGIContainer(self.wsgi_app)
|
||||
|
||||
class Handler(RequestHandler):
|
||||
def get(self, *args, **kwargs):
|
||||
self.finish(self.reverse_url("tornado"))
|
||||
|
||||
return RuleRouter([
|
||||
(PathMatches("/tornado.*"), Application([(r"/tornado/test", Handler, {}, "tornado")])),
|
||||
(PathMatches("/wsgi"), wsgi_app),
|
||||
])
|
||||
|
||||
def wsgi_app(self, environ, start_response):
|
||||
start_response("200 OK", [])
|
||||
return [b"WSGI"]
|
||||
|
||||
def test_wsgi_container(self):
|
||||
response = self.fetch("/tornado/test")
|
||||
self.assertEqual(response.body, b"/tornado/test")
|
||||
|
||||
response = self.fetch("/wsgi")
|
||||
self.assertEqual(response.body, b"WSGI")
|
||||
|
||||
def test_delegate_not_found(self):
|
||||
response = self.fetch("/404")
|
||||
self.assertEqual(response.code, 404)
|
||||
Executable
+219
@@ -0,0 +1,219 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import gc
|
||||
import io
|
||||
import locale # system locale module, not tornado.locale
|
||||
import logging
|
||||
import operator
|
||||
import textwrap
|
||||
import sys
|
||||
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.netutil import Resolver
|
||||
from tornado.options import define, options, add_parse_callback
|
||||
from tornado.test.util import unittest
|
||||
|
||||
try:
|
||||
reduce # py2
|
||||
except NameError:
|
||||
from functools import reduce # py3
|
||||
|
||||
TEST_MODULES = [
|
||||
'tornado.httputil.doctests',
|
||||
'tornado.iostream.doctests',
|
||||
'tornado.util.doctests',
|
||||
'tornado.test.asyncio_test',
|
||||
'tornado.test.auth_test',
|
||||
'tornado.test.autoreload_test',
|
||||
'tornado.test.concurrent_test',
|
||||
'tornado.test.curl_httpclient_test',
|
||||
'tornado.test.escape_test',
|
||||
'tornado.test.gen_test',
|
||||
'tornado.test.http1connection_test',
|
||||
'tornado.test.httpclient_test',
|
||||
'tornado.test.httpserver_test',
|
||||
'tornado.test.httputil_test',
|
||||
'tornado.test.import_test',
|
||||
'tornado.test.ioloop_test',
|
||||
'tornado.test.iostream_test',
|
||||
'tornado.test.locale_test',
|
||||
'tornado.test.locks_test',
|
||||
'tornado.test.netutil_test',
|
||||
'tornado.test.log_test',
|
||||
'tornado.test.options_test',
|
||||
'tornado.test.process_test',
|
||||
'tornado.test.queues_test',
|
||||
'tornado.test.routing_test',
|
||||
'tornado.test.simple_httpclient_test',
|
||||
'tornado.test.stack_context_test',
|
||||
'tornado.test.tcpclient_test',
|
||||
'tornado.test.tcpserver_test',
|
||||
'tornado.test.template_test',
|
||||
'tornado.test.testing_test',
|
||||
'tornado.test.twisted_test',
|
||||
'tornado.test.util_test',
|
||||
'tornado.test.web_test',
|
||||
'tornado.test.websocket_test',
|
||||
'tornado.test.windows_test',
|
||||
'tornado.test.wsgi_test',
|
||||
]
|
||||
|
||||
|
||||
def all():
|
||||
return unittest.defaultTestLoader.loadTestsFromNames(TEST_MODULES)
|
||||
|
||||
|
||||
def test_runner_factory(stderr):
|
||||
class TornadoTextTestRunner(unittest.TextTestRunner):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TornadoTextTestRunner, self).__init__(*args, stream=stderr, **kwargs)
|
||||
|
||||
def run(self, test):
|
||||
result = super(TornadoTextTestRunner, self).run(test)
|
||||
if result.skipped:
|
||||
skip_reasons = set(reason for (test, reason) in result.skipped)
|
||||
self.stream.write(textwrap.fill(
|
||||
"Some tests were skipped because: %s" %
|
||||
", ".join(sorted(skip_reasons))))
|
||||
self.stream.write("\n")
|
||||
return result
|
||||
return TornadoTextTestRunner
|
||||
|
||||
|
||||
class LogCounter(logging.Filter):
|
||||
"""Counts the number of WARNING or higher log records."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Can't use super() because logging.Filter is an old-style class in py26
|
||||
logging.Filter.__init__(self, *args, **kwargs)
|
||||
self.info_count = self.warning_count = self.error_count = 0
|
||||
|
||||
def filter(self, record):
|
||||
if record.levelno >= logging.ERROR:
|
||||
self.error_count += 1
|
||||
elif record.levelno >= logging.WARNING:
|
||||
self.warning_count += 1
|
||||
elif record.levelno >= logging.INFO:
|
||||
self.info_count += 1
|
||||
return True
|
||||
|
||||
|
||||
class CountingStderr(io.IOBase):
|
||||
def __init__(self, real):
|
||||
self.real = real
|
||||
self.byte_count = 0
|
||||
|
||||
def write(self, data):
|
||||
self.byte_count += len(data)
|
||||
return self.real.write(data)
|
||||
|
||||
def flush(self):
|
||||
return self.real.flush()
|
||||
|
||||
|
||||
def main():
|
||||
# The -W command-line option does not work in a virtualenv with
|
||||
# python 3 (as of virtualenv 1.7), so configure warnings
|
||||
# programmatically instead.
|
||||
import warnings
|
||||
# Be strict about most warnings. This also turns on warnings that are
|
||||
# ignored by default, including DeprecationWarnings and
|
||||
# python 3.2's ResourceWarnings.
|
||||
warnings.filterwarnings("error")
|
||||
# setuptools sometimes gives ImportWarnings about things that are on
|
||||
# sys.path even if they're not being used.
|
||||
warnings.filterwarnings("ignore", category=ImportWarning)
|
||||
# Tornado generally shouldn't use anything deprecated, but some of
|
||||
# our dependencies do (last match wins).
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("error", category=DeprecationWarning,
|
||||
module=r"tornado\..*")
|
||||
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
|
||||
warnings.filterwarnings("error", category=PendingDeprecationWarning,
|
||||
module=r"tornado\..*")
|
||||
# The unittest module is aggressive about deprecating redundant methods,
|
||||
# leaving some without non-deprecated spellings that work on both
|
||||
# 2.7 and 3.2
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning,
|
||||
message="Please use assert.* instead")
|
||||
warnings.filterwarnings("ignore", category=PendingDeprecationWarning,
|
||||
message="Please use assert.* instead")
|
||||
# Twisted 15.0.0 triggers some warnings on py3 with -bb.
|
||||
warnings.filterwarnings("ignore", category=BytesWarning,
|
||||
module=r"twisted\..*")
|
||||
if (3,) < sys.version_info < (3, 6):
|
||||
# Prior to 3.6, async ResourceWarnings were rather noisy
|
||||
# and even
|
||||
# `python3.4 -W error -c 'import asyncio; asyncio.get_event_loop()'`
|
||||
# would generate a warning.
|
||||
warnings.filterwarnings("ignore", category=ResourceWarning, # noqa: F821
|
||||
module=r"asyncio\..*")
|
||||
|
||||
logging.getLogger("tornado.access").setLevel(logging.CRITICAL)
|
||||
|
||||
define('httpclient', type=str, default=None,
|
||||
callback=lambda s: AsyncHTTPClient.configure(
|
||||
s, defaults=dict(allow_ipv6=False)))
|
||||
define('httpserver', type=str, default=None,
|
||||
callback=HTTPServer.configure)
|
||||
define('ioloop', type=str, default=None)
|
||||
define('ioloop_time_monotonic', default=False)
|
||||
define('resolver', type=str, default=None,
|
||||
callback=Resolver.configure)
|
||||
define('debug_gc', type=str, multiple=True,
|
||||
help="A comma-separated list of gc module debug constants, "
|
||||
"e.g. DEBUG_STATS or DEBUG_COLLECTABLE,DEBUG_OBJECTS",
|
||||
callback=lambda values: gc.set_debug(
|
||||
reduce(operator.or_, (getattr(gc, v) for v in values))))
|
||||
define('locale', type=str, default=None,
|
||||
callback=lambda x: locale.setlocale(locale.LC_ALL, x))
|
||||
|
||||
def configure_ioloop():
|
||||
kwargs = {}
|
||||
if options.ioloop_time_monotonic:
|
||||
from tornado.platform.auto import monotonic_time
|
||||
if monotonic_time is None:
|
||||
raise RuntimeError("monotonic clock not found")
|
||||
kwargs['time_func'] = monotonic_time
|
||||
if options.ioloop or kwargs:
|
||||
IOLoop.configure(options.ioloop, **kwargs)
|
||||
add_parse_callback(configure_ioloop)
|
||||
|
||||
log_counter = LogCounter()
|
||||
add_parse_callback(
|
||||
lambda: logging.getLogger().handlers[0].addFilter(log_counter))
|
||||
|
||||
# Certain errors (especially "unclosed resource" errors raised in
|
||||
# destructors) go directly to stderr instead of logging. Count
|
||||
# anything written by anything but the test runner as an error.
|
||||
orig_stderr = sys.stderr
|
||||
sys.stderr = CountingStderr(orig_stderr)
|
||||
|
||||
import tornado.testing
|
||||
kwargs = {}
|
||||
if sys.version_info >= (3, 2):
|
||||
# HACK: unittest.main will make its own changes to the warning
|
||||
# configuration, which may conflict with the settings above
|
||||
# or command-line flags like -bb. Passing warnings=False
|
||||
# suppresses this behavior, although this looks like an implementation
|
||||
# detail. http://bugs.python.org/issue15626
|
||||
kwargs['warnings'] = False
|
||||
kwargs['testRunner'] = test_runner_factory(orig_stderr)
|
||||
try:
|
||||
tornado.testing.main(**kwargs)
|
||||
finally:
|
||||
# The tests should run clean; consider it a failure if they
|
||||
# logged anything at info level or above.
|
||||
if (log_counter.info_count > 0 or
|
||||
log_counter.warning_count > 0 or
|
||||
log_counter.error_count > 0 or
|
||||
sys.stderr.byte_count > 0):
|
||||
logging.error("logged %d infos, %d warnings, %d errors, and %d bytes to stderr",
|
||||
log_counter.info_count, log_counter.warning_count,
|
||||
log_counter.error_count, sys.stderr.byte_count)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Executable
+767
@@ -0,0 +1,767 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import collections
|
||||
from contextlib import closing
|
||||
import errno
|
||||
import gzip
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
|
||||
from tornado.escape import to_unicode, utf8
|
||||
from tornado import gen
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httputil import HTTPHeaders, ResponseStartLine
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import UnsatisfiableReadError
|
||||
from tornado.locks import Event
|
||||
from tornado.log import gen_log
|
||||
from tornado.concurrent import Future
|
||||
from tornado.netutil import Resolver, bind_sockets
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient, HTTPStreamClosedError, HTTPTimeoutError
|
||||
from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler, RedirectHandler # noqa: E501
|
||||
from tornado.test import httpclient_test
|
||||
from tornado.testing import (AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase,
|
||||
ExpectLog, gen_test)
|
||||
from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port, skipBefore35, exec_test
|
||||
from tornado.web import RequestHandler, Application, url, stream_request_body
|
||||
|
||||
|
||||
class SimpleHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
|
||||
def get_http_client(self):
|
||||
client = SimpleAsyncHTTPClient(force_instance=True)
|
||||
self.assertTrue(isinstance(client, SimpleAsyncHTTPClient))
|
||||
return client
|
||||
|
||||
|
||||
class TriggerHandler(RequestHandler):
|
||||
def initialize(self, queue, wake_callback):
|
||||
self.queue = queue
|
||||
self.wake_callback = wake_callback
|
||||
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
logging.debug("queuing trigger")
|
||||
self.queue.append(self.finish)
|
||||
if self.get_argument("wake", "true") == "true":
|
||||
self.wake_callback()
|
||||
never_finish = Event()
|
||||
yield never_finish.wait()
|
||||
|
||||
|
||||
class HangHandler(RequestHandler):
|
||||
@gen.coroutine
|
||||
def get(self):
|
||||
never_finish = Event()
|
||||
yield never_finish.wait()
|
||||
|
||||
|
||||
class ContentLengthHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.stream = self.detach()
|
||||
IOLoop.current().spawn_callback(self.write_response)
|
||||
|
||||
@gen.coroutine
|
||||
def write_response(self):
|
||||
yield self.stream.write(utf8("HTTP/1.0 200 OK\r\nContent-Length: %s\r\n\r\nok" %
|
||||
self.get_argument("value")))
|
||||
self.stream.close()
|
||||
|
||||
|
||||
class HeadHandler(RequestHandler):
|
||||
def head(self):
|
||||
self.set_header("Content-Length", "7")
|
||||
|
||||
|
||||
class OptionsHandler(RequestHandler):
|
||||
def options(self):
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
self.write("ok")
|
||||
|
||||
|
||||
class NoContentHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.set_status(204)
|
||||
self.finish()
|
||||
|
||||
|
||||
class SeeOtherPostHandler(RequestHandler):
|
||||
def post(self):
|
||||
redirect_code = int(self.request.body)
|
||||
assert redirect_code in (302, 303), "unexpected body %r" % self.request.body
|
||||
self.set_header("Location", "/see_other_get")
|
||||
self.set_status(redirect_code)
|
||||
|
||||
|
||||
class SeeOtherGetHandler(RequestHandler):
|
||||
def get(self):
|
||||
if self.request.body:
|
||||
raise Exception("unexpected body %r" % self.request.body)
|
||||
self.write("ok")
|
||||
|
||||
|
||||
class HostEchoHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.write(self.request.headers["Host"])
|
||||
|
||||
|
||||
class NoContentLengthHandler(RequestHandler):
|
||||
def get(self):
|
||||
if self.request.version.startswith('HTTP/1'):
|
||||
# Emulate the old HTTP/1.0 behavior of returning a body with no
|
||||
# content-length. Tornado handles content-length at the framework
|
||||
# level so we have to go around it.
|
||||
stream = self.detach()
|
||||
stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
|
||||
b"hello")
|
||||
stream.close()
|
||||
else:
|
||||
self.finish('HTTP/1 required')
|
||||
|
||||
|
||||
class EchoPostHandler(RequestHandler):
|
||||
def post(self):
|
||||
self.write(self.request.body)
|
||||
|
||||
|
||||
@stream_request_body
|
||||
class RespondInPrepareHandler(RequestHandler):
|
||||
def prepare(self):
|
||||
self.set_status(403)
|
||||
self.finish("forbidden")
|
||||
|
||||
|
||||
class SimpleHTTPClientTestMixin(object):
|
||||
def get_app(self):
|
||||
# callable objects to finish pending /trigger requests
|
||||
self.triggers = collections.deque()
|
||||
return Application([
|
||||
url("/trigger", TriggerHandler, dict(queue=self.triggers,
|
||||
wake_callback=self.stop)),
|
||||
url("/chunk", ChunkHandler),
|
||||
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
|
||||
url("/hang", HangHandler),
|
||||
url("/hello", HelloWorldHandler),
|
||||
url("/content_length", ContentLengthHandler),
|
||||
url("/head", HeadHandler),
|
||||
url("/options", OptionsHandler),
|
||||
url("/no_content", NoContentHandler),
|
||||
url("/see_other_post", SeeOtherPostHandler),
|
||||
url("/see_other_get", SeeOtherGetHandler),
|
||||
url("/host_echo", HostEchoHandler),
|
||||
url("/no_content_length", NoContentLengthHandler),
|
||||
url("/echo_post", EchoPostHandler),
|
||||
url("/respond_in_prepare", RespondInPrepareHandler),
|
||||
url("/redirect", RedirectHandler),
|
||||
], gzip=True)
|
||||
|
||||
def test_singleton(self):
|
||||
# Class "constructor" reuses objects on the same IOLoop
|
||||
self.assertTrue(SimpleAsyncHTTPClient() is
|
||||
SimpleAsyncHTTPClient())
|
||||
# unless force_instance is used
|
||||
self.assertTrue(SimpleAsyncHTTPClient() is not
|
||||
SimpleAsyncHTTPClient(force_instance=True))
|
||||
# different IOLoops use different objects
|
||||
with closing(IOLoop()) as io_loop2:
|
||||
client1 = self.io_loop.run_sync(gen.coroutine(SimpleAsyncHTTPClient))
|
||||
client2 = io_loop2.run_sync(gen.coroutine(SimpleAsyncHTTPClient))
|
||||
self.assertTrue(client1 is not client2)
|
||||
|
||||
def test_connection_limit(self):
|
||||
with closing(self.create_client(max_clients=2)) as client:
|
||||
self.assertEqual(client.max_clients, 2)
|
||||
seen = []
|
||||
# Send 4 requests. Two can be sent immediately, while the others
|
||||
# will be queued
|
||||
for i in range(4):
|
||||
client.fetch(self.get_url("/trigger")).add_done_callback(
|
||||
lambda fut, i=i: (seen.append(i), self.stop()))
|
||||
self.wait(condition=lambda: len(self.triggers) == 2)
|
||||
self.assertEqual(len(client.queue), 2)
|
||||
|
||||
# Finish the first two requests and let the next two through
|
||||
self.triggers.popleft()()
|
||||
self.triggers.popleft()()
|
||||
self.wait(condition=lambda: (len(self.triggers) == 2 and
|
||||
len(seen) == 2))
|
||||
self.assertEqual(set(seen), set([0, 1]))
|
||||
self.assertEqual(len(client.queue), 0)
|
||||
|
||||
# Finish all the pending requests
|
||||
self.triggers.popleft()()
|
||||
self.triggers.popleft()()
|
||||
self.wait(condition=lambda: len(seen) == 4)
|
||||
self.assertEqual(set(seen), set([0, 1, 2, 3]))
|
||||
self.assertEqual(len(self.triggers), 0)
|
||||
|
||||
@gen_test
|
||||
def test_redirect_connection_limit(self):
|
||||
# following redirects should not consume additional connections
|
||||
with closing(self.create_client(max_clients=1)) as client:
|
||||
response = yield client.fetch(self.get_url('/countdown/3'),
|
||||
max_redirects=3)
|
||||
response.rethrow()
|
||||
|
||||
def test_gzip(self):
|
||||
# All the tests in this file should be using gzip, but this test
|
||||
# ensures that it is in fact getting compressed.
|
||||
# Setting Accept-Encoding manually bypasses the client's
|
||||
# decompression so we can see the raw data.
|
||||
response = self.fetch("/chunk", use_gzip=False,
|
||||
headers={"Accept-Encoding": "gzip"})
|
||||
self.assertEqual(response.headers["Content-Encoding"], "gzip")
|
||||
self.assertNotEqual(response.body, b"asdfqwer")
|
||||
# Our test data gets bigger when gzipped. Oops. :)
|
||||
# Chunked encoding bypasses the MIN_LENGTH check.
|
||||
self.assertEqual(len(response.body), 34)
|
||||
f = gzip.GzipFile(mode="r", fileobj=response.buffer)
|
||||
self.assertEqual(f.read(), b"asdfqwer")
|
||||
|
||||
def test_max_redirects(self):
|
||||
response = self.fetch("/countdown/5", max_redirects=3)
|
||||
self.assertEqual(302, response.code)
|
||||
# We requested 5, followed three redirects for 4, 3, 2, then the last
|
||||
# unfollowed redirect is to 1.
|
||||
self.assertTrue(response.request.url.endswith("/countdown/5"))
|
||||
self.assertTrue(response.effective_url.endswith("/countdown/2"))
|
||||
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
|
||||
|
||||
def test_header_reuse(self):
|
||||
# Apps may reuse a headers object if they are only passing in constant
|
||||
# headers like user-agent. The header object should not be modified.
|
||||
headers = HTTPHeaders({'User-Agent': 'Foo'})
|
||||
self.fetch("/hello", headers=headers)
|
||||
self.assertEqual(list(headers.get_all()), [('User-Agent', 'Foo')])
|
||||
|
||||
def test_see_other_redirect(self):
|
||||
for code in (302, 303):
|
||||
response = self.fetch("/see_other_post", method="POST", body="%d" % code)
|
||||
self.assertEqual(200, response.code)
|
||||
self.assertTrue(response.request.url.endswith("/see_other_post"))
|
||||
self.assertTrue(response.effective_url.endswith("/see_other_get"))
|
||||
# request is the original request, is a POST still
|
||||
self.assertEqual("POST", response.request.method)
|
||||
|
||||
@skipOnTravis
|
||||
@gen_test
|
||||
def test_connect_timeout(self):
|
||||
timeout = 0.1
|
||||
|
||||
class TimeoutResolver(Resolver):
|
||||
def resolve(self, *args, **kwargs):
|
||||
return Future() # never completes
|
||||
|
||||
with closing(self.create_client(resolver=TimeoutResolver())) as client:
|
||||
with self.assertRaises(HTTPTimeoutError):
|
||||
yield client.fetch(self.get_url('/hello'),
|
||||
connect_timeout=timeout,
|
||||
request_timeout=3600,
|
||||
raise_error=True)
|
||||
|
||||
@skipOnTravis
|
||||
def test_request_timeout(self):
|
||||
timeout = 0.1
|
||||
if os.name == 'nt':
|
||||
timeout = 0.5
|
||||
|
||||
with self.assertRaises(HTTPTimeoutError):
|
||||
self.fetch('/trigger?wake=false', request_timeout=timeout, raise_error=True)
|
||||
# trigger the hanging request to let it clean up after itself
|
||||
self.triggers.popleft()()
|
||||
|
||||
@skipIfNoIPv6
|
||||
def test_ipv6(self):
|
||||
[sock] = bind_sockets(None, '::1', family=socket.AF_INET6)
|
||||
port = sock.getsockname()[1]
|
||||
self.http_server.add_socket(sock)
|
||||
url = '%s://[::1]:%d/hello' % (self.get_protocol(), port)
|
||||
|
||||
# ipv6 is currently enabled by default but can be disabled
|
||||
with self.assertRaises(Exception):
|
||||
self.fetch(url, allow_ipv6=False, raise_error=True)
|
||||
|
||||
response = self.fetch(url)
|
||||
self.assertEqual(response.body, b"Hello world!")
|
||||
|
||||
def test_multiple_content_length_accepted(self):
|
||||
response = self.fetch("/content_length?value=2,2")
|
||||
self.assertEqual(response.body, b"ok")
|
||||
response = self.fetch("/content_length?value=2,%202,2")
|
||||
self.assertEqual(response.body, b"ok")
|
||||
|
||||
with ExpectLog(gen_log, ".*Multiple unequal Content-Lengths"):
|
||||
with self.assertRaises(HTTPStreamClosedError):
|
||||
self.fetch("/content_length?value=2,4", raise_error=True)
|
||||
with self.assertRaises(HTTPStreamClosedError):
|
||||
self.fetch("/content_length?value=2,%202,3", raise_error=True)
|
||||
|
||||
def test_head_request(self):
|
||||
response = self.fetch("/head", method="HEAD")
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertEqual(response.headers["content-length"], "7")
|
||||
self.assertFalse(response.body)
|
||||
|
||||
def test_options_request(self):
|
||||
response = self.fetch("/options", method="OPTIONS")
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertEqual(response.headers["content-length"], "2")
|
||||
self.assertEqual(response.headers["access-control-allow-origin"], "*")
|
||||
self.assertEqual(response.body, b"ok")
|
||||
|
||||
def test_no_content(self):
|
||||
response = self.fetch("/no_content")
|
||||
self.assertEqual(response.code, 204)
|
||||
# 204 status shouldn't have a content-length
|
||||
#
|
||||
# Tests with a content-length header are included below
|
||||
# in HTTP204NoContentTestCase.
|
||||
self.assertNotIn("Content-Length", response.headers)
|
||||
|
||||
def test_host_header(self):
|
||||
host_re = re.compile(b"^127.0.0.1:[0-9]+$")
|
||||
response = self.fetch("/host_echo")
|
||||
self.assertTrue(host_re.match(response.body))
|
||||
|
||||
url = self.get_url("/host_echo").replace("http://", "http://me:secret@")
|
||||
response = self.fetch(url)
|
||||
self.assertTrue(host_re.match(response.body), response.body)
|
||||
|
||||
def test_connection_refused(self):
|
||||
cleanup_func, port = refusing_port()
|
||||
self.addCleanup(cleanup_func)
|
||||
with ExpectLog(gen_log, ".*", required=False):
|
||||
with self.assertRaises(socket.error) as cm:
|
||||
self.fetch("http://127.0.0.1:%d/" % port, raise_error=True)
|
||||
|
||||
if sys.platform != 'cygwin':
|
||||
# cygwin returns EPERM instead of ECONNREFUSED here
|
||||
contains_errno = str(errno.ECONNREFUSED) in str(cm.exception)
|
||||
if not contains_errno and hasattr(errno, "WSAECONNREFUSED"):
|
||||
contains_errno = str(errno.WSAECONNREFUSED) in str(cm.exception)
|
||||
self.assertTrue(contains_errno, cm.exception)
|
||||
# This is usually "Connection refused".
|
||||
# On windows, strerror is broken and returns "Unknown error".
|
||||
expected_message = os.strerror(errno.ECONNREFUSED)
|
||||
self.assertTrue(expected_message in str(cm.exception),
|
||||
cm.exception)
|
||||
|
||||
def test_queue_timeout(self):
|
||||
with closing(self.create_client(max_clients=1)) as client:
|
||||
# Wait for the trigger request to block, not complete.
|
||||
fut1 = client.fetch(self.get_url('/trigger'), request_timeout=10)
|
||||
self.wait()
|
||||
with self.assertRaises(HTTPTimeoutError) as cm:
|
||||
self.io_loop.run_sync(lambda: client.fetch(
|
||||
self.get_url('/hello'), connect_timeout=0.1, raise_error=True))
|
||||
|
||||
self.assertEqual(str(cm.exception), "Timeout in request queue")
|
||||
self.triggers.popleft()()
|
||||
self.io_loop.run_sync(lambda: fut1)
|
||||
|
||||
def test_no_content_length(self):
|
||||
response = self.fetch("/no_content_length")
|
||||
if response.body == b"HTTP/1 required":
|
||||
self.skipTest("requires HTTP/1.x")
|
||||
else:
|
||||
self.assertEquals(b"hello", response.body)
|
||||
|
||||
def sync_body_producer(self, write):
|
||||
write(b'1234')
|
||||
write(b'5678')
|
||||
|
||||
@gen.coroutine
|
||||
def async_body_producer(self, write):
|
||||
yield write(b'1234')
|
||||
yield gen.moment
|
||||
yield write(b'5678')
|
||||
|
||||
def test_sync_body_producer_chunked(self):
|
||||
response = self.fetch("/echo_post", method="POST",
|
||||
body_producer=self.sync_body_producer)
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b"12345678")
|
||||
|
||||
def test_sync_body_producer_content_length(self):
|
||||
response = self.fetch("/echo_post", method="POST",
|
||||
body_producer=self.sync_body_producer,
|
||||
headers={'Content-Length': '8'})
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b"12345678")
|
||||
|
||||
def test_async_body_producer_chunked(self):
|
||||
response = self.fetch("/echo_post", method="POST",
|
||||
body_producer=self.async_body_producer)
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b"12345678")
|
||||
|
||||
def test_async_body_producer_content_length(self):
|
||||
response = self.fetch("/echo_post", method="POST",
|
||||
body_producer=self.async_body_producer,
|
||||
headers={'Content-Length': '8'})
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b"12345678")
|
||||
|
||||
@skipBefore35
|
||||
def test_native_body_producer_chunked(self):
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def body_producer(write):
|
||||
await write(b'1234')
|
||||
import asyncio
|
||||
await asyncio.sleep(0)
|
||||
await write(b'5678')
|
||||
""")
|
||||
response = self.fetch("/echo_post", method="POST",
|
||||
body_producer=namespace["body_producer"])
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b"12345678")
|
||||
|
||||
@skipBefore35
|
||||
def test_native_body_producer_content_length(self):
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
async def body_producer(write):
|
||||
await write(b'1234')
|
||||
import asyncio
|
||||
await asyncio.sleep(0)
|
||||
await write(b'5678')
|
||||
""")
|
||||
response = self.fetch("/echo_post", method="POST",
|
||||
body_producer=namespace["body_producer"],
|
||||
headers={'Content-Length': '8'})
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b"12345678")
|
||||
|
||||
def test_100_continue(self):
|
||||
response = self.fetch("/echo_post", method="POST",
|
||||
body=b"1234",
|
||||
expect_100_continue=True)
|
||||
self.assertEqual(response.body, b"1234")
|
||||
|
||||
def test_100_continue_early_response(self):
|
||||
def body_producer(write):
|
||||
raise Exception("should not be called")
|
||||
response = self.fetch("/respond_in_prepare", method="POST",
|
||||
body_producer=body_producer,
|
||||
expect_100_continue=True)
|
||||
self.assertEqual(response.code, 403)
|
||||
|
||||
def test_streaming_follow_redirects(self):
|
||||
# When following redirects, header and streaming callbacks
|
||||
# should only be called for the final result.
|
||||
# TODO(bdarnell): this test belongs in httpclient_test instead of
|
||||
# simple_httpclient_test, but it fails with the version of libcurl
|
||||
# available on travis-ci. Move it when that has been upgraded
|
||||
# or we have a better framework to skip tests based on curl version.
|
||||
headers = []
|
||||
chunks = []
|
||||
self.fetch("/redirect?url=/hello",
|
||||
header_callback=headers.append,
|
||||
streaming_callback=chunks.append)
|
||||
chunks = list(map(to_unicode, chunks))
|
||||
self.assertEqual(chunks, ['Hello world!'])
|
||||
# Make sure we only got one set of headers.
|
||||
num_start_lines = len([h for h in headers if h.startswith("HTTP/")])
|
||||
self.assertEqual(num_start_lines, 1)
|
||||
|
||||
|
||||
class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
|
||||
def setUp(self):
|
||||
super(SimpleHTTPClientTestCase, self).setUp()
|
||||
self.http_client = self.create_client()
|
||||
|
||||
def create_client(self, **kwargs):
|
||||
return SimpleAsyncHTTPClient(force_instance=True, **kwargs)
|
||||
|
||||
|
||||
class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
|
||||
def setUp(self):
|
||||
super(SimpleHTTPSClientTestCase, self).setUp()
|
||||
self.http_client = self.create_client()
|
||||
|
||||
def create_client(self, **kwargs):
|
||||
return SimpleAsyncHTTPClient(force_instance=True,
|
||||
defaults=dict(validate_cert=False),
|
||||
**kwargs)
|
||||
|
||||
def test_ssl_options(self):
|
||||
resp = self.fetch("/hello", ssl_options={})
|
||||
self.assertEqual(resp.body, b"Hello world!")
|
||||
|
||||
def test_ssl_context(self):
|
||||
resp = self.fetch("/hello",
|
||||
ssl_options=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
|
||||
self.assertEqual(resp.body, b"Hello world!")
|
||||
|
||||
def test_ssl_options_handshake_fail(self):
|
||||
with ExpectLog(gen_log, "SSL Error|Uncaught exception",
|
||||
required=False):
|
||||
with self.assertRaises(ssl.SSLError):
|
||||
self.fetch(
|
||||
"/hello", ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED),
|
||||
raise_error=True)
|
||||
|
||||
def test_ssl_context_handshake_fail(self):
|
||||
with ExpectLog(gen_log, "SSL Error|Uncaught exception"):
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
with self.assertRaises(ssl.SSLError):
|
||||
self.fetch("/hello", ssl_options=ctx, raise_error=True)
|
||||
|
||||
def test_error_logging(self):
|
||||
# No stack traces are logged for SSL errors (in this case,
|
||||
# failure to validate the testing self-signed cert).
|
||||
# The SSLError is exposed through ssl.SSLError.
|
||||
with ExpectLog(gen_log, '.*') as expect_log:
|
||||
with self.assertRaises(ssl.SSLError):
|
||||
self.fetch("/", validate_cert=True, raise_error=True)
|
||||
self.assertFalse(expect_log.logged_stack)
|
||||
|
||||
|
||||
class CreateAsyncHTTPClientTestCase(AsyncTestCase):
|
||||
def setUp(self):
|
||||
super(CreateAsyncHTTPClientTestCase, self).setUp()
|
||||
self.saved = AsyncHTTPClient._save_configuration()
|
||||
|
||||
def tearDown(self):
|
||||
AsyncHTTPClient._restore_configuration(self.saved)
|
||||
super(CreateAsyncHTTPClientTestCase, self).tearDown()
|
||||
|
||||
def test_max_clients(self):
|
||||
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
|
||||
with closing(AsyncHTTPClient(force_instance=True)) as client:
|
||||
self.assertEqual(client.max_clients, 10)
|
||||
with closing(AsyncHTTPClient(
|
||||
max_clients=11, force_instance=True)) as client:
|
||||
self.assertEqual(client.max_clients, 11)
|
||||
|
||||
# Now configure max_clients statically and try overriding it
|
||||
# with each way max_clients can be passed
|
||||
AsyncHTTPClient.configure(SimpleAsyncHTTPClient, max_clients=12)
|
||||
with closing(AsyncHTTPClient(force_instance=True)) as client:
|
||||
self.assertEqual(client.max_clients, 12)
|
||||
with closing(AsyncHTTPClient(
|
||||
max_clients=13, force_instance=True)) as client:
|
||||
self.assertEqual(client.max_clients, 13)
|
||||
with closing(AsyncHTTPClient(
|
||||
max_clients=14, force_instance=True)) as client:
|
||||
self.assertEqual(client.max_clients, 14)
|
||||
|
||||
|
||||
class HTTP100ContinueTestCase(AsyncHTTPTestCase):
|
||||
def respond_100(self, request):
|
||||
self.http1 = request.version.startswith('HTTP/1.')
|
||||
if not self.http1:
|
||||
request.connection.write_headers(ResponseStartLine('', 200, 'OK'),
|
||||
HTTPHeaders())
|
||||
request.connection.finish()
|
||||
return
|
||||
self.request = request
|
||||
fut = self.request.connection.stream.write(
|
||||
b"HTTP/1.1 100 CONTINUE\r\n\r\n")
|
||||
fut.add_done_callback(self.respond_200)
|
||||
|
||||
def respond_200(self, fut):
|
||||
fut.result()
|
||||
fut = self.request.connection.stream.write(
|
||||
b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\nA")
|
||||
fut.add_done_callback(lambda f: self.request.connection.stream.close())
|
||||
|
||||
def get_app(self):
|
||||
# Not a full Application, but works as an HTTPServer callback
|
||||
return self.respond_100
|
||||
|
||||
def test_100_continue(self):
|
||||
res = self.fetch('/')
|
||||
if not self.http1:
|
||||
self.skipTest("requires HTTP/1.x")
|
||||
self.assertEqual(res.body, b'A')
|
||||
|
||||
|
||||
class HTTP204NoContentTestCase(AsyncHTTPTestCase):
|
||||
def respond_204(self, request):
|
||||
self.http1 = request.version.startswith('HTTP/1.')
|
||||
if not self.http1:
|
||||
# Close the request cleanly in HTTP/2; it will be skipped anyway.
|
||||
request.connection.write_headers(ResponseStartLine('', 200, 'OK'),
|
||||
HTTPHeaders())
|
||||
request.connection.finish()
|
||||
return
|
||||
|
||||
# A 204 response never has a body, even if doesn't have a content-length
|
||||
# (which would otherwise mean read-until-close). We simulate here a
|
||||
# server that sends no content length and does not close the connection.
|
||||
#
|
||||
# Tests of a 204 response with no Content-Length header are included
|
||||
# in SimpleHTTPClientTestMixin.
|
||||
stream = request.connection.detach()
|
||||
stream.write(b"HTTP/1.1 204 No content\r\n")
|
||||
if request.arguments.get("error", [False])[-1]:
|
||||
stream.write(b"Content-Length: 5\r\n")
|
||||
else:
|
||||
stream.write(b"Content-Length: 0\r\n")
|
||||
stream.write(b"\r\n")
|
||||
stream.close()
|
||||
|
||||
def get_app(self):
|
||||
return self.respond_204
|
||||
|
||||
def test_204_no_content(self):
|
||||
resp = self.fetch('/')
|
||||
if not self.http1:
|
||||
self.skipTest("requires HTTP/1.x")
|
||||
self.assertEqual(resp.code, 204)
|
||||
self.assertEqual(resp.body, b'')
|
||||
|
||||
def test_204_invalid_content_length(self):
|
||||
# 204 status with non-zero content length is malformed
|
||||
with ExpectLog(gen_log, ".*Response with code 204 should not have body"):
|
||||
with self.assertRaises(HTTPStreamClosedError):
|
||||
self.fetch("/?error=1", raise_error=True)
|
||||
if not self.http1:
|
||||
self.skipTest("requires HTTP/1.x")
|
||||
if self.http_client.configured_class != SimpleAsyncHTTPClient:
|
||||
self.skipTest("curl client accepts invalid headers")
|
||||
|
||||
|
||||
class HostnameMappingTestCase(AsyncHTTPTestCase):
|
||||
def setUp(self):
|
||||
super(HostnameMappingTestCase, self).setUp()
|
||||
self.http_client = SimpleAsyncHTTPClient(
|
||||
hostname_mapping={
|
||||
'www.example.com': '127.0.0.1',
|
||||
('foo.example.com', 8000): ('127.0.0.1', self.get_http_port()),
|
||||
})
|
||||
|
||||
def get_app(self):
|
||||
return Application([url("/hello", HelloWorldHandler), ])
|
||||
|
||||
def test_hostname_mapping(self):
|
||||
response = self.fetch(
|
||||
'http://www.example.com:%d/hello' % self.get_http_port())
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b'Hello world!')
|
||||
|
||||
def test_port_mapping(self):
|
||||
response = self.fetch('http://foo.example.com:8000/hello')
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b'Hello world!')
|
||||
|
||||
|
||||
class ResolveTimeoutTestCase(AsyncHTTPTestCase):
|
||||
def setUp(self):
|
||||
# Dummy Resolver subclass that never finishes.
|
||||
class BadResolver(Resolver):
|
||||
@gen.coroutine
|
||||
def resolve(self, *args, **kwargs):
|
||||
yield Event().wait()
|
||||
|
||||
super(ResolveTimeoutTestCase, self).setUp()
|
||||
self.http_client = SimpleAsyncHTTPClient(
|
||||
resolver=BadResolver())
|
||||
|
||||
def get_app(self):
|
||||
return Application([url("/hello", HelloWorldHandler), ])
|
||||
|
||||
def test_resolve_timeout(self):
|
||||
with self.assertRaises(HTTPTimeoutError):
|
||||
self.fetch('/hello', connect_timeout=0.1, raise_error=True)
|
||||
|
||||
|
||||
class MaxHeaderSizeTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
class SmallHeaders(RequestHandler):
|
||||
def get(self):
|
||||
self.set_header("X-Filler", "a" * 100)
|
||||
self.write("ok")
|
||||
|
||||
class LargeHeaders(RequestHandler):
|
||||
def get(self):
|
||||
self.set_header("X-Filler", "a" * 1000)
|
||||
self.write("ok")
|
||||
|
||||
return Application([('/small', SmallHeaders),
|
||||
('/large', LargeHeaders)])
|
||||
|
||||
def get_http_client(self):
|
||||
return SimpleAsyncHTTPClient(max_header_size=1024)
|
||||
|
||||
def test_small_headers(self):
|
||||
response = self.fetch('/small')
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b'ok')
|
||||
|
||||
def test_large_headers(self):
|
||||
with ExpectLog(gen_log, "Unsatisfiable read"):
|
||||
with self.assertRaises(UnsatisfiableReadError):
|
||||
self.fetch('/large', raise_error=True)
|
||||
|
||||
|
||||
class MaxBodySizeTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
class SmallBody(RequestHandler):
|
||||
def get(self):
|
||||
self.write("a" * 1024 * 64)
|
||||
|
||||
class LargeBody(RequestHandler):
|
||||
def get(self):
|
||||
self.write("a" * 1024 * 100)
|
||||
|
||||
return Application([('/small', SmallBody),
|
||||
('/large', LargeBody)])
|
||||
|
||||
def get_http_client(self):
|
||||
return SimpleAsyncHTTPClient(max_body_size=1024 * 64)
|
||||
|
||||
def test_small_body(self):
|
||||
response = self.fetch('/small')
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b'a' * 1024 * 64)
|
||||
|
||||
def test_large_body(self):
|
||||
with ExpectLog(gen_log, "Malformed HTTP message from None: Content-Length too long"):
|
||||
with self.assertRaises(HTTPStreamClosedError):
|
||||
self.fetch('/large', raise_error=True)
|
||||
|
||||
|
||||
class MaxBufferSizeTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
|
||||
class LargeBody(RequestHandler):
|
||||
def get(self):
|
||||
self.write("a" * 1024 * 100)
|
||||
|
||||
return Application([('/large', LargeBody)])
|
||||
|
||||
def get_http_client(self):
|
||||
# 100KB body with 64KB buffer
|
||||
return SimpleAsyncHTTPClient(max_body_size=1024 * 100, max_buffer_size=1024 * 64)
|
||||
|
||||
def test_large_body(self):
|
||||
response = self.fetch('/large')
|
||||
response.rethrow()
|
||||
self.assertEqual(response.body, b'a' * 1024 * 100)
|
||||
|
||||
|
||||
class ChunkedWithContentLengthTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
|
||||
class ChunkedWithContentLength(RequestHandler):
|
||||
def get(self):
|
||||
# Add an invalid Transfer-Encoding to the response
|
||||
self.set_header('Transfer-Encoding', 'chunked')
|
||||
self.write("Hello world")
|
||||
|
||||
return Application([('/chunkwithcl', ChunkedWithContentLength)])
|
||||
|
||||
def get_http_client(self):
|
||||
return SimpleAsyncHTTPClient()
|
||||
|
||||
def test_chunked_with_content_length(self):
|
||||
# Make sure the invalid headers are detected
|
||||
with ExpectLog(gen_log, ("Malformed HTTP message from None: Response "
|
||||
"with both Transfer-Encoding and Content-Length")):
|
||||
with self.assertRaises(HTTPStreamClosedError):
|
||||
self.fetch('/chunkwithcl', raise_error=True)
|
||||
Executable
+297
@@ -0,0 +1,297 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from tornado import gen
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.log import app_log
|
||||
from tornado.stack_context import (StackContext, wrap, NullContext, StackContextInconsistentError,
|
||||
ExceptionStackContext, run_with_stack_context, _state)
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
|
||||
from tornado.test.util import unittest, ignore_deprecation
|
||||
from tornado.web import asynchronous, Application, RequestHandler
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
|
||||
class TestRequestHandler(RequestHandler):
|
||||
def __init__(self, app, request):
|
||||
super(TestRequestHandler, self).__init__(app, request)
|
||||
|
||||
with ignore_deprecation():
|
||||
@asynchronous
|
||||
def get(self):
|
||||
logging.debug('in get()')
|
||||
# call self.part2 without a self.async_callback wrapper. Its
|
||||
# exception should still get thrown
|
||||
IOLoop.current().add_callback(self.part2)
|
||||
|
||||
def part2(self):
|
||||
logging.debug('in part2()')
|
||||
# Go through a third layer to make sure that contexts once restored
|
||||
# are again passed on to future callbacks
|
||||
IOLoop.current().add_callback(self.part3)
|
||||
|
||||
def part3(self):
|
||||
logging.debug('in part3()')
|
||||
raise Exception('test exception')
|
||||
|
||||
def write_error(self, status_code, **kwargs):
|
||||
if 'exc_info' in kwargs and str(kwargs['exc_info'][1]) == 'test exception':
|
||||
self.write('got expected exception')
|
||||
else:
|
||||
self.write('unexpected failure')
|
||||
|
||||
|
||||
class HTTPStackContextTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
return Application([('/', TestRequestHandler)])
|
||||
|
||||
def test_stack_context(self):
|
||||
with ExpectLog(app_log, "Uncaught exception GET /"):
|
||||
with ignore_deprecation():
|
||||
self.http_client.fetch(self.get_url('/'), self.handle_response)
|
||||
self.wait()
|
||||
self.assertEqual(self.response.code, 500)
|
||||
self.assertTrue(b'got expected exception' in self.response.body)
|
||||
|
||||
def handle_response(self, response):
|
||||
self.response = response
|
||||
self.stop()
|
||||
|
||||
|
||||
class StackContextTest(AsyncTestCase):
|
||||
def setUp(self):
|
||||
super(StackContextTest, self).setUp()
|
||||
self.active_contexts = []
|
||||
self.warning_catcher = warnings.catch_warnings()
|
||||
self.warning_catcher.__enter__()
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
|
||||
def tearDown(self):
|
||||
self.warning_catcher.__exit__(None, None, None)
|
||||
super(StackContextTest, self).tearDown()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def context(self, name):
|
||||
self.active_contexts.append(name)
|
||||
yield
|
||||
self.assertEqual(self.active_contexts.pop(), name)
|
||||
|
||||
# Simulates the effect of an asynchronous library that uses its own
|
||||
# StackContext internally and then returns control to the application.
|
||||
def test_exit_library_context(self):
|
||||
def library_function(callback):
|
||||
# capture the caller's context before introducing our own
|
||||
callback = wrap(callback)
|
||||
with StackContext(functools.partial(self.context, 'library')):
|
||||
self.io_loop.add_callback(
|
||||
functools.partial(library_inner_callback, callback))
|
||||
|
||||
def library_inner_callback(callback):
|
||||
self.assertEqual(self.active_contexts[-2:],
|
||||
['application', 'library'])
|
||||
callback()
|
||||
|
||||
def final_callback():
|
||||
# implementation detail: the full context stack at this point
|
||||
# is ['application', 'library', 'application']. The 'library'
|
||||
# context was not removed, but is no longer innermost so
|
||||
# the application context takes precedence.
|
||||
self.assertEqual(self.active_contexts[-1], 'application')
|
||||
self.stop()
|
||||
with StackContext(functools.partial(self.context, 'application')):
|
||||
library_function(final_callback)
|
||||
self.wait()
|
||||
|
||||
def test_deactivate(self):
|
||||
deactivate_callbacks = []
|
||||
|
||||
def f1():
|
||||
with StackContext(functools.partial(self.context, 'c1')) as c1:
|
||||
deactivate_callbacks.append(c1)
|
||||
self.io_loop.add_callback(f2)
|
||||
|
||||
def f2():
|
||||
with StackContext(functools.partial(self.context, 'c2')) as c2:
|
||||
deactivate_callbacks.append(c2)
|
||||
self.io_loop.add_callback(f3)
|
||||
|
||||
def f3():
|
||||
with StackContext(functools.partial(self.context, 'c3')) as c3:
|
||||
deactivate_callbacks.append(c3)
|
||||
self.io_loop.add_callback(f4)
|
||||
|
||||
def f4():
|
||||
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
|
||||
deactivate_callbacks[1]()
|
||||
# deactivating a context doesn't remove it immediately,
|
||||
# but it will be missing from the next iteration
|
||||
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
|
||||
self.io_loop.add_callback(f5)
|
||||
|
||||
def f5():
|
||||
self.assertEqual(self.active_contexts, ['c1', 'c3'])
|
||||
self.stop()
|
||||
self.io_loop.add_callback(f1)
|
||||
self.wait()
|
||||
|
||||
def test_deactivate_order(self):
|
||||
# Stack context deactivation has separate logic for deactivation at
|
||||
# the head and tail of the stack, so make sure it works in any order.
|
||||
def check_contexts():
|
||||
# Make sure that the full-context array and the exception-context
|
||||
# linked lists are consistent with each other.
|
||||
full_contexts, chain = _state.contexts
|
||||
exception_contexts = []
|
||||
while chain is not None:
|
||||
exception_contexts.append(chain)
|
||||
chain = chain.old_contexts[1]
|
||||
self.assertEqual(list(reversed(full_contexts)), exception_contexts)
|
||||
return list(self.active_contexts)
|
||||
|
||||
def make_wrapped_function():
|
||||
"""Wraps a function in three stack contexts, and returns
|
||||
the function along with the deactivation functions.
|
||||
"""
|
||||
# Remove the test's stack context to make sure we can cover
|
||||
# the case where the last context is deactivated.
|
||||
with NullContext():
|
||||
partial = functools.partial
|
||||
with StackContext(partial(self.context, 'c0')) as c0:
|
||||
with StackContext(partial(self.context, 'c1')) as c1:
|
||||
with StackContext(partial(self.context, 'c2')) as c2:
|
||||
return (wrap(check_contexts), [c0, c1, c2])
|
||||
|
||||
# First make sure the test mechanism works without any deactivations
|
||||
func, deactivate_callbacks = make_wrapped_function()
|
||||
self.assertEqual(func(), ['c0', 'c1', 'c2'])
|
||||
|
||||
# Deactivate the tail
|
||||
func, deactivate_callbacks = make_wrapped_function()
|
||||
deactivate_callbacks[0]()
|
||||
self.assertEqual(func(), ['c1', 'c2'])
|
||||
|
||||
# Deactivate the middle
|
||||
func, deactivate_callbacks = make_wrapped_function()
|
||||
deactivate_callbacks[1]()
|
||||
self.assertEqual(func(), ['c0', 'c2'])
|
||||
|
||||
# Deactivate the head
|
||||
func, deactivate_callbacks = make_wrapped_function()
|
||||
deactivate_callbacks[2]()
|
||||
self.assertEqual(func(), ['c0', 'c1'])
|
||||
|
||||
def test_isolation_nonempty(self):
|
||||
# f2 and f3 are a chain of operations started in context c1.
|
||||
# f2 is incidentally run under context c2, but that context should
|
||||
# not be passed along to f3.
|
||||
def f1():
|
||||
with StackContext(functools.partial(self.context, 'c1')):
|
||||
wrapped = wrap(f2)
|
||||
with StackContext(functools.partial(self.context, 'c2')):
|
||||
wrapped()
|
||||
|
||||
def f2():
|
||||
self.assertIn('c1', self.active_contexts)
|
||||
self.io_loop.add_callback(f3)
|
||||
|
||||
def f3():
|
||||
self.assertIn('c1', self.active_contexts)
|
||||
self.assertNotIn('c2', self.active_contexts)
|
||||
self.stop()
|
||||
|
||||
self.io_loop.add_callback(f1)
|
||||
self.wait()
|
||||
|
||||
def test_isolation_empty(self):
|
||||
# Similar to test_isolation_nonempty, but here the f2/f3 chain
|
||||
# is started without any context. Behavior should be equivalent
|
||||
# to the nonempty case (although historically it was not)
|
||||
def f1():
|
||||
with NullContext():
|
||||
wrapped = wrap(f2)
|
||||
with StackContext(functools.partial(self.context, 'c2')):
|
||||
wrapped()
|
||||
|
||||
def f2():
|
||||
self.io_loop.add_callback(f3)
|
||||
|
||||
def f3():
|
||||
self.assertNotIn('c2', self.active_contexts)
|
||||
self.stop()
|
||||
|
||||
self.io_loop.add_callback(f1)
|
||||
self.wait()
|
||||
|
||||
def test_yield_in_with(self):
|
||||
@gen.engine
|
||||
def f():
|
||||
self.callback = yield gen.Callback('a')
|
||||
with StackContext(functools.partial(self.context, 'c1')):
|
||||
# This yield is a problem: the generator will be suspended
|
||||
# and the StackContext's __exit__ is not called yet, so
|
||||
# the context will be left on _state.contexts for anything
|
||||
# that runs before the yield resolves.
|
||||
yield gen.Wait('a')
|
||||
|
||||
with self.assertRaises(StackContextInconsistentError):
|
||||
f()
|
||||
self.wait()
|
||||
# Cleanup: to avoid GC warnings (which for some reason only seem
|
||||
# to show up on py33-asyncio), invoke the callback (which will do
|
||||
# nothing since the gen.Runner is already finished) and delete it.
|
||||
self.callback()
|
||||
del self.callback
|
||||
|
||||
@gen_test
|
||||
def test_yield_outside_with(self):
|
||||
# This pattern avoids the problem in the previous test.
|
||||
cb = yield gen.Callback('k1')
|
||||
with StackContext(functools.partial(self.context, 'c1')):
|
||||
self.io_loop.add_callback(cb)
|
||||
yield gen.Wait('k1')
|
||||
|
||||
def test_yield_in_with_exception_stack_context(self):
|
||||
# As above, but with ExceptionStackContext instead of StackContext.
|
||||
@gen.engine
|
||||
def f():
|
||||
with ExceptionStackContext(lambda t, v, tb: False):
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
|
||||
with self.assertRaises(StackContextInconsistentError):
|
||||
f()
|
||||
self.wait()
|
||||
|
||||
@gen_test
|
||||
def test_yield_outside_with_exception_stack_context(self):
|
||||
cb = yield gen.Callback('k1')
|
||||
with ExceptionStackContext(lambda t, v, tb: False):
|
||||
self.io_loop.add_callback(cb)
|
||||
yield gen.Wait('k1')
|
||||
|
||||
@gen_test
|
||||
def test_run_with_stack_context(self):
|
||||
@gen.coroutine
|
||||
def f1():
|
||||
self.assertEqual(self.active_contexts, ['c1'])
|
||||
yield run_with_stack_context(
|
||||
StackContext(functools.partial(self.context, 'c2')),
|
||||
f2)
|
||||
self.assertEqual(self.active_contexts, ['c1'])
|
||||
|
||||
@gen.coroutine
|
||||
def f2():
|
||||
self.assertEqual(self.active_contexts, ['c1', 'c2'])
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
self.assertEqual(self.active_contexts, ['c1', 'c2'])
|
||||
|
||||
self.assertEqual(self.active_contexts, [])
|
||||
yield run_with_stack_context(
|
||||
StackContext(functools.partial(self.context, 'c1')),
|
||||
f1)
|
||||
self.assertEqual(self.active_contexts, [])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Executable
+1
@@ -0,0 +1 @@
|
||||
this is the index
|
||||
Executable
+2
@@ -0,0 +1,2 @@
|
||||
User-agent: *
|
||||
Disallow: /
|
||||
Executable
+23
@@ -0,0 +1,23 @@
|
||||
<?xml version="1.0"?>
|
||||
<data>
|
||||
<country name="Liechtenstein">
|
||||
<rank>1</rank>
|
||||
<year>2008</year>
|
||||
<gdppc>141100</gdppc>
|
||||
<neighbor name="Austria" direction="E"/>
|
||||
<neighbor name="Switzerland" direction="W"/>
|
||||
</country>
|
||||
<country name="Singapore">
|
||||
<rank>4</rank>
|
||||
<year>2011</year>
|
||||
<gdppc>59900</gdppc>
|
||||
<neighbor name="Malaysia" direction="N"/>
|
||||
</country>
|
||||
<country name="Panama">
|
||||
<rank>68</rank>
|
||||
<year>2011</year>
|
||||
<gdppc>13600</gdppc>
|
||||
<neighbor name="Costa Rica" direction="W"/>
|
||||
<neighbor name="Colombia" direction="E"/>
|
||||
</country>
|
||||
</data>
|
||||
Executable
BIN
Binary file not shown.
Executable
BIN
Binary file not shown.
Executable
+2
@@ -0,0 +1,2 @@
|
||||
This file should not be served by StaticFileHandler even though
|
||||
its name starts with "static".
|
||||
Executable
+430
@@ -0,0 +1,430 @@
|
||||
#
|
||||
# Copyright 2014 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from contextlib import closing
|
||||
import os
|
||||
import socket
|
||||
|
||||
from tornado.concurrent import Future
|
||||
from tornado.netutil import bind_sockets, Resolver
|
||||
from tornado.queues import Queue
|
||||
from tornado.tcpclient import TCPClient, _Connector
|
||||
from tornado.tcpserver import TCPServer
|
||||
from tornado.testing import AsyncTestCase, gen_test
|
||||
from tornado.test.util import skipIfNoIPv6, unittest, refusing_port, skipIfNonUnix
|
||||
from tornado.gen import TimeoutError
|
||||
|
||||
# Fake address families for testing. Used in place of AF_INET
|
||||
# and AF_INET6 because some installations do not have AF_INET6.
|
||||
AF1, AF2 = 1, 2
|
||||
|
||||
|
||||
class TestTCPServer(TCPServer):
|
||||
def __init__(self, family):
|
||||
super(TestTCPServer, self).__init__()
|
||||
self.streams = []
|
||||
self.queue = Queue()
|
||||
sockets = bind_sockets(None, 'localhost', family)
|
||||
self.add_sockets(sockets)
|
||||
self.port = sockets[0].getsockname()[1]
|
||||
|
||||
def handle_stream(self, stream, address):
|
||||
self.streams.append(stream)
|
||||
self.queue.put(stream)
|
||||
|
||||
def stop(self):
|
||||
super(TestTCPServer, self).stop()
|
||||
for stream in self.streams:
|
||||
stream.close()
|
||||
|
||||
|
||||
class TCPClientTest(AsyncTestCase):
|
||||
def setUp(self):
|
||||
super(TCPClientTest, self).setUp()
|
||||
self.server = None
|
||||
self.client = TCPClient()
|
||||
|
||||
def start_server(self, family):
|
||||
if family == socket.AF_UNSPEC and 'TRAVIS' in os.environ:
|
||||
self.skipTest("dual-stack servers often have port conflicts on travis")
|
||||
self.server = TestTCPServer(family)
|
||||
return self.server.port
|
||||
|
||||
def stop_server(self):
|
||||
if self.server is not None:
|
||||
self.server.stop()
|
||||
self.server = None
|
||||
|
||||
def tearDown(self):
|
||||
self.client.close()
|
||||
self.stop_server()
|
||||
super(TCPClientTest, self).tearDown()
|
||||
|
||||
def skipIfLocalhostV4(self):
|
||||
# The port used here doesn't matter, but some systems require it
|
||||
# to be non-zero if we do not also pass AI_PASSIVE.
|
||||
addrinfo = self.io_loop.run_sync(lambda: Resolver().resolve('localhost', 80))
|
||||
families = set(addr[0] for addr in addrinfo)
|
||||
if socket.AF_INET6 not in families:
|
||||
self.skipTest("localhost does not resolve to ipv6")
|
||||
|
||||
@gen_test
|
||||
def do_test_connect(self, family, host, source_ip=None, source_port=None):
|
||||
port = self.start_server(family)
|
||||
stream = yield self.client.connect(host, port,
|
||||
source_ip=source_ip,
|
||||
source_port=source_port)
|
||||
server_stream = yield self.server.queue.get()
|
||||
with closing(stream):
|
||||
stream.write(b"hello")
|
||||
data = yield server_stream.read_bytes(5)
|
||||
self.assertEqual(data, b"hello")
|
||||
|
||||
def test_connect_ipv4_ipv4(self):
|
||||
self.do_test_connect(socket.AF_INET, '127.0.0.1')
|
||||
|
||||
def test_connect_ipv4_dual(self):
|
||||
self.do_test_connect(socket.AF_INET, 'localhost')
|
||||
|
||||
@skipIfNoIPv6
|
||||
def test_connect_ipv6_ipv6(self):
|
||||
self.skipIfLocalhostV4()
|
||||
self.do_test_connect(socket.AF_INET6, '::1')
|
||||
|
||||
@skipIfNoIPv6
|
||||
def test_connect_ipv6_dual(self):
|
||||
self.skipIfLocalhostV4()
|
||||
if Resolver.configured_class().__name__.endswith('TwistedResolver'):
|
||||
self.skipTest('TwistedResolver does not support multiple addresses')
|
||||
self.do_test_connect(socket.AF_INET6, 'localhost')
|
||||
|
||||
def test_connect_unspec_ipv4(self):
|
||||
self.do_test_connect(socket.AF_UNSPEC, '127.0.0.1')
|
||||
|
||||
@skipIfNoIPv6
|
||||
def test_connect_unspec_ipv6(self):
|
||||
self.skipIfLocalhostV4()
|
||||
self.do_test_connect(socket.AF_UNSPEC, '::1')
|
||||
|
||||
def test_connect_unspec_dual(self):
|
||||
self.do_test_connect(socket.AF_UNSPEC, 'localhost')
|
||||
|
||||
@gen_test
|
||||
def test_refused_ipv4(self):
|
||||
cleanup_func, port = refusing_port()
|
||||
self.addCleanup(cleanup_func)
|
||||
with self.assertRaises(IOError):
|
||||
yield self.client.connect('127.0.0.1', port)
|
||||
|
||||
def test_source_ip_fail(self):
|
||||
'''
|
||||
Fail when trying to use the source IP Address '8.8.8.8'.
|
||||
'''
|
||||
self.assertRaises(socket.error,
|
||||
self.do_test_connect,
|
||||
socket.AF_INET,
|
||||
'127.0.0.1',
|
||||
source_ip='8.8.8.8')
|
||||
|
||||
def test_source_ip_success(self):
|
||||
'''
|
||||
Success when trying to use the source IP Address '127.0.0.1'
|
||||
'''
|
||||
self.do_test_connect(socket.AF_INET, '127.0.0.1', source_ip='127.0.0.1')
|
||||
|
||||
@skipIfNonUnix
|
||||
def test_source_port_fail(self):
|
||||
'''
|
||||
Fail when trying to use source port 1.
|
||||
'''
|
||||
self.assertRaises(socket.error,
|
||||
self.do_test_connect,
|
||||
socket.AF_INET,
|
||||
'127.0.0.1',
|
||||
source_port=1)
|
||||
|
||||
@gen_test
|
||||
def test_connect_timeout(self):
|
||||
timeout = 0.05
|
||||
|
||||
class TimeoutResolver(Resolver):
|
||||
def resolve(self, *args, **kwargs):
|
||||
return Future() # never completes
|
||||
with self.assertRaises(TimeoutError):
|
||||
yield TCPClient(resolver=TimeoutResolver()).connect(
|
||||
'1.2.3.4', 12345, timeout=timeout)
|
||||
|
||||
|
||||
class TestConnectorSplit(unittest.TestCase):
|
||||
def test_one_family(self):
|
||||
# These addresses aren't in the right format, but split doesn't care.
|
||||
primary, secondary = _Connector.split(
|
||||
[(AF1, 'a'),
|
||||
(AF1, 'b')])
|
||||
self.assertEqual(primary, [(AF1, 'a'),
|
||||
(AF1, 'b')])
|
||||
self.assertEqual(secondary, [])
|
||||
|
||||
def test_mixed(self):
|
||||
primary, secondary = _Connector.split(
|
||||
[(AF1, 'a'),
|
||||
(AF2, 'b'),
|
||||
(AF1, 'c'),
|
||||
(AF2, 'd')])
|
||||
self.assertEqual(primary, [(AF1, 'a'), (AF1, 'c')])
|
||||
self.assertEqual(secondary, [(AF2, 'b'), (AF2, 'd')])
|
||||
|
||||
|
||||
class ConnectorTest(AsyncTestCase):
|
||||
class FakeStream(object):
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
def setUp(self):
|
||||
super(ConnectorTest, self).setUp()
|
||||
self.connect_futures = {}
|
||||
self.streams = {}
|
||||
self.addrinfo = [(AF1, 'a'), (AF1, 'b'),
|
||||
(AF2, 'c'), (AF2, 'd')]
|
||||
|
||||
def tearDown(self):
|
||||
# Unless explicitly checked (and popped) in the test, we shouldn't
|
||||
# be closing any streams
|
||||
for stream in self.streams.values():
|
||||
self.assertFalse(stream.closed)
|
||||
super(ConnectorTest, self).tearDown()
|
||||
|
||||
def create_stream(self, af, addr):
|
||||
stream = ConnectorTest.FakeStream()
|
||||
self.streams[addr] = stream
|
||||
future = Future()
|
||||
self.connect_futures[(af, addr)] = future
|
||||
return stream, future
|
||||
|
||||
def assert_pending(self, *keys):
|
||||
self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys))
|
||||
|
||||
def resolve_connect(self, af, addr, success):
|
||||
future = self.connect_futures.pop((af, addr))
|
||||
if success:
|
||||
future.set_result(self.streams[addr])
|
||||
else:
|
||||
self.streams.pop(addr)
|
||||
future.set_exception(IOError())
|
||||
# Run the loop to allow callbacks to be run.
|
||||
self.io_loop.add_callback(self.stop)
|
||||
self.wait()
|
||||
|
||||
def assert_connector_streams_closed(self, conn):
|
||||
for stream in conn.streams:
|
||||
self.assertTrue(stream.closed)
|
||||
|
||||
def start_connect(self, addrinfo):
|
||||
conn = _Connector(addrinfo, self.create_stream)
|
||||
# Give it a huge timeout; we'll trigger timeouts manually.
|
||||
future = conn.start(3600, connect_timeout=self.io_loop.time() + 3600)
|
||||
return conn, future
|
||||
|
||||
def test_immediate_success(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assertEqual(list(self.connect_futures.keys()),
|
||||
[(AF1, 'a')])
|
||||
self.resolve_connect(AF1, 'a', True)
|
||||
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
|
||||
|
||||
def test_immediate_failure(self):
|
||||
# Fail with just one address.
|
||||
conn, future = self.start_connect([(AF1, 'a')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assertRaises(IOError, future.result)
|
||||
|
||||
def test_one_family_second_try(self):
|
||||
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending((AF1, 'b'))
|
||||
self.resolve_connect(AF1, 'b', True)
|
||||
self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
|
||||
|
||||
def test_one_family_second_try_failure(self):
|
||||
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending((AF1, 'b'))
|
||||
self.resolve_connect(AF1, 'b', False)
|
||||
self.assertRaises(IOError, future.result)
|
||||
|
||||
def test_one_family_second_try_timeout(self):
|
||||
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
# trigger the timeout while the first lookup is pending;
|
||||
# nothing happens.
|
||||
conn.on_timeout()
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending((AF1, 'b'))
|
||||
self.resolve_connect(AF1, 'b', True)
|
||||
self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
|
||||
|
||||
def test_two_families_immediate_failure(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending((AF1, 'b'), (AF2, 'c'))
|
||||
self.resolve_connect(AF1, 'b', False)
|
||||
self.resolve_connect(AF2, 'c', True)
|
||||
self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
|
||||
|
||||
def test_two_families_timeout(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assert_pending((AF1, 'a'))
|
||||
conn.on_timeout()
|
||||
self.assert_pending((AF1, 'a'), (AF2, 'c'))
|
||||
self.resolve_connect(AF2, 'c', True)
|
||||
self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
|
||||
# resolving 'a' after the connection has completed doesn't start 'b'
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending()
|
||||
|
||||
def test_success_after_timeout(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assert_pending((AF1, 'a'))
|
||||
conn.on_timeout()
|
||||
self.assert_pending((AF1, 'a'), (AF2, 'c'))
|
||||
self.resolve_connect(AF1, 'a', True)
|
||||
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
|
||||
# resolving 'c' after completion closes the connection.
|
||||
self.resolve_connect(AF2, 'c', True)
|
||||
self.assertTrue(self.streams.pop('c').closed)
|
||||
|
||||
def test_all_fail(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assert_pending((AF1, 'a'))
|
||||
conn.on_timeout()
|
||||
self.assert_pending((AF1, 'a'), (AF2, 'c'))
|
||||
self.resolve_connect(AF2, 'c', False)
|
||||
self.assert_pending((AF1, 'a'), (AF2, 'd'))
|
||||
self.resolve_connect(AF2, 'd', False)
|
||||
# one queue is now empty
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending((AF1, 'b'))
|
||||
self.assertFalse(future.done())
|
||||
self.resolve_connect(AF1, 'b', False)
|
||||
self.assertRaises(IOError, future.result)
|
||||
|
||||
def test_one_family_timeout_after_connect_timeout(self):
|
||||
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
conn.on_connect_timeout()
|
||||
# the connector will close all streams on connect timeout, we
|
||||
# should explicitly pop the connect_future.
|
||||
self.connect_futures.pop((AF1, 'a'))
|
||||
self.assertTrue(self.streams.pop('a').closed)
|
||||
conn.on_timeout()
|
||||
# if the future is set with TimeoutError, we will not iterate next
|
||||
# possible address.
|
||||
self.assert_pending()
|
||||
self.assertEqual(len(conn.streams), 1)
|
||||
self.assert_connector_streams_closed(conn)
|
||||
self.assertRaises(TimeoutError, future.result)
|
||||
|
||||
def test_one_family_success_before_connect_timeout(self):
|
||||
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', True)
|
||||
conn.on_connect_timeout()
|
||||
self.assert_pending()
|
||||
self.assertEqual(self.streams['a'].closed, False)
|
||||
# success stream will be pop
|
||||
self.assertEqual(len(conn.streams), 0)
|
||||
# streams in connector should be closed after connect timeout
|
||||
self.assert_connector_streams_closed(conn)
|
||||
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
|
||||
|
||||
def test_one_family_second_try_after_connect_timeout(self):
|
||||
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending((AF1, 'b'))
|
||||
conn.on_connect_timeout()
|
||||
self.connect_futures.pop((AF1, 'b'))
|
||||
self.assertTrue(self.streams.pop('b').closed)
|
||||
self.assert_pending()
|
||||
self.assertEqual(len(conn.streams), 2)
|
||||
self.assert_connector_streams_closed(conn)
|
||||
self.assertRaises(TimeoutError, future.result)
|
||||
|
||||
def test_one_family_second_try_failure_before_connect_timeout(self):
|
||||
conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
|
||||
self.assert_pending((AF1, 'a'))
|
||||
self.resolve_connect(AF1, 'a', False)
|
||||
self.assert_pending((AF1, 'b'))
|
||||
self.resolve_connect(AF1, 'b', False)
|
||||
conn.on_connect_timeout()
|
||||
self.assert_pending()
|
||||
self.assertEqual(len(conn.streams), 2)
|
||||
self.assert_connector_streams_closed(conn)
|
||||
self.assertRaises(IOError, future.result)
|
||||
|
||||
def test_two_family_timeout_before_connect_timeout(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assert_pending((AF1, 'a'))
|
||||
conn.on_timeout()
|
||||
self.assert_pending((AF1, 'a'), (AF2, 'c'))
|
||||
conn.on_connect_timeout()
|
||||
self.connect_futures.pop((AF1, 'a'))
|
||||
self.assertTrue(self.streams.pop('a').closed)
|
||||
self.connect_futures.pop((AF2, 'c'))
|
||||
self.assertTrue(self.streams.pop('c').closed)
|
||||
self.assert_pending()
|
||||
self.assertEqual(len(conn.streams), 2)
|
||||
self.assert_connector_streams_closed(conn)
|
||||
self.assertRaises(TimeoutError, future.result)
|
||||
|
||||
def test_two_family_success_after_timeout(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assert_pending((AF1, 'a'))
|
||||
conn.on_timeout()
|
||||
self.assert_pending((AF1, 'a'), (AF2, 'c'))
|
||||
self.resolve_connect(AF1, 'a', True)
|
||||
# if one of streams succeed, connector will close all other streams
|
||||
self.connect_futures.pop((AF2, 'c'))
|
||||
self.assertTrue(self.streams.pop('c').closed)
|
||||
self.assert_pending()
|
||||
self.assertEqual(len(conn.streams), 1)
|
||||
self.assert_connector_streams_closed(conn)
|
||||
self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
|
||||
|
||||
def test_two_family_timeout_after_connect_timeout(self):
|
||||
conn, future = self.start_connect(self.addrinfo)
|
||||
self.assert_pending((AF1, 'a'))
|
||||
conn.on_connect_timeout()
|
||||
self.connect_futures.pop((AF1, 'a'))
|
||||
self.assertTrue(self.streams.pop('a').closed)
|
||||
self.assert_pending()
|
||||
conn.on_timeout()
|
||||
# if the future is set with TimeoutError, connector will not
|
||||
# trigger secondary address.
|
||||
self.assert_pending()
|
||||
self.assertEqual(len(conn.streams), 1)
|
||||
self.assert_connector_streams_closed(conn)
|
||||
self.assertRaises(TimeoutError, future.result)
|
||||
Executable
+193
@@ -0,0 +1,193 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
from tornado.escape import utf8, to_unicode
|
||||
from tornado import gen
|
||||
from tornado.iostream import IOStream
|
||||
from tornado.log import app_log
|
||||
from tornado.stack_context import NullContext
|
||||
from tornado.tcpserver import TCPServer
|
||||
from tornado.test.util import skipBefore35, skipIfNonUnix, exec_test, unittest
|
||||
from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
|
||||
|
||||
|
||||
class TCPServerTest(AsyncTestCase):
|
||||
@gen_test
|
||||
def test_handle_stream_coroutine_logging(self):
|
||||
# handle_stream may be a coroutine and any exception in its
|
||||
# Future will be logged.
|
||||
class TestServer(TCPServer):
|
||||
@gen.coroutine
|
||||
def handle_stream(self, stream, address):
|
||||
yield stream.read_bytes(len(b'hello'))
|
||||
stream.close()
|
||||
1 / 0
|
||||
|
||||
server = client = None
|
||||
try:
|
||||
sock, port = bind_unused_port()
|
||||
with NullContext():
|
||||
server = TestServer()
|
||||
server.add_socket(sock)
|
||||
client = IOStream(socket.socket())
|
||||
with ExpectLog(app_log, "Exception in callback"):
|
||||
yield client.connect(('localhost', port))
|
||||
yield client.write(b'hello')
|
||||
yield client.read_until_close()
|
||||
yield gen.moment
|
||||
finally:
|
||||
if server is not None:
|
||||
server.stop()
|
||||
if client is not None:
|
||||
client.close()
|
||||
|
||||
@skipBefore35
|
||||
@gen_test
|
||||
def test_handle_stream_native_coroutine(self):
|
||||
# handle_stream may be a native coroutine.
|
||||
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
class TestServer(TCPServer):
|
||||
async def handle_stream(self, stream, address):
|
||||
stream.write(b'data')
|
||||
stream.close()
|
||||
""")
|
||||
|
||||
sock, port = bind_unused_port()
|
||||
server = namespace['TestServer']()
|
||||
server.add_socket(sock)
|
||||
client = IOStream(socket.socket())
|
||||
yield client.connect(('localhost', port))
|
||||
result = yield client.read_until_close()
|
||||
self.assertEqual(result, b'data')
|
||||
server.stop()
|
||||
client.close()
|
||||
|
||||
def test_stop_twice(self):
|
||||
sock, port = bind_unused_port()
|
||||
server = TCPServer()
|
||||
server.add_socket(sock)
|
||||
server.stop()
|
||||
server.stop()
|
||||
|
||||
@gen_test
|
||||
def test_stop_in_callback(self):
|
||||
# Issue #2069: calling server.stop() in a loop callback should not
|
||||
# raise EBADF when the loop handles other server connection
|
||||
# requests in the same loop iteration
|
||||
|
||||
class TestServer(TCPServer):
|
||||
@gen.coroutine
|
||||
def handle_stream(self, stream, address):
|
||||
server.stop()
|
||||
yield stream.read_until_close()
|
||||
|
||||
sock, port = bind_unused_port()
|
||||
server = TestServer()
|
||||
server.add_socket(sock)
|
||||
server_addr = ('localhost', port)
|
||||
N = 40
|
||||
clients = [IOStream(socket.socket()) for i in range(N)]
|
||||
connected_clients = []
|
||||
|
||||
@gen.coroutine
|
||||
def connect(c):
|
||||
try:
|
||||
yield c.connect(server_addr)
|
||||
except EnvironmentError:
|
||||
pass
|
||||
else:
|
||||
connected_clients.append(c)
|
||||
|
||||
yield [connect(c) for c in clients]
|
||||
|
||||
self.assertGreater(len(connected_clients), 0,
|
||||
"all clients failed connecting")
|
||||
try:
|
||||
if len(connected_clients) == N:
|
||||
# Ideally we'd make the test deterministic, but we're testing
|
||||
# for a race condition in combination with the system's TCP stack...
|
||||
self.skipTest("at least one client should fail connecting "
|
||||
"for the test to be meaningful")
|
||||
finally:
|
||||
for c in connected_clients:
|
||||
c.close()
|
||||
|
||||
# Here tearDown() would re-raise the EBADF encountered in the IO loop
|
||||
|
||||
|
||||
@skipIfNonUnix
|
||||
class TestMultiprocess(unittest.TestCase):
|
||||
# These tests verify that the two multiprocess examples from the
|
||||
# TCPServer docs work. Both tests start a server with three worker
|
||||
# processes, each of which prints its task id to stdout (a single
|
||||
# byte, so we don't have to worry about atomicity of the shared
|
||||
# stdout stream) and then exits.
|
||||
def run_subproc(self, code):
|
||||
proc = subprocess.Popen(sys.executable,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE)
|
||||
proc.stdin.write(utf8(code))
|
||||
proc.stdin.close()
|
||||
proc.wait()
|
||||
stdout = proc.stdout.read()
|
||||
proc.stdout.close()
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError("Process returned %d. stdout=%r" % (
|
||||
proc.returncode, stdout))
|
||||
return to_unicode(stdout)
|
||||
|
||||
def test_single(self):
|
||||
# As a sanity check, run the single-process version through this test
|
||||
# harness too.
|
||||
code = textwrap.dedent("""
|
||||
from __future__ import print_function
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.tcpserver import TCPServer
|
||||
|
||||
server = TCPServer()
|
||||
server.listen(0, address='127.0.0.1')
|
||||
IOLoop.current().run_sync(lambda: None)
|
||||
print('012', end='')
|
||||
""")
|
||||
out = self.run_subproc(code)
|
||||
self.assertEqual(''.join(sorted(out)), "012")
|
||||
|
||||
def test_simple(self):
|
||||
code = textwrap.dedent("""
|
||||
from __future__ import print_function
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.process import task_id
|
||||
from tornado.tcpserver import TCPServer
|
||||
|
||||
server = TCPServer()
|
||||
server.bind(0, address='127.0.0.1')
|
||||
server.start(3)
|
||||
IOLoop.current().run_sync(lambda: None)
|
||||
print(task_id(), end='')
|
||||
""")
|
||||
out = self.run_subproc(code)
|
||||
self.assertEqual(''.join(sorted(out)), "012")
|
||||
|
||||
def test_advanced(self):
|
||||
code = textwrap.dedent("""
|
||||
from __future__ import print_function
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.netutil import bind_sockets
|
||||
from tornado.process import fork_processes, task_id
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.tcpserver import TCPServer
|
||||
|
||||
sockets = bind_sockets(0, address='127.0.0.1')
|
||||
fork_processes(3)
|
||||
server = TCPServer()
|
||||
server.add_sockets(sockets)
|
||||
IOLoop.current().run_sync(lambda: None)
|
||||
print(task_id(), end='')
|
||||
""")
|
||||
out = self.run_subproc(code)
|
||||
self.assertEqual(''.join(sorted(out)), "012")
|
||||
Executable
+496
@@ -0,0 +1,496 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from tornado.escape import utf8, native_str, to_unicode
|
||||
from tornado.template import Template, DictLoader, ParseError, Loader
|
||||
from tornado.test.util import unittest
|
||||
from tornado.util import ObjectDict, unicode_type
|
||||
|
||||
|
||||
class TemplateTest(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
template = Template("Hello {{ name }}!")
|
||||
self.assertEqual(template.generate(name="Ben"),
|
||||
b"Hello Ben!")
|
||||
|
||||
def test_bytes(self):
|
||||
template = Template("Hello {{ name }}!")
|
||||
self.assertEqual(template.generate(name=utf8("Ben")),
|
||||
b"Hello Ben!")
|
||||
|
||||
def test_expressions(self):
|
||||
template = Template("2 + 2 = {{ 2 + 2 }}")
|
||||
self.assertEqual(template.generate(), b"2 + 2 = 4")
|
||||
|
||||
def test_comment(self):
|
||||
template = Template("Hello{# TODO i18n #} {{ name }}!")
|
||||
self.assertEqual(template.generate(name=utf8("Ben")),
|
||||
b"Hello Ben!")
|
||||
|
||||
def test_include(self):
|
||||
loader = DictLoader({
|
||||
"index.html": '{% include "header.html" %}\nbody text',
|
||||
"header.html": "header text",
|
||||
})
|
||||
self.assertEqual(loader.load("index.html").generate(),
|
||||
b"header text\nbody text")
|
||||
|
||||
def test_extends(self):
|
||||
loader = DictLoader({
|
||||
"base.html": """\
|
||||
<title>{% block title %}default title{% end %}</title>
|
||||
<body>{% block body %}default body{% end %}</body>
|
||||
""",
|
||||
"page.html": """\
|
||||
{% extends "base.html" %}
|
||||
{% block title %}page title{% end %}
|
||||
{% block body %}page body{% end %}
|
||||
""",
|
||||
})
|
||||
self.assertEqual(loader.load("page.html").generate(),
|
||||
b"<title>page title</title>\n<body>page body</body>\n")
|
||||
|
||||
def test_relative_load(self):
|
||||
loader = DictLoader({
|
||||
"a/1.html": "{% include '2.html' %}",
|
||||
"a/2.html": "{% include '../b/3.html' %}",
|
||||
"b/3.html": "ok",
|
||||
})
|
||||
self.assertEqual(loader.load("a/1.html").generate(),
|
||||
b"ok")
|
||||
|
||||
def test_escaping(self):
|
||||
self.assertRaises(ParseError, lambda: Template("{{"))
|
||||
self.assertRaises(ParseError, lambda: Template("{%"))
|
||||
self.assertEqual(Template("{{!").generate(), b"{{")
|
||||
self.assertEqual(Template("{%!").generate(), b"{%")
|
||||
self.assertEqual(Template("{#!").generate(), b"{#")
|
||||
self.assertEqual(Template("{{ 'expr' }} {{!jquery expr}}").generate(),
|
||||
b"expr {{jquery expr}}")
|
||||
|
||||
def test_unicode_template(self):
|
||||
template = Template(utf8(u"\u00e9"))
|
||||
self.assertEqual(template.generate(), utf8(u"\u00e9"))
|
||||
|
||||
def test_unicode_literal_expression(self):
|
||||
# Unicode literals should be usable in templates. Note that this
|
||||
# test simulates unicode characters appearing directly in the
|
||||
# template file (with utf8 encoding), i.e. \u escapes would not
|
||||
# be used in the template file itself.
|
||||
if str is unicode_type:
|
||||
# python 3 needs a different version of this test since
|
||||
# 2to3 doesn't run on template internals
|
||||
template = Template(utf8(u'{{ "\u00e9" }}'))
|
||||
else:
|
||||
template = Template(utf8(u'{{ u"\u00e9" }}'))
|
||||
self.assertEqual(template.generate(), utf8(u"\u00e9"))
|
||||
|
||||
def test_custom_namespace(self):
|
||||
loader = DictLoader({"test.html": "{{ inc(5) }}"}, namespace={"inc": lambda x: x + 1})
|
||||
self.assertEqual(loader.load("test.html").generate(), b"6")
|
||||
|
||||
def test_apply(self):
|
||||
def upper(s):
|
||||
return s.upper()
|
||||
template = Template(utf8("{% apply upper %}foo{% end %}"))
|
||||
self.assertEqual(template.generate(upper=upper), b"FOO")
|
||||
|
||||
def test_unicode_apply(self):
|
||||
def upper(s):
|
||||
return to_unicode(s).upper()
|
||||
template = Template(utf8(u"{% apply upper %}foo \u00e9{% end %}"))
|
||||
self.assertEqual(template.generate(upper=upper), utf8(u"FOO \u00c9"))
|
||||
|
||||
def test_bytes_apply(self):
|
||||
def upper(s):
|
||||
return utf8(to_unicode(s).upper())
|
||||
template = Template(utf8(u"{% apply upper %}foo \u00e9{% end %}"))
|
||||
self.assertEqual(template.generate(upper=upper), utf8(u"FOO \u00c9"))
|
||||
|
||||
def test_if(self):
|
||||
template = Template(utf8("{% if x > 4 %}yes{% else %}no{% end %}"))
|
||||
self.assertEqual(template.generate(x=5), b"yes")
|
||||
self.assertEqual(template.generate(x=3), b"no")
|
||||
|
||||
def test_if_empty_body(self):
|
||||
template = Template(utf8("{% if True %}{% else %}{% end %}"))
|
||||
self.assertEqual(template.generate(), b"")
|
||||
|
||||
def test_try(self):
|
||||
template = Template(utf8("""{% try %}
|
||||
try{% set y = 1/x %}
|
||||
{% except %}-except
|
||||
{% else %}-else
|
||||
{% finally %}-finally
|
||||
{% end %}"""))
|
||||
self.assertEqual(template.generate(x=1), b"\ntry\n-else\n-finally\n")
|
||||
self.assertEqual(template.generate(x=0), b"\ntry-except\n-finally\n")
|
||||
|
||||
def test_comment_directive(self):
|
||||
template = Template(utf8("{% comment blah blah %}foo"))
|
||||
self.assertEqual(template.generate(), b"foo")
|
||||
|
||||
def test_break_continue(self):
|
||||
template = Template(utf8("""\
|
||||
{% for i in range(10) %}
|
||||
{% if i == 2 %}
|
||||
{% continue %}
|
||||
{% end %}
|
||||
{{ i }}
|
||||
{% if i == 6 %}
|
||||
{% break %}
|
||||
{% end %}
|
||||
{% end %}"""))
|
||||
result = template.generate()
|
||||
# remove extraneous whitespace
|
||||
result = b''.join(result.split())
|
||||
self.assertEqual(result, b"013456")
|
||||
|
||||
def test_break_outside_loop(self):
|
||||
try:
|
||||
Template(utf8("{% break %}"))
|
||||
raise Exception("Did not get expected exception")
|
||||
except ParseError:
|
||||
pass
|
||||
|
||||
def test_break_in_apply(self):
|
||||
# This test verifies current behavior, although of course it would
|
||||
# be nice if apply didn't cause seemingly unrelated breakage
|
||||
try:
|
||||
Template(utf8("{% for i in [] %}{% apply foo %}{% break %}{% end %}{% end %}"))
|
||||
raise Exception("Did not get expected exception")
|
||||
except ParseError:
|
||||
pass
|
||||
|
||||
@unittest.skipIf(sys.version_info >= division.getMandatoryRelease(),
|
||||
'no testable future imports')
|
||||
def test_no_inherit_future(self):
|
||||
# This file has from __future__ import division...
|
||||
self.assertEqual(1 / 2, 0.5)
|
||||
# ...but the template doesn't
|
||||
template = Template('{{ 1 / 2 }}')
|
||||
self.assertEqual(template.generate(), '0')
|
||||
|
||||
def test_non_ascii_name(self):
|
||||
loader = DictLoader({u"t\u00e9st.html": "hello"})
|
||||
self.assertEqual(loader.load(u"t\u00e9st.html").generate(), b"hello")
|
||||
|
||||
|
||||
class StackTraceTest(unittest.TestCase):
|
||||
def test_error_line_number_expression(self):
|
||||
loader = DictLoader({"test.html": """one
|
||||
two{{1/0}}
|
||||
three
|
||||
"""})
|
||||
try:
|
||||
loader.load("test.html").generate()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
self.assertTrue("# test.html:2" in traceback.format_exc())
|
||||
|
||||
def test_error_line_number_directive(self):
|
||||
loader = DictLoader({"test.html": """one
|
||||
two{%if 1/0%}
|
||||
three{%end%}
|
||||
"""})
|
||||
try:
|
||||
loader.load("test.html").generate()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
self.assertTrue("# test.html:2" in traceback.format_exc())
|
||||
|
||||
def test_error_line_number_module(self):
|
||||
loader = None
|
||||
|
||||
def load_generate(path, **kwargs):
|
||||
return loader.load(path).generate(**kwargs)
|
||||
|
||||
loader = DictLoader({
|
||||
"base.html": "{% module Template('sub.html') %}",
|
||||
"sub.html": "{{1/0}}",
|
||||
}, namespace={"_tt_modules": ObjectDict(Template=load_generate)})
|
||||
try:
|
||||
loader.load("base.html").generate()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
exc_stack = traceback.format_exc()
|
||||
self.assertTrue('# base.html:1' in exc_stack)
|
||||
self.assertTrue('# sub.html:1' in exc_stack)
|
||||
|
||||
def test_error_line_number_include(self):
|
||||
loader = DictLoader({
|
||||
"base.html": "{% include 'sub.html' %}",
|
||||
"sub.html": "{{1/0}}",
|
||||
})
|
||||
try:
|
||||
loader.load("base.html").generate()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
self.assertTrue("# sub.html:1 (via base.html:1)" in
|
||||
traceback.format_exc())
|
||||
|
||||
def test_error_line_number_extends_base_error(self):
|
||||
loader = DictLoader({
|
||||
"base.html": "{{1/0}}",
|
||||
"sub.html": "{% extends 'base.html' %}",
|
||||
})
|
||||
try:
|
||||
loader.load("sub.html").generate()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
exc_stack = traceback.format_exc()
|
||||
self.assertTrue("# base.html:1" in exc_stack)
|
||||
|
||||
def test_error_line_number_extends_sub_error(self):
|
||||
loader = DictLoader({
|
||||
"base.html": "{% block 'block' %}{% end %}",
|
||||
"sub.html": """
|
||||
{% extends 'base.html' %}
|
||||
{% block 'block' %}
|
||||
{{1/0}}
|
||||
{% end %}
|
||||
"""})
|
||||
try:
|
||||
loader.load("sub.html").generate()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
self.assertTrue("# sub.html:4 (via base.html:1)" in
|
||||
traceback.format_exc())
|
||||
|
||||
def test_multi_includes(self):
|
||||
loader = DictLoader({
|
||||
"a.html": "{% include 'b.html' %}",
|
||||
"b.html": "{% include 'c.html' %}",
|
||||
"c.html": "{{1/0}}",
|
||||
})
|
||||
try:
|
||||
loader.load("a.html").generate()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
self.assertTrue("# c.html:1 (via b.html:1, a.html:1)" in
|
||||
traceback.format_exc())
|
||||
|
||||
|
||||
class ParseErrorDetailTest(unittest.TestCase):
|
||||
def test_details(self):
|
||||
loader = DictLoader({
|
||||
"foo.html": "\n\n{{",
|
||||
})
|
||||
with self.assertRaises(ParseError) as cm:
|
||||
loader.load("foo.html")
|
||||
self.assertEqual("Missing end expression }} at foo.html:3",
|
||||
str(cm.exception))
|
||||
self.assertEqual("foo.html", cm.exception.filename)
|
||||
self.assertEqual(3, cm.exception.lineno)
|
||||
|
||||
def test_custom_parse_error(self):
|
||||
# Make sure that ParseErrors remain compatible with their
|
||||
# pre-4.3 signature.
|
||||
self.assertEqual("asdf at None:0", str(ParseError("asdf")))
|
||||
|
||||
|
||||
class AutoEscapeTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.templates = {
|
||||
"escaped.html": "{% autoescape xhtml_escape %}{{ name }}",
|
||||
"unescaped.html": "{% autoescape None %}{{ name }}",
|
||||
"default.html": "{{ name }}",
|
||||
|
||||
"include.html": """\
|
||||
escaped: {% include 'escaped.html' %}
|
||||
unescaped: {% include 'unescaped.html' %}
|
||||
default: {% include 'default.html' %}
|
||||
""",
|
||||
|
||||
"escaped_block.html": """\
|
||||
{% autoescape xhtml_escape %}\
|
||||
{% block name %}base: {{ name }}{% end %}""",
|
||||
"unescaped_block.html": """\
|
||||
{% autoescape None %}\
|
||||
{% block name %}base: {{ name }}{% end %}""",
|
||||
|
||||
# Extend a base template with different autoescape policy,
|
||||
# with and without overriding the base's blocks
|
||||
"escaped_extends_unescaped.html": """\
|
||||
{% autoescape xhtml_escape %}\
|
||||
{% extends "unescaped_block.html" %}""",
|
||||
"escaped_overrides_unescaped.html": """\
|
||||
{% autoescape xhtml_escape %}\
|
||||
{% extends "unescaped_block.html" %}\
|
||||
{% block name %}extended: {{ name }}{% end %}""",
|
||||
"unescaped_extends_escaped.html": """\
|
||||
{% autoescape None %}\
|
||||
{% extends "escaped_block.html" %}""",
|
||||
"unescaped_overrides_escaped.html": """\
|
||||
{% autoescape None %}\
|
||||
{% extends "escaped_block.html" %}\
|
||||
{% block name %}extended: {{ name }}{% end %}""",
|
||||
|
||||
"raw_expression.html": """\
|
||||
{% autoescape xhtml_escape %}\
|
||||
expr: {{ name }}
|
||||
raw: {% raw name %}""",
|
||||
}
|
||||
|
||||
def test_default_off(self):
|
||||
loader = DictLoader(self.templates, autoescape=None)
|
||||
name = "Bobby <table>s"
|
||||
self.assertEqual(loader.load("escaped.html").generate(name=name),
|
||||
b"Bobby <table>s")
|
||||
self.assertEqual(loader.load("unescaped.html").generate(name=name),
|
||||
b"Bobby <table>s")
|
||||
self.assertEqual(loader.load("default.html").generate(name=name),
|
||||
b"Bobby <table>s")
|
||||
|
||||
self.assertEqual(loader.load("include.html").generate(name=name),
|
||||
b"escaped: Bobby <table>s\n"
|
||||
b"unescaped: Bobby <table>s\n"
|
||||
b"default: Bobby <table>s\n")
|
||||
|
||||
def test_default_on(self):
|
||||
loader = DictLoader(self.templates, autoescape="xhtml_escape")
|
||||
name = "Bobby <table>s"
|
||||
self.assertEqual(loader.load("escaped.html").generate(name=name),
|
||||
b"Bobby <table>s")
|
||||
self.assertEqual(loader.load("unescaped.html").generate(name=name),
|
||||
b"Bobby <table>s")
|
||||
self.assertEqual(loader.load("default.html").generate(name=name),
|
||||
b"Bobby <table>s")
|
||||
|
||||
self.assertEqual(loader.load("include.html").generate(name=name),
|
||||
b"escaped: Bobby <table>s\n"
|
||||
b"unescaped: Bobby <table>s\n"
|
||||
b"default: Bobby <table>s\n")
|
||||
|
||||
def test_unextended_block(self):
|
||||
loader = DictLoader(self.templates)
|
||||
name = "<script>"
|
||||
self.assertEqual(loader.load("escaped_block.html").generate(name=name),
|
||||
b"base: <script>")
|
||||
self.assertEqual(loader.load("unescaped_block.html").generate(name=name),
|
||||
b"base: <script>")
|
||||
|
||||
def test_extended_block(self):
|
||||
loader = DictLoader(self.templates)
|
||||
|
||||
def render(name):
|
||||
return loader.load(name).generate(name="<script>")
|
||||
self.assertEqual(render("escaped_extends_unescaped.html"),
|
||||
b"base: <script>")
|
||||
self.assertEqual(render("escaped_overrides_unescaped.html"),
|
||||
b"extended: <script>")
|
||||
|
||||
self.assertEqual(render("unescaped_extends_escaped.html"),
|
||||
b"base: <script>")
|
||||
self.assertEqual(render("unescaped_overrides_escaped.html"),
|
||||
b"extended: <script>")
|
||||
|
||||
def test_raw_expression(self):
|
||||
loader = DictLoader(self.templates)
|
||||
|
||||
def render(name):
|
||||
return loader.load(name).generate(name='<>&"')
|
||||
self.assertEqual(render("raw_expression.html"),
|
||||
b"expr: <>&"\n"
|
||||
b"raw: <>&\"")
|
||||
|
||||
def test_custom_escape(self):
|
||||
loader = DictLoader({"foo.py":
|
||||
"{% autoescape py_escape %}s = {{ name }}\n"})
|
||||
|
||||
def py_escape(s):
|
||||
self.assertEqual(type(s), bytes)
|
||||
return repr(native_str(s))
|
||||
|
||||
def render(template, name):
|
||||
return loader.load(template).generate(py_escape=py_escape,
|
||||
name=name)
|
||||
self.assertEqual(render("foo.py", "<html>"),
|
||||
b"s = '<html>'\n")
|
||||
self.assertEqual(render("foo.py", "';sys.exit()"),
|
||||
b"""s = "';sys.exit()"\n""")
|
||||
self.assertEqual(render("foo.py", ["not a string"]),
|
||||
b"""s = "['not a string']"\n""")
|
||||
|
||||
def test_manual_minimize_whitespace(self):
|
||||
# Whitespace including newlines is allowed within template tags
|
||||
# and directives, and this is one way to avoid long lines while
|
||||
# keeping extra whitespace out of the rendered output.
|
||||
loader = DictLoader({'foo.txt': """\
|
||||
{% for i in items
|
||||
%}{% if i > 0 %}, {% end %}{#
|
||||
#}{{i
|
||||
}}{% end
|
||||
%}""",
|
||||
})
|
||||
self.assertEqual(loader.load("foo.txt").generate(items=range(5)),
|
||||
b"0, 1, 2, 3, 4")
|
||||
|
||||
def test_whitespace_by_filename(self):
|
||||
# Default whitespace handling depends on the template filename.
|
||||
loader = DictLoader({
|
||||
"foo.html": " \n\t\n asdf\t ",
|
||||
"bar.js": " \n\n\n\t qwer ",
|
||||
"baz.txt": "\t zxcv\n\n",
|
||||
"include.html": " {% include baz.txt %} \n ",
|
||||
"include.txt": "\t\t{% include foo.html %} ",
|
||||
})
|
||||
|
||||
# HTML and JS files have whitespace compressed by default.
|
||||
self.assertEqual(loader.load("foo.html").generate(),
|
||||
b"\nasdf ")
|
||||
self.assertEqual(loader.load("bar.js").generate(),
|
||||
b"\nqwer ")
|
||||
# TXT files do not.
|
||||
self.assertEqual(loader.load("baz.txt").generate(),
|
||||
b"\t zxcv\n\n")
|
||||
|
||||
# Each file maintains its own status even when included in
|
||||
# a file of the other type.
|
||||
self.assertEqual(loader.load("include.html").generate(),
|
||||
b" \t zxcv\n\n\n")
|
||||
self.assertEqual(loader.load("include.txt").generate(),
|
||||
b"\t\t\nasdf ")
|
||||
|
||||
def test_whitespace_by_loader(self):
|
||||
templates = {
|
||||
"foo.html": "\t\tfoo\n\n",
|
||||
"bar.txt": "\t\tbar\n\n",
|
||||
}
|
||||
loader = DictLoader(templates, whitespace='all')
|
||||
self.assertEqual(loader.load("foo.html").generate(), b"\t\tfoo\n\n")
|
||||
self.assertEqual(loader.load("bar.txt").generate(), b"\t\tbar\n\n")
|
||||
|
||||
loader = DictLoader(templates, whitespace='single')
|
||||
self.assertEqual(loader.load("foo.html").generate(), b" foo\n")
|
||||
self.assertEqual(loader.load("bar.txt").generate(), b" bar\n")
|
||||
|
||||
loader = DictLoader(templates, whitespace='oneline')
|
||||
self.assertEqual(loader.load("foo.html").generate(), b" foo ")
|
||||
self.assertEqual(loader.load("bar.txt").generate(), b" bar ")
|
||||
|
||||
def test_whitespace_directive(self):
|
||||
loader = DictLoader({
|
||||
"foo.html": """\
|
||||
{% whitespace oneline %}
|
||||
{% for i in range(3) %}
|
||||
{{ i }}
|
||||
{% end %}
|
||||
{% whitespace all %}
|
||||
pre\tformatted
|
||||
"""})
|
||||
self.assertEqual(loader.load("foo.html").generate(),
|
||||
b" 0 1 2 \n pre\tformatted\n")
|
||||
|
||||
|
||||
class TemplateLoaderTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.loader = Loader(os.path.join(os.path.dirname(__file__), "templates"))
|
||||
|
||||
def test_utf8_in_file(self):
|
||||
tmpl = self.loader.load("utf8.html")
|
||||
result = tmpl.generate()
|
||||
self.assertEqual(to_unicode(result).strip(), u"H\u00e9llo")
|
||||
Executable
+1
@@ -0,0 +1 @@
|
||||
Héllo
|
||||
Executable
+15
@@ -0,0 +1,15 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIICSDCCAbGgAwIBAgIJAN1oTowzMbkzMA0GCSqGSIb3DQEBBQUAMD0xCzAJBgNV
|
||||
BAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRkwFwYDVQQKDBBUb3JuYWRvIFdl
|
||||
YiBUZXN0MB4XDTEwMDgyNTE4MjQ0NFoXDTIwMDgyMjE4MjQ0NFowPTELMAkGA1UE
|
||||
BhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExGTAXBgNVBAoMEFRvcm5hZG8gV2Vi
|
||||
IFRlc3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBALirW3mX4jbdFse2aZwW
|
||||
zszCJ1IsRDrzALpbvMYLLbIZqo+Z8v5aERKTRQpXFqGaZyY+tdwYy7X7YXcLtKqv
|
||||
jnw/MSeIaqkw5pROKz5aR0nkPLvcTmhJVLVPCLc8dFnIlu8aC9TrDhr90P+PzU39
|
||||
UG7zLweA9zXKBuW3Tjo5dMP3AgMBAAGjUDBOMB0GA1UdDgQWBBRhJjMBYrzddCFr
|
||||
/0vvPyHMeqgo0TAfBgNVHSMEGDAWgBRhJjMBYrzddCFr/0vvPyHMeqgo0TAMBgNV
|
||||
HRMEBTADAQH/MA0GCSqGSIb3DQEBBQUAA4GBAGP6GaxSfb21bikcqaK3ZKCC1sRJ
|
||||
tiCuvJZbBUFUCAzl05dYUfJZim/oWK+GqyUkUB8ciYivUNnn9OtS7DnlTgT2ws2e
|
||||
lNgn5cuFXoAGcHXzVlHG3yoywYBf3y0Dn20uzrlLXUWJAzoSLOt2LTaXvwlgm7hF
|
||||
W1q8SQ6UBshRw2X0
|
||||
-----END CERTIFICATE-----
|
||||
Executable
+16
@@ -0,0 +1,16 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALirW3mX4jbdFse2
|
||||
aZwWzszCJ1IsRDrzALpbvMYLLbIZqo+Z8v5aERKTRQpXFqGaZyY+tdwYy7X7YXcL
|
||||
tKqvjnw/MSeIaqkw5pROKz5aR0nkPLvcTmhJVLVPCLc8dFnIlu8aC9TrDhr90P+P
|
||||
zU39UG7zLweA9zXKBuW3Tjo5dMP3AgMBAAECgYEAiygNaWYrf95AcUQi9w00zpUr
|
||||
nj9fNvCwxr2kVbRMvd2balS/CC4EmXPCXdVcZ3B7dBVjYzSIJV0Fh/iZLtnVysD9
|
||||
fcNMZ+Cz71b/T0ItsNYOsJk0qUVyP52uqsqkNppIPJsD19C+ZeMLZj6iEiylZyl8
|
||||
2U16c/kVIjER63mUEGkCQQDayQOTGPJrKHqPAkUqzeJkfvHH2yCf+cySU+w6ezyr
|
||||
j9yxcq8aZoLusCebDVT+kz7RqnD5JePFvB38cMuepYBLAkEA2BTFdZx30f4moPNv
|
||||
JlXlPNJMUTUzsXG7n4vNc+18O5ous0NGQII8jZWrIcTrP8wiP9fF3JwUsKrJhcBn
|
||||
xRs3hQJBAIDUgz1YIE+HW3vgi1gkOh6RPdBAsVpiXtr/fggFz3j60qrO7FswaAMj
|
||||
SX8c/6KUlBYkNjgP3qruFf4zcUNvEzcCQQCaioCPFVE9ByBpjLG6IUTKsz2R9xL5
|
||||
nfYqrbpLZ1aq6iLsYvkjugHE4X57sHLwNfdo4dHJbnf9wqhO2MVe25BhAkBdKYpY
|
||||
7OKc/2mmMbJDhVBgoixz/muN/5VjdfbvVY48naZkJF1p1tmogqPC5F1jPCS4rM+S
|
||||
FfPJIHRNEn2oktw5
|
||||
-----END PRIVATE KEY-----
|
||||
Executable
+350
@@ -0,0 +1,350 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from tornado import gen, ioloop
|
||||
from tornado.log import app_log
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient, HTTPTimeoutError
|
||||
from tornado.test.util import unittest, skipBefore35, exec_test, ignore_deprecation
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, bind_unused_port, gen_test, ExpectLog
|
||||
from tornado.web import Application
|
||||
import contextlib
|
||||
import os
|
||||
import platform
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError:
|
||||
asyncio = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_environ(name, value):
|
||||
old_value = os.environ.get(name)
|
||||
os.environ[name] = value
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if old_value is None:
|
||||
del os.environ[name]
|
||||
else:
|
||||
os.environ[name] = old_value
|
||||
|
||||
|
||||
class AsyncTestCaseTest(AsyncTestCase):
|
||||
def test_exception_in_callback(self):
|
||||
with ignore_deprecation():
|
||||
self.io_loop.add_callback(lambda: 1 / 0)
|
||||
try:
|
||||
self.wait()
|
||||
self.fail("did not get expected exception")
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
|
||||
def test_wait_timeout(self):
|
||||
time = self.io_loop.time
|
||||
|
||||
# Accept default 5-second timeout, no error
|
||||
self.io_loop.add_timeout(time() + 0.01, self.stop)
|
||||
self.wait()
|
||||
|
||||
# Timeout passed to wait()
|
||||
self.io_loop.add_timeout(time() + 1, self.stop)
|
||||
with self.assertRaises(self.failureException):
|
||||
self.wait(timeout=0.01)
|
||||
|
||||
# Timeout set with environment variable
|
||||
self.io_loop.add_timeout(time() + 1, self.stop)
|
||||
with set_environ('ASYNC_TEST_TIMEOUT', '0.01'):
|
||||
with self.assertRaises(self.failureException):
|
||||
self.wait()
|
||||
|
||||
def test_subsequent_wait_calls(self):
|
||||
"""
|
||||
This test makes sure that a second call to wait()
|
||||
clears the first timeout.
|
||||
"""
|
||||
self.io_loop.add_timeout(self.io_loop.time() + 0.00, self.stop)
|
||||
self.wait(timeout=0.02)
|
||||
self.io_loop.add_timeout(self.io_loop.time() + 0.03, self.stop)
|
||||
self.wait(timeout=0.15)
|
||||
|
||||
def test_multiple_errors(self):
|
||||
with ignore_deprecation():
|
||||
def fail(message):
|
||||
raise Exception(message)
|
||||
self.io_loop.add_callback(lambda: fail("error one"))
|
||||
self.io_loop.add_callback(lambda: fail("error two"))
|
||||
# The first error gets raised; the second gets logged.
|
||||
with ExpectLog(app_log, "multiple unhandled exceptions"):
|
||||
with self.assertRaises(Exception) as cm:
|
||||
self.wait()
|
||||
self.assertEqual(str(cm.exception), "error one")
|
||||
|
||||
|
||||
class AsyncHTTPTestCaseTest(AsyncHTTPTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(AsyncHTTPTestCaseTest, cls).setUpClass()
|
||||
# An unused port is bound so we can make requests upon it without
|
||||
# impacting a real local web server.
|
||||
cls.external_sock, cls.external_port = bind_unused_port()
|
||||
|
||||
def get_app(self):
|
||||
return Application()
|
||||
|
||||
def test_fetch_segment(self):
|
||||
path = '/path'
|
||||
response = self.fetch(path)
|
||||
self.assertEqual(response.request.url, self.get_url(path))
|
||||
|
||||
@gen_test
|
||||
def test_fetch_full_http_url(self):
|
||||
path = 'http://localhost:%d/path' % self.external_port
|
||||
|
||||
with contextlib.closing(SimpleAsyncHTTPClient(force_instance=True)) as client:
|
||||
with self.assertRaises(HTTPTimeoutError) as cm:
|
||||
yield client.fetch(path, request_timeout=0.1, raise_error=True)
|
||||
self.assertEqual(cm.exception.response.request.url, path)
|
||||
|
||||
@gen_test
|
||||
def test_fetch_full_https_url(self):
|
||||
path = 'https://localhost:%d/path' % self.external_port
|
||||
|
||||
with contextlib.closing(SimpleAsyncHTTPClient(force_instance=True)) as client:
|
||||
with self.assertRaises(HTTPTimeoutError) as cm:
|
||||
yield client.fetch(path, request_timeout=0.1, raise_error=True)
|
||||
self.assertEqual(cm.exception.response.request.url, path)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.external_sock.close()
|
||||
super(AsyncHTTPTestCaseTest, cls).tearDownClass()
|
||||
|
||||
|
||||
class AsyncTestCaseWrapperTest(unittest.TestCase):
|
||||
def test_undecorated_generator(self):
|
||||
class Test(AsyncTestCase):
|
||||
def test_gen(self):
|
||||
yield
|
||||
test = Test('test_gen')
|
||||
result = unittest.TestResult()
|
||||
test.run(result)
|
||||
self.assertEqual(len(result.errors), 1)
|
||||
self.assertIn("should be decorated", result.errors[0][1])
|
||||
|
||||
@skipBefore35
|
||||
@unittest.skipIf(platform.python_implementation() == 'PyPy',
|
||||
'pypy destructor warnings cannot be silenced')
|
||||
def test_undecorated_coroutine(self):
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
class Test(AsyncTestCase):
|
||||
async def test_coro(self):
|
||||
pass
|
||||
""")
|
||||
|
||||
test_class = namespace['Test']
|
||||
test = test_class('test_coro')
|
||||
result = unittest.TestResult()
|
||||
|
||||
# Silence "RuntimeWarning: coroutine 'test_coro' was never awaited".
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
test.run(result)
|
||||
|
||||
self.assertEqual(len(result.errors), 1)
|
||||
self.assertIn("should be decorated", result.errors[0][1])
|
||||
|
||||
def test_undecorated_generator_with_skip(self):
|
||||
class Test(AsyncTestCase):
|
||||
@unittest.skip("don't run this")
|
||||
def test_gen(self):
|
||||
yield
|
||||
test = Test('test_gen')
|
||||
result = unittest.TestResult()
|
||||
test.run(result)
|
||||
self.assertEqual(len(result.errors), 0)
|
||||
self.assertEqual(len(result.skipped), 1)
|
||||
|
||||
def test_other_return(self):
|
||||
class Test(AsyncTestCase):
|
||||
def test_other_return(self):
|
||||
return 42
|
||||
test = Test('test_other_return')
|
||||
result = unittest.TestResult()
|
||||
test.run(result)
|
||||
self.assertEqual(len(result.errors), 1)
|
||||
self.assertIn("Return value from test method ignored", result.errors[0][1])
|
||||
|
||||
|
||||
class SetUpTearDownTest(unittest.TestCase):
|
||||
def test_set_up_tear_down(self):
|
||||
"""
|
||||
This test makes sure that AsyncTestCase calls super methods for
|
||||
setUp and tearDown.
|
||||
|
||||
InheritBoth is a subclass of both AsyncTestCase and
|
||||
SetUpTearDown, with the ordering so that the super of
|
||||
AsyncTestCase will be SetUpTearDown.
|
||||
"""
|
||||
events = []
|
||||
result = unittest.TestResult()
|
||||
|
||||
class SetUpTearDown(unittest.TestCase):
|
||||
def setUp(self):
|
||||
events.append('setUp')
|
||||
|
||||
def tearDown(self):
|
||||
events.append('tearDown')
|
||||
|
||||
class InheritBoth(AsyncTestCase, SetUpTearDown):
|
||||
def test(self):
|
||||
events.append('test')
|
||||
|
||||
InheritBoth('test').run(result)
|
||||
expected = ['setUp', 'test', 'tearDown']
|
||||
self.assertEqual(expected, events)
|
||||
|
||||
|
||||
class GenTest(AsyncTestCase):
|
||||
def setUp(self):
|
||||
super(GenTest, self).setUp()
|
||||
self.finished = False
|
||||
|
||||
def tearDown(self):
|
||||
self.assertTrue(self.finished)
|
||||
super(GenTest, self).tearDown()
|
||||
|
||||
@gen_test
|
||||
def test_sync(self):
|
||||
self.finished = True
|
||||
|
||||
@gen_test
|
||||
def test_async(self):
|
||||
yield gen.moment
|
||||
self.finished = True
|
||||
|
||||
def test_timeout(self):
|
||||
# Set a short timeout and exceed it.
|
||||
@gen_test(timeout=0.1)
|
||||
def test(self):
|
||||
yield gen.sleep(1)
|
||||
|
||||
# This can't use assertRaises because we need to inspect the
|
||||
# exc_info triple (and not just the exception object)
|
||||
try:
|
||||
test(self)
|
||||
self.fail("did not get expected exception")
|
||||
except ioloop.TimeoutError:
|
||||
# The stack trace should blame the add_timeout line, not just
|
||||
# unrelated IOLoop/testing internals.
|
||||
self.assertIn(
|
||||
"gen.sleep(1)",
|
||||
traceback.format_exc())
|
||||
|
||||
self.finished = True
|
||||
|
||||
def test_no_timeout(self):
|
||||
# A test that does not exceed its timeout should succeed.
|
||||
@gen_test(timeout=1)
|
||||
def test(self):
|
||||
yield gen.sleep(0.1)
|
||||
|
||||
test(self)
|
||||
self.finished = True
|
||||
|
||||
def test_timeout_environment_variable(self):
|
||||
@gen_test(timeout=0.5)
|
||||
def test_long_timeout(self):
|
||||
yield gen.sleep(0.25)
|
||||
|
||||
# Uses provided timeout of 0.5 seconds, doesn't time out.
|
||||
with set_environ('ASYNC_TEST_TIMEOUT', '0.1'):
|
||||
test_long_timeout(self)
|
||||
|
||||
self.finished = True
|
||||
|
||||
def test_no_timeout_environment_variable(self):
|
||||
@gen_test(timeout=0.01)
|
||||
def test_short_timeout(self):
|
||||
yield gen.sleep(1)
|
||||
|
||||
# Uses environment-variable timeout of 0.1, times out.
|
||||
with set_environ('ASYNC_TEST_TIMEOUT', '0.1'):
|
||||
with self.assertRaises(ioloop.TimeoutError):
|
||||
test_short_timeout(self)
|
||||
|
||||
self.finished = True
|
||||
|
||||
def test_with_method_args(self):
|
||||
@gen_test
|
||||
def test_with_args(self, *args):
|
||||
self.assertEqual(args, ('test',))
|
||||
yield gen.moment
|
||||
|
||||
test_with_args(self, 'test')
|
||||
self.finished = True
|
||||
|
||||
def test_with_method_kwargs(self):
|
||||
@gen_test
|
||||
def test_with_kwargs(self, **kwargs):
|
||||
self.assertDictEqual(kwargs, {'test': 'test'})
|
||||
yield gen.moment
|
||||
|
||||
test_with_kwargs(self, test='test')
|
||||
self.finished = True
|
||||
|
||||
@skipBefore35
|
||||
def test_native_coroutine(self):
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
@gen_test
|
||||
async def test(self):
|
||||
self.finished = True
|
||||
""")
|
||||
|
||||
namespace['test'](self)
|
||||
|
||||
@skipBefore35
|
||||
def test_native_coroutine_timeout(self):
|
||||
# Set a short timeout and exceed it.
|
||||
namespace = exec_test(globals(), locals(), """
|
||||
@gen_test(timeout=0.1)
|
||||
async def test(self):
|
||||
await gen.sleep(1)
|
||||
""")
|
||||
|
||||
try:
|
||||
namespace['test'](self)
|
||||
self.fail("did not get expected exception")
|
||||
except ioloop.TimeoutError:
|
||||
self.finished = True
|
||||
|
||||
|
||||
@unittest.skipIf(asyncio is None, "asyncio module not present")
|
||||
class GetNewIOLoopTest(AsyncTestCase):
|
||||
def get_new_ioloop(self):
|
||||
# Use the current loop instead of creating a new one here.
|
||||
return ioloop.IOLoop.current()
|
||||
|
||||
def setUp(self):
|
||||
# This simulates the effect of an asyncio test harness like
|
||||
# pytest-asyncio.
|
||||
self.orig_loop = asyncio.get_event_loop()
|
||||
self.new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.new_loop)
|
||||
super(GetNewIOLoopTest, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super(GetNewIOLoopTest, self).tearDown()
|
||||
# AsyncTestCase must not affect the existing asyncio loop.
|
||||
self.assertFalse(asyncio.get_event_loop().is_closed())
|
||||
asyncio.set_event_loop(self.orig_loop)
|
||||
self.new_loop.close()
|
||||
|
||||
def test_loop(self):
|
||||
self.assertIs(self.io_loop.asyncio_loop, self.new_loop)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Executable
+729
@@ -0,0 +1,729 @@
|
||||
# Author: Ovidiu Predescu
|
||||
# Date: July 2011
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unittest for the twisted-style reactor.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
from tornado.escape import utf8
|
||||
from tornado import gen
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.ioloop import IOLoop, PollIOLoop
|
||||
from tornado.platform.auto import set_close_exec
|
||||
from tornado.testing import bind_unused_port
|
||||
from tornado.test.util import unittest
|
||||
from tornado.util import import_object, PY3
|
||||
from tornado.web import RequestHandler, Application
|
||||
|
||||
try:
|
||||
import fcntl
|
||||
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue # type: ignore
|
||||
from twisted.internet.interfaces import IReadDescriptor, IWriteDescriptor # type: ignore
|
||||
from twisted.internet.protocol import Protocol # type: ignore
|
||||
from twisted.python import log # type: ignore
|
||||
from tornado.platform.twisted import TornadoReactor, TwistedIOLoop
|
||||
from zope.interface import implementer # type: ignore
|
||||
have_twisted = True
|
||||
except ImportError:
|
||||
have_twisted = False
|
||||
|
||||
# The core of Twisted 12.3.0 is available on python 3, but twisted.web is not
|
||||
# so test for it separately.
|
||||
try:
|
||||
from twisted.web.client import Agent, readBody # type: ignore
|
||||
from twisted.web.resource import Resource # type: ignore
|
||||
from twisted.web.server import Site # type: ignore
|
||||
# As of Twisted 15.0.0, twisted.web is present but fails our
|
||||
# tests due to internal str/bytes errors.
|
||||
have_twisted_web = sys.version_info < (3,)
|
||||
except ImportError:
|
||||
have_twisted_web = False
|
||||
|
||||
if PY3:
|
||||
import _thread as thread
|
||||
else:
|
||||
import thread
|
||||
ResourceWarning = None
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError:
|
||||
asyncio = None
|
||||
|
||||
skipIfNoTwisted = unittest.skipUnless(have_twisted,
|
||||
"twisted module not present")
|
||||
|
||||
|
||||
def save_signal_handlers():
|
||||
saved = {}
|
||||
for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGCHLD]:
|
||||
saved[sig] = signal.getsignal(sig)
|
||||
if "twisted" in repr(saved):
|
||||
if not issubclass(IOLoop.configured_class(), TwistedIOLoop):
|
||||
# when the global ioloop is twisted, we expect the signal
|
||||
# handlers to be installed. Otherwise, it means we're not
|
||||
# cleaning up after twisted properly.
|
||||
raise Exception("twisted signal handlers already installed")
|
||||
return saved
|
||||
|
||||
|
||||
def restore_signal_handlers(saved):
|
||||
for sig, handler in saved.items():
|
||||
signal.signal(sig, handler)
|
||||
|
||||
|
||||
class ReactorTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._saved_signals = save_signal_handlers()
|
||||
IOLoop.clear_current()
|
||||
self._io_loop = IOLoop(make_current=True)
|
||||
self._reactor = TornadoReactor()
|
||||
IOLoop.clear_current()
|
||||
|
||||
def tearDown(self):
|
||||
self._io_loop.close(all_fds=True)
|
||||
restore_signal_handlers(self._saved_signals)
|
||||
|
||||
|
||||
@skipIfNoTwisted
|
||||
class ReactorWhenRunningTest(ReactorTestCase):
|
||||
def test_whenRunning(self):
|
||||
self._whenRunningCalled = False
|
||||
self._anotherWhenRunningCalled = False
|
||||
self._reactor.callWhenRunning(self.whenRunningCallback)
|
||||
self._reactor.run()
|
||||
self.assertTrue(self._whenRunningCalled)
|
||||
self.assertTrue(self._anotherWhenRunningCalled)
|
||||
|
||||
def whenRunningCallback(self):
|
||||
self._whenRunningCalled = True
|
||||
self._reactor.callWhenRunning(self.anotherWhenRunningCallback)
|
||||
self._reactor.stop()
|
||||
|
||||
def anotherWhenRunningCallback(self):
|
||||
self._anotherWhenRunningCalled = True
|
||||
|
||||
|
||||
@skipIfNoTwisted
|
||||
class ReactorCallLaterTest(ReactorTestCase):
|
||||
def test_callLater(self):
|
||||
self._laterCalled = False
|
||||
self._now = self._reactor.seconds()
|
||||
self._timeout = 0.001
|
||||
dc = self._reactor.callLater(self._timeout, self.callLaterCallback)
|
||||
self.assertEqual(self._reactor.getDelayedCalls(), [dc])
|
||||
self._reactor.run()
|
||||
self.assertTrue(self._laterCalled)
|
||||
self.assertTrue(self._called - self._now > self._timeout)
|
||||
self.assertEqual(self._reactor.getDelayedCalls(), [])
|
||||
|
||||
def callLaterCallback(self):
|
||||
self._laterCalled = True
|
||||
self._called = self._reactor.seconds()
|
||||
self._reactor.stop()
|
||||
|
||||
|
||||
@skipIfNoTwisted
|
||||
class ReactorTwoCallLaterTest(ReactorTestCase):
|
||||
def test_callLater(self):
|
||||
self._later1Called = False
|
||||
self._later2Called = False
|
||||
self._now = self._reactor.seconds()
|
||||
self._timeout1 = 0.0005
|
||||
dc1 = self._reactor.callLater(self._timeout1, self.callLaterCallback1)
|
||||
self._timeout2 = 0.001
|
||||
dc2 = self._reactor.callLater(self._timeout2, self.callLaterCallback2)
|
||||
self.assertTrue(self._reactor.getDelayedCalls() == [dc1, dc2] or
|
||||
self._reactor.getDelayedCalls() == [dc2, dc1])
|
||||
self._reactor.run()
|
||||
self.assertTrue(self._later1Called)
|
||||
self.assertTrue(self._later2Called)
|
||||
self.assertTrue(self._called1 - self._now > self._timeout1)
|
||||
self.assertTrue(self._called2 - self._now > self._timeout2)
|
||||
self.assertEqual(self._reactor.getDelayedCalls(), [])
|
||||
|
||||
def callLaterCallback1(self):
|
||||
self._later1Called = True
|
||||
self._called1 = self._reactor.seconds()
|
||||
|
||||
def callLaterCallback2(self):
|
||||
self._later2Called = True
|
||||
self._called2 = self._reactor.seconds()
|
||||
self._reactor.stop()
|
||||
|
||||
|
||||
@skipIfNoTwisted
|
||||
class ReactorCallFromThreadTest(ReactorTestCase):
|
||||
def setUp(self):
|
||||
super(ReactorCallFromThreadTest, self).setUp()
|
||||
self._mainThread = thread.get_ident()
|
||||
|
||||
def tearDown(self):
|
||||
self._thread.join()
|
||||
super(ReactorCallFromThreadTest, self).tearDown()
|
||||
|
||||
def _newThreadRun(self):
|
||||
self.assertNotEqual(self._mainThread, thread.get_ident())
|
||||
if hasattr(self._thread, 'ident'): # new in python 2.6
|
||||
self.assertEqual(self._thread.ident, thread.get_ident())
|
||||
self._reactor.callFromThread(self._fnCalledFromThread)
|
||||
|
||||
def _fnCalledFromThread(self):
|
||||
self.assertEqual(self._mainThread, thread.get_ident())
|
||||
self._reactor.stop()
|
||||
|
||||
def _whenRunningCallback(self):
|
||||
self._thread = threading.Thread(target=self._newThreadRun)
|
||||
self._thread.start()
|
||||
|
||||
def testCallFromThread(self):
|
||||
self._reactor.callWhenRunning(self._whenRunningCallback)
|
||||
self._reactor.run()
|
||||
|
||||
|
||||
@skipIfNoTwisted
|
||||
class ReactorCallInThread(ReactorTestCase):
|
||||
def setUp(self):
|
||||
super(ReactorCallInThread, self).setUp()
|
||||
self._mainThread = thread.get_ident()
|
||||
|
||||
def _fnCalledInThread(self, *args, **kwargs):
|
||||
self.assertNotEqual(thread.get_ident(), self._mainThread)
|
||||
self._reactor.callFromThread(lambda: self._reactor.stop())
|
||||
|
||||
def _whenRunningCallback(self):
|
||||
self._reactor.callInThread(self._fnCalledInThread)
|
||||
|
||||
def testCallInThread(self):
|
||||
self._reactor.callWhenRunning(self._whenRunningCallback)
|
||||
self._reactor.run()
|
||||
|
||||
|
||||
if have_twisted:
|
||||
@implementer(IReadDescriptor)
|
||||
class Reader(object):
|
||||
def __init__(self, fd, callback):
|
||||
self._fd = fd
|
||||
self._callback = callback
|
||||
|
||||
def logPrefix(self):
|
||||
return "Reader"
|
||||
|
||||
def close(self):
|
||||
self._fd.close()
|
||||
|
||||
def fileno(self):
|
||||
return self._fd.fileno()
|
||||
|
||||
def readConnectionLost(self, reason):
|
||||
self.close()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.close()
|
||||
|
||||
def doRead(self):
|
||||
self._callback(self._fd)
|
||||
|
||||
@implementer(IWriteDescriptor)
|
||||
class Writer(object):
|
||||
def __init__(self, fd, callback):
|
||||
self._fd = fd
|
||||
self._callback = callback
|
||||
|
||||
def logPrefix(self):
|
||||
return "Writer"
|
||||
|
||||
def close(self):
|
||||
self._fd.close()
|
||||
|
||||
def fileno(self):
|
||||
return self._fd.fileno()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.close()
|
||||
|
||||
def doWrite(self):
|
||||
self._callback(self._fd)
|
||||
|
||||
|
||||
@skipIfNoTwisted
|
||||
class ReactorReaderWriterTest(ReactorTestCase):
|
||||
def _set_nonblocking(self, fd):
|
||||
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
|
||||
def setUp(self):
|
||||
super(ReactorReaderWriterTest, self).setUp()
|
||||
r, w = os.pipe()
|
||||
self._set_nonblocking(r)
|
||||
self._set_nonblocking(w)
|
||||
set_close_exec(r)
|
||||
set_close_exec(w)
|
||||
self._p1 = os.fdopen(r, "rb", 0)
|
||||
self._p2 = os.fdopen(w, "wb", 0)
|
||||
|
||||
def tearDown(self):
|
||||
super(ReactorReaderWriterTest, self).tearDown()
|
||||
self._p1.close()
|
||||
self._p2.close()
|
||||
|
||||
def _testReadWrite(self):
|
||||
"""
|
||||
In this test the writer writes an 'x' to its fd. The reader
|
||||
reads it, check the value and ends the test.
|
||||
"""
|
||||
self.shouldWrite = True
|
||||
|
||||
def checkReadInput(fd):
|
||||
self.assertEquals(fd.read(1), b'x')
|
||||
self._reactor.stop()
|
||||
|
||||
def writeOnce(fd):
|
||||
if self.shouldWrite:
|
||||
self.shouldWrite = False
|
||||
fd.write(b'x')
|
||||
self._reader = Reader(self._p1, checkReadInput)
|
||||
self._writer = Writer(self._p2, writeOnce)
|
||||
|
||||
self._reactor.addWriter(self._writer)
|
||||
|
||||
# Test that adding the reader twice adds it only once to
|
||||
# IOLoop.
|
||||
self._reactor.addReader(self._reader)
|
||||
self._reactor.addReader(self._reader)
|
||||
|
||||
def testReadWrite(self):
|
||||
self._reactor.callWhenRunning(self._testReadWrite)
|
||||
self._reactor.run()
|
||||
|
||||
def _testNoWriter(self):
|
||||
"""
|
||||
In this test we have no writer. Make sure the reader doesn't
|
||||
read anything.
|
||||
"""
|
||||
def checkReadInput(fd):
|
||||
self.fail("Must not be called.")
|
||||
|
||||
def stopTest():
|
||||
# Close the writer here since the IOLoop doesn't know
|
||||
# about it.
|
||||
self._writer.close()
|
||||
self._reactor.stop()
|
||||
self._reader = Reader(self._p1, checkReadInput)
|
||||
|
||||
# We create a writer, but it should never be invoked.
|
||||
self._writer = Writer(self._p2, lambda fd: fd.write('x'))
|
||||
|
||||
# Test that adding and removing the writer leaves us with no writer.
|
||||
self._reactor.addWriter(self._writer)
|
||||
self._reactor.removeWriter(self._writer)
|
||||
|
||||
# Test that adding and removing the reader doesn't cause
|
||||
# unintended effects.
|
||||
self._reactor.addReader(self._reader)
|
||||
|
||||
# Wake up after a moment and stop the test
|
||||
self._reactor.callLater(0.001, stopTest)
|
||||
|
||||
def testNoWriter(self):
|
||||
self._reactor.callWhenRunning(self._testNoWriter)
|
||||
self._reactor.run()
|
||||
|
||||
# Test various combinations of twisted and tornado http servers,
|
||||
# http clients, and event loop interfaces.
|
||||
|
||||
|
||||
@skipIfNoTwisted
|
||||
@unittest.skipIf(not have_twisted_web, 'twisted web not present')
|
||||
class CompatibilityTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.saved_signals = save_signal_handlers()
|
||||
self.io_loop = IOLoop()
|
||||
self.io_loop.make_current()
|
||||
self.reactor = TornadoReactor()
|
||||
|
||||
def tearDown(self):
|
||||
self.reactor.disconnectAll()
|
||||
self.io_loop.clear_current()
|
||||
self.io_loop.close(all_fds=True)
|
||||
restore_signal_handlers(self.saved_signals)
|
||||
|
||||
def start_twisted_server(self):
|
||||
class HelloResource(Resource):
|
||||
isLeaf = True
|
||||
|
||||
def render_GET(self, request):
|
||||
return "Hello from twisted!"
|
||||
site = Site(HelloResource())
|
||||
port = self.reactor.listenTCP(0, site, interface='127.0.0.1')
|
||||
self.twisted_port = port.getHost().port
|
||||
|
||||
def start_tornado_server(self):
|
||||
class HelloHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.write("Hello from tornado!")
|
||||
app = Application([('/', HelloHandler)],
|
||||
log_function=lambda x: None)
|
||||
server = HTTPServer(app)
|
||||
sock, self.tornado_port = bind_unused_port()
|
||||
server.add_sockets([sock])
|
||||
|
||||
def run_ioloop(self):
|
||||
self.stop_loop = self.io_loop.stop
|
||||
self.io_loop.start()
|
||||
self.reactor.fireSystemEvent('shutdown')
|
||||
|
||||
def run_reactor(self):
|
||||
self.stop_loop = self.reactor.stop
|
||||
self.stop = self.reactor.stop
|
||||
self.reactor.run()
|
||||
|
||||
def tornado_fetch(self, url, runner):
|
||||
responses = []
|
||||
client = AsyncHTTPClient()
|
||||
|
||||
def callback(response):
|
||||
responses.append(response)
|
||||
self.stop_loop()
|
||||
client.fetch(url, callback=callback)
|
||||
runner()
|
||||
self.assertEqual(len(responses), 1)
|
||||
responses[0].rethrow()
|
||||
return responses[0]
|
||||
|
||||
def twisted_fetch(self, url, runner):
|
||||
# http://twistedmatrix.com/documents/current/web/howto/client.html
|
||||
chunks = []
|
||||
client = Agent(self.reactor)
|
||||
d = client.request(b'GET', utf8(url))
|
||||
|
||||
class Accumulator(Protocol):
|
||||
def __init__(self, finished):
|
||||
self.finished = finished
|
||||
|
||||
def dataReceived(self, data):
|
||||
chunks.append(data)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.finished.callback(None)
|
||||
|
||||
def callback(response):
|
||||
finished = Deferred()
|
||||
response.deliverBody(Accumulator(finished))
|
||||
return finished
|
||||
d.addCallback(callback)
|
||||
|
||||
def shutdown(failure):
|
||||
if hasattr(self, 'stop_loop'):
|
||||
self.stop_loop()
|
||||
elif failure is not None:
|
||||
# loop hasn't been initialized yet; try our best to
|
||||
# get an error message out. (the runner() interaction
|
||||
# should probably be refactored).
|
||||
try:
|
||||
failure.raiseException()
|
||||
except:
|
||||
logging.error('exception before starting loop', exc_info=True)
|
||||
d.addBoth(shutdown)
|
||||
runner()
|
||||
self.assertTrue(chunks)
|
||||
return ''.join(chunks)
|
||||
|
||||
def twisted_coroutine_fetch(self, url, runner):
|
||||
body = [None]
|
||||
|
||||
@gen.coroutine
|
||||
def f():
|
||||
# This is simpler than the non-coroutine version, but it cheats
|
||||
# by reading the body in one blob instead of streaming it with
|
||||
# a Protocol.
|
||||
client = Agent(self.reactor)
|
||||
response = yield client.request(b'GET', utf8(url))
|
||||
with warnings.catch_warnings():
|
||||
# readBody has a buggy DeprecationWarning in Twisted 15.0:
|
||||
# https://twistedmatrix.com/trac/changeset/43379
|
||||
warnings.simplefilter('ignore', category=DeprecationWarning)
|
||||
body[0] = yield readBody(response)
|
||||
self.stop_loop()
|
||||
self.io_loop.add_callback(f)
|
||||
runner()
|
||||
return body[0]
|
||||
|
||||
def testTwistedServerTornadoClientIOLoop(self):
|
||||
self.start_twisted_server()
|
||||
response = self.tornado_fetch(
|
||||
'http://127.0.0.1:%d' % self.twisted_port, self.run_ioloop)
|
||||
self.assertEqual(response.body, 'Hello from twisted!')
|
||||
|
||||
def testTwistedServerTornadoClientReactor(self):
|
||||
self.start_twisted_server()
|
||||
response = self.tornado_fetch(
|
||||
'http://127.0.0.1:%d' % self.twisted_port, self.run_reactor)
|
||||
self.assertEqual(response.body, 'Hello from twisted!')
|
||||
|
||||
def testTornadoServerTwistedClientIOLoop(self):
|
||||
self.start_tornado_server()
|
||||
response = self.twisted_fetch(
|
||||
'http://127.0.0.1:%d' % self.tornado_port, self.run_ioloop)
|
||||
self.assertEqual(response, 'Hello from tornado!')
|
||||
|
||||
def testTornadoServerTwistedClientReactor(self):
|
||||
self.start_tornado_server()
|
||||
response = self.twisted_fetch(
|
||||
'http://127.0.0.1:%d' % self.tornado_port, self.run_reactor)
|
||||
self.assertEqual(response, 'Hello from tornado!')
|
||||
|
||||
def testTornadoServerTwistedCoroutineClientIOLoop(self):
|
||||
self.start_tornado_server()
|
||||
response = self.twisted_coroutine_fetch(
|
||||
'http://127.0.0.1:%d' % self.tornado_port, self.run_ioloop)
|
||||
self.assertEqual(response, 'Hello from tornado!')
|
||||
|
||||
|
||||
@skipIfNoTwisted
|
||||
class ConvertDeferredTest(unittest.TestCase):
|
||||
def test_success(self):
|
||||
@inlineCallbacks
|
||||
def fn():
|
||||
if False:
|
||||
# inlineCallbacks doesn't work with regular functions;
|
||||
# must have a yield even if it's unreachable.
|
||||
yield
|
||||
returnValue(42)
|
||||
f = gen.convert_yielded(fn())
|
||||
self.assertEqual(f.result(), 42)
|
||||
|
||||
def test_failure(self):
|
||||
@inlineCallbacks
|
||||
def fn():
|
||||
if False:
|
||||
yield
|
||||
1 / 0
|
||||
f = gen.convert_yielded(fn())
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
f.result()
|
||||
|
||||
|
||||
if have_twisted:
|
||||
# Import and run as much of twisted's test suite as possible.
|
||||
# This is unfortunately rather dependent on implementation details,
|
||||
# but there doesn't appear to be a clean all-in-one conformance test
|
||||
# suite for reactors.
|
||||
#
|
||||
# This is a list of all test suites using the ReactorBuilder
|
||||
# available in Twisted 11.0.0 and 11.1.0 (and a blacklist of
|
||||
# specific test methods to be disabled).
|
||||
twisted_tests = {
|
||||
'twisted.internet.test.test_core.ObjectModelIntegrationTest': [],
|
||||
'twisted.internet.test.test_core.SystemEventTestsBuilder': [
|
||||
'test_iterate', # deliberately not supported
|
||||
# Fails on TwistedIOLoop and AsyncIOLoop.
|
||||
'test_runAfterCrash',
|
||||
],
|
||||
'twisted.internet.test.test_fdset.ReactorFDSetTestsBuilder': [
|
||||
"test_lostFileDescriptor", # incompatible with epoll and kqueue
|
||||
],
|
||||
'twisted.internet.test.test_process.ProcessTestsBuilder': [
|
||||
# Only work as root. Twisted's "skip" functionality works
|
||||
# with py27+, but not unittest2 on py26.
|
||||
'test_changeGID',
|
||||
'test_changeUID',
|
||||
# This test sometimes fails with EPIPE on a call to
|
||||
# kqueue.control. Happens consistently for me with
|
||||
# trollius but not asyncio or other IOLoops.
|
||||
'test_childConnectionLost',
|
||||
],
|
||||
# Process tests appear to work on OSX 10.7, but not 10.6
|
||||
# 'twisted.internet.test.test_process.PTYProcessTestsBuilder': [
|
||||
# 'test_systemCallUninterruptedByChildExit',
|
||||
# ],
|
||||
'twisted.internet.test.test_tcp.TCPClientTestsBuilder': [
|
||||
'test_badContext', # ssl-related; see also SSLClientTestsMixin
|
||||
],
|
||||
'twisted.internet.test.test_tcp.TCPPortTestsBuilder': [
|
||||
# These use link-local addresses and cause firewall prompts on mac
|
||||
'test_buildProtocolIPv6AddressScopeID',
|
||||
'test_portGetHostOnIPv6ScopeID',
|
||||
'test_serverGetHostOnIPv6ScopeID',
|
||||
'test_serverGetPeerOnIPv6ScopeID',
|
||||
],
|
||||
'twisted.internet.test.test_tcp.TCPConnectionTestsBuilder': [],
|
||||
'twisted.internet.test.test_tcp.WriteSequenceTests': [],
|
||||
'twisted.internet.test.test_tcp.AbortConnectionTestCase': [],
|
||||
'twisted.internet.test.test_threads.ThreadTestsBuilder': [],
|
||||
'twisted.internet.test.test_time.TimeTestsBuilder': [],
|
||||
# Extra third-party dependencies (pyOpenSSL)
|
||||
# 'twisted.internet.test.test_tls.SSLClientTestsMixin': [],
|
||||
'twisted.internet.test.test_udp.UDPServerTestsBuilder': [],
|
||||
'twisted.internet.test.test_unix.UNIXTestsBuilder': [
|
||||
# Platform-specific. These tests would be skipped automatically
|
||||
# if we were running twisted's own test runner.
|
||||
'test_connectToLinuxAbstractNamespace',
|
||||
'test_listenOnLinuxAbstractNamespace',
|
||||
# These tests use twisted's sendmsg.c extension and sometimes
|
||||
# fail with what looks like uninitialized memory errors
|
||||
# (more common on pypy than cpython, but I've seen it on both)
|
||||
'test_sendFileDescriptor',
|
||||
'test_sendFileDescriptorTriggersPauseProducing',
|
||||
'test_descriptorDeliveredBeforeBytes',
|
||||
'test_avoidLeakingFileDescriptors',
|
||||
],
|
||||
'twisted.internet.test.test_unix.UNIXDatagramTestsBuilder': [
|
||||
'test_listenOnLinuxAbstractNamespace',
|
||||
],
|
||||
'twisted.internet.test.test_unix.UNIXPortTestsBuilder': [],
|
||||
}
|
||||
if sys.version_info >= (3,):
|
||||
# In Twisted 15.2.0 on Python 3.4, the process tests will try to run
|
||||
# but fail, due in part to interactions between Tornado's strict
|
||||
# warnings-as-errors policy and Twisted's own warning handling
|
||||
# (it was not obvious how to configure the warnings module to
|
||||
# reconcile the two), and partly due to what looks like a packaging
|
||||
# error (process_cli.py missing). For now, just skip it.
|
||||
del twisted_tests['twisted.internet.test.test_process.ProcessTestsBuilder']
|
||||
for test_name, blacklist in twisted_tests.items():
|
||||
try:
|
||||
test_class = import_object(test_name)
|
||||
except (ImportError, AttributeError):
|
||||
continue
|
||||
for test_func in blacklist: # type: ignore
|
||||
if hasattr(test_class, test_func):
|
||||
# The test_func may be defined in a mixin, so clobber
|
||||
# it instead of delattr()
|
||||
setattr(test_class, test_func, lambda self: None)
|
||||
|
||||
def make_test_subclass(test_class):
|
||||
class TornadoTest(test_class): # type: ignore
|
||||
_reactors = ["tornado.platform.twisted._TestReactor"]
|
||||
|
||||
def setUp(self):
|
||||
# Twisted's tests expect to be run from a temporary
|
||||
# directory; they create files in their working directory
|
||||
# and don't always clean up after themselves.
|
||||
self.__curdir = os.getcwd()
|
||||
self.__tempdir = tempfile.mkdtemp()
|
||||
os.chdir(self.__tempdir)
|
||||
super(TornadoTest, self).setUp() # type: ignore
|
||||
|
||||
def tearDown(self):
|
||||
super(TornadoTest, self).tearDown() # type: ignore
|
||||
os.chdir(self.__curdir)
|
||||
shutil.rmtree(self.__tempdir)
|
||||
|
||||
def flushWarnings(self, *args, **kwargs):
|
||||
# This is a hack because Twisted and Tornado have
|
||||
# differing approaches to warnings in tests.
|
||||
# Tornado sets up a global set of warnings filters
|
||||
# in runtests.py, while Twisted patches the filter
|
||||
# list in each test. The net effect is that
|
||||
# Twisted's tests run with Tornado's increased
|
||||
# strictness (BytesWarning and ResourceWarning are
|
||||
# enabled) but without our filter rules to ignore those
|
||||
# warnings from Twisted code.
|
||||
filtered = []
|
||||
for w in super(TornadoTest, self).flushWarnings( # type: ignore
|
||||
*args, **kwargs):
|
||||
if w['category'] in (BytesWarning, ResourceWarning):
|
||||
continue
|
||||
filtered.append(w)
|
||||
return filtered
|
||||
|
||||
def buildReactor(self):
|
||||
self.__saved_signals = save_signal_handlers()
|
||||
return test_class.buildReactor(self)
|
||||
|
||||
def unbuildReactor(self, reactor):
|
||||
test_class.unbuildReactor(self, reactor)
|
||||
# Clean up file descriptors (especially epoll/kqueue
|
||||
# objects) eagerly instead of leaving them for the
|
||||
# GC. Unfortunately we can't do this in reactor.stop
|
||||
# since twisted expects to be able to unregister
|
||||
# connections in a post-shutdown hook.
|
||||
reactor._io_loop.close(all_fds=True)
|
||||
restore_signal_handlers(self.__saved_signals)
|
||||
|
||||
TornadoTest.__name__ = test_class.__name__
|
||||
return TornadoTest
|
||||
test_subclass = make_test_subclass(test_class)
|
||||
globals().update(test_subclass.makeTestCaseClasses())
|
||||
|
||||
# Since we're not using twisted's test runner, it's tricky to get
|
||||
# logging set up well. Most of the time it's easiest to just
|
||||
# leave it turned off, but while working on these tests you may want
|
||||
# to uncomment one of the other lines instead.
|
||||
log.defaultObserver.stop()
|
||||
# import sys; log.startLogging(sys.stderr, setStdout=0)
|
||||
# log.startLoggingWithObserver(log.PythonLoggingObserver().emit, setStdout=0)
|
||||
# import logging; logging.getLogger('twisted').setLevel(logging.WARNING)
|
||||
|
||||
# Twisted recently introduced a new logger; disable that one too.
|
||||
try:
|
||||
from twisted.logger import globalLogBeginner # type: ignore
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
globalLogBeginner.beginLoggingTo([], redirectStandardIO=False)
|
||||
|
||||
if have_twisted:
|
||||
class LayeredTwistedIOLoop(TwistedIOLoop):
|
||||
"""Layers a TwistedIOLoop on top of a TornadoReactor on a PollIOLoop.
|
||||
|
||||
This is of course silly, but is useful for testing purposes to make
|
||||
sure we're implementing both sides of the various interfaces
|
||||
correctly. In some tests another TornadoReactor is layered on top
|
||||
of the whole stack.
|
||||
"""
|
||||
def initialize(self, **kwargs):
|
||||
self.real_io_loop = PollIOLoop(make_current=False) # type: ignore
|
||||
reactor = self.real_io_loop.run_sync(gen.coroutine(TornadoReactor))
|
||||
super(LayeredTwistedIOLoop, self).initialize(reactor=reactor, **kwargs)
|
||||
self.add_callback(self.make_current)
|
||||
|
||||
def close(self, all_fds=False):
|
||||
super(LayeredTwistedIOLoop, self).close(all_fds=all_fds)
|
||||
# HACK: This is the same thing that test_class.unbuildReactor does.
|
||||
for reader in self.reactor._internalReaders:
|
||||
self.reactor.removeReader(reader)
|
||||
reader.connectionLost(None)
|
||||
self.real_io_loop.close(all_fds=all_fds)
|
||||
|
||||
def stop(self):
|
||||
# One of twisted's tests fails if I don't delay crash()
|
||||
# until the reactor has started, but if I move this to
|
||||
# TwistedIOLoop then the tests fail when I'm *not* running
|
||||
# tornado-on-twisted-on-tornado. I'm clearly missing something
|
||||
# about the startup/crash semantics, but since stop and crash
|
||||
# are really only used in tests it doesn't really matter.
|
||||
def f():
|
||||
self.reactor.crash()
|
||||
# Become current again on restart. This is needed to
|
||||
# override real_io_loop's claim to being the current loop.
|
||||
self.add_callback(self.make_current)
|
||||
self.reactor.callWhenRunning(f)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Executable
+118
@@ -0,0 +1,118 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import platform
|
||||
import socket
|
||||
import sys
|
||||
import textwrap
|
||||
import warnings
|
||||
|
||||
from tornado.testing import bind_unused_port
|
||||
|
||||
# Delegate the choice of unittest or unittest2 to tornado.testing.
|
||||
from tornado.testing import unittest
|
||||
|
||||
skipIfNonUnix = unittest.skipIf(os.name != 'posix' or sys.platform == 'cygwin',
|
||||
"non-unix platform")
|
||||
|
||||
# travis-ci.org runs our tests in an overworked virtual machine, which makes
|
||||
# timing-related tests unreliable.
|
||||
skipOnTravis = unittest.skipIf('TRAVIS' in os.environ,
|
||||
'timing tests unreliable on travis')
|
||||
|
||||
skipOnAppEngine = unittest.skipIf('APPENGINE_RUNTIME' in os.environ,
|
||||
'not available on Google App Engine')
|
||||
|
||||
# Set the environment variable NO_NETWORK=1 to disable any tests that
|
||||
# depend on an external network.
|
||||
skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ,
|
||||
'network access disabled')
|
||||
|
||||
skipBefore33 = unittest.skipIf(sys.version_info < (3, 3), 'PEP 380 (yield from) not available')
|
||||
skipBefore35 = unittest.skipIf(sys.version_info < (3, 5), 'PEP 492 (async/await) not available')
|
||||
skipNotCPython = unittest.skipIf(platform.python_implementation() != 'CPython',
|
||||
'Not CPython implementation')
|
||||
|
||||
# Used for tests affected by
|
||||
# https://bitbucket.org/pypy/pypy/issues/2616/incomplete-error-handling-in
|
||||
# TODO: remove this after pypy3 5.8 is obsolete.
|
||||
skipPypy3V58 = unittest.skipIf(platform.python_implementation() == 'PyPy' and
|
||||
sys.version_info > (3,) and
|
||||
sys.pypy_version_info < (5, 9),
|
||||
'pypy3 5.8 has buggy ssl module')
|
||||
|
||||
|
||||
def _detect_ipv6():
|
||||
if not socket.has_ipv6:
|
||||
# socket.has_ipv6 check reports whether ipv6 was present at compile
|
||||
# time. It's usually true even when ipv6 doesn't work for other reasons.
|
||||
return False
|
||||
sock = None
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET6)
|
||||
sock.bind(('::1', 0))
|
||||
except socket.error:
|
||||
return False
|
||||
finally:
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
return True
|
||||
|
||||
|
||||
skipIfNoIPv6 = unittest.skipIf(not _detect_ipv6(), 'ipv6 support not present')
|
||||
|
||||
|
||||
def refusing_port():
|
||||
"""Returns a local port number that will refuse all connections.
|
||||
|
||||
Return value is (cleanup_func, port); the cleanup function
|
||||
must be called to free the port to be reused.
|
||||
"""
|
||||
# On travis-ci, port numbers are reassigned frequently. To avoid
|
||||
# collisions with other tests, we use an open client-side socket's
|
||||
# ephemeral port number to ensure that nothing can listen on that
|
||||
# port.
|
||||
server_socket, port = bind_unused_port()
|
||||
server_socket.setblocking(1)
|
||||
client_socket = socket.socket()
|
||||
client_socket.connect(("127.0.0.1", port))
|
||||
conn, client_addr = server_socket.accept()
|
||||
conn.close()
|
||||
server_socket.close()
|
||||
return (client_socket.close, client_addr[1])
|
||||
|
||||
|
||||
def exec_test(caller_globals, caller_locals, s):
|
||||
"""Execute ``s`` in a given context and return the result namespace.
|
||||
|
||||
Used to define functions for tests in particular python
|
||||
versions that would be syntax errors in older versions.
|
||||
"""
|
||||
# Flatten the real global and local namespace into our fake
|
||||
# globals: it's all global from the perspective of code defined
|
||||
# in s.
|
||||
global_namespace = dict(caller_globals, **caller_locals) # type: ignore
|
||||
local_namespace = {}
|
||||
exec(textwrap.dedent(s), global_namespace, local_namespace)
|
||||
return local_namespace
|
||||
|
||||
|
||||
def subTest(test, *args, **kwargs):
|
||||
"""Compatibility shim for unittest.TestCase.subTest.
|
||||
|
||||
Usage: ``with tornado.test.util.subTest(self, x=x):``
|
||||
"""
|
||||
try:
|
||||
subTest = test.subTest # py34+
|
||||
except AttributeError:
|
||||
subTest = contextlib.contextmanager(lambda *a, **kw: (yield))
|
||||
return subTest(*args, **kwargs)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ignore_deprecation():
|
||||
"""Context manager to ignore deprecation warnings."""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
yield
|
||||
Executable
+286
@@ -0,0 +1,286 @@
|
||||
# coding: utf-8
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import re
|
||||
import sys
|
||||
import datetime
|
||||
|
||||
import tornado.escape
|
||||
from tornado.escape import utf8
|
||||
from tornado.test.util import unittest
|
||||
from tornado.util import (
|
||||
raise_exc_info, Configurable, exec_in, ArgReplacer,
|
||||
timedelta_to_seconds, import_object, re_unescape, is_finalizing, PY3,
|
||||
)
|
||||
|
||||
if PY3:
|
||||
from io import StringIO
|
||||
else:
|
||||
from cStringIO import StringIO
|
||||
|
||||
|
||||
class RaiseExcInfoTest(unittest.TestCase):
|
||||
def test_two_arg_exception(self):
|
||||
# This test would fail on python 3 if raise_exc_info were simply
|
||||
# a three-argument raise statement, because TwoArgException
|
||||
# doesn't have a "copy constructor"
|
||||
class TwoArgException(Exception):
|
||||
def __init__(self, a, b):
|
||||
super(TwoArgException, self).__init__()
|
||||
self.a, self.b = a, b
|
||||
|
||||
try:
|
||||
raise TwoArgException(1, 2)
|
||||
except TwoArgException:
|
||||
exc_info = sys.exc_info()
|
||||
try:
|
||||
raise_exc_info(exc_info)
|
||||
self.fail("didn't get expected exception")
|
||||
except TwoArgException as e:
|
||||
self.assertIs(e, exc_info[1])
|
||||
|
||||
|
||||
class TestConfigurable(Configurable):
|
||||
@classmethod
|
||||
def configurable_base(cls):
|
||||
return TestConfigurable
|
||||
|
||||
@classmethod
|
||||
def configurable_default(cls):
|
||||
return TestConfig1
|
||||
|
||||
|
||||
class TestConfig1(TestConfigurable):
|
||||
def initialize(self, pos_arg=None, a=None):
|
||||
self.a = a
|
||||
self.pos_arg = pos_arg
|
||||
|
||||
|
||||
class TestConfig2(TestConfigurable):
|
||||
def initialize(self, pos_arg=None, b=None):
|
||||
self.b = b
|
||||
self.pos_arg = pos_arg
|
||||
|
||||
|
||||
class TestConfig3(TestConfigurable):
|
||||
# TestConfig3 is a configuration option that is itself configurable.
|
||||
@classmethod
|
||||
def configurable_base(cls):
|
||||
return TestConfig3
|
||||
|
||||
@classmethod
|
||||
def configurable_default(cls):
|
||||
return TestConfig3A
|
||||
|
||||
|
||||
class TestConfig3A(TestConfig3):
|
||||
def initialize(self, a=None):
|
||||
self.a = a
|
||||
|
||||
|
||||
class TestConfig3B(TestConfig3):
|
||||
def initialize(self, b=None):
|
||||
self.b = b
|
||||
|
||||
|
||||
class ConfigurableTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.saved = TestConfigurable._save_configuration()
|
||||
self.saved3 = TestConfig3._save_configuration()
|
||||
|
||||
def tearDown(self):
|
||||
TestConfigurable._restore_configuration(self.saved)
|
||||
TestConfig3._restore_configuration(self.saved3)
|
||||
|
||||
def checkSubclasses(self):
|
||||
# no matter how the class is configured, it should always be
|
||||
# possible to instantiate the subclasses directly
|
||||
self.assertIsInstance(TestConfig1(), TestConfig1)
|
||||
self.assertIsInstance(TestConfig2(), TestConfig2)
|
||||
|
||||
obj = TestConfig1(a=1)
|
||||
self.assertEqual(obj.a, 1)
|
||||
obj = TestConfig2(b=2)
|
||||
self.assertEqual(obj.b, 2)
|
||||
|
||||
def test_default(self):
|
||||
obj = TestConfigurable()
|
||||
self.assertIsInstance(obj, TestConfig1)
|
||||
self.assertIs(obj.a, None)
|
||||
|
||||
obj = TestConfigurable(a=1)
|
||||
self.assertIsInstance(obj, TestConfig1)
|
||||
self.assertEqual(obj.a, 1)
|
||||
|
||||
self.checkSubclasses()
|
||||
|
||||
def test_config_class(self):
|
||||
TestConfigurable.configure(TestConfig2)
|
||||
obj = TestConfigurable()
|
||||
self.assertIsInstance(obj, TestConfig2)
|
||||
self.assertIs(obj.b, None)
|
||||
|
||||
obj = TestConfigurable(b=2)
|
||||
self.assertIsInstance(obj, TestConfig2)
|
||||
self.assertEqual(obj.b, 2)
|
||||
|
||||
self.checkSubclasses()
|
||||
|
||||
def test_config_args(self):
|
||||
TestConfigurable.configure(None, a=3)
|
||||
obj = TestConfigurable()
|
||||
self.assertIsInstance(obj, TestConfig1)
|
||||
self.assertEqual(obj.a, 3)
|
||||
|
||||
obj = TestConfigurable(42, a=4)
|
||||
self.assertIsInstance(obj, TestConfig1)
|
||||
self.assertEqual(obj.a, 4)
|
||||
self.assertEqual(obj.pos_arg, 42)
|
||||
|
||||
self.checkSubclasses()
|
||||
# args bound in configure don't apply when using the subclass directly
|
||||
obj = TestConfig1()
|
||||
self.assertIs(obj.a, None)
|
||||
|
||||
def test_config_class_args(self):
|
||||
TestConfigurable.configure(TestConfig2, b=5)
|
||||
obj = TestConfigurable()
|
||||
self.assertIsInstance(obj, TestConfig2)
|
||||
self.assertEqual(obj.b, 5)
|
||||
|
||||
obj = TestConfigurable(42, b=6)
|
||||
self.assertIsInstance(obj, TestConfig2)
|
||||
self.assertEqual(obj.b, 6)
|
||||
self.assertEqual(obj.pos_arg, 42)
|
||||
|
||||
self.checkSubclasses()
|
||||
# args bound in configure don't apply when using the subclass directly
|
||||
obj = TestConfig2()
|
||||
self.assertIs(obj.b, None)
|
||||
|
||||
def test_config_multi_level(self):
|
||||
TestConfigurable.configure(TestConfig3, a=1)
|
||||
obj = TestConfigurable()
|
||||
self.assertIsInstance(obj, TestConfig3A)
|
||||
self.assertEqual(obj.a, 1)
|
||||
|
||||
TestConfigurable.configure(TestConfig3)
|
||||
TestConfig3.configure(TestConfig3B, b=2)
|
||||
obj = TestConfigurable()
|
||||
self.assertIsInstance(obj, TestConfig3B)
|
||||
self.assertEqual(obj.b, 2)
|
||||
|
||||
def test_config_inner_level(self):
|
||||
# The inner level can be used even when the outer level
|
||||
# doesn't point to it.
|
||||
obj = TestConfig3()
|
||||
self.assertIsInstance(obj, TestConfig3A)
|
||||
|
||||
TestConfig3.configure(TestConfig3B)
|
||||
obj = TestConfig3()
|
||||
self.assertIsInstance(obj, TestConfig3B)
|
||||
|
||||
# Configuring the base doesn't configure the inner.
|
||||
obj = TestConfigurable()
|
||||
self.assertIsInstance(obj, TestConfig1)
|
||||
TestConfigurable.configure(TestConfig2)
|
||||
|
||||
obj = TestConfigurable()
|
||||
self.assertIsInstance(obj, TestConfig2)
|
||||
|
||||
obj = TestConfig3()
|
||||
self.assertIsInstance(obj, TestConfig3B)
|
||||
|
||||
|
||||
class UnicodeLiteralTest(unittest.TestCase):
|
||||
def test_unicode_escapes(self):
|
||||
self.assertEqual(utf8(u'\u00e9'), b'\xc3\xa9')
|
||||
|
||||
|
||||
class ExecInTest(unittest.TestCase):
|
||||
# This test is python 2 only because there are no new future imports
|
||||
# defined in python 3 yet.
|
||||
@unittest.skipIf(sys.version_info >= print_function.getMandatoryRelease(),
|
||||
'no testable future imports')
|
||||
def test_no_inherit_future(self):
|
||||
# This file has from __future__ import print_function...
|
||||
f = StringIO()
|
||||
print('hello', file=f)
|
||||
# ...but the template doesn't
|
||||
exec_in('print >> f, "world"', dict(f=f))
|
||||
self.assertEqual(f.getvalue(), 'hello\nworld\n')
|
||||
|
||||
|
||||
class ArgReplacerTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
def function(x, y, callback=None, z=None):
|
||||
pass
|
||||
self.replacer = ArgReplacer(function, 'callback')
|
||||
|
||||
def test_omitted(self):
|
||||
args = (1, 2)
|
||||
kwargs = dict()
|
||||
self.assertIs(self.replacer.get_old_value(args, kwargs), None)
|
||||
self.assertEqual(self.replacer.replace('new', args, kwargs),
|
||||
(None, (1, 2), dict(callback='new')))
|
||||
|
||||
def test_position(self):
|
||||
args = (1, 2, 'old', 3)
|
||||
kwargs = dict()
|
||||
self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
|
||||
self.assertEqual(self.replacer.replace('new', args, kwargs),
|
||||
('old', [1, 2, 'new', 3], dict()))
|
||||
|
||||
def test_keyword(self):
|
||||
args = (1,)
|
||||
kwargs = dict(y=2, callback='old', z=3)
|
||||
self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
|
||||
self.assertEqual(self.replacer.replace('new', args, kwargs),
|
||||
('old', (1,), dict(y=2, callback='new', z=3)))
|
||||
|
||||
|
||||
class TimedeltaToSecondsTest(unittest.TestCase):
|
||||
def test_timedelta_to_seconds(self):
|
||||
time_delta = datetime.timedelta(hours=1)
|
||||
self.assertEqual(timedelta_to_seconds(time_delta), 3600.0)
|
||||
|
||||
|
||||
class ImportObjectTest(unittest.TestCase):
|
||||
def test_import_member(self):
|
||||
self.assertIs(import_object('tornado.escape.utf8'), utf8)
|
||||
|
||||
def test_import_member_unicode(self):
|
||||
self.assertIs(import_object(u'tornado.escape.utf8'), utf8)
|
||||
|
||||
def test_import_module(self):
|
||||
self.assertIs(import_object('tornado.escape'), tornado.escape)
|
||||
|
||||
def test_import_module_unicode(self):
|
||||
# The internal implementation of __import__ differs depending on
|
||||
# whether the thing being imported is a module or not.
|
||||
# This variant requires a byte string in python 2.
|
||||
self.assertIs(import_object(u'tornado.escape'), tornado.escape)
|
||||
|
||||
|
||||
class ReUnescapeTest(unittest.TestCase):
|
||||
def test_re_unescape(self):
|
||||
test_strings = (
|
||||
'/favicon.ico',
|
||||
'index.html',
|
||||
'Hello, World!',
|
||||
'!$@#%;',
|
||||
)
|
||||
for string in test_strings:
|
||||
self.assertEqual(string, re_unescape(re.escape(string)))
|
||||
|
||||
def test_re_unescape_raises_error_on_invalid_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
re_unescape('\\d')
|
||||
with self.assertRaises(ValueError):
|
||||
re_unescape('\\b')
|
||||
with self.assertRaises(ValueError):
|
||||
re_unescape('\\Z')
|
||||
|
||||
|
||||
class IsFinalizingTest(unittest.TestCase):
|
||||
def test_basic(self):
|
||||
self.assertFalse(is_finalizing())
|
||||
Executable
+2967
File diff suppressed because it is too large
Load Diff
Executable
+805
@@ -0,0 +1,805 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import functools
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from tornado.concurrent import Future
|
||||
from tornado import gen
|
||||
from tornado.httpclient import HTTPError, HTTPRequest
|
||||
from tornado.locks import Event
|
||||
from tornado.log import gen_log, app_log
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient
|
||||
from tornado.template import DictLoader
|
||||
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
|
||||
from tornado.test.util import unittest, skipBefore35, exec_test
|
||||
from tornado.web import Application, RequestHandler
|
||||
|
||||
try:
|
||||
import tornado.websocket # noqa
|
||||
from tornado.util import _websocket_mask_python
|
||||
except ImportError:
|
||||
# The unittest module presents misleading errors on ImportError
|
||||
# (it acts as if websocket_test could not be found, hiding the underlying
|
||||
# error). If we get an ImportError here (which could happen due to
|
||||
# TORNADO_EXTENSION=1), print some extra information before failing.
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
from tornado.websocket import (
|
||||
WebSocketHandler, websocket_connect, WebSocketError, WebSocketClosedError,
|
||||
)
|
||||
|
||||
try:
|
||||
from tornado import speedups
|
||||
except ImportError:
|
||||
speedups = None
|
||||
|
||||
|
||||
class TestWebSocketHandler(WebSocketHandler):
|
||||
"""Base class for testing handlers that exposes the on_close event.
|
||||
|
||||
This allows for deterministic cleanup of the associated socket.
|
||||
"""
|
||||
def initialize(self, close_future, compression_options=None):
|
||||
self.close_future = close_future
|
||||
self.compression_options = compression_options
|
||||
|
||||
def get_compression_options(self):
|
||||
return self.compression_options
|
||||
|
||||
def on_close(self):
|
||||
self.close_future.set_result((self.close_code, self.close_reason))
|
||||
|
||||
|
||||
class EchoHandler(TestWebSocketHandler):
|
||||
@gen.coroutine
|
||||
def on_message(self, message):
|
||||
try:
|
||||
yield self.write_message(message, isinstance(message, bytes))
|
||||
except WebSocketClosedError:
|
||||
pass
|
||||
|
||||
|
||||
class ErrorInOnMessageHandler(TestWebSocketHandler):
|
||||
def on_message(self, message):
|
||||
1 / 0
|
||||
|
||||
|
||||
class HeaderHandler(TestWebSocketHandler):
|
||||
def open(self):
|
||||
methods_to_test = [
|
||||
functools.partial(self.write, 'This should not work'),
|
||||
functools.partial(self.redirect, 'http://localhost/elsewhere'),
|
||||
functools.partial(self.set_header, 'X-Test', ''),
|
||||
functools.partial(self.set_cookie, 'Chocolate', 'Chip'),
|
||||
functools.partial(self.set_status, 503),
|
||||
self.flush,
|
||||
self.finish,
|
||||
]
|
||||
for method in methods_to_test:
|
||||
try:
|
||||
# In a websocket context, many RequestHandler methods
|
||||
# raise RuntimeErrors.
|
||||
method()
|
||||
raise Exception("did not get expected exception")
|
||||
except RuntimeError:
|
||||
pass
|
||||
self.write_message(self.request.headers.get('X-Test', ''))
|
||||
|
||||
|
||||
class HeaderEchoHandler(TestWebSocketHandler):
|
||||
def set_default_headers(self):
|
||||
self.set_header("X-Extra-Response-Header", "Extra-Response-Value")
|
||||
|
||||
def prepare(self):
|
||||
for k, v in self.request.headers.get_all():
|
||||
if k.lower().startswith('x-test'):
|
||||
self.set_header(k, v)
|
||||
|
||||
|
||||
class NonWebSocketHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.write('ok')
|
||||
|
||||
|
||||
class CloseReasonHandler(TestWebSocketHandler):
|
||||
def open(self):
|
||||
self.on_close_called = False
|
||||
self.close(1001, "goodbye")
|
||||
|
||||
|
||||
class AsyncPrepareHandler(TestWebSocketHandler):
|
||||
@gen.coroutine
|
||||
def prepare(self):
|
||||
yield gen.moment
|
||||
|
||||
def on_message(self, message):
|
||||
self.write_message(message)
|
||||
|
||||
|
||||
class PathArgsHandler(TestWebSocketHandler):
|
||||
def open(self, arg):
|
||||
self.write_message(arg)
|
||||
|
||||
|
||||
class CoroutineOnMessageHandler(TestWebSocketHandler):
|
||||
def initialize(self, close_future, compression_options=None):
|
||||
super(CoroutineOnMessageHandler, self).initialize(close_future,
|
||||
compression_options)
|
||||
self.sleeping = 0
|
||||
|
||||
@gen.coroutine
|
||||
def on_message(self, message):
|
||||
if self.sleeping > 0:
|
||||
self.write_message('another coroutine is already sleeping')
|
||||
self.sleeping += 1
|
||||
yield gen.sleep(0.01)
|
||||
self.sleeping -= 1
|
||||
self.write_message(message)
|
||||
|
||||
|
||||
class RenderMessageHandler(TestWebSocketHandler):
|
||||
def on_message(self, message):
|
||||
self.write_message(self.render_string('message.html', message=message))
|
||||
|
||||
|
||||
class SubprotocolHandler(TestWebSocketHandler):
|
||||
def initialize(self, **kwargs):
|
||||
super(SubprotocolHandler, self).initialize(**kwargs)
|
||||
self.select_subprotocol_called = False
|
||||
|
||||
def select_subprotocol(self, subprotocols):
|
||||
if self.select_subprotocol_called:
|
||||
raise Exception("select_subprotocol called twice")
|
||||
self.select_subprotocol_called = True
|
||||
if 'goodproto' in subprotocols:
|
||||
return 'goodproto'
|
||||
return None
|
||||
|
||||
def open(self):
|
||||
if not self.select_subprotocol_called:
|
||||
raise Exception("select_subprotocol not called")
|
||||
self.write_message("subprotocol=%s" % self.selected_subprotocol)
|
||||
|
||||
|
||||
class OpenCoroutineHandler(TestWebSocketHandler):
|
||||
def initialize(self, test, **kwargs):
|
||||
super(OpenCoroutineHandler, self).initialize(**kwargs)
|
||||
self.test = test
|
||||
self.open_finished = False
|
||||
|
||||
@gen.coroutine
|
||||
def open(self):
|
||||
yield self.test.message_sent.wait()
|
||||
yield gen.sleep(0.010)
|
||||
self.open_finished = True
|
||||
|
||||
def on_message(self, message):
|
||||
if not self.open_finished:
|
||||
raise Exception('on_message called before open finished')
|
||||
self.write_message('ok')
|
||||
|
||||
|
||||
class ErrorInOpenHandler(TestWebSocketHandler):
|
||||
def open(self):
|
||||
raise Exception("boom")
|
||||
|
||||
|
||||
class ErrorInAsyncOpenHandler(TestWebSocketHandler):
|
||||
@gen.coroutine
|
||||
def open(self):
|
||||
yield gen.sleep(0.01)
|
||||
raise Exception("boom")
|
||||
|
||||
|
||||
class WebSocketBaseTestCase(AsyncHTTPTestCase):
|
||||
@gen.coroutine
|
||||
def ws_connect(self, path, **kwargs):
|
||||
ws = yield websocket_connect(
|
||||
'ws://127.0.0.1:%d%s' % (self.get_http_port(), path),
|
||||
**kwargs)
|
||||
raise gen.Return(ws)
|
||||
|
||||
@gen.coroutine
|
||||
def close(self, ws):
|
||||
"""Close a websocket connection and wait for the server side.
|
||||
|
||||
If we don't wait here, there are sometimes leak warnings in the
|
||||
tests.
|
||||
"""
|
||||
ws.close()
|
||||
yield self.close_future
|
||||
|
||||
|
||||
class WebSocketTest(WebSocketBaseTestCase):
|
||||
def get_app(self):
|
||||
self.close_future = Future()
|
||||
return Application([
|
||||
('/echo', EchoHandler, dict(close_future=self.close_future)),
|
||||
('/non_ws', NonWebSocketHandler),
|
||||
('/header', HeaderHandler, dict(close_future=self.close_future)),
|
||||
('/header_echo', HeaderEchoHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/close_reason', CloseReasonHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/error_in_on_message', ErrorInOnMessageHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/async_prepare', AsyncPrepareHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/path_args/(.*)', PathArgsHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/coroutine', CoroutineOnMessageHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/render', RenderMessageHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/subprotocol', SubprotocolHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
('/open_coroutine', OpenCoroutineHandler,
|
||||
dict(close_future=self.close_future, test=self)),
|
||||
("/error_in_open", ErrorInOpenHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
("/error_in_async_open", ErrorInAsyncOpenHandler,
|
||||
dict(close_future=self.close_future)),
|
||||
], template_loader=DictLoader({
|
||||
'message.html': '<b>{{ message }}</b>',
|
||||
}))
|
||||
|
||||
def get_http_client(self):
|
||||
# These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
|
||||
return SimpleAsyncHTTPClient()
|
||||
|
||||
def tearDown(self):
|
||||
super(WebSocketTest, self).tearDown()
|
||||
RequestHandler._template_loaders.clear()
|
||||
|
||||
def test_http_request(self):
|
||||
# WS server, HTTP client.
|
||||
response = self.fetch('/echo')
|
||||
self.assertEqual(response.code, 400)
|
||||
|
||||
def test_missing_websocket_key(self):
|
||||
response = self.fetch('/echo',
|
||||
headers={'Connection': 'Upgrade',
|
||||
'Upgrade': 'WebSocket',
|
||||
'Sec-WebSocket-Version': '13'})
|
||||
self.assertEqual(response.code, 400)
|
||||
|
||||
def test_bad_websocket_version(self):
|
||||
response = self.fetch('/echo',
|
||||
headers={'Connection': 'Upgrade',
|
||||
'Upgrade': 'WebSocket',
|
||||
'Sec-WebSocket-Version': '12'})
|
||||
self.assertEqual(response.code, 426)
|
||||
|
||||
@gen_test
|
||||
def test_websocket_gen(self):
|
||||
ws = yield self.ws_connect('/echo')
|
||||
yield ws.write_message('hello')
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, 'hello')
|
||||
yield self.close(ws)
|
||||
|
||||
def test_websocket_callbacks(self):
|
||||
websocket_connect(
|
||||
'ws://127.0.0.1:%d/echo' % self.get_http_port(),
|
||||
callback=self.stop)
|
||||
ws = self.wait().result()
|
||||
ws.write_message('hello')
|
||||
ws.read_message(self.stop)
|
||||
response = self.wait().result()
|
||||
self.assertEqual(response, 'hello')
|
||||
self.close_future.add_done_callback(lambda f: self.stop())
|
||||
ws.close()
|
||||
self.wait()
|
||||
|
||||
@gen_test
|
||||
def test_binary_message(self):
|
||||
ws = yield self.ws_connect('/echo')
|
||||
ws.write_message(b'hello \xe9', binary=True)
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, b'hello \xe9')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_unicode_message(self):
|
||||
ws = yield self.ws_connect('/echo')
|
||||
ws.write_message(u'hello \u00e9')
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, u'hello \u00e9')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_render_message(self):
|
||||
ws = yield self.ws_connect('/render')
|
||||
ws.write_message('hello')
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, '<b>hello</b>')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_error_in_on_message(self):
|
||||
ws = yield self.ws_connect('/error_in_on_message')
|
||||
ws.write_message('hello')
|
||||
with ExpectLog(app_log, "Uncaught exception"):
|
||||
response = yield ws.read_message()
|
||||
self.assertIs(response, None)
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_websocket_http_fail(self):
|
||||
with self.assertRaises(HTTPError) as cm:
|
||||
yield self.ws_connect('/notfound')
|
||||
self.assertEqual(cm.exception.code, 404)
|
||||
|
||||
@gen_test
|
||||
def test_websocket_http_success(self):
|
||||
with self.assertRaises(WebSocketError):
|
||||
yield self.ws_connect('/non_ws')
|
||||
|
||||
@gen_test
|
||||
def test_websocket_network_fail(self):
|
||||
sock, port = bind_unused_port()
|
||||
sock.close()
|
||||
with self.assertRaises(IOError):
|
||||
with ExpectLog(gen_log, ".*"):
|
||||
yield websocket_connect(
|
||||
'ws://127.0.0.1:%d/' % port,
|
||||
connect_timeout=3600)
|
||||
|
||||
@gen_test
|
||||
def test_websocket_close_buffered_data(self):
|
||||
ws = yield websocket_connect(
|
||||
'ws://127.0.0.1:%d/echo' % self.get_http_port())
|
||||
ws.write_message('hello')
|
||||
ws.write_message('world')
|
||||
# Close the underlying stream.
|
||||
ws.stream.close()
|
||||
yield self.close_future
|
||||
|
||||
@gen_test
|
||||
def test_websocket_headers(self):
|
||||
# Ensure that arbitrary headers can be passed through websocket_connect.
|
||||
ws = yield websocket_connect(
|
||||
HTTPRequest('ws://127.0.0.1:%d/header' % self.get_http_port(),
|
||||
headers={'X-Test': 'hello'}))
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, 'hello')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_websocket_header_echo(self):
|
||||
# Ensure that headers can be returned in the response.
|
||||
# Specifically, that arbitrary headers passed through websocket_connect
|
||||
# can be returned.
|
||||
ws = yield websocket_connect(
|
||||
HTTPRequest('ws://127.0.0.1:%d/header_echo' % self.get_http_port(),
|
||||
headers={'X-Test-Hello': 'hello'}))
|
||||
self.assertEqual(ws.headers.get('X-Test-Hello'), 'hello')
|
||||
self.assertEqual(ws.headers.get('X-Extra-Response-Header'), 'Extra-Response-Value')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_server_close_reason(self):
|
||||
ws = yield self.ws_connect('/close_reason')
|
||||
msg = yield ws.read_message()
|
||||
# A message of None means the other side closed the connection.
|
||||
self.assertIs(msg, None)
|
||||
self.assertEqual(ws.close_code, 1001)
|
||||
self.assertEqual(ws.close_reason, "goodbye")
|
||||
# The on_close callback is called no matter which side closed.
|
||||
code, reason = yield self.close_future
|
||||
# The client echoed the close code it received to the server,
|
||||
# so the server's close code (returned via close_future) is
|
||||
# the same.
|
||||
self.assertEqual(code, 1001)
|
||||
|
||||
@gen_test
|
||||
def test_client_close_reason(self):
|
||||
ws = yield self.ws_connect('/echo')
|
||||
ws.close(1001, 'goodbye')
|
||||
code, reason = yield self.close_future
|
||||
self.assertEqual(code, 1001)
|
||||
self.assertEqual(reason, 'goodbye')
|
||||
|
||||
@gen_test
|
||||
def test_write_after_close(self):
|
||||
ws = yield self.ws_connect('/close_reason')
|
||||
msg = yield ws.read_message()
|
||||
self.assertIs(msg, None)
|
||||
with self.assertRaises(WebSocketClosedError):
|
||||
ws.write_message('hello')
|
||||
|
||||
@gen_test
|
||||
def test_async_prepare(self):
|
||||
# Previously, an async prepare method triggered a bug that would
|
||||
# result in a timeout on test shutdown (and a memory leak).
|
||||
ws = yield self.ws_connect('/async_prepare')
|
||||
ws.write_message('hello')
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'hello')
|
||||
|
||||
@gen_test
|
||||
def test_path_args(self):
|
||||
ws = yield self.ws_connect('/path_args/hello')
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'hello')
|
||||
|
||||
@gen_test
|
||||
def test_coroutine(self):
|
||||
ws = yield self.ws_connect('/coroutine')
|
||||
# Send both messages immediately, coroutine must process one at a time.
|
||||
yield ws.write_message('hello1')
|
||||
yield ws.write_message('hello2')
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'hello1')
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'hello2')
|
||||
|
||||
@gen_test
|
||||
def test_check_origin_valid_no_path(self):
|
||||
port = self.get_http_port()
|
||||
|
||||
url = 'ws://127.0.0.1:%d/echo' % port
|
||||
headers = {'Origin': 'http://127.0.0.1:%d' % port}
|
||||
|
||||
ws = yield websocket_connect(HTTPRequest(url, headers=headers))
|
||||
ws.write_message('hello')
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, 'hello')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_check_origin_valid_with_path(self):
|
||||
port = self.get_http_port()
|
||||
|
||||
url = 'ws://127.0.0.1:%d/echo' % port
|
||||
headers = {'Origin': 'http://127.0.0.1:%d/something' % port}
|
||||
|
||||
ws = yield websocket_connect(HTTPRequest(url, headers=headers))
|
||||
ws.write_message('hello')
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, 'hello')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_check_origin_invalid_partial_url(self):
|
||||
port = self.get_http_port()
|
||||
|
||||
url = 'ws://127.0.0.1:%d/echo' % port
|
||||
headers = {'Origin': '127.0.0.1:%d' % port}
|
||||
|
||||
with self.assertRaises(HTTPError) as cm:
|
||||
yield websocket_connect(HTTPRequest(url, headers=headers))
|
||||
self.assertEqual(cm.exception.code, 403)
|
||||
|
||||
@gen_test
|
||||
def test_check_origin_invalid(self):
|
||||
port = self.get_http_port()
|
||||
|
||||
url = 'ws://127.0.0.1:%d/echo' % port
|
||||
# Host is 127.0.0.1, which should not be accessible from some other
|
||||
# domain
|
||||
headers = {'Origin': 'http://somewhereelse.com'}
|
||||
|
||||
with self.assertRaises(HTTPError) as cm:
|
||||
yield websocket_connect(HTTPRequest(url, headers=headers))
|
||||
|
||||
self.assertEqual(cm.exception.code, 403)
|
||||
|
||||
@gen_test
|
||||
def test_check_origin_invalid_subdomains(self):
|
||||
port = self.get_http_port()
|
||||
|
||||
url = 'ws://localhost:%d/echo' % port
|
||||
# Subdomains should be disallowed by default. If we could pass a
|
||||
# resolver to websocket_connect we could test sibling domains as well.
|
||||
headers = {'Origin': 'http://subtenant.localhost'}
|
||||
|
||||
with self.assertRaises(HTTPError) as cm:
|
||||
yield websocket_connect(HTTPRequest(url, headers=headers))
|
||||
|
||||
self.assertEqual(cm.exception.code, 403)
|
||||
|
||||
@gen_test
|
||||
def test_subprotocols(self):
|
||||
ws = yield self.ws_connect('/subprotocol', subprotocols=['badproto', 'goodproto'])
|
||||
self.assertEqual(ws.selected_subprotocol, 'goodproto')
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'subprotocol=goodproto')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_subprotocols_not_offered(self):
|
||||
ws = yield self.ws_connect('/subprotocol')
|
||||
self.assertIs(ws.selected_subprotocol, None)
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'subprotocol=None')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_open_coroutine(self):
|
||||
self.message_sent = Event()
|
||||
ws = yield self.ws_connect('/open_coroutine')
|
||||
yield ws.write_message('hello')
|
||||
self.message_sent.set()
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'ok')
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_error_in_open(self):
|
||||
with ExpectLog(app_log, "Uncaught exception"):
|
||||
ws = yield self.ws_connect("/error_in_open")
|
||||
res = yield ws.read_message()
|
||||
self.assertIsNone(res)
|
||||
|
||||
@gen_test
|
||||
def test_error_in_async_open(self):
|
||||
with ExpectLog(app_log, "Uncaught exception"):
|
||||
ws = yield self.ws_connect("/error_in_async_open")
|
||||
res = yield ws.read_message()
|
||||
self.assertIsNone(res)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 5):
|
||||
NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """
|
||||
class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
|
||||
def initialize(self, close_future, compression_options=None):
|
||||
super().initialize(close_future, compression_options)
|
||||
self.sleeping = 0
|
||||
|
||||
async def on_message(self, message):
|
||||
if self.sleeping > 0:
|
||||
self.write_message('another coroutine is already sleeping')
|
||||
self.sleeping += 1
|
||||
await gen.sleep(0.01)
|
||||
self.sleeping -= 1
|
||||
self.write_message(message)""")['NativeCoroutineOnMessageHandler']
|
||||
|
||||
|
||||
class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
|
||||
def get_app(self):
|
||||
self.close_future = Future()
|
||||
return Application([
|
||||
('/native', NativeCoroutineOnMessageHandler,
|
||||
dict(close_future=self.close_future))])
|
||||
|
||||
@skipBefore35
|
||||
@gen_test
|
||||
def test_native_coroutine(self):
|
||||
ws = yield self.ws_connect('/native')
|
||||
# Send both messages immediately, coroutine must process one at a time.
|
||||
yield ws.write_message('hello1')
|
||||
yield ws.write_message('hello2')
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'hello1')
|
||||
res = yield ws.read_message()
|
||||
self.assertEqual(res, 'hello2')
|
||||
|
||||
|
||||
class CompressionTestMixin(object):
|
||||
MESSAGE = 'Hello world. Testing 123 123'
|
||||
|
||||
def get_app(self):
|
||||
self.close_future = Future()
|
||||
|
||||
class LimitedHandler(TestWebSocketHandler):
|
||||
@property
|
||||
def max_message_size(self):
|
||||
return 1024
|
||||
|
||||
def on_message(self, message):
|
||||
self.write_message(str(len(message)))
|
||||
|
||||
return Application([
|
||||
('/echo', EchoHandler, dict(
|
||||
close_future=self.close_future,
|
||||
compression_options=self.get_server_compression_options())),
|
||||
('/limited', LimitedHandler, dict(
|
||||
close_future=self.close_future,
|
||||
compression_options=self.get_server_compression_options())),
|
||||
])
|
||||
|
||||
def get_server_compression_options(self):
|
||||
return None
|
||||
|
||||
def get_client_compression_options(self):
|
||||
return None
|
||||
|
||||
@gen_test
|
||||
def test_message_sizes(self):
|
||||
ws = yield self.ws_connect(
|
||||
'/echo',
|
||||
compression_options=self.get_client_compression_options())
|
||||
# Send the same message three times so we can measure the
|
||||
# effect of the context_takeover options.
|
||||
for i in range(3):
|
||||
ws.write_message(self.MESSAGE)
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, self.MESSAGE)
|
||||
self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
|
||||
self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
|
||||
self.verify_wire_bytes(ws.protocol._wire_bytes_in,
|
||||
ws.protocol._wire_bytes_out)
|
||||
yield self.close(ws)
|
||||
|
||||
@gen_test
|
||||
def test_size_limit(self):
|
||||
ws = yield self.ws_connect(
|
||||
'/limited',
|
||||
compression_options=self.get_client_compression_options())
|
||||
# Small messages pass through.
|
||||
ws.write_message('a' * 128)
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, '128')
|
||||
# This message is too big after decompression, but it compresses
|
||||
# down to a size that will pass the initial checks.
|
||||
ws.write_message('a' * 2048)
|
||||
response = yield ws.read_message()
|
||||
self.assertIsNone(response)
|
||||
yield self.close(ws)
|
||||
|
||||
|
||||
class UncompressedTestMixin(CompressionTestMixin):
|
||||
"""Specialization of CompressionTestMixin when we expect no compression."""
|
||||
def verify_wire_bytes(self, bytes_in, bytes_out):
|
||||
# Bytes out includes the 4-byte mask key per message.
|
||||
self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6))
|
||||
self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2))
|
||||
|
||||
|
||||
class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
|
||||
pass
|
||||
|
||||
|
||||
# If only one side tries to compress, the extension is not negotiated.
|
||||
class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
|
||||
def get_server_compression_options(self):
|
||||
return {}
|
||||
|
||||
|
||||
class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
|
||||
def get_client_compression_options(self):
|
||||
return {}
|
||||
|
||||
|
||||
class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase):
|
||||
def get_server_compression_options(self):
|
||||
return {}
|
||||
|
||||
def get_client_compression_options(self):
|
||||
return {}
|
||||
|
||||
def verify_wire_bytes(self, bytes_in, bytes_out):
|
||||
self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6))
|
||||
self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2))
|
||||
# Bytes out includes the 4 bytes mask key per message.
|
||||
self.assertEqual(bytes_out, bytes_in + 12)
|
||||
|
||||
|
||||
class MaskFunctionMixin(object):
|
||||
# Subclasses should define self.mask(mask, data)
|
||||
def test_mask(self):
|
||||
self.assertEqual(self.mask(b'abcd', b''), b'')
|
||||
self.assertEqual(self.mask(b'abcd', b'b'), b'\x03')
|
||||
self.assertEqual(self.mask(b'abcd', b'54321'), b'TVPVP')
|
||||
self.assertEqual(self.mask(b'ZXCV', b'98765432'), b'c`t`olpd')
|
||||
# Include test cases with \x00 bytes (to ensure that the C
|
||||
# extension isn't depending on null-terminated strings) and
|
||||
# bytes with the high bit set (to smoke out signedness issues).
|
||||
self.assertEqual(self.mask(b'\x00\x01\x02\x03',
|
||||
b'\xff\xfb\xfd\xfc\xfe\xfa'),
|
||||
b'\xff\xfa\xff\xff\xfe\xfb')
|
||||
self.assertEqual(self.mask(b'\xff\xfb\xfd\xfc',
|
||||
b'\x00\x01\x02\x03\x04\x05'),
|
||||
b'\xff\xfa\xff\xff\xfb\xfe')
|
||||
|
||||
|
||||
class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
|
||||
def mask(self, mask, data):
|
||||
return _websocket_mask_python(mask, data)
|
||||
|
||||
|
||||
@unittest.skipIf(speedups is None, "tornado.speedups module not present")
|
||||
class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
|
||||
def mask(self, mask, data):
|
||||
return speedups.websocket_mask(mask, data)
|
||||
|
||||
|
||||
class ServerPeriodicPingTest(WebSocketBaseTestCase):
|
||||
def get_app(self):
|
||||
class PingHandler(TestWebSocketHandler):
|
||||
def on_pong(self, data):
|
||||
self.write_message("got pong")
|
||||
|
||||
self.close_future = Future()
|
||||
return Application([
|
||||
('/', PingHandler, dict(close_future=self.close_future)),
|
||||
], websocket_ping_interval=0.01)
|
||||
|
||||
@gen_test
|
||||
def test_server_ping(self):
|
||||
ws = yield self.ws_connect('/')
|
||||
for i in range(3):
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, "got pong")
|
||||
yield self.close(ws)
|
||||
# TODO: test that the connection gets closed if ping responses stop.
|
||||
|
||||
|
||||
class ClientPeriodicPingTest(WebSocketBaseTestCase):
|
||||
def get_app(self):
|
||||
class PingHandler(TestWebSocketHandler):
|
||||
def on_ping(self, data):
|
||||
self.write_message("got ping")
|
||||
|
||||
self.close_future = Future()
|
||||
return Application([
|
||||
('/', PingHandler, dict(close_future=self.close_future)),
|
||||
])
|
||||
|
||||
@gen_test
|
||||
def test_client_ping(self):
|
||||
ws = yield self.ws_connect('/', ping_interval=0.01)
|
||||
for i in range(3):
|
||||
response = yield ws.read_message()
|
||||
self.assertEqual(response, "got ping")
|
||||
yield self.close(ws)
|
||||
# TODO: test that the connection gets closed if ping responses stop.
|
||||
|
||||
|
||||
class ManualPingTest(WebSocketBaseTestCase):
|
||||
def get_app(self):
|
||||
class PingHandler(TestWebSocketHandler):
|
||||
def on_ping(self, data):
|
||||
self.write_message(data, binary=isinstance(data, bytes))
|
||||
|
||||
self.close_future = Future()
|
||||
return Application([
|
||||
('/', PingHandler, dict(close_future=self.close_future)),
|
||||
])
|
||||
|
||||
@gen_test
|
||||
def test_manual_ping(self):
|
||||
ws = yield self.ws_connect('/')
|
||||
|
||||
self.assertRaises(ValueError, ws.ping, 'a' * 126)
|
||||
|
||||
ws.ping('hello')
|
||||
resp = yield ws.read_message()
|
||||
# on_ping always sees bytes.
|
||||
self.assertEqual(resp, b'hello')
|
||||
|
||||
ws.ping(b'binary hello')
|
||||
resp = yield ws.read_message()
|
||||
self.assertEqual(resp, b'binary hello')
|
||||
yield self.close(ws)
|
||||
|
||||
|
||||
class MaxMessageSizeTest(WebSocketBaseTestCase):
|
||||
def get_app(self):
|
||||
self.close_future = Future()
|
||||
return Application([
|
||||
('/', EchoHandler, dict(close_future=self.close_future)),
|
||||
], websocket_max_message_size=1024)
|
||||
|
||||
@gen_test
|
||||
def test_large_message(self):
|
||||
ws = yield self.ws_connect('/')
|
||||
|
||||
# Write a message that is allowed.
|
||||
msg = 'a' * 1024
|
||||
ws.write_message(msg)
|
||||
resp = yield ws.read_message()
|
||||
self.assertEqual(resp, msg)
|
||||
|
||||
# Write a message that is too large.
|
||||
ws.write_message(msg + 'b')
|
||||
resp = yield ws.read_message()
|
||||
# A message of None means the other side closed the connection.
|
||||
self.assertIs(resp, None)
|
||||
self.assertEqual(ws.close_code, 1009)
|
||||
self.assertEqual(ws.close_reason, "message too big")
|
||||
# TODO: Needs tests of messages split over multiple
|
||||
# continuation frames.
|
||||
Executable
+25
@@ -0,0 +1,25 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import functools
|
||||
import os
|
||||
import socket
|
||||
import unittest
|
||||
|
||||
from tornado.platform.auto import set_close_exec
|
||||
|
||||
skipIfNonWindows = unittest.skipIf(os.name != 'nt', 'non-windows platform')
|
||||
|
||||
|
||||
@skipIfNonWindows
|
||||
class WindowsTest(unittest.TestCase):
|
||||
def test_set_close_exec(self):
|
||||
# set_close_exec works with sockets.
|
||||
s = socket.socket()
|
||||
self.addCleanup(s.close)
|
||||
set_close_exec(s.fileno())
|
||||
|
||||
# But it doesn't work with pipes.
|
||||
r, w = os.pipe()
|
||||
self.addCleanup(functools.partial(os.close, r))
|
||||
self.addCleanup(functools.partial(os.close, w))
|
||||
with self.assertRaises(WindowsError):
|
||||
set_close_exec(r)
|
||||
Executable
+118
@@ -0,0 +1,118 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
from wsgiref.validate import validator
|
||||
|
||||
from tornado.escape import json_decode
|
||||
from tornado.test.httpserver_test import TypeCheckHandler
|
||||
from tornado.test.util import ignore_deprecation
|
||||
from tornado.testing import AsyncHTTPTestCase
|
||||
from tornado.web import RequestHandler, Application
|
||||
from tornado.wsgi import WSGIApplication, WSGIContainer, WSGIAdapter
|
||||
|
||||
from tornado.test import httpserver_test
|
||||
from tornado.test import web_test
|
||||
|
||||
|
||||
class WSGIContainerTest(AsyncHTTPTestCase):
|
||||
def wsgi_app(self, environ, start_response):
|
||||
status = "200 OK"
|
||||
response_headers = [("Content-Type", "text/plain")]
|
||||
start_response(status, response_headers)
|
||||
return [b"Hello world!"]
|
||||
|
||||
def get_app(self):
|
||||
return WSGIContainer(validator(self.wsgi_app))
|
||||
|
||||
def test_simple(self):
|
||||
response = self.fetch("/")
|
||||
self.assertEqual(response.body, b"Hello world!")
|
||||
|
||||
|
||||
class WSGIAdapterTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
class HelloHandler(RequestHandler):
|
||||
def get(self):
|
||||
self.write("Hello world!")
|
||||
|
||||
class PathQuotingHandler(RequestHandler):
|
||||
def get(self, path):
|
||||
self.write(path)
|
||||
|
||||
# It would be better to run the wsgiref server implementation in
|
||||
# another thread instead of using our own WSGIContainer, but this
|
||||
# fits better in our async testing framework and the wsgiref
|
||||
# validator should keep us honest
|
||||
with ignore_deprecation():
|
||||
return WSGIContainer(validator(WSGIAdapter(
|
||||
Application([
|
||||
("/", HelloHandler),
|
||||
("/path/(.*)", PathQuotingHandler),
|
||||
("/typecheck", TypeCheckHandler),
|
||||
]))))
|
||||
|
||||
def test_simple(self):
|
||||
response = self.fetch("/")
|
||||
self.assertEqual(response.body, b"Hello world!")
|
||||
|
||||
def test_path_quoting(self):
|
||||
response = self.fetch("/path/foo%20bar%C3%A9")
|
||||
self.assertEqual(response.body, u"foo bar\u00e9".encode("utf-8"))
|
||||
|
||||
def test_types(self):
|
||||
headers = {"Cookie": "foo=bar"}
|
||||
response = self.fetch("/typecheck?foo=bar", headers=headers)
|
||||
data = json_decode(response.body)
|
||||
self.assertEqual(data, {})
|
||||
|
||||
response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers)
|
||||
data = json_decode(response.body)
|
||||
self.assertEqual(data, {})
|
||||
|
||||
|
||||
# This is kind of hacky, but run some of the HTTPServer and web tests
|
||||
# through WSGIContainer and WSGIApplication to make sure everything
|
||||
# survives repeated disassembly and reassembly.
|
||||
class WSGIConnectionTest(httpserver_test.HTTPConnectionTest):
|
||||
def get_app(self):
|
||||
with ignore_deprecation():
|
||||
return WSGIContainer(validator(WSGIAdapter(Application(self.get_handlers()))))
|
||||
|
||||
|
||||
def wrap_web_tests_application():
|
||||
result = {}
|
||||
for cls in web_test.wsgi_safe_tests:
|
||||
def class_factory():
|
||||
class WSGIApplicationWrappedTest(cls): # type: ignore
|
||||
def setUp(self):
|
||||
self.warning_catcher = ignore_deprecation()
|
||||
self.warning_catcher.__enter__()
|
||||
super(WSGIApplicationWrappedTest, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super(WSGIApplicationWrappedTest, self).tearDown()
|
||||
self.warning_catcher.__exit__(None, None, None)
|
||||
|
||||
def get_app(self):
|
||||
self.app = WSGIApplication(self.get_handlers(),
|
||||
**self.get_app_kwargs())
|
||||
return WSGIContainer(validator(self.app))
|
||||
result["WSGIApplication_" + cls.__name__] = class_factory()
|
||||
return result
|
||||
|
||||
|
||||
globals().update(wrap_web_tests_application())
|
||||
|
||||
|
||||
def wrap_web_tests_adapter():
|
||||
result = {}
|
||||
for cls in web_test.wsgi_safe_tests:
|
||||
class WSGIAdapterWrappedTest(cls): # type: ignore
|
||||
def get_app(self):
|
||||
self.app = Application(self.get_handlers(),
|
||||
**self.get_app_kwargs())
|
||||
with ignore_deprecation():
|
||||
return WSGIContainer(validator(WSGIAdapter(self.app)))
|
||||
result["WSGIAdapter_" + cls.__name__] = WSGIAdapterWrappedTest
|
||||
return result
|
||||
|
||||
|
||||
globals().update(wrap_web_tests_adapter())
|
||||
Executable
+724
@@ -0,0 +1,724 @@
|
||||
"""Support classes for automated testing.
|
||||
|
||||
* `AsyncTestCase` and `AsyncHTTPTestCase`: Subclasses of unittest.TestCase
|
||||
with additional support for testing asynchronous (`.IOLoop`-based) code.
|
||||
|
||||
* `ExpectLog`: Make test logs less spammy.
|
||||
|
||||
* `main()`: A simple test runner (wrapper around unittest.main()) with support
|
||||
for the tornado.autoreload module to rerun the tests when code changes.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
try:
|
||||
from tornado import gen
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient
|
||||
from tornado.ioloop import IOLoop, TimeoutError
|
||||
from tornado import netutil
|
||||
from tornado.process import Subprocess
|
||||
except ImportError:
|
||||
# These modules are not importable on app engine. Parts of this module
|
||||
# won't work, but e.g. main() will.
|
||||
AsyncHTTPClient = None # type: ignore
|
||||
gen = None # type: ignore
|
||||
HTTPServer = None # type: ignore
|
||||
IOLoop = None # type: ignore
|
||||
netutil = None # type: ignore
|
||||
SimpleAsyncHTTPClient = None # type: ignore
|
||||
Subprocess = None # type: ignore
|
||||
from tornado.log import app_log
|
||||
from tornado.stack_context import ExceptionStackContext
|
||||
from tornado.util import raise_exc_info, basestring_type, PY3
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError:
|
||||
asyncio = None
|
||||
|
||||
|
||||
try:
|
||||
from collections.abc import Generator as GeneratorType # type: ignore
|
||||
except ImportError:
|
||||
from types import GeneratorType # type: ignore
|
||||
|
||||
if sys.version_info >= (3, 5):
|
||||
iscoroutine = inspect.iscoroutine # type: ignore
|
||||
iscoroutinefunction = inspect.iscoroutinefunction # type: ignore
|
||||
else:
|
||||
iscoroutine = iscoroutinefunction = lambda f: False
|
||||
|
||||
# Tornado's own test suite requires the updated unittest module
|
||||
# (either py27+ or unittest2) so tornado.test.util enforces
|
||||
# this requirement, but for other users of tornado.testing we want
|
||||
# to allow the older version if unitest2 is not available.
|
||||
if PY3:
|
||||
# On python 3, mixing unittest2 and unittest (including doctest)
|
||||
# doesn't seem to work, so always use unittest.
|
||||
import unittest
|
||||
else:
|
||||
# On python 2, prefer unittest2 when available.
|
||||
try:
|
||||
import unittest2 as unittest # type: ignore
|
||||
except ImportError:
|
||||
import unittest # type: ignore
|
||||
|
||||
|
||||
if asyncio is None:
|
||||
_NON_OWNED_IOLOOPS = ()
|
||||
else:
|
||||
import tornado.platform.asyncio
|
||||
_NON_OWNED_IOLOOPS = tornado.platform.asyncio.AsyncIOMainLoop
|
||||
|
||||
|
||||
def bind_unused_port(reuse_port=False):
|
||||
"""Binds a server socket to an available port on localhost.
|
||||
|
||||
Returns a tuple (socket, port).
|
||||
|
||||
.. versionchanged:: 4.4
|
||||
Always binds to ``127.0.0.1`` without resolving the name
|
||||
``localhost``.
|
||||
"""
|
||||
sock = netutil.bind_sockets(None, '127.0.0.1', family=socket.AF_INET,
|
||||
reuse_port=reuse_port)[0]
|
||||
port = sock.getsockname()[1]
|
||||
return sock, port
|
||||
|
||||
|
||||
def get_async_test_timeout():
|
||||
"""Get the global timeout setting for async tests.
|
||||
|
||||
Returns a float, the timeout in seconds.
|
||||
|
||||
.. versionadded:: 3.1
|
||||
"""
|
||||
try:
|
||||
return float(os.environ.get('ASYNC_TEST_TIMEOUT'))
|
||||
except (ValueError, TypeError):
|
||||
return 5
|
||||
|
||||
|
||||
class _TestMethodWrapper(object):
|
||||
"""Wraps a test method to raise an error if it returns a value.
|
||||
|
||||
This is mainly used to detect undecorated generators (if a test
|
||||
method yields it must use a decorator to consume the generator),
|
||||
but will also detect other kinds of return values (these are not
|
||||
necessarily errors, but we alert anyway since there is no good
|
||||
reason to return a value from a test).
|
||||
"""
|
||||
def __init__(self, orig_method):
|
||||
self.orig_method = orig_method
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
result = self.orig_method(*args, **kwargs)
|
||||
if isinstance(result, GeneratorType) or iscoroutine(result):
|
||||
raise TypeError("Generator and coroutine test methods should be"
|
||||
" decorated with tornado.testing.gen_test")
|
||||
elif result is not None:
|
||||
raise ValueError("Return value from test method ignored: %r" %
|
||||
result)
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Proxy all unknown attributes to the original method.
|
||||
|
||||
This is important for some of the decorators in the `unittest`
|
||||
module, such as `unittest.skipIf`.
|
||||
"""
|
||||
return getattr(self.orig_method, name)
|
||||
|
||||
|
||||
class AsyncTestCase(unittest.TestCase):
|
||||
"""`~unittest.TestCase` subclass for testing `.IOLoop`-based
|
||||
asynchronous code.
|
||||
|
||||
The unittest framework is synchronous, so the test must be
|
||||
complete by the time the test method returns. This means that
|
||||
asynchronous code cannot be used in quite the same way as usual
|
||||
and must be adapted to fit. To write your tests with coroutines,
|
||||
decorate your test methods with `tornado.testing.gen_test` instead
|
||||
of `tornado.gen.coroutine`.
|
||||
|
||||
This class also provides the (deprecated) `stop()` and `wait()`
|
||||
methods for a more manual style of testing. The test method itself
|
||||
must call ``self.wait()``, and asynchronous callbacks should call
|
||||
``self.stop()`` to signal completion.
|
||||
|
||||
By default, a new `.IOLoop` is constructed for each test and is available
|
||||
as ``self.io_loop``. If the code being tested requires a
|
||||
global `.IOLoop`, subclasses should override `get_new_ioloop` to return it.
|
||||
|
||||
The `.IOLoop`'s ``start`` and ``stop`` methods should not be
|
||||
called directly. Instead, use `self.stop <stop>` and `self.wait
|
||||
<wait>`. Arguments passed to ``self.stop`` are returned from
|
||||
``self.wait``. It is possible to have multiple ``wait``/``stop``
|
||||
cycles in the same test.
|
||||
|
||||
Example::
|
||||
|
||||
# This test uses coroutine style.
|
||||
class MyTestCase(AsyncTestCase):
|
||||
@tornado.testing.gen_test
|
||||
def test_http_fetch(self):
|
||||
client = AsyncHTTPClient()
|
||||
response = yield client.fetch("http://www.tornadoweb.org")
|
||||
# Test contents of response
|
||||
self.assertIn("FriendFeed", response.body)
|
||||
|
||||
# This test uses argument passing between self.stop and self.wait.
|
||||
class MyTestCase2(AsyncTestCase):
|
||||
def test_http_fetch(self):
|
||||
client = AsyncHTTPClient()
|
||||
client.fetch("http://www.tornadoweb.org/", self.stop)
|
||||
response = self.wait()
|
||||
# Test contents of response
|
||||
self.assertIn("FriendFeed", response.body)
|
||||
"""
|
||||
def __init__(self, methodName='runTest'):
|
||||
super(AsyncTestCase, self).__init__(methodName)
|
||||
self.__stopped = False
|
||||
self.__running = False
|
||||
self.__failure = None
|
||||
self.__stop_args = None
|
||||
self.__timeout = None
|
||||
|
||||
# It's easy to forget the @gen_test decorator, but if you do
|
||||
# the test will silently be ignored because nothing will consume
|
||||
# the generator. Replace the test method with a wrapper that will
|
||||
# make sure it's not an undecorated generator.
|
||||
setattr(self, methodName, _TestMethodWrapper(getattr(self, methodName)))
|
||||
|
||||
def setUp(self):
|
||||
super(AsyncTestCase, self).setUp()
|
||||
self.io_loop = self.get_new_ioloop()
|
||||
self.io_loop.make_current()
|
||||
|
||||
def tearDown(self):
|
||||
# Clean up Subprocess, so it can be used again with a new ioloop.
|
||||
Subprocess.uninitialize()
|
||||
self.io_loop.clear_current()
|
||||
if not isinstance(self.io_loop, _NON_OWNED_IOLOOPS):
|
||||
# Try to clean up any file descriptors left open in the ioloop.
|
||||
# This avoids leaks, especially when tests are run repeatedly
|
||||
# in the same process with autoreload (because curl does not
|
||||
# set FD_CLOEXEC on its file descriptors)
|
||||
self.io_loop.close(all_fds=True)
|
||||
super(AsyncTestCase, self).tearDown()
|
||||
# In case an exception escaped or the StackContext caught an exception
|
||||
# when there wasn't a wait() to re-raise it, do so here.
|
||||
# This is our last chance to raise an exception in a way that the
|
||||
# unittest machinery understands.
|
||||
self.__rethrow()
|
||||
|
||||
def get_new_ioloop(self):
|
||||
"""Returns the `.IOLoop` to use for this test.
|
||||
|
||||
By default, a new `.IOLoop` is created for each test.
|
||||
Subclasses may override this method to return
|
||||
`.IOLoop.current()` if it is not appropriate to use a new
|
||||
`.IOLoop` in each tests (for example, if there are global
|
||||
singletons using the default `.IOLoop`) or if a per-test event
|
||||
loop is being provided by another system (such as
|
||||
``pytest-asyncio``).
|
||||
"""
|
||||
return IOLoop()
|
||||
|
||||
def _handle_exception(self, typ, value, tb):
|
||||
if self.__failure is None:
|
||||
self.__failure = (typ, value, tb)
|
||||
else:
|
||||
app_log.error("multiple unhandled exceptions in test",
|
||||
exc_info=(typ, value, tb))
|
||||
self.stop()
|
||||
return True
|
||||
|
||||
def __rethrow(self):
|
||||
if self.__failure is not None:
|
||||
failure = self.__failure
|
||||
self.__failure = None
|
||||
raise_exc_info(failure)
|
||||
|
||||
def run(self, result=None):
|
||||
with ExceptionStackContext(self._handle_exception, delay_warning=True):
|
||||
super(AsyncTestCase, self).run(result)
|
||||
# As a last resort, if an exception escaped super.run() and wasn't
|
||||
# re-raised in tearDown, raise it here. This will cause the
|
||||
# unittest run to fail messily, but that's better than silently
|
||||
# ignoring an error.
|
||||
self.__rethrow()
|
||||
|
||||
def stop(self, _arg=None, **kwargs):
|
||||
"""Stops the `.IOLoop`, causing one pending (or future) call to `wait()`
|
||||
to return.
|
||||
|
||||
Keyword arguments or a single positional argument passed to `stop()` are
|
||||
saved and will be returned by `wait()`.
|
||||
|
||||
.. deprecated:: 5.1
|
||||
|
||||
`stop` and `wait` are deprecated; use ``@gen_test`` instead.
|
||||
"""
|
||||
assert _arg is None or not kwargs
|
||||
self.__stop_args = kwargs or _arg
|
||||
if self.__running:
|
||||
self.io_loop.stop()
|
||||
self.__running = False
|
||||
self.__stopped = True
|
||||
|
||||
def wait(self, condition=None, timeout=None):
|
||||
"""Runs the `.IOLoop` until stop is called or timeout has passed.
|
||||
|
||||
In the event of a timeout, an exception will be thrown. The
|
||||
default timeout is 5 seconds; it may be overridden with a
|
||||
``timeout`` keyword argument or globally with the
|
||||
``ASYNC_TEST_TIMEOUT`` environment variable.
|
||||
|
||||
If ``condition`` is not None, the `.IOLoop` will be restarted
|
||||
after `stop()` until ``condition()`` returns true.
|
||||
|
||||
.. versionchanged:: 3.1
|
||||
Added the ``ASYNC_TEST_TIMEOUT`` environment variable.
|
||||
|
||||
.. deprecated:: 5.1
|
||||
|
||||
`stop` and `wait` are deprecated; use ``@gen_test`` instead.
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = get_async_test_timeout()
|
||||
|
||||
if not self.__stopped:
|
||||
if timeout:
|
||||
def timeout_func():
|
||||
try:
|
||||
raise self.failureException(
|
||||
'Async operation timed out after %s seconds' %
|
||||
timeout)
|
||||
except Exception:
|
||||
self.__failure = sys.exc_info()
|
||||
self.stop()
|
||||
self.__timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout,
|
||||
timeout_func)
|
||||
while True:
|
||||
self.__running = True
|
||||
self.io_loop.start()
|
||||
if (self.__failure is not None or
|
||||
condition is None or condition()):
|
||||
break
|
||||
if self.__timeout is not None:
|
||||
self.io_loop.remove_timeout(self.__timeout)
|
||||
self.__timeout = None
|
||||
assert self.__stopped
|
||||
self.__stopped = False
|
||||
self.__rethrow()
|
||||
result = self.__stop_args
|
||||
self.__stop_args = None
|
||||
return result
|
||||
|
||||
|
||||
class AsyncHTTPTestCase(AsyncTestCase):
|
||||
"""A test case that starts up an HTTP server.
|
||||
|
||||
Subclasses must override `get_app()`, which returns the
|
||||
`tornado.web.Application` (or other `.HTTPServer` callback) to be tested.
|
||||
Tests will typically use the provided ``self.http_client`` to fetch
|
||||
URLs from this server.
|
||||
|
||||
Example, assuming the "Hello, world" example from the user guide is in
|
||||
``hello.py``::
|
||||
|
||||
import hello
|
||||
|
||||
class TestHelloApp(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
return hello.make_app()
|
||||
|
||||
def test_homepage(self):
|
||||
response = self.fetch('/')
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertEqual(response.body, 'Hello, world')
|
||||
|
||||
That call to ``self.fetch()`` is equivalent to ::
|
||||
|
||||
self.http_client.fetch(self.get_url('/'), self.stop)
|
||||
response = self.wait()
|
||||
|
||||
which illustrates how AsyncTestCase can turn an asynchronous operation,
|
||||
like ``http_client.fetch()``, into a synchronous operation. If you need
|
||||
to do other asynchronous operations in tests, you'll probably need to use
|
||||
``stop()`` and ``wait()`` yourself.
|
||||
"""
|
||||
def setUp(self):
|
||||
super(AsyncHTTPTestCase, self).setUp()
|
||||
sock, port = bind_unused_port()
|
||||
self.__port = port
|
||||
|
||||
self.http_client = self.get_http_client()
|
||||
self._app = self.get_app()
|
||||
self.http_server = self.get_http_server()
|
||||
self.http_server.add_sockets([sock])
|
||||
|
||||
def get_http_client(self):
|
||||
return AsyncHTTPClient()
|
||||
|
||||
def get_http_server(self):
|
||||
return HTTPServer(self._app, **self.get_httpserver_options())
|
||||
|
||||
def get_app(self):
|
||||
"""Should be overridden by subclasses to return a
|
||||
`tornado.web.Application` or other `.HTTPServer` callback.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def fetch(self, path, raise_error=False, **kwargs):
|
||||
"""Convenience method to synchronously fetch a URL.
|
||||
|
||||
The given path will be appended to the local server's host and
|
||||
port. Any additional kwargs will be passed directly to
|
||||
`.AsyncHTTPClient.fetch` (and so could be used to pass
|
||||
``method="POST"``, ``body="..."``, etc).
|
||||
|
||||
If the path begins with http:// or https://, it will be treated as a
|
||||
full URL and will be fetched as-is.
|
||||
|
||||
If ``raise_error`` is True, a `tornado.httpclient.HTTPError` will
|
||||
be raised if the response code is not 200. This is the same behavior
|
||||
as the ``raise_error`` argument to `.AsyncHTTPClient.fetch`, but
|
||||
the default is False here (it's True in `.AsyncHTTPClient`) because
|
||||
tests often need to deal with non-200 response codes.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
Added support for absolute URLs.
|
||||
|
||||
.. versionchanged:: 5.1
|
||||
|
||||
Added the ``raise_error`` argument.
|
||||
|
||||
.. deprecated:: 5.1
|
||||
|
||||
This method currently turns any exception into an
|
||||
`.HTTPResponse` with status code 599. In Tornado 6.0,
|
||||
errors other than `tornado.httpclient.HTTPError` will be
|
||||
passed through, and ``raise_error=False`` will only
|
||||
suppress errors that would be raised due to non-200
|
||||
response codes.
|
||||
|
||||
"""
|
||||
if path.lower().startswith(('http://', 'https://')):
|
||||
url = path
|
||||
else:
|
||||
url = self.get_url(path)
|
||||
return self.io_loop.run_sync(
|
||||
lambda: self.http_client.fetch(url, raise_error=raise_error, **kwargs),
|
||||
timeout=get_async_test_timeout())
|
||||
|
||||
def get_httpserver_options(self):
|
||||
"""May be overridden by subclasses to return additional
|
||||
keyword arguments for the server.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_http_port(self):
|
||||
"""Returns the port used by the server.
|
||||
|
||||
A new port is chosen for each test.
|
||||
"""
|
||||
return self.__port
|
||||
|
||||
def get_protocol(self):
|
||||
return 'http'
|
||||
|
||||
def get_url(self, path):
|
||||
"""Returns an absolute url for the given path on the test server."""
|
||||
return '%s://127.0.0.1:%s%s' % (self.get_protocol(),
|
||||
self.get_http_port(), path)
|
||||
|
||||
def tearDown(self):
|
||||
self.http_server.stop()
|
||||
self.io_loop.run_sync(self.http_server.close_all_connections,
|
||||
timeout=get_async_test_timeout())
|
||||
self.http_client.close()
|
||||
super(AsyncHTTPTestCase, self).tearDown()
|
||||
|
||||
|
||||
class AsyncHTTPSTestCase(AsyncHTTPTestCase):
|
||||
"""A test case that starts an HTTPS server.
|
||||
|
||||
Interface is generally the same as `AsyncHTTPTestCase`.
|
||||
"""
|
||||
def get_http_client(self):
|
||||
return AsyncHTTPClient(force_instance=True,
|
||||
defaults=dict(validate_cert=False))
|
||||
|
||||
def get_httpserver_options(self):
|
||||
return dict(ssl_options=self.get_ssl_options())
|
||||
|
||||
def get_ssl_options(self):
|
||||
"""May be overridden by subclasses to select SSL options.
|
||||
|
||||
By default includes a self-signed testing certificate.
|
||||
"""
|
||||
# Testing keys were generated with:
|
||||
# openssl req -new -keyout tornado/test/test.key \
|
||||
# -out tornado/test/test.crt -nodes -days 3650 -x509
|
||||
module_dir = os.path.dirname(__file__)
|
||||
return dict(
|
||||
certfile=os.path.join(module_dir, 'test', 'test.crt'),
|
||||
keyfile=os.path.join(module_dir, 'test', 'test.key'))
|
||||
|
||||
def get_protocol(self):
|
||||
return 'https'
|
||||
|
||||
|
||||
def gen_test(func=None, timeout=None):
|
||||
"""Testing equivalent of ``@gen.coroutine``, to be applied to test methods.
|
||||
|
||||
``@gen.coroutine`` cannot be used on tests because the `.IOLoop` is not
|
||||
already running. ``@gen_test`` should be applied to test methods
|
||||
on subclasses of `AsyncTestCase`.
|
||||
|
||||
Example::
|
||||
|
||||
class MyTest(AsyncHTTPTestCase):
|
||||
@gen_test
|
||||
def test_something(self):
|
||||
response = yield self.http_client.fetch(self.get_url('/'))
|
||||
|
||||
By default, ``@gen_test`` times out after 5 seconds. The timeout may be
|
||||
overridden globally with the ``ASYNC_TEST_TIMEOUT`` environment variable,
|
||||
or for each test with the ``timeout`` keyword argument::
|
||||
|
||||
class MyTest(AsyncHTTPTestCase):
|
||||
@gen_test(timeout=10)
|
||||
def test_something_slow(self):
|
||||
response = yield self.http_client.fetch(self.get_url('/'))
|
||||
|
||||
Note that ``@gen_test`` is incompatible with `AsyncTestCase.stop`,
|
||||
`AsyncTestCase.wait`, and `AsyncHTTPTestCase.fetch`. Use ``yield
|
||||
self.http_client.fetch(self.get_url())`` as shown above instead.
|
||||
|
||||
.. versionadded:: 3.1
|
||||
The ``timeout`` argument and ``ASYNC_TEST_TIMEOUT`` environment
|
||||
variable.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
The wrapper now passes along ``*args, **kwargs`` so it can be used
|
||||
on functions with arguments.
|
||||
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = get_async_test_timeout()
|
||||
|
||||
def wrap(f):
|
||||
# Stack up several decorators to allow us to access the generator
|
||||
# object itself. In the innermost wrapper, we capture the generator
|
||||
# and save it in an attribute of self. Next, we run the wrapped
|
||||
# function through @gen.coroutine. Finally, the coroutine is
|
||||
# wrapped again to make it synchronous with run_sync.
|
||||
#
|
||||
# This is a good case study arguing for either some sort of
|
||||
# extensibility in the gen decorators or cancellation support.
|
||||
@functools.wraps(f)
|
||||
def pre_coroutine(self, *args, **kwargs):
|
||||
result = f(self, *args, **kwargs)
|
||||
if isinstance(result, GeneratorType) or iscoroutine(result):
|
||||
self._test_generator = result
|
||||
else:
|
||||
self._test_generator = None
|
||||
return result
|
||||
|
||||
if iscoroutinefunction(f):
|
||||
coro = pre_coroutine
|
||||
else:
|
||||
coro = gen.coroutine(pre_coroutine)
|
||||
|
||||
@functools.wraps(coro)
|
||||
def post_coroutine(self, *args, **kwargs):
|
||||
try:
|
||||
return self.io_loop.run_sync(
|
||||
functools.partial(coro, self, *args, **kwargs),
|
||||
timeout=timeout)
|
||||
except TimeoutError as e:
|
||||
# run_sync raises an error with an unhelpful traceback.
|
||||
# If the underlying generator is still running, we can throw the
|
||||
# exception back into it so the stack trace is replaced by the
|
||||
# point where the test is stopped. The only reason the generator
|
||||
# would not be running would be if it were cancelled, which means
|
||||
# a native coroutine, so we can rely on the cr_running attribute.
|
||||
if getattr(self._test_generator, 'cr_running', True):
|
||||
self._test_generator.throw(e)
|
||||
# In case the test contains an overly broad except
|
||||
# clause, we may get back here.
|
||||
# Coroutine was stopped or didn't raise a useful stack trace,
|
||||
# so re-raise the original exception which is better than nothing.
|
||||
raise
|
||||
return post_coroutine
|
||||
|
||||
if func is not None:
|
||||
# Used like:
|
||||
# @gen_test
|
||||
# def f(self):
|
||||
# pass
|
||||
return wrap(func)
|
||||
else:
|
||||
# Used like @gen_test(timeout=10)
|
||||
return wrap
|
||||
|
||||
|
||||
# Without this attribute, nosetests will try to run gen_test as a test
|
||||
# anywhere it is imported.
|
||||
gen_test.__test__ = False # type: ignore
|
||||
|
||||
|
||||
class ExpectLog(logging.Filter):
|
||||
"""Context manager to capture and suppress expected log output.
|
||||
|
||||
Useful to make tests of error conditions less noisy, while still
|
||||
leaving unexpected log entries visible. *Not thread safe.*
|
||||
|
||||
The attribute ``logged_stack`` is set to true if any exception
|
||||
stack trace was logged.
|
||||
|
||||
Usage::
|
||||
|
||||
with ExpectLog('tornado.application', "Uncaught exception"):
|
||||
error_response = self.fetch("/some_page")
|
||||
|
||||
.. versionchanged:: 4.3
|
||||
Added the ``logged_stack`` attribute.
|
||||
"""
|
||||
def __init__(self, logger, regex, required=True):
|
||||
"""Constructs an ExpectLog context manager.
|
||||
|
||||
:param logger: Logger object (or name of logger) to watch. Pass
|
||||
an empty string to watch the root logger.
|
||||
:param regex: Regular expression to match. Any log entries on
|
||||
the specified logger that match this regex will be suppressed.
|
||||
:param required: If true, an exception will be raised if the end of
|
||||
the ``with`` statement is reached without matching any log entries.
|
||||
"""
|
||||
if isinstance(logger, basestring_type):
|
||||
logger = logging.getLogger(logger)
|
||||
self.logger = logger
|
||||
self.regex = re.compile(regex)
|
||||
self.required = required
|
||||
self.matched = False
|
||||
self.logged_stack = False
|
||||
|
||||
def filter(self, record):
|
||||
if record.exc_info:
|
||||
self.logged_stack = True
|
||||
message = record.getMessage()
|
||||
if self.regex.match(message):
|
||||
self.matched = True
|
||||
return False
|
||||
return True
|
||||
|
||||
def __enter__(self):
|
||||
self.logger.addFilter(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, typ, value, tb):
|
||||
self.logger.removeFilter(self)
|
||||
if not typ and self.required and not self.matched:
|
||||
raise Exception("did not get expected log message")
|
||||
|
||||
|
||||
def main(**kwargs):
|
||||
"""A simple test runner.
|
||||
|
||||
This test runner is essentially equivalent to `unittest.main` from
|
||||
the standard library, but adds support for tornado-style option
|
||||
parsing and log formatting. It is *not* necessary to use this
|
||||
`main` function to run tests using `AsyncTestCase`; these tests
|
||||
are self-contained and can run with any test runner.
|
||||
|
||||
The easiest way to run a test is via the command line::
|
||||
|
||||
python -m tornado.testing tornado.test.stack_context_test
|
||||
|
||||
See the standard library unittest module for ways in which tests can
|
||||
be specified.
|
||||
|
||||
Projects with many tests may wish to define a test script like
|
||||
``tornado/test/runtests.py``. This script should define a method
|
||||
``all()`` which returns a test suite and then call
|
||||
`tornado.testing.main()`. Note that even when a test script is
|
||||
used, the ``all()`` test suite may be overridden by naming a
|
||||
single test on the command line::
|
||||
|
||||
# Runs all tests
|
||||
python -m tornado.test.runtests
|
||||
# Runs one test
|
||||
python -m tornado.test.runtests tornado.test.stack_context_test
|
||||
|
||||
Additional keyword arguments passed through to ``unittest.main()``.
|
||||
For example, use ``tornado.testing.main(verbosity=2)``
|
||||
to show many test details as they are run.
|
||||
See http://docs.python.org/library/unittest.html#unittest.main
|
||||
for full argument list.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
|
||||
This function produces no output of its own; only that produced
|
||||
by the `unittest` module (Previously it would add a PASS or FAIL
|
||||
log message).
|
||||
"""
|
||||
from tornado.options import define, options, parse_command_line
|
||||
|
||||
define('exception_on_interrupt', type=bool, default=True,
|
||||
help=("If true (default), ctrl-c raises a KeyboardInterrupt "
|
||||
"exception. This prints a stack trace but cannot interrupt "
|
||||
"certain operations. If false, the process is more reliably "
|
||||
"killed, but does not print a stack trace."))
|
||||
|
||||
# support the same options as unittest's command-line interface
|
||||
define('verbose', type=bool)
|
||||
define('quiet', type=bool)
|
||||
define('failfast', type=bool)
|
||||
define('catch', type=bool)
|
||||
define('buffer', type=bool)
|
||||
|
||||
argv = [sys.argv[0]] + parse_command_line(sys.argv)
|
||||
|
||||
if not options.exception_on_interrupt:
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
|
||||
if options.verbose is not None:
|
||||
kwargs['verbosity'] = 2
|
||||
if options.quiet is not None:
|
||||
kwargs['verbosity'] = 0
|
||||
if options.failfast is not None:
|
||||
kwargs['failfast'] = True
|
||||
if options.catch is not None:
|
||||
kwargs['catchbreak'] = True
|
||||
if options.buffer is not None:
|
||||
kwargs['buffer'] = True
|
||||
|
||||
if __name__ == '__main__' and len(argv) == 1:
|
||||
print("No tests specified", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
# In order to be able to run tests by their fully-qualified name
|
||||
# on the command line without importing all tests here,
|
||||
# module must be set to None. Python 3.2's unittest.main ignores
|
||||
# defaultTest if no module is given (it tries to do its own
|
||||
# test discovery, which is incompatible with auto2to3), so don't
|
||||
# set module if we're not asking for a specific test.
|
||||
if len(argv) > 1:
|
||||
unittest.main(module=None, argv=argv, **kwargs)
|
||||
else:
|
||||
unittest.main(defaultTest="all", argv=argv, **kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Executable
+497
@@ -0,0 +1,497 @@
|
||||
"""Miscellaneous utility functions and classes.
|
||||
|
||||
This module is used internally by Tornado. It is not necessarily expected
|
||||
that the functions and classes defined here will be useful to other
|
||||
applications, but they are documented here in case they are.
|
||||
|
||||
The one public-facing part of this module is the `Configurable` class
|
||||
and its `~Configurable.configure` method, which becomes a part of the
|
||||
interface of its subclasses, including `.AsyncHTTPClient`, `.IOLoop`,
|
||||
and `.Resolver`.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import array
|
||||
import atexit
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import zlib
|
||||
|
||||
PY3 = sys.version_info >= (3,)
|
||||
|
||||
if PY3:
|
||||
xrange = range
|
||||
|
||||
# inspect.getargspec() raises DeprecationWarnings in Python 3.5.
|
||||
# The two functions have compatible interfaces for the parts we need.
|
||||
if PY3:
|
||||
from inspect import getfullargspec as getargspec
|
||||
else:
|
||||
from inspect import getargspec
|
||||
|
||||
# Aliases for types that are spelled differently in different Python
|
||||
# versions. bytes_type is deprecated and no longer used in Tornado
|
||||
# itself but is left in case anyone outside Tornado is using it.
|
||||
bytes_type = bytes
|
||||
if PY3:
|
||||
unicode_type = str
|
||||
basestring_type = str
|
||||
else:
|
||||
# The names unicode and basestring don't exist in py3 so silence flake8.
|
||||
unicode_type = unicode # noqa
|
||||
basestring_type = basestring # noqa
|
||||
|
||||
|
||||
try:
|
||||
import typing # noqa
|
||||
from typing import cast
|
||||
|
||||
_ObjectDictBase = typing.Dict[str, typing.Any]
|
||||
except ImportError:
|
||||
_ObjectDictBase = dict
|
||||
|
||||
def cast(typ, x):
|
||||
return x
|
||||
else:
|
||||
# More imports that are only needed in type comments.
|
||||
import datetime # noqa
|
||||
import types # noqa
|
||||
from typing import Any, AnyStr, Union, Optional, Dict, Mapping # noqa
|
||||
from typing import List, Tuple, Match, Callable # noqa
|
||||
|
||||
if PY3:
|
||||
_BaseString = str
|
||||
else:
|
||||
_BaseString = Union[bytes, unicode_type]
|
||||
|
||||
|
||||
try:
|
||||
from sys import is_finalizing
|
||||
except ImportError:
|
||||
# Emulate it
|
||||
def _get_emulated_is_finalizing():
|
||||
L = []
|
||||
atexit.register(lambda: L.append(None))
|
||||
|
||||
def is_finalizing():
|
||||
# Not referencing any globals here
|
||||
return L != []
|
||||
|
||||
return is_finalizing
|
||||
|
||||
is_finalizing = _get_emulated_is_finalizing()
|
||||
|
||||
|
||||
class TimeoutError(Exception):
|
||||
"""Exception raised by `.with_timeout` and `.IOLoop.run_sync`.
|
||||
|
||||
.. versionchanged:: 5.0:
|
||||
Unified ``tornado.gen.TimeoutError`` and
|
||||
``tornado.ioloop.TimeoutError`` as ``tornado.util.TimeoutError``.
|
||||
Both former names remain as aliases.
|
||||
"""
|
||||
|
||||
|
||||
class ObjectDict(_ObjectDictBase):
|
||||
"""Makes a dictionary behave like an object, with attribute-style access.
|
||||
"""
|
||||
def __getattr__(self, name):
|
||||
# type: (str) -> Any
|
||||
try:
|
||||
return self[name]
|
||||
except KeyError:
|
||||
raise AttributeError(name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
# type: (str, Any) -> None
|
||||
self[name] = value
|
||||
|
||||
|
||||
class GzipDecompressor(object):
|
||||
"""Streaming gzip decompressor.
|
||||
|
||||
The interface is like that of `zlib.decompressobj` (without some of the
|
||||
optional arguments, but it understands gzip headers and checksums.
|
||||
"""
|
||||
def __init__(self):
|
||||
# Magic parameter makes zlib module understand gzip header
|
||||
# http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib
|
||||
# This works on cpython and pypy, but not jython.
|
||||
self.decompressobj = zlib.decompressobj(16 + zlib.MAX_WBITS)
|
||||
|
||||
def decompress(self, value, max_length=None):
|
||||
# type: (bytes, Optional[int]) -> bytes
|
||||
"""Decompress a chunk, returning newly-available data.
|
||||
|
||||
Some data may be buffered for later processing; `flush` must
|
||||
be called when there is no more input data to ensure that
|
||||
all data was processed.
|
||||
|
||||
If ``max_length`` is given, some input data may be left over
|
||||
in ``unconsumed_tail``; you must retrieve this value and pass
|
||||
it back to a future call to `decompress` if it is not empty.
|
||||
"""
|
||||
return self.decompressobj.decompress(value, max_length)
|
||||
|
||||
@property
|
||||
def unconsumed_tail(self):
|
||||
# type: () -> bytes
|
||||
"""Returns the unconsumed portion left over
|
||||
"""
|
||||
return self.decompressobj.unconsumed_tail
|
||||
|
||||
def flush(self):
|
||||
# type: () -> bytes
|
||||
"""Return any remaining buffered data not yet returned by decompress.
|
||||
|
||||
Also checks for errors such as truncated input.
|
||||
No other methods may be called on this object after `flush`.
|
||||
"""
|
||||
return self.decompressobj.flush()
|
||||
|
||||
|
||||
def import_object(name):
|
||||
# type: (_BaseString) -> Any
|
||||
"""Imports an object by name.
|
||||
|
||||
import_object('x') is equivalent to 'import x'.
|
||||
import_object('x.y.z') is equivalent to 'from x.y import z'.
|
||||
|
||||
>>> import tornado.escape
|
||||
>>> import_object('tornado.escape') is tornado.escape
|
||||
True
|
||||
>>> import_object('tornado.escape.utf8') is tornado.escape.utf8
|
||||
True
|
||||
>>> import_object('tornado') is tornado
|
||||
True
|
||||
>>> import_object('tornado.missing_module')
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ImportError: No module named missing_module
|
||||
"""
|
||||
if not isinstance(name, str):
|
||||
# on python 2 a byte string is required.
|
||||
name = name.encode('utf-8')
|
||||
if name.count('.') == 0:
|
||||
return __import__(name, None, None)
|
||||
|
||||
parts = name.split('.')
|
||||
obj = __import__('.'.join(parts[:-1]), None, None, [parts[-1]], 0)
|
||||
try:
|
||||
return getattr(obj, parts[-1])
|
||||
except AttributeError:
|
||||
raise ImportError("No module named %s" % parts[-1])
|
||||
|
||||
|
||||
# Stubs to make mypy happy (and later for actual type-checking).
|
||||
def raise_exc_info(exc_info):
|
||||
# type: (Tuple[type, BaseException, types.TracebackType]) -> None
|
||||
pass
|
||||
|
||||
|
||||
def exec_in(code, glob, loc=None):
|
||||
# type: (Any, Dict[str, Any], Optional[Mapping[str, Any]]) -> Any
|
||||
if isinstance(code, basestring_type):
|
||||
# exec(string) inherits the caller's future imports; compile
|
||||
# the string first to prevent that.
|
||||
code = compile(code, '<string>', 'exec', dont_inherit=True)
|
||||
exec(code, glob, loc)
|
||||
|
||||
|
||||
if PY3:
|
||||
exec("""
|
||||
def raise_exc_info(exc_info):
|
||||
try:
|
||||
raise exc_info[1].with_traceback(exc_info[2])
|
||||
finally:
|
||||
exc_info = None
|
||||
|
||||
""")
|
||||
else:
|
||||
exec("""
|
||||
def raise_exc_info(exc_info):
|
||||
raise exc_info[0], exc_info[1], exc_info[2]
|
||||
""")
|
||||
|
||||
|
||||
def errno_from_exception(e):
|
||||
# type: (BaseException) -> Optional[int]
|
||||
"""Provides the errno from an Exception object.
|
||||
|
||||
There are cases that the errno attribute was not set so we pull
|
||||
the errno out of the args but if someone instantiates an Exception
|
||||
without any args you will get a tuple error. So this function
|
||||
abstracts all that behavior to give you a safe way to get the
|
||||
errno.
|
||||
"""
|
||||
|
||||
if hasattr(e, 'errno'):
|
||||
return e.errno # type: ignore
|
||||
elif e.args:
|
||||
return e.args[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
_alphanum = frozenset(
|
||||
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
|
||||
|
||||
|
||||
def _re_unescape_replacement(match):
|
||||
# type: (Match[str]) -> str
|
||||
group = match.group(1)
|
||||
if group[0] in _alphanum:
|
||||
raise ValueError("cannot unescape '\\\\%s'" % group[0])
|
||||
return group
|
||||
|
||||
|
||||
_re_unescape_pattern = re.compile(r'\\(.)', re.DOTALL)
|
||||
|
||||
|
||||
def re_unescape(s):
|
||||
# type: (str) -> str
|
||||
r"""Unescape a string escaped by `re.escape`.
|
||||
|
||||
May raise ``ValueError`` for regular expressions which could not
|
||||
have been produced by `re.escape` (for example, strings containing
|
||||
``\d`` cannot be unescaped).
|
||||
|
||||
.. versionadded:: 4.4
|
||||
"""
|
||||
return _re_unescape_pattern.sub(_re_unescape_replacement, s)
|
||||
|
||||
|
||||
class Configurable(object):
|
||||
"""Base class for configurable interfaces.
|
||||
|
||||
A configurable interface is an (abstract) class whose constructor
|
||||
acts as a factory function for one of its implementation subclasses.
|
||||
The implementation subclass as well as optional keyword arguments to
|
||||
its initializer can be set globally at runtime with `configure`.
|
||||
|
||||
By using the constructor as the factory method, the interface
|
||||
looks like a normal class, `isinstance` works as usual, etc. This
|
||||
pattern is most useful when the choice of implementation is likely
|
||||
to be a global decision (e.g. when `~select.epoll` is available,
|
||||
always use it instead of `~select.select`), or when a
|
||||
previously-monolithic class has been split into specialized
|
||||
subclasses.
|
||||
|
||||
Configurable subclasses must define the class methods
|
||||
`configurable_base` and `configurable_default`, and use the instance
|
||||
method `initialize` instead of ``__init__``.
|
||||
|
||||
.. versionchanged:: 5.0
|
||||
|
||||
It is now possible for configuration to be specified at
|
||||
multiple levels of a class hierarchy.
|
||||
|
||||
"""
|
||||
__impl_class = None # type: type
|
||||
__impl_kwargs = None # type: Dict[str, Any]
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
base = cls.configurable_base()
|
||||
init_kwargs = {}
|
||||
if cls is base:
|
||||
impl = cls.configured_class()
|
||||
if base.__impl_kwargs:
|
||||
init_kwargs.update(base.__impl_kwargs)
|
||||
else:
|
||||
impl = cls
|
||||
init_kwargs.update(kwargs)
|
||||
if impl.configurable_base() is not base:
|
||||
# The impl class is itself configurable, so recurse.
|
||||
return impl(*args, **init_kwargs)
|
||||
instance = super(Configurable, cls).__new__(impl)
|
||||
# initialize vs __init__ chosen for compatibility with AsyncHTTPClient
|
||||
# singleton magic. If we get rid of that we can switch to __init__
|
||||
# here too.
|
||||
instance.initialize(*args, **init_kwargs)
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def configurable_base(cls):
|
||||
# type: () -> Any
|
||||
# TODO: This class needs https://github.com/python/typing/issues/107
|
||||
# to be fully typeable.
|
||||
"""Returns the base class of a configurable hierarchy.
|
||||
|
||||
This will normally return the class in which it is defined.
|
||||
(which is *not* necessarily the same as the cls classmethod parameter).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def configurable_default(cls):
|
||||
# type: () -> type
|
||||
"""Returns the implementation class to be used if none is configured."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def initialize(self):
|
||||
# type: () -> None
|
||||
"""Initialize a `Configurable` subclass instance.
|
||||
|
||||
Configurable classes should use `initialize` instead of ``__init__``.
|
||||
|
||||
.. versionchanged:: 4.2
|
||||
Now accepts positional arguments in addition to keyword arguments.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def configure(cls, impl, **kwargs):
|
||||
# type: (Any, **Any) -> None
|
||||
"""Sets the class to use when the base class is instantiated.
|
||||
|
||||
Keyword arguments will be saved and added to the arguments passed
|
||||
to the constructor. This can be used to set global defaults for
|
||||
some parameters.
|
||||
"""
|
||||
base = cls.configurable_base()
|
||||
if isinstance(impl, (str, unicode_type)):
|
||||
impl = import_object(impl)
|
||||
if impl is not None and not issubclass(impl, cls):
|
||||
raise ValueError("Invalid subclass of %s" % cls)
|
||||
base.__impl_class = impl
|
||||
base.__impl_kwargs = kwargs
|
||||
|
||||
@classmethod
|
||||
def configured_class(cls):
|
||||
# type: () -> type
|
||||
"""Returns the currently configured class."""
|
||||
base = cls.configurable_base()
|
||||
# Manually mangle the private name to see whether this base
|
||||
# has been configured (and not another base higher in the
|
||||
# hierarchy).
|
||||
if base.__dict__.get('_Configurable__impl_class') is None:
|
||||
base.__impl_class = cls.configurable_default()
|
||||
return base.__impl_class
|
||||
|
||||
@classmethod
|
||||
def _save_configuration(cls):
|
||||
# type: () -> Tuple[type, Dict[str, Any]]
|
||||
base = cls.configurable_base()
|
||||
return (base.__impl_class, base.__impl_kwargs)
|
||||
|
||||
@classmethod
|
||||
def _restore_configuration(cls, saved):
|
||||
# type: (Tuple[type, Dict[str, Any]]) -> None
|
||||
base = cls.configurable_base()
|
||||
base.__impl_class = saved[0]
|
||||
base.__impl_kwargs = saved[1]
|
||||
|
||||
|
||||
class ArgReplacer(object):
|
||||
"""Replaces one value in an ``args, kwargs`` pair.
|
||||
|
||||
Inspects the function signature to find an argument by name
|
||||
whether it is passed by position or keyword. For use in decorators
|
||||
and similar wrappers.
|
||||
"""
|
||||
def __init__(self, func, name):
|
||||
# type: (Callable, str) -> None
|
||||
self.name = name
|
||||
try:
|
||||
self.arg_pos = self._getargnames(func).index(name)
|
||||
except ValueError:
|
||||
# Not a positional parameter
|
||||
self.arg_pos = None
|
||||
|
||||
def _getargnames(self, func):
|
||||
# type: (Callable) -> List[str]
|
||||
try:
|
||||
return getargspec(func).args
|
||||
except TypeError:
|
||||
if hasattr(func, 'func_code'):
|
||||
# Cython-generated code has all the attributes needed
|
||||
# by inspect.getargspec, but the inspect module only
|
||||
# works with ordinary functions. Inline the portion of
|
||||
# getargspec that we need here. Note that for static
|
||||
# functions the @cython.binding(True) decorator must
|
||||
# be used (for methods it works out of the box).
|
||||
code = func.func_code # type: ignore
|
||||
return code.co_varnames[:code.co_argcount]
|
||||
raise
|
||||
|
||||
def get_old_value(self, args, kwargs, default=None):
|
||||
# type: (List[Any], Dict[str, Any], Any) -> Any
|
||||
"""Returns the old value of the named argument without replacing it.
|
||||
|
||||
Returns ``default`` if the argument is not present.
|
||||
"""
|
||||
if self.arg_pos is not None and len(args) > self.arg_pos:
|
||||
return args[self.arg_pos]
|
||||
else:
|
||||
return kwargs.get(self.name, default)
|
||||
|
||||
def replace(self, new_value, args, kwargs):
|
||||
# type: (Any, List[Any], Dict[str, Any]) -> Tuple[Any, List[Any], Dict[str, Any]]
|
||||
"""Replace the named argument in ``args, kwargs`` with ``new_value``.
|
||||
|
||||
Returns ``(old_value, args, kwargs)``. The returned ``args`` and
|
||||
``kwargs`` objects may not be the same as the input objects, or
|
||||
the input objects may be mutated.
|
||||
|
||||
If the named argument was not found, ``new_value`` will be added
|
||||
to ``kwargs`` and None will be returned as ``old_value``.
|
||||
"""
|
||||
if self.arg_pos is not None and len(args) > self.arg_pos:
|
||||
# The arg to replace is passed positionally
|
||||
old_value = args[self.arg_pos]
|
||||
args = list(args) # *args is normally a tuple
|
||||
args[self.arg_pos] = new_value
|
||||
else:
|
||||
# The arg to replace is either omitted or passed by keyword.
|
||||
old_value = kwargs.get(self.name)
|
||||
kwargs[self.name] = new_value
|
||||
return old_value, args, kwargs
|
||||
|
||||
|
||||
def timedelta_to_seconds(td):
|
||||
# type: (datetime.timedelta) -> float
|
||||
"""Equivalent to td.total_seconds() (introduced in python 2.7)."""
|
||||
return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / float(10 ** 6)
|
||||
|
||||
|
||||
def _websocket_mask_python(mask, data):
|
||||
# type: (bytes, bytes) -> bytes
|
||||
"""Websocket masking function.
|
||||
|
||||
`mask` is a `bytes` object of length 4; `data` is a `bytes` object of any length.
|
||||
Returns a `bytes` object of the same length as `data` with the mask applied
|
||||
as specified in section 5.3 of RFC 6455.
|
||||
|
||||
This pure-python implementation may be replaced by an optimized version when available.
|
||||
"""
|
||||
mask_arr = array.array("B", mask)
|
||||
unmasked_arr = array.array("B", data)
|
||||
for i in xrange(len(data)):
|
||||
unmasked_arr[i] = unmasked_arr[i] ^ mask_arr[i % 4]
|
||||
if PY3:
|
||||
# tostring was deprecated in py32. It hasn't been removed,
|
||||
# but since we turn on deprecation warnings in our tests
|
||||
# we need to use the right one.
|
||||
return unmasked_arr.tobytes()
|
||||
else:
|
||||
return unmasked_arr.tostring()
|
||||
|
||||
|
||||
if (os.environ.get('TORNADO_NO_EXTENSION') or
|
||||
os.environ.get('TORNADO_EXTENSION') == '0'):
|
||||
# These environment variables exist to make it easier to do performance
|
||||
# comparisons; they are not guaranteed to remain supported in the future.
|
||||
_websocket_mask = _websocket_mask_python
|
||||
else:
|
||||
try:
|
||||
from tornado.speedups import websocket_mask as _websocket_mask
|
||||
except ImportError:
|
||||
if os.environ.get('TORNADO_EXTENSION') == '1':
|
||||
raise
|
||||
_websocket_mask = _websocket_mask_python
|
||||
|
||||
|
||||
def doctests():
|
||||
import doctest
|
||||
return doctest.DocTestSuite()
|
||||
Executable
+3394
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user