Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Add f8E8M0FNU type #111028

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

sergey-kozub
Copy link
Contributor

@sergey-kozub sergey-kozub commented Oct 3, 2024

This PR adds f8E8M0FNU type to MLIR.

f8E8M0FNU type is proposed in OpenCompute MX Specification. It defines a 8-bit floating point number with bit layout S0E8M0. Unlike IEEE-754 types, there are no infinity, denormals, zeros or negative values.

f8E8M0FNU
- Exponent bias: 127
- Maximum stored exponent value: 254 (binary 1111'1110)
- Maximum unbiased exponent value: 254 - 127 = 127
- Minimum stored exponent value: 0 (binary 0000'0000)
- Minimum unbiased exponent value: 0127 = -127
- Doesn't have zero
- Doesn't have infinity
- NaN is encoded as binary 1111'1111

Additional details:
- Zeros cannot be represented
- Negative values cannot be represented
- Mantissa is always 1

Related PRs:

  • PR-107127 [APFloat] Add APFloat support for E8M0 type
  • PR-105573 [MLIR] Add f6E3M2FN type - was used as a template for this PR
  • PR-107999 [MLIR] Add f6E2M3FN type
  • PR-108877 [MLIR] Add f4E2M1FN type

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 3, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-llvm

Author: Sergey Kozub (sergey-kozub)

Changes

This PR adds f8E8M0FNU type to MLIR.

f8E8M0FNU type is proposed in OpenCompute MX Specification. It defines a 8-bit floating point number with bit layout S0E8M0. Unlike IEEE-754 types, there are no infinity, zeros or negative values.

f8E8M0FNU
- Exponent bias: 127
- Maximum stored exponent value: 254 (binary 1111'1110)
- Maximum unbiased exponent value: 254 - 127 = 127
- Minimum stored exponent value: 0 (binary 0000'0000)
- Minimum unbiased exponent value: 0127 = -127
- Doesn't have zero
- Doesn't have infinity
- NaN is encoded as binary 1111'1111

Additional details:
- Zeros cannot be represented
- Negative values cannot be represented
- Mantissa is always 1

Related PRs:

  • PR-107127 [APFloat] Add APFloat support for E8M0 type
  • PR-105573 [MLIR] Add f6E3M2FN type - was used as a template for this PR
  • PR-107999 [MLIR] Add f6E2M3FN type
  • PR-108877 [MLIR] Add f4E2M1FN type

Patch is 20.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/111028.diff

24 Files Affected:

  • (modified) mlir/include/mlir-c/BuiltinTypes.h (+10)
  • (modified) mlir/include/mlir/IR/Builders.h (+1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+11-5)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+23)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+2)
  • (modified) mlir/include/mlir/IR/Types.h (+1)
  • (modified) mlir/lib/AsmParser/TokenKinds.def (+1)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+4)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+22)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+12)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Arith/Utils/Utils.cpp (+1)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+1)
  • (modified) mlir/lib/IR/Builders.cpp (+4)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+2)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+5)
  • (modified) mlir/lib/IR/Types.cpp (+3)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+14)
  • (modified) mlir/python/mlir/extras/types.py (+2)
  • (modified) mlir/test/IR/attribute.mlir (+4)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+3)
  • (modified) mlir/test/python/ir/builtin_types.py (+9)
  • (modified) mlir/utils/lldb-scripts/mlirDataFormatters.py (+1)
  • (modified) mlir/utils/tree-sitter-mlir/grammar.js (+2-1)
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 6dc25a56b8e614..6875fab7bf7961 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -179,6 +179,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx);
 
+/// Returns the typeID of an Float8E8M0FNU type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID(void);
+
+/// Checks whether the given type is an f8E8M0FNU type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E8M0FNU(MlirType type);
+
+/// Creates an f8E8M0FNU type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx);
+
 /// Returns the typeID of an BFloat16 type.
 MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index ee5d7879625309..04a8bddc3cd59a 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -70,6 +70,7 @@ class Builder {
   FloatType getFloat8E4M3FNUZType();
   FloatType getFloat8E4M3B11FNUZType();
   FloatType getFloat8E3M4Type();
+  FloatType getFloat8E8M0FNUType();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getTF32Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 91e68b4066dd67..25535408f4528a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -70,6 +70,7 @@ class FloatType : public Type {
   static FloatType getFloat4E2M1FN(MLIRContext *ctx);
   static FloatType getFloat6E2M3FN(MLIRContext *ctx);
   static FloatType getFloat6E3M2FN(MLIRContext *ctx);
+  static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
@@ -416,11 +417,12 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 }
 
 inline bool FloatType::classof(Type type) {
-  return llvm::isa<
-      Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType, Float8E5M2Type,
-      Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, Float8E4M3FNUZType,
-      Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type, Float16Type,
-      FloatTF32Type, Float32Type, Float64Type, Float80Type, Float128Type>(type);
+  return llvm::isa<Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
+                   Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
+                   Float8E5M2FNUZType, Float8E4M3FNUZType,
+                   Float8E4M3B11FNUZType, Float8E3M4Type, Float8E8M0FNUType,
+                   BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
+                   Float64Type, Float80Type, Float128Type>(type);
 }
 
 inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) {
@@ -463,6 +465,10 @@ inline FloatType FloatType::getFloat8E3M4(MLIRContext *ctx) {
   return Float8E3M4Type::get(ctx);
 }
 
+inline FloatType FloatType::getFloat8E8M0FNU(MLIRContext *ctx) {
+  return Float8E8M0FNUType::get(ctx);
+}
+
 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
   return BFloat16Type::get(ctx);
 }
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index b2b41b16beec29..dca228097d782d 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -296,6 +296,29 @@ def Builtin_Float6E3M2FN : Builtin_FloatType<"Float6E3M2FN", "f6E3M2FN"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float8E8M0FNUType
+
+def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> {
+  let summary = "8-bit floating point with 8-bit exponent, no mantissa or sign";
+  let description = [{
+    An 8-bit floating point type with no sign bit, 8 bits exponent and no
+    mantissa. This is not a standard type as defined by IEEE-754; it is intended
+    to be used for representing scaling factors, so it cannot represent zeros
+    and negative numbers. The values it can represent are powers of two in the
+    range [-127,127] and NaN.
+
+      * bit encoding: S0E8M0
+      * exponent bias: 127
+      * infinities: Not supported
+      * NaNs: Supported with all bits set to 1
+      * denormals: Not supported
+
+    Open Compute Project (OCP) microscaling formats (MX) specification:
+    https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 211385245555ad..48e4c24f838652 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -353,6 +353,8 @@ def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
                BuildableType<"$_builder.getFloat6E2M3FNType()">;
 def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
                BuildableType<"$_builder.getFloat6E3M2FNType()">;
+def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
+                BuildableType<"$_builder.getFloat8E8M0FNUType()">;
 
 def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
                       "complex-type", "::mlir::ComplexType">;
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 1b52b97f29b5f5..acd0f894abbbe6 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -135,6 +135,7 @@ class Type {
   bool isFloat8E4M3FNUZ() const;
   bool isFloat8E4M3B11FNUZ() const;
   bool isFloat8E3M4() const;
+  bool isFloat8E8M0FNU() const;
   bool isBF16() const;
   bool isF16() const;
   bool isTF32() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 2b29177b7dff0f..49da8c3dea5fa5 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -104,6 +104,7 @@ TOK_KEYWORD(f8E3M4)
 TOK_KEYWORD(f4E2M1FN)
 TOK_KEYWORD(f6E2M3FN)
 TOK_KEYWORD(f6E3M2FN)
+TOK_KEYWORD(f8E8M0FNU)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
 TOK_KEYWORD(floordiv)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 60903a86ff8ce1..c614eb39b364be 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -49,6 +49,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_f8E4M3FNUZ:
   case Token::kw_f8E4M3B11FNUZ:
   case Token::kw_f8E3M4:
+  case Token::kw_f8E8M0FNU:
   case Token::kw_bf16:
   case Token::kw_f16:
   case Token::kw_tf32:
@@ -336,6 +337,9 @@ Type Parser::parseNonFunctionType() {
   case Token::kw_f8E3M4:
     consumeToken(Token::kw_f8E3M4);
     return builder.getFloat8E3M4Type();
+  case Token::kw_f8E8M0FNU:
+    consumeToken(Token::kw_f8E8M0FNU);
+    return builder.getFloat8E8M0FNUType();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
     return builder.getBF16Type();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 5a369b5d4938cb..6f192bc4bffeef 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -331,6 +331,27 @@ class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
   }
 };
 
+/// Floating Point Type subclass - Float8E8M0FNUType.
+class PyFloat8E8M0FNUType
+    : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E8M0FNUTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E8M0FNUType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
+          return PyFloat8E8M0FNUType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float8_e8m0fnu type.");
+  }
+};
+
 /// Floating Point Type subclass - BF16Type.
 class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
 public:
@@ -953,6 +974,7 @@ void mlir::python::populateIRTypes(py::module &m) {
   PyFloat8E4M3B11FNUZType::bind(m);
   PyFloat8E5M2FNUZType::bind(m);
   PyFloat8E3M4Type::bind(m);
+  PyFloat8E8M0FNUType::bind(m);
   PyBF16Type::bind(m);
   PyF16Type::bind(m);
   PyTF32Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index efc1e857a39c7a..252ff54afe0c5d 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -205,6 +205,18 @@ MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E3M4(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() {
+  return wrap(Float8E8M0FNUType::getTypeID());
+}
+
+bool mlirTypeIsAFloat8E8M0FNU(MlirType type) {
+  return unwrap(type).isFloat8E8M0FNU();
+}
+
+MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat8E8M0FNU(unwrap(ctx)));
+}
+
 MlirTypeID mlirBFloat16TypeGetTypeID() {
   return wrap(BFloat16Type::getTypeID());
 }
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index fd6369b5bb4ee5..5a92fa839e9847 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -250,7 +250,8 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) const {
   if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
       type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
       type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
-      type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN())
+      type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
+      type.isFloat8E8M0FNU())
     return IntegerType::get(&getContext(), type.getWidth());
   return type;
 }
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index c0aa16cc0da407..67dcce454f028b 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -370,6 +370,7 @@ std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
       .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
       .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
       .Case("f8E3M4", b.getFloat8E3M4Type())
+      .Case("f8E8M0FNU", b.getFloat8E8M0FNUType())
       .Case("bf16", b.getBF16Type())
       .Case("f16", b.getF16Type())
       .Case("f32", b.getF32Type())
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 7f95f5ace8c00f..96fb66d53fb835 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2588,6 +2588,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
       .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
       .Case<Float8E3M4Type>([&](Type) { os << "f8E3M4"; })
+      .Case<Float8E8M0FNUType>([&](Type) { os << "f8E8M0FNU"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
       .Case<FloatTF32Type>([&](Type) { os << "tf32"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 7aed415343e551..a9bc3c0ef65a23 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -74,6 +74,10 @@ FloatType Builder::getFloat8E3M4Type() {
   return FloatType::getFloat8E3M4(context);
 }
 
+FloatType Builder::getFloat8E8M0FNUType() {
+  return FloatType::getFloat8E8M0FNU(context);
+}
+
 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 782a32b3074680..25e9f80c9963cb 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -121,6 +121,8 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
     return APFloat::Float8E4M3B11FNUZ();
   if (llvm::isa<Float8E3M4Type>(*this))
     return APFloat::Float8E3M4();
+  if (llvm::isa<Float8E8M0FNUType>(*this))
+    return APFloat::Float8E8M0FNU();
   if (llvm::isa<BFloat16Type>(*this))
     return APFloat::BFloat();
   if (llvm::isa<Float16Type>(*this))
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f45de17dd24910..f05666fcde207b 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -231,6 +231,7 @@ class MLIRContextImpl {
   Float8E4M3FNUZType f8E4M3FNUZTy;
   Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
   Float8E3M4Type f8E3M4Ty;
+  Float8E8M0FNUType f8E8M0FNUTy;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   FloatTF32Type tf32Ty;
@@ -326,6 +327,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
   impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
   impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
   impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
+  impl->f8E8M0FNUTy = TypeUniquer::get<Float8E8M0FNUType>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
@@ -1049,6 +1051,9 @@ Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
 Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
   return context->getImpl().f8E3M4Ty;
 }
+Float8E8M0FNUType Float8E8M0FNUType::get(MLIRContext *context) {
+  return context->getImpl().f8E8M0FNUTy;
+}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index efefbc299a91f3..e190902b2e4898 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -49,6 +49,9 @@ bool Type::isFloat8E4M3FNUZ() const {
 bool Type::isFloat8E4M3B11FNUZ() const {
   return llvm::isa<Float8E4M3B11FNUZType>(*this);
 }
+bool Type::isFloat8E8M0FNU() const {
+  return llvm::isa<Float8E8M0FNUType>(*this);
+}
 bool Type::isFloat8E3M4() const { return llvm::isa<Float8E3M4Type>(*this); }
 bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
 bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 41ed84e0467254..fb7efb8cd28a5e 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -117,6 +117,7 @@ __all__ = [
     "Float8E4M3Type",
     "Float8E5M2FNUZType",
     "Float8E5M2Type",
+    "Float8E8M0FNUType",
     "FloatAttr",
     "FloatTF32Type",
     "FloatType",
@@ -1660,6 +1661,19 @@ class Float8E5M2Type(FloatType):
     @property
     def typeid(self) -> TypeID: ...
 
+class Float8E8M0FNUType(FloatType):
+    static_typeid: ClassVar[TypeID]
+    @staticmethod
+    def get(context: Context | None = None) -> Float8E8M0FNUType:
+        """
+        Create a float8_e8m0fnu type.
+        """
+    @staticmethod
+    def isinstance(other: Type) -> bool: ...
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @property
+    def typeid(self) -> TypeID: ...
+
 class FloatAttr(Attribute):
     static_typeid: ClassVar[TypeID]
     @staticmethod
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index 5b24a6d526f2f8..34eee1edb57ff5 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -20,6 +20,7 @@
     Float8E4M3FNType,
     Float8E4M3Type,
     Float8E5M2Type,
+    Float8E8M0FNUType,
     FunctionType,
     IndexType,
     IntegerType,
@@ -80,6 +81,7 @@ def ui(width):
 f4E2M1FN = lambda: Float4E2M1FNType.get()
 f6E2M3FN = lambda: Float6E2M3FNType.get()
 f6E3M2FN = lambda: Float6E3M2FNType.get()
+f8E8M0FNU = lambda: Float8E8M0FNUType.get()
 
 none = lambda: NoneType.get()
 
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 31a4663f72e6e9..a62de3f5004d73 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -76,6 +76,10 @@ func.func @float_attrs_pass() {
     // CHECK: float_attr = 2.000000e+00 : f8E3M4
     float_attr = 2. : f8E3M4
   } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f8E8M0FNU
+    float_attr = 2. : f8E8M0FNU
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f16
     float_attr = 2. : f16
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 327c9f05f4c72c..c884f83cb4d32d 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -72,6 +72,9 @@ llvm.mlir.global internal @f8E5M2FNUZ_global_as_i8(1.5 : f8E5M2FNUZ) : i8
 // CHECK: @f8E4M3B11FNUZ_global_as_i8 = internal global i8 92
 llvm.mlir.global internal @f8E4M3B11FNUZ_global_as_i8(1.5 : f8E4M3B11FNUZ) : i8
 
+// CHECK: @f8E8M0FNU_global_as_i8 = internal global i8 127
+llvm.mlir.global internal @f8E8M0FNU_global_as_i8(1.0 : f8E8M0FNU) : i8
+
 // CHECK: @bf16_global_as_i16 = internal global i16 16320
 llvm.mlir.global internal @bf16_global_as_i16(1.5 : bf16) : i16
 
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 6154a6ff9e9aed..48ddc8359ca0a1 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -133,6 +133,8 @@ def testFloatTypeSubclasses():
     # CHECK: True
     print(isinstance(Type.parse("f8E5M2FNUZ", ctx), FloatType))
     # CHECK: True
+    print(isinstance(Type.parse("f8E8M0FNU", ctx), FloatType))
+    # CHECK: True
     print(isinstance(Type.parse("f16", ctx), FloatType))
     # CHECK: True
     print(isinstance(Type.parse("bf16", ctx), FloatType))
@@ -259,6 +261,8 @@ def testFloatType():
         print("float:", Float8E4M3FNUZType.get())
         # CHECK: float: f8E4M3B11FNUZ
         print("float:", Float8E4M3B11FNUZType.get())
+        # CHECK: float: f8E8M0FNU
+        print("float:", Float8E8M0FNUType.get())
         # CHECK: float: bf16
         print("float:", BF16Type.get())
         # CHECK: float: f16
@@ -631,6 +635,7 @@ def testTypeIDs():
             (Float8E4M3FNUZType, Float8E4M3FNUZType.get()),
             (Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()),
             (Float8E5M2FNUZType, Float8E5M2FNUZType.get()),
+            (Float8E8M0FNUType, Float8E8M0FNUType.get()),
             (BF16Type, BF16Type.get()),
             (F16Type, F16Type.get()),
             (F32Type, F32Type.get()),
@@ -659,6 +664,7 @@ def testTypeIDs():
         # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
         # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
         # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
+        # CHECK: Float8E8M0FNUType(f8E8M0FNU)
         # CHECK: BF16Type(bf16)
         # CHECK: F16Type(f16)
         # CHECK: F32Type(f32)
@@ -761,6 +767,9 @@ def print_downcasted(typ):
         # CHECK: Float8E5M2FNUZType
         # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
         print_downcasted(Float8E5M2FNUZType.get())
+        # CHECK: Float8E8M0FNUType
+        # CHECK: Float8E8M0FNUType(f8E8M0FNU)
+        print_downcasted(Float8E8M0FNUType.get())
         # CHECK: BF16Type
         # CHECK: BF16Type(bf16)
         print_downcasted(BF16Type.get())
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index 54d3d703640403..38e8278eefbbd3 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -60,6 +60,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
     "mlir::Float8E4M3FNUZType": '"f8E4M3FNUZ"',
     "mlir::Float8E4M3B11FNUZType": '"f8E4M3B11FNUZ"',
     "mlir::Float8E3M4Type": '"f8E3M4"',
+    "mlir::Float8E8M0FNUType": '"f8E8M0FNU"',
     "mlir::BFloat16Type": '"bf16"',
     "mlir::Float16Type": '"f16"',
     "mlir::FloatTF32Type": '"tf32"',
diff --git a/mlir/utils/tree-sitter-mlir/grammar.js b/mlir/utils/tree-sitter-mlir/grammar.js
index f7d916dfb57e2f..2dadd46c4760ca 1006...
[truncated]

@durga4github
Copy link
Contributor

Looks good to me. Follows the same template from the earlier PRs.

In the commit message:
Unlike IEEE-754 types, there are no infinity, zeros or negative values

Can we also add "no denorms" too?

@sergey-kozub
Copy link
Contributor Author

Can we also add "no denorms" too?

Added, thanks.

@stellaraccident
Copy link
Contributor

stellaraccident commented Oct 3, 2024

Fly on the wall, but at what point, upon removing all features that traditionally make something a "floating point number" (mantissa, zero, denorms, infinities) does something no longer make any sense at all being part of a floating point hierarchy. It's just a bit-vector with a special error value.

I'm not blocking this in any way or even asking seriously. Just kind of balking at the cargo cult mentality that is going into bundling these things together like this.

@River707
Copy link
Contributor

River707 commented Oct 3, 2024

Fly on the wall, but at what point, upon removing all features that traditionally make something a "floating point number" (mantissa, zero, denorms, infinities) does something no longer make any sense at all being part of a floating point hierarchy. It's just a bit-vector with a special error value.

I'm not blocking this in any way or even asking seriously. Just kind of balking at the cargo cult mentality that is going into bundling these things together like this.

I had the same gut reaction... We are now up to ~18 floating point types... That kind of points to a serious issue with the way things are scaling here, and I think we should really rethink what's being done (especially given that each one of these PRs are identical, add big chunks of code in the core library, etc).

@sergey-kozub
Copy link
Contributor Author

Fly on the wall, but at what point, upon removing all features that traditionally make something a "floating point number" (mantissa, zero, denorms, infinities) does something no longer make any sense at all being part of a floating point hierarchy. It's just a bit-vector with a special error value.

I also wondered what makes it a floating point number.
From my perspective, both int and floating point types represent values in a numeric range. The difference is that for int types, the distance between adjacent numbers is constant, and for floating point types it's variable.

For E8M0, mantissa is there but is implicit (has 1 bit which has value of one) - other FP types also have an implicit bit of data.
Unsigned ints also can't represent negative numbers (same with E8M0). Not having infinities is also common, e.g. for other FP8 types like E4M3FN and E5M2FNUZ.

The E8M0 is intended to be used as a scaling factor in block scaled formats like MXFP8, which is exactly why it doesn't have negatives, infinities or zeros - none of these makes sense for a scaling factor.

@stellaraccident
Copy link
Contributor

Fly on the wall, but at what point, upon removing all features that traditionally make something a "floating point number" (mantissa, zero, denorms, infinities) does something no longer make any sense at all being part of a floating point hierarchy. It's just a bit-vector with a special error value.
I'm not blocking this in any way or even asking seriously. Just kind of balking at the cargo cult mentality that is going into bundling these things together like this.

I had the same gut reaction... We are now up to ~18 floating point types... That kind of points to a serious issue with the way things are scaling here, and I think we should really rethink what's being done (especially given that each one of these PRs are identical, add big chunks of code in the core library, etc).

Ok, I didn't want to unilaterally hit pause, but it seems like we've got a number of people with the same analysis. Should we at least discuss this a bit more? Or proceed with this patch? I agree that the path we're on is not very sustainable.

@sergey-kozub
Copy link
Contributor Author

especially given that each one of these PRs are identical

The amount of boilerplate code is annoying, I believe this could (and should) be generalized.
One could come up with a few dozen more FP8 (or smaller dtypes).

@stellaraccident
Copy link
Contributor

especially given that each one of these PRs are identical

The amount of boilerplate code is annoying, I believe this could (and should) be generalized. One could come up with a few dozen more FP8 (or smaller dtypes).

Yeah, it was never meant to scale beyond the primary fp8 types. Needs a rethink... if not in this case, certainly soon.

@sergey-kozub
Copy link
Contributor Author

One could come up with a few dozen more FP8 (or smaller dtypes).

For example, https://arxiv.org/html/2405.13938v1 mentions more esoterics like e0m3 and e1m3.

@stellaraccident
Copy link
Contributor

One could come up with a few dozen more FP8 (or smaller dtypes).

For example, https://arxiv.org/html/2405.13938v1 mentions more esoterics like e0m3 and e1m3.

Just judging by the mood and temperament of the industry, we'll end up with just about every combination before too long. Might as well try to structure the code for that eventuality vs being the victim of it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants