
且构网 - 分享程序员编程开发的那些事

预训练的Tensorflow模型RGB-> RGBY通道扩展

更新时间:2022-04-11 01:06:53

使用layer.get_weights()和layer.set_weights()功能. rel ="noreferrer"> Keras API .

Use the layer.get_weights() and layer.set_weights() functions of Keras api.

为4层VGG创建模板结构(设置输入shape=(width, height, 4)).然后将权重从3通道RGB模型加载为4通道作为RGBB.

Create a template structure for 4-layers VGG (set input shape=(width, height, 4)). Then load the weights from 3-channel RGB model into 4-channel as RGBB.


Below is the code that does the procedure. In case of sequential VGG, the only layer that needs to be modified is the first Convolution layer. The structure of the subsequent layers is independent on the number of channels.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from keras.applications.vgg19 import VGG19
from keras.models import Model

vgg19 = VGG19(weights='imagenet')
vgg19.summary() # To check which layers will be omitted in 'pretrained' model

# Load part of the VGG without the top layers into 'pretrained' model
pretrained = Model(inputs=vgg19.input, outputs=vgg19.get_layer('block5_pool').output)

#%% Prepare model template with 4 input channels
config = pretrained.get_config() # run config['layers'][i] for reference
                                 # to restore layer-by layer structure

from keras.layers import Input, Conv2D, MaxPooling2D
from keras import optimizers

# For training from scratch change kernel_initializer to e.g.'VarianceScaling'
inputs = Input(shape=(224, 224, 4), name='input_17')
# block 1
x = Conv2D(64, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block1_conv1')(inputs)
x = Conv2D(64, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block1_conv2')(x)
x = MaxPooling2D(pool_size=(2, 2), name='block1_pool')(x)

# block 2
x = Conv2D(128, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block2_conv1')(x)
x = Conv2D(128, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block2_conv2')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2,2), name='block2_pool')(x)

# block 3
x = Conv2D(256, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block3_conv1')(x)
x = Conv2D(256, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block3_conv2')(x)
x = Conv2D(256, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block3_conv3')(x)
x = Conv2D(256, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block3_conv4')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2,2), name='block3_pool')(x)

# block 4
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block4_conv1')(x)
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block4_conv2')(x)
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block4_conv3')(x)
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block4_conv4')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2,2), name='block4_pool')(x)

# block 5
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block5_conv1')(x)
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block5_conv2')(x)
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block5_conv3')(x)
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block5_conv4')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2,2), name='block5_pool')(x)

vgg_template = Model(inputs=inputs, outputs=x)


#%% Rewrite the weight loading/modification function
import numpy as np

layers_to_modify = ['block1_conv1'] # Turns out the only layer that changes
                                    # shape due to 4th channel is the first
                                    # convolution layer.

for layer in pretrained.layers: # pretrained Model and template have the same
                                # layers, so it doesn't matter which to 
                                # iterate over.

    if layer.get_weights() != []: # Skip input, pooling and no weights layers

        target_layer = vgg_template.get_layer(name=layer.name)

        if layer.name in layers_to_modify:

            kernels = layer.get_weights()[0]
            biases  = layer.get_weights()[1]

            kernels_extra_channel = np.concatenate((kernels,
                                                    axis=-2) # For channels_last

            target_layer.set_weights([kernels_extra_channel, biases])


#%% Save 4 channel model populated with weights for futher use    
