vipr_reflectometry.reflectorch.load_model package¶
Submodules¶
vipr_reflectometry.reflectorch.load_model.reflectorch_model_loader module¶
- class vipr_reflectometry.reflectorch.load_model.reflectorch_model_loader.ReflectorchModelLoader(**kw: Any)¶
Bases:
ModelLoaderHandler
- class vipr_reflectometry.reflectorch.load_model.reflectorch_model_loader.ReflectorchModelLoaderParams(*, config_name: str = 'b_mc_point_xray_conv_standard_L2_InputQ', model_name: str | None = None, root_dir: str | None = None, weights_format: str = 'safetensors', repo_id: str = 'valentinsingularity/reflectivity', device: str = 'cpu')¶
Bases:
BaseModelParameters for the reflectorch model loader.
- model_config: ClassVar[ConfigDict] = {'extra': 'forbid'}¶
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Module contents¶
- class vipr_reflectometry.reflectorch.load_model.TimedInferenceModel(config_name: str | None = None, model_name: str | None = None, root_dir: str | None = None, weights_format: str = 'safetensors', repo_id: str = 'valentinsingularity/reflectivity', trainer: PointEstimatorTrainer | None = None, device='cuda')¶
Bases:
InferenceModelInferenceModel with timing measurements for NN prediction and polishing.
Extends
reflectorch.InferenceModelto add detailed timing metrics:nn_prediction_s: Time for neural network inferencefitting_s: Time for scipy optimization (polishing)total_s: Combined execution time
The timing information is automatically added to the prediction result dict under the
'timing'key.Example:
>>> model = TimedInferenceModel(config_name='mc66', device='cpu') >>> result = model.predict(curve, q_values, prior_bounds, polish_prediction=True) >>> print(result['timing']) {'nn_prediction_s': 0.123, 'fitting_s': 2.456, 'total_s': 2.579}
- predict(*args, polish_prediction=False, **kwargs)¶
Override predict to inject timing measurements.
Behaves identically to
InferenceModel.predict(), but adds a'timing'dictionary to the result containing:nn_prediction_s– neural network inference timefitting_s– scipy fitting time (0.0whenpolish_prediction=False)total_s– total time (nn_prediction_s + fitting_s)