TPUの重み保存

qiita.com

def get_model_weights_as_numpy(model):
    weights = {}
    for v in model.weights:
        # model.weightsで各Layerの重みを取り出し
        # 各variableはnumpyメソッドでnumpy配列に変換できる
        weights[v.name] = v.numpy()
    return {'model_name': model.name, 'weights': weights}


def get_optimizer_weights_as_numpy(optimizer, model):
    weights = {}
    slot_names = optimizer.get_slot_names()
    for v in model.weights:
      try:
        # model.weightsで各Layerの重みを取り出し
        weights[v.name] = {}
        for slot in slot_names:
          # 各Slotに対し、optimizerのget_slotで値を取り出す
          weights[v.name][slot] = optimizer.get_slot(v, slot).numpy()
      except:
        pass
    return {'optimizer_name': optimizer._name, 'weights': weights}


def save_weights_as_pickle(file_prefix, optimizer, model):
    model_weights = get_model_weights_as_numpy(model)
    optimizer_weights = get_optimizer_weights_as_numpy(optimizer, model)
    all_weights = {'model': model_weights, 'optimizer': optimizer_weights}

    with open(file_prefix + '.pkl', 'wb') as f:
        pickle.dump(all_weights, f)


def set_model_weights_from_numpy(weights, model):
    for v in model.weights:
        if v.name in weights.keys():
            v.assign(weights[v.name])
        else:
            print('Not loaded weights: ' + v.name)



def set_optimizer_weights_from_numpy(weights, optimizer, model):
    # optimizerの名前でscopeする
    with tf.name_scope(weights['optimizer_name']):
        optimizer_weights = weights['weights']
        for v in model.weights:
            if v.name in optimizer_weights.keys():
                for slot in optimizer_weights[v.name].keys():
                    # 学習済みの重みを初期値としてslotを作成
                    initializer = tf.initializers.Constant(optimizer_weights[v.name][slot])
                    optimizer.add_slot(v, slot, initializer=initializer)
            else:
                print('Not loaded optimizer weights: ' + v.name)



def load_weights_from_pickle(file_prefix, optimizer, model):
    with open(file_prefix + '.pkl', 'rb') as f:
        weights = pickle.load(f)

    # modelの重みを復元(後述)
    # set_model_weights_from_numpy(weights['model'], model)

    # optimizerの重みを復元(後述)
    set_optimizer_weights_from_numpy(weights['optimizer'], optimizer, model)




        # モデル読み込み
        if epoch > 0:
          model.load_weights(f'{SAVE_MODEL_ROOT}/model.h5')
          load_weights_from_pickle(file_prefix, model.optimizer, model)



        # モデル保存
        model.save_weights(f'{SAVE_MODEL_ROOT}/model.h5')
        save_weights_as_pickle(file_prefix, model.optimizer, model)