diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 7d91218468..4daf5c8608 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -81,6 +81,9 @@ from deepmd.utils.data import ( DataRequirementItem, ) +from deepmd.utils.finetune import ( + warn_configuration_mismatch_during_finetune, +) from deepmd.utils.path import ( DPH5Path, ) @@ -520,9 +523,8 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: new_state_dict = {} target_state_dict = self.wrapper.state_dict() # pretrained_model - pretrained_model = get_model_for_wrapper( - state_dict["_extra_state"]["model_params"] - ) + pretrained_model_params = state_dict["_extra_state"]["model_params"] + pretrained_model = get_model_for_wrapper(pretrained_model_params) pretrained_model_wrapper = ModelWrapper(pretrained_model) pretrained_model_wrapper.set_state_dict(state_dict) # update type related params @@ -557,6 +559,25 @@ def collect_single_finetune_params( ) -> None: _new_fitting = _finetune_rule_single.get_random_fitting() _model_key_from = _finetune_rule_single.get_model_branch() + _input_model_params = ( + model_params["model_dict"][_model_key] + if self.multi_task + else model_params + ) + _pretrained_model_params = ( + pretrained_model_params["model_dict"][_model_key_from] + if "model_dict" in pretrained_model_params + else pretrained_model_params + ) + if ( + "descriptor" in _input_model_params + and "descriptor" in _pretrained_model_params + ): + warn_configuration_mismatch_during_finetune( + _input_model_params["descriptor"], + _pretrained_model_params["descriptor"], + _model_key_from, + ) target_keys = [ i for i in _random_state_dict.keys() diff --git a/deepmd/pd/utils/finetune.py b/deepmd/pd/utils/finetune.py index 7b3bdf615b..510a91c392 100644 --- a/deepmd/pd/utils/finetune.py +++ b/deepmd/pd/utils/finetune.py @@ -8,6 +8,7 @@ from deepmd.utils.finetune import ( FinetuneRuleItem, + warn_descriptor_config_differences, ) log = logging.getLogger(__name__) @@ -57,6 +58,12 @@ def get_finetune_rule_single( random_fitting=new_fitting, ) if change_model_params: + if "descriptor" in single_config and "descriptor" in single_config_chosen: + warn_descriptor_config_differences( + single_config["descriptor"], + single_config_chosen["descriptor"], + model_branch_chosen, + ) trainable_param = { "descriptor": single_config.get("descriptor", {}).get("trainable", True), "fitting_net": single_config.get("fitting_net", {}).get("trainable", True), diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 048ec8b1c6..d5d3965324 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -115,6 +115,9 @@ from deepmd.utils.data import ( DataRequirementItem, ) +from deepmd.utils.finetune import ( + warn_configuration_mismatch_during_finetune, +) if torch.__version__.startswith("2"): import torch._dynamo @@ -787,9 +790,8 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: new_state_dict = {} target_state_dict = self.wrapper.state_dict() # pretrained_model - pretrained_model = get_model_for_wrapper( - state_dict["_extra_state"]["model_params"] - ) + pretrained_model_params = state_dict["_extra_state"]["model_params"] + pretrained_model = get_model_for_wrapper(pretrained_model_params) pretrained_model_wrapper = ModelWrapper(pretrained_model) pretrained_model_wrapper.load_state_dict(state_dict) # update type related params @@ -824,6 +826,25 @@ def collect_single_finetune_params( ) -> None: _new_fitting = _finetune_rule_single.get_random_fitting() _model_key_from = _finetune_rule_single.get_model_branch() + _input_model_params = ( + model_params["model_dict"][_model_key] + if self.multi_task + else model_params + ) + _pretrained_model_params = ( + pretrained_model_params["model_dict"][_model_key_from] + if "model_dict" in pretrained_model_params + else pretrained_model_params + ) + if ( + "descriptor" in _input_model_params + and "descriptor" in _pretrained_model_params + ): + warn_configuration_mismatch_during_finetune( + _input_model_params["descriptor"], + _pretrained_model_params["descriptor"], + _model_key_from, + ) target_keys = [ i for i in _random_state_dict.keys() diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index 0e86c9aa6c..c4db694578 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -14,6 +14,7 @@ ) from deepmd.utils.finetune import ( FinetuneRuleItem, + warn_descriptor_config_differences, ) from deepmd.utils.model_branch_dict import ( get_model_dict, @@ -69,6 +70,12 @@ def get_finetune_rule_single( random_fitting=new_fitting, ) if change_model_params: + if "descriptor" in single_config and "descriptor" in single_config_chosen: + warn_descriptor_config_differences( + single_config["descriptor"], + single_config_chosen["descriptor"], + model_branch_chosen, + ) trainable_param = { "descriptor": single_config.get("descriptor", {}).get("trainable", True), "fitting_net": single_config.get("fitting_net", {}).get("trainable", True), diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 202c5d10de..bd6fdb02a3 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -70,6 +70,9 @@ from deepmd.utils.data_system import ( DeepmdDataSystem, ) +from deepmd.utils.finetune import ( + warn_configuration_mismatch_during_finetune, +) from deepmd.utils.path import ( DPPath, ) @@ -1121,6 +1124,29 @@ def _make_sample( model_with_new_type_stat=model_with_new_type_stat, ) + for model_key in self.model_keys: + finetune_rule = finetune_links[model_key] + _model_key_from = finetune_rule.get_model_branch() + input_model_params = ( + model_params["model_dict"][model_key] + if self.multi_task + else model_params + ) + branch_pretrained_model_params = ( + pretrained_model_params["model_dict"][_model_key_from] + if "model_dict" in pretrained_model_params + else pretrained_model_params + ) + if ( + "descriptor" in input_model_params + and "descriptor" in branch_pretrained_model_params + ): + warn_configuration_mismatch_during_finetune( + input_model_params["descriptor"], + branch_pretrained_model_params["descriptor"], + _model_key_from, + ) + # Selective weight copy (per-branch key remapping) pretrained_state = pretrained_wrapper.state_dict() target_state = self._unwrapped.state_dict() diff --git a/deepmd/utils/finetune.py b/deepmd/utils/finetune.py index c019cc68ab..9bcb85dc82 100644 --- a/deepmd/utils/finetune.py +++ b/deepmd/utils/finetune.py @@ -1,8 +1,197 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +from collections.abc import ( + Mapping, +) +from copy import ( + deepcopy, +) +from typing import ( + Any, +) log = logging.getLogger(__name__) +_IGNORED_DESCRIPTOR_KEYS = frozenset({"trainable"}) +_MISSING = object() +_MAX_DESCRIPTOR_CONFIG_DIFFS = 20 +_MAX_CONFIG_VALUE_LENGTH = 200 + + +def _infer_synthetic_type_count(descriptor: Mapping[str, Any]) -> int: + """Infer a safe type count for descriptor-only normalization. + + The real model ``type_map`` is not available at every finetune warning call + site. Use descriptor fields that explicitly encode per-type lists to avoid + normalizing a 3+-type descriptor against the historical two-type stub. This + is still a best-effort normalization helper: intentional type-map changes + may still show up in type-count-dependent fields such as ``sel``. + """ + type_count = 2 + for key in ("sel", "sel_a", "sel_r"): + value = descriptor.get(key) + if isinstance(value, list) and all(isinstance(item, int) for item in value): + type_count = max(type_count, len(value)) + exclude_types = descriptor.get("exclude_types") + if isinstance(exclude_types, list): + for pair in exclude_types: + if ( + isinstance(pair, list) + and len(pair) == 2 + and all(isinstance(item, int) for item in pair) + ): + type_count = max(type_count, pair[0] + 1, pair[1] + 1) + return type_count + + +def _normalize_descriptor_for_compare( + descriptor: Mapping[str, Any], +) -> Mapping[str, Any]: + """Normalize a descriptor config so implicit defaults do not warn.""" + from deepmd.utils.argcheck import ( + normalize, + ) + + config = { + "model": { + "descriptor": deepcopy(dict(descriptor)), + "fitting_net": {"neuron": [240, 240, 240]}, + "type_map": [ + f"Type{ii}" for ii in range(_infer_synthetic_type_count(descriptor)) + ], + }, + "training": {"training_data": {"systems": ["fake"]}, "numb_steps": 100}, + } + return normalize(config, multi_task=False)["model"]["descriptor"] + + +def _format_config_value(value: Any) -> str: + text = repr(value) + if len(text) > _MAX_CONFIG_VALUE_LENGTH: + text = text[: _MAX_CONFIG_VALUE_LENGTH - 3] + "..." + return text + + +def _iter_descriptor_config_differences( + input_config: Any, + pretrained_config: Any, + prefix: str = "", +) -> list[tuple[str, Any, Any]]: + differences: list[tuple[str, Any, Any]] = [] + if isinstance(input_config, Mapping) and isinstance(pretrained_config, Mapping): + keys = sorted(set(input_config) | set(pretrained_config)) + for key in keys: + if key in _IGNORED_DESCRIPTOR_KEYS: + continue + key_path = f"{prefix}.{key}" if prefix else str(key) + if key not in input_config: + differences.append((key_path, _MISSING, pretrained_config[key])) + elif key not in pretrained_config: + differences.append((key_path, input_config[key], _MISSING)) + else: + differences.extend( + _iter_descriptor_config_differences( + input_config[key], pretrained_config[key], key_path + ) + ) + return differences + if input_config != pretrained_config: + return [(prefix, input_config, pretrained_config)] + return differences + + +def _descriptor_config_differences( + input_descriptor: Mapping[str, Any], + pretrained_descriptor: Mapping[str, Any], +) -> list[tuple[str, Any, Any]]: + """Return meaningful descriptor differences, ignoring implicit defaults.""" + input_descriptor_cmp: Mapping[str, Any] = input_descriptor + pretrained_descriptor_cmp: Mapping[str, Any] = pretrained_descriptor + try: + input_descriptor_cmp = _normalize_descriptor_for_compare(input_descriptor) + pretrained_descriptor_cmp = _normalize_descriptor_for_compare( + pretrained_descriptor + ) + except Exception: + # Some in-flight or legacy descriptor schemas may not be normalizable with + # the minimal synthetic config above. If either side fails, compare raw + # descriptor against raw descriptor; mixing normalized and raw values would + # report implicit defaults as spurious differences. + input_descriptor_cmp = input_descriptor + pretrained_descriptor_cmp = pretrained_descriptor + return _iter_descriptor_config_differences( + input_descriptor_cmp, pretrained_descriptor_cmp + ) + + +def _format_descriptor_differences( + differences: list[tuple[str, Any, Any]], + *, + overwrite: bool, +) -> str: + lines = [] + shown = differences[:_MAX_DESCRIPTOR_CONFIG_DIFFS] + for key, input_value, pretrained_value in shown: + input_text = ( + "(missing)" + if input_value is _MISSING + else _format_config_value(input_value) + ) + pretrained_text = ( + "(missing)" + if pretrained_value is _MISSING + else _format_config_value(pretrained_value) + ) + if overwrite: + lines.append(f" {key}: {input_text} -> {pretrained_text}") + else: + lines.append(f" {key}: input={input_text}, pretrained={pretrained_text}") + remaining = len(differences) - len(shown) + if remaining > 0: + lines.append(f" ... and {remaining} more difference(s)") + return "\n".join(lines) + + +def warn_descriptor_config_differences( + input_descriptor: Mapping[str, Any], + pretrained_descriptor: Mapping[str, Any], + model_branch: str = "Default", +) -> None: + """Warn when ``--use-pretrain-script`` overwrites descriptor config.""" + differences = _descriptor_config_differences( + input_descriptor, pretrained_descriptor + ) + if not differences: + return + log.warning( + "Descriptor configuration in input.json differs from pretrained model " + f"(branch '{model_branch}'). The input descriptor configuration will be " + "overwritten with the pretrained model's descriptor configuration " + "except for the trainable flag:\n" + + _format_descriptor_differences(differences, overwrite=True) + ) + + +def warn_configuration_mismatch_during_finetune( + input_descriptor: Mapping[str, Any], + pretrained_descriptor: Mapping[str, Any], + model_branch: str = "Default", +) -> None: + """Warn when fine-tuning loads only compatible descriptor parameters.""" + differences = _descriptor_config_differences( + input_descriptor, pretrained_descriptor + ) + if not differences: + return + log.warning( + "Descriptor configuration mismatch detected between input.json and " + f"pretrained model (branch '{model_branch}'). Only descriptor parameters " + "that are compatible with the pretrained model can be reused; " + "incompatible parameters may be reinitialized, skipped, or rejected by " + "backend-specific loading:\n" + + _format_descriptor_differences(differences, overwrite=False) + ) + class FinetuneRuleItem: def __init__( diff --git a/source/tests/common/test_finetune_utils.py b/source/tests/common/test_finetune_utils.py new file mode 100644 index 0000000000..776d7bb131 --- /dev/null +++ b/source/tests/common/test_finetune_utils.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging + +from deepmd.utils import ( + finetune, +) + + +def test_descriptor_normalization_uses_descriptor_type_count(): + assert finetune._infer_synthetic_type_count({"sel": [16, 24, 32]}) == 3 + assert finetune._infer_synthetic_type_count({"exclude_types": [[0, 3]]}) == 4 + + +def test_descriptor_config_warning_reports_nested_difference(monkeypatch, caplog): + monkeypatch.setattr( + finetune, + "_normalize_descriptor_for_compare", + lambda descriptor: descriptor, + ) + + input_descriptor = { + "type": "dpa3", + "repflow": {"nlayers": 6, "rcut": 6.0}, + "trainable": False, + } + pretrained_descriptor = { + "type": "dpa3", + "repflow": {"nlayers": 16, "rcut": 6.0}, + "trainable": True, + } + + with caplog.at_level(logging.WARNING): + finetune.warn_configuration_mismatch_during_finetune( + input_descriptor, + pretrained_descriptor, + ) + + assert "repflow.nlayers" in caplog.text + assert "input=6, pretrained=16" in caplog.text + assert "trainable" not in caplog.text + + +def test_descriptor_config_warning_skips_default_only_difference(caplog): + with caplog.at_level(logging.WARNING): + finetune.warn_descriptor_config_differences( + {"type": "se_e2_a", "sel": [16, 16], "rcut": 6.0}, + { + "type": "se_e2_a", + "sel": [16, 16], + "rcut": 6.0, + "activation_function": "tanh", + }, + ) + + assert caplog.text == "" + + +def test_descriptor_config_warning_falls_back_to_raw_if_normalization_fails( + monkeypatch, caplog +): + input_descriptor = {"type": "dpa3", "repflow": {"nlayers": 6}} + pretrained_descriptor = {"type": "dpa3", "repflow": {"nlayers": 16}} + + def normalize_one_side_then_fail(descriptor): + if descriptor is pretrained_descriptor: + raise ValueError("legacy schema") + return {**descriptor, "implicit_default": True} + + monkeypatch.setattr( + finetune, + "_normalize_descriptor_for_compare", + normalize_one_side_then_fail, + ) + + with caplog.at_level(logging.WARNING): + finetune.warn_configuration_mismatch_during_finetune( + input_descriptor, + pretrained_descriptor, + ) + + assert "repflow.nlayers" in caplog.text + assert "implicit_default" not in caplog.text + + +def test_descriptor_config_warning_distinguishes_none_from_missing(monkeypatch, caplog): + monkeypatch.setattr( + finetune, + "_normalize_descriptor_for_compare", + lambda descriptor: descriptor, + ) + + with caplog.at_level(logging.WARNING): + finetune.warn_configuration_mismatch_during_finetune( + {"type": "dpa3", "input_none": None}, + {"type": "dpa3", "pretrained_none": None}, + ) + + assert "input_none: input=None, pretrained=(missing)" in caplog.text + assert "pretrained_none: input=(missing), pretrained=None" in caplog.text