Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Stable Diffusion exporter and pipeline #399

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

anton-l
Copy link
Member

@anton-l anton-l commented Sep 7, 2022

Usage:

pip install onnxruntime

python scripts/convert_stable_diffusion_checkpoint_to_onnx.py --model_path="CompVis/stable-diffusion-v1-4" --output_path="./sd_onnx"

from diffusers import StableDiffusionOnnxPipeline

pipe = StableDiffusionOnnxPipeline.from_pretrained("./sd_onnx", provider="CPUExecutionProvider")

def __init__(
self,
vae_decoder: OnnxModel,
text_encoder: OnnxModel,
tokenizer: CLIPTokenizer,
unet: OnnxModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxModel,
feature_extractor: CLIPFeatureExtractor,
):
Copy link
Member Author

@anton-l anton-l Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main difference it that the models are replaced with diffusers.onnx_utils.OnnxModel

has_nsfw_concepts = torch.tensor([len(res["bad_concepts"]) > 0 for res in result])

for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept:
images[idx] = np.zeros(images[idx].shape) # black image
images[idx] = torch.zeros(images[idx].shape) # black image
Copy link
Member Author

@anton-l anton-l Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important changes to the safety checker to make it traceable

if is_input_numpy:
images = images.numpy()
has_nsfw_concepts = has_nsfw_concepts.numpy()

Copy link
Member Author

@anton-l anton-l Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backward-compatible outputs of the safety checker (for community pipelines that use it)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

class OnnxModel:
base_model_prefix = "onnx_model"

def __init__(self, model=None, **kwargs):
self.model = model
self.model_save_dir = kwargs.get("model_save_dir", None)
self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
Copy link
Member Author

@anton-l anton-l Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly copied from optimum.onnxruntime.ORTModel

The copy was necessary to remove the AutoModel and AutoConfig-related code (the optimum model was intended as a wrapper for transformers models only)

@@ -33,6 +33,10 @@ def __init__(self, config: CLIPConfig):

@torch.no_grad()
def forward(self, clip_input, images):
is_input_numpy = isinstance(images, np.ndarray)
if is_input_numpy:
Copy link
Member

@patrickvonplaten patrickvonplaten Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to avoid one-liners -> is_input_numpy is not easier to understand or more readable than isinstance(images, np.ndarray) IMO, but it requires one to have to remember one more attribute.

Suggested change
if is_input_numpy:
if isinstance(images, np.ndarray):

Copy link
Member Author

@anton-l anton-l Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to remember the type to convert the outputs back, so is_input_numpy this is used later in the function


if any(has_nsfw_concepts):
logger.warning(
"Potential NSFW content was detected in one or more images. A black image will be returned instead."
" Try again with a different prompt and/or seed."
)

if is_input_numpy:
Copy link
Member

@patrickvonplaten patrickvonplaten Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if is_input_numpy:
if isinstance(images, np.ndarray):

Copy link
Member Author

@anton-l anton-l Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is no longer True here (see the comment above)

except NameError:
pass

def git_config_username_and_email(self, git_user: str = None, git_email: str = None):
Copy link
Member

@patrickvonplaten patrickvonplaten Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this method?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants