Skip to content

Commit

Permalink
Added 'MoreMath' tests
Browse files Browse the repository at this point in the history
  • Loading branch information
superbobry committed Oct 15, 2015
1 parent 5921b67 commit f79db55
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 41 deletions.
48 changes: 11 additions & 37 deletions src/main/kotlin/org/jetbrains/bio/viktor/MoreMath.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,18 @@ package org.jetbrains.bio.viktor
import org.apache.commons.math3.util.FastMath

/**
* Useful mathematical routines absent in [java.util.Math]
* and [org.apache.commons.math3.util.FastMath].
* Evaluates log(exp(a) + exp(b)) using the following trick
*
* When adding new functionality please consider reading
* http://blog.juma.me.uk/2011/02/23/performance-of-fastmath-from-commons-math.
* log(exp(a) + log(exp(b)) = a + log(1 + exp(b - a))
*
* @author Alexey Dievsky
* @author Sergei Lebedev
* @since 0.1.0
* assuming a >= b.
*/
object MoreMath {
/**
* Evaluates log(exp(a) + exp(b)) using the following trick
*
* log(exp(a) + log(exp(b)) = a + log(1 + exp(b - a))
*
* assuming a >= b.
*/
@JvmStatic fun logAddExp(a: Double, b: Double): Double {
return when {
a.isInfinite() && a < 0 -> b
b.isInfinite() && b < 0 -> a
else -> Math.max(a, b) + StrictMath.log1p(FastMath.exp(-Math.abs(a - b)))
}
fun Double.logAddExp(b: Double): Double {
val a = this
return when {
a.isInfinite() && a < 0 -> b
b.isInfinite() && b < 0 -> a
else -> Math.max(a, b) + StrictMath.log1p(FastMath.exp(-Math.abs(a - b)))
}
}

Expand All @@ -42,7 +30,7 @@ object MoreMath {
* @author Alexey Dievsky
* @since 0.1.0
*/
class KahanSum private constructor(private var accumulator: Double) {
class KahanSum @JvmOverloads constructor(private var accumulator: Double = 0.0) {
private var compensator = 0.0

/**
Expand All @@ -68,19 +56,5 @@ class KahanSum private constructor(private var accumulator: Double) {
/**
* Returns the sum accumulated so far.
*/
fun result(): Double = accumulator + compensator

companion object {
/**
* Creates and returns a zero-initiated accumulator which can be
* fed doubles and polled for the accumulated sum.
*/
@JvmStatic fun create(): KahanSum = create(0.0)

/**
* Creates and returns an accumulator which can be fed
* doubles and polled for the accumulated sum.
*/
@JvmStatic fun create(initial: Double): KahanSum = KahanSum(initial)
}
fun result() = accumulator + compensator
}
8 changes: 4 additions & 4 deletions src/main/kotlin/org/jetbrains/bio/viktor/StridedVector.kt
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ open class StridedVector(protected val data: DoubleArray,
}

open fun sum(): Double {
val acc = KahanSum.create()
val acc = KahanSum()
for (pos in 0..size - 1) {
acc += unsafeGet(pos)
}
Expand All @@ -228,7 +228,7 @@ open class StridedVector(protected val data: DoubleArray,
}

open fun cumSum() {
val acc = KahanSum.create()
val acc = KahanSum()
for (pos in 0..size - 1) {
acc += unsafeGet(pos)
unsafeSet(pos, acc.result())
Expand Down Expand Up @@ -307,7 +307,7 @@ open class StridedVector(protected val data: DoubleArray,

open fun logSumExp(): Double {
val offset = max()
val sum = KahanSum.create()
val sum = KahanSum()
for (pos in 0..size - 1) {
sum += FastMath.exp(unsafeGet(pos) - offset)
}
Expand All @@ -325,7 +325,7 @@ open class StridedVector(protected val data: DoubleArray,
checkSize(other)
checkSize(dst)
for (pos in 0..size - 1) {
dst.unsafeSet(pos, MoreMath.logAddExp(unsafeGet(pos), other.unsafeGet(pos)))
dst.unsafeSet(pos, unsafeGet(pos) logAddExp other.unsafeGet(pos))
}
}

Expand Down
40 changes: 40 additions & 0 deletions src/test/kotlin/org/jetbrains/bio/viktor/MoreMathTests.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package org.jetbrains.bio.viktor

import org.junit.Test
import java.util.*
import kotlin.test.assertEquals
import kotlin.test.assertTrue

class MoreMathTest {
@Test fun testLogAddExpEdgeCases() {
val r = Random()
val logx = -Math.abs(r.nextDouble())
assertEquals(logx, Double.NEGATIVE_INFINITY logAddExp logx)
assertEquals(logx, logx logAddExp Double.NEGATIVE_INFINITY)
assertEquals(Double.NEGATIVE_INFINITY,
Double.NEGATIVE_INFINITY logAddExp Double.NEGATIVE_INFINITY)
}
}

class KahanSumTest {
@Test fun testPrecision() {
val bigNumber = 10000000
for (d in 9..15) {
// note that in each case 1/d is not precisely representable as a double,
// which is bound to lead to accumulating rounding errors.
val oneDth = 1.0 / d
val preciseSum = KahanSum()
var impreciseSum = 0.0
for (i in 0..bigNumber * d - 1) {
preciseSum += oneDth
impreciseSum += oneDth
}

val imprecision = Math.abs(impreciseSum - bigNumber)
val precision = Math.abs(preciseSum.result() - bigNumber)
assertTrue(imprecision >= precision,
"Kahan's algorithm yielded worse precision than ordinary summation: " +
"$precision is greater than $imprecision")
}
}
}

0 comments on commit f79db55

Please sign in to comment.