Skip to content

Commit

Permalink
Add resample
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza committed Feb 27, 2016
1 parent be38e9f commit f5878b4
Show file tree
Hide file tree
Showing 4 changed files with 345 additions and 1 deletion.
122 changes: 122 additions & 0 deletions src/main/scala/com/cloudera/sparkts/Resample.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/**
* Copyright (c) 2016, Cloudera, Inc. All Rights Reserved.
*
* Cloudera, Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"). You may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for
* the specific language governing permissions and limitations under the
* License.
*/

package com.cloudera.sparkts

import org.apache.spark.mllib.linalg.{Vectors, Vector}

private[sparkts] object Resample {
/**
* Converts a time series to a new date-time index, with flexible semantics for aggregating
* observations when downsampling.
*
* Based on the closedRight and stampRight parameters, resampling partitions time into non-
* overlapping intervals, each corresponding to a date-time in the target index. Each resulting
* value in the output series is determined by applying an aggregation function over all the
* values that fall within the corresponding window in the input series. If no values in the
* input series fall within the window, a NaN is used.
*
* Compare with the equivalent functionality in Pandas:
* http://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.resample.html
*
* @param ts The values of the input series.
* @param sourceIndex The date-time index of the input series.
* @param targetIndex The date-time index of the resulting series.
* @param aggr Function for aggregating multiple points that fall within a window.
* @param closedRight If true, the windows are open on the left and closed on the right. Otherwise
* the windows are closed on the left and open on the right.
* @param stampRight If true, each date-time in the resulting series marks the end of a window.
* This means that all observations after the end of the last window will be
* ignored. Otherwise, each date-time in the resulting series marks the start of
* a window. This means that all observations after the end of the last window
* will be ignored.
* @return The values of the resampled series.
*/
def resample(
ts: Vector,
sourceIndex: DateTimeIndex,
targetIndex: DateTimeIndex,
aggr: (Array[Double], Int, Int) => Double,
closedRight: Boolean,
stampRight: Boolean): Vector = {
val tsarr = ts.toArray
val result = new Array[Double](targetIndex.size)
val sourceIter = sourceIndex.nanosIterator().buffered
val targetIter = targetIndex.nanosIterator().buffered

// Values within interval corresponding to stamp "c" (with next stamp at "n")
//
// !closedRight && stampRight:
// 1 2 3 4
// c
//
// !closedRight && !stampRight:
// 1 2 3 4
// c n
//
// closedRight && stampRight:
// 1 2 3 4
// c
//
// closedRight && !stampRight
// 1 2 3 4
// c n

// End predicate should return true iff dt falls after the window labeled by cur DT (at i)
val endPredicate: (Long, Long, Long) => Boolean = if (!closedRight && stampRight) {
(cur, next, dt) => dt >= cur
} else if (!closedRight && !stampRight) {
(cur, next, dt) => dt >= next
} else if (closedRight && stampRight) {
(cur, next, dt) => dt > cur
} else {
(cur, next, dt) => dt > next
}

var i = 0 // index in result array
var j = 0 // index in source array

// Skip observations that don't belong with any stamp
if (!stampRight) {
val firstStamp = targetIter.head
while (sourceIter.head < firstStamp || (closedRight && sourceIter.head == firstStamp)) {
sourceIter.next()
j += 1
}
}

// Invariant is that nothing lower than j should be needed to populate result(i)
while (i < result.length) {
val cur = targetIter.next()
val next = if (targetIter.hasNext) targetIter.head else Long.MaxValue
val sourceStartIndex = j

while (sourceIter.hasNext && !endPredicate(cur, next, sourceIter.head)) {
sourceIter.next()
j += 1
}
val sourceEndIndex = j

if (sourceStartIndex == sourceEndIndex) {
result(i) = Double.NaN
} else {
result(i) = aggr(tsarr, sourceStartIndex, sourceEndIndex)
}

i += 1
}
Vectors.dense(result)
}
}
34 changes: 34 additions & 0 deletions src/main/scala/com/cloudera/sparkts/TimeSeries.scala
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,40 @@ class TimeSeries[K](val index: DateTimeIndex, val data: DenseMatrix,
* Gets the first univariate series and its key.
*/
def head(): (K, Vector) = univariateKeyAndSeriesIterator().next()

/**
* Returns a TimeSeries with each univariate series resampled to a new date-time index. Resampling
* provides flexible semantics for specifying which date-times in each input series correspond to
* which date-times in the output series, and for aggregating observations when downsampling.
*
* Based on the closedRight and stampRight parameters, resampling partitions time into non-
* overlapping intervals, each corresponding to a date-time in the target index. Each resulting
* value in the output series is determined by applying an aggregation function over all the
* values that fall within the corresponding window in the input series. If no values in the
* input series fall within the window, a NaN is used.
*
* Compare with the equivalent functionality in Pandas:
* http://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.resample.html
*
* @param targetIndex The date-time index of the resulting series.
* @param aggr Function for aggregating multiple points that fall within a window.
* @param closedRight If true, the windows are open on the left and closed on the right. Otherwise
* the windows are closed on the left and open on the right.
* @param stampRight If true, each date-time in the resulting series marks the end of a window.
* This means that all observations after the end of the last window will be
* ignored. Otherwise, each date-time in the resulting series marks the start of
* a window. This means that all observations after the end of the last window
* will be ignored.
* @return The values of the resampled series.
*/

def resample(
targetIndex: DateTimeIndex,
aggr: (Array[Double], Int, Int) => Double,
closedRight: Boolean,
stampRight: Boolean): TimeSeries[K] = {
mapSeries(targetIndex, Resample.resample(_, index, targetIndex, aggr, closedRight, stampRight))
}
}

object TimeSeries {
Expand Down
37 changes: 36 additions & 1 deletion src/main/scala/com/cloudera/sparkts/TimeSeriesRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -472,14 +472,49 @@ class TimeSeriesRDD[K](val index: DateTimeIndex, parent: RDD[(K, Vector)])

/**
* Returns a TimeSeriesRDD rebased on top of a new index. Any timestamps that exist in the new
* index but not in the existing index will be filled in with NaNs.
* index but not in the existing index will be filled in with NaNs. [[resample]] offers similar
* functionality with richer semantics for aggregating values within windows.
*
* @param newIndex The DateTimeIndex for the new RDD
*/
def withIndex(newIndex: DateTimeIndex): TimeSeriesRDD[K] = {
val rebaser = TimeSeriesUtils.rebaser(index, newIndex, Double.NaN)
mapSeries(rebaser, newIndex)
}

/**
* Returns a TimeSeriesRDD with each series resampled to a new date-time index. Resampling
* provides flexible semantics for specifying which date-times in each input series correspond to
* which date-times in the output series, and for aggregating observations when downsampling.
*
* Based on the closedRight and stampRight parameters, resampling partitions time into non-
* overlapping intervals, each corresponding to a date-time in the target index. Each resulting
* value in the output series is determined by applying an aggregation function over all the
* values that fall within the corresponding window in the input series. If no values in the
* input series fall within the window, a NaN is used.
*
* Compare with the equivalent functionality in Pandas:
* http://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.resample.html
*
* @param targetIndex The date-time index of the resulting series.
* @param aggr Function for aggregating multiple points that fall within a window.
* @param closedRight If true, the windows are open on the left and closed on the right. Otherwise
* the windows are closed on the left and open on the right.
* @param stampRight If true, each date-time in the resulting series marks the end of a window.
* This means that all observations after the end of the last window will be
* ignored. Otherwise, each date-time in the resulting series marks the start of
* a window. This means that all observations after the end of the last window
* will be ignored.
* @return The values of the resampled series.
*/

def resample(
targetIndex: DateTimeIndex,
aggr: (Array[Double], Int, Int) => Double,
closedRight: Boolean,
stampRight: Boolean): TimeSeriesRDD[K] = {
mapSeries(Resample.resample(_, index, targetIndex, aggr, closedRight, stampRight), targetIndex)
}
}

object TimeSeriesRDD {
Expand Down
153 changes: 153 additions & 0 deletions src/test/scala/com/cloudera/sparkts/ResampleSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/**
* Copyright (c) 2015, Cloudera, Inc. All Rights Reserved.
*
* Cloudera, Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"). You may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for
* the specific language governing permissions and limitations under the
* License.
*/

package com.cloudera.sparkts

import java.time.{ZonedDateTime, ZoneId}

import org.apache.spark.mllib.linalg._

import org.scalatest._

import scala.collection.mutable.ArrayBuffer

class ResampleSuite extends FunSuite with ShouldMatchers {
def verify(
series: String,
closedRight: Boolean,
stampRight: Boolean,
result: String): Unit = {
def seriesStrToIndexAndVec(str: String): (DateTimeIndex, Vector) = {
val Z = ZoneId.of("Z")
val baseDT = ZonedDateTime.of(2015, 4, 8, 0, 0, 0, 0, Z)
val seriesPointsRaw = str.toCharArray.zipWithIndex.filter(_._1 != ' ').map { case (c, i) =>
(if (c == 'N') Double.NaN else c.toString.toDouble, i)
}.toBuffer
// Account for numbers with multiple digits:
val seriesPoints = new ArrayBuffer[(Double, Int)]()
val iter = seriesPointsRaw.iterator.buffered
while (iter.hasNext) {
val point = iter.next()
if (iter.hasNext && iter.head._2 == point._2 + 1) {
seriesPoints += ((point._1 * 10 + iter.next()._1, point._2))
} else {
seriesPoints += point
}
}
val index = DateTimeIndex.irregular(seriesPoints.map(x => baseDT.plusDays(x._2)).toArray)
val vec = Vectors.dense(seriesPoints.map(_._1).toArray)
(index, vec)
}

val (sourceIndex, sourceVec) = seriesStrToIndexAndVec(series)
val (resultIndex, resultVec) = seriesStrToIndexAndVec(result)

def aggr(arr: Array[Double], start: Int, end: Int) = {
arr.slice(start, end).sum
}
val resampled = Resample.resample(sourceVec, sourceIndex, resultIndex, aggr, closedRight,
stampRight)

resampled should be (resultVec)
}

test("downsampling") {
verify(
"0 1 2 3 4 5 6 7 8",
false, false,
"3 12 21"
)

verify(
"0 1 2 3 4 5 6 7 8",
true, false,
"6 15 15"
)

verify(
"1 2 3 4 5 6 7 8 9",
false, true,
"N 6 15"
)

verify(
"0 1 2 3 4 5 6 7 8",
true, true,
"0 6 15"
)

verify(
"0 1 2 3 4 5 6 7 8",
false, false,
" 12 21 N"
)

verify(
"1 2 3 4 5 6 7 8 9",
true, false,
" 18 17 N"
)

verify(
"0 1 2 3 4 5 6 7 8",
false, true,
" 3 12 21"
)

verify(
"1 2 3 4 5 6 7 8 9",
true, true,
" 10 18 17"
)

verify(
"0 1 2 3 4 5 6 7 8",
false, false,
"6 15 15"
)

verify(
"0 1 2 3 4 5 6 7 8",
true, false,
"6 15 15"
)

verify(
"1 2 3 4 5 6 7 8 9",
false, true,
"N 10 18"
)

verify(
"0 1 2 3 4 5 6 7 8",
true, true,
"0 6 15"
)
}

test("upsampling") {
verify(
"1 2 3 4 5",
false, false,
"1 N 2 N 3 N 4 N 5"
)

verify(
"1 2 3 4 5",
false, false,
"1 2 N 3 N 4 N 5"
)
}
}

0 comments on commit f5878b4

Please sign in to comment.