訓練模型以對視訊進行分類
對於這個例子,讓 model 為 Keras 模型,用於對視訊輸入進行分類,讓 X 為視訊輸入的大資料集,形狀為 (樣本,幀,通道,行,列) ,讓 Y 為相應的資料集單熱編碼標籤,形狀為 (樣本,類) 。兩個資料集都儲存在名為 video_data.h5 的 HDF5 檔案中。HDF5 檔案還具有樣本數量的 sample_count 屬性。
以下是使用 fit_generator 訓練模型的功能
def train_model(model, video_data_fn="video_data.h5", validation_ratio=0.3, batch_size=32):
""" Train the video classification model
"""
with h5py.File(video_data_fn, "r") as video_data:
sample_count = int(video_data.attrs["sample_count"])
sample_idxs = range(0, sample_count)
sample_idxs = np.random.permutation(sample_idxs)
training_sample_idxs = sample_idxs[0:int((1-validation_ratio)*sample_count)]
validation_sample_idxs = sample_idxs[int((1-validation_ratio)*sample_count):]
training_sequence_generator = generate_training_sequences(batch_size=batch_size,
video_data=video_data,
training_sample_idxs=training_sample_idxs)
validation_sequence_generator = generate_validation_sequences(batch_size=batch_size,
video_data=video_data,
validation_sample_idxs=validation_sample_idxs)
model.fit_generator(generator=training_sequence_generator,
validation_data=validation_sequence_generator,
samples_per_epoch=len(training_sample_idxs),
nb_val_samples=len(validation_sample_idxs),
nb_epoch=100,
max_q_size=1,
verbose=2,
class_weight=None,
nb_worker=1)
以下是培訓和驗證序列生成器
def generate_training_sequences(batch_size, video_data, training_sample_idxs):
""" Generates training sequences on demand
"""
while True:
# generate sequences for training
training_sample_count = len(training_sample_idxs)
batches = int(training_sample_count/batch_size)
remainder_samples = training_sample_count%batch_size
if remainder_samples:
batches = batches + 1
# generate batches of samples
for idx in xrange(0, batches):
if idx == batches - 1:
batch_idxs = training_sample_idxs[idx*batch_size:]
else:
batch_idxs = training_sample_idxs[idx*batch_size:idx*batch_size+batch_size]
batch_idxs = sorted(batch_idxs)
X = video_data["X"][batch_idxs]
Y = video_data["Y"][batch_idxs]
yield (np.array(X), np.array(Y))
def generate_validation_sequences(batch_size, video_data, validation_sample_idxs):
""" Generates validation sequences on demand
"""
while True:
# generate sequences for validation
validation_sample_count = len(validation_sample_idxs)
batches = int(validation_sample_count/batch_size)
remainder_samples = validation_sample_count%batch_size
if remainder_samples:
batches = batches + 1
# generate batches of samples
for idx in xrange(0, batches):
if idx == batches - 1:
batch_idxs = validation_sample_idxs[idx*batch_size:]
else:
batch_idxs = validation_sample_idxs[idx*batch_size:idx*batch_size+batch_size]
batch_idxs = sorted(batch_idxs)
X = video_data["X"][batch_idxs]
Y = video_data["Y"][batch_idxs]
yield (np.array(X), np.array(Y))