Train and deploy a DreamBooth model on Replicate

Posted November 21, 2022 by @bfirsh and @zeke

Generative AI has been abuzz with DreamBooth. It's a way to train Stable Diffusion on a particular object or style, creating your own version of the model that generates those objects or styles. You can train a model with as few as three images and the training process takes less than half an hour.

Notably, DreamBooth works with people, so you can make a version of Stable Diffusion that can generate images of yourself.

People have been making some magical products with DreamBooth, such as Avatar AI and ProfilePicture.AI.

Now, you can create your own projects with DreamBooth too. We've built an API that lets you train DreamBooth models and run predictions on them in the cloud.

You need as few as three training images and it takes about 20 minutes (depending on how many iterations that you use). It costs about $2.50 to train a model.

Train a DreamBooth model

First, grab your API token and set it:

export REPLICATE_API_TOKEN=...

Next, gather your training data as a set of JPEGs in a directory called data/ and zip it up:

zip -r data.zip data

Put this zip file somewhere accessible via HTTP. If you like, you can use our API for uploading files:

RESPONSE=$(curl -X POST -H "Authorization: Token $REPLICATE_API_TOKEN" https://dreambooth-api-experimental.replicate.com/v1/upload/data.zip)
curl -X PUT -H "Content-Type: application/zip" --upload-file data.zip "$(jq -r ".upload_url" <<< "$RESPONSE")"
SERVING_URL=$(jq -r ".serving_url" <<< $RESPONSE)

Then, start a training job:

curl -X POST \
    -H "Authorization: Token $REPLICATE_API_TOKEN" \
    -H "Content-Type: application/json" \
    -d '{
            "input": {
                "instance_prompt": "a photo of a cjw person",
                "class_prompt": "a photo of a person",
                "instance_data": "'"$SERVING_URL"'",
                "max_train_steps": 2000
            },
            "model": "yourusername/yourmodel",
            "webhook_completed": "https://example.com/dreambooth-webhook"
        }' \
    https://dreambooth-api-experimental.replicate.com/v1/trainings

You need to set:

  • instance_prompt: the prompt that you use to describe your training images, in the format a [identifier] [class noun], where identifier is some rare token. In the example above, we use cjw, but you can use any string that you like. For best results, use an identifier containing three Unicode characters, without spaces.
  • class_prompt: a prompt of the broader category of images that you're training on, in the format a [class noun]. This is used to generate other images like your training data to avoid overfitting.
  • instance_data: the URL to your training data.
  • max_train_steps: the maximum number of steps of the training job. This is the number to drop if a prediction takes too long.
  • model: a name to give your model on Replicate, in the form username/modelname. For example, bfirsh/bfirshbooth. Replicate automatically creates the model for you if it doesn't exist yet.
  • webhook_completed: a webhook to call when the job finishes. (Optional.)

Behind the scenes, this runs the replicate/dreambooth model. Any input to that model can be passed in the input object.

The API responds with this object:

{
  "id": "rrr4z55ocneqzikepnug6xezpe",
  "input": {
    "instance_prompt": "photo of a cjw person",
    "class_prompt": "photo of a person",
    "instance_data": "https://replicate.delivery/pbxt/HoUeWsrtTTCJEpKGdLKqIYTfo8nbUTSNs565MkGxEstjfwKt/data.zip",
    "max_train_steps": 2000
  },
  "model": "yourusername/yourmodel",
  "status": "starting",
  "webhook_completed": "https://example.com/dreambooth-webhook"
}

You can get the status of the training job by calling GET /v1/trainings/<id>:

curl -H "Authorization: Token $REPLICATE_API_TOKEN" \
  https://dreambooth-api-experimental.replicate.com/v1/trainings/rrr4z55ocneqzikepnug6xezpe

It responds with the same object:

{
  "id": "rrr4z55ocneqzikepnug6xezpe",
  "input": {
    "instance_prompt": "photo of a cjw person",
    "class_prompt": "photo of a person",
    "instance_data": "https://replicate.delivery/pbxt/HoUeWsrtTTCJEpKGdLKqIYTfo8nbUTSNs565MkGxEstjfwKt/data.zip",
    "max_train_steps": 2000
  },
  "model": "yourusername/yourmodel",
  "status": "succeeded",
  "webhook_completed": "https://example.com/dreambooth-webhook",
  "version": "8abccf52e7cba9f6e82317253f4a3549082e966db5584e92c808ece132037776"
}

This is the same object that is sent to your webhook.

Run your trained model

When the training process has completed successfully, it pushes the model to Replicate. You can run it like any other Replicate model, using the website or the API.

To run on the website, go to your dashboard then click on "models".

Your new model is private by default, and only visible to you. If you want anyone to be able to see and run your model, then you can make it public in the Settings tab on your model page.

The version in the successful training response can be used to run predictions via an API. For example:

curl -X POST \
    -H "Authorization: Token $REPLICATE_API_TOKEN" \
    -d '{
            "input": {
                "prompt": "painting of cjw by andy warhol",
            },
            "version": "8abccf52e7cba9f6e82317253f4a3549082e966db5584e92c808ece132037776",
        }' \
    https://api.replicate.com/v1/predictions

Or, with Python:

import replicate
model = replicate.models.get("yourusername/yourmodel")
version = model.versions.get("8abccf52e7cba9f6e82317253f4a3549082e966db5584e92c808ece132037776")
version.predict(prompt="painting of cjw by andy warhol")

To learn more about running models on Replicate, take a look at the Python getting started guide or the HTTP API reference.

Next steps

If you have any questions about using this, drop into #dreambooth on our Discord.

Happy training! 🚂

Replicate