Spaces:
Paused
Paused
| """Test RAGLite's reranking functionality.""" | |
| import pytest | |
| from rerankers.models.ranker import BaseRanker | |
| from raglite import RAGLiteConfig, hybrid_search, rerank_chunks, retrieve_chunks | |
| from raglite._database import Chunk | |
| from raglite._flashrank import PatchedFlashRankRanker as FlashRankRanker | |
| def reranker( | |
| request: pytest.FixtureRequest, | |
| ) -> BaseRanker | tuple[tuple[str, BaseRanker], ...] | None: | |
| """Get a reranker to test RAGLite with.""" | |
| reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None = request.param | |
| return reranker | |
| def test_reranker( | |
| raglite_test_config: RAGLiteConfig, | |
| reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None, | |
| ) -> None: | |
| """Test inserting a document, updating the indexes, and searching for a query.""" | |
| # Update the config with the reranker. | |
| raglite_test_config = RAGLiteConfig( | |
| db_url=raglite_test_config.db_url, embedder=raglite_test_config.embedder, reranker=reranker | |
| ) | |
| # Search for a query. | |
| query = "What does it mean for two events to be simultaneous?" | |
| chunk_ids, _ = hybrid_search(query, num_results=3, config=raglite_test_config) | |
| # Retrieve the chunks. | |
| chunks = retrieve_chunks(chunk_ids, config=raglite_test_config) | |
| assert all(isinstance(chunk, Chunk) for chunk in chunks) | |
| assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True)) | |
| # Rerank the chunks given an inverted chunk order. | |
| reranked_chunks = rerank_chunks(query, chunks[::-1], config=raglite_test_config) | |
| if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder: | |
| assert reranked_chunks[0] == chunks[0] | |
| # Test that we can also rerank given the chunk_ids only. | |
| reranked_chunks = rerank_chunks(query, chunk_ids[::-1], config=raglite_test_config) | |
| if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder: | |
| assert reranked_chunks[0] == chunks[0] | |