diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 96d0669..33459d3 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -538,7 +538,8 @@ def triton_kernel_call_lowering( named_args = dict(unsafe_zip(fn.arg_names, args)) if isinstance(fn, autotuner.Autotuner): - if any(idx not in fn.key_idx for idx, _, _ in scalar_args): + key_idxs = [fn.arg_names.index(k) for k in fn.keys] + if any(idx not in key_idxs for idx, _, _ in scalar_args): logging.warning( "Auto-tuning key does not include all scalar arguments. " "We may perform redundant auto-tuning."