From d4c10fae59a1c59f5f64a1d3fdb63d1e6aaea5da Mon Sep 17 00:00:00 2001 From: Ashwin Naren Date: Wed, 15 Jan 2025 11:59:49 -0800 Subject: [PATCH 1/8] updated logging to 3.12, added logging tests, and added smtplib and tests --- Lib/logging/__init__.py | 254 +- Lib/logging/config.py | 222 +- Lib/logging/handlers.py | 261 +- Lib/smtplib.py | 1109 ++++++ Lib/test/support/__init__.py | 35 +- Lib/test/support/smtpd.py | 873 +++++ Lib/test/test_logging.py | 7023 ++++++++++++++++++++++++++++++++++ Lib/test/test_smtplib.py | 1560 ++++++++ Lib/test/test_smtpnet.py | 90 + 9 files changed, 11169 insertions(+), 258 deletions(-) create mode 100644 Lib/smtplib.py create mode 100644 Lib/test/support/smtpd.py create mode 100644 Lib/test/test_logging.py create mode 100644 Lib/test/test_smtplib.py create mode 100644 Lib/test/test_smtpnet.py diff --git a/Lib/logging/__init__.py b/Lib/logging/__init__.py index 19bd2bc20b..28df075dcd 100644 --- a/Lib/logging/__init__.py +++ b/Lib/logging/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2001-2019 by Vinay Sajip. All Rights Reserved. +# Copyright 2001-2022 by Vinay Sajip. All Rights Reserved. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose and without fee is hereby granted, @@ -18,13 +18,14 @@ Logging package for Python. Based on PEP 282 and comments thereto in comp.lang.python. -Copyright (C) 2001-2019 Vinay Sajip. All Rights Reserved. +Copyright (C) 2001-2022 Vinay Sajip. All Rights Reserved. To use, simply 'import logging' and log away! """ import sys, os, time, io, re, traceback, warnings, weakref, collections.abc +from types import GenericAlias from string import Template from string import Formatter as StrFormatter @@ -37,7 +38,8 @@ 'exception', 'fatal', 'getLevelName', 'getLogger', 'getLoggerClass', 'info', 'log', 'makeLogRecord', 'setLoggerClass', 'shutdown', 'warn', 'warning', 'getLogRecordFactory', 'setLogRecordFactory', - 'lastResort', 'raiseExceptions'] + 'lastResort', 'raiseExceptions', 'getLevelNamesMapping', + 'getHandlerByName', 'getHandlerNames'] import threading @@ -63,20 +65,25 @@ raiseExceptions = True # -# If you don't want threading information in the log, set this to zero +# If you don't want threading information in the log, set this to False # logThreads = True # -# If you don't want multiprocessing information in the log, set this to zero +# If you don't want multiprocessing information in the log, set this to False # logMultiprocessing = True # -# If you don't want process information in the log, set this to zero +# If you don't want process information in the log, set this to False # logProcesses = True +# +# If you don't want asyncio task information in the log, set this to False +# +logAsyncioTasks = True + #--------------------------------------------------------------------------- # Level related stuff #--------------------------------------------------------------------------- @@ -116,6 +123,9 @@ 'NOTSET': NOTSET, } +def getLevelNamesMapping(): + return _nameToLevel.copy() + def getLevelName(level): """ Return the textual or numeric representation of logging level 'level'. @@ -156,15 +166,15 @@ def addLevelName(level, levelName): finally: _releaseLock() -if hasattr(sys, '_getframe'): - currentframe = lambda: sys._getframe(3) +if hasattr(sys, "_getframe"): + currentframe = lambda: sys._getframe(1) else: #pragma: no cover def currentframe(): """Return the frame object for the caller's stack frame.""" try: raise Exception - except Exception: - return sys.exc_info()[2].tb_frame.f_back + except Exception as exc: + return exc.__traceback__.tb_frame.f_back # # _srcfile is used when walking the stack to check when we've got the first @@ -181,13 +191,18 @@ def currentframe(): _srcfile = os.path.normcase(addLevelName.__code__.co_filename) # _srcfile is only used in conjunction with sys._getframe(). -# To provide compatibility with older versions of Python, set _srcfile -# to None if _getframe() is not available; this value will prevent -# findCaller() from being called. You can also do this if you want to avoid -# the overhead of fetching caller information, even when _getframe() is -# available. -#if not hasattr(sys, '_getframe'): -# _srcfile = None +# Setting _srcfile to None will prevent findCaller() from being called. This +# way, you can avoid the overhead of fetching caller information. + +# The following is based on warnings._is_internal_frame. It makes sure that +# frames of the import mechanism are skipped when logging at module level and +# using a stacklevel value greater than one. +def _is_internal_frame(frame): + """Signal whether the frame is a CPython or logging module internal.""" + filename = os.path.normcase(frame.f_code.co_filename) + return filename == _srcfile or ( + "importlib" in filename and "_bootstrap" in filename + ) def _checkLevel(level): @@ -307,7 +322,7 @@ def __init__(self, name, level, pathname, lineno, # Thus, while not removing the isinstance check, it does now look # for collections.abc.Mapping rather than, as before, dict. if (args and len(args) == 1 and isinstance(args[0], collections.abc.Mapping) - and args[0]): + and args[0]): args = args[0] self.args = args self.levelname = getLevelName(level) @@ -325,7 +340,7 @@ def __init__(self, name, level, pathname, lineno, self.lineno = lineno self.funcName = func self.created = ct - self.msecs = (ct - int(ct)) * 1000 + self.msecs = int((ct - int(ct)) * 1000) + 0.0 # see gh-89047 self.relativeCreated = (self.created - _startTime) * 1000 if logThreads: self.thread = threading.get_ident() @@ -352,9 +367,18 @@ def __init__(self, name, level, pathname, lineno, else: self.process = None + self.taskName = None + if logAsyncioTasks: + asyncio = sys.modules.get('asyncio') + if asyncio: + try: + self.taskName = asyncio.current_task().get_name() + except Exception: + pass + def __repr__(self): return ''%(self.name, self.levelno, - self.pathname, self.lineno, self.msg) + self.pathname, self.lineno, self.msg) def getMessage(self): """ @@ -487,7 +511,7 @@ def __init__(self, *args, **kwargs): def usesTime(self): fmt = self._fmt - return fmt.find('$asctime') >= 0 or fmt.find(self.asctime_format) >= 0 + return fmt.find('$asctime') >= 0 or fmt.find(self.asctime_search) >= 0 def validate(self): pattern = Template.pattern @@ -557,6 +581,7 @@ class Formatter(object): (typically at application startup time) %(thread)d Thread ID (if available) %(threadName)s Thread name (if available) + %(taskName)s Task name (if available) %(process)d Process ID (if available) %(message)s The result of record.getMessage(), computed just as the record is emitted @@ -583,7 +608,7 @@ def __init__(self, fmt=None, datefmt=None, style='%', validate=True, *, """ if style not in _STYLES: raise ValueError('Style must be one of: %s' % ','.join( - _STYLES.keys())) + _STYLES.keys())) self._style = _STYLES[style][0](fmt, defaults=defaults) if validate: self._style.validate() @@ -808,23 +833,36 @@ def filter(self, record): Determine if a record is loggable by consulting all the filters. The default is to allow the record to be logged; any filter can veto - this and the record is then dropped. Returns a zero value if a record - is to be dropped, else non-zero. + this by returning a false value. + If a filter attached to a handler returns a log record instance, + then that instance is used in place of the original log record in + any further processing of the event by that handler. + If a filter returns any other true value, the original log record + is used in any further processing of the event by that handler. + + If none of the filters return false values, this method returns + a log record. + If any of the filters return a false value, this method returns + a false value. .. versionchanged:: 3.2 Allow filters to be just callables. + + .. versionchanged:: 3.12 + Allow filters to return a LogRecord instead of + modifying it in place. """ - rv = True for f in self.filters: if hasattr(f, 'filter'): result = f.filter(record) else: result = f(record) # assume callable - will raise if not if not result: - rv = False - break - return rv + return False + if isinstance(result, LogRecord): + record = result + return record #--------------------------------------------------------------------------- # Handler classes and functions @@ -845,8 +883,9 @@ def _removeHandlerRef(wr): if acquire and release and handlers: acquire() try: - if wr in handlers: - handlers.remove(wr) + handlers.remove(wr) + except ValueError: + pass finally: release() @@ -860,6 +899,23 @@ def _addHandlerRef(handler): finally: _releaseLock() + +def getHandlerByName(name): + """ + Get a handler with the specified *name*, or None if there isn't one with + that name. + """ + return _handlers.get(name) + + +def getHandlerNames(): + """ + Return all known handler names as an immutable set. + """ + result = set(_handlers.keys()) + return frozenset(result) + + class Handler(Filterer): """ Handler instances dispatch logging events to specific destinations. @@ -958,10 +1014,14 @@ def handle(self, record): Emission depends on filters which may have been added to the handler. Wrap the actual emission of the record with acquisition/release of - the I/O thread lock. Returns whether the filter passed the record for - emission. + the I/O thread lock. + + Returns an instance of the log record that was emitted + if it passed all filters, otherwise a false value is returned. """ rv = self.filter(record) + if isinstance(rv, LogRecord): + record = rv if rv: self.acquire() try: @@ -1032,7 +1092,7 @@ def handleError(self, record): else: # couldn't find the right stack frame, for some reason sys.stderr.write('Logged from file %s, line %s\n' % ( - record.filename, record.lineno)) + record.filename, record.lineno)) # Issue 18671: output logging message and arguments try: sys.stderr.write('Message: %r\n' @@ -1044,7 +1104,7 @@ def handleError(self, record): sys.stderr.write('Unable to print the message and arguments' ' - possible formatting error.\nUse the' ' traceback above to help find the error.\n' - ) + ) except OSError: #pragma: no cover pass # see issue 5971 finally: @@ -1136,6 +1196,8 @@ def __repr__(self): name += ' ' return '<%s %s(%s)>' % (self.__class__.__name__, name, level) + __class_getitem__ = classmethod(GenericAlias) + class FileHandler(StreamHandler): """ @@ -1459,7 +1521,7 @@ def debug(self, msg, *args, **kwargs): To pass exception information, use the keyword argument exc_info with a true value, e.g. - logger.debug("Houston, we have a %s", "thorny problem", exc_info=1) + logger.debug("Houston, we have a %s", "thorny problem", exc_info=True) """ if self.isEnabledFor(DEBUG): self._log(DEBUG, msg, args, **kwargs) @@ -1471,7 +1533,7 @@ def info(self, msg, *args, **kwargs): To pass exception information, use the keyword argument exc_info with a true value, e.g. - logger.info("Houston, we have a %s", "interesting problem", exc_info=1) + logger.info("Houston, we have a %s", "notable problem", exc_info=True) """ if self.isEnabledFor(INFO): self._log(INFO, msg, args, **kwargs) @@ -1483,14 +1545,14 @@ def warning(self, msg, *args, **kwargs): To pass exception information, use the keyword argument exc_info with a true value, e.g. - logger.warning("Houston, we have a %s", "bit of a problem", exc_info=1) + logger.warning("Houston, we have a %s", "bit of a problem", exc_info=True) """ if self.isEnabledFor(WARNING): self._log(WARNING, msg, args, **kwargs) def warn(self, msg, *args, **kwargs): warnings.warn("The 'warn' method is deprecated, " - "use 'warning' instead", DeprecationWarning, 2) + "use 'warning' instead", DeprecationWarning, 2) self.warning(msg, *args, **kwargs) def error(self, msg, *args, **kwargs): @@ -1500,7 +1562,7 @@ def error(self, msg, *args, **kwargs): To pass exception information, use the keyword argument exc_info with a true value, e.g. - logger.error("Houston, we have a %s", "major problem", exc_info=1) + logger.error("Houston, we have a %s", "major problem", exc_info=True) """ if self.isEnabledFor(ERROR): self._log(ERROR, msg, args, **kwargs) @@ -1518,7 +1580,7 @@ def critical(self, msg, *args, **kwargs): To pass exception information, use the keyword argument exc_info with a true value, e.g. - logger.critical("Houston, we have a %s", "major disaster", exc_info=1) + logger.critical("Houston, we have a %s", "major disaster", exc_info=True) """ if self.isEnabledFor(CRITICAL): self._log(CRITICAL, msg, args, **kwargs) @@ -1536,7 +1598,7 @@ def log(self, level, msg, *args, **kwargs): To pass exception information, use the keyword argument exc_info with a true value, e.g. - logger.log(level, "We have a %s", "mysterious problem", exc_info=1) + logger.log(level, "We have a %s", "mysterious problem", exc_info=True) """ if not isinstance(level, int): if raiseExceptions: @@ -1554,33 +1616,31 @@ def findCaller(self, stack_info=False, stacklevel=1): f = currentframe() #On some versions of IronPython, currentframe() returns None if #IronPython isn't run with -X:Frames. - if f is not None: - f = f.f_back - orig_f = f - while f and stacklevel > 1: - f = f.f_back - stacklevel -= 1 - if not f: - f = orig_f - rv = "(unknown file)", 0, "(unknown function)", None - while hasattr(f, "f_code"): - co = f.f_code - filename = os.path.normcase(co.co_filename) - if filename == _srcfile: - f = f.f_back - continue - sinfo = None - if stack_info: - sio = io.StringIO() - sio.write('Stack (most recent call last):\n') + if f is None: + return "(unknown file)", 0, "(unknown function)", None + while stacklevel > 0: + next_f = f.f_back + if next_f is None: + ## We've got options here. + ## If we want to use the last (deepest) frame: + break + ## If we want to mimic the warnings module: + #return ("sys", 1, "(unknown function)", None) + ## If we want to be pedantic: + #raise ValueError("call stack is not deep enough") + f = next_f + if not _is_internal_frame(f): + stacklevel -= 1 + co = f.f_code + sinfo = None + if stack_info: + with io.StringIO() as sio: + sio.write("Stack (most recent call last):\n") traceback.print_stack(f, file=sio) sinfo = sio.getvalue() if sinfo[-1] == '\n': sinfo = sinfo[:-1] - sio.close() - rv = (co.co_filename, f.f_lineno, co.co_name, sinfo) - break - return rv + return co.co_filename, f.f_lineno, co.co_name, sinfo def makeRecord(self, name, level, fn, lno, msg, args, exc_info, func=None, extra=None, sinfo=None): @@ -1589,7 +1649,7 @@ def makeRecord(self, name, level, fn, lno, msg, args, exc_info, specialized LogRecords. """ rv = _logRecordFactory(name, level, fn, lno, msg, args, exc_info, func, - sinfo) + sinfo) if extra is not None: for key in extra: if (key in ["message", "asctime"]) or (key in rv.__dict__): @@ -1630,8 +1690,14 @@ def handle(self, record): This method is used for unpickled records received from a socket, as well as those created locally. Logger-level filtering is applied. """ - if (not self.disabled) and self.filter(record): - self.callHandlers(record) + if self.disabled: + return + maybe_record = self.filter(record) + if not maybe_record: + return + if isinstance(maybe_record, LogRecord): + record = maybe_record + self.callHandlers(record) def addHandler(self, hdlr): """ @@ -1737,7 +1803,7 @@ def isEnabledFor(self, level): is_enabled = self._cache[level] = False else: is_enabled = self._cache[level] = ( - level >= self.getEffectiveLevel() + level >= self.getEffectiveLevel() ) finally: _releaseLock() @@ -1762,13 +1828,30 @@ def getChild(self, suffix): suffix = '.'.join((self.name, suffix)) return self.manager.getLogger(suffix) + def getChildren(self): + + def _hierlevel(logger): + if logger is logger.manager.root: + return 0 + return 1 + logger.name.count('.') + + d = self.manager.loggerDict + _acquireLock() + try: + # exclude PlaceHolders - the last check is to ensure that lower-level + # descendants aren't returned - if there are placeholders, a logger's + # parent field might point to a grandparent or ancestor thereof. + return set(item for item in d.values() + if isinstance(item, Logger) and item.parent is self and + _hierlevel(item) == 1 + _hierlevel(item.parent)) + finally: + _releaseLock() + def __repr__(self): level = getLevelName(self.getEffectiveLevel()) return '<%s %s (%s)>' % (self.__class__.__name__, self.name, level) def __reduce__(self): - # In general, only the root logger will not be accessible via its name. - # However, the root logger's class has its own __reduce__ method. if getLogger(self.name) is not self: import pickle raise pickle.PicklingError('logger cannot be pickled') @@ -1848,7 +1931,7 @@ def warning(self, msg, *args, **kwargs): def warn(self, msg, *args, **kwargs): warnings.warn("The 'warn' method is deprecated, " - "use 'warning' instead", DeprecationWarning, 2) + "use 'warning' instead", DeprecationWarning, 2) self.warning(msg, *args, **kwargs) def error(self, msg, *args, **kwargs): @@ -1902,18 +1985,11 @@ def hasHandlers(self): """ return self.logger.hasHandlers() - def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False): + def _log(self, level, msg, args, **kwargs): """ Low-level log implementation, proxied to allow nested logger adapters. """ - return self.logger._log( - level, - msg, - args, - exc_info=exc_info, - extra=extra, - stack_info=stack_info, - ) + return self.logger._log(level, msg, args, **kwargs) @property def manager(self): @@ -1932,6 +2008,8 @@ def __repr__(self): level = getLevelName(logger.getEffectiveLevel()) return '<%s %s (%s)>' % (self.__class__.__name__, logger.name, level) + __class_getitem__ = classmethod(GenericAlias) + root = RootLogger(WARNING) Logger.root = root Logger.manager = Manager(Logger.root) @@ -1971,7 +2049,7 @@ def basicConfig(**kwargs): that this argument is incompatible with 'filename' - if both are present, 'stream' is ignored. handlers If specified, this should be an iterable of already created - handlers, which will be added to the root handler. Any handler + handlers, which will be added to the root logger. Any handler in the list which does not have a formatter assigned will be assigned the formatter created in this function. force If this keyword is specified as true, any existing handlers @@ -2047,7 +2125,7 @@ def basicConfig(**kwargs): style = kwargs.pop("style", '%') if style not in _STYLES: raise ValueError('Style must be one of: %s' % ','.join( - _STYLES.keys())) + _STYLES.keys())) fs = kwargs.pop("format", _STYLES[style][1]) fmt = Formatter(fs, dfs, style) for h in handlers: @@ -2124,7 +2202,7 @@ def warning(msg, *args, **kwargs): def warn(msg, *args, **kwargs): warnings.warn("The 'warn' function is deprecated, " - "use 'warning' instead", DeprecationWarning, 2) + "use 'warning' instead", DeprecationWarning, 2) warning(msg, *args, **kwargs) def info(msg, *args, **kwargs): @@ -2179,7 +2257,11 @@ def shutdown(handlerList=_handlerList): if h: try: h.acquire() - h.flush() + # MemoryHandlers might not want to be flushed on close, + # but circular imports prevent us scoping this to just + # those handlers. hence the default to True. + if getattr(h, 'flushOnClose', True): + h.flush() h.close() except (OSError, ValueError): # Ignore errors which might be caused @@ -2242,7 +2324,9 @@ def _showwarning(message, category, filename, lineno, file=None, line=None): logger = getLogger("py.warnings") if not logger.handlers: logger.addHandler(NullHandler()) - logger.warning("%s", s) + # bpo-46557: Log str(s) as msg instead of logger.warning("%s", s) + # since some log aggregation tools group logs by the msg arg + logger.warning(str(s)) def captureWarnings(capture): """ diff --git a/Lib/logging/config.py b/Lib/logging/config.py index 3bc63b7862..ef04a35168 100644 --- a/Lib/logging/config.py +++ b/Lib/logging/config.py @@ -1,4 +1,4 @@ -# Copyright 2001-2019 by Vinay Sajip. All Rights Reserved. +# Copyright 2001-2023 by Vinay Sajip. All Rights Reserved. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose and without fee is hereby granted, @@ -19,18 +19,20 @@ is based on PEP 282 and comments thereto in comp.lang.python, and influenced by Apache's log4j system. -Copyright (C) 2001-2019 Vinay Sajip. All Rights Reserved. +Copyright (C) 2001-2022 Vinay Sajip. All Rights Reserved. To use, simply 'import logging' and log away! """ import errno +import functools import io import logging import logging.handlers +import os +import queue import re import struct -import sys import threading import traceback @@ -59,15 +61,24 @@ def fileConfig(fname, defaults=None, disable_existing_loggers=True, encoding=Non """ import configparser + if isinstance(fname, str): + if not os.path.exists(fname): + raise FileNotFoundError(f"{fname} doesn't exist") + elif not os.path.getsize(fname): + raise RuntimeError(f'{fname} is an empty file') + if isinstance(fname, configparser.RawConfigParser): cp = fname else: - cp = configparser.ConfigParser(defaults) - if hasattr(fname, 'readline'): - cp.read_file(fname) - else: - encoding = io.text_encoding(encoding) - cp.read(fname, encoding=encoding) + try: + cp = configparser.ConfigParser(defaults) + if hasattr(fname, 'readline'): + cp.read_file(fname) + else: + encoding = io.text_encoding(encoding) + cp.read(fname, encoding=encoding) + except configparser.ParsingError as e: + raise RuntimeError(f'{fname} is invalid: {e}') formatters = _create_formatters(cp) @@ -113,11 +124,18 @@ def _create_formatters(cp): fs = cp.get(sectname, "format", raw=True, fallback=None) dfs = cp.get(sectname, "datefmt", raw=True, fallback=None) stl = cp.get(sectname, "style", raw=True, fallback='%') + defaults = cp.get(sectname, "defaults", raw=True, fallback=None) + c = logging.Formatter class_name = cp[sectname].get("class") if class_name: c = _resolve(class_name) - f = c(fs, dfs, stl) + + if defaults is not None: + defaults = eval(defaults, vars(logging)) + f = c(fs, dfs, stl, defaults=defaults) + else: + f = c(fs, dfs, stl) formatters[form] = f return formatters @@ -296,7 +314,7 @@ def convert_with_key(self, key, value, replace=True): if replace: self[key] = result if type(result) in (ConvertingDict, ConvertingList, - ConvertingTuple): + ConvertingTuple): result.parent = self result.key = key return result @@ -305,7 +323,7 @@ def convert(self, value): result = self.configurator.convert(value) if value is not result: if type(result) in (ConvertingDict, ConvertingList, - ConvertingTuple): + ConvertingTuple): result.parent = self return result @@ -392,11 +410,9 @@ def resolve(self, s): self.importer(used) found = getattr(found, frag) return found - except ImportError: - e, tb = sys.exc_info()[1:] + except ImportError as e: v = ValueError('Cannot resolve %r: %s' % (s, e)) - v.__cause__, v.__traceback__ = e, tb - raise v + raise v from e def ext_convert(self, value): """Default converter for the ext:// protocol.""" @@ -448,8 +464,8 @@ def convert(self, value): elif not isinstance(value, ConvertingList) and isinstance(value, list): value = ConvertingList(value) value.configurator = self - elif not isinstance(value, ConvertingTuple) and\ - isinstance(value, tuple) and not hasattr(value, '_fields'): + elif not isinstance(value, ConvertingTuple) and \ + isinstance(value, tuple) and not hasattr(value, '_fields'): value = ConvertingTuple(value) value.configurator = self elif isinstance(value, str): # str for py3k @@ -469,10 +485,10 @@ def configure_custom(self, config): c = config.pop('()') if not callable(c): c = self.resolve(c) - props = config.pop('.', None) # Check for valid identifiers - kwargs = {k: config[k] for k in config if valid_ident(k)} + kwargs = {k: config[k] for k in config if (k != '.' and valid_ident(k))} result = c(**kwargs) + props = config.pop('.', None) if props: for name, value in props.items(): setattr(result, name, value) @@ -484,6 +500,33 @@ def as_tuple(self, value): value = tuple(value) return value +def _is_queue_like_object(obj): + """Check that *obj* implements the Queue API.""" + if isinstance(obj, (queue.Queue, queue.SimpleQueue)): + return True + # defer importing multiprocessing as much as possible + from multiprocessing.queues import Queue as MPQueue + if isinstance(obj, MPQueue): + return True + # Depending on the multiprocessing start context, we cannot create + # a multiprocessing.managers.BaseManager instance 'mm' to get the + # runtime type of mm.Queue() or mm.JoinableQueue() (see gh-119819). + # + # Since we only need an object implementing the Queue API, we only + # do a protocol check, but we do not use typing.runtime_checkable() + # and typing.Protocol to reduce import time (see gh-121723). + # + # Ideally, we would have wanted to simply use strict type checking + # instead of a protocol-based type checking since the latter does + # not check the method signatures. + # + # Note that only 'put_nowait' and 'get' are required by the logging + # queue handler and queue listener (see gh-124653) and that other + # methods are either optional or unused. + minimal_queue_interface = ['put_nowait', 'get'] + return all(callable(getattr(obj, method, None)) + for method in minimal_queue_interface) + class DictConfigurator(BaseConfigurator): """ Configure logging using a dictionary-like object to describe the @@ -542,7 +585,7 @@ def configure(self): for name in formatters: try: formatters[name] = self.configure_formatter( - formatters[name]) + formatters[name]) except Exception as e: raise ValueError('Unable to configure ' 'formatter %r' % name) from e @@ -566,7 +609,7 @@ def configure(self): handler.name = name handlers[name] = handler except Exception as e: - if 'target not configured yet' in str(e.__cause__): + if ' not configured yet' in str(e.__cause__): deferred.append(name) else: raise ValueError('Unable to configure handler ' @@ -669,18 +712,27 @@ def configure_formatter(self, config): dfmt = config.get('datefmt', None) style = config.get('style', '%') cname = config.get('class', None) + defaults = config.get('defaults', None) if not cname: c = logging.Formatter else: c = _resolve(cname) + kwargs = {} + + # Add defaults only if it exists. + # Prevents TypeError in custom formatter callables that do not + # accept it. + if defaults is not None: + kwargs['defaults'] = defaults + # A TypeError would be raised if "validate" key is passed in with a formatter callable # that does not accept "validate" as a parameter if 'validate' in config: # if user hasn't mentioned it, the default will be fine - result = c(fmt, dfmt, style, config['validate']) + result = c(fmt, dfmt, style, config['validate'], **kwargs) else: - result = c(fmt, dfmt, style) + result = c(fmt, dfmt, style, **kwargs) return result @@ -697,10 +749,29 @@ def add_filters(self, filterer, filters): """Add filters to a filterer from a list of names.""" for f in filters: try: - filterer.addFilter(self.config['filters'][f]) + if callable(f) or callable(getattr(f, 'filter', None)): + filter_ = f + else: + filter_ = self.config['filters'][f] + filterer.addFilter(filter_) except Exception as e: raise ValueError('Unable to add filter %r' % f) from e + def _configure_queue_handler(self, klass, **kwargs): + if 'queue' in kwargs: + q = kwargs.pop('queue') + else: + q = queue.Queue() # unbounded + + rhl = kwargs.pop('respect_handler_level', False) + lklass = kwargs.pop('listener', logging.handlers.QueueListener) + handlers = kwargs.pop('handlers', []) + + listener = lklass(q, *handlers, respect_handler_level=rhl) + handler = klass(q, **kwargs) + handler.listener = listener + return handler + def configure_handler(self, config): """Configure a handler from a dictionary.""" config_copy = dict(config) # for restoring in case of error @@ -720,28 +791,87 @@ def configure_handler(self, config): factory = c else: cname = config.pop('class') - klass = self.resolve(cname) - #Special case for handler which refers to another handler - if issubclass(klass, logging.handlers.MemoryHandler) and\ - 'target' in config: - try: - th = self.config['handlers'][config['target']] - if not isinstance(th, logging.Handler): - config.update(config_copy) # restore for deferred cfg - raise TypeError('target not configured yet') - config['target'] = th - except Exception as e: - raise ValueError('Unable to set target handler ' - '%r' % config['target']) from e - elif issubclass(klass, logging.handlers.SMTPHandler) and\ - 'mailhost' in config: + if callable(cname): + klass = cname + else: + klass = self.resolve(cname) + if issubclass(klass, logging.handlers.MemoryHandler): + if 'flushLevel' in config: + config['flushLevel'] = logging._checkLevel(config['flushLevel']) + if 'target' in config: + # Special case for handler which refers to another handler + try: + tn = config['target'] + th = self.config['handlers'][tn] + if not isinstance(th, logging.Handler): + config.update(config_copy) # restore for deferred cfg + raise TypeError('target not configured yet') + config['target'] = th + except Exception as e: + raise ValueError('Unable to set target handler %r' % tn) from e + elif issubclass(klass, logging.handlers.QueueHandler): + # Another special case for handler which refers to other handlers + # if 'handlers' not in config: + # raise ValueError('No handlers specified for a QueueHandler') + if 'queue' in config: + qspec = config['queue'] + + if isinstance(qspec, str): + q = self.resolve(qspec) + if not callable(q): + raise TypeError('Invalid queue specifier %r' % qspec) + config['queue'] = q() + elif isinstance(qspec, dict): + if '()' not in qspec: + raise TypeError('Invalid queue specifier %r' % qspec) + config['queue'] = self.configure_custom(dict(qspec)) + elif not _is_queue_like_object(qspec): + raise TypeError('Invalid queue specifier %r' % qspec) + + if 'listener' in config: + lspec = config['listener'] + if isinstance(lspec, type): + if not issubclass(lspec, logging.handlers.QueueListener): + raise TypeError('Invalid listener specifier %r' % lspec) + else: + if isinstance(lspec, str): + listener = self.resolve(lspec) + if isinstance(listener, type) and \ + not issubclass(listener, logging.handlers.QueueListener): + raise TypeError('Invalid listener specifier %r' % lspec) + elif isinstance(lspec, dict): + if '()' not in lspec: + raise TypeError('Invalid listener specifier %r' % lspec) + listener = self.configure_custom(dict(lspec)) + else: + raise TypeError('Invalid listener specifier %r' % lspec) + if not callable(listener): + raise TypeError('Invalid listener specifier %r' % lspec) + config['listener'] = listener + if 'handlers' in config: + hlist = [] + try: + for hn in config['handlers']: + h = self.config['handlers'][hn] + if not isinstance(h, logging.Handler): + config.update(config_copy) # restore for deferred cfg + raise TypeError('Required handler %r ' + 'is not configured yet' % hn) + hlist.append(h) + except Exception as e: + raise ValueError('Unable to set required handler %r' % hn) from e + config['handlers'] = hlist + elif issubclass(klass, logging.handlers.SMTPHandler) and \ + 'mailhost' in config: config['mailhost'] = self.as_tuple(config['mailhost']) - elif issubclass(klass, logging.handlers.SysLogHandler) and\ - 'address' in config: + elif issubclass(klass, logging.handlers.SysLogHandler) and \ + 'address' in config: config['address'] = self.as_tuple(config['address']) - factory = klass - props = config.pop('.', None) - kwargs = {k: config[k] for k in config if valid_ident(k)} + if issubclass(klass, logging.handlers.QueueHandler): + factory = functools.partial(self._configure_queue_handler, klass) + else: + factory = klass + kwargs = {k: config[k] for k in config if (k != '.' and valid_ident(k))} try: result = factory(**kwargs) except TypeError as te: @@ -759,6 +889,7 @@ def configure_handler(self, config): result.setLevel(logging._checkLevel(level)) if filters: self.add_filters(result, filters) + props = config.pop('.', None) if props: for name, value in props.items(): setattr(result, name, value) @@ -794,6 +925,7 @@ def configure_logger(self, name, config, incremental=False): """Configure a non-root logger from a dictionary.""" logger = logging.getLogger(name) self.common_logger_config(logger, config, incremental) + logger.disabled = False propagate = config.get('propagate', None) if propagate is not None: logger.propagate = propagate diff --git a/Lib/logging/handlers.py b/Lib/logging/handlers.py index 61a39958c0..bf42ea1103 100644 --- a/Lib/logging/handlers.py +++ b/Lib/logging/handlers.py @@ -187,15 +187,18 @@ def shouldRollover(self, record): Basically, see if the supplied record would cause the file to exceed the size limit we have. """ - # See bpo-45401: Never rollover anything other than regular files - if os.path.exists(self.baseFilename) and not os.path.isfile(self.baseFilename): - return False if self.stream is None: # delay was set... self.stream = self._open() if self.maxBytes > 0: # are we rolling over? + pos = self.stream.tell() + if not pos: + # gh-116263: Never rollover an empty file + return False msg = "%s\n" % self.format(record) - self.stream.seek(0, 2) #due to non-posix-compliant Windows feature - if self.stream.tell() + len(msg) >= self.maxBytes: + if pos + len(msg) >= self.maxBytes: + # See bpo-45401: Never rollover anything other than regular files + if os.path.exists(self.baseFilename) and not os.path.isfile(self.baseFilename): + return False return True return False @@ -232,19 +235,19 @@ def __init__(self, filename, when='h', interval=1, backupCount=0, if self.when == 'S': self.interval = 1 # one second self.suffix = "%Y-%m-%d_%H-%M-%S" - self.extMatch = r"^\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}(\.\w+)?$" + extMatch = r"(?= self.rolloverAt: + # See #89564: Never rollover anything other than regular files + if os.path.exists(self.baseFilename) and not os.path.isfile(self.baseFilename): + # The file is not a regular file, so do not rollover, but do + # set the next rollover time to avoid repeated checks. + self.rolloverAt = self.computeRollover(t) + return False + return True return False @@ -365,32 +382,28 @@ def getFilesToDelete(self): dirName, baseName = os.path.split(self.baseFilename) fileNames = os.listdir(dirName) result = [] - # See bpo-44753: Don't use the extension when computing the prefix. - n, e = os.path.splitext(baseName) - prefix = n + '.' - plen = len(prefix) - for fileName in fileNames: - if self.namer is None: - # Our files will always start with baseName - if not fileName.startswith(baseName): - continue - else: - # Our files could be just about anything after custom naming, but - # likely candidates are of the form - # foo.log.DATETIME_SUFFIX or foo.DATETIME_SUFFIX.log - if (not fileName.startswith(baseName) and fileName.endswith(e) and - len(fileName) > (plen + 1) and not fileName[plen+1].isdigit()): - continue - - if fileName[:plen] == prefix: - suffix = fileName[plen:] - # See bpo-45628: The date/time suffix could be anywhere in the - # filename - parts = suffix.split('.') - for part in parts: - if self.extMatch.match(part): + if self.namer is None: + prefix = baseName + '.' + plen = len(prefix) + for fileName in fileNames: + if fileName[:plen] == prefix: + suffix = fileName[plen:] + if self.extMatch.fullmatch(suffix): + result.append(os.path.join(dirName, fileName)) + else: + for fileName in fileNames: + # Our files could be just about anything after custom naming, + # but they should contain the datetime suffix. + # Try to find the datetime suffix in the file name and verify + # that the file name can be generated by this handler. + m = self.extMatch.search(fileName) + while m: + dfn = self.namer(self.baseFilename + "." + m[0]) + if os.path.basename(dfn) == fileName: result.append(os.path.join(dirName, fileName)) break + m = self.extMatch.search(fileName, m.start() + 1) + if len(result) < self.backupCount: result = [] else: @@ -406,17 +419,14 @@ def doRollover(self): then we have to get a list of matching filenames, sort them and remove the one with the oldest suffix. """ - if self.stream: - self.stream.close() - self.stream = None # get the time that this sequence started at and make it a TimeTuple currentTime = int(time.time()) - dstNow = time.localtime(currentTime)[-1] t = self.rolloverAt - self.interval if self.utc: timeTuple = time.gmtime(t) else: timeTuple = time.localtime(t) + dstNow = time.localtime(currentTime)[-1] dstThen = timeTuple[-1] if dstNow != dstThen: if dstNow: @@ -427,26 +437,19 @@ def doRollover(self): dfn = self.rotation_filename(self.baseFilename + "." + time.strftime(self.suffix, timeTuple)) if os.path.exists(dfn): - os.remove(dfn) + # Already rolled over. + return + + if self.stream: + self.stream.close() + self.stream = None self.rotate(self.baseFilename, dfn) if self.backupCount > 0: for s in self.getFilesToDelete(): os.remove(s) if not self.delay: self.stream = self._open() - newRolloverAt = self.computeRollover(currentTime) - while newRolloverAt <= currentTime: - newRolloverAt = newRolloverAt + self.interval - #If DST changes and midnight or weekly rollover, adjust for this. - if (self.when == 'MIDNIGHT' or self.when.startswith('W')) and not self.utc: - dstAtRollover = time.localtime(newRolloverAt)[-1] - if dstNow != dstAtRollover: - if not dstNow: # DST kicks in before next rollover, so we need to deduct an hour - addend = -3600 - else: # DST bows out before next rollover, so we need to add an hour - addend = 3600 - newRolloverAt += addend - self.rolloverAt = newRolloverAt + self.rolloverAt = self.computeRollover(currentTime) class WatchedFileHandler(logging.FileHandler): """ @@ -800,7 +803,7 @@ class SysLogHandler(logging.Handler): "panic": LOG_EMERG, # DEPRECATED "warn": LOG_WARNING, # DEPRECATED "warning": LOG_WARNING, - } + } facility_names = { "auth": LOG_AUTH, @@ -827,12 +830,10 @@ class SysLogHandler(logging.Handler): "local5": LOG_LOCAL5, "local6": LOG_LOCAL6, "local7": LOG_LOCAL7, - } + } - #The map below appears to be trivially lowercasing the key. However, - #there's more to it than meets the eye - in some locales, lowercasing - #gives unexpected results. See SF #1524081: in the Turkish locale, - #"INFO".lower() != "info" + # Originally added to work around GH-43683. Unnecessary since GH-50043 but kept + # for backwards compatibility. priority_map = { "DEBUG" : "debug", "INFO" : "info", @@ -859,12 +860,49 @@ def __init__(self, address=('localhost', SYSLOG_UDP_PORT), self.address = address self.facility = facility self.socktype = socktype + self.socket = None + self.createSocket() + + def _connect_unixsocket(self, address): + use_socktype = self.socktype + if use_socktype is None: + use_socktype = socket.SOCK_DGRAM + self.socket = socket.socket(socket.AF_UNIX, use_socktype) + try: + self.socket.connect(address) + # it worked, so set self.socktype to the used type + self.socktype = use_socktype + except OSError: + self.socket.close() + if self.socktype is not None: + # user didn't specify falling back, so fail + raise + use_socktype = socket.SOCK_STREAM + self.socket = socket.socket(socket.AF_UNIX, use_socktype) + try: + self.socket.connect(address) + # it worked, so set self.socktype to the used type + self.socktype = use_socktype + except OSError: + self.socket.close() + raise + + def createSocket(self): + """ + Try to create a socket and, if it's not a datagram socket, connect it + to the other end. This method is called during handler initialization, + but it's not regarded as an error if the other end isn't listening yet + --- the method will be called again when emitting an event, + if there is no socket at that point. + """ + address = self.address + socktype = self.socktype if isinstance(address, str): self.unixsocket = True # Syslog server may be unavailable during handler initialisation. # C's openlog() function also ignores connection errors. - # Moreover, we ignore these errors while logging, so it not worse + # Moreover, we ignore these errors while logging, so it's not worse # to ignore it also here. try: self._connect_unixsocket(address) @@ -895,30 +933,6 @@ def __init__(self, address=('localhost', SYSLOG_UDP_PORT), self.socket = sock self.socktype = socktype - def _connect_unixsocket(self, address): - use_socktype = self.socktype - if use_socktype is None: - use_socktype = socket.SOCK_DGRAM - self.socket = socket.socket(socket.AF_UNIX, use_socktype) - try: - self.socket.connect(address) - # it worked, so set self.socktype to the used type - self.socktype = use_socktype - except OSError: - self.socket.close() - if self.socktype is not None: - # user didn't specify falling back, so fail - raise - use_socktype = socket.SOCK_STREAM - self.socket = socket.socket(socket.AF_UNIX, use_socktype) - try: - self.socket.connect(address) - # it worked, so set self.socktype to the used type - self.socktype = use_socktype - except OSError: - self.socket.close() - raise - def encodePriority(self, facility, priority): """ Encode the facility and priority. You can pass in strings or @@ -938,7 +952,10 @@ def close(self): """ self.acquire() try: - self.socket.close() + sock = self.socket + if sock: + self.socket = None + sock.close() logging.Handler.close(self) finally: self.release() @@ -978,6 +995,10 @@ def emit(self, record): # Message is a string. Convert to bytes as required by RFC 5424 msg = msg.encode('utf-8') msg = prio + msg + + if not self.socket: + self.createSocket() + if self.unixsocket: try: self.socket.send(msg) @@ -1094,7 +1115,16 @@ def __init__(self, appname, dllname=None, logtype="Application"): dllname = os.path.join(dllname[0], r'win32service.pyd') self.dllname = dllname self.logtype = logtype - self._welu.AddSourceToRegistry(appname, dllname, logtype) + # Administrative privileges are required to add a source to the registry. + # This may not be available for a user that just wants to add to an + # existing source - handle this specific case. + try: + self._welu.AddSourceToRegistry(appname, dllname, logtype) + except Exception as e: + # This will probably be a pywintypes.error. Only raise if it's not + # an "access denied" error, else let it pass + if getattr(e, 'winerror', None) != 5: # not access denied + raise self.deftype = win32evtlog.EVENTLOG_ERROR_TYPE self.typemap = { logging.DEBUG : win32evtlog.EVENTLOG_INFORMATION_TYPE, @@ -1102,10 +1132,10 @@ def __init__(self, appname, dllname=None, logtype="Application"): logging.WARNING : win32evtlog.EVENTLOG_WARNING_TYPE, logging.ERROR : win32evtlog.EVENTLOG_ERROR_TYPE, logging.CRITICAL: win32evtlog.EVENTLOG_ERROR_TYPE, - } + } except ImportError: - print("The Python Win32 extensions for NT (service, event "\ - "logging) appear not to be available.") + print("The Python Win32 extensions for NT (service, event " \ + "logging) appear not to be available.") self._welu = None def getMessageID(self, record): @@ -1348,7 +1378,7 @@ def shouldFlush(self, record): Check for buffer full or a record at the flushLevel or higher. """ return (len(self.buffer) >= self.capacity) or \ - (record.levelno >= self.flushLevel) + (record.levelno >= self.flushLevel) def setTarget(self, target): """ @@ -1366,7 +1396,7 @@ def flush(self): records to the target, if there is one. Override if you want different behaviour. - The record buffer is also cleared by this operation. + The record buffer is only cleared if a target has been set. """ self.acquire() try: @@ -1411,6 +1441,7 @@ def __init__(self, queue): """ logging.Handler.__init__(self) self.queue = queue + self.listener = None # will be set to listener if configured via dictConfig() def enqueue(self, record): """ @@ -1424,12 +1455,15 @@ def enqueue(self, record): def prepare(self, record): """ - Prepares a record for queuing. The object returned by this method is + Prepare a record for queuing. The object returned by this method is enqueued. - The base implementation formats the record to merge the message - and arguments, and removes unpickleable items from the record - in-place. + The base implementation formats the record to merge the message and + arguments, and removes unpickleable items from the record in-place. + Specifically, it overwrites the record's `msg` and + `message` attributes with the merged message (obtained by + calling the handler's `format` method), and sets the `args`, + `exc_info` and `exc_text` attributes to None. You might want to override this method if you want to convert the record to a dict or JSON string, or send a modified copy @@ -1439,7 +1473,7 @@ def prepare(self, record): # (if there's exception data), and also returns the formatted # message. We can then use this to replace the original # msg + args, as these might be unpickleable. We also zap the - # exc_info and exc_text attributes, as they are no longer + # exc_info, exc_text and stack_info attributes, as they are no longer # needed and, if not None, will typically not be pickleable. msg = self.format(record) # bpo-35726: make copy of record to avoid affecting other handlers in the chain. @@ -1449,6 +1483,7 @@ def prepare(self, record): record.args = None record.exc_info = None record.exc_text = None + record.stack_info = None return record def emit(self, record): diff --git a/Lib/smtplib.py b/Lib/smtplib.py new file mode 100644 index 0000000000..912233d817 --- /dev/null +++ b/Lib/smtplib.py @@ -0,0 +1,1109 @@ +#! /usr/bin/env python3 + +'''SMTP/ESMTP client class. + +This should follow RFC 821 (SMTP), RFC 1869 (ESMTP), RFC 2554 (SMTP +Authentication) and RFC 2487 (Secure SMTP over TLS). + +Notes: + +Please remember, when doing ESMTP, that the names of the SMTP service +extensions are NOT the same thing as the option keywords for the RCPT +and MAIL commands! + +Example: + + >>> import smtplib + >>> s=smtplib.SMTP("localhost") + >>> print(s.help()) + This is Sendmail version 8.8.4 + Topics: + HELO EHLO MAIL RCPT DATA + RSET NOOP QUIT HELP VRFY + EXPN VERB ETRN DSN + For more info use "HELP ". + To report bugs in the implementation send email to + sendmail-bugs@sendmail.org. + For local information send email to Postmaster at your site. + End of HELP info + >>> s.putcmd("vrfy","someone@here") + >>> s.getreply() + (250, "Somebody OverHere ") + >>> s.quit() +''' + +# Author: The Dragon De Monsyne +# ESMTP support, test code and doc fixes added by +# Eric S. Raymond +# Better RFC 821 compliance (MAIL and RCPT, and CRLF in data) +# by Carey Evans , for picky mail servers. +# RFC 2554 (authentication) support by Gerhard Haering . +# +# This was modified from the Python 1.5 library HTTP lib. + +import socket +import io +import re +import email.utils +import email.message +import email.generator +import base64 +import hmac +import copy +import datetime +import sys +from email.base64mime import body_encode as encode_base64 + +__all__ = ["SMTPException", "SMTPNotSupportedError", "SMTPServerDisconnected", "SMTPResponseException", + "SMTPSenderRefused", "SMTPRecipientsRefused", "SMTPDataError", + "SMTPConnectError", "SMTPHeloError", "SMTPAuthenticationError", + "quoteaddr", "quotedata", "SMTP"] + +SMTP_PORT = 25 +SMTP_SSL_PORT = 465 +CRLF = "\r\n" +bCRLF = b"\r\n" +_MAXLINE = 8192 # more than 8 times larger than RFC 821, 4.5.3 +_MAXCHALLENGE = 5 # Maximum number of AUTH challenges sent + +OLDSTYLE_AUTH = re.compile(r"auth=(.*)", re.I) + +# Exception classes used by this module. +class SMTPException(OSError): + """Base class for all exceptions raised by this module.""" + +class SMTPNotSupportedError(SMTPException): + """The command or option is not supported by the SMTP server. + + This exception is raised when an attempt is made to run a command or a + command with an option which is not supported by the server. + """ + +class SMTPServerDisconnected(SMTPException): + """Not connected to any SMTP server. + + This exception is raised when the server unexpectedly disconnects, + or when an attempt is made to use the SMTP instance before + connecting it to a server. + """ + +class SMTPResponseException(SMTPException): + """Base class for all exceptions that include an SMTP error code. + + These exceptions are generated in some instances when the SMTP + server returns an error code. The error code is stored in the + `smtp_code' attribute of the error, and the `smtp_error' attribute + is set to the error message. + """ + + def __init__(self, code, msg): + self.smtp_code = code + self.smtp_error = msg + self.args = (code, msg) + +class SMTPSenderRefused(SMTPResponseException): + """Sender address refused. + + In addition to the attributes set by on all SMTPResponseException + exceptions, this sets `sender' to the string that the SMTP refused. + """ + + def __init__(self, code, msg, sender): + self.smtp_code = code + self.smtp_error = msg + self.sender = sender + self.args = (code, msg, sender) + +class SMTPRecipientsRefused(SMTPException): + """All recipient addresses refused. + + The errors for each recipient are accessible through the attribute + 'recipients', which is a dictionary of exactly the same sort as + SMTP.sendmail() returns. + """ + + def __init__(self, recipients): + self.recipients = recipients + self.args = (recipients,) + + +class SMTPDataError(SMTPResponseException): + """The SMTP server didn't accept the data.""" + +class SMTPConnectError(SMTPResponseException): + """Error during connection establishment.""" + +class SMTPHeloError(SMTPResponseException): + """The server refused our HELO reply.""" + +class SMTPAuthenticationError(SMTPResponseException): + """Authentication error. + + Most probably the server didn't accept the username/password + combination provided. + """ + +def quoteaddr(addrstring): + """Quote a subset of the email addresses defined by RFC 821. + + Should be able to handle anything email.utils.parseaddr can handle. + """ + displayname, addr = email.utils.parseaddr(addrstring) + if (displayname, addr) == ('', ''): + # parseaddr couldn't parse it, use it as is and hope for the best. + if addrstring.strip().startswith('<'): + return addrstring + return "<%s>" % addrstring + return "<%s>" % addr + +def _addr_only(addrstring): + displayname, addr = email.utils.parseaddr(addrstring) + if (displayname, addr) == ('', ''): + # parseaddr couldn't parse it, so use it as is. + return addrstring + return addr + +# Legacy method kept for backward compatibility. +def quotedata(data): + """Quote data for email. + + Double leading '.', and change Unix newline '\\n', or Mac '\\r' into + internet CRLF end-of-line. + """ + return re.sub(r'(?m)^\.', '..', + re.sub(r'(?:\r\n|\n|\r(?!\n))', CRLF, data)) + +def _quote_periods(bindata): + return re.sub(br'(?m)^\.', b'..', bindata) + +def _fix_eols(data): + return re.sub(r'(?:\r\n|\n|\r(?!\n))', CRLF, data) + +try: + import ssl +except ImportError: + _have_ssl = False +else: + _have_ssl = True + + +class SMTP: + """This class manages a connection to an SMTP or ESMTP server. + SMTP Objects: + SMTP objects have the following attributes: + helo_resp + This is the message given by the server in response to the + most recent HELO command. + + ehlo_resp + This is the message given by the server in response to the + most recent EHLO command. This is usually multiline. + + does_esmtp + This is a True value _after you do an EHLO command_, if the + server supports ESMTP. + + esmtp_features + This is a dictionary, which, if the server supports ESMTP, + will _after you do an EHLO command_, contain the names of the + SMTP service extensions this server supports, and their + parameters (if any). + + Note, all extension names are mapped to lower case in the + dictionary. + + See each method's docstrings for details. In general, there is a + method of the same name to perform each SMTP command. There is also a + method called 'sendmail' that will do an entire mail transaction. + """ + debuglevel = 0 + + sock = None + file = None + helo_resp = None + ehlo_msg = "ehlo" + ehlo_resp = None + does_esmtp = False + default_port = SMTP_PORT + + def __init__(self, host='', port=0, local_hostname=None, + timeout=socket._GLOBAL_DEFAULT_TIMEOUT, + source_address=None): + """Initialize a new instance. + + If specified, `host` is the name of the remote host to which to + connect. If specified, `port` specifies the port to which to connect. + By default, smtplib.SMTP_PORT is used. If a host is specified the + connect method is called, and if it returns anything other than a + success code an SMTPConnectError is raised. If specified, + `local_hostname` is used as the FQDN of the local host in the HELO/EHLO + command. Otherwise, the local hostname is found using + socket.getfqdn(). The `source_address` parameter takes a 2-tuple (host, + port) for the socket to bind to as its source address before + connecting. If the host is '' and port is 0, the OS default behavior + will be used. + + """ + self._host = host + self.timeout = timeout + self.esmtp_features = {} + self.command_encoding = 'ascii' + self.source_address = source_address + self._auth_challenge_count = 0 + + if host: + (code, msg) = self.connect(host, port) + if code != 220: + self.close() + raise SMTPConnectError(code, msg) + if local_hostname is not None: + self.local_hostname = local_hostname + else: + # RFC 2821 says we should use the fqdn in the EHLO/HELO verb, and + # if that can't be calculated, that we should use a domain literal + # instead (essentially an encoded IP address like [A.B.C.D]). + fqdn = socket.getfqdn() + if '.' in fqdn: + self.local_hostname = fqdn + else: + # We can't find an fqdn hostname, so use a domain literal + addr = '127.0.0.1' + try: + addr = socket.gethostbyname(socket.gethostname()) + except socket.gaierror: + pass + self.local_hostname = '[%s]' % addr + + def __enter__(self): + return self + + def __exit__(self, *args): + try: + code, message = self.docmd("QUIT") + if code != 221: + raise SMTPResponseException(code, message) + except SMTPServerDisconnected: + pass + finally: + self.close() + + def set_debuglevel(self, debuglevel): + """Set the debug output level. + + A non-false value results in debug messages for connection and for all + messages sent to and received from the server. + + """ + self.debuglevel = debuglevel + + def _print_debug(self, *args): + if self.debuglevel > 1: + print(datetime.datetime.now().time(), *args, file=sys.stderr) + else: + print(*args, file=sys.stderr) + + def _get_socket(self, host, port, timeout): + # This makes it simpler for SMTP_SSL to use the SMTP connect code + # and just alter the socket connection bit. + if timeout is not None and not timeout: + raise ValueError('Non-blocking socket (timeout=0) is not supported') + if self.debuglevel > 0: + self._print_debug('connect: to', (host, port), self.source_address) + return socket.create_connection((host, port), timeout, + self.source_address) + + def connect(self, host='localhost', port=0, source_address=None): + """Connect to a host on a given port. + + If the hostname ends with a colon (`:') followed by a number, and + there is no port specified, that suffix will be stripped off and the + number interpreted as the port number to use. + + Note: This method is automatically invoked by __init__, if a host is + specified during instantiation. + + """ + + if source_address: + self.source_address = source_address + + if not port and (host.find(':') == host.rfind(':')): + i = host.rfind(':') + if i >= 0: + host, port = host[:i], host[i + 1:] + try: + port = int(port) + except ValueError: + raise OSError("nonnumeric port") + if not port: + port = self.default_port + sys.audit("smtplib.connect", self, host, port) + self.sock = self._get_socket(host, port, self.timeout) + self.file = None + (code, msg) = self.getreply() + if self.debuglevel > 0: + self._print_debug('connect:', repr(msg)) + return (code, msg) + + def send(self, s): + """Send `s' to the server.""" + if self.debuglevel > 0: + self._print_debug('send:', repr(s)) + if self.sock: + if isinstance(s, str): + # send is used by the 'data' command, where command_encoding + # should not be used, but 'data' needs to convert the string to + # binary itself anyway, so that's not a problem. + s = s.encode(self.command_encoding) + sys.audit("smtplib.send", self, s) + try: + self.sock.sendall(s) + except OSError: + self.close() + raise SMTPServerDisconnected('Server not connected') + else: + raise SMTPServerDisconnected('please run connect() first') + + def putcmd(self, cmd, args=""): + """Send a command to the server.""" + if args == "": + s = cmd + else: + s = f'{cmd} {args}' + if '\r' in s or '\n' in s: + s = s.replace('\n', '\\n').replace('\r', '\\r') + raise ValueError( + f'command and arguments contain prohibited newline characters: {s}' + ) + self.send(f'{s}{CRLF}') + + def getreply(self): + """Get a reply from the server. + + Returns a tuple consisting of: + + - server response code (e.g. '250', or such, if all goes well) + Note: returns -1 if it can't read response code. + + - server response string corresponding to response code (multiline + responses are converted to a single, multiline string). + + Raises SMTPServerDisconnected if end-of-file is reached. + """ + resp = [] + if self.file is None: + self.file = self.sock.makefile('rb') + while 1: + try: + line = self.file.readline(_MAXLINE + 1) + except OSError as e: + self.close() + raise SMTPServerDisconnected("Connection unexpectedly closed: " + + str(e)) + if not line: + self.close() + raise SMTPServerDisconnected("Connection unexpectedly closed") + if self.debuglevel > 0: + self._print_debug('reply:', repr(line)) + if len(line) > _MAXLINE: + self.close() + raise SMTPResponseException(500, "Line too long.") + resp.append(line[4:].strip(b' \t\r\n')) + code = line[:3] + # Check that the error code is syntactically correct. + # Don't attempt to read a continuation line if it is broken. + try: + errcode = int(code) + except ValueError: + errcode = -1 + break + # Check if multiline response. + if line[3:4] != b"-": + break + + errmsg = b"\n".join(resp) + if self.debuglevel > 0: + self._print_debug('reply: retcode (%s); Msg: %a' % (errcode, errmsg)) + return errcode, errmsg + + def docmd(self, cmd, args=""): + """Send a command, and return its response code.""" + self.putcmd(cmd, args) + return self.getreply() + + # std smtp commands + def helo(self, name=''): + """SMTP 'helo' command. + Hostname to send for this command defaults to the FQDN of the local + host. + """ + self.putcmd("helo", name or self.local_hostname) + (code, msg) = self.getreply() + self.helo_resp = msg + return (code, msg) + + def ehlo(self, name=''): + """ SMTP 'ehlo' command. + Hostname to send for this command defaults to the FQDN of the local + host. + """ + self.esmtp_features = {} + self.putcmd(self.ehlo_msg, name or self.local_hostname) + (code, msg) = self.getreply() + # According to RFC1869 some (badly written) + # MTA's will disconnect on an ehlo. Toss an exception if + # that happens -ddm + if code == -1 and len(msg) == 0: + self.close() + raise SMTPServerDisconnected("Server not connected") + self.ehlo_resp = msg + if code != 250: + return (code, msg) + self.does_esmtp = True + #parse the ehlo response -ddm + assert isinstance(self.ehlo_resp, bytes), repr(self.ehlo_resp) + resp = self.ehlo_resp.decode("latin-1").split('\n') + del resp[0] + for each in resp: + # To be able to communicate with as many SMTP servers as possible, + # we have to take the old-style auth advertisement into account, + # because: + # 1) Else our SMTP feature parser gets confused. + # 2) There are some servers that only advertise the auth methods we + # support using the old style. + auth_match = OLDSTYLE_AUTH.match(each) + if auth_match: + # This doesn't remove duplicates, but that's no problem + self.esmtp_features["auth"] = self.esmtp_features.get("auth", "") \ + + " " + auth_match.groups(0)[0] + continue + + # RFC 1869 requires a space between ehlo keyword and parameters. + # It's actually stricter, in that only spaces are allowed between + # parameters, but were not going to check for that here. Note + # that the space isn't present if there are no parameters. + m = re.match(r'(?P[A-Za-z0-9][A-Za-z0-9\-]*) ?', each) + if m: + feature = m.group("feature").lower() + params = m.string[m.end("feature"):].strip() + if feature == "auth": + self.esmtp_features[feature] = self.esmtp_features.get(feature, "") \ + + " " + params + else: + self.esmtp_features[feature] = params + return (code, msg) + + def has_extn(self, opt): + """Does the server support a given SMTP service extension?""" + return opt.lower() in self.esmtp_features + + def help(self, args=''): + """SMTP 'help' command. + Returns help text from server.""" + self.putcmd("help", args) + return self.getreply()[1] + + def rset(self): + """SMTP 'rset' command -- resets session.""" + self.command_encoding = 'ascii' + return self.docmd("rset") + + def _rset(self): + """Internal 'rset' command which ignores any SMTPServerDisconnected error. + + Used internally in the library, since the server disconnected error + should appear to the application when the *next* command is issued, if + we are doing an internal "safety" reset. + """ + try: + self.rset() + except SMTPServerDisconnected: + pass + + def noop(self): + """SMTP 'noop' command -- doesn't do anything :>""" + return self.docmd("noop") + + def mail(self, sender, options=()): + """SMTP 'mail' command -- begins mail xfer session. + + This method may raise the following exceptions: + + SMTPNotSupportedError The options parameter includes 'SMTPUTF8' + but the SMTPUTF8 extension is not supported by + the server. + """ + optionlist = '' + if options and self.does_esmtp: + if any(x.lower()=='smtputf8' for x in options): + if self.has_extn('smtputf8'): + self.command_encoding = 'utf-8' + else: + raise SMTPNotSupportedError( + 'SMTPUTF8 not supported by server') + optionlist = ' ' + ' '.join(options) + self.putcmd("mail", "FROM:%s%s" % (quoteaddr(sender), optionlist)) + return self.getreply() + + def rcpt(self, recip, options=()): + """SMTP 'rcpt' command -- indicates 1 recipient for this mail.""" + optionlist = '' + if options and self.does_esmtp: + optionlist = ' ' + ' '.join(options) + self.putcmd("rcpt", "TO:%s%s" % (quoteaddr(recip), optionlist)) + return self.getreply() + + def data(self, msg): + """SMTP 'DATA' command -- sends message data to server. + + Automatically quotes lines beginning with a period per rfc821. + Raises SMTPDataError if there is an unexpected reply to the + DATA command; the return value from this method is the final + response code received when the all data is sent. If msg + is a string, lone '\\r' and '\\n' characters are converted to + '\\r\\n' characters. If msg is bytes, it is transmitted as is. + """ + self.putcmd("data") + (code, repl) = self.getreply() + if self.debuglevel > 0: + self._print_debug('data:', (code, repl)) + if code != 354: + raise SMTPDataError(code, repl) + else: + if isinstance(msg, str): + msg = _fix_eols(msg).encode('ascii') + q = _quote_periods(msg) + if q[-2:] != bCRLF: + q = q + bCRLF + q = q + b"." + bCRLF + self.send(q) + (code, msg) = self.getreply() + if self.debuglevel > 0: + self._print_debug('data:', (code, msg)) + return (code, msg) + + def verify(self, address): + """SMTP 'verify' command -- checks for address validity.""" + self.putcmd("vrfy", _addr_only(address)) + return self.getreply() + # a.k.a. + vrfy = verify + + def expn(self, address): + """SMTP 'expn' command -- expands a mailing list.""" + self.putcmd("expn", _addr_only(address)) + return self.getreply() + + # some useful methods + + def ehlo_or_helo_if_needed(self): + """Call self.ehlo() and/or self.helo() if needed. + + If there has been no previous EHLO or HELO command this session, this + method tries ESMTP EHLO first. + + This method may raise the following exceptions: + + SMTPHeloError The server didn't reply properly to + the helo greeting. + """ + if self.helo_resp is None and self.ehlo_resp is None: + if not (200 <= self.ehlo()[0] <= 299): + (code, resp) = self.helo() + if not (200 <= code <= 299): + raise SMTPHeloError(code, resp) + + def auth(self, mechanism, authobject, *, initial_response_ok=True): + """Authentication command - requires response processing. + + 'mechanism' specifies which authentication mechanism is to + be used - the valid values are those listed in the 'auth' + element of 'esmtp_features'. + + 'authobject' must be a callable object taking a single argument: + + data = authobject(challenge) + + It will be called to process the server's challenge response; the + challenge argument it is passed will be a bytes. It should return + an ASCII string that will be base64 encoded and sent to the server. + + Keyword arguments: + - initial_response_ok: Allow sending the RFC 4954 initial-response + to the AUTH command, if the authentication methods supports it. + """ + # RFC 4954 allows auth methods to provide an initial response. Not all + # methods support it. By definition, if they return something other + # than None when challenge is None, then they do. See issue #15014. + mechanism = mechanism.upper() + initial_response = (authobject() if initial_response_ok else None) + if initial_response is not None: + response = encode_base64(initial_response.encode('ascii'), eol='') + (code, resp) = self.docmd("AUTH", mechanism + " " + response) + self._auth_challenge_count = 1 + else: + (code, resp) = self.docmd("AUTH", mechanism) + self._auth_challenge_count = 0 + # If server responds with a challenge, send the response. + while code == 334: + self._auth_challenge_count += 1 + challenge = base64.decodebytes(resp) + response = encode_base64( + authobject(challenge).encode('ascii'), eol='') + (code, resp) = self.docmd(response) + # If server keeps sending challenges, something is wrong. + if self._auth_challenge_count > _MAXCHALLENGE: + raise SMTPException( + "Server AUTH mechanism infinite loop. Last response: " + + repr((code, resp)) + ) + if code in (235, 503): + return (code, resp) + raise SMTPAuthenticationError(code, resp) + + def auth_cram_md5(self, challenge=None): + """ Authobject to use with CRAM-MD5 authentication. Requires self.user + and self.password to be set.""" + # CRAM-MD5 does not support initial-response. + if challenge is None: + return None + return self.user + " " + hmac.HMAC( + self.password.encode('ascii'), challenge, 'md5').hexdigest() + + def auth_plain(self, challenge=None): + """ Authobject to use with PLAIN authentication. Requires self.user and + self.password to be set.""" + return "\0%s\0%s" % (self.user, self.password) + + def auth_login(self, challenge=None): + """ Authobject to use with LOGIN authentication. Requires self.user and + self.password to be set.""" + if challenge is None or self._auth_challenge_count < 2: + return self.user + else: + return self.password + + def login(self, user, password, *, initial_response_ok=True): + """Log in on an SMTP server that requires authentication. + + The arguments are: + - user: The user name to authenticate with. + - password: The password for the authentication. + + Keyword arguments: + - initial_response_ok: Allow sending the RFC 4954 initial-response + to the AUTH command, if the authentication methods supports it. + + If there has been no previous EHLO or HELO command this session, this + method tries ESMTP EHLO first. + + This method will return normally if the authentication was successful. + + This method may raise the following exceptions: + + SMTPHeloError The server didn't reply properly to + the helo greeting. + SMTPAuthenticationError The server didn't accept the username/ + password combination. + SMTPNotSupportedError The AUTH command is not supported by the + server. + SMTPException No suitable authentication method was + found. + """ + + self.ehlo_or_helo_if_needed() + if not self.has_extn("auth"): + raise SMTPNotSupportedError( + "SMTP AUTH extension not supported by server.") + + # Authentication methods the server claims to support + advertised_authlist = self.esmtp_features["auth"].split() + + # Authentication methods we can handle in our preferred order: + preferred_auths = ['CRAM-MD5', 'PLAIN', 'LOGIN'] + + # We try the supported authentications in our preferred order, if + # the server supports them. + authlist = [auth for auth in preferred_auths + if auth in advertised_authlist] + if not authlist: + raise SMTPException("No suitable authentication method found.") + + # Some servers advertise authentication methods they don't really + # support, so if authentication fails, we continue until we've tried + # all methods. + self.user, self.password = user, password + for authmethod in authlist: + method_name = 'auth_' + authmethod.lower().replace('-', '_') + try: + (code, resp) = self.auth( + authmethod, getattr(self, method_name), + initial_response_ok=initial_response_ok) + # 235 == 'Authentication successful' + # 503 == 'Error: already authenticated' + if code in (235, 503): + return (code, resp) + except SMTPAuthenticationError as e: + last_exception = e + + # We could not login successfully. Return result of last attempt. + raise last_exception + + def starttls(self, *, context=None): + """Puts the connection to the SMTP server into TLS mode. + + If there has been no previous EHLO or HELO command this session, this + method tries ESMTP EHLO first. + + If the server supports TLS, this will encrypt the rest of the SMTP + session. If you provide the context parameter, + the identity of the SMTP server and client can be checked. This, + however, depends on whether the socket module really checks the + certificates. + + This method may raise the following exceptions: + + SMTPHeloError The server didn't reply properly to + the helo greeting. + """ + self.ehlo_or_helo_if_needed() + if not self.has_extn("starttls"): + raise SMTPNotSupportedError( + "STARTTLS extension not supported by server.") + (resp, reply) = self.docmd("STARTTLS") + if resp == 220: + if not _have_ssl: + raise RuntimeError("No SSL support included in this Python") + if context is None: + context = ssl._create_stdlib_context() + self.sock = context.wrap_socket(self.sock, + server_hostname=self._host) + self.file = None + # RFC 3207: + # The client MUST discard any knowledge obtained from + # the server, such as the list of SMTP service extensions, + # which was not obtained from the TLS negotiation itself. + self.helo_resp = None + self.ehlo_resp = None + self.esmtp_features = {} + self.does_esmtp = False + else: + # RFC 3207: + # 501 Syntax error (no parameters allowed) + # 454 TLS not available due to temporary reason + raise SMTPResponseException(resp, reply) + return (resp, reply) + + def sendmail(self, from_addr, to_addrs, msg, mail_options=(), + rcpt_options=()): + """This command performs an entire mail transaction. + + The arguments are: + - from_addr : The address sending this mail. + - to_addrs : A list of addresses to send this mail to. A bare + string will be treated as a list with 1 address. + - msg : The message to send. + - mail_options : List of ESMTP options (such as 8bitmime) for the + mail command. + - rcpt_options : List of ESMTP options (such as DSN commands) for + all the rcpt commands. + + msg may be a string containing characters in the ASCII range, or a byte + string. A string is encoded to bytes using the ascii codec, and lone + \\r and \\n characters are converted to \\r\\n characters. + + If there has been no previous EHLO or HELO command this session, this + method tries ESMTP EHLO first. If the server does ESMTP, message size + and each of the specified options will be passed to it. If EHLO + fails, HELO will be tried and ESMTP options suppressed. + + This method will return normally if the mail is accepted for at least + one recipient. It returns a dictionary, with one entry for each + recipient that was refused. Each entry contains a tuple of the SMTP + error code and the accompanying error message sent by the server. + + This method may raise the following exceptions: + + SMTPHeloError The server didn't reply properly to + the helo greeting. + SMTPRecipientsRefused The server rejected ALL recipients + (no mail was sent). + SMTPSenderRefused The server didn't accept the from_addr. + SMTPDataError The server replied with an unexpected + error code (other than a refusal of + a recipient). + SMTPNotSupportedError The mail_options parameter includes 'SMTPUTF8' + but the SMTPUTF8 extension is not supported by + the server. + + Note: the connection will be open even after an exception is raised. + + Example: + + >>> import smtplib + >>> s=smtplib.SMTP("localhost") + >>> tolist=["one@one.org","two@two.org","three@three.org","four@four.org"] + >>> msg = '''\\ + ... From: Me@my.org + ... Subject: testin'... + ... + ... This is a test ''' + >>> s.sendmail("me@my.org",tolist,msg) + { "three@three.org" : ( 550 ,"User unknown" ) } + >>> s.quit() + + In the above example, the message was accepted for delivery to three + of the four addresses, and one was rejected, with the error code + 550. If all addresses are accepted, then the method will return an + empty dictionary. + + """ + self.ehlo_or_helo_if_needed() + esmtp_opts = [] + if isinstance(msg, str): + msg = _fix_eols(msg).encode('ascii') + if self.does_esmtp: + if self.has_extn('size'): + esmtp_opts.append("size=%d" % len(msg)) + for option in mail_options: + esmtp_opts.append(option) + (code, resp) = self.mail(from_addr, esmtp_opts) + if code != 250: + if code == 421: + self.close() + else: + self._rset() + raise SMTPSenderRefused(code, resp, from_addr) + senderrs = {} + if isinstance(to_addrs, str): + to_addrs = [to_addrs] + for each in to_addrs: + (code, resp) = self.rcpt(each, rcpt_options) + if (code != 250) and (code != 251): + senderrs[each] = (code, resp) + if code == 421: + self.close() + raise SMTPRecipientsRefused(senderrs) + if len(senderrs) == len(to_addrs): + # the server refused all our recipients + self._rset() + raise SMTPRecipientsRefused(senderrs) + (code, resp) = self.data(msg) + if code != 250: + if code == 421: + self.close() + else: + self._rset() + raise SMTPDataError(code, resp) + #if we got here then somebody got our mail + return senderrs + + def send_message(self, msg, from_addr=None, to_addrs=None, + mail_options=(), rcpt_options=()): + """Converts message to a bytestring and passes it to sendmail. + + The arguments are as for sendmail, except that msg is an + email.message.Message object. If from_addr is None or to_addrs is + None, these arguments are taken from the headers of the Message as + described in RFC 2822 (a ValueError is raised if there is more than + one set of 'Resent-' headers). Regardless of the values of from_addr and + to_addr, any Bcc field (or Resent-Bcc field, when the Message is a + resent) of the Message object won't be transmitted. The Message + object is then serialized using email.generator.BytesGenerator and + sendmail is called to transmit the message. If the sender or any of + the recipient addresses contain non-ASCII and the server advertises the + SMTPUTF8 capability, the policy is cloned with utf8 set to True for the + serialization, and SMTPUTF8 and BODY=8BITMIME are asserted on the send. + If the server does not support SMTPUTF8, an SMTPNotSupported error is + raised. Otherwise the generator is called without modifying the + policy. + + """ + # 'Resent-Date' is a mandatory field if the Message is resent (RFC 2822 + # Section 3.6.6). In such a case, we use the 'Resent-*' fields. However, + # if there is more than one 'Resent-' block there's no way to + # unambiguously determine which one is the most recent in all cases, + # so rather than guess we raise a ValueError in that case. + # + # TODO implement heuristics to guess the correct Resent-* block with an + # option allowing the user to enable the heuristics. (It should be + # possible to guess correctly almost all of the time.) + + self.ehlo_or_helo_if_needed() + resent = msg.get_all('Resent-Date') + if resent is None: + header_prefix = '' + elif len(resent) == 1: + header_prefix = 'Resent-' + else: + raise ValueError("message has more than one 'Resent-' header block") + if from_addr is None: + # Prefer the sender field per RFC 2822:3.6.2. + from_addr = (msg[header_prefix + 'Sender'] + if (header_prefix + 'Sender') in msg + else msg[header_prefix + 'From']) + from_addr = email.utils.getaddresses([from_addr])[0][1] + if to_addrs is None: + addr_fields = [f for f in (msg[header_prefix + 'To'], + msg[header_prefix + 'Bcc'], + msg[header_prefix + 'Cc']) + if f is not None] + to_addrs = [a[1] for a in email.utils.getaddresses(addr_fields)] + # Make a local copy so we can delete the bcc headers. + msg_copy = copy.copy(msg) + del msg_copy['Bcc'] + del msg_copy['Resent-Bcc'] + international = False + try: + ''.join([from_addr, *to_addrs]).encode('ascii') + except UnicodeEncodeError: + if not self.has_extn('smtputf8'): + raise SMTPNotSupportedError( + "One or more source or delivery addresses require" + " internationalized email support, but the server" + " does not advertise the required SMTPUTF8 capability") + international = True + with io.BytesIO() as bytesmsg: + if international: + g = email.generator.BytesGenerator( + bytesmsg, policy=msg.policy.clone(utf8=True)) + mail_options = (*mail_options, 'SMTPUTF8', 'BODY=8BITMIME') + else: + g = email.generator.BytesGenerator(bytesmsg) + g.flatten(msg_copy, linesep='\r\n') + flatmsg = bytesmsg.getvalue() + return self.sendmail(from_addr, to_addrs, flatmsg, mail_options, + rcpt_options) + + def close(self): + """Close the connection to the SMTP server.""" + try: + file = self.file + self.file = None + if file: + file.close() + finally: + sock = self.sock + self.sock = None + if sock: + sock.close() + + def quit(self): + """Terminate the SMTP session.""" + res = self.docmd("quit") + # A new EHLO is required after reconnecting with connect() + self.ehlo_resp = self.helo_resp = None + self.esmtp_features = {} + self.does_esmtp = False + self.close() + return res + +if _have_ssl: + + class SMTP_SSL(SMTP): + """ This is a subclass derived from SMTP that connects over an SSL + encrypted socket (to use this class you need a socket module that was + compiled with SSL support). If host is not specified, '' (the local + host) is used. If port is omitted, the standard SMTP-over-SSL port + (465) is used. local_hostname and source_address have the same meaning + as they do in the SMTP class. context also optional, can contain a + SSLContext. + + """ + + default_port = SMTP_SSL_PORT + + def __init__(self, host='', port=0, local_hostname=None, + *, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, + source_address=None, context=None): + if context is None: + context = ssl._create_stdlib_context() + self.context = context + SMTP.__init__(self, host, port, local_hostname, timeout, + source_address) + + def _get_socket(self, host, port, timeout): + if self.debuglevel > 0: + self._print_debug('connect:', (host, port)) + new_socket = super()._get_socket(host, port, timeout) + new_socket = self.context.wrap_socket(new_socket, + server_hostname=self._host) + return new_socket + + __all__.append("SMTP_SSL") + +# +# LMTP extension +# +LMTP_PORT = 2003 + +class LMTP(SMTP): + """LMTP - Local Mail Transfer Protocol + + The LMTP protocol, which is very similar to ESMTP, is heavily based + on the standard SMTP client. It's common to use Unix sockets for + LMTP, so our connect() method must support that as well as a regular + host:port server. local_hostname and source_address have the same + meaning as they do in the SMTP class. To specify a Unix socket, + you must use an absolute path as the host, starting with a '/'. + + Authentication is supported, using the regular SMTP mechanism. When + using a Unix socket, LMTP generally don't support or require any + authentication, but your mileage might vary.""" + + ehlo_msg = "lhlo" + + def __init__(self, host='', port=LMTP_PORT, local_hostname=None, + source_address=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT): + """Initialize a new instance.""" + super().__init__(host, port, local_hostname=local_hostname, + source_address=source_address, timeout=timeout) + + def connect(self, host='localhost', port=0, source_address=None): + """Connect to the LMTP daemon, on either a Unix or a TCP socket.""" + if host[0] != '/': + return super().connect(host, port, source_address=source_address) + + if self.timeout is not None and not self.timeout: + raise ValueError('Non-blocking socket (timeout=0) is not supported') + + # Handle Unix-domain sockets. + try: + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + if self.timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: + self.sock.settimeout(self.timeout) + self.file = None + self.sock.connect(host) + except OSError: + if self.debuglevel > 0: + self._print_debug('connect fail:', host) + if self.sock: + self.sock.close() + self.sock = None + raise + (code, msg) = self.getreply() + if self.debuglevel > 0: + self._print_debug('connect:', msg) + return (code, msg) + + +# Test the sendmail method, which tests most of the others. +# Note: This always sends to localhost. +if __name__ == '__main__': + def prompt(prompt): + sys.stdout.write(prompt + ": ") + sys.stdout.flush() + return sys.stdin.readline().strip() + + fromaddr = prompt("From") + toaddrs = prompt("To").split(',') + print("Enter message, end with ^D:") + msg = '' + while line := sys.stdin.readline(): + msg = msg + line + print("Message length is %d" % len(msg)) + + server = SMTP('localhost') + server.set_debuglevel(1) + server.sendmail(fromaddr, toaddrs, msg) + server.quit() diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index a1a6bd8e73..3490936b73 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -400,33 +400,37 @@ def skip_if_buildbot(reason=None): isbuildbot = False return unittest.skipIf(isbuildbot, reason) -def check_sanitizer(*, address=False, memory=False, ub=False): +def check_sanitizer(*, address=False, memory=False, ub=False, thread=False): """Returns True if Python is compiled with sanitizer support""" - if not (address or memory or ub): - raise ValueError('At least one of address, memory, or ub must be True') + if not (address or memory or ub or thread): + raise ValueError('At least one of address, memory, ub or thread must be True') - _cflags = sysconfig.get_config_var('CFLAGS') or '' - _config_args = sysconfig.get_config_var('CONFIG_ARGS') or '' + cflags = sysconfig.get_config_var('CFLAGS') or '' + config_args = sysconfig.get_config_var('CONFIG_ARGS') or '' memory_sanitizer = ( - '-fsanitize=memory' in _cflags or - '--with-memory-sanitizer' in _config_args + '-fsanitize=memory' in cflags or + '--with-memory-sanitizer' in config_args ) address_sanitizer = ( - '-fsanitize=address' in _cflags or - '--with-address-sanitizer' in _config_args + '-fsanitize=address' in cflags or + '--with-address-sanitizer' in config_args ) ub_sanitizer = ( - '-fsanitize=undefined' in _cflags or - '--with-undefined-behavior-sanitizer' in _config_args + '-fsanitize=undefined' in cflags or + '--with-undefined-behavior-sanitizer' in config_args + ) + thread_sanitizer = ( + '-fsanitize=thread' in cflags or + '--with-thread-sanitizer' in config_args ) return ( - (memory and memory_sanitizer) or - (address and address_sanitizer) or - (ub and ub_sanitizer) + (memory and memory_sanitizer) or + (address and address_sanitizer) or + (ub and ub_sanitizer) or + (thread and thread_sanitizer) ) - def skip_if_sanitizer(reason=None, *, address=False, memory=False, ub=False, thread=False): """Decorator raising SkipTest if running with a sanitizer active.""" if not reason: @@ -2550,3 +2554,4 @@ def adjust_int_max_str_digits(max_digits): #Windows doesn't have os.uname() but it doesn't support s390x. skip_on_s390x = unittest.skipIf(hasattr(os, 'uname') and os.uname().machine == 's390x', 'skipped on s390x') +HAVE_ASAN_FORK_BUG = check_sanitizer(address=True) diff --git a/Lib/test/support/smtpd.py b/Lib/test/support/smtpd.py new file mode 100644 index 0000000000..ec4e7d2f4c --- /dev/null +++ b/Lib/test/support/smtpd.py @@ -0,0 +1,873 @@ +#! /usr/bin/env python3 +"""An RFC 5321 smtp proxy with optional RFC 1870 and RFC 6531 extensions. + +Usage: %(program)s [options] [localhost:localport [remotehost:remoteport]] + +Options: + + --nosetuid + -n + This program generally tries to setuid `nobody', unless this flag is + set. The setuid call will fail if this program is not run as root (in + which case, use this flag). + + --version + -V + Print the version number and exit. + + --class classname + -c classname + Use `classname' as the concrete SMTP proxy class. Uses `PureProxy' by + default. + + --size limit + -s limit + Restrict the total size of the incoming message to "limit" number of + bytes via the RFC 1870 SIZE extension. Defaults to 33554432 bytes. + + --smtputf8 + -u + Enable the SMTPUTF8 extension and behave as an RFC 6531 smtp proxy. + + --debug + -d + Turn on debugging prints. + + --help + -h + Print this message and exit. + +Version: %(__version__)s + +If localhost is not given then `localhost' is used, and if localport is not +given then 8025 is used. If remotehost is not given then `localhost' is used, +and if remoteport is not given, then 25 is used. +""" + +# Overview: +# +# This file implements the minimal SMTP protocol as defined in RFC 5321. It +# has a hierarchy of classes which implement the backend functionality for the +# smtpd. A number of classes are provided: +# +# SMTPServer - the base class for the backend. Raises NotImplementedError +# if you try to use it. +# +# DebuggingServer - simply prints each message it receives on stdout. +# +# PureProxy - Proxies all messages to a real smtpd which does final +# delivery. One known problem with this class is that it doesn't handle +# SMTP errors from the backend server at all. This should be fixed +# (contributions are welcome!). +# +# +# Author: Barry Warsaw +# +# TODO: +# +# - support mailbox delivery +# - alias files +# - Handle more ESMTP extensions +# - handle error codes from the backend smtpd + +import sys +import os +import errno +import getopt +import time +import socket +import collections +from test.support import asyncore, asynchat +from warnings import warn +from email._header_value_parser import get_addr_spec, get_angle_addr + +__all__ = [ + "SMTPChannel", "SMTPServer", "DebuggingServer", "PureProxy", +] + +program = sys.argv[0] +__version__ = 'Python SMTP proxy version 0.3' + + +class Devnull: + def write(self, msg): pass + def flush(self): pass + + +DEBUGSTREAM = Devnull() +NEWLINE = '\n' +COMMASPACE = ', ' +DATA_SIZE_DEFAULT = 33554432 + + +def usage(code, msg=''): + print(__doc__ % globals(), file=sys.stderr) + if msg: + print(msg, file=sys.stderr) + sys.exit(code) + + +class SMTPChannel(asynchat.async_chat): + COMMAND = 0 + DATA = 1 + + command_size_limit = 512 + command_size_limits = collections.defaultdict(lambda x=command_size_limit: x) + + @property + def max_command_size_limit(self): + try: + return max(self.command_size_limits.values()) + except ValueError: + return self.command_size_limit + + def __init__(self, server, conn, addr, data_size_limit=DATA_SIZE_DEFAULT, + map=None, enable_SMTPUTF8=False, decode_data=False): + asynchat.async_chat.__init__(self, conn, map=map) + self.smtp_server = server + self.conn = conn + self.addr = addr + self.data_size_limit = data_size_limit + self.enable_SMTPUTF8 = enable_SMTPUTF8 + self._decode_data = decode_data + if enable_SMTPUTF8 and decode_data: + raise ValueError("decode_data and enable_SMTPUTF8 cannot" + " be set to True at the same time") + if decode_data: + self._emptystring = '' + self._linesep = '\r\n' + self._dotsep = '.' + self._newline = NEWLINE + else: + self._emptystring = b'' + self._linesep = b'\r\n' + self._dotsep = ord(b'.') + self._newline = b'\n' + self._set_rset_state() + self.seen_greeting = '' + self.extended_smtp = False + self.command_size_limits.clear() + self.fqdn = socket.getfqdn() + try: + self.peer = conn.getpeername() + except OSError as err: + # a race condition may occur if the other end is closing + # before we can get the peername + self.close() + if err.errno != errno.ENOTCONN: + raise + return + print('Peer:', repr(self.peer), file=DEBUGSTREAM) + self.push('220 %s %s' % (self.fqdn, __version__)) + + def _set_post_data_state(self): + """Reset state variables to their post-DATA state.""" + self.smtp_state = self.COMMAND + self.mailfrom = None + self.rcpttos = [] + self.require_SMTPUTF8 = False + self.num_bytes = 0 + self.set_terminator(b'\r\n') + + def _set_rset_state(self): + """Reset all state variables except the greeting.""" + self._set_post_data_state() + self.received_data = '' + self.received_lines = [] + + + # properties for backwards-compatibility + @property + def __server(self): + warn("Access to __server attribute on SMTPChannel is deprecated, " + "use 'smtp_server' instead", DeprecationWarning, 2) + return self.smtp_server + @__server.setter + def __server(self, value): + warn("Setting __server attribute on SMTPChannel is deprecated, " + "set 'smtp_server' instead", DeprecationWarning, 2) + self.smtp_server = value + + @property + def __line(self): + warn("Access to __line attribute on SMTPChannel is deprecated, " + "use 'received_lines' instead", DeprecationWarning, 2) + return self.received_lines + @__line.setter + def __line(self, value): + warn("Setting __line attribute on SMTPChannel is deprecated, " + "set 'received_lines' instead", DeprecationWarning, 2) + self.received_lines = value + + @property + def __state(self): + warn("Access to __state attribute on SMTPChannel is deprecated, " + "use 'smtp_state' instead", DeprecationWarning, 2) + return self.smtp_state + @__state.setter + def __state(self, value): + warn("Setting __state attribute on SMTPChannel is deprecated, " + "set 'smtp_state' instead", DeprecationWarning, 2) + self.smtp_state = value + + @property + def __greeting(self): + warn("Access to __greeting attribute on SMTPChannel is deprecated, " + "use 'seen_greeting' instead", DeprecationWarning, 2) + return self.seen_greeting + @__greeting.setter + def __greeting(self, value): + warn("Setting __greeting attribute on SMTPChannel is deprecated, " + "set 'seen_greeting' instead", DeprecationWarning, 2) + self.seen_greeting = value + + @property + def __mailfrom(self): + warn("Access to __mailfrom attribute on SMTPChannel is deprecated, " + "use 'mailfrom' instead", DeprecationWarning, 2) + return self.mailfrom + @__mailfrom.setter + def __mailfrom(self, value): + warn("Setting __mailfrom attribute on SMTPChannel is deprecated, " + "set 'mailfrom' instead", DeprecationWarning, 2) + self.mailfrom = value + + @property + def __rcpttos(self): + warn("Access to __rcpttos attribute on SMTPChannel is deprecated, " + "use 'rcpttos' instead", DeprecationWarning, 2) + return self.rcpttos + @__rcpttos.setter + def __rcpttos(self, value): + warn("Setting __rcpttos attribute on SMTPChannel is deprecated, " + "set 'rcpttos' instead", DeprecationWarning, 2) + self.rcpttos = value + + @property + def __data(self): + warn("Access to __data attribute on SMTPChannel is deprecated, " + "use 'received_data' instead", DeprecationWarning, 2) + return self.received_data + @__data.setter + def __data(self, value): + warn("Setting __data attribute on SMTPChannel is deprecated, " + "set 'received_data' instead", DeprecationWarning, 2) + self.received_data = value + + @property + def __fqdn(self): + warn("Access to __fqdn attribute on SMTPChannel is deprecated, " + "use 'fqdn' instead", DeprecationWarning, 2) + return self.fqdn + @__fqdn.setter + def __fqdn(self, value): + warn("Setting __fqdn attribute on SMTPChannel is deprecated, " + "set 'fqdn' instead", DeprecationWarning, 2) + self.fqdn = value + + @property + def __peer(self): + warn("Access to __peer attribute on SMTPChannel is deprecated, " + "use 'peer' instead", DeprecationWarning, 2) + return self.peer + @__peer.setter + def __peer(self, value): + warn("Setting __peer attribute on SMTPChannel is deprecated, " + "set 'peer' instead", DeprecationWarning, 2) + self.peer = value + + @property + def __conn(self): + warn("Access to __conn attribute on SMTPChannel is deprecated, " + "use 'conn' instead", DeprecationWarning, 2) + return self.conn + @__conn.setter + def __conn(self, value): + warn("Setting __conn attribute on SMTPChannel is deprecated, " + "set 'conn' instead", DeprecationWarning, 2) + self.conn = value + + @property + def __addr(self): + warn("Access to __addr attribute on SMTPChannel is deprecated, " + "use 'addr' instead", DeprecationWarning, 2) + return self.addr + @__addr.setter + def __addr(self, value): + warn("Setting __addr attribute on SMTPChannel is deprecated, " + "set 'addr' instead", DeprecationWarning, 2) + self.addr = value + + # Overrides base class for convenience. + def push(self, msg): + asynchat.async_chat.push(self, bytes( + msg + '\r\n', 'utf-8' if self.require_SMTPUTF8 else 'ascii')) + + # Implementation of base class abstract method + def collect_incoming_data(self, data): + limit = None + if self.smtp_state == self.COMMAND: + limit = self.max_command_size_limit + elif self.smtp_state == self.DATA: + limit = self.data_size_limit + if limit and self.num_bytes > limit: + return + elif limit: + self.num_bytes += len(data) + if self._decode_data: + self.received_lines.append(str(data, 'utf-8')) + else: + self.received_lines.append(data) + + # Implementation of base class abstract method + def found_terminator(self): + line = self._emptystring.join(self.received_lines) + print('Data:', repr(line), file=DEBUGSTREAM) + self.received_lines = [] + if self.smtp_state == self.COMMAND: + sz, self.num_bytes = self.num_bytes, 0 + if not line: + self.push('500 Error: bad syntax') + return + if not self._decode_data: + line = str(line, 'utf-8') + i = line.find(' ') + if i < 0: + command = line.upper() + arg = None + else: + command = line[:i].upper() + arg = line[i+1:].strip() + max_sz = (self.command_size_limits[command] + if self.extended_smtp else self.command_size_limit) + if sz > max_sz: + self.push('500 Error: line too long') + return + method = getattr(self, 'smtp_' + command, None) + if not method: + self.push('500 Error: command "%s" not recognized' % command) + return + method(arg) + return + else: + if self.smtp_state != self.DATA: + self.push('451 Internal confusion') + self.num_bytes = 0 + return + if self.data_size_limit and self.num_bytes > self.data_size_limit: + self.push('552 Error: Too much mail data') + self.num_bytes = 0 + return + # Remove extraneous carriage returns and de-transparency according + # to RFC 5321, Section 4.5.2. + data = [] + for text in line.split(self._linesep): + if text and text[0] == self._dotsep: + data.append(text[1:]) + else: + data.append(text) + self.received_data = self._newline.join(data) + args = (self.peer, self.mailfrom, self.rcpttos, self.received_data) + kwargs = {} + if not self._decode_data: + kwargs = { + 'mail_options': self.mail_options, + 'rcpt_options': self.rcpt_options, + } + status = self.smtp_server.process_message(*args, **kwargs) + self._set_post_data_state() + if not status: + self.push('250 OK') + else: + self.push(status) + + # SMTP and ESMTP commands + def smtp_HELO(self, arg): + if not arg: + self.push('501 Syntax: HELO hostname') + return + # See issue #21783 for a discussion of this behavior. + if self.seen_greeting: + self.push('503 Duplicate HELO/EHLO') + return + self._set_rset_state() + self.seen_greeting = arg + self.push('250 %s' % self.fqdn) + + def smtp_EHLO(self, arg): + if not arg: + self.push('501 Syntax: EHLO hostname') + return + # See issue #21783 for a discussion of this behavior. + if self.seen_greeting: + self.push('503 Duplicate HELO/EHLO') + return + self._set_rset_state() + self.seen_greeting = arg + self.extended_smtp = True + self.push('250-%s' % self.fqdn) + if self.data_size_limit: + self.push('250-SIZE %s' % self.data_size_limit) + self.command_size_limits['MAIL'] += 26 + if not self._decode_data: + self.push('250-8BITMIME') + if self.enable_SMTPUTF8: + self.push('250-SMTPUTF8') + self.command_size_limits['MAIL'] += 10 + self.push('250 HELP') + + def smtp_NOOP(self, arg): + if arg: + self.push('501 Syntax: NOOP') + else: + self.push('250 OK') + + def smtp_QUIT(self, arg): + # args is ignored + self.push('221 Bye') + self.close_when_done() + + def _strip_command_keyword(self, keyword, arg): + keylen = len(keyword) + if arg[:keylen].upper() == keyword: + return arg[keylen:].strip() + return '' + + def _getaddr(self, arg): + if not arg: + return '', '' + if arg.lstrip().startswith('<'): + address, rest = get_angle_addr(arg) + else: + address, rest = get_addr_spec(arg) + if not address: + return address, rest + return address.addr_spec, rest + + def _getparams(self, params): + # Return params as dictionary. Return None if not all parameters + # appear to be syntactically valid according to RFC 1869. + result = {} + for param in params: + param, eq, value = param.partition('=') + if not param.isalnum() or eq and not value: + return None + result[param] = value if eq else True + return result + + def smtp_HELP(self, arg): + if arg: + extended = ' [SP ]' + lc_arg = arg.upper() + if lc_arg == 'EHLO': + self.push('250 Syntax: EHLO hostname') + elif lc_arg == 'HELO': + self.push('250 Syntax: HELO hostname') + elif lc_arg == 'MAIL': + msg = '250 Syntax: MAIL FROM:
' + if self.extended_smtp: + msg += extended + self.push(msg) + elif lc_arg == 'RCPT': + msg = '250 Syntax: RCPT TO:
' + if self.extended_smtp: + msg += extended + self.push(msg) + elif lc_arg == 'DATA': + self.push('250 Syntax: DATA') + elif lc_arg == 'RSET': + self.push('250 Syntax: RSET') + elif lc_arg == 'NOOP': + self.push('250 Syntax: NOOP') + elif lc_arg == 'QUIT': + self.push('250 Syntax: QUIT') + elif lc_arg == 'VRFY': + self.push('250 Syntax: VRFY
') + else: + self.push('501 Supported commands: EHLO HELO MAIL RCPT ' + 'DATA RSET NOOP QUIT VRFY') + else: + self.push('250 Supported commands: EHLO HELO MAIL RCPT DATA ' + 'RSET NOOP QUIT VRFY') + + def smtp_VRFY(self, arg): + if arg: + address, params = self._getaddr(arg) + if address: + self.push('252 Cannot VRFY user, but will accept message ' + 'and attempt delivery') + else: + self.push('502 Could not VRFY %s' % arg) + else: + self.push('501 Syntax: VRFY
') + + def smtp_MAIL(self, arg): + if not self.seen_greeting: + self.push('503 Error: send HELO first') + return + print('===> MAIL', arg, file=DEBUGSTREAM) + syntaxerr = '501 Syntax: MAIL FROM:
' + if self.extended_smtp: + syntaxerr += ' [SP ]' + if arg is None: + self.push(syntaxerr) + return + arg = self._strip_command_keyword('FROM:', arg) + address, params = self._getaddr(arg) + if not address: + self.push(syntaxerr) + return + if not self.extended_smtp and params: + self.push(syntaxerr) + return + if self.mailfrom: + self.push('503 Error: nested MAIL command') + return + self.mail_options = params.upper().split() + params = self._getparams(self.mail_options) + if params is None: + self.push(syntaxerr) + return + if not self._decode_data: + body = params.pop('BODY', '7BIT') + if body not in ['7BIT', '8BITMIME']: + self.push('501 Error: BODY can only be one of 7BIT, 8BITMIME') + return + if self.enable_SMTPUTF8: + smtputf8 = params.pop('SMTPUTF8', False) + if smtputf8 is True: + self.require_SMTPUTF8 = True + elif smtputf8 is not False: + self.push('501 Error: SMTPUTF8 takes no arguments') + return + size = params.pop('SIZE', None) + if size: + if not size.isdigit(): + self.push(syntaxerr) + return + elif self.data_size_limit and int(size) > self.data_size_limit: + self.push('552 Error: message size exceeds fixed maximum message size') + return + if len(params.keys()) > 0: + self.push('555 MAIL FROM parameters not recognized or not implemented') + return + self.mailfrom = address + print('sender:', self.mailfrom, file=DEBUGSTREAM) + self.push('250 OK') + + def smtp_RCPT(self, arg): + if not self.seen_greeting: + self.push('503 Error: send HELO first'); + return + print('===> RCPT', arg, file=DEBUGSTREAM) + if not self.mailfrom: + self.push('503 Error: need MAIL command') + return + syntaxerr = '501 Syntax: RCPT TO:
' + if self.extended_smtp: + syntaxerr += ' [SP ]' + if arg is None: + self.push(syntaxerr) + return + arg = self._strip_command_keyword('TO:', arg) + address, params = self._getaddr(arg) + if not address: + self.push(syntaxerr) + return + if not self.extended_smtp and params: + self.push(syntaxerr) + return + self.rcpt_options = params.upper().split() + params = self._getparams(self.rcpt_options) + if params is None: + self.push(syntaxerr) + return + # XXX currently there are no options we recognize. + if len(params.keys()) > 0: + self.push('555 RCPT TO parameters not recognized or not implemented') + return + self.rcpttos.append(address) + print('recips:', self.rcpttos, file=DEBUGSTREAM) + self.push('250 OK') + + def smtp_RSET(self, arg): + if arg: + self.push('501 Syntax: RSET') + return + self._set_rset_state() + self.push('250 OK') + + def smtp_DATA(self, arg): + if not self.seen_greeting: + self.push('503 Error: send HELO first'); + return + if not self.rcpttos: + self.push('503 Error: need RCPT command') + return + if arg: + self.push('501 Syntax: DATA') + return + self.smtp_state = self.DATA + self.set_terminator(b'\r\n.\r\n') + self.push('354 End data with .') + + # Commands that have not been implemented + def smtp_EXPN(self, arg): + self.push('502 EXPN not implemented') + + +class SMTPServer(asyncore.dispatcher): + # SMTPChannel class to use for managing client connections + channel_class = SMTPChannel + + def __init__(self, localaddr, remoteaddr, + data_size_limit=DATA_SIZE_DEFAULT, map=None, + enable_SMTPUTF8=False, decode_data=False): + self._localaddr = localaddr + self._remoteaddr = remoteaddr + self.data_size_limit = data_size_limit + self.enable_SMTPUTF8 = enable_SMTPUTF8 + self._decode_data = decode_data + if enable_SMTPUTF8 and decode_data: + raise ValueError("decode_data and enable_SMTPUTF8 cannot" + " be set to True at the same time") + asyncore.dispatcher.__init__(self, map=map) + try: + gai_results = socket.getaddrinfo(*localaddr, + type=socket.SOCK_STREAM) + self.create_socket(gai_results[0][0], gai_results[0][1]) + # try to re-use a server port if possible + self.set_reuse_addr() + self.bind(localaddr) + self.listen(5) + except: + self.close() + raise + else: + print('%s started at %s\n\tLocal addr: %s\n\tRemote addr:%s' % ( + self.__class__.__name__, time.ctime(time.time()), + localaddr, remoteaddr), file=DEBUGSTREAM) + + def handle_accepted(self, conn, addr): + print('Incoming connection from %s' % repr(addr), file=DEBUGSTREAM) + channel = self.channel_class(self, + conn, + addr, + self.data_size_limit, + self._map, + self.enable_SMTPUTF8, + self._decode_data) + + # API for "doing something useful with the message" + def process_message(self, peer, mailfrom, rcpttos, data, **kwargs): + """Override this abstract method to handle messages from the client. + + peer is a tuple containing (ipaddr, port) of the client that made the + socket connection to our smtp port. + + mailfrom is the raw address the client claims the message is coming + from. + + rcpttos is a list of raw addresses the client wishes to deliver the + message to. + + data is a string containing the entire full text of the message, + headers (if supplied) and all. It has been `de-transparencied' + according to RFC 821, Section 4.5.2. In other words, a line + containing a `.' followed by other text has had the leading dot + removed. + + kwargs is a dictionary containing additional information. It is + empty if decode_data=True was given as init parameter, otherwise + it will contain the following keys: + 'mail_options': list of parameters to the mail command. All + elements are uppercase strings. Example: + ['BODY=8BITMIME', 'SMTPUTF8']. + 'rcpt_options': same, for the rcpt command. + + This function should return None for a normal `250 Ok' response; + otherwise, it should return the desired response string in RFC 821 + format. + + """ + raise NotImplementedError + + +class DebuggingServer(SMTPServer): + + def _print_message_content(self, peer, data): + inheaders = 1 + lines = data.splitlines() + for line in lines: + # headers first + if inheaders and not line: + peerheader = 'X-Peer: ' + peer[0] + if not isinstance(data, str): + # decoded_data=false; make header match other binary output + peerheader = repr(peerheader.encode('utf-8')) + print(peerheader) + inheaders = 0 + if not isinstance(data, str): + # Avoid spurious 'str on bytes instance' warning. + line = repr(line) + print(line) + + def process_message(self, peer, mailfrom, rcpttos, data, **kwargs): + print('---------- MESSAGE FOLLOWS ----------') + if kwargs: + if kwargs.get('mail_options'): + print('mail options: %s' % kwargs['mail_options']) + if kwargs.get('rcpt_options'): + print('rcpt options: %s\n' % kwargs['rcpt_options']) + self._print_message_content(peer, data) + print('------------ END MESSAGE ------------') + + +class PureProxy(SMTPServer): + def __init__(self, *args, **kwargs): + if 'enable_SMTPUTF8' in kwargs and kwargs['enable_SMTPUTF8']: + raise ValueError("PureProxy does not support SMTPUTF8.") + super(PureProxy, self).__init__(*args, **kwargs) + + def process_message(self, peer, mailfrom, rcpttos, data): + lines = data.split('\n') + # Look for the last header + i = 0 + for line in lines: + if not line: + break + i += 1 + lines.insert(i, 'X-Peer: %s' % peer[0]) + data = NEWLINE.join(lines) + refused = self._deliver(mailfrom, rcpttos, data) + # TBD: what to do with refused addresses? + print('we got some refusals:', refused, file=DEBUGSTREAM) + + def _deliver(self, mailfrom, rcpttos, data): + import smtplib + refused = {} + try: + s = smtplib.SMTP() + s.connect(self._remoteaddr[0], self._remoteaddr[1]) + try: + refused = s.sendmail(mailfrom, rcpttos, data) + finally: + s.quit() + except smtplib.SMTPRecipientsRefused as e: + print('got SMTPRecipientsRefused', file=DEBUGSTREAM) + refused = e.recipients + except (OSError, smtplib.SMTPException) as e: + print('got', e.__class__, file=DEBUGSTREAM) + # All recipients were refused. If the exception had an associated + # error code, use it. Otherwise,fake it with a non-triggering + # exception code. + errcode = getattr(e, 'smtp_code', -1) + errmsg = getattr(e, 'smtp_error', 'ignore') + for r in rcpttos: + refused[r] = (errcode, errmsg) + return refused + + +class Options: + setuid = True + classname = 'PureProxy' + size_limit = None + enable_SMTPUTF8 = False + + +def parseargs(): + global DEBUGSTREAM + try: + opts, args = getopt.getopt( + sys.argv[1:], 'nVhc:s:du', + ['class=', 'nosetuid', 'version', 'help', 'size=', 'debug', + 'smtputf8']) + except getopt.error as e: + usage(1, e) + + options = Options() + for opt, arg in opts: + if opt in ('-h', '--help'): + usage(0) + elif opt in ('-V', '--version'): + print(__version__) + sys.exit(0) + elif opt in ('-n', '--nosetuid'): + options.setuid = False + elif opt in ('-c', '--class'): + options.classname = arg + elif opt in ('-d', '--debug'): + DEBUGSTREAM = sys.stderr + elif opt in ('-u', '--smtputf8'): + options.enable_SMTPUTF8 = True + elif opt in ('-s', '--size'): + try: + int_size = int(arg) + options.size_limit = int_size + except: + print('Invalid size: ' + arg, file=sys.stderr) + sys.exit(1) + + # parse the rest of the arguments + if len(args) < 1: + localspec = 'localhost:8025' + remotespec = 'localhost:25' + elif len(args) < 2: + localspec = args[0] + remotespec = 'localhost:25' + elif len(args) < 3: + localspec = args[0] + remotespec = args[1] + else: + usage(1, 'Invalid arguments: %s' % COMMASPACE.join(args)) + + # split into host/port pairs + i = localspec.find(':') + if i < 0: + usage(1, 'Bad local spec: %s' % localspec) + options.localhost = localspec[:i] + try: + options.localport = int(localspec[i+1:]) + except ValueError: + usage(1, 'Bad local port: %s' % localspec) + i = remotespec.find(':') + if i < 0: + usage(1, 'Bad remote spec: %s' % remotespec) + options.remotehost = remotespec[:i] + try: + options.remoteport = int(remotespec[i+1:]) + except ValueError: + usage(1, 'Bad remote port: %s' % remotespec) + return options + + +if __name__ == '__main__': + options = parseargs() + # Become nobody + classname = options.classname + if "." in classname: + lastdot = classname.rfind(".") + mod = __import__(classname[:lastdot], globals(), locals(), [""]) + classname = classname[lastdot+1:] + else: + import __main__ as mod + class_ = getattr(mod, classname) + proxy = class_((options.localhost, options.localport), + (options.remotehost, options.remoteport), + options.size_limit, enable_SMTPUTF8=options.enable_SMTPUTF8) + if options.setuid: + try: + import pwd + except ImportError: + print('Cannot import module "pwd"; try running with -n option.', file=sys.stderr) + sys.exit(1) + nobody = pwd.getpwnam('nobody')[2] + try: + os.setuid(nobody) + except PermissionError: + print('Cannot setuid "nobody"; try running with -n option.', file=sys.stderr) + sys.exit(1) + try: + asyncore.loop() + except KeyboardInterrupt: + pass diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py new file mode 100644 index 0000000000..8eb03354f8 --- /dev/null +++ b/Lib/test/test_logging.py @@ -0,0 +1,7023 @@ +# Copyright 2001-2022 by Vinay Sajip. All Rights Reserved. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose and without fee is hereby granted, +# provided that the above copyright notice appear in all copies and that +# both that copyright notice and this permission notice appear in +# supporting documentation, and that the name of Vinay Sajip +# not be used in advertising or publicity pertaining to distribution +# of the software without specific, written prior permission. +# VINAY SAJIP DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING +# ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL +# VINAY SAJIP BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR +# ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER +# IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Test harness for the logging module. Run all tests. + +Copyright (C) 2001-2022 Vinay Sajip. All Rights Reserved. +""" +import logging +import logging.handlers +import logging.config + +import codecs +import configparser +import copy +import datetime +import pathlib +import pickle +import io +import itertools +import gc +import json +import os +import queue +import random +import re +import shutil +import socket +import struct +import sys +import tempfile +from test.support.script_helper import assert_python_ok, assert_python_failure +from test import support +from test.support import import_helper +from test.support import os_helper +from test.support import socket_helper +from test.support import threading_helper +from test.support import warnings_helper +from test.support import asyncore +from test.support import smtpd +from test.support.logging_helper import TestHandler +import textwrap +import threading +import asyncio +import time +import unittest +import warnings +import weakref + +from http.server import HTTPServer, BaseHTTPRequestHandler +from unittest.mock import patch +from urllib.parse import urlparse, parse_qs +from socketserver import (ThreadingUDPServer, DatagramRequestHandler, + ThreadingTCPServer, StreamRequestHandler) + +try: + import win32evtlog, win32evtlogutil, pywintypes +except ImportError: + win32evtlog = win32evtlogutil = pywintypes = None + +try: + import zlib +except ImportError: + pass + + +# gh-89363: Skip fork() test if Python is built with Address Sanitizer (ASAN) +# to work around a libasan race condition, dead lock in pthread_create(). +skip_if_asan_fork = unittest.skipIf( + support.HAVE_ASAN_FORK_BUG, + "libasan has a pthread_create() dead lock related to thread+fork") +skip_if_tsan_fork = unittest.skipIf( + support.check_sanitizer(thread=True), + "TSAN doesn't support threads after fork") + + +class BaseTest(unittest.TestCase): + + """Base class for logging tests.""" + + log_format = "%(name)s -> %(levelname)s: %(message)s" + expected_log_pat = r"^([\w.]+) -> (\w+): (\d+)$" + message_num = 0 + + def setUp(self): + """Setup the default logging stream to an internal StringIO instance, + so that we can examine log output as we want.""" + self._threading_key = threading_helper.threading_setup() + + logger_dict = logging.getLogger().manager.loggerDict + logging._acquireLock() + try: + self.saved_handlers = logging._handlers.copy() + self.saved_handler_list = logging._handlerList[:] + self.saved_loggers = saved_loggers = logger_dict.copy() + self.saved_name_to_level = logging._nameToLevel.copy() + self.saved_level_to_name = logging._levelToName.copy() + self.logger_states = logger_states = {} + for name in saved_loggers: + logger_states[name] = getattr(saved_loggers[name], + 'disabled', None) + finally: + logging._releaseLock() + + # Set two unused loggers + self.logger1 = logging.getLogger("\xab\xd7\xbb") + self.logger2 = logging.getLogger("\u013f\u00d6\u0047") + + self.root_logger = logging.getLogger("") + self.original_logging_level = self.root_logger.getEffectiveLevel() + + self.stream = io.StringIO() + self.root_logger.setLevel(logging.DEBUG) + self.root_hdlr = logging.StreamHandler(self.stream) + self.root_formatter = logging.Formatter(self.log_format) + self.root_hdlr.setFormatter(self.root_formatter) + if self.logger1.hasHandlers(): + hlist = self.logger1.handlers + self.root_logger.handlers + raise AssertionError('Unexpected handlers: %s' % hlist) + if self.logger2.hasHandlers(): + hlist = self.logger2.handlers + self.root_logger.handlers + raise AssertionError('Unexpected handlers: %s' % hlist) + self.root_logger.addHandler(self.root_hdlr) + self.assertTrue(self.logger1.hasHandlers()) + self.assertTrue(self.logger2.hasHandlers()) + + def tearDown(self): + """Remove our logging stream, and restore the original logging + level.""" + self.stream.close() + self.root_logger.removeHandler(self.root_hdlr) + while self.root_logger.handlers: + h = self.root_logger.handlers[0] + self.root_logger.removeHandler(h) + h.close() + self.root_logger.setLevel(self.original_logging_level) + logging._acquireLock() + try: + logging._levelToName.clear() + logging._levelToName.update(self.saved_level_to_name) + logging._nameToLevel.clear() + logging._nameToLevel.update(self.saved_name_to_level) + logging._handlers.clear() + logging._handlers.update(self.saved_handlers) + logging._handlerList[:] = self.saved_handler_list + manager = logging.getLogger().manager + manager.disable = 0 + loggerDict = manager.loggerDict + loggerDict.clear() + loggerDict.update(self.saved_loggers) + logger_states = self.logger_states + for name in self.logger_states: + if logger_states[name] is not None: + self.saved_loggers[name].disabled = logger_states[name] + finally: + logging._releaseLock() + + self.doCleanups() + threading_helper.threading_cleanup(*self._threading_key) + + def assert_log_lines(self, expected_values, stream=None, pat=None): + """Match the collected log lines against the regular expression + self.expected_log_pat, and compare the extracted group values to + the expected_values list of tuples.""" + stream = stream or self.stream + pat = re.compile(pat or self.expected_log_pat) + actual_lines = stream.getvalue().splitlines() + self.assertEqual(len(actual_lines), len(expected_values)) + for actual, expected in zip(actual_lines, expected_values): + match = pat.search(actual) + if not match: + self.fail("Log line does not match expected pattern:\n" + + actual) + self.assertEqual(tuple(match.groups()), expected) + s = stream.read() + if s: + self.fail("Remaining output at end of log stream:\n" + s) + + def next_message(self): + """Generate a message consisting solely of an auto-incrementing + integer.""" + self.message_num += 1 + return "%d" % self.message_num + + +class BuiltinLevelsTest(BaseTest): + """Test builtin levels and their inheritance.""" + + def test_flat(self): + # Logging levels in a flat logger namespace. + m = self.next_message + + ERR = logging.getLogger("ERR") + ERR.setLevel(logging.ERROR) + INF = logging.LoggerAdapter(logging.getLogger("INF"), {}) + INF.setLevel(logging.INFO) + DEB = logging.getLogger("DEB") + DEB.setLevel(logging.DEBUG) + + # These should log. + ERR.log(logging.CRITICAL, m()) + ERR.error(m()) + + INF.log(logging.CRITICAL, m()) + INF.error(m()) + INF.warning(m()) + INF.info(m()) + + DEB.log(logging.CRITICAL, m()) + DEB.error(m()) + DEB.warning(m()) + DEB.info(m()) + DEB.debug(m()) + + # These should not log. + ERR.warning(m()) + ERR.info(m()) + ERR.debug(m()) + + INF.debug(m()) + + self.assert_log_lines([ + ('ERR', 'CRITICAL', '1'), + ('ERR', 'ERROR', '2'), + ('INF', 'CRITICAL', '3'), + ('INF', 'ERROR', '4'), + ('INF', 'WARNING', '5'), + ('INF', 'INFO', '6'), + ('DEB', 'CRITICAL', '7'), + ('DEB', 'ERROR', '8'), + ('DEB', 'WARNING', '9'), + ('DEB', 'INFO', '10'), + ('DEB', 'DEBUG', '11'), + ]) + + def test_nested_explicit(self): + # Logging levels in a nested namespace, all explicitly set. + m = self.next_message + + INF = logging.getLogger("INF") + INF.setLevel(logging.INFO) + INF_ERR = logging.getLogger("INF.ERR") + INF_ERR.setLevel(logging.ERROR) + + # These should log. + INF_ERR.log(logging.CRITICAL, m()) + INF_ERR.error(m()) + + # These should not log. + INF_ERR.warning(m()) + INF_ERR.info(m()) + INF_ERR.debug(m()) + + self.assert_log_lines([ + ('INF.ERR', 'CRITICAL', '1'), + ('INF.ERR', 'ERROR', '2'), + ]) + + def test_nested_inherited(self): + # Logging levels in a nested namespace, inherited from parent loggers. + m = self.next_message + + INF = logging.getLogger("INF") + INF.setLevel(logging.INFO) + INF_ERR = logging.getLogger("INF.ERR") + INF_ERR.setLevel(logging.ERROR) + INF_UNDEF = logging.getLogger("INF.UNDEF") + INF_ERR_UNDEF = logging.getLogger("INF.ERR.UNDEF") + UNDEF = logging.getLogger("UNDEF") + + # These should log. + INF_UNDEF.log(logging.CRITICAL, m()) + INF_UNDEF.error(m()) + INF_UNDEF.warning(m()) + INF_UNDEF.info(m()) + INF_ERR_UNDEF.log(logging.CRITICAL, m()) + INF_ERR_UNDEF.error(m()) + + # These should not log. + INF_UNDEF.debug(m()) + INF_ERR_UNDEF.warning(m()) + INF_ERR_UNDEF.info(m()) + INF_ERR_UNDEF.debug(m()) + + self.assert_log_lines([ + ('INF.UNDEF', 'CRITICAL', '1'), + ('INF.UNDEF', 'ERROR', '2'), + ('INF.UNDEF', 'WARNING', '3'), + ('INF.UNDEF', 'INFO', '4'), + ('INF.ERR.UNDEF', 'CRITICAL', '5'), + ('INF.ERR.UNDEF', 'ERROR', '6'), + ]) + + def test_nested_with_virtual_parent(self): + # Logging levels when some parent does not exist yet. + m = self.next_message + + INF = logging.getLogger("INF") + GRANDCHILD = logging.getLogger("INF.BADPARENT.UNDEF") + CHILD = logging.getLogger("INF.BADPARENT") + INF.setLevel(logging.INFO) + + # These should log. + GRANDCHILD.log(logging.FATAL, m()) + GRANDCHILD.info(m()) + CHILD.log(logging.FATAL, m()) + CHILD.info(m()) + + # These should not log. + GRANDCHILD.debug(m()) + CHILD.debug(m()) + + self.assert_log_lines([ + ('INF.BADPARENT.UNDEF', 'CRITICAL', '1'), + ('INF.BADPARENT.UNDEF', 'INFO', '2'), + ('INF.BADPARENT', 'CRITICAL', '3'), + ('INF.BADPARENT', 'INFO', '4'), + ]) + + def test_regression_22386(self): + """See issue #22386 for more information.""" + self.assertEqual(logging.getLevelName('INFO'), logging.INFO) + self.assertEqual(logging.getLevelName(logging.INFO), 'INFO') + + def test_issue27935(self): + fatal = logging.getLevelName('FATAL') + self.assertEqual(fatal, logging.FATAL) + + def test_regression_29220(self): + """See issue #29220 for more information.""" + logging.addLevelName(logging.INFO, '') + self.addCleanup(logging.addLevelName, logging.INFO, 'INFO') + self.assertEqual(logging.getLevelName(logging.INFO), '') + self.assertEqual(logging.getLevelName(logging.NOTSET), 'NOTSET') + self.assertEqual(logging.getLevelName('NOTSET'), logging.NOTSET) + +class BasicFilterTest(BaseTest): + + """Test the bundled Filter class.""" + + def test_filter(self): + # Only messages satisfying the specified criteria pass through the + # filter. + filter_ = logging.Filter("spam.eggs") + handler = self.root_logger.handlers[0] + try: + handler.addFilter(filter_) + spam = logging.getLogger("spam") + spam_eggs = logging.getLogger("spam.eggs") + spam_eggs_fish = logging.getLogger("spam.eggs.fish") + spam_bakedbeans = logging.getLogger("spam.bakedbeans") + + spam.info(self.next_message()) + spam_eggs.info(self.next_message()) # Good. + spam_eggs_fish.info(self.next_message()) # Good. + spam_bakedbeans.info(self.next_message()) + + self.assert_log_lines([ + ('spam.eggs', 'INFO', '2'), + ('spam.eggs.fish', 'INFO', '3'), + ]) + finally: + handler.removeFilter(filter_) + + def test_callable_filter(self): + # Only messages satisfying the specified criteria pass through the + # filter. + + def filterfunc(record): + parts = record.name.split('.') + prefix = '.'.join(parts[:2]) + return prefix == 'spam.eggs' + + handler = self.root_logger.handlers[0] + try: + handler.addFilter(filterfunc) + spam = logging.getLogger("spam") + spam_eggs = logging.getLogger("spam.eggs") + spam_eggs_fish = logging.getLogger("spam.eggs.fish") + spam_bakedbeans = logging.getLogger("spam.bakedbeans") + + spam.info(self.next_message()) + spam_eggs.info(self.next_message()) # Good. + spam_eggs_fish.info(self.next_message()) # Good. + spam_bakedbeans.info(self.next_message()) + + self.assert_log_lines([ + ('spam.eggs', 'INFO', '2'), + ('spam.eggs.fish', 'INFO', '3'), + ]) + finally: + handler.removeFilter(filterfunc) + + def test_empty_filter(self): + f = logging.Filter() + r = logging.makeLogRecord({'name': 'spam.eggs'}) + self.assertTrue(f.filter(r)) + +# +# First, we define our levels. There can be as many as you want - the only +# limitations are that they should be integers, the lowest should be > 0 and +# larger values mean less information being logged. If you need specific +# level values which do not fit into these limitations, you can use a +# mapping dictionary to convert between your application levels and the +# logging system. +# +SILENT = 120 +TACITURN = 119 +TERSE = 118 +EFFUSIVE = 117 +SOCIABLE = 116 +VERBOSE = 115 +TALKATIVE = 114 +GARRULOUS = 113 +CHATTERBOX = 112 +BORING = 111 + +LEVEL_RANGE = range(BORING, SILENT + 1) + +# +# Next, we define names for our levels. You don't need to do this - in which +# case the system will use "Level n" to denote the text for the level. +# +my_logging_levels = { + SILENT : 'Silent', + TACITURN : 'Taciturn', + TERSE : 'Terse', + EFFUSIVE : 'Effusive', + SOCIABLE : 'Sociable', + VERBOSE : 'Verbose', + TALKATIVE : 'Talkative', + GARRULOUS : 'Garrulous', + CHATTERBOX : 'Chatterbox', + BORING : 'Boring', +} + +class GarrulousFilter(logging.Filter): + + """A filter which blocks garrulous messages.""" + + def filter(self, record): + return record.levelno != GARRULOUS + +class VerySpecificFilter(logging.Filter): + + """A filter which blocks sociable and taciturn messages.""" + + def filter(self, record): + return record.levelno not in [SOCIABLE, TACITURN] + + +class CustomLevelsAndFiltersTest(BaseTest): + + """Test various filtering possibilities with custom logging levels.""" + + # Skip the logger name group. + expected_log_pat = r"^[\w.]+ -> (\w+): (\d+)$" + + def setUp(self): + BaseTest.setUp(self) + for k, v in my_logging_levels.items(): + logging.addLevelName(k, v) + + def log_at_all_levels(self, logger): + for lvl in LEVEL_RANGE: + logger.log(lvl, self.next_message()) + + def test_handler_filter_replaces_record(self): + def replace_message(record: logging.LogRecord): + record = copy.copy(record) + record.msg = "new message!" + return record + + # Set up a logging hierarchy such that "child" and it's handler + # (and thus `replace_message()`) always get called before + # propagating up to "parent". + # Then we can confirm that `replace_message()` was able to + # replace the log record without having a side effect on + # other loggers or handlers. + parent = logging.getLogger("parent") + child = logging.getLogger("parent.child") + stream_1 = io.StringIO() + stream_2 = io.StringIO() + handler_1 = logging.StreamHandler(stream_1) + handler_2 = logging.StreamHandler(stream_2) + handler_2.addFilter(replace_message) + parent.addHandler(handler_1) + child.addHandler(handler_2) + + child.info("original message") + handler_1.flush() + handler_2.flush() + self.assertEqual(stream_1.getvalue(), "original message\n") + self.assertEqual(stream_2.getvalue(), "new message!\n") + + def test_logging_filter_replaces_record(self): + records = set() + + class RecordingFilter(logging.Filter): + def filter(self, record: logging.LogRecord): + records.add(id(record)) + return copy.copy(record) + + logger = logging.getLogger("logger") + logger.setLevel(logging.INFO) + logger.addFilter(RecordingFilter()) + logger.addFilter(RecordingFilter()) + + logger.info("msg") + + self.assertEqual(2, len(records)) + + def test_logger_filter(self): + # Filter at logger level. + self.root_logger.setLevel(VERBOSE) + # Levels >= 'Verbose' are good. + self.log_at_all_levels(self.root_logger) + self.assert_log_lines([ + ('Verbose', '5'), + ('Sociable', '6'), + ('Effusive', '7'), + ('Terse', '8'), + ('Taciturn', '9'), + ('Silent', '10'), + ]) + + def test_handler_filter(self): + # Filter at handler level. + self.root_logger.handlers[0].setLevel(SOCIABLE) + try: + # Levels >= 'Sociable' are good. + self.log_at_all_levels(self.root_logger) + self.assert_log_lines([ + ('Sociable', '6'), + ('Effusive', '7'), + ('Terse', '8'), + ('Taciturn', '9'), + ('Silent', '10'), + ]) + finally: + self.root_logger.handlers[0].setLevel(logging.NOTSET) + + def test_specific_filters(self): + # Set a specific filter object on the handler, and then add another + # filter object on the logger itself. + handler = self.root_logger.handlers[0] + specific_filter = None + garr = GarrulousFilter() + handler.addFilter(garr) + try: + self.log_at_all_levels(self.root_logger) + first_lines = [ + # Notice how 'Garrulous' is missing + ('Boring', '1'), + ('Chatterbox', '2'), + ('Talkative', '4'), + ('Verbose', '5'), + ('Sociable', '6'), + ('Effusive', '7'), + ('Terse', '8'), + ('Taciturn', '9'), + ('Silent', '10'), + ] + self.assert_log_lines(first_lines) + + specific_filter = VerySpecificFilter() + self.root_logger.addFilter(specific_filter) + self.log_at_all_levels(self.root_logger) + self.assert_log_lines(first_lines + [ + # Not only 'Garrulous' is still missing, but also 'Sociable' + # and 'Taciturn' + ('Boring', '11'), + ('Chatterbox', '12'), + ('Talkative', '14'), + ('Verbose', '15'), + ('Effusive', '17'), + ('Terse', '18'), + ('Silent', '20'), + ]) + finally: + if specific_filter: + self.root_logger.removeFilter(specific_filter) + handler.removeFilter(garr) + + +def make_temp_file(*args, **kwargs): + fd, fn = tempfile.mkstemp(*args, **kwargs) + os.close(fd) + return fn + + +class HandlerTest(BaseTest): + def test_name(self): + h = logging.Handler() + h.name = 'generic' + self.assertEqual(h.name, 'generic') + h.name = 'anothergeneric' + self.assertEqual(h.name, 'anothergeneric') + self.assertRaises(NotImplementedError, h.emit, None) + + def test_builtin_handlers(self): + # We can't actually *use* too many handlers in the tests, + # but we can try instantiating them with various options + if sys.platform in ('linux', 'darwin'): + for existing in (True, False): + fn = make_temp_file() + if not existing: + os.unlink(fn) + h = logging.handlers.WatchedFileHandler(fn, encoding='utf-8', delay=True) + if existing: + dev, ino = h.dev, h.ino + self.assertEqual(dev, -1) + self.assertEqual(ino, -1) + r = logging.makeLogRecord({'msg': 'Test'}) + h.handle(r) + # Now remove the file. + os.unlink(fn) + self.assertFalse(os.path.exists(fn)) + # The next call should recreate the file. + h.handle(r) + self.assertTrue(os.path.exists(fn)) + else: + self.assertEqual(h.dev, -1) + self.assertEqual(h.ino, -1) + h.close() + if existing: + os.unlink(fn) + if sys.platform == 'darwin': + sockname = '/var/run/syslog' + else: + sockname = '/dev/log' + try: + h = logging.handlers.SysLogHandler(sockname) + self.assertEqual(h.facility, h.LOG_USER) + self.assertTrue(h.unixsocket) + h.close() + except OSError: # syslogd might not be available + pass + for method in ('GET', 'POST', 'PUT'): + if method == 'PUT': + self.assertRaises(ValueError, logging.handlers.HTTPHandler, + 'localhost', '/log', method) + else: + h = logging.handlers.HTTPHandler('localhost', '/log', method) + h.close() + h = logging.handlers.BufferingHandler(0) + r = logging.makeLogRecord({}) + self.assertTrue(h.shouldFlush(r)) + h.close() + h = logging.handlers.BufferingHandler(1) + self.assertFalse(h.shouldFlush(r)) + h.close() + + def test_pathlike_objects(self): + """ + Test that path-like objects are accepted as filename arguments to handlers. + + See Issue #27493. + """ + fn = make_temp_file() + os.unlink(fn) + pfn = os_helper.FakePath(fn) + cases = ( + (logging.FileHandler, (pfn, 'w')), + (logging.handlers.RotatingFileHandler, (pfn, 'a')), + (logging.handlers.TimedRotatingFileHandler, (pfn, 'h')), + ) + if sys.platform in ('linux', 'darwin'): + cases += ((logging.handlers.WatchedFileHandler, (pfn, 'w')),) + for cls, args in cases: + h = cls(*args, encoding="utf-8") + self.assertTrue(os.path.exists(fn)) + h.close() + os.unlink(fn) + + @unittest.skipIf(os.name == 'nt', 'WatchedFileHandler not appropriate for Windows.') + @unittest.skipIf( + support.is_emscripten, "Emscripten cannot fstat unlinked files." + ) + @threading_helper.requires_working_threading() + @support.requires_resource('walltime') + def test_race(self): + # Issue #14632 refers. + def remove_loop(fname, tries): + for _ in range(tries): + try: + os.unlink(fname) + self.deletion_time = time.time() + except OSError: + pass + time.sleep(0.004 * random.randint(0, 4)) + + del_count = 500 + log_count = 500 + + self.handle_time = None + self.deletion_time = None + + for delay in (False, True): + fn = make_temp_file('.log', 'test_logging-3-') + remover = threading.Thread(target=remove_loop, args=(fn, del_count)) + remover.daemon = True + remover.start() + h = logging.handlers.WatchedFileHandler(fn, encoding='utf-8', delay=delay) + f = logging.Formatter('%(asctime)s: %(levelname)s: %(message)s') + h.setFormatter(f) + try: + for _ in range(log_count): + time.sleep(0.005) + r = logging.makeLogRecord({'msg': 'testing' }) + try: + self.handle_time = time.time() + h.handle(r) + except Exception: + print('Deleted at %s, ' + 'opened at %s' % (self.deletion_time, + self.handle_time)) + raise + finally: + remover.join() + h.close() + if os.path.exists(fn): + os.unlink(fn) + + # The implementation relies on os.register_at_fork existing, but we test + # based on os.fork existing because that is what users and this test use. + # This helps ensure that when fork exists (the important concept) that the + # register_at_fork mechanism is also present and used. + @support.requires_fork() + @threading_helper.requires_working_threading() + @skip_if_asan_fork + @skip_if_tsan_fork + def test_post_fork_child_no_deadlock(self): + """Ensure child logging locks are not held; bpo-6721 & bpo-36533.""" + class _OurHandler(logging.Handler): + def __init__(self): + super().__init__() + self.sub_handler = logging.StreamHandler( + stream=open('/dev/null', 'wt', encoding='utf-8')) + + def emit(self, record): + self.sub_handler.acquire() + try: + self.sub_handler.emit(record) + finally: + self.sub_handler.release() + + self.assertEqual(len(logging._handlers), 0) + refed_h = _OurHandler() + self.addCleanup(refed_h.sub_handler.stream.close) + refed_h.name = 'because we need at least one for this test' + self.assertGreater(len(logging._handlers), 0) + self.assertGreater(len(logging._at_fork_reinit_lock_weakset), 1) + test_logger = logging.getLogger('test_post_fork_child_no_deadlock') + test_logger.addHandler(refed_h) + test_logger.setLevel(logging.DEBUG) + + locks_held__ready_to_fork = threading.Event() + fork_happened__release_locks_and_end_thread = threading.Event() + + def lock_holder_thread_fn(): + logging._acquireLock() + try: + refed_h.acquire() + try: + # Tell the main thread to do the fork. + locks_held__ready_to_fork.set() + + # If the deadlock bug exists, the fork will happen + # without dealing with the locks we hold, deadlocking + # the child. + + # Wait for a successful fork or an unreasonable amount of + # time before releasing our locks. To avoid a timing based + # test we'd need communication from os.fork() as to when it + # has actually happened. Given this is a regression test + # for a fixed issue, potentially less reliably detecting + # regression via timing is acceptable for simplicity. + # The test will always take at least this long. :( + fork_happened__release_locks_and_end_thread.wait(0.5) + finally: + refed_h.release() + finally: + logging._releaseLock() + + lock_holder_thread = threading.Thread( + target=lock_holder_thread_fn, + name='test_post_fork_child_no_deadlock lock holder') + lock_holder_thread.start() + + locks_held__ready_to_fork.wait() + pid = os.fork() + if pid == 0: + # Child process + try: + test_logger.info(r'Child process did not deadlock. \o/') + finally: + os._exit(0) + else: + # Parent process + test_logger.info(r'Parent process returned from fork. \o/') + fork_happened__release_locks_and_end_thread.set() + lock_holder_thread.join() + + support.wait_process(pid, exitcode=0) + + +class BadStream(object): + def write(self, data): + raise RuntimeError('deliberate mistake') + +class TestStreamHandler(logging.StreamHandler): + def handleError(self, record): + self.error_record = record + +class StreamWithIntName(object): + level = logging.NOTSET + name = 2 + +class StreamHandlerTest(BaseTest): + def test_error_handling(self): + h = TestStreamHandler(BadStream()) + r = logging.makeLogRecord({}) + old_raise = logging.raiseExceptions + + try: + h.handle(r) + self.assertIs(h.error_record, r) + + h = logging.StreamHandler(BadStream()) + with support.captured_stderr() as stderr: + h.handle(r) + msg = '\nRuntimeError: deliberate mistake\n' + self.assertIn(msg, stderr.getvalue()) + + logging.raiseExceptions = False + with support.captured_stderr() as stderr: + h.handle(r) + self.assertEqual('', stderr.getvalue()) + finally: + logging.raiseExceptions = old_raise + + def test_stream_setting(self): + """ + Test setting the handler's stream + """ + h = logging.StreamHandler() + stream = io.StringIO() + old = h.setStream(stream) + self.assertIs(old, sys.stderr) + actual = h.setStream(old) + self.assertIs(actual, stream) + # test that setting to existing value returns None + actual = h.setStream(old) + self.assertIsNone(actual) + + def test_can_represent_stream_with_int_name(self): + h = logging.StreamHandler(StreamWithIntName()) + self.assertEqual(repr(h), '') + +# -- The following section could be moved into a server_helper.py module +# -- if it proves to be of wider utility than just test_logging + +class TestSMTPServer(smtpd.SMTPServer): + """ + This class implements a test SMTP server. + + :param addr: A (host, port) tuple which the server listens on. + You can specify a port value of zero: the server's + *port* attribute will hold the actual port number + used, which can be used in client connections. + :param handler: A callable which will be called to process + incoming messages. The handler will be passed + the client address tuple, who the message is from, + a list of recipients and the message data. + :param poll_interval: The interval, in seconds, used in the underlying + :func:`select` or :func:`poll` call by + :func:`asyncore.loop`. + :param sockmap: A dictionary which will be used to hold + :class:`asyncore.dispatcher` instances used by + :func:`asyncore.loop`. This avoids changing the + :mod:`asyncore` module's global state. + """ + + def __init__(self, addr, handler, poll_interval, sockmap): + smtpd.SMTPServer.__init__(self, addr, None, map=sockmap, + decode_data=True) + self.port = self.socket.getsockname()[1] + self._handler = handler + self._thread = None + self._quit = False + self.poll_interval = poll_interval + + def process_message(self, peer, mailfrom, rcpttos, data): + """ + Delegates to the handler passed in to the server's constructor. + + Typically, this will be a test case method. + :param peer: The client (host, port) tuple. + :param mailfrom: The address of the sender. + :param rcpttos: The addresses of the recipients. + :param data: The message. + """ + self._handler(peer, mailfrom, rcpttos, data) + + def start(self): + """ + Start the server running on a separate daemon thread. + """ + self._thread = t = threading.Thread(target=self.serve_forever, + args=(self.poll_interval,)) + t.daemon = True + t.start() + + def serve_forever(self, poll_interval): + """ + Run the :mod:`asyncore` loop until normal termination + conditions arise. + :param poll_interval: The interval, in seconds, used in the underlying + :func:`select` or :func:`poll` call by + :func:`asyncore.loop`. + """ + while not self._quit: + asyncore.loop(poll_interval, map=self._map, count=1) + + def stop(self): + """ + Stop the thread by closing the server instance. + Wait for the server thread to terminate. + """ + self._quit = True + threading_helper.join_thread(self._thread) + self._thread = None + self.close() + asyncore.close_all(map=self._map, ignore_all=True) + + +class ControlMixin(object): + """ + This mixin is used to start a server on a separate thread, and + shut it down programmatically. Request handling is simplified - instead + of needing to derive a suitable RequestHandler subclass, you just + provide a callable which will be passed each received request to be + processed. + + :param handler: A handler callable which will be called with a + single parameter - the request - in order to + process the request. This handler is called on the + server thread, effectively meaning that requests are + processed serially. While not quite web scale ;-), + this should be fine for testing applications. + :param poll_interval: The polling interval in seconds. + """ + def __init__(self, handler, poll_interval): + self._thread = None + self.poll_interval = poll_interval + self._handler = handler + self.ready = threading.Event() + + def start(self): + """ + Create a daemon thread to run the server, and start it. + """ + self._thread = t = threading.Thread(target=self.serve_forever, + args=(self.poll_interval,)) + t.daemon = True + t.start() + + def serve_forever(self, poll_interval): + """ + Run the server. Set the ready flag before entering the + service loop. + """ + self.ready.set() + super(ControlMixin, self).serve_forever(poll_interval) + + def stop(self): + """ + Tell the server thread to stop, and wait for it to do so. + """ + self.shutdown() + if self._thread is not None: + threading_helper.join_thread(self._thread) + self._thread = None + self.server_close() + self.ready.clear() + +class TestHTTPServer(ControlMixin, HTTPServer): + """ + An HTTP server which is controllable using :class:`ControlMixin`. + + :param addr: A tuple with the IP address and port to listen on. + :param handler: A handler callable which will be called with a + single parameter - the request - in order to + process the request. + :param poll_interval: The polling interval in seconds. + :param log: Pass ``True`` to enable log messages. + """ + def __init__(self, addr, handler, poll_interval=0.5, + log=False, sslctx=None): + class DelegatingHTTPRequestHandler(BaseHTTPRequestHandler): + def __getattr__(self, name, default=None): + if name.startswith('do_'): + return self.process_request + raise AttributeError(name) + + def process_request(self): + self.server._handler(self) + + def log_message(self, format, *args): + if log: + super(DelegatingHTTPRequestHandler, + self).log_message(format, *args) + HTTPServer.__init__(self, addr, DelegatingHTTPRequestHandler) + ControlMixin.__init__(self, handler, poll_interval) + self.sslctx = sslctx + + def get_request(self): + try: + sock, addr = self.socket.accept() + if self.sslctx: + sock = self.sslctx.wrap_socket(sock, server_side=True) + except OSError as e: + # socket errors are silenced by the caller, print them here + sys.stderr.write("Got an error:\n%s\n" % e) + raise + return sock, addr + +class TestTCPServer(ControlMixin, ThreadingTCPServer): + """ + A TCP server which is controllable using :class:`ControlMixin`. + + :param addr: A tuple with the IP address and port to listen on. + :param handler: A handler callable which will be called with a single + parameter - the request - in order to process the request. + :param poll_interval: The polling interval in seconds. + :bind_and_activate: If True (the default), binds the server and starts it + listening. If False, you need to call + :meth:`server_bind` and :meth:`server_activate` at + some later time before calling :meth:`start`, so that + the server will set up the socket and listen on it. + """ + + allow_reuse_address = True + + def __init__(self, addr, handler, poll_interval=0.5, + bind_and_activate=True): + class DelegatingTCPRequestHandler(StreamRequestHandler): + + def handle(self): + self.server._handler(self) + ThreadingTCPServer.__init__(self, addr, DelegatingTCPRequestHandler, + bind_and_activate) + ControlMixin.__init__(self, handler, poll_interval) + + def server_bind(self): + super(TestTCPServer, self).server_bind() + self.port = self.socket.getsockname()[1] + +class TestUDPServer(ControlMixin, ThreadingUDPServer): + """ + A UDP server which is controllable using :class:`ControlMixin`. + + :param addr: A tuple with the IP address and port to listen on. + :param handler: A handler callable which will be called with a + single parameter - the request - in order to + process the request. + :param poll_interval: The polling interval for shutdown requests, + in seconds. + :bind_and_activate: If True (the default), binds the server and + starts it listening. If False, you need to + call :meth:`server_bind` and + :meth:`server_activate` at some later time + before calling :meth:`start`, so that the server will + set up the socket and listen on it. + """ + def __init__(self, addr, handler, poll_interval=0.5, + bind_and_activate=True): + class DelegatingUDPRequestHandler(DatagramRequestHandler): + + def handle(self): + self.server._handler(self) + + def finish(self): + data = self.wfile.getvalue() + if data: + try: + super(DelegatingUDPRequestHandler, self).finish() + except OSError: + if not self.server._closed: + raise + + ThreadingUDPServer.__init__(self, addr, + DelegatingUDPRequestHandler, + bind_and_activate) + ControlMixin.__init__(self, handler, poll_interval) + self._closed = False + + def server_bind(self): + super(TestUDPServer, self).server_bind() + self.port = self.socket.getsockname()[1] + + def server_close(self): + super(TestUDPServer, self).server_close() + self._closed = True + +if hasattr(socket, "AF_UNIX"): + class TestUnixStreamServer(TestTCPServer): + address_family = socket.AF_UNIX + + class TestUnixDatagramServer(TestUDPServer): + address_family = socket.AF_UNIX + +# - end of server_helper section + +@support.requires_working_socket() +@threading_helper.requires_working_threading() +class SMTPHandlerTest(BaseTest): + # bpo-14314, bpo-19665, bpo-34092: don't wait forever + TIMEOUT = support.LONG_TIMEOUT + + def test_basic(self): + sockmap = {} + server = TestSMTPServer((socket_helper.HOST, 0), self.process_message, 0.001, + sockmap) + server.start() + addr = (socket_helper.HOST, server.port) + h = logging.handlers.SMTPHandler(addr, 'me', 'you', 'Log', + timeout=self.TIMEOUT) + self.assertEqual(h.toaddrs, ['you']) + self.messages = [] + r = logging.makeLogRecord({'msg': 'Hello \u2713'}) + self.handled = threading.Event() + h.handle(r) + self.handled.wait(self.TIMEOUT) + server.stop() + self.assertTrue(self.handled.is_set()) + self.assertEqual(len(self.messages), 1) + peer, mailfrom, rcpttos, data = self.messages[0] + self.assertEqual(mailfrom, 'me') + self.assertEqual(rcpttos, ['you']) + self.assertIn('\nSubject: Log\n', data) + self.assertTrue(data.endswith('\n\nHello \u2713')) + h.close() + + def process_message(self, *args): + self.messages.append(args) + self.handled.set() + +class MemoryHandlerTest(BaseTest): + + """Tests for the MemoryHandler.""" + + # Do not bother with a logger name group. + expected_log_pat = r"^[\w.]+ -> (\w+): (\d+)$" + + def setUp(self): + BaseTest.setUp(self) + self.mem_hdlr = logging.handlers.MemoryHandler(10, logging.WARNING, + self.root_hdlr) + self.mem_logger = logging.getLogger('mem') + self.mem_logger.propagate = 0 + self.mem_logger.addHandler(self.mem_hdlr) + + def tearDown(self): + self.mem_hdlr.close() + BaseTest.tearDown(self) + + def test_flush(self): + # The memory handler flushes to its target handler based on specific + # criteria (message count and message level). + self.mem_logger.debug(self.next_message()) + self.assert_log_lines([]) + self.mem_logger.info(self.next_message()) + self.assert_log_lines([]) + # This will flush because the level is >= logging.WARNING + self.mem_logger.warning(self.next_message()) + lines = [ + ('DEBUG', '1'), + ('INFO', '2'), + ('WARNING', '3'), + ] + self.assert_log_lines(lines) + for n in (4, 14): + for i in range(9): + self.mem_logger.debug(self.next_message()) + self.assert_log_lines(lines) + # This will flush because it's the 10th message since the last + # flush. + self.mem_logger.debug(self.next_message()) + lines = lines + [('DEBUG', str(i)) for i in range(n, n + 10)] + self.assert_log_lines(lines) + + self.mem_logger.debug(self.next_message()) + self.assert_log_lines(lines) + + def test_flush_on_close(self): + """ + Test that the flush-on-close configuration works as expected. + """ + self.mem_logger.debug(self.next_message()) + self.assert_log_lines([]) + self.mem_logger.info(self.next_message()) + self.assert_log_lines([]) + self.mem_logger.removeHandler(self.mem_hdlr) + # Default behaviour is to flush on close. Check that it happens. + self.mem_hdlr.close() + lines = [ + ('DEBUG', '1'), + ('INFO', '2'), + ] + self.assert_log_lines(lines) + # Now configure for flushing not to be done on close. + self.mem_hdlr = logging.handlers.MemoryHandler(10, logging.WARNING, + self.root_hdlr, + False) + self.mem_logger.addHandler(self.mem_hdlr) + self.mem_logger.debug(self.next_message()) + self.assert_log_lines(lines) # no change + self.mem_logger.info(self.next_message()) + self.assert_log_lines(lines) # no change + self.mem_logger.removeHandler(self.mem_hdlr) + self.mem_hdlr.close() + # assert that no new lines have been added + self.assert_log_lines(lines) # no change + + def test_shutdown_flush_on_close(self): + """ + Test that the flush-on-close configuration is respected by the + shutdown method. + """ + self.mem_logger.debug(self.next_message()) + self.assert_log_lines([]) + self.mem_logger.info(self.next_message()) + self.assert_log_lines([]) + # Default behaviour is to flush on close. Check that it happens. + logging.shutdown(handlerList=[logging.weakref.ref(self.mem_hdlr)]) + lines = [ + ('DEBUG', '1'), + ('INFO', '2'), + ] + self.assert_log_lines(lines) + # Now configure for flushing not to be done on close. + self.mem_hdlr = logging.handlers.MemoryHandler(10, logging.WARNING, + self.root_hdlr, + False) + self.mem_logger.addHandler(self.mem_hdlr) + self.mem_logger.debug(self.next_message()) + self.assert_log_lines(lines) # no change + self.mem_logger.info(self.next_message()) + self.assert_log_lines(lines) # no change + # assert that no new lines have been added after shutdown + logging.shutdown(handlerList=[logging.weakref.ref(self.mem_hdlr)]) + self.assert_log_lines(lines) # no change + + @threading_helper.requires_working_threading() + def test_race_between_set_target_and_flush(self): + class MockRaceConditionHandler: + def __init__(self, mem_hdlr): + self.mem_hdlr = mem_hdlr + self.threads = [] + + def removeTarget(self): + self.mem_hdlr.setTarget(None) + + def handle(self, msg): + thread = threading.Thread(target=self.removeTarget) + self.threads.append(thread) + thread.start() + + target = MockRaceConditionHandler(self.mem_hdlr) + try: + self.mem_hdlr.setTarget(target) + + for _ in range(10): + time.sleep(0.005) + self.mem_logger.info("not flushed") + self.mem_logger.warning("flushed") + finally: + for thread in target.threads: + threading_helper.join_thread(thread) + + +class ExceptionFormatter(logging.Formatter): + """A special exception formatter.""" + def formatException(self, ei): + return "Got a [%s]" % ei[0].__name__ + +def closeFileHandler(h, fn): + h.close() + os.remove(fn) + +class ConfigFileTest(BaseTest): + + """Reading logging config from a .ini-style config file.""" + + check_no_resource_warning = warnings_helper.check_no_resource_warning + expected_log_pat = r"^(\w+) \+\+ (\w+)$" + + # config0 is a standard configuration. + config0 = """ + [loggers] + keys=root + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + level=WARNING + handlers=hand1 + + [handler_hand1] + class=StreamHandler + level=NOTSET + formatter=form1 + args=(sys.stdout,) + + [formatter_form1] + format=%(levelname)s ++ %(message)s + datefmt= + """ + + # config1 adds a little to the standard configuration. + config1 = """ + [loggers] + keys=root,parser + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + level=WARNING + handlers= + + [logger_parser] + level=DEBUG + handlers=hand1 + propagate=1 + qualname=compiler.parser + + [handler_hand1] + class=StreamHandler + level=NOTSET + formatter=form1 + args=(sys.stdout,) + + [formatter_form1] + format=%(levelname)s ++ %(message)s + datefmt= + """ + + # config1a moves the handler to the root. + config1a = """ + [loggers] + keys=root,parser + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + level=WARNING + handlers=hand1 + + [logger_parser] + level=DEBUG + handlers= + propagate=1 + qualname=compiler.parser + + [handler_hand1] + class=StreamHandler + level=NOTSET + formatter=form1 + args=(sys.stdout,) + + [formatter_form1] + format=%(levelname)s ++ %(message)s + datefmt= + """ + + # config2 has a subtle configuration error that should be reported + config2 = config1.replace("sys.stdout", "sys.stbout") + + # config3 has a less subtle configuration error + config3 = config1.replace("formatter=form1", "formatter=misspelled_name") + + # config4 specifies a custom formatter class to be loaded + config4 = """ + [loggers] + keys=root + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + level=NOTSET + handlers=hand1 + + [handler_hand1] + class=StreamHandler + level=NOTSET + formatter=form1 + args=(sys.stdout,) + + [formatter_form1] + class=""" + __name__ + """.ExceptionFormatter + format=%(levelname)s:%(name)s:%(message)s + datefmt= + """ + + # config5 specifies a custom handler class to be loaded + config5 = config1.replace('class=StreamHandler', 'class=logging.StreamHandler') + + # config6 uses ', ' delimiters in the handlers and formatters sections + config6 = """ + [loggers] + keys=root,parser + + [handlers] + keys=hand1, hand2 + + [formatters] + keys=form1, form2 + + [logger_root] + level=WARNING + handlers= + + [logger_parser] + level=DEBUG + handlers=hand1 + propagate=1 + qualname=compiler.parser + + [handler_hand1] + class=StreamHandler + level=NOTSET + formatter=form1 + args=(sys.stdout,) + + [handler_hand2] + class=StreamHandler + level=NOTSET + formatter=form1 + args=(sys.stderr,) + + [formatter_form1] + format=%(levelname)s ++ %(message)s + datefmt= + + [formatter_form2] + format=%(message)s + datefmt= + """ + + # config7 adds a compiler logger, and uses kwargs instead of args. + config7 = """ + [loggers] + keys=root,parser,compiler + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + level=WARNING + handlers=hand1 + + [logger_compiler] + level=DEBUG + handlers= + propagate=1 + qualname=compiler + + [logger_parser] + level=DEBUG + handlers= + propagate=1 + qualname=compiler.parser + + [handler_hand1] + class=StreamHandler + level=NOTSET + formatter=form1 + kwargs={'stream': sys.stdout,} + + [formatter_form1] + format=%(levelname)s ++ %(message)s + datefmt= + """ + + # config 8, check for resource warning + config8 = r""" + [loggers] + keys=root + + [handlers] + keys=file + + [formatters] + keys= + + [logger_root] + level=DEBUG + handlers=file + + [handler_file] + class=FileHandler + level=DEBUG + args=("{tempfile}",) + kwargs={{"encoding": "utf-8"}} + """ + + + config9 = """ + [loggers] + keys=root + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + level=WARNING + handlers=hand1 + + [handler_hand1] + class=StreamHandler + level=NOTSET + formatter=form1 + args=(sys.stdout,) + + [formatter_form1] + format=%(message)s ++ %(customfield)s + defaults={"customfield": "defaultvalue"} + """ + + disable_test = """ + [loggers] + keys=root + + [handlers] + keys=screen + + [formatters] + keys= + + [logger_root] + level=DEBUG + handlers=screen + + [handler_screen] + level=DEBUG + class=StreamHandler + args=(sys.stdout,) + formatter= + """ + + def apply_config(self, conf, **kwargs): + file = io.StringIO(textwrap.dedent(conf)) + logging.config.fileConfig(file, encoding="utf-8", **kwargs) + + def test_config0_ok(self): + # A simple config file which overrides the default settings. + with support.captured_stdout() as output: + self.apply_config(self.config0) + logger = logging.getLogger() + # Won't output anything + logger.info(self.next_message()) + # Outputs a message + logger.error(self.next_message()) + self.assert_log_lines([ + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config0_using_cp_ok(self): + # A simple config file which overrides the default settings. + with support.captured_stdout() as output: + file = io.StringIO(textwrap.dedent(self.config0)) + cp = configparser.ConfigParser() + cp.read_file(file) + logging.config.fileConfig(cp) + logger = logging.getLogger() + # Won't output anything + logger.info(self.next_message()) + # Outputs a message + logger.error(self.next_message()) + self.assert_log_lines([ + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config1_ok(self, config=config1): + # A config file defining a sub-parser as well. + with support.captured_stdout() as output: + self.apply_config(config) + logger = logging.getLogger("compiler.parser") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config2_failure(self): + # A simple config file which overrides the default settings. + self.assertRaises(Exception, self.apply_config, self.config2) + + def test_config3_failure(self): + # A simple config file which overrides the default settings. + self.assertRaises(Exception, self.apply_config, self.config3) + + def test_config4_ok(self): + # A config file specifying a custom formatter class. + with support.captured_stdout() as output: + self.apply_config(self.config4) + logger = logging.getLogger() + try: + raise RuntimeError() + except RuntimeError: + logging.exception("just testing") + sys.stdout.seek(0) + self.assertEqual(output.getvalue(), + "ERROR:root:just testing\nGot a [RuntimeError]\n") + # Original logger output is empty + self.assert_log_lines([]) + + def test_config5_ok(self): + self.test_config1_ok(config=self.config5) + + def test_config6_ok(self): + self.test_config1_ok(config=self.config6) + + def test_config7_ok(self): + with support.captured_stdout() as output: + self.apply_config(self.config1a) + logger = logging.getLogger("compiler.parser") + # See issue #11424. compiler-hyphenated sorts + # between compiler and compiler.xyz and this + # was preventing compiler.xyz from being included + # in the child loggers of compiler because of an + # overzealous loop termination condition. + hyphenated = logging.getLogger('compiler-hyphenated') + # All will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + hyphenated.critical(self.next_message()) + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ('CRITICAL', '3'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + with support.captured_stdout() as output: + self.apply_config(self.config7) + logger = logging.getLogger("compiler.parser") + self.assertFalse(logger.disabled) + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + logger = logging.getLogger("compiler.lexer") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + # Will not appear + hyphenated.critical(self.next_message()) + self.assert_log_lines([ + ('INFO', '4'), + ('ERROR', '5'), + ('INFO', '6'), + ('ERROR', '7'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config8_ok(self): + + with self.check_no_resource_warning(): + fn = make_temp_file(".log", "test_logging-X-") + + # Replace single backslash with double backslash in windows + # to avoid unicode error during string formatting + if os.name == "nt": + fn = fn.replace("\\", "\\\\") + + config8 = self.config8.format(tempfile=fn) + + self.apply_config(config8) + self.apply_config(config8) + + handler = logging.root.handlers[0] + self.addCleanup(closeFileHandler, handler, fn) + + def test_config9_ok(self): + self.apply_config(self.config9) + formatter = logging.root.handlers[0].formatter + result = formatter.format(logging.makeLogRecord({'msg': 'test'})) + self.assertEqual(result, 'test ++ defaultvalue') + result = formatter.format(logging.makeLogRecord( + {'msg': 'test', 'customfield': "customvalue"})) + self.assertEqual(result, 'test ++ customvalue') + + + def test_logger_disabling(self): + self.apply_config(self.disable_test) + logger = logging.getLogger('some_pristine_logger') + self.assertFalse(logger.disabled) + self.apply_config(self.disable_test) + self.assertTrue(logger.disabled) + self.apply_config(self.disable_test, disable_existing_loggers=False) + self.assertFalse(logger.disabled) + + def test_config_set_handler_names(self): + test_config = """ + [loggers] + keys=root + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + handlers=hand1 + + [handler_hand1] + class=StreamHandler + formatter=form1 + + [formatter_form1] + format=%(levelname)s ++ %(message)s + """ + self.apply_config(test_config) + self.assertEqual(logging.getLogger().handlers[0].name, 'hand1') + + def test_exception_if_confg_file_is_invalid(self): + test_config = """ + [loggers] + keys=root + + [handlers] + keys=hand1 + + [formatters] + keys=form1 + + [logger_root] + handlers=hand1 + + [handler_hand1] + class=StreamHandler + formatter=form1 + + [formatter_form1] + format=%(levelname)s ++ %(message)s + + prince + """ + + file = io.StringIO(textwrap.dedent(test_config)) + self.assertRaises(RuntimeError, logging.config.fileConfig, file) + + def test_exception_if_confg_file_is_empty(self): + fd, fn = tempfile.mkstemp(prefix='test_empty_', suffix='.ini') + os.close(fd) + self.assertRaises(RuntimeError, logging.config.fileConfig, fn) + os.remove(fn) + + def test_exception_if_config_file_does_not_exist(self): + self.assertRaises(FileNotFoundError, logging.config.fileConfig, 'filenotfound') + + def test_defaults_do_no_interpolation(self): + """bpo-33802 defaults should not get interpolated""" + ini = textwrap.dedent(""" + [formatters] + keys=default + + [formatter_default] + + [handlers] + keys=console + + [handler_console] + class=logging.StreamHandler + args=tuple() + + [loggers] + keys=root + + [logger_root] + formatter=default + handlers=console + """).strip() + fd, fn = tempfile.mkstemp(prefix='test_logging_', suffix='.ini') + try: + os.write(fd, ini.encode('ascii')) + os.close(fd) + logging.config.fileConfig( + fn, + encoding="utf-8", + defaults=dict( + version=1, + disable_existing_loggers=False, + formatters={ + "generic": { + "format": "%(asctime)s [%(process)d] [%(levelname)s] %(message)s", + "datefmt": "[%Y-%m-%d %H:%M:%S %z]", + "class": "logging.Formatter" + }, + }, + ) + ) + finally: + os.unlink(fn) + + +@support.requires_working_socket() +@threading_helper.requires_working_threading() +class SocketHandlerTest(BaseTest): + + """Test for SocketHandler objects.""" + + server_class = TestTCPServer + address = ('localhost', 0) + + def setUp(self): + """Set up a TCP server to receive log messages, and a SocketHandler + pointing to that server's address and port.""" + BaseTest.setUp(self) + # Issue #29177: deal with errors that happen during setup + self.server = self.sock_hdlr = self.server_exception = None + try: + self.server = server = self.server_class(self.address, + self.handle_socket, 0.01) + server.start() + # Uncomment next line to test error recovery in setUp() + # raise OSError('dummy error raised') + except OSError as e: + self.server_exception = e + return + server.ready.wait() + hcls = logging.handlers.SocketHandler + if isinstance(server.server_address, tuple): + self.sock_hdlr = hcls('localhost', server.port) + else: + self.sock_hdlr = hcls(server.server_address, None) + self.log_output = '' + self.root_logger.removeHandler(self.root_logger.handlers[0]) + self.root_logger.addHandler(self.sock_hdlr) + self.handled = threading.Semaphore(0) + + def tearDown(self): + """Shutdown the TCP server.""" + try: + if self.sock_hdlr: + self.root_logger.removeHandler(self.sock_hdlr) + self.sock_hdlr.close() + if self.server: + self.server.stop() + finally: + BaseTest.tearDown(self) + + def handle_socket(self, request): + conn = request.connection + while True: + chunk = conn.recv(4) + if len(chunk) < 4: + break + slen = struct.unpack(">L", chunk)[0] + chunk = conn.recv(slen) + while len(chunk) < slen: + chunk = chunk + conn.recv(slen - len(chunk)) + obj = pickle.loads(chunk) + record = logging.makeLogRecord(obj) + self.log_output += record.msg + '\n' + self.handled.release() + + def test_output(self): + # The log message sent to the SocketHandler is properly received. + if self.server_exception: + self.skipTest(self.server_exception) + logger = logging.getLogger("tcp") + logger.error("spam") + self.handled.acquire() + logger.debug("eggs") + self.handled.acquire() + self.assertEqual(self.log_output, "spam\neggs\n") + + def test_noserver(self): + if self.server_exception: + self.skipTest(self.server_exception) + # Avoid timing-related failures due to SocketHandler's own hard-wired + # one-second timeout on socket.create_connection() (issue #16264). + self.sock_hdlr.retryStart = 2.5 + # Kill the server + self.server.stop() + # The logging call should try to connect, which should fail + try: + raise RuntimeError('Deliberate mistake') + except RuntimeError: + self.root_logger.exception('Never sent') + self.root_logger.error('Never sent, either') + now = time.time() + self.assertGreater(self.sock_hdlr.retryTime, now) + time.sleep(self.sock_hdlr.retryTime - now + 0.001) + self.root_logger.error('Nor this') + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required") +class UnixSocketHandlerTest(SocketHandlerTest): + + """Test for SocketHandler with unix sockets.""" + + if hasattr(socket, "AF_UNIX"): + server_class = TestUnixStreamServer + + def setUp(self): + # override the definition in the base class + self.address = socket_helper.create_unix_domain_name() + self.addCleanup(os_helper.unlink, self.address) + SocketHandlerTest.setUp(self) + +@support.requires_working_socket() +@threading_helper.requires_working_threading() +class DatagramHandlerTest(BaseTest): + + """Test for DatagramHandler.""" + + server_class = TestUDPServer + address = ('localhost', 0) + + def setUp(self): + """Set up a UDP server to receive log messages, and a DatagramHandler + pointing to that server's address and port.""" + BaseTest.setUp(self) + # Issue #29177: deal with errors that happen during setup + self.server = self.sock_hdlr = self.server_exception = None + try: + self.server = server = self.server_class(self.address, + self.handle_datagram, 0.01) + server.start() + # Uncomment next line to test error recovery in setUp() + # raise OSError('dummy error raised') + except OSError as e: + self.server_exception = e + return + server.ready.wait() + hcls = logging.handlers.DatagramHandler + if isinstance(server.server_address, tuple): + self.sock_hdlr = hcls('localhost', server.port) + else: + self.sock_hdlr = hcls(server.server_address, None) + self.log_output = '' + self.root_logger.removeHandler(self.root_logger.handlers[0]) + self.root_logger.addHandler(self.sock_hdlr) + self.handled = threading.Event() + + def tearDown(self): + """Shutdown the UDP server.""" + try: + if self.server: + self.server.stop() + if self.sock_hdlr: + self.root_logger.removeHandler(self.sock_hdlr) + self.sock_hdlr.close() + finally: + BaseTest.tearDown(self) + + def handle_datagram(self, request): + slen = struct.pack('>L', 0) # length of prefix + packet = request.packet[len(slen):] + obj = pickle.loads(packet) + record = logging.makeLogRecord(obj) + self.log_output += record.msg + '\n' + self.handled.set() + + def test_output(self): + # The log message sent to the DatagramHandler is properly received. + if self.server_exception: + self.skipTest(self.server_exception) + logger = logging.getLogger("udp") + logger.error("spam") + self.handled.wait() + self.handled.clear() + logger.error("eggs") + self.handled.wait() + self.assertEqual(self.log_output, "spam\neggs\n") + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required") +class UnixDatagramHandlerTest(DatagramHandlerTest): + + """Test for DatagramHandler using Unix sockets.""" + + if hasattr(socket, "AF_UNIX"): + server_class = TestUnixDatagramServer + + def setUp(self): + # override the definition in the base class + self.address = socket_helper.create_unix_domain_name() + self.addCleanup(os_helper.unlink, self.address) + DatagramHandlerTest.setUp(self) + +@support.requires_working_socket() +@threading_helper.requires_working_threading() +class SysLogHandlerTest(BaseTest): + + """Test for SysLogHandler using UDP.""" + + server_class = TestUDPServer + address = ('localhost', 0) + + def setUp(self): + """Set up a UDP server to receive log messages, and a SysLogHandler + pointing to that server's address and port.""" + BaseTest.setUp(self) + # Issue #29177: deal with errors that happen during setup + self.server = self.sl_hdlr = self.server_exception = None + try: + self.server = server = self.server_class(self.address, + self.handle_datagram, 0.01) + server.start() + # Uncomment next line to test error recovery in setUp() + # raise OSError('dummy error raised') + except OSError as e: + self.server_exception = e + return + server.ready.wait() + hcls = logging.handlers.SysLogHandler + if isinstance(server.server_address, tuple): + self.sl_hdlr = hcls((server.server_address[0], server.port)) + else: + self.sl_hdlr = hcls(server.server_address) + self.log_output = b'' + self.root_logger.removeHandler(self.root_logger.handlers[0]) + self.root_logger.addHandler(self.sl_hdlr) + self.handled = threading.Event() + + def tearDown(self): + """Shutdown the server.""" + try: + if self.server: + self.server.stop() + if self.sl_hdlr: + self.root_logger.removeHandler(self.sl_hdlr) + self.sl_hdlr.close() + finally: + BaseTest.tearDown(self) + + def handle_datagram(self, request): + self.log_output = request.packet + self.handled.set() + + def test_output(self): + if self.server_exception: + self.skipTest(self.server_exception) + # The log message sent to the SysLogHandler is properly received. + logger = logging.getLogger("slh") + logger.error("sp\xe4m") + self.handled.wait(support.LONG_TIMEOUT) + self.assertEqual(self.log_output, b'<11>sp\xc3\xa4m\x00') + self.handled.clear() + self.sl_hdlr.append_nul = False + logger.error("sp\xe4m") + self.handled.wait(support.LONG_TIMEOUT) + self.assertEqual(self.log_output, b'<11>sp\xc3\xa4m') + self.handled.clear() + self.sl_hdlr.ident = "h\xe4m-" + logger.error("sp\xe4m") + self.handled.wait(support.LONG_TIMEOUT) + self.assertEqual(self.log_output, b'<11>h\xc3\xa4m-sp\xc3\xa4m') + + def test_udp_reconnection(self): + logger = logging.getLogger("slh") + self.sl_hdlr.close() + self.handled.clear() + logger.error("sp\xe4m") + self.handled.wait(support.LONG_TIMEOUT) + self.assertEqual(self.log_output, b'<11>sp\xc3\xa4m\x00') + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required") +class UnixSysLogHandlerTest(SysLogHandlerTest): + + """Test for SysLogHandler with Unix sockets.""" + + if hasattr(socket, "AF_UNIX"): + server_class = TestUnixDatagramServer + + def setUp(self): + # override the definition in the base class + self.address = socket_helper.create_unix_domain_name() + self.addCleanup(os_helper.unlink, self.address) + SysLogHandlerTest.setUp(self) + +@unittest.skipUnless(socket_helper.IPV6_ENABLED, + 'IPv6 support required for this test.') +class IPv6SysLogHandlerTest(SysLogHandlerTest): + + """Test for SysLogHandler with IPv6 host.""" + + server_class = TestUDPServer + address = ('::1', 0) + + def setUp(self): + self.server_class.address_family = socket.AF_INET6 + super(IPv6SysLogHandlerTest, self).setUp() + + def tearDown(self): + self.server_class.address_family = socket.AF_INET + super(IPv6SysLogHandlerTest, self).tearDown() + +@support.requires_working_socket() +@threading_helper.requires_working_threading() +class HTTPHandlerTest(BaseTest): + """Test for HTTPHandler.""" + + def setUp(self): + """Set up an HTTP server to receive log messages, and a HTTPHandler + pointing to that server's address and port.""" + BaseTest.setUp(self) + self.handled = threading.Event() + + def handle_request(self, request): + self.command = request.command + self.log_data = urlparse(request.path) + if self.command == 'POST': + try: + rlen = int(request.headers['Content-Length']) + self.post_data = request.rfile.read(rlen) + except: + self.post_data = None + request.send_response(200) + request.end_headers() + self.handled.set() + + def test_output(self): + # The log message sent to the HTTPHandler is properly received. + logger = logging.getLogger("http") + root_logger = self.root_logger + root_logger.removeHandler(self.root_logger.handlers[0]) + for secure in (False, True): + addr = ('localhost', 0) + if secure: + try: + import ssl + except ImportError: + sslctx = None + else: + here = os.path.dirname(__file__) + localhost_cert = os.path.join(here, "certdata", "keycert.pem") + sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + sslctx.load_cert_chain(localhost_cert) + + context = ssl.create_default_context(cafile=localhost_cert) + else: + sslctx = None + context = None + self.server = server = TestHTTPServer(addr, self.handle_request, + 0.01, sslctx=sslctx) + server.start() + server.ready.wait() + host = 'localhost:%d' % server.server_port + secure_client = secure and sslctx + self.h_hdlr = logging.handlers.HTTPHandler(host, '/frob', + secure=secure_client, + context=context, + credentials=('foo', 'bar')) + self.log_data = None + root_logger.addHandler(self.h_hdlr) + + for method in ('GET', 'POST'): + self.h_hdlr.method = method + self.handled.clear() + msg = "sp\xe4m" + logger.error(msg) + self.handled.wait() + self.assertEqual(self.log_data.path, '/frob') + self.assertEqual(self.command, method) + if method == 'GET': + d = parse_qs(self.log_data.query) + else: + d = parse_qs(self.post_data.decode('utf-8')) + self.assertEqual(d['name'], ['http']) + self.assertEqual(d['funcName'], ['test_output']) + self.assertEqual(d['msg'], [msg]) + + self.server.stop() + self.root_logger.removeHandler(self.h_hdlr) + self.h_hdlr.close() + +class MemoryTest(BaseTest): + + """Test memory persistence of logger objects.""" + + def setUp(self): + """Create a dict to remember potentially destroyed objects.""" + BaseTest.setUp(self) + self._survivors = {} + + def _watch_for_survival(self, *args): + """Watch the given objects for survival, by creating weakrefs to + them.""" + for obj in args: + key = id(obj), repr(obj) + self._survivors[key] = weakref.ref(obj) + + def _assertTruesurvival(self): + """Assert that all objects watched for survival have survived.""" + # Trigger cycle breaking. + gc.collect() + dead = [] + for (id_, repr_), ref in self._survivors.items(): + if ref() is None: + dead.append(repr_) + if dead: + self.fail("%d objects should have survived " + "but have been destroyed: %s" % (len(dead), ", ".join(dead))) + + def test_persistent_loggers(self): + # Logger objects are persistent and retain their configuration, even + # if visible references are destroyed. + self.root_logger.setLevel(logging.INFO) + foo = logging.getLogger("foo") + self._watch_for_survival(foo) + foo.setLevel(logging.DEBUG) + self.root_logger.debug(self.next_message()) + foo.debug(self.next_message()) + self.assert_log_lines([ + ('foo', 'DEBUG', '2'), + ]) + del foo + # foo has survived. + self._assertTruesurvival() + # foo has retained its settings. + bar = logging.getLogger("foo") + bar.debug(self.next_message()) + self.assert_log_lines([ + ('foo', 'DEBUG', '2'), + ('foo', 'DEBUG', '3'), + ]) + + +class EncodingTest(BaseTest): + def test_encoding_plain_file(self): + # In Python 2.x, a plain file object is treated as having no encoding. + log = logging.getLogger("test") + fn = make_temp_file(".log", "test_logging-1-") + # the non-ascii data we write to the log. + data = "foo\x80" + try: + handler = logging.FileHandler(fn, encoding="utf-8") + log.addHandler(handler) + try: + # write non-ascii data to the log. + log.warning(data) + finally: + log.removeHandler(handler) + handler.close() + # check we wrote exactly those bytes, ignoring trailing \n etc + f = open(fn, encoding="utf-8") + try: + self.assertEqual(f.read().rstrip(), data) + finally: + f.close() + finally: + if os.path.isfile(fn): + os.remove(fn) + + def test_encoding_cyrillic_unicode(self): + log = logging.getLogger("test") + # Get a message in Unicode: Do svidanya in Cyrillic (meaning goodbye) + message = '\u0434\u043e \u0441\u0432\u0438\u0434\u0430\u043d\u0438\u044f' + # Ensure it's written in a Cyrillic encoding + writer_class = codecs.getwriter('cp1251') + writer_class.encoding = 'cp1251' + stream = io.BytesIO() + writer = writer_class(stream, 'strict') + handler = logging.StreamHandler(writer) + log.addHandler(handler) + try: + log.warning(message) + finally: + log.removeHandler(handler) + handler.close() + # check we wrote exactly those bytes, ignoring trailing \n etc + s = stream.getvalue() + # Compare against what the data should be when encoded in CP-1251 + self.assertEqual(s, b'\xe4\xee \xf1\xe2\xe8\xe4\xe0\xed\xe8\xff\n') + + +class WarningsTest(BaseTest): + + def test_warnings(self): + with warnings.catch_warnings(): + logging.captureWarnings(True) + self.addCleanup(logging.captureWarnings, False) + warnings.filterwarnings("always", category=UserWarning) + stream = io.StringIO() + h = logging.StreamHandler(stream) + logger = logging.getLogger("py.warnings") + logger.addHandler(h) + warnings.warn("I'm warning you...") + logger.removeHandler(h) + s = stream.getvalue() + h.close() + self.assertGreater(s.find("UserWarning: I'm warning you...\n"), 0) + + # See if an explicit file uses the original implementation + a_file = io.StringIO() + warnings.showwarning("Explicit", UserWarning, "dummy.py", 42, + a_file, "Dummy line") + s = a_file.getvalue() + a_file.close() + self.assertEqual(s, + "dummy.py:42: UserWarning: Explicit\n Dummy line\n") + + def test_warnings_no_handlers(self): + with warnings.catch_warnings(): + logging.captureWarnings(True) + self.addCleanup(logging.captureWarnings, False) + + # confirm our assumption: no loggers are set + logger = logging.getLogger("py.warnings") + self.assertEqual(logger.handlers, []) + + warnings.showwarning("Explicit", UserWarning, "dummy.py", 42) + self.assertEqual(len(logger.handlers), 1) + self.assertIsInstance(logger.handlers[0], logging.NullHandler) + + +def formatFunc(format, datefmt=None): + return logging.Formatter(format, datefmt) + +class myCustomFormatter: + def __init__(self, fmt, datefmt=None): + pass + +def handlerFunc(): + return logging.StreamHandler() + +class CustomHandler(logging.StreamHandler): + pass + +class CustomListener(logging.handlers.QueueListener): + pass + +class CustomQueue(queue.Queue): + pass + +class CustomQueueProtocol: + def __init__(self, maxsize=0): + self.queue = queue.Queue(maxsize) + + def __getattr__(self, attribute): + queue = object.__getattribute__(self, 'queue') + return getattr(queue, attribute) + +class CustomQueueFakeProtocol(CustomQueueProtocol): + # An object implementing the minimial Queue API for + # the logging module but with incorrect signatures. + # + # The object will be considered a valid queue class since we + # do not check the signatures (only callability of methods) + # but will NOT be usable in production since a TypeError will + # be raised due to the extra argument in 'put_nowait'. + def put_nowait(self): + pass + +class CustomQueueWrongProtocol(CustomQueueProtocol): + put_nowait = None + +class MinimalQueueProtocol: + def put_nowait(self, x): pass + def get(self): pass + +def queueMaker(): + return queue.Queue() + +def listenerMaker(arg1, arg2, respect_handler_level=False): + def func(queue, *handlers, **kwargs): + kwargs.setdefault('respect_handler_level', respect_handler_level) + return CustomListener(queue, *handlers, **kwargs) + return func + +class ConfigDictTest(BaseTest): + + """Reading logging config from a dictionary.""" + + check_no_resource_warning = warnings_helper.check_no_resource_warning + expected_log_pat = r"^(\w+) \+\+ (\w+)$" + + # config0 is a standard configuration. + config0 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'root' : { + 'level' : 'WARNING', + 'handlers' : ['hand1'], + }, + } + + # config1 adds a little to the standard configuration. + config1 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # config1a moves the handler to the root. Used with config8a + config1a = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + }, + }, + 'root' : { + 'level' : 'WARNING', + 'handlers' : ['hand1'], + }, + } + + # config2 has a subtle configuration error that should be reported + config2 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdbout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # As config1 but with a misspelt level on a handler + config2a = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NTOSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + + # As config1 but with a misspelt level on a logger + config2b = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WRANING', + }, + } + + # config3 has a less subtle configuration error + config3 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'misspelled_name', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # config4 specifies a custom formatter class to be loaded + config4 = { + 'version': 1, + 'formatters': { + 'form1' : { + '()' : __name__ + '.ExceptionFormatter', + 'format' : '%(levelname)s:%(name)s:%(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'root' : { + 'level' : 'NOTSET', + 'handlers' : ['hand1'], + }, + } + + # As config4 but using an actual callable rather than a string + config4a = { + 'version': 1, + 'formatters': { + 'form1' : { + '()' : ExceptionFormatter, + 'format' : '%(levelname)s:%(name)s:%(message)s', + }, + 'form2' : { + '()' : __name__ + '.formatFunc', + 'format' : '%(levelname)s:%(name)s:%(message)s', + }, + 'form3' : { + '()' : formatFunc, + 'format' : '%(levelname)s:%(name)s:%(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + 'hand2' : { + '()' : handlerFunc, + }, + }, + 'root' : { + 'level' : 'NOTSET', + 'handlers' : ['hand1'], + }, + } + + # config5 specifies a custom handler class to be loaded + config5 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : __name__ + '.CustomHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # config6 specifies a custom handler class to be loaded + # but has bad arguments + config6 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : __name__ + '.CustomHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + '9' : 'invalid parameter name', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # config 7 does not define compiler.parser but defines compiler.lexer + # so compiler.parser should be disabled after applying it + config7 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.lexer' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # config8 defines both compiler and compiler.lexer + # so compiler.parser should not be disabled (since + # compiler is defined) + config8 = { + 'version': 1, + 'disable_existing_loggers' : False, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + 'compiler.lexer' : { + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # config8a disables existing loggers + config8a = { + 'version': 1, + 'disable_existing_loggers' : True, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + 'compiler.lexer' : { + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + config9 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'WARNING', + 'stream' : 'ext://sys.stdout', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'WARNING', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'NOTSET', + }, + } + + config9a = { + 'version': 1, + 'incremental' : True, + 'handlers' : { + 'hand1' : { + 'level' : 'WARNING', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'INFO', + }, + }, + } + + config9b = { + 'version': 1, + 'incremental' : True, + 'handlers' : { + 'hand1' : { + 'level' : 'INFO', + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'INFO', + }, + }, + } + + # As config1 but with a filter added + config10 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'filters' : { + 'filt1' : { + 'name' : 'compiler.parser', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + 'filters' : ['filt1'], + }, + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'filters' : ['filt1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + 'handlers' : ['hand1'], + }, + } + + # As config1 but using cfg:// references + config11 = { + 'version': 1, + 'true_formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handler_configs': { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'formatters' : 'cfg://true_formatters', + 'handlers' : { + 'hand1' : 'cfg://handler_configs[hand1]', + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # As config11 but missing the version key + config12 = { + 'true_formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handler_configs': { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'formatters' : 'cfg://true_formatters', + 'handlers' : { + 'hand1' : 'cfg://handler_configs[hand1]', + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # As config11 but using an unsupported version + config13 = { + 'version': 2, + 'true_formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handler_configs': { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'formatters' : 'cfg://true_formatters', + 'handlers' : { + 'hand1' : 'cfg://handler_configs[hand1]', + }, + 'loggers' : { + 'compiler.parser' : { + 'level' : 'DEBUG', + 'handlers' : ['hand1'], + }, + }, + 'root' : { + 'level' : 'WARNING', + }, + } + + # As config0, but with properties + config14 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(levelname)s ++ %(message)s', + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + '.': { + 'foo': 'bar', + 'terminator': '!\n', + } + }, + }, + 'root' : { + 'level' : 'WARNING', + 'handlers' : ['hand1'], + }, + } + + # config0 but with default values for formatter. Skipped 15, it is defined + # in the test code. + config16 = { + 'version': 1, + 'formatters': { + 'form1' : { + 'format' : '%(message)s ++ %(customfield)s', + 'defaults': {"customfield": "defaultvalue"} + }, + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'form1', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'root' : { + 'level' : 'WARNING', + 'handlers' : ['hand1'], + }, + } + + class CustomFormatter(logging.Formatter): + custom_property = "." + + def format(self, record): + return super().format(record) + + config17 = { + 'version': 1, + 'formatters': { + "custom": { + "()": CustomFormatter, + "style": "{", + "datefmt": "%Y-%m-%d %H:%M:%S", + "format": "{message}", # <-- to force an exception when configuring + ".": { + "custom_property": "value" + } + } + }, + 'handlers' : { + 'hand1' : { + 'class' : 'logging.StreamHandler', + 'formatter' : 'custom', + 'level' : 'NOTSET', + 'stream' : 'ext://sys.stdout', + }, + }, + 'root' : { + 'level' : 'WARNING', + 'handlers' : ['hand1'], + }, + } + + config18 = { + "version": 1, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": "DEBUG", + }, + "buffering": { + "class": "logging.handlers.MemoryHandler", + "capacity": 5, + "target": "console", + "level": "DEBUG", + "flushLevel": "ERROR" + } + }, + "loggers": { + "mymodule": { + "level": "DEBUG", + "handlers": ["buffering"], + "propagate": "true" + } + } + } + + bad_format = { + "version": 1, + "formatters": { + "mySimpleFormatter": { + "format": "%(asctime)s (%(name)s) %(levelname)s: %(message)s", + "style": "$" + } + }, + "handlers": { + "fileGlobal": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "mySimpleFormatter" + }, + "bufferGlobal": { + "class": "logging.handlers.MemoryHandler", + "capacity": 5, + "formatter": "mySimpleFormatter", + "target": "fileGlobal", + "level": "DEBUG" + } + }, + "loggers": { + "mymodule": { + "level": "DEBUG", + "handlers": ["bufferGlobal"], + "propagate": "true" + } + } + } + + # Configuration with custom logging.Formatter subclass as '()' key and 'validate' set to False + custom_formatter_class_validate = { + 'version': 1, + 'formatters': { + 'form1': { + '()': __name__ + '.ExceptionFormatter', + 'format': '%(levelname)s:%(name)s:%(message)s', + 'validate': False, + }, + }, + 'handlers' : { + 'hand1' : { + 'class': 'logging.StreamHandler', + 'formatter': 'form1', + 'level': 'NOTSET', + 'stream': 'ext://sys.stdout', + }, + }, + "loggers": { + "my_test_logger_custom_formatter": { + "level": "DEBUG", + "handlers": ["hand1"], + "propagate": "true" + } + } + } + + # Configuration with custom logging.Formatter subclass as 'class' key and 'validate' set to False + custom_formatter_class_validate2 = { + 'version': 1, + 'formatters': { + 'form1': { + 'class': __name__ + '.ExceptionFormatter', + 'format': '%(levelname)s:%(name)s:%(message)s', + 'validate': False, + }, + }, + 'handlers' : { + 'hand1' : { + 'class': 'logging.StreamHandler', + 'formatter': 'form1', + 'level': 'NOTSET', + 'stream': 'ext://sys.stdout', + }, + }, + "loggers": { + "my_test_logger_custom_formatter": { + "level": "DEBUG", + "handlers": ["hand1"], + "propagate": "true" + } + } + } + + # Configuration with custom class that is not inherited from logging.Formatter + custom_formatter_class_validate3 = { + 'version': 1, + 'formatters': { + 'form1': { + 'class': __name__ + '.myCustomFormatter', + 'format': '%(levelname)s:%(name)s:%(message)s', + 'validate': False, + }, + }, + 'handlers' : { + 'hand1' : { + 'class': 'logging.StreamHandler', + 'formatter': 'form1', + 'level': 'NOTSET', + 'stream': 'ext://sys.stdout', + }, + }, + "loggers": { + "my_test_logger_custom_formatter": { + "level": "DEBUG", + "handlers": ["hand1"], + "propagate": "true" + } + } + } + + # Configuration with custom function, 'validate' set to False and no defaults + custom_formatter_with_function = { + 'version': 1, + 'formatters': { + 'form1': { + '()': formatFunc, + 'format': '%(levelname)s:%(name)s:%(message)s', + 'validate': False, + }, + }, + 'handlers' : { + 'hand1' : { + 'class': 'logging.StreamHandler', + 'formatter': 'form1', + 'level': 'NOTSET', + 'stream': 'ext://sys.stdout', + }, + }, + "loggers": { + "my_test_logger_custom_formatter": { + "level": "DEBUG", + "handlers": ["hand1"], + "propagate": "true" + } + } + } + + # Configuration with custom function, and defaults + custom_formatter_with_defaults = { + 'version': 1, + 'formatters': { + 'form1': { + '()': formatFunc, + 'format': '%(levelname)s:%(name)s:%(message)s:%(customfield)s', + 'defaults': {"customfield": "myvalue"} + }, + }, + 'handlers' : { + 'hand1' : { + 'class': 'logging.StreamHandler', + 'formatter': 'form1', + 'level': 'NOTSET', + 'stream': 'ext://sys.stdout', + }, + }, + "loggers": { + "my_test_logger_custom_formatter": { + "level": "DEBUG", + "handlers": ["hand1"], + "propagate": "true" + } + } + } + + config_queue_handler = { + 'version': 1, + 'handlers' : { + 'h1' : { + 'class': 'logging.FileHandler', + }, + # key is before depended on handlers to test that deferred config works + 'ah' : { + 'class': 'logging.handlers.QueueHandler', + 'handlers': ['h1'] + }, + }, + "root": { + "level": "DEBUG", + "handlers": ["ah"] + } + } + + def apply_config(self, conf): + logging.config.dictConfig(conf) + + def check_handler(self, name, cls): + h = logging.getHandlerByName(name) + self.assertIsInstance(h, cls) + + def test_config0_ok(self): + # A simple config which overrides the default settings. + with support.captured_stdout() as output: + self.apply_config(self.config0) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger() + # Won't output anything + logger.info(self.next_message()) + # Outputs a message + logger.error(self.next_message()) + self.assert_log_lines([ + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config1_ok(self, config=config1): + # A config defining a sub-parser as well. + with support.captured_stdout() as output: + self.apply_config(config) + logger = logging.getLogger("compiler.parser") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config2_failure(self): + # A simple config which overrides the default settings. + self.assertRaises(Exception, self.apply_config, self.config2) + + def test_config2a_failure(self): + # A simple config which overrides the default settings. + self.assertRaises(Exception, self.apply_config, self.config2a) + + def test_config2b_failure(self): + # A simple config which overrides the default settings. + self.assertRaises(Exception, self.apply_config, self.config2b) + + def test_config3_failure(self): + # A simple config which overrides the default settings. + self.assertRaises(Exception, self.apply_config, self.config3) + + def test_config4_ok(self): + # A config specifying a custom formatter class. + with support.captured_stdout() as output: + self.apply_config(self.config4) + self.check_handler('hand1', logging.StreamHandler) + #logger = logging.getLogger() + try: + raise RuntimeError() + except RuntimeError: + logging.exception("just testing") + sys.stdout.seek(0) + self.assertEqual(output.getvalue(), + "ERROR:root:just testing\nGot a [RuntimeError]\n") + # Original logger output is empty + self.assert_log_lines([]) + + def test_config4a_ok(self): + # A config specifying a custom formatter class. + with support.captured_stdout() as output: + self.apply_config(self.config4a) + #logger = logging.getLogger() + try: + raise RuntimeError() + except RuntimeError: + logging.exception("just testing") + sys.stdout.seek(0) + self.assertEqual(output.getvalue(), + "ERROR:root:just testing\nGot a [RuntimeError]\n") + # Original logger output is empty + self.assert_log_lines([]) + + def test_config5_ok(self): + self.test_config1_ok(config=self.config5) + self.check_handler('hand1', CustomHandler) + + def test_config6_failure(self): + self.assertRaises(Exception, self.apply_config, self.config6) + + def test_config7_ok(self): + with support.captured_stdout() as output: + self.apply_config(self.config1) + logger = logging.getLogger("compiler.parser") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + with support.captured_stdout() as output: + self.apply_config(self.config7) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger("compiler.parser") + self.assertTrue(logger.disabled) + logger = logging.getLogger("compiler.lexer") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '3'), + ('ERROR', '4'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + # Same as test_config_7_ok but don't disable old loggers. + def test_config_8_ok(self): + with support.captured_stdout() as output: + self.apply_config(self.config1) + logger = logging.getLogger("compiler.parser") + # All will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + with support.captured_stdout() as output: + self.apply_config(self.config8) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger("compiler.parser") + self.assertFalse(logger.disabled) + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + logger = logging.getLogger("compiler.lexer") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '3'), + ('ERROR', '4'), + ('INFO', '5'), + ('ERROR', '6'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config_8a_ok(self): + with support.captured_stdout() as output: + self.apply_config(self.config1a) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger("compiler.parser") + # See issue #11424. compiler-hyphenated sorts + # between compiler and compiler.xyz and this + # was preventing compiler.xyz from being included + # in the child loggers of compiler because of an + # overzealous loop termination condition. + hyphenated = logging.getLogger('compiler-hyphenated') + # All will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + hyphenated.critical(self.next_message()) + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ('CRITICAL', '3'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + with support.captured_stdout() as output: + self.apply_config(self.config8a) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger("compiler.parser") + self.assertFalse(logger.disabled) + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + logger = logging.getLogger("compiler.lexer") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + # Will not appear + hyphenated.critical(self.next_message()) + self.assert_log_lines([ + ('INFO', '4'), + ('ERROR', '5'), + ('INFO', '6'), + ('ERROR', '7'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + def test_config_9_ok(self): + with support.captured_stdout() as output: + self.apply_config(self.config9) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger("compiler.parser") + # Nothing will be output since both handler and logger are set to WARNING + logger.info(self.next_message()) + self.assert_log_lines([], stream=output) + self.apply_config(self.config9a) + # Nothing will be output since handler is still set to WARNING + logger.info(self.next_message()) + self.assert_log_lines([], stream=output) + self.apply_config(self.config9b) + # Message should now be output + logger.info(self.next_message()) + self.assert_log_lines([ + ('INFO', '3'), + ], stream=output) + + def test_config_10_ok(self): + with support.captured_stdout() as output: + self.apply_config(self.config10) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger("compiler.parser") + logger.warning(self.next_message()) + logger = logging.getLogger('compiler') + # Not output, because filtered + logger.warning(self.next_message()) + logger = logging.getLogger('compiler.lexer') + # Not output, because filtered + logger.warning(self.next_message()) + logger = logging.getLogger("compiler.parser.codegen") + # Output, as not filtered + logger.error(self.next_message()) + self.assert_log_lines([ + ('WARNING', '1'), + ('ERROR', '4'), + ], stream=output) + + def test_config11_ok(self): + self.test_config1_ok(self.config11) + + def test_config12_failure(self): + self.assertRaises(Exception, self.apply_config, self.config12) + + def test_config13_failure(self): + self.assertRaises(Exception, self.apply_config, self.config13) + + def test_config14_ok(self): + with support.captured_stdout() as output: + self.apply_config(self.config14) + h = logging._handlers['hand1'] + self.assertEqual(h.foo, 'bar') + self.assertEqual(h.terminator, '!\n') + logging.warning('Exclamation') + self.assertTrue(output.getvalue().endswith('Exclamation!\n')) + + def test_config15_ok(self): + + with self.check_no_resource_warning(): + fn = make_temp_file(".log", "test_logging-X-") + + config = { + "version": 1, + "handlers": { + "file": { + "class": "logging.FileHandler", + "filename": fn, + "encoding": "utf-8", + } + }, + "root": { + "handlers": ["file"] + } + } + + self.apply_config(config) + self.apply_config(config) + + handler = logging.root.handlers[0] + self.addCleanup(closeFileHandler, handler, fn) + + def test_config16_ok(self): + self.apply_config(self.config16) + h = logging._handlers['hand1'] + + # Custom value + result = h.formatter.format(logging.makeLogRecord( + {'msg': 'Hello', 'customfield': 'customvalue'})) + self.assertEqual(result, 'Hello ++ customvalue') + + # Default value + result = h.formatter.format(logging.makeLogRecord( + {'msg': 'Hello'})) + self.assertEqual(result, 'Hello ++ defaultvalue') + + def test_config17_ok(self): + self.apply_config(self.config17) + h = logging._handlers['hand1'] + self.assertEqual(h.formatter.custom_property, 'value') + + def test_config18_ok(self): + self.apply_config(self.config18) + handler = logging.getLogger('mymodule').handlers[0] + self.assertEqual(handler.flushLevel, logging.ERROR) + + def setup_via_listener(self, text, verify=None): + text = text.encode("utf-8") + # Ask for a randomly assigned port (by using port 0) + t = logging.config.listen(0, verify) + t.start() + t.ready.wait() + # Now get the port allocated + port = t.port + t.ready.clear() + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2.0) + sock.connect(('localhost', port)) + + slen = struct.pack('>L', len(text)) + s = slen + text + sentsofar = 0 + left = len(s) + while left > 0: + sent = sock.send(s[sentsofar:]) + sentsofar += sent + left -= sent + sock.close() + finally: + t.ready.wait(2.0) + logging.config.stopListening() + threading_helper.join_thread(t) + + @support.requires_working_socket() + def test_listen_config_10_ok(self): + with support.captured_stdout() as output: + self.setup_via_listener(json.dumps(self.config10)) + self.check_handler('hand1', logging.StreamHandler) + logger = logging.getLogger("compiler.parser") + logger.warning(self.next_message()) + logger = logging.getLogger('compiler') + # Not output, because filtered + logger.warning(self.next_message()) + logger = logging.getLogger('compiler.lexer') + # Not output, because filtered + logger.warning(self.next_message()) + logger = logging.getLogger("compiler.parser.codegen") + # Output, as not filtered + logger.error(self.next_message()) + self.assert_log_lines([ + ('WARNING', '1'), + ('ERROR', '4'), + ], stream=output) + + @support.requires_working_socket() + def test_listen_config_1_ok(self): + with support.captured_stdout() as output: + self.setup_via_listener(textwrap.dedent(ConfigFileTest.config1)) + logger = logging.getLogger("compiler.parser") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], stream=output) + # Original logger output is empty. + self.assert_log_lines([]) + + @support.requires_working_socket() + def test_listen_verify(self): + + def verify_fail(stuff): + return None + + def verify_reverse(stuff): + return stuff[::-1] + + logger = logging.getLogger("compiler.parser") + to_send = textwrap.dedent(ConfigFileTest.config1) + # First, specify a verification function that will fail. + # We expect to see no output, since our configuration + # never took effect. + with support.captured_stdout() as output: + self.setup_via_listener(to_send, verify_fail) + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([], stream=output) + # Original logger output has the stuff we logged. + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], pat=r"^[\w.]+ -> (\w+): (\d+)$") + + # Now, perform no verification. Our configuration + # should take effect. + + with support.captured_stdout() as output: + self.setup_via_listener(to_send) # no verify callable specified + logger = logging.getLogger("compiler.parser") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '3'), + ('ERROR', '4'), + ], stream=output) + # Original logger output still has the stuff we logged before. + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], pat=r"^[\w.]+ -> (\w+): (\d+)$") + + # Now, perform verification which transforms the bytes. + + with support.captured_stdout() as output: + self.setup_via_listener(to_send[::-1], verify_reverse) + logger = logging.getLogger("compiler.parser") + # Both will output a message + logger.info(self.next_message()) + logger.error(self.next_message()) + self.assert_log_lines([ + ('INFO', '5'), + ('ERROR', '6'), + ], stream=output) + # Original logger output still has the stuff we logged before. + self.assert_log_lines([ + ('INFO', '1'), + ('ERROR', '2'), + ], pat=r"^[\w.]+ -> (\w+): (\d+)$") + + def test_bad_format(self): + self.assertRaises(ValueError, self.apply_config, self.bad_format) + + def test_bad_format_with_dollar_style(self): + config = copy.deepcopy(self.bad_format) + config['formatters']['mySimpleFormatter']['format'] = "${asctime} (${name}) ${levelname}: ${message}" + + self.apply_config(config) + handler = logging.getLogger('mymodule').handlers[0] + self.assertIsInstance(handler.target, logging.Handler) + self.assertIsInstance(handler.formatter._style, + logging.StringTemplateStyle) + self.assertEqual(sorted(logging.getHandlerNames()), + ['bufferGlobal', 'fileGlobal']) + + def test_custom_formatter_class_with_validate(self): + self.apply_config(self.custom_formatter_class_validate) + handler = logging.getLogger("my_test_logger_custom_formatter").handlers[0] + self.assertIsInstance(handler.formatter, ExceptionFormatter) + + def test_custom_formatter_class_with_validate2(self): + self.apply_config(self.custom_formatter_class_validate2) + handler = logging.getLogger("my_test_logger_custom_formatter").handlers[0] + self.assertIsInstance(handler.formatter, ExceptionFormatter) + + def test_custom_formatter_class_with_validate2_with_wrong_fmt(self): + config = self.custom_formatter_class_validate.copy() + config['formatters']['form1']['style'] = "$" + + # Exception should not be raised as we have configured 'validate' to False + self.apply_config(config) + handler = logging.getLogger("my_test_logger_custom_formatter").handlers[0] + self.assertIsInstance(handler.formatter, ExceptionFormatter) + + def test_custom_formatter_class_with_validate3(self): + self.assertRaises(ValueError, self.apply_config, self.custom_formatter_class_validate3) + + def test_custom_formatter_function_with_validate(self): + self.assertRaises(ValueError, self.apply_config, self.custom_formatter_with_function) + + def test_custom_formatter_function_with_defaults(self): + self.assertRaises(ValueError, self.apply_config, self.custom_formatter_with_defaults) + + def test_baseconfig(self): + d = { + 'atuple': (1, 2, 3), + 'alist': ['a', 'b', 'c'], + 'adict': {'d': 'e', 'f': 3 }, + 'nest1': ('g', ('h', 'i'), 'j'), + 'nest2': ['k', ['l', 'm'], 'n'], + 'nest3': ['o', 'cfg://alist', 'p'], + } + bc = logging.config.BaseConfigurator(d) + self.assertEqual(bc.convert('cfg://atuple[1]'), 2) + self.assertEqual(bc.convert('cfg://alist[1]'), 'b') + self.assertEqual(bc.convert('cfg://nest1[1][0]'), 'h') + self.assertEqual(bc.convert('cfg://nest2[1][1]'), 'm') + self.assertEqual(bc.convert('cfg://adict.d'), 'e') + self.assertEqual(bc.convert('cfg://adict[f]'), 3) + v = bc.convert('cfg://nest3') + self.assertEqual(v.pop(1), ['a', 'b', 'c']) + self.assertRaises(KeyError, bc.convert, 'cfg://nosuch') + self.assertRaises(ValueError, bc.convert, 'cfg://!') + self.assertRaises(KeyError, bc.convert, 'cfg://adict[2]') + + def test_namedtuple(self): + # see bpo-39142 + from collections import namedtuple + + class MyHandler(logging.StreamHandler): + def __init__(self, resource, *args, **kwargs): + super().__init__(*args, **kwargs) + self.resource: namedtuple = resource + + def emit(self, record): + record.msg += f' {self.resource.type}' + return super().emit(record) + + Resource = namedtuple('Resource', ['type', 'labels']) + resource = Resource(type='my_type', labels=['a']) + + config = { + 'version': 1, + 'handlers': { + 'myhandler': { + '()': MyHandler, + 'resource': resource + } + }, + 'root': {'level': 'INFO', 'handlers': ['myhandler']}, + } + with support.captured_stderr() as stderr: + self.apply_config(config) + logging.info('some log') + self.assertEqual(stderr.getvalue(), 'some log my_type\n') + + def test_config_callable_filter_works(self): + def filter_(_): + return 1 + self.apply_config({ + "version": 1, "root": {"level": "DEBUG", "filters": [filter_]} + }) + assert logging.getLogger().filters[0] is filter_ + logging.getLogger().filters = [] + + def test_config_filter_works(self): + filter_ = logging.Filter("spam.eggs") + self.apply_config({ + "version": 1, "root": {"level": "DEBUG", "filters": [filter_]} + }) + assert logging.getLogger().filters[0] is filter_ + logging.getLogger().filters = [] + + def test_config_filter_method_works(self): + class FakeFilter: + def filter(self, _): + return 1 + filter_ = FakeFilter() + self.apply_config({ + "version": 1, "root": {"level": "DEBUG", "filters": [filter_]} + }) + assert logging.getLogger().filters[0] is filter_ + logging.getLogger().filters = [] + + def test_invalid_type_raises(self): + class NotAFilter: pass + for filter_ in [None, 1, NotAFilter()]: + self.assertRaises( + ValueError, + self.apply_config, + {"version": 1, "root": {"level": "DEBUG", "filters": [filter_]}} + ) + + def do_queuehandler_configuration(self, qspec, lspec): + cd = copy.deepcopy(self.config_queue_handler) + fn = make_temp_file('.log', 'test_logging-cqh-') + cd['handlers']['h1']['filename'] = fn + if qspec is not None: + cd['handlers']['ah']['queue'] = qspec + if lspec is not None: + cd['handlers']['ah']['listener'] = lspec + qh = None + try: + self.apply_config(cd) + qh = logging.getHandlerByName('ah') + self.assertEqual(sorted(logging.getHandlerNames()), ['ah', 'h1']) + self.assertIsNotNone(qh.listener) + qh.listener.start() + logging.debug('foo') + logging.info('bar') + logging.warning('baz') + + # Need to let the listener thread finish its work + while support.sleeping_retry(support.LONG_TIMEOUT, + "queue not empty"): + if qh.listener.queue.empty(): + break + + # wait until the handler completed its last task + qh.listener.queue.join() + + with open(fn, encoding='utf-8') as f: + data = f.read().splitlines() + self.assertEqual(data, ['foo', 'bar', 'baz']) + finally: + if qh: + qh.listener.stop() + h = logging.getHandlerByName('h1') + if h: + self.addCleanup(closeFileHandler, h, fn) + else: + self.addCleanup(os.remove, fn) + + @threading_helper.requires_working_threading() + @support.requires_subprocess() + def test_config_queue_handler(self): + qs = [CustomQueue(), CustomQueueProtocol()] + dqs = [{'()': f'{__name__}.{cls}', 'maxsize': 10} + for cls in ['CustomQueue', 'CustomQueueProtocol']] + dl = { + '()': __name__ + '.listenerMaker', + 'arg1': None, + 'arg2': None, + 'respect_handler_level': True + } + qvalues = (None, __name__ + '.queueMaker', __name__ + '.CustomQueue', *dqs, *qs) + lvalues = (None, __name__ + '.CustomListener', dl, CustomListener) + for qspec, lspec in itertools.product(qvalues, lvalues): + self.do_queuehandler_configuration(qspec, lspec) + + # Some failure cases + qvalues = (None, 4, int, '', 'foo') + lvalues = (None, 4, int, '', 'bar') + for qspec, lspec in itertools.product(qvalues, lvalues): + if lspec is None and qspec is None: + continue + with self.assertRaises(ValueError) as ctx: + self.do_queuehandler_configuration(qspec, lspec) + msg = str(ctx.exception) + self.assertEqual(msg, "Unable to configure handler 'ah'") + + def _apply_simple_queue_listener_configuration(self, qspec): + self.apply_config({ + "version": 1, + "handlers": { + "queue_listener": { + "class": "logging.handlers.QueueHandler", + "queue": qspec, + }, + }, + }) + + @threading_helper.requires_working_threading() + @support.requires_subprocess() + @patch("multiprocessing.Manager") + def test_config_queue_handler_does_not_create_multiprocessing_manager(self, manager): + # gh-120868, gh-121723, gh-124653 + + for qspec in [ + {"()": "queue.Queue", "maxsize": -1}, + queue.Queue(), + # queue.SimpleQueue does not inherit from queue.Queue + queue.SimpleQueue(), + # CustomQueueFakeProtocol passes the checks but will not be usable + # since the signatures are incompatible. Checking the Queue API + # without testing the type of the actual queue is a trade-off + # between usability and the work we need to do in order to safely + # check that the queue object correctly implements the API. + CustomQueueFakeProtocol(), + MinimalQueueProtocol(), + ]: + with self.subTest(qspec=qspec): + self._apply_simple_queue_listener_configuration(qspec) + manager.assert_not_called() + + @patch("multiprocessing.Manager") + def test_config_queue_handler_invalid_config_does_not_create_multiprocessing_manager(self, manager): + # gh-120868, gh-121723 + + for qspec in [object(), CustomQueueWrongProtocol()]: + with self.subTest(qspec=qspec), self.assertRaises(ValueError): + self._apply_simple_queue_listener_configuration(qspec) + manager.assert_not_called() + + @skip_if_tsan_fork + @support.requires_subprocess() + @unittest.skipUnless(support.Py_DEBUG, "requires a debug build for testing" + " assertions in multiprocessing") + def test_config_reject_simple_queue_handler_multiprocessing_context(self): + # multiprocessing.SimpleQueue does not implement 'put_nowait' + # and thus cannot be used as a queue-like object (gh-124653) + + import multiprocessing + + if support.MS_WINDOWS: + start_methods = ['spawn'] + else: + start_methods = ['spawn', 'fork', 'forkserver'] + + for start_method in start_methods: + with self.subTest(start_method=start_method): + ctx = multiprocessing.get_context(start_method) + qspec = ctx.SimpleQueue() + with self.assertRaises(ValueError): + self._apply_simple_queue_listener_configuration(qspec) + + @skip_if_tsan_fork + @support.requires_subprocess() + @unittest.skipUnless(support.Py_DEBUG, "requires a debug build for testing" + "assertions in multiprocessing") + def test_config_queue_handler_multiprocessing_context(self): + # regression test for gh-121723 + if support.MS_WINDOWS: + start_methods = ['spawn'] + else: + start_methods = ['spawn', 'fork', 'forkserver'] + for start_method in start_methods: + with self.subTest(start_method=start_method): + ctx = multiprocessing.get_context(start_method) + with ctx.Manager() as manager: + q = manager.Queue() + records = [] + # use 1 process and 1 task per child to put 1 record + with ctx.Pool(1, initializer=self._mpinit_issue121723, + initargs=(q, "text"), maxtasksperchild=1): + records.append(q.get(timeout=60)) + self.assertTrue(q.empty()) + self.assertEqual(len(records), 1) + + @staticmethod + def _mpinit_issue121723(qspec, message_to_log): + # static method for pickling support + logging.config.dictConfig({ + 'version': 1, + 'disable_existing_loggers': True, + 'handlers': { + 'log_to_parent': { + 'class': 'logging.handlers.QueueHandler', + 'queue': qspec + } + }, + 'root': {'handlers': ['log_to_parent'], 'level': 'DEBUG'} + }) + # log a message (this creates a record put in the queue) + logging.getLogger().info(message_to_log) + + @support.requires_subprocess() + def test_multiprocessing_queues(self): + # See gh-119819 + + cd = copy.deepcopy(self.config_queue_handler) + from multiprocessing import Queue as MQ, Manager as MM + q1 = MQ() # this can't be pickled + q2 = MM().Queue() # a proxy queue for use when pickling is needed + q3 = MM().JoinableQueue() # a joinable proxy queue + for qspec in (q1, q2, q3): + fn = make_temp_file('.log', 'test_logging-cmpqh-') + cd['handlers']['h1']['filename'] = fn + cd['handlers']['ah']['queue'] = qspec + qh = None + try: + self.apply_config(cd) + qh = logging.getHandlerByName('ah') + self.assertEqual(sorted(logging.getHandlerNames()), ['ah', 'h1']) + self.assertIsNotNone(qh.listener) + self.assertIs(qh.queue, qspec) + self.assertIs(qh.listener.queue, qspec) + finally: + h = logging.getHandlerByName('h1') + if h: + self.addCleanup(closeFileHandler, h, fn) + else: + self.addCleanup(os.remove, fn) + + def test_90195(self): + # See gh-90195 + config = { + 'version': 1, + 'disable_existing_loggers': False, + 'handlers': { + 'console': { + 'level': 'DEBUG', + 'class': 'logging.StreamHandler', + }, + }, + 'loggers': { + 'a': { + 'level': 'DEBUG', + 'handlers': ['console'] + } + } + } + logger = logging.getLogger('a') + self.assertFalse(logger.disabled) + self.apply_config(config) + self.assertFalse(logger.disabled) + # Should disable all loggers ... + self.apply_config({'version': 1}) + self.assertTrue(logger.disabled) + del config['disable_existing_loggers'] + self.apply_config(config) + # Logger should be enabled, since explicitly mentioned + self.assertFalse(logger.disabled) + + def test_111615(self): + # See gh-111615 + import_helper.import_module('_multiprocessing') # see gh-113692 + mp = import_helper.import_module('multiprocessing') + + config = { + 'version': 1, + 'handlers': { + 'sink': { + 'class': 'logging.handlers.QueueHandler', + 'queue': mp.get_context('spawn').Queue(), + }, + }, + 'root': { + 'handlers': ['sink'], + 'level': 'DEBUG', + }, + } + logging.config.dictConfig(config) + + # gh-118868: check if kwargs are passed to logging QueueHandler + def test_kwargs_passing(self): + class CustomQueueHandler(logging.handlers.QueueHandler): + def __init__(self, *args, **kwargs): + super().__init__(queue.Queue()) + self.custom_kwargs = kwargs + + custom_kwargs = {'foo': 'bar'} + + config = { + 'version': 1, + 'handlers': { + 'custom': { + 'class': CustomQueueHandler, + **custom_kwargs + }, + }, + 'root': { + 'level': 'DEBUG', + 'handlers': ['custom'] + } + } + + logging.config.dictConfig(config) + + handler = logging.getHandlerByName('custom') + self.assertEqual(handler.custom_kwargs, custom_kwargs) + + +class ManagerTest(BaseTest): + def test_manager_loggerclass(self): + logged = [] + + class MyLogger(logging.Logger): + def _log(self, level, msg, args, exc_info=None, extra=None): + logged.append(msg) + + man = logging.Manager(None) + self.assertRaises(TypeError, man.setLoggerClass, int) + man.setLoggerClass(MyLogger) + logger = man.getLogger('test') + logger.warning('should appear in logged') + logging.warning('should not appear in logged') + + self.assertEqual(logged, ['should appear in logged']) + + def test_set_log_record_factory(self): + man = logging.Manager(None) + expected = object() + man.setLogRecordFactory(expected) + self.assertEqual(man.logRecordFactory, expected) + +class ChildLoggerTest(BaseTest): + def test_child_loggers(self): + r = logging.getLogger() + l1 = logging.getLogger('abc') + l2 = logging.getLogger('def.ghi') + c1 = r.getChild('xyz') + c2 = r.getChild('uvw.xyz') + self.assertIs(c1, logging.getLogger('xyz')) + self.assertIs(c2, logging.getLogger('uvw.xyz')) + c1 = l1.getChild('def') + c2 = c1.getChild('ghi') + c3 = l1.getChild('def.ghi') + self.assertIs(c1, logging.getLogger('abc.def')) + self.assertIs(c2, logging.getLogger('abc.def.ghi')) + self.assertIs(c2, c3) + + def test_get_children(self): + r = logging.getLogger() + l1 = logging.getLogger('foo') + l2 = logging.getLogger('foo.bar') + l3 = logging.getLogger('foo.bar.baz.bozz') + l4 = logging.getLogger('bar') + kids = r.getChildren() + expected = {l1, l4} + self.assertEqual(expected, kids & expected) # might be other kids for root + self.assertNotIn(l2, expected) + kids = l1.getChildren() + self.assertEqual({l2}, kids) + kids = l2.getChildren() + self.assertEqual(set(), kids) + +class DerivedLogRecord(logging.LogRecord): + pass + +class LogRecordFactoryTest(BaseTest): + + def setUp(self): + class CheckingFilter(logging.Filter): + def __init__(self, cls): + self.cls = cls + + def filter(self, record): + t = type(record) + if t is not self.cls: + msg = 'Unexpected LogRecord type %s, expected %s' % (t, + self.cls) + raise TypeError(msg) + return True + + BaseTest.setUp(self) + self.filter = CheckingFilter(DerivedLogRecord) + self.root_logger.addFilter(self.filter) + self.orig_factory = logging.getLogRecordFactory() + + def tearDown(self): + self.root_logger.removeFilter(self.filter) + BaseTest.tearDown(self) + logging.setLogRecordFactory(self.orig_factory) + + def test_logrecord_class(self): + self.assertRaises(TypeError, self.root_logger.warning, + self.next_message()) + logging.setLogRecordFactory(DerivedLogRecord) + self.root_logger.error(self.next_message()) + self.assert_log_lines([ + ('root', 'ERROR', '2'), + ]) + + +@threading_helper.requires_working_threading() +class QueueHandlerTest(BaseTest): + # Do not bother with a logger name group. + expected_log_pat = r"^[\w.]+ -> (\w+): (\d+)$" + + def setUp(self): + BaseTest.setUp(self) + self.queue = queue.Queue(-1) + self.que_hdlr = logging.handlers.QueueHandler(self.queue) + self.name = 'que' + self.que_logger = logging.getLogger('que') + self.que_logger.propagate = False + self.que_logger.setLevel(logging.WARNING) + self.que_logger.addHandler(self.que_hdlr) + + def tearDown(self): + self.que_hdlr.close() + BaseTest.tearDown(self) + + def test_queue_handler(self): + self.que_logger.debug(self.next_message()) + self.assertRaises(queue.Empty, self.queue.get_nowait) + self.que_logger.info(self.next_message()) + self.assertRaises(queue.Empty, self.queue.get_nowait) + msg = self.next_message() + self.que_logger.warning(msg) + data = self.queue.get_nowait() + self.assertTrue(isinstance(data, logging.LogRecord)) + self.assertEqual(data.name, self.que_logger.name) + self.assertEqual((data.msg, data.args), (msg, None)) + + def test_formatting(self): + msg = self.next_message() + levelname = logging.getLevelName(logging.WARNING) + log_format_str = '{name} -> {levelname}: {message}' + formatted_msg = log_format_str.format(name=self.name, + levelname=levelname, message=msg) + formatter = logging.Formatter(self.log_format) + self.que_hdlr.setFormatter(formatter) + self.que_logger.warning(msg) + log_record = self.queue.get_nowait() + self.assertEqual(formatted_msg, log_record.msg) + self.assertEqual(formatted_msg, log_record.message) + + @unittest.skipUnless(hasattr(logging.handlers, 'QueueListener'), + 'logging.handlers.QueueListener required for this test') + def test_queue_listener(self): + handler = TestHandler(support.Matcher()) + listener = logging.handlers.QueueListener(self.queue, handler) + listener.start() + try: + self.que_logger.warning(self.next_message()) + self.que_logger.error(self.next_message()) + self.que_logger.critical(self.next_message()) + finally: + listener.stop() + self.assertTrue(handler.matches(levelno=logging.WARNING, message='1')) + self.assertTrue(handler.matches(levelno=logging.ERROR, message='2')) + self.assertTrue(handler.matches(levelno=logging.CRITICAL, message='3')) + handler.close() + + # Now test with respect_handler_level set + + handler = TestHandler(support.Matcher()) + handler.setLevel(logging.CRITICAL) + listener = logging.handlers.QueueListener(self.queue, handler, + respect_handler_level=True) + listener.start() + try: + self.que_logger.warning(self.next_message()) + self.que_logger.error(self.next_message()) + self.que_logger.critical(self.next_message()) + finally: + listener.stop() + self.assertFalse(handler.matches(levelno=logging.WARNING, message='4')) + self.assertFalse(handler.matches(levelno=logging.ERROR, message='5')) + self.assertTrue(handler.matches(levelno=logging.CRITICAL, message='6')) + handler.close() + + @unittest.skipUnless(hasattr(logging.handlers, 'QueueListener'), + 'logging.handlers.QueueListener required for this test') + def test_queue_listener_with_StreamHandler(self): + # Test that traceback and stack-info only appends once (bpo-34334, bpo-46755). + listener = logging.handlers.QueueListener(self.queue, self.root_hdlr) + listener.start() + try: + 1 / 0 + except ZeroDivisionError as e: + exc = e + self.que_logger.exception(self.next_message(), exc_info=exc) + self.que_logger.error(self.next_message(), stack_info=True) + listener.stop() + self.assertEqual(self.stream.getvalue().strip().count('Traceback'), 1) + self.assertEqual(self.stream.getvalue().strip().count('Stack'), 1) + + @unittest.skipUnless(hasattr(logging.handlers, 'QueueListener'), + 'logging.handlers.QueueListener required for this test') + def test_queue_listener_with_multiple_handlers(self): + # Test that queue handler format doesn't affect other handler formats (bpo-35726). + self.que_hdlr.setFormatter(self.root_formatter) + self.que_logger.addHandler(self.root_hdlr) + + listener = logging.handlers.QueueListener(self.queue, self.que_hdlr) + listener.start() + self.que_logger.error("error") + listener.stop() + self.assertEqual(self.stream.getvalue().strip(), "que -> ERROR: error") + +if hasattr(logging.handlers, 'QueueListener'): + import multiprocessing + from unittest.mock import patch + + @threading_helper.requires_working_threading() + class QueueListenerTest(BaseTest): + """ + Tests based on patch submitted for issue #27930. Ensure that + QueueListener handles all log messages. + """ + + repeat = 20 + + @staticmethod + def setup_and_log(log_queue, ident): + """ + Creates a logger with a QueueHandler that logs to a queue read by a + QueueListener. Starts the listener, logs five messages, and stops + the listener. + """ + logger = logging.getLogger('test_logger_with_id_%s' % ident) + logger.setLevel(logging.DEBUG) + handler = logging.handlers.QueueHandler(log_queue) + logger.addHandler(handler) + listener = logging.handlers.QueueListener(log_queue) + listener.start() + + logger.info('one') + logger.info('two') + logger.info('three') + logger.info('four') + logger.info('five') + + listener.stop() + logger.removeHandler(handler) + handler.close() + + @patch.object(logging.handlers.QueueListener, 'handle') + def test_handle_called_with_queue_queue(self, mock_handle): + for i in range(self.repeat): + log_queue = queue.Queue() + self.setup_and_log(log_queue, '%s_%s' % (self.id(), i)) + self.assertEqual(mock_handle.call_count, 5 * self.repeat, + 'correct number of handled log messages') + + @patch.object(logging.handlers.QueueListener, 'handle') + def test_handle_called_with_mp_queue(self, mock_handle): + # bpo-28668: The multiprocessing (mp) module is not functional + # when the mp.synchronize module cannot be imported. + support.skip_if_broken_multiprocessing_synchronize() + for i in range(self.repeat): + log_queue = multiprocessing.Queue() + self.setup_and_log(log_queue, '%s_%s' % (self.id(), i)) + log_queue.close() + log_queue.join_thread() + self.assertEqual(mock_handle.call_count, 5 * self.repeat, + 'correct number of handled log messages') + + @staticmethod + def get_all_from_queue(log_queue): + try: + while True: + yield log_queue.get_nowait() + except queue.Empty: + return [] + + def test_no_messages_in_queue_after_stop(self): + """ + Five messages are logged then the QueueListener is stopped. This + test then gets everything off the queue. Failure of this test + indicates that messages were not registered on the queue until + _after_ the QueueListener stopped. + """ + # bpo-28668: The multiprocessing (mp) module is not functional + # when the mp.synchronize module cannot be imported. + support.skip_if_broken_multiprocessing_synchronize() + for i in range(self.repeat): + queue = multiprocessing.Queue() + self.setup_and_log(queue, '%s_%s' %(self.id(), i)) + # time.sleep(1) + items = list(self.get_all_from_queue(queue)) + queue.close() + queue.join_thread() + + expected = [[], [logging.handlers.QueueListener._sentinel]] + self.assertIn(items, expected, + 'Found unexpected messages in queue: %s' % ( + [m.msg if isinstance(m, logging.LogRecord) + else m for m in items])) + + def test_calls_task_done_after_stop(self): + # Issue 36813: Make sure queue.join does not deadlock. + log_queue = queue.Queue() + listener = logging.handlers.QueueListener(log_queue) + listener.start() + listener.stop() + with self.assertRaises(ValueError): + # Make sure all tasks are done and .join won't block. + log_queue.task_done() + + +ZERO = datetime.timedelta(0) + +class UTC(datetime.tzinfo): + def utcoffset(self, dt): + return ZERO + + dst = utcoffset + + def tzname(self, dt): + return 'UTC' + +utc = UTC() + +class AssertErrorMessage: + + def assert_error_message(self, exception, message, *args, **kwargs): + try: + self.assertRaises((), *args, **kwargs) + except exception as e: + self.assertEqual(message, str(e)) + +class FormatterTest(unittest.TestCase, AssertErrorMessage): + def setUp(self): + self.common = { + 'name': 'formatter.test', + 'level': logging.DEBUG, + 'pathname': os.path.join('path', 'to', 'dummy.ext'), + 'lineno': 42, + 'exc_info': None, + 'func': None, + 'msg': 'Message with %d %s', + 'args': (2, 'placeholders'), + } + self.variants = { + 'custom': { + 'custom': 1234 + } + } + + def get_record(self, name=None): + result = dict(self.common) + if name is not None: + result.update(self.variants[name]) + return logging.makeLogRecord(result) + + def test_percent(self): + # Test %-formatting + r = self.get_record() + f = logging.Formatter('${%(message)s}') + self.assertEqual(f.format(r), '${Message with 2 placeholders}') + f = logging.Formatter('%(random)s') + self.assertRaises(ValueError, f.format, r) + self.assertFalse(f.usesTime()) + f = logging.Formatter('%(asctime)s') + self.assertTrue(f.usesTime()) + f = logging.Formatter('%(asctime)-15s') + self.assertTrue(f.usesTime()) + f = logging.Formatter('%(asctime)#15s') + self.assertTrue(f.usesTime()) + + def test_braces(self): + # Test {}-formatting + r = self.get_record() + f = logging.Formatter('$%{message}%$', style='{') + self.assertEqual(f.format(r), '$%Message with 2 placeholders%$') + f = logging.Formatter('{random}', style='{') + self.assertRaises(ValueError, f.format, r) + f = logging.Formatter("{message}", style='{') + self.assertFalse(f.usesTime()) + f = logging.Formatter('{asctime}', style='{') + self.assertTrue(f.usesTime()) + f = logging.Formatter('{asctime!s:15}', style='{') + self.assertTrue(f.usesTime()) + f = logging.Formatter('{asctime:15}', style='{') + self.assertTrue(f.usesTime()) + + def test_dollars(self): + # Test $-formatting + r = self.get_record() + f = logging.Formatter('${message}', style='$') + self.assertEqual(f.format(r), 'Message with 2 placeholders') + f = logging.Formatter('$message', style='$') + self.assertEqual(f.format(r), 'Message with 2 placeholders') + f = logging.Formatter('$$%${message}%$$', style='$') + self.assertEqual(f.format(r), '$%Message with 2 placeholders%$') + f = logging.Formatter('${random}', style='$') + self.assertRaises(ValueError, f.format, r) + self.assertFalse(f.usesTime()) + f = logging.Formatter('${asctime}', style='$') + self.assertTrue(f.usesTime()) + f = logging.Formatter('$asctime', style='$') + self.assertTrue(f.usesTime()) + f = logging.Formatter('${message}', style='$') + self.assertFalse(f.usesTime()) + f = logging.Formatter('${asctime}--', style='$') + self.assertTrue(f.usesTime()) + + def test_format_validate(self): + # Check correct formatting + # Percentage style + f = logging.Formatter("%(levelname)-15s - %(message) 5s - %(process)03d - %(module) - %(asctime)*.3s") + self.assertEqual(f._fmt, "%(levelname)-15s - %(message) 5s - %(process)03d - %(module) - %(asctime)*.3s") + f = logging.Formatter("%(asctime)*s - %(asctime)*.3s - %(process)-34.33o") + self.assertEqual(f._fmt, "%(asctime)*s - %(asctime)*.3s - %(process)-34.33o") + f = logging.Formatter("%(process)#+027.23X") + self.assertEqual(f._fmt, "%(process)#+027.23X") + f = logging.Formatter("%(foo)#.*g") + self.assertEqual(f._fmt, "%(foo)#.*g") + + # StrFormat Style + f = logging.Formatter("$%{message}%$ - {asctime!a:15} - {customfield['key']}", style="{") + self.assertEqual(f._fmt, "$%{message}%$ - {asctime!a:15} - {customfield['key']}") + f = logging.Formatter("{process:.2f} - {custom.f:.4f}", style="{") + self.assertEqual(f._fmt, "{process:.2f} - {custom.f:.4f}") + f = logging.Formatter("{customfield!s:#<30}", style="{") + self.assertEqual(f._fmt, "{customfield!s:#<30}") + f = logging.Formatter("{message!r}", style="{") + self.assertEqual(f._fmt, "{message!r}") + f = logging.Formatter("{message!s}", style="{") + self.assertEqual(f._fmt, "{message!s}") + f = logging.Formatter("{message!a}", style="{") + self.assertEqual(f._fmt, "{message!a}") + f = logging.Formatter("{process!r:4.2}", style="{") + self.assertEqual(f._fmt, "{process!r:4.2}") + f = logging.Formatter("{process!s:<#30,.12f}- {custom:=+#30,.1d} - {module:^30}", style="{") + self.assertEqual(f._fmt, "{process!s:<#30,.12f}- {custom:=+#30,.1d} - {module:^30}") + f = logging.Formatter("{process!s:{w},.{p}}", style="{") + self.assertEqual(f._fmt, "{process!s:{w},.{p}}") + f = logging.Formatter("{foo:12.{p}}", style="{") + self.assertEqual(f._fmt, "{foo:12.{p}}") + f = logging.Formatter("{foo:{w}.6}", style="{") + self.assertEqual(f._fmt, "{foo:{w}.6}") + f = logging.Formatter("{foo[0].bar[1].baz}", style="{") + self.assertEqual(f._fmt, "{foo[0].bar[1].baz}") + f = logging.Formatter("{foo[k1].bar[k2].baz}", style="{") + self.assertEqual(f._fmt, "{foo[k1].bar[k2].baz}") + f = logging.Formatter("{12[k1].bar[k2].baz}", style="{") + self.assertEqual(f._fmt, "{12[k1].bar[k2].baz}") + + # Dollar style + f = logging.Formatter("${asctime} - $message", style="$") + self.assertEqual(f._fmt, "${asctime} - $message") + f = logging.Formatter("$bar $$", style="$") + self.assertEqual(f._fmt, "$bar $$") + f = logging.Formatter("$bar $$$$", style="$") + self.assertEqual(f._fmt, "$bar $$$$") # this would print two $($$) + + # Testing when ValueError being raised from incorrect format + # Percentage Style + self.assertRaises(ValueError, logging.Formatter, "%(asctime)Z") + self.assertRaises(ValueError, logging.Formatter, "%(asctime)b") + self.assertRaises(ValueError, logging.Formatter, "%(asctime)*") + self.assertRaises(ValueError, logging.Formatter, "%(asctime)*3s") + self.assertRaises(ValueError, logging.Formatter, "%(asctime)_") + self.assertRaises(ValueError, logging.Formatter, '{asctime}') + self.assertRaises(ValueError, logging.Formatter, '${message}') + self.assertRaises(ValueError, logging.Formatter, '%(foo)#12.3*f') # with both * and decimal number as precision + self.assertRaises(ValueError, logging.Formatter, '%(foo)0*.8*f') + + # StrFormat Style + # Testing failure for '-' in field name + self.assert_error_message( + ValueError, + "invalid format: invalid field name/expression: 'name-thing'", + logging.Formatter, "{name-thing}", style="{" + ) + # Testing failure for style mismatch + self.assert_error_message( + ValueError, + "invalid format: no fields", + logging.Formatter, '%(asctime)s', style='{' + ) + # Testing failure for invalid conversion + self.assert_error_message( + ValueError, + "invalid conversion: 'Z'" + ) + self.assertRaises(ValueError, logging.Formatter, '{asctime!s:#30,15f}', style='{') + self.assert_error_message( + ValueError, + "invalid format: expected ':' after conversion specifier", + logging.Formatter, '{asctime!aa:15}', style='{' + ) + # Testing failure for invalid spec + self.assert_error_message( + ValueError, + "invalid format: bad specifier: '.2ff'", + logging.Formatter, '{process:.2ff}', style='{' + ) + self.assertRaises(ValueError, logging.Formatter, '{process:.2Z}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{process!s:<##30,12g}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{process!s:<#30#,12g}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{process!s:{{w}},{{p}}}', style='{') + # Testing failure for mismatch braces + self.assert_error_message( + ValueError, + "invalid format: expected '}' before end of string", + logging.Formatter, '{process', style='{' + ) + self.assert_error_message( + ValueError, + "invalid format: Single '}' encountered in format string", + logging.Formatter, 'process}', style='{' + ) + self.assertRaises(ValueError, logging.Formatter, '{{foo!r:4.2}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{{foo!r:4.2}}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{foo/bar}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{foo:{{w}}.{{p}}}}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{foo!X:{{w}}.{{p}}}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{foo!a:random}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{foo!a:ran{dom}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{foo!a:ran{d}om}', style='{') + self.assertRaises(ValueError, logging.Formatter, '{foo.!a:d}', style='{') + + # Dollar style + # Testing failure for mismatch bare $ + self.assert_error_message( + ValueError, + "invalid format: bare \'$\' not allowed", + logging.Formatter, '$bar $$$', style='$' + ) + self.assert_error_message( + ValueError, + "invalid format: bare \'$\' not allowed", + logging.Formatter, 'bar $', style='$' + ) + self.assert_error_message( + ValueError, + "invalid format: bare \'$\' not allowed", + logging.Formatter, 'foo $.', style='$' + ) + # Testing failure for mismatch style + self.assert_error_message( + ValueError, + "invalid format: no fields", + logging.Formatter, '{asctime}', style='$' + ) + self.assertRaises(ValueError, logging.Formatter, '%(asctime)s', style='$') + + # Testing failure for incorrect fields + self.assert_error_message( + ValueError, + "invalid format: no fields", + logging.Formatter, 'foo', style='$' + ) + self.assertRaises(ValueError, logging.Formatter, '${asctime', style='$') + + def test_defaults_parameter(self): + fmts = ['%(custom)s %(message)s', '{custom} {message}', '$custom $message'] + styles = ['%', '{', '$'] + for fmt, style in zip(fmts, styles): + f = logging.Formatter(fmt, style=style, defaults={'custom': 'Default'}) + r = self.get_record() + self.assertEqual(f.format(r), 'Default Message with 2 placeholders') + r = self.get_record("custom") + self.assertEqual(f.format(r), '1234 Message with 2 placeholders') + + # Without default + f = logging.Formatter(fmt, style=style) + r = self.get_record() + self.assertRaises(ValueError, f.format, r) + + # Non-existing default is ignored + f = logging.Formatter(fmt, style=style, defaults={'Non-existing': 'Default'}) + r = self.get_record("custom") + self.assertEqual(f.format(r), '1234 Message with 2 placeholders') + + def test_invalid_style(self): + self.assertRaises(ValueError, logging.Formatter, None, None, 'x') + + def test_time(self): + r = self.get_record() + dt = datetime.datetime(1993, 4, 21, 8, 3, 0, 0, utc) + # We use None to indicate we want the local timezone + # We're essentially converting a UTC time to local time + r.created = time.mktime(dt.astimezone(None).timetuple()) + r.msecs = 123 + f = logging.Formatter('%(asctime)s %(message)s') + f.converter = time.gmtime + self.assertEqual(f.formatTime(r), '1993-04-21 08:03:00,123') + self.assertEqual(f.formatTime(r, '%Y:%d'), '1993:21') + f.format(r) + self.assertEqual(r.asctime, '1993-04-21 08:03:00,123') + + def test_default_msec_format_none(self): + class NoMsecFormatter(logging.Formatter): + default_msec_format = None + default_time_format = '%d/%m/%Y %H:%M:%S' + + r = self.get_record() + dt = datetime.datetime(1993, 4, 21, 8, 3, 0, 123, utc) + r.created = time.mktime(dt.astimezone(None).timetuple()) + f = NoMsecFormatter() + f.converter = time.gmtime + self.assertEqual(f.formatTime(r), '21/04/1993 08:03:00') + + def test_issue_89047(self): + f = logging.Formatter(fmt='{asctime}.{msecs:03.0f} {message}', style='{', datefmt="%Y-%m-%d %H:%M:%S") + for i in range(2500): + time.sleep(0.0004) + r = logging.makeLogRecord({'msg': 'Message %d' % (i + 1)}) + s = f.format(r) + self.assertNotIn('.1000', s) + + +class TestBufferingFormatter(logging.BufferingFormatter): + def formatHeader(self, records): + return '[(%d)' % len(records) + + def formatFooter(self, records): + return '(%d)]' % len(records) + +class BufferingFormatterTest(unittest.TestCase): + def setUp(self): + self.records = [ + logging.makeLogRecord({'msg': 'one'}), + logging.makeLogRecord({'msg': 'two'}), + ] + + def test_default(self): + f = logging.BufferingFormatter() + self.assertEqual('', f.format([])) + self.assertEqual('onetwo', f.format(self.records)) + + def test_custom(self): + f = TestBufferingFormatter() + self.assertEqual('[(2)onetwo(2)]', f.format(self.records)) + lf = logging.Formatter('<%(message)s>') + f = TestBufferingFormatter(lf) + self.assertEqual('[(2)(2)]', f.format(self.records)) + +class ExceptionTest(BaseTest): + def test_formatting(self): + r = self.root_logger + h = RecordingHandler() + r.addHandler(h) + try: + raise RuntimeError('deliberate mistake') + except: + logging.exception('failed', stack_info=True) + r.removeHandler(h) + h.close() + r = h.records[0] + self.assertTrue(r.exc_text.startswith('Traceback (most recent ' + 'call last):\n')) + self.assertTrue(r.exc_text.endswith('\nRuntimeError: ' + 'deliberate mistake')) + self.assertTrue(r.stack_info.startswith('Stack (most recent ' + 'call last):\n')) + self.assertTrue(r.stack_info.endswith('logging.exception(\'failed\', ' + 'stack_info=True)')) + + +class LastResortTest(BaseTest): + def test_last_resort(self): + # Test the last resort handler + root = self.root_logger + root.removeHandler(self.root_hdlr) + old_lastresort = logging.lastResort + old_raise_exceptions = logging.raiseExceptions + + try: + with support.captured_stderr() as stderr: + root.debug('This should not appear') + self.assertEqual(stderr.getvalue(), '') + root.warning('Final chance!') + self.assertEqual(stderr.getvalue(), 'Final chance!\n') + + # No handlers and no last resort, so 'No handlers' message + logging.lastResort = None + with support.captured_stderr() as stderr: + root.warning('Final chance!') + msg = 'No handlers could be found for logger "root"\n' + self.assertEqual(stderr.getvalue(), msg) + + # 'No handlers' message only printed once + with support.captured_stderr() as stderr: + root.warning('Final chance!') + self.assertEqual(stderr.getvalue(), '') + + # If raiseExceptions is False, no message is printed + root.manager.emittedNoHandlerWarning = False + logging.raiseExceptions = False + with support.captured_stderr() as stderr: + root.warning('Final chance!') + self.assertEqual(stderr.getvalue(), '') + finally: + root.addHandler(self.root_hdlr) + logging.lastResort = old_lastresort + logging.raiseExceptions = old_raise_exceptions + + +class FakeHandler: + + def __init__(self, identifier, called): + for method in ('acquire', 'flush', 'close', 'release'): + setattr(self, method, self.record_call(identifier, method, called)) + + def record_call(self, identifier, method_name, called): + def inner(): + called.append('{} - {}'.format(identifier, method_name)) + return inner + + +class RecordingHandler(logging.NullHandler): + + def __init__(self, *args, **kwargs): + super(RecordingHandler, self).__init__(*args, **kwargs) + self.records = [] + + def handle(self, record): + """Keep track of all the emitted records.""" + self.records.append(record) + + +class ShutdownTest(BaseTest): + + """Test suite for the shutdown method.""" + + def setUp(self): + super(ShutdownTest, self).setUp() + self.called = [] + + raise_exceptions = logging.raiseExceptions + self.addCleanup(setattr, logging, 'raiseExceptions', raise_exceptions) + + def raise_error(self, error): + def inner(): + raise error() + return inner + + def test_no_failure(self): + # create some fake handlers + handler0 = FakeHandler(0, self.called) + handler1 = FakeHandler(1, self.called) + handler2 = FakeHandler(2, self.called) + + # create live weakref to those handlers + handlers = map(logging.weakref.ref, [handler0, handler1, handler2]) + + logging.shutdown(handlerList=list(handlers)) + + expected = ['2 - acquire', '2 - flush', '2 - close', '2 - release', + '1 - acquire', '1 - flush', '1 - close', '1 - release', + '0 - acquire', '0 - flush', '0 - close', '0 - release'] + self.assertEqual(expected, self.called) + + def _test_with_failure_in_method(self, method, error): + handler = FakeHandler(0, self.called) + setattr(handler, method, self.raise_error(error)) + handlers = [logging.weakref.ref(handler)] + + logging.shutdown(handlerList=list(handlers)) + + self.assertEqual('0 - release', self.called[-1]) + + def test_with_ioerror_in_acquire(self): + self._test_with_failure_in_method('acquire', OSError) + + def test_with_ioerror_in_flush(self): + self._test_with_failure_in_method('flush', OSError) + + def test_with_ioerror_in_close(self): + self._test_with_failure_in_method('close', OSError) + + def test_with_valueerror_in_acquire(self): + self._test_with_failure_in_method('acquire', ValueError) + + def test_with_valueerror_in_flush(self): + self._test_with_failure_in_method('flush', ValueError) + + def test_with_valueerror_in_close(self): + self._test_with_failure_in_method('close', ValueError) + + def test_with_other_error_in_acquire_without_raise(self): + logging.raiseExceptions = False + self._test_with_failure_in_method('acquire', IndexError) + + def test_with_other_error_in_flush_without_raise(self): + logging.raiseExceptions = False + self._test_with_failure_in_method('flush', IndexError) + + def test_with_other_error_in_close_without_raise(self): + logging.raiseExceptions = False + self._test_with_failure_in_method('close', IndexError) + + def test_with_other_error_in_acquire_with_raise(self): + logging.raiseExceptions = True + self.assertRaises(IndexError, self._test_with_failure_in_method, + 'acquire', IndexError) + + def test_with_other_error_in_flush_with_raise(self): + logging.raiseExceptions = True + self.assertRaises(IndexError, self._test_with_failure_in_method, + 'flush', IndexError) + + def test_with_other_error_in_close_with_raise(self): + logging.raiseExceptions = True + self.assertRaises(IndexError, self._test_with_failure_in_method, + 'close', IndexError) + + +class ModuleLevelMiscTest(BaseTest): + + """Test suite for some module level methods.""" + + def test_disable(self): + old_disable = logging.root.manager.disable + # confirm our assumptions are correct + self.assertEqual(old_disable, 0) + self.addCleanup(logging.disable, old_disable) + + logging.disable(83) + self.assertEqual(logging.root.manager.disable, 83) + + self.assertRaises(ValueError, logging.disable, "doesnotexists") + + class _NotAnIntOrString: + pass + + self.assertRaises(TypeError, logging.disable, _NotAnIntOrString()) + + logging.disable("WARN") + + # test the default value introduced in 3.7 + # (Issue #28524) + logging.disable() + self.assertEqual(logging.root.manager.disable, logging.CRITICAL) + + def _test_log(self, method, level=None): + called = [] + support.patch(self, logging, 'basicConfig', + lambda *a, **kw: called.append((a, kw))) + + recording = RecordingHandler() + logging.root.addHandler(recording) + + log_method = getattr(logging, method) + if level is not None: + log_method(level, "test me: %r", recording) + else: + log_method("test me: %r", recording) + + self.assertEqual(len(recording.records), 1) + record = recording.records[0] + self.assertEqual(record.getMessage(), "test me: %r" % recording) + + expected_level = level if level is not None else getattr(logging, method.upper()) + self.assertEqual(record.levelno, expected_level) + + # basicConfig was not called! + self.assertEqual(called, []) + + def test_log(self): + self._test_log('log', logging.ERROR) + + def test_debug(self): + self._test_log('debug') + + def test_info(self): + self._test_log('info') + + def test_warning(self): + self._test_log('warning') + + def test_error(self): + self._test_log('error') + + def test_critical(self): + self._test_log('critical') + + def test_set_logger_class(self): + self.assertRaises(TypeError, logging.setLoggerClass, object) + + class MyLogger(logging.Logger): + pass + + logging.setLoggerClass(MyLogger) + self.assertEqual(logging.getLoggerClass(), MyLogger) + + logging.setLoggerClass(logging.Logger) + self.assertEqual(logging.getLoggerClass(), logging.Logger) + + def test_subclass_logger_cache(self): + # bpo-37258 + message = [] + + class MyLogger(logging.getLoggerClass()): + def __init__(self, name='MyLogger', level=logging.NOTSET): + super().__init__(name, level) + message.append('initialized') + + logging.setLoggerClass(MyLogger) + logger = logging.getLogger('just_some_logger') + self.assertEqual(message, ['initialized']) + stream = io.StringIO() + h = logging.StreamHandler(stream) + logger.addHandler(h) + try: + logger.setLevel(logging.DEBUG) + logger.debug("hello") + self.assertEqual(stream.getvalue().strip(), "hello") + + stream.truncate(0) + stream.seek(0) + + logger.setLevel(logging.INFO) + logger.debug("hello") + self.assertEqual(stream.getvalue(), "") + finally: + logger.removeHandler(h) + h.close() + logging.setLoggerClass(logging.Logger) + + def test_logging_at_shutdown(self): + # bpo-20037: Doing text I/O late at interpreter shutdown must not crash + code = textwrap.dedent(""" + import logging + + class A: + def __del__(self): + try: + raise ValueError("some error") + except Exception: + logging.exception("exception in __del__") + + a = A() + """) + rc, out, err = assert_python_ok("-c", code) + err = err.decode() + self.assertIn("exception in __del__", err) + self.assertIn("ValueError: some error", err) + + def test_logging_at_shutdown_open(self): + # bpo-26789: FileHandler keeps a reference to the builtin open() + # function to be able to open or reopen the file during Python + # finalization. + filename = os_helper.TESTFN + self.addCleanup(os_helper.unlink, filename) + + code = textwrap.dedent(f""" + import builtins + import logging + + class A: + def __del__(self): + logging.error("log in __del__") + + # basicConfig() opens the file, but logging.shutdown() closes + # it at Python exit. When A.__del__() is called, + # FileHandler._open() must be called again to re-open the file. + logging.basicConfig(filename={filename!r}, encoding="utf-8") + + a = A() + + # Simulate the Python finalization which removes the builtin + # open() function. + del builtins.open + """) + assert_python_ok("-c", code) + + with open(filename, encoding="utf-8") as fp: + self.assertEqual(fp.read().rstrip(), "ERROR:root:log in __del__") + + def test_recursion_error(self): + # Issue 36272 + code = textwrap.dedent(""" + import logging + + def rec(): + logging.error("foo") + rec() + + rec() + """) + rc, out, err = assert_python_failure("-c", code) + err = err.decode() + self.assertNotIn("Cannot recover from stack overflow.", err) + self.assertEqual(rc, 1) + + def test_get_level_names_mapping(self): + mapping = logging.getLevelNamesMapping() + self.assertEqual(logging._nameToLevel, mapping) # value is equivalent + self.assertIsNot(logging._nameToLevel, mapping) # but not the internal data + new_mapping = logging.getLevelNamesMapping() # another call -> another copy + self.assertIsNot(mapping, new_mapping) # verify not the same object as before + self.assertEqual(mapping, new_mapping) # but equivalent in value + + +class LogRecordTest(BaseTest): + def test_str_rep(self): + r = logging.makeLogRecord({}) + s = str(r) + self.assertTrue(s.startswith('')) + + def test_dict_arg(self): + h = RecordingHandler() + r = logging.getLogger() + r.addHandler(h) + d = {'less' : 'more' } + logging.warning('less is %(less)s', d) + self.assertIs(h.records[0].args, d) + self.assertEqual(h.records[0].message, 'less is more') + r.removeHandler(h) + h.close() + + @staticmethod # pickled as target of child process in the following test + def _extract_logrecord_process_name(key, logMultiprocessing, conn=None): + prev_logMultiprocessing = logging.logMultiprocessing + logging.logMultiprocessing = logMultiprocessing + try: + import multiprocessing as mp + name = mp.current_process().name + + r1 = logging.makeLogRecord({'msg': f'msg1_{key}'}) + + # https://bugs.python.org/issue45128 + with support.swap_item(sys.modules, 'multiprocessing', None): + r2 = logging.makeLogRecord({'msg': f'msg2_{key}'}) + + results = {'processName' : name, + 'r1.processName': r1.processName, + 'r2.processName': r2.processName, + } + finally: + logging.logMultiprocessing = prev_logMultiprocessing + if conn: + conn.send(results) + else: + return results + + def test_multiprocessing(self): + support.skip_if_broken_multiprocessing_synchronize() + multiprocessing_imported = 'multiprocessing' in sys.modules + try: + # logMultiprocessing is True by default + self.assertEqual(logging.logMultiprocessing, True) + + LOG_MULTI_PROCESSING = True + # When logMultiprocessing == True: + # In the main process processName = 'MainProcess' + r = logging.makeLogRecord({}) + self.assertEqual(r.processName, 'MainProcess') + + results = self._extract_logrecord_process_name(1, LOG_MULTI_PROCESSING) + self.assertEqual('MainProcess', results['processName']) + self.assertEqual('MainProcess', results['r1.processName']) + self.assertEqual('MainProcess', results['r2.processName']) + + # In other processes, processName is correct when multiprocessing in imported, + # but it is (incorrectly) defaulted to 'MainProcess' otherwise (bpo-38762). + import multiprocessing + parent_conn, child_conn = multiprocessing.Pipe() + p = multiprocessing.Process( + target=self._extract_logrecord_process_name, + args=(2, LOG_MULTI_PROCESSING, child_conn,) + ) + p.start() + results = parent_conn.recv() + self.assertNotEqual('MainProcess', results['processName']) + self.assertEqual(results['processName'], results['r1.processName']) + self.assertEqual('MainProcess', results['r2.processName']) + p.join() + + finally: + if multiprocessing_imported: + import multiprocessing + + def test_optional(self): + NONE = self.assertIsNone + NOT_NONE = self.assertIsNotNone + + r = logging.makeLogRecord({}) + NOT_NONE(r.thread) + NOT_NONE(r.threadName) + NOT_NONE(r.process) + NOT_NONE(r.processName) + NONE(r.taskName) + log_threads = logging.logThreads + log_processes = logging.logProcesses + log_multiprocessing = logging.logMultiprocessing + log_asyncio_tasks = logging.logAsyncioTasks + try: + logging.logThreads = False + logging.logProcesses = False + logging.logMultiprocessing = False + logging.logAsyncioTasks = False + r = logging.makeLogRecord({}) + + NONE(r.thread) + NONE(r.threadName) + NONE(r.process) + NONE(r.processName) + NONE(r.taskName) + finally: + logging.logThreads = log_threads + logging.logProcesses = log_processes + logging.logMultiprocessing = log_multiprocessing + logging.logAsyncioTasks = log_asyncio_tasks + + async def _make_record_async(self, assertion): + r = logging.makeLogRecord({}) + assertion(r.taskName) + + @support.requires_working_socket() + def test_taskName_with_asyncio_imported(self): + try: + make_record = self._make_record_async + with asyncio.Runner() as runner: + logging.logAsyncioTasks = True + runner.run(make_record(self.assertIsNotNone)) + logging.logAsyncioTasks = False + runner.run(make_record(self.assertIsNone)) + finally: + asyncio.set_event_loop_policy(None) + + @support.requires_working_socket() + def test_taskName_without_asyncio_imported(self): + try: + make_record = self._make_record_async + with asyncio.Runner() as runner, support.swap_item(sys.modules, 'asyncio', None): + logging.logAsyncioTasks = True + runner.run(make_record(self.assertIsNone)) + logging.logAsyncioTasks = False + runner.run(make_record(self.assertIsNone)) + finally: + asyncio.set_event_loop_policy(None) + + +class BasicConfigTest(unittest.TestCase): + + """Test suite for logging.basicConfig.""" + + def setUp(self): + super(BasicConfigTest, self).setUp() + self.handlers = logging.root.handlers + self.saved_handlers = logging._handlers.copy() + self.saved_handler_list = logging._handlerList[:] + self.original_logging_level = logging.root.level + self.addCleanup(self.cleanup) + logging.root.handlers = [] + + def tearDown(self): + for h in logging.root.handlers[:]: + logging.root.removeHandler(h) + h.close() + super(BasicConfigTest, self).tearDown() + + def cleanup(self): + setattr(logging.root, 'handlers', self.handlers) + logging._handlers.clear() + logging._handlers.update(self.saved_handlers) + logging._handlerList[:] = self.saved_handler_list + logging.root.setLevel(self.original_logging_level) + + def test_no_kwargs(self): + logging.basicConfig() + + # handler defaults to a StreamHandler to sys.stderr + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.StreamHandler) + self.assertEqual(handler.stream, sys.stderr) + + formatter = handler.formatter + # format defaults to logging.BASIC_FORMAT + self.assertEqual(formatter._style._fmt, logging.BASIC_FORMAT) + # datefmt defaults to None + self.assertIsNone(formatter.datefmt) + # style defaults to % + self.assertIsInstance(formatter._style, logging.PercentStyle) + + # level is not explicitly set + self.assertEqual(logging.root.level, self.original_logging_level) + + def test_strformatstyle(self): + with support.captured_stdout() as output: + logging.basicConfig(stream=sys.stdout, style="{") + logging.error("Log an error") + sys.stdout.seek(0) + self.assertEqual(output.getvalue().strip(), + "ERROR:root:Log an error") + + def test_stringtemplatestyle(self): + with support.captured_stdout() as output: + logging.basicConfig(stream=sys.stdout, style="$") + logging.error("Log an error") + sys.stdout.seek(0) + self.assertEqual(output.getvalue().strip(), + "ERROR:root:Log an error") + + def test_filename(self): + + def cleanup(h1, h2, fn): + h1.close() + h2.close() + os.remove(fn) + + logging.basicConfig(filename='test.log', encoding='utf-8') + + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.FileHandler) + + expected = logging.FileHandler('test.log', 'a', encoding='utf-8') + self.assertEqual(handler.stream.mode, expected.stream.mode) + self.assertEqual(handler.stream.name, expected.stream.name) + self.addCleanup(cleanup, handler, expected, 'test.log') + + def test_filemode(self): + + def cleanup(h1, h2, fn): + h1.close() + h2.close() + os.remove(fn) + + logging.basicConfig(filename='test.log', filemode='wb') + + handler = logging.root.handlers[0] + expected = logging.FileHandler('test.log', 'wb') + self.assertEqual(handler.stream.mode, expected.stream.mode) + self.addCleanup(cleanup, handler, expected, 'test.log') + + def test_stream(self): + stream = io.StringIO() + self.addCleanup(stream.close) + logging.basicConfig(stream=stream) + + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.StreamHandler) + self.assertEqual(handler.stream, stream) + + def test_format(self): + logging.basicConfig(format='%(asctime)s - %(message)s') + + formatter = logging.root.handlers[0].formatter + self.assertEqual(formatter._style._fmt, '%(asctime)s - %(message)s') + + def test_datefmt(self): + logging.basicConfig(datefmt='bar') + + formatter = logging.root.handlers[0].formatter + self.assertEqual(formatter.datefmt, 'bar') + + def test_style(self): + logging.basicConfig(style='$') + + formatter = logging.root.handlers[0].formatter + self.assertIsInstance(formatter._style, logging.StringTemplateStyle) + + def test_level(self): + old_level = logging.root.level + self.addCleanup(logging.root.setLevel, old_level) + + logging.basicConfig(level=57) + self.assertEqual(logging.root.level, 57) + # Test that second call has no effect + logging.basicConfig(level=58) + self.assertEqual(logging.root.level, 57) + + def test_incompatible(self): + assertRaises = self.assertRaises + handlers = [logging.StreamHandler()] + stream = sys.stderr + assertRaises(ValueError, logging.basicConfig, filename='test.log', + stream=stream) + assertRaises(ValueError, logging.basicConfig, filename='test.log', + handlers=handlers) + assertRaises(ValueError, logging.basicConfig, stream=stream, + handlers=handlers) + # Issue 23207: test for invalid kwargs + assertRaises(ValueError, logging.basicConfig, loglevel=logging.INFO) + # Should pop both filename and filemode even if filename is None + logging.basicConfig(filename=None, filemode='a') + + def test_handlers(self): + handlers = [ + logging.StreamHandler(), + logging.StreamHandler(sys.stdout), + logging.StreamHandler(), + ] + f = logging.Formatter() + handlers[2].setFormatter(f) + logging.basicConfig(handlers=handlers) + self.assertIs(handlers[0], logging.root.handlers[0]) + self.assertIs(handlers[1], logging.root.handlers[1]) + self.assertIs(handlers[2], logging.root.handlers[2]) + self.assertIsNotNone(handlers[0].formatter) + self.assertIsNotNone(handlers[1].formatter) + self.assertIs(handlers[2].formatter, f) + self.assertIs(handlers[0].formatter, handlers[1].formatter) + + def test_force(self): + old_string_io = io.StringIO() + new_string_io = io.StringIO() + old_handlers = [logging.StreamHandler(old_string_io)] + new_handlers = [logging.StreamHandler(new_string_io)] + logging.basicConfig(level=logging.WARNING, handlers=old_handlers) + logging.warning('warn') + logging.info('info') + logging.debug('debug') + self.assertEqual(len(logging.root.handlers), 1) + logging.basicConfig(level=logging.INFO, handlers=new_handlers, + force=True) + logging.warning('warn') + logging.info('info') + logging.debug('debug') + self.assertEqual(len(logging.root.handlers), 1) + self.assertEqual(old_string_io.getvalue().strip(), + 'WARNING:root:warn') + self.assertEqual(new_string_io.getvalue().strip(), + 'WARNING:root:warn\nINFO:root:info') + + def test_encoding(self): + try: + encoding = 'utf-8' + logging.basicConfig(filename='test.log', encoding=encoding, + errors='strict', + format='%(message)s', level=logging.DEBUG) + + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.FileHandler) + self.assertEqual(handler.encoding, encoding) + logging.debug('The Øresund Bridge joins Copenhagen to Malmö') + finally: + handler.close() + with open('test.log', encoding='utf-8') as f: + data = f.read().strip() + os.remove('test.log') + self.assertEqual(data, + 'The Øresund Bridge joins Copenhagen to Malmö') + + def test_encoding_errors(self): + try: + encoding = 'ascii' + logging.basicConfig(filename='test.log', encoding=encoding, + errors='ignore', + format='%(message)s', level=logging.DEBUG) + + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.FileHandler) + self.assertEqual(handler.encoding, encoding) + logging.debug('The Øresund Bridge joins Copenhagen to Malmö') + finally: + handler.close() + with open('test.log', encoding='utf-8') as f: + data = f.read().strip() + os.remove('test.log') + self.assertEqual(data, 'The resund Bridge joins Copenhagen to Malm') + + def test_encoding_errors_default(self): + try: + encoding = 'ascii' + logging.basicConfig(filename='test.log', encoding=encoding, + format='%(message)s', level=logging.DEBUG) + + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.FileHandler) + self.assertEqual(handler.encoding, encoding) + self.assertEqual(handler.errors, 'backslashreplace') + logging.debug('😂: ☃️: The Øresund Bridge joins Copenhagen to Malmö') + finally: + handler.close() + with open('test.log', encoding='utf-8') as f: + data = f.read().strip() + os.remove('test.log') + self.assertEqual(data, r'\U0001f602: \u2603\ufe0f: The \xd8resund ' + r'Bridge joins Copenhagen to Malm\xf6') + + def test_encoding_errors_none(self): + # Specifying None should behave as 'strict' + try: + encoding = 'ascii' + logging.basicConfig(filename='test.log', encoding=encoding, + errors=None, + format='%(message)s', level=logging.DEBUG) + + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.FileHandler) + self.assertEqual(handler.encoding, encoding) + self.assertIsNone(handler.errors) + + message = [] + + def dummy_handle_error(record): + message.append(str(sys.exception())) + + handler.handleError = dummy_handle_error + logging.debug('The Øresund Bridge joins Copenhagen to Malmö') + self.assertTrue(message) + self.assertIn("'ascii' codec can't encode " + "character '\\xd8' in position 4:", message[0]) + finally: + handler.close() + with open('test.log', encoding='utf-8') as f: + data = f.read().strip() + os.remove('test.log') + # didn't write anything due to the encoding error + self.assertEqual(data, r'') + + @support.requires_working_socket() + def test_log_taskName(self): + async def log_record(): + logging.warning('hello world') + + handler = None + log_filename = make_temp_file('.log', 'test-logging-taskname-') + self.addCleanup(os.remove, log_filename) + try: + encoding = 'utf-8' + logging.basicConfig(filename=log_filename, errors='strict', + encoding=encoding, level=logging.WARNING, + format='%(taskName)s - %(message)s') + + self.assertEqual(len(logging.root.handlers), 1) + handler = logging.root.handlers[0] + self.assertIsInstance(handler, logging.FileHandler) + + with asyncio.Runner(debug=True) as runner: + logging.logAsyncioTasks = True + runner.run(log_record()) + with open(log_filename, encoding='utf-8') as f: + data = f.read().strip() + self.assertRegex(data, r'Task-\d+ - hello world') + finally: + asyncio.set_event_loop_policy(None) + if handler: + handler.close() + + + def _test_log(self, method, level=None): + # logging.root has no handlers so basicConfig should be called + called = [] + + old_basic_config = logging.basicConfig + def my_basic_config(*a, **kw): + old_basic_config() + old_level = logging.root.level + logging.root.setLevel(100) # avoid having messages in stderr + self.addCleanup(logging.root.setLevel, old_level) + called.append((a, kw)) + + support.patch(self, logging, 'basicConfig', my_basic_config) + + log_method = getattr(logging, method) + if level is not None: + log_method(level, "test me") + else: + log_method("test me") + + # basicConfig was called with no arguments + self.assertEqual(called, [((), {})]) + + def test_log(self): + self._test_log('log', logging.WARNING) + + def test_debug(self): + self._test_log('debug') + + def test_info(self): + self._test_log('info') + + def test_warning(self): + self._test_log('warning') + + def test_error(self): + self._test_log('error') + + def test_critical(self): + self._test_log('critical') + + +class LoggerAdapterTest(unittest.TestCase): + def setUp(self): + super(LoggerAdapterTest, self).setUp() + old_handler_list = logging._handlerList[:] + + self.recording = RecordingHandler() + self.logger = logging.root + self.logger.addHandler(self.recording) + self.addCleanup(self.logger.removeHandler, self.recording) + self.addCleanup(self.recording.close) + + def cleanup(): + logging._handlerList[:] = old_handler_list + + self.addCleanup(cleanup) + self.addCleanup(logging.shutdown) + self.adapter = logging.LoggerAdapter(logger=self.logger, extra=None) + + def test_exception(self): + msg = 'testing exception: %r' + exc = None + try: + 1 / 0 + except ZeroDivisionError as e: + exc = e + self.adapter.exception(msg, self.recording) + + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertEqual(record.levelno, logging.ERROR) + self.assertEqual(record.msg, msg) + self.assertEqual(record.args, (self.recording,)) + self.assertEqual(record.exc_info, + (exc.__class__, exc, exc.__traceback__)) + + def test_exception_excinfo(self): + try: + 1 / 0 + except ZeroDivisionError as e: + exc = e + + self.adapter.exception('exc_info test', exc_info=exc) + + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertEqual(record.exc_info, + (exc.__class__, exc, exc.__traceback__)) + + def test_critical(self): + msg = 'critical test! %r' + self.adapter.critical(msg, self.recording) + + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertEqual(record.levelno, logging.CRITICAL) + self.assertEqual(record.msg, msg) + self.assertEqual(record.args, (self.recording,)) + self.assertEqual(record.funcName, 'test_critical') + + def test_is_enabled_for(self): + old_disable = self.adapter.logger.manager.disable + self.adapter.logger.manager.disable = 33 + self.addCleanup(setattr, self.adapter.logger.manager, 'disable', + old_disable) + self.assertFalse(self.adapter.isEnabledFor(32)) + + def test_has_handlers(self): + self.assertTrue(self.adapter.hasHandlers()) + + for handler in self.logger.handlers: + self.logger.removeHandler(handler) + + self.assertFalse(self.logger.hasHandlers()) + self.assertFalse(self.adapter.hasHandlers()) + + def test_nested(self): + msg = 'Adapters can be nested, yo.' + adapter = PrefixAdapter(logger=self.logger, extra=None) + adapter_adapter = PrefixAdapter(logger=adapter, extra=None) + adapter_adapter.prefix = 'AdapterAdapter' + self.assertEqual(repr(adapter), repr(adapter_adapter)) + adapter_adapter.log(logging.CRITICAL, msg, self.recording) + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertEqual(record.levelno, logging.CRITICAL) + self.assertEqual(record.msg, f"Adapter AdapterAdapter {msg}") + self.assertEqual(record.args, (self.recording,)) + self.assertEqual(record.funcName, 'test_nested') + orig_manager = adapter_adapter.manager + self.assertIs(adapter.manager, orig_manager) + self.assertIs(self.logger.manager, orig_manager) + temp_manager = object() + try: + adapter_adapter.manager = temp_manager + self.assertIs(adapter_adapter.manager, temp_manager) + self.assertIs(adapter.manager, temp_manager) + self.assertIs(self.logger.manager, temp_manager) + finally: + adapter_adapter.manager = orig_manager + self.assertIs(adapter_adapter.manager, orig_manager) + self.assertIs(adapter.manager, orig_manager) + self.assertIs(self.logger.manager, orig_manager) + + def test_styled_adapter(self): + # Test an example from the Cookbook. + records = self.recording.records + adapter = StyleAdapter(self.logger) + adapter.warning('Hello, {}!', 'world') + self.assertEqual(str(records[-1].msg), 'Hello, world!') + self.assertEqual(records[-1].funcName, 'test_styled_adapter') + adapter.log(logging.WARNING, 'Goodbye {}.', 'world') + self.assertEqual(str(records[-1].msg), 'Goodbye world.') + self.assertEqual(records[-1].funcName, 'test_styled_adapter') + + def test_nested_styled_adapter(self): + records = self.recording.records + adapter = PrefixAdapter(self.logger) + adapter.prefix = '{}' + adapter2 = StyleAdapter(adapter) + adapter2.warning('Hello, {}!', 'world') + self.assertEqual(str(records[-1].msg), '{} Hello, world!') + self.assertEqual(records[-1].funcName, 'test_nested_styled_adapter') + adapter2.log(logging.WARNING, 'Goodbye {}.', 'world') + self.assertEqual(str(records[-1].msg), '{} Goodbye world.') + self.assertEqual(records[-1].funcName, 'test_nested_styled_adapter') + + def test_find_caller_with_stacklevel(self): + the_level = 1 + trigger = self.adapter.warning + + def innermost(): + trigger('test', stacklevel=the_level) + + def inner(): + innermost() + + def outer(): + inner() + + records = self.recording.records + outer() + self.assertEqual(records[-1].funcName, 'innermost') + lineno = records[-1].lineno + the_level += 1 + outer() + self.assertEqual(records[-1].funcName, 'inner') + self.assertGreater(records[-1].lineno, lineno) + lineno = records[-1].lineno + the_level += 1 + outer() + self.assertEqual(records[-1].funcName, 'outer') + self.assertGreater(records[-1].lineno, lineno) + lineno = records[-1].lineno + the_level += 1 + outer() + self.assertEqual(records[-1].funcName, 'test_find_caller_with_stacklevel') + self.assertGreater(records[-1].lineno, lineno) + + def test_extra_in_records(self): + self.adapter = logging.LoggerAdapter(logger=self.logger, + extra={'foo': '1'}) + + self.adapter.critical('foo should be here') + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertTrue(hasattr(record, 'foo')) + self.assertEqual(record.foo, '1') + + def test_extra_not_merged_by_default(self): + self.adapter.critical('foo should NOT be here', extra={'foo': 'nope'}) + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertFalse(hasattr(record, 'foo')) + + +class PrefixAdapter(logging.LoggerAdapter): + prefix = 'Adapter' + + def process(self, msg, kwargs): + return f"{self.prefix} {msg}", kwargs + + +class Message: + def __init__(self, fmt, args): + self.fmt = fmt + self.args = args + + def __str__(self): + return self.fmt.format(*self.args) + + +class StyleAdapter(logging.LoggerAdapter): + def log(self, level, msg, /, *args, stacklevel=1, **kwargs): + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, Message(msg, args), **kwargs, + stacklevel=stacklevel+1) + + +class LoggerTest(BaseTest, AssertErrorMessage): + + def setUp(self): + super(LoggerTest, self).setUp() + self.recording = RecordingHandler() + self.logger = logging.Logger(name='blah') + self.logger.addHandler(self.recording) + self.addCleanup(self.logger.removeHandler, self.recording) + self.addCleanup(self.recording.close) + self.addCleanup(logging.shutdown) + + def test_set_invalid_level(self): + self.assert_error_message( + TypeError, 'Level not an integer or a valid string: None', + self.logger.setLevel, None) + self.assert_error_message( + TypeError, 'Level not an integer or a valid string: (0, 0)', + self.logger.setLevel, (0, 0)) + + def test_exception(self): + msg = 'testing exception: %r' + exc = None + try: + 1 / 0 + except ZeroDivisionError as e: + exc = e + self.logger.exception(msg, self.recording) + + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertEqual(record.levelno, logging.ERROR) + self.assertEqual(record.msg, msg) + self.assertEqual(record.args, (self.recording,)) + self.assertEqual(record.exc_info, + (exc.__class__, exc, exc.__traceback__)) + + def test_log_invalid_level_with_raise(self): + with support.swap_attr(logging, 'raiseExceptions', True): + self.assertRaises(TypeError, self.logger.log, '10', 'test message') + + def test_log_invalid_level_no_raise(self): + with support.swap_attr(logging, 'raiseExceptions', False): + self.logger.log('10', 'test message') # no exception happens + + def test_find_caller_with_stack_info(self): + called = [] + support.patch(self, logging.traceback, 'print_stack', + lambda f, file: called.append(file.getvalue())) + + self.logger.findCaller(stack_info=True) + + self.assertEqual(len(called), 1) + self.assertEqual('Stack (most recent call last):\n', called[0]) + + def test_find_caller_with_stacklevel(self): + the_level = 1 + trigger = self.logger.warning + + def innermost(): + trigger('test', stacklevel=the_level) + + def inner(): + innermost() + + def outer(): + inner() + + records = self.recording.records + outer() + self.assertEqual(records[-1].funcName, 'innermost') + lineno = records[-1].lineno + the_level += 1 + outer() + self.assertEqual(records[-1].funcName, 'inner') + self.assertGreater(records[-1].lineno, lineno) + lineno = records[-1].lineno + the_level += 1 + outer() + self.assertEqual(records[-1].funcName, 'outer') + self.assertGreater(records[-1].lineno, lineno) + lineno = records[-1].lineno + root_logger = logging.getLogger() + root_logger.addHandler(self.recording) + trigger = logging.warning + outer() + self.assertEqual(records[-1].funcName, 'outer') + root_logger.removeHandler(self.recording) + trigger = self.logger.warning + the_level += 1 + outer() + self.assertEqual(records[-1].funcName, 'test_find_caller_with_stacklevel') + self.assertGreater(records[-1].lineno, lineno) + + def test_make_record_with_extra_overwrite(self): + name = 'my record' + level = 13 + fn = lno = msg = args = exc_info = func = sinfo = None + rv = logging._logRecordFactory(name, level, fn, lno, msg, args, + exc_info, func, sinfo) + + for key in ('message', 'asctime') + tuple(rv.__dict__.keys()): + extra = {key: 'some value'} + self.assertRaises(KeyError, self.logger.makeRecord, name, level, + fn, lno, msg, args, exc_info, + extra=extra, sinfo=sinfo) + + def test_make_record_with_extra_no_overwrite(self): + name = 'my record' + level = 13 + fn = lno = msg = args = exc_info = func = sinfo = None + extra = {'valid_key': 'some value'} + result = self.logger.makeRecord(name, level, fn, lno, msg, args, + exc_info, extra=extra, sinfo=sinfo) + self.assertIn('valid_key', result.__dict__) + + def test_has_handlers(self): + self.assertTrue(self.logger.hasHandlers()) + + for handler in self.logger.handlers: + self.logger.removeHandler(handler) + self.assertFalse(self.logger.hasHandlers()) + + def test_has_handlers_no_propagate(self): + child_logger = logging.getLogger('blah.child') + child_logger.propagate = False + self.assertFalse(child_logger.hasHandlers()) + + def test_is_enabled_for(self): + old_disable = self.logger.manager.disable + self.logger.manager.disable = 23 + self.addCleanup(setattr, self.logger.manager, 'disable', old_disable) + self.assertFalse(self.logger.isEnabledFor(22)) + + def test_is_enabled_for_disabled_logger(self): + old_disabled = self.logger.disabled + old_disable = self.logger.manager.disable + + self.logger.disabled = True + self.logger.manager.disable = 21 + + self.addCleanup(setattr, self.logger, 'disabled', old_disabled) + self.addCleanup(setattr, self.logger.manager, 'disable', old_disable) + + self.assertFalse(self.logger.isEnabledFor(22)) + + def test_root_logger_aliases(self): + root = logging.getLogger() + self.assertIs(root, logging.root) + self.assertIs(root, logging.getLogger(None)) + self.assertIs(root, logging.getLogger('')) + self.assertIs(root, logging.getLogger('root')) + self.assertIs(root, logging.getLogger('foo').root) + self.assertIs(root, logging.getLogger('foo.bar').root) + self.assertIs(root, logging.getLogger('foo').parent) + + self.assertIsNot(root, logging.getLogger('\0')) + self.assertIsNot(root, logging.getLogger('foo.bar').parent) + + def test_invalid_names(self): + self.assertRaises(TypeError, logging.getLogger, any) + self.assertRaises(TypeError, logging.getLogger, b'foo') + + def test_pickling(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + for name in ('', 'root', 'foo', 'foo.bar', 'baz.bar'): + logger = logging.getLogger(name) + s = pickle.dumps(logger, proto) + unpickled = pickle.loads(s) + self.assertIs(unpickled, logger) + + def test_caching(self): + root = self.root_logger + logger1 = logging.getLogger("abc") + logger2 = logging.getLogger("abc.def") + + # Set root logger level and ensure cache is empty + root.setLevel(logging.ERROR) + self.assertEqual(logger2.getEffectiveLevel(), logging.ERROR) + self.assertEqual(logger2._cache, {}) + + # Ensure cache is populated and calls are consistent + self.assertTrue(logger2.isEnabledFor(logging.ERROR)) + self.assertFalse(logger2.isEnabledFor(logging.DEBUG)) + self.assertEqual(logger2._cache, {logging.ERROR: True, logging.DEBUG: False}) + self.assertEqual(root._cache, {}) + self.assertTrue(logger2.isEnabledFor(logging.ERROR)) + + # Ensure root cache gets populated + self.assertEqual(root._cache, {}) + self.assertTrue(root.isEnabledFor(logging.ERROR)) + self.assertEqual(root._cache, {logging.ERROR: True}) + + # Set parent logger level and ensure caches are emptied + logger1.setLevel(logging.CRITICAL) + self.assertEqual(logger2.getEffectiveLevel(), logging.CRITICAL) + self.assertEqual(logger2._cache, {}) + + # Ensure logger2 uses parent logger's effective level + self.assertFalse(logger2.isEnabledFor(logging.ERROR)) + + # Set level to NOTSET and ensure caches are empty + logger2.setLevel(logging.NOTSET) + self.assertEqual(logger2.getEffectiveLevel(), logging.CRITICAL) + self.assertEqual(logger2._cache, {}) + self.assertEqual(logger1._cache, {}) + self.assertEqual(root._cache, {}) + + # Verify logger2 follows parent and not root + self.assertFalse(logger2.isEnabledFor(logging.ERROR)) + self.assertTrue(logger2.isEnabledFor(logging.CRITICAL)) + self.assertFalse(logger1.isEnabledFor(logging.ERROR)) + self.assertTrue(logger1.isEnabledFor(logging.CRITICAL)) + self.assertTrue(root.isEnabledFor(logging.ERROR)) + + # Disable logging in manager and ensure caches are clear + logging.disable() + self.assertEqual(logger2.getEffectiveLevel(), logging.CRITICAL) + self.assertEqual(logger2._cache, {}) + self.assertEqual(logger1._cache, {}) + self.assertEqual(root._cache, {}) + + # Ensure no loggers are enabled + self.assertFalse(logger1.isEnabledFor(logging.CRITICAL)) + self.assertFalse(logger2.isEnabledFor(logging.CRITICAL)) + self.assertFalse(root.isEnabledFor(logging.CRITICAL)) + + +class BaseFileTest(BaseTest): + "Base class for handler tests that write log files" + + def setUp(self): + BaseTest.setUp(self) + self.fn = make_temp_file(".log", "test_logging-2-") + self.rmfiles = [] + + def tearDown(self): + for fn in self.rmfiles: + os.unlink(fn) + if os.path.exists(self.fn): + os.unlink(self.fn) + BaseTest.tearDown(self) + + def assertLogFile(self, filename): + "Assert a log file is there and register it for deletion" + self.assertTrue(os.path.exists(filename), + msg="Log file %r does not exist" % filename) + self.rmfiles.append(filename) + + def next_rec(self): + return logging.LogRecord('n', logging.DEBUG, 'p', 1, + self.next_message(), None, None, None) + +class FileHandlerTest(BaseFileTest): + def test_delay(self): + os.unlink(self.fn) + fh = logging.FileHandler(self.fn, encoding='utf-8', delay=True) + self.assertIsNone(fh.stream) + self.assertFalse(os.path.exists(self.fn)) + fh.handle(logging.makeLogRecord({})) + self.assertIsNotNone(fh.stream) + self.assertTrue(os.path.exists(self.fn)) + fh.close() + + def test_emit_after_closing_in_write_mode(self): + # Issue #42378 + os.unlink(self.fn) + fh = logging.FileHandler(self.fn, encoding='utf-8', mode='w') + fh.setFormatter(logging.Formatter('%(message)s')) + fh.emit(self.next_rec()) # '1' + fh.close() + fh.emit(self.next_rec()) # '2' + with open(self.fn) as fp: + self.assertEqual(fp.read().strip(), '1') + +class RotatingFileHandlerTest(BaseFileTest): + def test_should_not_rollover(self): + # If file is empty rollover never occurs + rh = logging.handlers.RotatingFileHandler( + self.fn, encoding="utf-8", maxBytes=1) + self.assertFalse(rh.shouldRollover(None)) + rh.close() + + # If maxBytes is zero rollover never occurs + rh = logging.handlers.RotatingFileHandler( + self.fn, encoding="utf-8", maxBytes=0) + self.assertFalse(rh.shouldRollover(None)) + rh.close() + + with open(self.fn, 'wb') as f: + f.write(b'\n') + rh = logging.handlers.RotatingFileHandler( + self.fn, encoding="utf-8", maxBytes=0) + self.assertFalse(rh.shouldRollover(None)) + rh.close() + + @unittest.skipIf(support.is_wasi, "WASI does not have /dev/null.") + def test_should_not_rollover_non_file(self): + # bpo-45401 - test with special file + # We set maxBytes to 1 so that rollover would normally happen, except + # for the check for regular files + rh = logging.handlers.RotatingFileHandler( + os.devnull, encoding="utf-8", maxBytes=1) + self.assertFalse(rh.shouldRollover(self.next_rec())) + rh.close() + + def test_should_rollover(self): + with open(self.fn, 'wb') as f: + f.write(b'\n') + rh = logging.handlers.RotatingFileHandler(self.fn, encoding="utf-8", maxBytes=2) + self.assertTrue(rh.shouldRollover(self.next_rec())) + rh.close() + + def test_file_created(self): + # checks that the file is created and assumes it was created + # by us + os.unlink(self.fn) + rh = logging.handlers.RotatingFileHandler(self.fn, encoding="utf-8") + rh.emit(self.next_rec()) + self.assertLogFile(self.fn) + rh.close() + + def test_max_bytes(self, delay=False): + kwargs = {'delay': delay} if delay else {} + os.unlink(self.fn) + rh = logging.handlers.RotatingFileHandler( + self.fn, encoding="utf-8", backupCount=2, maxBytes=100, **kwargs) + self.assertIs(os.path.exists(self.fn), not delay) + small = logging.makeLogRecord({'msg': 'a'}) + large = logging.makeLogRecord({'msg': 'b'*100}) + self.assertFalse(rh.shouldRollover(small)) + self.assertFalse(rh.shouldRollover(large)) + rh.emit(small) + self.assertLogFile(self.fn) + self.assertFalse(os.path.exists(self.fn + ".1")) + self.assertFalse(rh.shouldRollover(small)) + self.assertTrue(rh.shouldRollover(large)) + rh.emit(large) + self.assertTrue(os.path.exists(self.fn)) + self.assertLogFile(self.fn + ".1") + self.assertFalse(os.path.exists(self.fn + ".2")) + self.assertTrue(rh.shouldRollover(small)) + self.assertTrue(rh.shouldRollover(large)) + rh.close() + + def test_max_bytes_delay(self): + self.test_max_bytes(delay=True) + + def test_rollover_filenames(self): + def namer(name): + return name + ".test" + rh = logging.handlers.RotatingFileHandler( + self.fn, encoding="utf-8", backupCount=2, maxBytes=1) + rh.namer = namer + rh.emit(self.next_rec()) + self.assertLogFile(self.fn) + self.assertFalse(os.path.exists(namer(self.fn + ".1"))) + rh.emit(self.next_rec()) + self.assertLogFile(namer(self.fn + ".1")) + self.assertFalse(os.path.exists(namer(self.fn + ".2"))) + rh.emit(self.next_rec()) + self.assertLogFile(namer(self.fn + ".2")) + self.assertFalse(os.path.exists(namer(self.fn + ".3"))) + rh.emit(self.next_rec()) + self.assertFalse(os.path.exists(namer(self.fn + ".3"))) + rh.close() + + def test_namer_rotator_inheritance(self): + class HandlerWithNamerAndRotator(logging.handlers.RotatingFileHandler): + def namer(self, name): + return name + ".test" + + def rotator(self, source, dest): + if os.path.exists(source): + os.replace(source, dest + ".rotated") + + rh = HandlerWithNamerAndRotator( + self.fn, encoding="utf-8", backupCount=2, maxBytes=1) + self.assertEqual(rh.namer(self.fn), self.fn + ".test") + rh.emit(self.next_rec()) + self.assertLogFile(self.fn) + rh.emit(self.next_rec()) + self.assertLogFile(rh.namer(self.fn + ".1") + ".rotated") + self.assertFalse(os.path.exists(rh.namer(self.fn + ".1"))) + rh.close() + + @support.requires_zlib() + def test_rotator(self): + def namer(name): + return name + ".gz" + + def rotator(source, dest): + with open(source, "rb") as sf: + data = sf.read() + compressed = zlib.compress(data, 9) + with open(dest, "wb") as df: + df.write(compressed) + os.remove(source) + + rh = logging.handlers.RotatingFileHandler( + self.fn, encoding="utf-8", backupCount=2, maxBytes=1) + rh.rotator = rotator + rh.namer = namer + m1 = self.next_rec() + rh.emit(m1) + self.assertLogFile(self.fn) + m2 = self.next_rec() + rh.emit(m2) + fn = namer(self.fn + ".1") + self.assertLogFile(fn) + newline = os.linesep + with open(fn, "rb") as f: + compressed = f.read() + data = zlib.decompress(compressed) + self.assertEqual(data.decode("ascii"), m1.msg + newline) + rh.emit(self.next_rec()) + fn = namer(self.fn + ".2") + self.assertLogFile(fn) + with open(fn, "rb") as f: + compressed = f.read() + data = zlib.decompress(compressed) + self.assertEqual(data.decode("ascii"), m1.msg + newline) + rh.emit(self.next_rec()) + fn = namer(self.fn + ".2") + with open(fn, "rb") as f: + compressed = f.read() + data = zlib.decompress(compressed) + self.assertEqual(data.decode("ascii"), m2.msg + newline) + self.assertFalse(os.path.exists(namer(self.fn + ".3"))) + rh.close() + +class TimedRotatingFileHandlerTest(BaseFileTest): + @unittest.skipIf(support.is_wasi, "WASI does not have /dev/null.") + def test_should_not_rollover(self): + # See bpo-45401. Should only ever rollover regular files + fh = logging.handlers.TimedRotatingFileHandler( + os.devnull, 'S', encoding="utf-8", backupCount=1) + time.sleep(1.1) # a little over a second ... + r = logging.makeLogRecord({'msg': 'testing - device file'}) + self.assertFalse(fh.shouldRollover(r)) + fh.close() + + # other test methods added below + def test_rollover(self): + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, 'S', encoding="utf-8", backupCount=1) + fmt = logging.Formatter('%(asctime)s %(message)s') + fh.setFormatter(fmt) + r1 = logging.makeLogRecord({'msg': 'testing - initial'}) + fh.emit(r1) + self.assertLogFile(self.fn) + time.sleep(1.1) # a little over a second ... + r2 = logging.makeLogRecord({'msg': 'testing - after delay'}) + fh.emit(r2) + fh.close() + # At this point, we should have a recent rotated file which we + # can test for the existence of. However, in practice, on some + # machines which run really slowly, we don't know how far back + # in time to go to look for the log file. So, we go back a fair + # bit, and stop as soon as we see a rotated file. In theory this + # could of course still fail, but the chances are lower. + found = False + now = datetime.datetime.now() + GO_BACK = 5 * 60 # seconds + for secs in range(GO_BACK): + prev = now - datetime.timedelta(seconds=secs) + fn = self.fn + prev.strftime(".%Y-%m-%d_%H-%M-%S") + found = os.path.exists(fn) + if found: + self.rmfiles.append(fn) + break + msg = 'No rotated files found, went back %d seconds' % GO_BACK + if not found: + # print additional diagnostics + dn, fn = os.path.split(self.fn) + files = [f for f in os.listdir(dn) if f.startswith(fn)] + print('Test time: %s' % now.strftime("%Y-%m-%d %H-%M-%S"), file=sys.stderr) + print('The only matching files are: %s' % files, file=sys.stderr) + for f in files: + print('Contents of %s:' % f) + path = os.path.join(dn, f) + with open(path, 'r') as tf: + print(tf.read()) + self.assertTrue(found, msg=msg) + + def test_rollover_at_midnight(self, weekly=False): + os_helper.unlink(self.fn) + now = datetime.datetime.now() + atTime = now.time() + if not 0.1 < atTime.microsecond/1e6 < 0.9: + # The test requires all records to be emitted within + # the range of the same whole second. + time.sleep((0.1 - atTime.microsecond/1e6) % 1.0) + now = datetime.datetime.now() + atTime = now.time() + atTime = atTime.replace(microsecond=0) + fmt = logging.Formatter('%(asctime)s %(message)s') + when = f'W{now.weekday()}' if weekly else 'MIDNIGHT' + for i in range(3): + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when=when, atTime=atTime) + fh.setFormatter(fmt) + r2 = logging.makeLogRecord({'msg': f'testing1 {i}'}) + fh.emit(r2) + fh.close() + self.assertLogFile(self.fn) + with open(self.fn, encoding="utf-8") as f: + for i, line in enumerate(f): + self.assertIn(f'testing1 {i}', line) + + os.utime(self.fn, (now.timestamp() - 1,)*2) + for i in range(2): + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when=when, atTime=atTime) + fh.setFormatter(fmt) + r2 = logging.makeLogRecord({'msg': f'testing2 {i}'}) + fh.emit(r2) + fh.close() + rolloverDate = now - datetime.timedelta(days=7 if weekly else 1) + otherfn = f'{self.fn}.{rolloverDate:%Y-%m-%d}' + self.assertLogFile(otherfn) + with open(self.fn, encoding="utf-8") as f: + for i, line in enumerate(f): + self.assertIn(f'testing2 {i}', line) + with open(otherfn, encoding="utf-8") as f: + for i, line in enumerate(f): + self.assertIn(f'testing1 {i}', line) + + def test_rollover_at_weekday(self): + self.test_rollover_at_midnight(weekly=True) + + def test_invalid(self): + assertRaises = self.assertRaises + assertRaises(ValueError, logging.handlers.TimedRotatingFileHandler, + self.fn, 'X', encoding="utf-8", delay=True) + assertRaises(ValueError, logging.handlers.TimedRotatingFileHandler, + self.fn, 'W', encoding="utf-8", delay=True) + assertRaises(ValueError, logging.handlers.TimedRotatingFileHandler, + self.fn, 'W7', encoding="utf-8", delay=True) + + # TODO: Test for utc=False. + def test_compute_rollover_daily_attime(self): + currentTime = 0 + rh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', + utc=True, atTime=None) + try: + actual = rh.computeRollover(currentTime) + self.assertEqual(actual, currentTime + 24 * 60 * 60) + + actual = rh.computeRollover(currentTime + 24 * 60 * 60 - 1) + self.assertEqual(actual, currentTime + 24 * 60 * 60) + + actual = rh.computeRollover(currentTime + 24 * 60 * 60) + self.assertEqual(actual, currentTime + 48 * 60 * 60) + + actual = rh.computeRollover(currentTime + 25 * 60 * 60) + self.assertEqual(actual, currentTime + 48 * 60 * 60) + finally: + rh.close() + + atTime = datetime.time(12, 0, 0) + rh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', + utc=True, atTime=atTime) + try: + actual = rh.computeRollover(currentTime) + self.assertEqual(actual, currentTime + 12 * 60 * 60) + + actual = rh.computeRollover(currentTime + 12 * 60 * 60 - 1) + self.assertEqual(actual, currentTime + 12 * 60 * 60) + + actual = rh.computeRollover(currentTime + 12 * 60 * 60) + self.assertEqual(actual, currentTime + 36 * 60 * 60) + + actual = rh.computeRollover(currentTime + 13 * 60 * 60) + self.assertEqual(actual, currentTime + 36 * 60 * 60) + finally: + rh.close() + + # TODO: Test for utc=False. + def test_compute_rollover_weekly_attime(self): + currentTime = int(time.time()) + today = currentTime - currentTime % 86400 + + atTime = datetime.time(12, 0, 0) + + wday = time.gmtime(today).tm_wday + for day in range(7): + rh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W%d' % day, interval=1, backupCount=0, + utc=True, atTime=atTime) + try: + if wday > day: + # The rollover day has already passed this week, so we + # go over into next week + expected = (7 - wday + day) + else: + expected = (day - wday) + # At this point expected is in days from now, convert to seconds + expected *= 24 * 60 * 60 + # Add in the rollover time + expected += 12 * 60 * 60 + # Add in adjustment for today + expected += today + + actual = rh.computeRollover(today) + if actual != expected: + print('failed in timezone: %d' % time.timezone) + print('local vars: %s' % locals()) + self.assertEqual(actual, expected) + + actual = rh.computeRollover(today + 12 * 60 * 60 - 1) + if actual != expected: + print('failed in timezone: %d' % time.timezone) + print('local vars: %s' % locals()) + self.assertEqual(actual, expected) + + if day == wday: + # goes into following week + expected += 7 * 24 * 60 * 60 + actual = rh.computeRollover(today + 12 * 60 * 60) + if actual != expected: + print('failed in timezone: %d' % time.timezone) + print('local vars: %s' % locals()) + self.assertEqual(actual, expected) + + actual = rh.computeRollover(today + 13 * 60 * 60) + if actual != expected: + print('failed in timezone: %d' % time.timezone) + print('local vars: %s' % locals()) + self.assertEqual(actual, expected) + finally: + rh.close() + + def test_compute_files_to_delete(self): + # See bpo-46063 for background + wd = tempfile.mkdtemp(prefix='test_logging_') + self.addCleanup(shutil.rmtree, wd) + times = [] + dt = datetime.datetime.now() + for i in range(10): + times.append(dt.strftime('%Y-%m-%d_%H-%M-%S')) + dt += datetime.timedelta(seconds=5) + prefixes = ('a.b', 'a.b.c', 'd.e', 'd.e.f', 'g') + files = [] + rotators = [] + for prefix in prefixes: + p = os.path.join(wd, '%s.log' % prefix) + rotator = logging.handlers.TimedRotatingFileHandler(p, when='s', + interval=5, + backupCount=7, + delay=True) + rotators.append(rotator) + if prefix.startswith('a.b'): + for t in times: + files.append('%s.log.%s' % (prefix, t)) + elif prefix.startswith('d.e'): + def namer(filename): + dirname, basename = os.path.split(filename) + basename = basename.replace('.log', '') + '.log' + return os.path.join(dirname, basename) + rotator.namer = namer + for t in times: + files.append('%s.%s.log' % (prefix, t)) + elif prefix == 'g': + def namer(filename): + dirname, basename = os.path.split(filename) + basename = 'g' + basename[6:] + '.oldlog' + return os.path.join(dirname, basename) + rotator.namer = namer + for t in times: + files.append('g%s.oldlog' % t) + # Create empty files + for fn in files: + p = os.path.join(wd, fn) + with open(p, 'wb') as f: + pass + # Now the checks that only the correct files are offered up for deletion + for i, prefix in enumerate(prefixes): + rotator = rotators[i] + candidates = rotator.getFilesToDelete() + self.assertEqual(len(candidates), 3, candidates) + if prefix.startswith('a.b'): + p = '%s.log.' % prefix + for c in candidates: + d, fn = os.path.split(c) + self.assertTrue(fn.startswith(p)) + elif prefix.startswith('d.e'): + for c in candidates: + d, fn = os.path.split(c) + self.assertTrue(fn.endswith('.log'), fn) + self.assertTrue(fn.startswith(prefix + '.') and + fn[len(prefix) + 2].isdigit()) + elif prefix == 'g': + for c in candidates: + d, fn = os.path.split(c) + self.assertTrue(fn.endswith('.oldlog')) + self.assertTrue(fn.startswith('g') and fn[1].isdigit()) + + def test_compute_files_to_delete_same_filename_different_extensions(self): + # See GH-93205 for background + wd = pathlib.Path(tempfile.mkdtemp(prefix='test_logging_')) + self.addCleanup(shutil.rmtree, wd) + times = [] + dt = datetime.datetime.now() + n_files = 10 + for _ in range(n_files): + times.append(dt.strftime('%Y-%m-%d_%H-%M-%S')) + dt += datetime.timedelta(seconds=5) + prefixes = ('a.log', 'a.log.b') + files = [] + rotators = [] + for i, prefix in enumerate(prefixes): + backupCount = i+1 + rotator = logging.handlers.TimedRotatingFileHandler(wd / prefix, when='s', + interval=5, + backupCount=backupCount, + delay=True) + rotators.append(rotator) + for t in times: + files.append('%s.%s' % (prefix, t)) + for t in times: + files.append('a.log.%s.c' % t) + # Create empty files + for f in files: + (wd / f).touch() + # Now the checks that only the correct files are offered up for deletion + for i, prefix in enumerate(prefixes): + backupCount = i+1 + rotator = rotators[i] + candidates = rotator.getFilesToDelete() + self.assertEqual(len(candidates), n_files - backupCount, candidates) + matcher = re.compile(r"^\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\Z") + for c in candidates: + d, fn = os.path.split(c) + self.assertTrue(fn.startswith(prefix+'.')) + suffix = fn[(len(prefix)+1):] + self.assertRegex(suffix, matcher) + + # Run with US-style DST rules: DST begins 2 a.m. on second Sunday in + # March (M3.2.0) and ends 2 a.m. on first Sunday in November (M11.1.0). + @support.run_with_tz('EST+05EDT,M3.2.0,M11.1.0') + def test_compute_rollover_MIDNIGHT_local(self): + # DST begins at 2012-3-11T02:00:00 and ends at 2012-11-4T02:00:00. + DT = datetime.datetime + def test(current, expected): + actual = fh.computeRollover(current.timestamp()) + diff = actual - expected.timestamp() + if diff: + self.assertEqual(diff, 0, datetime.timedelta(seconds=diff)) + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', utc=False) + + test(DT(2012, 3, 10, 23, 59, 59), DT(2012, 3, 11, 0, 0)) + test(DT(2012, 3, 11, 0, 0), DT(2012, 3, 12, 0, 0)) + test(DT(2012, 3, 11, 1, 0), DT(2012, 3, 12, 0, 0)) + + test(DT(2012, 11, 3, 23, 59, 59), DT(2012, 11, 4, 0, 0)) + test(DT(2012, 11, 4, 0, 0), DT(2012, 11, 5, 0, 0)) + test(DT(2012, 11, 4, 1, 0), DT(2012, 11, 5, 0, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', utc=False, + atTime=datetime.time(12, 0, 0)) + + test(DT(2012, 3, 10, 11, 59, 59), DT(2012, 3, 10, 12, 0)) + test(DT(2012, 3, 10, 12, 0), DT(2012, 3, 11, 12, 0)) + test(DT(2012, 3, 10, 13, 0), DT(2012, 3, 11, 12, 0)) + + test(DT(2012, 11, 3, 11, 59, 59), DT(2012, 11, 3, 12, 0)) + test(DT(2012, 11, 3, 12, 0), DT(2012, 11, 4, 12, 0)) + test(DT(2012, 11, 3, 13, 0), DT(2012, 11, 4, 12, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', utc=False, + atTime=datetime.time(2, 0, 0)) + + test(DT(2012, 3, 10, 1, 59, 59), DT(2012, 3, 10, 2, 0)) + # 2:00:00 is the same as 3:00:00 at 2012-3-11. + test(DT(2012, 3, 10, 2, 0), DT(2012, 3, 11, 3, 0)) + test(DT(2012, 3, 10, 3, 0), DT(2012, 3, 11, 3, 0)) + + test(DT(2012, 3, 11, 1, 59, 59), DT(2012, 3, 11, 3, 0)) + # No time between 2:00:00 and 3:00:00 at 2012-3-11. + test(DT(2012, 3, 11, 3, 0), DT(2012, 3, 12, 2, 0)) + test(DT(2012, 3, 11, 4, 0), DT(2012, 3, 12, 2, 0)) + + test(DT(2012, 11, 3, 1, 59, 59), DT(2012, 11, 3, 2, 0)) + test(DT(2012, 11, 3, 2, 0), DT(2012, 11, 4, 2, 0)) + test(DT(2012, 11, 3, 3, 0), DT(2012, 11, 4, 2, 0)) + + # 1:00:00-2:00:00 is repeated twice at 2012-11-4. + test(DT(2012, 11, 4, 1, 59, 59), DT(2012, 11, 4, 2, 0)) + test(DT(2012, 11, 4, 1, 59, 59, fold=1), DT(2012, 11, 4, 2, 0)) + test(DT(2012, 11, 4, 2, 0), DT(2012, 11, 5, 2, 0)) + test(DT(2012, 11, 4, 3, 0), DT(2012, 11, 5, 2, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', utc=False, + atTime=datetime.time(2, 30, 0)) + + test(DT(2012, 3, 10, 2, 29, 59), DT(2012, 3, 10, 2, 30)) + # No time 2:30:00 at 2012-3-11. + test(DT(2012, 3, 10, 2, 30), DT(2012, 3, 11, 3, 30)) + test(DT(2012, 3, 10, 3, 0), DT(2012, 3, 11, 3, 30)) + + test(DT(2012, 3, 11, 1, 59, 59), DT(2012, 3, 11, 3, 30)) + # No time between 2:00:00 and 3:00:00 at 2012-3-11. + test(DT(2012, 3, 11, 3, 0), DT(2012, 3, 12, 2, 30)) + test(DT(2012, 3, 11, 3, 30), DT(2012, 3, 12, 2, 30)) + + test(DT(2012, 11, 3, 2, 29, 59), DT(2012, 11, 3, 2, 30)) + test(DT(2012, 11, 3, 2, 30), DT(2012, 11, 4, 2, 30)) + test(DT(2012, 11, 3, 3, 0), DT(2012, 11, 4, 2, 30)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', utc=False, + atTime=datetime.time(1, 30, 0)) + + test(DT(2012, 3, 11, 1, 29, 59), DT(2012, 3, 11, 1, 30)) + test(DT(2012, 3, 11, 1, 30), DT(2012, 3, 12, 1, 30)) + test(DT(2012, 3, 11, 1, 59, 59), DT(2012, 3, 12, 1, 30)) + # No time between 2:00:00 and 3:00:00 at 2012-3-11. + test(DT(2012, 3, 11, 3, 0), DT(2012, 3, 12, 1, 30)) + test(DT(2012, 3, 11, 3, 30), DT(2012, 3, 12, 1, 30)) + + # 1:00:00-2:00:00 is repeated twice at 2012-11-4. + test(DT(2012, 11, 4, 1, 0), DT(2012, 11, 4, 1, 30)) + test(DT(2012, 11, 4, 1, 29, 59), DT(2012, 11, 4, 1, 30)) + test(DT(2012, 11, 4, 1, 30), DT(2012, 11, 5, 1, 30)) + test(DT(2012, 11, 4, 1, 59, 59), DT(2012, 11, 5, 1, 30)) + # It is weird, but the rollover date jumps back from 2012-11-5 + # to 2012-11-4. + test(DT(2012, 11, 4, 1, 0, fold=1), DT(2012, 11, 4, 1, 30, fold=1)) + test(DT(2012, 11, 4, 1, 29, 59, fold=1), DT(2012, 11, 4, 1, 30, fold=1)) + test(DT(2012, 11, 4, 1, 30, fold=1), DT(2012, 11, 5, 1, 30)) + test(DT(2012, 11, 4, 1, 59, 59, fold=1), DT(2012, 11, 5, 1, 30)) + test(DT(2012, 11, 4, 2, 0), DT(2012, 11, 5, 1, 30)) + test(DT(2012, 11, 4, 2, 30), DT(2012, 11, 5, 1, 30)) + + fh.close() + + # Run with US-style DST rules: DST begins 2 a.m. on second Sunday in + # March (M3.2.0) and ends 2 a.m. on first Sunday in November (M11.1.0). + @support.run_with_tz('EST+05EDT,M3.2.0,M11.1.0') + def test_compute_rollover_W6_local(self): + # DST begins at 2012-3-11T02:00:00 and ends at 2012-11-4T02:00:00. + DT = datetime.datetime + def test(current, expected): + actual = fh.computeRollover(current.timestamp()) + diff = actual - expected.timestamp() + if diff: + self.assertEqual(diff, 0, datetime.timedelta(seconds=diff)) + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False) + + test(DT(2012, 3, 4, 23, 59, 59), DT(2012, 3, 5, 0, 0)) + test(DT(2012, 3, 5, 0, 0), DT(2012, 3, 12, 0, 0)) + test(DT(2012, 3, 5, 1, 0), DT(2012, 3, 12, 0, 0)) + + test(DT(2012, 10, 28, 23, 59, 59), DT(2012, 10, 29, 0, 0)) + test(DT(2012, 10, 29, 0, 0), DT(2012, 11, 5, 0, 0)) + test(DT(2012, 10, 29, 1, 0), DT(2012, 11, 5, 0, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, + atTime=datetime.time(0, 0, 0)) + + test(DT(2012, 3, 10, 23, 59, 59), DT(2012, 3, 11, 0, 0)) + test(DT(2012, 3, 11, 0, 0), DT(2012, 3, 18, 0, 0)) + test(DT(2012, 3, 11, 1, 0), DT(2012, 3, 18, 0, 0)) + + test(DT(2012, 11, 3, 23, 59, 59), DT(2012, 11, 4, 0, 0)) + test(DT(2012, 11, 4, 0, 0), DT(2012, 11, 11, 0, 0)) + test(DT(2012, 11, 4, 1, 0), DT(2012, 11, 11, 0, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, + atTime=datetime.time(12, 0, 0)) + + test(DT(2012, 3, 4, 11, 59, 59), DT(2012, 3, 4, 12, 0)) + test(DT(2012, 3, 4, 12, 0), DT(2012, 3, 11, 12, 0)) + test(DT(2012, 3, 4, 13, 0), DT(2012, 3, 11, 12, 0)) + + test(DT(2012, 10, 28, 11, 59, 59), DT(2012, 10, 28, 12, 0)) + test(DT(2012, 10, 28, 12, 0), DT(2012, 11, 4, 12, 0)) + test(DT(2012, 10, 28, 13, 0), DT(2012, 11, 4, 12, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, + atTime=datetime.time(2, 0, 0)) + + test(DT(2012, 3, 4, 1, 59, 59), DT(2012, 3, 4, 2, 0)) + # 2:00:00 is the same as 3:00:00 at 2012-3-11. + test(DT(2012, 3, 4, 2, 0), DT(2012, 3, 11, 3, 0)) + test(DT(2012, 3, 4, 3, 0), DT(2012, 3, 11, 3, 0)) + + test(DT(2012, 3, 11, 1, 59, 59), DT(2012, 3, 11, 3, 0)) + # No time between 2:00:00 and 3:00:00 at 2012-3-11. + test(DT(2012, 3, 11, 3, 0), DT(2012, 3, 18, 2, 0)) + test(DT(2012, 3, 11, 4, 0), DT(2012, 3, 18, 2, 0)) + + test(DT(2012, 10, 28, 1, 59, 59), DT(2012, 10, 28, 2, 0)) + test(DT(2012, 10, 28, 2, 0), DT(2012, 11, 4, 2, 0)) + test(DT(2012, 10, 28, 3, 0), DT(2012, 11, 4, 2, 0)) + + # 1:00:00-2:00:00 is repeated twice at 2012-11-4. + test(DT(2012, 11, 4, 1, 59, 59), DT(2012, 11, 4, 2, 0)) + test(DT(2012, 11, 4, 1, 59, 59, fold=1), DT(2012, 11, 4, 2, 0)) + test(DT(2012, 11, 4, 2, 0), DT(2012, 11, 11, 2, 0)) + test(DT(2012, 11, 4, 3, 0), DT(2012, 11, 11, 2, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, + atTime=datetime.time(2, 30, 0)) + + test(DT(2012, 3, 4, 2, 29, 59), DT(2012, 3, 4, 2, 30)) + # No time 2:30:00 at 2012-3-11. + test(DT(2012, 3, 4, 2, 30), DT(2012, 3, 11, 3, 30)) + test(DT(2012, 3, 4, 3, 0), DT(2012, 3, 11, 3, 30)) + + test(DT(2012, 3, 11, 1, 59, 59), DT(2012, 3, 11, 3, 30)) + # No time between 2:00:00 and 3:00:00 at 2012-3-11. + test(DT(2012, 3, 11, 3, 0), DT(2012, 3, 18, 2, 30)) + test(DT(2012, 3, 11, 3, 30), DT(2012, 3, 18, 2, 30)) + + test(DT(2012, 10, 28, 2, 29, 59), DT(2012, 10, 28, 2, 30)) + test(DT(2012, 10, 28, 2, 30), DT(2012, 11, 4, 2, 30)) + test(DT(2012, 10, 28, 3, 0), DT(2012, 11, 4, 2, 30)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, + atTime=datetime.time(1, 30, 0)) + + test(DT(2012, 3, 11, 1, 29, 59), DT(2012, 3, 11, 1, 30)) + test(DT(2012, 3, 11, 1, 30), DT(2012, 3, 18, 1, 30)) + test(DT(2012, 3, 11, 1, 59, 59), DT(2012, 3, 18, 1, 30)) + # No time between 2:00:00 and 3:00:00 at 2012-3-11. + test(DT(2012, 3, 11, 3, 0), DT(2012, 3, 18, 1, 30)) + test(DT(2012, 3, 11, 3, 30), DT(2012, 3, 18, 1, 30)) + + # 1:00:00-2:00:00 is repeated twice at 2012-11-4. + test(DT(2012, 11, 4, 1, 0), DT(2012, 11, 4, 1, 30)) + test(DT(2012, 11, 4, 1, 29, 59), DT(2012, 11, 4, 1, 30)) + test(DT(2012, 11, 4, 1, 30), DT(2012, 11, 11, 1, 30)) + test(DT(2012, 11, 4, 1, 59, 59), DT(2012, 11, 11, 1, 30)) + # It is weird, but the rollover date jumps back from 2012-11-11 + # to 2012-11-4. + test(DT(2012, 11, 4, 1, 0, fold=1), DT(2012, 11, 4, 1, 30, fold=1)) + test(DT(2012, 11, 4, 1, 29, 59, fold=1), DT(2012, 11, 4, 1, 30, fold=1)) + test(DT(2012, 11, 4, 1, 30, fold=1), DT(2012, 11, 11, 1, 30)) + test(DT(2012, 11, 4, 1, 59, 59, fold=1), DT(2012, 11, 11, 1, 30)) + test(DT(2012, 11, 4, 2, 0), DT(2012, 11, 11, 1, 30)) + test(DT(2012, 11, 4, 2, 30), DT(2012, 11, 11, 1, 30)) + + fh.close() + + # Run with US-style DST rules: DST begins 2 a.m. on second Sunday in + # March (M3.2.0) and ends 2 a.m. on first Sunday in November (M11.1.0). + @support.run_with_tz('EST+05EDT,M3.2.0,M11.1.0') + def test_compute_rollover_MIDNIGHT_local_interval(self): + # DST begins at 2012-3-11T02:00:00 and ends at 2012-11-4T02:00:00. + DT = datetime.datetime + def test(current, expected): + actual = fh.computeRollover(current.timestamp()) + diff = actual - expected.timestamp() + if diff: + self.assertEqual(diff, 0, datetime.timedelta(seconds=diff)) + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', utc=False, interval=3) + + test(DT(2012, 3, 8, 23, 59, 59), DT(2012, 3, 11, 0, 0)) + test(DT(2012, 3, 9, 0, 0), DT(2012, 3, 12, 0, 0)) + test(DT(2012, 3, 9, 1, 0), DT(2012, 3, 12, 0, 0)) + test(DT(2012, 3, 10, 23, 59, 59), DT(2012, 3, 13, 0, 0)) + test(DT(2012, 3, 11, 0, 0), DT(2012, 3, 14, 0, 0)) + test(DT(2012, 3, 11, 1, 0), DT(2012, 3, 14, 0, 0)) + + test(DT(2012, 11, 1, 23, 59, 59), DT(2012, 11, 4, 0, 0)) + test(DT(2012, 11, 2, 0, 0), DT(2012, 11, 5, 0, 0)) + test(DT(2012, 11, 2, 1, 0), DT(2012, 11, 5, 0, 0)) + test(DT(2012, 11, 3, 23, 59, 59), DT(2012, 11, 6, 0, 0)) + test(DT(2012, 11, 4, 0, 0), DT(2012, 11, 7, 0, 0)) + test(DT(2012, 11, 4, 1, 0), DT(2012, 11, 7, 0, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='MIDNIGHT', utc=False, interval=3, + atTime=datetime.time(12, 0, 0)) + + test(DT(2012, 3, 8, 11, 59, 59), DT(2012, 3, 10, 12, 0)) + test(DT(2012, 3, 8, 12, 0), DT(2012, 3, 11, 12, 0)) + test(DT(2012, 3, 8, 13, 0), DT(2012, 3, 11, 12, 0)) + test(DT(2012, 3, 10, 11, 59, 59), DT(2012, 3, 12, 12, 0)) + test(DT(2012, 3, 10, 12, 0), DT(2012, 3, 13, 12, 0)) + test(DT(2012, 3, 10, 13, 0), DT(2012, 3, 13, 12, 0)) + + test(DT(2012, 11, 1, 11, 59, 59), DT(2012, 11, 3, 12, 0)) + test(DT(2012, 11, 1, 12, 0), DT(2012, 11, 4, 12, 0)) + test(DT(2012, 11, 1, 13, 0), DT(2012, 11, 4, 12, 0)) + test(DT(2012, 11, 3, 11, 59, 59), DT(2012, 11, 5, 12, 0)) + test(DT(2012, 11, 3, 12, 0), DT(2012, 11, 6, 12, 0)) + test(DT(2012, 11, 3, 13, 0), DT(2012, 11, 6, 12, 0)) + + fh.close() + + # Run with US-style DST rules: DST begins 2 a.m. on second Sunday in + # March (M3.2.0) and ends 2 a.m. on first Sunday in November (M11.1.0). + @support.run_with_tz('EST+05EDT,M3.2.0,M11.1.0') + def test_compute_rollover_W6_local_interval(self): + # DST begins at 2012-3-11T02:00:00 and ends at 2012-11-4T02:00:00. + DT = datetime.datetime + def test(current, expected): + actual = fh.computeRollover(current.timestamp()) + diff = actual - expected.timestamp() + if diff: + self.assertEqual(diff, 0, datetime.timedelta(seconds=diff)) + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, interval=3) + + test(DT(2012, 2, 19, 23, 59, 59), DT(2012, 3, 5, 0, 0)) + test(DT(2012, 2, 20, 0, 0), DT(2012, 3, 12, 0, 0)) + test(DT(2012, 2, 20, 1, 0), DT(2012, 3, 12, 0, 0)) + test(DT(2012, 3, 4, 23, 59, 59), DT(2012, 3, 19, 0, 0)) + test(DT(2012, 3, 5, 0, 0), DT(2012, 3, 26, 0, 0)) + test(DT(2012, 3, 5, 1, 0), DT(2012, 3, 26, 0, 0)) + + test(DT(2012, 10, 14, 23, 59, 59), DT(2012, 10, 29, 0, 0)) + test(DT(2012, 10, 15, 0, 0), DT(2012, 11, 5, 0, 0)) + test(DT(2012, 10, 15, 1, 0), DT(2012, 11, 5, 0, 0)) + test(DT(2012, 10, 28, 23, 59, 59), DT(2012, 11, 12, 0, 0)) + test(DT(2012, 10, 29, 0, 0), DT(2012, 11, 19, 0, 0)) + test(DT(2012, 10, 29, 1, 0), DT(2012, 11, 19, 0, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, interval=3, + atTime=datetime.time(0, 0, 0)) + + test(DT(2012, 2, 25, 23, 59, 59), DT(2012, 3, 11, 0, 0)) + test(DT(2012, 2, 26, 0, 0), DT(2012, 3, 18, 0, 0)) + test(DT(2012, 2, 26, 1, 0), DT(2012, 3, 18, 0, 0)) + test(DT(2012, 3, 10, 23, 59, 59), DT(2012, 3, 25, 0, 0)) + test(DT(2012, 3, 11, 0, 0), DT(2012, 4, 1, 0, 0)) + test(DT(2012, 3, 11, 1, 0), DT(2012, 4, 1, 0, 0)) + + test(DT(2012, 10, 20, 23, 59, 59), DT(2012, 11, 4, 0, 0)) + test(DT(2012, 10, 21, 0, 0), DT(2012, 11, 11, 0, 0)) + test(DT(2012, 10, 21, 1, 0), DT(2012, 11, 11, 0, 0)) + test(DT(2012, 11, 3, 23, 59, 59), DT(2012, 11, 18, 0, 0)) + test(DT(2012, 11, 4, 0, 0), DT(2012, 11, 25, 0, 0)) + test(DT(2012, 11, 4, 1, 0), DT(2012, 11, 25, 0, 0)) + + fh.close() + + fh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when='W6', utc=False, interval=3, + atTime=datetime.time(12, 0, 0)) + + test(DT(2012, 2, 18, 11, 59, 59), DT(2012, 3, 4, 12, 0)) + test(DT(2012, 2, 19, 12, 0), DT(2012, 3, 11, 12, 0)) + test(DT(2012, 2, 19, 13, 0), DT(2012, 3, 11, 12, 0)) + test(DT(2012, 3, 4, 11, 59, 59), DT(2012, 3, 18, 12, 0)) + test(DT(2012, 3, 4, 12, 0), DT(2012, 3, 25, 12, 0)) + test(DT(2012, 3, 4, 13, 0), DT(2012, 3, 25, 12, 0)) + + test(DT(2012, 10, 14, 11, 59, 59), DT(2012, 10, 28, 12, 0)) + test(DT(2012, 10, 14, 12, 0), DT(2012, 11, 4, 12, 0)) + test(DT(2012, 10, 14, 13, 0), DT(2012, 11, 4, 12, 0)) + test(DT(2012, 10, 28, 11, 59, 59), DT(2012, 11, 11, 12, 0)) + test(DT(2012, 10, 28, 12, 0), DT(2012, 11, 18, 12, 0)) + test(DT(2012, 10, 28, 13, 0), DT(2012, 11, 18, 12, 0)) + + fh.close() + + +def secs(**kw): + return datetime.timedelta(**kw) // datetime.timedelta(seconds=1) + +for when, exp in (('S', 1), + ('M', 60), + ('H', 60 * 60), + ('D', 60 * 60 * 24), + ('MIDNIGHT', 60 * 60 * 24), + # current time (epoch start) is a Thursday, W0 means Monday + ('W0', secs(days=4, hours=24)), + ): + for interval in 1, 3: + def test_compute_rollover(self, when=when, interval=interval, exp=exp): + rh = logging.handlers.TimedRotatingFileHandler( + self.fn, encoding="utf-8", when=when, interval=interval, backupCount=0, utc=True) + currentTime = 0.0 + actual = rh.computeRollover(currentTime) + if when.startswith('W'): + exp += secs(days=7*(interval-1)) + else: + exp *= interval + if exp != actual: + # Failures occur on some systems for MIDNIGHT and W0. + # Print detailed calculation for MIDNIGHT so we can try to see + # what's going on + if when == 'MIDNIGHT': + try: + if rh.utc: + t = time.gmtime(currentTime) + else: + t = time.localtime(currentTime) + currentHour = t[3] + currentMinute = t[4] + currentSecond = t[5] + # r is the number of seconds left between now and midnight + r = logging.handlers._MIDNIGHT - ((currentHour * 60 + + currentMinute) * 60 + + currentSecond) + result = currentTime + r + print('t: %s (%s)' % (t, rh.utc), file=sys.stderr) + print('currentHour: %s' % currentHour, file=sys.stderr) + print('currentMinute: %s' % currentMinute, file=sys.stderr) + print('currentSecond: %s' % currentSecond, file=sys.stderr) + print('r: %s' % r, file=sys.stderr) + print('result: %s' % result, file=sys.stderr) + except Exception as e: + print('exception in diagnostic code: %s' % e, file=sys.stderr) + self.assertEqual(exp, actual) + rh.close() + name = "test_compute_rollover_%s" % when + if interval > 1: + name += "_interval" + test_compute_rollover.__name__ = name + setattr(TimedRotatingFileHandlerTest, name, test_compute_rollover) + + +@unittest.skipUnless(win32evtlog, 'win32evtlog/win32evtlogutil/pywintypes required for this test.') +class NTEventLogHandlerTest(BaseTest): + def test_basic(self): + logtype = 'Application' + elh = win32evtlog.OpenEventLog(None, logtype) + num_recs = win32evtlog.GetNumberOfEventLogRecords(elh) + + try: + h = logging.handlers.NTEventLogHandler('test_logging') + except pywintypes.error as e: + if e.winerror == 5: # access denied + raise unittest.SkipTest('Insufficient privileges to run test') + raise + + r = logging.makeLogRecord({'msg': 'Test Log Message'}) + h.handle(r) + h.close() + # Now see if the event is recorded + self.assertLess(num_recs, win32evtlog.GetNumberOfEventLogRecords(elh)) + flags = win32evtlog.EVENTLOG_BACKWARDS_READ | \ + win32evtlog.EVENTLOG_SEQUENTIAL_READ + found = False + GO_BACK = 100 + events = win32evtlog.ReadEventLog(elh, flags, GO_BACK) + for e in events: + if e.SourceName != 'test_logging': + continue + msg = win32evtlogutil.SafeFormatMessage(e, logtype) + if msg != 'Test Log Message\r\n': + continue + found = True + break + msg = 'Record not found in event log, went back %d records' % GO_BACK + self.assertTrue(found, msg=msg) + + +class MiscTestCase(unittest.TestCase): + def test__all__(self): + not_exported = { + 'logThreads', 'logMultiprocessing', 'logProcesses', 'currentframe', + 'PercentStyle', 'StrFormatStyle', 'StringTemplateStyle', + 'Filterer', 'PlaceHolder', 'Manager', 'RootLogger', 'root', + 'threading', 'logAsyncioTasks'} + support.check__all__(self, logging, not_exported=not_exported) + + +# Set the locale to the platform-dependent default. I have no idea +# why the test does this, but in any case we save the current locale +# first and restore it at the end. +def setUpModule(): + unittest.enterModuleContext(support.run_with_locale('LC_ALL', '')) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py new file mode 100644 index 0000000000..b5508a833e --- /dev/null +++ b/Lib/test/test_smtplib.py @@ -0,0 +1,1560 @@ +import base64 +import email.mime.text +from email.message import EmailMessage +from email.base64mime import body_encode as encode_base64 +import email.utils +import hashlib +import hmac +import socket +import smtplib +import io +import re +import sys +import time +import select +import errno +import textwrap +import threading + +import unittest +from test import support, mock_socket +from test.support import hashlib_helper +from test.support import socket_helper +from test.support import threading_helper +from test.support import asyncore +from test.support import smtpd +from unittest.mock import Mock + + +support.requires_working_socket(module=True) + +HOST = socket_helper.HOST + +if sys.platform == 'darwin': + # select.poll returns a select.POLLHUP at the end of the tests + # on darwin, so just ignore it + def handle_expt(self): + pass + smtpd.SMTPChannel.handle_expt = handle_expt + + +def server(evt, buf, serv): + serv.listen() + evt.set() + try: + conn, addr = serv.accept() + except TimeoutError: + pass + else: + n = 500 + while buf and n > 0: + r, w, e = select.select([], [conn], []) + if w: + sent = conn.send(buf) + buf = buf[sent:] + + n -= 1 + + conn.close() + finally: + serv.close() + evt.set() + +class GeneralTests: + + def setUp(self): + smtplib.socket = mock_socket + self.port = 25 + + def tearDown(self): + smtplib.socket = socket + + # This method is no longer used but is retained for backward compatibility, + # so test to make sure it still works. + def testQuoteData(self): + teststr = "abc\n.jkl\rfoo\r\n..blue" + expected = "abc\r\n..jkl\r\nfoo\r\n...blue" + self.assertEqual(expected, smtplib.quotedata(teststr)) + + def testBasic1(self): + mock_socket.reply_with(b"220 Hola mundo") + # connects + client = self.client(HOST, self.port) + client.close() + + def testSourceAddress(self): + mock_socket.reply_with(b"220 Hola mundo") + # connects + client = self.client(HOST, self.port, + source_address=('127.0.0.1',19876)) + self.assertEqual(client.source_address, ('127.0.0.1', 19876)) + client.close() + + def testBasic2(self): + mock_socket.reply_with(b"220 Hola mundo") + # connects, include port in host name + client = self.client("%s:%s" % (HOST, self.port)) + client.close() + + def testLocalHostName(self): + mock_socket.reply_with(b"220 Hola mundo") + # check that supplied local_hostname is used + client = self.client(HOST, self.port, local_hostname="testhost") + self.assertEqual(client.local_hostname, "testhost") + client.close() + + def testTimeoutDefault(self): + mock_socket.reply_with(b"220 Hola mundo") + self.assertIsNone(mock_socket.getdefaulttimeout()) + mock_socket.setdefaulttimeout(30) + self.assertEqual(mock_socket.getdefaulttimeout(), 30) + try: + client = self.client(HOST, self.port) + finally: + mock_socket.setdefaulttimeout(None) + self.assertEqual(client.sock.gettimeout(), 30) + client.close() + + def testTimeoutNone(self): + mock_socket.reply_with(b"220 Hola mundo") + self.assertIsNone(socket.getdefaulttimeout()) + socket.setdefaulttimeout(30) + try: + client = self.client(HOST, self.port, timeout=None) + finally: + socket.setdefaulttimeout(None) + self.assertIsNone(client.sock.gettimeout()) + client.close() + + def testTimeoutZero(self): + mock_socket.reply_with(b"220 Hola mundo") + with self.assertRaises(ValueError): + self.client(HOST, self.port, timeout=0) + + def testTimeoutValue(self): + mock_socket.reply_with(b"220 Hola mundo") + client = self.client(HOST, self.port, timeout=30) + self.assertEqual(client.sock.gettimeout(), 30) + client.close() + + def test_debuglevel(self): + mock_socket.reply_with(b"220 Hello world") + client = self.client() + client.set_debuglevel(1) + with support.captured_stderr() as stderr: + client.connect(HOST, self.port) + client.close() + expected = re.compile(r"^connect:", re.MULTILINE) + self.assertRegex(stderr.getvalue(), expected) + + def test_debuglevel_2(self): + mock_socket.reply_with(b"220 Hello world") + client = self.client() + client.set_debuglevel(2) + with support.captured_stderr() as stderr: + client.connect(HOST, self.port) + client.close() + expected = re.compile(r"^\d{2}:\d{2}:\d{2}\.\d{6} connect: ", + re.MULTILINE) + self.assertRegex(stderr.getvalue(), expected) + + +class SMTPGeneralTests(GeneralTests, unittest.TestCase): + + client = smtplib.SMTP + + +class LMTPGeneralTests(GeneralTests, unittest.TestCase): + + client = smtplib.LMTP + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), "test requires Unix domain socket") + def testUnixDomainSocketTimeoutDefault(self): + local_host = '/some/local/lmtp/delivery/program' + mock_socket.reply_with(b"220 Hello world") + try: + client = self.client(local_host, self.port) + finally: + mock_socket.setdefaulttimeout(None) + self.assertIsNone(client.sock.gettimeout()) + client.close() + + def testTimeoutZero(self): + super().testTimeoutZero() + local_host = '/some/local/lmtp/delivery/program' + with self.assertRaises(ValueError): + self.client(local_host, timeout=0) + +# Test server thread using the specified SMTP server class +def debugging_server(serv, serv_evt, client_evt): + serv_evt.set() + + try: + if hasattr(select, 'poll'): + poll_fun = asyncore.poll2 + else: + poll_fun = asyncore.poll + + n = 1000 + while asyncore.socket_map and n > 0: + poll_fun(0.01, asyncore.socket_map) + + # when the client conversation is finished, it will + # set client_evt, and it's then ok to kill the server + if client_evt.is_set(): + serv.close() + break + + n -= 1 + + except TimeoutError: + pass + finally: + if not client_evt.is_set(): + # allow some time for the client to read the result + time.sleep(0.5) + serv.close() + asyncore.close_all() + serv_evt.set() + +MSG_BEGIN = '---------- MESSAGE FOLLOWS ----------\n' +MSG_END = '------------ END MESSAGE ------------\n' + +# NOTE: Some SMTP objects in the tests below are created with a non-default +# local_hostname argument to the constructor, since (on some systems) the FQDN +# lookup caused by the default local_hostname sometimes takes so long that the +# test server times out, causing the test to fail. + +# Test behavior of smtpd.DebuggingServer +class DebuggingServerTests(unittest.TestCase): + + maxDiff = None + + def setUp(self): + self.thread_key = threading_helper.threading_setup() + self.real_getfqdn = socket.getfqdn + socket.getfqdn = mock_socket.getfqdn + # temporarily replace sys.stdout to capture DebuggingServer output + self.old_stdout = sys.stdout + self.output = io.StringIO() + sys.stdout = self.output + + self.serv_evt = threading.Event() + self.client_evt = threading.Event() + # Capture SMTPChannel debug output + self.old_DEBUGSTREAM = smtpd.DEBUGSTREAM + smtpd.DEBUGSTREAM = io.StringIO() + # Pick a random unused port by passing 0 for the port number + self.serv = smtpd.DebuggingServer((HOST, 0), ('nowhere', -1), + decode_data=True) + # Keep a note of what server host and port were assigned + self.host, self.port = self.serv.socket.getsockname()[:2] + serv_args = (self.serv, self.serv_evt, self.client_evt) + self.thread = threading.Thread(target=debugging_server, args=serv_args) + self.thread.start() + + # wait until server thread has assigned a port number + self.serv_evt.wait() + self.serv_evt.clear() + + def tearDown(self): + socket.getfqdn = self.real_getfqdn + # indicate that the client is finished + self.client_evt.set() + # wait for the server thread to terminate + self.serv_evt.wait() + threading_helper.join_thread(self.thread) + # restore sys.stdout + sys.stdout = self.old_stdout + # restore DEBUGSTREAM + smtpd.DEBUGSTREAM.close() + smtpd.DEBUGSTREAM = self.old_DEBUGSTREAM + del self.thread + self.doCleanups() + threading_helper.threading_cleanup(*self.thread_key) + + def get_output_without_xpeer(self): + test_output = self.output.getvalue() + return re.sub(r'(.*?)^X-Peer:\s*\S+\n(.*)', r'\1\2', + test_output, flags=re.MULTILINE|re.DOTALL) + + def testBasic(self): + # connect + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.quit() + + def testSourceAddress(self): + # connect + src_port = socket_helper.find_unused_port() + try: + smtp = smtplib.SMTP(self.host, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT, + source_address=(self.host, src_port)) + self.addCleanup(smtp.close) + self.assertEqual(smtp.source_address, (self.host, src_port)) + self.assertEqual(smtp.local_hostname, 'localhost') + smtp.quit() + except OSError as e: + if e.errno == errno.EADDRINUSE: + self.skipTest("couldn't bind to source port %d" % src_port) + raise + + def testNOOP(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + expected = (250, b'OK') + self.assertEqual(smtp.noop(), expected) + smtp.quit() + + def testRSET(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + expected = (250, b'OK') + self.assertEqual(smtp.rset(), expected) + smtp.quit() + + def testELHO(self): + # EHLO isn't implemented in DebuggingServer + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + expected = (250, b'\nSIZE 33554432\nHELP') + self.assertEqual(smtp.ehlo(), expected) + smtp.quit() + + def testEXPNNotImplemented(self): + # EXPN isn't implemented in DebuggingServer + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + expected = (502, b'EXPN not implemented') + smtp.putcmd('EXPN') + self.assertEqual(smtp.getreply(), expected) + smtp.quit() + + def test_issue43124_putcmd_escapes_newline(self): + # see: https://bugs.python.org/issue43124 + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + with self.assertRaises(ValueError) as exc: + smtp.putcmd('helo\nX-INJECTED') + self.assertIn("prohibited newline characters", str(exc.exception)) + smtp.quit() + + def testVRFY(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + expected = (252, b'Cannot VRFY user, but will accept message ' + \ + b'and attempt delivery') + self.assertEqual(smtp.vrfy('nobody@nowhere.com'), expected) + self.assertEqual(smtp.verify('nobody@nowhere.com'), expected) + smtp.quit() + + def testSecondHELO(self): + # check that a second HELO returns a message that it's a duplicate + # (this behavior is specific to smtpd.SMTPChannel) + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.helo() + expected = (503, b'Duplicate HELO/EHLO') + self.assertEqual(smtp.helo(), expected) + smtp.quit() + + def testHELP(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + self.assertEqual(smtp.help(), b'Supported commands: EHLO HELO MAIL ' + \ + b'RCPT DATA RSET NOOP QUIT VRFY') + smtp.quit() + + def testSend(self): + # connect and send mail + m = 'A test message' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.sendmail('John', 'Sally', m) + # XXX(nnorwitz): this test is flaky and dies with a bad file descriptor + # in asyncore. This sleep might help, but should really be fixed + # properly by using an Event variable. + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + mexpect = '%s%s\n%s' % (MSG_BEGIN, m, MSG_END) + self.assertEqual(self.output.getvalue(), mexpect) + + def testSendBinary(self): + m = b'A test message' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.sendmail('John', 'Sally', m) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + mexpect = '%s%s\n%s' % (MSG_BEGIN, m.decode('ascii'), MSG_END) + self.assertEqual(self.output.getvalue(), mexpect) + + def testSendNeedingDotQuote(self): + # Issue 12283 + m = '.A test\n.mes.sage.' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.sendmail('John', 'Sally', m) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + mexpect = '%s%s\n%s' % (MSG_BEGIN, m, MSG_END) + self.assertEqual(self.output.getvalue(), mexpect) + + def test_issue43124_escape_localhostname(self): + # see: https://bugs.python.org/issue43124 + # connect and send mail + m = 'wazzuuup\nlinetwo' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='hi\nX-INJECTED', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + with self.assertRaises(ValueError) as exc: + smtp.sendmail("hi@me.com", "you@me.com", m) + self.assertIn( + "prohibited newline characters: ehlo hi\\nX-INJECTED", + str(exc.exception), + ) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + debugout = smtpd.DEBUGSTREAM.getvalue() + self.assertNotIn("X-INJECTED", debugout) + + def test_issue43124_escape_options(self): + # see: https://bugs.python.org/issue43124 + # connect and send mail + m = 'wazzuuup\nlinetwo' + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + + self.addCleanup(smtp.close) + smtp.sendmail("hi@me.com", "you@me.com", m) + with self.assertRaises(ValueError) as exc: + smtp.mail("hi@me.com", ["X-OPTION\nX-INJECTED-1", "X-OPTION2\nX-INJECTED-2"]) + msg = str(exc.exception) + self.assertIn("prohibited newline characters", msg) + self.assertIn("X-OPTION\\nX-INJECTED-1 X-OPTION2\\nX-INJECTED-2", msg) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + debugout = smtpd.DEBUGSTREAM.getvalue() + self.assertNotIn("X-OPTION", debugout) + self.assertNotIn("X-OPTION2", debugout) + self.assertNotIn("X-INJECTED-1", debugout) + self.assertNotIn("X-INJECTED-2", debugout) + + def testSendNullSender(self): + m = 'A test message' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.sendmail('<>', 'Sally', m) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + mexpect = '%s%s\n%s' % (MSG_BEGIN, m, MSG_END) + self.assertEqual(self.output.getvalue(), mexpect) + debugout = smtpd.DEBUGSTREAM.getvalue() + sender = re.compile("^sender: <>$", re.MULTILINE) + self.assertRegex(debugout, sender) + + def testSendMessage(self): + m = email.mime.text.MIMEText('A test message') + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.send_message(m, from_addr='John', to_addrs='Sally') + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + # Remove the X-Peer header that DebuggingServer adds as figuring out + # exactly what IP address format is put there is not easy (and + # irrelevant to our test). Typically 127.0.0.1 or ::1, but it is + # not always the same as socket.gethostbyname(HOST). :( + test_output = self.get_output_without_xpeer() + del m['X-Peer'] + mexpect = '%s%s\n%s' % (MSG_BEGIN, m.as_string(), MSG_END) + self.assertEqual(test_output, mexpect) + + def testSendMessageWithAddresses(self): + m = email.mime.text.MIMEText('A test message') + m['From'] = 'foo@bar.com' + m['To'] = 'John' + m['CC'] = 'Sally, Fred' + m['Bcc'] = 'John Root , "Dinsdale" ' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.send_message(m) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + # make sure the Bcc header is still in the message. + self.assertEqual(m['Bcc'], 'John Root , "Dinsdale" ' + '') + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + # Remove the X-Peer header that DebuggingServer adds. + test_output = self.get_output_without_xpeer() + del m['X-Peer'] + # The Bcc header should not be transmitted. + del m['Bcc'] + mexpect = '%s%s\n%s' % (MSG_BEGIN, m.as_string(), MSG_END) + self.assertEqual(test_output, mexpect) + debugout = smtpd.DEBUGSTREAM.getvalue() + sender = re.compile("^sender: foo@bar.com$", re.MULTILINE) + self.assertRegex(debugout, sender) + for addr in ('John', 'Sally', 'Fred', 'root@localhost', + 'warped@silly.walks.com'): + to_addr = re.compile(r"^recips: .*'{}'.*$".format(addr), + re.MULTILINE) + self.assertRegex(debugout, to_addr) + + def testSendMessageWithSomeAddresses(self): + # Make sure nothing breaks if not all of the three 'to' headers exist + m = email.mime.text.MIMEText('A test message') + m['From'] = 'foo@bar.com' + m['To'] = 'John, Dinsdale' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.send_message(m) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + # Remove the X-Peer header that DebuggingServer adds. + test_output = self.get_output_without_xpeer() + del m['X-Peer'] + mexpect = '%s%s\n%s' % (MSG_BEGIN, m.as_string(), MSG_END) + self.assertEqual(test_output, mexpect) + debugout = smtpd.DEBUGSTREAM.getvalue() + sender = re.compile("^sender: foo@bar.com$", re.MULTILINE) + self.assertRegex(debugout, sender) + for addr in ('John', 'Dinsdale'): + to_addr = re.compile(r"^recips: .*'{}'.*$".format(addr), + re.MULTILINE) + self.assertRegex(debugout, to_addr) + + def testSendMessageWithSpecifiedAddresses(self): + # Make sure addresses specified in call override those in message. + m = email.mime.text.MIMEText('A test message') + m['From'] = 'foo@bar.com' + m['To'] = 'John, Dinsdale' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.send_message(m, from_addr='joe@example.com', to_addrs='foo@example.net') + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + # Remove the X-Peer header that DebuggingServer adds. + test_output = self.get_output_without_xpeer() + del m['X-Peer'] + mexpect = '%s%s\n%s' % (MSG_BEGIN, m.as_string(), MSG_END) + self.assertEqual(test_output, mexpect) + debugout = smtpd.DEBUGSTREAM.getvalue() + sender = re.compile("^sender: joe@example.com$", re.MULTILINE) + self.assertRegex(debugout, sender) + for addr in ('John', 'Dinsdale'): + to_addr = re.compile(r"^recips: .*'{}'.*$".format(addr), + re.MULTILINE) + self.assertNotRegex(debugout, to_addr) + recip = re.compile(r"^recips: .*'foo@example.net'.*$", re.MULTILINE) + self.assertRegex(debugout, recip) + + def testSendMessageWithMultipleFrom(self): + # Sender overrides To + m = email.mime.text.MIMEText('A test message') + m['From'] = 'Bernard, Bianca' + m['Sender'] = 'the_rescuers@Rescue-Aid-Society.com' + m['To'] = 'John, Dinsdale' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.send_message(m) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + # Remove the X-Peer header that DebuggingServer adds. + test_output = self.get_output_without_xpeer() + del m['X-Peer'] + mexpect = '%s%s\n%s' % (MSG_BEGIN, m.as_string(), MSG_END) + self.assertEqual(test_output, mexpect) + debugout = smtpd.DEBUGSTREAM.getvalue() + sender = re.compile("^sender: the_rescuers@Rescue-Aid-Society.com$", re.MULTILINE) + self.assertRegex(debugout, sender) + for addr in ('John', 'Dinsdale'): + to_addr = re.compile(r"^recips: .*'{}'.*$".format(addr), + re.MULTILINE) + self.assertRegex(debugout, to_addr) + + def testSendMessageResent(self): + m = email.mime.text.MIMEText('A test message') + m['From'] = 'foo@bar.com' + m['To'] = 'John' + m['CC'] = 'Sally, Fred' + m['Bcc'] = 'John Root , "Dinsdale" ' + m['Resent-Date'] = 'Thu, 1 Jan 1970 17:42:00 +0000' + m['Resent-From'] = 'holy@grail.net' + m['Resent-To'] = 'Martha , Jeff' + m['Resent-Bcc'] = 'doe@losthope.net' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.send_message(m) + # XXX (see comment in testSend) + time.sleep(0.01) + smtp.quit() + + self.client_evt.set() + self.serv_evt.wait() + self.output.flush() + # The Resent-Bcc headers are deleted before serialization. + del m['Bcc'] + del m['Resent-Bcc'] + # Remove the X-Peer header that DebuggingServer adds. + test_output = self.get_output_without_xpeer() + del m['X-Peer'] + mexpect = '%s%s\n%s' % (MSG_BEGIN, m.as_string(), MSG_END) + self.assertEqual(test_output, mexpect) + debugout = smtpd.DEBUGSTREAM.getvalue() + sender = re.compile("^sender: holy@grail.net$", re.MULTILINE) + self.assertRegex(debugout, sender) + for addr in ('my_mom@great.cooker.com', 'Jeff', 'doe@losthope.net'): + to_addr = re.compile(r"^recips: .*'{}'.*$".format(addr), + re.MULTILINE) + self.assertRegex(debugout, to_addr) + + def testSendMessageMultipleResentRaises(self): + m = email.mime.text.MIMEText('A test message') + m['From'] = 'foo@bar.com' + m['To'] = 'John' + m['CC'] = 'Sally, Fred' + m['Bcc'] = 'John Root , "Dinsdale" ' + m['Resent-Date'] = 'Thu, 1 Jan 1970 17:42:00 +0000' + m['Resent-From'] = 'holy@grail.net' + m['Resent-To'] = 'Martha , Jeff' + m['Resent-Bcc'] = 'doe@losthope.net' + m['Resent-Date'] = 'Thu, 2 Jan 1970 17:42:00 +0000' + m['Resent-To'] = 'holy@grail.net' + m['Resent-From'] = 'Martha , Jeff' + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + with self.assertRaises(ValueError): + smtp.send_message(m) + smtp.close() + +class NonConnectingTests(unittest.TestCase): + + def testNotConnected(self): + # Test various operations on an unconnected SMTP object that + # should raise exceptions (at present the attempt in SMTP.send + # to reference the nonexistent 'sock' attribute of the SMTP object + # causes an AttributeError) + smtp = smtplib.SMTP() + self.assertRaises(smtplib.SMTPServerDisconnected, smtp.ehlo) + self.assertRaises(smtplib.SMTPServerDisconnected, + smtp.send, 'test msg') + + def testNonnumericPort(self): + # check that non-numeric port raises OSError + self.assertRaises(OSError, smtplib.SMTP, + "localhost", "bogus") + self.assertRaises(OSError, smtplib.SMTP, + "localhost:bogus") + + def testSockAttributeExists(self): + # check that sock attribute is present outside of a connect() call + # (regression test, the previous behavior raised an + # AttributeError: 'SMTP' object has no attribute 'sock') + with smtplib.SMTP() as smtp: + self.assertIsNone(smtp.sock) + + +class DefaultArgumentsTests(unittest.TestCase): + + def setUp(self): + self.msg = EmailMessage() + self.msg['From'] = 'Páolo ' + self.smtp = smtplib.SMTP() + self.smtp.ehlo = Mock(return_value=(200, 'OK')) + self.smtp.has_extn, self.smtp.sendmail = Mock(), Mock() + + def testSendMessage(self): + expected_mail_options = ('SMTPUTF8', 'BODY=8BITMIME') + self.smtp.send_message(self.msg) + self.smtp.send_message(self.msg) + self.assertEqual(self.smtp.sendmail.call_args_list[0][0][3], + expected_mail_options) + self.assertEqual(self.smtp.sendmail.call_args_list[1][0][3], + expected_mail_options) + + def testSendMessageWithMailOptions(self): + mail_options = ['STARTTLS'] + expected_mail_options = ('STARTTLS', 'SMTPUTF8', 'BODY=8BITMIME') + self.smtp.send_message(self.msg, None, None, mail_options) + self.assertEqual(mail_options, ['STARTTLS']) + self.assertEqual(self.smtp.sendmail.call_args_list[0][0][3], + expected_mail_options) + + +# test response of client to a non-successful HELO message +class BadHELOServerTests(unittest.TestCase): + + def setUp(self): + smtplib.socket = mock_socket + mock_socket.reply_with(b"199 no hello for you!") + self.old_stdout = sys.stdout + self.output = io.StringIO() + sys.stdout = self.output + self.port = 25 + + def tearDown(self): + smtplib.socket = socket + sys.stdout = self.old_stdout + + def testFailingHELO(self): + self.assertRaises(smtplib.SMTPConnectError, smtplib.SMTP, + HOST, self.port, 'localhost', 3) + + +class TooLongLineTests(unittest.TestCase): + respdata = b'250 OK' + (b'.' * smtplib._MAXLINE * 2) + b'\n' + + def setUp(self): + self.thread_key = threading_helper.threading_setup() + self.old_stdout = sys.stdout + self.output = io.StringIO() + sys.stdout = self.output + + self.evt = threading.Event() + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.settimeout(15) + self.port = socket_helper.bind_port(self.sock) + servargs = (self.evt, self.respdata, self.sock) + self.thread = threading.Thread(target=server, args=servargs) + self.thread.start() + self.evt.wait() + self.evt.clear() + + def tearDown(self): + self.evt.wait() + sys.stdout = self.old_stdout + threading_helper.join_thread(self.thread) + del self.thread + self.doCleanups() + threading_helper.threading_cleanup(*self.thread_key) + + def testLineTooLong(self): + self.assertRaises(smtplib.SMTPResponseException, smtplib.SMTP, + HOST, self.port, 'localhost', 3) + + +sim_users = {'Mr.A@somewhere.com':'John A', + 'Ms.B@xn--fo-fka.com':'Sally B', + 'Mrs.C@somewhereesle.com':'Ruth C', + } + +sim_auth = ('Mr.A@somewhere.com', 'somepassword') +sim_cram_md5_challenge = ('PENCeUxFREJoU0NnbmhNWitOMjNGNn' + 'dAZWx3b29kLmlubm9zb2Z0LmNvbT4=') +sim_lists = {'list-1':['Mr.A@somewhere.com','Mrs.C@somewhereesle.com'], + 'list-2':['Ms.B@xn--fo-fka.com',], + } + +# Simulated SMTP channel & server +class ResponseException(Exception): pass +class SimSMTPChannel(smtpd.SMTPChannel): + + quit_response = None + mail_response = None + rcpt_response = None + data_response = None + rcpt_count = 0 + rset_count = 0 + disconnect = 0 + AUTH = 99 # Add protocol state to enable auth testing. + authenticated_user = None + + def __init__(self, extra_features, *args, **kw): + self._extrafeatures = ''.join( + [ "250-{0}\r\n".format(x) for x in extra_features ]) + super(SimSMTPChannel, self).__init__(*args, **kw) + + # AUTH related stuff. It would be nice if support for this were in smtpd. + def found_terminator(self): + if self.smtp_state == self.AUTH: + line = self._emptystring.join(self.received_lines) + print('Data:', repr(line), file=smtpd.DEBUGSTREAM) + self.received_lines = [] + try: + self.auth_object(line) + except ResponseException as e: + self.smtp_state = self.COMMAND + self.push('%s %s' % (e.smtp_code, e.smtp_error)) + return + super().found_terminator() + + + def smtp_AUTH(self, arg): + if not self.seen_greeting: + self.push('503 Error: send EHLO first') + return + if not self.extended_smtp or 'AUTH' not in self._extrafeatures: + self.push('500 Error: command "AUTH" not recognized') + return + if self.authenticated_user is not None: + self.push( + '503 Bad sequence of commands: already authenticated') + return + args = arg.split() + if len(args) not in [1, 2]: + self.push('501 Syntax: AUTH [initial-response]') + return + auth_object_name = '_auth_%s' % args[0].lower().replace('-', '_') + try: + self.auth_object = getattr(self, auth_object_name) + except AttributeError: + self.push('504 Command parameter not implemented: unsupported ' + ' authentication mechanism {!r}'.format(auth_object_name)) + return + self.smtp_state = self.AUTH + self.auth_object(args[1] if len(args) == 2 else None) + + def _authenticated(self, user, valid): + if valid: + self.authenticated_user = user + self.push('235 Authentication Succeeded') + else: + self.push('535 Authentication credentials invalid') + self.smtp_state = self.COMMAND + + def _decode_base64(self, string): + return base64.decodebytes(string.encode('ascii')).decode('utf-8') + + def _auth_plain(self, arg=None): + if arg is None: + self.push('334 ') + else: + logpass = self._decode_base64(arg) + try: + *_, user, password = logpass.split('\0') + except ValueError as e: + self.push('535 Splitting response {!r} into user and password' + ' failed: {}'.format(logpass, e)) + return + self._authenticated(user, password == sim_auth[1]) + + def _auth_login(self, arg=None): + if arg is None: + # base64 encoded 'Username:' + self.push('334 VXNlcm5hbWU6') + elif not hasattr(self, '_auth_login_user'): + self._auth_login_user = self._decode_base64(arg) + # base64 encoded 'Password:' + self.push('334 UGFzc3dvcmQ6') + else: + password = self._decode_base64(arg) + self._authenticated(self._auth_login_user, password == sim_auth[1]) + del self._auth_login_user + + def _auth_buggy(self, arg=None): + # This AUTH mechanism will 'trap' client in a neverending 334 + # base64 encoded 'BuGgYbUgGy' + self.push('334 QnVHZ1liVWdHeQ==') + + def _auth_cram_md5(self, arg=None): + if arg is None: + self.push('334 {}'.format(sim_cram_md5_challenge)) + else: + logpass = self._decode_base64(arg) + try: + user, hashed_pass = logpass.split() + except ValueError as e: + self.push('535 Splitting response {!r} into user and password ' + 'failed: {}'.format(logpass, e)) + return False + valid_hashed_pass = hmac.HMAC( + sim_auth[1].encode('ascii'), + self._decode_base64(sim_cram_md5_challenge).encode('ascii'), + 'md5').hexdigest() + self._authenticated(user, hashed_pass == valid_hashed_pass) + # end AUTH related stuff. + + def smtp_EHLO(self, arg): + resp = ('250-testhost\r\n' + '250-EXPN\r\n' + '250-SIZE 20000000\r\n' + '250-STARTTLS\r\n' + '250-DELIVERBY\r\n') + resp = resp + self._extrafeatures + '250 HELP' + self.push(resp) + self.seen_greeting = arg + self.extended_smtp = True + + def smtp_VRFY(self, arg): + # For max compatibility smtplib should be sending the raw address. + if arg in sim_users: + self.push('250 %s %s' % (sim_users[arg], smtplib.quoteaddr(arg))) + else: + self.push('550 No such user: %s' % arg) + + def smtp_EXPN(self, arg): + list_name = arg.lower() + if list_name in sim_lists: + user_list = sim_lists[list_name] + for n, user_email in enumerate(user_list): + quoted_addr = smtplib.quoteaddr(user_email) + if n < len(user_list) - 1: + self.push('250-%s %s' % (sim_users[user_email], quoted_addr)) + else: + self.push('250 %s %s' % (sim_users[user_email], quoted_addr)) + else: + self.push('550 No access for you!') + + def smtp_QUIT(self, arg): + if self.quit_response is None: + super(SimSMTPChannel, self).smtp_QUIT(arg) + else: + self.push(self.quit_response) + self.close_when_done() + + def smtp_MAIL(self, arg): + if self.mail_response is None: + super().smtp_MAIL(arg) + else: + self.push(self.mail_response) + if self.disconnect: + self.close_when_done() + + def smtp_RCPT(self, arg): + if self.rcpt_response is None: + super().smtp_RCPT(arg) + return + self.rcpt_count += 1 + self.push(self.rcpt_response[self.rcpt_count-1]) + + def smtp_RSET(self, arg): + self.rset_count += 1 + super().smtp_RSET(arg) + + def smtp_DATA(self, arg): + if self.data_response is None: + super().smtp_DATA(arg) + else: + self.push(self.data_response) + + def handle_error(self): + raise + + +class SimSMTPServer(smtpd.SMTPServer): + + channel_class = SimSMTPChannel + + def __init__(self, *args, **kw): + self._extra_features = [] + self._addresses = {} + smtpd.SMTPServer.__init__(self, *args, **kw) + + def handle_accepted(self, conn, addr): + self._SMTPchannel = self.channel_class( + self._extra_features, self, conn, addr, + decode_data=self._decode_data) + + def process_message(self, peer, mailfrom, rcpttos, data): + self._addresses['from'] = mailfrom + self._addresses['tos'] = rcpttos + + def add_feature(self, feature): + self._extra_features.append(feature) + + def handle_error(self): + raise + + +# Test various SMTP & ESMTP commands/behaviors that require a simulated server +# (i.e., something with more features than DebuggingServer) +class SMTPSimTests(unittest.TestCase): + + def setUp(self): + self.thread_key = threading_helper.threading_setup() + self.real_getfqdn = socket.getfqdn + socket.getfqdn = mock_socket.getfqdn + self.serv_evt = threading.Event() + self.client_evt = threading.Event() + # Pick a random unused port by passing 0 for the port number + self.serv = SimSMTPServer((HOST, 0), ('nowhere', -1), decode_data=True) + # Keep a note of what port was assigned + self.port = self.serv.socket.getsockname()[1] + serv_args = (self.serv, self.serv_evt, self.client_evt) + self.thread = threading.Thread(target=debugging_server, args=serv_args) + self.thread.start() + + # wait until server thread has assigned a port number + self.serv_evt.wait() + self.serv_evt.clear() + + def tearDown(self): + socket.getfqdn = self.real_getfqdn + # indicate that the client is finished + self.client_evt.set() + # wait for the server thread to terminate + self.serv_evt.wait() + threading_helper.join_thread(self.thread) + del self.thread + self.doCleanups() + threading_helper.threading_cleanup(*self.thread_key) + + def testBasic(self): + # smoke test + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.quit() + + def testEHLO(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + + # no features should be present before the EHLO + self.assertEqual(smtp.esmtp_features, {}) + + # features expected from the test server + expected_features = {'expn':'', + 'size': '20000000', + 'starttls': '', + 'deliverby': '', + 'help': '', + } + + smtp.ehlo() + self.assertEqual(smtp.esmtp_features, expected_features) + for k in expected_features: + self.assertTrue(smtp.has_extn(k)) + self.assertFalse(smtp.has_extn('unsupported-feature')) + smtp.quit() + + def testVRFY(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + + for addr_spec, name in sim_users.items(): + expected_known = (250, bytes('%s %s' % + (name, smtplib.quoteaddr(addr_spec)), + "ascii")) + self.assertEqual(smtp.vrfy(addr_spec), expected_known) + + u = 'nobody@nowhere.com' + expected_unknown = (550, ('No such user: %s' % u).encode('ascii')) + self.assertEqual(smtp.vrfy(u), expected_unknown) + smtp.quit() + + def testEXPN(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + + for listname, members in sim_lists.items(): + users = [] + for m in members: + users.append('%s %s' % (sim_users[m], smtplib.quoteaddr(m))) + expected_known = (250, bytes('\n'.join(users), "ascii")) + self.assertEqual(smtp.expn(listname), expected_known) + + u = 'PSU-Members-List' + expected_unknown = (550, b'No access for you!') + self.assertEqual(smtp.expn(u), expected_unknown) + smtp.quit() + + def testAUTH_PLAIN(self): + self.serv.add_feature("AUTH PLAIN") + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + resp = smtp.login(sim_auth[0], sim_auth[1]) + self.assertEqual(resp, (235, b'Authentication Succeeded')) + smtp.close() + + def testAUTH_LOGIN(self): + self.serv.add_feature("AUTH LOGIN") + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + resp = smtp.login(sim_auth[0], sim_auth[1]) + self.assertEqual(resp, (235, b'Authentication Succeeded')) + smtp.close() + + def testAUTH_LOGIN_initial_response_ok(self): + self.serv.add_feature("AUTH LOGIN") + with smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) as smtp: + smtp.user, smtp.password = sim_auth + smtp.ehlo("test_auth_login") + resp = smtp.auth("LOGIN", smtp.auth_login, initial_response_ok=True) + self.assertEqual(resp, (235, b'Authentication Succeeded')) + + def testAUTH_LOGIN_initial_response_notok(self): + self.serv.add_feature("AUTH LOGIN") + with smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) as smtp: + smtp.user, smtp.password = sim_auth + smtp.ehlo("test_auth_login") + resp = smtp.auth("LOGIN", smtp.auth_login, initial_response_ok=False) + self.assertEqual(resp, (235, b'Authentication Succeeded')) + + def testAUTH_BUGGY(self): + self.serv.add_feature("AUTH BUGGY") + + def auth_buggy(challenge=None): + self.assertEqual(b"BuGgYbUgGy", challenge) + return "\0" + + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT + ) + try: + smtp.user, smtp.password = sim_auth + smtp.ehlo("test_auth_buggy") + expect = r"^Server AUTH mechanism infinite loop.*" + with self.assertRaisesRegex(smtplib.SMTPException, expect) as cm: + smtp.auth("BUGGY", auth_buggy, initial_response_ok=False) + finally: + smtp.close() + + @hashlib_helper.requires_hashdigest('md5', openssl=True) + def testAUTH_CRAM_MD5(self): + self.serv.add_feature("AUTH CRAM-MD5") + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + resp = smtp.login(sim_auth[0], sim_auth[1]) + self.assertEqual(resp, (235, b'Authentication Succeeded')) + smtp.close() + + @hashlib_helper.requires_hashdigest('md5', openssl=True) + def testAUTH_multiple(self): + # Test that multiple authentication methods are tried. + self.serv.add_feature("AUTH BOGUS PLAIN LOGIN CRAM-MD5") + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + resp = smtp.login(sim_auth[0], sim_auth[1]) + self.assertEqual(resp, (235, b'Authentication Succeeded')) + smtp.close() + + def test_auth_function(self): + supported = {'PLAIN', 'LOGIN'} + try: + hashlib.md5() + except ValueError: + pass + else: + supported.add('CRAM-MD5') + for mechanism in supported: + self.serv.add_feature("AUTH {}".format(mechanism)) + for mechanism in supported: + with self.subTest(mechanism=mechanism): + smtp = smtplib.SMTP(HOST, self.port, + local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.ehlo('foo') + smtp.user, smtp.password = sim_auth[0], sim_auth[1] + method = 'auth_' + mechanism.lower().replace('-', '_') + resp = smtp.auth(mechanism, getattr(smtp, method)) + self.assertEqual(resp, (235, b'Authentication Succeeded')) + smtp.close() + + def test_quit_resets_greeting(self): + smtp = smtplib.SMTP(HOST, self.port, + local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + code, message = smtp.ehlo() + self.assertEqual(code, 250) + self.assertIn('size', smtp.esmtp_features) + smtp.quit() + self.assertNotIn('size', smtp.esmtp_features) + smtp.connect(HOST, self.port) + self.assertNotIn('size', smtp.esmtp_features) + smtp.ehlo_or_helo_if_needed() + self.assertIn('size', smtp.esmtp_features) + smtp.quit() + + def test_with_statement(self): + with smtplib.SMTP(HOST, self.port) as smtp: + code, message = smtp.noop() + self.assertEqual(code, 250) + self.assertRaises(smtplib.SMTPServerDisconnected, smtp.send, b'foo') + with smtplib.SMTP(HOST, self.port) as smtp: + smtp.close() + self.assertRaises(smtplib.SMTPServerDisconnected, smtp.send, b'foo') + + def test_with_statement_QUIT_failure(self): + with self.assertRaises(smtplib.SMTPResponseException) as error: + with smtplib.SMTP(HOST, self.port) as smtp: + smtp.noop() + self.serv._SMTPchannel.quit_response = '421 QUIT FAILED' + self.assertEqual(error.exception.smtp_code, 421) + self.assertEqual(error.exception.smtp_error, b'QUIT FAILED') + + #TODO: add tests for correct AUTH method fallback now that the + #test infrastructure can support it. + + # Issue 17498: make sure _rset does not raise SMTPServerDisconnected exception + def test__rest_from_mail_cmd(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.noop() + self.serv._SMTPchannel.mail_response = '451 Requested action aborted' + self.serv._SMTPchannel.disconnect = True + with self.assertRaises(smtplib.SMTPSenderRefused): + smtp.sendmail('John', 'Sally', 'test message') + self.assertIsNone(smtp.sock) + + # Issue 5713: make sure close, not rset, is called if we get a 421 error + def test_421_from_mail_cmd(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.noop() + self.serv._SMTPchannel.mail_response = '421 closing connection' + with self.assertRaises(smtplib.SMTPSenderRefused): + smtp.sendmail('John', 'Sally', 'test message') + self.assertIsNone(smtp.sock) + self.assertEqual(self.serv._SMTPchannel.rset_count, 0) + + def test_421_from_rcpt_cmd(self): + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.noop() + self.serv._SMTPchannel.rcpt_response = ['250 accepted', '421 closing'] + with self.assertRaises(smtplib.SMTPRecipientsRefused) as r: + smtp.sendmail('John', ['Sally', 'Frank', 'George'], 'test message') + self.assertIsNone(smtp.sock) + self.assertEqual(self.serv._SMTPchannel.rset_count, 0) + self.assertDictEqual(r.exception.args[0], {'Frank': (421, b'closing')}) + + def test_421_from_data_cmd(self): + class MySimSMTPChannel(SimSMTPChannel): + def found_terminator(self): + if self.smtp_state == self.DATA: + self.push('421 closing') + else: + super().found_terminator() + self.serv.channel_class = MySimSMTPChannel + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.noop() + with self.assertRaises(smtplib.SMTPDataError): + smtp.sendmail('John@foo.org', ['Sally@foo.org'], 'test message') + self.assertIsNone(smtp.sock) + self.assertEqual(self.serv._SMTPchannel.rcpt_count, 0) + + def test_smtputf8_NotSupportedError_if_no_server_support(self): + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.ehlo() + self.assertTrue(smtp.does_esmtp) + self.assertFalse(smtp.has_extn('smtputf8')) + self.assertRaises( + smtplib.SMTPNotSupportedError, + smtp.sendmail, + 'John', 'Sally', '', mail_options=['BODY=8BITMIME', 'SMTPUTF8']) + self.assertRaises( + smtplib.SMTPNotSupportedError, + smtp.mail, 'John', options=['BODY=8BITMIME', 'SMTPUTF8']) + + def test_send_unicode_without_SMTPUTF8(self): + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + self.assertRaises(UnicodeEncodeError, smtp.sendmail, 'Alice', 'Böb', '') + self.assertRaises(UnicodeEncodeError, smtp.mail, 'Älice') + + def test_send_message_error_on_non_ascii_addrs_if_no_smtputf8(self): + # This test is located here and not in the SMTPUTF8SimTests + # class because it needs a "regular" SMTP server to work + msg = EmailMessage() + msg['From'] = "Páolo " + msg['To'] = 'Dinsdale' + msg['Subject'] = 'Nudge nudge, wink, wink \u1F609' + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + with self.assertRaises(smtplib.SMTPNotSupportedError): + smtp.send_message(msg) + + def test_name_field_not_included_in_envelop_addresses(self): + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + + message = EmailMessage() + message['From'] = email.utils.formataddr(('Michaël', 'michael@example.com')) + message['To'] = email.utils.formataddr(('René', 'rene@example.com')) + + self.assertDictEqual(smtp.send_message(message), {}) + + self.assertEqual(self.serv._addresses['from'], 'michael@example.com') + self.assertEqual(self.serv._addresses['tos'], ['rene@example.com']) + + +class SimSMTPUTF8Server(SimSMTPServer): + + def __init__(self, *args, **kw): + # The base SMTP server turns these on automatically, but our test + # server is set up to munge the EHLO response, so we need to provide + # them as well. And yes, the call is to SMTPServer not SimSMTPServer. + self._extra_features = ['SMTPUTF8', '8BITMIME'] + smtpd.SMTPServer.__init__(self, *args, **kw) + + def handle_accepted(self, conn, addr): + self._SMTPchannel = self.channel_class( + self._extra_features, self, conn, addr, + decode_data=self._decode_data, + enable_SMTPUTF8=self.enable_SMTPUTF8, + ) + + def process_message(self, peer, mailfrom, rcpttos, data, mail_options=None, + rcpt_options=None): + self.last_peer = peer + self.last_mailfrom = mailfrom + self.last_rcpttos = rcpttos + self.last_message = data + self.last_mail_options = mail_options + self.last_rcpt_options = rcpt_options + + +class SMTPUTF8SimTests(unittest.TestCase): + + maxDiff = None + + def setUp(self): + self.thread_key = threading_helper.threading_setup() + self.real_getfqdn = socket.getfqdn + socket.getfqdn = mock_socket.getfqdn + self.serv_evt = threading.Event() + self.client_evt = threading.Event() + # Pick a random unused port by passing 0 for the port number + self.serv = SimSMTPUTF8Server((HOST, 0), ('nowhere', -1), + decode_data=False, + enable_SMTPUTF8=True) + # Keep a note of what port was assigned + self.port = self.serv.socket.getsockname()[1] + serv_args = (self.serv, self.serv_evt, self.client_evt) + self.thread = threading.Thread(target=debugging_server, args=serv_args) + self.thread.start() + + # wait until server thread has assigned a port number + self.serv_evt.wait() + self.serv_evt.clear() + + def tearDown(self): + socket.getfqdn = self.real_getfqdn + # indicate that the client is finished + self.client_evt.set() + # wait for the server thread to terminate + self.serv_evt.wait() + threading_helper.join_thread(self.thread) + del self.thread + self.doCleanups() + threading_helper.threading_cleanup(*self.thread_key) + + def test_test_server_supports_extensions(self): + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.ehlo() + self.assertTrue(smtp.does_esmtp) + self.assertTrue(smtp.has_extn('smtputf8')) + + def test_send_unicode_with_SMTPUTF8_via_sendmail(self): + m = '¡a test message containing unicode!'.encode('utf-8') + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.sendmail('Jőhn', 'Sálly', m, + mail_options=['BODY=8BITMIME', 'SMTPUTF8']) + self.assertEqual(self.serv.last_mailfrom, 'Jőhn') + self.assertEqual(self.serv.last_rcpttos, ['Sálly']) + self.assertEqual(self.serv.last_message, m) + self.assertIn('BODY=8BITMIME', self.serv.last_mail_options) + self.assertIn('SMTPUTF8', self.serv.last_mail_options) + self.assertEqual(self.serv.last_rcpt_options, []) + + def test_send_unicode_with_SMTPUTF8_via_low_level_API(self): + m = '¡a test message containing unicode!'.encode('utf-8') + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + smtp.ehlo() + self.assertEqual( + smtp.mail('Jő', options=['BODY=8BITMIME', 'SMTPUTF8']), + (250, b'OK')) + self.assertEqual(smtp.rcpt('János'), (250, b'OK')) + self.assertEqual(smtp.data(m), (250, b'OK')) + self.assertEqual(self.serv.last_mailfrom, 'Jő') + self.assertEqual(self.serv.last_rcpttos, ['János']) + self.assertEqual(self.serv.last_message, m) + self.assertIn('BODY=8BITMIME', self.serv.last_mail_options) + self.assertIn('SMTPUTF8', self.serv.last_mail_options) + self.assertEqual(self.serv.last_rcpt_options, []) + + def test_send_message_uses_smtputf8_if_addrs_non_ascii(self): + msg = EmailMessage() + msg['From'] = "Páolo " + msg['To'] = 'Dinsdale' + msg['Subject'] = 'Nudge nudge, wink, wink \u1F609' + # XXX I don't know why I need two \n's here, but this is an existing + # bug (if it is one) and not a problem with the new functionality. + msg.set_content("oh là là, know what I mean, know what I mean?\n\n") + # XXX smtpd converts received /r/n to /n, so we can't easily test that + # we are successfully sending /r/n :(. + expected = textwrap.dedent("""\ + From: Páolo + To: Dinsdale + Subject: Nudge nudge, wink, wink \u1F609 + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: 8bit + MIME-Version: 1.0 + + oh là là, know what I mean, know what I mean? + """) + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + self.addCleanup(smtp.close) + self.assertEqual(smtp.send_message(msg), {}) + self.assertEqual(self.serv.last_mailfrom, 'főo@bar.com') + self.assertEqual(self.serv.last_rcpttos, ['Dinsdale']) + self.assertEqual(self.serv.last_message.decode(), expected) + self.assertIn('BODY=8BITMIME', self.serv.last_mail_options) + self.assertIn('SMTPUTF8', self.serv.last_mail_options) + self.assertEqual(self.serv.last_rcpt_options, []) + + +EXPECTED_RESPONSE = encode_base64(b'\0psu\0doesnotexist', eol='') + +class SimSMTPAUTHInitialResponseChannel(SimSMTPChannel): + def smtp_AUTH(self, arg): + # RFC 4954's AUTH command allows for an optional initial-response. + # Not all AUTH methods support this; some require a challenge. AUTH + # PLAIN does those, so test that here. See issue #15014. + args = arg.split() + if args[0].lower() == 'plain': + if len(args) == 2: + # AUTH PLAIN with the response base 64 + # encoded. Hard code the expected response for the test. + if args[1] == EXPECTED_RESPONSE: + self.push('235 Ok') + return + self.push('571 Bad authentication') + +class SimSMTPAUTHInitialResponseServer(SimSMTPServer): + channel_class = SimSMTPAUTHInitialResponseChannel + + +class SMTPAUTHInitialResponseSimTests(unittest.TestCase): + def setUp(self): + self.thread_key = threading_helper.threading_setup() + self.real_getfqdn = socket.getfqdn + socket.getfqdn = mock_socket.getfqdn + self.serv_evt = threading.Event() + self.client_evt = threading.Event() + # Pick a random unused port by passing 0 for the port number + self.serv = SimSMTPAUTHInitialResponseServer( + (HOST, 0), ('nowhere', -1), decode_data=True) + # Keep a note of what port was assigned + self.port = self.serv.socket.getsockname()[1] + serv_args = (self.serv, self.serv_evt, self.client_evt) + self.thread = threading.Thread(target=debugging_server, args=serv_args) + self.thread.start() + + # wait until server thread has assigned a port number + self.serv_evt.wait() + self.serv_evt.clear() + + def tearDown(self): + socket.getfqdn = self.real_getfqdn + # indicate that the client is finished + self.client_evt.set() + # wait for the server thread to terminate + self.serv_evt.wait() + threading_helper.join_thread(self.thread) + del self.thread + self.doCleanups() + threading_helper.threading_cleanup(*self.thread_key) + + def testAUTH_PLAIN_initial_response_login(self): + self.serv.add_feature('AUTH PLAIN') + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.login('psu', 'doesnotexist') + smtp.close() + + def testAUTH_PLAIN_initial_response_auth(self): + self.serv.add_feature('AUTH PLAIN') + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', + timeout=support.LOOPBACK_TIMEOUT) + smtp.user = 'psu' + smtp.password = 'doesnotexist' + code, response = smtp.auth('plain', smtp.auth_plain) + smtp.close() + self.assertEqual(code, 235) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_smtpnet.py b/Lib/test/test_smtpnet.py new file mode 100644 index 0000000000..be25e961f7 --- /dev/null +++ b/Lib/test/test_smtpnet.py @@ -0,0 +1,90 @@ +import unittest +from test import support +from test.support import import_helper +from test.support import socket_helper +import smtplib +import socket + +ssl = import_helper.import_module("ssl") + +support.requires("network") + +def check_ssl_verifiy(host, port): + context = ssl.create_default_context() + with socket.create_connection((host, port)) as sock: + try: + sock = context.wrap_socket(sock, server_hostname=host) + except Exception: + return False + else: + sock.close() + return True + + +class SmtpTest(unittest.TestCase): + testServer = 'smtp.gmail.com' + remotePort = 587 + + def test_connect_starttls(self): + support.get_attribute(smtplib, 'SMTP_SSL') + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + with socket_helper.transient_internet(self.testServer): + server = smtplib.SMTP(self.testServer, self.remotePort) + try: + server.starttls(context=context) + except smtplib.SMTPException as e: + if e.args[0] == 'STARTTLS extension not supported by server.': + unittest.skip(e.args[0]) + else: + raise + server.ehlo() + server.quit() + + +class SmtpSSLTest(unittest.TestCase): + testServer = 'smtp.gmail.com' + remotePort = 465 + + def test_connect(self): + support.get_attribute(smtplib, 'SMTP_SSL') + with socket_helper.transient_internet(self.testServer): + server = smtplib.SMTP_SSL(self.testServer, self.remotePort) + server.ehlo() + server.quit() + + def test_connect_default_port(self): + support.get_attribute(smtplib, 'SMTP_SSL') + with socket_helper.transient_internet(self.testServer): + server = smtplib.SMTP_SSL(self.testServer) + server.ehlo() + server.quit() + + @support.requires_resource('walltime') + def test_connect_using_sslcontext(self): + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + support.get_attribute(smtplib, 'SMTP_SSL') + with socket_helper.transient_internet(self.testServer): + server = smtplib.SMTP_SSL(self.testServer, self.remotePort, context=context) + server.ehlo() + server.quit() + + def test_connect_using_sslcontext_verified(self): + with socket_helper.transient_internet(self.testServer): + can_verify = check_ssl_verifiy(self.testServer, self.remotePort) + if not can_verify: + self.skipTest("SSL certificate can't be verified") + + support.get_attribute(smtplib, 'SMTP_SSL') + context = ssl.create_default_context() + with socket_helper.transient_internet(self.testServer): + server = smtplib.SMTP_SSL(self.testServer, self.remotePort, context=context) + server.ehlo() + server.quit() + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 493b21778f950bd485b4590984116c7eee85a7f6 Mon Sep 17 00:00:00 2001 From: Ashwin Naren Date: Wed, 15 Jan 2025 12:03:28 -0800 Subject: [PATCH 2/8] fix expected failures on test_smtplib.py --- Lib/test/test_smtplib.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py index b5508a833e..a36d7bbe2a 100644 --- a/Lib/test/test_smtplib.py +++ b/Lib/test/test_smtplib.py @@ -1170,6 +1170,8 @@ def auth_buggy(challenge=None): finally: smtp.close() + # TODO: RUSTPYTHON + @unittest.expectedFailure @hashlib_helper.requires_hashdigest('md5', openssl=True) def testAUTH_CRAM_MD5(self): self.serv.add_feature("AUTH CRAM-MD5") @@ -1179,6 +1181,8 @@ def testAUTH_CRAM_MD5(self): self.assertEqual(resp, (235, b'Authentication Succeeded')) smtp.close() + # TODO: RUSTPYTHON + @unittest.expectedFailure @hashlib_helper.requires_hashdigest('md5', openssl=True) def testAUTH_multiple(self): # Test that multiple authentication methods are tried. @@ -1189,6 +1193,8 @@ def testAUTH_multiple(self): self.assertEqual(resp, (235, b'Authentication Succeeded')) smtp.close() + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_auth_function(self): supported = {'PLAIN', 'LOGIN'} try: @@ -1453,6 +1459,8 @@ def test_send_unicode_with_SMTPUTF8_via_low_level_API(self): self.assertIn('SMTPUTF8', self.serv.last_mail_options) self.assertEqual(self.serv.last_rcpt_options, []) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_send_message_uses_smtputf8_if_addrs_non_ascii(self): msg = EmailMessage() msg['From'] = "Páolo " From 9ca7262694b6d82105abe506e65cc265168764d9 Mon Sep 17 00:00:00 2001 From: Ashwin Naren Date: Wed, 15 Jan 2025 12:15:15 -0800 Subject: [PATCH 3/8] mark all rustpython fails on test_logging.py --- Lib/test/test_logging.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index 8eb03354f8..031f51a4dd 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -1131,6 +1131,8 @@ class SMTPHandlerTest(BaseTest): # bpo-14314, bpo-19665, bpo-34092: don't wait forever TIMEOUT = support.LONG_TIMEOUT + # TODO: RUSTPYTHON + @unittest.skip(reason="Hangs RustPython") def test_basic(self): sockmap = {} server = TestSMTPServer((socket_helper.HOST, 0), self.process_message, 0.001, @@ -4025,6 +4027,8 @@ def _mpinit_issue121723(qspec, message_to_log): # log a message (this creates a record put in the queue) logging.getLogger().info(message_to_log) + # TODO: RustPython + @unittest.expectedFailure @support.requires_subprocess() def test_multiprocessing_queues(self): # See gh-119819 @@ -4083,6 +4087,8 @@ def test_90195(self): # Logger should be enabled, since explicitly mentioned self.assertFalse(logger.disabled) + # TODO: RustPython + @unittest.expectedFailure def test_111615(self): # See gh-111615 import_helper.import_module('_multiprocessing') # see gh-113692 @@ -4530,6 +4536,8 @@ def test_dollars(self): f = logging.Formatter('${asctime}--', style='$') self.assertTrue(f.usesTime()) + # TODO: RustPython + @unittest.expectedFailure def test_format_validate(self): # Check correct formatting # Percentage style @@ -4703,6 +4711,8 @@ def test_defaults_parameter(self): def test_invalid_style(self): self.assertRaises(ValueError, logging.Formatter, None, None, 'x') + # TODO: RustPython + @unittest.expectedFailure def test_time(self): r = self.get_record() dt = datetime.datetime(1993, 4, 21, 8, 3, 0, 0, utc) @@ -4717,6 +4727,8 @@ def test_time(self): f.format(r) self.assertEqual(r.asctime, '1993-04-21 08:03:00,123') + # TODO: RustPython + @unittest.expectedFailure def test_default_msec_format_none(self): class NoMsecFormatter(logging.Formatter): default_msec_format = None @@ -5047,6 +5059,8 @@ def __init__(self, name='MyLogger', level=logging.NOTSET): h.close() logging.setLoggerClass(logging.Logger) + # TODO: RustPython + @unittest.expectedFailure def test_logging_at_shutdown(self): # bpo-20037: Doing text I/O late at interpreter shutdown must not crash code = textwrap.dedent(""" @@ -5066,6 +5080,8 @@ def __del__(self): self.assertIn("exception in __del__", err) self.assertIn("ValueError: some error", err) + # TODO: RustPython + @unittest.expectedFailure def test_logging_at_shutdown_open(self): # bpo-26789: FileHandler keeps a reference to the builtin open() # function to be able to open or reopen the file during Python @@ -5507,6 +5523,8 @@ def test_encoding_errors_default(self): self.assertEqual(data, r'\U0001f602: \u2603\ufe0f: The \xd8resund ' r'Bridge joins Copenhagen to Malm\xf6') + # TODO: RustPython + @unittest.expectedFailure def test_encoding_errors_none(self): # Specifying None should behave as 'strict' try: @@ -6241,6 +6259,8 @@ def rotator(source, dest): rh.close() class TimedRotatingFileHandlerTest(BaseFileTest): + # TODO: RustPython + @unittest.expectedFailure @unittest.skipIf(support.is_wasi, "WASI does not have /dev/null.") def test_should_not_rollover(self): # See bpo-45401. Should only ever rollover regular files @@ -6294,6 +6314,8 @@ def test_rollover(self): print(tf.read()) self.assertTrue(found, msg=msg) + # TODO: RustPython + @unittest.expectedFailure def test_rollover_at_midnight(self, weekly=False): os_helper.unlink(self.fn) now = datetime.datetime.now() @@ -6337,6 +6359,8 @@ def test_rollover_at_midnight(self, weekly=False): for i, line in enumerate(f): self.assertIn(f'testing1 {i}', line) + # TODO: RustPython + @unittest.expectedFailure def test_rollover_at_weekday(self): self.test_rollover_at_midnight(weekly=True) From 71a140d3c6b120279c14fdd432d6b3a4dffa9f9e Mon Sep 17 00:00:00 2001 From: Ashwin Naren Date: Wed, 15 Jan 2025 16:26:09 -0800 Subject: [PATCH 4/8] fix failing test --- Lib/test/test_logging.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index 031f51a4dd..870d863c8e 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -2170,6 +2170,8 @@ def handle_request(self, request): request.end_headers() self.handled.set() + # TODO: RUSTPYTHON + @unittest.expectedFailure() def test_output(self): # The log message sent to the HTTPHandler is properly received. logger = logging.getLogger("http") From 0c0bc1511f3bfb34913c1975d9d7243e784950d5 Mon Sep 17 00:00:00 2001 From: Ashwin Naren Date: Wed, 15 Jan 2025 20:49:18 -0800 Subject: [PATCH 5/8] fix test --- Lib/test/test_logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index 870d863c8e..f9e1c0cb55 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -2171,7 +2171,7 @@ def handle_request(self, request): self.handled.set() # TODO: RUSTPYTHON - @unittest.expectedFailure() + @unittest.expectedFailure def test_output(self): # The log message sent to the HTTPHandler is properly received. logger = logging.getLogger("http") From 01043f33b6d881663dad3f1f5a37cd7b49873f1c Mon Sep 17 00:00:00 2001 From: Ashwin Naren Date: Wed, 15 Jan 2025 20:58:18 -0800 Subject: [PATCH 6/8] fix test --- Lib/test/test_logging.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index f9e1c0cb55..031f51a4dd 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -2170,8 +2170,6 @@ def handle_request(self, request): request.end_headers() self.handled.set() - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_output(self): # The log message sent to the HTTPHandler is properly received. logger = logging.getLogger("http") From c7630771d8d2f104577c01685feec276ad359223 Mon Sep 17 00:00:00 2001 From: Ashwin Naren Date: Thu, 16 Jan 2025 08:49:26 -0800 Subject: [PATCH 7/8] disable test on test_logging.py --- Lib/test/test_logging.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index 031f51a4dd..4baca2b481 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -2170,6 +2170,8 @@ def handle_request(self, request): request.end_headers() self.handled.set() + # TODO: RUSTPYTHON + @unittest.skip("RUSTPYTHON") def test_output(self): # The log message sent to the HTTPHandler is properly received. logger = logging.getLogger("http") From 8610fbcac766da8a1568935c605197963050b3f2 Mon Sep 17 00:00:00 2001 From: Ashwin Naren Date: Thu, 16 Jan 2025 12:30:26 -0800 Subject: [PATCH 8/8] skip tests from OS dependent bugs --- Lib/test/test_logging.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index 4baca2b481..a570d65f6c 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -6261,8 +6261,8 @@ def rotator(source, dest): rh.close() class TimedRotatingFileHandlerTest(BaseFileTest): - # TODO: RustPython - @unittest.expectedFailure + # TODO: RUSTPYTHON + @unittest.skip("OS dependent bug") @unittest.skipIf(support.is_wasi, "WASI does not have /dev/null.") def test_should_not_rollover(self): # See bpo-45401. Should only ever rollover regular files @@ -6317,7 +6317,7 @@ def test_rollover(self): self.assertTrue(found, msg=msg) # TODO: RustPython - @unittest.expectedFailure + @unittest.skip("OS dependent bug") def test_rollover_at_midnight(self, weekly=False): os_helper.unlink(self.fn) now = datetime.datetime.now() @@ -6362,7 +6362,7 @@ def test_rollover_at_midnight(self, weekly=False): self.assertIn(f'testing1 {i}', line) # TODO: RustPython - @unittest.expectedFailure + @unittest.skip("OS dependent bug") def test_rollover_at_weekday(self): self.test_rollover_at_midnight(weekly=True)