Skip to content

Commit

Permalink
[MLIR][sparse] Add soa property to sparse_tensor Python bindings (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol authored Oct 2, 2024
1 parent 504585d commit b50ce4c
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 2 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir-c/Dialect/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ enum MlirSparseTensorLevelFormat {
enum MlirSparseTensorLevelPropertyNondefault {
MLIR_SPARSE_PROPERTY_NON_UNIQUE = 0x0001,
MLIR_SPARSE_PROPERTY_NON_ORDERED = 0x0002,
MLIR_SPARSE_PROPERTY_SOA = 0x0004,
};

//===----------------------------------------------------------------------===//
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Bindings/Python/DialectSparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
py::enum_<MlirSparseTensorLevelPropertyNondefault>(m, "LevelProperty",
py::module_local())
.value("non_ordered", MLIR_SPARSE_PROPERTY_NON_ORDERED)
.value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE);
.value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE)
.value("soa", MLIR_SPARSE_PROPERTY_SOA);

mlir_attribute_subclass(m, "EncodingAttr",
mlirAttributeIsASparseTensorEncodingAttr)
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/CAPI/Dialect/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ static_assert(
static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) ==
static_cast<int>(LevelPropNonDefault::Nonordered) &&
static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) ==
static_cast<int>(LevelPropNonDefault::Nonunique),
static_cast<int>(LevelPropNonDefault::Nonunique) &&
static_cast<int>(MLIR_SPARSE_PROPERTY_SOA) ==
static_cast<int>(LevelPropNonDefault::SoA),
"MlirSparseTensorLevelProperty (C-API) and "
"LevelPropertyNondefault (C++) mismatch");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def main():
prop = st.LevelProperty
levels = [
[builder(fmt.compressed, [prop.non_unique]), builder(fmt.singleton)],
[
builder(fmt.compressed, [prop.non_unique]),
builder(fmt.singleton, [prop.soa]),
],
[builder(fmt.dense), builder(fmt.compressed)],
[builder(fmt.dense), builder(fmt.loose_compressed)],
[builder(fmt.compressed), builder(fmt.compressed)],
Expand Down

0 comments on commit b50ce4c

Please sign in to comment.