Created
September 29, 2023 23:10
-
-
Save fsndzomga/b65c0c0da0611d3a829519cbf6916fbe to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def create_embeddings(chunks_with_metadata): | |
""" | |
Generate embeddings for each chunk using BERT. | |
Parameters: | |
- chunks_with_metadata (list): A list of dictionaries containing chunk and page number. | |
Returns: | |
- list: A list of dictionaries with embeddings and metadata. | |
""" | |
# Load pre-trained BERT tokenizer and model | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
model = BertModel.from_pretrained('bert-base-uncased') | |
vectors = [] | |
for idx, chunk_data in enumerate(chunks_with_metadata): | |
# Tokenize chunk and get the BERT embeddings | |
inputs = tokenizer(chunk_data['chunk'], return_tensors='pt', padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Use mean pooling to get sentence embeddings | |
embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().tolist() | |
# Construct the vector dictionary | |
vector_data = { | |
'id': f'vec{idx + 1}', | |
'values': embeddings, | |
'metadata': { | |
'chunk': chunk_data['chunk'], | |
'page_number': chunk_data['page_number'] | |
} | |
} | |
vectors.append(vector_data) | |
return vectors | |
def create_chunks_with_metadata(pages): | |
""" | |
Split pages content into chunks (paragraphs) and store with page number metadata, ensuring each chunk | |
does not exceed a specified token limit. | |
Parameters: | |
- pages (list): A list of strings where each string represents the content of one page. | |
Returns: | |
- list: A list of dictionaries. Each dictionary has two keys: 'chunk' and 'page_number'. | |
""" | |
MAX_TOKENS = 512 | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
chunks_with_metadata = [] | |
for page_number, page_content in enumerate(pages, start=1): | |
# Split the page content by empty lines to get paragraphs | |
paragraphs = [p.strip() for p in page_content.split('\n') if p.strip()] | |
for paragraph in paragraphs: | |
# Tokenize the paragraph | |
tokens = tokenizer.tokenize(paragraph) | |
# If the paragraph is too long, split it into smaller chunks | |
while tokens: | |
chunk_tokens = tokens[:MAX_TOKENS] | |
tokens = tokens[MAX_TOKENS:] | |
chunk_text = tokenizer.convert_tokens_to_string(chunk_tokens) | |
chunk_data = { | |
'chunk': chunk_text, | |
'page_number': page_number | |
} | |
chunks_with_metadata.append(chunk_data) | |
return chunks_with_metadata |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment