AST Utils

November 22, 2024

Notes for converting large amounts of legacy unittest to pytest.

import ast

"""
Convert UnitTest to Pytest
"""
class AssertFalseTransformer(ast.NodeTransformer):
    """Convert `assertFalse` to `assert foo is False`"""
    def visit_Call(self, node):
        # Convert self.assertFalse(foo) into: assert foo is False
        if (isinstance(node.func, ast.Attribute)
            and node.func.attr == 'assertFalse'):
            return ast.Assert(
                test=ast.Compare(
                    left=node.args[0],
                    ops=[ast.Is()],
                    comparators=[ast.Constant(value=False)]
                ),
                msg=None
            )
        return self.generic_visit(node)

class AssertEqualTransformer(ast.NodeTransformer):
    """Convert assertEqual to assert foo == bar"""
    def visit_Call(self, node):
        # Convert self.assertEqual(a, b) into: assert a == b
        if (isinstance(node.func, ast.Attribute)
            and node.func.attr == 'assertEqual'
            and len(node.args) == 2):
            return ast.Assert(
                test=ast.Compare(
                    left=node.args[0],
                    ops=[ast.Eq()],
                    comparators=[node.args[1]]
                ),
                msg=None
            )
        return self.generic_visit(node)


class SetupTransformer(ast.NodeTransformer):
    """Convert setUp to a pytest fixture"""
    def visit_FunctionDef(self, node):
        # Convert setUp to a pytest fixture
        if node.name == 'setUp':
            node.decorator_list.append(ast.Name(id='pytest.fixture', ctx=ast.Load()))
            node.name = 'setup_fixture'
        # Convert tearDown to a pytest fixture
        elif node.name == 'tearDown':
            node.decorator_list.append(ast.Name(id='pytest.fixture', ctx=ast.Load()))
            node.name = 'teardown_fixture'
        return self.generic_visit(node)
import ast
import astor

code = """
class TestExample(unittest.TestCase):
    def setUp(self):
        self.foo = 'bar'
    
    def test_example(self):
        self.assertFalse(False)
        self.assertEqual(1, 1)
"""

tree = ast.parse(code)

transformers = [
    AssertFalseTransformer(),
    AssertEqualTransformer(),
    SetupTransformer()
]

for transformer in transformers:
    tree = transformer.visit(tree)

print(astor.to_source(tree))