训练模型以对视频进行分类
对于这个例子,让 model 为 Keras 模型,用于对视频输入进行分类,让 X 为视频输入的大数据集,形状为 (样本,帧,通道,行,列) ,让 Y 为相应的数据集单热编码标签,形状为 (样本,类) 。两个数据集都存储在名为 video_data.h5 的 HDF5 文件中。HDF5 文件还具有样本数量的 sample_count 属性。
以下是使用 fit_generator 训练模型的功能
def train_model(model, video_data_fn="video_data.h5", validation_ratio=0.3, batch_size=32):
""" Train the video classification model
"""
with h5py.File(video_data_fn, "r") as video_data:
sample_count = int(video_data.attrs["sample_count"])
sample_idxs = range(0, sample_count)
sample_idxs = np.random.permutation(sample_idxs)
training_sample_idxs = sample_idxs[0:int((1-validation_ratio)*sample_count)]
validation_sample_idxs = sample_idxs[int((1-validation_ratio)*sample_count):]
training_sequence_generator = generate_training_sequences(batch_size=batch_size,
video_data=video_data,
training_sample_idxs=training_sample_idxs)
validation_sequence_generator = generate_validation_sequences(batch_size=batch_size,
video_data=video_data,
validation_sample_idxs=validation_sample_idxs)
model.fit_generator(generator=training_sequence_generator,
validation_data=validation_sequence_generator,
samples_per_epoch=len(training_sample_idxs),
nb_val_samples=len(validation_sample_idxs),
nb_epoch=100,
max_q_size=1,
verbose=2,
class_weight=None,
nb_worker=1)
以下是培训和验证序列生成器
def generate_training_sequences(batch_size, video_data, training_sample_idxs):
""" Generates training sequences on demand
"""
while True:
# generate sequences for training
training_sample_count = len(training_sample_idxs)
batches = int(training_sample_count/batch_size)
remainder_samples = training_sample_count%batch_size
if remainder_samples:
batches = batches + 1
# generate batches of samples
for idx in xrange(0, batches):
if idx == batches - 1:
batch_idxs = training_sample_idxs[idx*batch_size:]
else:
batch_idxs = training_sample_idxs[idx*batch_size:idx*batch_size+batch_size]
batch_idxs = sorted(batch_idxs)
X = video_data["X"][batch_idxs]
Y = video_data["Y"][batch_idxs]
yield (np.array(X), np.array(Y))
def generate_validation_sequences(batch_size, video_data, validation_sample_idxs):
""" Generates validation sequences on demand
"""
while True:
# generate sequences for validation
validation_sample_count = len(validation_sample_idxs)
batches = int(validation_sample_count/batch_size)
remainder_samples = validation_sample_count%batch_size
if remainder_samples:
batches = batches + 1
# generate batches of samples
for idx in xrange(0, batches):
if idx == batches - 1:
batch_idxs = validation_sample_idxs[idx*batch_size:]
else:
batch_idxs = validation_sample_idxs[idx*batch_size:idx*batch_size+batch_size]
batch_idxs = sorted(batch_idxs)
X = video_data["X"][batch_idxs]
Y = video_data["Y"][batch_idxs]
yield (np.array(X), np.array(Y))