35 lines
903 B
Python
35 lines
903 B
Python
from contextlib import nullcontext
|
||
from functools import wraps
|
||
from typing import Callable, Optional, Tuple, Type, TypeVar, Union, overload, ContextManager
|
||
|
||
import torch
|
||
|
||
__all__ = ["LayerType", "PadType", "nullwrap", "disable_compiler"]
|
||
|
||
|
||
LayerType = Union[str, Callable, Type[torch.nn.Module]]
|
||
PadType = Union[str, int, Tuple[int, int]]
|
||
|
||
F = TypeVar("F", bound=Callable[..., object])
|
||
|
||
|
||
@overload
|
||
def nullwrap(fn: F) -> F: ... # decorator form
|
||
|
||
@overload
|
||
def nullwrap(fn: None = ...) -> ContextManager: ... # context‑manager form
|
||
|
||
def nullwrap(fn: Optional[F] = None):
|
||
# as a context manager
|
||
if fn is None:
|
||
return nullcontext() # `with nullwrap():`
|
||
|
||
# as a decorator
|
||
@wraps(fn)
|
||
def wrapper(*args, **kwargs):
|
||
return fn(*args, **kwargs)
|
||
return wrapper # `@nullwrap`
|
||
|
||
|
||
disable_compiler = getattr(getattr(torch, "compiler", None), "disable", None) or nullwrap
|