Fine-tune Llama 2 for English to Hinglish translation with axolotl

Posted by @nateraw

In this guide, we’ll fine tune Meta’s Llama-2-7b for language translation on Replicate using axolotl. Our goal is to train a model that can take an English input and translate it to Hinglish, a hybrid language that blends elements of both Hindi and English. We’ll use axolotl for training, which is a helpful library for training language models.

The full code for this guide is available here ➡️ on GitHub.

Hinglish is widely used in the Indian subcontinent, particularly in urban areas and among younger generations who are fluent in both languages. This linguistic fusion occurs both in spoken and written forms, often incorporating English words into Hindi sentences and vice versa. It’s common to see it spoken in SMS messages, email, or on social media.

Most language translation systems focus on Hindi written with Devanagari script, the writing system used in modern Hindi. Hinglish, however, is written in the Latin alphabet, which may be easier to read for non-native Hindi speakers learning the language (and certainly easier to type!).

For example:

  • English: Where is the book?
  • Hindi: किताब कहाँ है?
  • Hinglish: book kaha hai?

Since there aren’t many machine learning models out there for translating to and from Hinglish, let’s train one ourselves so we can impress our Hindustani friends and coworkers.

The model we train in this guide can be found here ➡️ on Replicate.

What this guide covers:

  1. How to use cog to write our own custom trainer that runs on replicate (if you are not already familiar with cog, you can read about it here.)
  2. Preparing a custom dataset with English and Hinglish pairs
  3. Training on Replicate
  4. Inference with our fine-tuned translation model

This is a more advanced guide, so if you’re new to Replicate or fine-tuning language models, you may want to check out this guide first for a more beginner friendly walkthrough.

Writing a custom trainer

To add training support for a model with cog, we have to write a Python script that holds the training code, and then point to it from our cog.yaml file. Here’s our cog.yaml:

# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md

build:
  # set to true if your model requires a GPU
  gpu: true
  cuda: "12.1"

  # python version in the form '3.8' or '3.8.12'
  python_version: "3.11"

  # a list of packages in the format <package-name>==<version>
  python_packages:
    - "aiohttp[speedups]"
    - "torch==2.1.1"
    - "packaging==23.2"

  # commands run after the environment is setup
  run:
    - 'git clone https://github.com/OpenAccess-AI-Collective/axolotl && pip install -e "./axolotl[flash-attn,deepspeed]"'
    - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.5.4/pget_linux_x86_64" && chmod +x /usr/local/bin/pget

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
# train.py is the script that trains your model
train: "train.py:train"

We set up our environment by running the pip install command provided by axolotl, which includes support for all sorts of goodies: flash attention, deepspeed, FSDP, weight quantization with bitsandbytes, etc.

Note that we specify train: "train.py:train" - this means that cog will look for a file, train.py, which should include a function, train, that runs the training logic. The input arguments to the train function will be the available parameters to change via the replicate api when you kick off a training.

From a high level, the training script will look like this:

from cog import BaseModel, Input, Path


class TrainingOutput(BaseModel):
    weights: Path


def train(config: Path = Input(description="axolotl config file")) -> TrainingOutput:
    # run training code
    do_training()

    # zip up outputs (model weights, config, etc.) to training_output.zip
    prepare_output_zipfile()

    # Return training output, specifying the path to archived outputs
    return TrainingOutput(weights="training_output.zip")

Axolotl training is configured by a single config file, so we’ll specify it as an input to our train function. You can see some examples of these config files in the axolotl repo.

We return a TrainingOutput object, which specifies the path to our archived outputs. In our case, we’ll zip up the outputs to training_output.zip and return that. This will be uploaded to replicate and available for download after the training is complete.

Following the usage instructions from axolotl’s documentation, you would run the following to train a model locally:

accelerate launch axolotl.cli.train your_config.yaml

Our training script will work by running the same command, but via subprocess.

See the full code for train.py below (or on GitHub):

train.py
import os
import subprocess
import time
from argparse import ArgumentParser
from zipfile import ZipFile

import psutil
import torch
import yaml
from cog import BaseModel, Input, Path

from zipfile import ZipFile


def zip_files(directory, output_path, file_paths):
    with ZipFile(output_path, "w") as zip:
        for file_path in file_paths:
            print(f"Adding file to {output_path}: {file_path}")
            zip.write(file_path, arcname=file_path.relative_to(directory))


# Enables anonymous logging to wandb
os.environ["HF_HOME"] = "./hf-cache"
os.environ["WANDB_ANONYMOUS"] = "must" if not os.environ.get("WANDB_API_KEY") else "allow"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# where the adapter weights will be saved
OUTPUT_DIR = "./lora-out"


class TrainingOutput(BaseModel):
    weights: Path


def train(
    config: Path = Input(description="axolotl config file"),
    mixed_precision: str = Input(default="bf16", description="Mixed precision (no,fp16,bf16,fp8)"),
) -> TrainingOutput:
    axolotl_cfg = yaml.safe_load(config.read_text())
    print(f"----- axolotl_cfg -----\n{yaml.dump(axolotl_cfg)}\n-----------------------\n")
    base_model = axolotl_cfg["base_model"].lower()
    num_gpus = torch.cuda.device_count()
    multi_gpu = True if num_gpus > 1 else False
    cmd = (
        [
            "accelerate",
            "launch",
            "-m",
        ]
        + (["--multi_gpu"] if multi_gpu else [])
        + [
            f"--mixed_precision={mixed_precision}",
            f"--num_processes={num_gpus}",
            "--num_machines=1",
            "--dynamo_backend=no",
            "axolotl.cli.train",
            f"{config}",
            f"--base_model={base_model}",
            f"--output_dir={OUTPUT_DIR}",
            "--save_total_limit=1",
        ]
    )

    print("-" * 80)
    print(cmd)
    print("-" * 80)

    p = None
    try:
        p = subprocess.Popen(cmd, close_fds=False)
        p.wait()
        return_code = p.poll()
        if return_code != 0:
            raise Exception(f"Training failed with exit code {return_code}! Check logs for details")
        directory = Path(OUTPUT_DIR)
        weights_out_path = Path("training_output.zip")
        zip_files(directory, weights_out_path, sorted(f for f in directory.glob("*") if f.is_file()))
        return TrainingOutput(weights=weights_out_path)
    finally:
        if p and p.poll() is None:
            top = psutil.Process(p.pid)
            children = top.children(recursive=True)
            for process in children + [top]:
                process.terminate()
            _, alive = psutil.wait_procs(children + [top], timeout=5)
            if alive:
                for process in alive:
                    print(f"process {process.pid} survived termination")
            else:
                print("terminated all processes successfully")


def parse_args(args: str = None):
    parser = ArgumentParser()
    parser.add_argument("--config", type=Path, required=True, help="axolotl config file")
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="bf16",
        help="Mixed Precision config for accelerate, which launches the script. (no,fp16,bf16,fp8)",
    )
    return parser.parse_args(args=args)


if __name__ == "__main__":
    args = parse_args()
    print(args)
    train(**vars(args))

Setting up the predictor

Now that we have our training script, we can set up the cog predictor to run inference on our fine-tuned model.

Here’s a high level overview of what our predictor will look like:

from cog import BasePredictor, Input, Path


BASE_MODEL_ID = "nousresearch/llama-2-7b-hf"
PEFT_MODEL_DIR = "peft_model"

class Predictor(BasePredictor):
    def setup(self, weights: Path = None):

        # Download training_output.zip if provided + unzip it
        if weights:
            download_and_unzip(weights, PEFT_MODEL_DIR)

        # Initialize base model + tokenizer
        self.model, self.tokenizer = ...

        # Load adapter weights
        if weights:
            self.model.load_adapter(PEFT_MODEL_DIR)

    def predict(...):
        # make predictions with model
        ...

When a model is fine-tuned with our training script, its training_output.zip will be passed along to the weights argument of our predictor’s setup function. If no weights are provided, we’ll just use the base model.

Here’s the code for the predictor (which you can also find on GitHub):

predict.py
import os
from threading import Thread
from typing import Optional

import torch
from cog import BasePredictor, ConcatenateIterator, Input, Path

from utils import download_and_unzip_weights

# Set HF_HOME before importing transformers
CACHE_DIR = "./hf-cache"
os.environ["HF_HOME"] = CACHE_DIR
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)
from peft import PeftConfig


BASE_MODEL_ID = "nousresearch/llama-2-7b-hf"
PEFT_MODEL_DIR = "peft_model"
PROMPT_TEMPLATE = "{prompt}"


class Predictor(BasePredictor):
    def setup(self, weights: Optional[Path] = None):
        print("Starting setup")
        base_model_id = BASE_MODEL_ID

        if weights:
            print(f"Weights: {weights}")
            download_and_unzip_weights(weights, PEFT_MODEL_DIR)
            config = PeftConfig.from_pretrained(PEFT_MODEL_DIR)
            base_model_id = config.base_model_name_or_path
            print(f"Overriding default Base model id {BASE_MODEL_ID} with: {base_model_id}")
        else:
            print("----- NOT USING ADAPTER MODEL -----")

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = AutoModelForCausalLM.from_pretrained(
            base_model_id,
            torch_dtype="auto",
            device_map="auto",
        )
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_id)

        if weights:
            print(f"loading adapter from {PEFT_MODEL_DIR}")
            self.model.load_adapter(PEFT_MODEL_DIR)

    def predict(
        self,
        prompt: str,
        max_new_tokens: int = Input(
            description="The maximum number of tokens the model should generate as output.",
            default=512,
        ),
        temperature: float = Input(
            description="The value used to modulate the next token probabilities.", default=0.7
        ),
        do_sample: bool = Input(
            description="Whether or not to use sampling; otherwise use greedy decoding.", default=True
        ),
        top_p: float = Input(
            description="A probability threshold for generating the output. If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751).",
            default=0.95,
        ),
        top_k: int = Input(
            description="The number of highest probability tokens to consider for generating the output. If > 0, only keep the top k tokens with highest probability (top-k filtering).",
            default=50,
        ),
        prompt_template: str = Input(
            description="The template used to format the prompt before passing it to the model. For no template, you can set this to `{prompt}`.",
            default=PROMPT_TEMPLATE,
        ),
    ) -> ConcatenateIterator:
        prompt = prompt_template.format(prompt=prompt)
        print(f"=== Formatted Prompt ===\n{prompt}\n{'=' * 24}\n")
        inputs = self.tokenizer([prompt], return_tensors="pt", return_token_type_ids=False).to(self.device)
        streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
        generate_kwargs = dict(
            **inputs,
            streamer=streamer,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            top_k=top_k,
            num_beams=1,
            **({"temperature": temperature, "top_p": top_p} if do_sample else {}),
        )
        t = Thread(target=self.model.generate, kwargs=generate_kwargs)
        t.start()
        for text in streamer:
            yield text

# For local inference, hard coding in the translation prompt template
_prompt_template = """\
### System:
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Translate the input from English to Hinglish

### Input:
{prompt}

### Response:
"""


if __name__ == "__main__":
    p = Predictor()
    p.setup(weights="training_output.zip")
    for text in p.predict(
        "What time is the game tomorrow?",
        max_new_tokens=512,
        temperature=0.0,
        do_sample=False,
        top_p=0.95,
        top_k=50,
        prompt_template=_prompt_template,
    ):
        print(text, end="")

Preparing a custom dataset

📝 Note: If you’re following along and want to use the same dataset as ours, you can skip this section and just use the one we already pushed to the Hugging Face Hub: nateraw/axolotl-english-to-hinglish.

As mentioned earlier, axolotl training is configured by a single config file. This config file specifies (among other things) the dataset to use for training. To train our translation model, we’ll need to prepare a dataset with English and Hinglish pairs and make sure it’s in a format that axolotl can work with.

Thankfully, there’s a nice dataset available on the Hugging Face Hub, findnitai/english-to-hinglish, that includes ~189,000 English-Hinglish pairs from various real and synthetic sources. We just have to convert it to the format axolotl expects.

The format we’ll be using is the alpaca format, which expects instruction, input, and output columns. An example used by axolotl is mhenrichsen/alpaca_2k_test, so we’ll use it as a reference to match its format. The instruction column is a string that describes the task, and the input and output columns are the input and output of the task, respectively (so in our case, the input column will be English and the output column will be Hinglish).

Prompt format

During training, the model will see examples in the following format:

### System:
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Translate the input from English to Hinglish

### Input:
{input}

### Response:
{output}

Then, at inference time, we’ll use the following template to prompt the model for a translation (including a newline at the end). The next token the model generates will be the start of the translation.

### System:
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Translate the input from English to Hinglish

### Input:
{prompt}

### Response:

Reformatting the dataset

We’ll use the datasets library to reformat the dataset.

from datasets import load_dataset

ds = load_dataset("findnitai/english-to-hinglish", split="train")
ds = ds.train_test_split(test_size=0.055, seed=1234)

# remap input/output cols, add instruction feature, then reorder columns
ds = ds.map(lambda ex: ex['translation'], remove_columns=['translation']).rename_columns({'en': 'input', 'hi_ng': 'output'})
ds = ds.map(lambda ex: {'instruction': "Translate the input from English to Hinglish", **ex})
ds = ds.select_columns(["instruction", "input", "output", "source"])

Then, we’ll push the dataset to the Hugging Face Hub so we can access it from our training script. You’ll need to be authenticated with the Hugging Face Hub to do this.

from huggingface_hub import create_repo

repo_name = "axolotl-english-to-hinglish"
repo_url = create_repo(repo_name, repo_type="dataset", exist_ok=True)
repo_id = repo_url.repo_id
ds.push_to_hub(repo_id)

You can check out the prepared dataset here.

Running locally

To run everything locally, you can clone the example repo, cd axolotl-training-minimal, and run:

cog run python train.py --config config/debug.yaml

This will produce a training_output.zip file in the current directory, which you can then use to run inference with the predictor:

cog run python predict.py

Pushing the model to Replicate

To push our model to Replicate, we’ll need to create a new model on Replicate to push the image to. The one from this guide is here.

Then, we’ll run cog push to push the model to Replicate:

cog push r8.im/nateraw/axolotl-trainer-llama-2-7b

Train the model on Replicate

Now that we’ve pushed our model, we launch a training using Replicate’s Python library.

First, we’ll need to authenticate by setting our Replicate API token as an environment variable. Sign up for Replicate or log in if you haven’t already. Then, you can find your API token on your account page.

export REPLICATE_API_TOKEN=<paste-your-token-here>

Then, we can define our axolotl config file and kick off a training. Here’s our config file:

config.yaml
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
  - path: nateraw/axolotl-english-to-hinglish
    type: alpaca
dataset_prepared_path:
val_set_size: 0
output_dir: ./lora-out

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project: axolotl-english-to-hinglish
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 10
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

######DEBUGGING########################
# max_steps: 8
# warmup_steps: 0
#
# Note: comment out warmup_steps below
#######################################

warmup_steps: 10
evals_per_epoch: 1
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 0
save_total_limit: 0
debug: true
deepspeed: deepspeed/zero2.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

From Python, we can run the following, which will create a destination model (the name of the fine-tuned model) if it doesn’t already exist, and then kick off a training:

import replicate
from replicate.exceptions import ReplicateException

dest_model_id = "nateraw/llama-2-7b-english-to-hinglish"

try:
    replicate.models.create(
        owner=dest_model_id.split("/")[0],
        name=dest_model_id.split("/")[1],
        visibility="public",
        hardware="gpu-a40-large"  # This is the hardware used for inference
    )
except ReplicateException as e:
    print("Model already exists")

training = replicate.trainings.create(
  version="nateraw/axolotl-trainer-llama-2-7b:895d22a6364e09e25fc2ba5ce8c23a0f8e957f7ac69620ab8dd21a07cde2cfb5",
  input={
    "config": open("config.yaml", "rb"),
  },
  destination="nateraw/axolotl-llama-2-7b-english-to-hinglish"
)
print(training)

If we check out the trainings dashboard, you should see your training status. You can click on it to see the logs and monitor the progress. When it’s finished, you can run inference on it from the web or with an API, or download the weights and run inference locally.

Run the model

Here’s the trained model from this guide, which you can use to translate English to Hinglish.

Note that the prompt_template in our trained model’s default example is set to match the template used during training. If you’re using a different dataset than the one in this guide, you’ll want to update the prompt template with the correct one given your training dataset configuration.

To make predictions with Python using the Replicate Python library, you can run:

import replicate

output = replicate.run(
    "nateraw/axolotl-llama-2-7b-english-to-hinglish:03c8cc6582309c28ec5fdea84c94f49085fb105a1137f4771525376a88d8d95f",
    input={
        "prompt": "What's happening?",
        "do_sample": False,
    }
)
# The nateraw/axolotl-llama-2-7b-english-to-hinglish model can stream output as it's running.
# The predict method returns an iterator, and you can iterate over that output.
for item in output:
    # https://replicate.com/nateraw/axolotl-llama-2-7b-english-to-hinglish/api#output-schema
    print(item, end="")

See other ways to run inference (Node.js, curl,docker) with the model here.

Next Steps

This was an advanced guide to setting up a custom language model trainer on Replicate using axolotl. Here are some additional resources you may find helpful: