Fine-tune a language model

You can fine-tune language models to make them better at a particular task. With Replicate, you can fine-tune and run your model in the cloud without having to set up any GPUs.

You can train a language model to do things like:

  • Generate text in a particular style: Trained on text in a particular style, it can generate more text in that style.
  • Classify text: Trained on text and the category that text should be in, it can generate the category for more text.
  • Answer questions: Trained on questions and their answers, it can generate answers about a particular topic in a particular style.
  • Be a chatbot: Trained on a conversation and the next message in that conversation, it can suggest the next response to that conversation.
  • Extract structured data from text: Trained on some text and a particular piece of data in that text, it can extract that piece of data from more text.

These things are sometimes possible by creating prompts, but you can only pass a limited amount of data in the prompt.

In this guide, we’ll show you how to create a text summarizer. We’ll be using Llama 2 7B, an open-source large language model from Meta and fine-tuning it on a dataset of messenger-like conversations with summaries. When we’re done, you’ll be able to distill chat transcripts, emails, webpages, and other documents into a brief summary. Short and sweet.

Supported models

You can fine-tune many language models on Replicate, including:

  • meta/llama-2-70b-chat: 70 billion parameter model fine-tuned on chat completions. If you want to build a chat bot with the best accuracy, this is the one to use.
  • meta/llama-2-13b-chat: 13 billion parameter model fine-tuned on chat completions. Use this if you’re building a chat bot and would prefer it to be faster and cheaper at the expense of accuracy.
  • meta/llama-2-7b-chat: 7 billion parameter model fine-tuned on chat completions. This is an even smaller, faster model.
  • meta/llama-2-7b: 7 billion parameter base model
  • meta/llama-2-70b: 70 billion parameter base model

To see all the language models that currently support fine-tuning, check out our collection of trainable language models.

If you’re looking to fine-tune image models, check out our guide to fine-tuning image models.

Prepare your training data

Your training data should be in a single JSONL text file. JSONL (or “JSON lines”) is a file format for storing structured data in a text-based, line-delimited format. Each line in the file is a standalone JSON object.

If you’re building an instruction-tuned model like a chat bot that answers questions, structure your data using an object with a prompt key and a completion key on each line:

{"prompt": "...", "completion": "..."}
{"prompt": "Why don't scientists trust atoms?", "completion": "Because they make up everything!"}
{"prompt": "Why did the scarecrow win an award?", "completion": "Because he was outstanding in his field!"}
{"prompt": "What do you call fake spaghetti?", "completion": "An impasta!"}

If you’re building an autocompleting model to do tasks like completing a user’s writing, code completion, finishing lists, few-shotting specific tasks like classification, or if you want more control over the format of your training data, structure each JSON line as a single object with a text key and a string value:

{"text": "..."}
{"text": "..."}
{"text": "..."}

Here are some existing datasets to give you a sense of what real training data looks like:

  • SAMSum - a collection of messenger-like conversations with summaries.
  • Hacker News - all stories and comments from Hacker News from its launch in 2006 to present.
  • Elixir docstrings - API reference documentation for all functions in the Elixir programming language’s standard library.
  • ABC notated music - a collection of songs in a shorthand form of musical notation for computers.
  • Recipes - Lists of food ingredients as prompts, and recipe instructions as responses.
  • Bob Dylan Lyrics - Song titles and lyrics from Bob Dylan’s repertoire.

You don’t need to prepare any of the datasets listed above, since they’re already uploaded. However, for reference, here’s how you can use the Hugging Face Datasets library to download the SAMSum dataset, prepare it, and save it as a JSONL file:

# Important: install these dependencies first:
# pip install datasets py7zr
from datasets import load_dataset

PROMPT_TEMPLATE = "[INST] <<SYS>>\nUse the Input to provide a summary of a conversation.\n<</SYS>>\n\nInput:\n{message} [/INST]\n\nSummary: {summary}"


def format_instruction(sample):
    return {"text": PROMPT_TEMPLATE.format(message=sample["dialogue"], summary=sample["summary"])}


ds = load_dataset("samsum", split='train')
print(format_instruction(ds[5])['text'])

ds = ds.map(format_instruction, remove_columns=['id', 'dialogue', 'summary'])
ds.to_json("samsum.jsonl")

Note that we chose to prepare our text using a custom prompt template. The template should produce examples that look like this:

[INST] <<SYS>>
Use the Input to provide a summary of a conversation.
<</SYS>>

Input:
Neville: Hi there, does anyone remember what date I got married on?
Don: Are you serious?
Neville: Dead serious. We're on vacation, and Tina's mad at me about something. I have a strange suspicion that this might have something to do with our wedding anniversary, but I have nowhere to check.
Wyatt: Hang on, I'll ask my wife.
Don: Haha, someone's in a lot of trouble :D
Wyatt: September 17. I hope you remember the year ;) [/INST]

Summary: Wyatt reminds Neville his wedding anniversary is on the 17th of September. Neville's wife is upset and it might be because Neville forgot about their anniversary.

Note that when prompting a model at inference time, you’ll need to use the same template that you used during training. That’s the format the model understands. For example, to prompt the model to summarize the example above, you would use the following:

[INST] <<SYS>>
Use the Input to provide a summary of a conversation.
<</SYS>>

Input:
Neville: Hi there, does anyone remember what date I got married on?
Don: Are you serious?
Neville: Dead serious. We're on vacation, and Tina's mad at me about something. I have a strange suspicion that this might have something to do with our wedding anniversary, but I have nowhere to check.
Wyatt: Hang on, I'll ask my wife.
Don: Haha, someone's in a lot of trouble :D
Wyatt: September 17. I hope you remember the year ;) [/INST]

Summary:

The model should then know to fill in the summary section with a summary of the conversation.

What factors affect training time?

Training time varies depending on:

  • Dataset size: The larger your training dataset, the longer fine-tuning will take. When training on your own data, start with a small dataset (500 to a few thousand lines). If the results are promising, then you can increase the dataset size.
  • Training hardware: The more powerful the GPU you use, the faster it will train. Training jobs often run on different hardware from predictions. You can find out what hardware a model uses for training by clicking the “Train” tab. To find out which models can be trained, check out the trainable language models collection.
  • Pack sequences: If your training dataset consists of many short examples, you can set the pack_sequences input to true to speed up your training. This ‘packs together’ your examples to make full use of the available sequence length (~2048 tokens). This reduces the number of steps needed during training.

How long will it take and how much will it cost?

In this guide, we’ll fine-tune Llama 2 7B, which uses an 8x Nvidia A40 GPU (large) instance for training, costing $0.348 per minute. Training for 3 epochs on the SAMSum dataset should take ~75 minutes, so it will total about $25.

Here are more example training times and costs for some of the datasets mentioned above:

Dataset Dataset size Training hardware Training time Cost
SAMSum 11 MB (~3,500,000 tokens) 8x Nvidia A40 (Large) 75 minutes (3 epochs) ~$26
Elixir docstrings 1.5 MB (~450,000 tokens) 8x Nvidia A40 (Large) 17 minutes ~$6
ABC notated music 67 MB (~40,000,000 tokens) 8x Nvidia A40 (Large) 12.5 hours ~$260

Create a model

You need to create an empty model on Replicate for your trained model. When your training finishes, it will be pushed as a new version to this model.

Go to replicate.com/create and create a new model called “llama2-summarizer”. You probably want to make it private to start, and you have the option of making it public later.

Authenticate

Authenticate by setting your token in an environment variable:

Upload your training data

If you’re using the example SAMSum dataset, you can skip this section. It’s already uploaded!

If you’ve created your own dataset, you’ll need to upload it somewhere on the internet that is publicly accessible, like an S3 bucket or a GitHub Pages site.

If you like, you can use our API for uploading files. Run these commands to upload your data.jsonl file:

RESPONSE=$(curl -s -X POST -H "Authorization: Token $REPLICATE_API_TOKEN" https://dreambooth-api-experimental.replicate.com/v1/upload/data.jsonl)

curl -X PUT -H "Content-Type: application/jsonl" --upload-file data.jsonl "$(jq -r ".upload_url" <<< "$RESPONSE")"

SERVING_URL=$(jq -r ".serving_url" <<< $RESPONSE)
echo $SERVING_URL

Create a training

To find out which models can be trained, check out the trainable language models collection.

Install the Python library:

pip install replicate

Then, run this to create a training with meta/llama-2-7b as the base model:

import replicate

training = replicate.trainings.create(
  version="meta/llama-2-7b:73001d654114dad81ec65da3b834e2f691af1e1526453189b7bf36fb3f32d0f9",
  input={
    "train_data": "https://gist.githubusercontent.com/nateraw/055c55b000e4c37d43ce8eb142ccc0a2/raw/d13853512fc83e8c656a3e8b6e1270dd3c398e77/samsum.jsonl",
    "num_train_epochs": 3
  },
  destination=f"{username}/llama2-summarizer"
)

print(training)

It takes these arguments:

  • version: The model to train, in the format {username}/{model}:{version}.
  • input: The training data and params to pass to the training process, which are defined by the model. Llama 2's params can be found in the model's "Train" tab.
  • destination: The model to push the trained version to.
curl -s -X POST \
  -d '{"destination": "{username}/llama2-summarizer", "input": {"train_data": "https://storage.googleapis.com/dan-scratch-public/fine-tuning/70k_samples.jsonl"}, "webhook": "https://example.com/my/webhook/endpoint"}' \
  -H "Authorization: Token $REPLICATE_API_TOKEN" \
  https://api.replicate.com/v1/models/meta/llama-2-7b/versions/73001d654114dad81ec65da3b834e2f691af1e1526453189b7bf36fb3f32d0f9/trainings

The API response will look like this:

{
  "id": "zz4ibbonubfz7carwiefibzgga",
  "version": "73001d654114dad81ec65da3b834e2f691af1e1526453189b7bf36fb3f32d0f9",
  "status": "starting",
  "input": {
    "data": "..."
  },
  "output": null,
  "error": null,
  "logs": null,
  "started_at": null,
  "created_at": "2023-03-28T21:47:58.566434Z",
  "completed_at": null
}

To learn more about creating trainings, take a look at the API documentation.

Wait for the training to finish

Once you’ve kicked off your training, visit replicate.com/trainings in your browser to monitor the progress.

If you set a webhook in the previous step, you’ll receive a POST request at your webhook URL when the training completes.

If you're not using webhooks, you can poll for the status of the training job. When `training.status` is `succeeded`, then your model has been trained.

# If you've got a handle to the object returned by create()
training.reload()

# If you've got the training ID
training = replicate.trainings.get("zz4ibbonubfz7carwiefibzgga")

if training.status == "succeeded":
    print(training.output)
    # {"weights": "...", "version": "..."}

If you’re not using webhooks, you can poll for the status of the training job. Find the `id` from the JSON response in the previous step, and use it to make a followup API request:

curl -s \
  -H "Authorization: Token $REPLICATE_API_TOKEN" \
  https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga

Check the `status` property of the response. When status is `succeeded`, your new model is ready to use!

Run the model

You can now run your model from the web or with an API. To use your model in the browser, go to your model page.

To use your model with an API, run the version from the training output:

prompt = """[INST] <<SYS>>\
Use the Input to provide a summary of a conversation.
<</SYS>>

Input:
Harry: Who are you?
Hagrid: Rubeus Hagrid, Keeper of Keys and Grounds at Hogwarts. Of course, you know all about Hogwarts.
Harry: Sorry, no.
Hagrid: No? Blimey, Harry, did you never wonder where yer parents learned it all?
Harry: All what?
Hagrid: Yer a wizard, Harry.
Harry: I-- I'm a what?
Hagrid: A wizard! And a thumpin' good 'un, I'll wager, once you've been trained up a bit. [/INST]

Summary: """

output = replicate.run(
  training.output["version"],
  input={"prompt": prompt, "stop_sequences": "</s>"}
)
for s in output:
  print(s, end="", flush=True)

To use your model with the HTTP API, run the `version` from the training output:

curl -s -X POST \
  -d '{"version": "$TRAINED_OUTPUT_VERSION", "input": {"prompt": "How do I exit vim?"}}' \
  -H "Authorization: Token $REPLICATE_API_TOKEN" \
  https://api.replicate.com/v1/predictions

Find the `urls.get` field in the JSON response and call that URL:

curl -s \
  -H "Authorization: Token $REPLICATE_API_TOKEN" \
  https://api.replicate.com/v1/predictions/$PREDICTION_ID

Results

Let’s prove our assumption that the fine-tuned model will be able to summarize better than the base model. We’ll compare the results of the fine-tuned model to the base model on the same prompt (the one we show above in the “Run the model” section).

Model Name Output
Fine-Tuned on SAMSum Neville and Tina are on vacation. Neville has a suspicion that Tina is mad at him because it’s their wedding anniversary. Neville’s wedding anniversary is September 17.
Base Llama-2-7b Neville: Hey guys! Does anyone know when we got married? Don: You are kidding right? Neville: Nope, it is really important for some reason. Our anniversary is coming up soon, so I want to make sure I don’t miss it. Wyatt: Hold on, let me ask my wife. Don: Ha ha, looks like he’s going to get into big trouble :) Wyatt: It was September 17th. Just in case you forgot the year too…

Next steps

  • Try training with your own dataset. Take a look at the “Prepare your training data” section above for instructions on how to do that.
  • Try training with a different size of Llama 2. Larger models usually get better results, but they’re slower to train and run.