Train and deploy a DreamBooth model on Replicate

Posted November 21, 2022 by @bfirsh and @zeke

Generative AI has been abuzz with DreamBooth. It's a way to train Stable Diffusion on a particular object or style, creating your own version of the model that generates those objects or styles. You can train a model with as few as three images and the training process takes less than half an hour.

Notably, DreamBooth works with people, so you can make a version of Stable Diffusion that can generate images of yourself.

People have been making some magical products with DreamBooth, such as Avatar AI and ProfilePicture.AI.

Now, you can create your own projects with DreamBooth too. We've built an API that lets you train DreamBooth models and run predictions on them in the cloud.

You need as few as three training images and it takes about 20 minutes (depending on how many iterations that you use). It costs about $2.50 to train a model.

Train a DreamBooth model

First, grab your API token and set it:

export REPLICATE_API_TOKEN=...

Next, gather your training data as a set of JPEGs in a directory called data/ and zip it up:

zip -r data.zip data

Put this zip file somewhere accessible via HTTP. If you like, you can use our API for uploading files:

RESPONSE=$(curl -X POST -H "Authorization: Token $REPLICATE_API_TOKEN" https://dreambooth-api-experimental.replicate.com/v1/upload/data.zip)
curl -X PUT -H "Content-Type: application/zip" --upload-file data.zip "$(jq -r ".upload_url" <<< "$RESPONSE")"
SERVING_URL=$(jq -r ".serving_url" <<< $RESPONSE)

Then, start a training job:

curl -X POST \
    -H "Authorization: Token $REPLICATE_API_TOKEN" \
    -H "Content-Type: application/json" \
    -d '{
            "input": {
                "instance_prompt": "a photo of a cjw person",
                "class_prompt": "a photo of a person",
                "instance_data": "'"$SERVING_URL"'",
                "max_train_steps": 2000
            },
            "model": "yourusername/yourmodel",
            "trainer_version": "cd3f925f7ab21afaef7d45224790eedbb837eeac40d22e8fefe015489ab644aa",
            "webhook_completed": "https://example.com/dreambooth-webhook"
        }' \
    https://dreambooth-api-experimental.replicate.com/v1/trainings

You need to set:

  • instance_prompt: The prompt that you use to describe your training images, in the format a [identifier] [class noun], where identifier is some rare token. In the example above, we use cjw, but you can use any string that you like. For best results, use an identifier containing three Unicode characters, without spaces.
  • class_prompt: A prompt of the broader category of images that you're training on, in the format a [class noun]. This is used to generate other images like your training data to avoid overfitting.
  • instance_data: The URL to your training data.
  • max_train_steps: The number of training steps to run. Fewer steps make it run faster but typically make it worse quality, and vice versa.
  • model: A name to give your model on Replicate, in the form username/modelname. For example, bfirsh/bfirshbooth. Replicate automatically creates the model for you if it doesn't exist yet.
  • trainer_version: The version of DreamBooth and Stable Diffusion to use. See the "versions" section below for more details.
  • webhook_completed: A webhook to call when the job finishes. (Optional.)

Behind the scenes, this runs the replicate/dreambooth model. Any input to that model can be passed in the input object.

The API responds with this object:

{
  "id": "rrr4z55ocneqzikepnug6xezpe",
  "input": {
    "instance_prompt": "photo of a cjw person",
    "class_prompt": "photo of a person",
    "instance_data": "https://replicate.delivery/pbxt/HoUeWsrtTTCJEpKGdLKqIYTfo8nbUTSNs565MkGxEstjfwKt/data.zip",
    "max_train_steps": 2000
  },
  "model": "yourusername/yourmodel",
  "status": "starting",
  "trainer_version": "cd3f925f7ab21afaef7d45224790eedbb837eeac40d22e8fefe015489ab644aa",
  "webhook_completed": "https://example.com/dreambooth-webhook"
}

You can get the status of the training job by calling GET /v1/trainings/<id>:

curl -H "Authorization: Token $REPLICATE_API_TOKEN" \
  https://dreambooth-api-experimental.replicate.com/v1/trainings/rrr4z55ocneqzikepnug6xezpe

It responds with the same object:

{
  "id": "rrr4z55ocneqzikepnug6xezpe",
  "input": {
    "instance_prompt": "photo of a cjw person",
    "class_prompt": "photo of a person",
    "instance_data": "https://replicate.delivery/pbxt/HoUeWsrtTTCJEpKGdLKqIYTfo8nbUTSNs565MkGxEstjfwKt/data.zip",
    "max_train_steps": 2000
  },
  "model": "yourusername/yourmodel",
  "status": "succeeded",
  "trainer_version": "cd3f925f7ab21afaef7d45224790eedbb837eeac40d22e8fefe015489ab644aa",
  "webhook_completed": "https://example.com/dreambooth-webhook",
  "version": "8abccf52e7cba9f6e82317253f4a3549082e966db5584e92c808ece132037776"
}

This is the same object that is sent to your webhook.

Run your trained model

When the training process has completed successfully, it pushes the model to Replicate. You can run it like any other Replicate model, using the website or the API.

To run on the website, go to your dashboard then click on "models".

Your new model is private by default, and only visible to you. If you want anyone to be able to see and run your model, then you can make it public in the Settings tab on your model page.

The version in the successful training response can be used to run predictions via an API. For example:

curl -X POST \
    -H "Authorization: Token $REPLICATE_API_TOKEN" \
    -d '{
            "input": {
                "prompt": "painting of cjw by andy warhol",
            },
            "version": "8abccf52e7cba9f6e82317253f4a3549082e966db5584e92c808ece132037776",
        }' \
    https://api.replicate.com/v1/predictions

Or, with Python:

import replicate
model = replicate.models.get("yourusername/yourmodel")
version = model.versions.get("8abccf52e7cba9f6e82317253f4a3549082e966db5584e92c808ece132037776")
version.predict(prompt="painting of cjw by andy warhol")

To learn more about running models on Replicate, take a look at the Python getting started guide or the HTTP API reference.

Versions

By default, DreamBooth trains a Stable Diffusion 1.5 model. This model tends to work better for DreamBooth because it has more different styles in it.

If you'd like to use another version, then you can select different versions with the trainer_version option. Here are the versions that are supported:

  • Stable Diffusion 1.5: cd3f925f7ab21afaef7d45224790eedbb837eeac40d22e8fefe015489ab644aa
  • Custom Checkpoints: 9c41656f8ae2e3d2af4c1b46913d7467cd891f2c1c5f3d97f1142e876e63ed7a
  • Stable Diffusion 2.1-base: d5e058608f43886b9620a8fbb1501853b8cbae4f45c857a014011c86ee614ffb

To find other versions that are available, take a look at the release notes of the DreamBooth trainer.

Next steps

If you have any questions about using this, drop into #dreambooth on our Discord.

Happy training! 🚂