Java Deep Learning Tutorial

So bauen Sie ein neuronales Netzwerk auf

29.05.2024
Von 
Matthew Tyson ist Java-Entwickler und schreibt unter anderem für unsere US-Schwesterpublikation Infoworld.com.

System feinabstimmen

Nun müssen wir nur noch einige Trainingsdaten in das Netz fließen lassen und es mit weiteren Predictions austesten. Im Folgenden betrachten wir, wie man Trainingsdaten bereitstellt.

Listing 9: Trainingsdaten

List<List<Integer>> data = new ArrayList<List<Integer>>();

data.add(Arrays.asList(115, 66));

data.add(Arrays.asList(175, 78));

data.add(Arrays.asList(205, 72));

data.add(Arrays.asList(120, 67));

List<Double> answers = Arrays.asList(1.0,0.0,0.0,1.0);

Network network = new Network();

network.train(data, answers);

In Listing 9 bestehen unsere Trainingsdaten aus einer Liste von zweidimensionalen Integer-Sets (wir könnten sie uns als Gewicht und Größe vorstellen) und einer Liste von Antworten (wobei 1.0 weiblich und 0.0 männlich ist).

Wenn wir den Trainingsalgorithmus eine Logging-Funktionalität hinzufügen, erhalten wir nachfolgendes Resultat.

Listing 10. Trainingsdaten-Logging

// Logging:

if (epoch % 10 == 0) System.out.println(String.format("Epoch: %s | bestEpochLoss: %.15f | thisEpochLoss: %.15f", epoch, bestEpochLoss, thisEpochLoss));

// output:

Epoch: 910 | bestEpochLoss: 0.034404863820424 | thisEpochLoss: 0.034437939546120

Epoch: 920 | bestEpochLoss: 0.033875954196897 | thisEpochLoss: 0.431451026477016

Epoch: 930 | bestEpochLoss: 0.032509260025490 | thisEpochLoss: 0.032509260025490

Epoch: 940 | bestEpochLoss: 0.003092720117159 | thisEpochLoss: 0.003098025397281

Epoch: 950 | bestEpochLoss: 0.002990128276146 | thisEpochLoss: 0.431062364628853

Epoch: 960 | bestEpochLoss: 0.001651762688346 | thisEpochLoss: 0.001651762688346

Epoch: 970 | bestEpochLoss: 0.001637709485751 | thisEpochLoss: 0.001636810460399

Epoch: 980 | bestEpochLoss: 0.001083365453009 | thisEpochLoss: 0.391527869500699

Epoch: 990 | bestEpochLoss: 0.001078338540452 | thisEpochLoss: 0.001078338540452

Wie in Listing 10 zu sehen, nimmt der "Loss" (also die Fehlerabweichung von "100 Prozent korrekt") langsam ab. Das Modell nähert sich also immer mehr einer genauen Vorhersage an. Nun gilt es zu überprüfen, wie gut unser Modell mit echten Daten funktioniert.

Listing 11: Vorhersagen

System.out.println("");

System.out.println(String.format(" male, 167, 73: %.10f", network.predict(167, 73)));

System.out.println(String.format("female, 105, 67: %.10", network.predict(105, 67)));

System.out.println(String.format("female, 120, 72: %.10f | network1000: %.10f", network.predict(120, 72)));

System.out.println(String.format(" male, 143, 67: %.10f | network1000: %.10f", network.predict(143, 67)));

System.out.println(String.format(" male', 130, 66: %.10f | network: %.10f", network.predict(130, 66)));

Wie in Listing 11 zu sehen, füttern wir unser trainiertes neuronales Netz mit Daten und geben die Vorhersagen aus. Das Resultat sieht in etwa wie folgt aus.

Listing 12: Trainierte Predictions

male, 167, 73: 0.0279697143

female, 105, 67: 0.9075809407

female, 120, 72: 0.9075808235

male, 143, 67: 0.0305401413

male, 130, 66: network: 0.9009811922

Listing 12 zeigt, dass das Netzwerk bei den meisten Wertepaaren (Vektoren) ziemlich gute Arbeit geleistet hat. Es gibt den weiblichen Datensätzen eine Schätzung um 0.907 - was ziemlich nahe an 1 ist. Zwei männliche Datensätze weisen 0.027 und 0.030 auf - und nähern sich damit der 0. Der männliche Ausreißer-Datensatz (130, 67) wird als "wahrscheinlich weiblich" angesehen, bei einem Wert von 0.900 allerdings mit geringerer Zuversicht.

Es gibt eine Reihe von Möglichkeiten, die Einstellungen an diesem System zu verändern: Die Anzahl der Epochs in einem Trainingslauf ist dabei ein wichtiger Faktor. Je mehr Epochs, desto besser wird das Modell auf die Daten abgestimmt. Das kann auch die Genauigkeit von Live-Daten verbessern, die mit den Trainingssätzen übereinstimmen. Allerdings kann es auch in einem "Overtraining" resultieren - also einem Modell, das zuversichtlich die falschen Ergebnisse für Randfälle vorhersagt.

Den vollständigen Code für dieses Tutorial finden Sie in diesem GitHub-Repository - zusammen mit einigen zusätzlichen Funktionen. (fm)

Dieser Beitrag basiert auf einem Artikel unserer US-Schwesterpublikation Infoworld.