Skip to content

Commit

Permalink
[SPARK-41970] Introduce SparkPath for typesafety
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR proposes a strongly typed `SparkPath` that encapsulates a url-encoded string. It has helper methods for creating hadoop paths, uris, and uri-encoded strings.
The intent is to identify and fix various bugs in the way that Spark handles these paths. To do this we introduced the SparkPath type to `PartitionFile` (a widely used class), and then started fixing compile errors. In doing so we fixed various bugs.

### Why are the changes needed?

Given `val str = "s3://bucket/path with space/a"` There is a difference between `new Path(str)` and `new Path(new URI(str))`, and thus a difference between `new URI(str)` and `new Path(str).toUri`.
Both `URI` and `Path` are symmetric in construction and `toString`, but are not interchangeable. Spark confuses these two paths (uri-encoded vs not). This PR attempts to use types to disambiguate them.

### Does this PR introduce _any_ user-facing change?

This PR proposes changing the public API of `PartitionedFile`, and various other methods in the name of type safety. It needs to be clear to callers of an API what type of path string is expected.

### How was this patch tested?

We rely on existing tests, and update the default temp path creation to include paths with spaces.

Closes apache#39488 from databricks-david-lewis/SPARK_PATH.

Authored-by: David Lewis <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
databricks-david-lewis authored and cloud-fan committed Jan 19, 2023
1 parent 498b3ec commit faedcd9
Show file tree
Hide file tree
Showing 42 changed files with 216 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.avro

import java.io._
import java.net.URI

import scala.util.control.NonFatal

Expand Down Expand Up @@ -96,9 +95,9 @@ private[sql] class AvroFileFormat extends FileFormat
// Doing input file filtering is improper because we may generate empty tasks that process no
// input files but stress the scheduler. We should probably add a more general input file
// filtering mechanism for `FileFormat` data sources. See SPARK-16317.
if (parsedOptions.ignoreExtension || file.filePath.endsWith(".avro")) {
if (parsedOptions.ignoreExtension || file.urlEncodedPath.endsWith(".avro")) {
val reader = {
val in = new FsInput(new Path(new URI(file.filePath)), conf)
val in = new FsInput(file.toPath, conf)
try {
val datumReader = userProvidedSchema match {
case Some(userSchema) => new GenericDatumReader[GenericRecord](userSchema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@
*/
package org.apache.spark.sql.v2.avro

import java.net.URI

import scala.util.control.NonFatal

import org.apache.avro.file.DataFileReader
import org.apache.avro.generic.{GenericDatumReader, GenericRecord}
import org.apache.avro.mapred.FsInput
import org.apache.hadoop.fs.Path

import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
Expand Down Expand Up @@ -62,9 +59,9 @@ case class AvroPartitionReaderFactory(
val conf = broadcastedConf.value.value
val userProvidedSchema = options.schema

if (options.ignoreExtension || partitionedFile.filePath.endsWith(".avro")) {
if (options.ignoreExtension || partitionedFile.urlEncodedPath.endsWith(".avro")) {
val reader = {
val in = new FsInput(new Path(new URI(partitionedFile.filePath)), conf)
val in = new FsInput(partitionedFile.toPath, conf)
try {
val datumReader = userProvidedSchema match {
case Some(userSchema) => new GenericDatumReader[GenericRecord](userSchema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
package org.apache.spark.sql.avro

import java.io._
import java.net.URI

import org.apache.avro.file.DataFileReader
import org.apache.avro.generic.{GenericDatumReader, GenericRecord}
import org.apache.avro.mapred.FsInput
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.SparkConf
import org.apache.spark.sql._
Expand Down Expand Up @@ -62,8 +60,8 @@ class AvroRowReaderSuite
case BatchScanExec(_, f: AvroScan, _, _, _, _, _) => f
}
val filePath = fileScan.get.fileIndex.inputFiles(0)
val fileSize = new File(new URI(filePath)).length
val in = new FsInput(new Path(new URI(filePath)), new Configuration())
val fileSize = new File(filePath.toUri).length
val in = new FsInput(filePath.toPath, new Configuration())
val reader = DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]())

val it = new Iterator[InternalRow] with AvroUtils.RowReader {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2357,7 +2357,8 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
assert(fileScan.get.dataFilters.nonEmpty)
assert(fileScan.get.planInputPartitions().forall { partition =>
partition.asInstanceOf[FilePartition].files.forall { file =>
file.filePath.contains("p1=1") && file.filePath.contains("p2=2")
file.urlEncodedPath.contains("p1=1") &&
file.urlEncodedPath.contains("p2=2")
}
})
checkAnswer(df, Row("b", 1, 2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private[spark] class WorkerWatcher(
private[deploy] var isShutDown = false

// Lets filter events only from the worker's rpc system
private val expectedAddress = RpcAddress.fromURIString(workerUrl)
private val expectedAddress = RpcAddress.fromUrlString(workerUrl)
private def isWorker(address: RpcAddress) = expectedAddress == address

private def exitNonZero() =
Expand Down
55 changes: 55 additions & 0 deletions core/src/main/scala/org/apache/spark/paths/SparkPath.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.paths

import java.net.URI

import org.apache.hadoop.fs.{FileStatus, Path}

/**
* A canonical representation of a file path. This class is intended to provide
* type-safety to the way that Spark handles Paths. Paths can be represented as
* Strings in multiple ways, which are not always compatible. Spark regularly uses
* two ways: 1. hadoop Path.toString and java URI.toString.
*/
case class SparkPath private (private val underlying: String) {
def urlEncoded: String = underlying
def toUri: URI = new URI(underlying)
def toPath: Path = new Path(toUri)
override def toString: String = underlying
}

object SparkPath {
/**
* Creates a SparkPath from a hadoop Path string.
* Please be very sure that the provided string is encoded (or not encoded) in the right way.
*
* Please see the hadoop Path documentation here:
* https://hadoop.apache.org/docs/stable/api/org/apache/hadoop/fs/Path.html#Path-java.lang.String-
*/
def fromPathString(str: String): SparkPath = fromPath(new Path(str))
def fromPath(path: Path): SparkPath = fromUri(path.toUri)
def fromFileStatus(fs: FileStatus): SparkPath = fromPath(fs.getPath)

/**
* Creates a SparkPath from a url-encoded string.
* Note: It is the responsibility of the caller to ensure that str is a valid url-encoded string.
*/
def fromUrlString(str: String): SparkPath = SparkPath(str)
def fromUri(uri: URI): SparkPath = fromUrlString(uri.toString)
}
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[spark] case class RpcAddress(_host: String, port: Int) {
private[spark] object RpcAddress {

/** Return the [[RpcAddress]] represented by `uri`. */
def fromURIString(uri: String): RpcAddress = {
def fromUrlString(uri: String): RpcAddress = {
val uriObj = new java.net.URI(uri)
RpcAddress(uriObj.getHost, uriObj.getPort)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.source.image

import com.google.common.io.{ByteStreams, Closeables}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.mapreduce.Job

import org.apache.spark.ml.image.ImageSchema
Expand Down Expand Up @@ -71,8 +71,8 @@ private[image] class ImageFileFormat extends FileFormat with DataSourceRegister
if (!imageSourceOptions.dropInvalid && requiredSchema.isEmpty) {
Iterator(emptyUnsafeRow)
} else {
val origin = file.filePath
val path = new Path(origin)
val origin = file.urlEncodedPath
val path = file.toPath
val fs = path.getFileSystem(broadcastedHadoopConf.value.value)
val stream = fs.open(path)
val bytes = try {
Expand Down
8 changes: 8 additions & 0 deletions scalastyle-config.xml
Original file line number Diff line number Diff line change
Expand Up @@ -437,4 +437,12 @@ This file is divided into 3 sections:
Use org.apache.spark.util.Utils.createTempDir instead.
</customMessage>
</check>

<check customId="pathfromuri" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
<parameters><parameter name="regex">new Path\(new URI\(</parameter></parameters>
<customMessage><![CDATA[
Are you sure that this string is uri encoded? Please be careful when converting hadoop Paths
and URIs to and from String. If possible, please use SparkPath.
]]></customMessage>
</check>
</scalastyle>
7 changes: 4 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.api.java.function._
import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
import org.apache.spark.api.r.RRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.paths.SparkPath
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QueryPlanningTracker, ScalaReflection, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
Expand Down Expand Up @@ -3924,18 +3925,18 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
def inputFiles: Array[String] = {
val files: Seq[String] = queryExecution.optimizedPlan.collect {
val files: Seq[SparkPath] = queryExecution.optimizedPlan.collect {
case LogicalRelation(fsBasedRelation: FileRelation, _, _, _) =>
fsBasedRelation.inputFiles
case fr: FileRelation =>
fr.inputFiles
case r: HiveTableRelation =>
r.tableMeta.storage.locationUri.map(_.toString).toArray
r.tableMeta.storage.locationUri.map(SparkPath.fromUri).toArray
case DataSourceV2ScanRelation(DataSourceV2Relation(table: FileTable, _, _, _, _),
_, _, _, _) =>
table.fileIndex.inputFiles
}.flatten
files.toSet.toArray
files.iterator.map(_.urlEncoded).toSet.toArray
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,8 @@ case class FileSourceScanExec(
}
}.groupBy { f =>
BucketingUtils
.getBucketId(new Path(f.filePath).getName)
.getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.filePath))
.getBucketId(f.toPath.getName)
.getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.urlEncodedPath))
}

val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

package org.apache.spark.sql.execution

import org.apache.spark.paths.SparkPath

/**
* An interface for relations that are backed by files. When a class implements this interface,
* the list of paths that it returns will be returned to a user who calls `inputPaths` on any
* DataFrame that queries this relation.
*/
trait FileRelation {
/** Returns the list of files that will be read when scanning this relation. */
def inputFiles: Array[String]
def inputFiles: Array[SparkPath]
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution

import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path}

import org.apache.spark.paths.SparkPath
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources._
Expand All @@ -36,7 +37,7 @@ object PartitionedFileUtil {
val remaining = file.getLen - offset
val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining
val hosts = getBlockHosts(getBlockLocations(file), offset, size)
PartitionedFile(partitionValues, filePath.toUri.toString, offset, size, hosts,
PartitionedFile(partitionValues, SparkPath.fromPath(filePath), offset, size, hosts,
file.getModificationTime, file.getLen)
}
} else {
Expand All @@ -49,7 +50,7 @@ object PartitionedFileUtil {
filePath: Path,
partitionValues: InternalRow): PartitionedFile = {
val hosts = getBlockHosts(getBlockLocations(file), 0, file.getLen)
PartitionedFile(partitionValues, filePath.toUri.toString, 0, file.getLen, hosts,
PartitionedFile(partitionValues, SparkPath.fromPath(filePath), 0, file.getLen, hosts,
file.getModificationTime, file.getLen)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.net.URI
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.paths.SparkPath
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, ExternalCatalogUtils}
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -94,7 +95,7 @@ class CatalogFileIndex(
}
}

override def inputFiles: Array[String] = filterPartitions(Nil).inputFiles
override def inputFiles: Array[SparkPath] = filterPartitions(Nil).inputFiles

// `CatalogFileIndex` may be a member of `HadoopFsRelation`, `HadoopFsRelation` may be a member
// of `LogicalRelation`, and `LogicalRelation` may be used as the cache key. So we need to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ import org.apache.spark.util.{HadoopFSUtils, ThreadUtils, Utils}
*
* @param paths A list of file system paths that hold data. These will be globbed before if
* the "__globPaths__" option is true, and will be qualified. This option only works
* when reading from a [[FileFormat]].
* when reading from a [[FileFormat]]. These paths are expected to be hadoop [[Path]]
* strings.
* @param userSpecifiedSchema An optional specification of the schema of the data. When present
* we skip attempting to infer the schema.
* @param partitionColumns A list of column names that the relation is partitioned by. This list is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources

import org.apache.hadoop.fs._

import org.apache.spark.paths.SparkPath
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -62,7 +63,7 @@ trait FileIndex {
* Returns the list of files that will be read when scanning this relation. This call may be
* very expensive for large tables.
*/
def inputFiles: Array[String]
def inputFiles: Array[SparkPath]

/** Refresh any cached file listings */
def refresh(): Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
package org.apache.spark.sql.execution.datasources

import java.io.{Closeable, FileNotFoundException, IOException}
import java.net.URI

import scala.util.control.NonFatal

import org.apache.hadoop.fs.Path

import org.apache.spark.{Partition => RDDPartition, SparkUpgradeException, TaskContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.paths.SparkPath
import org.apache.spark.rdd.{InputFileBlockHolder, RDD}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{FileSourceOptions, InternalRow}
Expand All @@ -51,12 +53,17 @@ import org.apache.spark.util.NextIterator
*/
case class PartitionedFile(
partitionValues: InternalRow,
filePath: String,
filePath: SparkPath,
start: Long,
length: Long,
@transient locations: Array[String] = Array.empty,
modificationTime: Long = 0L,
fileSize: Long = 0L) {

def pathUri: URI = filePath.toUri
def toPath: Path = filePath.toPath
def urlEncodedPath: String = filePath.urlEncoded

override def toString: String = {
s"path: $filePath, range: $start-${start + length}, partition values: $partitionValues"
}
Expand Down Expand Up @@ -140,14 +147,14 @@ class FileScanRDD(
private def updateMetadataRow(): Unit =
if (metadataColumns.nonEmpty && currentFile != null) {
updateMetadataInternalRow(metadataRow, metadataColumns.map(_.name),
new Path(currentFile.filePath), currentFile.fileSize, currentFile.modificationTime)
currentFile.toPath, currentFile.fileSize, currentFile.modificationTime)
}

/**
* Create an array of constant column vectors containing all required metadata columns
*/
private def createMetadataColumnVector(c: ColumnarBatch): Array[ColumnVector] = {
val path = new Path(currentFile.filePath)
val path = currentFile.toPath
metadataColumns.map(_.name).map {
case FILE_PATH =>
val columnVector = new ConstantColumnVector(c.numRows(), StringType)
Expand Down Expand Up @@ -223,7 +230,8 @@ class FileScanRDD(
updateMetadataRow()
logInfo(s"Reading File $currentFile")
// Sets InputFileBlockHolder for the file block's information
InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length)
InputFileBlockHolder
.set(currentFile.urlEncodedPath, currentFile.start, currentFile.length)

resetCurrentIterator()
if (ignoreMissingFiles || ignoreCorruptFiles) {
Expand Down Expand Up @@ -278,12 +286,13 @@ class FileScanRDD(
} catch {
case e: SchemaColumnConvertNotSupportedException =>
throw QueryExecutionErrors.unsupportedSchemaColumnConvertError(
currentFile.filePath, e.getColumn, e.getLogicalType, e.getPhysicalType, e)
currentFile.urlEncodedPath, e.getColumn, e.getLogicalType, e.getPhysicalType, e)
case sue: SparkUpgradeException => throw sue
case NonFatal(e) =>
e.getCause match {
case sue: SparkUpgradeException => throw sue
case _ => throw QueryExecutionErrors.cannotReadFilesError(e, currentFile.filePath)
case _ =>
throw QueryExecutionErrors.cannotReadFilesError(e, currentFile.urlEncodedPath)
}
}
} else {
Expand Down
Loading

0 comments on commit faedcd9

Please sign in to comment.