Skip to content

[SPARK-52704][SQL] Simplify interoperations between SQLConf and file-format options in TextBasedFileFormats #51398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.{SessionState, SQLConf}
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -235,6 +235,20 @@ trait FileFormat {
*/
def fileConstantMetadataExtractors: Map[String, PartitionedFile => Any] =
FileFormat.BASE_METADATA_EXTRACTORS

protected def sessionState(sparkSession: SparkSession): SessionState = {
sparkSession.sessionState
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the method body, this method is a kind of utility method which can be a static method also.


protected def sqlConf(sparkSession: SparkSession): SQLConf = {
sessionState(sparkSession).conf
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same. Both sqlConf and sessionState can be a utility method which can be a static method also.


protected def hadoopConf(
sparkSession: SparkSession,
options: Map[String, String]): Configuration = {
sessionState(sparkSession).newHadoopConfWithOptions(options)
}
Copy link
Member

@dongjoon-hyun dongjoon-hyun Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. This method has no relation to any underlying objects or trait FileFormat directly. We had better put this outside independently with the above sessionState and sqlConf methods

}

object FileFormat {
Expand Down Expand Up @@ -364,8 +378,7 @@ abstract class TextBasedFileFormat extends FileFormat {
options: Map[String, String],
path: Path): Boolean = {
if (codecFactory == null) {
codecFactory = new CompressionCodecFactory(
sparkSession.sessionState.newHadoopConfWithOptions(options))
codecFactory = new CompressionCodecFactory(hadoopConf(sparkSession, options))
}
val codec = codecFactory.getCodec(path)
codec == null || codec.isInstanceOf[SplittableCompressionCodec]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,15 @@ case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister {
sparkSession: SparkSession,
options: Map[String, String],
path: Path): Boolean = {
val parsedOptions = new CSVOptions(
options,
columnPruning = sparkSession.sessionState.conf.csvColumnPruning,
sparkSession.sessionState.conf.sessionLocalTimeZone)
val csvDataSource = CSVDataSource(parsedOptions)
csvDataSource.isSplitable && super.isSplitable(sparkSession, options, path)
val parsedOptions = getCsvOptions(sparkSession, options)
CSVDataSource(parsedOptions).isSplitable && super.isSplitable(sparkSession, options, path)
}

override def inferSchema(
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
val parsedOptions = new CSVOptions(
options,
columnPruning = sparkSession.sessionState.conf.csvColumnPruning,
sparkSession.sessionState.conf.sessionLocalTimeZone)

val parsedOptions = getCsvOptions(sparkSession, options)
CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions)
}

Expand All @@ -76,25 +68,21 @@ case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister {
throw QueryCompilationErrors.dataTypeUnsupportedByDataSourceError("CSV", field)
}
}
val conf = job.getConfiguration
val csvOptions = new CSVOptions(
options,
columnPruning = sparkSession.sessionState.conf.csvColumnPruning,
sparkSession.sessionState.conf.sessionLocalTimeZone)
csvOptions.compressionCodec.foreach { codec =>
CompressionCodecs.setCodecConfiguration(conf, codec)
val parsedOptions = getCsvOptions(sparkSession, options)
parsedOptions.compressionCodec.foreach { codec =>
CompressionCodecs.setCodecConfiguration(job.getConfiguration, codec)
}

new OutputWriterFactory {
override def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new CsvOutputWriter(path, dataSchema, context, csvOptions)
new CsvOutputWriter(path, dataSchema, context, parsedOptions)
}

override def getFileExtension(context: TaskAttemptContext): String = {
"." + csvOptions.extension + CodecStreams.getCompressionExtension(context)
"." + parsedOptions.extension + CodecStreams.getCompressionExtension(context)
}
}
}
Expand All @@ -109,11 +97,7 @@ case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister {
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
val broadcastedHadoopConf =
SerializableConfiguration.broadcast(sparkSession.sparkContext, hadoopConf)
val parsedOptions = new CSVOptions(
options,
sparkSession.sessionState.conf.csvColumnPruning,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val parsedOptions = getCsvOptions(sparkSession, options)
val isColumnPruningEnabled = parsedOptions.isColumnPruningEnabled(requiredSchema)

// Check a field requirement for corrupt records here to throw an exception in a driver side
Expand Down Expand Up @@ -180,4 +164,15 @@ case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister {
}

override def allowDuplicatedColumnNames: Boolean = true

private def getCsvOptions(
sparkSession: SparkSession,
options: Map[String, String]): CSVOptions = {
val conf = sqlConf(sparkSession)
new CSVOptions(
options,
conf.csvColumnPruning,
conf.sessionLocalTimeZone,
conf.columnNameOfCorruptRecord)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,15 @@ case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister
sparkSession: SparkSession,
options: Map[String, String],
path: Path): Boolean = {
val parsedOptions = new JSONOptionsInRead(
options,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val jsonDataSource = JsonDataSource(parsedOptions)
jsonDataSource.isSplitable && super.isSplitable(sparkSession, options, path)
val parsedOptions = getJsonOptions(sparkSession, options)
JsonDataSource(parsedOptions).isSplitable && super.isSplitable(sparkSession, options, path)
}

override def inferSchema(
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
val parsedOptions = new JSONOptionsInRead(
options,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val parsedOptions = getJsonOptions(sparkSession, options)
JsonDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions)
}

Expand All @@ -63,13 +56,9 @@ case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
val conf = job.getConfiguration
val parsedOptions = new JSONOptions(
options,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val parsedOptions = getJsonOptions(sparkSession, options, inRead = false)
parsedOptions.compressionCodec.foreach { codec =>
CompressionCodecs.setCodecConfiguration(conf, codec)
CompressionCodecs.setCodecConfiguration(job.getConfiguration, codec)
}

new OutputWriterFactory {
Expand All @@ -96,12 +85,7 @@ case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister
hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
val broadcastedHadoopConf =
SerializableConfiguration.broadcast(sparkSession.sparkContext, hadoopConf)

val parsedOptions = new JSONOptionsInRead(
options,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)

val parsedOptions = getJsonOptions(sparkSession, options)
val actualSchema =
StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
// Check a field requirement for corrupt records here to throw an exception in a driver side
Expand Down Expand Up @@ -147,4 +131,16 @@ case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister

case _ => false
}

private def getJsonOptions(
spark: SparkSession,
options: Map[String, String],
inRead: Boolean = true): JSONOptions = {
val conf = sqlConf(spark)
if (inRead) {
new JSONOptionsInRead(options, conf.sessionLocalTimeZone, conf.columnNameOfCorruptRecord)
} else {
new JSONOptions(options, conf.sessionLocalTimeZone, conf.columnNameOfCorruptRecord)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,19 @@ case class XmlFileFormat() extends TextBasedFileFormat with DataSourceRegister {

override def shortName(): String = "xml"

def getXmlOptions(
private def getXmlOptions(
sparkSession: SparkSession,
parameters: Map[String, String]): XmlOptions = {
new XmlOptions(parameters,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord,
true)
val conf = sqlConf(sparkSession)
new XmlOptions(parameters, conf.sessionLocalTimeZone, conf.columnNameOfCorruptRecord, true)
}

override def isSplitable(
sparkSession: SparkSession,
options: Map[String, String],
path: Path): Boolean = {
val xmlOptions = getXmlOptions(sparkSession, options)
val xmlDataSource = XmlDataSource(xmlOptions)
xmlDataSource.isSplitable && super.isSplitable(sparkSession, options, path)
XmlDataSource(xmlOptions).isSplitable && super.isSplitable(sparkSession, options, path)
}

override def inferSchema(
Expand Down