Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

  1. Prerequisites
    This tutorial assumes you've already successfully completed "Gemma-7B Inference using NGC PyTorch". For fine-tuning Gemma, we will rely on the NGC PyTorch container and the libraries we installed in the Python environment.

  2. Set up TRL
    We will use HuggingFace TRL to fine-tune Gemma-7B on the OpenAssistant dataset. First, we need to update our Python environment with some extra libraries to support TRL. To do this, we can launch an interactive shell in the PyTorch container, just like we did in the previous tutorial. Then, we install trl and peft:
    Code Block
    [cluster][user@cluster-ln001 ~]$ srun --environment=gemma-pytorch --container-workdir=$PWD --pty bash
    user@nid001234:/bret/scratch/cscs/user/gemma-inference$ source ./gemma-venv/bin/activate
    (gemma-venv) user@nid001234:/bret/scratch/cscs/user/gemma-inference$ python -m pip install trl peft
    # ... pip output ...
    When this step is complete, you can exit the shell by typing exit.

    Next, we also need to clone the TRL git repository so that we can access some of the scripts in it. We also need to checkout a particular commit to make it compatible with the fine-tuning script we're going to setup:
    Code Block
    [cluster][user@cluster-ln001 ~]$ git clone https://github.com/huggingface/trl
    [cluster][user@cluster-ln001 ~]$ cd trl
    [cluster][user@cluster-ln001 trl]$ git checkout 9bc478ecbb2e9d0c9311784428347c382c05303d
    # ... git output ...
    [cluster][user@cluster-ln001 trl]$ cd ..
  3. Fine-tune Gemma
    At this point, we can set up a fine-tuning script and start training Gemma-7B. Use your favorite text editor to create the file fine-tune-gemma.sh just  just outside the trl and gemma-venv directories:
  4. Code Block
    #!/bin/bash
    
    GPUS_PER_NODE=4
    ACCEL_PROCS=$(( $SLURM_NNODES * $GPUS_PER_NODE ))
    
    MAIN_ADDR=$(echo "${SLURM_NODELIST}" | sed 's/[],].*//g; s/\[//g')
    MAIN_PORT=12802
    
    source ./gemma-venv/bin/activate
    accelerate launch --config_file trl/examples/accelerate_configs/multi_gpu.yaml \
               --num_machines=$SLURM_NNODES --num_processes=$ACCEL_PROCS \
               --machine_rank $SLURM_PROCID \
               --main_process_ip $MAIN_ADDR --main_process_port $MAIN_PORT \
               trl/examples/scripts/sft.py \
               --model_name google/gemma-7b \
               --dataset_name OpenAssistant/oasst_top1_2023-08-25 \
               --per_device_train_batch_size 2 \
               --gradient_accumulation_steps 1 \
               --learning_rate 2e-4 \
               --save_steps 1894 \
               --use_peft \
               --lora_r 16 --lora_alpha 32 \
               --lora_target_modules q_proj k_proj v_proj o_proj \
               --output_dir gemma-finetuned-openassistant
    This script has quite a bit more content to unpack. We use HuggingFace accelerate to launch the fine-tuning process, so we need to make sure that accelerate understands which hardware is available and where. Setting this up will be useful in the long run because it means we can tell Slurm how much hardware to reserve, and this script will setup all the details for us.

    First, we manually set GPUS_PER_NODE to 4; the cluster has four GH200 chips per compute node. Then, we calculate how many processes accelerate should launch; this should be four processes per node, so we multiply the two corresponding variables. Next, we use some bash magic to extract the name of the head node from Slurm environment variables. Accelerate expects one main node and launches tasks on the other nodes from this main node. Finally, we can source our python environment and launch Gemma fine-tuning. The first four lines of the launch line are used to configure accelerate. Everything after that configures the trl/examples/scripts/sft.py Python script, which we use to train Gemma.

    Next, we also need to create a short Slurm batch script to launch our fine-tuning script:
    Code Block
    titlefine-tune-sft.sbatch
    #!/bin/bash
    #SBATCH --job-name=gemma-finetune
    #SBATCH --time=00:30:00
    #SBATCH --ntasks-per-node=1
    #SBATCH --cpus-per-task=288
    #SBATCH --account=<project>
    
    export HF_HOME=$SCRATCH/huggingface
    
    srun --environment=gemma-pytorch --container-workdir=$PWD bash fine-tune-gemma.sh
    We set a few Slurm parameters like we already did in the previous tutorial. Note that we leave the number of nodes unspecified. This way, we can decide the number of nodes we want to use when we launch the batch job using Slurm.

    Now that we've setup a fine-tuning script and a Slurm batch script, we can launch our fine-tuning job. We'll start out by launching it on two nodes. It should take about 10-15 minutes to fine-tune Gemma:
    Code Block
    [cluster][user@cluster-ln001 ~]$ sbatch --nodes=2 fine-tune-sft.sbatch


  5. Compare fine-tuned Gemma against default Gemma
    We can reuse our python script from the first tutorial to do inference on the Gemma model that we just fine-tuned. Let's try out a different prompt in gemma-inference.py :
    Code Block
    input_text = "What are the 5 tallest mountains in the Swiss Alps?"

    We can run inference using our batch script from the previous tutorial:
    Code Block
    [cluster][user@cluster-ln001 ~]$ sbatch ./gemma-inference.sbatch

    Inspecting the output should yield something like this:
    Code Block
    <bos>What are the 5 tallest mountains in the Swiss Alps?
    
    The Swiss Alps are home to some of the tallest mountains in the world. Here are
    the 5 tallest mountains in the Swiss Alps:
    
    1. Mont Blanc (4,808 meters)
    2. Matterhorn (4,411 meters)
    3. Dom (4,161 meters)
    4. Jungfrau (4,158 meters)
    5. Mont Rose (4,117 meters)<eos>

    Next, we can update the model line in our Python inference script to use the model that we just fine-tuned:
    Code Block
    model = AutoModelForCausalLM.from_pretrained("gemma-finetuned-openassistant/checkpoint-1894", device_map="auto")

    If we re-run inference, the output will be a bit more detailed and explanatory, similar to output we might expect from a helpful chatbot. One example looks like this:
    Code Block
    <bos>What are the 5 tallest mountains in the Swiss Alps?
    
    The Swiss Alps are home to some of the tallest mountains in Europe, and they are a popular destination for mountai
    neers and hikers. Here are the five tallest mountains in the Swiss Alps:
    
    1. Mont Blanc (4,808 m/15,774 ft): Mont Blanc is the highest mountain in the Alps and the highest mountain in Euro
    pe outside of Russia. It is located on the border between France and Italy, and it is a popular destination for mo
    untaineers and hikers.
    
    2. Dufourspitze (4,634 m/15,203 ft): Dufourspitze is the highest mountain in Switzerland and the second-highest mo
    untain in the Alps. It is located in the Valais canton of Switzerland, and it is a popular destination for mountai
    neers and hikers.
    
    3. Liskamm (4,527 m/14,855 ft): Liskamm is a mountain in the Bernese Alps of Switzerland. It is located in the Ber
    n canton of Switzerland, and it is a popular destination for mountaineers and hikers.
    
    4. Weisshorn (4,506 m/14,783 ft): Weisshorn is a mountain in the Pennine Alps of Switzerland. It is located in the
     Valais canton of Switzerland, and it is a popular destination for mountaineers and hikers.
    
    5. Matterhorn (4,478 m/14,690 ft): Matterhorn is a mountain in the Pennine Alps of Switzerland. It is located in the Valais canton of Switzerland, and it is a popular destination for mountaineers and hikers.
    
    These mountains are all located in the Swiss Alps, and they are a popular destination for mountaineers and hikers. If you are planning a trip to the Swiss Alps, be sure to check out these mountains and plan your itinerary accordingly.

    Your output may look different after fine-tuning, but in general you will see that the fine-tuned model generates more verbose output. Double-checking the output reveals that the list of mountains produced by Gemma is not actually correct. The following table lists the 5 tallest Swiss peaks, according to Wikipedia.

    Dufourspitze4,634m
    Nordend4,609m
    Zumsteinspitze4,563m
    Signalkuppe4,554m
    Dom

    4,545m

    This is an important reminder that machine-learning models like Gemma need extra checks to confirm any generated outputs.