afiaka87 / retrieval-augmented-diffusion

Generate 768px images from text using CompVis `retrieval-augmented-diffusion`

  • Public
  • 38.4K runs
  • GitHub
  • Paper
  • License



Run time and cost

This model runs on Nvidia A100 (40GB) GPU hardware. Predictions typically complete within 35 seconds. The predict time for this model varies significantly based on the inputs.



Note: This model is likely to change soon. Make sure to use the model’s unique SHA version if you use it from the API. The current stable version is: ac7878dfd12cb445115fb250a79faf3c3028e6cfa6c94788a7c9c53b0ce5898e

CompVis’ Retrieval Augmented Diffusion, from latent diffusion.

API example: Batching with the Replicate API

# in an interactive environment, you can use these.
from IPython.display import Image as IPythonImage
from IPython.display import display

from tqdm import tqdm
import replicate
from pathlib import Path

import os
import requests
import shutil

def download_image(url, path):
    response = requests.get(url, stream=True)
    with open(path, "wb") as out_file:
        shutil.copyfileobj(response.raw, out_file)
    del response

target_dir = Path("generations")

rdm_model = replicate.models.get("afiaka87/retrieval-augmented-diffusion")

subjects = ["An astronaut", "Teddy bears", "A bowl of soup"]
actions = [
    "riding a horse",
    "lounging in a tropical resort in space",
    "playing basketball with cats in space",
styles = [
    "in a photorealistic style",
    "in the style of Andy Warhol",
    "as a pencil drawing",

prompts = [
    f"{subject} {action} {style}"
    for subject in subjects
    for action in actions
    for style in styles
for prompt_index in tqdm(range(0, len(prompts), 8)):
    if prompt_index + 8 > len(prompts):
        batch = prompts[prompt_index:]
        batch = prompts[prompt_index : prompt_index + 8]

    batch = "|".join(batch) # "|" character is how text is batched on rdm
    print(f"{prompt_index} {batch}")

    # run the model
    generations = rdm_model.predict(
    for generation_index, generation in enumerate(generations):
        target_stub = target_dir / f"{prompt_index:04d}-{generation_index:03d}"

        image_url = generation["image"]
        image_format = image_url.split(".")[-1].lower()
        image_path = target_stub.with_suffix(f".{image_format}")
        download_image(image_url, image_path)
        print(f"Downloaded {image_url} to {image_path}")

        # display in notebook, may want to decrease width for large images
        display(IPythonImage(image_path, width=512)) 

        caption = generation["caption"]
        caption_path = target_stub.with_suffix(".txt")
        with open(caption_path, "w") as f:
        print(f"Saved caption {caption} to {caption_path}")


Retrieval-augmentation works by searching for images similar to the prompt you provide, then allowing the model to see these during generation.

The model can see from 1 to 20 existing datapoints. You can reduce this to narrow the capabilities for that prediction. Increase it to encourage abstraction and composition, however this can result in mode collapse (“flattening” of otherwise photorealistic images, for instance), perhaps because the results are inconsistent in style (drawing, photo, etc.).

simulacra John David Pressman and Katherine Crowson and Simulacra Captions Contributors - 2022 Simulacra Aesthetic Captions Stability AI (Thanks Katherine!)


~700K synthetic CLIP image embeddings, generated with the unCLIP/conditioned-prior from laion-ai.


A ~4M image subset laion-aesthetic dataset from laion-ai. It was curated from laion2B by using their trained aesthetic classifier.

You can view and search the dataset here, if using the default options on the sidebar.

shows the search for cats in laion-aesthetic dataset