Je construis un Variational Autoencoder (VAE) dans PyTorch et j'ai un problème pour écrire du code indépendant du périphérique. L'Autoencoder est un enfant de nn.Module
avec un réseau d'encodeur et de décodeur, qui le sont aussi. Tous les poids du réseau peuvent être déplacés d'un appareil à un autre en appelant net.to (appareil)
.
Le problème que j'ai est avec l'astuce de reparamétrisation:
class VariationalGenerator(nn.Module): def __init__(self, input_nc, output_nc): super(VariationalGenerator, self).__init__() self.input_nc = input_nc self.output_nc = output_nc embedding_size = 128 self._train_noise = torch.randn(batch_size, embedding_size) self._eval_noise = torch.randn(1, embedding_size) self.noise = self._train_noise # Create encoder self.encoder = Encoder(input_nc, embedding_size) # Create decoder self.decoder = Decoder(output_nc, embedding_size) def train(self, mode=True): super(VariationalGenerator, self).train(mode) self.noise = self._train_noise def eval(self): super(VariationalGenerator, self).eval() self.noise = self._eval_noise def forward(self, inputs): # Calculate parameters of embedding space mu, log_sigma = self.encoder.forward(inputs) # Resample noise if training if self.training: self.noise.normal_() # Reparametrize noise to embedding space inputs = mu + self.noise * torch.exp(0.5 * log_sigma) # Decode to image inputs = self.decoder(inputs) return inputs, mu, log_sigma
Le bruit est un tenseur de la même taille que mu
et sigma
et enregistré comme variable membre du module autoencoder. Il est initialisé dans le constructeur et rééchantillonné sur place à chaque étape d'apprentissage. Je le fais de cette façon pour éviter de construire un nouveau tenseur de bruit à chaque étape et de le pousser vers l'appareil souhaité. De plus, je souhaite corriger le bruit dans l'évaluation. Voici le code:
encoding = mu + noise * sigma
Quand je déplace maintenant l'auto-encodeur vers le GPU avec net.to ('cuda: 0')
j'obtiens un erreur de transfert car le tenseur de bruit n'est pas déplacé.
Je ne veux pas ajouter de paramètre de périphérique au constructeur, car il n'est toujours pas possible de le déplacer vers un autre périphérique plus tard. J'ai également essayé d'envelopper le bruit dans nn.Parameter
afin qu'il soit affecté par net.to ()
, mais cela donne une erreur de l'optimiseur, car le bruit est marqué comme requires_grad=False
.
N'importe qui a une solution pour déplacer tous les modules avec net.to ()
?
p>
3 Réponses :
Utilisez ceci:
net.to(device) input = input.to(device)
Maintenant pour le modèle et chaque tenseur que vous utilisez
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Ce n'était pas le problème. Le problème est que le tenseur de bruit n'est pas déplacé avec les poids lorsque net.to (périphérique)
est utilisé.
Après quelques essais et erreurs, j'ai trouvé deux méthodes:
self._train_noise = torch.randn (batch_size, embedding_size)
par self.register_buffer ('_ train_noise', torch.randn (batch_size, embedding_size) code > le tenseur de bruit est ajouté au module en tant que tampon. Cela permet à net.to (périphérique)
de l'affecter également. De plus, le tenseur fait maintenant partie de state_dict.
Remplacer net.to (périphérique)
: en utilisant ceci, le bruit reste en dehors de state_dict.
def to(device): new_self = super(VariationalGenerator, self).to(device) new_self._train_noise = new_self._train_noise.to(device) new_self._eval_noise = new_self._eval_noise.to(device) return new_self
L'approche 2 serait probablement meilleure en remplaçant _apply
plutôt que simplement à
; puis .cuda ()
, etc. fonctionnent tous aussi. ( ma réponse )
Une meilleure version de la deuxième approche de tilman151 consiste probablement à remplacer _apply
, plutôt que à . De cette façon, net.cuda ()
, net.float ()
, etc. fonctionneront également tous, car ils appellent tous _apply
plutôt que à (comme on peut le voir dans la source , qui est plus simple que vous ne le pensez):
def _apply(self, fn): super(VariationalGenerator, self)._apply(fn) self._train_noise = fn(self._train_noise) self._eval_noise = fn(self._eval_noise) return self