Skip to content

Commit

Permalink
Fix bug in SWA with missing feature data (feathr-ai#1171)
Browse files Browse the repository at this point in the history
* Fix bug in SWA with missing feature data

* remove unwanted code

* Address feedback and version bump

---------

Co-authored-by: Rakesh Kashyap Hanasoge Padmanabha <[email protected]>
  • Loading branch information
rakeshkashyap123 and Rakesh Kashyap Hanasoge Padmanabha authored May 17, 2023
1 parent c06699a commit 04e89fd
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups:

val (joinedDF, header) = doJoinObsAndFeatures(joinConfig, jobContext, obsData)
(joinedDF, header, Map(SuppressedExceptionHandlerUtils.MISSING_DATA_EXCEPTION
-> SuppressedExceptionHandlerUtils.missingFeatures.mkString))
-> SuppressedExceptionHandlerUtils.missingFeatures.mkString(", ")))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ private[offline] class SlidingWindowAggregationJoiner(
} else if (originalSourceDf.isEmpty && shouldAddDefaultColForMissingData) { // If add default col for missing data flag features
// flag is set and there is a data related error, an empty dataframe will be returned.
res.map(emptyFeatures.add)
val exceptionMsg = emptyFeatures.mkString
log.warn(s"Missing data for features ${emptyFeatures}. Default values will be populated for this column.")
SuppressedExceptionHandlerUtils.missingFeatures ++= emptyFeatures
anchors.map(anchor => (anchor, originalSourceDf))
Expand Down Expand Up @@ -299,7 +298,7 @@ private[offline] class SlidingWindowAggregationJoiner(
substituteDefaults(withFDSFeatureDF, defaults.keys.filter(joinedFeatures.contains).toSeq, defaults, userSpecifiedTypesConfig, ss)

allInferredFeatureTypes ++= inferredTypes
contextDF = standardizeFeatureColumnNames(origContextObsColumns, withFeatureContextDF, joinedFeatures, keyTags.map(keyTagList))
contextDF = standardizeFeatureColumnNames(ss, origContextObsColumns, withFeatureContextDF, joinedFeatures, keyTags.map(keyTagList))
if (shouldCheckPoint(ss)) {
// checkpoint complicated dataframe for each stage to avoid Spark failure
contextDF = contextDF.checkpoint(true)
Expand All @@ -325,13 +324,19 @@ private[offline] class SlidingWindowAggregationJoiner(
* @return
*/
def standardizeFeatureColumnNames(
ss: SparkSession,
origContextObsColumns: Seq[String],
withSWAFeatureDF: DataFrame,
featureNames: Seq[String],
keyTags: Seq[String]): DataFrame = {
val inputColumnSize = origContextObsColumns.size
val outputColumnNum = withSWAFeatureDF.columns.size
if (outputColumnNum != inputColumnSize + featureNames.size) {
val shouldAddDefaultColForMissingData = FeathrUtils.getFeathrJobParam(ss.sparkContext.getConf,
FeathrUtils.ADD_DEFAULT_COL_FOR_MISSING_DATA).toBoolean

// Do not perform this check if shouldAddDefaultColForMissingData is true as we add the null values to all SWA features at once,
// and do not care for the SWA groupings.
if (!shouldAddDefaultColForMissingData && (outputColumnNum != inputColumnSize + featureNames.size)) {
throw new FeathrIllegalStateException(
s"Number of columns (${outputColumnNum}) in the dataframe returned by " +
s"sliding window aggregation does not equal to number of columns in the observation data (${inputColumnSize}) " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import com.linkedin.feathr.offline.transformation.MultiLevelAggregationTransform
import com.linkedin.feathr.offline.util.FeathrUtils
import com.linkedin.feathr.offline.util.FeathrUtils.{FILTER_NULLS, SKIP_MISSING_FEATURE, setFeathrJobParam}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row}
import org.testng.Assert._
Expand All @@ -19,10 +18,7 @@ import scala.collection.mutable


import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, TimestampType}

import scala.concurrent.duration._
class SlidingWindowAggIntegTest extends FeathrIntegTest {

def getDf(): DataFrame = {
Expand Down Expand Up @@ -259,6 +255,157 @@ class SlidingWindowAggIntegTest extends FeathrIntegTest {
assertEquals(row1f1f1, TestUtils.build1dSparseTensorFDSRow(Array("f1t1"), Array(12.0f)))
}

/**
* test SWA with lateralview parameters and ADD_DEFAULT_COL_FOR_MISSING_DATA flag set
*/
@Test
def testLocalAnchorSWATestWithDataMissingFlagSet: Unit = {
setFeathrJobParam(FeathrUtils.ADD_DEFAULT_COL_FOR_MISSING_DATA, "true")
val df = runLocalFeatureJoinForTest(
joinConfigAsString =
"""
| settings: {
| observationDataTimeSettings: {
| absoluteTimeRange: {
| startTime: "2018-05-01"
| endTime: "2018-05-03"
| timeFormat: "yyyy-MM-dd"
| }
| }
| joinTimeSettings: {
| timestampColumn: {
| def: timestamp
| format: "yyyy-MM-dd"
| }
| }
|}
|
|features: [
| {
| key: [x],
| featureList: ["f1", "f1Sum", "f2", "f1f1"]
| },
| {
| key: [x, y]
| featureList: ["f3", "f4"]
| }
|]
""".stripMargin,
featureDefAsString =
"""
|sources: {
| ptSource: {
| type: "PASSTHROUGH"
| }
| swaSource: {
| location: { path: "missingData/localSWAAnchorTestFeatureData/daily" }
| timePartitionPattern: "yyyy/MM/dd"
| timeWindowParameters: {
| timestampColumn: "timestamp"
| timestampColumnFormat: "yyyy-MM-dd"
| }
| }
|}
|
|anchors: {
| ptAnchor: {
| source: "ptSource"
| key: "x"
| features: {
| f1f1: {
| def: "([$.term:$.value] in passthroughFeatures if $.name == 'f1f1')"
| }
| }
| }
| swaAnchor: {
| source: "swaSource"
| key: "substring(x, 0)"
| lateralViewParameters: {
| lateralViewDef: explode(features)
| lateralViewItemAlias: feature
| }
| features: {
| f1: {
| def: "feature.col.value"
| filter: "feature.col.name = 'f1'"
| aggregation: SUM
| groupBy: "feature.col.term"
| window: 3d
| }
| }
| }
|
| swaAnchor2: {
| source: "swaSource"
| key: "x"
| lateralViewParameters: {
| lateralViewDef: explode(features)
| lateralViewItemAlias: feature
| }
| features: {
| f1Sum: {
| def: "feature.col.value"
| filter: "feature.col.name = 'f1'"
| aggregation: SUM
| groupBy: "feature.col.term"
| window: 3d
| }
| }
| }
| swaAnchorWithKeyExtractor: {
| source: "swaSource"
| keyExtractor: "com.linkedin.feathr.offline.anchored.keyExtractor.SimpleSampleKeyExtractor"
| features: {
| f3: {
| def: "aggregationWindow"
| aggregation: SUM
| window: 3d
| }
| }
| }
| swaAnchorWithKeyExtractor2: {
| source: "swaSource"
| keyExtractor: "com.linkedin.feathr.offline.anchored.keyExtractor.SimpleSampleKeyExtractor"
| features: {
| f4: {
| def: "aggregationWindow"
| aggregation: SUM
| window: 3d
| }
| }
| }
| swaAnchorWithKeyExtractor3: {
| source: "swaSource"
| keyExtractor: "com.linkedin.feathr.offline.anchored.keyExtractor.SimpleSampleKeyExtractor2"
| lateralViewParameters: {
| lateralViewDef: explode(features)
| lateralViewItemAlias: feature
| }
| features: {
| f2: {
| def: "feature.col.value"
| filter: "feature.col.name = 'f2'"
| aggregation: SUM
| groupBy: "feature.col.term"
| window: 3d
| }
| }
| }
|}
""".stripMargin,
"slidingWindowAgg/localAnchorTestObsData.avro.json").data
df.show()

// validate output in name term value format
val featureList = df.collect().sortBy(row => if (row.get(0) != null) row.getAs[String]("x") else "null")
val row0 = featureList(0)
val row0f1 = row0.getAs[Row]("f1")
assertEquals(row0f1, null)
val row0f2 = row0.getAs[Row]("f2")
assertEquals(row0f2, null)
setFeathrJobParam(FeathrUtils.ADD_DEFAULT_COL_FOR_MISSING_DATA, "false")
}

/**
* test SWA with lateralview parameters
*/
Expand Down Expand Up @@ -869,6 +1016,7 @@ class SlidingWindowAggIntegTest extends FeathrIntegTest {
| timestampColumnFormat: "yyyy-MM-dd"
| }
| }
|
|}
|
|anchors: {
Expand Down

0 comments on commit 04e89fd

Please sign in to comment.