13 lines
383 B
Python
13 lines
383 B
Python
import torch
|
|
from contextlib import suppress
|
|
|
|
|
|
def get_autocast(precision):
|
|
if precision == 'amp':
|
|
return torch.cuda.amp.autocast
|
|
elif precision == 'amp_bfloat16' or precision == 'amp_bf16':
|
|
# amp_bfloat16 is more stable than amp float16 for clip training
|
|
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
|
|
else:
|
|
return suppress
|