Skip to content

Quick Start

This guide will get you up and running with Fast-LLM. Let's train a model and see some results!

PrerequisitesΒΆ

To follow this guide, you'll need:

  • Hardware: At least one NVIDIA GPU, preferably with Ampere architecture or newer. Note that this tutorial is designed for 80 GB A100s or H100 GPUs, and some adjustments are needed to run it with less memory or an earlier architecture.
  • Software: Depending on your setup, you'll need one of the following:
    • Docker: If you're using the prebuilt Docker image on your local machine.
    • Python 3.10: If you're setting up a custom environment (virtual environment, bare-metal, etc.) on your local machine.
    • Cluster Setup: Access to a Docker-enabled Slurm cluster or to a Kubernetes cluster with Kubeflow if you're using those environments.

πŸ— Step 1: Initial SetupΒΆ

First, create a working directory for this tutorial:

mkdir ./fast-llm-tutorial

We'll use this directory to store all the files and data needed for training.

Now, select the compute environment that matches your setup or preferred workflow. Once you select an environment, all sections of this guide will adapt to provide instructions specific to your choice:

Use a preconfigured Docker container with the Fast-LLM image, which includes all the required software and dependencies. Run the following command to pull the image and start a container:

docker run --gpus all -it --rm \
    -v $(pwd)/fast-llm-tutorial:/app/fast-llm-tutorial \
    ghcr.io/servicenow/fast-llm:latest \
    bash

Replace --gpus all with --gpus '"device=0,1,2,3,4,5,6,7"' etc. if you want to use specific GPUs.

Once inside the container, all commands from this guide can be executed as-is. The fast-llm-tutorial directory is mounted inside the container at /app/fast-llm-tutorial, so any files saved there will persist and be accessible on your host machine as well.

If you prefer not to use the prebuilt Docker image or already have an environment you'd like to use (e.g., a custom Docker image, virtual environment, or bare-metal setup), follow these steps to install the necessary software and dependencies:

  1. Ensure Python 3.12: Install Python 3.12 (or later) if it's not already available on your system. For a Python virtual environment, run:

    python3.12 -m venv ./fast-llm-tutorial/venv
    source ./fast-llm-tutorial/venv/bin/activate
    pip install --upgrade pip
    

    You can deactivate the virtual environment later with deactivate.

  2. Verify CUDA Installation: Make sure CUDA 12.1 or later is installed in your environment. Verify with:

    nvcc --version
    

    If CUDA is not installed or the version is incorrect, follow the CUDA installation guide to set it up.

  3. Pre-install PyTorch and pybind11: Install PyTorch and pybind11 to meet Fast-LLM's requirements:

    pip install pybind11 "torch>=2.2.2"
    
  4. Install NVIDIA APEX: Fast-LLM uses certain kernels from APEX. Follow the installation instructions on their GitHub page, ensuring you use the --cuda_ext and --fast_layer_norm options to install all kernels supported by Fast-LLM:

    git clone https://github.com/NVIDIA/apex ./fast-llm-tutorial/apex
    pushd ./fast-llm-tutorial/apex
    pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" --config-settings "--build-option=--fast_layer_norm" ./
    popd
    
  5. Install Fast-LLM and Dependencies: Finally, install Fast-LLM along with its remaining dependencies, including FlashAttention-2:

    pip install --no-build-isolation "git+https://github.com/ServiceNow/Fast-LLM.git#egg=fast_llm[CORE,OPTIONAL,DEV]"
    
  6. Verify the Installation: Confirm the setup with the following commands:

    python -c "import torch; print(torch.cuda.is_available())"
    python -c "from amp_C import *"
    python -c "import flash_attn; print(flash_attn.__version__)"
    python -c "import fast_llm; print(fast_llm.__version__)"
    

If you made it this far without any errors, your local environment is ready to run Fast-LLM.

Use Docker-enabled Slurm for this tutorial. The ghcr.io/servicenow/fast-llm:latest Docker image will be pulled and run on the compute nodes. Ensure the fast-llm-tutorial directory is accessible across all nodes (e.g., via a shared filesystem like NFS).

Use Kubernetes with Kubeflow and a PyTorchJob resource to train our model using the ghcr.io/servicenow/fast-llm:latest Docker image. We'll copy the configuration files and dataset to shared persistent volume claims (PVCs) to ensure all nodes have access to the same data. Follow these steps:

  1. Create a Persistent Volume Claim (PVC)

    Create a PVC named pvc-fast-llm-tutorial to store input data and output results:

    kubectl apply -f - <<EOF
    apiVersion: "v1"
    kind: "PersistentVolumeClaim"
    metadata:
      name: "pvc-fast-llm-tutorial"
    spec:
      storageClassName: local-path  
      accessModes:
        - ReadWriteMany
      resources:
        requests:
          storage: 100Gi  
    EOF
    

    StorageClassName

    Replace local-path with the appropriate StorageClassName for your Kubernetes cluster. Consult your cluster admin or documentation if unsure.

  2. Set Up a Temporary Pod for Data Management

    Create a temporary pod to manage input data and results:

    kubectl apply -f - <<EOF
    apiVersion: v1
    kind: Pod
    metadata:
      name: pod-fast-llm-tutorial
    spec:
      containers:
        - name: fast-llm-tutorial-container
          image: ghcr.io/servicenow/fast-llm:latest
          command: ["sleep", "infinity"]
          volumeMounts:
            - mountPath: /app/fast-llm-tutorial
              name: fast-llm-tutorial
      volumes:
        - name: fast-llm-tutorial
          persistentVolumeClaim:
            claimName: pvc-fast-llm-tutorial
    EOF
    

    Purpose of the Temporary Pod

    This pod ensures you have an interactive container for managing input data and retrieving results. Use kubectl exec to interact with it:

    kubectl exec -it pod-fast-llm-tutorial -- bash
    

    Use kubectl cp to copy files between the pod and your local machine:

    kubectl cp ./fast-llm-tutorial pod-fast-llm-tutorial:/app
    

πŸ€– Step 2: Choose Your Training ConfigurationΒΆ

This guide offers two training configurations:

For a quick, single-node setup and immediate results to test Fast-LLM with a smaller model. Ideal for getting started and understanding the basics. It's the "hello world" of Fast-LLM.

For a more advanced setup with more data and larger models to explore Fast-LLM's full capabilities. This configuration requires more resources and time to complete, but it prepares you for production-like workloads.

Choose based on your goals for this tutorial.

πŸ“₯ Step 3: Download the Pretrained ModelΒΆ

For the small configuration, we'll use a SmolLM2 model configuration with 135M parameters, which is fast to train. Run the following commands to download the model configuration and tokenizer:

git lfs install
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/HuggingFaceTB/SmolLM2-135M ./fast-llm-tutorial/pretrained-model

For the big configuration, we'll use a Llama model with 8B parameters. We'll grab the model from the HuggingFace Hub and save it to our inputs folder.

Access Required

Meta gates access to their Llama models. You need to request access to the model from Meta before you can download it at https://huggingface.co/meta-llama/Llama-3.1-8B. You'll need to authenticate with your HuggingFace account to download the model:

pip install huggingface_hub
huggingface-cli login

When asked for whether to use this as git credentials, answer in the affirmative.

git lfs install
git clone https://huggingface.co/meta-llama/Llama-3.1-8B ./fast-llm-tutorial/pretrained-model

πŸ“š Step 3: Prepare the Training DataΒΆ

For this tutorial, we'll use text from the OpenWebText dataset. This dataset is a free approximation of the WebText data OpenAI used for GPT-2, and it's perfect for our test run!

Create a configuration file for the dataset preparation. Save the following as `./fast-llm-tutorial/prepare-config.yaml``:

output_path: fast-llm-tutorial/dataset

loading_workers: 16  
tokenize_workers: 16
saving_workers: 16

dataset:
  path: stas/openwebtext-10k  
  split: "train"
  trust_remote_code: true

tokenizer:
  path: fast-llm-tutorial/pretrained-model

splits:  
  training: 0.9
  validation: 0.1
output_path: fast-llm-tutorial/dataset

loading_workers: 128  
tokenize_workers: 128
saving_workers: 128

dataset:
  path: openwebtext
  split: train
  trust_remote_code: true

tokenizer:
  path: fast-llm-tutorial/pretrained-model

splits:  
  training: 0.99
  validation: 0.01

Fast-LLM ships with a prepare command that will download and preprocess the dataset for you.

Run data preparation with the following command:

fast-llm prepare gpt_memmap --config fast-llm-tutorial/prepare-config.yaml

Run data preparation with the following command:

fast-llm prepare gpt_memmap --config fast-llm-tutorial/prepare-config.yaml

Run data preparation with the following command:

sbatch <<EOF
#!/bin/bash
# SBATCH --job-name=fast-llm-prepare
# SBATCH --nodes=4
# SBATCH --ntasks-per-node=1
# SBATCH --exclusive
# SBATCH --output=/app/fast-llm-tutorial/prepare-output.log
# SBATCH --error=/app/fast-llm-tutorial/prepare-error.log

MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=8001

export PYTHONHASHSEED=0

srun \
    --container-image="ghcr.io/servicenow/fast-llm:latest" \
    --container-mounts="$(pwd)/fast-llm-tutorial:/app/fast-llm-tutorial" \
    --container-env="PYTHONHASHSEED" \
    --ntasks-per-node=$SLURM_NTASKS_PER_NODE \
    bash -c "
        torchrun --rdzv_backend=static \
                 --rdzv_id=0 \
                 --rdzv_endpoint=\${MASTER_ADDR}:\${MASTER_PORT} \
                 --node_rank=\\$SLURM_NODEID \
                 --nproc_per_node=\\$SLURM_NTASKS_PER_NODE \
                 --nnodes=\\$SLURM_NNODES:\\$SLURM_NNODES \
                 --max_restarts=0 \
                 --rdzv_conf=timeout=3600 \
                 --no_python \
                 fast-llm prepare gpt_memmap \
                 --config fast-llm-tutorial/prepare-config.yaml"
EOF

You can follow the job's progress by running squeue -u $USER and checking the logs in fast-llm-tutorial/prepare-output.log and fast-llm-tutorial/prepare-error.log, respectively.

Copy the files to the shared PVC if they're not already there:

kubectl cp ./fast-llm-tutorial pod-fast-llm-tutorial:/app

Then, run data preparation with the following command:

kubectl apply -f - <<EOF
apiVersion: "kubeflow.org/v1"
kind: "PyTorchJob"
metadata:
  name: "fast-llm-prepare"
spec:
  nprocPerNode: "1"
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      restartPolicy: Never
      template:
        spec:
          tolerations:
            - key: nvidia.com/gpu
              value: "true"
              operator: Equal
              effect: NoSchedule
          containers:
            - name: pytorch
              image: ghcr.io/servicenow/fast-llm:latest
              resources:
                limits:
                  memory: "1024Gi"
                  cpu:
                requests:
                  memory: "1024Gi"
                  cpu: 128
              command:
                - /bin/bash
                - -c
                - |
                  torchrun --rdzv_backend=static \
                           --rdzv_id=0 \
                           --rdzv_endpoint=\${MASTER_ADDR}:\${MASTER_PORT} \
                           --node_rank=\${RANK} \
                           --nproc_per_node=\${PET_NPROC_PER_NODE} \
                           --nnodes=\${PET_NNODES}:\${PET_NNODES} \
                           --max_restarts=0 \
                           --rdzv_conf=timeout=3600 \
                           --no_python \
                           fast-llm prepare gpt_memmap \
                           --config fast-llm-tutorial/prepare-config.yaml
              env:
                - name: PYTHONHASHSEED
                  value: "0"
              securityContext:
                capabilities:
                  add:
                    - IPC_LOCK
              volumeMounts:
                - mountPath: /app/fast-llm-tutorial
                  name: fast-llm-tutorial
                - mountPath: /dev/shm
                  name: dshm
          volumes:
            - name: fast-llm-tutorial
              persistentVolumeClaim:
                claimName: pvc-fast-llm-tutorial
            - name: dshm
              emptyDir:
                medium: Memory
                sizeLimit: "1024Gi"
    Worker:
      replicas: 3
      restartPolicy: Never
      template:
        spec:
          tolerations:
            - key: nvidia.com/gpu
              value: "true"
              operator: Equal
              effect: NoSchedule
          containers:
            - name: pytorch
              image: ghcr.io/servicenow/fast-llm:latest
              resources:
                limits:
                  memory: "1024Gi"
                  cpu:
                requests:
                  memory: "1024Gi"
                  cpu: 128
              command:
                - /bin/bash
                - -c
                - |
                  torchrun --rdzv_backend=static \
                           --rdzv_id=0 \
                           --rdzv_endpoint=\${MASTER_ADDR}:\${MASTER_PORT} \
                           --node_rank=\${RANK} \
                           --nproc_per_node=\${PET_NPROC_PER_NODE} \
                           --nnodes=\${PET_NNODES}:\${PET_NNODES} \
                           --max_restarts=0 \
                           --rdzv_conf=timeout=3600 \
                           --no_python \
                           fast-llm prepare gpt_memmap \
                           --config fast-llm-tutorial/prepare-config.yaml
              env:
                - name: PYTHONHASHSEED
                  value: "0"
              securityContext:
                capabilities:
                  add:
                    - IPC_LOCK
              volumeMounts:
                - mountPath: /app/fast-llm-tutorial
                  name: fast-llm-tutorial
                - mountPath: /dev/shm
                  name: dshm
          volumes:
            - name: fast-llm-tutorial
              persistentVolumeClaim:
                claimName: pvc-fast-llm-tutorial
            - name: dshm
              emptyDir:
                medium: Memory
                sizeLimit: "1024Gi"
EOF

You can follow the job's progress by running kubectl get pods and checking the logs with kubectl logs fast-llm-prepare-master-0.

βš™οΈ Step 4: Configure Fast-LLMΒΆ

Next, we'll create a configuration file for Fast-LLM.

FlashAttention

Fast-LLM uses FlashAttention by default. If you're using Volta GPUs, you must disable FlashAttention by setting use_flash_attention: no in the configuration file, as shown below.

Micro-Batch Size

The micro_batch_size in the configuration below is optimized for 80GB GPUs. If you're using GPUs with less memory, you will need to lower this value. Alternatively, you can decrease the sequence_length to reduce the memory footprint.

Save the following as fast-llm-tutorial/train-config.yaml:

training:
  train_iters: 100  
  logs:
    interval: 10
  validation:
    iterations: 25
    interval: 100
  export:  
    format: llama
    interval: 100
  wandb:  
    project_name: fast-llm-tutorial
    group_name: Small
    entity_name: null
batch:
  micro_batch_size: 60  
  sequence_length: 1024
  batch_size: 480  
data:
  datasets:
    Training:
      type: file
      path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml  
    Validation:
      type: file
      path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml  
optimizer:
  learning_rate:
    base: 6.0e-04
pretrained:
  format: llama  
  path: fast-llm-tutorial/pretrained-model
  model_weights: no  
model:
  base_model:
    transformer:
      use_flash_attention: yes  
  distributed:
    training_dtype: bf16  
run:
  experiment_dir: fast-llm-tutorial/experiment
training:
  train_iters: 100_000  
  logs:
    interval: 10
  validation:
    iterations: 25
    interval: 1000
  checkpoint:
    interval: 1000
    keep: 5
  test_iters: 0
  export:  
    format: llama
    interval: 20_000
  wandb:  
    project_name: fast-llm-tutorial
    group_name: Big
    entity_name: null
batch:
  micro_batch_size: 2  
  sequence_length: 4096
  batch_size: 512  
data:
  datasets:
    Training:
      type: file
      path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml  
    Validation:
      type: file
      path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml  
optimizer:  
  weight_decay: 0.1
  beta_1: 0.9
  beta_2: 0.95
  learning_rate:  
    base: 6.0e-04
    minimum: 6.0e-05
    decay_style: cosine
    decay_iterations: 100_000
    warmup_iterations: 2000
pretrained:
  format: llama  
  path: fast-llm-tutorial/pretrained-model
  model_weights: yes  
model:
  base_model:
    transformer:
      use_flash_attention: yes  
    cross_entropy_impl: fused  
  multi_stage:
    zero_stage: 2  
  distributed:
    training_dtype: bf16  
run:
  experiment_dir: fast-llm-tutorial/experiment

πŸ”‘ (Optional) Step 6: Add Your Weights & Biases API KeyΒΆ

If you included the W&B section in your configuration, you'll need to add your API key. Save it to ./fast-llm-tutorial/.wandb_api_key and use the WANDB_API_KEY_PATH environment variable as shown in the training command.

πŸš€ Step 7: Launch TrainingΒΆ

Alright, the big moment! Let's launch the training run.

Python Hash Seed

The Python hash seed must be set to 0 to ensure consistent, reproducible ordering in hash-dependent operations across processes. Training will fail if this isn't set.

If you have 8 GPUs available, run the following to start training:

export PYTHONHASHSEED=0
# export WANDB_API_KEY_PATH=/app/fast-llm-tutorial/.wandb_api_key
torchrun --standalone --nnodes 1 --nproc_per_node=8 --no_python \
    fast-llm train gpt --config fast-llm-tutorial/train-config.yaml

If you have 8 GPUs available, run the following to start training:

export PYTHONHASHSEED=0
# export WANDB_API_KEY_PATH=/app/fast-llm-tutorial/.wandb_api_key
torchrun --standalone --nnodes 1 --nproc_per_node=8 --no_python \
    fast-llm train gpt --config fast-llm-tutorial/train-config.yaml

If you have 4 nodes with 8 GPUs each, run the following to start training:

sbatch <<EOF
#!/bin/bash
# SBATCH --job-name=fast-llm-train
# SBATCH --nodes=4
# SBATCH --gpus-per-node=8
# SBATCH --ntasks-per-node=1
# SBATCH --exclusive
# SBATCH --output=/app/fast-llm-tutorial/train-output.log
# SBATCH --error=/app/fast-llm-tutorial/train-error.log

export PYTHONHASHSEED=0
export WANDB_API_KEY_PATH=/app/fast-llm-tutorial/.wandb_api_key
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_DEBUG=INFO

srun \
    --container-image="ghcr.io/servicenow/fast-llm:latest" \
    --container-mounts="$(pwd)/fast-llm-tutorial:/app/fast-llm-tutorial" \
    --container-env="PYTHONHASHSEED,WANDB_API_KEY_PATH,TORCH_NCCL_ASYNC_ERROR_HANDLING,NCCL_DEBUG" \
    --gpus-per-node=\$SLURM_GPUS_PER_NODE \
    --ntasks-per-node=\$SLURM_NTASKS_PER_NODE \
    bash -c "
        torchrun --rdzv_backend=static \
                 --rdzv_id=0 \
                 --rdzv_endpoint=\${MASTER_ADDR}:\${MASTER_PORT} \
                 --node_rank=\\$SLURM_NODEID \
                 --nproc_per_node=\\$SLURM_GPUS_PER_NODE \
                 --nnodes=\\$SLURM_NNODES \
                 --max_restarts=0 \
                 --rdzv_conf=timeout=3600 \
                 --no_python \
                 fast-llm train gpt \
                 --config fast-llm-tutorial/train-config.yaml"
EOF

Copy the configuration file to the shared PVC:

kubectl cp ./fast-llm-tutorial/train-config.yaml pod-fast-llm-tutorial:/app/fast-llm-tutorial

If you have 4 nodes with 8 GPUs each, run the following to start training:

kubectl apply -f - <<EOF
apiVersion: "kubeflow.org/v1"
kind: "PyTorchJob"
metadata:
  name: "fast-llm-train"
spec:
  nprocPerNode: "8"
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      restartPolicy: Never
      template:
        spec:
          tolerations:
            - key: nvidia.com/gpu
              value: "true"
              operator: Equal
              effect: NoSchedule
          containers:
            - name: pytorch
              image: ghcr.io/servicenow/fast-llm:latest
              resources:
                limits:
                  nvidia.com/gpu: 8
                  rdma/rdma_shared_device_a: 1
                  memory: "1024Gi"
                  cpu:
                requests:
                  nvidia.com/gpu: 8
                  rdma/rdma_shared_device_a: 1
                  memory: "1024Gi"
                  cpu: 128
              command:
                - /bin/bash
                - -c
                - |
                  torchrun --rdzv_backend=static \
                           --rdzv_endpoint=\${MASTER_ADDR}:\${MASTER_PORT} \
                           --node_rank=\${RANK} \
                           --nproc_per_node=\${PET_NPROC_PER_NODE} \
                           --nnodes=\${PET_NNODES} \
                           --max_restarts=0 \
                           --rdzv_conf=timeout=3600 \
                           --no_python \
                           fast-llm train gpt \
                           --config fast-llm-tutorial/train-config.yaml
              env:
                - name: PYTHONHASHSEED
                  value: "0"
                - name: WANDB_API_KEY_PATH
                  value: "/app/fast-llm-tutorial/.wandb_api_key"
                - name: TORCH_NCCL_ASYNC_ERROR_HANDLING
                  value: "1"
                - name: NCCL_DEBUG
                  value: "INFO"
              securityContext:
                capabilities:
                  add:
                    - IPC_LOCK
              volumeMounts:
                - mountPath: /app/fast-llm-tutorial
                  name: fast-llm-tutorial
                - mountPath: /dev/shm
                  name: dshm
          volumes:
            - name: fast-llm-tutorial
              persistentVolumeClaim:
                claimName: pvc-fast-llm-tutorial
            - name: dshm
              emptyDir:
                medium: Memory
                sizeLimit: "1024Gi"
    Worker:
      replicas: 3
      restartPolicy: Never
      template:
        spec:
          tolerations:
            - key: nvidia.com/gpu
              value: "true"
              operator: Equal
              effect: NoSchedule
          containers:
            - name: pytorch
              image: ghcr.io/servicenow/fast-llm:latest
              resources:
                limits:
                  nvidia.com/gpu: 8
                  rdma/rdma_shared_device_a: 1
                  memory: "1024Gi"
                  cpu:
                requests:
                  nvidia.com/gpu: 8
                  rdma/rdma_shared_device_a: 1
                  memory: "1024Gi"
                  cpu: 128
              command:
                - /bin/bash
                - -c
                - |
                  torchrun --rdzv_backend=static \
                           --rdzv_endpoint=\${MASTER_ADDR}:\${MASTER_PORT} \
                           --node_rank=\${RANK} \
                           --nproc_per_node=\${PET_NPROC_PER_NODE} \
                           --nnodes=\${PET_NNODES} \
                           --max_restarts=0 \
                           --rdzv_conf=timeout=3600 \
                           --no_python \
                           fast-llm train gpt \
                           --config fast-llm-tutorial/train-config.yaml
              env:
                - name: PYTHONHASHSEED
                  value: "0"
                - name: WANDB_API_KEY_PATH
                  value: "/app/fast-llm-tutorial/.wandb_api_key"
                - name: TORCH_NCCL_ASYNC_ERROR_HANDLING
                  value: "1"
                - name: NCCL_DEBUG
                  value: "INFO"
              securityContext:
                capabilities:
                  add:
                    - IPC_LOCK
              volumeMounts:
                - mountPath: /app/fast-llm-tutorial
                  name: fast-llm-tutorial
                - mountPath: /dev/shm
                  name: dshm
          volumes:
            - name: fast-llm-tutorial
              persistentVolumeClaim:
                claimName: pvc-fast-llm-tutorial
            - name: dshm
              emptyDir:
                medium: Memory
                sizeLimit: "1024Gi"
EOF

πŸ“Š Step 8. Track Training ProgressΒΆ

Fast-LLM will log training progress to the console every 10 iterations.

You can cancel training at any time by pressing Ctrl+C in the terminal.

Fast-LLM will log training progress to the console every 10 iterations.

You can cancel training at any time by pressing Ctrl+C in the terminal.

Use squeue -u $USER to see the job status. Follow train-output.log and train-error.log in your working directory for logs. Fast-LLM will log training progress to those files every 10 iterations.

You can cancel training by running scancel <job_id>.

Use kubectl get pods to see the job status. Use kubectl logs fast-llm-train-master-0 to check the logs. Fast-LLM will log training progress to the console every 10 iterations.

You can cancel training by deleting the PyTorchJob:

kubectl delete pytorchjob fast-llm-train

Cleaning Up Resources

Delete the data management pod and PVC if you're finished with the tutorial:

kubectl delete pod pod-fast-llm-tutorial
kubectl delete pvc pvc-fast-llm-tutorial

This will shut down the temporary pod and remove the PVC with all its contents.

You can expect to see the following performance metrics in Fast-LLM's output:

Performance Metric 8x V100-SXM2-32GB1 8x A100-SXM4-80GB2 8x H100-SXM5-80GB3
tokens/s/GPU 16,700 149,000 294,000
tflop/s (model) 15.3 137 268
peak tflop/s (theoretical)7 125 312 990
utilization 12.2% 44% 27%
total training time 68 minutes 3.9 minutes
Performance Metric 32x V100-SXM2-32GB4 32x A100-SXM4-80GB5 32x H100-SXM5-80GB6
tokens/s/GPU 10,100
tflop/s (model) 487
peak tflop/s (theoretical)7 125 312 990
utilization 49.2%
total training time 180 hours

If you included the W&B section in your configuration, you can also track your training progress on the Weights & Biases dashboard as well. Follow the link in the console output to view your training run.

πŸŽ‰ Final ThoughtsΒΆ

And that's it! You've set up, prepped data, chosen a model, configured training, and launched a full training run with Fast-LLM. You can try out the saved model directly with Transformers.

import transformers

model = transformers.AutoModelForCausalLM.from_pretrained("/app/fast-llm-tutorial/experiment/export/llama/100").cuda()
tokenizer = transformers.AutoTokenizer.from_pretrained("fast-llm-tutorial/pretrained-model/")

inputs = {k:v.cuda() for k,v in tokenizer("This is what the small model can do after fine-tuning for 100 steps:", return_tensors="pt").items()}
outputs=model.generate(**inputs, max_new_tokens=100)

print(tokenizer.decode(outputs[0]))
import transformers

model = transformers.AutoModelForCausalLM.from_pretrained("/app/fast-llm-tutorial/experiment/export/llama/100000").cuda()
tokenizer = transformers.AutoTokenizer.from_pretrained("fast-llm-tutorial/pretrained-model/")

inputs = {k:v.cuda() for k,v in tokenizer("This is what the big model can do after fine-tuning for 100K steps:", return_tensors="pt").items()}
outputs=model.generate(**inputs, max_new_tokens=100)

print(tokenizer.decode(outputs[0]))

From here, feel free to tweak the model, try out larger datasets, or scale things up to larger clusters. The sky's the limit!

Happy training!


  1. Precision was set to fp16, since bf16 is not supported on V100 GPUs. FlashAttention was disabled, as it is not supported on V100 GPUs. Micro-batch size was set to 12. β†©

  2. Precision was set to bf16. FlashAttention was enabled. Micro-batch size was set to 60. β†©

  3. Precision was set to bf16. FlashAttention was enabled. Micro-batch size was set to 60. β†©

  4. Precision was set to fp16, since bf16 is not supported on V100 GPUs. FlashAttention was disabled, as it is not supported on V100 GPUs. Micro-batch size was set to 4. β†©

  5. Precision was set to bf16. FlashAttention was enabled. Micro-batch size was set to 1. ZeRO stage 2 was used. β†©

  6. Precision was set to bf16. FlashAttention was enabled. Micro-batch size was set to 2. ZeRO stage 2 was used. β†©

  7. Theoretical peak performance of the GPU for dense tensors in fp16 or bf16 precision, depending on the GPU architecture. Source: Wikipedia↩↩