One Shot Learning con le reti siamesi usando PyTorch

Articolo in lingua originale di Harshvardhan Gupta

Esempio di One Shot Learning.  Sorgente.

Questo articolo assume che ci sia familiarità con le reti neurali. 

Questo articolo rappresenta la prima parte, potete trovare la seconda qui.

Le reti neurali profonde sono l’algoritmo di riferimento per la classificazione delle immagini. Ciò è dovuto in parte al fatto che possono avere un numero arbitrariamente elevato di parametri addestrabili. Tuttavia, questo ha il costo di richiedere una grande quantità di dati, che a volte non sono disponibili. In questo articolo parleremo di un metodo di apprendimento, il One Shot Learning, che mira a mitigare questo problema. Nella seconda parte dell’articolo affronteremo, inoltre, come implementare in PyTorch una rete neurale in grado di utilizzare questo metodo di apprendimento.

Questo articolo prende spunto da questo paper.

Classificazione standard vs. classificazione One Shot

La classificazione standard è quella utilizzata da quasi tutti i modelli di classificazione. L’input viene immesso in una serie di layer e alla fine vengono emesse le probabilità di appartenere a ciascuna classe che l’immagine ha. Se si vuole classificare un’immagine come raffigurante un cane o un gatto, si addestra il modello su immagini simili (ma non uguali) di cani/gatti che la rete riceverà poi durante la fase di predizione vera e propria. Naturalmente, ciò richiede che si disponga di un set di dati simile a quello che ci si aspetterebbe una volta utilizzato il modello per la predizione.

I modelli di classificazione One Shot, invece, richiedono un solo campione di addestramento per ogni classe che si vuole prevedere. Il modello viene comunque addestrato su diverse istanze, ma queste devono appartenere a un dominio simile a quello dell’esempio di addestramento.

Un buon esempio potrebbe essere il riconoscimento facciale. Si potrebbe addestrare un modello di classificazione One Shot su un set di dati che contiene immagini di alcune persone scattate con diverse angolazioni, diversa illuminazione, ecc. Per riconoscere se un’immagine contiene la persona X si scatta una singola foto di quella persona e poi si chiede al modello se quella persona è presente in quell’immagine (si noti che il modello non è stato addestrato utilizzando alcuna foto della persona X).

Come esseri umani siam iin grado di riconoscere una persona dalla sua faccia anche avendola incontrata una sola volta, vogliamo che anche i computer abbiano tale abilità in quanto spesso non disponiamo di molti dati, ovvero non abbiamo un numero sufficiente di immagini della stessa persona.

Introduzione alle reti siamesi

Le reti siamesi rappresentano un tipo speciale di architettura di reti neurali. Invece che un modello che apprende come classificare gli input che riceve, queste reti neurali imparano a distinguere tra due input, imparano la somiglianza tra di essi.

Architettura

Le reti siamesi sono due reti neurali identiche, ciascuna di esse prende in input due immagini. Gli ultimi livelli delle reti rappresentano poi l’input di una funzione di perdita di paragone che calcola la somiglianza tra le due immagini. Ho realizzato un’immagine che può aiutare a capire quest’architettura:

Figure 1

Ci sono due reti sorelle, che sono reti neurali identiche, esattamente con gli stessi pesi. Ciascun immagine nella coppia di immagini viene inviata come input a queste reti. Le reti sono ottimizzate utilizzando una funzione di perdita di paragone (arriveremo poi alla funzione esatta).

Funzione di perdita di paragone

L’obiettivo dell’architettura siamese non è quello di classificare le immagini in input, ma essere in grado di distinguerle. Quindi una funzione di perdita utilizzata per la classificazione, come la Cross Entropy, per esempio, non sarebbe una scelta adatta. Invece, quest’architettura si adatta meglio all’utilizzo di una funzione di perdita di paragone. Intuitivamente, questa funzione valuta semplicemente quanto bene la rete sta distinguendo una coppia di immagini date.

Potete leggere più dettagli in questo articolo.

La funzione di perdita di paragone ha questa formula:

Equazione 1

Dove Dw viene definita come la distanza euclidea tra due output delle reti siamesi. Matematicamente, la distanza euclidea si calcola come:

Equazione 1.1

Dove Gw rappresenta l’output di una delle due reti sorelle. X1 e X2 sono la coppia di dati in input.

 

Spiegazione dell’equazione 1

Y può assumere valore 0 o 1. Se gli input provengono dalla stessa classe, allora Y = 0, altrimenti Y = 1.

Max() è una funzione che estrae il valore maggiore tra 0 e m-Dw, dove m è un valore margine maggiore di 0. Avere un margine significa che le coppie dissimili con un indice di somiglianza oltre il margine non contribuiscono alla funzione di perdita. Questo ha senso in quanto vorremmo ottimizzare la rete solamente basandoci sulle coppie che sono effettivamente dissimili e che la rete considera, invece, simili.

 

I dataset

Utilizzeremo due dataset: MNIST (classico) e OmniGlot. MNIST verrà utilizzato per addestrare il modello a comprendere come distinguere i caratteri e poi testeremo il modello utilizzando OmniGlot.

 

Omniglot

Questo dataset consiste in campioni provenienti da 50 lingue diverse. Ciascun alfabeto ha solamente 20 campioni. Questo dataset è considerato un “trasposto” di MNIST che contiene invece 10 classi dove i campioni sono numeri. In OmniGlot c’è un numero molto grande di classi, con pochi esempi in ciascuna di esse.

Figura 2: Alcuni esempi dal dataset OmniGlot

OmniGlot verrà utilizzato come dataset di classificazione one shot per essere in grado di riconoscere molte classi diverse attraverso esempi alla mano.

Conclusione

Abbiamo fatto un salto nelle premesse generali del one shot learning provando a risolvere il problema attraverso un’architettura di rete neurale chiamata rete siamese. Abbiamo parlato di funzione di perdita che distingue tra una coppia di input.

Nella seconda parte di questo articolo andremo a fonto implementando un’architettura di questo tipo, addestrandola sul dataset MNIST per poi ottenere le predizioni su OmniGlot.

Share:

Contenuti
Torna in alto