Fine-tune Llama 2 on Replicate

Posted by @cbh123

Llama 2 is the first open-source language model of the same caliber as OpenAI’s models, and because it’s open source you can hack it to do new things that aren’t possible with GPT-4.

Like become a better poet. Talk like Homer Simpson. Write Midjourney prompts. Or replace your best friends.

A stampede of futuristic llamas
A stampede of futuristic llamas by ai-forever/kandinsky-2.2

One of the main reasons to fine-tune models is so you can use a small model do a task that would normally require a large model. This means you can do the same task, but cheaper and faster. For example, the 7 billion parameter Llama 2 model is not good at summarizing text, but we can teach it how.

In this guide, we’ll show you how to create a text summarizer. We’ll be using Llama 2 7B, an open-source large language model from Meta and fine-tuning it on a dataset of messenger-like conversations with summaries. When we’re done, you’ll be able to distill chat transcripts, emails, webpages, and other documents into a brief summary. Short and sweet.

This is a short guide just to get your started. If you want to dig into this in more depth or create your own dataset, take a look at our guide for fine-tuning language models on Replicate.

Supported models

Here are the Llama models on Replicate that you can fine-tune:

If your model is responding to instructions from users, you want to use the chat models. If you are just completing text, you’ll want to use the base.

Training data

Your training data should be in a JSONL text file. To learn more about contructing datasets, take a look at our full guide to fine-tuning language models.

In this guide, we’ll be using the SAMSum dataset, transformed into JSONL.

Create a model

You need to create an empty model on Replicate for your trained model. When your training finishes, it will be pushed as a new version to this model.

Go to replicate.com/create and create a new model called “llama2-summarizer”.

Authenticate

Authenticate by setting your token in an environment variable:

export REPLICATE_API_TOKEN=<paste-your-token-here>

Find your API token in your account settings.

Create a training

Install the Python library:

pip install replicate

And kick off training, replacing the destination name with your username and the name of your new model:

import replicate

training = replicate.trainings.create(
  version="meta/llama-2-7b:73001d654114dad81ec65da3b834e2f691af1e1526453189b7bf36fb3f32d0f9",
  input={
    "train_data": "https://gist.githubusercontent.com/nateraw/055c55b000e4c37d43ce8eb142ccc0a2/raw/d13853512fc83e8c656a3e8b6e1270dd3c398e77/samsum.jsonl",
    "num_train_epochs": 3,
  },
  destination="<your-username>/llama2-summarizer"
)

print(training)

It takes these arguments:

  • version: The model to train, in the format {username}/{model}:{version}.
  • input: The training data and params to pass to the training process, which are defined by the model. Llama 2’s params can be found in the model’s “Train” tab.
  • destination: The model to push the trained version to, in the format your-username/your-model-name

Once you’ve kicked off your training, visit replicate.com/trainings in your browser to monitor the progress.

Run the model

You can now run your model from the web or with an API. To use your model in the browser, go to your model page.

To use your model with an API, run the version from the training output:

training.reload()

prompt = """[INST] <<SYS>>\
Use the Input to provide a summary of a conversation.
<</SYS>>

Input:
Harry: Who are you?
Hagrid: Rubeus Hagrid, Keeper of Keys and Grounds at Hogwarts. Of course, you know all about Hogwarts.
Harry: Sorry, no.
Hagrid: No? Blimey, Harry, did you never wonder where yer parents learned it all?
Harry: All what?
Hagrid: Yer a wizard, Harry.
Harry: I-- I'm a what?
Hagrid: A wizard! And a thumpin' good 'un, I'll wager, once you've been trained up a bit. [/INST]

Summary: """

output = replicate.run(
  training.output["version"],
  input={"prompt": prompt, "stop_sequences": "</s>"}
)
for s in output:
  print(s, end="", flush=True)

That’s it! You’ve fine-tuned Llama 2 and can run your new model with an API.

Next steps

Happy hacking! 🦙