TFRecord 格式
TFRecord 是 Tensorflow 官方推荐的可扩展的数据存取格式。
TFRecord 的格式是由一系列带 CRC32C 校验数据的记录组成的。每一条记录的格式如下1:
uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data
TFRecord 的格式中的 data 由 example.proto 定义;example.proto 对应为 tf.train.Example 类。
定义如下:
message Example {
Features features = 1;
};
message Features {
map<string, Feature> feature = 1;
};
message Feature {
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
读写示例:
import tensorflow as tf
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 写数据
tfrecord_filename = "test.tfrecotd"
writer = tf.python_io.TFRecordWriter(tfrecord_filename)
example = tf.train.Example(features=tf.train.Features(feature={
'label': int64_feature(0),
'data': bytes_feature('1234'.encode('utf8'))
}))
writer.write(example.SerializeToString())
example = tf.train.Example(features=tf.train.Features(feature={
'label': int64_feature(1),
'data': bytes_feature('abcd'.encode('utf8'))
}))
writer.write(example.SerializeToString())
writer.close()
# 直接读数据
idx = 1
for record in tf.python_io.tf_record_iterator(tfrecord_filename):
example = tf.train.Example()
example.ParseFromString(record)
print('record {}'.format(idx))
idx = idx + 1
print(example)
# 通过图和会话读数据
def parse_function(example_proto):
feature_set = {
'label': tf.FixedLenFeature((), tf.int64, default_value=0),
'data': tf.FixedLenFeature((), tf.string, default_value='')
}
features = tf.parse_single_example(example_proto, features=feature_set)
label = features['label']
data = features['data']
return label, data
dataset = tf.data.TFRecordDataset([tfrecord_filename])
dataset = dataset.map(parse_function)
iter = dataset.make_one_shot_iterator()
next_element = iter.get_next()
with tf.device('/cpu:0'):
with tf.Session() as sess:
print(sess.run(next_element))
print(sess.run(next_element))