Files

35 lines
903 B
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from contextlib import nullcontext
from functools import wraps
from typing import Callable, Optional, Tuple, Type, TypeVar, Union, overload, ContextManager
import torch
__all__ = ["LayerType", "PadType", "nullwrap", "disable_compiler"]
LayerType = Union[str, Callable, Type[torch.nn.Module]]
PadType = Union[str, int, Tuple[int, int]]
F = TypeVar("F", bound=Callable[..., object])
@overload
def nullwrap(fn: F) -> F: ... # decorator form
@overload
def nullwrap(fn: None = ...) -> ContextManager: ... # contextmanager form
def nullwrap(fn: Optional[F] = None):
# as a context manager
if fn is None:
return nullcontext() # `with nullwrap():`
# as a decorator
@wraps(fn)
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)
return wrapper # `@nullwrap`
disable_compiler = getattr(getattr(torch, "compiler", None), "disable", None) or nullwrap