diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 3ca763870b..4018ebab14 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -705,7 +705,11 @@ def _pass_filter( ), ) self.recovered_switch *= tf.reshape( - tf.slice(tf.reshape(mask, [-1, 4]), [0, 0], [-1, 1]), + tf.slice( + tf.reshape(tf.cast(mask, self.filter_precision), [-1, 4]), + [0, 0], + [-1, 1], + ), [-1, natoms[0], self.sel_all_a[0]], ) else: