Skip to content

Commit

Permalink
azure-cosmos-spark: Adding metrics (bytesWritten, recordsWritten, tot…
Browse files Browse the repository at this point in the history
…alRequestCharge) for the Spark connector write path (Azure#37510)

* azure-cosmos-spark: Adding metrics (bytesWritten, recordsWritten, totalRequestCharge) for the Spark connector write path

* Update CosmosRowConverterBase.scala

* Codestyle fixes

* Renaming TestMetricsPublisher

* Fixing build error

* Update SparkInternalsBridge.scala

* Updating changelogs

* Update CosmosConflictsTest.java

* Update BulkExecutor.java

* Update BulkExecutor.java

* Update CosmosBulkAsyncTest.java

* Reacting to code review comments

* Update CHANGELOG.md
  • Loading branch information
FabianMeiswinkel authored Nov 21, 2023
1 parent b09aa28 commit 1856ee3
Show file tree
Hide file tree
Showing 52 changed files with 2,209 additions and 637 deletions.
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-1_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.24.0-beta.1 (Unreleased)

#### Features Added
* Added `bytesWritten` and `recordsWritten` metrics in the sink of the Azure Cosmos DB connector. - See [PR 37510](https://github.com/Azure/azure-sdk-for-java/pull/37510)

#### Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import org.apache.spark.sql.connector.catalog.{NamespaceChange, SupportsNamespac

// scalastyle:off underscore.import

class CosmosCatalog
private[spark] class CosmosCatalog
extends CosmosCatalogBase
with SupportsNamespaces {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.cosmos.spark

import com.azure.cosmos.CosmosDiagnosticsContext
import com.azure.cosmos.implementation.ImplementationBridgeHelpers
import org.apache.spark.SparkInternalsBridge
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.WriterCommitMessage
import org.apache.spark.sql.types.StructType

import java.util.concurrent.atomic.AtomicLong

private class CosmosWriter(
userConfig: Map[String, String],
cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots],
diagnosticsConfig: DiagnosticsConfig,
inputSchema: StructType,
partitionId: Int,
taskId: Long,
epochId: Option[Long],
sparkEnvironmentInfo: String)
extends CosmosWriterBase(
userConfig,
cosmosClientStateHandles,
diagnosticsConfig,
inputSchema,
partitionId,
taskId,
epochId,
sparkEnvironmentInfo
) with OutputMetricsPublisherTrait {

private val recordsWritten = new AtomicLong(0)
private val bytesWritten = new AtomicLong(0)
private val count: AtomicLong = new AtomicLong(0)
override def getOutputMetricsPublisher(): OutputMetricsPublisherTrait = this

override def trackWriteOperation(recordCount: Long, diagnostics: Option[CosmosDiagnosticsContext]): Unit = {
if (recordCount > 0) {
recordsWritten.addAndGet(recordCount)
}

diagnostics match {
case Some(ctx) =>
bytesWritten.addAndGet(
if (ImplementationBridgeHelpers
.CosmosDiagnosticsContextHelper
.getCosmosDiagnosticsContextAccessor
.getOperationType(ctx)
.isReadOnlyOperation) {

ctx.getMaxRequestPayloadSizeInBytes + ctx.getMaxResponsePayloadSizeInBytes
} else {
ctx.getMaxRequestPayloadSizeInBytes
}
)
case None =>
}
}

override def write(internalRow: InternalRow): Unit = {
super.write(internalRow)

if (count.incrementAndGet() % SparkInternalsBridge.NUM_ROWS_PER_UPDATE == 0) {
SparkInternalsBridge.updateInternalTaskMetrics(recordsWritten.get, bytesWritten.get)
}
}

override def commit(): WriterCommitMessage = {
val commitMessage = super.commit()

// In Spark 3.1 there is no concept of custom metrics yet, updating bytesWritten and recordsWritten
// needs to be done manually in the DataSource (using internal TaskMetrics API) - so, this is a pretty
// ugly workaround - but given that Spark 3.1 is close to end-of-life already the risk of the
// behavior changing within Spark 3.1 is low enough to make it an acceptable workaround
SparkInternalsBridge.updateInternalTaskMetrics(recordsWritten.get, bytesWritten.get)

commitMessage
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ private case class ItemsScanBuilder(session: SparkSession,
with SupportsPushDownRequiredColumns {

@transient private lazy val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass)
log.logInfo(s"Instantiated ${this.getClass.getSimpleName}")
log.logTrace(s"Instantiated ${this.getClass.getSimpleName}")

val configMap = config.asScala.toMap
val readConfig = CosmosReadConfig.parseCosmosReadConfig(configMap)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ private class ItemsWriterBuilder
)
extends WriteBuilder {
@transient private lazy val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass)
log.logInfo(s"Instantiated ${this.getClass.getSimpleName}")
log.logTrace(s"Instantiated ${this.getClass.getSimpleName}")

override def buildForBatch(): BatchWrite =
new ItemsBatchWriter(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package org.apache.spark

object SparkInternalsBridge {

val NUM_ROWS_PER_UPDATE = 100

def updateInternalTaskMetrics(recordsWrittenSnapshot: Long, bytesWrittenSnapshot: Long): Unit = {
Option(TaskContext.get()) match {
case Some(taskContext) =>
val outputMetrics = taskContext.taskMetrics.outputMetrics
outputMetrics.setRecordsWritten(recordsWrittenSnapshot)
outputMetrics.setBytesWritten(bytesWrittenSnapshot)
case None =>
}
}
}
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-2_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.24.0-beta.1 (Unreleased)

#### Features Added
* Added `bytesWritten`, `recordsWritten` and `cosmos.totalRequestCharge` metrics in the sink of the Azure Cosmos DB connector. - See [PR 37510](https://github.com/Azure/azure-sdk-for-java/pull/37510)

#### Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.cosmos.spark

import com.azure.cosmos.CosmosDiagnosticsContext
import com.azure.cosmos.implementation.ImplementationBridgeHelpers
import org.apache.spark.SparkInternalsBridge
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.connector.write.WriterCommitMessage
import org.apache.spark.sql.execution.metric.CustomMetrics
import org.apache.spark.sql.types.StructType

import java.util.concurrent.atomic.AtomicLong

private class CosmosWriter(
userConfig: Map[String, String],
cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots],
diagnosticsConfig: DiagnosticsConfig,
inputSchema: StructType,
partitionId: Int,
taskId: Long,
epochId: Option[Long],
sparkEnvironmentInfo: String)
extends CosmosWriterBase(
userConfig,
cosmosClientStateHandles,
diagnosticsConfig,
inputSchema,
partitionId,
taskId,
epochId,
sparkEnvironmentInfo
) with OutputMetricsPublisherTrait {

private val recordsWritten = new AtomicLong(0)
private val bytesWritten = new AtomicLong(0)
private val totalRequestCharge = new AtomicLong(0)

private val recordsWrittenMetric = new CustomTaskMetric {
override def name(): String = CosmosConstants.MetricNames.RecordsWritten
override def value(): Long = recordsWritten.get()
}

private val bytesWrittenMetric = new CustomTaskMetric {
override def name(): String = CosmosConstants.MetricNames.BytesWritten

override def value(): Long = bytesWritten.get()
}

private val totalRequestChargeMetric = new CustomTaskMetric {
override def name(): String = CosmosConstants.MetricNames.TotalRequestCharge

// Internally we capture RU/s up to 2 fractional digits to have more precise rounding
override def value(): Long = totalRequestCharge.get() / 100L
}

private val metrics = Array(recordsWrittenMetric, bytesWrittenMetric, totalRequestChargeMetric)

private val count: AtomicLong = new AtomicLong(0)

override def currentMetricsValues(): Array[CustomTaskMetric] = {
metrics
}

override def getOutputMetricsPublisher(): OutputMetricsPublisherTrait = this

override def trackWriteOperation(recordCount: Long, diagnostics: Option[CosmosDiagnosticsContext]): Unit = {
if (recordCount > 0) {
recordsWritten.addAndGet(recordCount)
}

diagnostics match {
case Some(ctx) =>
// Capturing RU/s with 2 fractional digits internally
totalRequestCharge.addAndGet((ctx.getTotalRequestCharge * 100L).toLong)
bytesWritten.addAndGet(
if (ImplementationBridgeHelpers
.CosmosDiagnosticsContextHelper
.getCosmosDiagnosticsContextAccessor
.getOperationType(ctx)
.isReadOnlyOperation) {

ctx.getMaxRequestPayloadSizeInBytes + ctx.getMaxResponsePayloadSizeInBytes
} else {
ctx.getMaxRequestPayloadSizeInBytes
}
)
case None =>
}
}

override def write(internalRow: InternalRow): Unit = {
super.write(internalRow)

if (count.incrementAndGet() % SparkInternalsBridge.NUM_ROWS_PER_UPDATE == 0) {
SparkInternalsBridge.updateInternalTaskMetrics(currentMetricsValues())
}
}

override def commit(): WriterCommitMessage = {
val commitMessage = super.commit()

// TODO @fabianm - this is a workaround - it shouldn't be necessary to do this here
// Unfortunately WriteToDataSourceV2Exec.scala is not updating custom metrics after the
// call to commit - meaning DataSources which asynchronously write data and flush in commit
// won't get accurate metrics because updates between the last call to write and flushing the
// writes are lost. See https://issues.apache.org/jira/browse/SPARK-45759
// Once above issue is addressed (probably in Spark 3.4.1 or 3.5 - this needs to be changed
//
// NOTE: This also means that the RU/s metrics cannot be updated in commit - so the
// RU/s metric at the end of a task will be slightly outdated/behind
CustomMetrics.updateMetrics(
currentMetricsValues(),
SparkInternalsBridge.getInternalCustomTaskMetricsAsSQLMetric(CosmosConstants.MetricNames.KnownCustomMetricNames))

// In Spark 3.2 CustomMetrics.updateMetrics is not yet updating the built-in
// bytesWritten and recordsWritten metrics
SparkInternalsBridge.updateInternalTaskMetrics(currentMetricsValues())

commitMessage
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ private case class ItemsScanBuilder(session: SparkSession,
with SupportsPushDownRequiredColumns {

@transient private lazy val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass)
log.logInfo(s"Instantiated ${this.getClass.getSimpleName}")
log.logTrace(s"Instantiated ${this.getClass.getSimpleName}")

val configMap = config.asScala.toMap
val readConfig = CosmosReadConfig.parseCosmosReadConfig(configMap)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark

import com.azure.cosmos.spark.diagnostics.LoggerHelper
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric}
import org.apache.spark.sql.connector.write.streaming.StreamingWrite
import org.apache.spark.sql.connector.write.{BatchWrite, Write, WriteBuilder}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

// scalastyle:off underscore.import
import scala.collection.JavaConverters._
// scalastyle:on underscore.import

private class ItemsWriterBuilder
(
userConfig: CaseInsensitiveStringMap,
inputSchema: StructType,
cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots],
diagnosticsConfig: DiagnosticsConfig,
sparkEnvironmentInfo: String
)
extends WriteBuilder {
@transient private lazy val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass)
log.logTrace(s"Instantiated ${this.getClass.getSimpleName}")

override def build(): Write = {
new CosmosWrite
}

override def buildForBatch(): BatchWrite =
new ItemsBatchWriter(
userConfig.asCaseSensitiveMap().asScala.toMap,
inputSchema,
cosmosClientStateHandles,
diagnosticsConfig,
sparkEnvironmentInfo)

override def buildForStreaming(): StreamingWrite =
new ItemsBatchWriter(
userConfig.asCaseSensitiveMap().asScala.toMap,
inputSchema,
cosmosClientStateHandles,
diagnosticsConfig,
sparkEnvironmentInfo)

private class CosmosWrite extends Write {

private[this] val supportedCosmosMetrics: Array[CustomMetric] = {
Array(
new CustomSumMetric {
override def name(): String = CosmosConstants.MetricNames.BytesWritten

override def description(): String = CosmosConstants.MetricNames.BytesWritten
},
new CustomSumMetric {
override def name(): String = CosmosConstants.MetricNames.RecordsWritten

override def description(): String = CosmosConstants.MetricNames.RecordsWritten
},
new CustomSumMetric {
override def name(): String = CosmosConstants.MetricNames.TotalRequestCharge

override def description(): String = CosmosConstants.MetricNames.TotalRequestCharge
}
)
}

override def toBatch(): BatchWrite =
new ItemsBatchWriter(
userConfig.asCaseSensitiveMap().asScala.toMap,
inputSchema,
cosmosClientStateHandles,
diagnosticsConfig,
sparkEnvironmentInfo)

override def toStreaming: StreamingWrite =
new ItemsBatchWriter(
userConfig.asCaseSensitiveMap().asScala.toMap,
inputSchema,
cosmosClientStateHandles,
diagnosticsConfig,
sparkEnvironmentInfo)

override def supportedCustomMetrics(): Array[CustomMetric] = supportedCosmosMetrics
}
}
Loading

0 comments on commit 1856ee3

Please sign in to comment.