Skip to content

Commit 6a012d4

Browse files
committed
improve SparseMatrix.Entry
1 parent 2838c70 commit 6a012d4

File tree

3 files changed

+33
-30
lines changed

3 files changed

+33
-30
lines changed

math/src/main/java/smile/math/matrix/SparseMatrix.java

+29-26
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ public void foreachNonzero(int startColumn, int endColumn, MatrixElementConsumer
260260
* Provides a stream over all of the non-zero elements of a sparse matrix.
261261
*/
262262
public Stream<Entry> nonzeros() {
263-
return StreamSupport.stream(new SparseMatrixSpliterator(this, 0, ncols), false);
263+
return StreamSupport.stream(new SparseMatrixSpliterator(0, ncols), false);
264264
}
265265

266266
/**
@@ -270,42 +270,40 @@ public Stream<Entry> nonzeros() {
270270
* @param endColumn The first column after the ones that we should scan.
271271
*/
272272
public Stream<Entry> nonzeros(int startColumn, int endColumn) {
273-
return StreamSupport.stream(new SparseMatrixSpliterator(this, startColumn, endColumn), false);
273+
return StreamSupport.stream(new SparseMatrixSpliterator(startColumn, endColumn), false);
274274
}
275275

276276
/**
277277
* Provides a spliterator for access to a sparse matrix in column major order.
278278
* <p>
279279
* This is exposed to facilitate lower level access to the stream API for a matrix.
280280
*/
281-
private static class SparseMatrixSpliterator implements Spliterator<Entry> {
282-
private final SparseMatrix m;
281+
private class SparseMatrixSpliterator implements Spliterator<Entry> {
283282
private int col; // current column, advanced on split or traversal
284283
private int index; // current element within column
285284
private final int fence; // one past the last column to process
286285

287-
SparseMatrixSpliterator(SparseMatrix matrix, int col, int fence) {
288-
this.m = matrix;
286+
SparseMatrixSpliterator(int col, int fence) {
289287
this.col = col;
290-
this.index = m.colIndex[col];
288+
this.index = colIndex[col];
291289
this.fence = fence;
292290
}
293291

294292
public void forEachRemaining(Consumer<? super Entry> action) {
295293
for (; col < fence; col++) {
296-
for (; index < m.colIndex[col + 1]; index++) {
297-
action.accept(new Entry(m.rowIndex[index], col, m.x[index], index, m.x));
294+
for (; index < colIndex[col + 1]; index++) {
295+
action.accept(new Entry(rowIndex[index], col, x[index], index));
298296
}
299297
}
300298
}
301299

302300
public boolean tryAdvance(Consumer<? super Entry> action) {
303301
if (col < fence) {
304-
while (col < fence && index >= m.colIndex[col + 1]) {
302+
while (col < fence && index >= colIndex[col + 1]) {
305303
col++;
306304
}
307305
if (col < fence) {
308-
action.accept(new Entry(m.rowIndex[index], col, m.x[index], index, m.x));
306+
action.accept(new Entry(rowIndex[index], col, x[index], index));
309307
index++;
310308
return true;
311309
} else {
@@ -322,15 +320,15 @@ public Spliterator<Entry> trySplit() {
322320
int mid = ((lo + fence) >>> 1) & ~1; // force midpoint to be even
323321
if (lo < mid) { // split out left half
324322
col = mid; // reset this Spliterator's origin
325-
return new SparseMatrixSpliterator(m, lo, mid);
323+
return new SparseMatrixSpliterator(lo, mid);
326324
} else {
327325
// too small to split
328326
return null;
329327
}
330328
}
331329

332330
public long estimateSize() {
333-
return (long) (m.colIndex[fence] - m.colIndex[col]);
331+
return (long) (colIndex[fence] - colIndex[col]);
334332
}
335333

336334
public int characteristics() {
@@ -339,32 +337,37 @@ public int characteristics() {
339337
}
340338

341339
/**
342-
* Encapsulates important information about an entry in a matrix for use
343-
* in streaming, including a string that leads back to the original cell
344-
* so that in-place updates are possible.
340+
* Encapsulates an entry in a matrix for use in streaming. As typical stream object,
341+
* this object is immutable. But we can update the corresponding value in the matrix
342+
* through <code>update</code> method. This provides an efficient way to update the
343+
* non-zero entries of a sparse matrix.
345344
*/
346-
public static class Entry {
345+
public class Entry {
347346
// these fields are exposed for direct access to simplify in-lining by the JVM
348347
public final int row;
349348
public final int col;
350-
public double x;
349+
public final double value;
351350

352351
// these are hidden due to internal dependency
353352
private final int index;
354-
private double[] values;
355353

356-
357-
public Entry(int row, int col, double x, int index, double[] values) {
354+
/**
355+
* Private constructor. Only the enclosure matrix can creates
356+
* the instances of entry.
357+
*/
358+
private Entry(int row, int col, double value, int index) {
358359
this.row = row;
359360
this.col = col;
360-
this.x = x;
361+
this.value = value;
361362
this.index = index;
362-
this.values = values;
363363
}
364364

365-
public void set(double value) {
366-
this.x = value;
367-
values[index] = value;
365+
/**
366+
* Update the value of entry in the matrix. Note that the field <code>value</code>
367+
* is final and thus not updated.
368+
*/
369+
public void update(double value) {
370+
x[index] = value;
368371
}
369372
}
370373

math/src/test/java/smile/math/matrix/SparseMatrixTest.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ public void nonZeroIterator() {
225225
entry -> {
226226
int i = entry.row;
227227
int j = entry.col;
228-
double x = entry.x;
228+
double x = entry.value;
229229

230230
assertEquals(d[i][j], x, 0);
231231
assertEquals(d[i][j], m.get(i, j), 0);
@@ -243,7 +243,7 @@ public void nonZeroIterator() {
243243
assertTrue(col >= 100);
244244
assertTrue(col < 400);
245245

246-
assertEquals(d[entry.row][col], entry.x, 0);
246+
assertEquals(d[entry.row][col], entry.value, 0);
247247
assertEquals(d[entry.row][col], m.get(entry.row, col), 0);
248248
k.incrementAndGet();
249249
}
@@ -283,7 +283,7 @@ public void iterationSpeed() {
283283
double[] sum2 = new double[2000];
284284
for (int rep = 0; rep < 1000; rep++) {
285285
m.nonzeros()
286-
.forEach(entry -> sum2[entry.col] += entry.x);
286+
.forEach(entry -> sum2[entry.col] += entry.value);
287287
}
288288
t1 = System.nanoTime() / 1e9;
289289
sum = 0;

plot/src/main/java/smile/plot/SparseMatrixPlot.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ public SparseMatrixPlot(SparseMatrix sparse, Color[] palette) {
114114

115115
// In case of outliers, we use 1% and 99% quantiles as lower and
116116
// upper limits instead of min and max.
117-
double[] values = sparse.nonzeros().mapToDouble(entry -> entry.x).filter(x -> !Double.isNaN(x)).toArray();
117+
double[] values = sparse.nonzeros().mapToDouble(entry -> entry.value).filter(x -> !Double.isNaN(x)).toArray();
118118

119119
if (values.length == 0) {
120120
throw new IllegalArgumentException("Sparse matrix has no non-zero values");

0 commit comments

Comments
 (0)