vipr_reflectometry.reflectorch.load_model package

Submodules

vipr_reflectometry.reflectorch.load_model.reflectorch_model_loader module

class vipr_reflectometry.reflectorch.load_model.reflectorch_model_loader.ReflectorchModelLoader(**kw: Any)

Bases: ModelLoaderHandler

class Meta

Bases: object

label = 'reflectorch'
class vipr_reflectometry.reflectorch.load_model.reflectorch_model_loader.ReflectorchModelLoaderParams(*, config_name: str = 'b_mc_point_xray_conv_standard_L2_InputQ', model_name: str | None = None, root_dir: str | None = None, weights_format: str = 'safetensors', repo_id: str = 'valentinsingularity/reflectivity', device: str = 'cpu')

Bases: BaseModel

Parameters for the reflectorch model loader.

config_name: str
device: str
model_config: ClassVar[ConfigDict] = {'extra': 'forbid'}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_name: str | None
repo_id: str
root_dir: str | None
weights_format: str

Module contents

class vipr_reflectometry.reflectorch.load_model.TimedInferenceModel(config_name: str | None = None, model_name: str | None = None, root_dir: str | None = None, weights_format: str = 'safetensors', repo_id: str = 'valentinsingularity/reflectivity', trainer: PointEstimatorTrainer | None = None, device='cuda')

Bases: InferenceModel

InferenceModel with timing measurements for NN prediction and polishing.

Extends reflectorch.InferenceModel to add detailed timing metrics:

  • nn_prediction_s: Time for neural network inference

  • fitting_s: Time for scipy optimization (polishing)

  • total_s: Combined execution time

The timing information is automatically added to the prediction result dict under the 'timing' key.

Example:

>>> model = TimedInferenceModel(config_name='mc66', device='cpu')
>>> result = model.predict(curve, q_values, prior_bounds, polish_prediction=True)
>>> print(result['timing'])
{'nn_prediction_s': 0.123, 'fitting_s': 2.456, 'total_s': 2.579}
predict(*args, polish_prediction=False, **kwargs)

Override predict to inject timing measurements.

Behaves identically to InferenceModel.predict(), but adds a 'timing' dictionary to the result containing:

  • nn_prediction_s – neural network inference time

  • fitting_s – scipy fitting time (0.0 when polish_prediction=False)

  • total_s – total time (nn_prediction_s + fitting_s)

Parameters:
  • *args – Positional arguments forwarded to InferenceModel.predict().

  • polish_prediction (bool) – Whether to perform polishing/fitting.

  • **kwargs – Keyword arguments forwarded to InferenceModel.predict().

Returns:

Prediction results with an added 'timing' key.

Return type:

dict