Skip to content

Commit

Permalink
Update torch.cuda.amp to torch.amp in pytorch.py (#945)
Browse files Browse the repository at this point in the history
update pytorch.py to get rid of torch.cuda.amp deprecated warning.
  • Loading branch information
Atlogit committed Sep 30, 2024
1 parent 04956f0 commit e08ce14
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions thinc/shims/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def predict(self, inputs: ArgsKwargs) -> Any:
"""
self._model.eval()
with torch.no_grad():
with torch.cuda.amp.autocast(self._mixed_precision):
with torch.amp.autocast("cuda", self._mixed_precision):
outputs = self._model(*inputs.args, **inputs.kwargs)
self._model.train()
return outputs
Expand All @@ -125,7 +125,7 @@ def begin_update(self, inputs: ArgsKwargs):
self._model.train()

# Note: mixed-precision autocast must not be applied to backprop.
with torch.cuda.amp.autocast(self._mixed_precision):
with torch.amp.autocast("cuda", self._mixed_precision):
output = self._model(*inputs.args, **inputs.kwargs)

def backprop(grads):
Expand Down

0 comments on commit e08ce14

Please sign in to comment.