New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Stable Diffusion exporter and pipeline #399
base: main
Are you sure you want to change the base?
Conversation
| def __init__( | ||
| self, | ||
| vae_decoder: OnnxModel, | ||
| text_encoder: OnnxModel, | ||
| tokenizer: CLIPTokenizer, | ||
| unet: OnnxModel, | ||
| scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], | ||
| safety_checker: OnnxModel, | ||
| feature_extractor: CLIPFeatureExtractor, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main difference it that the models are replaced with diffusers.onnx_utils.OnnxModel
| has_nsfw_concepts = torch.tensor([len(res["bad_concepts"]) > 0 for res in result]) | ||
|
|
||
| for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): | ||
| if has_nsfw_concept: | ||
| images[idx] = np.zeros(images[idx].shape) # black image | ||
| images[idx] = torch.zeros(images[idx].shape) # black image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Important changes to the safety checker to make it traceable
| if is_input_numpy: | ||
| images = images.numpy() | ||
| has_nsfw_concepts = has_nsfw_concepts.numpy() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
backward-compatible outputs of the safety checker (for community pipelines that use it)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
| class OnnxModel: | ||
| base_model_prefix = "onnx_model" | ||
|
|
||
| def __init__(self, model=None, **kwargs): | ||
| self.model = model | ||
| self.model_save_dir = kwargs.get("model_save_dir", None) | ||
| self.latest_model_name = kwargs.get("latest_model_name", "model.onnx") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mostly copied from optimum.onnxruntime.ORTModel
The copy was necessary to remove the AutoModel and AutoConfig-related code (the optimum model was intended as a wrapper for transformers models only)
| @@ -33,6 +33,10 @@ def __init__(self, config: CLIPConfig): | |||
|
|
|||
| @torch.no_grad() | |||
| def forward(self, clip_input, images): | |||
| is_input_numpy = isinstance(images, np.ndarray) | |||
| if is_input_numpy: | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try to avoid one-liners -> is_input_numpy is not easier to understand or more readable than isinstance(images, np.ndarray) IMO, but it requires one to have to remember one more attribute.
| if is_input_numpy: | |
| if isinstance(images, np.ndarray): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to remember the type to convert the outputs back, so is_input_numpy this is used later in the function
|
|
||
| if any(has_nsfw_concepts): | ||
| logger.warning( | ||
| "Potential NSFW content was detected in one or more images. A black image will be returned instead." | ||
| " Try again with a different prompt and/or seed." | ||
| ) | ||
|
|
||
| if is_input_numpy: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if is_input_numpy: | |
| if isinstance(images, np.ndarray): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is no longer True here (see the comment above)
| except NameError: | ||
| pass | ||
|
|
||
| def git_config_username_and_email(self, git_user: str = None, git_email: str = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this method?
Usage:
pip install onnxruntimepython scripts/convert_stable_diffusion_checkpoint_to_onnx.py --model_path="CompVis/stable-diffusion-v1-4" --output_path="./sd_onnx"