backwardcompatibilityml.widgets.model_comparison package


backwardcompatibilityml.widgets.model_comparison.model_comparison module

class backwardcompatibilityml.widgets.model_comparison.model_comparison.ModelComparison(h1, h2, dataset, performance_metric=<function model_accuracy>, port=None, get_instance_image_by_id=None, get_instance_metadata=None, device='cpu')

Bases: object

Model Comparison widget

The ModelComparison class is an interactive widget intended for use within a Jupyter Notebook. It provides an interactive UI for the user that allows the user to:

  1. Compare two models h1 and h2 on a dataset with regard to compatibility.
  2. The comparison is run by comparing the set of classification errors that h1 and h2 make on the dataset.
  3. The Venn Diagram plot within the widget provides a breakdown of the overlap between the sets of classification errors made by h1 and h2.
  4. The bar chart indicates the number of errors made by h2 that are not made by h1 on a per class basis.
  5. The error instances table, provides an exploratory view to allow the user to explore the instances which h1 and h2 have misclassified. This table is linked to the Venn Diagram and Bar Charts, so that the user may filter the error instances displayed in the table by clicking on regions of those components.
  • h1 – The reference model being used.
  • h2 – The model that we want to compare against model h1.
  • dataset – The list of dataset samples as (batch_ids, input, target). This data needs to be batched.
  • performance_metric
    A function to evaluate model performance. The function is expected to have the following signature:
    metric(model, dataset, device)
    model: The model being evaluated dataset: The dataset as a list of (input, target) pairs device: The device Pytorch is using for training - “cpu” or “cuda”

    If unspecified, then accuracy is used.

  • port – An integer value to indicate the port to which the Flask service should bind.
  • get_instance_image_by_id
    A function that returns an image representation of the data corresponding to the instance id, in PNG format. It should be a function of the form:
    instance_id: An integer instance id

    And should return a PNG image.

  • get_instance_metadata
    A function that returns a text string representation of some metadata corresponding to the instance id. It should be a function of the form:
    instance_id: An integer instance id

    And should return a string.

  • device – A string with values either “cpu” or “cuda” to indicate the device that Pytorch is performing training on. By default this value is “cpu”. But in case your models reside on the GPU, make sure to set this to “cuda”. This makes sure that the input and target tensors are transferred to the GPU during training.

A small helper function to return a dictionary of the environment type and the base url of the Flask service for the environment type.

Parameters:flask_service_env – An instance of an environment from rai_core_flask.environments.
Returns:A dictionary of the environment type specified as a string, and the base url to be used when accessing the Flask service for this environment type.
backwardcompatibilityml.widgets.model_comparison.model_comparison.init_app_routes(app, comparison_manager)

Defines the API for the Flask app.

  • app – The Flask app to use for the API.
  • comparison_manager – The ComparisonManager that will be controlled by the API.
backwardcompatibilityml.widgets.model_comparison.model_comparison.render_widget_html(api_service_environment, data)

Renders the HTML for the compatibility analysis widget.

Parameters:api_service_environment – A dictionary of the environment type, the base URL, and the port for the Flask service.
Returns:The widget HTML rendered as a string.

Module contents