qiita.com
def get_model_weights_as_numpy(model):
weights = {}
for v in model.weights:
weights[v.name] = v.numpy()
return {'model_name': model.name, 'weights': weights}
def get_optimizer_weights_as_numpy(optimizer, model):
weights = {}
slot_names = optimizer.get_slot_names()
for v in model.weights:
try:
weights[v.name] = {}
for slot in slot_names:
weights[v.name][slot] = optimizer.get_slot(v, slot).numpy()
except:
pass
return {'optimizer_name': optimizer._name, 'weights': weights}
def save_weights_as_pickle(file_prefix, optimizer, model):
model_weights = get_model_weights_as_numpy(model)
optimizer_weights = get_optimizer_weights_as_numpy(optimizer, model)
all_weights = {'model': model_weights, 'optimizer': optimizer_weights}
with open(file_prefix + '.pkl', 'wb') as f:
pickle.dump(all_weights, f)
def set_model_weights_from_numpy(weights, model):
for v in model.weights:
if v.name in weights.keys():
v.assign(weights[v.name])
else:
print('Not loaded weights: ' + v.name)
def set_optimizer_weights_from_numpy(weights, optimizer, model):
with tf.name_scope(weights['optimizer_name']):
optimizer_weights = weights['weights']
for v in model.weights:
if v.name in optimizer_weights.keys():
for slot in optimizer_weights[v.name].keys():
initializer = tf.initializers.Constant(optimizer_weights[v.name][slot])
optimizer.add_slot(v, slot, initializer=initializer)
else:
print('Not loaded optimizer weights: ' + v.name)
def load_weights_from_pickle(file_prefix, optimizer, model):
with open(file_prefix + '.pkl', 'rb') as f:
weights = pickle.load(f)
set_optimizer_weights_from_numpy(weights['optimizer'], optimizer, model)
if epoch > 0:
model.load_weights(f'{SAVE_MODEL_ROOT}/model.h5')
load_weights_from_pickle(file_prefix, model.optimizer, model)
model.save_weights(f'{SAVE_MODEL_ROOT}/model.h5')
save_weights_as_pickle(file_prefix, model.optimizer, model)