Using open-source models for faster and cheaper text embeddings

Posted by @nateraw

Embeddings are a powerful tool for working with text. By “embedding” text into vectors, you encode its meaning into a representation that can more easily be used for tasks like semantic search, clustering, and classification. If you’re new to embeddings, check out this awesome introduction by Simon Willison to get up to speed. These days, embeddings are being used for even more interesting applications like Retrieval Augmented Generation, which uses semantic search over embeddings to improve the quality of responses from language models.

In this guide, we’ll see how to use the BAAI/bge-large-en-v1.5 model on Replicate to generate text embeddings. The “BAAI General Embedding” (BGE) suite of models, released by the Beijing Academy of Artificial Intelligence (BAAI), are open source and available on the Hugging Face Hub.

As of October 2023, the large BGE model we’ll use here is the current state-of-the-art open source model for text embeddings. It is ranked higher than OpenAI embeddings on the MTEB leaderboard, and is 4x cheaper to run on Replicate for large-scale text embedding (more on this later!).

👇 The code in this post is also available as a hosted, interactive Google Colab notebook:

Open In Colab

Prerequisites

You’ll need:

  • An account on Replicate: You’ll use Replicate to run the BAE model. It’s free to get started, and you get a bit of credit when you sign up. After that, you pay per second for your usage. See how billing works for more details.
  • A Python Environment to follow along in (or you can use the Google Colab notebook instead).

👀 See the model in the Replicate UI here, and more ways to run it (Node.js, cURL, Docker, etc.) here.

Install the dependencies

Start by installing the following dependencies:

pip install replicate

# to count tokens:
pip install transformers sentencepiece

# for our example "samsum" dataset:
pip install datasets py7zr scikit-learn

Authenticate with Replicate

Grab a Replicate API token from replicate.com/account/api-tokens and set it as an environment variable:

export REPLICATE_API_TOKEN=...

Generate embeddings from a list of text

Now you can run the embedding model. We’ll use the replicate library to run the model on Replicate:

import json
import replicate

texts = [
    "the happy cat",
    "the quick brown fox jumps over the lazy dog",
    "lorem ipsum dolor sit amet",
    "this is a test",
]

output = replicate.run(
    "nateraw/bge-large-en-v1.5:9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
    input={"texts": json.dumps(texts)}
)
print(output)

The output here will be a list of embeddings for each text.

Generate embeddings from a JSONL file

JSONL (or “JSON lines”) is a file format for storing structured data in a text-based, line-delimited format. Each line in the file is a standalone JSON object.

Here’s an example of a JSONL file, dummy_example.jsonl:

{"text": "the happy cat"}
{"text": "the quick brown fox jumps over the lazy dog"}
{"text": "lorem ipsum dolor sit amet"}
{"text": "this is a test"}

Run the model on this file by specifying the path input.

output = replicate.run(
    "nateraw/bge-large-en-v1.5:9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
    input={"path": open("dummy_example.jsonl", "rb")}
)
len(output)
# Output:
# 4

Real-world example: Embedding the SAMSum dataset

The SAMSum dataset is a collection of ~14k example dialogues with manually annotated summaries. It is often used for training and evaluating language models.

Here we’ll encode the whole SAMSum dataset. We’ll use the datasets library to load the dataset, convert it to a JSONL file, and then run the BGE model on it to generate text embeddings.

from pathlib import Path

from datasets import load_dataset

dataset_name = "samsum"
text_field = "dialogue"
outfile_name = "samsum_dialogue.jsonl"

ds = load_dataset(dataset_name, split='train')
ds = ds.remove_columns([x for x in ds.column_names if x != text_field])
ds = ds.rename_column(text_field, "text")
texts = ds["text"]
texts[0]
# Output:
# "Amanda: I baked cookies. Do you want some?\r\nJerry: Sure!\r\nAmanda: I'll bring you tomorrow :-)"

To convert the dataset to a JSONL file, call .to_json on the dataset.

ds.to_json(outfile_name)

If all goes well, the dataset should be written to samsum_dialogue.jsonl. Use the head command to see the first few lines of the file:

head -n 5 {outfile_name}

You should see the following:

{"text":"Amanda: I baked  cookies. Do you want some?\r\nJerry: Sure!\r\nAmanda: I'll bring you tomorrow :-)"}
{"text":"Olivia: Who are you voting for in this election? \r\nOliver: Liberals as always.\r\nOlivia: Me too!!\r\nOliver: Great"}
{"text":"Tim: Hi, what's up?\r\nKim: Bad mood tbh, I was going to do lots of stuff but ended up procrastinating\r\nTim: What did you plan on doing?\r\nKim: Oh you know, uni stuff and unfucking my room\r\nKim: Maybe tomorrow I'll move my ass and do everything\r\nKim: We were going to defrost a fridge so instead of shopping I'll eat some defrosted veggies\r\nTim: For doing stuff I recommend Pomodoro technique where u use breaks for doing chores\r\nTim: It really helps\r\nKim: thanks, maybe I'll do that\r\nTim: I also like using post-its in kaban style"}
{"text":"Edward: Rachel, I think I'm in ove with Bella..\r\nrachel: Dont say anything else..\r\nEdward: What do you mean??\r\nrachel: Open your fu**ing door.. I'm outside"}
{"text":"Sam: hey  overheard rick say something\r\nSam: i don't know what to do :-\/\r\nNaomi: what did he say??\r\nSam: he was talking on the phone with someone\r\nSam: i don't know who\r\nSam: and he was telling them that he wasn't very happy here\r\nNaomi: damn!!!\r\nSam: he was saying he doesn't like being my roommate\r\nNaomi: wow, how do you feel about it?\r\nSam: i thought i was a good rommate\r\nSam: and that we have a nice place\r\nNaomi: that's true man!!!\r\nNaomi: i used to love living with you before i moved in with me boyfriend\r\nNaomi: i don't know why he's saying that\r\nSam: what should i do???\r\nNaomi: honestly if it's bothering you that much you should talk to him\r\nNaomi: see what's going on\r\nSam: i don't want to get in any kind of confrontation though\r\nSam: maybe i'll just let it go\r\nSam: and see how it goes in the future\r\nNaomi: it's your choice sam\r\nNaomi: if i were you i would just talk to him and clear the air"}

Let’s embed the dataset. This time we’ll specify convert_to_numpy=True to get the embeddings as a numpy array, which is a more efficient output format for such a large dataset.

import time

start = time.time()
output = replicate.run(
    "nateraw/bge-large-en-v1.5:9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
    input=dict(
        path=open(outfile_name, "rb"),
        convert_to_numpy=True,
        batch_size=64
    )
)
time_to_embed = time.time() - start
print(f"that took {time_to_embed:.2f} seconds.")
print("output", output)

# Output:
# that took 65.51 seconds.
# output https://replicate.delivery/pbxt/ZpzzGcdZf5VbCCgynfufoXww7MtymKITDa0HfAZOOVsvNNJHB/embeddings.npy

Load the predictions

Since we chose to convert to numpy, we’ll load with numpy here.

import requests
from io import BytesIO

import numpy as np

embeds = np.load(BytesIO(requests.get(output).content))
embeds.shape
# Output:
# (14732, 1024)

Price vs. OpenAI

At the time of this writing, OpenAI’s Ada v2 model costs $0.0001 / 1K tokens.

On Replicate, you’re charged by the second for the hardware you’re running on. The nateraw/bge-large-en-v1.5 we’re using here runs on A40 (Large) instances, which cost $0.000725/sec.

Below, we’ll compare both OpenAI and Replicate. To do so, we’ll need to count the number of tokens in the dataset. We’ll use the transformers library to do this:

from datasets import Dataset
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-en-v1.5")

text = """\
Lorem ipsum dolor sit amet, consectetur adipiscing elit, \
sed do eiusmod tempor a b
""" * 16  # Not long enough, need >= 512 tokens, so multiply by 16

ds = Dataset.from_dict({"text": [text] * 10000})

def count_tokens(ex):
    ex['num_tokens'] = len(tokenizer.encode(ex["text"], truncation=True, add_special_tokens=False))
    return ex

ds = ds.map(count_tokens)

In the snippet above, we prepare a benchmark file with 512 tokens per line. This is the maximum number of tokens supported by the BGE model. In total, the dataset has 5,120,000 tokens. Let’s double-check that:

total_tokens = sum(ds['num_tokens'])
total_tokens
# Output:
# 5120000

Finally, we’ll write this dataset to a JSONL file, just as we did earlier.

outfile_name = "benchmark.jsonl"
ds.to_json(outfile_name)

Run the benchmark

Now we’ll run the benchmark. We’ll use replicate.predictions.create to run the model asynchronously. This will return a Prediction object, which we can use to get the results of the run, as well as its associated metrics. We can then use the predict_time to calculate the price of the run.

model = replicate.models.get("nateraw/bge-large-en-v1.5")
version = model.latest_version
prediction = replicate.predictions.create(
    version,
    input=dict(
        path=open(outfile_name, "rb"),
        convert_to_numpy=True,
        batch_size=64
    )
)
prediction.wait()
output = prediction.output
time_to_embed = prediction.metrics['predict_time']
print(f"that took {time_to_embed:.2f} seconds.")
print("output", output)
# Output:
# that took 151.92 seconds.
# output https://replicate.delivery/pbxt/VVrkEaiaem3uHCzAqAOmCaewTobbvrmA20QNpJo8tE39VTyRA/embeddings.npy

Let’s see what the price of this run would have been using the OpenAI API:

openai_cost = 0.0001  # per 1k tokens
openai_price = total_tokens / 1000 * openai_cost
print(f"OpenAI price: ${openai_price:.3f} USD")
# OpenAI price: $0.512 USD

And the price on Replicate:

replicate_price = time_to_embed * 0.000725
print(f"Replicate price: ${replicate_price:.3f}")
# Replicate price: $0.110

The price on Replicate is more than 4x cheaper than OpenAI, and that’s with a model ranked higher on the MTEB leaderboard. 🎉

Next steps

If you enjoyed this post and want to see a more in-depth example of using this text embedding model in the wild, check out this blogpost by @jakedahn that covers how to do Retrieval Augmented Generation (RAG) with ChromaDB and Mistral.

Happy hacking!