-1
votes

Afficher des images mal classées à Pytorch

Je suis nouveau à Pytorch et à Numpy, donc cela peut être une question stupide. J'aimerais voir des images mal classées par mon net, avec l'étiquette correcte et l'étiquette prédite. Voici mon code

valid_and_test_set = torchvision.datasets.MNIST("./mnist", train=False, download=True)
dataset_valid, dataset_test = torch.utils.data.random_split(valid_and_test_set,[5000, 5000])
dataset_test.dataset.transform = transform #transform is composed by unsqueeze, normalize, view and gaussian noise with randn
dataset_test.dataset.target_transform = OneHot() #OneHot return the label
dataloader_test = torch.utils.data.DataLoader(dataset_test.dataset, batch_size=5000, num_workers=num_workers, pin_memory=True)

def test(dataset, dataloader):
    net.eval()  
    with torch.no_grad():
        for batch in dataloader:
            inputs = batch[0]
            inputs = inputs.to(device, non_blocking=True)
            outputs = net(inputs)
            predictions = torch.argmax(outputs, dim=1)
            return predictions


0 commentaires

3 Réponses :


0
votes

Il y a au moins deux façons de le faire.

On est, pour stocker les images mal classées lors de l'évaluation (en cours d'exécution dans les données de test) et tracé ceux-ci. Ceci est montré ici

Une autre manière consiste à utiliser le tensorboard. C'est assez élégant à mon avis et vous pouvez trouver un guide complet pour celui-ci ICI


0 commentaires

0
votes
def showimg(model):
    model=np.reshape(model.numpy(),[28,28]) # For 1D Vector
    
    #If you normalize the image then use Next three-line
    #Otherwise skip that
    mean=np.array([0.485, 0.456, 0.406] )
    std=np.array([0.229, 0.224, 0.225])
    model=(model*std+mean)
    


    #print(model)

    cv2.imshow("ABC", model)
    
    #waits for user to press any key
    #(this is necessary to avoid Python kernel form crashing)
    cv2.waitKey(0)

    #closing all open windows
    cv2.destroyAllWindows()

2 commentaires

Quelle est la variable "Data" dans la fonction de train?


Modifié cette variable.



0
votes

Je reçois cette erreur, je ne sais pas ce que cela signifie

ValueError                                Traceback (most recent call last)
 in 
    288 
    289         # test on validation
--> 290         predictions = test(dataset_valid, dataloader_valid)
    291         accuracy_valid = 100. * predictions.eq(dataset_valid.dataset.targets[dataset_valid.indices].to(device)).sum().float() / len(dataset_valid)
    292 

 in test(dataset, dataloader)
    236                     print("Predicted Label")
    237                     print(predictions[sampleno])
--> 238                     showimages(inputs[sampleno].cpu())
    239             return predictions
    240 

 in showimages(model)
    240 
    241 def showimages(model):
--> 242     model=np.transpose(model.numpy(),(1,2,0))
    243 
    244     

<__array_function__ internals> in transpose(*args, **kwargs)

~/.local/lib/python3.7/site-packages/numpy/core/fromnumeric.py in transpose(a, axes)
    649 
    650     """
--> 651     return _wrapfunc(a, 'transpose', axes)
    652 
    653 

~/.local/lib/python3.7/site-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
     59 
     60     try:
---> 61         return bound(*args, **kwds)
     62     except TypeError:
     63         # A TypeError occurs if the object does have such a method in its

ValueError: axes don't match array


11 commentaires

Quelle est la taille de votre entrée et SAMPLENO] ?


TORCH.SIZE ([5000, 784]) et TORCH.SIZE ([784]). Le problème est que c'est en niveaux de gris et vous avez utilisé RVB (je pense)


Ok, je l'ai fait pour le tableau 3D (image RVB) mais dans votre entrée de cas est un vecteur, j'ai modifié ma réponse. Je suppose que ça fonctionnera cette fois


Presque fait, j'ai eu une erreur de forme non valide (1, 28, 28) sur plt.imshow (je devais l'utiliser parce que j'ai des problèmes avec OpenCV sur Debian)


Utilisez np.squeeze () fonction avant d'utiliser plt.inshow


Je l'ai fait, mais je n'ai aucune sortie. Je suis désolé si je vous dérangeant, j'ai juste besoin de voir les images et c'est plus difficile que de construire le modèle.


Qu'entendez-vous par aucune sortie? Erreur MSG? ou image vierge? Est-ce que vous normalisez les images avant de vous entraîner et de tester?


Je ne veux rien dire, pas de message d'erreur, pas de lignes / d'images vierges, absolument rien. J'ai normalisé l'entrée et dénormalisée avant l'impression que vous l'avez fait dans votre exemple (avec 1 valeur au lieu de 3)


Essayez la version opencv. Je n'utilise pas plt autant.


J'ai oublié de plt.show () après PLT.IMSHOW, et cela fonctionne maintenant. Merci beaucoup pour votre aide.


Vous pouvez supprimer cette réponse et cette conversation. Puisque ce n'est pas la réponse réelle