Quick Start
This guide will get you up and running with Fast-LLM. Let's train a model and see some results!
Prerequisites¶
To follow this guide, you'll need:
- Hardware: At least one NVIDIA GPU, preferably with Ampere architecture or newer. Note that this tutorial is designed for 80 GB A100s or H100 GPUs, and some adjustments are needed to run it with less memory or an earlier architecture.
- Software: Depending on your setup, you'll need one of the following:
- Docker: If you're using the prebuilt Docker image on your local machine.
- Python 3.10: If you're setting up a custom environment (virtual environment, bare-metal, etc.) on your local machine.
- Cluster Setup: Access to a Docker-enabled Slurm cluster or to a Kubernetes cluster with Kubeflow if you're using those environments.
🏗 Step 1: Initial Setup¶
First, create a working directory for this tutorial:
We'll use this directory to store all the files and data needed for training.
Now, select the compute environment that matches your setup or preferred workflow. Once you select an environment, all sections of this guide will adapt to provide instructions specific to your choice:
Use a preconfigured Docker container with the Fast-LLM image, which includes all the required software and dependencies. Run the following command to pull the image and start a container:
docker run --gpus all -it --rm \
-v $(pwd)/fast-llm-tutorial:/app/fast-llm-tutorial \
ghcr.io/servicenow/fast-llm:latest \
bash
Replace --gpus all
with --gpus '"device=0,1,2,3,4,5,6,7"'
etc. if you want to use specific GPUs.
Once inside the container, all commands from this guide can be executed as-is. The fast-llm-tutorial
directory is mounted inside the container at /app/fast-llm-tutorial
, so any files saved there will persist and be accessible on your host machine as well.
If you prefer not to use the prebuilt Docker image or already have an environment you'd like to use (e.g., a custom Docker image, virtual environment, or bare-metal setup), follow these steps to install the necessary software and dependencies:
-
Ensure Python 3.12: Install Python 3.12 (or later) if it's not already available on your system. For a Python virtual environment, run:
python3.10 -m venv ./fast-llm-tutorial/venv source ./fast-llm-tutorial/venv/bin/activate pip install --upgrade pip
You can deactivate the virtual environment later with
deactivate
. -
Verify CUDA Installation: Make sure CUDA 12.1 or later is installed in your environment. Verify with:
If CUDA is not installed or the version is incorrect, follow the CUDA installation guide to set it up.
-
Pre-install PyTorch and pybind11: Install PyTorch and pybind11 to meet Fast-LLM's requirements:
-
Install NVIDIA APEX: Fast-LLM uses certain kernels from APEX. Follow the installation instructions on their GitHub page, ensuring you use the
--cuda_ext
and--fast_layer_norm
options to install all kernels supported by Fast-LLM:git clone https://github.com/NVIDIA/apex ./fast-llm-tutorial/apex pushd ./fast-llm-tutorial/apex pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" --config-settings "--build-option=--fast_layer_norm" ./ popd
-
Install Fast-LLM and Dependencies: Finally, install Fast-LLM along with its remaining dependencies, including FlashAttention-2:
-
Verify the Installation: Confirm the setup with the following commands:
If you made it this far without any errors, your local environment is ready to run Fast-LLM.
Use Docker-enabled Slurm for this tutorial. The ghcr.io/servicenow/fast-llm:latest
Docker image will be pulled and run on the compute nodes. Ensure the fast-llm-tutorial
directory is accessible across all nodes (e.g., via a shared filesystem like NFS).
Use Kubernetes with Kubeflow and a PyTorchJob
resource to train our model using the ghcr.io/servicenow/fast-llm:latest
Docker image. We'll copy the configuration files and dataset to shared persistent volume claims (PVCs) to ensure all nodes have access to the same data. Follow these steps:
-
Create a Persistent Volume Claim (PVC)
Create a PVC named
pvc-fast-llm-tutorial
to store input data and output results:kubectl apply -f - <<EOF apiVersion: "v1" kind: "PersistentVolumeClaim" metadata: name: "pvc-fast-llm-tutorial" spec: storageClassName: local-path # (1)! accessModes: - ReadWriteMany resources: requests: storage: 100Gi # (2)! EOF
- Replace with your cluster's StorageClassName.
- Adjust the storage size as needed.
StorageClassName
Replace
local-path
with the appropriateStorageClassName
for your Kubernetes cluster. Consult your cluster admin or documentation if unsure. -
Set Up a Temporary Pod for Data Management
Create a temporary pod to manage input data and results:
kubectl apply -f - <<EOF apiVersion: v1 kind: Pod metadata: name: pod-fast-llm-tutorial spec: containers: - name: fast-llm-tutorial-container image: ghcr.io/servicenow/fast-llm:latest command: ["sleep", "infinity"] volumeMounts: - mountPath: /app/fast-llm-tutorial name: fast-llm-tutorial volumes: - name: fast-llm-tutorial persistentVolumeClaim: claimName: pvc-fast-llm-tutorial EOF
Purpose of the Temporary Pod
This pod ensures you have an interactive container for managing input data and retrieving results. Use
kubectl exec
to interact with it:Use
kubectl cp
to copy files between the pod and your local machine:
🤖 Step 2: Choose Your Training Configuration¶
This guide offers two training configurations:
For a quick, single-node setup and immediate results to test Fast-LLM with a smaller model. Ideal for getting started and understanding the basics. It's the "hello world" of Fast-LLM.
For a more advanced setup with more data and larger models to explore Fast-LLM's full capabilities. This configuration requires more resources and time to complete, but it prepares you for production-like workloads.
Choose based on your goals for this tutorial.
📥 Step 3: Download the Pretrained Model¶
For the small configuration, we'll use a SmolLM2 model configuration with 135M parameters, which is fast to train. Run the following commands to download the model configuration and tokenizer:
For the big configuration, we'll use a Llama model with 8B parameters. We'll grab the model from the Huggingface Hub and save it to our inputs folder.
Access Required
Meta gates access to their Llama models. You need to request access to the model from Meta before you can download it at https://huggingface.co/meta-llama/Llama-3.1-8B. You'll need to authenticate with your Hugging Face account to download the model:
When asked for whether to use this as git credentials, answer in the affirmative.
📚 Step 3: Prepare the Training Data¶
For this tutorial, we'll use text from the OpenWebText dataset. This dataset is a free approximation of the WebText data OpenAI used for GPT-2, and it's perfect for our test run!
Create a configuration file for the dataset preparation. Copy the following content:
output_path: fast-llm-tutorial/dataset
loading_workers: 16 # (1)!
tokenize_workers: 16
saving_workers: 16
dataset:
path: stas/openwebtext-10k # (2)!
split: "train"
trust_remote_code: true
tokenizer:
path: fast-llm-tutorial/pretrained-model
- Processing speed scales linearly with the number of CPUs.
- This small dataset restricts to the first 10K records of the OpenWebText dataset to speed up the process. If you want to use the full dataset, replace with
openwebtext
.
Save it as ./fast-llm-tutorial/prepare-config.yaml
.
Fast-LLM ships with a prepare
command that will download and preprocess the dataset for you.
Run data preparation with the following command:
Run data preparation with the following command:
Run data preparation with the following command:
sbatch <<EOF
#!/bin/bash
# SBATCH --job-name=fast-llm-prepare
# SBATCH --nodes=4
# SBATCH --ntasks-per-node=1
# SBATCH --exclusive
# SBATCH --output=/app/fast-llm-tutorial/prepare-output.log
# SBATCH --error=/app/fast-llm-tutorial/prepare-error.log
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=8001
export PYTHONHASHSEED=0
srun \
--container-image="ghcr.io/servicenow/fast-llm:latest" \
--container-mounts="$(pwd)/fast-llm-tutorial:/app/fast-llm-tutorial" \
--container-env="PYTHONHASHSEED" \
--ntasks-per-node=$SLURM_NTASKS_PER_NODE \
bash -c "
torchrun --rdzv_backend=static \
--rdzv_id=0 \
--rdzv_endpoint=\${MASTER_ADDR}:\${MASTER_PORT} \
--node_rank=\\$SLURM_NODEID \
--nproc_per_node=\\$SLURM_NTASKS_PER_NODE \
--nnodes=\\$SLURM_NNODES:\\$SLURM_NNODES \
--max_restarts=0 \
--rdzv_conf=timeout=3600 \
--no_python \
fast-llm prepare gpt_memmap \
--config fast-llm-tutorial/prepare-config.yaml"
EOF
You can follow the job's progress by running squeue -u $USER
and checking the logs in fast-llm-tutorial/prepare-output.log
and fast-llm-tutorial/prepare-error.log
, respectively.
Copy the files to the shared PVC if they're not already there:
Then, run data preparation with the following command:
kubectl apply -f - <<EOF
apiVersion: "kubeflow.org/v1"
kind: "PyTorchJob"
metadata:
name: "fast-llm-prepare"
spec:
nprocPerNode: "1"
pytorchReplicaSpecs:
Master:
replicas: 1
restartPolicy: Never
template:
spec:
tolerations:
- key: nvidia.com/gpu
value: "true"
operator: Equal
effect: NoSchedule
containers:
- name: pytorch
image: ghcr.io/servicenow/fast-llm:latest
resources:
limits:
memory: "1024Gi"
cpu:
requests:
memory: "1024Gi"
cpu: 128
command:
- /bin/bash
- -c
- |
torchrun --rdzv_backend=static \
--rdzv_id=0 \
--rdzv_endpoint=\${MASTER_ADDR}:\${MASTER_PORT} \
--node_rank=\${RANK} \
--nproc_per_node=\${PET_NPROC_PER_NODE} \
--nnodes=\${PET_NNODES}:\${PET_NNODES} \
--max_restarts=0 \
--rdzv_conf=timeout=3600 \
--no_python \
fast-llm prepare gpt_memmap \
--config fast-llm-tutorial/prepare-config.yaml
env:
- name: PYTHONHASHSEED
value: "0"
securityContext:
capabilities:
add:
- IPC_LOCK
volumeMounts:
- mountPath: /app/fast-llm-tutorial
name: fast-llm-tutorial
- mountPath: /dev/shm
name: dshm
volumes:
- name: fast-llm-tutorial
persistentVolumeClaim:
claimName: pvc-fast-llm-tutorial
- name: dshm
emptyDir:
medium: Memory
sizeLimit: "1024Gi"
Worker:
replicas: 3
restartPolicy: Never
template:
spec:
tolerations:
- key: nvidia.com/gpu
value: "true"
operator: Equal
effect: NoSchedule
containers:
- name: pytorch
image: ghcr.io/servicenow/fast-llm:latest
resources:
limits:
memory: "1024Gi"
cpu:
requests:
memory: "1024Gi"
cpu: 128
command:
- /bin/bash
- -c
- |
torchrun --rdzv_backend=static \
--rdzv_id=0 \
--rdzv_endpoint=\${MASTER_ADDR}:\${MASTER_PORT} \
--node_rank=\${RANK} \
--nproc_per_node=\${PET_NPROC_PER_NODE} \
--nnodes=\${PET_NNODES}:\${PET_NNODES} \
--max_restarts=0 \
--rdzv_conf=timeout=3600 \
--no_python \
fast-llm prepare gpt_memmap \
--config fast-llm-tutorial/prepare-config.yaml
env:
- name: PYTHONHASHSEED
value: "0"
securityContext:
capabilities:
add:
- IPC_LOCK
volumeMounts:
- mountPath: /app/fast-llm-tutorial
name: fast-llm-tutorial
- mountPath: /dev/shm
name: dshm
volumes:
- name: fast-llm-tutorial
persistentVolumeClaim:
claimName: pvc-fast-llm-tutorial
- name: dshm
emptyDir:
medium: Memory
sizeLimit: "1024Gi"
EOF
You can follow the job's progress by running kubectl get pods
and checking the logs with kubectl logs fast-llm-prepare-master-0
.
⚙️ Step 4: Configure Fast-LLM¶
Next, we'll create a configuration file for Fast-LLM.
FlashAttention
Fast-LLM uses FlashAttention by default. If you're using Volta GPUs, you must disable FlashAttention by setting use_flash_attention: no
in the configuration file, as shown below.
Micro-Batch Size
The micro_batch_size
in the configuration below is optimized for 80GB GPUs. If you're using GPUs with less memory, you will need to lower this value. Alternatively, you can decrease the sequence_length
to reduce the memory footprint.
Save the following as fast-llm-tutorial/train-config.yaml
:
training:
train_iters: 100 # (1)!
logs:
interval: 10
validation:
iterations: 25
interval: 100
export: # (2)!
format: llama
interval: 100
wandb: # (3)!
project_name: fast-llm-tutorial
group_name: Small
entity_name: null
batch:
micro_batch_size: 60 # (4)!
sequence_length: 1024
batch_size: 480 # (5)!
data:
format: file
path: fast-llm-tutorial/dataset/fast_llm_dataset.json # (6)!
split: [9, 1, 0] # (7)!
optimizer:
learning_rate:
base: 6.0e-04
pretrained:
format: llama # (8)!
path: fast-llm-tutorial/pretrained-model
model_weights: no # (9)!
model:
base_model:
transformer:
use_flash_attention: yes # (10)!
distributed:
training_dtype: bf16 # (11)!
run:
experiment_dir: fast-llm-tutorial/experiment
- For the small run, we'll stop after 100 iterations.
- The trained model will be saved in
Transformers
Llama format tofast-llm-tutorial/experiment/export/llama/100
at the end of the small run. You can also save as aFast-LLM
checkpoint by setting theformat
tofast_llm
. - Entirely optional, but it's a good idea to track your training progress with Weights & Biases. Replace
null
with your own W&B entity name. If you don't want to use W&B, just ignore this section. - Adjust the number of sequences per GPU based on GPU memory. For SmolLM2-135M at 1024 sequenced length and a 80GB GPU, a
micro_batch_size
of 60 should work well. - Must be divisible by the number of GPUs and the
micro_batch_size
. At 1024 tokens per sequence, 480 corresponds to about 500,000 tokens per batch. - Location of the dataset metadata file generated in Step 4.
- 90% train, 10% validation, 0% test. These settings need to be adjusted based on the size of your dataset.
- Format of the pretrained model. Since SmolLM is a Llama model, we set this to
llama
. - We'll train SmolLM2-135M from scratch. You can set to
yes
to continue training from a checkpoint (if you put one in the model directory). - By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to
no
. bf16
(bfloat16, or Brain Floating Point 16) is supported on Ampere GPUs and higher. On Volta GPUs, usefp16
(half-precision floating point) for training instead ofbf16
.
training:
train_iters: 100_000 # (1)!
logs:
interval: 10
validation:
iterations: 25
interval: 1000
checkpoint:
interval: 1000
keep: 5
test_iters: 0
export: # (2)!
format: llama
interval: 20_000
wandb: # (3)!
project_name: fast-llm-tutorial
group_name: Big
entity_name: null
batch:
micro_batch_size: 2 # (4)!
sequence_length: 4096
batch_size: 512 # (5)!
data:
format: file
path: fast-llm-tutorial/dataset/fast_llm_dataset.json # (6)!
split: [99, 1, 0] # (7)!
optimizer: # (8)!
weight_decay: 0.1
beta_1: 0.9
beta_2: 0.95
learning_rate: # (9)!
base: 6.0e-04
minimum: 6.0e-05
decay_style: cosine
decay_iterations: 100_000
warmup_iterations: 2000
pretrained:
format: llama # (10)!
path: fast-llm-tutorial/pretrained-model
model_weights: yes # (11)!
model:
base_model:
transformer:
use_flash_attention: yes # (12)!
cross_entropy_impl: fused # (13)!
multi_stage:
zero_stage: 2 # (14)!
distributed:
training_dtype: bf16 # (15)!
run:
experiment_dir: fast-llm-tutorial/experiment
- Total number of training tokens will be approximately 210B: 100,000 iterations * 512 * 4096 tokens per batch.
- A permanent model checkpoint in
Transformers
Llama format will be saved tofast-llm-tutorial/experiment/export/llama/[iteration]/
every 20,000 iterations. You can also save as aFast-LLM
checkpoint by setting theformat
tofast_llm
. - Entirely optional, but it's a good idea to track your training progress with Weights & Biases. Replace
null
with your own W&B entity name. If you don't want to use W&B, just ignore this section. - Adjust the number of sequences per GPU based on GPU memory. Considering a 4k token sequence length and 80GB GPUs, a
micro_batch_size
of 1 should work well. - Must be divisible by the number of GPUs and the
micro_batch_size
. At 4k tokens per sequence, 512 corresponds to about 2.1 million tokens per batch. - Location of the dataset metadata file generated in Step 4.
- 99% train, 1% validation, 0% test. These settings need to be adjusted based on the size of your dataset. If you're using a smaller dataset, you need to increase the validation split.
- These are good default optimizer settings for training models.
- We are using a cosine decay schedule with linear warmup. After reaching the peak learning rate
base
atwarmup_iterations
, the learning rate will decay tominimum
atdecay_iterations
, following a cosine curve. The minimum learning rate should be 1/10th of the base learning rate per Chinchilla. - Format of the pretrained model. Since it's a Llama model, we set this to
llama
. - We want to continue training Llama-3.1-8B from a checkpoint. If you're training from scratch, set this to
no
. - By default, Fast-LLM uses FlashAttention for faster training. If you're using Volta GPUs, set this to
no
. - Configure Fast-LLM to use the fused cross-entropy loss implementation rather than the default Triton implementation for models with a large vocabulary size such as Llama-3.1-8B. This avoids issues with block size limitations in our current Triton code.
- We are using ZeRO stage 2 for this tutorial. You can set this to
1
,2
, or3
for ZeRO-1, ZeRO-2, or ZeRO-3, respectively. bf16
(bfloat16, or Brain Floating Point 16) is supported on Ampere GPUs and higher. On Volta GPUs, usefp16
(half-precision floating point) for training instead ofbf16
.
🔑 (Optional) Step 6: Add Your Weights & Biases API Key¶
If you included the W&B section in your configuration, you'll need to add your API key. Save it to ./fast-llm-tutorial/.wandb_api_key
and use the WANDB_API_KEY_PATH
environment variable as shown in the training command.
🚀 Step 7: Launch Training¶
Alright, the big moment! Let's launch the training run.
Python Hash Seed
The Python hash seed must be set to 0 to ensure consistent, reproducible ordering in hash-dependent operations across processes. Training will fail if this isn't set.
If you have 8 GPUs available, run the following to start training:
If you have 8 GPUs available, run the following to start training:
If you have 4 nodes with 8 GPUs each, run the following to start training:
sbatch <<EOF
#!/bin/bash
# SBATCH --job-name=fast-llm-train
# SBATCH --nodes=4
# SBATCH --gpus-per-node=8
# SBATCH --ntasks-per-node=1
# SBATCH --exclusive
# SBATCH --output=/app/fast-llm-tutorial/train-output.log
# SBATCH --error=/app/fast-llm-tutorial/train-error.log
export PYTHONHASHSEED=0
export WANDB_API_KEY_PATH=/app/fast-llm-tutorial/.wandb_api_key
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_DEBUG=INFO
srun \
--container-image="ghcr.io/servicenow/fast-llm:latest" \
--container-mounts="$(pwd)/fast-llm-tutorial:/app/fast-llm-tutorial" \
--container-env="PYTHONHASHSEED,WANDB_API_KEY_PATH,TORCH_NCCL_ASYNC_ERROR_HANDLING,NCCL_DEBUG" \
--gpus-per-node=\$SLURM_GPUS_PER_NODE \
--ntasks-per-node=\$SLURM_NTASKS_PER_NODE \
bash -c "
torchrun --rdzv_backend=static \
--rdzv_id=0 \
--rdzv_endpoint=\${MASTER_ADDR}:\${MASTER_PORT} \
--node_rank=\\$SLURM_NODEID \
--nproc_per_node=\\$SLURM_GPUS_PER_NODE \
--nnodes=\\$SLURM_NNODES \
--max_restarts=0 \
--rdzv_conf=timeout=3600 \
--no_python \
fast-llm train gpt \
--config fast-llm-tutorial/train-config.yaml"
EOF
Copy the configuration file to the shared PVC:
If you have 4 nodes with 8 GPUs each, run the following to start training:
kubectl apply -f - <<EOF
apiVersion: "kubeflow.org/v1"
kind: "PyTorchJob"
metadata:
name: "fast-llm-train"
spec:
nprocPerNode: "8"
pytorchReplicaSpecs:
Master:
replicas: 1
restartPolicy: Never
template:
spec:
tolerations:
- key: nvidia.com/gpu
value: "true"
operator: Equal
effect: NoSchedule
containers:
- name: pytorch
image: ghcr.io/servicenow/fast-llm:latest
resources:
limits:
nvidia.com/gpu: 8
rdma/rdma_shared_device_a: 1
memory: "1024Gi"
cpu:
requests:
nvidia.com/gpu: 8
rdma/rdma_shared_device_a: 1
memory: "1024Gi"
cpu: 128
command:
- /bin/bash
- -c
- |
torchrun --rdzv_backend=static \
--rdzv_endpoint=\${MASTER_ADDR}:\${MASTER_PORT} \
--node_rank=\${RANK} \
--nproc_per_node=\${PET_NPROC_PER_NODE} \
--nnodes=\${PET_NNODES} \
--max_restarts=0 \
--rdzv_conf=timeout=3600 \
--no_python \
fast-llm train gpt \
--config fast-llm-tutorial/train-config.yaml
env:
- name: PYTHONHASHSEED
value: "0"
- name: WANDB_API_KEY_PATH
value: "/app/fast-llm-tutorial/.wandb_api_key"
- name: TORCH_NCCL_ASYNC_ERROR_HANDLING
value: "1"
- name: NCCL_DEBUG
value: "INFO"
securityContext:
capabilities:
add:
- IPC_LOCK
volumeMounts:
- mountPath: /app/fast-llm-tutorial
name: fast-llm-tutorial
- mountPath: /dev/shm
name: dshm
volumes:
- name: fast-llm-tutorial
persistentVolumeClaim:
claimName: pvc-fast-llm-tutorial
- name: dshm
emptyDir:
medium: Memory
sizeLimit: "1024Gi"
Worker:
replicas: 3
restartPolicy: Never
template:
spec:
tolerations:
- key: nvidia.com/gpu
value: "true"
operator: Equal
effect: NoSchedule
containers:
- name: pytorch
image: ghcr.io/servicenow/fast-llm:latest
resources:
limits:
nvidia.com/gpu: 8
rdma/rdma_shared_device_a: 1
memory: "1024Gi"
cpu:
requests:
nvidia.com/gpu: 8
rdma/rdma_shared_device_a: 1
memory: "1024Gi"
cpu: 128
command:
- /bin/bash
- -c
- |
torchrun --rdzv_backend=static \
--rdzv_endpoint=\${MASTER_ADDR}:\${MASTER_PORT} \
--node_rank=\${RANK} \
--nproc_per_node=\${PET_NPROC_PER_NODE} \
--nnodes=\${PET_NNODES} \
--max_restarts=0 \
--rdzv_conf=timeout=3600 \
--no_python \
fast-llm train gpt \
--config fast-llm-tutorial/train-config.yaml
env:
- name: PYTHONHASHSEED
value: "0"
- name: WANDB_API_KEY_PATH
value: "/app/fast-llm-tutorial/.wandb_api_key"
- name: TORCH_NCCL_ASYNC_ERROR_HANDLING
value: "1"
- name: NCCL_DEBUG
value: "INFO"
securityContext:
capabilities:
add:
- IPC_LOCK
volumeMounts:
- mountPath: /app/fast-llm-tutorial
name: fast-llm-tutorial
- mountPath: /dev/shm
name: dshm
volumes:
- name: fast-llm-tutorial
persistentVolumeClaim:
claimName: pvc-fast-llm-tutorial
- name: dshm
emptyDir:
medium: Memory
sizeLimit: "1024Gi"
EOF
📊 Step 8. Track Training Progress¶
Fast-LLM will log training progress to the console every 10 iterations.
You can cancel training at any time by pressing Ctrl+C
in the terminal.
Fast-LLM will log training progress to the console every 10 iterations.
You can cancel training at any time by pressing Ctrl+C
in the terminal.
Use squeue -u $USER
to see the job status.
Follow train-output.log
and train-error.log
in your working directory for logs.
Fast-LLM will log training progress to those files every 10 iterations.
You can cancel training by running scancel <job_id>
.
Use kubectl get pods
to see the job status.
Use kubectl logs fast-llm-train-master-0
to check the logs.
Fast-LLM will log training progress to the console every 10 iterations.
You can cancel training by deleting the PyTorchJob:
You can expect to see the following performance metrics in Fast-LLM's output:
If you included the W&B section in your configuration, you can also track your training progress on the Weights & Biases dashboard as well. Follow the link in the console output to view your training run.
🎉 Final Thoughts¶
And that's it! You've set up, prepped data, chosen a model, configured training, and launched a full training run with Fast-LLM. You can try out the saved model directly with Transformers.
import transformers
model = transformers.AutoModelForCausalLM.from_pretrained("/app/fast-llm-tutorial/experiment/export/llama/100").cuda()
tokenizer = transformers.AutoTokenizer.from_pretrained("fast-llm-tutorial/pretrained-model/")
inputs = {k:v.cuda() for k,v in tokenizer("This is what the small model can do after fine-tuning for 100 steps:", return_tensors="pt").items()}
outputs=model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))
import transformers
model = transformers.AutoModelForCausalLM.from_pretrained("/app/fast-llm-tutorial/experiment/export/llama/100000").cuda()
tokenizer = transformers.AutoTokenizer.from_pretrained("fast-llm-tutorial/pretrained-model/")
inputs = {k:v.cuda() for k,v in tokenizer("This is what the big model can do after fine-tuning for 100K steps:", return_tensors="pt").items()}
outputs=model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))
From here, feel free to tweak the model, try out larger datasets, or scale things up to larger clusters. The sky's the limit!
Happy training!
-
Precision was set to
fp16
, sincebf16
is not supported on V100 GPUs. FlashAttention was disabled, as it is not supported on V100 GPUs. Micro-batch size was set to 12. ↩ -
Precision was set to
bf16
. FlashAttention was enabled. Micro-batch size was set to 60. ↩ -
Precision was set to
bf16
. FlashAttention was enabled. Micro-batch size was set to 60. ↩ -
Precision was set to
fp16
, sincebf16
is not supported on V100 GPUs. FlashAttention was disabled, as it is not supported on V100 GPUs. Micro-batch size was set to 4. ↩ -
Precision was set to
bf16
. FlashAttention was enabled. Micro-batch size was set to 1. ZeRO stage 2 was used. ↩ -
Precision was set to
bf16
. FlashAttention was enabled. Micro-batch size was set to 2. ZeRO stage 2 was used. ↩ -
Theoretical peak performance of the GPU for dense tensors in
fp16
orbf16
precision, depending on the GPU architecture. Source: Wikipedia. ↩↩