Readme
You can do the same with pruna by following this minimal example:
import torch
from f_lite import FLitePipeline
# Trick required because it is not a native diffusers model
from diffusers.pipelines.pipeline_loading_utils import LOADABLE_CLASSES, ALL_IMPORTABLE_CLASSES
LOADABLE_CLASSES["f_lite"] = LOADABLE_CLASSES["f_lite.model"] = {"DiT": ["save_pretrained", "from_pretrained"]}
ALL_IMPORTABLE_CLASSES["DiT"] = ["save_pretrained", "from_pretrained"]
pipeline = FLitePipeline.from_pretrained("Freepik/F-Lite", torch_dtype=torch.bfloat16).to("cuda")
pipeline.dit_model = torch.compile(pipeline.dit_model)
from pruna_pro import SmashConfig, smash
# Initialize the SmashConfig
smash_config = SmashConfig()
smash_config["cacher"] = "auto"
# smash_config["auto_cache_mode"] = "taylor"
smash_config["auto_speed_factor"] = 0.8 # Lower is faster, but reduces quality
smash_config["auto_custom_model"] = True
smashed_pipe = smash(
model=pipeline,
smash_config=smash_config,
experimental=True,
)
smashed_pipe.cache_helper.configure(
pipe=pipeline,
pipe_call_method="__call__",
step_argument="num_inference_steps",
backbone=pipeline.dit_model,
backbone_call_method="forward",
)
smashed_pipe(
prompt="A cake with 'pruna' written on it",
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=3.0,
negative_prompt=None,
generator=torch.Generator(device="cuda").manual_seed(0),
).images[0]