前言

TFRecords是 TensorFlow子集的binary storage format。

如果你正在一个很大的数据集上工作,使用一个binary file来存储data对于提升数据导入的pipeline的性能有很大作用,结果就是训练时间的减少。Binary data占据更少的磁盘空间,copy的时间更少,并且读起来更高效,这个差距在机械硬盘上体现的更明显,因为机械硬盘比SSD读写的速度慢很多。

另外,不光是性能上的优势,TFRecords在多个方面被优化以用于Tensorflow。首先,它非常方便联合多个datasets并且和libarary中提供的预处理方法无缝连接。尤其在datasets太大的时候,这个优势体现在只有当前的data(比如batch)需要load和process。另一个主要优势是TFRecords在存储序列数据(比如时间序列,单词序列)的时候非常高效和方便导入。查看(Reading Data)(https://www.tensorflow.org/api_guides/python/reading_data)来查看更多如何读取TFRecord files。

当然它很不便的地方就是必须要把数据转化为TFRecords格式。官方文档给出了例子,但是在实际使用时候仍然只能看到表面。

Convert to TFRecord

将上面面四张图片存入本地,然后在相同目录下新建一个文件convert.py粘贴下面的内容,就可以将图片转为samples.tfrecord

import tensorflow as tf
import os
import numpy as np
from PIL import Image
import glob

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

filename = "samples.tfrecords"
print("Writing", filename)

img_paths = glob.glob("*.jpg")
                            
with tf.python_io.TFRecordWriter(filename) as writer:
    for img_path in img_paths:
        print(img_path)
        img = np.array(Image.open(img_path))
        height, width = img.shape[0], img.shape[1]
        img_raw = img.tostring()
        example = tf.train.Example(
            features = tf.train.Features(
                feature = {
                    'img_raw': _bytes_feature(img_raw),
                    'width': _int64_feature(width),
                    'height': _int64_feature(height),
                    'label': _int64_feature(1)
                }
            )
        )
        writer.write(example.SerializeToString())

官方文档中对于tf.train.Example的描述并不充分。这是由于tf.train.Example不是一个Python class,而是一个protocol buffer。Protocol buffer是Google开发的序列化结构数据的方法。TFRecords有两类主要结构tf.train.Exampletf.train.SequenceExample。这里只用到了前者。 因为tf.train.Example适合每个特征都有相同的类型的时候,比如所有年龄都用一个整型表示

上面的代码,为每个样本,也就是每个图片定一个四个feature(img_raw, width, height, label),每一个feature都用tf.train.Feature来表示,包装好同类型的list of data。

tf.train.Feature的核心是tf.train.BytesList, tf.train.FloatList, tf.train.Int64List,这三者在创建的时候分别需要赋值一个包含bytes, float, int数据的list作为value。

多个Feature组成tf.train.Features。它在创建的时候只有一个关键字feature=,需要输入一个字典,key是feature的名字,value是tf.train.Feature

最后,Features会被包装进tf.train.Example,然后tf.python_io.TFRecordWriter会将序列化后的tf.train.Example写入磁盘。和file handler一样,tf.python_io.TFRecordWriter包含write,flush,close方法,这是使用with语句来保证写入完成后关闭writer。

Read TFRecord

实际使用中经常会看到两类方法来读取TFRecord, 一类使用Threading and Queues,另一类使用tf.data API。实际上在tensorflow1.2之前,推荐使用多线程队列的方法,而在tensorflow1.4之后,这个老旧的方法逐渐被抛弃,tf.data的接口更加简单,使得数据输入的pipeline的构建更加容易。

这里对两种方法都会给出例子:

Using tf.data API

import tensorflow as tf
import numpy as np
import os

filename = "samples.tfrecords"
IMG_HEIGHT = 224
IMG_WIDTH = 224
BATCH_SIZE = 2
NUM_EPOCHS = 2

def decode(serialized_example):
    features = tf.parse_single_example(
        serialized_example,
        features={
            'img_raw': tf.FixedLenFeature([], tf.string),
            'width': tf.FixedLenFeature([], tf.int64),
            'height': tf.FixedLenFeature([], tf.int64),
            'label': tf.FixedLenFeature([], tf.int64)
        }
    )
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    label = tf.cast(features['label'], tf.int32)
    height, width = features['height'], features['width']
    image = tf.reshape(image, [height, width, 3])
    return image, label

def resize(image, label):
    resized_img = tf.image.resize_image_with_crop_or_pad(
        image=image,
        target_height=IMG_HEIGHT,
        target_width=IMG_WIDTH
    )
    return resized_img, label

def inputs(batch_size, num_epochs):
    with tf.name_scope("input"):
        dataset  = tf.data.TFRecordDataset(filename)
        dataset = dataset.map(decode)
        dataset = dataset.map(resize)
        dataset = dataset.repeat(num_epochs)
        dataset = dataset.batch(batch_size)
        iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()

def run():
    with tf.Graph().as_default():
        image_batch, label_batch = inputs(BATCH_SIZE, NUM_EPOCHS)
        with tf.Session() as sess:
            try:
                while True:
                    images, labels = sess.run([image_batch, label_batch])
                    print(images.shape, labels)
            except tf.errors.OutOfRangeError:
                print("done.")

if __name__ == "__main__":
    run()

官方文档中的Importing Data章节对于Dataset这个类如何使用给出了详细描述。

Using Queue

import tensorflow as tf


filename = 'samples.tfrecords'

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()

    _, serialized_example = reader.read(filename_queue)

    feature = {
        'img_raw':tf.FixedLenFeature([], tf.string),
        'width':tf.FixedLenFeature([], tf.int64),
        'height': tf.FixedLenFeature([], tf.int64),
        'label': tf.FixedLenFeature([], tf.int64)
    }
    features = tf.parse_single_example(serialized_example, features=feature)

    image = tf.decode_raw(features['img_raw'], tf.uint8)
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    image = tf.reshape(image, tf.stack([height, width, 3]))
    resized_image = tf.image.resize_image_with_crop_or_pad(image=image, target_height=224, target_width=224)
    label = tf.cast(features['label'], tf.int32)

    images, labels = tf.train.shuffle_batch([resized_image, label], batch_size=2, capacity=30, num_threads=2, min_after_dequeue=10)
    return images, labels

if __name__ == "__main__":
    filename_queue = tf.train.string_input_producer([filename], num_epochs=2)
    _, labels = read_and_decode(filename_queue)
    init_op = tf.group([tf.global_variables_initializer(), tf.local_variables_initializer()])
    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        try:
            while True:
                lbl = sess.run([labels])
                print(lbl)
        except tf.errors.OutOfRangeError as e:
            coord.request_stop(e)
        coord.request_stop(e)
        coord.join(threads)

这里对于多线程和队列的用法给出了详细描述。

参考