|
|
|
|
|
import hashlib |
|
|
import json |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
import anndata as ad |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import pytest |
|
|
from scipy.sparse import csr_matrix |
|
|
|
|
|
from teddy.data_processing.preprocessing.preprocess import ( |
|
|
compute_and_save_medians, |
|
|
filter_cells_by_gene_counts, |
|
|
filter_cells_by_mitochondrial_fraction, |
|
|
filter_highly_variable_genes, |
|
|
initialize_processed_layer, |
|
|
log_transform_layer, |
|
|
normalize_data_inplace, |
|
|
preprocess, |
|
|
set_raw_if_necessary, |
|
|
update_metadata, |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def synthetic_anndata(): |
|
|
""" |
|
|
Returns a small synthetic AnnData object for testing. |
|
|
""" |
|
|
X = np.array([[0, 2, 3], [4, 0, 6], [7, 8, 0], [0, 0, 1]], dtype=float) |
|
|
var = pd.DataFrame(index=["geneA", "geneB", "geneC"]) |
|
|
obs = pd.DataFrame(index=["cell1", "cell2", "cell3", "cell4"]) |
|
|
|
|
|
adata = ad.AnnData(X=csr_matrix(X), var=var, obs=obs) |
|
|
return adata |
|
|
|
|
|
|
|
|
def test_set_raw_if_necessary(synthetic_anndata): |
|
|
data = synthetic_anndata.copy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = set_raw_if_necessary(data) |
|
|
assert result is not None |
|
|
assert result.raw is not None |
|
|
assert (result.raw.X != 0).toarray().any() |
|
|
|
|
|
|
|
|
def test_initialize_processed_layer(synthetic_anndata): |
|
|
data = synthetic_anndata.copy() |
|
|
|
|
|
data.raw = data.copy() |
|
|
|
|
|
|
|
|
assert "processed" not in data.layers |
|
|
data = initialize_processed_layer(data) |
|
|
assert "processed" in data.layers |
|
|
|
|
|
assert (data.layers["processed"] != 0).toarray().any() |
|
|
|
|
|
|
|
|
def test_filter_cells_by_gene_counts(synthetic_anndata): |
|
|
data = synthetic_anndata.copy() |
|
|
data.raw = data |
|
|
data.layers["processed"] = data.X.copy() |
|
|
|
|
|
|
|
|
result = filter_cells_by_gene_counts(data, min_count=2) |
|
|
|
|
|
assert result.n_obs == 3 |
|
|
assert "cell4" not in result.obs_names |
|
|
|
|
|
|
|
|
def test_filter_cells_by_mitochondrial_fraction(synthetic_anndata): |
|
|
data = synthetic_anndata.copy() |
|
|
data.raw = data |
|
|
data.var["feature_name"] = ["MT-GENE1", "geneB", "geneC"] |
|
|
data.layers["processed"] = data.X.copy() |
|
|
|
|
|
|
|
|
result = filter_cells_by_mitochondrial_fraction(data, max_mito_prop=0.25) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert result.n_obs == 2 |
|
|
assert "cell2" not in result.obs_names |
|
|
assert "cell3" not in result.obs_names |
|
|
assert "cell1" in result.obs_names |
|
|
assert "cell4" in result.obs_names |
|
|
|
|
|
|
|
|
def test_normalize_data_inplace(synthetic_anndata): |
|
|
data = synthetic_anndata.copy() |
|
|
data.layers["processed"] = data.X.copy() |
|
|
|
|
|
if not isinstance(data.layers["processed"], csr_matrix): |
|
|
data.layers["processed"] = csr_matrix(data.layers["processed"]) |
|
|
|
|
|
normalize_data_inplace(data.layers["processed"], 1e4) |
|
|
|
|
|
row_sums = np.array(data.layers["processed"].sum(axis=1)).flatten() |
|
|
for s in row_sums: |
|
|
|
|
|
assert pytest.approx(s, 1e-4) == 10000.0 |
|
|
|
|
|
|
|
|
def test_log_transform_layer(synthetic_anndata): |
|
|
data = synthetic_anndata.copy() |
|
|
data.layers["processed"] = data.X.copy() |
|
|
log_transform_layer(data, "processed") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
val_after = data.layers["processed"][1, 0] |
|
|
assert pytest.approx(val_after, 1e-5) == np.log1p(4) |
|
|
|
|
|
|
|
|
def test_compute_and_save_medians(tmp_path, synthetic_anndata): |
|
|
""" |
|
|
Check that compute_and_save_medians writes out a JSON file with medians of non-zero entries. |
|
|
""" |
|
|
data = synthetic_anndata.copy() |
|
|
data.layers["processed"] = data.X.copy() |
|
|
|
|
|
mock_data_path = str(tmp_path / "test.h5ad") |
|
|
data.write_h5ad(mock_data_path) |
|
|
|
|
|
hyperparams = { |
|
|
"load_dir": str(tmp_path), |
|
|
"save_dir": str(tmp_path), |
|
|
"median_column": "index", |
|
|
} |
|
|
compute_and_save_medians(data, mock_data_path, hyperparams) |
|
|
|
|
|
median_path = mock_data_path.replace(".h5ad", "_medians.json") |
|
|
assert os.path.exists(median_path) |
|
|
with open(median_path, "r") as f: |
|
|
mdict = json.load(f) |
|
|
|
|
|
assert len(mdict) == data.n_vars |
|
|
|
|
|
|
|
|
def test_update_metadata(synthetic_anndata): |
|
|
data = synthetic_anndata.copy() |
|
|
metadata = {} |
|
|
hyperparams = {"some": "arg"} |
|
|
data.obs["some_col"] = [1, 2, 3, 4] |
|
|
new_meta = update_metadata(metadata, data, hyperparams) |
|
|
assert new_meta.get("cell_count", None) == 4 |
|
|
assert "processings_args" in new_meta or "processing_args" in new_meta |
|
|
|
|
|
|
|
|
def test_preprocess_pipeline_end_to_end(tmp_path): |
|
|
""" |
|
|
Test the entire preprocess pipeline with minimal hyperparameters on synthetic data_processing. |
|
|
""" |
|
|
|
|
|
|
|
|
X = np.random.rand(10, 5) |
|
|
|
|
|
|
|
|
adata = ad.AnnData( |
|
|
X=X, |
|
|
obs=pd.DataFrame(index=[f"cell_{i}" for i in range(X.shape[0])]), |
|
|
var=pd.DataFrame(index=[f"gene_{j}" for j in range(X.shape[1])]), |
|
|
) |
|
|
|
|
|
|
|
|
adata.raw = adata |
|
|
|
|
|
|
|
|
print("raw.X is type:", type(adata.raw.X)) |
|
|
|
|
|
|
|
|
raw_adata = adata.raw.to_adata() |
|
|
raw_adata.X = csr_matrix(raw_adata.X) |
|
|
adata.raw = raw_adata |
|
|
print("Now raw.X is type:", type(adata.raw.X)) |
|
|
data_path = tmp_path / "before.h5ad" |
|
|
adata.write_h5ad(data_path) |
|
|
|
|
|
metadata = {"sample": "test_sample"} |
|
|
metadata_path = tmp_path / "metadata.json" |
|
|
with open(metadata_path, "w") as f: |
|
|
json.dump(metadata, f, indent=4) |
|
|
|
|
|
|
|
|
hyperparameters = { |
|
|
"load_dir": str(tmp_path), |
|
|
"save_dir": str(tmp_path), |
|
|
"reference_id_only": False, |
|
|
"remove_assays": [], |
|
|
"min_gene_counts": 0, |
|
|
"max_mitochondrial_prop": None, |
|
|
"hvg_method": None, |
|
|
"normalized_total": None, |
|
|
"median_dict": None, |
|
|
"median_column": "index", |
|
|
"log1p": False, |
|
|
"compute_medians": False, |
|
|
} |
|
|
|
|
|
|
|
|
returned = preprocess(str(data_path), str(metadata_path), hyperparameters) |
|
|
assert returned is not None |
|
|
|
|
|
|
|
|
after_path = str(data_path).replace(".h5ad", "_medians.json") |
|
|
|
|
|
|
|
|
final_h5ad = str(data_path).replace("before.h5ad", "before.h5ad") |
|
|
|
|
|
|
|
|
assert os.path.exists(final_h5ad) |
|
|
|
|
|
|
|
|
|