diff --git a/train.py b/train.py index e4633ce..0b5847f 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,11 @@ import datetime -device = "cuda" if torch.cuda.is_available() else "cpu" +device = "cpu" +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" lexicon = get_lexicon(config["model_type"])