replicate/dreambooth

Public
Train your own custom Stable Diffusion model using a small set of images
6.8K runs

This is a low-level model for generating model weights with DreamBooth. Read on to get started training and publishing your own model.

Prerequisites

  • A paid Replicate account. Sign up at https://replicate.com/account. It’s free to sign up, and you pay by the second for predictions you run. Each DreamBooth training run costs about $2.50
  • Some imagery to train your model

Option 1: Train a model using GitHub Actions

We created a template repository on GitHub that makes it easy to train and publish your own model using a GitHub Actions workflow. This approach doesn't require you to write any code and gives you a foundation for a simple and continuous training process, so you can easily refine and improve your model over time with new training data.

To get started, check out replicate/dreambooth-action on GitHub

Option 2: Train a model on the command line

If you like to get your hands dirty, you can train your model on the command line using git and curl.

First, grab your API token and set it:

export REPLICATE_API_TOKEN=...

Next, gather your training data as a set of JPEGs in a directory called data/ and zip it up:

zip -r data.zip data

Put this zip file somewhere accessible via HTTP. If you like, you can use our API for uploading files:

RESPONSE=$(curl -X POST -H "Authorization: Token $REPLICATE_API_TOKEN" https://dreambooth-api-experimental.replicate.com/v1/upload/data.zip)
curl -X PUT -H "Content-Type: application/zip" --upload-file data.zip "$(jq -r ".upload_url" <<< "$RESPONSE")"
SERVING_URL=$(jq -r ".serving_url" <<< $RESPONSE)

Then, start a training job:

curl -X POST \
    -H "Authorization: Token $REPLICATE_API_TOKEN" \
    -H "Content-Type: application/json" \
    -d '{
            "input": {
                "instance_prompt": "a photo of a cjw person",
                "class_prompt": "a photo of a person",
                "instance_data": "'"$SERVING_URL"'",
                "max_train_steps": 2000
            },
            "model": "yourusername/yourmodel",
            "webhook_completed": "https://example.com/dreambooth-webhook"
        }' \
    https://dreambooth-api-experimental.replicate.com/v1/trainings

You need to set:

  • instance_prompt: the prompt that you use to describe your training images, in the format a [identifier] [class noun], where identifier is some rare token. In the example above, we use cjw, but you can use any string that you like. For best results, use an identifier containing three Unicode characters, without spaces.
  • class_prompt: a prompt of the broader category of images that you're training on, in the format a [class noun]. This is used to generate other images like your training data to avoid overfitting.
  • instance_data: the URL to your training data.
  • max_train_steps: the maximum number of steps of the training job. This is the number to drop if a prediction takes too long.
  • model: a name to give your model on Replicate, in the form username/modelname. For example, bfirsh/bfirshbooth. Replicate automatically creates the model for you if it doesn't exist yet.
  • webhook_completed: a webhook to call when the job finishes. (Optional.)

Behind the scenes, this runs the replicate/dreambooth model. Any input to that model can be passed in the input object.

The API responds with this object:

{
  "id": "rrr4z55ocneqzikepnug6xezpe",
  "input": {
    "instance_prompt": "photo of a cjw person",
    "class_prompt": "photo of a person",
    "instance_data": "https://replicate.delivery/pbxt/HoUeWsrtTTCJEpKGdLKqIYTfo8nbUTSNs565MkGxEstjfwKt/data.zip",
    "max_train_steps": 2000
  },
  "model": "yourusername/yourmodel",
  "status": "starting",
  "webhook_completed": "https://example.com/dreambooth-webhook"
}

You can get the status of the training job by calling GET /v1/trainings/<id>:

curl -H "Authorization: Token $REPLICATE_API_TOKEN" \
  https://dreambooth-api-experimental.replicate.com/v1/trainings/rrr4z55ocneqzikepnug6xezpe

It responds with the same object:

{
  "id": "rrr4z55ocneqzikepnug6xezpe",
  "input": {
    "instance_prompt": "photo of a cjw person",
    "class_prompt": "photo of a person",
    "instance_data": "https://replicate.delivery/pbxt/HoUeWsrtTTCJEpKGdLKqIYTfo8nbUTSNs565MkGxEstjfwKt/data.zip",
    "max_train_steps": 2000
  },
  "model": "yourusername/yourmodel",
  "status": "succeeded",
  "webhook_completed": "https://example.com/dreambooth-webhook",
  "version": "8abccf52e7cba9f6e82317253f4a3549082e966db5584e92c808ece132037776"
}

This is the same object that is sent to your webhook.

Run your trained model

When the training process has completed successfully, it pushes the model to Replicate. You can run it like any other Replicate model, using the website or the API.

To run on the website, go to your dashboard then click on "models".

Your new model is private by default, and only visible to you. If you want anyone to be able to see and run your model, then you can make it public in the Settings tab on your model page.

The version in the successful training response can be used to run predictions via an API. For example:

curl -X POST \
    -H "Authorization: Token $REPLICATE_API_TOKEN" \
    -d '{
            "input": {
                "prompt": "painting of cjw by andy warhol",
            },
            "version": "8abccf52e7cba9f6e82317253f4a3549082e966db5584e92c808ece132037776",
        }' \
    https://api.replicate.com/v1/predictions

Or, with Python:

import replicate
model = replicate.models.get("yourusername/yourmodel")
version = model.versions.get("8abccf52e7cba9f6e82317253f4a3549082e966db5584e92c808ece132037776")
version.predict(prompt="painting of cjw by andy warhol")

To learn more about running models on Replicate, take a look at the Python getting started guide or the HTTP API reference.

Next steps

If you have any questions about using this, drop into #dreambooth on our Discord.

Happy training! 🚂

Replicate