Java Deep Learning Tutorial

So bauen Sie ein neuronales Netzwerk auf

14.01.2024
Von 
Matthew Tyson ist Java-Entwickler und schreibt unter anderem für unsere US-Schwesterpublikation Infoworld.com.
Um neuronale Netze wirklich zu verstehen, erstellen Sie am besten selbst eines. Dieses Tutorial zeigt Ihnen, wie das mit Java funktioniert.
Künstliche neuronale Netze bilden eine Säule moderner, künstlicher Intelligenz.
Künstliche neuronale Netze bilden eine Säule moderner, künstlicher Intelligenz.
Foto: Sergey Tarasov - shutterstock.com

Künstliche neuronale Netze sind eine Form des Deep Learning. Der beste Weg, um ihre Funktionsweise vollständig zu durchdringen, besteht darin, sich selbst die Hände "schmutzig" zu machen. Dieser Artikel liefert Ihnen dafür die Grundlage und demonstriert, wie Sie ein neuronales Netzwerk in Java aufbauen und trainieren. Unser Beispiel für diesen Artikel ist dabei keineswegs ein produktionsreifes System - vielmehr gibt es in verständlicher Form Aufschluss über alle Hauptkomponenten.

Ein grundlegendes, neuronales Netz

Ein neuronales Netz ist ein Graph, bestehend aus Knoten (Nodes), die Neuronen genannt werden. Das Neuron ist die Grundeinheit der Berechnung: Es empfängt Eingaben und verarbeitet diese mithilfe:

  • eines Weight-per-Input-Algorithmus,

  • eines Bias-per-Node-Algorithmus sowie

  • eines Final-Function-Processor-Algorithmus.

Nachfolgende Abbildung zeigt ein Neuron mit zwei Inputs:

Ein neuronales Netz bestehend aus zwei Input-Neuronen.
Ein neuronales Netz bestehend aus zwei Input-Neuronen.
Foto: Foundry / Matthew Tyson

Dieses Modell ist sehr variabel - im Folgenden verwenden wir diese Konfiguration.

Unser erster Schritt besteht darin, eine Neuron-Class zu modellieren, die diese Werte enthalten soll. Eine erste Version der Class sehen Sie im folgenden Listing 1 - diese wird sich im weiteren Verlauf verändern, wenn weitere Funktionen hinzukommen.

Listing 1: Eine einfache Neuron-Class

class Neuron {

Random random = new Random();

private Double bias = random.nextDouble(-1, 1);

public Double weight1 = random.nextDouble(-1, 1);

private Double weight2 = random.nextDouble(-1, 1);

public double compute(double input1, double input2){

double preActivation = (this.weight1 * input1) + (this.weight2 * input2) + this.bias;

double output = Util.sigmoid(preActivation);

return output;

}

}

Wie Sie sehen, ist die Neuron-Class recht simpel und weist drei Mitglieder auf: bias, weight1 und weight2. Jedes dieser Mitglieder wird mit einem zufälligen Double zwischen -1 und 1 initialisiert.

Geht es darum, den Output des Neurons zu berechnen, folgen wir dem in Abbildung 1 gezeigten Algorithmus: Wir multiplizieren jede Eingabe mit ihrer Gewichtung plus dem Bias: input1 * weight1 + input2 * weight2 + bias. So erhalten wir die unverarbeitete Berechnung (preActivation), die wir durch die Aktivierungsfunktion laufen lassen. In diesem Fall verwenden wir die Sigmoid-Aktivierungsfunktion, die Werte in einem Bereich von -1 bis 1 komprimiert. Im Folgenden die statische Util.sigmoid()-Methode.

Listing 2: Sigmoid-Aktivierungsfunktion

public class Util {

public static double sigmoid(double in){

return 1 / (1 + Math.exp(-in));

}

}

Nachdem wir nun die Funktionsweise von Neuronen beleuchtet haben, gilt es, einige Neuronen in ein Netzwerk einfügen. Dazu nutzen wir eine Network-Class mit einer Liste von Neuronen.

Listing 3: Die Neural Network Class

class Network {

List<Neuron> neurons = Arrays.asList(

new Neuron(), new Neuron(), new Neuron(), /* input nodes */

new Neuron(), new Neuron(), /* hidden nodes */

new Neuron()); /* output node */

}

}

Obwohl die Liste der Neuronen eindimensional ist, werden wir sie während der Nutzung zu einem Netzwerk verbinden. Die ersten drei Neuronen sind Inputs, die folgenden beiden versteckt und das letzte der Output-Knoten.

Eine Prediction anstoßen

Nun soll es darum gehen, ein Netzwerk zu Prediction-Zwecken einzusetzen. Dazu verwenden wir einen einfachen Datensatz mit zwei ganzzahligen Inputs und einem Antwortformat von 0 bis 1. In unserem Beispiel wird eine Kombination aus Gewicht und Größe verwendet, um das Geschlecht einer Person zu erraten.

Dabei wird davon ausgegangen, dass mehr Gewicht und Größe auf eine männliche Person hindeuten. Dieselbe Formel ließe sich für jede beliebige Wahrscheinlichkeitsrechnung mit zwei Faktoren und einem Output nutzen. Den Input könnte man auch als Vektor betrachten - und somit die Gesamtfunktion der Neuronen als Umwandlung eines Vektors in einem Skalarwert. Die Prediction-Phase des Netzes gestaltet sich wie folgt.

Listing 4: Network prediction

public Double predict(Integer input1, Integer input2){

return neurons.get(5).compute(

neurons.get(4).compute(

neurons.get(2).compute(input1, input2),

neurons.get(1).compute(input1, input2)

),

neurons.get(3).compute(

neurons.get(1).compute(input1, input2),

neurons.get(0).compute(input1, input2)

)

);

}

Listing 4 zeigt, dass die beiden Inputs an die ersten drei Neuronen fließen. Deren Outputs werden an die Neuronen 4 und 5 weitergeleitet wird, die wiederum in das Output-Neuron einspeisen. Dieser Prozess wird als Feedforward bezeichnet. Nun könnten wir das Netz zu einer Prediction auffordern.

Listing 5: Prediction

Network network = new Network();

Double prediction = network.predict(Arrays.asList(115, 66));

System.out.println("prediction: " + prediction);

Das würde sicher zu Ergebnissen führen - die allerdings nur auf Zufallswerten und Bias basieren. Für eine echte Prediction ist es nötig, das Netzwerk zuvor zu trainieren.

Das Netzwerk trainieren

Das Training eines neuronalen Netzwerks folgt einem Prozess, der als Backpropagation bekannt ist. Der beinhaltet im Grunde, Änderungen rückwärts durch das Netzwerk zu "schieben", damit sich der Output in Richtung eines gewünschten Zielwerts bewegt. Backpropagation lässt sich mit Hilfe von Funktionsdifferenzierung durchführen - für unser Beispiel werden wir allerdings einen anderen Weg gehen und jedem Neuron die Fähigkeit verleihen, zu "mutieren".

In jeder Trainingsrunde (auch Epoch genannt) wählen wir ein anderes Neuron aus, um eine kleine, zufällige Anpassung an einer seiner Eigenschaften (weight1, weight2 oder bias) vorzunehmen und dann zu prüfen, ob sich die Ergebnisse verbessern. Ist das der Fall, behalten wir diese Änderung mit einer remember()-Methode bei. Wenn sich die Ergebnisse verschlechtert haben, machen wir sie mit einer forget()-Methode rückgängig.

Um die Änderungen zu tracken, fügen wir Class-Mitglieder hinzu (old*-Versionen von weights und bias). Im Folgenden betrachten wir die Methoden mutate(), remember() und forget().

Listing 6: mutate(), remember(), forget()

public class Neuron() {

private Double oldBias = random.nextDouble(-1, 1), bias = random.nextDouble(-1, 1);

public Double oldWeight1 = random.nextDouble(-1, 1), weight1 = random.nextDouble(-1, 1);

private Double oldWeight2 = random.nextDouble(-1, 1), weight2 = random.nextDouble(-1, 1);

public void mutate(){

int propertyToChange = random.nextInt(0, 3);

Double changeFactor = random.nextDouble(-1, 1);

if (propertyToChange == 0){

this.bias += changeFactor;

} else if (propertyToChange == 1){

this.weight1 += changeFactor;

} else {

this.weight2 += changeFactor;

};

}

public void forget(){

bias = oldBias;

weight1 = oldWeight1;

weight2 = oldWeight2;

}

public void remember(){

oldBias = bias;

oldWeight1 = weight1;

oldWeight2 = weight2;

}

}

Zusammengefasst:

  • Die mutate()-Methode wählt eine zufällige Eigenschaft und einen zufälligen Wert zwischen -1 und 1 aus und ändert dann die Eigenschaft.

  • Die forget()-Methode setzt diese Änderung auf den alten Wert zurück.

  • Die remember()-Methode kopiert den neuen Wert in den Puffer.

Um nun die neuen Fähigkeiten unseres Neurons zu nutzen, fügen wir Network eine train()-Methode hinzu.

Listing 7: Die Network.train()-Methode

public void train(List<List<Integer>> data, List<Double> answers){

Double bestEpochLoss = null;

for (int epoch = 0; epoch < 1000; epoch++){

// adapt neuron

Neuron epochNeuron = neurons.get(epoch % 6);

List<Double> predictions = new ArrayList<Double>();

for (int i = 0; i < data.size(); i++){

predictions.add(i, this.predict(data.get(i).get(0), data.get(i).get(1)));

}

Double thisEpochLoss = Util.meanSquareLoss(answers, predictions);

if (bestEpochLoss == null){

bestEpochLoss = thisEpochLoss;

epochNeuron.remember();

} else {

if (thisEpochLoss < bestEpochLoss){

bestEpochLoss = thisEpochLoss;

epochNeuron.remember();

} else {

epochNeuron.forget();

}

}

}

Die train()-Methode iteriert eintausendmal über die aufgeführten data, answers und lists. Es handelt sich um gleich große Trainingsmengen: data beinhaltet Input-Werte, answers die bekannten, richtigen Antworten. Die Methode ermittelt dann einen Wert darüber, wie nahe das Ergebnis des Netzwerks den bekannten, richtigen Antworten kommt. Dann wird ein zufälliges Neuron verändert (mutiert), wobei die Änderung beibehalten wird, wenn ein neuer Test ergibt, dass sie eine bessere Vorhersage zur Folge hatte.

Die Ergebnisse lassen sich mithilfe der Mean-Squared-Error (MSE) -Formel überprüfen - einer dafür gängigen Methode.

Listing 8: MSE-Funktion

public static Double meanSquareLoss(List<Double> correctAnswers, List<Double> predictedAnswers){

double sumSquare = 0;

for (int i = 0; i < correctAnswers.size(); i++){

double error = correctAnswers.get(i) - predictedAnswers.get(i);

sumSquare += (error * error);

}

return sumSquare / (correctAnswers.size());

}