Skip to content

Commit

Permalink
Improve int4 constexpr-ness, add more operators, numeric_limits.
Browse files Browse the repository at this point in the history
This is to allow better support for int4 in C++.

PiperOrigin-RevId: 561778414
  • Loading branch information
cantonios authored and The ml_dtypes Authors committed Aug 31, 2023
1 parent fa8060c commit 28cc246
Show file tree
Hide file tree
Showing 2 changed files with 592 additions and 32 deletions.
227 changes: 195 additions & 32 deletions ml_dtypes/include/int4.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define ML_DTYPES_INT4_H_

#include <cstdint>
#include <limits>
#include <optional>
#include <ostream>
#include <sstream>
Expand All @@ -30,11 +31,15 @@ struct i4 {
UnderlyingTy v : 4;

public:
i4() : v(0) {}
explicit i4(UnderlyingTy val) : v(val & 0x0F) {}
constexpr i4() : v(0) {}
constexpr i4(const i4& other) = default;
constexpr i4(i4&& other) = default;
constexpr i4& operator=(const i4& other) = default;
constexpr i4& operator=(i4&&) = default;

explicit constexpr i4(UnderlyingTy val) : v(val & 0x0F) {}
template <typename T>
explicit i4(T t) : i4(static_cast<UnderlyingTy>(t)) {}
i4(const i4& other) = default;
explicit constexpr i4(T t) : i4(static_cast<UnderlyingTy>(t)) {}

static constexpr i4 lowest() {
return std::is_signed<UnderlyingTy>::value ? i4(-8) : i4(0);
Expand All @@ -44,41 +49,112 @@ struct i4 {
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T>>>
explicit operator T() const {
explicit constexpr operator T() const {
return static_cast<T>(v);
}
// NOLINTNEXTLINE(google-explicit-constructor)
operator std::optional<int64_t>() const { return static_cast<int64_t>(v); }

i4 operator-() const { return i4(-v); }
i4 operator+(const i4& other) const { return i4((v + other.v)); }
i4 operator-(const i4& other) const { return i4((v - other.v)); }
i4 operator*(const i4& other) const { return i4((v * other.v)); }
i4 operator/(const i4& other) const { return i4((v / other.v)); }
i4 operator%(const i4& other) const { return i4((v % other.v)); }

i4 operator>>(const int amount) const { return i4((v >> amount)); }
i4 operator<<(const int amount) const { return i4((v << amount)); }

bool operator==(const i4& other) const { return v == other.v; }
bool operator!=(const i4& other) const { return v != other.v; }
bool operator<(const i4& other) const { return v < other.v; }
bool operator>(const i4& other) const { return v > other.v; }
bool operator<=(const i4& other) const { return v <= other.v; }
bool operator>=(const i4& other) const { return v >= other.v; }

bool operator==(const int64_t other) const { return v == other; }
bool operator!=(const int64_t other) const { return v != other; }
bool operator<(const int64_t other) const { return v < other; }
bool operator>(const int64_t other) const { return v > other; }
bool operator<=(const int64_t other) const { return v <= other; }
bool operator>=(const int64_t other) const { return v >= other; }

i4& operator++() {
constexpr operator std::optional<int64_t>() const {
return static_cast<int64_t>(v);
}

constexpr i4 operator-() const { return i4(-v); }
constexpr i4 operator+(const i4& other) const { return i4((v + other.v)); }
constexpr i4 operator-(const i4& other) const { return i4((v - other.v)); }
constexpr i4 operator*(const i4& other) const { return i4((v * other.v)); }
constexpr i4 operator/(const i4& other) const { return i4((v / other.v)); }
constexpr i4 operator%(const i4& other) const { return i4((v % other.v)); }

constexpr i4 operator&(const i4& other) const { return i4((v & other.v)); }
constexpr i4 operator|(const i4& other) const { return i4((v | other.v)); }
constexpr i4 operator^(const i4& other) const { return i4((v ^ other.v)); }
constexpr i4 operator~() const { return i4(~v); }
constexpr i4 operator>>(int amount) const { return i4((v >> amount)); }
constexpr i4 operator<<(int amount) const { return i4((v << amount)); }

constexpr bool operator==(const i4& other) const { return v == other.v; }
constexpr bool operator!=(const i4& other) const { return v != other.v; }
constexpr bool operator<(const i4& other) const { return v < other.v; }
constexpr bool operator>(const i4& other) const { return v > other.v; }
constexpr bool operator<=(const i4& other) const { return v <= other.v; }
constexpr bool operator>=(const i4& other) const { return v >= other.v; }

constexpr bool operator==(int64_t other) const { return v == other; }
constexpr bool operator!=(int64_t other) const { return v != other; }
constexpr bool operator<(int64_t other) const { return v < other; }
constexpr bool operator>(int64_t other) const { return v > other; }
constexpr bool operator<=(int64_t other) const { return v <= other; }
constexpr bool operator>=(int64_t other) const { return v >= other; }

friend constexpr bool operator==(int64_t a, const i4& b) { return a == b.v; }
friend constexpr bool operator!=(int64_t a, const i4& b) { return a != b.v; }
friend constexpr bool operator<(int64_t a, const i4& b) { return a < b.v; }
friend constexpr bool operator>(int64_t a, const i4& b) { return a > b.v; }
friend constexpr bool operator<=(int64_t a, const i4& b) { return a <= b.v; }
friend constexpr bool operator>=(int64_t a, const i4& b) { return a >= b.v; }

constexpr i4& operator++() {
v = (v + 1) & 0x0F;
return *this;
}

constexpr i4 operator++(int) {
i4 orig = *this;
this->operator++();
return orig;
}

constexpr i4& operator--() {
v = (v - 1) & 0x0F;
return *this;
}

constexpr i4 operator--(int) {
i4 orig = *this;
this->operator--();
return orig;
}

constexpr i4& operator+=(const i4& other) {
*this = *this + other;
return *this;
}
constexpr i4& operator-=(const i4& other) {
*this = *this - other;
return *this;
}
constexpr i4& operator*=(const i4& other) {
*this = *this * other;
return *this;
}
constexpr i4& operator/=(const i4& other) {
*this = *this / other;
return *this;
}
constexpr i4& operator%=(const i4& other) {
*this = *this % other;
return *this;
}
constexpr i4& operator&=(const i4& other) {
*this = *this & other;
return *this;
}
constexpr i4& operator|=(const i4& other) {
*this = *this | other;
return *this;
}
constexpr i4& operator^=(const i4& other) {
*this = *this ^ other;
return *this;
}
constexpr i4& operator>>=(int amount) {
*this = *this >> amount;
return *this;
}
constexpr i4& operator<<=(int amount) {
*this = *this << amount;
return *this;
}

friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) {
os << static_cast<int16_t>(num.v);
return os;
Expand All @@ -94,6 +170,93 @@ struct i4 {
using int4 = i4<int8_t>;
using uint4 = i4<uint8_t>;

namespace internal {

struct int4_numeric_limits_base {
static inline constexpr const bool is_specialized = true;
static inline constexpr const bool is_integer = true;
static inline constexpr const bool is_exact = true;
static inline constexpr const bool has_infinity = false;
static inline constexpr const bool has_quiet_NaN = false;
static inline constexpr const bool has_signaling_NaN = false;
static inline constexpr const std::float_denorm_style has_denorm =
std::denorm_absent;
static inline constexpr const bool has_denorm_loss = false;
static inline constexpr const std::float_round_style round_style =
std::round_toward_zero;
static inline constexpr const bool is_iec559 = false;
static inline constexpr const bool is_bounded = true;
static inline constexpr const int max_digits10 = 0; // Not used for integers.
static inline constexpr const int radix = 2;
static inline constexpr const int min_exponent = 0;
static inline constexpr const int min_exponent10 = 0;
static inline constexpr const int max_exponent = 0;
static inline constexpr const int max_exponent10 = 0;
static inline constexpr const bool traps = true;
static inline constexpr const bool tinyness_before = false;

static constexpr ml_dtypes::int4 epsilon() noexcept {
return ml_dtypes::int4(0);
}
static constexpr ml_dtypes::int4 round_error() noexcept {
return ml_dtypes::int4(0);
}
static constexpr ml_dtypes::int4 infinity() noexcept {
return ml_dtypes::int4(0);
}
static constexpr ml_dtypes::int4 quiet_NaN() noexcept {
return ml_dtypes::int4(0);
}
static constexpr ml_dtypes::int4 signaling_NaN() noexcept {
return ml_dtypes::int4(0);
}
static constexpr ml_dtypes::int4 denorm_min() noexcept {
return ml_dtypes::int4(0);
}
};

} // namespace internal

} // namespace ml_dtypes

namespace std {

template <>
struct numeric_limits<ml_dtypes::int4>
: public ml_dtypes::internal::int4_numeric_limits_base {
static inline constexpr const bool is_signed = true;
static inline constexpr const bool is_modulo = false;
static inline constexpr const int digits = 3;
static inline constexpr const int digits10 = 0; // floor(3 * log10(2))
static constexpr ml_dtypes::int4 min() noexcept {
return ml_dtypes::int4::lowest();
}
static constexpr ml_dtypes::int4 lowest() noexcept {
return ml_dtypes::int4::lowest();
}
static constexpr ml_dtypes::int4 max() noexcept {
return ml_dtypes::int4::highest();
}
};

template <>
struct numeric_limits<ml_dtypes::uint4>
: public ml_dtypes::internal::int4_numeric_limits_base {
static inline constexpr const bool is_signed = false;
static inline constexpr const bool is_modulo = true;
static inline constexpr const int digits = 4;
static inline constexpr const int digits10 = 1; // floor(4 * log10(2))
static constexpr ml_dtypes::uint4 min() noexcept {
return ml_dtypes::uint4::lowest();
}
static constexpr ml_dtypes::uint4 lowest() noexcept {
return ml_dtypes::uint4::lowest();
}
static constexpr ml_dtypes::uint4 max() noexcept {
return ml_dtypes::uint4::highest();
}
};

} // namespace std

#endif // ML_DTYPES_INT4_H_
Loading

0 comments on commit 28cc246

Please sign in to comment.