diff --git a/src/ghstack/pull.py b/src/ghstack/pull.py index b3312e1..4872a7f 100644 --- a/src/ghstack/pull.py +++ b/src/ghstack/pull.py @@ -102,9 +102,21 @@ async def _find_head_with_tree( commit, commit_tree = line.split() if commit_tree == tree: return commit + + # Some historical or externally rewritten ghstack heads do not keep the + # old head commit in first-parent history. The source id is still enough + # to recover the merge base as long as the tree object is present locally. + if await sh.agit("cat-file", "-e", f"{tree}^{{tree}}", exitcode=True): + return await sh.agit( + "commit-tree", + tree, + input="Synthetic ghstack pull merge base\n\n[ghstack-poisoned]\n", + ) + raise RuntimeError( "Could not find the previously checked out ghstack head commit. " - "The local ghstack-source-id does not appear in the remote head history." + "The local ghstack-source-id does not appear in the remote head history, " + "and the corresponding tree is not available locally." ) @@ -147,18 +159,33 @@ async def _finish_pull(sh: ghstack.shell.Shell, state: Dict[str, Any]) -> None: await _clear_state(sh) -async def main( +async def _local_ghstack_stack( + sh: ghstack.shell.Shell, *, github_url: str +) -> List[str]: + stack = [] + commit = "HEAD" + while True: + commit_msg = await sh.agit("log", "-1", "--format=%B", commit) + if ghstack.diff.PullRequestResolved.search(commit_msg, github_url) is None: + break + commit_hash = await sh.agit("rev-parse", commit) + stack.append(commit_hash) + parents = (await sh.agit("rev-list", "--parents", "-n", "1", commit)).split() + if len(parents) != 2: + break + commit = parents[1] + stack.reverse() + return stack + + +async def _pull_current( github: ghstack.github.GitHubEndpoint, sh: ghstack.shell.Shell, remote_name: str, github_url: str, pull_request: Optional[str] = None, - continue_: bool = False, + parent_override: Optional[str] = None, ) -> None: - if continue_: - await _finish_pull(sh, await _read_state(sh)) - return - params = await _resolve_params( pull_request=pull_request, github_url=github_url, @@ -216,7 +243,15 @@ async def main( returncode, merge_tree_output = await _run_git_for_status( sh, - ["merge-tree", "--write-tree", "--messages", remote_head, local_imputed_head], + [ + "merge-tree", + "--write-tree", + "--messages", + "--merge-base", + old_head, + remote_head, + local_imputed_head, + ], ) merged_tree = merge_tree_output.splitlines()[0] if returncode == 0 else None @@ -229,7 +264,11 @@ async def main( if m_remote_source_id is not None else await sh.agit("rev-parse", f"{remote_orig}^{{tree}}") ) - remote_orig_parent = await sh.agit("rev-parse", f"{remote_orig}^") + remote_orig_parent = ( + parent_override + if parent_override is not None + else await sh.agit("rev-parse", f"{remote_orig}^") + ) author_name = await sh.agit("log", "-1", "--format=%an", "HEAD") author_email = await sh.agit("log", "-1", "--format=%ae", "HEAD") @@ -268,3 +307,64 @@ async def main( }, ) await sh.agit("checkout", pulled_orig) + + +async def _pull_stack( + github: ghstack.github.GitHubEndpoint, + sh: ghstack.shell.Shell, + remote_name: str, + github_url: str, + stack: List[str], +) -> None: + current_head: Optional[str] = None + for i, commit in enumerate(stack): + if i == 0: + await sh.agit("checkout", commit) + parent_override = None + else: + assert current_head is not None + await sh.agit("checkout", current_head) + await sh.agit("cherry-pick", commit) + parent_override = await sh.agit("rev-parse", "HEAD^") + + await _pull_current( + github=github, + sh=sh, + remote_name=remote_name, + github_url=github_url, + parent_override=parent_override, + ) + current_head = await sh.agit("rev-parse", "HEAD") + + +async def main( + github: ghstack.github.GitHubEndpoint, + sh: ghstack.shell.Shell, + remote_name: str, + github_url: str, + pull_request: Optional[str] = None, + continue_: bool = False, +) -> None: + if continue_: + await _finish_pull(sh, await _read_state(sh)) + return + + if pull_request is None: + stack = await _local_ghstack_stack(sh, github_url=github_url) + if len(stack) > 1: + await _pull_stack( + github=github, + sh=sh, + remote_name=remote_name, + github_url=github_url, + stack=stack, + ) + return + + await _pull_current( + github=github, + sh=sh, + remote_name=remote_name, + github_url=github_url, + pull_request=pull_request, + ) diff --git a/test/pull/rewritten_head_history.py.test b/test/pull/rewritten_head_history.py.test new file mode 100644 index 0000000..3a18ca1 --- /dev/null +++ b/test/pull/rewritten_head_history.py.test @@ -0,0 +1,38 @@ +from ghstack.test_prelude import * + +await init_test() + +await commit("A") +(A,) = await gh_submit("Initial") +old_orig = A.orig + +await write_file_and_add("remote.txt", "remote change") +await git("commit", "--amend", "--no-edit") +await gh_submit("Remote update") + +remote_tree = await git("rev-parse", "origin/gh/ezyang/1/head^{tree}") +rewritten_head = await get_upstream_sh().agit( + "commit-tree", + "-p", + "gh/ezyang/1/base", + remote_tree, + input="Rewritten remote head\n\n[ghstack-poisoned]\n", +) +await get_upstream_sh().agit( + "update-ref", + "refs/heads/gh/ezyang/1/head", + rewritten_head, +) + +await checkout(old_orig) +await write_file_and_add("local.txt", "local change") +await git("commit", "--amend", "--no-edit") + +await gh_pull() + +assert_eq(await git("show", "HEAD:remote.txt"), "remote change") +assert_eq(await git("show", "HEAD:local.txt"), "local change") + +await gh_submit("Local update") + +ok() diff --git a/test/pull/stack_after_top_only_pull.py.test b/test/pull/stack_after_top_only_pull.py.test new file mode 100644 index 0000000..bdea3f2 --- /dev/null +++ b/test/pull/stack_after_top_only_pull.py.test @@ -0,0 +1,51 @@ +from ghstack.test_prelude import * + +await init_test() + +await commit("A") +await commit("B") +A, B = await gh_submit("Initial") +old_top = B.orig + +await checkout(A) +await write_file_and_add("A.txt", "remote A") +await git("commit", "--amend", "--no-edit") +await gh_submit("Remote bottom update") + +await checkout(old_top) +await write_file_and_add("B.txt", "local B") +await git("commit", "--amend", "--no-edit") + +# Simulate the old top-only behavior: the top PR was pulled/recreated, but +# the lower PR in the local stack remained stale. +await gh_pull(f"https://github.com/pytorch/pytorch/pull/{B.number}") +assert_eq( + await git("rev-parse", "HEAD^^{tree}"), + await git("rev-parse", A.orig + "^{tree}"), +) + +await gh_pull() + +assert_eq(await git("show", "HEAD:A.txt"), "remote A") +assert_eq(await git("show", "HEAD:B.txt"), "local B") +assert_eq( + await git("rev-parse", "HEAD^^{tree}"), + await git("rev-parse", "origin/gh/ezyang/1/orig^{tree}"), +) +assert ( + "ghstack-source-id: " + + await git("rev-parse", "origin/gh/ezyang/1/orig^{tree}") + in await git("log", "-1", "--format=%B", "HEAD^") +) +assert ( + "ghstack-source-id: " + + await git("rev-parse", "origin/gh/ezyang/2/orig^{tree}") + in await git("log", "-1", "--format=%B", "HEAD") +) +pulled_tree = await git("rev-parse", "HEAD^{tree}") +await gh_pull() +assert_eq(await git("rev-parse", "HEAD^{tree}"), pulled_tree) + +await gh_submit("Local update") + +ok() diff --git a/test/pull/stack_parent_update.py.test b/test/pull/stack_parent_update.py.test new file mode 100644 index 0000000..8320ebf --- /dev/null +++ b/test/pull/stack_parent_update.py.test @@ -0,0 +1,44 @@ +from ghstack.test_prelude import * + +await init_test() + +await commit("A") +await commit("B") +A, B = await gh_submit("Initial") +old_top = B.orig + +await checkout(A) +await write_file_and_add("A.txt", "remote A") +await git("commit", "--amend", "--no-edit") +await gh_submit("Remote bottom update") + +await checkout(old_top) +await write_file_and_add("B.txt", "local B") +await git("commit", "--amend", "--no-edit") + +await gh_pull() + +assert_eq(await git("show", "HEAD:A.txt"), "remote A") +assert_eq(await git("show", "HEAD:B.txt"), "local B") +assert_eq( + await git("rev-parse", "HEAD^^"), + await git("rev-parse", "origin/main"), +) +assert_eq( + await git("rev-parse", "HEAD^^{tree}"), + await git("rev-parse", "origin/gh/ezyang/1/orig^{tree}"), +) +assert ( + "ghstack-source-id: " + + await git("rev-parse", "origin/gh/ezyang/1/orig^{tree}") + in await git("log", "-1", "--format=%B", "HEAD^") +) +assert ( + "ghstack-source-id: " + + await git("rev-parse", "origin/gh/ezyang/2/orig^{tree}") + in await git("log", "-1", "--format=%B", "HEAD") +) + +await gh_submit("Local update") + +ok()