Skip to content

Commit b7995af

Browse files
author
chengyitian
committed
AJ-703: support tensor;
1 parent e575c16 commit b7995af

File tree

5 files changed

+191
-1
lines changed

5 files changed

+191
-1
lines changed

src/com/xxdb/data/AbstractTensor.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package com.xxdb.data;
2+
3+
public abstract class AbstractTensor extends AbstractEntity implements Tensor {
4+
5+
@Override
6+
public DATA_FORM getDataForm() {
7+
return DATA_FORM.DF_TENSOR;
8+
}
9+
10+
@Override
11+
public int columns() {
12+
return 1;
13+
}
14+
15+
}

src/com/xxdb/data/BasicEntityFactory.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ else if(form == Entity.DATA_FORM.DF_SET)
7777
return new BasicSet(type, in);
7878
else if(form == Entity.DATA_FORM.DF_CHUNK)
7979
return new BasicChunkMeta(in);
80+
else if (form == Entity.DATA_FORM.DF_TENSOR)
81+
return new BasicTensor(type, in);
8082
else if(type == Entity.DATA_TYPE.DT_ANY && (form == Entity.DATA_FORM.DF_VECTOR || form == Entity.DATA_FORM.DF_PAIR))
8183
return new BasicAnyVector(in);
8284
else if(type.getValue() >= Entity.DATA_TYPE.DT_BOOL_ARRAY.getValue() && type.getValue() <= DT_DECIMAL128_ARRAY.getValue())

src/com/xxdb/data/BasicTensor.java

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
package com.xxdb.data;
2+
3+
import com.xxdb.io.ExtendedDataInput;
4+
import com.xxdb.io.ExtendedDataOutput;
5+
import java.io.IOException;
6+
7+
public class BasicTensor extends AbstractTensor {
8+
9+
private DATA_TYPE dataType;
10+
private int tensorType;
11+
private int deviceType;
12+
private int tensorFlags;
13+
private int dimensions;
14+
15+
/**
16+
* shapes: shape[i] represents the size of the i-th dimension.
17+
*/
18+
private long[] shapes;
19+
20+
/**
21+
* strides: strides[i] represents the distance between elements in the i-th dimension.
22+
*/
23+
private long[] strides;
24+
25+
private long preserveValue;
26+
27+
private long elemCount;
28+
29+
private Vector data;
30+
31+
protected BasicTensor(DATA_TYPE dataType, ExtendedDataInput in) throws IOException {
32+
deserialize(dataType, in);
33+
}
34+
35+
protected void deserialize(DATA_TYPE dataType, ExtendedDataInput in) throws IOException {
36+
this.dataType = dataType;
37+
tensorType = in.readByte();
38+
deviceType = in.readByte();
39+
tensorFlags = in.readInt();
40+
dimensions = in.readInt();
41+
42+
shapes = new long[dimensions];
43+
strides = new long[dimensions];
44+
45+
for (int d = 0; d < dimensions; d++)
46+
shapes[d] = in.readLong();
47+
48+
for (int d = 0; d < dimensions; d++)
49+
strides[d] = in.readLong();
50+
51+
preserveValue = in.readLong();
52+
elemCount = in.readLong();
53+
54+
if (elemCount > Integer.MAX_VALUE)
55+
throw new RuntimeException("tensor element count more than 2,147,483,647(Integer.MAX_VALUE).");
56+
57+
Vector subVector = BasicEntityFactory.instance().createVectorWithDefaultValue(dataType, (int) elemCount, -1);
58+
subVector.deserialize(0, (int) elemCount, in);
59+
this.data = subVector;
60+
}
61+
62+
@Override
63+
public DATA_CATEGORY getDataCategory() {
64+
return getDataCategory(dataType);
65+
}
66+
67+
@Override
68+
public DATA_TYPE getDataType() {
69+
return dataType;
70+
}
71+
72+
@Override
73+
public int rows() {
74+
return data.rows();
75+
}
76+
77+
@Override
78+
public void write(ExtendedDataOutput output) throws IOException {
79+
throw new RuntimeException("BasicTensor not support write method.");
80+
}
81+
82+
public int getDimensions() {
83+
return dimensions;
84+
}
85+
86+
public long[] getShapes() {
87+
return shapes;
88+
}
89+
90+
public long[] getStrides() {
91+
return strides;
92+
}
93+
94+
public long getElemCount() {
95+
return elemCount;
96+
}
97+
98+
public Vector getData() {
99+
return data;
100+
}
101+
102+
@Override
103+
public String getString() {
104+
StringBuilder sb = new StringBuilder();
105+
sb.append("tensor<").append(getDataTypeString());;
106+
for (long shape : shapes) {
107+
sb.append("[").append(shape).append("]");
108+
}
109+
sb.append(">(");
110+
printTensor(sb, 0, 0, new int[dimensions]);
111+
sb.append(")");
112+
return sb.toString();
113+
}
114+
115+
private void printTensor(StringBuilder sb, int depth, int index, int[] indices) {
116+
if (depth == dimensions) {
117+
int flatIndex = getFlatIndex(indices);
118+
sb.append(data.get(flatIndex));
119+
return;
120+
}
121+
122+
sb.append("[");
123+
long size = shapes[depth];
124+
for (int i = 0; i < size; i++) {
125+
indices[depth] = i;
126+
if (depth == dimensions - 1 && size > 5 && i == 5) {
127+
sb.append("...");
128+
break;
129+
} else {
130+
if (i > 0) {
131+
sb.append(",");
132+
}
133+
printTensor(sb, depth + 1, index * (int) size + i, indices);
134+
}
135+
}
136+
sb.append("]");
137+
}
138+
139+
private String getDataTypeString() {
140+
switch (dataType) {
141+
case DT_BOOL:
142+
return "bool";
143+
case DT_BYTE:
144+
return "char";
145+
case DT_SHORT:
146+
return "short";
147+
case DT_INT:
148+
return "int";
149+
case DT_LONG:
150+
return "long";
151+
case DT_FLOAT:
152+
return "float";
153+
case DT_DOUBLE:
154+
return "double";
155+
default:
156+
throw new IllegalArgumentException("Unsupported data type: " + dataType);
157+
}
158+
}
159+
160+
private int getFlatIndex(int[] indices) {
161+
int flatIndex = 0;
162+
int stride = 1;
163+
for (int i = dimensions - 1; i >= 0; i--) {
164+
flatIndex += indices[i] * stride;
165+
stride *= shapes[i];
166+
}
167+
return flatIndex;
168+
}
169+
}

src/com/xxdb/data/Entity.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public static DATA_TYPE valueOfTypeName(String name){
6464
};
6565

6666
enum DATA_CATEGORY {NOTHING,LOGICAL,INTEGRAL,FLOATING,TEMPORAL,LITERAL,SYSTEM,MIXED,BINARY,ARRAY,DENARY};
67-
enum DATA_FORM {DF_SCALAR,DF_VECTOR,DF_PAIR,DF_MATRIX,DF_SET,DF_DICTIONARY,DF_TABLE,DF_CHART,DF_CHUNK,DF_SYSOBJ};
67+
enum DATA_FORM {DF_SCALAR,DF_VECTOR,DF_PAIR,DF_MATRIX,DF_SET,DF_DICTIONARY,DF_TABLE,DF_CHART,DF_CHUNK,DF_SYSOBJ,DF_TENSOR};
6868
enum PARTITION_TYPE {SEQ, VALUE, RANGE, LIST, COMPO, HASH}
6969
enum DURATION {NS, US, MS, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, YEAR, BDAY, TDAY};
7070
DATA_FORM getDataForm();

src/com/xxdb/data/Tensor.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
package com.xxdb.data;
2+
3+
public interface Tensor extends Entity {
4+
}

0 commit comments

Comments
 (0)