diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index d3078740b819c..05c939bed5f8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{SessionState, SQLConf} import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -235,6 +235,20 @@ trait FileFormat { */ def fileConstantMetadataExtractors: Map[String, PartitionedFile => Any] = FileFormat.BASE_METADATA_EXTRACTORS + + protected def sessionState(sparkSession: SparkSession): SessionState = { + sparkSession.sessionState + } + + protected def sqlConf(sparkSession: SparkSession): SQLConf = { + sessionState(sparkSession).conf + } + + protected def hadoopConf( + sparkSession: SparkSession, + options: Map[String, String]): Configuration = { + sessionState(sparkSession).newHadoopConfWithOptions(options) + } } object FileFormat { @@ -364,8 +378,7 @@ abstract class TextBasedFileFormat extends FileFormat { options: Map[String, String], path: Path): Boolean = { if (codecFactory == null) { - codecFactory = new CompressionCodecFactory( - sparkSession.sessionState.newHadoopConfWithOptions(options)) + codecFactory = new CompressionCodecFactory(hadoopConf(sparkSession, options)) } val codec = codecFactory.getCodec(path) codec == null || codec.isInstanceOf[SplittableCompressionCodec] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 8aaeae3ae952f..c0ec531add357 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -43,23 +43,15 @@ case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = { - val parsedOptions = new CSVOptions( - options, - columnPruning = sparkSession.sessionState.conf.csvColumnPruning, - sparkSession.sessionState.conf.sessionLocalTimeZone) - val csvDataSource = CSVDataSource(parsedOptions) - csvDataSource.isSplitable && super.isSplitable(sparkSession, options, path) + val parsedOptions = getCsvOptions(sparkSession, options) + CSVDataSource(parsedOptions).isSplitable && super.isSplitable(sparkSession, options, path) } override def inferSchema( sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val parsedOptions = new CSVOptions( - options, - columnPruning = sparkSession.sessionState.conf.csvColumnPruning, - sparkSession.sessionState.conf.sessionLocalTimeZone) - + val parsedOptions = getCsvOptions(sparkSession, options) CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions) } @@ -76,13 +68,9 @@ case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister { throw QueryCompilationErrors.dataTypeUnsupportedByDataSourceError("CSV", field) } } - val conf = job.getConfiguration - val csvOptions = new CSVOptions( - options, - columnPruning = sparkSession.sessionState.conf.csvColumnPruning, - sparkSession.sessionState.conf.sessionLocalTimeZone) - csvOptions.compressionCodec.foreach { codec => - CompressionCodecs.setCodecConfiguration(conf, codec) + val parsedOptions = getCsvOptions(sparkSession, options) + parsedOptions.compressionCodec.foreach { codec => + CompressionCodecs.setCodecConfiguration(job.getConfiguration, codec) } new OutputWriterFactory { @@ -90,11 +78,11 @@ case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new CsvOutputWriter(path, dataSchema, context, csvOptions) + new CsvOutputWriter(path, dataSchema, context, parsedOptions) } override def getFileExtension(context: TaskAttemptContext): String = { - "." + csvOptions.extension + CodecStreams.getCompressionExtension(context) + "." + parsedOptions.extension + CodecStreams.getCompressionExtension(context) } } } @@ -109,11 +97,7 @@ case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister { hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { val broadcastedHadoopConf = SerializableConfiguration.broadcast(sparkSession.sparkContext, hadoopConf) - val parsedOptions = new CSVOptions( - options, - sparkSession.sessionState.conf.csvColumnPruning, - sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val parsedOptions = getCsvOptions(sparkSession, options) val isColumnPruningEnabled = parsedOptions.isColumnPruningEnabled(requiredSchema) // Check a field requirement for corrupt records here to throw an exception in a driver side @@ -180,4 +164,15 @@ case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister { } override def allowDuplicatedColumnNames: Boolean = true + + private def getCsvOptions( + sparkSession: SparkSession, + options: Map[String, String]): CSVOptions = { + val conf = sqlConf(sparkSession) + new CSVOptions( + options, + conf.csvColumnPruning, + conf.sessionLocalTimeZone, + conf.columnNameOfCorruptRecord) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index ed096cf289b56..a165993b5af58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -39,22 +39,15 @@ case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = { - val parsedOptions = new JSONOptionsInRead( - options, - sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val jsonDataSource = JsonDataSource(parsedOptions) - jsonDataSource.isSplitable && super.isSplitable(sparkSession, options, path) + val parsedOptions = getJsonOptions(sparkSession, options) + JsonDataSource(parsedOptions).isSplitable && super.isSplitable(sparkSession, options, path) } override def inferSchema( sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val parsedOptions = new JSONOptionsInRead( - options, - sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val parsedOptions = getJsonOptions(sparkSession, options) JsonDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions) } @@ -63,13 +56,9 @@ case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - val conf = job.getConfiguration - val parsedOptions = new JSONOptions( - options, - sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val parsedOptions = getJsonOptions(sparkSession, options, inRead = false) parsedOptions.compressionCodec.foreach { codec => - CompressionCodecs.setCodecConfiguration(conf, codec) + CompressionCodecs.setCodecConfiguration(job.getConfiguration, codec) } new OutputWriterFactory { @@ -96,12 +85,7 @@ case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { val broadcastedHadoopConf = SerializableConfiguration.broadcast(sparkSession.sparkContext, hadoopConf) - - val parsedOptions = new JSONOptionsInRead( - options, - sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord) - + val parsedOptions = getJsonOptions(sparkSession, options) val actualSchema = StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) // Check a field requirement for corrupt records here to throw an exception in a driver side @@ -147,4 +131,16 @@ case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister case _ => false } + + private def getJsonOptions( + spark: SparkSession, + options: Map[String, String], + inRead: Boolean = true): JSONOptions = { + val conf = sqlConf(spark) + if (inRead) { + new JSONOptionsInRead(options, conf.sessionLocalTimeZone, conf.columnNameOfCorruptRecord) + } else { + new JSONOptions(options, conf.sessionLocalTimeZone, conf.columnNameOfCorruptRecord) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala index 5072a87af4df7..5038c55e046ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala @@ -39,13 +39,11 @@ case class XmlFileFormat() extends TextBasedFileFormat with DataSourceRegister { override def shortName(): String = "xml" - def getXmlOptions( + private def getXmlOptions( sparkSession: SparkSession, parameters: Map[String, String]): XmlOptions = { - new XmlOptions(parameters, - sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord, - true) + val conf = sqlConf(sparkSession) + new XmlOptions(parameters, conf.sessionLocalTimeZone, conf.columnNameOfCorruptRecord, true) } override def isSplitable( @@ -53,8 +51,7 @@ case class XmlFileFormat() extends TextBasedFileFormat with DataSourceRegister { options: Map[String, String], path: Path): Boolean = { val xmlOptions = getXmlOptions(sparkSession, options) - val xmlDataSource = XmlDataSource(xmlOptions) - xmlDataSource.isSplitable && super.isSplitable(sparkSession, options, path) + XmlDataSource(xmlOptions).isSplitable && super.isSplitable(sparkSession, options, path) } override def inferSchema(