PyTorch implementation of "TryOnDiffusion: A Tale of Two UNets", a virtual try-on diffusion-based network by Google
This is an unofficial implementation of the “TryOnDiffusion: A Tale of Two UNets” paper.
In this repository you can find:
In this repository you WON’T find:
You are welcome to join our Discord server and share your thoughts with us regarding this project in the dedicated # 🐙 | open-source channel
To get a better sense of how to use this repository, please look at the example scripts available in ./examples
from tryondiffusion import BaseParallelUnet
unet = BaseParallelUnet()
noisy_images = torch.randn(1, 3, 128, 128)
ca_images = torch.randn(1, 3, 128, 128)
garment_images = torch.randn(1, 3, 128, 128)
person_poses = torch.randn(1, 18, 2)
garment_poses = torch.randn(1, 18, 2)
time = torch.randn(1,)
ca_noise_times = torch.randn(1,)
garment_noise_times = torch.randn(1,)
output = unet(
noisy_images=noisy_images,
time=time,
ca_images=ca_images,
ca_noise_times=ca_noise_times,
garment_images=garment_images,
garment_noise_times=garment_noise_times,
person_poses=person_poses,
garment_poses=garment_poses,
# lowres inputs are not needed for base UNet
lowres_cond_img=None,
lowres_noise_times=None,
)
print(output.shape) # (1, 3, 128, 128)
For more control over the UNet structure, you can instantiate from the tryondiffusion.ParallelUNet
class directly or edit the pre-defined templates.
class BaseParallelUnet(ParallelUNet):
def __init__(self, *args, **kwargs):
default_kwargs = dict(
image_size=(128, 128),
dim=128,
context_dim=512,
time_cond_dim=512,
pose_depth=2,
feature_channels=(128, 256, 512, 1024),
num_blocks=(3, 4, 6, 7),
layer_attns=(False, False, True, True),
layer_cross_attns=(False, False, True, True),
attn_heads=(None, None, 4, 8),
attn_dim_head=(None, None, 128, 64),
)
super().__init__(*args, **{**default_kwargs, **kwargs})
Use the tryondiffusion.TryOnImagen
class to manage multiple UNets for training and sampling
from tryondiffusion import TryOnImagen, get_unet_by_name
unet1 = get_unet_by_name("base")
unet2 = get_unet_by_name("sr")
device = "cuda" if torch.cuda.is_available() else "cpu"
imagen = TryOnImagen(
unets=(unet1, unet2),
image_sizes=((128, 128), (256, 256)),
timesteps=(256, 128),
)
imagen = imagen.to(device)
Perform a forward pass to get the loss
person_images = person_images.to(device)
ca_images = ca_images.to(device)
garment_images = garment_images.to(device)
person_pose = person_pose.to(device)
garment_pose = garment_pose.to(device)
loss = imagen(
person_images=person_images,
ca_images=ca_images,
garment_images=garment_images,
person_poses=person_pose,
garment_poses=garment_pose,
unet_number=TRAIN_UNET_NUMBER,
)
loss.backward()
Call the sample
method to generate images in a cascade
images = imagen.sample(
# notice we don't pass person_images
ca_images=ca_images,
garment_images=garment_images,
person_poses=person_pose,
garment_poses=garment_pose,
batch_size=1,
cond_scale=2.0,
start_at_unet_number=1,
return_all_unet_outputs=False,
return_pil_images=True,
use_tqdm=True,
use_one_unet_in_gpu=True,
)
images[0].show() # or images[0].save("output.png")
Use the tryondiffusion.TryOnImagenTrainer
class to handle training multiple UNets and manage their states.
This class is almost identical to the ImagenTrainer
from the original imagen-pytorch repository, so you can also refer to it for more examples on how to use it.
from tryondiffusion import TryOnImagenTrainer
trainer = TryOnImagenTrainer(
imagen=imagen,
lr=1e-4,
eps=1e-8,
betas=(0.9, 0.99),
max_grad_norm=1.0,
checkpoint_every=5000,
checkpoint_path="./checkpoints"
# use the 'accelerate_` prefix to pass arguments to the underlying `accelerate.Accelerator`
accelerate_cpu=True,
accelerate_gradient_accumulation_steps=8,
# see more options in the class definition
)
# see ./examples/test_tryon_imagen_trainer.py for an example dataset and dataloader
trainer.add_train_dataloader(train_dataloader)
trainer.add_valid_dataloader(validation_dataloader)
# training loop
for i in range(10):
# choose which unet in the cascade you wish to train
loss = trainer.train_step(unet_number=1)
print(f"train loss: {loss}")
# validation
valid_loss = trainer.valid_step(unet_number=1)
print(f"valid loss: {valid_loss}")
# sampling with the trainer
validation_sample = next(trainer.valid_dl_iter)
_ = validation_sample.pop("person_images") # person_images are only for training
imagen_sample_kwargs = dict(
**validation_sample,
batch_size=BATCH_SIZE,
cond_scale=2.0,
start_at_unet_number=1,
return_all_unet_outputs=True,
return_pil_images=True,
use_tqdm=True,
use_one_unet_in_gpu=True,
)
images = trainer.sample(**imagen_sample_kwargs)
for unet_output in images:
for image in unet_output:
image.show()
Some context regarding our experiment, which doesn’t follow the original paper exactly:
Already at 250k steps (50% from the 500k prescribed by the paper) we can see some outstanding results on difficult background, fabrics and poses. Thus, we stopped at this stage to focus on the SR UNet.
At 125k steps, we can see that the SR UNet does a successful job at upscaling the base UNet output. It adds a lot of detail and even fixes mistakes in the input image. However, as is immediately apparent there is a persistent color tint which plagues the generation results.
This is a known issue which is well-documented in related literature for diffusion upscalers. Other researchers we know that have tried to implement this paper have also reported this phenomenon.
We have stopped pursuing this direction, as we currently have other ideas that we prefer to explore for virtual try-on (and our resources are limited). However, we still believe this implementation is mostly correct and might be improved by:
v
as suggested by Progressive Distillation for Fast Sampling of Diffusion Models, and was reported to remedy the color shift issue in upscaling UNets for video. This is already implemented and can be tried by changing the prediction objective in the TryOnImagen
class.imagen = TryOnImagen(...,
pred_objectives=("noise", "v"),
...
)
@misc{zhu2023tryondiffusion,
title={TryOnDiffusion: A Tale of Two UNets},
author={Luyang Zhu and Dawei Yang and Tyler Zhu and Fitsum Reda and William Chan and Chitwan Saharia and Mohammad Norouzi and Ira Kemelmacher-Shlizerman},
year={2023},
eprint={2306.08276},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
@misc{saharia2022photorealistic,
title={Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding},
author={Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and S. Sara Mahdavi and Rapha Gontijo Lopes and Tim Salimans and Jonathan Ho and David J Fleet and Mohammad Norouzi},
year={2022},
eprint={2205.11487},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
@misc{salimans2022progressive,
title={Progressive Distillation for Fast Sampling of Diffusion Models},
author={Tim Salimans and Jonathan Ho},
year={2022},
eprint={2202.00512},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
@misc{hoogeboom2023simple,
title={Simple diffusion: End-to-end diffusion for high resolution images},
author={Emiel Hoogeboom and Jonathan Heek and Tim Salimans},
year={2023},
eprint={2301.11093},
archivePrefix={arXiv},
primaryClass={cs.CV}
}