|
21 | 21 |
|
22 | 22 | import org.apache.parquet.column.ColumnDescriptor;
|
23 | 23 | import org.apache.parquet.schema.LogicalTypeAnnotation;
|
24 |
| -import org.apache.parquet.schema.MessageType; |
25 | 24 | import org.apache.parquet.schema.PrimitiveType;
|
26 | 25 | import org.apache.parquet.schema.Type;
|
| 26 | +import org.apache.parquet.schema.Types; |
27 | 27 | import org.apache.spark.sql.types.*;
|
28 | 28 |
|
29 | 29 | import org.apache.comet.CometSchemaImporter;
|
@@ -290,15 +290,154 @@ public static ColumnDescriptor buildColumnDescriptor(ParquetColumnSpec columnSpe
|
290 | 290 | }
|
291 | 291 |
|
292 | 292 | String name = columnSpec.getPath()[columnSpec.getPath().length - 1];
|
| 293 | + // Reconstruct the logical type from parameters |
| 294 | + LogicalTypeAnnotation logicalType = null; |
| 295 | + if (columnSpec.getLogicalTypeName() != null) { |
| 296 | + logicalType = |
| 297 | + reconstructLogicalType( |
| 298 | + columnSpec.getLogicalTypeName(), columnSpec.getLogicalTypeParams()); |
| 299 | + } |
293 | 300 |
|
294 | 301 | PrimitiveType primitiveType;
|
295 | 302 | if (primType == PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) {
|
296 |
| - primitiveType = new PrimitiveType(repetition, primType, columnSpec.getTypeLength(), name); |
| 303 | + primitiveType = |
| 304 | + Types.primitive(primType, repetition) |
| 305 | + .length(columnSpec.getTypeLength()) |
| 306 | + .as(logicalType) |
| 307 | + .id(columnSpec.getFieldId()) |
| 308 | + .named(name); |
297 | 309 | } else {
|
298 |
| - primitiveType = new PrimitiveType(repetition, primType, name); |
| 310 | + primitiveType = |
| 311 | + Types.primitive(primType, repetition) |
| 312 | + .as(logicalType) |
| 313 | + .id(columnSpec.getFieldId()) |
| 314 | + .named(name); |
299 | 315 | }
|
300 | 316 |
|
301 |
| - MessageType schema = new MessageType("root", primitiveType); |
302 |
| - return schema.getColumnDescription(columnSpec.getPath()); |
| 317 | + return new ColumnDescriptor( |
| 318 | + columnSpec.getPath(), |
| 319 | + primitiveType, |
| 320 | + columnSpec.getMaxRepetitionLevel(), |
| 321 | + columnSpec.getMaxDefinitionLevel()); |
| 322 | + } |
| 323 | + |
| 324 | + private static LogicalTypeAnnotation reconstructLogicalType( |
| 325 | + String logicalTypeName, java.util.Map<String, String> params) { |
| 326 | + |
| 327 | + switch (logicalTypeName) { |
| 328 | + // MAP |
| 329 | + case "MapLogicalTypeAnnotation": |
| 330 | + return LogicalTypeAnnotation.mapType(); |
| 331 | + |
| 332 | + // LIST |
| 333 | + case "ListLogicalTypeAnnotation": |
| 334 | + return LogicalTypeAnnotation.listType(); |
| 335 | + |
| 336 | + // STRING |
| 337 | + case "StringLogicalTypeAnnotation": |
| 338 | + return LogicalTypeAnnotation.stringType(); |
| 339 | + |
| 340 | + // MAP_KEY_VALUE |
| 341 | + case "MapKeyValueLogicalTypeAnnotation": |
| 342 | + return LogicalTypeAnnotation.MapKeyValueTypeAnnotation.getInstance(); |
| 343 | + |
| 344 | + // ENUM |
| 345 | + case "EnumLogicalTypeAnnotation": |
| 346 | + return LogicalTypeAnnotation.enumType(); |
| 347 | + |
| 348 | + // DECIMAL |
| 349 | + case "DecimalLogicalTypeAnnotation": |
| 350 | + if (!params.containsKey("scale") || !params.containsKey("precision")) { |
| 351 | + throw new IllegalArgumentException( |
| 352 | + "Missing required parameters for DecimalLogicalTypeAnnotation: " + params); |
| 353 | + } |
| 354 | + int scale = Integer.parseInt(params.get("scale")); |
| 355 | + int precision = Integer.parseInt(params.get("precision")); |
| 356 | + return LogicalTypeAnnotation.decimalType(scale, precision); |
| 357 | + |
| 358 | + // DATE |
| 359 | + case "DateLogicalTypeAnnotation": |
| 360 | + return LogicalTypeAnnotation.dateType(); |
| 361 | + |
| 362 | + // TIME |
| 363 | + case "TimeLogicalTypeAnnotation": |
| 364 | + if (!params.containsKey("isAdjustedToUTC") || !params.containsKey("unit")) { |
| 365 | + throw new IllegalArgumentException( |
| 366 | + "Missing required parameters for TimeLogicalTypeAnnotation: " + params); |
| 367 | + } |
| 368 | + |
| 369 | + boolean isUTC = Boolean.parseBoolean(params.get("isAdjustedToUTC")); |
| 370 | + String timeUnitStr = params.get("unit"); |
| 371 | + |
| 372 | + LogicalTypeAnnotation.TimeUnit timeUnit; |
| 373 | + switch (timeUnitStr) { |
| 374 | + case "MILLIS": |
| 375 | + timeUnit = LogicalTypeAnnotation.TimeUnit.MILLIS; |
| 376 | + break; |
| 377 | + case "MICROS": |
| 378 | + timeUnit = LogicalTypeAnnotation.TimeUnit.MICROS; |
| 379 | + break; |
| 380 | + case "NANOS": |
| 381 | + timeUnit = LogicalTypeAnnotation.TimeUnit.NANOS; |
| 382 | + break; |
| 383 | + default: |
| 384 | + throw new IllegalArgumentException("Unknown time unit: " + timeUnitStr); |
| 385 | + } |
| 386 | + return LogicalTypeAnnotation.timeType(isUTC, timeUnit); |
| 387 | + |
| 388 | + // TIMESTAMP |
| 389 | + case "TimestampLogicalTypeAnnotation": |
| 390 | + if (!params.containsKey("isAdjustedToUTC") || !params.containsKey("unit")) { |
| 391 | + throw new IllegalArgumentException( |
| 392 | + "Missing required parameters for TimestampLogicalTypeAnnotation: " + params); |
| 393 | + } |
| 394 | + boolean isAdjustedToUTC = Boolean.parseBoolean(params.get("isAdjustedToUTC")); |
| 395 | + String unitStr = params.get("unit"); |
| 396 | + |
| 397 | + LogicalTypeAnnotation.TimeUnit unit; |
| 398 | + switch (unitStr) { |
| 399 | + case "MILLIS": |
| 400 | + unit = LogicalTypeAnnotation.TimeUnit.MILLIS; |
| 401 | + break; |
| 402 | + case "MICROS": |
| 403 | + unit = LogicalTypeAnnotation.TimeUnit.MICROS; |
| 404 | + break; |
| 405 | + case "NANOS": |
| 406 | + unit = LogicalTypeAnnotation.TimeUnit.NANOS; |
| 407 | + break; |
| 408 | + default: |
| 409 | + throw new IllegalArgumentException("Unknown timestamp unit: " + unitStr); |
| 410 | + } |
| 411 | + return LogicalTypeAnnotation.timestampType(isAdjustedToUTC, unit); |
| 412 | + |
| 413 | + // INTEGER |
| 414 | + case "IntLogicalTypeAnnotation": |
| 415 | + if (!params.containsKey("isSigned") || !params.containsKey("bitWidth")) { |
| 416 | + throw new IllegalArgumentException( |
| 417 | + "Missing required parameters for IntLogicalTypeAnnotation: " + params); |
| 418 | + } |
| 419 | + boolean isSigned = Boolean.parseBoolean(params.get("isSigned")); |
| 420 | + int bitWidth = Integer.parseInt(params.get("bitWidth")); |
| 421 | + return LogicalTypeAnnotation.intType(bitWidth, isSigned); |
| 422 | + |
| 423 | + // JSON |
| 424 | + case "JsonLogicalTypeAnnotation": |
| 425 | + return LogicalTypeAnnotation.jsonType(); |
| 426 | + |
| 427 | + // BSON |
| 428 | + case "BsonLogicalTypeAnnotation": |
| 429 | + return LogicalTypeAnnotation.bsonType(); |
| 430 | + |
| 431 | + // UUID |
| 432 | + case "UUIDLogicalTypeAnnotation": |
| 433 | + return LogicalTypeAnnotation.uuidType(); |
| 434 | + |
| 435 | + // INTERVAL |
| 436 | + case "IntervalLogicalTypeAnnotation": |
| 437 | + return LogicalTypeAnnotation.IntervalLogicalTypeAnnotation.getInstance(); |
| 438 | + |
| 439 | + default: |
| 440 | + throw new IllegalArgumentException("Unknown logical type: " + logicalTypeName); |
| 441 | + } |
303 | 442 | }
|
304 | 443 | }
|
0 commit comments