Created
October 13, 2024 05:32
-
-
Save tspannhw/d51abdd385bdb3340133a34c0f872934 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from pymilvus import AnnSearchRequest, RRFRanker, WeightedRanker | |
| import time | |
| from pymilvus import WeightedRanker | |
| # Use WeightedRanker to combine results with specified weights | |
| rerank = WeightedRanker(0.6, 0.7, 0.8, 0.9 , 0.6) | |
| QUERY = str("creepy skull head ghost") | |
| queries = [QUERY] | |
| query_embeddings = model.encode(queries) | |
| search_param_0 = { | |
| "data": query_embeddings, | |
| "anns_field": "vector", | |
| "param": {"metric_type": "COSINE", },"limit": 3 } | |
| request_0 = AnnSearchRequest(**search_param_0) | |
| search_param_1 = { | |
| "data": textmodel.encode([QUERY]), | |
| "anns_field": "text_vector2", | |
| "param": { | |
| "metric_type": "COSINE", | |
| }, "limit": 3 } | |
| request_1 = AnnSearchRequest(**search_param_1) | |
| search_param_2 = { | |
| "data": splade_ef.encode_queries([QUERY]), | |
| "anns_field": "text_vector", | |
| "param": { | |
| "metric_type": "IP", | |
| }, "limit": 3} | |
| request_2 = AnnSearchRequest(**search_param_2) | |
| images = Image.open(requests.get("http://192.168.1.166:9000/images/ghost4.jpg", stream=True).raw) | |
| image_embeddings = model.encode([images]) | |
| search_param_3 = { | |
| "data": image_embeddings, | |
| "anns_field": "vector", | |
| "param": { | |
| "metric_type": "COSINE", | |
| },"limit": 3 } | |
| request_3 = AnnSearchRequest(**search_param_3) | |
| bgem3_queries = bge_m3.encode_queries([QUERY]) | |
| search_param_4 = { | |
| "data": bgem3_queries["dense"], | |
| "anns_field": "text_vector3", | |
| "param": { | |
| "metric_type": "L2", | |
| },"limit": 3 } | |
| request_4 = AnnSearchRequest(**search_param_4) | |
| reqs = [request_0, request_1, request_2, request_3, request_4] | |
| collection.load() | |
| # Run a start and end time for our search | |
| # Search collection vector from above | |
| t0 = time.time() | |
| results = collection.hybrid_search(reqs, rerank, | |
| output_fields=["s3path", "timestamp", "ghostclass", "location"], | |
| limit=10) | |
| t1 = time.time() | |
| print(f"search latency: {round(t1-t0, 4)} seconds\n\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment