diff --git a/sorts/tree_sort.py b/sorts/tree_sort.py index 056864957d4d..865fa191337e 100644 --- a/sorts/tree_sort.py +++ b/sorts/tree_sort.py @@ -1,12 +1,11 @@ """ Tree_sort algorithm. - Build a Binary Search Tree and then iterate thru it to get a sorted list. """ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from dataclasses import dataclass @@ -39,7 +38,7 @@ def insert(self, val: int) -> None: self.right.insert(val) -def tree_sort(arr: list[int]) -> tuple[int, ...]: +def tree_sort(arr: Iterable[int]) -> tuple[int, ...]: """ >>> tree_sort([]) () @@ -53,14 +52,17 @@ def tree_sort(arr: list[int]) -> tuple[int, ...]: (-4, 2, 5, 7, 9) >>> tree_sort([5, 6, 1, -1, 4, 37, 2, 7]) (-1, 1, 2, 4, 5, 6, 7, 37) - - # >>> tree_sort(range(10, -10, -1)) == tuple(sorted(range(10, -10, -1))) - # True + >>> tree_sort(range(10, -10, -1)) == tuple(sorted(range(10, -10, -1))) + True """ - if len(arr) == 0: - return tuple(arr) - root = Node(arr[0]) - for item in arr[1:]: + iterator = iter(arr) + try: + first = next(iterator) + except StopIteration: + return () + + root = Node(first) + for item in iterator: root.insert(item) return tuple(root)