Filters & Hooks Integration

This document details how the Reflectometry Plugin integrates with VIPR’s filter and hook systems for workflow customization.

Filter System

Filters transform vipr.plugins.inference.dataset.DataSet objects in the inference pipeline. The Reflectometry Plugin uses filters for preprocessing steps like data cleaning and interpolation.

Filter Weight System

Filters execute in weight-order (lowest first):

weight=-10: NeutronDataCleaner.clean_experimental_data
weight=0:   Reflectorch._preprocess_interpolate
weight=0:   FlowPreprocessor._preprocess_flow

Pattern: Cleaning (weight=-10) runs before interpolation (weight=0).

Neutron Data Cleaner Filter

Implementation Pattern

The vipr.plugins.discovery.decorators.discover_filter() decorator registers the filter with VIPR’s discovery system (see Discovery Plugin for details):

from vipr.plugins.discovery.decorators import discover_filter
from vipr.plugins.inference.dataset import DataSet

class NeutronDataCleaner:
    def __init__(self, app):
        self.app = app
    
    @discover_filter(
        'INFERENCE_PREPROCESS_PRE_FILTER',  # Filter slot
        weight=-10,                          # Execution order
        enabled_in_config=False,             # Requires explicit config
        parameters={
            'error_threshold': {
                'type': 'float',
                'default': 0.5,
                'help': 'Relative error threshold (dR/R) for filtering'
            },
            'consecutive_errors': {
                'type': 'int',
                'default': 3,
                'help': 'Consecutive high-error points to trigger truncation'
            },
            'remove_single_errors': {
                'type': 'bool',
                'default': False,
                'help': 'Remove isolated high-error points'
            }
        }
    )
    def clean_experimental_data(self, data: DataSet, **kwargs) -> DataSet:
        """
        Clean experimental reflectometry data.
        
        Removes:
        1. Negative intensity values (R < 0)
        2. High-error points (dR/R > threshold)
        """
        # Extract parameters from kwargs
        threshold = float(kwargs.get('error_threshold', 0.5))
        consecutive = int(kwargs.get('consecutive_errors', 3))
        remove_singles = bool(kwargs.get('remove_single_errors', False))
        
        # Process each spectrum independently
        cleaned_q, cleaned_r = [], []
        cleaned_dq = [] if data.has_x_errors() else None
        cleaned_dy = [] if data.has_y_errors() else None
        
        for i in range(data.batch_size):
            # Clean spectrum
            q_clean, r_clean, dq_clean, dr_clean = self._clean_single_spectrum(
                data.x[i], data.y[i], 
                data.dx[i] if data.has_x_errors() else None,
                data.dy[i] if data.has_y_errors() else None,
                threshold, consecutive, remove_singles
            )
            
            cleaned_q.append(q_clean)
            cleaned_r.append(r_clean)
            if dq_clean is not None:
                cleaned_dq.append(dq_clean)
            if dr_clean is not None:
                cleaned_dy.append(dr_clean)
        
        # Return padded DataSet (uniform shape)
        return self._create_padded_dataset(
            cleaned_q, cleaned_r, cleaned_dq, cleaned_dy, data
        )

Key Patterns:

  • vipr.plugins.discovery.decorators.discover_filter() decorator specifies filter slot, weight, and parameters

  • Method signature: (self, data: {py:class}~vipr.plugins.inference.dataset.DataSet, **kwargs) -> {py:class}~vipr.plugins.inference.dataset.DataSet``

  • enabled_in_config=False requires explicit YAML activation

  • Parameters extracted from kwargs (provided by config)

  • Returns transformed vipr.plugins.inference.dataset.DataSet (immutable pattern)

vipr.plugins.inference.dataset.DataSet Error Propagation

Filters must handle error arrays (dx, dy) correctly:

def _clean_single_spectrum(self, q, r, dq, dr, threshold, consecutive, remove_singles):
    """Clean single spectrum maintaining error consistency."""
    
    # Step 1: Remove negative intensities
    mask = r > 0
    q, r = q[mask], r[mask]
    if dq is not None:
        dq = dq[mask]
    if dr is not None:
        dr = dr[mask]
    
    # Step 2: Filter high error bars (if dr available)
    if dr is not None and len(r) > 0:
        rel_error = dr / r
        
        # Remove singles if requested
        if remove_singles:
            mask = rel_error < threshold
            q, r, dr = q[mask], r[mask], dr[mask]
            if dq is not None:
                dq = dq[mask]
        
        # Truncate at consecutive high errors
        for idx, err in enumerate(rel_error):
            if err >= threshold:
                count += 1
                if count >= consecutive:
                    cutoff = idx - consecutive + 1
                    q, r, dr = q[:cutoff], r[:cutoff], dr[:cutoff]
                    if dq is not None:
                        dq = dq[:cutoff]
                    break
    
    return q, r, dq, dr

Key Pattern: Error arrays (dx, dy) must be transformed consistently with data arrays (x, y) in vipr.plugins.inference.dataset.DataSet transformations.

Handling Variable-Length Spectra

After cleaning, spectra have different lengths. The filter pads to uniform vipr.plugins.inference.dataset.DataSet shape:

def _create_padded_dataset(self, cleaned_q, cleaned_r, cleaned_dq, cleaned_dy, original_data):
    """Create uniform DataSet from variable-length cleaned spectra."""
    
    cleaned_lengths = [len(q) for q in cleaned_q]
    max_len = max(cleaned_lengths)
    batch_size = original_data.batch_size
    
    # Create zero-padded arrays
    q_uniform = np.zeros((batch_size, max_len), dtype=np.float64)
    r_uniform = np.zeros((batch_size, max_len), dtype=np.float64)
    
    # Fill with cleaned data
    for i in range(batch_size):
        n = cleaned_lengths[i]
        q_uniform[i, :n] = cleaned_q[i]
        r_uniform[i, :n] = cleaned_r[i]
    
    # Return DataSet with padding metadata
    return DataSet(
        x=q_uniform,
        y=r_uniform,
        dx=dq_uniform if cleaned_dq else None,
        dy=dr_uniform if cleaned_dy else None,
        metadata={
            **original_data.metadata,
            'cleaned': True,
            'cleaned_lengths': cleaned_lengths,
            'padded': True
        }
    )

Pattern: Zero-padding at Q=0, R=0 is ignored by subsequent interpolation (real data has Q>0). The vipr.plugins.inference.dataset.DataSet stores padding metadata.

Reflectorch Interpolation Filter

Implementation via Extension Pattern

class Reflectorch:
    """Extension providing Reflectorch-specific filters."""
    
    def __init__(self, app):
        self.app = app
    
    @discover_filter(
        'INFERENCE_PREPROCESS_PRE_FILTER',
        weight=0,
        enabled_in_config=False,
        parameters={}
    )
    def _preprocess_interpolate(self, data: DataSet, **kwargs) -> DataSet:
        """Interpolate experimental data to model's Q-grid."""
        
        # Get model from workflow
        model = self.app.inference.model
        
        # Get model's Q-grid
        model_q = model.get_q_grid()
        
        # Interpolate each spectrum
        interpolated_x = []
        interpolated_y = []
        interpolated_dy = []
        
        for i in range(data.batch_size):
            # Skip padded regions (Q=0)
            valid_mask = data.x[i] > 0
            valid_q = data.x[i][valid_mask]
            valid_r = data.y[i][valid_mask]
            
            # Interpolate reflectivity
            interp_r = np.interp(model_q, valid_q, valid_r)
            interpolated_y.append(interp_r)
            
            # Interpolate errors if present
            if data.has_y_errors():
                valid_dr = data.dy[i][valid_mask]
                interp_dr = np.interp(model_q, valid_q, valid_dr)
                interpolated_dy.append(interp_dr)
        
        # Return interpolated DataSet
        return DataSet(
            x=np.tile(model_q, (data.batch_size, 1)),  # All spectra on same grid
            y=np.array(interpolated_y),
            dy=np.array(interpolated_dy) if interpolated_dy else None,
            metadata={**data.metadata, 'interpolated': True}
        )

Key Patterns:

  • Filter implemented as extension method (app.extend('reflectorch', Reflectorch(app)))

  • Accesses model from workflow: self.app.inference.model

  • Handles padded data (Q=0 regions)

  • Transforms error arrays consistently in vipr.plugins.inference.dataset.DataSet

Flow Preprocessor Filter

class FlowPreprocessor:
    """Preprocessing for normalizing flow models (CINN, NSF, MAF)."""
    
    def __init__(self, app):
        self.app = app
    
    @discover_filter('INFERENCE_PREPROCESS_PRE_FILTER', enabled_in_config=False)
    def _preprocess_flow(self, data: DataSet, **kwargs) -> DataSet:
        """
        Flow Network preprocessing for reflectivity curves.
        
        Performs (vectorized for batch efficiency):
        1. Q-grid interpolation to model's grid
        2. Flow-specific curve scaling (using reflectorch components)
        3. Proper tensor formatting for flow models
        """
        model = kwargs.get('model')
        config_name = self._get_config_name(kwargs)
        
        # Load q_generator and curves_scaler from config or model
        q_generator, curves_scaler = self._load_components(model, config_name)
        
        # Get model's q-grid
        q_model = q_generator.q.cpu().numpy()
        
        # Vectorized logarithmic interpolation (physically correct for reflectivity)
        curves_interp = np.array([
            interp_reflectivity(q_model, q_exp, curve_exp) 
            for q_exp, curve_exp in zip(data.x, data.y)
        ])
        
        # Batch-wise scaling using reflectorch curves_scaler
        curves_tensor = torch.from_numpy(curves_interp).float()
        curves_scaled = curves_scaler.scale(curves_tensor).numpy()
        
        # Tile q-grid for batch
        q_model_batch = np.tile(q_model, (data.batch_size, 1))
        
        return DataSet(
            x=q_model_batch,
            y=curves_scaled,
            dx=None,
            dy=None,
            metadata={
                **data.metadata,
                'flow_preprocessed': True,
                'q_grid_interpolated': True,
                'curve_scaled': True
            }
        )

Pattern: Flow preprocessing is more complex than simple log-transform - it uses reflectorch’s q_generator and curves_scaler components for proper normalization.

Hook System

Hooks execute at specific workflow points without transforming data. See vipr.core.exc for hook system details.

Environment Setup Hook

def setup_reflectometry_env_defaults(app):
    """Hook: Set environment defaults before inference starts."""
    
    working_dir = Path.cwd()
    
    defaults = {
        'REFLECTORCH_ROOT_DIR': working_dir / 'storage' / 'reflectorch',
    }
    
    for key, default_path in defaults.items():
        if key not in os.environ:
            os.environ[key] = str(default_path)
            app.log.debug(f"Set default {key}={default_path}")

# Register in plugin loading
def load(app):
    app.hook.register('INFERENCE_BEFORE_START_HOOK', setup_reflectometry_env_defaults)

Key Patterns:

  • Hook signature: (app) -> None

  • No return value (side effects only)

  • Registered in plugin load(app) function

  • Runs before inference workflow starts

Visualization Hooks (Postprocessing)

class BasicCornerPlot:
    """Postprocessing hook for corner plot visualization."""
    
    def __init__(self, app):
        self.app = app
    
    @discover_hook(
        'INFERENCE_POSTPROCESS_PRE_PRE_FILTER_HOOK',
        weight=0,
        enabled_in_config=False
    )
    def _create_basic_corner_plot(self, app):
        """Generate corner plot from posterior samples."""
        
        # Access prediction results from workflow
        result = app.inference.result
        samples = result['posterior_samples']
        
        # Create corner plot
        fig = corner.corner(samples)
        
        # Store in data collector for UI
        if hasattr(app, 'flow_dc'):
            app.flow_dc.add_plot('corner_plot', fig)
        
        app.log.info("Corner plot created")

Key Patterns:

Configuration

Filter Activation

Filters must be explicitly enabled in YAML:

vipr:
  inference:
    filters:
      INFERENCE_PREPROCESS_PRE_FILTER:
        # Neutron data cleaning (weight=-10)
        - class: vipr_reflectometry.shared.preprocessing.neutron_data_cleaner.NeutronDataCleaner
          enabled: true
          method: clean_experimental_data
          parameters:
            error_threshold: 0.5
            consecutive_errors: 3
            remove_single_errors: false
          weight: -10
        
        # Interpolation (weight=0)
        - class: vipr_reflectometry.reflectorch.reflectorch_extension.Reflectorch
          enabled: true
          method: _preprocess_interpolate
          parameters: null
          weight: 0

Configuration Structure:

  • Filter slot: INFERENCE_PREPROCESS_PRE_FILTER

  • Class path: Fully qualified Python path

  • Method: Filter method name

  • Parameters: Passed as kwargs to filter

  • Weight: Execution order (lower runs first)

  • Enabled: Explicit activation flag

Hook Activation

vipr:
  inference:
    hooks:
      INFERENCE_POSTPROCESS_PRE_PRE_FILTER_HOOK:
        - class: vipr_reflectometry.flow_models.postprocessors.basic_corner_plot.BasicCornerPlot
          enabled: true
          method: _create_basic_corner_plot
          parameters: null
          weight: 0

Pattern: Same configuration structure as filters, but no return value expected.

Filter Execution Flow

LoadDataStep (outputs DataSet)
    ↓
NormalizeStep (optional)
    ↓
PreprocessStep:
    ↓
    PRE_PRE_FILTER_HOOK (hooks before filters)
    ↓
    PREPROCESS_PRE_FILTER (weight=-10):
        → NeutronDataCleaner.clean_experimental_data
            Input:  DataSet (raw experimental data)
            Output: DataSet (cleaned, padded)
    ↓
    PREPROCESS_PRE_FILTER (weight=0):
        → Reflectorch._preprocess_interpolate
            Input:  DataSet (cleaned)
            Output: DataSet (interpolated to model grid)
    ↓
    POST_PRE_FILTER_HOOK (hooks after filters)
    ↓
PredictionStep (uses preprocessed DataSet)

Key Pattern: Multiple filters on same slot execute in weight order.

Key Takeaways

Filter Implementation Checklist

  1. ✅ Use vipr.plugins.discovery.decorators.discover_filter() decorator with slot, weight, parameters

  2. ✅ Signature: (self, data: {py:class}~vipr.plugins.inference.dataset.DataSet, **kwargs) -> {py:class}~vipr.plugins.inference.dataset.DataSet``

  3. ✅ Handle error arrays (dx, dy) consistently

  4. ✅ Return immutable vipr.plugins.inference.dataset.DataSet

  5. ✅ Require explicit config activation (enabled_in_config=False)

Hook Implementation Checklist

  1. ✅ Use vipr.plugins.discovery.decorators.discover_hook() decorator (or app.hook.register())

  2. ✅ Signature: (app) -> None

  3. ✅ No return value (side effects only)

  4. ✅ Access workflow state via app.inference.*

VIPR Framework Patterns Used