From e08ce14966effb42b14f124893bf96d6fb533712 Mon Sep 17 00:00:00 2001 From: Atlogit <86947554+Atlogit@users.noreply.github.com> Date: Mon, 30 Sep 2024 15:52:12 +0300 Subject: [PATCH] Update torch.cuda.amp to torch.amp in pytorch.py (#945) update pytorch.py to get rid of torch.cuda.amp deprecated warning. --- thinc/shims/pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thinc/shims/pytorch.py b/thinc/shims/pytorch.py index 505669867..3d18daf0f 100644 --- a/thinc/shims/pytorch.py +++ b/thinc/shims/pytorch.py @@ -111,7 +111,7 @@ def predict(self, inputs: ArgsKwargs) -> Any: """ self._model.eval() with torch.no_grad(): - with torch.cuda.amp.autocast(self._mixed_precision): + with torch.amp.autocast("cuda", self._mixed_precision): outputs = self._model(*inputs.args, **inputs.kwargs) self._model.train() return outputs @@ -125,7 +125,7 @@ def begin_update(self, inputs: ArgsKwargs): self._model.train() # Note: mixed-precision autocast must not be applied to backprop. - with torch.cuda.amp.autocast(self._mixed_precision): + with torch.amp.autocast("cuda", self._mixed_precision): output = self._model(*inputs.args, **inputs.kwargs) def backprop(grads):