Instantly share code, notes, and snippets.
Created
July 15, 2023 05:51
-
Star
0
(0)
You must be signed in to star a gist -
Fork
0
(0)
You must be signed in to fork a gist
-
Save PttCodingMan/7cdb47d857c0cad987ade593ae86a765 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import re | |
from queue import PriorityQueue | |
from sentence_transformers import SentenceTransformer, util | |
from sklearn.metrics.pairwise import cosine_similarity | |
from post_walker import post_walker | |
logging.basicConfig( | |
level=logging.INFO, | |
format='[%(asctime)s] %(message)s', | |
datefmt='%m%d %H:%M:%S', | |
) | |
model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2') | |
related_top = 5 | |
block_post_flags = ['top: true', 'hidden: true'] | |
class Post: | |
def __init__(self, title: str, abbrlink: str, tags: list[str], raw_data: str): | |
# logger.info('title %s', title) | |
# logger.info('abbrlink %s', abbrlink) | |
self.title = title | |
self.link = abbrlink | |
self.tags = tags | |
self.analytics_content = ' '.join(self.tags) + ' ' + self.title + ' ' + raw_data[raw_data.rfind('---') + 3:].strip() | |
if '## 相關文章' in self.analytics_content: | |
self.analytics_content = self.analytics_content[:self.analytics_content.rfind('## 相關文章')] | |
self.related_posts = [] | |
self.raw_data = raw_data | |
self.related_score = 0 | |
def add_related_post(self, related_post): | |
self.related_posts.put( | |
(-cosine_similarity(self.embedding, related_post.embedding)[0][0], related_post.link)) | |
def count_must_related_post(): | |
sentences = [p.analytics_content for p in posts.values()] | |
embedding = model.encode(sentences, convert_to_tensor=False) | |
cosine_scores = util.cos_sim(embedding, embedding) | |
for i in range(len(sentences)): | |
priority_queue = PriorityQueue() | |
for j in range(len(sentences)): | |
if i == j: | |
continue | |
priority_queue.put((-cosine_scores[i][j].item(), j)) | |
current_post = posts[sentences[i]] | |
logger.info('post %s top %s', current_post.title, related_top) | |
for _ in range(related_top): | |
if priority_queue.empty(): | |
break | |
score, index = priority_queue.get() | |
related_post = posts[sentences[index]] | |
logger.info('related post %s score %s', related_post.title, -score) | |
current_post.related_posts.append(related_post) | |
######## | |
posts = dict() | |
raw_posts = dict() | |
def is_block_post(raw_data: str) -> bool: | |
for flag in block_post_flags: | |
if flag in raw_data: | |
return True | |
return False | |
def collect_posts(raw_data: str) -> str: | |
if is_block_post(raw_data): | |
return raw_data | |
global posts | |
title_match = re.search(r"title: (.+)", raw_data) | |
abbrlink_match = re.search(r"abbrlink: (.+)", raw_data) | |
title = None | |
if title_match: | |
title = title_match.group(1).strip() | |
link = None | |
if abbrlink_match: | |
link = abbrlink_match.group(1).strip() | |
if link.startswith("'"): | |
link = link[1:-1] | |
tags_pattern = r'tags:\n((?:\s*-\s*.+\n)*)' | |
tags_match = re.search(tags_pattern, raw_data) | |
tags = None | |
if tags_match: | |
tags_string = tags_match.group(1) | |
tags = [tag.strip().lower() for tag in re.findall(r'-\s*(.+)', tags_string)] | |
if '--' in tags: | |
tags.remove('--') | |
tags.sort() | |
if title is None or link is None or tags is None: | |
return raw_data | |
post = Post(title, link, tags, raw_data) | |
posts[post.analytics_content] = post | |
raw_posts[post.raw_data] = post | |
return raw_data | |
def add_related_post(raw_data: str) -> str: | |
if raw_data not in raw_posts: | |
return raw_data | |
current_post = raw_posts[raw_data] | |
if '## 相關文章' in raw_data: | |
raw_data = raw_data[:raw_data.rfind('## 相關文章')].strip() | |
if is_block_post(raw_data): | |
return raw_data | |
append_content = '\n\n## 相關文章\n\n' | |
for related_post in current_post.related_posts: | |
append_content += f'- [{related_post.title}](/{related_post.link})\n' | |
raw_data += append_content | |
return raw_data | |
if __name__ == '__main__': | |
logger = logging.getLogger(__name__) | |
post_walker(collect_posts) | |
count_must_related_post() | |
post_walker(add_related_post) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment