Working with pytest on PyTorch
Prerequisites
To run the code in this post yourself, make sure you have torch
, ipytest>0.9
, and the plugin to be introduced pytest-pytorch
installed.
pip install torch 'ipytest>0.9' pytest-pytorch
Before we start testing, we need to configure ipytest
. We use the ipytest.autoconfig()
as base and add some pytest
CLI flags in order to get a concise output.
import ipytest
ipytest.autoconfig(defopts=False)
default_flags = ("--quiet", "--disable-warnings")
def _configure_ipytest(*additional_flags, collect_only=False):
addopts = list(default_flags)
if collect_only:
addopts.append("--collect-only")
addopts.extend(additional_flags)
ipytest.config(addopts=addopts)
def enable_pytest_pytorch(collect_only=False):
_configure_ipytest(collect_only=collect_only)
def disable_pytest_pytorch(collect_only=False):
_configure_ipytest("--disable-pytest-pytorch", collect_only=collect_only)
disable_pytest_pytorch()
%%run_pytest[clean] {MODULE}::TestFoo::test_bar
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_device_type import instantiate_device_type_tests
class TestFoo(TestCase):
def test_bar(self, device):
assert False, "Don't worry, this is supposed to happen!"
instantiate_device_type_tests(TestFoo, globals(), only_for=["cpu"])
If the absence of this very basic pytest
feature has ever been the source of frustration for you, you don't need to worry anymore. By installing the pytest-pytorch
plugin with
pip install pytest-pytorch
or
conda install -c conda-forge pytest-pytorch
you get the default pytest
experience back even if your workflow involves running tests from within your IDE!