diff --git a/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala b/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala index 00f5e0d7..198060b4 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala @@ -19,10 +19,9 @@ package net.snowflake.spark.snowflake import java.sql.Timestamp import java.text._ -import java.time.ZonedDateTime +import java.time.{LocalDateTime, ZonedDateTime} import java.time.format.DateTimeFormatter import java.util.{Date, TimeZone} - import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -126,6 +125,15 @@ private[snowflake] object Conversions { TimeZone.getDefault.toZoneId)) } + def formatTimestamp(t: LocalDateTime): String = { + // For writing to snowflake, time zone needs to be included + // in the timestamp string. The spark default timezone is used. + timestampWriteFormatter.format( + ZonedDateTime.of( + t, + TimeZone.getDefault.toZoneId)) + } + // All strings are converted into double-quoted strings, with // quote inside converted to double quotes def formatString(s: String): String = { @@ -176,7 +184,7 @@ private[snowflake] object Conversions { case ShortType => data.toShort case StringType => if (isIR) UTF8String.fromString(data) else data - case TimestampType => parseTimestamp(data, isIR) + case TimestampType | TimestampNTZType => parseTimestamp(data, isIR) case _ => data } } @@ -276,7 +284,7 @@ private[snowflake] object Conversions { case ShortType => data.shortValue() case StringType => if (isIR) UTF8String.fromString(data.asText()) else data.asText() - case TimestampType => parseTimestamp(data.asText(), isIR) + case TimestampType | TimestampNTZType => parseTimestamp(data.asText(), isIR) case ArrayType(dt, _) => val result = new Array[Any](data.size()) (0 until data.size()) diff --git a/src/main/scala/net/snowflake/spark/snowflake/FilterPushdown.scala b/src/main/scala/net/snowflake/spark/snowflake/FilterPushdown.scala index 8f769fd8..9097240d 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/FilterPushdown.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/FilterPushdown.scala @@ -69,7 +69,7 @@ private[snowflake] object FilterPushdown { )) ! case DateType => StringVariable(Option(value).map(_.asInstanceOf[Date].toString)) + "::DATE" - case TimestampType => + case TimestampType | TimestampNTZType => StringVariable(Option(value).map(_.asInstanceOf[Timestamp].toString)) + "::TIMESTAMP(3)" case _ => value match { diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala index cd7cef81..68b1ad9c 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeJDBCWrapper.scala @@ -168,6 +168,7 @@ private[snowflake] class JDBCWrapper { "BINARY" } case TimestampType => "TIMESTAMP" + case TimestampNTZType => "TIMESTAMP_NTZ" case DateType => "DATE" case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})" case _: StructType | _: ArrayType | _: MapType => "VARIANT" diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala index 5257ac79..ba73ca16 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala @@ -19,6 +19,7 @@ package net.snowflake.spark.snowflake import java.sql.{Date, Timestamp} import net.snowflake.client.jdbc.internal.apache.commons.codec.binary.Base64 +import net.snowflake.spark.snowflake.Conversions.timestampWriteFormatter import net.snowflake.spark.snowflake.DefaultJDBCWrapper.{snowflakeStyleSchema, snowflakeStyleString} import net.snowflake.spark.snowflake.Parameters.MergedParameters import net.snowflake.spark.snowflake.io.SupportedFormat @@ -27,6 +28,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.types._ import org.apache.spark.sql._ +import java.time.{LocalDateTime, ZonedDateTime} +import java.util.TimeZone + /** * Functions to write data to Snowflake. * @@ -285,6 +289,11 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) { if (v == null) "" else Conversions.formatTimestamp(v.asInstanceOf[Timestamp]) } + case TimestampNTZType => + (v: Any) => { + if (v == null) "" + else Conversions.formatTimestamp(v.asInstanceOf[LocalDateTime]) + } case StringType => (v: Any) => { diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/ParquetUtils.scala b/src/main/scala/net/snowflake/spark/snowflake/io/ParquetUtils.scala index b860ec69..e8522ab1 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/ParquetUtils.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/ParquetUtils.scala @@ -93,7 +93,7 @@ object ParquetUtils { builder.stringBuilder() .prop("logicalType", "date") .endString() - case TimestampType => + case TimestampType | TimestampNTZType => builder.stringBuilder() .prop("logicalType", " timestamp-micros") .endString() diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/SnowflakeResultSetRDD.scala b/src/main/scala/net/snowflake/spark/snowflake/io/SnowflakeResultSetRDD.scala index 9ca177b8..f1f35e42 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/SnowflakeResultSetRDD.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/SnowflakeResultSetRDD.scala @@ -236,7 +236,7 @@ case class ResultIterator[T: ClassTag]( case FloatType => data.getFloat(index + 1) case IntegerType => data.getInt(index + 1) case LongType => data.getLong(index + 1) - case TimestampType => + case TimestampType | TimestampNTZType => if (isIR) { DateTimeUtils.fromJavaTimestamp(data.getTimestamp(index + 1)) } else { diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala index 2d8dc0cb..b2da845a 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/StageWriter.scala @@ -25,7 +25,7 @@ import net.snowflake.spark.snowflake.io.SupportedFormat.SupportedFormat import net.snowflake.spark.snowflake.DefaultJDBCWrapper.DataBaseOperations import net.snowflake.spark.snowflake.test.{TestHook, TestHookFlag} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.{BinaryType, StructType, TimestampType} +import org.apache.spark.sql.types.{BinaryType, StructType, TimestampNTZType, TimestampType} import org.apache.spark.sql.{SQLContext, SaveMode} import org.slf4j.LoggerFactory @@ -925,7 +925,7 @@ private[io] object StageWriter { val mappingFromString = getMappingFromString(mappingList, fromString) val hasTimestampColumn: Boolean = - schema.exists(field => field.dataType == TimestampType) + schema.exists(field => field.dataType == TimestampType || field.dataType == TimestampNTZType) val timestampFormat: String = if (params.getStringTimestampFormat.isEmpty) { diff --git a/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala b/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala index bb451553..96eb840e 100644 --- a/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala +++ b/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala @@ -144,6 +144,7 @@ class ConversionsSuite extends FunSuite { |"short":123, |"string": "test string", |"timestamp": "2015-07-01 00:00:00.001", + |"timestamp_ntz": "2015-07-01 00:00:00.001", |"array":[1,2,3,4,5], |"map":{"a":1,"b":2,"c":3}, |"structure":{"num":123,"str":"str1"} @@ -163,6 +164,7 @@ class ConversionsSuite extends FunSuite { StructField("short", ShortType, nullable = false), StructField("string", StringType, nullable = false), StructField("timestamp", TimestampType, nullable = false), + StructField("timestamp_ntz", TimestampNTZType, nullable = false), StructField("array", ArrayType(IntegerType), nullable = false), StructField("map", MapType(StringType, IntegerType), nullable = false), StructField( @@ -199,9 +201,10 @@ class ConversionsSuite extends FunSuite { assert(result.getShort(8) == 123.toShort) assert(result.getString(9) == "test string") assert(result.getTimestamp(10) == Timestamp.valueOf("2015-07-01 00:00:00.001")) - assert(result.getSeq(11) sameElements Array(1, 2, 3, 4, 5)) - assert(result.getMap(12) == Map("b" -> 2, "a" -> 1, "c" -> 3)) - assert(result.getStruct(13) == Row(123, "str1")) + assert(result.getTimestamp(11) == Timestamp.valueOf("2015-07-01 00:00:00.001")) + assert(result.getSeq(12) sameElements Array(1, 2, 3, 4, 5)) + assert(result.getMap(13) == Map("b" -> 2, "a" -> 1, "c" -> 3)) + assert(result.getStruct(14) == Row(123, "str1")) } }