Skip to content

Commit

Permalink
Gradient.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 25, 2024
1 parent 9a83836 commit 77a730e
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions include/xgboost/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ namespace detail {
template <typename T>
class GradientPairInternal {
/*! \brief gradient statistics */
T grad_;
T grad_{0};
/*! \brief second order gradient statistics */
T hess_;
T hess_{0};

XGBOOST_DEVICE void SetGrad(T g) { grad_ = g; }
XGBOOST_DEVICE void SetHess(T h) { hess_ = h; }
Expand All @@ -161,15 +161,18 @@ class GradientPairInternal {
a += b;
}

XGBOOST_DEVICE GradientPairInternal() : grad_(0), hess_(0) {}
GradientPairInternal() = default;

XGBOOST_DEVICE GradientPairInternal(T grad, T hess) {
SetGrad(grad);
SetHess(hess);
}

// Copy constructor if of same value type, marked as default to be trivially_copyable
GradientPairInternal(const GradientPairInternal<T> &g) = default;
GradientPairInternal(GradientPairInternal const &g) = default;
GradientPairInternal(GradientPairInternal &&g) = default;
GradientPairInternal &operator=(GradientPairInternal const &that) = default;
GradientPairInternal &operator=(GradientPairInternal &&that) = default;

// Copy constructor if different value type - use getters and setters to
// perform conversion
Expand Down Expand Up @@ -274,10 +277,11 @@ class GradientPairInt64 {
GradientPairInt64() = default;

// Copy constructor if of same value type, marked as default to be trivially_copyable
GradientPairInt64(const GradientPairInt64 &g) = default;
GradientPairInt64(GradientPairInt64 const &g) = default;
GradientPairInt64 &operator=(GradientPairInt64 const &g) = default;

XGBOOST_DEVICE T GetQuantisedGrad() const { return grad_; }
XGBOOST_DEVICE T GetQuantisedHess() const { return hess_; }
[[nodiscard]] XGBOOST_DEVICE T GetQuantisedGrad() const { return grad_; }
[[nodiscard]] XGBOOST_DEVICE T GetQuantisedHess() const { return hess_; }

XGBOOST_DEVICE GradientPairInt64 &operator+=(const GradientPairInt64 &rhs) {
grad_ += rhs.grad_;
Expand Down

0 comments on commit 77a730e

Please sign in to comment.