Cnnamon Documentation

A modular framework for DNA sequence-based Convolutional Neural Network (CNN) model training and explainability.

What is Cnnamon?

Cnnamon is a comprehensive Python framework designed for building, training, and interpreting 1D CNNs for DNA sequence analysis. It provides:

Key Features

🔬 Filter Visualization

Extract and visualize sequence motifs learned by convolutional filters using multiple methods (softmax, top-activating, positive-activating, significant-activating).

📊 Filter Importance

Identify the most important filters in your model through perturbation analysis.

🔗 Filter Clustering

Group functionally similar filters based on activation profiles.

🎯 Class Enrichment

Discover which filters are enriched for specific output classes.

Installation

Requirements

Install from Source

git clone https://github.com/yourusername/cnnamon.git
cd cnnamon
pip install -e .

Dependencies

pip install tensorflow numpy pandas matplotlib seaborn scikit-learn logomaker pycirclize scipy statsmodels joblib tqdm

Quick Start

Basic Workflow

import cnnamon
from cnnamon.util import PrepareData, KerasModelBuilder
from cnnamon.CNN1D import FilterVisualize, FilterImportance

# 1. Prepare data
preparer = PrepareData(
    intervalfile="peaks.bed",
    genomefasta="genome.fa",
    outdir="output/"
)
train, test, val = preparer.run()

# 2. Build and train model
model = KerasModelBuilder.from_json("model_config.json")
model.train(train['x'], train['y'], val['x'], val['y'])
model.save("my_model.keras")

# 3. Visualize learned motifs
motifs = FilterVisualize.top_activating(model.model, test['x'])
motifs.to_motifs(savefig="motifs.png")

# 4. Analyze filter importance
importance = FilterImportance(model.model, test, n_cores=4)
importance.boxplot(savefig="importance.png")

PrepareData

The PrepareData class handles genomic sequence data preparation for CNN training, including sequence extraction, one-hot encoding, and flexible dataset splitting.

Initialization

PrepareData(intervalfile, genomefasta, outdir, **kwargs)

Initialize data preparer with genomic intervals and reference genome.

Core Parameters:

intervalfile (str): Path to BED file containing genomic intervals
genomefasta (str): Path to reference genome FASTA file
outdir (str): Output directory for processed data

Configuration Options (**kwargs):

split_segmentation (str): Splitting strategy. Options: 'random', 'chromosome', 'custom'. (Default: 'random')
ratios (list): Train/Test/Validation split ratios [train, test, val]. Used for 'random' and 'chromosome' modes. (Default: [0.6, 0.2, 0.2])
augment_RC (bool): If True, augments the dataset with Reverse Complement sequences. (Default: False)
seed (int): Random seed for reproducibility. (Optional)
save_splits (str): Set to "1" to save .npy files to disk. (Default: "0")

Custom Split Parameters (if split_segmentation='custom'):

train_chr_list (list): List of chromosomes for training (e.g., ['chr1', 'chr2'])
val_chr_list (list): List of chromosomes for validation
test_chr_list (list): List of chromosomes for testing

Methods

run() → Tuple[Dict, Dict, Dict]

Execute full data preparation pipeline.

Returns:

(train_dict, test_dict, validation_dict) where each contains:

  • 'x': One-hot encoded sequences (N × L × 4)
  • 'y': Label array
  • 'info': Interval information (chr, start, end)
load_splits_from_disk(directory_path) → Tuple[Dict, Dict, Dict]
STATIC

Load previously saved data splits (if save_splits="1" was used).

Parameters:

directory_path (str): Path to saved split directory

Splitting Strategies

Cnnamon supports three robust strategies for dividing your genomic data:

🔀 Random Split ('random')

Pools all intervals from the genome and splits them randomly based on ratios. Best for non-overlapping, independent peaks.

🧬 Chromosome Split ('chromosome')

Randomly selects entire chromosomes to hold out for testing/validation based on ratios. Prevents data leakage between similar sequences on the same chromosome.

🛠 Custom Split ('custom')

Manually assign specific chromosomes to each set using *_chr_list parameters. Ideal for benchmarking against standard splits (e.g., "chr8 for test").

Examples

1. Random Split with Augmentation

preparer = PrepareData(
    intervalfile="peaks.bed", genomefasta="hg38.fa", outdir="out_rnd/",
    split_segmentation="random",
    ratios=[0.8, 0.1, 0.1],
    augment_RC=True,    # Double data with Reverse Complements
    seed=42
)
train, test, val = preparer.run()

2. Chromosome Holdout

preparer = PrepareData(
    intervalfile="peaks.bed", genomefasta="hg38.fa", outdir="out_chr/",
    split_segmentation="chromosome",
    ratios=[0.7, 0.15, 0.15],
    seed=123
)

3. Custom Chromosome Sets

preparer = PrepareData(
    intervalfile="peaks.bed", genomefasta="hg38.fa", outdir="out_custom/",
    split_segmentation="custom",
    train_chr_list=["chr1", "chr2", "chr3"],
    val_chr_list=["chr4"],
    test_chr_list=["chr5"]
)

KerasModelBuilder

A flexible wrapper for building, training, and evaluating Keras models directly from JSON configurations. It supports the full Keras Functional/Sequential API capabilities by dynamically loading layers, optimizers, and callbacks.

1. Configuration Structure (JSON)

The configuration file is divided into four main sections: model, compile, training_params, and callbacks.

A. Model Architecture

Define layers using the standard Keras class_name and config dictionary. You can use any layer available in tf.keras.layers.

"model": {
    "layers": [
        {
            "class_name": "Conv1D",
            "config": {
                "filters": 64,
                "kernel_size": 8,
                "activation": "relu",
                "input_shape": [400, 4]
            }
        },
        { "class_name": "MaxPooling1D", "config": { "pool_size": 2 } },
        { "class_name": "Flatten", "config": {} },
        { "class_name": "Dense", "config": { "units": 1, "activation": "sigmoid" } }
    ]
}

B. Compilation

Specifies the optimizer, loss function, and metrics. Optimizers can be simple strings or detailed objects.

"compile": {
    "optimizer": {
        "class_name": "Adam",
        "config": { "learning_rate": 0.0001 }
    },
    "loss": "binary_crossentropy",
    "metrics": ["accuracy"]
}

C. Training Parameters

Arguments passed directly to the model.fit() method.

"training_params": {
    "epochs": 50,
    "batch_size": 32,
    "verbose": 1
}

D. Callbacks

Define Keras callbacks. Keys match the callback class names in tf.keras.callbacks.

"callbacks": {
    "EarlyStopping": {
        "monitor": "val_loss",
        "patience": 5,
        "restore_best_weights": true
    },
    "CSVLogger": {
        "filename": "logs/training_log.csv"
    }
}

2. Core Methods

KerasModelBuilder.from_json(json_path_or_dict)
CLASS METHOD

Initialize and build the model directly from a JSON file path or a dictionary.

train(x_train, y_train, x_val=None, y_val=None)

Train the model using the parameters and callbacks defined in the JSON configuration.

evaluate(x_test, y_test)

Run standard Keras evaluation (returns loss and metrics).

plot_history(title="Training History", savefig=None)

Plot training vs. validation metrics (loss, accuracy, etc.) over epochs.

Parameters:

title (str): Main title for the plot.
savefig (str): Path to save the plot.
save(path) / load(path)

Save the trained model to a .keras file or load a pre-trained one.

3. Advanced Evaluation (model.eval)

The builder includes a helper eval object for generating publication-ready plots and metrics.

model.eval.cm(x_test, y_test, class_names=None, title="Confusion Matrix", ...)

Plot a confusion matrix with row-normalized proportions (recall) and raw counts.

Parameters:

class_names (list): List of label names for the axes (e.g. ["Negative", "Positive"]).
title (str): Custom title for the plot.
xlabel (str): Label for the x-axis.
ylabel (str): Label for the y-axis.
savefig (str): Path to save the plot image.
model.eval.roc(x_test, y_test, class_names=None, title="ROC Curve", ...)

Plot the Receiver Operating Characteristic (ROC) curve with AUC score (Binary only).

Parameters:

class_names (list): List of label names.
title (str): Custom title for the plot.
savefig (str): Path to save the plot.
model.eval.auc(x_test, y_test)

Calculate and print the AUC score (supports One-vs-Rest for multiclass).

4. Full Example

from cnnamon.util import KerasModelBuilder

# 1. Initialize and Build
builder = KerasModelBuilder.from_json("config.json")

# 2. Train
builder.train(train_x, train_y, val_x, val_y)

# 3. Plot History
builder.plot_history(savefig="plots/history.png")

# 4. Advanced Evaluation
# Confusion Matrix with custom labels
builder.eval.cm(
    test_x, test_y, 
    class_names=["Non-Promoter", "Promoter"],
    title="Promoter Prediction Performance",
    savefig="plots/cm.png"
)

# ROC Curve
builder.eval.roc(
    test_x, test_y,
    title="Model ROC",
    savefig="plots/roc.png"
)

Filter Visualization

The FilterVisualize class is your primary tool for interpreting what the Convolutional filters in your model have learned. It provides multiple methods to convert the raw numerical weights of your model into human-readable sequence motifs (Position Frequency Matrices).

1. Visualization Methods

FilterVisualize.softmax(model, background=None) → MotifSet
STATIC

The "Quick & Dirty" Method. Applies a Softmax function directly to the filter weights. This gives you a theoretical idea of what the filter wants to see, but doesn't tell you if it actually fires on real data.

Parameters:

model (keras.Model): Trained model.
background (dict, optional): Background frequencies (e.g. {'A':0.25, ...}). Default is uniform.
FilterVisualize.top_activating(model, x_data, percentile=90.0, include_all_positive=False, n_cores=1) → MotifSet
STATIC RECOMMENDED

The "Data-Driven" Method. Feeds your actual sequences through the model and extracts the subsequences that produce the highest activation scores. This is the most accurate representation of the biological signals your model has discovered.

Parameters:

model (keras.Model): Trained model.
x_data (np.ndarray): Input sequences (one-hot encoded).
percentile (float): The percentile cutoff for "top" activations (default: 90.0). Higher means stricter/cleaner motifs.
include_all_positive (bool): If True, ignores percentile and builds motifs from all subsequences with activation > 0. Useful for rare motifs.
n_cores (int): Number of parallel cores to use.
FilterVisualize.pos_activating(model, n_cores=1, background=None) → MotifSet
STATIC

The "Consensus" Method. Builds a motif by looking only at the positive weights in the filter kernel. It constructs a "consensus" sequence that would theoretically maximize activation.

FilterVisualize.significant_activating(model, x_data, n_perturbations=100, p_value_cutoff=0.01, n_cores=1) → MotifSet
STATIC RIGOROUS

The "Statistical" Method. Performs a perturbation test to compare real filter activations against a "null" model of random noise. It only builds motifs from subsequences that are statistically significant (p < cutoff). Warning: Computationally expensive.

Parameters:

n_perturbations (int): Number of random filters to generate per real filter (default: 100).
p_value_cutoff (float): Significance threshold (default: 0.01).

2. The MotifSet Object

All visualization methods return a MotifSet object, which behaves like a dictionary of DataFrames but has powerful export capabilities.

motifs.to_motifs(savefig=None, figsize=None)

Plots all filters in a grid layout as Sequence Logos (Information Content bits).

motifs.to_meme(outfile)

Exports all motifs to a MEME-formatted text file. Compatible with tools like TOMTOM (for matching against JASPAR/HOCOMOCO databases) or FIMO.

motifs.to_svgs(outdir)

Saves each individual filter as a high-quality SVG vector graphic in the specified directory. Great for publications or clustering analysis.

3. Examples

Example 1: Standard Discovery

from cnnamon.CNN1D import FilterVisualize

# 1. Generate motifs from Test Set (Top 10% of activators)
motifs = FilterVisualize.top_activating(
    model, 
    test['x'], 
    percentile=90.0, 
    n_cores=4
)

# 2. Visualize in a grid
motifs.to_motifs(savefig="figures/all_motifs.png")

# 3. Export for external tools
motifs.to_meme("results/learned_motifs.meme")

Example 2: Analyzing Rare Motifs

If you suspect a filter detects a very rare signal (e.g. only 50 instances in 10,000 sequences), a 90th percentile cutoff might be too strict. Use include_all_positive=True to capture every instance.

rare_motifs = FilterVisualize.top_activating(
    model, 
    test['x'], 
    include_all_positive=True,  # Capture ALL positive activations
    n_cores=4
)
rare_motifs.to_motifs(savefig="figures/rare_motifs.png")

Example 3: High-Rigor Verification

Use the statistical method to filter out noise "motifs" that are just random GC-rich patches.

rigorous_motifs = FilterVisualize.significant_activating(
    model, 
    test['x'], 
    n_perturbations=200, 
    p_value_cutoff=0.01,
    n_cores=8
)
# Only statistically significant patterns will appear here
rigorous_motifs.to_motifs()

Filter Importance

Identify the most important filters through perturbation analysis.

Initialization

FilterImportance(model, testset, n_iterations=10, method='mean', n_cores=1, batch_size=32)

Runs the full perturbation experiment and ranks filters by importance.

Parameters:

model (keras.Model): Trained model
testset (Dict): Data dict with 'x' and 'y' keys
n_iterations (int): Perturbation rounds per filter
method (str): 'mean' or 'median' for aggregation
batch_size (int): Evaluation batch size (increase for GPU speedup)

How It Works

  1. Calculate baseline model loss on test data
  2. For each filter:
    • Perturb its weights with Gaussian noise (preserving mean/std)
    • Evaluate model loss with perturbed filter
    • Repeat n_iterations times
  3. Rank filters by average loss increase

Visualization Methods

boxplot(savefig=None, **kwargs)

Plot distribution of perturbed losses as boxplots, ordered by importance.

violin(savefig=None, **kwargs)

Plot distribution of perturbed losses as violin plots, ordered by importance.

Example

from cnnamon.CNN1D import FilterImportance

importance = FilterImportance(
    model,
    testset=test,
    n_iterations=20,
    method='mean',
    n_cores=8,
    batch_size=256
)

# Visualize results
importance.boxplot(savefig="importance_boxplot.png")
importance.violin(savefig="importance_violin.png")

# Access ranking
print("Most important filters:", importance.filter_importance_ranking[:5])
⚠️ Performance Note: This analysis can be slow. Increase batch_size (e.g., 256-512) for GPU acceleration, and use n_cores for CPU parallelization.

Filter Clustering

Group functionally similar filters based on their activation profiles.

Initialization

FilterClustering(model, testset, target_layer=None, linkage_method='ward')

Performs hierarchical clustering of filters.

Parameters:

model (keras.Model): Trained model
testset (Dict): Data dict with 'x' and 'y' keys
target_layer (str, optional): Conv1D layer name (uses first if None)
linkage_method (str): Clustering linkage ('ward', 'average', 'complete')

Visualization Methods

plot_heatmap(savefig=None, **kwargs)

Plot clustered heatmap showing filter activation patterns.

plot_dendrogram(savefig=None, **kwargs)

Plot hierarchical clustering dendrogram.

plot_circlize(savefig=None, **kwargs)

Plot circular phylogenetic tree of filter relationships.

get_clusters(n_clusters) → Dict[str, List[str]]

Extract filter groups by cutting dendrogram at specified number of clusters.

Returns:

Dictionary mapping cluster names to filter lists.

Example

from cnnamon.CNN1D import FilterClustering

clustering = FilterClustering(
    model,
    testset=test,
    target_layer='conv1d_0',
    linkage_method='ward'
)

# Visualize clustering
clustering.plot_heatmap(savefig="cluster_heatmap.png")
clustering.plot_dendrogram(savefig="dendrogram.png")
clustering.plot_circlize(savefig="circular_tree.png")

# Get cluster assignments
clusters = clustering.get_clusters(n_clusters=5)
for cluster_name, filters in clusters.items():
    print(f"{cluster_name}: {filters}")
💡 Use Case: Clustering helps identify redundant filters or functionally related motif groups. Combine with FilterVisualize to see what motifs each cluster represents.

Filter Enrichment

Discover which filters are enriched for specific output classes through statistical testing.

Initialization

FilterEnrichment(model, testset, target_layer=None, class_names=None, method='mann-whitney', n_cores=1)

Performs class-specific enrichment analysis with FDR correction.

Parameters:

model (keras.Model): Trained model
testset (Dict): Data dict with 'x' and 'y' keys
target_layer (str, optional): Conv1D layer name
class_names (List[str], optional): Custom class labels
method (str): 'mann-whitney'
n_cores (int): Parallel processing cores

How It Works

  1. Extract filter activations for all sequences
  2. For each filter and each class:
    • Split activations into class-positive and class-negative groups
    • Calculate log2 fold change (enrichment direction)
    • Perform statistical test (Mann-Whitney U
  3. Apply FDR correction (Benjamini-Hochberg) across all tests
  4. Identify significantly enriched filter-class pairs

Methods

get_results(value_type='q-value', q_cutoff=1.0, logFC_cutoff=0.0) → DataFrame

Extract filtered enrichment results.

Parameters:

value_type (str): 'q-value', 'p-value', or 'logFC'
q_cutoff (float): Maximum q-value threshold
logFC_cutoff (float): Minimum absolute log2 fold change
plot_heatmap(q_cutoff=0.05, savefig=None, **kwargs)

Plot log2 fold change heatmap with significance markers (*) for q ≤ cutoff.

Example

from cnnamon.CNN1D import FilterEnrichment

enrichment = FilterEnrichment(
    model,
    testset=test,
    class_names=['Enhancer', 'Promoter', 'Silencer'],
    method='mann-whitney',
    n_cores=4
)

# Visualize enrichment
enrichment.plot_heatmap(
    q_cutoff=0.05,
    savefig="enrichment_heatmap.png"
)

# Get significant results
significant = enrichment.get_results(
    value_type='logFC',
    q_cutoff=0.05,
    logFC_cutoff=1.0
)
print(significant)

# Access raw data
q_values = enrichment.q_values  # DataFrame: filters × classes
logFCs = enrichment.logFCs      # DataFrame: filters × classes
💡 Interpretation:
  • Positive logFC: Filter activates more strongly for sequences in that class
  • Negative logFC: Filter activates less for sequences in that class
  • * marker: q-value ≤ significance threshold (statistically significant)