Professional Documents
Culture Documents
Meta Llama3 in Torchtune - Torchtune Main Documentation
Meta Llama3 in Torchtune - Torchtune Main Documentation
Table of Contents
Llama3-8B
Meta Llama 3 is a new family of models released by Meta AI that improves upon the performance of the Llama2 family of models across a range of different benchmarks. Currently there are
two different sizes of Meta Llama 3: 8B and 70B. In this tutorial we will focus on the 8B size model. There are a few main changes between Llama2-7B and Llama3-8B models:
Llama3-8B uses grouped-query attention instead of the standard multi-head attention from Llama2-7B
Llama3-8B has a larger vocab size (128,256 instead of 32,000 from Llama2 models)
Llama3-8B uses a different tokenizer than Llama2 models (tiktoken instead of sentencepiece)
Llama3-8B uses a larger intermediate dimension in its MLP layers than Llama2-7B
Llama3-8B uses a higher base value to calculate theta in its rotary positional embeddings
Let’s take a look at how we can fine-tune Llama3-8B-Instruct with LoRA on a single device using torchtune. In this example, we will fine-tune for one epoch on a common instruct dataset for
illustrative purposes. The basic command for a single-device LoRA fine-tune is
• NOTE
To see a full list of recipes and their corresponding configs, simply run tune ls from the command line.
This will load the Llama3-8B-Instruct checkpoint and tokenizer from <checkpoint_dir> used in the tune download command above, then save a final checkpoint in the same directory
following the original format. For more details on the checkpoint formats supported in torchtune, see our checkpointing deep-dive.
• NOTE
To see the full set of configurable parameters for this (and other) configs we can use tune cp to copy (and modify) the default config. tune cp can be used with recipe scripts too,
in case you want to make more custom changes that cannot be achieved by directly modifying existing configurable parameters. For more on tune cp see the section on modifying
configs.
https://pytorch.org/torchtune/main/tutorials/llama3.html 1/5
6/24/24, 6:15 PM Meta Llama3 in torchtune — torchtune main documentation
Once training is complete, the model checkpoints will be saved and their locations will be logged. For LoRA fine-tuning, the final checkpoint will contain the merged weights, and a copy of
just the (much smaller) LoRA weights will be saved separately.
In our experiments, we observed a peak memory usage of 18.5 GB. The default config can be trained on a consumer GPU with 24 GB VRAM.
If you have multiple GPUs available, you can run the distributed version of the recipe. torchtune makes use of the FSDP APIs from PyTorch Distributed to shard the model, optimizer states,
and gradients. This should enable you to increase your batch size, resulting in faster overall training. For example, on two devices:
Finally, if we want to use even less memory, we can leverage torchtune’s QLoRA recipe via:
Since our default configs enable full bfloat16 training, all of the above commands can be run with devices having at least 24 GB of VRAM, and in fact the QLoRA recipe should have peak
allocated memory below 10 GB. You can also experiment with different configurations of LoRA and QLoRA, or even run a full fine-tune. Try it out!
First, torchtune provides an integration with EleutherAI’s evaluation harness for model evaluation on common benchmark tasks.
• NOTE
Make sure you’ve first installed the evaluation harness via pip install "lm_eval==0.4.*" .
For this tutorial we’ll use the truthfulqa_mc2 task from the harness. This task measures a model’s propensity to be truthful when answering questions and measures the model’s zero-shot
accuracy on a question followed by one or more true responses and one or more false responses. First, let’s copy the config so we can point the YAML file to our fine-tuned checkpoint files.
model:
_component_: torchtune.models.llama3.llama3_8b
checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
output_dir: <checkpoint_dir>
model_type: LLAMA3
Try it for yourself and see what accuracy your model gets!
Similar to what we did, let’s copy and modify the default generation config.
https://pytorch.org/torchtune/main/tutorials/llama3.html 2/5
6/24/24, 6:15 PM Meta Llama3 in torchtune — torchtune main documentation
Now we modify custom_generation_config.yaml to point to our checkpoint and tokenizer.
model:
_component_: torchtune.models.llama3.llama3_8b
checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
output_dir: <checkpoint_dir>
model_type: LLAMA3
Running generation with our LoRA-finetuned model, we see the following output:
[generate.py:122] Hello, my name is Sarah and I am a busy working mum of two young children, living in the North East of England.
...
[generate.py:135] Time for inference: 10.88 sec total, 18.94 tokens/sec
[generate.py:138] Bandwidth achieved: 346.09 GB/s
[generate.py:139] Memory used: 18.31 GB
If you’ve been following along this far, you know the drill by now. Let’s copy the quantization config and point it at our fine-tuned model.
# Model arguments
model:
_component_: torchtune.models.llama3.llama3_8b
checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
output_dir: <checkpoint_dir>
model_type: LLAMA3
We can see that the model is now under 5 GB, or just over four bits for each of the 8B parameters.
• NOTE
Unlike the fine-tuned checkpoints, the quantization recipe outputs a single checkpoint file. This is because our quantization APIs currently don’t support any conversion across
formats. As a result you won’t be able to use these quantized models outside of torchtune. But you should be able to use these with the generation and evaluation recipes within
torchtune. These results will help inform which quantization methods you should use with your favorite inference engine.
Let’s take our quantized model and run the same generation again. First, we’ll make one more change to our custom_generation_config.yaml.
https://pytorch.org/torchtune/main/tutorials/llama3.html 3/5
6/24/24, 6:15 PM Meta Llama3 in torchtune — torchtune main documentation
checkpointer:
# we need to use the custom torchtune checkpointer
# instead of the HF checkpointer for loading
# quantized models
_component_: torchtune.utils.FullModelTorchTuneCheckpointer
output_dir: <checkpoint_dir>
model_type: LLAMA3
This is just the beginning of what you can do with Meta Llama3 using torchtune and the broader ecosystem. We look forward to seeing what you build!
Previous Next
Docs
Access comprehensive developer documentation for PyTorch
View Docs
Tutorials
Get in-depth tutorials for beginners and advanced developers
View Tutorials
Resources
Find development resources and get your questions answered
View Resources
PyTorch Resources
Features Docs
Ecosystem Discuss
Facebook Spotify
https://pytorch.org/torchtune/main/tutorials/llama3.html 4/5
6/24/24, 6:15 PM Meta Llama3 in torchtune — torchtune main documentation
Twitter Apple
YouTube Google
LinkedIn Amazon
Terms | Privacy
© Copyright The Linux Foundation. The PyTorch Foundation is a project of The Linux Foundation. For web site terms of use, trademark policy and other policies applicable to The PyTorch Foundation
please see www.linuxfoundation.org/policies/. The PyTorch Foundation supports the PyTorch open source project, which has been established as PyTorch Project a Series of LF Projects, LLC. For
policies applicable to the PyTorch Project a Series of LF Projects, LLC, please see www.lfprojects.org/policies/.
https://pytorch.org/torchtune/main/tutorials/llama3.html 5/5