Skip to content

Commit

Permalink
Matrix refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
altavir committed Apr 7, 2019
1 parent 4065fda commit 14f05eb
Show file tree
Hide file tree
Showing 18 changed files with 191 additions and 205 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package scientifik.kmath.linear

import koma.matrix.ejml.EJMLMatrixFactory
import scientifik.kmath.structures.Matrix
import kotlin.random.Random
import kotlin.system.measureTimeMillis

Expand Down Expand Up @@ -30,7 +31,7 @@ fun main() {

//commons-math

val cmContext = CMMatrixContext
val cmContext = CMLUPSolver

val commonsTime = measureTimeMillis {
cmContext.run {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package scientifik.kmath.linear

import koma.matrix.ejml.EJMLMatrixFactory
import scientifik.kmath.structures.Matrix
import kotlin.random.Random
import kotlin.system.measureTimeMillis

Expand Down
1 change: 1 addition & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ subprojects {
targets.all {
sourceSets.all {
languageSettings.progressiveMode = true
languageSettings.enableLanguageFeature("InlineClasses")
}
}

Expand Down
39 changes: 21 additions & 18 deletions kmath-commons/src/main/kotlin/scientifik/kmath/linear/CMMatrix.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package scientifik.kmath.linear
import org.apache.commons.math3.linear.*
import org.apache.commons.math3.linear.RealMatrix
import org.apache.commons.math3.linear.RealVector
import scientifik.kmath.structures.Matrix

class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) : Matrix<Double> {
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> {
override val rowNum: Int get() = origin.rowDimension
override val colNum: Int get() = origin.columnDimension

override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> {
if(origin is DiagonalMatrix) yield(DiagonalFeature)
if (origin is DiagonalMatrix) yield(DiagonalFeature)
}.toSet()

override fun suggestFeature(vararg features: MatrixFeature) =
Expand Down Expand Up @@ -45,28 +46,13 @@ fun Point<Double>.toCM(): CMVector = if (this is CMVector) {

fun RealVector.toPoint() = CMVector(this)

object CMMatrixContext : MatrixContext<Double>, LinearSolver<Double> {
object CMMatrixContext : MatrixContext<Double> {

override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix {
val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } }
return CMMatrix(Array2DRowRealMatrix(array))
}

override fun solve(a: Matrix<Double>, b: Matrix<Double>): CMMatrix {
val decomposition = LUDecomposition(a.toCM().origin)
return decomposition.solver.solve(b.toCM().origin).toMatrix()
}

override fun solve(a: Matrix<Double>, b: Point<Double>): CMVector {
val decomposition = LUDecomposition(a.toCM().origin)
return decomposition.solver.solve(b.toCM().origin).toPoint()
}

override fun inverse(a: Matrix<Double>): CMMatrix {
val decomposition = LUDecomposition(a.toCM().origin)
return decomposition.solver.inverse.toMatrix()
}

override fun Matrix<Double>.dot(other: Matrix<Double>) =
CMMatrix(this.toCM().origin.multiply(other.toCM().origin))

Expand All @@ -87,6 +73,23 @@ object CMMatrixContext : MatrixContext<Double>, LinearSolver<Double> {
CMMatrix(this.toCM().origin.scalarMultiply(value.toDouble()))
}

object CMLUPSolver: LinearSolver<Double>{
override fun solve(a: Matrix<Double>, b: Matrix<Double>): CMMatrix {
val decomposition = LUDecomposition(a.toCM().origin)
return decomposition.solver.solve(b.toCM().origin).toMatrix()
}

override fun solve(a: Matrix<Double>, b: Point<Double>): CMVector {
val decomposition = LUDecomposition(a.toCM().origin)
return decomposition.solver.solve(b.toCM().origin).toPoint()
}

override fun inverse(a: Matrix<Double>): CMMatrix {
val decomposition = LUDecomposition(a.toCM().origin)
return decomposition.solver.inverse.toMatrix()
}
}

operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.add(other.origin))
operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.subtract(other.origin))

Expand Down
28 changes: 0 additions & 28 deletions kmath-core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,6 @@ kotlin {
val commonMain by getting {
dependencies {
api(project(":kmath-memory"))
api(kotlin("stdlib"))
}
}
val commonTest by getting {
dependencies {
implementation(kotlin("test-common"))
implementation(kotlin("test-annotations-common"))
}
}
val jvmMain by getting {
dependencies {
api(kotlin("stdlib-jdk8"))
}
}
val jvmTest by getting {
dependencies {
implementation(kotlin("test"))
implementation(kotlin("test-junit"))
}
}
val jsMain by getting {
dependencies {
api(kotlin("stdlib-js"))
}
}
val jsTest by getting {
dependencies {
implementation(kotlin("test-js"))
}
}
// mingwMain {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package scientifik.kmath.linear

import scientifik.kmath.operations.Ring
import scientifik.kmath.structures.*
import kotlin.jvm.JvmSynthetic

/**
* Basic implementation of Matrix space based on [NDStructure]
Expand All @@ -25,7 +24,7 @@ class BufferMatrix<T : Any>(
override val colNum: Int,
val buffer: Buffer<out T>,
override val features: Set<MatrixFeature> = emptySet()
) : Matrix<T> {
) : FeaturedMatrix<T> {

init {
if (buffer.size != rowNum * colNum) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import scientifik.kmath.structures.Buffer.Companion.boxing
import kotlin.math.sqrt


/**
* Basic operations on matrices. Operates on [Matrix]
*/
interface MatrixContext<T : Any> {
/**
* Produce a matrix with this context and given dimensions
Expand Down Expand Up @@ -101,18 +104,18 @@ interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
operator fun Matrix<T>.times(number: Number): Matrix<T> =
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * number } }

operator fun Number.times(matrix: Matrix<T>): Matrix<T> = matrix * this
operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this

override fun Matrix<T>.times(value: T): Matrix<T> =
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } }
}

/**
* Specialized 2-d structure
* A 2d structure plus optional matrix-specific features
*/
interface Matrix<T : Any> : Structure2D<T> {
val rowNum: Int
val colNum: Int
interface FeaturedMatrix<T : Any> : Matrix<T> {

override val shape: IntArray get() = intArrayOf(rowNum, colNum)

val features: Set<MatrixFeature>

Expand All @@ -122,70 +125,54 @@ interface Matrix<T : Any> : Structure2D<T> {
* The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to
* add only those features that are valid.
*/
fun suggestFeature(vararg features: MatrixFeature): Matrix<T>

override fun get(index: IntArray): T = get(index[0], index[1])

override val shape: IntArray get() = intArrayOf(rowNum, colNum)

val rows: Point<Point<T>>
get() = VirtualBuffer(rowNum) { i ->
VirtualBuffer(colNum) { j -> get(i, j) }
}
fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T>

val columns: Point<Point<T>>
get() = VirtualBuffer(colNum) { j ->
VirtualBuffer(rowNum) { i -> get(i, j) }
}
companion object {

override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
for (i in (0 until rowNum)) {
for (j in (0 until colNum)) {
yield(intArrayOf(i, j) to get(i, j))
}
}
}
}

companion object {
fun real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
MatrixContext.real.produce(rows, columns, initializer)

/**
* Build a square matrix from given elements.
*/
fun <T : Any> square(vararg elements: T): Matrix<T> {
val size: Int = sqrt(elements.size.toDouble()).toInt()
if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square")
val buffer = elements.asBuffer()
return BufferMatrix(size, size, buffer)
}
fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
MatrixContext.real.produce(rows, columns, initializer)

fun <T : Any> build(rows: Int, columns: Int): MatrixBuilder<T> = MatrixBuilder(rows, columns)
}
/**
* Build a square matrix from given elements.
*/
fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> {
val size: Int = sqrt(elements.size.toDouble()).toInt()
if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square")
val buffer = elements.asBuffer()
return BufferMatrix(size, size, buffer)
}

fun <T : Any> Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder<T> = MatrixBuilder(rows, columns)

class MatrixBuilder<T : Any>(val rows: Int, val columns: Int) {
operator fun invoke(vararg elements: T): Matrix<T> {
operator fun invoke(vararg elements: T): FeaturedMatrix<T> {
if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns")
val buffer = elements.asBuffer()
return BufferMatrix(rows, columns, buffer)
}
}

val Matrix<*>.features get() = (this as? FeaturedMatrix)?.features?: emptySet()

/**
* Check if matrix has the given feature class
*/
inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean = features.find { it is T } != null
inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean =
features.find { it is T } != null

/**
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
*/
inline fun <reified T : Any> Matrix<*>.getFeature(): T? = features.filterIsInstance<T>().firstOrNull()
inline fun <reified T : Any> Matrix<*>.getFeature(): T? =
features.filterIsInstance<T>().firstOrNull()

/**
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
*/
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> =
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): FeaturedMatrix<T> =
VirtualMatrix<T>(rows, columns) { i, j ->
if (i == j) elementContext.one else elementContext.zero
}
Expand All @@ -194,7 +181,7 @@ fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: In
/**
* A virtual matrix of zeroes
*/
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> =
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): FeaturedMatrix<T> =
VirtualMatrix<T>(rows, columns) { _, _ -> elementContext.zero }

class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ package scientifik.kmath.linear

import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Ring
import scientifik.kmath.structures.MutableBuffer
import scientifik.kmath.structures.*
import scientifik.kmath.structures.MutableBuffer.Companion.boxing
import scientifik.kmath.structures.MutableBufferFactory
import scientifik.kmath.structures.NDStructure
import scientifik.kmath.structures.get

/**
* Common implementation of [LUPDecompositionFeature]
*/
class LUPDecomposition<T : Comparable<T>>(
private val elementContext: Ring<T>,
internal val lu: NDStructure<T>,
Expand All @@ -20,7 +20,7 @@ class LUPDecomposition<T : Comparable<T>>(
*
* L is a lower-triangular matrix with [Ring.one] in diagonal
*/
override val l: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j ->
override val l: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j ->
when {
j < i -> lu[i, j]
j == i -> elementContext.one
Expand All @@ -34,7 +34,7 @@ class LUPDecomposition<T : Comparable<T>>(
*
* U is an upper-triangular matrix including the diagonal
*/
override val u: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j ->
override val u: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j ->
if (j >= i) lu[i, j] else elementContext.zero
}

Expand All @@ -45,7 +45,7 @@ class LUPDecomposition<T : Comparable<T>>(
* P is a sparse matrix with exactly one element set to [Ring.one] in
* each row and each column, all other elements being set to [Ring.zero].
*/
override val p: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
override val p: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
if (j == pivot[i]) elementContext.one else elementContext.zero
}

Expand All @@ -62,7 +62,9 @@ class LUPDecomposition<T : Comparable<T>>(

}


/**
* Common implementation of LUP [LinearSolver] based on commons-math code
*/
class LUSolver<T : Comparable<T>, F : Field<T>>(
val context: GenericMatrixContext<T, F>,
val bufferFactory: MutableBufferFactory<T> = ::boxing,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@ package scientifik.kmath.linear
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Norm
import scientifik.kmath.operations.RealField
import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.VirtualBuffer
import scientifik.kmath.structures.asSequence


/**
* A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors
*/
interface LinearSolver<T : Any> {
interface LinearSolver<T : Any> {
fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T>
fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.toMatrix()).toVector()
fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.toMatrix()).asPoint()
fun inverse(a: Matrix<T>): Matrix<T>
}

Expand All @@ -32,14 +33,14 @@ object VectorL2Norm : Norm<Point<out Number>, Double> {
typealias RealVector = Vector<Double, RealField>
typealias RealMatrix = Matrix<Double>



/**
* Convert matrix to vector if it is possible
*/
fun <T: Any> Matrix<T>.toVector(): Point<T> =
fun <T : Any> Matrix<T>.asPoint(): Point<T> =
if (this.colNum == 1) {
VirtualBuffer(rowNum){ get(it, 0) }
} else error("Can't convert matrix with more than one column to vector")
VirtualBuffer(rowNum) { get(it, 0) }
} else {
error("Can't convert matrix with more than one column to vector")
}

fun <T: Any> Point<T>.toMatrix(): Matrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) }
fun <T : Any> Point<T>.toMatrix() = VirtualMatrix(size, 1) { i, _ -> get(i) }
Loading

0 comments on commit 14f05eb

Please sign in to comment.