Je construis un Variational Autoencoder (VAE) dans PyTorch et j'ai un problème pour écrire du code indépendant du périphérique. L'Autoencoder est un enfant de nn.Module avec un réseau d'encodeur et de décodeur, qui le sont aussi. Tous les poids du réseau peuvent être déplacés d'un appareil à un autre en appelant net.to (appareil) .
Le problème que j'ai est avec l'astuce de reparamétrisation:
class VariationalGenerator(nn.Module):
def __init__(self, input_nc, output_nc):
super(VariationalGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
embedding_size = 128
self._train_noise = torch.randn(batch_size, embedding_size)
self._eval_noise = torch.randn(1, embedding_size)
self.noise = self._train_noise
# Create encoder
self.encoder = Encoder(input_nc, embedding_size)
# Create decoder
self.decoder = Decoder(output_nc, embedding_size)
def train(self, mode=True):
super(VariationalGenerator, self).train(mode)
self.noise = self._train_noise
def eval(self):
super(VariationalGenerator, self).eval()
self.noise = self._eval_noise
def forward(self, inputs):
# Calculate parameters of embedding space
mu, log_sigma = self.encoder.forward(inputs)
# Resample noise if training
if self.training:
self.noise.normal_()
# Reparametrize noise to embedding space
inputs = mu + self.noise * torch.exp(0.5 * log_sigma)
# Decode to image
inputs = self.decoder(inputs)
return inputs, mu, log_sigma
Le bruit est un tenseur de la même taille que mu et sigma et enregistré comme variable membre du module autoencoder. Il est initialisé dans le constructeur et rééchantillonné sur place à chaque étape d'apprentissage. Je le fais de cette façon pour éviter de construire un nouveau tenseur de bruit à chaque étape et de le pousser vers l'appareil souhaité. De plus, je souhaite corriger le bruit dans l'évaluation. Voici le code:
encoding = mu + noise * sigma
Quand je déplace maintenant l'auto-encodeur vers le GPU avec net.to ('cuda: 0') j'obtiens un erreur de transfert car le tenseur de bruit n'est pas déplacé.
Je ne veux pas ajouter de paramètre de périphérique au constructeur, car il n'est toujours pas possible de le déplacer vers un autre périphérique plus tard. J'ai également essayé d'envelopper le bruit dans nn.Parameter afin qu'il soit affecté par net.to () , mais cela donne une erreur de l'optimiseur, car le bruit est marqué comme requires_grad=False.
N'importe qui a une solution pour déplacer tous les modules avec net.to () ?
p>
3 Réponses :
Utilisez ceci:
net.to(device) input = input.to(device)
Maintenant pour le modèle et chaque tenseur que vous utilisez
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Ce n'était pas le problème. Le problème est que le tenseur de bruit n'est pas déplacé avec les poids lorsque net.to (périphérique) est utilisé.
Après quelques essais et erreurs, j'ai trouvé deux méthodes:
self._train_noise = torch.randn (batch_size, embedding_size) par self.register_buffer ('_ train_noise', torch.randn (batch_size, embedding_size) code > le tenseur de bruit est ajouté au module en tant que tampon. Cela permet à net.to (périphérique) de l'affecter également. De plus, le tenseur fait maintenant partie de state_dict. Remplacer net.to (périphérique) : en utilisant ceci, le bruit reste en dehors de state_dict.
def to(device):
new_self = super(VariationalGenerator, self).to(device)
new_self._train_noise = new_self._train_noise.to(device)
new_self._eval_noise = new_self._eval_noise.to(device)
return new_self
L'approche 2 serait probablement meilleure en remplaçant _apply plutôt que simplement à ; puis .cuda () , etc. fonctionnent tous aussi. ( ma réponse )
Une meilleure version de la deuxième approche de tilman151 consiste probablement à remplacer _apply , plutôt que à . De cette façon, net.cuda () , net.float () , etc. fonctionneront également tous, car ils appellent tous _apply plutôt que à (comme on peut le voir dans la source , qui est plus simple que vous ne le pensez):
def _apply(self, fn):
super(VariationalGenerator, self)._apply(fn)
self._train_noise = fn(self._train_noise)
self._eval_noise = fn(self._eval_noise)
return self