|
| 1 | +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +from __future__ import absolute_import |
| 14 | + |
| 15 | +import argparse |
| 16 | +from random import randint |
| 17 | +import struct |
| 18 | +import sys |
| 19 | + |
| 20 | +import numpy as np |
| 21 | +import tensorflow as tf |
| 22 | + |
| 23 | +# Utility functions for generating a recordio encoded file of labeled numpy data |
| 24 | +# for testing. Each file contains one or more records. Each record is a TensorFlow |
| 25 | +# protobuf Example object. Each object contains an integer label and a numpy array |
| 26 | +# encoded as a byte list. |
| 27 | + |
| 28 | +# This file can be used in script mode to generate a single file or be used |
| 29 | +# as a module to generate files via build_record_file. |
| 30 | + |
| 31 | +_kmagic = 0xced7230a |
| 32 | + |
| 33 | +padding = {} |
| 34 | +for amount in range(4): |
| 35 | + if sys.version_info >= (3,): |
| 36 | + padding[amount] = bytes([0x00 for _ in range(amount)]) |
| 37 | + else: |
| 38 | + padding[amount] = bytearray([0x00 for _ in range(amount)]) |
| 39 | + |
| 40 | + |
| 41 | +def write_recordio(f, data, header_flag=0): |
| 42 | + """Writes a single data point as a RecordIO record to the given file.""" |
| 43 | + length = len(data) |
| 44 | + f.write(struct.pack('I', _kmagic)) |
| 45 | + header = (header_flag << 29) | length |
| 46 | + f.write(struct.pack('I', header)) |
| 47 | + pad = (((length + 3) >> 2) << 2) - length |
| 48 | + f.write(data) |
| 49 | + f.write(padding[pad]) |
| 50 | + |
| 51 | + |
| 52 | +def write_recordio_multipart(f, data): |
| 53 | + """Writes a single data point into three multipart records.""" |
| 54 | + length = len(data) |
| 55 | + stride = int(length / 3) |
| 56 | + |
| 57 | + data_start = data[0:stride] |
| 58 | + data_middle = data[stride:2 * stride] |
| 59 | + data_end = data[2 * stride:] |
| 60 | + |
| 61 | + write_recordio(f, data_start, 1) |
| 62 | + write_recordio(f, data_middle, 2) |
| 63 | + write_recordio(f, data_end, 3) |
| 64 | + |
| 65 | + |
| 66 | +def string_feature(value): |
| 67 | + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.tostring()])) |
| 68 | + |
| 69 | + |
| 70 | +def label_feature(value): |
| 71 | + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) |
| 72 | + |
| 73 | + |
| 74 | +def write_numpy_array(f, feature_name, label, arr, multipart=False): |
| 75 | + feature = {'labels': label_feature(label), feature_name: string_feature(arr)} |
| 76 | + example = tf.train.Example(features=tf.train.Features(feature=feature)) |
| 77 | + if multipart: |
| 78 | + write_recordio_multipart(f, example.SerializeToString()) |
| 79 | + else: |
| 80 | + write_recordio(f, example.SerializeToString()) |
| 81 | + |
| 82 | + |
| 83 | +def build_record_file(filename, num_records, dimension, classes=2, data_feature_name='data', multipart=False): |
| 84 | + """Builds a recordio encoded file of TF protobuf Example objects. Each object |
| 85 | + is a labeled numpy array. Each example has two field - a single int64 'label' |
| 86 | + field and a single bytes list field, containing a serialized numpy array. |
| 87 | +
|
| 88 | + Each generated numpy array is a multidimensional normal with |
| 89 | + the specified dimension. The normal distribution is class specific, each class |
| 90 | + has a different mean for the distribution, so it should be possible to learn |
| 91 | + a multiclass classifier on this data. Class means are determnistic - so multiple |
| 92 | + calls to this function with the same number of classes will produce samples drawn |
| 93 | + from the same distribution for each class. |
| 94 | +
|
| 95 | + Args: |
| 96 | + filename - the file to write to |
| 97 | + num_records - how many labeled numpy arrays to generate |
| 98 | + classes - the cardinality of labels |
| 99 | + data_feature_name - the name to give the numpy array in the Example object |
| 100 | + dimension - the size of each numpy array. |
| 101 | + """ |
| 102 | + with open(filename, 'wb') as f: |
| 103 | + for i in range(num_records): |
| 104 | + cur_class = i % classes |
| 105 | + loc = int(cur_class - (classes / 2)) |
| 106 | + write_numpy_array(f, data_feature_name, cur_class, np.random.normal(loc=loc, size=(dimension,)), multipart) |
| 107 | + |
| 108 | + |
| 109 | +def build_single_record_file(filename, dimension, classes=2, data_feature_name='data'): |
| 110 | + cur_class = randint(0, classes - 1) |
| 111 | + loc = int(cur_class - (classes / 2)) |
| 112 | + |
| 113 | + arr = np.random.normal(loc=loc, size=(dimension,)) |
| 114 | + feature = {'labels': label_feature(cur_class), data_feature_name: string_feature(arr)} |
| 115 | + example = tf.train.Example(features=tf.train.Features(feature=feature)) |
| 116 | + with open(filename, 'wb') as f: |
| 117 | + f.write(example.SerializeToString()) |
| 118 | + |
| 119 | + |
| 120 | +def validate_record_file(filename, dimension): |
| 121 | + data = open(filename, 'rb').read() |
| 122 | + magic_number, length = struct.unpack('II', data[0:8]) |
| 123 | + encoded = data[8:8 + length] |
| 124 | + |
| 125 | + features = { |
| 126 | + 'data': tf.io.FixedLenFeature([], tf.string), |
| 127 | + 'labels': tf.io.FixedLenFeature([], tf.int64), |
| 128 | + } |
| 129 | + parsed = tf.io.parse_single_example(encoded, features) |
| 130 | + array = tf.io.decode_raw(parsed['data'], tf.float64) |
| 131 | + |
| 132 | + assert array.shape[0] == dimension |
| 133 | + |
| 134 | + |
| 135 | +if __name__ == '__main__': |
| 136 | + parser = argparse.ArgumentParser(description="Generate synthetic multi-class training data") |
| 137 | + parser.add_argument('--dimension', default=65536, type=int) |
| 138 | + parser.add_argument('--classes', default=2, type=int) |
| 139 | + parser.add_argument('--num-records', default=4, type=int) |
| 140 | + parser.add_argument('--data-feature-name', default='data') |
| 141 | + parser.add_argument('filename', type=str) |
| 142 | + args = parser.parse_args() |
| 143 | + build_record_file(args.filename, args.num_records, args.dimension, args.classes, args.data_feature_name) |
| 144 | + validate_record_file(args.filename, args.dimension) |
0 commit comments