TFRecord 是 Tensorflow 官方推荐的可扩展的数据存取格式。

TFRecord 的格式是由一系列带 CRC32C 校验数据的记录组成的。每一条记录的格式如下1

uint64 length
uint32 masked_crc32_of_length
byte   data[length]
uint32 masked_crc32_of_data

TFRecord 的格式中的 data 由 example.proto 定义;example.proto 对应为 tf.train.Example 类。

定义如下:

message Example {
  Features features = 1;
};

message Features {
  map<string, Feature> feature = 1;
};

message Feature {
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

读写示例:

import tensorflow as tf

def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# 写数据
tfrecord_filename = "test.tfrecotd"
writer = tf.python_io.TFRecordWriter(tfrecord_filename)

example = tf.train.Example(features=tf.train.Features(feature={
    'label': int64_feature(0),
    'data': bytes_feature('1234'.encode('utf8'))
}))

writer.write(example.SerializeToString())

example = tf.train.Example(features=tf.train.Features(feature={
    'label': int64_feature(1),
    'data': bytes_feature('abcd'.encode('utf8'))
}))

writer.write(example.SerializeToString())

writer.close()


# 直接读数据
idx = 1
for record in tf.python_io.tf_record_iterator(tfrecord_filename):
    example = tf.train.Example()
    example.ParseFromString(record)
    print('record {}'.format(idx))
    idx = idx + 1
    print(example)

# 通过图和会话读数据
def parse_function(example_proto):
    feature_set = {
        'label': tf.FixedLenFeature((), tf.int64, default_value=0),
        'data': tf.FixedLenFeature((), tf.string, default_value='')
    }

    features = tf.parse_single_example(example_proto, features=feature_set)
    label = features['label']
    data = features['data']
    return label, data

dataset = tf.data.TFRecordDataset([tfrecord_filename])
dataset = dataset.map(parse_function)
iter = dataset.make_one_shot_iterator()
next_element = iter.get_next()

with tf.device('/cpu:0'):
    with tf.Session() as sess:
        print(sess.run(next_element))
        print(sess.run(next_element))