Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions python_files/vscode_pytest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,21 @@ class TestItem(TestData):
class TestNode(TestData):
"""A general class that handles all test data which contains children."""

children: list[TestNode | TestItem | None]
children: Children
lineno: NotRequired[str] # Optional field for class/function nodes


class Children:
def __init__(self, init=None):
self._children = dict(init) if init is not None else {}

def add(self, child: TestNode | TestItem):
self._children[child["id_"]] = child

def values(self):
return list(self._children.values())


class VSCodePytestError(Exception):
"""A custom exception class for pytest errors."""

Expand Down Expand Up @@ -439,7 +450,7 @@ def pytest_sessionfinish(session, exitstatus):
"name": "",
"path": test_root_path,
"type_": "error",
"children": [],
"children": Children(),
"id_": "",
}
send_discovery_message(os.fsdecode(test_root_path), error_node)
Expand All @@ -459,7 +470,7 @@ def pytest_sessionfinish(session, exitstatus):
"name": "",
"path": test_root_path,
"type_": "error",
"children": [],
"children": Children(),
"id_": "",
}
send_discovery_message(os.fsdecode(test_root_path), error_node)
Expand Down Expand Up @@ -664,8 +675,7 @@ def process_parameterized_test(
)
function_nodes_dict[parent_id] = function_test_node

if test_node not in function_test_node["children"]:
function_test_node["children"].append(test_node)
function_test_node["children"].add(test_node)

# Check if the parent node of the function is file, if so create/add to this file node.
if isinstance(test_case.parent, pytest.File):
Expand All @@ -676,8 +686,7 @@ def process_parameterized_test(
if parent_test_case is None:
parent_test_case = create_file_node(parent_path)
file_nodes_dict[parent_path_key] = parent_test_case
if function_test_node not in parent_test_case["children"]:
parent_test_case["children"].append(function_test_node)
parent_test_case["children"].add(function_test_node)

# Return the function node as the test node to handle subsequent nesting
return function_test_node
Expand Down Expand Up @@ -725,8 +734,7 @@ def build_test_tree(session: pytest.Session) -> TestNode:
test_class_node = create_class_node(case_iter)
class_nodes_dict[case_iter.nodeid] = test_class_node
# Check if the class already has the child node. This will occur if the test is parameterized.
if node_child_iter not in test_class_node["children"]:
test_class_node["children"].append(node_child_iter)
test_class_node["children"].add(node_child_iter)
# Iterate up.
node_child_iter = test_class_node
case_iter = case_iter.parent
Expand All @@ -744,8 +752,8 @@ def build_test_tree(session: pytest.Session) -> TestNode:
test_file_node = create_file_node(parent_path)
file_nodes_dict[parent_path_key] = test_file_node
# Check if the class is already a child of the file node.
if test_class_node is not None and test_class_node not in test_file_node["children"]:
test_file_node["children"].append(test_class_node)
if test_class_node is not None:
test_file_node["children"].add(test_class_node)
elif not hasattr(test_case, "callspec"):
# This includes test cases that are pytest functions or a doctests.
if test_case.parent is None:
Expand All @@ -762,12 +770,13 @@ def build_test_tree(session: pytest.Session) -> TestNode:
if parent_test_case is None:
parent_test_case = create_file_node(parent_path)
file_nodes_dict[parent_path_key] = parent_test_case
parent_test_case["children"].append(test_node)
parent_test_case["children"].add(test_node)
# Process all files and construct them into nested folders
session_children_dict = construct_nested_folders(
file_nodes_dict, session_node, session_children_dict
)
session_node["children"] = list(session_children_dict.values())
session_node["children"] = Children(session_children_dict)

return session_node


Expand Down Expand Up @@ -807,8 +816,7 @@ def build_nested_folders(
if curr_folder_node is None:
curr_folder_node = create_folder_node(curr_folder_name, iterator_path)
created_files_folders_dict[iterator_path_key] = curr_folder_node
if prev_folder_node not in curr_folder_node["children"]:
curr_folder_node["children"].append(prev_folder_node)
curr_folder_node["children"].add(prev_folder_node)
iterator_path = iterator_path.parent
prev_folder_node = curr_folder_node
# Handles error where infinite loop occurs.
Expand Down Expand Up @@ -857,7 +865,7 @@ def create_session_node(session: pytest.Session) -> TestNode:
"name": node_path.name,
"path": node_path,
"type_": "folder",
"children": [],
"children": Children(),
"id_": os.fspath(node_path),
}

Expand All @@ -884,7 +892,7 @@ def create_class_node(class_module: pytest.Class | DescribeBlock) -> TestNode:
"name": class_module.name,
"path": get_node_path(class_module),
"type_": "class",
"children": [],
"children": Children(),
"id_": get_absolute_test_id(class_module.nodeid, get_node_path(class_module)),
"lineno": class_line,
}
Expand All @@ -905,7 +913,7 @@ def create_parameterized_function_node(
"name": function_name,
"path": test_path,
"type_": "function",
"children": [],
"children": Children(),
"id_": function_id,
}

Expand All @@ -921,7 +929,7 @@ def create_file_node(calculated_node_path: pathlib.Path) -> TestNode:
"path": calculated_node_path,
"type_": "file",
"id_": os.fspath(calculated_node_path),
"children": [],
"children": Children(),
}


Expand All @@ -937,7 +945,7 @@ def create_folder_node(folder_name: str, path_iterator: pathlib.Path) -> TestNod
"path": path_iterator,
"type_": "folder",
"id_": os.fspath(path_iterator),
"children": [],
"children": Children(),
}


Expand Down Expand Up @@ -1092,15 +1100,17 @@ def send_discovery_message(cwd: str, session_node: TestNode) -> None:
}
if ERRORS is not None:
payload["error"] = ERRORS
send_message(payload, cls_encoder=PathEncoder)
send_message(payload, cls_encoder=CustomEncoder)


class PathEncoder(json.JSONEncoder):
class CustomEncoder(json.JSONEncoder):
"""A custom JSON encoder that encodes pathlib.Path objects as strings."""

def default(self, o):
if isinstance(o, pathlib.Path):
return os.fspath(o)
if isinstance(o, Children):
return o.values()
return super().default(o)


Expand Down