Source code for rtm_wrapper.util

"""
Misc utilities.
"""
from __future__ import annotations

import importlib.metadata
import logging.config
import platform
import subprocess
from typing import Any, Callable, Hashable, Iterable, TypeVar

from typing_extensions import Never

_T = TypeVar("_T")
_H = TypeVar("_H", bound=Hashable)


[docs] class TrapCalledError(RuntimeError): """Raised when a trap callable is invoked.""" message: str args: tuple[Any, ...] kwargs: dict[str, Any] def __init__( self, message: str, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> None: self.message = message self.args = args self.kwargs = kwargs
[docs] def __str__(self) -> str: return f"Trap called: {self.message}."
[docs] def setup_debug_root_logging(level: int = logging.NOTSET) -> None: """ Configure the root logger with a basic debugging configuration. All records at the given level or above will be written to stdout. This function should be called once near the start of an application entry point, BEFORE any calls to ``logging.getLogger`` are made. Disables any existing loggers. """ logging.config.dictConfig( { "version": 1, "disable_existing_loggers": True, "formatters": { "console": { "format": "[{asctime},{msecs:06.2f}] {levelname:7s} ({threadName}:{name}) {funcName}:{lineno} {message}", "style": "{", "datefmt": "%Y-%m-%d %H:%M:%S", "validate": True, }, }, "handlers": { "console": { "class": "logging.StreamHandler", "formatter": "console", "level": "NOTSET", # Capture everything. "stream": "ext://sys.stdout", } }, "root": {"handlers": ["console"], "level": level}, } )
[docs] def partition_dict( dictionary: dict[_H, _T], predicate: Callable[[_H], bool] ) -> tuple[dict[_H, _T], dict[_H, _T]]: """ Partition the given dictionary based on the provided predicate. >>> d = {i: i**2 for i in range(6)} >>> partition_dict(d, lambda x: x % 2 == 0) ({0: 0, 2: 4, 4: 16}, {1: 1, 3: 9, 5: 25}) """ left_dict = {} right_dict = {} for key, value in dictionary.items(): if predicate(key): left_dict[key] = value else: right_dict[key] = value return left_dict, right_dict
[docs] def build_version() -> str: """Return the version of this distribution with local build number, if available.""" base_version = importlib.metadata.version("rtm-wrapper") try: result = subprocess.run( ["git", "rev-parse", "--short", "HEAD"], text=True, check=True, capture_output=True, ) build_commit = result.stdout.strip() return f"{base_version}+{build_commit}" except (FileNotFoundError, subprocess.SubprocessError): return base_version
[docs] def platform_summary() -> str: """Return a platform summary string.""" return f"{platform.python_implementation()} {platform.python_version()} ({' '.join(platform.uname())})"
[docs] def first_or(iterable: Iterable[_T], default: _T | None = None) -> _T | None: """ Return the first element of the iterable, or the given default if the iterable is empty. """ try: return next(iter(iterable)) except StopIteration: return default
[docs] def trap(message: str) -> Callable[..., Never]: """Return a trap callable that raises when called.""" def _raise(*args: Any, **kwargs: Any) -> Never: raise TrapCalledError(message, args, kwargs) return _raise