From 7e7ac8bc45932ca783ed71b4955d9c7f1fe9b6f4 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Wed, 9 Jul 2025 11:37:57 +0800 Subject: [PATCH] [SPARK-52704][SQL][FOLLOWUP] Move session state utilities from trait FileFormat to SessionStateHelper --- .../execution/datasources/FileFormat.scala | 20 ++----- .../datasources/csv/CSVFileFormat.scala | 2 +- .../datasources/json/JsonFileFormat.scala | 2 +- .../execution/datasources/v2/FileScan.scala | 11 ++-- .../datasources/xml/XmlFileFormat.scala | 2 +- .../sql/internal/SessionStateHelper.scala | 53 +++++++++++++++++++ 6 files changed, 63 insertions(+), 27 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/internal/SessionStateHelper.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 05c939bed5f8b..8a254b464da71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.internal.{SessionState, SQLConf} +import org.apache.spark.sql.internal.{SessionStateHelper, SQLConf} import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -235,20 +235,6 @@ trait FileFormat { */ def fileConstantMetadataExtractors: Map[String, PartitionedFile => Any] = FileFormat.BASE_METADATA_EXTRACTORS - - protected def sessionState(sparkSession: SparkSession): SessionState = { - sparkSession.sessionState - } - - protected def sqlConf(sparkSession: SparkSession): SQLConf = { - sessionState(sparkSession).conf - } - - protected def hadoopConf( - sparkSession: SparkSession, - options: Map[String, String]): Configuration = { - sessionState(sparkSession).newHadoopConfWithOptions(options) - } } object FileFormat { @@ -370,7 +356,7 @@ object FileFormat { /** * The base class file format that is based on text file. */ -abstract class TextBasedFileFormat extends FileFormat { +abstract class TextBasedFileFormat extends FileFormat with SessionStateHelper { private var codecFactory: CompressionCodecFactory = _ override def isSplitable( @@ -378,7 +364,7 @@ abstract class TextBasedFileFormat extends FileFormat { options: Map[String, String], path: Path): Boolean = { if (codecFactory == null) { - codecFactory = new CompressionCodecFactory(hadoopConf(sparkSession, options)) + codecFactory = new CompressionCodecFactory(getHadoopConf(sparkSession, options)) } val codec = codecFactory.getCodec(path) codec == null || codec.isInstanceOf[SplittableCompressionCodec] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index c0ec531add357..bf189268b4d6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -168,7 +168,7 @@ case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister { private def getCsvOptions( sparkSession: SparkSession, options: Map[String, String]): CSVOptions = { - val conf = sqlConf(sparkSession) + val conf = getSqlConf(sparkSession) new CSVOptions( options, conf.csvColumnPruning, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index a165993b5af58..e3b78ef432505 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -136,7 +136,7 @@ case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister spark: SparkSession, options: Map[String, String], inRead: Boolean = true): JSONOptions = { - val conf = sqlConf(spark) + val conf = getSqlConf(spark) if (inRead) { new JSONOptionsInRead(options, conf.sessionLocalTimeZone, conf.columnNameOfCorruptRecord) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 9ace0540ec243..32fe21701af67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -30,11 +30,11 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression, Expr import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes -import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan, Statistics, SupportsReportStatistics} +import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.PartitionedFileUtil import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.internal.{SessionState, SQLConf} +import org.apache.spark.sql.internal.{SessionStateHelper, SQLConf} import org.apache.spark.sql.internal.connector.SupportsMetadata import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType @@ -113,10 +113,7 @@ trait FileScan extends Scan override def hashCode(): Int = getClass.hashCode() - override def conf: SQLConf = { - val sessionState: SessionState = sparkSession.sessionState - sessionState.conf - } + override def conf: SQLConf = SessionStateHelper.getSqlConf(sparkSession) val maxMetadataValueLength = conf.maxMetadataStringLength @@ -177,7 +174,7 @@ trait FileScan extends Scan if (splitFiles.length == 1) { val path = splitFiles(0).toPath if (!isSplitable(path) && splitFiles(0).length > - sparkSession.sparkContext.conf.get(IO_WARNING_LARGEFILETHRESHOLD)) { + SessionStateHelper.getSparkConf(sparkSession).get(IO_WARNING_LARGEFILETHRESHOLD)) { logWarning(log"Loading one large unsplittable file ${MDC(PATH, path.toString)} with only " + log"one partition, the reason is: ${MDC(REASON, getFileUnSplittableReason(path))}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala index 5038c55e046ed..06151c5fa4d64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala @@ -42,7 +42,7 @@ case class XmlFileFormat() extends TextBasedFileFormat with DataSourceRegister { private def getXmlOptions( sparkSession: SparkSession, parameters: Map[String, String]): XmlOptions = { - val conf = sqlConf(sparkSession) + val conf = getSqlConf(sparkSession) new XmlOptions(parameters, conf.sessionLocalTimeZone, conf.columnNameOfCorruptRecord, true) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionStateHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionStateHelper.scala new file mode 100644 index 0000000000000..6279f8c123765 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionStateHelper.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession + +/** + * Helper trait to access session state related configurations and utilities. + * It also provides type annotations for IDEs to build indexes. + */ +trait SessionStateHelper { + private def sessionState(sparkSession: SparkSession): SessionState = { + sparkSession.sessionState + } + + private def sparkContext(sparkSession: SparkSession): SparkContext = { + sparkSession.sparkContext + } + + def getSparkConf(sparkSession: SparkSession): SparkConf = { + sparkContext(sparkSession).conf + } + + def getSqlConf(sparkSession: SparkSession): SQLConf = { + sessionState(sparkSession).conf + } + + def getHadoopConf( + sparkSession: SparkSession, + options: Map[String, String]): Configuration = { + sessionState(sparkSession).newHadoopConfWithOptions(options) + } +} + +object SessionStateHelper extends SessionStateHelper