Je suis nouveau à Pytorch et à Numpy, donc cela peut être une question stupide. J'aimerais voir des images mal classées par mon net, avec l'étiquette correcte et l'étiquette prédite. Voici mon code
valid_and_test_set = torchvision.datasets.MNIST("./mnist", train=False, download=True) dataset_valid, dataset_test = torch.utils.data.random_split(valid_and_test_set,[5000, 5000]) dataset_test.dataset.transform = transform #transform is composed by unsqueeze, normalize, view and gaussian noise with randn dataset_test.dataset.target_transform = OneHot() #OneHot return the label dataloader_test = torch.utils.data.DataLoader(dataset_test.dataset, batch_size=5000, num_workers=num_workers, pin_memory=True) def test(dataset, dataloader): net.eval() with torch.no_grad(): for batch in dataloader: inputs = batch[0] inputs = inputs.to(device, non_blocking=True) outputs = net(inputs) predictions = torch.argmax(outputs, dim=1) return predictions
3 Réponses :
Il y a au moins deux façons de le faire. P>
On est, pour stocker les images mal classées lors de l'évaluation (en cours d'exécution dans les données de test) et tracé ceux-ci. Ceci est montré ici P>
Une autre manière consiste à utiliser le tensorboard. C'est assez élégant à mon avis et vous pouvez trouver un guide complet pour celui-ci ICI P>
def showimg(model): model=np.reshape(model.numpy(),[28,28]) # For 1D Vector #If you normalize the image then use Next three-line #Otherwise skip that mean=np.array([0.485, 0.456, 0.406] ) std=np.array([0.229, 0.224, 0.225]) model=(model*std+mean) #print(model) cv2.imshow("ABC", model) #waits for user to press any key #(this is necessary to avoid Python kernel form crashing) cv2.waitKey(0) #closing all open windows cv2.destroyAllWindows()
Quelle est la variable "Data" dans la fonction de train?
Modifié cette variable.
Je reçois cette erreur, je ne sais pas ce que cela signifie
ValueError Traceback (most recent call last) in 288 289 # test on validation --> 290 predictions = test(dataset_valid, dataloader_valid) 291 accuracy_valid = 100. * predictions.eq(dataset_valid.dataset.targets[dataset_valid.indices].to(device)).sum().float() / len(dataset_valid) 292 in test(dataset, dataloader) 236 print("Predicted Label") 237 print(predictions[sampleno]) --> 238 showimages(inputs[sampleno].cpu()) 239 return predictions 240 in showimages(model) 240 241 def showimages(model): --> 242 model=np.transpose(model.numpy(),(1,2,0)) 243 244 <__array_function__ internals> in transpose(*args, **kwargs) ~/.local/lib/python3.7/site-packages/numpy/core/fromnumeric.py in transpose(a, axes) 649 650 """ --> 651 return _wrapfunc(a, 'transpose', axes) 652 653 ~/.local/lib/python3.7/site-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds) 59 60 try: ---> 61 return bound(*args, **kwds) 62 except TypeError: 63 # A TypeError occurs if the object does have such a method in its ValueError: axes don't match array
Quelle est la taille de votre entrée code> et
SAMPLENO] code>?
TORCH.SIZE ([5000, 784]) et TORCH.SIZE ([784]). Le problème est que c'est en niveaux de gris et vous avez utilisé RVB (je pense)
Ok, je l'ai fait pour le tableau 3D (image RVB) mais dans votre entrée de cas est un vecteur, j'ai modifié ma réponse. Je suppose que ça fonctionnera cette fois
Presque fait, j'ai eu une erreur de forme non valide (1, 28, 28) sur plt.imshow (je devais l'utiliser parce que j'ai des problèmes avec OpenCV sur Debian)
Utilisez np.squeeze () code> fonction avant d'utiliser
plt.inshow code>
Je l'ai fait, mais je n'ai aucune sortie. Je suis désolé si je vous dérangeant, j'ai juste besoin de voir les images et c'est plus difficile que de construire le modèle.
Qu'entendez-vous par aucune sortie? Erreur MSG? ou image vierge? Est-ce que vous normalisez les images avant de vous entraîner et de tester?
Je ne veux rien dire, pas de message d'erreur, pas de lignes / d'images vierges, absolument rien. J'ai normalisé l'entrée et dénormalisée avant l'impression que vous l'avez fait dans votre exemple (avec 1 valeur au lieu de 3)
Essayez la version opencv. Je n'utilise pas plt code> autant.
J'ai oublié de plt.show () après PLT.IMSHOW, et cela fonctionne maintenant. Merci beaucoup pour votre aide.
Vous pouvez supprimer cette réponse et cette conversation. Puisque ce n'est pas la réponse réelle