Train and deploy a DreamBooth model on Replicate

Posted by @bfirsh and @zeke

Update, August 2023: We've added fine-tuning support to SDXL, the latest version of Stable Diffusion. The DreamBooth API described below still works, but you can achieve better results at a higher resolution using SDXL. Check out the SDXL fine-tuning blog post to get started, or read on to use the old DreamBooth API.

Generative AI has been abuzz with DreamBooth. It’s a way to train Stable Diffusion on a particular object or style, creating your own version of the model that generates those objects or styles. You can train a model with as few as three images and the training process takes less than half an hour.

Notably, DreamBooth works with people, so you can make a version of Stable Diffusion that can generate images of yourself.

People have been making some magical products with DreamBooth, such as Avatar AI and ProfilePicture.AI.

Now, you can create your own projects with DreamBooth too. We’ve built an API that lets you train DreamBooth models and run predictions on them in the cloud.

You need as few as three training images and it takes about 20 minutes (depending on how many iterations that you use). It costs about $2.50 to train a model.

Train a DreamBooth model

First, grab your API token and set it in a terminal:

export REPLICATE_API_TOKEN=...

Next, gather your training data as a set of JPEGs in a directory called data/ and zip it up:

zip -r data.zip data

Put this zip file somewhere accessible via HTTP. If you like, you can use our API for uploading files. Run these three commands:

RESPONSE=$(curl -X POST -H "Authorization: Bearer $REPLICATE_API_TOKEN" https://dreambooth-api-experimental.replicate.com/v1/upload/data.zip)

curl -X PUT -H "Content-Type: application/zip" --upload-file data.zip "$(jq -r ".upload_url" <<< "$RESPONSE")"

SERVING_URL=$(jq -r ".serving_url" <<< $RESPONSE)

Then, start a training job:

curl -X POST \
    -H "Authorization: Bearer $REPLICATE_API_TOKEN" \
    -d '{
            "input": {
                "instance_prompt": "a photo of a cjw person",
                "class_prompt": "a photo of a person",
                "instance_data": "'"$SERVING_URL"'",
                "max_train_steps": 2000
            },
            "model": "yourusername/yourmodel",
            "trainer_version": "cd3f925f7ab21afaef7d45224790eedbb837eeac40d22e8fefe015489ab644aa",
            "webhook_completed": "https://example.com/dreambooth-webhook"
        }' \
    https://dreambooth-api-experimental.replicate.com/v1/trainings

You need to set:

  • instance_prompt: The prompt that you use to describe your training images, in the format a [identifier] [class noun], where identifier is some rare token. In the example above, we use cjw, but you can use any string that you like. For best results, use an identifier containing three Unicode characters, without spaces.
  • class_prompt: A prompt of the broader category of images that you’re training on, in the format a [class noun]. This is used to generate other images like your training data to avoid overfitting.
  • instance_data: The URL to your training data.
  • max_train_steps: The number of training steps to run. Fewer steps make it run faster but typically make it worse quality, and vice versa.
  • model: A name to give your model on Replicate, in the form username/modelname. For example, bfirsh/bfirshbooth. Replicate automatically creates the model for you if it doesn’t exist yet.
  • trainer_version: The version of DreamBooth and Stable Diffusion to use. See the “versions” section below for more details.
  • webhook_completed: A webhook to call when the job finishes. (Optional.)

Behind the scenes, this runs the replicate/dreambooth model. Any input to that model can be passed in the input object.

The API responds with this object:

{
  "id": "rrr4z55ocneqzikepnug6xezpe",
  "input": {
    "instance_prompt": "photo of a cjw person",
    "class_prompt": "photo of a person",
    "instance_data": "https://replicate.delivery/pbxt/HoUeWsrtTTCJEpKGdLKqIYTfo8nbUTSNs565MkGxEstjfwKt/data.zip",
    "max_train_steps": 2000
  },
  "model": "yourusername/yourmodel",
  "status": "starting",
  "trainer_version": "cd3f925f7ab21afaef7d45224790eedbb837eeac40d22e8fefe015489ab644aa",
  "webhook_completed": "https://example.com/dreambooth-webhook"
}

You can get the status of the training job by calling GET /v1/trainings/<id>:

curl -H "Authorization: Bearer $REPLICATE_API_TOKEN" \
  https://dreambooth-api-experimental.replicate.com/v1/trainings/rrr4z55ocneqzikepnug6xezpe

It responds with the same object:

{
  "id": "rrr4z55ocneqzikepnug6xezpe",
  "input": {
    "instance_prompt": "photo of a cjw person",
    "class_prompt": "photo of a person",
    "instance_data": "https://replicate.delivery/pbxt/HoUeWsrtTTCJEpKGdLKqIYTfo8nbUTSNs565MkGxEstjfwKt/data.zip",
    "max_train_steps": 2000
  },
  "model": "yourusername/yourmodel",
  "status": "succeeded",
  "trainer_version": "cd3f925f7ab21afaef7d45224790eedbb837eeac40d22e8fefe015489ab644aa",
  "webhook_completed": "https://example.com/dreambooth-webhook",
  "version": "8abccf52e7cba9f6e82317253f4a3549082e966db5584e92c808ece132037776"
}

This is the same object that is sent to your webhook.

Run your trained model

When the training process has completed successfully, it pushes the model to Replicate.

You can run the model like any other model on Replicate, using the website or the API.

To run on the website, go to your dashboard then click on “models”.

Your new model is private by default, and only visible to you. If you want anyone to be able to see and run your model, then you can make it public in the Settings tab on your model page.

To run the model as an API, you first need to grab the version ID. This can be found on the “API” tab on the model page, or version in the response from the training API.

Then, you can make an API call:

curl -X POST \
    -H "Authorization: Bearer $REPLICATE_API_TOKEN" \
    -d '{
            "input": {
                "prompt": "painting of cjw by andy warhol",
            },
            "version": "8abccf52e7cba9f6e82317253f4a3549082e966db5584e92c808ece132037776",
        }' \
    https://api.replicate.com/v1/predictions

Or, with Python:

import replicate
replicate.run(
    "yourusername/yourmodel:8abccf52e7cba9f6e82317253f4a3549082e966db5584e92c808ece132037776",
    input={"prompt": "painting of cjw by andy warhol"},
)

To learn more about running models on Replicate, take a look at the Python getting started guide or the HTTP API reference.

Versions

By default, DreamBooth trains a Stable Diffusion 1.5 model. This model tends to work better for DreamBooth because it has more different styles in it.

If you’d like to use another version, then you can select different versions with the trainer_version option. Here are the versions that are supported:

  • Stable Diffusion 1.5: cd3f925f7ab21afaef7d45224790eedbb837eeac40d22e8fefe015489ab644aa
  • Custom Checkpoints: 9c41656f8ae2e3d2af4c1b46913d7467cd891f2c1c5f3d97f1142e876e63ed7a
  • Stable Diffusion 2.1-base: d5e058608f43886b9620a8fbb1501853b8cbae4f45c857a014011c86ee614ffb

To find other versions that are available, take a look at the release notes of the DreamBooth trainer.

Next steps

If you have any questions about using this, drop into #dreambooth on our Discord.

Happy training! 🚂