Safetensors
TEDDY / tests /test_preprocess.py
soumyatghosh's picture
Upload folder using huggingface_hub
4527b5f verified
# test_preprocess.py
import hashlib
import json
import os
from pathlib import Path
import anndata as ad
import numpy as np
import pandas as pd
import pytest
from scipy.sparse import csr_matrix
from teddy.data_processing.preprocessing.preprocess import (
compute_and_save_medians,
filter_cells_by_gene_counts,
filter_cells_by_mitochondrial_fraction,
filter_highly_variable_genes,
initialize_processed_layer,
log_transform_layer,
normalize_data_inplace,
preprocess,
set_raw_if_necessary,
update_metadata,
)
@pytest.fixture
def synthetic_anndata():
"""
Returns a small synthetic AnnData object for testing.
"""
X = np.array([[0, 2, 3], [4, 0, 6], [7, 8, 0], [0, 0, 1]], dtype=float)
var = pd.DataFrame(index=["geneA", "geneB", "geneC"])
obs = pd.DataFrame(index=["cell1", "cell2", "cell3", "cell4"])
adata = ad.AnnData(X=csr_matrix(X), var=var, obs=obs)
return adata
def test_set_raw_if_necessary(synthetic_anndata):
data = synthetic_anndata.copy()
# The first 64 rows in synthetic data_processing are the entire data_processing in this test
# Check if they're integer:
# They are not all integer if we rely on the raw array—some zeros, but let's see:
# Actually, they *are* integers, so raw should be set.
result = set_raw_if_necessary(data)
assert result is not None
assert result.raw is not None
assert (result.raw.X != 0).toarray().any()
def test_initialize_processed_layer(synthetic_anndata):
data = synthetic_anndata.copy()
# Manually set data_processing.raw to be some integer data_processing so that the function can copy it
data.raw = data.copy()
# processed layer does not exist initially
assert "processed" not in data.layers
data = initialize_processed_layer(data)
assert "processed" in data.layers
# The processed layer should match data_processing.raw.X
assert (data.layers["processed"] != 0).toarray().any()
def test_filter_cells_by_gene_counts(synthetic_anndata):
data = synthetic_anndata.copy()
data.raw = data # So that processed can be set
data.layers["processed"] = data.X.copy()
# min_gene_counts = 2 (Example)
result = filter_cells_by_gene_counts(data, min_count=2)
# We expect that cell1 has total = 5, cell2 total=10, cell3=15, cell4=1 => so cell4 gets filtered out
assert result.n_obs == 3
assert "cell4" not in result.obs_names
def test_filter_cells_by_mitochondrial_fraction(synthetic_anndata):
data = synthetic_anndata.copy()
data.raw = data # needed so we can add a processed layer
data.var["feature_name"] = ["MT-GENE1", "geneB", "geneC"] # Suppose first gene is a MT gene
data.layers["processed"] = data.X.copy()
# If max_mito_prop is 0.25, we see if the fraction is bigger than 0.25
result = filter_cells_by_mitochondrial_fraction(data, max_mito_prop=0.25)
# Evaluate how many cells remain. The first gene is mitochondrial => row sums for that gene:
# cell1: 0 / row sum=5 => 0
# cell2: 4 / row sum=10 => 0.4 (excluded)
# cell3: 7 / row sum=15 => ~0.466.. (excluded)
# cell4: 0 / row sum=1 => 0
# So we expect cell1 and cell4 remain
assert result.n_obs == 2
assert "cell2" not in result.obs_names
assert "cell3" not in result.obs_names
assert "cell1" in result.obs_names
assert "cell4" in result.obs_names
def test_normalize_data_inplace(synthetic_anndata):
data = synthetic_anndata.copy()
data.layers["processed"] = data.X.copy()
# Convert to CSR
if not isinstance(data.layers["processed"], csr_matrix):
data.layers["processed"] = csr_matrix(data.layers["processed"])
normalize_data_inplace(data.layers["processed"], 1e4)
# After normalization, row sums ~ 1e4
row_sums = np.array(data.layers["processed"].sum(axis=1)).flatten()
for s in row_sums:
# row sums should be close to 10000
assert pytest.approx(s, 1e-4) == 10000.0
def test_log_transform_layer(synthetic_anndata):
data = synthetic_anndata.copy()
data.layers["processed"] = data.X.copy()
log_transform_layer(data, "processed")
# Just ensure that the data_processing was log1p-transformed
# e.g., check (some random cell, gene) is log(1 + original_value)
# cell2, gene1 was 4 => log(5) approx 1.609...
# Let's check:
val_after = data.layers["processed"][1, 0]
assert pytest.approx(val_after, 1e-5) == np.log1p(4)
def test_compute_and_save_medians(tmp_path, synthetic_anndata):
"""
Check that compute_and_save_medians writes out a JSON file with medians of non-zero entries.
"""
data = synthetic_anndata.copy()
data.layers["processed"] = data.X.copy()
# Suppose we want to save them
mock_data_path = str(tmp_path / "test.h5ad")
data.write_h5ad(mock_data_path)
hyperparams = {
"load_dir": str(tmp_path),
"save_dir": str(tmp_path),
"median_column": "index",
}
compute_and_save_medians(data, mock_data_path, hyperparams)
# check that the .json file was written
median_path = mock_data_path.replace(".h5ad", "_medians.json")
assert os.path.exists(median_path)
with open(median_path, "r") as f:
mdict = json.load(f)
# We expect some medians in that dictionary
assert len(mdict) == data.n_vars
def test_update_metadata(synthetic_anndata):
data = synthetic_anndata.copy()
metadata = {}
hyperparams = {"some": "arg"}
data.obs["some_col"] = [1, 2, 3, 4] # just to show we have 4 cells
new_meta = update_metadata(metadata, data, hyperparams)
assert new_meta.get("cell_count", None) == 4
assert "processings_args" in new_meta or "processing_args" in new_meta
def test_preprocess_pipeline_end_to_end(tmp_path):
"""
Test the entire preprocess pipeline with minimal hyperparameters on synthetic data_processing.
"""
# 1) Write out a small synthetic .h5ad and a metadata.json
# 1) Create a small synthetic dataset (10 cells × 5 genes here, for example)
X = np.random.rand(10, 5) # or integer counts if your pipeline expects that
# 2) Construct an AnnData in memory
adata = ad.AnnData(
X=X,
obs=pd.DataFrame(index=[f"cell_{i}" for i in range(X.shape[0])]),
var=pd.DataFrame(index=[f"gene_{j}" for j in range(X.shape[1])]),
)
# 3) Set raw to be a snapshot of the same data (like many pipelines do)
adata.raw = adata # This copies over current X, var, obs into .raw
# 4) Confirm: By default, .raw.X is a numpy array, not a sparse matrix:
print("raw.X is type:", type(adata.raw.X)) # typically <class 'numpy.ndarray'>
# 5) If your code wants .raw.X to be a CSR matrix, do the 'official' approach:
raw_adata = adata.raw.to_adata() # extract .raw as its own AnnData
raw_adata.X = csr_matrix(raw_adata.X) # now we can set .X to sparse
adata.raw = raw_adata # reassign
print("Now raw.X is type:", type(adata.raw.X)) # <class 'scipy.sparse.csr.csr_matrix'>
data_path = tmp_path / "before.h5ad"
adata.write_h5ad(data_path)
metadata = {"sample": "test_sample"}
metadata_path = tmp_path / "metadata.json"
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=4)
# 2) Minimal hyperparameters
hyperparameters = {
"load_dir": str(tmp_path),
"save_dir": str(tmp_path),
"reference_id_only": False,
"remove_assays": [],
"min_gene_counts": 0,
"max_mitochondrial_prop": None,
"hvg_method": None,
"normalized_total": None,
"median_dict": None,
"median_column": "index",
"log1p": False,
"compute_medians": False,
}
# 3) Run preprocess
returned = preprocess(str(data_path), str(metadata_path), hyperparameters)
assert returned is not None
# 4) Check that output file was created
after_path = str(data_path).replace(".h5ad", "_medians.json") # noqa: F841
# only if compute_medians was True
# but we had compute_medians=False, so let's instead check final .h5ad
final_h5ad = str(data_path).replace("before.h5ad", "before.h5ad") # or the same name if not changed
# Actually, the pipeline calls: data_path.replace(load_dir, save_dir), which is the same
# So let's just see if it wrote "before.h5ad" properly
assert os.path.exists(final_h5ad)