Created
March 23, 2026 07:18
-
-
Save d3vw/8dbbc133fc40bb25bf37941f082ba6bb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # | |
| # Copyright (c) Microsoft. All rights reserved. | |
| # To learn more, please visit the documentation - Quickstart: Azure Content Safety: https://aka.ms/acsstudiodoc | |
| # | |
| import enum | |
| import json | |
| import requests | |
| from typing import Union | |
| class MediaType(enum.Enum): | |
| Text = 1 | |
| Image = 2 | |
| class Category(enum.Enum): | |
| Hate = 1 | |
| SelfHarm = 2 | |
| Sexual = 3 | |
| Violence = 4 | |
| class Action(enum.Enum): | |
| Accept = 1 | |
| Reject = 2 | |
| class DetectionError(Exception): | |
| def __init__(self, code: str, message: str) -> None: | |
| """ | |
| Exception raised when there is an error in detecting the content. | |
| Args: | |
| - code (str): The error code. | |
| - message (str): The error message. | |
| """ | |
| self.code = code | |
| self.message = message | |
| def __repr__(self) -> str: | |
| return f"DetectionError(code={self.code}, message={self.message})" | |
| class Decision(object): | |
| def __init__( | |
| self, suggested_action: Action, action_by_category: dict[Category, Action] | |
| ) -> None: | |
| """ | |
| Represents the decision made by the content moderation system. | |
| Args: | |
| - suggested_action (Action): The suggested action to take. | |
| - action_by_category (dict[Category, Action]): The action to take for each category. | |
| """ | |
| self.suggested_action = suggested_action | |
| self.action_by_category = action_by_category | |
| class ContentSafety(object): | |
| def __init__(self, endpoint: str, subscription_key: str, api_version: str) -> None: | |
| """ | |
| Creates a new ContentSafety instance. | |
| Args: | |
| - endpoint (str): The endpoint URL for the Content Safety API. | |
| - subscription_key (str): The subscription key for the Content Safety API. | |
| - api_version (str): The version of the Content Safety API to use. | |
| """ | |
| self.endpoint = endpoint | |
| self.subscription_key = subscription_key | |
| self.api_version = api_version | |
| def build_url(self, media_type: MediaType) -> str: | |
| """ | |
| Builds the URL for the Content Safety API based on the media type. | |
| Args: | |
| - media_type (MediaType): The type of media to analyze. | |
| Returns: | |
| - str: The URL for the Content Safety API. | |
| """ | |
| if media_type == MediaType.Text: | |
| return f"{self.endpoint}/contentsafety/text:analyze?api-version={self.api_version}" | |
| elif media_type == MediaType.Image: | |
| return f"{self.endpoint}/contentsafety/image:analyze?api-version={self.api_version}" | |
| else: | |
| raise ValueError(f"Invalid Media Type {media_type}") | |
| def build_headers(self) -> dict[str, str]: | |
| """ | |
| Builds the headers for the Content Safety API request. | |
| Returns: | |
| - dict[str, str]: The headers for the Content Safety API request. | |
| """ | |
| return { | |
| "Ocp-Apim-Subscription-Key": self.subscription_key, | |
| "Content-Type": "application/json", | |
| } | |
| def build_request_body( | |
| self, | |
| media_type: MediaType, | |
| content: str, | |
| blocklists: list[str], | |
| ) -> dict: | |
| """ | |
| Builds the request body for the Content Safety API request. | |
| Args: | |
| - media_type (MediaType): The type of media to analyze. | |
| - content (str): The content to analyze. | |
| - blocklists (list[str]): The blocklists to use for text analysis. | |
| Returns: | |
| - dict: The request body for the Content Safety API request. | |
| """ | |
| if media_type == MediaType.Text: | |
| return { | |
| "text": content, | |
| "blocklistNames": blocklists, | |
| } | |
| elif media_type == MediaType.Image: | |
| return {"image": {"content": content}} | |
| else: | |
| raise ValueError(f"Invalid Media Type {media_type}") | |
| def detect( | |
| self, | |
| media_type: MediaType, | |
| content: str, | |
| blocklists: list[str] = [], | |
| ) -> dict: | |
| """ | |
| Detects unsafe content using the Content Safety API. | |
| Args: | |
| - media_type (MediaType): The type of media to analyze. | |
| - content (str): The content to analyze. | |
| - blocklists (list[str]): The blocklists to use for text analysis. | |
| Returns: | |
| - dict: The response from the Content Safety API. | |
| """ | |
| url = self.build_url(media_type) | |
| headers = self.build_headers() | |
| request_body = self.build_request_body(media_type, content, blocklists) | |
| payload = json.dumps(request_body) | |
| response = requests.post(url, headers=headers, data=payload) | |
| print(response.status_code) | |
| print(response.headers) | |
| print(response.text) | |
| res_content = response.json() | |
| if response.status_code != 200: | |
| raise DetectionError( | |
| res_content["error"]["code"], res_content["error"]["message"] | |
| ) | |
| return res_content | |
| def get_detect_result_by_category( | |
| self, category: Category, detect_result: dict | |
| ) -> Union[int, None]: | |
| """ | |
| Gets the detection result for the given category from the Content Safety API response. | |
| Args: | |
| - category (Category): The category to get the detection result for. | |
| - detect_result (dict): The Content Safety API response. | |
| Returns: | |
| - Union[int, None]: The detection result for the given category, or None if it is not found. | |
| """ | |
| category_res = detect_result.get("categoriesAnalysis", None) | |
| for res in category_res: | |
| if category.name == res.get("category", None): | |
| return res | |
| raise ValueError(f"Invalid Category {category}") | |
| def make_decision( | |
| self, | |
| detection_result: dict, | |
| reject_thresholds: dict[Category, int], | |
| ) -> Decision: | |
| """ | |
| Makes a decision based on the Content Safety API response and the specified reject thresholds. | |
| Users can customize their decision-making method. | |
| Args: | |
| - detection_result (dict): The Content Safety API response. | |
| - reject_thresholds (dict[Category, int]): The reject thresholds for each category. | |
| Returns: | |
| - Decision: The decision based on the Content Safety API response and the specified reject thresholds. | |
| """ | |
| action_result = {} | |
| final_action = Action.Accept | |
| for category, threshold in reject_thresholds.items(): | |
| if threshold not in (-1, 0, 2, 4, 6): | |
| raise ValueError("RejectThreshold can only be in (-1, 0, 2, 4, 6)") | |
| cate_detect_res = self.get_detect_result_by_category( | |
| category, detection_result | |
| ) | |
| if cate_detect_res is None or "severity" not in cate_detect_res: | |
| raise ValueError(f"Can not find detection result for {category}") | |
| severity = cate_detect_res["severity"] | |
| action = ( | |
| Action.Reject | |
| if threshold != -1 and severity >= threshold | |
| else Action.Accept | |
| ) | |
| action_result[category] = action | |
| if action.value > final_action.value: | |
| final_action = action | |
| if ( | |
| "blocklistsMatch" in detection_result | |
| and detection_result["blocklistsMatch"] | |
| and len(detection_result["blocklistsMatch"]) > 0 | |
| ): | |
| final_action = Action.Reject | |
| print(final_action.name) | |
| print(action_result) | |
| return Decision(final_action, action_result) | |
| if __name__ == "__main__": | |
| # Replace the placeholders with your own values | |
| endpoint = "<endpoint>" | |
| subscription_key = "<subscription_key>" | |
| api_version = "2024-09-01" | |
| # Initialize the ContentSafety object | |
| content_safety = ContentSafety(endpoint, subscription_key, api_version) | |
| # Set the media type and blocklists | |
| media_type = MediaType.Text | |
| blocklists = [] | |
| # Set the content to be tested | |
| content = "<test_content>" | |
| # Detect content safety | |
| detection_result = content_safety.detect(media_type, content, blocklists) | |
| # Set the reject thresholds for each category | |
| reject_thresholds = { | |
| Category.Hate: 4, | |
| Category.SelfHarm: 4, | |
| Category.Sexual: 4, | |
| Category.Violence: 4, | |
| } | |
| # Make a decision based on the detection result and reject thresholds | |
| decision_result = content_safety.make_decision(detection_result, reject_thresholds) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment