Skip to content

Commit

Permalink
unify sd into one single file
Browse files Browse the repository at this point in the history
  • Loading branch information
xmfan committed Oct 4, 2023
1 parent 71d32e1 commit 599e3d5
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 172 deletions.
19 changes: 13 additions & 6 deletions torchbenchmark/models/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,20 @@
It requires users to specify "HUGGINGFACE_AUTH_TOKEN" in environment variable
to authorize login and agree HuggingFace terms and conditions.
"""
from torch import nn
from torchbenchmark.tasks import COMPUTER_VISION
from torchbenchmark.util.model import BenchmarkModel
from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceAuthMixin

import torch
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler

class SDPipelineWrapper(nn.Module):
def __init__(self, pipe):
self.pipe = pipe

def forward(self, x):
return self.pipe(x)

class Model(BenchmarkModel, HuggingFaceAuthMixin):
task = COMPUTER_VISION.GENERATION
Expand All @@ -28,9 +35,9 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
batch_size=batch_size, extra_args=extra_args)
model_id = "stabilityai/stable-diffusion-2"
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
self.pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler)
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler)
self.model = SDPipelineWrapper(pipe).to(self.device)
self.example_inputs = "a photo of an astronaut riding a horse on mars"
self.pipe.to(self.device)

def enable_fp16(self):
# This model uses fp16 by default
Expand All @@ -41,14 +48,14 @@ def get_module(self):
random_input = torch.randn(1, 4, 128, 128).to(self.device)
timestep = torch.tensor([1.0]).to(self.device)
encoder_hidden_states = torch.randn(1, 1, 1024).to(self.device)
return self.pipe.unet, [random_input, timestep, encoder_hidden_states]
return self.model, [random_input, timestep, encoder_hidden_states]

def set_module(self, module):
self.pipe.unet = module
self.model = module

def train(self):
raise NotImplementedError("Train test is not implemented for the stable diffusion model.")

def eval(self):
image = self.pipe(self.example_inputs)
return (image, )
images = self.model(self.example_inputs)
return (images, )
58 changes: 0 additions & 58 deletions torchbenchmark/models/stable_diffusion_text_encoder/__init__.py

This file was deleted.

17 changes: 0 additions & 17 deletions torchbenchmark/models/stable_diffusion_text_encoder/install.py

This file was deleted.

10 changes: 0 additions & 10 deletions torchbenchmark/models/stable_diffusion_text_encoder/metadata.yaml

This file was deleted.

54 changes: 0 additions & 54 deletions torchbenchmark/models/stable_diffusion_unet/__init__.py

This file was deleted.

17 changes: 0 additions & 17 deletions torchbenchmark/models/stable_diffusion_unet/install.py

This file was deleted.

10 changes: 0 additions & 10 deletions torchbenchmark/models/stable_diffusion_unet/metadata.yaml

This file was deleted.

0 comments on commit 599e3d5

Please sign in to comment.