Fine-tune SDXL with your own images

Stability AI recently open-sourced SDXL, the newest and most powerful version of Stable Diffusion yet. Replicate was ready from day one with a hosted version of SDXL that you can run from the web or using our cloud API.

Today, we’re following up to announce fine-tuning support for SDXL 1.0. Fine-tuning allows you to train SDXL on a particular object or style, and create a new model that generates images of those objects or styles. For example, we fine-tuned SDXL on images from the Barbie movie and our colleague Zeke. There are multiple ways to fine-tune SDXL, such as Dreambooth, LoRA diffusion (Originally for LLMs), and Textual Inversion. We’ve got all of these covered for SDXL 1.0.

In this post, we’ll show you how to fine-tune SDXL on your own images with one line of code and publish the fine-tuned result as your own hosted public or private model. You can train your model with just a few images, and the training process takes about 10-15 minutes. You can also download your fine-tuned LoRA weights to use elsewhere.

Contents

What is fine-tuning?

Fine-tuning is a process of taking a pre-trained model and training it with more data to create a new model that is better suited to a particular task. You can fine-tune image generation models like SDXL on your own images to create a new version of the model that is better at generating images of a particular person, object, or style.

Prepare your training images

The training API expects a zip file containing your training images. A handful of images (5-6) is enough to fine-tune SDXL on a single person, but you might need more if your training subject is more complex or the images are very different.

Check out the example datasets in the SDXL repository for inspiration.

Keep the following guidelines in mind when preparing your training images:

  • Images can be of yourself, your pet, your favorite stuffed animal, or any unique object.
  • Images should contain only the subject itself, without background noise or other objects.
  • Do not use images of other people without their consent.
  • Images can be in JPEG or PNG format.
  • Dimensions and size don’t matter.
  • Filenames don’t matter.

Put your images in a folder and zip it up. The directory structure of the zip file doesn’t matter:

zip -r data.zip data

Add your Replicate API token

Before starting the training job you need to grab your Replicate API token from replicate.com/account. In your shell, store that token in an environment variable called REPLICATE_API_TOKEN.

export REPLICATE_API_TOKEN=r8_...

Upload your training data

Upload your zip file of training data somewhere on the internet that is publicly accessible, like an S3 bucket or a GitHub Pages site.

Create a model

You also need to create a model on Replicate that will be the destination for the trained SDXL version. Go to replicate.com/create to create the model. In the example below we call it my-name/my-model.

You can make your model public or private. If your model is private, only you will be able to run it. If your model is public, anyone will be able to run it, but only you will be able to update it.

Start the training

Now that you’ve gathered your training data and created a model, it’s time to start the training process using Replicate’s API.

This guide uses Python, but if you want to use another language you can use a client library or call the HTTP API directly

If you don’t already have a Python environment configured, you can kick off the training process using a hosted Jupyter notebook on Google Colab:

Jupyter notebook on Google Colab

Start by installing the Replicate Python package:

pip install replicate

Then create a training:

import replicate

training = replicate.trainings.create(
    version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
    input={
        "input_images": "https://my-domain/my-input-images.zip",
    },
    destination="my-name/my-model"
)
print(training)

The input_images input parameter is required, but there are other inputs you can set as well. See the training inputs in the SDXL README for a full list of inputs. Note that by default we will be using LoRA for training, and if you instead want to use Dreambooth you can set is_lora to false. If you wish to perform just the textual inversion, you can set lora_lr to 0.

Fine-tune using Dreambooth + LoRA with faces dataset

If you’re fine-tuning on faces the default training parameters will work well, but you can also use the use_face_detection_instead setting. This will automatically use face segmentation so that training is focused only on the faces in your images.

import replicate

training = replicate.trainings.create(
    version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
    input={
        "input_images": "https://my-domain/face-images.zip",
        "use_face_detection_instead": True,
    },
    destination="my-name/my-model"
)

Fine-tuning a style

To get the best results for a style you need to:

  • speed up the lora learning rate, this stops the training from focusing too closely on the details. Experiment with different values like 1e-4, 2e-4. Our Barbie fine-tune used 4e-4.
  • use a different caption_prefix to refer to a style
import replicate

training = replicate.trainings.create(
    version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
    input={
        "input_images": "https://my-domain/style-images.zip",
        "lora_lr": 2e-4,
        "caption_prefix": 'In the style of TOK,',
    },
    destination="my-name/my-model"
)

To show what’s possible, we put together a couple of fine-tunes based on the Barbie and Tron: Legacy movies.

Monitor training progress

Visit replicate.com/trainings to follow the progress of your training job, or inspect the training programmatically:

training.reload()
print(training.status)
print("\n".join(training.logs.split("\n")[-10:]))

Run the model

When the model has finished training you can run it using the GUI on replicate.com/my-name/my-model, or via the API:

output = replicate.run(
    "my-name/my-model:abcde1234...",
    input={"prompt": "a photo of TOK riding a rainbow unicorn"},
)

The trained concept is named TOK by default, but you can change that by setting token_string and caption_prefix inputs during the training process.

How fine-tuning works

Before fine-tuning starts, the input images are preprocessed using multiple models:

  • SwinIR upscales the input images to a higher resolution.
  • BLIP generates text captions for each input image.
  • CLIPSeg removes regions of the images that are not interesting or helpful for training.

The full list of training parameters is available in the SDXL model README

Advanced: Using your fine-tuned model with Diffusers

If you’re using the diffusers library directly to build a custom pipeline, you can use the weights from the model you’ve trained on Replicate.

The .output field of the training object contains both a pointer to the trained version and a URL to the trained weights:

print(training.output)
# {
#   'version': 'cloneofsimo/sdxl_mixes:...',
#   'weights': 'https://pbxt.replicate.delivery/.../trained_model.tar'
# }

Download the .tar file in the weights field and untar it. Now you can load the weights with the diffusers package.

import torch
from diffusers import DiffusionPipeline

from safetensors import safe_open
from dataset_and_utils import TokenEmbeddingsHandler
from diffusers.models import AutoencoderKL

pipe = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16,
        variant="fp16",
).to("cuda")

with safe_open("training_out/unet.safetensors", framework="pt", device="cuda") as f:
    for key in f.keys():
       tensors[key] = f.get_tensor(key)

pipe.unet.load_state_dict(tensors, strict=False) # should take < 2 seconds

text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]

embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers)
embhandler.load_embeddings("training_out/embeddings.pti")

Generate outputs by prompting the model according to special_params.json.

pipe(prompt="A photo of <s0><s1>").images[0].save("monster.png")

Advanced: Replacing generated prompts with custom prompts for training

For most users, the captions that BLIP generates for training work quite well. However, you can provide your own captions by adding a caption.csv file to the zip file of input images provided for training. Each input image needs to have a corresponding caption. Here’s an example csv for the specifics of the formatting.

What’s next?

We’ll continue to make SDXL fine-tuning better over the coming weeks. Follow along on Twitter and in Discord.