Last active
October 25, 2023 15:49
-
-
Save Codelaby/4dc54511809f3e4cc7c80e6cc937cf3d to your computer and use it in GitHub Desktop.
bm25 algoritm swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Foundation | |
struct Document: Hashable { | |
let id: String | |
let content: String | |
func hash(into hasher: inout Hasher) { | |
hasher.combine(id) | |
} | |
static func ==(lhs: Document, rhs: Document) -> Bool { | |
return lhs.id == rhs.id | |
} | |
} | |
func calculateBM25(query: String, documents: [Document]) -> [Document: Double] { | |
let k1 = 1.2 | |
let b = 0.75 | |
let queryTerms = normalizeText(query).components(separatedBy: " ") | |
let documentFrequencies = calculateDocumentFrequencies(documents: documents, queryTerms: queryTerms) | |
let averageDocumentLength = calculateAverageDocumentLength(documents: documents) | |
var scores: [Document: Double] = [:] | |
for document in documents { | |
let documentLength = Double(document.content.lowercased().components(separatedBy: " ").count) | |
var score = 0.0 | |
for term in queryTerms { | |
let termFrequency = calculateTermFrequency(term: term, document: document) | |
let documentFrequency = documentFrequencies[term] ?? 0 | |
let numerator = (k1 + 1) * termFrequency | |
let denominator = k1 * ((1 - b) + b * (documentLength / averageDocumentLength)) + termFrequency | |
let idf = log((Double(documents.count) - Double(documentFrequency) + 0.5) / (Double(documentFrequency) + 0.5)) | |
score += idf * (numerator / denominator) | |
} | |
scores[document] = score | |
} | |
return scores | |
} | |
func calculateDocumentFrequencies(documents: [Document], queryTerms: [String]) -> [String: Int] { | |
var documentFrequencies: [String: Int] = [:] | |
for term in queryTerms { | |
for document in documents { | |
if document.content.lowercased().contains(term) { | |
documentFrequencies[term, default: 0] += 1 | |
} | |
} | |
} | |
return documentFrequencies | |
} | |
func calculateTermFrequency(term: String, document: Document) -> Double { | |
let normalizedDocument = normalizeText(document.content) | |
let terms = normalizedDocument.components(separatedBy: " ") | |
let termCount = terms.filter { $0 == term }.count | |
return Double(termCount) | |
} | |
func normalizeText(_ text: String) -> String { | |
let normalizedText = text.lowercased() | |
.folding(options: .diacriticInsensitive, locale: .current) | |
.replacingOccurrences(of: #"[^a-z0-9\s]+"#, with: "", options: .regularExpression) | |
return normalizedText | |
} | |
func calculateAverageDocumentLength(documents: [Document]) -> Double { | |
let totalLength = documents.reduce(0) { $0 + $1.content.lowercased().components(separatedBy: " ").count } | |
return Double(totalLength) / Double(documents.count) | |
} | |
// Example usage | |
let documents = [ | |
Document(id: "doc1", content: "This is the first document. ID: CKG, Name: Chongqing Jiangbei International Airport, City: Chongqing, City 2: Jiangbei, Country: China, Description: Opened in 1990, Chongqing Jiangbei International replaced the older Baishiyi Airport. Its three-letter code comes from the city’s former English name: Chungking. Image Credit: byeangel, Image Credit Link: https://www.flickr.com/photos/byeangel/. State: Yubei District."), | |
Document(id: "doc2", content: "This document is the second document. ID: LCG, Name: Aeroporto da Coruña-Alvedro, City: A Coruña, City 2: Galicia, Country: Spain, Description: Formerly known as Alvedro Airport, A Coruña Airport was inaugurated in 1963. Its airport code comes from the Spanish city of La Coruña, Galicia. Image Credit: Caneles, Image Credit Link: https://www.flickr.com/photos/94446676@N00/."), | |
Document(id: "doc3", content: "And this is the third document. ID: MAD, Name: Aeropuerto Adolfo Suárez Madrid-Barajas, City: Madrid, City 2: Barajas, Country: Spain, Description: Spain’s largest airport honors former Prime Minister Adolfo Suárez, but its airport code honors its home in the capital city of Madrid. Image Credit: Anh Dinh, Image Credit Link: https://www.flickr.com/photos/anhgemus-photography/. Name in English: Adolfo Suárez Madrid–Barajas Airport."), | |
Document(id: "doc4", content: "This is the fourth document. ID: BCN, Name: Aeropuerto de Barcelona-El Prat, City: Barcelona, City 2: El Prat de Llobregat, Country: Spain, Description: Barcelona’s first airfield was built in 1916, but a new location in El Prat was chosen in 1918. The airport now uses the code BCN which stands for Barcelona. Image Credit: Camilo Rueda López, Image Credit Link: https://www.flickr.com/photos/kozumel/."), | |
] | |
let query = "Image Credit" | |
let scores = calculateBM25(query: query, documents: documents) | |
for (document, score) in scores { | |
print("Document ID: \(document.id), Score: \(score)") | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Foundation | |
struct Document: Hashable { | |
let id: String | |
let content: String | |
func hash(into hasher: inout Hasher) { | |
hasher.combine(id) | |
} | |
static func ==(lhs: Document, rhs: Document) -> Bool { | |
return lhs.id == rhs.id | |
} | |
} | |
func calculateBM25(query: String, documents: [Document]) -> [Document: Double] { | |
let k1 = 1.2 | |
let b = 0.75 | |
let queryTerms = normalizeText(query).components(separatedBy: " ") | |
let documentFrequencies = calculateDocumentFrequencies(documents: documents, queryTerms: queryTerms) | |
let averageDocumentLength = calculateAverageDocumentLength(documents: documents) | |
var scores: [Document: Double] = [:] | |
for document in documents { | |
let documentLength = Double(document.content.lowercased().components(separatedBy: " ").count) | |
var score = 0.0 | |
for term in queryTerms { | |
let termFrequency = calculateTermFrequency(term: term, document: document) | |
let documentFrequency = documentFrequencies[term] ?? 0 | |
let numerator = (k1 + 1) * termFrequency | |
let denominator = k1 * ((1 - b) + b * (documentLength / averageDocumentLength)) + termFrequency | |
let idf = log((Double(documents.count) - Double(documentFrequency) + 0.5) / (Double(documentFrequency) + 0.5)) | |
score += idf * (numerator / denominator) | |
} | |
scores[document] = score | |
} | |
return scores | |
} | |
func calculateDocumentFrequencies(documents: [Document], queryTerms: [String]) -> [String: Int] { | |
var documentFrequencies: [String: Int] = [:] | |
for term in queryTerms { | |
for document in documents { | |
if document.content.lowercased().contains(term) { | |
documentFrequencies[term, default: 0] += 1 | |
} | |
} | |
} | |
return documentFrequencies | |
} | |
func calculateTermFrequency(term: String, document: Document) -> Double { | |
let normalizedDocument = normalizeText(document.content) | |
let terms = normalizedDocument.components(separatedBy: " ") | |
let termCount = terms.filter { $0 == term }.count | |
return Double(termCount) | |
} | |
func normalizeText(_ text: String) -> String { | |
let normalizedText = text.lowercased() | |
.folding(options: .diacriticInsensitive, locale: .current) | |
.replacingOccurrences(of: #"[^a-z0-9\s]+"#, with: "", options: .regularExpression) | |
return normalizedText | |
} | |
func calculateAverageDocumentLength(documents: [Document]) -> Double { | |
let totalLength = documents.reduce(0) { $0 + $1.content.lowercased().components(separatedBy: " ").count } | |
return Double(totalLength) / Double(documents.count) | |
} | |
// Example usage | |
// proposed search by code 'es' | |
let documents = [ | |
Document(id: "doc1", content: "name: Bangladesh, dialCode: +880, code: BD, displayName: Bangladés"), | |
Document(id: "doc2", content: "name: Estonia, dialCode: +372, code: EE, displayName: Estonia"), | |
Document(id: "doc3", content: "name: French Guiana, dialCode: +594, code: GF, displayName: Guayana Francesa"), | |
Document(id: "doc4", content: "name: French Polynesia, dialCode: +689, code: PF, displayName: Polinesia Francesa"), | |
Document(id: "doc5", content: "name: Guernsey, dialCode: +44, code: GG, displayName: Guernesey"), | |
Document(id: "doc6", content: "name: Indonesia, dialCode: +62, code: ID, displayName: Indonesia"), | |
Document(id: "doc7", content: "name: Lesotho, dialCode: +266, code: LS, displayName: Lesoto"), | |
Document(id: "doc8", content: "name: Maldives, dialCode: +960, code: MV, displayName: Maldivas"), | |
Document(id: "doc9", content: "name: Micronesia, Federated States of Micronesia, dialCode: +691, code: FM, displayName: Micronesia"), | |
Document(id: "doc10", content: "name: Netherlands, dialCode: +31, code: NL, displayName: Países Bajos"), | |
Document(id: "doc11", content: "name: Palestinian Territory, Occupied, dialCode: +970, code: PS, displayName: Territorios Palestinos"), | |
Document(id: "doc12", content: "name: Philippines, dialCode: +63, code: PH, displayName: Filipinas"), | |
Document(id: "doc13", content: "name: Saint Kitts and Nevis, dialCode: +1869, code: KN, displayName: San Cristóbal y Nieves"), | |
Document(id: "doc14", content: "name: Saint Vincent and the Grenadines, dialCode: +1784, code: VC, displayName: San Vicente y las Granadinas"), | |
Document(id: "doc15", content: "name: Seychelles, dialCode: +248, code: SC, displayName: Seychelles"), | |
Document(id: "doc16", content: "name: Slovakia, dialCode: +421, code: SK, displayName: Eslovaquia"), | |
Document(id: "doc17", content: "name: Slovenia, dialCode: +386, code: SI, displayName: Eslovenia"), | |
Document(id: "doc18", content: "name: Spain, dialCode: +34, code: ES, displayName: España"), | |
Document(id: "doc19", content: "name: Swaziland, dialCode: +268, code: SZ, displayName: Esuatini"), | |
Document(id: "doc20", content: "name: Timor-Leste, dialCode: +670, code: TL, displayName: Timor Oriental"), | |
Document(id: "doc21", content: "name: United Arab Emirates, dialCode: +971, code: AE, displayName: Emiratos Árabes Unidos"), | |
Document(id: "doc22", content: "name: United States, dialCode: +1, code: US, displayName: Estados Unidos"), | |
Document(id: "doc23", content: "name: Virgin Islands, British, dialCode: +1284, code: VG, displayName: Islas Vírgenes Británicas"), | |
Document(id: "doc24", content: "name: Virgin Islands, U.S., dialCode: +1340, code: VI, displayName: Islas Vírgenes de EE. UU."), | |
] | |
let query = "+34" | |
let scores = calculateBM25(query: query, documents: documents2) | |
for (document, score) in scores { | |
print("Document ID: \(document.id), Score: \(score)") | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment