Skip to content

Commit cdd7f22

Browse files
committed
[SPARK-52704][SQL] Simplify interoperations between SQLConf and file-format options in TextBasedFileFormats
### What changes were proposed in this pull request? Simplify interoperations between SQLConf and file-format options in TextBasedFileFormats ### Why are the changes needed? - Reduce code duplication - Restore type annotation for IDE ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #51398 from yaooqinn/SPARK-52704. Authored-by: Kent Yao <yao@apache.org> Signed-off-by: Kent Yao <yao@apache.org>
1 parent 25b0f7b commit cdd7f22

File tree

4 files changed

+58
-57
lines changed

4 files changed

+58
-57
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions._
2929
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
3030
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
3131
import org.apache.spark.sql.errors.QueryExecutionErrors
32-
import org.apache.spark.sql.internal.SQLConf
32+
import org.apache.spark.sql.internal.{SessionState, SQLConf}
3333
import org.apache.spark.sql.sources.Filter
3434
import org.apache.spark.sql.types._
3535

@@ -235,6 +235,20 @@ trait FileFormat {
235235
*/
236236
def fileConstantMetadataExtractors: Map[String, PartitionedFile => Any] =
237237
FileFormat.BASE_METADATA_EXTRACTORS
238+
239+
protected def sessionState(sparkSession: SparkSession): SessionState = {
240+
sparkSession.sessionState
241+
}
242+
243+
protected def sqlConf(sparkSession: SparkSession): SQLConf = {
244+
sessionState(sparkSession).conf
245+
}
246+
247+
protected def hadoopConf(
248+
sparkSession: SparkSession,
249+
options: Map[String, String]): Configuration = {
250+
sessionState(sparkSession).newHadoopConfWithOptions(options)
251+
}
238252
}
239253

240254
object FileFormat {
@@ -364,8 +378,7 @@ abstract class TextBasedFileFormat extends FileFormat {
364378
options: Map[String, String],
365379
path: Path): Boolean = {
366380
if (codecFactory == null) {
367-
codecFactory = new CompressionCodecFactory(
368-
sparkSession.sessionState.newHadoopConfWithOptions(options))
381+
codecFactory = new CompressionCodecFactory(hadoopConf(sparkSession, options))
369382
}
370383
val codec = codecFactory.getCodec(path)
371384
codec == null || codec.isInstanceOf[SplittableCompressionCodec]

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,15 @@ case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister {
4343
sparkSession: SparkSession,
4444
options: Map[String, String],
4545
path: Path): Boolean = {
46-
val parsedOptions = new CSVOptions(
47-
options,
48-
columnPruning = sparkSession.sessionState.conf.csvColumnPruning,
49-
sparkSession.sessionState.conf.sessionLocalTimeZone)
50-
val csvDataSource = CSVDataSource(parsedOptions)
51-
csvDataSource.isSplitable && super.isSplitable(sparkSession, options, path)
46+
val parsedOptions = getCsvOptions(sparkSession, options)
47+
CSVDataSource(parsedOptions).isSplitable && super.isSplitable(sparkSession, options, path)
5248
}
5349

5450
override def inferSchema(
5551
sparkSession: SparkSession,
5652
options: Map[String, String],
5753
files: Seq[FileStatus]): Option[StructType] = {
58-
val parsedOptions = new CSVOptions(
59-
options,
60-
columnPruning = sparkSession.sessionState.conf.csvColumnPruning,
61-
sparkSession.sessionState.conf.sessionLocalTimeZone)
62-
54+
val parsedOptions = getCsvOptions(sparkSession, options)
6355
CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions)
6456
}
6557

@@ -76,25 +68,21 @@ case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister {
7668
throw QueryCompilationErrors.dataTypeUnsupportedByDataSourceError("CSV", field)
7769
}
7870
}
79-
val conf = job.getConfiguration
80-
val csvOptions = new CSVOptions(
81-
options,
82-
columnPruning = sparkSession.sessionState.conf.csvColumnPruning,
83-
sparkSession.sessionState.conf.sessionLocalTimeZone)
84-
csvOptions.compressionCodec.foreach { codec =>
85-
CompressionCodecs.setCodecConfiguration(conf, codec)
71+
val parsedOptions = getCsvOptions(sparkSession, options)
72+
parsedOptions.compressionCodec.foreach { codec =>
73+
CompressionCodecs.setCodecConfiguration(job.getConfiguration, codec)
8674
}
8775

8876
new OutputWriterFactory {
8977
override def newInstance(
9078
path: String,
9179
dataSchema: StructType,
9280
context: TaskAttemptContext): OutputWriter = {
93-
new CsvOutputWriter(path, dataSchema, context, csvOptions)
81+
new CsvOutputWriter(path, dataSchema, context, parsedOptions)
9482
}
9583

9684
override def getFileExtension(context: TaskAttemptContext): String = {
97-
"." + csvOptions.extension + CodecStreams.getCompressionExtension(context)
85+
"." + parsedOptions.extension + CodecStreams.getCompressionExtension(context)
9886
}
9987
}
10088
}
@@ -109,11 +97,7 @@ case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister {
10997
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
11098
val broadcastedHadoopConf =
11199
SerializableConfiguration.broadcast(sparkSession.sparkContext, hadoopConf)
112-
val parsedOptions = new CSVOptions(
113-
options,
114-
sparkSession.sessionState.conf.csvColumnPruning,
115-
sparkSession.sessionState.conf.sessionLocalTimeZone,
116-
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
100+
val parsedOptions = getCsvOptions(sparkSession, options)
117101
val isColumnPruningEnabled = parsedOptions.isColumnPruningEnabled(requiredSchema)
118102

119103
// Check a field requirement for corrupt records here to throw an exception in a driver side
@@ -180,4 +164,15 @@ case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister {
180164
}
181165

182166
override def allowDuplicatedColumnNames: Boolean = true
167+
168+
private def getCsvOptions(
169+
sparkSession: SparkSession,
170+
options: Map[String, String]): CSVOptions = {
171+
val conf = sqlConf(sparkSession)
172+
new CSVOptions(
173+
options,
174+
conf.csvColumnPruning,
175+
conf.sessionLocalTimeZone,
176+
conf.columnNameOfCorruptRecord)
177+
}
183178
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,15 @@ case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister
3939
sparkSession: SparkSession,
4040
options: Map[String, String],
4141
path: Path): Boolean = {
42-
val parsedOptions = new JSONOptionsInRead(
43-
options,
44-
sparkSession.sessionState.conf.sessionLocalTimeZone,
45-
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
46-
val jsonDataSource = JsonDataSource(parsedOptions)
47-
jsonDataSource.isSplitable && super.isSplitable(sparkSession, options, path)
42+
val parsedOptions = getJsonOptions(sparkSession, options)
43+
JsonDataSource(parsedOptions).isSplitable && super.isSplitable(sparkSession, options, path)
4844
}
4945

5046
override def inferSchema(
5147
sparkSession: SparkSession,
5248
options: Map[String, String],
5349
files: Seq[FileStatus]): Option[StructType] = {
54-
val parsedOptions = new JSONOptionsInRead(
55-
options,
56-
sparkSession.sessionState.conf.sessionLocalTimeZone,
57-
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
50+
val parsedOptions = getJsonOptions(sparkSession, options)
5851
JsonDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions)
5952
}
6053

@@ -63,13 +56,9 @@ case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister
6356
job: Job,
6457
options: Map[String, String],
6558
dataSchema: StructType): OutputWriterFactory = {
66-
val conf = job.getConfiguration
67-
val parsedOptions = new JSONOptions(
68-
options,
69-
sparkSession.sessionState.conf.sessionLocalTimeZone,
70-
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
59+
val parsedOptions = getJsonOptions(sparkSession, options, inRead = false)
7160
parsedOptions.compressionCodec.foreach { codec =>
72-
CompressionCodecs.setCodecConfiguration(conf, codec)
61+
CompressionCodecs.setCodecConfiguration(job.getConfiguration, codec)
7362
}
7463

7564
new OutputWriterFactory {
@@ -96,12 +85,7 @@ case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister
9685
hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
9786
val broadcastedHadoopConf =
9887
SerializableConfiguration.broadcast(sparkSession.sparkContext, hadoopConf)
99-
100-
val parsedOptions = new JSONOptionsInRead(
101-
options,
102-
sparkSession.sessionState.conf.sessionLocalTimeZone,
103-
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
104-
88+
val parsedOptions = getJsonOptions(sparkSession, options)
10589
val actualSchema =
10690
StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
10791
// Check a field requirement for corrupt records here to throw an exception in a driver side
@@ -147,4 +131,16 @@ case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister
147131

148132
case _ => false
149133
}
134+
135+
private def getJsonOptions(
136+
spark: SparkSession,
137+
options: Map[String, String],
138+
inRead: Boolean = true): JSONOptions = {
139+
val conf = sqlConf(spark)
140+
if (inRead) {
141+
new JSONOptionsInRead(options, conf.sessionLocalTimeZone, conf.columnNameOfCorruptRecord)
142+
} else {
143+
new JSONOptions(options, conf.sessionLocalTimeZone, conf.columnNameOfCorruptRecord)
144+
}
145+
}
150146
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,19 @@ case class XmlFileFormat() extends TextBasedFileFormat with DataSourceRegister {
3939

4040
override def shortName(): String = "xml"
4141

42-
def getXmlOptions(
42+
private def getXmlOptions(
4343
sparkSession: SparkSession,
4444
parameters: Map[String, String]): XmlOptions = {
45-
new XmlOptions(parameters,
46-
sparkSession.sessionState.conf.sessionLocalTimeZone,
47-
sparkSession.sessionState.conf.columnNameOfCorruptRecord,
48-
true)
45+
val conf = sqlConf(sparkSession)
46+
new XmlOptions(parameters, conf.sessionLocalTimeZone, conf.columnNameOfCorruptRecord, true)
4947
}
5048

5149
override def isSplitable(
5250
sparkSession: SparkSession,
5351
options: Map[String, String],
5452
path: Path): Boolean = {
5553
val xmlOptions = getXmlOptions(sparkSession, options)
56-
val xmlDataSource = XmlDataSource(xmlOptions)
57-
xmlDataSource.isSplitable && super.isSplitable(sparkSession, options, path)
54+
XmlDataSource(xmlOptions).isSplitable && super.isSplitable(sparkSession, options, path)
5855
}
5956

6057
override def inferSchema(

0 commit comments

Comments
 (0)