Skip to content

Commit b92b10c

Browse files
Infer missing arrow schema (#233)
* Refactor: pass the whole metadata object to result handlers Signed-off-by: Levko Kravets <levko.ne@gmail.com> * Infer Arrow schema when it is not available Signed-off-by: Levko Kravets <levko.ne@gmail.com> --------- Signed-off-by: Levko Kravets <levko.ne@gmail.com>
1 parent ff9fc0d commit b92b10c

12 files changed

+186
-52
lines changed

lib/DBSQLOperation.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -372,20 +372,20 @@ export default class DBSQLOperation implements IOperation {
372372

373373
switch (resultFormat) {
374374
case TSparkRowSetType.COLUMN_BASED_SET:
375-
resultSource = new JsonResultHandler(this.context, this._data, metadata.schema);
375+
resultSource = new JsonResultHandler(this.context, this._data, metadata);
376376
break;
377377
case TSparkRowSetType.ARROW_BASED_SET:
378378
resultSource = new ArrowResultConverter(
379379
this.context,
380-
new ArrowResultHandler(this.context, this._data, metadata.arrowSchema, metadata.lz4Compressed),
381-
metadata.schema,
380+
new ArrowResultHandler(this.context, this._data, metadata),
381+
metadata,
382382
);
383383
break;
384384
case TSparkRowSetType.URL_BASED_SET:
385385
resultSource = new ArrowResultConverter(
386386
this.context,
387-
new CloudFetchResultHandler(this.context, this._data, metadata.lz4Compressed),
388-
metadata.schema,
387+
new CloudFetchResultHandler(this.context, this._data, metadata),
388+
metadata,
389389
);
390390
break;
391391
// no default

lib/result/ArrowResultConverter.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import {
1313
RecordBatchReader,
1414
util as arrowUtils,
1515
} from 'apache-arrow';
16-
import { TTableSchema, TColumnDesc } from '../../thrift/TCLIService_types';
16+
import { TGetResultSetMetadataResp, TColumnDesc } from '../../thrift/TCLIService_types';
1717
import IClientContext from '../contracts/IClientContext';
1818
import IResultsProvider, { ResultsProviderFetchNextOptions } from './IResultsProvider';
1919
import { getSchemaColumns, convertThriftValue } from './utils';
@@ -34,7 +34,7 @@ export default class ArrowResultConverter implements IResultsProvider<Array<any>
3434

3535
private pendingRecordBatch?: RecordBatch<TypeMap>;
3636

37-
constructor(context: IClientContext, source: IResultsProvider<Array<Buffer>>, schema?: TTableSchema) {
37+
constructor(context: IClientContext, source: IResultsProvider<Array<Buffer>>, { schema }: TGetResultSetMetadataResp) {
3838
this.context = context;
3939
this.source = source;
4040
this.schema = getSchemaColumns(schema);

lib/result/ArrowResultHandler.ts

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import LZ4 from 'lz4';
2-
import { TRowSet } from '../../thrift/TCLIService_types';
2+
import { TGetResultSetMetadataResp, TRowSet } from '../../thrift/TCLIService_types';
33
import IClientContext from '../contracts/IClientContext';
44
import IResultsProvider, { ResultsProviderFetchNextOptions } from './IResultsProvider';
5+
import { hiveSchemaToArrowSchema } from './utils';
56

67
export default class ArrowResultHandler implements IResultsProvider<Array<Buffer>> {
78
protected readonly context: IClientContext;
@@ -15,13 +16,14 @@ export default class ArrowResultHandler implements IResultsProvider<Array<Buffer
1516
constructor(
1617
context: IClientContext,
1718
source: IResultsProvider<TRowSet | undefined>,
18-
arrowSchema?: Buffer,
19-
isLZ4Compressed?: boolean,
19+
{ schema, arrowSchema, lz4Compressed }: TGetResultSetMetadataResp,
2020
) {
2121
this.context = context;
2222
this.source = source;
23-
this.arrowSchema = arrowSchema;
24-
this.isLZ4Compressed = isLZ4Compressed ?? false;
23+
// Arrow schema is not available in old DBR versions, which also don't support native Arrow types,
24+
// so it's possible to infer Arrow schema from Hive schema ignoring `useArrowNativeTypes` option
25+
this.arrowSchema = arrowSchema ?? hiveSchemaToArrowSchema(schema);
26+
this.isLZ4Compressed = lz4Compressed ?? false;
2527
}
2628

2729
public async hasMore() {

lib/result/CloudFetchResultHandler.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import LZ4 from 'lz4';
22
import fetch, { RequestInfo, RequestInit } from 'node-fetch';
3-
import { TRowSet, TSparkArrowResultLink } from '../../thrift/TCLIService_types';
3+
import { TGetResultSetMetadataResp, TRowSet, TSparkArrowResultLink } from '../../thrift/TCLIService_types';
44
import IClientContext from '../contracts/IClientContext';
55
import IResultsProvider, { ResultsProviderFetchNextOptions } from './IResultsProvider';
66

@@ -15,10 +15,14 @@ export default class CloudFetchResultHandler implements IResultsProvider<Array<B
1515

1616
private downloadTasks: Array<Promise<Buffer>> = [];
1717

18-
constructor(context: IClientContext, source: IResultsProvider<TRowSet | undefined>, isLZ4Compressed?: boolean) {
18+
constructor(
19+
context: IClientContext,
20+
source: IResultsProvider<TRowSet | undefined>,
21+
{ lz4Compressed }: TGetResultSetMetadataResp,
22+
) {
1923
this.context = context;
2024
this.source = source;
21-
this.isLZ4Compressed = isLZ4Compressed ?? false;
25+
this.isLZ4Compressed = lz4Compressed ?? false;
2226
}
2327

2428
public async hasMore() {

lib/result/JsonResultHandler.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { ColumnCode } from '../hive/Types';
2-
import { TRowSet, TTableSchema, TColumn, TColumnDesc } from '../../thrift/TCLIService_types';
2+
import { TGetResultSetMetadataResp, TRowSet, TColumn, TColumnDesc } from '../../thrift/TCLIService_types';
33
import IClientContext from '../contracts/IClientContext';
44
import IResultsProvider, { ResultsProviderFetchNextOptions } from './IResultsProvider';
55
import { getSchemaColumns, convertThriftValue } from './utils';
@@ -11,7 +11,11 @@ export default class JsonResultHandler implements IResultsProvider<Array<any>> {
1111

1212
private readonly schema: Array<TColumnDesc>;
1313

14-
constructor(context: IClientContext, source: IResultsProvider<TRowSet | undefined>, schema?: TTableSchema) {
14+
constructor(
15+
context: IClientContext,
16+
source: IResultsProvider<TRowSet | undefined>,
17+
{ schema }: TGetResultSetMetadataResp,
18+
) {
1519
this.context = context;
1620
this.source = source;
1721
this.schema = getSchemaColumns(schema);

lib/result/ResultSlicer.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,13 @@ export default class ResultSlicer<T> implements IResultsProvider<Array<T>> {
5252
// Fetch items from source results provider until we reach a requested count
5353
while (resultsCount < options.limit) {
5454
// eslint-disable-next-line no-await-in-loop
55-
const chunk = await this.source.fetchNext(options);
56-
if (chunk.length === 0) {
55+
const hasMore = await this.source.hasMore();
56+
if (!hasMore) {
5757
break;
5858
}
5959

60+
// eslint-disable-next-line no-await-in-loop
61+
const chunk = await this.source.fetchNext(options);
6062
result.push(chunk);
6163
resultsCount += chunk.length;
6264
}

lib/result/utils.ts

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
11
import Int64 from 'node-int64';
2+
import {
3+
Schema,
4+
Field,
5+
DataType,
6+
Bool as ArrowBool,
7+
Int8 as ArrowInt8,
8+
Int16 as ArrowInt16,
9+
Int32 as ArrowInt32,
10+
Int64 as ArrowInt64,
11+
Float32 as ArrowFloat32,
12+
Float64 as ArrowFloat64,
13+
Utf8 as ArrowString,
14+
Date_ as ArrowDate,
15+
Binary as ArrowBinary,
16+
DateUnit,
17+
RecordBatchWriter,
18+
} from 'apache-arrow';
219
import { TTableSchema, TColumnDesc, TPrimitiveTypeEntry, TTypeId } from '../../thrift/TCLIService_types';
20+
import HiveDriverError from '../errors/HiveDriverError';
321

422
export function getSchemaColumns(schema?: TTableSchema): Array<TColumnDesc> {
523
if (!schema) {
@@ -73,3 +91,52 @@ export function convertThriftValue(typeDescriptor: TPrimitiveTypeEntry | undefin
7391
return value;
7492
}
7593
}
94+
95+
// This type map corresponds to Arrow without native types support (most complex types are serialized as strings)
96+
const hiveTypeToArrowType: Record<TTypeId, DataType | null> = {
97+
[TTypeId.BOOLEAN_TYPE]: new ArrowBool(),
98+
[TTypeId.TINYINT_TYPE]: new ArrowInt8(),
99+
[TTypeId.SMALLINT_TYPE]: new ArrowInt16(),
100+
[TTypeId.INT_TYPE]: new ArrowInt32(),
101+
[TTypeId.BIGINT_TYPE]: new ArrowInt64(),
102+
[TTypeId.FLOAT_TYPE]: new ArrowFloat32(),
103+
[TTypeId.DOUBLE_TYPE]: new ArrowFloat64(),
104+
[TTypeId.STRING_TYPE]: new ArrowString(),
105+
[TTypeId.TIMESTAMP_TYPE]: new ArrowString(),
106+
[TTypeId.BINARY_TYPE]: new ArrowBinary(),
107+
[TTypeId.ARRAY_TYPE]: new ArrowString(),
108+
[TTypeId.MAP_TYPE]: new ArrowString(),
109+
[TTypeId.STRUCT_TYPE]: new ArrowString(),
110+
[TTypeId.UNION_TYPE]: new ArrowString(),
111+
[TTypeId.USER_DEFINED_TYPE]: new ArrowString(),
112+
[TTypeId.DECIMAL_TYPE]: new ArrowString(),
113+
[TTypeId.NULL_TYPE]: null,
114+
[TTypeId.DATE_TYPE]: new ArrowDate(DateUnit.DAY),
115+
[TTypeId.VARCHAR_TYPE]: new ArrowString(),
116+
[TTypeId.CHAR_TYPE]: new ArrowString(),
117+
[TTypeId.INTERVAL_YEAR_MONTH_TYPE]: new ArrowString(),
118+
[TTypeId.INTERVAL_DAY_TIME_TYPE]: new ArrowString(),
119+
};
120+
121+
export function hiveSchemaToArrowSchema(schema?: TTableSchema): Buffer | undefined {
122+
if (!schema) {
123+
return undefined;
124+
}
125+
126+
const columns = getSchemaColumns(schema);
127+
128+
const arrowFields = columns.map((column) => {
129+
const hiveType = column.typeDesc.types[0].primitiveEntry?.type ?? undefined;
130+
const arrowType = hiveType !== undefined ? hiveTypeToArrowType[hiveType] : undefined;
131+
if (!arrowType) {
132+
throw new HiveDriverError(`Unsupported column type: ${hiveType ? TTypeId[hiveType] : 'undefined'}`);
133+
}
134+
return new Field(column.columnName, arrowType, true);
135+
});
136+
137+
const arrowSchema = new Schema(arrowFields);
138+
const writer = new RecordBatchWriter();
139+
writer.reset(undefined, arrowSchema);
140+
writer.finish();
141+
return Buffer.from(writer.toUint8Array(true));
142+
}

tests/unit/result/ArrowResultConverter.test.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,30 +57,30 @@ describe('ArrowResultHandler', () => {
5757
it('should convert data', async () => {
5858
const context = {};
5959
const rowSetProvider = new ResultsProviderMock([sampleArrowBatch]);
60-
const result = new ArrowResultConverter(context, rowSetProvider, sampleThriftSchema);
60+
const result = new ArrowResultConverter(context, rowSetProvider, { schema: sampleThriftSchema });
6161
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([{ 1: 1 }]);
6262
});
6363

6464
it('should return empty array if no data to process', async () => {
6565
const context = {};
6666
const rowSetProvider = new ResultsProviderMock([], []);
67-
const result = new ArrowResultConverter(context, rowSetProvider, sampleThriftSchema);
67+
const result = new ArrowResultConverter(context, rowSetProvider, { schema: sampleThriftSchema });
6868
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([]);
6969
expect(await result.hasMore()).to.be.false;
7070
});
7171

7272
it('should return empty array if no schema available', async () => {
7373
const context = {};
7474
const rowSetProvider = new ResultsProviderMock([sampleArrowBatch]);
75-
const result = new ArrowResultConverter(context, rowSetProvider);
75+
const result = new ArrowResultConverter(context, rowSetProvider, {});
7676
expect(await result.hasMore()).to.be.false;
7777
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([]);
7878
});
7979

8080
it('should detect nulls', async () => {
8181
const context = {};
8282
const rowSetProvider = new ResultsProviderMock([arrowBatchAllNulls]);
83-
const result = new ArrowResultConverter(context, rowSetProvider, thriftSchemaAllNulls);
83+
const result = new ArrowResultConverter(context, rowSetProvider, { schema: thriftSchemaAllNulls });
8484
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([
8585
{
8686
boolean_field: null,

tests/unit/result/ArrowResultHandler.test.js

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ describe('ArrowResultHandler', () => {
6161
it('should return data', async () => {
6262
const context = {};
6363
const rowSetProvider = new ResultsProviderMock([sampleRowSet1]);
64-
const result = new ArrowResultHandler(context, rowSetProvider, sampleArrowSchema);
64+
const result = new ArrowResultHandler(context, rowSetProvider, { arrowSchema: sampleArrowSchema });
6565

6666
const batches = await result.fetchNext({ limit: 10000 });
6767
expect(await rowSetProvider.hasMore()).to.be.false;
@@ -74,7 +74,10 @@ describe('ArrowResultHandler', () => {
7474
it('should handle LZ4 compressed data', async () => {
7575
const context = {};
7676
const rowSetProvider = new ResultsProviderMock([sampleRowSet1LZ4Compressed]);
77-
const result = new ArrowResultHandler(context, rowSetProvider, sampleArrowSchema, true);
77+
const result = new ArrowResultHandler(context, rowSetProvider, {
78+
arrowSchema: sampleArrowSchema,
79+
lz4Compressed: true,
80+
});
7881

7982
const batches = await result.fetchNext({ limit: 10000 });
8083
expect(await rowSetProvider.hasMore()).to.be.false;
@@ -87,7 +90,7 @@ describe('ArrowResultHandler', () => {
8790
it('should not buffer any data', async () => {
8891
const context = {};
8992
const rowSetProvider = new ResultsProviderMock([sampleRowSet1]);
90-
const result = new ArrowResultHandler(context, rowSetProvider, sampleArrowSchema);
93+
const result = new ArrowResultHandler(context, rowSetProvider, { arrowSchema: sampleArrowSchema });
9194
expect(await rowSetProvider.hasMore()).to.be.true;
9295
expect(await result.hasMore()).to.be.true;
9396

@@ -100,34 +103,61 @@ describe('ArrowResultHandler', () => {
100103
const context = {};
101104
case1: {
102105
const rowSetProvider = new ResultsProviderMock();
103-
const result = new ArrowResultHandler(context, rowSetProvider, sampleArrowSchema);
106+
const result = new ArrowResultHandler(context, rowSetProvider, { arrowSchema: sampleArrowSchema });
104107
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([]);
105108
expect(await result.hasMore()).to.be.false;
106109
}
107110
case2: {
108111
const rowSetProvider = new ResultsProviderMock([sampleRowSet2]);
109-
const result = new ArrowResultHandler(context, rowSetProvider, sampleArrowSchema);
112+
const result = new ArrowResultHandler(context, rowSetProvider, { arrowSchema: sampleArrowSchema });
110113
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([]);
111114
expect(await result.hasMore()).to.be.false;
112115
}
113116
case3: {
114117
const rowSetProvider = new ResultsProviderMock([sampleRowSet3]);
115-
const result = new ArrowResultHandler(context, rowSetProvider, sampleArrowSchema);
118+
const result = new ArrowResultHandler(context, rowSetProvider, { arrowSchema: sampleArrowSchema });
116119
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([]);
117120
expect(await result.hasMore()).to.be.false;
118121
}
119122
case4: {
120123
const rowSetProvider = new ResultsProviderMock([sampleRowSet4]);
121-
const result = new ArrowResultHandler(context, rowSetProvider, sampleArrowSchema);
124+
const result = new ArrowResultHandler(context, rowSetProvider, { arrowSchema: sampleArrowSchema });
122125
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([]);
123126
expect(await result.hasMore()).to.be.false;
124127
}
125128
});
126129

130+
it('should infer arrow schema from thrift schema', async () => {
131+
const context = {};
132+
const rowSetProvider = new ResultsProviderMock([sampleRowSet2]);
133+
134+
const sampleThriftSchema = {
135+
columns: [
136+
{
137+
columnName: '1',
138+
typeDesc: {
139+
types: [
140+
{
141+
primitiveEntry: {
142+
type: 3,
143+
typeQualifiers: null,
144+
},
145+
},
146+
],
147+
},
148+
position: 1,
149+
},
150+
],
151+
};
152+
153+
const result = new ArrowResultHandler(context, rowSetProvider, { schema: sampleThriftSchema });
154+
expect(result.arrowSchema).to.not.be.undefined;
155+
});
156+
127157
it('should return empty array if no schema available', async () => {
128158
const context = {};
129159
const rowSetProvider = new ResultsProviderMock([sampleRowSet2]);
130-
const result = new ArrowResultHandler(context, rowSetProvider);
160+
const result = new ArrowResultHandler(context, rowSetProvider, {});
131161
expect(await result.fetchNext({ limit: 10000 })).to.be.deep.eq([]);
132162
expect(await result.hasMore()).to.be.false;
133163
});

tests/unit/result/CloudFetchResultHandler.test.js

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ describe('CloudFetchResultHandler', () => {
8686
getConfig: () => clientConfig,
8787
};
8888

89-
const result = new CloudFetchResultHandler(context, rowSetProvider);
89+
const result = new CloudFetchResultHandler(context, rowSetProvider, {});
9090

9191
case1: {
9292
result.pendingLinks = [];
@@ -119,7 +119,7 @@ describe('CloudFetchResultHandler', () => {
119119
getConfig: () => clientConfig,
120120
};
121121

122-
const result = new CloudFetchResultHandler(context, rowSetProvider);
122+
const result = new CloudFetchResultHandler(context, rowSetProvider, {});
123123

124124
sinon.stub(result, 'fetch').returns(
125125
Promise.resolve({
@@ -153,7 +153,7 @@ describe('CloudFetchResultHandler', () => {
153153
getConfig: () => clientConfig,
154154
};
155155

156-
const result = new CloudFetchResultHandler(context, rowSetProvider);
156+
const result = new CloudFetchResultHandler(context, rowSetProvider, {});
157157

158158
sinon.stub(result, 'fetch').returns(
159159
Promise.resolve({
@@ -213,7 +213,7 @@ describe('CloudFetchResultHandler', () => {
213213
getConfig: () => clientConfig,
214214
};
215215

216-
const result = new CloudFetchResultHandler(context, rowSetProvider, true);
216+
const result = new CloudFetchResultHandler(context, rowSetProvider, { lz4Compressed: true });
217217

218218
const expectedBatch = Buffer.concat([sampleArrowSchema, sampleArrowBatch]);
219219

@@ -244,7 +244,7 @@ describe('CloudFetchResultHandler', () => {
244244
getConfig: () => clientConfig,
245245
};
246246

247-
const result = new CloudFetchResultHandler(context, rowSetProvider);
247+
const result = new CloudFetchResultHandler(context, rowSetProvider, {});
248248

249249
sinon.stub(result, 'fetch').returns(
250250
Promise.resolve({
@@ -275,7 +275,7 @@ describe('CloudFetchResultHandler', () => {
275275
getConfig: () => clientConfig,
276276
};
277277

278-
const result = new CloudFetchResultHandler(context, rowSetProvider);
278+
const result = new CloudFetchResultHandler(context, rowSetProvider, {});
279279

280280
sinon.stub(result, 'fetch').returns(
281281
Promise.resolve({

0 commit comments

Comments
 (0)