Journal RPCDataloader: chargement et pré-traitement de données distribué pour l'IA

11
22
avr.
2023

Sommaire

Introduction

On continue la lignée des utilitaires pour une grappe de calcul IA (cluster en bon franglais).
Cette fois-ci, l'objectif est de déplacer le chargement et le pré-traitement des données sur des serveurs différents de ceux qui gèrent le modèle en lui-même.

On s'intéresse plus précisément à un algorithme d’entraînement de réseau de neurones avec le framework PyTorch. Bien que la librairie que je présente (RPCDataloader) n'ait pas une très forte dépendance à PyTorch, son API est conçue pour s'intégrer avec.

Contexte : Entraînement de réseaux de neurones avec PyTorch

Algorithmiquement, l’entraînement d'un réseau de neurones suit les étapes suivantes en boucle :

  1. On prépare un petit ensemble d'exemples (textes, images, sons, etc.) tiré d'un grand corpus.
  2. On applique un modèle sur ces données.
  3. On applique une métrique entre la prédiction et la vérité terrain attendue.
  4. On calcule la dérivée de cette métrique par rapport aux paramètres du modèle.
  5. On fait un pas de gradient sur les paramètres dans le sens qui maximise la métrique (ou la minimise selon l'objectif).

La première étape est réalisée sur CPU et comprend des pré-traitements potentiellement lourds. Le reste est fait sur GPU car les opérations à effectuer s'y prêtent bien.

Par "sur GPU", on veut en fait dire qu'un processus (en l'occurrence un script python) suit les étapes algorithmiques, les décompose en calculs élémentaires, puis soumet ces calculs au GPU (via une file d'attente). Pour maximiser le taux d'utilisation du GPU, il faut donc que la partie logique (le script) reste en avance sur les calculs.

Le script est écrit par le chercheur, mais PyTorch fournit une énorme bibliothèque d'algorithmes allant de la simple addition à des opérations composées très haut niveau, par exemple le calcul de la dérivée en un point de la sortie d'un modèle par rapport à ses paramètres.

PyTorch fournit aussi des outils pour le chargement et la préparation des données. Comme évoqué précédemment, ces calculs ne sont pas déportés sur un GPU mais dans des processus séparés. Le script principal là encore le travail (i.e. les indices des éléments à préparer) via une file d'attente.

Le schéma ci-dessous donne une vue simplifiée de l'ensemble.

Fonctionnement général de PyTorch

Le pitch

Le mode de fonctionnement présenté ci-dessus fonctionne plutôt pas mal dans la mesure où la puissance de calcul du GPU est bien souvent le facteur limitant. Le script principal, bien qu'il soit en python, parvient à garder de l'avance sur le GPU.

Là où ça coince, c'est plutôt au niveau des prétraitements. Comment faire si toute la puissance des CPUs n'est pas suffisante pour alimenter le reste de la chaîne en continu ?
On gâche alors de la ressource GPU et on occupe un serveur plus longtemps. C'est aussi mauvais pour le taux d'utilisation d'un cluster multi-utilisateur. Une expérience qui réserve 75% des CPUs mais seulement 50% des GPUs bloquera souvent un nœud de calcul entier, car les expériences avec des besoins complémentaires sont rares en pratique.

Vous allez me dire qu'on aurait dû acheter des serveurs avec plus de CPUs ou moins de GPUs. Je vous répondrai que c'est un exercice difficile, prévoir de la marge coûte très cher et on dimensionne pour le cas général, pas l'exception.

L'idée pour résoudre ce problème, c'est tout simplement de déplacer les processus de chargement et de préparation des données sur un autre nœud de calcul.

Initialisation

Dans PyTorch, les processus pour les données sont obtenus en forkant le processus principal. C'est ainsi qu'ils récupèrent un objet décrivant le dataset et qui contient généralement la liste des fichiers ainsi que les données qui peuvent rentrer en mémoire. C'est simple, mais peu efficace en utilisation de RAM car les processus enfants gardent les autres variables du parent, tandis que le script principal garde inutilement les données relatives au dataset.

Dans RPCDataLoader, on démarre des petits serveurs TCP qui attendent des commandes à exécuter. C'est de l'exécution de procédure distante (RPC) très simplifiée. Le script principal leur ordonne initialement d'instancier un objet pour le dataset, puis leur ordonne de charger et de traiter tel ou tel indice.

Communication

À l'origine, les processus données reçoivent leurs instructions du script principal via un Pipe.
Comme le résultat renvoyé dans l'autre sens est plutôt lourd, seule une description de la réponse est renvoyée via un Pipe, tandis que les buffers de données sont placés en mémoire partagée pour économiser une copie… Mais une copie des buffers sera quand même nécessaire, car il faut placer ces données en mémoire verrouillée (pinned memory). En gros ça garantit au GPU que Linux ne va pas déplacer ces plages de la mémoire, ce qui permet de faire des transferts plus rapides et asynchrones.

Dans RPCDataLoader, les pipes sont remplacés par des sockets TCP et l'optimisation de la copie grâce à la mémoire partagée est perdue… ou presque !
En effet, le module de sérialisation de données de python, appelé pickle, propose une fonctionnalité très intéressante bien qu'un peu obscure : on peut sérialiser et désérialiser les buffers de données séparément de l'objet qui les encapsule (arguments buffer_callback et buffers). Dans RPCDataloader j'en profite donc pour allouer les buffers directement en pinned memory avant de les réceptionner, ce qui économise une copie.

Gestion des variables distantes

Parmi les commandes que le script envoie au processus distants, il y a la création de variables. Cela pose la question de savoir quand elles devront être détruites. Le script distant ne pouvant pas prendre cette initiative, j'alloue pour chaque variable distante une référence locale qui sert de témoin. Avec la fonction weakref.finalize de python, on peut faire en sorte d'être notifié à la destruction du témoin local et d'ordonner simultanément la suppression de la variable distante.
C'est l'API RPC PyTorch qui m'a inspiré cette manière de faire.

En surchargeant la procédure de sérialisation/dé-sérialisation de cet objet, on peut aussi obtenir un comportement assez élégant lors de l'envoi d'une commande à appliquer sur l'objet distant :

rpc_async(
    host="node1:8888",
    func=somefunc,
    args=[1, 2, my_rref])

Lorsque le processus distant réceptionne la commande, my_rref devient l'objet référencé durant la dé-sérialisation. Ça fonctionne même si la référence est encapsulée à l'intérieur d'une liste ou d'un objet.

Le code est ici si jamais mes explications ne sont pas claires.

Conclusion

La librairie fonctionne et les performances sont plutôt satisfaisantes. Je craignais que le passage par le réseau et par pickle entraînerait de la latence et un surcoût CPU, mais il n'en est rien. Il a fallu par contre prendre garde à quelques détails d'implémentation sur la partie réseaux pour ne pas voir s'effondrer les performances. Et il faut bien entendu un réseau très rapide.

Sur l'exemple en lien ci-dessous, on peut occuper 8 NVidia A100 à plus de 80% en utilisant l'équivalent de 10 CPUs sur le nœud principal. La configuration typique pour accompagner ces 8 GPUs prévoit 128 cœurs, précisément pour absorber la charge des pré-traitements de données localement.

À l'avenir, j'aimerais ajouter une forme d'authentification, car n'importe qui peut envoyer une commande aux workers actuellement, ce qui restreint l'utilisation à un environnement de confiance.

Pour aller plus loin :

  • # merci

    Posté par  . Évalué à 1.

    Merci du partage et pour les explications !

  • # Merci

    Posté par  . Évalué à 1.

    Pour la découverte, ici on utilise knative et Kafka pour faire ça. J'avoue ça juste marche parfaitement. Mais je vais quand même regarder ça. Merci.

  • # Pourquoi RPC?

    Posté par  (site web personnel) . Évalué à 3.

    Merci pour cette présentation très intéressante qui m’a fait découvrir les technologies distribuées gérées par pytorch.

    Si j’ai bien compris, RPC est utilisé pour faire discuter entre eux les différents nœuds de calcul : ceux avec les GPUs et ceux avec les CPUs. Pourquoi avoir choisi cette technologie ? Je me demande notamment si le protocole MPI a été envisagé sachant que son implémentation sur les grappes de calcul est souvent très optimisée bas niveau par le fabricant du supercalculateur. À moins que le serveur avec les nœuds CPUs ne soit en dehors de la grappe de calcul ?

    • [^] # Re: Pourquoi RPC?

      Posté par  . Évalué à 1.

      J'y ai pensé pour une v2 au moment de la mise au point. Là le TCP est sur une couche IPoIB (lien 100GB/s) qui s'avère largement suffisante, donc j'ai laissé ça de côté. Et ça fait une dépendance en moins.

Suivre le flux des commentaires

Note : les commentaires appartiennent à celles et ceux qui les ont postés. Nous n’en sommes pas responsables.