afiaka87 / retrieval-augmented-diffusion

Generate 768px images from text using CompVis `retrieval-augmented-diffusion`

  • Public
  • 38.4K runs
  • A100 (80GB)
  • GitHub
  • Paper
  • License

Input

string
Shift + Return to add a new line

(batched) Use up to 8 prompts by separating with a `|` character.

Default: ""

file

(overrides `prompts`) Use an image as the prompt to generate variations of an existing image.

integer
(minimum: 1, maximum: 8)

Number of variations to generate when using only one `text_prompt`, or an `image_prompt`.

Default: 4

string

Which database to use for the semantic search. Different databases have different capabilities.

Default: "laion-aesthetic"

boolean

Whether to use the database for the semantic search.

Default: true

number

Classifier-free unconditional scale for the sampler.

Default: 5

integer
(minimum: 1, maximum: 20)

The number of search results from the retrieval backend to guide the generation with.

Default: 10

integer

Desired width of generated images. Values beside 768 are likely to cause zooming issues.

Default: 768

integer

Desired width of generated images. Values beside 768 are not supported, likely to cause artifacts.

Default: 768

integer

How many steps to run the model for. Using more will make generation take longer. 50 tends to work well.

Default: 50

boolean

Use ddim sampling instead of the faster plms sampling.

Default: false

number

The eta parameter for ddim sampling.

Default: 0

string
Shift + Return to add a new line

(experimental) Use this prompt as a negative prompt for the sampler.

Default: ""

integer

Seed for the random number generator. Set to -1 to use a random seed.

Default: -1

Output

image

image

caption

an astronaut riding a horse in photorealistic style, digital art
Generated in

This example was created by a different version, afiaka87/retrieval-augmented-diffusion:a4ecc1ee.

Run time and cost

This model costs approximately $0.048 to run on Replicate, or 20 runs per $1, but this varies depending on your inputs. It is also open source and you can run it on your own computer with Docker.

This model runs on Nvidia A100 (80GB) GPU hardware. Predictions typically complete within 35 seconds. The predict time for this model varies significantly based on the inputs.

Readme

work-in-progress

Note: This model is likely to change soon. Make sure to use the model’s unique SHA version if you use it from the API. The current stable version is: ac7878dfd12cb445115fb250a79faf3c3028e6cfa6c94788a7c9c53b0ce5898e

CompVis’ Retrieval Augmented Diffusion, from latent diffusion.

API example: Batching with the Replicate API

# in an interactive environment, you can use these.
from IPython.display import Image as IPythonImage
from IPython.display import display

from tqdm import tqdm
import replicate
from pathlib import Path

import os
import requests
import shutil

def download_image(url, path):
    response = requests.get(url, stream=True)
    with open(path, "wb") as out_file:
        shutil.copyfileobj(response.raw, out_file)
    del response


target_dir = Path("generations")
target_dir.mkdir(exist_ok=True)

rdm_model = replicate.models.get("afiaka87/retrieval-augmented-diffusion")

subjects = ["An astronaut", "Teddy bears", "A bowl of soup"]
actions = [
    "riding a horse",
    "lounging in a tropical resort in space",
    "playing basketball with cats in space",
]
styles = [
    "in a photorealistic style",
    "in the style of Andy Warhol",
    "as a pencil drawing",
]

prompts = [
    f"{subject} {action} {style}"
    for subject in subjects
    for action in actions
    for style in styles
]
for prompt_index in tqdm(range(0, len(prompts), 8)):
    if prompt_index + 8 > len(prompts):
        batch = prompts[prompt_index:]
    else:
        batch = prompts[prompt_index : prompt_index + 8]

    batch = "|".join(batch) # "|" character is how text is batched on rdm
    print(f"{prompt_index} {batch}")

    # run the model
    generations = rdm_model.predict(
        prompts=batch,
        database_name="laion-aesthetic",
        scale=5.0,
        use_database=True,
        num_database_results=20,
        width=768,
        height=768,
        steps=50,
        ddim_sampling=False,
        ddim_eta=0.0,
        seed=8675309
    )
    for generation_index, generation in enumerate(generations):
        target_stub = target_dir / f"{prompt_index:04d}-{generation_index:03d}"

        image_url = generation["image"]
        image_format = image_url.split(".")[-1].lower()
        image_path = target_stub.with_suffix(f".{image_format}")
        download_image(image_url, image_path)
        print(f"Downloaded {image_url} to {image_path}")

        # display in notebook, may want to decrease width for large images
        display(IPythonImage(image_path, width=512)) 

        caption = generation["caption"]
        caption_path = target_stub.with_suffix(".txt")
        with open(caption_path, "w") as f:
            f.write(caption)
        print(f"Saved caption {caption} to {caption_path}")

databases

Retrieval-augmentation works by searching for images similar to the prompt you provide, then allowing the model to see these during generation.

The model can see from 1 to 20 existing datapoints. You can reduce this to narrow the capabilities for that prediction. Increase it to encourage abstraction and composition, however this can result in mode collapse (“flattening” of otherwise photorealistic images, for instance), perhaps because the results are inconsistent in style (drawing, photo, etc.).

simulacra

https://github.com/JD-P/simulacra-aesthetic-captions John David Pressman and Katherine Crowson and Simulacra Captions Contributors - 2022 Simulacra Aesthetic Captions Stability AI (Thanks Katherine!)

prompt-engineer

~700K synthetic CLIP image embeddings, generated with the unCLIP/conditioned-prior from laion-ai.

aesthetic

A ~4M image subset laion-aesthetic dataset from laion-ai. It was curated from laion2B by using their trained aesthetic classifier.

You can view and search the dataset here, if using the default options on the sidebar.

shows the search for cats in laion-aesthetic dataset