Fine-tune a language model
You can fine-tune language models to make them better at a particular task. With Replicate, you can fine-tune and run your model in the cloud without having to set up any GPUs.
You can train a language model to do things like:
- Generate text in a particular style: Trained on text in a particular style, it can generate more text in that style.
- Classify text: Trained on text and the category that text should be in, it can generate the category for more text.
- Answer questions: Trained on questions and their answers, it can generate answers about a particular topic in a particular style.
- Be a chatbot: Trained on a conversation and the next message in that conversation, it can suggest the next response to that conversation.
- Extract structured data from text: Trained on some text and a particular piece of data in that text, it can extract that piece of data from more text.
These things are sometimes possible by creating prompts, but you can only pass a limited amount of data in the prompt.
In this guide, we’ll show you how to create a text summarizer. We’ll be using Llama 2 7B, an open-source large language model from Meta and fine-tuning it on a dataset of messenger-like conversations with summaries. When we’re done, you’ll be able to distill chat transcripts, emails, webpages, and other documents into a brief summary. Short and sweet.
Supported models
You can fine-tune many language models on Replicate, including:
- meta/llama-2-70b-chat: 70 billion parameter model fine-tuned on chat completions. If you want to build a chat bot with the best accuracy, this is the one to use.
- meta/llama-2-13b-chat: 13 billion parameter model fine-tuned on chat completions. Use this if you’re building a chat bot and would prefer it to be faster and cheaper at the expense of accuracy.
- meta/llama-2-7b-chat: 7 billion parameter model fine-tuned on chat completions. This is an even smaller, faster model.
- meta/llama-2-7b: 7 billion parameter base model
- meta/llama-2-70b: 70 billion parameter base model
To see all the language models that currently support fine-tuning, check out our collection of trainable language models.
If you’re looking to fine-tune image models, check out our guide to fine-tuning image models.
Prepare your training data
Your training data should be in a single JSONL text file. JSONL (or “JSON lines”) is a file format for storing structured data in a text-based, line-delimited format. Each line in the file is a standalone JSON object.
If you’re building an instruction-tuned model like a chat bot that answers questions, structure your data using an object with a prompt
key and a completion
key on each line:
{"prompt": "...", "completion": "..."}
{"prompt": "Why don't scientists trust atoms?", "completion": "Because they make up everything!"}
{"prompt": "Why did the scarecrow win an award?", "completion": "Because he was outstanding in his field!"}
{"prompt": "What do you call fake spaghetti?", "completion": "An impasta!"}
If you’re building an autocompleting model to do tasks like completing a user’s writing, code completion, finishing lists, few-shotting specific tasks like classification, or if you want more control over the format of your training data, structure each JSON line as a single object with a text
key and a string value:
{"text": "..."}
{"text": "..."}
{"text": "..."}
Here are some existing datasets to give you a sense of what real training data looks like:
- SAMSum - a collection of messenger-like conversations with summaries.
- Hacker News - all stories and comments from Hacker News from its launch in 2006 to present.
- Elixir docstrings - API reference documentation for all functions in the Elixir programming language’s standard library.
- ABC notated music - a collection of songs in a shorthand form of musical notation for computers.
- Recipes - Lists of food ingredients as prompts, and recipe instructions as responses.
- Bob Dylan Lyrics - Song titles and lyrics from Bob Dylan’s repertoire.
In this guide, we’ll be using the SAMSum dataset, transformed into JSONL.
What factors affect training time?
Training time varies depending on:
- Dataset size: The larger your training dataset, the longer fine-tuning will take. When training on your own data, start with a small dataset (500 to a few thousand lines). If the results are promising, then you can increase the dataset size.
- Training hardware: The more powerful the GPU you use, the faster it will train. Training jobs often run on different hardware from predictions. You can find out what hardware a model uses for training by clicking the “Train” tab. To find out which models can be trained, check out the trainable language models collection.
- Pack sequences: If your training dataset consists of many short examples, you can set the
pack_sequences
input to true to speed up your training. This ‘packs together’ your examples to make full use of the available sequence length (~2048 tokens). This reduces the number of steps needed during training.
How long will it take and how much will it cost?
In this guide, we’ll fine-tune Llama 2 7B, which uses an 8x Nvidia A40 GPU (large) instance for training, costing $0.348 per minute. Training for 3 epochs on the SAMSum dataset should take ~75 minutes, so it will total about $25.
Here are more example training times and costs for some of the datasets mentioned above:
Dataset | Dataset size | Training hardware | Training time | Cost |
---|---|---|---|---|
SAMSum | 11 MB (~3,500,000 tokens) | 8x Nvidia A40 (Large) | 75 minutes (3 epochs) | ~$26 |
Elixir docstrings | 1.5 MB (~450,000 tokens) | 8x Nvidia A40 (Large) | 17 minutes | ~$6 |
ABC notated music | 67 MB (~40,000,000 tokens) | 8x Nvidia A40 (Large) | 12.5 hours | ~$260 |
Create a model
You need to create an empty model on Replicate for your trained model. When your training finishes, it will be pushed as a new version to this model.
Go to replicate.com/create and create a new model called “llama2-summarizer”. You probably want to make it private to start, and you have the option of making it public later.
Authenticate
Authenticate by setting your token in an environment variable:
Upload your training data
If you’re using the example SAMSum dataset, you can skip this section. It’s already uploaded!
If you’ve created your own dataset, you’ll need to upload it somewhere on the internet that is publicly accessible, like an S3 bucket or a GitHub Pages site.
If you like, you can use our API for uploading files. Run these commands to upload your data.jsonl
file:
RESPONSE=$(curl -s -X POST -H "Authorization: Token $REPLICATE_API_TOKEN" https://dreambooth-api-experimental.replicate.com/v1/upload/data.jsonl)
curl -X PUT -H "Content-Type: application/jsonl" --upload-file data.jsonl "$(jq -r ".upload_url" <<< "$RESPONSE")"
SERVING_URL=$(jq -r ".serving_url" <<< $RESPONSE)
echo $SERVING_URL
Create a training
To find out which models can be trained, check out the trainable language models collection.
Install the Python library:
pip install replicate
Then, run this to create a training with meta/llama-2-7b as the base model:
import replicate
training = replicate.trainings.create(
version="meta/llama-2-7b:bf0a2a692f015ee21527ed2668e338032c1f937b4fcfa1f217f5cd79bf33478c",
input={
"train_data": "https://gist.githubusercontent.com/nateraw/055c55b000e4c37d43ce8eb142ccc0a2/raw/d13853512fc83e8c656a3e8b6e1270dd3c398e77/samsum.jsonl",
"num_train_epochs": 3
},
destination=f"{username}/llama2-summarizer"
)
print(training)
It takes these arguments:
version
: The model to train, in the format{username}/{model}:{version}
.input
: The training data and params to pass to the training process, which are defined by the model. Llama 2's params can be found in the model's "Train" tab.destination
: The model to push the trained version to.
$ curl -s -X POST \
-d '{"destination": "{username}/llama2-summarizer", "input": {"train_data": "https://storage.googleapis.com/dan-scratch-public/fine-tuning/70k_samples.jsonl"}, "webhook": "https://example.com/my/webhook/endpoint"}' \
-H "Authorization: Token $REPLICATE_API_TOKEN" \
https://api.replicate.com/v1/models/meta/llama-2-7b/versions/bf0a2a692f015ee21527ed2668e338032c1f937b4fcfa1f217f5cd79bf33478c/trainings
The API response will look like this:
{
"id": "zz4ibbonubfz7carwiefibzgga",
"version": "4841472f9a9d279cf03ba0a8f633b13eccd0e0a033d3af0c9b58830982d33132",
"status": "starting",
"input": {
"data": "..."
},
"output": null,
"error": null,
"logs": null,
"started_at": null,
"created_at": "2023-03-28T21:47:58.566434Z",
"completed_at": null
}
To learn more about creating trainings, take a look at the API documentation.
Wait for the training to finish
Once you’ve kicked off your training, visit replicate.com/trainings in your browser to monitor the progress.
If you set a webhook in the previous step, you’ll receive a POST request at your webhook URL when the training completes.
If you're not using webhooks, you can poll for the status of the training job. When `training.status` is `succeeded`, then your model has been trained.
# If you've got a handle to the object returned by create()
training.reload()
# If you've got the training ID
training = replicate.trainings.get("zz4ibbonubfz7carwiefibzgga")
if training.status == "succeeded":
print(training.output)
# {"weights": "...", "version": "..."}
If you’re not using webhooks, you can poll for the status of the training job. Find the `id` from the JSON response in the previous step, and use it to make a followup API request:
$ curl -s \
-H "Authorization: Token $REPLICATE_API_TOKEN" \
https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga
Check the `status` property of the response. When status is `succeeded`, your new model is ready to use!
Run the model
You can now run your model from the web or with an API. To use your model in the browser, go to your model page.
To use your model with an API, run the version
from the training output:
prompt = """[INST] <<SYS>>\
Use the Input to provide a summary of a conversation.
<</SYS>>
Input:
Harry: Who are you?
Hagrid: Rubeus Hagrid, Keeper of Keys and Grounds at Hogwarts. Of course, you know all about Hogwarts.
Harry: Sorry, no.
Hagrid: No? Blimey, Harry, did you never wonder where yer parents learned it all?
Harry: All what?
Hagrid: Yer a wizard, Harry.
Harry: I-- I'm a what?
Hagrid: A wizard! And a thumpin' good 'un, I'll wager, once you've been trained up a bit. [/INST]
Summary: """
output = replicate.run(
training.output["version"],
input={"prompt": prompt, "stop_sequences": "</s>"}
)
for s in output:
print(s, end="", flush=True)
To use your model with the HTTP API, run the `version` from the training output:
$ curl -s -X POST \
-d '{"version": "$TRAINED_OUTPUT_VERSION", "input": {"prompt": "How do I exit vim?"}}' \
-H "Authorization: Token $REPLICATE_API_TOKEN" \
https://api.replicate.com/v1/predictions
Find the `urls.get` field in the JSON response and call that URL:
$ curl -s \
-H "Authorization: Token $REPLICATE_API_TOKEN" \
https://api.replicate.com/v1/predictions/$PREDICTION_ID
Next steps
- Try training with your own dataset. Take a look at the “Prepare your training data” section above for instructions on how to do that.
- Try training with a different size of Llama 2. Larger models usually get better results, but they’re slower to train and run.