Warum wir KI-Prinzipien brauchen

Künstliche Intelligenz verändert unsere Welt grundlegend. Algorithmen beeinflussen zunehmend, wie wir uns verhalten, denken und fühlen. Unternehmen rund um den Globus werden KI-Technologien zunehmend nutzen und ihre derzeitigen Prozesse und Geschäftsmodelle neu erfinden. Unsere sozialen Strukturen, die Art und Weise, wie wir arbeiten und wie wir miteinander interagieren, werden sich mit den Fortschritten der Digitalisierung, insbesondere der KI, verändern.

Neben ihrem sozialen und wirtschaftlichen Einfluss spielt KI auch eine wichtige Rolle bei einer der größten Herausforderungen unserer Zeit: dem Klimawandel. Einerseits kann KI Instrumente bereitstellen, um einen Teil dieser dringenden Herausforderung zu bewältigen. Andererseits wird die Entwicklung und Implementierung von KI-Anwendungen viel Energie verbrauchen und große Mengen an Treibhausgasen ausstoßen.

Risiken der KI

Mit dem Fortschritt einer Technologie, die einen so großen Einfluss auf alle Bereiche unseres Lebens hat, gehen große Chancen, aber auch große Risiken einher. Um Euch einen Eindruck von den Risiken zu vermitteln, haben wir sechs Beispiele herausgegriffen:

  • KI kann zur Überwachung von Menschen eingesetzt werden, zum Beispiel durch Gesichtserkennungssysteme. Einige Länder setzen diese Technologie bereits seit einigen Jahren intensiv ein.
  • KI wird in sehr sensiblen Bereichen eingesetzt. In diesen können schon kleine Fehlfunktionen dramatische Auswirkungen haben. Beispiele dafür sind autonomes Fahren, robotergestützte Chirurgie, Kreditwürdigkeitsprüfung, Auswahl von Bewerber:innen oder Strafverfolgung.
  • Der Skandal um Facebook und Cambridge Analytica hat gezeigt, dass Daten und KI-Technologien zur Erstellung psychografischer Profile genutzt werden können. Diese Profile ermöglichen die gezielte Ansprache von Personen mit maßgeschneiderten Inhalten. Beispielsweise zur Beeinflussung von politischen Wahlen. Dieses Beispiel zeigt die enorme Macht der KI-Technologien und die Möglichkeit für Missbrauch und Manipulation.
  • Mit den jüngsten Fortschritten in der Computer Vision Technologie können Deep Learning Algorithmen nun zur Erstellung von Deepfakes verwendet werden. Deepfakes sind realistische Videos oder Bilder von Menschen, in denen diese etwas tun oder sagen, was sie nie in der Realität getan oder gesagt haben. Die Möglichkeiten für Missbrauch dieser Technologie sind vielfältig.
  • KI-Lösungen werden häufig entwickelt, um manuelle Prozesse zu verbessern oder zu optimieren. Es wird Anwendungsfälle geben, bei denen menschliche Arbeit ersetzt wird. Dabei entstehen unterschiedlichste Herausforderungen, die nicht ignoriert, sondern frühzeitig angegangen werden müssen.
  • In der Vergangenheit haben KI-Modelle diskriminierende Muster der Daten, auf denen sie trainiert wurden, reproduziert. So hat Amazon beispielsweise ein KI-System in seinem Rekrutierungsprozess eingesetzt, das Frauen eindeutig benachteiligte.

Diese Beispiele machen deutlich, dass jedes Unternehmen und jede Person, die KI-Systeme entwickelt, sehr sorgfältig darüber nachdenken sollte, welche Auswirkungen das System auf die Gesellschaft, bestimmte Gruppen oder sogar Einzelpersonen haben wird oder haben könnte.

Daher besteht die große Herausforderung für uns darin, sicherzustellen, dass die von uns entwickelten KI-Technologien den Menschen helfen und sie befähigen, während wir gleichzeitig potenzielle Risiken minimieren.

Warum gibt es im Jahr 2022 keine offizielle Regelung?

Vielleicht fragt Ihr euch, warum es keine Gesetze gibt, die sich mit diesem Thema befassen. Das Problem bei neuen Technologien, insbesondere bei künstlicher Intelligenz, ist, dass sie sich schnell weiterentwickeln, manchmal sogar zu schnell.

Die jüngsten Veröffentlichungen neuer Sprachmodelle wie GPT-3 oder Computer Vision Modelle, z. B. DALLE-2, haben selbst die Erwartungen vieler KI-Expert:innen übertroffen. Die Fähigkeiten und Anwendungen der KI-Technologien werden sich schneller weiterentwickeln, als die Regulierung es kann. Und wir sprechen hier nicht von Monaten, sondern von Jahren.

Dabei ist zu erwähnen, dass die EU einen ersten Versuch in diese Richtung unternommen hat, indem sie eine Regulierung von künstlicher Intelligenz vorgeschlagen hat. In diesem Vorschlag wird jedoch darauf hingewiesen, dass die Verordnung frühestens in der zweiten Hälfte des Jahres 2024 für die anwendenden Unternehmen gelten könnte. Das sind Jahre, nachdem die oben beschriebenen Beispiele Realität geworden sind.

Unser Ansatz: statworx AI Principles

Die logische Konsequenz daraus ist, dass wir uns als Unternehmen selbst dieser Herausforderung stellen müssen. Und genau deshalb arbeiten wir derzeit an den statworx AI Principles, einer Reihe von Prinzipien, die uns bei der Entwicklung von KI-Lösungen leiten und Orientierung geben sollen.

Was wir bisher getan haben und wie wir dazu gekommen sind

In unserer Arbeitsgruppe „AI & Society“ haben wir begonnen, uns mit diesem Thema zu beschäftigen. Zunächst haben wir den Markt gescannt und viele interessante Paper gefunden. Allerdings sind wir zu dem Schluss gekommen, dass sich keins davon 1:1 auf unser Geschäftsmodell übertragen lässt. Oft waren diese Prinzipien oder Richtlinien sehr schwammig oder zu detailliert und zusätzlich ungeeignet für ein Beratungsunternehmen, das im B2B-Bereich als Dienstleister tätig ist. Also beschlossen wir, dass wir selbst eine Lösung entwickeln mussten.

In den ersten Diskussionen darüber wurden vier große Herausforderungen deutlich:

  • Einerseits müssen die AI Principles klar und für das breite Publikum verständlich formuliert sein, damit auch Nicht-Expert:innen sie verstehen. Andererseits müssen sie konkret sein, um sie in unseren Entwicklungsprozess integrieren zu können.
  • Als Dienstleister haben wir nur begrenzte Kontrolle und Entscheidungsgewalt über einige Aspekte einer KI-Lösung. Daher müssen wir verstehen, was wir entscheiden können und was außerhalb unserer Kontrolle liegt.
  • Unsere AI Principles werden nur dann einen nachhaltigen Mehrwert schaffen, wenn wir auch nach ihnen handeln können. Deshalb müssen wir sie in unseren Kundenprojekten anwenden und bewerben. Wir sind uns darüber im Klaren, dass Budgetzwänge, finanzielle Ziele und andere Faktoren dem entgegenstehen könnten, da es zusätzlichen Zeit- und Geldaufwand erfordert.
  • Außerdem ist nicht immer klar, was falsch und richtig ist. Unsere Diskussionen haben gezeigt, dass es viele unterschiedliche Auffassungen darüber gibt, was richtig und notwendig ist. Das bedeutet, dass wir eine gemeinsame Basis finden müssen, auf die wir uns als Unternehmen einigen können.

Unsere zwei wichtigsten Erkenntnisse

Eine wichtige Erkenntnis aus diesen Überlegungen war, dass wir zwei Dinge brauchen.

In einem ersten Schritt brauchen wir übergeordnete Grundsätze, die verständlich und klar sind und bei denen alle mit an Bord sind. Diese Grundsätze dienen als Leitidee und geben Orientierung bei der Entscheidungsfindung. In einem zweiten Schritt wird daraus ein Framework abgeleitet, welches diese Grundsätze in allen Phasen unserer Projekte in konkrete Maßnahmen übersetzt.

Die zweite wichtige Erkenntnis ist, dass es durchaus schwierig ist, diesen Prozess zu durchlaufen und sich diese Fragen zu stellen. Aber gleichzeitig auch, dass dies für jedes Unternehmen, das KI-Technologie entwickelt oder einsetzt, unvermeidlich ist.

 

Was kommt als nächstes?

Bis jetzt sind wir fast am Ende des ersten Schritts angelangt. Wir werden die statworx AI Principles bald über unsere Kanäle kommunizieren. Wenn Ihr euch ebenfalls in diesem Prozess befindet, würden wir uns freuen, mit Euch in Kontakt zu treten, um zu verstehen, wie ihr vorgegangen seid und was ihr dabei gelernt habt.

Quellen

https://www.nytimes.com/2019/04/14/technology/china-surveillance-artificial-intelligence-racial-profiling.html

https://www.nytimes.com/2018/04/04/us/politics/cambridge-analytica-scandal-fallout.html

https://www.reuters.com/article/us-amazon-com-jobs-automation-insight-idUSKCN1MK08G

https://digital-strategy.ec.europa.eu/en/policies/european-approach-artificial-intelligence

https://www.bundesregierung.de/breg-de/themen/umgang-mit-desinformation/deep-fakes-1876736

https://www.welt.de/wirtschaft/article173642209/Jobverlust-Diese-Jobs-werden-als-erstes-durch-Roboter-ersetzt.html

Jan Fischer Jan Fischer Jan Fischer Jan Fischer Jan Fischer Jan Fischer Alexander Blaufuss

Inhalt

Einführung

Je komplexer ein beliebiges Data Science Projekt in Python wird, desto schwieriger wird es in der Regel, den Überblick darüber zu behalten, wie alle Module miteinander interagieren. Wenn man in einem Team an einem größeren Projekt arbeitet, wie es hier bei STATWORX oft der Fall ist, kann die Codebasis schnell so groß werden, dass die Komplexität abschreckend wirken kann. In einem typischen Szenario arbeitet jedes Teammitglied in seiner „Ecke“ des Projekts, so dass jeder nur über ein solides lokales Wissen über den Code des Projekts verfügt, aber möglicherweise nur eine vage Vorstellung von der Gesamtarchitektur des Projekts hat. Im Idealfall sollte jedoch jeder, der an dem Projekt beteiligt ist, einen guten globalen Überblick über das Projekt haben. Damit meine ich nicht, dass man wissen muss, wie jede Funktion intern funktioniert, sondern eher, dass man die Zuständigkeit der Hauptmodule kennt und weiß, wie sie miteinander verbunden sind.

Ein visuelles Hilfsmittel, um die globale Struktur kennenzulernen, kann ein Call Graph sein. Ein Call Graph ist ein gerichteter Graph, der anzeigt, welche Funktion welche Funktion aufruft. Er wird aus den Daten eines Python-Profilers wie cProfile erstellt.

Da sich ein solcher Graph in einem Projekt, an dem ich arbeite, als hilfreich erwiesen hat, habe ich ein Paket namens project_graph erstellt, das einen solchen Call Graph für ein beliebiges Python-Skript erstellt. Das Paket erstellt ein Profil des gegebenen Skripts über cProfile, konvertiert es in einen gefilterten Punktgraphen über gprof2dot und exportiert es schließlich als .png-Datei.

Warum sind Projektgrafiken nützlich?

Als erstes kleines Beispiel soll dieses einfache Modul dienen.

# test_script.py

import time
from tests.goodnight import sleep_five_seconds

def sleep_one_seconds():
    time.sleep(1)

def sleep_two_seconds():
    time.sleep(2)

for i in range(3):
    sleep_one_seconds()

sleep_two_seconds()

sleep_five_seconds()

Nach der Installation (siehe unten) wird durch Eingabe von project_graph test_script.py in die Kommandozeile die folgende png-Datei neben dem Skript platziert:

Das zu profilierende Skript dient immer als Ausgangspunkt und ist die Wurzel des Baums. Jedes Kästchen ist mit dem Namen einer Funktion, dem Gesamtprozentsatz der in der Funktion verbrachten Zeit und der Anzahl ihrer Aufrufe beschriftet. Die Zahl in Klammern gibt an, wieviel Zeit innerhalb einer Funktion verbracht wurde, jedoch ohne die Zeit in weiteren Unterfunktion zu berücksichtigen.

In diesem Fall wird die gesamte Zeit in der Funktion sleep des externen Moduls time verbracht, weshalb die Zahl 0,00% beträgt. In selbstgeschriebenen Funktionen wird nur selten viel Zeit verbracht, da die Arbeitslast eines Skripts in der Regel schnell auf sehr einfache Funktionen der Python-Implementierung selbst rausläuft. Neben den Pfeilen ist auch die Zeit angegeben, die eine Funktion an die andere weitergibt, zusammen mit der Anzahl der Aufrufe. Die Farben (ROT-GRÜN-BLAU, absteigend) und die Dicke der Pfeile zeigen die Relevanz der verschiedenen Stellen im Programm an.

Beachten Sie, dass sich die Prozentsätze der drei obigen Funktionen nicht zu 100 % aufaddieren. Der Grund dafür ist, dass der Graph so eingestellt ist, dass er nur selbst geschriebene Funktionen enthält. In diesem Fall hat das Importieren des Moduls time den Python-Interpreter dazu veranlasst, 0,04% der Zeit für eine Funktion des Moduls importlib aufzuwenden.

Auswertung mit externen Packages

Betrachten wir ein zweites Beispiel:

# test_script_2.py

import pandas as pd
from tests.goodnight import sleep_five_seconds

# some random madness
for i in range(1000):
   a_frame = pd.DataFrame([[1,2,3]])

sleep_five_seconds()

In diesem Skript wird ein Teil der Arbeit in einem externen Paket erledigt, das auf der Top-Ebene und nicht in einer benutzerdefinierten Funktion aufgerufen wird. Um dies im Graphen zu erfassen, können wir das externe Paket (pandas) mit der Flag -x hinzufügen. Die Initialisierung eines Pandas DataFrame wird jedoch in vielen Pandas-internen Funktionen durchgeführt. Offen gesagt, bin ich persönlich nicht an den inneren Verwicklungen von pandas interessiert, weshalb ich möchte, dass der Baum nicht zu tief in die Pandas-Mechanik „hineinwächst“. Diesem Umstand kann man Rechnung tragen, indem man nur Funktionen auftauchen lässt, die einen minimalen Prozentsatz der Laufzeit in ihnen verbringen. Genau dies kann mit der -m-Flag erreicht werden.

In Kombination ergibt project_graph -m 8 -x pandas test_script_2.py das folgende Ergebnis:

Project Graph Creation Example 02

Spaß(-Beispiele) beiseite, nun wollen wir uns ernsteren Dingen zuwenden. Ein echtes Data Science Projekt könnte wie dieses aussehen:

Project Graph Creation Example 03

Dieses Mal ist der Baum viel größer. Er ist sogar noch größer als in der Abbildung zu sehen, da viel mehr selbst geschriebene Funktionen aufgerufen werden. Sie werden jedoch aus Gründen der Übersichtlichkeit aus dem Baum entfernt, da Funktionen, für die weniger als 0,5 % der Gesamtzeit aufgewendet werden, herausgefiltert werden (dies ist die Standardeinstellung für die -m Flag). Beachten Sie, dass ein solches Diagramm auch bei der Suche nach Leistungsengpässen sehr vorteilhaft ist. Man sieht sofort, welche Funktionen den größten Teil der Arbeitslast tragen, wann sie aufgerufen werden und wie oft sie aufgerufen werden. Das kann Sie davor bewahren, Ihr Programm an den falschen Stellen zu optimieren und dabei den Elefanten im Raum zu übersehen.

Wie man project graph verwendet

Installation

Gehen Sie in Ihrer Projektumgebung wie folgt vor:

brew install graphviz

pip install git+https://github.com/fior-di-latte/project_graph.git

Verwendung

Wechseln Sie in der Projektumgebung in das aktuelle Arbeitsverzeichnis des Projekts (das ist wichtig!) und geben Sie für die Standardverwendung ein:

project_graph myscript.py

Wenn Ihr Skript einen argparser enthält, verwenden Sie (vergessen Sie nicht die Anführungsstriche!):

project_graph "myscript.py <arg1> <arg2> (...)"

Wenn Sie den gesamten Graphen sehen wollen, einschließlich aller externen Pakete, verwenden Sie:

project_graph -a myscript.py

Wenn Sie eine andere Sichtbarkeitsschwelle als 1% verwenden wollen, benutzen Sie:

project_graph -m <percent_value> myscript.py

Wenn Sie schließlich externe Pakete in den Graphen aufnehmen wollen, können Sie sie wie folgt angeben:

project_graph -x <package1> -x <package2> (...) myscript.py

Schluss & Hinweise

Dieses Paket hat einige Schwächen, von denen die meisten behoben werden können, z.B. durch Formatierung des Codes in einen funktionsbasierten Stil, durch Trimmen mit der -m-Flag oder durch Hinzufügen von Paketen mit der-x-Flag. Wenn etwas seltsam erscheint ist der erste Schritt wahrscheinlich die Verwendung der -a-Flag zur Fehlersuche. Wesentliche Einschränkungen sind die folgenden:

  • Es funktioniert nur auf Unix-Systemen.
  • Es zeigt keinen wahrheitsgetreuen Graphen an, wenn es mit Multiprocessing verwendet wird. Der Grund dafür ist, dass cProfile nicht mit Multiprocessing kompatibel ist. Wenn Multiprocessing verwendet wird, wird nur der Root-Prozess profiliert, was zu falschen Berechnungszeiten im Graphen führt. Wechseln Sie zu einer nicht-parallelen Version des Zielskripts.
  • Die Profilerstellung eines Skripts kann zu einem beträchtlichen Overhead bei der Berechnung führen. Es kann sinnvoll sein, die in Ihrem Skript geleistete Arbeit zu verringern (d. h. die Menge der Eingabedaten zu reduzieren). In diesem Fall kann die in den Funktionen verbrachte Zeit natürlich massiv verzerrt werden, wenn die Funktionen nicht linear skalieren.
  • Verschachtelte Funktionen werden im Diagramm nicht angezeigt. Insbesondere ein Dekorator verschachtelt implizit Ihre Funktion und versteckt sie daher. Das heißt, wenn Sie einen externen Dekorator verwenden, vergessen Sie nicht, das Paket des Dekorators über die-x Flag hinzuzufügen (zum Beispiel project_graph -x numba myscript.py).
  • Wenn Ihre selbst geschriebene Funktion ausschließlich von einer Funktion eines externen Pakets aufgerufen wird, müssen Sie das externe Paket manuell mit der -x Flag hinzufügen. Andernfalls wird Ihre Funktion nicht im Baum auftauchen, da ihr Parent eine externe Funktion ist und daher nicht berücksichtigt wird.

Sie können das kleine Paket gerne für Ihr eigenes Projekt verwenden, sei es für Leistungsanalysen, Code-Einführungen für neue Teammitglieder oder aus reiner Neugier. Was mich betrifft, so finde ich es sehr befriedigend, eine solche Visualisierung meiner Projekte zu sehen. Wenn Sie Probleme bei der Verwendung haben, zögern Sie nicht, mich auf Github zu kontaktieren (https://github.com/fior-di-latte/project_graph/).

PS: Wenn Sie nach einem ähnlichen Paket in R suchen, sehen Sie sich Jakobs Beitrag über Flussdiagramme von Funktionen an.

Du willst Python lernen? Oder bist du ein R-Profi und dir entfallen bei der Arbeit mit Python regelmäßig die wichtigen Funktionen und Befehle? Oder vielleicht brauchst du von Zeit zu Zeit eine kleine Gedächtnisstütze beim Programmieren? Genau dafür wurden Cheatsheets erfunden!

Cheatsheets helfen dir in all diesen Situationen weiter. Unser erstes Cheatsheet mit den Grundlagen von Python ist der Start einer neuen Blog-Serie, in der weitere Cheatsheets in unserem einzigartigen STATWORX Stil folgen werden.

Du kannst also neugierig sein auf unsere Serie von neuen Python-Cheatsheets, die sowohl Grundlagen als auch Pakete und Arbeitsfelder, die für Data Science relevant sind, behandeln werden.

Unsere Cheatsheets stehen euch zum kostenfreien Download frei zur Verfügung, ohne Anmeldung oder sonstige Paywall.

Warum haben wir neue Cheatsheets erstellt?

Als erfahrene R User sucht man schier endlos nach entsprechend modernen Python Cheatsheets, ähnlich denen, die du von R Studio kennst.

Klar, es gibt eine Vielzahl von Cheatsheets für jeden Themenbereich, die sich aber in Design und Inhalt stark unterscheiden. Sobald man mehrere Cheatsheets in unterschiedlichen Designs verwendet, muss man sich ständig neu orientieren und verliert so insgesamt viel Zeit. Für uns als Data Scientists ist es wichtig, einheitliche Cheatsheets zu haben, anhand derer wir schnell die gewünschte Funktion oder den Befehl finden können.

Diesem nervigen Zusammensuchen von Informationen wollen wir entgegenwirken. Daher möchten wir auf unserem Blog zukünftig regelmäßig neue Cheatsheets in einer Designsprache veröffentlichen – und euch alle an dieser Arbeitserleichterung teilhaben lassen.

 

Was enthält das erste Cheatsheet?

Unser erstes Cheatsheet in dieser Reihe richtet sich in erster Linie an Python-Neulinge, an R-Nutzer, die Python seltener verwenden, oder an Leute, die gerade erst anfangen, mit Python zu arbeiten.

Es erleichtert den Einstieg und Überblick in Python. Die grundlegende Syntax, die Datentypen und der Umgang mit diesen werden vorgestellt und grundlegende Kontrollstrukturen eingeführt. So kannst du schnell auf die Inhalte zugreifen, die du z.B. in unserer STATWORX Academy gelernt hast oder dir die Grundlagen für dein nächstes Programmierprojekt ins Gedächtnis rufen.

Was behandelt das STATWORX Cheatsheet Episode 2?

Das nächste Cheatsheet behandelt den ersten Schritt eines Data Scientists in einem neuen Projekt: Data Wrangling. Außerdem erwartet dich ein Cheatsheet für pandas über das Laden, Auswählen, Manipulieren, Aggregieren und Zusammenführen von Daten. Happy Coding!

Data Science ist in aller Munde, doch wie lässt es sich am besten im Unternehmen einsetzen? Was muss man bei der Planung eines AI Projektes beachten? Was sind die Risiken, und was sind die potenziellen Vorteile? Es sind genau diese Fragen, mit welchen sich die Studierenden der Dualen Hochschulen Baden-Württemberg im Rahmen der Gastvorlesung mit STATWORX auseinandergesetzt haben.

Inhalt

Aspekte der Vorlesung

Anfang Juni haben unser COO Fabian Müller und Data Science Consultant Paul Mora eine Vorlesung im Rahmen des Wirtschaftsingenieur-Studienganges der DHBW gehalten. Der Fokus der Vorlesung war es, den Studierenden bewusst zu machen welche Aspekte es bei der Planung und Evaluierung eines Data Science Projektes zu beachten gilt. Neben den finanziellen Risiken wurde hierbei auch explizit auf die ethischen Fragen der Nutzung von Künstlicher Intelligenz eingegangen.

Fabian Müller, COO bei STATWORX, hält regelmäßig Vorträge an Hochschulen & Universitäten, um aktiv Aufklärung zum Thema künstliche Intelligenz zu betreiben.

Eine unserer Missionen bei STATWORX ist es, unser Wissen mit der Gesellschaft zu teilen. Vorträge an Hochschulen und Universitäten sind dabei eine tolle Möglichkeit, die Generation von morgen für Vorteile und Risiken von KI zu sensibilisieren. 

Hands-on Case Study

Als benotete Hausaufgabe haben sich die Studierenden dann in Gruppen aufgeteilt und einen selbst erdachten Data Science Use Case im Rahmen eines Unternehmens bewertet. Eine besonders gute Bearbeitung der Aufgabe ist dem Team von Christian Paul, Mark Kekel, Sebastian Schmidt und Moritz Brüggemann gelungen. Wie im folgenden Abstract beschrieben, widmete sich das Team der Überlegung des Einsatzes von Data Science bei der Vorhersage von Kundenbestellungen.

Consultant Paul Mora erklärt den Studierenden der DHBW den AI Project Canvas.

Abstract: Anwendung künstlicher Intelligenz im Kontext eines fiktiven Unternehmens

Die vorliegende Fallstudie gibt einen Überblick über die Möglichkeiten einer KI-gesteuerten Problemlösung anhand des fiktiven und aufstrebenden Unternehmens aus dem Bereich der Wintersportausrüster. Hierbei wurden vier unterschiedliche Use-Cases, die von der Nutzung einer KI profitieren, innerhalb einer Machbarkeits-Wirkungs-Matrix analysiert und das Konzept eines KI-gesteuerten After-Sales-Managements priorisiert.

Bezüglich des After-Sales-Managements wurden bis dato keine innovativen Methoden zur Verkaufsförderung entwickelt. Lediglich die Versendung von Gutscheinen, vier Wochen nach Erhalt der Bestellung, findet bereits Anwendung. Dies stellt hierbei jedoch keine adäquate Lösung für eine langfristige Kundenbindung dar. Mithilfe konzentrierter Rabatt- oder Gutscheinaktionen sollen Kunden zukünftig zum richtigen Zeitpunkt zu einem erneuten Kauf der Produkte angeregt werden. Der richtige Zeitpunkt, also der Fälligkeitstag, an dem der Bedarf des Kunden auftritt, soll hierbei unter der Verwendung von KI fortlaufend ermittelt werden. Unter dem Einsatz der KI erhofft sich das Management den Customer Journey nachvollziehen und diesen zukünftig vorhersagen zu können. Die absatzsteigernde Maßnahme basiert dabei auf dem von Daniel Kahnemann und Vernon L. Smith entwickeltem Konzept der deskriptiven Entscheidungstheorie, welche empirisch darstellt, wie Entscheidungen in der Realität getroffen werden. Die deskriptive Entscheidungstheorie definiert dabei Anreize zur richtigen Zeit, um gegenwärtige Bedürfnisse/ Bedarfe zu stillen, als einen zentralen Aspekt in der Entscheidungsfindung eines Entscheidungsträgers.

Das Data Sciences Model Canvas wurde hierbei als Werkzeug zur Strukturierung des Implementierungsprozesses der KI innerhalb des Unternehmens gewählt. Dabei soll das vorliegende Machine-Learning-Problem, unter dessen Verwendung zukünftige Bestelltermine der Kunden vorausgesagt werden sollen, mithilfe des sogenannten „Supervised Learnings“ bearbeitet werden. Übergreifend versucht der Algorithmus eine Hypothese zu finden, die möglichst zielsichere Annahmen trifft, wobei es sich unterkategorisiert um ein Regressionsproblem handelt. Richtig umgesetzt, werden Kunden bereits zum Zeitpunkt, an dem ihr Bedarf auftritt, mithilfe von konzentrierten Rabattaktionen zu einem Kauf angeregt. Dies ermöglicht unter anderem auch die Bindung hybrider Kunden, deren Nachfrageverhalten zwar wechselhaft ist, jedoch latent beeinflusst werden kann. Der Einsatz eines intelligenten After-Sales-Management-Systems ermöglicht somit eine langfristige Markt- und Kundenorientierung.

Interesse geweckt?

Den voll ausgearbeiteten Bericht sowie eine kurze und prägnante Management-Präsentation könnt Ihr euch nachfolgend herunterladen. Der Bericht zeigt, wie man Data Science effektiv innerhalb ein Unternehmen verwenden kann, um Kundenbeziehung zu stärken und Entscheidungen fundierter zu treffen. Des Weiteren präsentiert der Bericht drei weitere potentielle Einsatzmöglichkeiten von AI und wägt dessen Vorteile und Nachteile durch das AI Project Canvas ab.

Management Summary

Mit Kubernetes steht uns eine Technologie zur Verfügung, welche in vielerlei Hinsicht die Bereitstellung und Wartung von Anwendungen und Rechenlasten, insbesondere das Training und Hosten von Machine Learning Modellen, enorm vereinfacht. Gleichzeitig ermöglicht sie uns, die benötigten Hardware-Ressourcen dazu an den Bedarf anzupassen, und bietet damit eine skalierbare und kostentransparente Lösung.

Dieser Artikel behandelt zuerst den Weg vom Server hin zu dem Management und der Orchestrierung von Containern: isolierte Anwendungen oder Modelle, welche mit all ihren Anforderungen einmal verpackt werden und im Anschluss fast überall ausgeführt werden können. Unabhängig vom Server können diese mit Kubernetes beliebig repliziert werden und ermöglichen somit aufwandslos und schier nahtlos eine durchgehende Erreichbarkeit ihrer Dienste auch unter hoher Last. Ebenfalls kann ihre Anzahl bis auf einen Mindeststand reduziert werden, wenn die Nachfrage vorübergehend oder periodisch schwindet, um Rechenressourcen anderweitig zu nutzen oder unnötige Kosten zu vermeiden.

Aus den Möglichkeiten dieser Infrastruktur geht ein nützliches Architektur-Paradigma hervor, die Microservices. Ehemals zentralisierte Anwendungen werden so in ihre Funktionalitäten heruntergebrochen, welche ein hohes Maß an Wiederverwendbarkeit bieten. Diese können von unterschiedlichen Diensten angesprochen und verwendet werden und skalieren einzeln je nach internem Bedarf. Ein Beispiel hierfür sind große und komplexe Sprachmodelle im Natural Language Processing, welche den Kontext eines Textes unabhängig von dessen weiterer Verwendung erfassen können, und damit vielen downstream Zwecken zugrunde liegen. Andere Microservices (Modelle), wie zur Text-Klassifikation oder Zusammenfassung, können diese aufrufen und die Teilergebnisse weiterverarbeiten.

Nach einer kurzen Einführung der allgemeinen Begrifflichkeiten und Funktionsweise von Kubernetes, sowie mögliche Anwendungsfälle, richtet sich das Augenmerk auf die am weitesten verbreitete Form Kubernetes zu nutzen: mit Cloud Anbietern wie Google GCP, Amazon AWS oder Microsoft Azure. Diese erlauben sog. Kubernetes Clustern, dynamisch mehr oder weniger Ressourcen zu beanspruchen, wenn gleich die entstehenden Kosten auf pay-per-use Basis absehbar bleiben. Auch weitere gängige Dienste wie Datenspeicher, Versionierung und Networking können von den Anbietern einfach eingebunden werden. Letztlich gibt der Beitrag noch einen Ausblick über Tools und Weiterentwicklungen, welche entweder die Nutzung von Kubernetes noch effizienter machen oder das Verfahren hin zu Serverless Architekturen weiter abstrahieren und vereinfachen.

Inhalt

Einleitung

Über die letzten 20 Jahre sind Unmengen neuer Technologien in der Softwareentwicklung und -Bereitstellung zu Tage gekommen, welche nicht nur die Auswahl an Diensten, Programmiersprachen, Bibliotheken oder ähnliches vervielfacht und diversifiziert haben, sondern gar auch bei vielen Anwendungsfällen oder -Gebieten bis hin zu einem Paradigmenwechsel geführt haben.

Abb. 1: Google Trends Chart, die den zuvor genannten Paradigmenwechsel verdeutlicht.

Betrachtet man so auch die Art und Weise der Bereitstellung von Softwarelösungen, Modellen oder Rechen- und Arbeitslasten über die Jahre, lässt sich erkennen wie auch in diesem Bereich die Neuerungen u.a. zu mehr Flexibilität, Skalierbarkeit und Ressourceneffizienz geführt haben.

Zu Beginn wurden diese als lokale Prozesse direkt auf einem (von mehreren Anwendungen geteilten) Server betrieben, was einige Einschränkungen und Probleme aufwarf: zum einen ist man bei der Auswahl der technischen Werkzeuge an die Begebenheiten der Server und deren Betriebssystem gebunden, zum anderen sind alle Anwendungen, welche auf dem Server gehostet werden, durch dessen Speicher- und Prozessorkapazitäten begrenzt. Somit teilen sie sich nicht nur in Summe die Ressourcen, sondern auch eine eventuelle Prozess-übergreifende Fehleranfälligkeit.

Als erste Weiterentwicklung können Virtuelle Maschinen daraufhin eine weitere Abstraktionsebene bieten: durch das auf dem Server aufgesetzte Emulieren („Virtualisieren“) einer eigenständigen Maschine entsteht für die Entwicklung und das Deployment Modularität und damit größere Freiheit: zum Beispiel in der Wahl des Betriebssystems oder der verwendeten Programmiersprachen und -Bibliotheken. Aus Sicht des „echten“ Servers können die Ressourcen, welche der Anwendung zustehen sollen, besser beschränkt bzw. garantiert werden. Jedoch sind deren Anforderungen auch bedeutend höher, da die Virtuelle Maschine auch das virtuelle Betriebssystem unterhalten muss.

Letztendlich wurde dieses Prinzip durch die Verbreitung von Containern, vor allem Docker, wesentlich verschlankt und vereinfacht. Vereinfacht gesagt baut/konfiguriert man für eine Anwendung oder ein Machine Learning Modell einen eigenen virtuellen, abgegrenzten Server. So enthält jeder Container sein eigenes Dateisystem und gewisse Systembibliotheken, aber nicht das Betriebssystem. Damit wird er technisch zu einem Sandkasten, dessen andere Konfiguration, Code-Abhängigkeiten oder Fehler sich nicht auf den Host-Server auswirken, aber gleichzeitig als relativ „leichtgewichtige“ Prozesse direkt auf diesem laufen können.

_Vergleich virtuelle Maschine und Docker Container Architektur
Abb. 2: Vergleich zwischen Virtueller Maschine und Docker Container Systemarchitektur, Quelle: https://i1.wp.com/www.docker.com/blog/wp-content/uploads/Blog.-Are-containers-..VM-Image-1-1024×435.png?ssl=1

Es besteht also die Möglichkeit, alles für die gewünschte Anwendung zu kopieren, installieren, usw., und dies in einem verpackten Container überall in einem einheitlichen Format bereitzustellen. Dies ist nicht nur für das Produktionsumfeld extrem nützlich, sondern findet bei STATWORX auch gerne in der Entwicklung von komplizierteren Projekten oder der Proof-of-Concept Phase Gebrauch. Zwischenschritte oder -Ergebnisse, wie beispielsweise die Extraktion von Text aus Bildern, können als Container wie ein kleiner Webserver von denjenigen verwendet werden, die an der Weiterverarbeitung des Textes interessiert sind, etwa die Extraktion gewisser zentraler Informationen, oder die Bestimmung von dessen Stimmung oder Absicht.

Diese Unterteilung in sogenannte „Microservices“ mit Hilfe von Containern hilft ungemein bei der Wiederverwendbarkeit der einzelnen Module, bei der Planung und Entwicklung der Architektur komplexer Systeme; sie befreit gleichzeitig die einzelnen Arbeitsschritte von technischen Abhängigkeiten gegenübereinander und erleichtert die Wartungs- und Update-Prozeduren.

Nach diesem kleinen Überblick über die mächtigen und vielseitigen Möglichkeiten der Bereitstellung von Software wird sich der folgende Text damit beschäftigen, wie man diese Container (sprich Anwendungen oder Modelle) verlässlich und skalierbar für Kunden, andere Anwendungen, interne Dienste oder Berechnungen mit Kubernetes bereitstellen kann.

Kubernetes – 8 wesentliche Komponenten

Kubernetes wurde 2014 von Google als open-source Container-Management Software (auch Container-Orchestrierung genannt) vorgestellt. Intern benutzte man bereits seit Jahren eigens entwickelte Tools, um Arbeitslasten und Anwendungen zu verwalten, und sah in der Entwicklung von Kubernetes nicht nur das Zusammenkommen von best practises und lessons learned, sondern auch die Möglichkeit damit ein neues Geschäftsfeld im Cloud Computing zu erschließen.

Der Name Kubernetes (griechisch für Steuermann) wurde angeblich in Bezug auf ein symbolisches Containerschiff ausgewählt, für dessen optimalen Betrieb jener verantwortlich ist.

1.    Nodes

Spricht man von einer Kubernetes-Instanz, wird sie als (Kubernetes) Cluster bezeichnet: dieses besteht aus mehreren Servern, genannt Nodes. Eine davon, die sogenannte Master-Node, ist komplett für den administrativen Betrieb zuständig, und ist die Schnittstelle, welche vom Entwickler angesprochen wird. Alle weiteren, genannt Worker-Nodes, sind zu Beginn unbelegt und damit flexibel einsetzbar. Während Nodes tatsächlich physische Instanzen sind, meist in Rechenzentren, sind die nun folgenden Begrifflichkeiten Konzepte von Kubernetes.

2.    Pods

Soll eine Anwendung auf dem Cluster bereitgestellt werden, wird im einfachsten Fall der gewünschte Container angegeben, und daraufhin (automatisch) ein sogenannter Pod erstellt und einer Node zugewiesen. Der Pod ähnelt hier einfach einem laufenden Container. Sollen gleich mehrere Instanzen der gleichen Anwendung parallel laufen, etwa um bessere Verfügbarkeit zu bieten, kann die Anzahl der Replicas angegeben werden. Hierbei wird die spezifizierte Anzahl an Pods mit jeweils derselben Anwendung auf die Nodes verteilt. Sollte der Bedarf nach der Anwendung trotz Replicas die Kapazitäten übersteigen, können mit dem Horizontal Autoscaler automatisch noch mehr Pods erstellt werden. Besonders bei Deep Learning Modellen mit verhältnismäßig langer Inferenzzeit können hier Metriken wie CPU- oder GPU-Auslastung überwacht werden, und die Anzahl der Pods vergrößert oder verringert werden, um sowohl Kapazitäten als auch Kosten zu optimieren.

Illustration des Autoscaling und der Belegung der Nodes
Abb. 3: Illustration des Autoscaling und der Belegung der Nodes. Die Breite der Balken entspricht dem Ressourcenbedarf der Pods bzw. der Kapazität der Nodes.

Um nicht zu verwirren: Letztlich ist jeder laufende Container, also jede Arbeitslast, ein Pod. Im Falle der Bereitstellung einer Anwendung geschieht das technisch über ein Deployment, zeitlich begrenzte Rechenlasten sind hingegen Jobs. Persistente Speicher wie Datenbanken werden mit StatefulSets verwaltet. Die folgende Abbildung gibt einen Überblick über die Begriffe:

Deployment-Controller
Abb. 4: Was ist was in Kubernetes? Im Deployment wird angegeben was gewünscht ist; der Deployment-Controller kümmert sich um das Erstellen, den Erhalt und das Skalieren der Modell-Container, welche denn als einzelne Pods auf den Nodes laufen. Jobs und StatefulSets funktionieren analog mit ihrem eigenen Controller.

3.    Jobs

Mit Kubernetes Jobs können sowohl einmalige als auch wiederkehrende Jobs (sog. CronJobs) in Form eines Container-Deployment auf dem Cluster ausgeführt werden.

Im einfachsten Fall können diese wie ein Skript gesehen werden, welches für Wartungs- oder Aufbereitungsarbeiten von beispielsweise Datenbanken genutzt werden kann. Des Weiteren verwendet man diese auch zum Batch-Processing, wenn zum Beispiel Deep Learning Modelle auf größere Datenmengen angewandt werden sollen und es sich aber nicht lohnt das Modell durchgehend auf dem Cluster zu halten. Der Modell-Container wird hier eigens hochgefahren, erhält Zugriff auf das gewünschte Dataset, führt seine Inferenz darüber aus, speichert die Ergebnisse und fährt sich herunter. Auch für die Herkunft und anschließende Speicherung der Daten ist man hier flexibel, so können eigene oder Cloud Datenbanken, Bucket/Objekt-Speicher oder auch lokale Daten und Logging-Frameworks angebunden werden.

Für wiederkehrende CronJobs kann ein einfaches Zeitschema spezifiziert werden, sodass beispielsweise nachts bestimmte Kundendaten, -transaktionen oder ähnliches verarbeitet werden. Mit Natural Language Processing können so zum Beispiel nachts automatisch Pressespiegel erstellt werden, welche am folgenden Morgen ausgewertet bereitstehen: Nachrichten zu einem Unternehmen, dessen Branche, Wirtschaftsstandorte, Kunden, usw. können aggregiert oder bezogen werden, mit NLP ausgewertet, zusammengefasst, und mit Stimmungsbildern präsentiert oder nach Themen/Inhalten geordnet werden.

Auch arbeitsintensive ETL (Extract Transform Load) Prozesse können so außerhalb der Geschäftszeiten durchgeführt oder vorbereitet werden.

4.    Rolling Updates

Soll ein Deployment auf die neuste Version gebracht werden, oder muss ein Rollback auf eine ältere Version vollzogen werden, können in Kubernetes Rolling Updates angestoßen werden. Diese garantieren durchgehende Erreichbarkeit der Anwendungen und Modelle innerhalb einer Continuous Integration/Continuous Deployment Pipeline.

Ein solches Rollout kann reibungslos in einem oder wenigen Schritten angestoßen und überwacht werden. Durch eine Rollout-History besteht auch die Möglichkeit, nicht nur auf eine vorherige Containerversion zurückzuspringen, sondern auch die vorherigen Deployment-Parameter wiederhergestellt werden, sprich Mindest- und Höchstanzahl der Nodes, welche Ressourcengruppe (GPU Nodes, CPU Nodes mit wenig/viel RAM,…), Health-Checks usw.

Wird ein Rolling Update angestoßen, werden die jeweiligen bestehenden Pods so lange am Laufen und erreichbar gehalten, bis dieselbe Anzahl an neuen Pods hochgefahren und zugänglich sind. Hier gibt es sowohl Methoden, um zu garantieren, dass keine Requests verloren gehen, wie auch Parameter, die für den Wechsel eine Mindesterreichbarkeit oder einen maximalen Überschuss an Pods regeln.

Illustration eines Rolling Updates
Abb. 5: Illustration eines Rolling Updates.

Die Abbildung 5 veranschaulicht das Rolling Update.

1) Die bisher aktuelle Version einer Anwendung liegt mit 2 Replicas auf dem Kubernetes Cluster und ist gewohnt ansprechbar.

2) Ein Rolling Update auf Version V2 wird gestartet, dieselbe Anzahl an Pods wie für V1 werden erstellt.

3) Sobald die neuen Pods den Zustand „Running“ haben und ggf. Health-Checks absolviert wurden, damit also funktional sind, werden die Container der älteren Version heruntergefahren.

4) Die älteren Pods sind entfernt und die Ressourcen wieder freigegeben.

Der DevOps- und Zeitaufwand ist hierbei marginal, intern ändern sich keine Hostnamen oder ähnliches, während der Dienst aus Sicht der Konsumierenden wie bisher in gewohnter Weise ansprechbar ist (gleiche IP, URL, …) und lediglich auf die neuste Version gebracht wurde.

5.    Platform/Infrastructure as a Service

Natürlich lässt sich ein Kubernetes Cluster auch lokal auf eigener Hardware on-premises einrichten sowie auf teilweise vorgefertigten Lösungen wie DGX Workbenches.

Einige unserer Kunden haben strikte Richtlinien oder Auflagen bezüglich (Data-) Compliance oder Informationssicherheit, und möchten nicht, dass möglicherweise sensible Daten das Unternehmen verlassen. Weiterhin kann so vermieden werden, dass der Datenverkehr über nicht-europäische Knotenpunkte fließt oder generell in ausländischen Rechenzentren landet.

Erfahrungsgemäß ist dies aber nur in einem sehr geringen Anteil der Fall. Durch Verschlüsselung, Rechtemanagement und SLAs der Betreiber erachten wir die Verwendung von Cloud-Diensten und -Rechenzentren als allgemein sicher und verwenden diese auch für größere Projekte. Diesbezüglich sind auch Deployment, Wartung, CI/CD Pipelines dank Methoden der Containerization (Docker) und Abstraktion (Kubernetes) größtenteils identisch und einfach zu verwenden.

Alle großen Cloud-Betreiber wie Google (GCP), Amazon (AWS) und Microsoft (Azure), aber auch kleinere Anbieter und bald sogar spannende neue deutsche Projekte, bieten sehr ähnliche Kubernetes Dienste an. Dadurch wird es noch einfacher, ein Projekt oder Modell bereitzustellen und vor allem zu skalieren, da durch auto-scaling das Cluster je nach Ressourcenbedarf erweitert oder verkleinert werden kann. Dies entbindet uns aus technischer Sicht größtenteils davon die Nachfrage eines Dienstes abschätzen zu müssen, während die Rentabilität und Kostenstruktur gleichbleiben. Weiterhin können die Dienste auch in unterschiedlichen (geographischen) Zonen gehostet und betrieben werden, um schnellste Erreichbarkeit und Redundanz zu garantieren.

6.    Node-Vielfalt

Die Cloud-Betreiber bieten eine große Anzahl unterschiedlicher Node-Typen an, um für alle Anwendungsfälle vom einfacheren Webservice bis hin zu High Performance Computing alle Ressourcenanforderungen zu befriedigen. Besonders im Anwendungsfeld Deep Learning lassen sich so die immer größer werdenden Modelle stets auf der benötigten neuesten Hardware trainieren und bereitstellen.

Während wir beispielsweise für kleinere NLP Zwecke Nodes mit einer durchschnittlichen CPU und geringem Arbeitsspeicher verwenden, lassen sich große Transformer-Modelle im gleichen Cluster auf GPU-Nodes deployen, was deren Verwendung effektiv erst ermöglicht und gleichzeitig die Inferenz (Anwendung des Modells) um Faktor 20 beschleunigen kann. Da neuerdings die Bedeutung dedizierter Hardware für Neuronale Netze stetig zunimmt, bietet Google auch Zugriff auf die eigens entwickelten, für Tensorflow optimierten TPUs an.

Die Organisation und Gruppierung all dieser unterschiedlichen Nodes erfolgt in Kubernetes in sog. Node Pools. Diese können im Deployment ausgewählt bzw. angegeben werden, sodass den Pods der Modelle die richtigen Ressourcen zugeteilt werden.

7.    Cluster Autoscaling

Das Ausmaß der Nutzung von Modellen oder Diensten, intern oder durch Kunden, ist oftmals nicht absehbar oder schwankt zeitlich stark. Mit einem Cluster Autoscaler können automatisch neue Nodes erstellt werden, oder nicht benötigte „leerstehende“ Nodes entfernt werden. Auch hier kann ein Minimum an Nodes angegeben werden, welche immer bereitstehen sollen sowie, wenn gewünscht, auch eine maximale Anzahl, die nicht überschritten werden kann, um ggf. die Kosten zu deckeln.

8.    Anbindung anderer Dienste

Prinzipiell können Cloud Dienste verschiedener Anbieter kombiniert werden, komfortabler und einfacher ist jedoch die Nutzung eines Anbieters (Beispiel Google GCP). Somit können Dienste wie Datenbuckets, Container-Registry, Lambda Funktionen Cloud-intern ohne große Authentifizierungsprozesse eingebunden und verwendet werden. Des Weiteren ist gerade in einer Microservice-Architektur die Netzwerkkommunikation unter den einzelnen Hosts (Anwendungen, Modelle) wichtig und innerhalb eines Anbieters erleichtert. Hier kann auch Zugangskontrolle/RBAC implementiert werden, sowie mehrere Cluster oder Projekte mit einem Virtuellen Netzwerk überbrückt werden, um die Zuständigkeits- und Kompetenzbereiche besser zu trennen.

Umfeld und zukünftige Entwicklungen

Die steigende Nutzung und Verbreitung von Kubernetes haben ein ganzes Umfeld an nützlichen Tools, wie auch Weiterentwicklungen und weitere Abstraktionen mit sich gebracht, welche dessen Verwendung weiter erleichtern.

Tools und Pipelines basierend auf Kubernetes

Mit Kubeflow lässt sich beispielsweise das Training von Machine Learning Modellen als TensorFlow Training Job anstoßen und fertige Modelle mit TensorFlow Serving bereitstellen.

Der ganze Prozess kann auch in eine Pipeline verpackt werden, welche dann mit Verweis auf Trainings-, Validation- und Testdaten in Speicherbuckets das Training verschiedener Modelle durchführt, überwacht, deren Metriken loggt und die Modell-Performance vergleicht. Der Workflow beinhaltet auch die Aufbereitung der Inputdaten, sodass nach erstmaligem Aufbau der Pipeline einfach Experimente zur Exploration von Modellarchitekturen und Hyperparameter-Tuning angestellt werden können

Serverless

Durch Serverless Deployment Verfahren wie Cloud Run oder Amazon Fargate wird ein weiterer Abstraktionsschritt weg von den technischen Anforderungen unternommen. Hiermit können Container binnen Sekunden deployed werden, und skalieren wie Pods auf einem Kubernetes Cluster, ohne dass man dieses überhaupt erstellen oder warten muss. Dieselbe Infrastruktur wurde also noch einmal in ihrer Benutzung vereinfacht. Nach dem Prinzip pay-per-use wird nur die Zeit berechnet, in welcher der Code wirklich aufgerufen und ausgeführt wird.

Fazit

Kubernetes ist heute zu einer zentralen Säule im Machine Learning Deployment geworden. Der Weg von der Daten- und Modellexploration zum Prototyp und schließlich in die Produktion ist durch Bibliotheken wie PyTorch, TensorFlow und Keras zum einen enorm verschlankt und vereinfacht worden. Gleichzeitig können diese Methoden bei Bedarf aber auch enorm detailliert verwendet werden, um maßgeschneiderte Komponenten zu entwickeln oder mittels Transfer Learning bestehende Modelle einzubinden und anzupassen. Container Technologien wie Docker erlauben im Anschluss, das Ergebnis mit all dessen Anforderungen und Abhängigkeiten zu bündeln und ohne weiteren Aufwand fast überall blitzschnell auszuführen. Im letzten Schritt ist deren Bereitstellung, Wartung und Skalierung mit Kubernetes ebenfalls ungemein vereinfacht und leistungsfähig geworden.

All dies erlaubt uns eigene Produkte sowie Lösungen für Kunden strukturiert zu entwickeln:

  • Die Komponenten und die Rahmeninfrastruktur haben eine hohe Wiederverwendbarkeit
  • Mit verhältnismäßig geringem Zeit- und Kostenaufwand kann ein erster Meilenstein oder Proof-of-Concept erreicht werden
  • Die weiterführende Entwicklungsarbeit folgt auf natürliche Weise weiter diesem Prozess
  • Fertige Deployments skalieren ohne zusätzlichen Aufwand, mit Kosten proportional zum Bedarf
  • Daraus folgt eine verlässliche Plattform mit planbarer Kostenstruktur

Wenn Sie sich im Anschluss an diesen Artikel weiter über einige zentrale Komponenten informieren möchten, haben wir hier noch einige interessante Beiträge über:

Quellen

 

Jonas Braun

Management Summary

Machine Learning Projekte zu deployen und zu überwachen ist ein komplexes Vorhaben. Neben dem konsequenten Dokumentieren von Modellparametern und den dazugehörigen Evaluationsmetriken, besteht die Herausforderung vor allem darin, das gewünschte Modell in eine Produktivumgebung zu überführen. Sofern mehrere Personen an der Entwicklung beteiligt sind, ergeben sich zusätzlich Synchronisationsprobleme in Bezug auf die Entwicklungsumgebungen und Versionsstände der Modelle. Aus diesem Grund werden Tools zum effizienten Management von Modellergebnissen bis hin zu umfangreichen Trainings- und Inferenzpipelines benötigt.

In diesem Artikel werden die typischen Herausforderungen entlang des Machine Learning Workflows dargestellt und mit MLflow eine mögliche Lösungsplattform beschrieben. Zusätzlich stellen wir drei verschiedene Szenarien dar, mit deren Hilfe sich Machine Learning Workflows professionalisieren lassen:

  1. Einsteigervariante:
    Modellparameter und Performance-Metriken werden über eine R/Python API geloggt und in einer GUI übersichtlich dargestellt. Zusätzlich werden die trainierten Modelle als Artefakt abgespeichert und können über APIs bereitgestellt werden.
  2. Fortgeschrittenes Modellmanagement:
    Neben dem Tracking von Parametern und Metriken werden bestimmte Modelle geloggt und versioniert. Dies ermöglicht ein kontrolliertes Monitoring und vereinfacht das Deployment von ausgewählten Modellversionen.
  3. Kollaboratives Workflowmanagement:
    Das Abkapseln von Machine Learning Projekten als Pakete oder Git Repositories und der damit einhergehenden lokalen Reproduzierbarkeit von Entwicklungsumgebungen, ermöglichen eine reibungslose Entwicklung von Machine Learning Projekten mit mehreren Beteiligten.

Je nach Reifegrad Ihres Machine Learning Projektes können die drei Szenarien als Inspiration für einen potenziellen Machine Learning Workflow dienen. Zum besseren Verständnis haben wir jedes Szenario detailliert ausgearbeitet und geben Empfehlungen hinsichtlich der zu verwendeten APIs und Deployment-Umgebungen.

Herausforderungen entlang des Machine Learning Workflows

Das Training von Machine Learning Modellen wird immer einfacher. Mittlerweile ermöglichen eine Vielzahl von Open Source Tools eine effiziente Datenaufbereitung sowie ein immer einfacheres Modelltraining und Deployment.

Der Mehrwert für Unternehmen entsteht vor allem durch das systematische Zusammenspiel von Modelltraining, in Form von Modellidentifikation, Hyperparametertuning und Fitting auf den Trainingsdaten, und Deployment, also dem Bereitstellen des Modells zur Berechnung von Vorhersagen. Insbesondere in frühen Phasen der Entwicklung von Machine Learning Initiativen wird dieses Zusammenspiel häufig nicht als kontinuierlicher Prozess etabliert. Ein Modell kann jedoch nur dann langfristig Mehrwerte generieren, wenn ein stabiler Produktionsprozess vom Modelltraining, über dessen Validierung bis hin zum Test und Deployment implementiert wird. Sofern dieser Prozess korrekt implementiert wird können bei der operativen Inbetriebnahme des Modells komplexe Abhängigkeiten und langfristig kostspielige Wartungsarbeiten entstehen [2]. Die folgenden Risiken sind hierbei besonders hervorzuheben

1. Gewährleistung von Synchronität

Häufig werden im explorativen Kontext Datenaufbereitungs- und Modellierungs-Workflows lokal entwickelt. Unterschiedliche Konfigurationen der Entwicklungsumgebungen oder gar der Einsatz von verschiedenen Technologien erschweren eine Reproduktion von Ergebnissen, insbesondere zwischen Entwickler*innen bzw. Teams. Zusätzlich ergeben sich potenzielle Gefahren hinsichtlich der Kompatibilität des Workflows, sofern mehrere Skripte in einer logischen Reihenfolge exekutiert werden müssen. Ohne einer entsprechenden Versionskontroll-Logik kann der Synchronisationsaufwand im Nachhinein nur mit großem Aufwand gewährleistet werden.

2. Aufwand der Dokumentation

Um die Performance des Modells zu bewerten, werden häufig im Anschluss an das Training Modellmetriken berechnet. Diese hängen von verschiedenen Faktoren ab, wie z.B. der Parametrisierung des Modells oder den verwendeten Einflussfaktoren. Diese Metainformationen über das Modell werden häufig nicht zentral gespeichert. Zur systematischen Weiterentwicklung und Verbesserung eines Modells ist es jedoch zwingend erforderlich, eine Übersicht über die Parametrisierung und Performance aller vergangenen Trainingsläufe zu haben.

3. Heterogenität von Modellformaten

Neben der Verwaltung von Modellparametern und Ergebnissen besteht die Herausforderung das Modell anschließend in die Produktionsumgebung zu überführen. Sofern verschiedene Modelle aus mehreren Paketen zum Training verwendet werden kann das Deployment aufgrund unterschiedlicher Pakete und Versionen schnell umständlich und fehleranfällig werden.

4. Wiederherstellung alter Ergebnisse

In einem typischen Machine Learning Projekt ergibt sich häufig die Situation, dass ein Modell über einen langen Zeitraum entwickelt wird. Beispielsweise können neue Features verwendet oder auch gänzlich neue Architekturen evaluiert werden. Nicht zwangsläufig führen diese Experimente zu besseren Ergebnissen. Sofern Experimente nicht sauber versioniert werden, besteht die Gefahr alte Ergebnisse nicht mehr nachbilden zu können.

Um diese und weitere Herausforderungen im Umgang und Management von Machine Learning Workflows zu lösen, wurden in den vergangenen Jahren verschiedene Tools entwickelt, wie beispielsweise TensorFlow TFX, cortex, Marvin oder MLFlow. Insbesondere letzteres ist aktuell eine der am häufigsten verwendeten Lösungen.

MLflow ist ein Open Source Projekt mit dem Ziel, das Beste aus existierenden ML Plattformen zu vereinen, um die Integration zu bestehenden ML Bibliotheken, Algorithmen und Deployment Tools so unkompliziert wie möglich zu gestalten [3]. Im Folgenden werden die wesentlichen MLflow Module vorgestellt und Möglichkeiten erörtert, mit der Machine Learning Workflows über MLflow abgebildet werden können.

MLflow Services

MLflow besteht aus vier Komponenten: MLflow Tracking, MLflow Models, MLflow Projectsund MLflow Registry. Je nach Anforderung an das Experimental- und Deployment-Szenario können alle Services gemeinsam genutzt, oder auch einzelne Komponenten isoliert werden.

Mit MLflowTracking lassen sich alle Hyperparameter, Metriken (Modell-Performance) und Artefakte, wie bspw. Charts, loggen. MLflow Tracking bietet die Möglichkeit, für jeden Trainings- oder Scoring-Lauf eines Modells Voreinstellungen, Parameter und Ergebnisse für ein kollektives Monitoring zu sammeln. Die geloggten Ergebnisse lassen sich in einer GUI visualisieren oder alternativ über eine REST API ansprechen.

Das Modul MLflow Models fungiert als Schnittstelle zwischen Technologien und ermöglicht ein vereinfachtes Deployment. Ein Modell wird je nach Typ als Binary, z.B, als reine Python-Funktion oder als Keras-, oder H2O-Modell gespeichert. Man spricht hierbei von den sogenannten model flavors. Weiterhin stellt MLflow Models eine Unterstützung zur Modellbereitstellung auf verschiedenen Machine Learning Cloud Services bereit, z.B. für AzureML und Amazon Sagemaker.

MLflow Projects dienen dazu, einzelne ML-Projekte in einem Paket oder Git-Repository abzukapseln. Die Basiskonfigurationen des jeweiligen Environments werden über eine YAML-Datei festgelegt. Über diese kann z.B. gesteuert werden, wie genau das conda-Environment parametrisiert ist, das im Falle einer Ausführung von MLflow erstellt wird. Durch MLflow Projects können Experimente, die lokal entwickelt wurden, auf anderen Rechnern in der gleichen Umgebung ausgeführt werden. Dies ist bspw. bei der Entwicklung in kleineren Teams von Vorteil.

Ein zentralisiertes Modellmanagement bietet MLflow Registry. Ausgewählte MLflow Models können darin registriert und versioniert werden. Ein Staging-Workflow ermöglicht ein kontrolliertes Überführen von Modellen in die Produktivumgebung. Der gesamte Prozess lässt sich wiederum über eine GUI oder eine REST API steuern.

Beispiele für Machine Learning Pipelines mit MLflow

Im Folgenden werden mit Hilfe der o.g. MLflow Module drei verschiedene ML Workflow-Szenarien dargestellt. Diese steigern sich von Szenario zu Szenario hinsichtlich der Komplexität. In allen Szenarien wird ein Datensatz mittels eines Python Skripts in eine Entwicklungsumgebung geladen, verarbeitet und ein Machine Learning Modell trainiert. Der letzte Schritt stellt in allen Szenarien ein Deployment des ML Modells in eine beispielhafte Produktivumgebung dar.

1. Szenario – Die Einsteigervariante

Szenario 1 – Simple Metrics TrackingSzenario 1 – Simple Metrics Tracking

Szenario 1 bedient sich der Module MLflow Tracking und MLflow Models. Hierbei können mittels der Python API die Modellparameter und Metriken der einzelnen Runs auf dem MLflow Tracking Server Backend Store gespeichert und das entsprechende MLflow Model File als Artefakt auf dem MLflow Tracking Server Artifact Store abgelegt werden. Jeder Run wird hierbei einem Experiment zugeordnet. Beispielsweise könnte ein Experiment ‚fraud_classification‘ lauten und ein Run wäre ein bestimmtes ML Modell mit einer Hyperparameterkonfiguration und den entsprechenden Metriken. Jeder Run wird zur eindeutigen Zuordnung mit einer einzigartigen RunID abgespeichert.

Artikel MLFlow Tool Bild 01

Im Screenshot wird die MLflow Tracking UI beispielhaft nach der Ausführung eines Modelltrainings dargestellt. Der Server wird im Beispiel lokal gehostet. Selbstverständlich besteht auch die Möglichkeit den Server Remote, beispielsweise in einem Docker Container, innerhalb einer VM zu hosten. Neben den Parametern und Modellmetriken werden zudem der Zeitpunkt des Modelltrainings sowie der User und der Name des zugrundeliegenden Skripts geloggt. Klickt man auf einen bestimmten Run werden zudem weitere Informationen dargestellt, wie beispielsweise die RunID und die Modelltrainingsdauer.

Artikel MLFlow Tool Bild 02

Sofern man neben den Metriken zusätzlich noch weitere Artefakte, wie bspw. das Modell, geloggt hat, wird das MLflow Model Artifact ebenfalls in der Run-Ansicht dargestellt. In dem Beispiel wurde ein Modell aus dem sklearn.svm Package verwendet. Das File MLmodel enthält Metadaten mit Informationen über die Art und Weise, wie das Modell geladen werden soll. Zusätzlich dazu wird ein conda.yaml erstellt, das alle Paketabhängigkeiten des Environments zum Trainingszeitpunkt enthält. Das Modell selbst befindet sich als serialisierte Version unter model.pklund enthält die auf den Trainingsdaten optimierten Modellparameter.

Artikel MLFlow Tool Bild 03

Das Deployment des trainierten Modells kann nun auf mehrere Weisen erfolgen. Möchte man beispielsweise das Modell mit der besten Accuracy Metrik deployen, kann der MLflow Tracking Server über die Python API mlflow.list_run_infos angesteuert werden, um so die RunID des gesuchten Modells zu identifizieren. Nun kann der Pfad zu dem gewünschten Artefakt zusammengesetzt werden und das Modell bspw. über das Python Paket pickle geladen werden. Dieser Workflow kann nun über ein Dockerfile getriggert werden, was ein flexibles Deployment in die Infrastruktur Ihrer Wahl ermöglicht. MLFlow bietet für das Deployment auf Microsoft Azure und AWS zusätzliche gesonderte APIs an. Sofern das Modell bspw. auf AzureML deployed werden soll, kann ein Azure ML Container Image mit der Python API mlflow.azureml.build_image erstellt werden, welches als Webservice nach Azure Container Instances oder Azure Kubernetes Service deployed werden kann. Neben dem MLflow Tracking Server besteht auch die Möglichkeit andere Ablagesysteme für das Artefakt zu verwenden, wie zum Beispiel Amazon S3, Azure Blob Storage, Google Cloud Storage, SFTP Server, NFS und HDFS.

2. Szenario – Fortgeschrittenes Modellmanagement

Szenario 2 – Advanced Model ManagementSzenario 2 – Advanced Model Management

Szenario 2 beinhaltet, neben den in Szenario 1 verwendeten Modulen, zusätzlich MLflow Model Registry als Modelmanagementkomponente. Hierbei besteht die Möglichkeit, aus bestimmten Runs die dort geloggten Modelle zu registrieren und zu verarbeiten. Diese Schritte können über die API oder GUI gesteuert werden. Eine Grundvoraussetzung, um die Model Registry zu nutzen, ist eine Bereitstellung des MLflow Tracking Server Backend Store als Database Backend Store. Um ein Modell über die GUI zu registrieren, wählt man einen bestimmten Run aus und scrollt in die Artefakt Übersicht.

Artikel MLFlow Tool Bild 04

Mit einem Klick auf Register Model öffnet sich ein neues Fenster, in dem ein Modell registriert werden kann. Sofern man eine neue Version eines bereits existierenden Modells registrieren möchte, wählt man das gesuchte Modell aus dem Dropdown Feld aus. Ansonsten kann jederzeit ein neues Modell angelegt werden. Nach dem Klick auf den Button Register erscheint in dem Reiter Models das zuvor registrierte Modell mit einer entsprechenden Versionierung.

Artikel MLFlow Tool Bild 05

Jedes Modell beinhaltet eine Übersichtsseite, bei der alle vergangenen Versionen dargestellt werden. Dies ist bspw. nützlich, um nachzuvollziehen, welche Modelle wann in Produktion waren.

Artikel MLFlow Tool Bild 06

Wählt man nun eine Modellversion aus, gelangt man auf eine Übersicht, bei der beispielsweise eine Modellbeschreibung angefügt werden kann. Ebenso gelangt man über den Link Source Run zu dem Run, aus dem das Modell registriert worden ist. Hier befindet sich auch das dazugehörige Artefakt, das später zum Deployment verwendet werden kann.

Artikel MLFlow Tool Bild 07

Zusätzlich können einzelne Modellversionen in dem Bereich Stage in festgelegte Phasen kategorisiert werden. Dieses Feature kann beispielsweise dazu genutzt werden, um festzulegen, welches Modell gerade in der Produktion verwendet wird oder dahin überführt werden soll. Für das Deployment kann, im Gegensatz zu Szenario 1, die Versionierung und der Staging-Status dazu verwendet werden, um das geeignete Modell identifizieren und zu deployen. Hierzu kann z.B. die Python API MlflowClient().search_model_versions verwendet werden, um das gewünschte Modell und die dazugehörige RunID zu filtern. Ähnlich wie in Szenario 1 kann dann das Deployment beispielsweise nach AWS Sagemaker oder AzureML über die jeweiligen Python APIs vollzogen werden.

3. Szenario – Kollaboratives Workflowmanagement

Szenario 3 – Full Workflow ManagementSzenario 3 – Full Workflow Management

Das Szenario 3 beinhaltet, neben denen in Szenario 2 verwendeten Modulen, zusätzlich noch das Modul MLflow Projects. Wie bereits erläutert, eignen sich MLflow Projects besonders gut für kollaborative Arbeiten. Jedes Git Repository oder jede lokale Umgebung kann hierbei als Projekt fungieren und mittels eines MLproject File gesteuert werden. Hierbei können Paketabhängigkeiten in einem conda.yaml festgehalten und beim Starten des Projekts auf das MLproject File zugegriffen werden. Anschließend wird die entsprechende conda Umgebung mit allen Abhängigkeiten vor dem Training und Logging des Modells erstellt. Dies verhindert den Bedarf eines manuellen Angleichens der Entwicklungsumgebungen aller beteiligten Entwickler*innen und garantiert zudem standardisierte und vergleichbare Ergebnisse aller Runs. Insbesondere letzteres ist erforderlich im Deployment Kontext, da allgemein nicht garantiert werden kann, dass unterschiedliche Package-Versionen dieselben Modellartefakte produzieren. Anstelle einer conda Umgebung kann auch eine Docker Umgebung mittels eines Dockerfiles definiert werden. Dies bietet den Vorteil, dass auch von Python unabhängige Paketabhängigkeiten festgelegt werden können. Ebenso ermöglichen MLflow Projects durch die Anwendung unterschiedlicher commit hashes oder branch names das Verwenden verschiedener Projektstände, sofern ein Git Repository verwendet wird.

Ein interessanter Use Case hierbei ist die modularisierte Entwicklung von Machine Learning Trainingspipelines [4]. Hierbei kann bspw. die Datenaufbereitung vom Modelltraining entkoppelt und parallel weiterentwickelt werden, während parallel ein anderes Team einen unterschiedlichen branch name verwendet, um das Modell zu trainieren. Hierbei muss lediglich beim Starten des Projektes im MLflow Projects File ein unterschiedlicher branch name als Parameter verwendet werden. Die finale Datenaufbereitung kann im Anschluss auf denselben branch name gepusht werden, der zum Modelltraining verwendet wird und wäre somit bereits vollständig in der Trainingspipeline implementiert. Das Deployment kann ebenfalls als Teilmodul innerhalb der Projektpipeline mittels eines Python Skripts über das ML Project File gesteuert werden und analog zu Szenario 1 oder 2 auf eine Plattform Ihrer Wahl erfolgen.

Fazit und Ausblick

MLflow bietet eine flexible Möglichkeit den Machine Learning Workflow robust gegen die typischen Herausforderungen im Alltag eines Data Scientists zu gestalten, wie beispielsweise Synchronisationsprobleme aufgrund unterschiedlicher Entwicklungsumgebungen oder fehlendes Modellmanagement. Je nach Reifegrad des bestehenden Machine Learning Workflows können verschiedene Services aus dem MLflow Portfolio verwendet werden, um eine höhere Professionalisierungsstufe zu erreichen.

Im Artikel wurden drei, in der Komplexität aufsteigende, Machine Learning Workflows exemplarisch dargestellt. Vom einfachen Logging der Ergebnisse in einer interaktiven UI, bis hin zu komplexeren, modularen Modellierungspipelines können MLflow Services unterstützen. Logischerweise ergeben sich auch außerhalb des MLflow Ökosystems Synergien mit anderen Tools, wie zum Beispiel Docker/Kubernetes zur Modellskalierung oder auch Jenkins zur Steuerung der CI/CD Pipeline. Sofern noch weiteres Interesse an MLOps Herausforderungen und Best Practices besteht verweise ich auf das von uns kostenfrei zur Verfügung gestellte Webinar zu MLOps von unserem CEO Sebastian Heinz.

Quellen

John Vicente John Vicente John Vicente

In den letzten drei Beiträgen dieser Serie haben wir erklärt, wie man ein Deep-Learning-Modell trainiert, um ein Auto anhand seiner Marke und seines Modells zu klassifizieren, basierend auf einem Bild des Autos (Teil 1), wie man dieses Modell aus einem Docker-Container mit TensorFlow Serving einsetzt (Teil 2) und wie man die Vorhersagen des Modells erklärt (Teil 3). In diesem Beitrag lernt ihr, wie ihr mit Dash eine ansprechende Oberfläche um unseren Auto-Modell-Classifier herum bauen könnt.

Wir werden unsere Machine Learning-Vorhersagen und -Erklärungen in ein lustiges und spannendes Spiel verwandeln. Wir präsentieren den Anwender*innen zunächst ein Bild von einem Auto. Die Anwender*innen müssen erraten, um welches Automodell und welche Marke es sich handelt – das Machine-Learning-Modell wird das Gleiche tun. Nach 5 Runden wird ausgewertet, wer die Automarke besser vorhersagen kann: die Anwender*innen oder das Modell.

Inhalt

Das Tech Stack: Was ist Dash?

Dash ist, wie der Name schon sagt, eine Software zum Erstellen von Dashboards in Python. In Python, fragen ihr euch? Ja – ihr müsst nichts direkt in HTML oder Javascript programmieren (obwohl ein grundlegendes Verständnis von HTML sicherlich hilfreich ist). Eine hervorragende Einführung findet ihr in dem ausgezeichneten Blogpost meines Kollegen Alexander Blaufuss.

Um das Layout und Styling unserer Web-App zu vereinfachen, verwenden wir auch Dash Bootstrap Components. Sie folgen weitgehend der gleichen Syntax wie die standardmäßigen Dash-Komponenten und fügen sich nahtlos in das Dash-Erlebnis ein.

Denkt daran, dass Dash für Dashboards gemacht ist – das heißt, es ist für Interaktivität gemacht, aber nicht unbedingt für Apps mit mehreren Seiten. Mit dieser Info im Hinterkopf werden wir in diesem Artikel Dash an seine Grenzen bringen.

Organisation ist alles – Die Projektstruktur

Um alles nachbauen zu können, solltet ihr euch unser GitHub-Repository ansehen, auf dem alle Dateien verfügbar sind. Außerdem könnt ihr alle Docker-Container mit einem Klick starten und loslegen.

Die Dateien für das Frontend selbst sind logischerweise in mehrere Teile aufgeteilt. Es ist zwar möglich, alles in eine Datei zu schreiben, aber man verliert leicht den Überblick und daher später schwer zu pflegen. Die Dateien folgen der Struktur des Artikels:

  1. In einer Datei wird das gesamte Layout definiert. Jeder Button, jede Überschrift, jeder Text wird dort gesetzt.
  2. In einer anderen Datei wird die gesamte Dashboard-Logik (sogenannte Callbacks) definiert. Dort wird z. B. definiert, was passieren soll, nachdem die Benutzer*innen auf eine Schaltfläche geklickt hat.
  3. Wir brauchen ein Modul, das 5 zufällige Bilder auswählt und die Kommunikation mit der Prediction and Explainable API übernimmt.
  4. Abschließend gibt es noch zwei Dateien, die die Haupteinstiegspunkte (Entry Points) zum Starten der App sind.

Erstellen der Einstiegspunkte – Das große Ganze

Beginnen wir mit dem letzten Teil, dem Haupteinstiegspunkt für unser Dashboard. Wenn ihr wisst, wie man eine Web-App schreibt, wie z. B. eine Dash-Anwendung oder auch eine Flask-App, ist euch das Konzept einer App-Instanz vertraut. Vereinfacht ausgedrückt, ist die App-Instanz alles. Sie enthält die Konfiguration für die App und schließlich das gesamte Layout. In unserem Fall initialisieren wir die App-Instanz direkt mit den Bootstrap-CSS-Dateien, um das Styling überschaubarer zu machen. Im gleichen Schritt exponieren wir die zugrundeliegende Flask-App. Die Flask-App wird verwendet, um das Frontend in einer produktiven Umgebung zu bedienen.

# app.py
import dash
import dash_bootstrap_components as dbc

# ...

# Initialize Dash App with Bootstrap CSS
app = dash.Dash(
    __name__,
    external_stylesheets=[dbc.themes.BOOTSTRAP],
)

# Underlying Flask App for productive deployment
server = app.server

Diese Einstellung wird für jede Dash-Anwendung verwendet. Im Gegensatz zu einem Dashboard benötigen wir eine Möglichkeit, mit mehreren URL-Pfaden umzugehen. Genauer gesagt, wenn die Benutzer*innen /attempt eingibt, wollen wir ihm erlauben, ein Auto zu erraten; wenn er /result eingibt, wollen wir das Ergebnis seiner Vorhersage anzeigen.

Zunächst definieren wir das Layout. Bemerkenswert ist, dass es zunächst grundsätzlich leer ist. Ihr findet dort eine spezielle Dash Core Component. Diese Komponente dient dazu, die aktuelle URL dort zu speichern und funktioniert in beide Richtungen. Mit einem Callback können wir den Inhalt auslesen, herausfinden, welche Seite die Benutzer*innen aufrufen möchte, und das Layout entsprechend rendern. Wir können auch den Inhalt dieser Komponente manipulieren, was praktisch eine Weiterleitung auf eine andere Seite ist. Das leere div wird als Platzhalter für das eigentliche Layout verwendet.

# launch_dashboard.py
import dash_bootstrap_components as dbc
import dash_core_components as dcc
import dash_html_components as html
from app import app

# ...

# Set Layout
app.layout = dbc.Container(
    [dcc.Location(id='url', refresh=False),
     html.Div(id='main-page')])

Die Magie geschieht in der folgenden Funktion. Die Funktion selbst hat ein Argument, den aktuellen Pfad als String. Basierend auf dieser Eingabe gibt sie das richtige Layout zurück. Wenn die Benutzer*innen zum Beispiel zum ersten Mal auf die Seite zugreift, ist der Pfad / und das Layout daher start_page. Auf das Layout werden wir gleich noch im Detail eingehen; beachtet zunächst, dass wir an jedes Layout immer eine Instanz der App selbst und den aktuellen Spielzustand übergeben.

Damit diese Funktion tatsächlich funktioniert, müssen wir sie mit dem Callback Decorator schmücken. Jeder Callback benötigt mindestens eine Eingabe und mindestens eine Ausgabe. Eine Änderung des Inputs löst die Funktion aus. Der Eingang ist einfach die oben definierte Ortskomponente mit der Eigenschaft Pathname. Einfach ausgedrückt, aus welchem Grund auch immer sich der Pfad ändert, wird diese Funktion ausgelöst. Die Ausgabe ist das neue Layout, gerendert in dem zuvor zunächst leeren div.

# launch_dashboard.py
import dash_html_components as html
from dash.dependencies import Input, Output
from dash.exceptions import PreventUpdate

# ...

@app.callback(Output('main-page', 'children'), [Input('url', 'pathname')])
def display_page(pathname: str) -> html:
    """Function to define the routing. Mapping routes to layout.

    Arguments:
        pathname {str} -- pathname from url/browser

    Raises:
        PreventUpdate: Unknown/Invalid route, do nothing

    Returns:
        html -- layout
    """
    if pathname == '/attempt':
        return main_layout(app, game_data, attempt(app, game_data))

    elif pathname == '/result':
        return main_layout(app, game_data, result(app, game_data))

    elif pathname == '/finish':
        return main_layout(app, game_data, finish_page(app, game_data))

    elif pathname == '/':
        return main_layout(app, game_data, start_page(app, game_data))

    else:
        raise PreventUpdate

Layout – Schön & Shiny

Beginnen wir mit dem Layout unserer App – wie soll sie aussehen? Wir haben uns für ein relativ einfaches Aussehen entschieden. Wie ihr in der Animation oben sehen könnt, besteht die App aus drei Teilen: dem Header, dem Hauptcontent und dem Footer. Der Header und der Footer sind auf jeder Seite gleich, nur der Hauptinhalt ändert sich. Einige Layouts aus dem Hauptcontent sind in der Regel eher schwierig zu erstellen. Zum Beispiel besteht die Ergebnisseite aus vier Boxen. Die Boxen sollten immer die gleiche Breite von genau der Hälfte der verwendeten Bildschirmgröße haben, können aber je nach Bildgröße in der Höhe variieren. Sie dürfen sich aber nicht überlappen, usw. Von den Cross-Browser-Inkompatibilitäten ganz zu schweigen.

Ihr könnt euch sicher vorstellen, dass wir leicht mehrere Arbeitstage damit hätten verbringen können, das optimale Layout zu finden. Glücklicherweise können wir uns wieder einmal auf Bootstrap und das Bootstrap Grid System verlassen. Die Hauptidee ist, dass ihr so viele Zeilen wie ihr wollt (zwei, im Fall der Ergebnisseite) und bis zu 12 Spalten pro Zeile (ebenfalls zwei für die Ergebnisseite) erstellen könnt. Die Begrenzung auf 12 Spalten basiert auf der Tatsache, dass Bootstrap die Seite intern in 12 gleich große Spalten aufteilt. Ihr müsst nur mit einer einfachen CSS-Klasse definieren, wie groß die Spalte sein soll. Und was noch viel cooler ist: Ihr könnt mehrere Layouts einstellen, je nach Bildschirmgröße. Es wäre also nicht schwierig, unsere App vollständig responsive zu machen.

Um auf den Dash-Teil zurückzukommen, bauen wir eine Funktion für jedes unabhängige Layout-Teil. Den Header, den Footer und eine für jede URL, die die Benutzer*innen aufrufen könnte. Für den Header sieht das so aus:

# layout.py
import dash_bootstrap_components as dbc
import dash_html_components as html

# ...

def get_header(app: dash.Dash, data: GameData) -> html:
    """Layout for the header

    Arguments:
        app {dash.Dash} -- dash app instance
        data {GameData} -- game data

    Returns:
        html -- html layout
    """
    logo = app.get_asset_url("logo.png")

    score_user, score_ai = count_score(data)

    header = dbc.Container(
        dbc.Navbar(
            [
                html.A(
                    # Use row and col to control vertical alignment of logo / brand
                    dbc.Row(
                        [
                            dbc.Col(html.Img(src=logo, height="40px")),
                            dbc.Col(
                                dbc.NavbarBrand("Beat the AI - Car Edition",
                                                className="ml-2")),
                        ],
                        align="center",
                        no_gutters=True,
                    ),
                    href="/",
                ),
                # You find the score counter here; Left out for clarity
            ],
            color=COLOR_STATWORX,
            dark=True,
        ),
        className='mb-4 mt-4 navbar-custom')

    return header

Auch hier seht ihr, dass wir die App-Instanz und den globalen Spieldatenstatus an die Layout-Funktion übergeben. In einer perfekten Welt müssten wir mit keiner dieser Variablen im Layout herumspielen. Leider ist das eine der Einschränkungen von Dash. Eine perfekte Trennung von Layout und Logik ist nicht möglich. Die App-Instanz wird benötigt, um dem Webserver mitzuteilen, dass er das STATWORX-Logo als statische Datei ausliefern soll.

Natürlich könnte man das Logo von einem externen Server ausliefern, das machen wir ja auch für die Fahrzeugbilder, aber nur für ein Logo wäre das ein bisschen zu viel des Guten. Für die Spieldaten müssen wir den aktuellen Punktestand des Benutzers und der KI berechnen. Alles andere ist entweder normales HTML oder Bootstrap-Komponenten. Wer sich damit nicht auskennt, den kann ich noch einmal auf den Blogpost von meinem Kollegen Alexander verweisen oder auf eines der zahlreichen HTML-Tutorials im Internet.

Callbacks – Reaktivität einführen

Wie bereits erwähnt, sind Callbacks das Mittel der Wahl, um das Layout interaktiv zu gestalten. In unserem Fall bestehen sie hauptsächlich aus der Handhabung des Dropdowns sowie der Button Klicks. Während die Dropdowns relativ einfach zu programmieren waren, bereiteten uns die Buttons einige Kopfschmerzen.

Einem guten Programmierstandard folgend, sollte jede Funktion genau eine Verantwortung haben. Deshalb haben wir für jeden Button einen Callback eingerichtet. Nach einer Art Eingabevalidierung und Datenmanipulation ist das Ziel, die Benutzer*innen auf die folgende Seite umzuleiten. Während die Eingabe für den Callback das Button-Klick-Ereignis und möglicherweise einige andere Eingabeformulare ist, ist die Ausgabe immer die Location-Komponente, um die Benutzer*innen weiterzuleiten. Leider erlaubt Dash nicht, mehr als einen Callback zum gleichen Ausgang zu haben. Daher waren wir gezwungen, die Logik für jede Schaltfläche in eine Funktion zu quetschen.

Da wir die Benutzereingaben auf der Versuchsseite validieren mussten, haben wir die aktuellen Werte aus dem Dropdown an den Callback übergeben. Während das für die Versuchsseite einwandfrei funktionierte, funktionierte die Schaltfläche auf der Ergebnisseite nicht mehr, da keine Dropdowns zur Übergabe an die Funktion verfügbar waren. Wir mussten ein verstecktes, nicht funktionierendes Dummy-Dropdown in die Ergebnisseite einfügen, damit die Schaltfläche wieder funktionierte. Das ist zwar eine Lösung und funktioniert in unserem Fall einwandfrei, aber für eine umfangreichere Anwendung könnte es zu kompliziert sein.

Data Download – Wir brauchen Autos

Jetzt haben wir eine schöne App mit funktionierenden Buttons und so weiter, aber die Daten fehlen noch. Wir müssen Bilder, Vorhersagen und Erklärungen in die App einbinden.

Die High-Level-Idee ist, dass jede Komponente für sich alleine läuft – zum Beispiel in einem eigenen Docker-Container mit eigenem Webserver. Alles ist nur lose über APIs miteinander gekoppelt. Der Ablauf ist der folgende:

  • Schritt 1: Abfrage einer Liste aller verfügbaren Auto-Images. Wähle zufällig 5 aus und fordere diese Bilder vom Webserver an.
  • Schritt 2: Sende für alle 5 Bilder eine Anfrage an die Vorhersage-API und parse das Ergebnis aus der API.
  • Schritt 3: Sende wiederum für alle 5 Bilder eine Anfrage an die Explainable-API und speichere das zurückgegebene Bild.

Kombiniert nun jede Ausgabe in der GameData-Klasse.

Aktuell speichern wir die GameData-Instanz als globale Variable. Das erlaubt uns, von überall darauf zuzugreifen. Das ist zwar theoretisch eine schlaue Idee, funktioniert aber nicht, wenn mehr als eine Benutzerin versucht, auf die App zuzugreifen. Derdie zweite Benutzerin wird den Spielstatus vom ersten sehen. Da wir planen, das Spiel auf Messen auf einer großen Leinwand zu zeigen, ist das für den Moment in Ordnung. In Zukunft könnten wir das Dashboard mit Shiny Proxy starten, so dass jeder Benutzer seinen eigenen Docker-Container mit einem isolierten globalen Status erhält.

Data Storage – Die Autos parken

Die native Dash-Methode besteht darin, benutzerspezifische Zustände in einer Store-Komponente zu speichern. Das ist im Grunde dasselbe wie die oben erläuterte Location-Komponente. Die Daten werden im Webbrowser gespeichert, ein Callback wird ausgelöst, und die Daten werden an den Server gesendet. Der erste Nachteil ist, dass wir bei jedem Seitenwechsel die gesamte Spieldateninstanz vom Browser zum Server übertragen müssen. Das kann ziemlich viel Traffic verursachen und verlangsamt das gesamte App-Erlebnis.

Außerdem müssen wir, wenn wir den Spielzustand ändern wollen, dies über einen Callback tun. Die Beschränkung auf einen Callback pro Ausgabe gilt auch hier. Unserer Meinung nach macht es nicht allzu viel aus, wenn Sie ein klassisches Dashboard haben; dafür ist Dash gedacht. Die Verantwortlichkeiten sind getrennt. In unserem Fall wird der Spielstatus von mehreren Komponenten aus aufgerufen und verändert. Wir haben Dash definitiv an seine Grenzen gebracht.

Eine weitere Sache, die ihr im Auge behalten solltet, wenn ihr euch entscheidet, eure eigene Microservice-App zu bauen, ist die Performance der API-Aufrufe. Anfänglich haben wir die berühmte requests Bibliothek verwendet. Während wir große Fans dieser Bibliothek sind, sind alle Anfragen blockierend. Daher wird die zweite Anfrage ausgeführt, sobald die erste abgeschlossen ist. Da unsere Anfragen relativ langsam sind (bedenkt, dass im Hintergrund vollwertige neuronale Netze laufen), verbringt die App viel Zeit mit Warten. Wir haben asynchrone Aufrufe mit Hilfe der Bibliothek aiohttp implementiert. Alle Anfragen werden nun parallel verschickt. Die App verbringt weniger Zeit mit Warten, und der Benutzer ist früher bereit zum Spielen.

Fazit und Hinweise

Auch wenn die Web-App einwandfrei funktioniert, gibt es ein paar Dinge, die zu beachten sind. Wir haben Dash verwendet, wohl wissend, dass es als Dashboarding-Tool gedacht ist. Wir haben es bis an die Grenzen und darüber hinaus getrieben, was zu einigen suboptimalen interessanten Design-Entscheidungen führte.

Zum Beispiel könnt ihr nur einen Callback pro Ausgabeparameter setzen. Mehrere Callbacks für dieselbe Ausgabe sind derzeit nicht möglich. Da das Routing von einer Seite zur anderen im Wesentlichen eine Änderung des Ausgabeparameters (‚url‘, ‚pathname‘) ist, muss jeder Seitenwechsel durch einen Callback geleitet werden. Das erhöht die Komplexität des Codes exponentiell.

Ein weiteres Problem ist die Schwierigkeit, Zustände über mehrere Seiten hinweg zu speichern. Dash bietet mit der Store Component die Möglichkeit, Benutzerdaten im Frontend zu speichern. Das ist eine hervorragende Lösung für kleine Apps; bei größeren steht man schnell vor dem gleichen Problem wie oben – ein Callback, eine Funktion zum Schreiben in den Store, reicht einfach nicht aus. Entweder ihr nutzt den globalen Zustand von Python, was schwierig wird, wenn mehrere Benutzer gleichzeitig auf die Seite zugreifen, oder ihr bindet einen cache ein.

In unserer Blogserie haben wir Ihnen gezeigt, wie ihr den gesamten Lebenszyklus eines Data-Science-Projekts durchlauft, von der Datenexploration über das Modelltraining bis hin zur Bereitstellung und Visualisierung. Dies ist der letzte Artikel dieser Serie, und wir hoffen, ihr habt beim Erstellen der Anwendung genauso viel gelernt wie wir.

Um das Durchblättern der vier Artikel zu erleichtern, sind hier die direkten Links:

  1. Transfer Learning mit ResNet
  2. Deployment von TensorFlow-Modellen in Docker mit TensorFlow Serving
  3. Erklärbarkeit von Deep Learning Modellen mit Grad-CAM

Im ersten Artikel dieser Serie über die Klassifizierung von Automodellen haben wir ein Modell gebaut, das Transfer Learning verwendet, um das Automodell durch ein Bild eines Autos zu klassifizieren. Im zweiten Beitrag haben wir gezeigt, wie TensorFlow Serving verwendet werden kann, um ein TensorFlow-Modell am Beispiel des Automodell-Classifiers einzusetzen. Diesen dritten Beitrag widmen wir einem weiteren wesentlichen Aspekt von Deep Learning und maschinellem Lernen im Allgemeinen: der Erklärbarkeit von Modellvorhersagen (englisch: Explainable AI).

Wir beginnen mit einer kurzen allgemeinen Einführung in das Thema Erklärbarkeit beim maschinellen Lernen. Als nächstes werden wir kurz auf verbreitete Methoden eingehen, die zur Erklärung und Interpretation von CNN-Vorhersagen verwendet werden können. Anschließend werden wir Grad-CAM, eine gradientenbasierte Methode, ausführlich erklären, indem wir Schritt für Schritt eine Implementierung des Verfahrens durchgehen. Zum Schluss zeigen wir Ergebnisse, die wir mit unserer Grad-CAM-Implementierung für den Auto-Modell-Classifier berechnet haben.

Inhalt

Eine kurze Einführung in die Erklärbarkeit von Machine Learning Modellen

In den letzten Jahren war die Erklärbarkeit ein immer wiederkehrendes Thema – aber dennoch ein Nischenthema – im Machine Learning. In den letzten vier Jahren jedoch hat das Interesse an diesem Thema stark zugenommen. Stark dazu beigetragen hat unter anderem die steigende Anzahl von Machine Learning-Modellen in der Produktion. Einerseits führt dies zu einer wachsenden Zahl von Endnutzern, die verstehen müssen, wie die Modelle Entscheidungen treffen. Andererseits müssen immer mehr Entwickler*innen von Machine Learning verstehen, warum (oder warum nicht) ein Modell auf eine bestimmte Weise funktioniert.

Dieser steigende Bedarf an Erklärbarkeit führte in den letzten Jahren zu einigen sowohl methodisch als auch technisch bemerkenswerten Innovationen:

Methoden zur Erklärung von CNN-Outputs für Bilddaten

Deep Neural Networks (DNNs) und insbesondere komplexe Architekturen wie CNNs galten lange Zeit als reine Blackbox-Modelle. Wie oben beschrieben änderte sich dies in den letzten Jahren, und inzwischen gibt es verschiedene Methoden, um CNN-Outputs zu erklären. Zum Beispiel implementiert die hervorragende Bibliothek tf-explain eine breite Palette nützlicher Methoden für TensorFlow 2.x. Wir werden nun kurz auf die Ideen der verschiedenen Ansätze eingehen, bevor wir uns Grad-CAM zuwenden:

Activations Visualization

Activations Visualization ist die einfachste Visualisierungstechnik. Hierbei wird die Ausgabe einer bestimmten Layer innerhalb des Netzwerks während des Vorwärtsdurchlaufs ausgegeben. Diese kann hilfreich sein, um ein Gefühl für die extrahierten Features zu bekommen, da die meisten Activations während des Trainings gegen Null tendieren (bei Verwendung der ReLu-Activation). Ein Beispiel für die Ausgabe der ersten Faltungsschicht des Auto-Modell-Classifiers ist unten dargestellt:

Activations Beispielbild

Vanilla Gradients

Man kann die Vanilla-Gradients der Ausgabe der vorhergesagten Klassen für das Eingangsbild verwenden, um die Bedeutung der Eingangspixel abzuleiten.

Vanilla Gradients Beispielbild

Wir sehen hier, dass der hervorgehobene Bereich hauptsächlich auf das Auto fokussiert ist. Im Vergleich zu den unten besprochenen Methoden ist der diskriminierende Bereich viel weniger eingegrenzt.

Occlusion Sensitivity

Bei diesem Ansatz wird die Signifikanz bestimmter Teile des Eingangsbildes berechnet, indem die Vorhersage des Modells für verschiedene ausgeblendete Teile des Eingangsbildes bewertet wird. Teile des Bildes werden iterativ ausgeblendet, indem sie durch graue Pixel ersetzt werden. Je schwächer die Vorhersage wird, wenn ein Teil des Bildes ausgeblendet ist, desto wichtiger ist dieser Teil für die endgültige Vorhersage. Basierend auf der Unterscheidungskraft der Bildregionen kann eine Heatmap erstellt und dargestellt werden. Die Anwendung der Occlusion Sensitivity für unseren Auto-Modell-Classifier hat keine aussagekräftigen Ergebnisse geliefert. Daher zeigen wir das Beispielbild von tf-explain, welches das Ergebnis der Anwendung des Verfahrens der Occlusion Sensitivity für ein Katzenbild zeigt.

Occlusion Sensitivity Beispielbild

CNN Fixations

Ein weiterer interessanter Ansatz namens CNN Fixations wurde in diesem Paper vorgestellt . Die Idee dabei ist, zurück zu verfolgen, welche Neuronen in jeder Schicht signifikant waren, indem man die Activations aus der Vorwärtsrechnung und die Netzwerkgewichte betrachtet. Die Neuronen mit großem Einfluss werden als Fixations bezeichnet. Dieser Ansatz erlaubt es also, die wesentlichen Regionen für das Ergebnis zu finden, ohne wiederholte Modellvorhersagen berechnen zu müssen (wie dies z.B. für die oben erklärte Occlusion Sensitivity der Fall ist).

Das Verfahren kann wie folgt beschrieben werden: Der Knoten, der der Klasse entspricht, wird als Fixation in der Ausgabeschicht gewählt. Dann werden die Fixations für die vorherige Schicht bestimmt, indem berechnet wird, welche der Knoten den größten Einfluss auf die Fixations der nächsthöheren Ebene haben, die im letzten Schritt bestimmt wurden. Die Knotengewichtung wird durch Multiplikation von Activations und Netzwerk-Gewichten errechnet. Wenn ihr an den Details des Verfahrens interessiert seid, schaut euch das Paper oder das entsprechende Github Repo an. Dieses Backtracking wird so lange durchgeführt, bis das Eingabebild erreicht ist, was eine Menge von Pixeln mit beträchtlicher Unterscheidungskraft ergibt. Ein Beispiel aus dem Paper ist unten dargestellt.

CNN Fixations Beispielbild

CAM

Das in diesem Paper vorgestellte Class Activation Mapping (CAM) ist ein Verfahren, um die diskriminante(n) Region(en) für eine CNN-Vorhersage durch die Berechnung von sogenannten Class Activation Maps zu finden. Ein wesentlicher Nachteil dieses Verfahrens ist, dass das Netzwerk als letzten Schritt vor der Vorhersageschicht ein Global Average Pooling (GAP) verwenden muss. Es ist daher nicht möglich, diesen Ansatz für allgemeine CNN-Architekturen anzuwenden. Ein Beispiel ist in der folgenden Abbildung dargestellt (entnommen aus dem CAM paper):

CAM Beispielbild

Die Class Activation Map weist jeder Position (x, y) in der letzten Faltungsschicht eine Bedeutung zu, indem sie die Linearkombination der Activations – gewichtet mit den entsprechenden Ausgangsgewichten für die beobachtete Klasse (im obigen Beispiel „Australian Terrier“) – berechnet. Die resultierende Class Activation Mapping wird dann auf die Größe des Eingabebildes hochgerechnet. Dies wird durch die oben dargestellte Heatmap veranschaulicht. Aufgrund der Architektur von CNNs ist die Aktivierung, z. B. oben links für eine beliebige Schicht, direkt mit der oberen linken Seite des Eingabebildes verbunden. Deshalb können wir nur aus der Betrachtung der letzten CNN-Schicht schließen, welche Eingabebereiche wichtig sind.

Bei dem Grad-CAM-Verfahren, das wir unten im Detail besprechen werden, handelt es sich um eine Verallgemeinerung von CAM. Grad-CAM kann auf Netzwerke mit allgemeinen CNN-Architekturen angewendet werden, die mehrere fully connected Layers am Ausgang enthalten.

Grad-CAM

Grad-CAM erweitert die Anwendbarkeit des CAM-Verfahrens durch das Einbeziehen von Gradienteninformationen. Konkret bestimmt der Gradient der Loss-Funktion in Bezug auf die letzte Faltungsschicht das Gewicht für jede der entsprechenden Feature Maps. Wie beim obigen CAM-Verfahren bestehen die weiteren Schritte in der Berechnung der gewichteten Summe der Aktivierungen und dem anschließenden Upsampling des Ergebnisses auf die Bildgröße, um das Originalbild mit der erhaltenen Heatmap darzustellen. Wir werden nun den Code, der zur Ausführung von Grad-CAM verwendet werden kann, zeigen und diskutieren. Der vollständige Code ist hier auf GitHub verfügbar.

import pickle
import tensorflow as tf
import cv2
from car_classifier.modeling import TransferModel

INPUT_SHAPE = (224, 224, 3)

# Load list of targets
file = open('.../classes.pickle', 'rb')
classes = pickle.load(file)

# Load model
model = TransferModel('ResNet', INPUT_SHAPE, classes=classes)
model.load('...')

# Gradient model, takes the original input and outputs tuple with:
# - output of conv layer (in this case: conv5_block3_3_conv)
# - output of head layer (original output)
grad_model = tf.keras.models.Model([model.model.inputs],
                                   [model.model.get_layer('conv5_block3_3_conv').output,
                                    model.model.output])

# Run model and record outputs, loss, and gradients
with tf.GradientTape() as tape:
    conv_outputs, predictions = grad_model(img)
    loss = predictions[:, label_idx]

# Output of conv layer
output = conv_outputs[0]

# Gradients of loss w.r.t. conv layer
grads = tape.gradient(loss, conv_outputs)[0]

# Guided Backprop (elimination of negative values)
gate_f = tf.cast(output > 0, 'float32')
gate_r = tf.cast(grads > 0, 'float32')
guided_grads = gate_f * gate_r * grads

# Average weight of filters
weights = tf.reduce_mean(guided_grads, axis=(0, 1))

# Class activation map (cam)
# Multiply output values of conv filters (feature maps) with gradient weights
cam = np.zeros(output.shape[0: 2], dtype=np.float32)
for i, w in enumerate(weights):
    cam += w * output[:, :, i]

# Or more elegant: 
# cam = tf.reduce_sum(output * weights, axis=2)

# Rescale to org image size and min-max scale
cam = cv2.resize(cam.numpy(), (224, 224))
cam = np.maximum(cam, 0)
heatmap = (cam - cam.min()) / (cam.max() - cam.min())

Detailbetrachtung des Codes

  • Der erste Schritt besteht darin, eine Instanz des Modells zu laden.
  • Dann erstellen wir eine neue keras.Model-Instanz, die zwei Ausgaben hat: Die Aktivierungen der letzten CNN-Schicht ('conv5_block3_3_conv') und die ursprüngliche Modellausgabe.
  • Als nächstes führen wir eine Vorwärtsrechnung für unser neues grad_model aus, wobei wir als Eingabe ein Bild ( img) der Form (1, 224, 224, 3) verwenden, das mit der Methode resnetv2.preprocess_input vorverarbeitet wurde. Zur Aufzeichnung der Gradienten wird tf.GradientTape angelegt und angewendet (die Gradienten werden hierbei im tapeObjekt gespeichert). Weiterhin werden die Ausgaben der Faltungsschicht (conv_outputs) und des heads (predictions) gespeichert. Schließlich können wir label_idx verwenden, um den Verlust zu erhalten, der dem Label entspricht, für das wir die diskriminierenden Regionen finden wollen.
  • Mit Hilfe der gradient-Methode kann man die gewünschten Gradienten aus tape extrahieren. In diesem Fall benötigen wir den Gradienten des Verlustes in Bezug auf die Ausgabe der Faltungsschicht.
  • In einem weiteren Schritt wird eine guided Backprop angewendet. Dabei werden nur Werte für die Gradienten behalten, bei denen sowohl die Aktivierungen als auch die Gradienten positiv sind. Dies bedeutet im Wesentlichen, dass die Aufmerksamkeit auf die Aktivierungen beschränkt wird, die positiv zu der gewünschten Ausgabevorhersage beitragen.
  • Die weights werden durch Mittelung der erhaltenen geführten Gradienten für jeden Filter berechnet.
  • Die Class Activation Map cam wird dann als gewichteter Durchschnitt der Aktivierungen der Feature Map (output) berechnet. Die Methode mit der obigen for-Schleife hilft zu verstehen, was die Funktion im Detail tut. Eine weniger einfache, aber effizientere Art, die CAM-Berechnung zu implementieren, ist die Verwendung von tf.reduce_mean und wird in der kommentierten Zeile unterhalb der Schleifenimplementierung gezeigt.
  • Schließlich wird das Resampling (Größenänderung) mit der resize-Methode von OpenCV2 durchgeführt, und die Heatmap wird so skaliert, dass sie Werte in [0, 1] enthält, um sie zu plotten.

Eine Version von Grad-CAM ist auch in tf-explain implementiert.

Beispiele für den Auto-Modell-Classifier

Wir verwenden nun die Grad-CAM-Implementierung, um die Vorhersagen des TransferModel für die Klassifizierung von Automodellen zu interpretieren und zu erklären. Wir beginnen mit der Betrachtung von Fahrzeugbildern, die von vorne aufgenommen wurden.

Grad-CAM für Fahrzeugaufnahmen von der Vorderseite
Grad-CAM für Fahrzeugaufnahmen von der Vorderseite

Die roten Regionen markieren die wichtigsten diskriminierenden Regionen, die blauen Regionen die unwichtigsten. Wir können sehen, dass sich das CNN bei Bildern von vorne auf den Kühlergrill des Autos und den Bereich des Logos konzentriert. Ist das Auto leicht gekippt, verschiebt sich der Fokus mehr auf den Rand des Fahrzeugs. Dies ist auch bei leicht gekippten Bildern von der Rückseite des Fahrzeugs der Fall, wie im mittleren Bild unten gezeigt.

Grad-CAM für Fahrzeugaufnahmen von der Rückseite
Grad-CAM für Fahrzeugaufnahmen von der Rückseite

Bei Bildern von der Rückseite des Autos liegt der wichtigste Unterscheidungsbereich in der Nähe des Nummernschilds. Wie bereits erwähnt, hat bei Autos, die aus einem Winkel betrachtet werden, die nächstgelegene Ecke die höchste Trennschärfe. Ein sehr interessantes Beispiel ist die Mercedes-Benz C-Klasse auf der rechten Seite, bei der sich das Modell nicht nur auf die Rückleuchten konzentriert, sondern auch die höchste Trennschärfe auf den Modellschriftzug legt.

Grad-CAM für Fahrzeugaufnahmen von der Seite
Grad-CAM für Fahrzeugaufnahmen von der Seite

Wenn wir Bilder von der Seite betrachten, stellen wir fest, dass die diskriminierende Region auf die untere Hälfte der Autos beschränkt ist. Auch hier bestimmt der Winkel, aus dem das Fahrzeugbild aufgenommen wurde, die Verschiebung der Region in Richtung der vorderen oder hinteren Ecke.

Im Allgemeinen ist die wichtigste Tatsache, dass die diskriminierenden Bereiche immer auf Teile der Autos beschränkt sind. Es gibt keine Bilder, bei denen der Hintergrund eine hohe Unterscheidungskraft hat. Die Betrachtung der Heatmaps und der zugehörigen diskriminierenden Regionen kann als Sanity-Check für CNN-Modelle verwendet werden.

Fazit

Wir haben mehrere Ansätze zur Erklärung von CNN-Klassifikatorausgaben diskutiert. Wir haben Grad-CAM im Detail vorgestellt, indem wir den Code untersucht und uns Beispiele für den Auto-Modell-Classifier angeschaut haben. Am auffälligsten ist, dass die durch das Grad-CAM-Verfahren hervorgehobenen diskriminierenden Regionen immer auf das Auto fokussiert sind und nie auf die Hintergründe der Bilder. Das Ergebnis zeigt, dass das Modell so funktioniert, wie wir es erwarten und spezifische Teile des Autos zur Unterscheidung zwischen verschiedenen Modellen verwendet werden.

Im vierten und letzten Teil dieser Blog-Serie werden wir zeigen, wie der Car Classifier mit Dash in eine Web-Anwendung eingebaut werden kann. Bis bald!

Im ersten Beitrag dieser Serie haben wir Transfer Learning im Detail besprochen und ein Modell zur Klassifizierung von Automodellen erstellt. In diesem Beitrag werden wir das Problem der Modellbereitstellung am Beispiel des im ersten Beitrags vorgestellten TransferModel diskutieren.

Ein Modell ist in der Praxis nutzlos, wenn es keine einfache Möglichkeit gibt, damit zu interagieren. Mit anderen Worten: Wir brauchen eine API für unsere Modelle. TensorFlow Serving wurde entwickelt, um diese Funktionalitäten für TensorFlow-Modelle bereitzustellen. In diesem Beitrag zeigen wir, wie ein TensorFlow Serving Server in einem Docker-Container gestartet werden kann und wie wir mit dem Server über HTTP-Anfragen interagieren können.

Wenn ihr noch nie mit Docker gearbeitet habt, empfehlen wir, dieses Tutorial von Docker durchzuarbeiten, bevor ihr diesen Artikel lest. Wenn ihr euch ein Beispiel für das Deployment in Docker ansehen möchtet, empfehlen wir euch, diesen Blogbeitrag von unserem Kollegen Oliver Guggenbühl zu lesen, in dem beschrieben wird, wie ein R-Skript in Docker ausgeführt werden kann.

Inhalt

Einführung in TensorFlow Serving

Zum Einstieg geben wir euch zunächst einen Überblick über TensorFlow Serving.

TensorFlow Serving ist das Serving-System von TensorFlow, das entwickelt wurde, um das Deployment von verschiedenen Modellen mit einer einheitlichen API zu ermöglichen. Unter Verwendung der Abstraktion von Servables, die im Grunde Objekte sind, mit denen Inferenz durchgeführt werden kann, ist es möglich, mehrere Versionen von deployten Modellen zu serven. Das ermöglicht zum Beispiel, dass eine neue Version eines Modells hochgeladen werden kann, während die vorherige Version noch für Kunden verfügbar ist. Im Großen und Ganzen sind sogenannte Manager für die Verwaltung des Lebenszyklus von Servables verantwortlich, d. h. für das Laden, Bereitstellen und Löschen.

In diesem Beitrag werden wir zeigen, wie eine einzelne Modellversion deployed werden kann. Die unten aufgeführten Code-Beispiele zeigen, wie ein Server in einem Docker-Container gestartet werden kann und wie die Predict API verwendet werden kann, um mit dem Modell zu interagieren. Um mehr über TensorFlow Serving zu erfahren, verweisen wir auf die TensorFlow-Website.

Implementierung

Wir werden nun die folgenden drei Schritte besprechen, die erforderlich sind, um das Modell einzusetzen und Requests zu senden.

  • Speichern eines Modells im richtigen Format und in der richtigen Ordnerstruktur mit TensorFlow SavedModel
  • Ausführen eines Serving-Servers innerhalb eines Docker-Containers
  • Interaktion mit dem Modell über REST Requests

Speichern von TensorFlow-Modellen

Für diejenigen, die den ersten Beitrag dieser Serie nicht gelesen haben, folgt nun eine kurze Zusammenfassung der wichtigsten Punkte, die zum Verständnis des nachfolgenden Codes notwendig sind:

Das TransferModel.model (unten im Code auch self.model) ist eine tf.keras.Model Instanz, also kann es mit der eingebauten save Methode gespeichert werden. Da das Modell auf im Internet gescrapten Daten trainiert wurde, können sich die Klassenbezeichnungen beim erneuten Scraping der Daten ändern. Wir speichern daher die Index-Klassen-Zuordnung beim Speichern des Modells in classes.pickle ab. TensorFlow Serving erfordert, dass das Modell im SavedModel Format gespeichert wird. Wenn Sie tf.keras.Model.save verwenden, muss der Pfad ein Ordnername sein, sonst wird das Modell in einem anderen, inkompatiblen Format (z.B. HDF5) gespeichert. Im Code unten enthält folderpath den Pfad des Ordners, in dem wir alle modellrelevanten Informationen speichern wollen. Das SavedModel wird unter folderpath/model gespeichert und das Class Mapping wird als folderpath/classes.pickle gespeichert.

def save(self, folderpath: str):
    """
    Save the model using tf.keras.model.save

    Args:
        folderpath: (Full) Path to folder where model should be stored
    """

    # Make sure folderpath ends on slash, else fix
    if not folderpath.endswith("/"):
        folderpath += "/"

    if self.model is not None:
        os.mkdir(folderpath)
        model_path = folderpath + "model"
        # Save model to model dir
        self.model.save(filepath=model_path)
        # Save associated class mapping
        class_df = pd.DataFrame({'classes': self.classes})
        class_df.to_pickle(folderpath + "classes.pickle")
    else:
        raise AttributeError('Model does not exist')

TensorFlow Serving im Docker Container starten

Nachdem wir das Modell auf der Festplatte gespeichert haben, müssen wir nun den TensorFlow Serving Server starten. Am schnellsten deployen kann man TensorFlow Serving mithilfe eines Docker-Containers. Der erste Schritt ist daher das Ziehen des TensorFlow Serving Images von DockerHub. Das kann im Terminal mit dem Befehl docker pull tensorflow/serving gemacht werden.

Dann können wir den unten stehenden Code verwenden, um einen TensorFlow Serving Container zu starten. Er führt den Shell-Befehl zum Starten eines Containers aus. Die in docker_run_cmd gesetzten Optionen sind die folgenden:

  • Das Serving-Image exponiert Port 8501 für die REST-API, die wir später zum Senden von Anfragen verwenden werden. Wir mappen mithilfe der -p– Flag also den Host-Port 8501 auf den Port 8501 des Containers.
  • Als nächstes binden wir unser Modell mit -v in den Container ein. Es ist wichtig, dass das Modell in einem versionierten Ordner gespeichert ist (hier MODEL_VERSION=1); andernfalls wird das Serving-Image das Modell nicht finden. Der model_path_guest muss also die Form <path>/<model name>/MODEL_VERSION haben, wobei MODEL_VERSION eine ganze Zahl ist.
  • Mit -e können wir die Umgebungsvariable MODEL_NAME setzen, die den Namen unseres Modells enthält.
  • Die Option --name tf_serving wird nur benötigt, um unserem neuen Docker-Container einen bestimmten Namen zuzuweisen.

Wenn wir versuchen, diese Datei zweimal hintereinander auszuführen, wird der Docker-Befehl beim zweiten Mal nicht ausgeführt, da bereits ein Container mit dem Namen tf_serving existiert. Um dieses Problem zu vermeiden, verwenden wir docker_run_cmd_cond. Hier prüfen wir zunächst, ob ein Container mit diesem spezifischen Namen bereits existiert und läuft. Wenn ja, lassen wir ihn gleich; wenn nicht, prüfen wir, ob eine beendete Version des Containers existiert. Wenn ja, wird diese gelöscht und ein neuer Container gestartet; wenn nicht, wird direkt ein neuer Container erstellt.

import os

MODEL_FOLDER = 'models'
MODEL_SAVED_NAME = 'resnet_unfreeze_all_filtered.tf'
MODEL_NAME = 'resnet_unfreeze_all_filtered'
MODEL_VERSION = '1'

# Define paths on host and guest system
model_path_host = os.path.join(os.getcwd(), MODEL_FOLDER, MODEL_SAVED_NAME, 'model')
model_path_guest = os.path.join('/models', MODEL_NAME, MODEL_VERSION)

# Container start command
docker_run_cmd = f'docker run ' 
                 f'-p 8501:8501 ' 
                 f'-v {model_path_host}:{model_path_guest} ' 
                 f'-e MODEL_NAME={MODEL_NAME} ' 
                 f'-d ' 
                 f'--name tf_serving ' 
                 f'tensorflow/serving'

# If container is not running, create a new instance and run it
docker_run_cmd_cond = f'if [ ! "(docker ps -q -f name=tf_serving)" ]; then n'                        f'   if [ "(docker ps -aq -f status=exited -f name=tf_serving)" ]; then 														n' 
                      f'   		docker rm tf_serving n' 
                      f'   fi n' 
                      f'   {docker_run_cmd} n' 
                      f'fi'

# Start container
os.system(docker_run_cmd_cond)

Anstatt das Modell von unserer lokalen Festplatte zu mounten, indem wir das -v-Flag im Docker-Befehl verwenden, könnten wir das Modell auch in das Docker-Image kopieren, so dass das Modell einfach durch das Ausführen eines Containers und die Angabe der Port-Zuweisungen bedient werden könnte. Es ist wichtig zu beachten, dass in diesem Fall das Modell mit der Ordnerstruktur Ordnerpfad/<Modellname>/1 gespeichert werden muss, wie oben erklärt. Wenn dies nicht der Fall ist, wird der TensorFlow Serving Container das Modell nicht finden. Wir werden hier nicht weiter auf diesen Fall eingehen. Wenn ihr daran interessiert seid, eure Modelle auf diese Weise zu deployen, verweisen wir auf diese Anleitung auf der TensorFlow Webseite.

REST Request

Da das Modell nun geserved ist und bereit zur Verwendung ist, brauchen wir einen Weg, um damit zu interagieren. TensorFlow Serving bietet zwei Optionen, um Anfragen an den Server zu senden: gRCP und REST API, welche beide an unterschiedlichen Ports verfügbar sind. Im folgenden Codebeispiel werden wir REST verwenden, um das Modell abzufragen.

Zuerst laden wir ein Bild von der Festplatte, für das wir eine Vorhersage machen wollen. Dies kann mit dem image Modul von TensorFlow gemacht werden. Als nächstes konvertieren wir das Bild in ein Numpy-Array mittels der img_to_array-Methode. Der nächste und letzte Schritt ist entscheidend für unseren Car Classifier Use Case: da wir das Trainingsbild vorverarbeitet haben, bevor wir unser Modell trainiert haben (z.B. Normalisierung), müssen wir die gleiche Transformation auf das Bild anwenden, das wir vorhersagen wollen. Die praktische Funktion „preprocess_input“ sorgt dafür, dass alle notwendigen Transformationen auf unser Bild angewendet werden.

from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet_v2 import preprocess_input

# Load image
img = image.load_img(path, target_size=(224, 224))
img = image.img_to_array(img)

# Preprocess and reshape data
img = preprocess_input(img)
img = img.reshape(-1, *img.shape)

Die RESTful API von TensorFlow Serving bietet mehrere Endpunkte. Im Allgemeinen akzeptiert die API Post-Requests der folgenden Struktur:

POST http://host:port/<URI>:<VERB>

URI: /v1/models/{MODEL_NAME}[/versions/{MODEL_VERSION}]
VERB: classify|regress|predict

Für unser Modell können wir die folgende URL für Vorhersagen verwenden: http://localhost:8501/v1/models/resnet_unfreeze_all_filtered:predict

Die Portnummer (hier 8501) ist der Port des Hosts, den wir oben angegeben haben, um ihn auf den Port 8501 des Serving-Images abzubilden. Wie oben erwähnt, ist 8501 der Port des Serving-Containers, der für die REST-API verwendet wird. Die Modellversion ist optional und wird standardmäßig auf die neueste Version gesetzt, wenn sie weggelassen wird.

In Python kann die Bibliothek requests verwendet werden, um HTTP-Anfragen zu senden. Wie in der Dokumentation angegeben, muss der Request-Body für die predict API ein JSON-Objekt mit den unten aufgeführten Schlüssel-Wert-Paaren sein:

  • signature_name – zu verwendende Signatur (weitere Informationen finden Sie in der Dokumentation)
  • instances – Modelleingabe im Zeilenformat
import json
import requests

# Send image as list to TF serving via json dump
request_url = 'http://localhost:8501/v1/models/resnet_unfreeze_all_filtered:predict'
request_body = json.dumps({"signature_name": "serving_default", "instances": img.tolist()})
request_headers = {"content-type": "application/json"}
json_response = requests.post(request_url, data=request_body, headers=request_headers)
response_body = json.loads(json_response.text)
predictions = response_body['predictions']

# Get label from prediction
y_hat_idx = np.argmax(predictions)
y_hat = classes[y_hat_idx]

Der Response-Body ist ebenfalls ein JSON-Objekt mit einem einzigen Schlüssel namens predictions. Da wir für jede Zeile in den Instanzen die Wahrscheinlichkeit für alle 300 Klassen erhalten, verwenden wir np.argmax, um die wahrscheinlichste Klasse zurückzugeben. Alternativ hätten wir auch die übergeordnete classify-API verwenden können.

Fazit

In diesem zweiten Blog-Artikel der Serie „Car Model Classification“ haben wir gelernt, wie ein TensorFlow-Modell zur Bilderkennung mittels TensorFlow Serving als RestAPI bereitgestellt werden kann, und wie damit Modellabfragen ausgeführt werden können.

Dazu haben wir zuerst das Modell im SavedModel Format abgespeichert. Als nächstes haben wir den TensorFlow Serving-Server in einem Docker-Container gestartet. Schließlich haben wir gezeigt, wie man Vorhersagen aus dem Modell mit Hilfe der API-Endpunkte und einem korrekt spezifizierten Request Body anfordert.

Ein Hauptkritikpunkt an Deep Learning Modellen jeglicher Art ist die fehlende Erklärbarkeit der Vorhersagen. Im dritten Beitrag werden wir zeigen, wie man Modellvorhersagen mit einer Methode namens Grad-CAM erklären kann.

Deep Learning ist eines der Themen im Bereich der künstlichen Intelligenz, die uns bei STATWORX besonders faszinieren. In dieser Blogserie möchten wir veranschaulichen, wie ein End-to-end Deep Learning Projekt implementiert werden kann. Dabei verwenden wir die TensorFlow 2.x Bibliothek für die Implementierung.

Die Themen der 4-teiligen Blogserie umfassen:

  • Transfer Learning für Computer Vision
  • Deployment über TensorFlow Serving
  • Interpretierbarkeit von Deep-Learning-Modellen mittels Grad-CAM
  • Integration des Modells in ein Dashboard

Im ersten Teil zeigen wir, wie man Transfer Learning nutzen kann, um die Marke eines Autos mittels Bildklassifizierung vorherzusagen. Wir beginnen mit einem kurzen Überblick über Transfer Learning und das ResNet und gehen dann auf die Details der Implementierung ein. Der vorgestellte Code ist in diesem Github Repository zu finden.

Table of Contents

Einführung: Transfer Learning & ResNet

Was ist Transfer Learning?

Beim traditionellen (Machine) Learning entwickeln wir ein Modell und trainieren es auf neuen Daten für jede neue Aufgabe, die ansteht. Transfer Learning unterscheidet sich von diesem Ansatz dadurch, dass das gesammelte Wissen von einer Aufgabe auf eine andere übertragen wird. Dieser Ansatz ist besonders nützlich, wenn einem zu wenige Trainingsdaten zur Verfügung stehen. Modelle, die für ein ähnliches Problem vortrainiert wurden, können als Ausgangspunkt für das Training neuer Modelle verwendet werden. Die vortrainierten Modelle werden als Basismodelle bezeichnet.

In unserem Beispiel kann ein Deep Learning-Modell, das auf dem ImageNet-Datensatz trainiert wurde, als Ausgangspunkt für die Erstellung eines Klassifikationsnetzwerks für Automodelle verwendet werden. Die Hauptidee hinter dem Transfer Learning für Deep Learning-Modelle ist, dass die ersten Layer eines Netzwerks verwendet werden, um wichtige High-Level-Features zu extrahieren, die für die jeweilige Art der behandelten Daten ähnlich bleiben. Die finalen Layer, auch „head“ genannt, des ursprünglichen Netzwerks werden durch einen benutzerdefinierten head ersetzt, der für das vorliegende Problem geeignet ist. Die Gewichte im head werden zufällig initialisiert, und das resultierende Netz kann für die spezifische Aufgabe trainiert werden.

Es gibt verschiedene Möglichkeiten, wie das Basismodell beim Training behandelt werden kann. Im ersten Schritt können seine Gewichte fixiert werden. Wenn der Lernfortschritt darauf schließen lässt, dass das Modell nicht flexibel genug ist, können bestimmte Layer oder das gesamte Basismodell auch mit trainiert werden. Ein weiterer wichtiger Aspekt, den es zu beachten gilt, ist, dass der Input die gleiche Dimensionalität haben muss wie die Daten, auf denen das Basismodell initial trainiert wurde – sofern die ersten Layer des Basismodells festgehalten werden sollen.

image-20200319174208670

Als nächstes stellen wir kurz das ResNet vor, eine beliebte und leistungsfähige CNN-Architektur für Bilddaten. Anschließend zeigen wir, wie wir Transfer Learning mit ResNet zur Klassifizierung von Automodellen eingesetzt haben.

Was ist ResNet?

Das Training von Deep Neural Networks kann aufgrund des sogenannten Vanishing Gradient-Problems schnell zur Herausforderung werden. Aber was sind Vanishing Gradients? Neuronale Netze werden in der Regel mit Back-Propagation trainiert. Dieser Algorithmus nutzt die Kettenregel der Differentialrechnung, um Gradienten in tieferen Layern des Netzes abzuleiten, indem Gradienten aus früheren Layern multipliziert werden. Da Gradienten in Deep Networks wiederholt multipliziert werden, können sie sich während der Backpropagation schnell infinitesimal kleinen Werten annähern.

ResNet ist ein CNN-Netz, welches das Problem des Vanishing Gradients mit sogenannten Residualblöcken löst (eine gute Erklärung, warum sie ‚Residual‘ heißen, findest du hier). Im Residualblock wird die unmodifizierte Eingabe an das nächste Layer weitergereicht, indem sie zum Ausgang eines Layers addiert wird (siehe Abbildung rechts). Diese Modifikation sorgt dafür, dass ein besserer Informationsfluss von der Eingabe zu den tieferen Layers möglich ist. Die gesamte ResNet-Architektur ist im rechten Netzwerk in der linken Abbildung unten dargestellt. Weiter sind daneben ein klassisches CNN und das VGG-19-Netzwerk, eine weitere Standard-CNN-Architektur, abgebildet.

Resnet-Architecture_Residual-Block

ResNet hat sich als leistungsfähige Netzarchitektur für Bildklassifikationsprobleme erwiesen. Zum Beispiel hat ein Ensemble von ResNets mit 152 Layern den ILSVRC 2015 Bildklassifikationswettbewerb gewonnen. Im Modul tensorflow.keras.application sind vortrainierte ResNet-Modelle unterschiedlicher Größe verfügbar, nämlich ResNet50, ResNet101, ResNet152 und die entsprechenden zweiten Versionen (ResNet50V2, …). Die Zahl hinter dem Modellnamen gibt die Anzahl der Layer an, über die die Netze verfügen. Die verfügbaren Gewichte sind auf dem ImageNet-Datensatz vortrainiert. Die Modelle wurden auf großen Rechenclustern unter Verwendung von spezialisierter Hardware (z.B. TPU) über signifikante Zeiträume trainiert. Transfer Learning ermöglicht es uns daher, diese Trainingsergebnisse zu nutzen und die erhaltenen Gewichte als Ausgangspunkt zu verwenden.

Klassifizierung von Automodellen

Als anschauliches Beispiel für die Anwendung von Transfer Learning behandeln wir das Problem der Klassifizierung des Automodells anhand eines Bildes des Autos. Wir beginnen mit der Beschreibung des verwendeten Datensatzes und wie wir unerwünschte Beispiele aus dem Datensatz herausfiltern können. Anschließend gehen wir darauf ein, wie eine Datenpipeline mit tensorflow.data eingerichtet werden kann. Im zweiten Abschnitt werden wir die Modellimplementierung durchgehen und aufzeigen, auf welche Aspekte ihr beim Training und bei der Inferenz besonders achten müsst.

Datenvorbereitung

Wir haben den Datensatz aus diesem GitHub Repo verwendet – dort könnt ihr den gesamten Datensatz herunterladen. Der Autor hat einen Datascraper gebaut, um alle Autobilder von der Car Connection Website zu scrapen. Er erklärt, dass viele Bilder aus dem Innenraum der Autos stammen. Da sie im Datensatz nicht erwünscht sind, filtern wir sie anhand der Pixelfarbe heraus. Der Datensatz enthält 64’467 jpg-Bilder, wobei die Dateinamen Informationen über die Automarke, das Modell, das Baujahr usw. enthalten. Für einen detaillierteren Einblick in den Datensatz empfehlen wir euch, das originale GitHub Repo zu konsultieren. Hier sind drei Beispielbilder:

Car Collage 01

Bei der Betrachtung der Daten haben wir festgestellt, dass im Datensatz noch viele unerwünschte Bilder enthalten waren, z.B. Bilder von Außenspiegeln, Türgriffen, GPS-Panels oder Leuchten. Beispiele für unerwünschte Bilder sind hier zu sehen:

Car Collage 02

Daher ist es von Vorteil, die Daten zusätzlich vorzufiltern, um mehr unerwünschte Bilder zu entfernen.

Filtern unerwünschter Bilder aus dem Datensatz

Es gibt mehrere mögliche Ansätze, um Nicht-Auto-Bilder aus dem Datensatz herauszufiltern:

  1. Verwendung eines vortrainierten Modells
  2. Ein anderes Modell trainieren, um Auto/Nicht-Auto zu klassifizieren
  3. Trainieren eines Generative Networks auf einem Auto-Datensatz und Verwendung des Diskriminatorteil des Netzwerks

Wir haben uns für den ersten Ansatz entschieden, da er der direkteste ist und ausgezeichnete, vortrainierte Modelle leicht verfügbar sind. Wenn ihr den zweiten oder dritten Ansatz verfolgen wollt, könnt ihr z. B. diesen Datensatz verwenden, um das Modell zu trainieren. Dieser Datensatz enthält nur Bilder von Autos, ist aber deutlich kleiner als der von uns verwendete Datensatz.

Unsere Wahl fiel auf das ResNet50V2 im Modul tensorflow.keras.applications mit den vortrainierten „imagenet“-Gewichten. In einem ersten Schritt müssen wir jetzt die Indizes und Klassennamen der imagenet-Labels herausfinden, die den Autobildern entsprechen.

# Class labels in imagenet corresponding to cars
CAR_IDX = [656, 627, 817, 511, 468, 751, 705, 757, 717, 734, 654, 675, 864, 609, 436]

CAR_CLASSES = ['minivan', 'limousine', 'sports_car', 'convertible', 'cab', 'racer', 'passenger_car', 'recreational_vehicle', 'pickup', 'police_van', 'minibus', 'moving_van', 'tow_truck', 'jeep', 'landrover', 'beach_wagon']

Als nächstes laden wir das vortrainierte ResNet50V2-Modell.

from tensorflow.keras.applications import ResNet50V2

model = ResNet50V2(weights='imagenet')

Wir können dieses Modell nun verwenden, um die Bilder zu klassifizieren. Die Bilder, die der Vorhersagemethode zugeführt werden, müssen identisch skaliert sein wie die Bilder, die zum Training verwendet wurden. Die verschiedenen ResNet-Modelle werden auf unterschiedlich skalierten Bilddaten trainiert. Es ist daher wichtig, das richtige Preprocessing anzuwenden.

from tensorflow.keras.applications.resnet_v2 import preprocess_input

image = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image)
image = tf.cast(image, tf.float32)
image = tf.image.resize_with_crop_or_pad(image, target_height=224, target_width=224)
image = preprocess_input(image)
predictions = model.predict(image)

Es gibt verschiedene Ideen, wie die erhaltenen Vorhersagen für die Autoerkennung verwendet werden können.

  • Ist eine der CAR_CLASSES unter den Top-k-Vorhersagen?
  • Ist die kumulierte Wahrscheinlichkeit der CAR_CLASSES in den Vorhersagen größer als ein definierter Schwellenwert?
  • Spezielle Behandlung unerwünschter Bilder (z. B. Erkennen und Herausfiltern von Rädern)?

Wir zeigen den Code für den Vergleich der kumulierten Wahrscheinlichkeitsmaße über die CAR_CLASSES.

def is_car_acc_prob(predictions, thresh=THRESH, car_idx=CAR_IDX):
    """
    Determine if car on image by accumulating probabilities of car prediction and comparing to threshold

    Args:
        predictions: (?, 1000) matrix of probability predictions resulting from ResNet with                                              imagenet weights
        thresh: threshold accumulative probability over which an image is considered a car
        car_idx: indices corresponding to cars

    Returns:
        np.array of booleans describing if car or not
    """
    predictions = np.array(predictions, dtype=float)
    car_probs = predictions[:, car_idx]
    car_probs_acc = car_probs.sum(axis=1)
    return car_probs_acc > thresh

Je höher der Schwellenwert eingestellt ist, desto strenger ist das Filterverfahren. Ein Wert für den Schwellenwert, der gute Ergebnisse liefert, ist THRESH = 0.1. Damit wird sichergestellt, dass nicht zu viele echte Bilder von Autos verloren gehen. Die Wahl eines geeigneten Schwellenwerts bleibt jedoch eine subjektive Angelegenheit.

Das Colab-Notebook, in dem die Funktion is_car_acc_prob zum Filtern des Datensatzes verwendet wird, ist im GitHub Repository verfügbar.

Bei der Abstimmung der Vorfilterung haben wir Folgendes beobachtet:

  • Viele der Autobilder mit hellem Hintergrund wurden als „Strandwagen“ klassifiziert. Wir haben daher entschieden, auch die Klasse „Strandwagen“ in imagenet als eine der CAR_CLASSES zu berücksichtigen.
  • Bilder, die die Front eines Autos zeigen, bekommen oft eine hohe Wahrscheinlichkeit der Klasse „Kühlergrill“ („grille“) zugeordnet, d.h. dem Gitter an der Front eines Autos, das zur Kühlung dient. Diese Zuordnung ist korrekt, führt aber dazu, dass die oben gezeigte Prozedur bestimmte Bilder von Autos nicht als Autos betrachtet, da wir „grille“ nicht in die CAR_CLASSES aufgenommen haben. Dieses Problem führt zu dem Kompromiss, entweder viele Nahaufnahmen von Autokühlergrills im Datensatz zu belassen oder einige Autobilder herauszufiltern. Wir haben uns für den zweiten Ansatz entschieden, da er einen saubereren Datensatz ergibt.

Nach der Vorfilterung der Bilder mit dem vorgeschlagenen Verfahren verbleiben zunächst 53’738 von 64’467 im Datensatz.

Übersicht über die endgültigen Datensätze

Der vorgefilterte Datensatz enthält Bilder von 323 Automodellen. Wir haben uns dazu entschieden, unsere Aufmerksamkeit auf die 300 häufigsten Klassen im Datensatz zu reduzieren. Das ist deshalb sinnvoll, da einige der am wenigsten häufigen Klassen weniger als zehn Repräsentanten haben und somit nicht sinnvoll in ein Trainings-, Validierungs- und Testset aufgeteilt werden können. Reduziert man den Datensatz auf die Bilder der 300 häufigsten Klassen, erhält man einen Datensatz mit 53.536 beschrifteten Bildern. Die Klassenvorkommen sind wie folgt verteilt:

Histogram

Die Anzahl der Bilder pro Klasse (Automodell) reicht von 24 bis knapp unter 500. Wir können sehen, dass der Datensatz sehr unausgewogen ist. Dies muss beim Training und bei der Auswertung des Modells unbedingt beachtet werden.

Aufbau von Datenpipelines mit tf.data

Selbst nach der Vorfilterung und der Reduktion auf die besten 300 Klassen bleiben immer noch zahlreiche Bilder übrig. Dies stellt ein potenzielles Problem dar, da wir nicht einfach alle Bilder auf einmal in den Speicher unserer GPU laden können. Um dieses Problem zu lösen, werden wir tf.data verwenden.

Mit tf.data und insbesondere der tf.data.Dataset API lassen sich elegante und gleichzeitig sehr effiziente Eingabe-Pipelines erstellen. Die API enthält viele allgemeine Methoden, die zum Laden und Transformieren potenziell großer Datensätze verwendet werden können. Die Methode tf.data.Dataset ist besonders nützlich, wenn Modelle auf GPU(s) trainiert werden. Es ermöglicht das Laden von Daten von der Festplatte, wendet on-the-fly Transformationen an und erstellt Batches, die dann an die GPU gesendet werden. Und das alles geschieht so, dass die GPU nie auf neue Daten warten muss.

Die folgenden Funktionen erstellen eine <code>tf.data.Dataset-Instanz für unseren konkreten Anwendungsfall:

def construct_ds(input_files: list,
                 batch_size: int,
                 classes: list,
                 label_type: str,
                 input_size: tuple = (212, 320),
                 prefetch_size: int = 10,
                 shuffle_size: int = 32,
                 shuffle: bool = True,
                 augment: bool = False):
    """
    Function to construct a tf.data.Dataset set from list of files

    Args:
        input_files: list of files
        batch_size: number of observations in batch
        classes: list with all class labels
        input_size: size of images (output size)
        prefetch_size: buffer size (number of batches to prefetch)
        shuffle_size: shuffle size (size of buffer to shuffle from)
        shuffle: boolean specifying whether to shuffle dataset
        augment: boolean if image augmentation should be applied
        label_type: 'make' or 'model'

    Returns:
        buffered and prefetched tf.data.Dataset object with (image, label) tuple
    """
    # Create tf.data.Dataset from list of files
    ds = tf.data.Dataset.from_tensor_slices(input_files)

    # Shuffle files
    if shuffle:
        ds = ds.shuffle(buffer_size=shuffle_size)

    # Load image/labels
    ds = ds.map(lambda x: parse_file(x, classes=classes, input_size=input_size,                                                                                                                                        label_type=label_type))

    # Image augmentation
    if augment and tf.random.uniform((), minval=0, maxval=1, dtype=tf.dtypes.float32, seed=None, name=None) < 0.7:
        ds = ds.map(image_augment)

    # Batch and prefetch data
    ds = ds.batch(batch_size=batch_size)
    ds = ds.prefetch(buffer_size=prefetch_size)

    return ds

Wir werden nun die verwendeten tf.data-Methoden beschreiben:

  • from_tensor_slices() ist eine der verfügbaren Methoden für die Erstellung eines Datensatzes. Der erzeugte Datensatz enthält Slices des angegebenen Tensors, in diesem Fall die Dateinamen.
  • Als nächstes betrachtet die Methode shuffle() jeweils buffer_size-Elemente separat und mischt diese Elemente isoliert vom Rest des Datensatzes. Wenn das Mischen des gesamten Datensatzes erforderlich ist, muss buffer_size größer sein als die Anzahl der Einträge im Datensatz. Das Mischen wird nur durchgeführt, wenn shuffle=True gesetzt ist.
  • Mit map() lassen sich beliebige Funktionen auf den Datensatz anwenden. Wir haben eine Funktion parse_file() erstellt, die im GitHub Repo zu finden ist. Sie ist verantwortlich für das Lesen und die Größenänderung der Bilder, das Ableiten der Beschriftungen aus dem Dateinamen und die Kodierung der Beschriftungen mit einem One-Hot-Encoder. Wenn die Flag „augment“ gesetzt ist, wird das Verfahren zur Datenerweiterung aktiviert. Die Augmentierung wird nur in 70 % der Fälle angewendet, da es von Vorteil ist, das Modell auch auf nicht modifizierten Bildern zu trainieren. Die in image_augment verwendeten Augmentierungstechniken sind Flipping, Helligkeits- und Kontrastanpassungen.
  • Schließlich wird die Methode batch() verwendet, um den Datensatz in Batches der Größe batch_size zu gruppieren, und die Methode prefetch() ermöglicht die Vorbereitung späterer Batches, während der aktuelle Batch verarbeitet wird, und verbessert so die Leistung. Wenn die Methode nach einem Aufruf von batch() verwendet wird, werden prefetch_size-Batches vorab geholt.

Fine Tuning des Modells

Nachdem wir unsere Eingabe-Pipeline definiert haben, wenden wir uns nun dem Trainingsteil des Modells zu. Der Code unten zeigt auf, wie ein Modell basierend auf dem vortrainierten ResNet instanziiert werden kann:

from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D


class TransferModel:

    def __init__(self, shape: tuple, classes: list):
        """
        Class for transfer learning from ResNet

        Args:
            shape: Input shape as tuple (height, width, channels)
            classes: List of class labels
        """
        self.shape = shape
        self.classes = classes
        self.history = None
        self.model = None

        # Use pre-trained ResNet model
        self.base_model = ResNet50V2(include_top=False,
                                     input_shape=self.shape,
                                     weights='imagenet')

        # Allow parameter updates for all layers
        self.base_model.trainable = True

        # Add a new pooling layer on the original output
        add_to_base = self.base_model.output
        add_to_base = GlobalAveragePooling2D(data_format='channels_last', name='head_gap')(add_to_base)

        # Add new output layer as head
        new_output = Dense(len(self.classes), activation='softmax', name='head_pred')(add_to_base)

        # Define model
        self.model = Model(self.base_model.input, new_output)

Ein paar weitere Details zum oben stehenden Code:

  • Wir erzeugen zunächst eine Instanz der Klasse tf.keras.applications.ResNet50V2. Mit include_top=False weisen wir das vortrainierte Modell an, den ursprünglichen head des Modells (in diesem Fall für die Klassifikation von 1000 Klassen auf ImageNet ausgelegt) wegzulassen.
  • Mit base_model.trainable = True werden alle Layer trainierbar.
  • Mit der funktionalen API tf.keras stapeln wir dann ein neues Pooling-Layer auf den letzten Faltungsblock des ursprünglichen ResNet-Modells. Dies ist ein notwendiger Zwischenschritt, bevor die Ausgabe an die endgültigen Klassifizierungs-Layer weitergeleitet wird.
  • Die endgültigen Klassifizierungs-Layer wird dann mit „tf.keras.layers.Dense“ definiert. Wir definieren die Anzahl der Neuronen so, dass sie gleich der Anzahl der gewünschten Klassen ist. Und die Softmax-Aktivierungsfunktion sorgt dafür, dass die Ausgabe eine Pseudowahrscheinlichkeit im Bereich von (0,1] ist.

Die Vollversion von TransferModel (s. GitHub) enthält auch die Option, das Basismodell durch ein VGG16-Netzwerk zu ersetzen, ein weiteres Standard-CNN für die Bildklassifikation. Außerdem erlaubt es, nur bestimmte Layer freizugeben, d.h. wir können die entsprechenden Parameter trainierbar machen, während wir die anderen festgehalten werden. Standardmässig haben wir hier alle Parameter trainierbar gemacht.

Nachdem wir das Modell definiert haben, müssen wir es für das Training konfigurieren. Dies kann mit der compile()-Methode von tf.keras.Model gemacht werden:

def compile(self, **kwargs):
      """
    Compile method
    """
    self.model.compile(**kwargs)

Wir übergeben dann die folgenden Keyword-Argumente an unsere Methode:

  • loss = "categorical_crossentropy" für die Mehrklassen-Klassifikation,
  • optimizer = Adam(0.0001) für die Verwendung des Adam-Optimierers aus tf.keras.optimizers mit einer relativ kleinen Lernrate (mehr zur Lernrate weiter unten), und
  • metrics = ["categorical_accuracy"] für die Trainings- und Validierungsüberwachung.

Als Nächstes wollen wir uns das Trainingsverfahren ansehen. Dazu definieren wir eine train-Methode für unsere oben vorgestellte TransferModel-Klasse:

from tensorflow.keras.callbacks import EarlyStopping

def train(self,
          ds_train: tf.data.Dataset,
          epochs: int,
          ds_valid: tf.data.Dataset = None,
          class_weights: np.array = None):
    """
    Trains model in ds_train with for epochs rounds

    Args:
        ds_train: training data as tf.data.Dataset
        epochs: number of epochs to train
        ds_valid: optional validation data as tf.data.Dataset
        class_weights: optional class weights to treat unbalanced classes

    Returns
        Training history from self.history
    """

    # Define early stopping as callback
    early_stopping = EarlyStopping(monitor='val_loss',
                                   min_delta=0,
                                   patience=12,
                                   restore_best_weights=True)

    callbacks = [early_stopping]

    # Fitting
    self.history = self.model.fit(ds_train,
                                  epochs=epochs,
                                  validation_data=ds_valid,
                                  callbacks=callbacks,
                                  class_weight=class_weights)

    return self.history

Da unser Modell eine Instanz von tensorflow.keras.Model ist, können wir es mit der Methode fit trainieren. Um Overfitting zu verhindern, wird Early Stopping verwendet, indem es als Callback-Funktion an die fit-Methode übergeben wird. Der patience-Parameter kann eingestellt werden, um festzulegen, wie schnell das Early Stopping angewendet werden soll. Der Parameter steht für die Anzahl der Epochen, nach denen, wenn keine Abnahme des Validierungsverlustes registriert wird, das Training abgebrochen wird. Weiterhin können Klassengewichte an die Methode fit übergeben werden. Klassengewichte erlauben es, unausgewogene Daten zu behandeln, indem den verschiedenen Klassen unterschiedliche Gewichte zugewiesen werden, wodurch die Wirkung von Klassen mit weniger Trainingsbeispielen erhöht werden kann.

Wir können den Trainingsprozess mit einem vortrainierten Modell wie folgt beschreiben: Da die Gewichte im head zufällig initialisiert werden und die Gewichte des Basismodells vortrainiert sind, setzt sich das Training aus dem Training des heads von Grund auf und der Feinabstimmung der Gewichte des vortrainierten Modells zusammen. Es wird generell für Transfer Learning empfohlen, eine kleine Lernrate zu verwenden (z. B. 1e-4), da eine zu große Lernrate die nahezu optimalen vortrainierten Gewichte des Basismodells zerstören kann.

Der Trainingsvorgang kann beschleunigt werden, indem zunächst einige Epochen lang trainiert wird, ohne dass das Basismodell trainierbar ist. Der Zweck dieser ersten Epochen ist es, die Gewichte des heads an das Problem anzupassen. Dies beschleunigt das Training, da wenn nur der head trainiert wird, viel weniger Parameter trainierbar sind und somit für jeden Batch aktualisiert werden. Die resultierenden Modellgewichte können dann als Ausgangspunkt für das Training des gesamten Modells verwendet werden, wobei das Basismodell trainierbar ist. Für das hier betrachtete Autoklassifizierungsproblem führte die Anwendung dieses zweistufigen Trainings zu keiner nennenswerten Leistungsverbesserung.

Evaluation/Vorhersage der Modell Performance

Bei der Verwendung der API tf.data.Dataset muss man auf die Art der verwendeten Methoden achten. Die folgende Methode in unserer Klasse TransferModel kann als Vorhersagemethode verwendet werden.

def predict(self, ds_new: tf.data.Dataset, proba: bool = True):
    """
    Predict class probs or labels on ds_new
    Labels are obtained by taking the most likely class given the predicted probs

    Args:
        ds_new: New data as tf.data.Dataset
        proba: Boolean if probabilities should be returned

    Returns:
        class labels or probabilities
    """

    p = self.model.predict(ds_new)

    if proba:
        return p
    else:
        return [np.argmax(x) for x in p]

Es ist wichtig, dass der Datensatz ds_new nicht gemischt wird, sonst stimmen die erhaltenen Vorhersagen nicht mit den erhaltenen Bildern überein, wenn ein zweites Mal über den Datensatz iteriert wird. Dies ist der Fall, da die Flag reshuffle_each_iteration in der Implementierung der Methode shuffle standardmäßig auf True gesetzt ist. Ein weiterer Effekt des Shufflens ist, dass mehrere Aufrufe der Methode take nicht die gleichen Daten zurückgeben. Dies ist wichtig, wenn z. B. Vorhersagen für nur eine Charge überprüft werden sollen. Ein einfaches Beispiel, an dem dies zu sehen ist, ist:

# Use construct_ds method from above to create a shuffled dataset
ds = construct_ds(..., shuffle=True)

# Take 1 batch (e.g. 32 images) of dataset: This returns a new dataset
ds_batch = ds.take(1)

# Predict labels for one batch
predictions = model.predict(ds_batch)

# Predict labels again: The result will not be the same as predictions above due to shuffling
predictions_2 = model.predict(ds_batch)

Eine Funktion zum Plotten von Bildern, die mit den entsprechenden Vorhersagen beschriftet sind, könnte wie folgt aussehen:

def show_batch_with_pred(model, ds, classes, rescale=True, size=(10, 10), title=None):
      for image, label in ds.take(1):
        image_array = image.numpy()
        label_array = label.numpy()
        batch_size = image_array.shape[0]
        pred = model.predict(image, proba=False)
        for idx in range(batch_size):
            label = classes[np.argmax(label_array[idx])]
            ax = plt.subplot(np.ceil(batch_size / 4), 4, idx + 1)
            if rescale:
                plt.imshow(image_array[idx] / 255)
            else:
                plt.imshow(image_array[idx])
            plt.title("label: " + label + "n" 
                      + "prediction: " + classes[pred[idx]], fontsize=10)
            plt.axis('off')

Die Methode show_batch_with_pred funktioniert auch für gemischte Datensätze, da image und label demselben Aufruf der Methode take entsprechen.

Die Auswertung der Model-Performance kann mit der Methode evaluate von keras.Model durchgeführt werden.

Wie akkurat ist unser finales Modell?

Das Modell erreicht eine kategoriale Genauigkeit von etwas über 70 % für die Vorhersage des Automodells für Bilder aus 300 Modellklassen. Um die Vorhersagen des Modells besser zu verstehen, ist es hilfreich, die Konfusionsmatrix zu betrachten. Unten ist die Heatmap der Vorhersagen des Modells für den Validierungsdatensatz abgebildet.

heatmap

Wir haben die Heatmap auf Einträge der Konfusionsmatrix in [0, 5] beschränkt, da das Zulassen einer weiteren Spanne keine Region außerhalb der Diagonalen signifikant hervorgehoben hat. Wie in der Heatmap zu sehen ist, wird eine Klasse den Beispielen fast aller Klassen zugeordnet. Das ist an der dunkelroten vertikalen Linie zwei Drittel rechts in der Abbildung oben zu erkennen.

Abgesehen von der zuvor erwähnten Klasse gibt es keine offensichtlichen Verzerrungen in den Vorhersagen. Wir möchten an dieser Stelle betonen, dass die Accuracy im Allgemeinen nicht ausreicht, um die Leistung eines Modells zufriedenstellend zu beurteilen, insbesondere im Fall unausgewogener Klassen.

Fazit und nächste Schritte

In diesem Blog-Beitrag haben wir Transfer Learning mit dem ResNet50V2 angewendet, um das Fahrzeugmodell anhand von Bildern von Autos zu klassifizieren. Unser Modell erreicht 70% kategoriale Genauigkeit über 300 Klassen.

Wir haben festgestellt, dass das Trainieren des gesamten Basismodells und die Verwendung einer kleinen Lernrate die besten Ergebnisse erzielen. Ein cooles Auto-Klassifikationsmodell entwickelt zu haben ist großartig, aber wie können wir unser Modell in einer produktiven Umgebung einsetzen? Natürlich könnten wir unsere eigene Modell-API mit Flask oder FastAPI bauen… Aber gibt es vielleicht sogar einen einfacheren, standardisierten Weg?

Im zweiten Beitrag unserer Blog-Serie, „Deployment von TensorFlow-Modellen in Docker mit TensorFlow Serving“ zeigen wir Euch, wie dieses Modell mit TensorFlow Serving bereitgestellt werden kann.

Deep Learning ist eines der Themen im Bereich der künstlichen Intelligenz, die uns bei STATWORX besonders faszinieren. In dieser Blogserie möchten wir veranschaulichen, wie ein End-to-end Deep Learning Projekt implementiert werden kann. Dabei verwenden wir die TensorFlow 2.x Bibliothek für die Implementierung.

Die Themen der 4-teiligen Blogserie umfassen:

Im ersten Teil zeigen wir, wie man Transfer Learning nutzen kann, um die Marke eines Autos mittels Bildklassifizierung vorherzusagen. Wir beginnen mit einem kurzen Überblick über Transfer Learning und das ResNet und gehen dann auf die Details der Implementierung ein. Der vorgestellte Code ist in diesem Github Repository zu finden.

Table of Contents

Einführung: Transfer Learning & ResNet

Was ist Transfer Learning?

Beim traditionellen (Machine) Learning entwickeln wir ein Modell und trainieren es auf neuen Daten für jede neue Aufgabe, die ansteht. Transfer Learning unterscheidet sich von diesem Ansatz dadurch, dass das gesammelte Wissen von einer Aufgabe auf eine andere übertragen wird. Dieser Ansatz ist besonders nützlich, wenn einem zu wenige Trainingsdaten zur Verfügung stehen. Modelle, die für ein ähnliches Problem vortrainiert wurden, können als Ausgangspunkt für das Training neuer Modelle verwendet werden. Die vortrainierten Modelle werden als Basismodelle bezeichnet.

In unserem Beispiel kann ein Deep Learning-Modell, das auf dem ImageNet-Datensatz trainiert wurde, als Ausgangspunkt für die Erstellung eines Klassifikationsnetzwerks für Automodelle verwendet werden. Die Hauptidee hinter dem Transfer Learning für Deep Learning-Modelle ist, dass die ersten Layer eines Netzwerks verwendet werden, um wichtige High-Level-Features zu extrahieren, die für die jeweilige Art der behandelten Daten ähnlich bleiben. Die finalen Layer, auch „head“ genannt, des ursprünglichen Netzwerks werden durch einen benutzerdefinierten head ersetzt, der für das vorliegende Problem geeignet ist. Die Gewichte im head werden zufällig initialisiert, und das resultierende Netz kann für die spezifische Aufgabe trainiert werden.

Es gibt verschiedene Möglichkeiten, wie das Basismodell beim Training behandelt werden kann. Im ersten Schritt können seine Gewichte fixiert werden. Wenn der Lernfortschritt darauf schließen lässt, dass das Modell nicht flexibel genug ist, können bestimmte Layer oder das gesamte Basismodell auch mit trainiert werden. Ein weiterer wichtiger Aspekt, den es zu beachten gilt, ist, dass der Input die gleiche Dimensionalität haben muss wie die Daten, auf denen das Basismodell initial trainiert wurde – sofern die ersten Layer des Basismodells festgehalten werden sollen.

image-20200319174208670

Als nächstes stellen wir kurz das ResNet vor, eine beliebte und leistungsfähige CNN-Architektur für Bilddaten. Anschließend zeigen wir, wie wir Transfer Learning mit ResNet zur Klassifizierung von Automodellen eingesetzt haben.

Was ist ResNet?

Das Training von Deep Neural Networks kann aufgrund des sogenannten Vanishing Gradient-Problems schnell zur Herausforderung werden. Aber was sind Vanishing Gradients? Neuronale Netze werden in der Regel mit Back-Propagation trainiert. Dieser Algorithmus nutzt die Kettenregel der Differentialrechnung, um Gradienten in tieferen Layern des Netzes abzuleiten, indem Gradienten aus früheren Layern multipliziert werden. Da Gradienten in Deep Networks wiederholt multipliziert werden, können sie sich während der Backpropagation schnell infinitesimal kleinen Werten annähern.

ResNet ist ein CNN-Netz, welches das Problem des Vanishing Gradients mit sogenannten Residualblöcken löst (eine gute Erklärung, warum sie ‚Residual‘ heißen, findest du hier). Im Residualblock wird die unmodifizierte Eingabe an das nächste Layer weitergereicht, indem sie zum Ausgang eines Layers addiert wird (siehe Abbildung rechts). Diese Modifikation sorgt dafür, dass ein besserer Informationsfluss von der Eingabe zu den tieferen Layers möglich ist. Die gesamte ResNet-Architektur ist im rechten Netzwerk in der linken Abbildung unten dargestellt. Weiter sind daneben ein klassisches CNN und das VGG-19-Netzwerk, eine weitere Standard-CNN-Architektur, abgebildet.

Resnet-Architecture_Residual-Block

ResNet hat sich als leistungsfähige Netzarchitektur für Bildklassifikationsprobleme erwiesen. Zum Beispiel hat ein Ensemble von ResNets mit 152 Layern den ILSVRC 2015 Bildklassifikationswettbewerb gewonnen. Im Modul tensorflow.keras.application sind vortrainierte ResNet-Modelle unterschiedlicher Größe verfügbar, nämlich ResNet50, ResNet101, ResNet152 und die entsprechenden zweiten Versionen (ResNet50V2, …). Die Zahl hinter dem Modellnamen gibt die Anzahl der Layer an, über die die Netze verfügen. Die verfügbaren Gewichte sind auf dem ImageNet-Datensatz vortrainiert. Die Modelle wurden auf großen Rechenclustern unter Verwendung von spezialisierter Hardware (z.B. TPU) über signifikante Zeiträume trainiert. Transfer Learning ermöglicht es uns daher, diese Trainingsergebnisse zu nutzen und die erhaltenen Gewichte als Ausgangspunkt zu verwenden.

Klassifizierung von Automodellen

Als anschauliches Beispiel für die Anwendung von Transfer Learning behandeln wir das Problem der Klassifizierung des Automodells anhand eines Bildes des Autos. Wir beginnen mit der Beschreibung des verwendeten Datensatzes und wie wir unerwünschte Beispiele aus dem Datensatz herausfiltern können. Anschließend gehen wir darauf ein, wie eine Datenpipeline mit tensorflow.data eingerichtet werden kann. Im zweiten Abschnitt werden wir die Modellimplementierung durchgehen und aufzeigen, auf welche Aspekte ihr beim Training und bei der Inferenz besonders achten müsst.

Datenvorbereitung

Wir haben den Datensatz aus diesem GitHub Repo verwendet – dort könnt ihr den gesamten Datensatz herunterladen. Der Autor hat einen Datascraper gebaut, um alle Autobilder von der Car Connection Website zu scrapen. Er erklärt, dass viele Bilder aus dem Innenraum der Autos stammen. Da sie im Datensatz nicht erwünscht sind, filtern wir sie anhand der Pixelfarbe heraus. Der Datensatz enthält 64’467 jpg-Bilder, wobei die Dateinamen Informationen über die Automarke, das Modell, das Baujahr usw. enthalten. Für einen detaillierteren Einblick in den Datensatz empfehlen wir euch, das originale GitHub Repo zu konsultieren. Hier sind drei Beispielbilder:

Car Collage 01

Bei der Betrachtung der Daten haben wir festgestellt, dass im Datensatz noch viele unerwünschte Bilder enthalten waren, z.B. Bilder von Außenspiegeln, Türgriffen, GPS-Panels oder Leuchten. Beispiele für unerwünschte Bilder sind hier zu sehen:

Car Collage 02

Daher ist es von Vorteil, die Daten zusätzlich vorzufiltern, um mehr unerwünschte Bilder zu entfernen.

Filtern unerwünschter Bilder aus dem Datensatz

Es gibt mehrere mögliche Ansätze, um Nicht-Auto-Bilder aus dem Datensatz herauszufiltern:

  1. Verwendung eines vortrainierten Modells
  2. Ein anderes Modell trainieren, um Auto/Nicht-Auto zu klassifizieren
  3. Trainieren eines Generative Networks auf einem Auto-Datensatz und Verwendung des Diskriminatorteil des Netzwerks

Wir haben uns für den ersten Ansatz entschieden, da er der direkteste ist und ausgezeichnete, vortrainierte Modelle leicht verfügbar sind. Wenn ihr den zweiten oder dritten Ansatz verfolgen wollt, könnt ihr z. B. diesen Datensatz verwenden, um das Modell zu trainieren. Dieser Datensatz enthält nur Bilder von Autos, ist aber deutlich kleiner als der von uns verwendete Datensatz.

Unsere Wahl fiel auf das ResNet50V2 im Modul tensorflow.keras.applications mit den vortrainierten „imagenet“-Gewichten. In einem ersten Schritt müssen wir jetzt die Indizes und Klassennamen der imagenet-Labels herausfinden, die den Autobildern entsprechen.

# Class labels in imagenet corresponding to cars
CAR_IDX = [656, 627, 817, 511, 468, 751, 705, 757, 717, 734, 654, 675, 864, 609, 436]

CAR_CLASSES = ['minivan', 'limousine', 'sports_car', 'convertible', 'cab', 'racer', 'passenger_car', 'recreational_vehicle', 'pickup', 'police_van', 'minibus', 'moving_van', 'tow_truck', 'jeep', 'landrover', 'beach_wagon']

Als nächstes laden wir das vortrainierte ResNet50V2-Modell.

from tensorflow.keras.applications import ResNet50V2

model = ResNet50V2(weights='imagenet')

Wir können dieses Modell nun verwenden, um die Bilder zu klassifizieren. Die Bilder, die der Vorhersagemethode zugeführt werden, müssen identisch skaliert sein wie die Bilder, die zum Training verwendet wurden. Die verschiedenen ResNet-Modelle werden auf unterschiedlich skalierten Bilddaten trainiert. Es ist daher wichtig, das richtige Preprocessing anzuwenden.

from tensorflow.keras.applications.resnet_v2 import preprocess_input

image = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image)
image = tf.cast(image, tf.float32)
image = tf.image.resize_with_crop_or_pad(image, target_height=224, target_width=224)
image = preprocess_input(image)
predictions = model.predict(image)

Es gibt verschiedene Ideen, wie die erhaltenen Vorhersagen für die Autoerkennung verwendet werden können.

Wir zeigen den Code für den Vergleich der kumulierten Wahrscheinlichkeitsmaße über die CAR_CLASSES.

def is_car_acc_prob(predictions, thresh=THRESH, car_idx=CAR_IDX):
    """
    Determine if car on image by accumulating probabilities of car prediction and comparing to threshold

    Args:
        predictions: (?, 1000) matrix of probability predictions resulting from ResNet with                                              imagenet weights
        thresh: threshold accumulative probability over which an image is considered a car
        car_idx: indices corresponding to cars

    Returns:
        np.array of booleans describing if car or not
    """
    predictions = np.array(predictions, dtype=float)
    car_probs = predictions[:, car_idx]
    car_probs_acc = car_probs.sum(axis=1)
    return car_probs_acc > thresh

Je höher der Schwellenwert eingestellt ist, desto strenger ist das Filterverfahren. Ein Wert für den Schwellenwert, der gute Ergebnisse liefert, ist THRESH = 0.1. Damit wird sichergestellt, dass nicht zu viele echte Bilder von Autos verloren gehen. Die Wahl eines geeigneten Schwellenwerts bleibt jedoch eine subjektive Angelegenheit.

Das Colab-Notebook, in dem die Funktion is_car_acc_prob zum Filtern des Datensatzes verwendet wird, ist im GitHub Repository verfügbar.

Bei der Abstimmung der Vorfilterung haben wir Folgendes beobachtet:

Nach der Vorfilterung der Bilder mit dem vorgeschlagenen Verfahren verbleiben zunächst 53’738 von 64’467 im Datensatz.

Übersicht über die endgültigen Datensätze

Der vorgefilterte Datensatz enthält Bilder von 323 Automodellen. Wir haben uns dazu entschieden, unsere Aufmerksamkeit auf die 300 häufigsten Klassen im Datensatz zu reduzieren. Das ist deshalb sinnvoll, da einige der am wenigsten häufigen Klassen weniger als zehn Repräsentanten haben und somit nicht sinnvoll in ein Trainings-, Validierungs- und Testset aufgeteilt werden können. Reduziert man den Datensatz auf die Bilder der 300 häufigsten Klassen, erhält man einen Datensatz mit 53.536 beschrifteten Bildern. Die Klassenvorkommen sind wie folgt verteilt:

Histogram

Die Anzahl der Bilder pro Klasse (Automodell) reicht von 24 bis knapp unter 500. Wir können sehen, dass der Datensatz sehr unausgewogen ist. Dies muss beim Training und bei der Auswertung des Modells unbedingt beachtet werden.

Aufbau von Datenpipelines mit tf.data

Selbst nach der Vorfilterung und der Reduktion auf die besten 300 Klassen bleiben immer noch zahlreiche Bilder übrig. Dies stellt ein potenzielles Problem dar, da wir nicht einfach alle Bilder auf einmal in den Speicher unserer GPU laden können. Um dieses Problem zu lösen, werden wir tf.data verwenden.

Mit tf.data und insbesondere der tf.data.Dataset API lassen sich elegante und gleichzeitig sehr effiziente Eingabe-Pipelines erstellen. Die API enthält viele allgemeine Methoden, die zum Laden und Transformieren potenziell großer Datensätze verwendet werden können. Die Methode tf.data.Dataset ist besonders nützlich, wenn Modelle auf GPU(s) trainiert werden. Es ermöglicht das Laden von Daten von der Festplatte, wendet on-the-fly Transformationen an und erstellt Batches, die dann an die GPU gesendet werden. Und das alles geschieht so, dass die GPU nie auf neue Daten warten muss.

Die folgenden Funktionen erstellen eine <code>tf.data.Dataset-Instanz für unseren konkreten Anwendungsfall:

def construct_ds(input_files: list,
                 batch_size: int,
                 classes: list,
                 label_type: str,
                 input_size: tuple = (212, 320),
                 prefetch_size: int = 10,
                 shuffle_size: int = 32,
                 shuffle: bool = True,
                 augment: bool = False):
    """
    Function to construct a tf.data.Dataset set from list of files

    Args:
        input_files: list of files
        batch_size: number of observations in batch
        classes: list with all class labels
        input_size: size of images (output size)
        prefetch_size: buffer size (number of batches to prefetch)
        shuffle_size: shuffle size (size of buffer to shuffle from)
        shuffle: boolean specifying whether to shuffle dataset
        augment: boolean if image augmentation should be applied
        label_type: 'make' or 'model'

    Returns:
        buffered and prefetched tf.data.Dataset object with (image, label) tuple
    """
    # Create tf.data.Dataset from list of files
    ds = tf.data.Dataset.from_tensor_slices(input_files)

    # Shuffle files
    if shuffle:
        ds = ds.shuffle(buffer_size=shuffle_size)

    # Load image/labels
    ds = ds.map(lambda x: parse_file(x, classes=classes, input_size=input_size,                                                                                                                                        label_type=label_type))

    # Image augmentation
    if augment and tf.random.uniform((), minval=0, maxval=1, dtype=tf.dtypes.float32, seed=None, name=None) < 0.7:
        ds = ds.map(image_augment)

    # Batch and prefetch data
    ds = ds.batch(batch_size=batch_size)
    ds = ds.prefetch(buffer_size=prefetch_size)

    return ds

Wir werden nun die verwendeten tf.data-Methoden beschreiben:

Fine Tuning des Modells

Nachdem wir unsere Eingabe-Pipeline definiert haben, wenden wir uns nun dem Trainingsteil des Modells zu. Der Code unten zeigt auf, wie ein Modell basierend auf dem vortrainierten ResNet instanziiert werden kann:

from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D


class TransferModel:

    def __init__(self, shape: tuple, classes: list):
        """
        Class for transfer learning from ResNet

        Args:
            shape: Input shape as tuple (height, width, channels)
            classes: List of class labels
        """
        self.shape = shape
        self.classes = classes
        self.history = None
        self.model = None

        # Use pre-trained ResNet model
        self.base_model = ResNet50V2(include_top=False,
                                     input_shape=self.shape,
                                     weights='imagenet')

        # Allow parameter updates for all layers
        self.base_model.trainable = True

        # Add a new pooling layer on the original output
        add_to_base = self.base_model.output
        add_to_base = GlobalAveragePooling2D(data_format='channels_last', name='head_gap')(add_to_base)

        # Add new output layer as head
        new_output = Dense(len(self.classes), activation='softmax', name='head_pred')(add_to_base)

        # Define model
        self.model = Model(self.base_model.input, new_output)

Ein paar weitere Details zum oben stehenden Code:

Die Vollversion von TransferModel (s. GitHub) enthält auch die Option, das Basismodell durch ein VGG16-Netzwerk zu ersetzen, ein weiteres Standard-CNN für die Bildklassifikation. Außerdem erlaubt es, nur bestimmte Layer freizugeben, d.h. wir können die entsprechenden Parameter trainierbar machen, während wir die anderen festgehalten werden. Standardmässig haben wir hier alle Parameter trainierbar gemacht.

Nachdem wir das Modell definiert haben, müssen wir es für das Training konfigurieren. Dies kann mit der compile()-Methode von tf.keras.Model gemacht werden:

def compile(self, **kwargs):
      """
    Compile method
    """
    self.model.compile(**kwargs)

Wir übergeben dann die folgenden Keyword-Argumente an unsere Methode:

Als Nächstes wollen wir uns das Trainingsverfahren ansehen. Dazu definieren wir eine train-Methode für unsere oben vorgestellte TransferModel-Klasse:

from tensorflow.keras.callbacks import EarlyStopping

def train(self,
          ds_train: tf.data.Dataset,
          epochs: int,
          ds_valid: tf.data.Dataset = None,
          class_weights: np.array = None):
    """
    Trains model in ds_train with for epochs rounds

    Args:
        ds_train: training data as tf.data.Dataset
        epochs: number of epochs to train
        ds_valid: optional validation data as tf.data.Dataset
        class_weights: optional class weights to treat unbalanced classes

    Returns
        Training history from self.history
    """

    # Define early stopping as callback
    early_stopping = EarlyStopping(monitor='val_loss',
                                   min_delta=0,
                                   patience=12,
                                   restore_best_weights=True)

    callbacks = [early_stopping]

    # Fitting
    self.history = self.model.fit(ds_train,
                                  epochs=epochs,
                                  validation_data=ds_valid,
                                  callbacks=callbacks,
                                  class_weight=class_weights)

    return self.history

Da unser Modell eine Instanz von tensorflow.keras.Model ist, können wir es mit der Methode fit trainieren. Um Overfitting zu verhindern, wird Early Stopping verwendet, indem es als Callback-Funktion an die fit-Methode übergeben wird. Der patience-Parameter kann eingestellt werden, um festzulegen, wie schnell das Early Stopping angewendet werden soll. Der Parameter steht für die Anzahl der Epochen, nach denen, wenn keine Abnahme des Validierungsverlustes registriert wird, das Training abgebrochen wird. Weiterhin können Klassengewichte an die Methode fit übergeben werden. Klassengewichte erlauben es, unausgewogene Daten zu behandeln, indem den verschiedenen Klassen unterschiedliche Gewichte zugewiesen werden, wodurch die Wirkung von Klassen mit weniger Trainingsbeispielen erhöht werden kann.

Wir können den Trainingsprozess mit einem vortrainierten Modell wie folgt beschreiben: Da die Gewichte im head zufällig initialisiert werden und die Gewichte des Basismodells vortrainiert sind, setzt sich das Training aus dem Training des heads von Grund auf und der Feinabstimmung der Gewichte des vortrainierten Modells zusammen. Es wird generell für Transfer Learning empfohlen, eine kleine Lernrate zu verwenden (z. B. 1e-4), da eine zu große Lernrate die nahezu optimalen vortrainierten Gewichte des Basismodells zerstören kann.

Der Trainingsvorgang kann beschleunigt werden, indem zunächst einige Epochen lang trainiert wird, ohne dass das Basismodell trainierbar ist. Der Zweck dieser ersten Epochen ist es, die Gewichte des heads an das Problem anzupassen. Dies beschleunigt das Training, da wenn nur der head trainiert wird, viel weniger Parameter trainierbar sind und somit für jeden Batch aktualisiert werden. Die resultierenden Modellgewichte können dann als Ausgangspunkt für das Training des gesamten Modells verwendet werden, wobei das Basismodell trainierbar ist. Für das hier betrachtete Autoklassifizierungsproblem führte die Anwendung dieses zweistufigen Trainings zu keiner nennenswerten Leistungsverbesserung.

Evaluation/Vorhersage der Modell Performance

Bei der Verwendung der API tf.data.Dataset muss man auf die Art der verwendeten Methoden achten. Die folgende Methode in unserer Klasse TransferModel kann als Vorhersagemethode verwendet werden.

def predict(self, ds_new: tf.data.Dataset, proba: bool = True):
    """
    Predict class probs or labels on ds_new
    Labels are obtained by taking the most likely class given the predicted probs

    Args:
        ds_new: New data as tf.data.Dataset
        proba: Boolean if probabilities should be returned

    Returns:
        class labels or probabilities
    """

    p = self.model.predict(ds_new)

    if proba:
        return p
    else:
        return [np.argmax(x) for x in p]

Es ist wichtig, dass der Datensatz ds_new nicht gemischt wird, sonst stimmen die erhaltenen Vorhersagen nicht mit den erhaltenen Bildern überein, wenn ein zweites Mal über den Datensatz iteriert wird. Dies ist der Fall, da die Flag reshuffle_each_iteration in der Implementierung der Methode shuffle standardmäßig auf True gesetzt ist. Ein weiterer Effekt des Shufflens ist, dass mehrere Aufrufe der Methode take nicht die gleichen Daten zurückgeben. Dies ist wichtig, wenn z. B. Vorhersagen für nur eine Charge überprüft werden sollen. Ein einfaches Beispiel, an dem dies zu sehen ist, ist:

# Use construct_ds method from above to create a shuffled dataset
ds = construct_ds(..., shuffle=True)

# Take 1 batch (e.g. 32 images) of dataset: This returns a new dataset
ds_batch = ds.take(1)

# Predict labels for one batch
predictions = model.predict(ds_batch)

# Predict labels again: The result will not be the same as predictions above due to shuffling
predictions_2 = model.predict(ds_batch)

Eine Funktion zum Plotten von Bildern, die mit den entsprechenden Vorhersagen beschriftet sind, könnte wie folgt aussehen:

def show_batch_with_pred(model, ds, classes, rescale=True, size=(10, 10), title=None):
      for image, label in ds.take(1):
        image_array = image.numpy()
        label_array = label.numpy()
        batch_size = image_array.shape[0]
        pred = model.predict(image, proba=False)
        for idx in range(batch_size):
            label = classes[np.argmax(label_array[idx])]
            ax = plt.subplot(np.ceil(batch_size / 4), 4, idx + 1)
            if rescale:
                plt.imshow(image_array[idx] / 255)
            else:
                plt.imshow(image_array[idx])
            plt.title("label: " + label + "n" 
                      + "prediction: " + classes[pred[idx]], fontsize=10)
            plt.axis('off')

Die Methode show_batch_with_pred funktioniert auch für gemischte Datensätze, da image und label demselben Aufruf der Methode take entsprechen.

Die Auswertung der Model-Performance kann mit der Methode evaluate von keras.Model durchgeführt werden.

Wie akkurat ist unser finales Modell?

Das Modell erreicht eine kategoriale Genauigkeit von etwas über 70 % für die Vorhersage des Automodells für Bilder aus 300 Modellklassen. Um die Vorhersagen des Modells besser zu verstehen, ist es hilfreich, die Konfusionsmatrix zu betrachten. Unten ist die Heatmap der Vorhersagen des Modells für den Validierungsdatensatz abgebildet.

heatmap

Wir haben die Heatmap auf Einträge der Konfusionsmatrix in [0, 5] beschränkt, da das Zulassen einer weiteren Spanne keine Region außerhalb der Diagonalen signifikant hervorgehoben hat. Wie in der Heatmap zu sehen ist, wird eine Klasse den Beispielen fast aller Klassen zugeordnet. Das ist an der dunkelroten vertikalen Linie zwei Drittel rechts in der Abbildung oben zu erkennen.

Abgesehen von der zuvor erwähnten Klasse gibt es keine offensichtlichen Verzerrungen in den Vorhersagen. Wir möchten an dieser Stelle betonen, dass die Accuracy im Allgemeinen nicht ausreicht, um die Leistung eines Modells zufriedenstellend zu beurteilen, insbesondere im Fall unausgewogener Klassen.

Fazit und nächste Schritte

In diesem Blog-Beitrag haben wir Transfer Learning mit dem ResNet50V2 angewendet, um das Fahrzeugmodell anhand von Bildern von Autos zu klassifizieren. Unser Modell erreicht 70% kategoriale Genauigkeit über 300 Klassen.

Wir haben festgestellt, dass das Trainieren des gesamten Basismodells und die Verwendung einer kleinen Lernrate die besten Ergebnisse erzielen. Ein cooles Auto-Klassifikationsmodell entwickelt zu haben ist großartig, aber wie können wir unser Modell in einer produktiven Umgebung einsetzen? Natürlich könnten wir unsere eigene Modell-API mit Flask oder FastAPI bauen… Aber gibt es vielleicht sogar einen einfacheren, standardisierten Weg?

Im zweiten Beitrag unserer Blog-Serie, „Deployment von TensorFlow-Modellen in Docker mit TensorFlow Serving“ zeigen wir Euch, wie dieses Modell mit TensorFlow Serving bereitgestellt werden kann.