diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index b8b0e8d..96d0669 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -14,9 +14,9 @@ """Module for calling Triton kernels from JAX.""" -# b/301982023 from __future__ import annotations +from collections.abc import Callable, Sequence import copy import dataclasses import functools @@ -26,9 +26,7 @@ import tempfile import types from typing import Any, Protocol, Union -from collections.abc import Callable, Sequence import zlib -from functools import partial from absl import logging import jax @@ -44,6 +42,7 @@ import jax.numpy as jnp import numpy as np + CAN_USE_TRITON = False try: import triton @@ -226,19 +225,6 @@ def compile_ttir_to_ptx_inplace( ) -> CompilationResult: if cuda_options.debug: print(ttir) - if isinstance(ttir, ir.Module): - context = _triton.ir.context() - _triton.ir.load_dialects(context) - cuda_backend.load_dialects(context) - - # Triton compilation APIs only accept Triton-specific MLIR wrappers. - # So, here we serialize an ir.Module to a file and then deserialize - # it as a tl_ir.module. - with tempfile.NamedTemporaryFile(mode="wb") as f: - ttir.operation.write_bytecode(f) - f.flush() - ttir = tl_ir.parse_mlir_module(f.name, context) - ttir.context = context try: metadata = {} opt_ttir = cuda_backend.make_ttir(ttir, metadata, cuda_options) @@ -295,19 +281,6 @@ def compile_ttir_to_hsaco_inplace( ) -> CompilationResult: if hip_options.debug: print(ttir) - if isinstance(ttir, ir.Module): - context = _triton.ir.context() - _triton.ir.load_dialects(context) - hip_backend.load_dialects(context) - - # Triton compilation APIs only accept Triton-specific MLIR wrappers. - # So, here we serialize an ir.Module to a file and then deserialize - # it as a tl_ir.module. - with tempfile.NamedTemporaryFile(mode="wb") as f: - ttir.operation.write_bytecode(f) - f.flush() - ttir = tl_ir.parse_mlir_module(f.name, context) - ttir.context = context try: metadata = {} opt_ttir = hip_backend.make_ttir(ttir, metadata, hip_options) @@ -554,7 +527,6 @@ def triton_kernel_call_lowering( "`input_output_aliases` only supported on `jaxlib>=0.3.22" ) - kernel_call_name = name args = list(ctx.avals_in) arg_dtypes = list(map(get_triton_type, ctx.avals_in)) @@ -720,13 +692,17 @@ def prune_configs(configs, named_args, **kwargs): operand_output_aliases=dict(input_output_aliases), ).results -mlir.register_lowering(triton_kernel_call_p, - partial(triton_kernel_call_lowering, get_cuda_backend), - platform='cuda') +mlir.register_lowering( + triton_kernel_call_p, + functools.partial(triton_kernel_call_lowering, get_cuda_backend), + platform="cuda", +) -mlir.register_lowering(triton_kernel_call_p, - partial(triton_kernel_call_lowering, get_hip_backend), - platform='rocm') +mlir.register_lowering( + triton_kernel_call_p, + functools.partial(triton_kernel_call_lowering, get_hip_backend), + platform="rocm", +) class ShapeDtype(Protocol):