[问题] 生成对抗训练程式

楼主: psw (ICK)   2023-08-29 09:54:24
如题
这个程式训练一些照片
最后把训练的鉴别网络权重参数结果存在TESTgen/discriminator_weights.h5中
但后来要加载TESTgen/discriminator_weights.h5这个参数鉴别网络时却不断说discrimi
nator_weights.h5
里有问题
我打开discriminator_weights.h5中看起来是网络参数
跟float32浮点数格式
但要加载用来辨识其他照片时却说无法加载HTF5格式
我用的是tensrflow GPU
求跪强者们开示
谢谢
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
#缺Keras HDF5 格式
# 设定图像参数
img_rows = 28
img_cols = 28
channels = 1
# 设定生成器
def build_generator():
? ? noise_shape = (100,)
? ? model = Sequential()
? ? model.add(Dense(256, input_shape=noise_shape))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(512))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(1024))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(img_rows * img_cols * channels, activation='tanh'))
? ? model.add(Reshape((img_rows, img_cols, channels)))
? ? model.summary()
? ? noise = Input(shape=noise_shape)
? ? img = model(noise)
? ? return Model(noise, img)
# 设定鉴别器
def build_discriminator():
? ? model = Sequential()
? ? model.add(Flatten(input_shape=(img_rows, img_cols, channels)))
? ? model.add(Dense((img_rows * img_cols * channels), input_shape=(img_rows, i
mg_cols, channels)))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(Dense(int((img_rows * img_cols * channels) / 2)))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(Dense(1, activation='sigmoid'))
? ? model.summary()
? ? img = Input(shape=(img_rows, img_cols, channels))
? ? validity = model(img)
? ? return Model(img, validity)
# 设定生成器和对抗器
generator = build_generator()
discriminator = build_discriminator()
# 编译鉴别器
discriminator.compile(loss='binary_crossentropy',
? ? ? ? ? ? ? ? ? ? ? optimizer=Adam(0.0002, 0.5),
? ? ? ? ? ? ? ? ? ? ? metrics=['accuracy'])
# 建立结合模型
z = Input(shape=(100,))
img = generator(z)
discriminator.trainable = False
validity = discriminator(img)
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy',
? ? ? ? ? ? ? ? 漑ptimizer=Adam(0.0002, 0.5))
# 加载并预处理MNIST资料集
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
# 定义训练参数
epochs = 3000
batch_size = 128
save_interval = 100
# 定义图像标签
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
# 训练生成器和鉴别器
for epoch in range(epochs):
? ? # 训练鉴别器
? ? idx = np.random.randint(0, X_train.shape[0], batch_size)
? ? imgs = X_train[idx]
? ? noise = np.random.normal(0, 1, (batch_size, 100))
? ? gen_imgs = generator.predict(noise)
? ? d_loss_real = discriminator.train_on_batch(imgs, valid)
? ? d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
? ? d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
? ? # 训练生成器
? ? noise = np.random.normal(0, 1, (batch_size, 100))
? ? g_loss = combined.train_on_batch(noise, valid)
? ? # 显示训练进度
? ? if epoch % save_interval == 0:
? ? ? ? print(f"Epoch {epoch}/{epochs}, D loss: {d_loss[0]}, acc.: {100 * d_lo
ss[1]}, G loss: {g_loss}")
? ? ? ? # 显示生成的图像
? ? ? ? r, c = 2, 2
? ? ? ? noise = np.random.normal(0, 1, (r * c, 100))
? ? ? ? gen_imgs = generator.predict(noise)
? ? ? ? fig, axs = plt.subplots(r, c)
? ? ? ? cnt = 0
? ? ? ? for i in range(r):
? ? ? ? ? ? for j in range(c):
? ? ? ? ? ? ? ? axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
? ? ? ? ? ? ? ? axs[i, j].axis('off')
? ? ? ? ? ? ? ? cnt += 1
? ? ? ? plt.show()
? ?
# 将生成网络和鉴别器的参数保存到TESTgen资料夹中
os.makedirs("TESTgen", exist_ok=True)
generator.save_weights("TESTgen/generator_weights.h5")
discriminator.save_weights("TESTgen/discriminator_weights.h5", save_format="h5
")
with open("TESTgen.txt", "w") as f:
? ? f.write("Generator and discriminator parameters saved.")
print("训练完成并保存生成网络和鉴别器参数。") ? ?
? ?
?
? ? import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
# 汇入所需的库和模组
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
# 设定图像参数
img_rows = 28
img_cols = 28
channels = 1
# 设定生成器
def build_generator():
? ? noise_shape = (100,)
? ? model = Sequential()
? ? model.add(Dense(256, input_shape=noise_shape))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(512))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(1024))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(img_rows * img_cols * channels, activation='tanh'))
? ? model.add(Reshape((img_rows, img_cols, channels)))
? ? model.summary()
? ? noise = Input(shape=noise_shape)
? ? img = model(noise)
? ? return Model(noise, img)
# 建立生成器模型
def build_generator():
? ? noise_shape = (100,)
? ? model = Sequential()
? ? model.add(Dense(256, input_shape=noise_shape)) ?# 全连接层,输入是噪音
? ? model.add(LeakyReLU(alpha=0.2)) ?# LeakyReLU 激活函数
? ? model.add(BatchNormalization(momentum=0.8)) ?# BatchNormalization 正规化
? ? model.add(Dense(512))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(1024))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(img_rows * img_cols * channels, activation='tanh')) ?# 生
成器输出,使用 tanh 激活函数
? ? model.add(Reshape((img_rows, img_cols, channels))) ?# 重塑输出形状
? ? model.summary()
? ? noise = Input(shape=noise_shape) ?# 噪音输入
? ? img = model(noise) ?# 使用模型生成图像
? ? return Model(noise, img) ?# 返回噪音和生成图像模型
# 设定鉴别器
def build_discriminator():
? ? model = Sequential()
? ? model.add(Flatten(input_shape=(img_rows, img_cols, channels))) ?# 将图像展
平为一维
? ? model.add(Dense((img_rows * img_cols * channels), input_shape=(img_rows, i
mg_cols, channels))) ?# 全连接层
? ? model.add(LeakyReLU(alpha=0.2)) ?# LeakyReLU 激活函数
? ? model.add(Dense(int((img_rows * img_cols * channels) / 2)))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(Dense(1, activation='sigmoid')) ?# 预测真假的输出,使用 sigmoid
激活函数
? ? model.summary()
? ? img = Input(shape=(img_rows, img_cols, channels)) ?# 图像输入
? ? validity = model(img) ?# 使用模型判断真假
? ? return Model(img, validity) ?# 返回图像和判断真假模型
# 建立生成器和鉴别器
generator = build_generator() ?# 创建生成器模型
discriminator = build_discriminator() ?# 创建鉴别器模型
# 编译鉴别器
discriminator.compile(loss='binary_crossentropy',
? ? ? ? ? ? ? ? ? ? ? optimizer=Adam(0.0002, 0.5),
? ? ? ? ? ? ? ? ? ? ? metrics=['accuracy'])
# 建立结合模型
z = Input(shape=(100,))
img = generator(z)
discriminator.trainable = False ?# 在结合模型中,鉴别器权重冻结
validity = discriminator(img)
combined = Model(z, validity) ?# 创建结合模型,输入噪音,输出真假
combined.compile(loss='binary_crossentropy',
? ? ? ? ? ? ? ? 漑ptimizer=Adam(0.0002, 0.5))
# 加载并预处理MNIST资料集
(X_train, _), (_, _) = mnist.load_data() ?# 加载MNIST数据集
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 ?# 正规化数据到-1到1之

X_train = np.expand_dims(X_train, axis=3) ?# 增加一个维度(通道)
# 定义训练参数
epochs = 3000 ?# 训练迭代次数
batch_size = 128 ?# 批次大小
save_interval = 100 ?# 每隔多少个迭代保存模型
# 定义图像标签
valid = np.ones((batch_size, 1)) ?# 真实标签
fake = np.zeros((batch_size, 1)) ?# 假标签
# 训练生成器和鉴别器
for epoch in range(epochs):
? ? # 训练鉴别器
? ? idx = np.random.randint(0, X_train.shape[0], batch_size)
? ? imgs = X_train[idx] ?# 随机选取真实图像
? ? noise = np.random.normal(0, 1, (batch_size, 100)) ?#
?
(X_tr
X_tra

X_tra
作者: lycantrope (阿宽)   2023-08-29 10:52:00
问GPT,不经大脑复制贴上,也没写你是怎么加载h5
作者: tsoahans (ㄎㄎ)   2023-08-29 15:04:00
save_weights对应load_weights model.save对load_model
作者: lycantrope (阿宽)   2023-08-29 15:18:00
同楼上,从model= build_discriminator()产生model后model.load_weights才对
作者: chang1248w (彩棠)   2023-09-14 22:58:00
超热心ww

Links booklink

Contact Us: admin [ a t ] ucptt.com