Skip to content

Instantly share code, notes, and snippets.

@Norod
Last active October 14, 2022 18:44
Show Gist options
  • Select an option

  • Save Norod/2ec83804b173c22a2b26ba9002566dee to your computer and use it in GitHub Desktop.

Select an option

Save Norod/2ec83804b173c22a2b26ba9002566dee to your computer and use it in GitHub Desktop.
Patches safety_checker in diffuser's stable_diffusion pipeline to not blacken the potential NSFW image. The caller can still decide what to do with it based on the value of has_nsfw_concept[index]
import importlib
checker_location = importlib.util.find_spec('diffusers.pipelines.stable_diffusion.safety_checker').origin
checker_location_flax = importlib.util.find_spec('diffusers.pipelines.stable_diffusion.safety_checker_flax').origin
blocker_location_flax = importlib.util.find_spec('diffusers.pipelines.stable_diffusion.pipeline_flax_stable_diffusion').origin
#FlaxStableDiffusionPipeline
import fileinput
def replace_in_file(file_path, search_text, new_text):
with fileinput.input(file_path, inplace=True) as file:
for line in file:
new_line = line.replace(search_text, new_text)
print(new_line, end='')
string_to_look_for = 'images[idx] = np.zeros(images[idx].shape)'
string_to_replace_with = 'print("Image at index " + str(idx) + " might contain NSFW content")'
print("Inspecting: " + str(checker_location))
replace_in_file(checker_location, string_to_look_for, string_to_replace_with)
print("Inspecting: " + str(checker_location_flax))
replace_in_file(checker_location_flax, string_to_look_for, string_to_replace_with)
replace_in_file(checker_location_flax, 'images_was_copied = False', 'images_was_copied = True')
print("Inspecting: " + str(blocker_location_flax))
replace_in_file(blocker_location_flax, 'images[i] = np.asarray(images_uint8_casted[i])', 'print(" ")')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment