Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.finetune import (
warn_configuration_mismatch_during_finetune,
)
from deepmd.utils.path import (
DPH5Path,
)
Expand Down Expand Up @@ -520,9 +523,8 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
new_state_dict = {}
target_state_dict = self.wrapper.state_dict()
# pretrained_model
pretrained_model = get_model_for_wrapper(
state_dict["_extra_state"]["model_params"]
)
pretrained_model_params = state_dict["_extra_state"]["model_params"]
pretrained_model = get_model_for_wrapper(pretrained_model_params)
pretrained_model_wrapper = ModelWrapper(pretrained_model)
pretrained_model_wrapper.set_state_dict(state_dict)
# update type related params
Expand Down Expand Up @@ -557,6 +559,25 @@ def collect_single_finetune_params(
) -> None:
_new_fitting = _finetune_rule_single.get_random_fitting()
_model_key_from = _finetune_rule_single.get_model_branch()
_input_model_params = (
model_params["model_dict"][_model_key]
if self.multi_task
else model_params
)
_pretrained_model_params = (
pretrained_model_params["model_dict"][_model_key_from]
if "model_dict" in pretrained_model_params
else pretrained_model_params
)
if (
"descriptor" in _input_model_params
and "descriptor" in _pretrained_model_params
):
warn_configuration_mismatch_during_finetune(
_input_model_params["descriptor"],
_pretrained_model_params["descriptor"],
_model_key_from,
)
target_keys = [
i
for i in _random_state_dict.keys()
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pd/utils/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from deepmd.utils.finetune import (
FinetuneRuleItem,
warn_descriptor_config_differences,
)

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -57,6 +58,12 @@ def get_finetune_rule_single(
random_fitting=new_fitting,
)
if change_model_params:
if "descriptor" in single_config and "descriptor" in single_config_chosen:
warn_descriptor_config_differences(
single_config["descriptor"],
single_config_chosen["descriptor"],
model_branch_chosen,
)
trainable_param = {
"descriptor": single_config.get("descriptor", {}).get("trainable", True),
"fitting_net": single_config.get("fitting_net", {}).get("trainable", True),
Expand Down
27 changes: 24 additions & 3 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.finetune import (
warn_configuration_mismatch_during_finetune,
)

if torch.__version__.startswith("2"):
import torch._dynamo
Expand Down Expand Up @@ -787,9 +790,8 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
new_state_dict = {}
target_state_dict = self.wrapper.state_dict()
# pretrained_model
pretrained_model = get_model_for_wrapper(
state_dict["_extra_state"]["model_params"]
)
pretrained_model_params = state_dict["_extra_state"]["model_params"]
pretrained_model = get_model_for_wrapper(pretrained_model_params)
pretrained_model_wrapper = ModelWrapper(pretrained_model)
pretrained_model_wrapper.load_state_dict(state_dict)
# update type related params
Expand Down Expand Up @@ -824,6 +826,25 @@ def collect_single_finetune_params(
) -> None:
_new_fitting = _finetune_rule_single.get_random_fitting()
_model_key_from = _finetune_rule_single.get_model_branch()
_input_model_params = (
model_params["model_dict"][_model_key]
if self.multi_task
else model_params
)
_pretrained_model_params = (
pretrained_model_params["model_dict"][_model_key_from]
if "model_dict" in pretrained_model_params
else pretrained_model_params
)
if (
"descriptor" in _input_model_params
and "descriptor" in _pretrained_model_params
):
warn_configuration_mismatch_during_finetune(
_input_model_params["descriptor"],
_pretrained_model_params["descriptor"],
_model_key_from,
)
target_keys = [
i
for i in _random_state_dict.keys()
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/utils/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from deepmd.utils.finetune import (
FinetuneRuleItem,
warn_descriptor_config_differences,
)
from deepmd.utils.model_branch_dict import (
get_model_dict,
Expand Down Expand Up @@ -69,6 +70,12 @@ def get_finetune_rule_single(
random_fitting=new_fitting,
)
if change_model_params:
if "descriptor" in single_config and "descriptor" in single_config_chosen:
warn_descriptor_config_differences(
single_config["descriptor"],
single_config_chosen["descriptor"],
model_branch_chosen,
)
trainable_param = {
"descriptor": single_config.get("descriptor", {}).get("trainable", True),
"fitting_net": single_config.get("fitting_net", {}).get("trainable", True),
Expand Down
26 changes: 26 additions & 0 deletions deepmd/pt_expt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.finetune import (
warn_configuration_mismatch_during_finetune,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -1121,6 +1124,29 @@ def _make_sample(
model_with_new_type_stat=model_with_new_type_stat,
)

for model_key in self.model_keys:
finetune_rule = finetune_links[model_key]
_model_key_from = finetune_rule.get_model_branch()
input_model_params = (
model_params["model_dict"][model_key]
if self.multi_task
else model_params
)
branch_pretrained_model_params = (
pretrained_model_params["model_dict"][_model_key_from]
if "model_dict" in pretrained_model_params
else pretrained_model_params
)
if (
"descriptor" in input_model_params
and "descriptor" in branch_pretrained_model_params
):
warn_configuration_mismatch_during_finetune(
input_model_params["descriptor"],
branch_pretrained_model_params["descriptor"],
_model_key_from,
)

# Selective weight copy (per-branch key remapping)
pretrained_state = pretrained_wrapper.state_dict()
target_state = self._unwrapped.state_dict()
Expand Down
189 changes: 189 additions & 0 deletions deepmd/utils/finetune.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,197 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from collections.abc import (
Mapping,
)
from copy import (
deepcopy,
)
from typing import (
Any,
)

log = logging.getLogger(__name__)

_IGNORED_DESCRIPTOR_KEYS = frozenset({"trainable"})
_MISSING = object()
_MAX_DESCRIPTOR_CONFIG_DIFFS = 20
_MAX_CONFIG_VALUE_LENGTH = 200


def _infer_synthetic_type_count(descriptor: Mapping[str, Any]) -> int:
"""Infer a safe type count for descriptor-only normalization.

The real model ``type_map`` is not available at every finetune warning call
site. Use descriptor fields that explicitly encode per-type lists to avoid
normalizing a 3+-type descriptor against the historical two-type stub. This
is still a best-effort normalization helper: intentional type-map changes
may still show up in type-count-dependent fields such as ``sel``.
"""
type_count = 2
for key in ("sel", "sel_a", "sel_r"):
value = descriptor.get(key)
if isinstance(value, list) and all(isinstance(item, int) for item in value):
type_count = max(type_count, len(value))
exclude_types = descriptor.get("exclude_types")
if isinstance(exclude_types, list):
for pair in exclude_types:
if (
isinstance(pair, list)
and len(pair) == 2
and all(isinstance(item, int) for item in pair)
):
type_count = max(type_count, pair[0] + 1, pair[1] + 1)
return type_count


def _normalize_descriptor_for_compare(
descriptor: Mapping[str, Any],
) -> Mapping[str, Any]:
"""Normalize a descriptor config so implicit defaults do not warn."""
from deepmd.utils.argcheck import (
normalize,
)

config = {
"model": {
"descriptor": deepcopy(dict(descriptor)),
"fitting_net": {"neuron": [240, 240, 240]},
"type_map": [
f"Type{ii}" for ii in range(_infer_synthetic_type_count(descriptor))
],
},
"training": {"training_data": {"systems": ["fake"]}, "numb_steps": 100},
}
return normalize(config, multi_task=False)["model"]["descriptor"]


def _format_config_value(value: Any) -> str:
text = repr(value)
if len(text) > _MAX_CONFIG_VALUE_LENGTH:
text = text[: _MAX_CONFIG_VALUE_LENGTH - 3] + "..."
return text


def _iter_descriptor_config_differences(
input_config: Any,
pretrained_config: Any,
prefix: str = "",
) -> list[tuple[str, Any, Any]]:
differences: list[tuple[str, Any, Any]] = []
if isinstance(input_config, Mapping) and isinstance(pretrained_config, Mapping):
keys = sorted(set(input_config) | set(pretrained_config))
for key in keys:
if key in _IGNORED_DESCRIPTOR_KEYS:
continue
key_path = f"{prefix}.{key}" if prefix else str(key)
if key not in input_config:
differences.append((key_path, _MISSING, pretrained_config[key]))
elif key not in pretrained_config:
differences.append((key_path, input_config[key], _MISSING))
else:
differences.extend(
_iter_descriptor_config_differences(
input_config[key], pretrained_config[key], key_path
)
)
return differences
if input_config != pretrained_config:
return [(prefix, input_config, pretrained_config)]
return differences


def _descriptor_config_differences(
input_descriptor: Mapping[str, Any],
pretrained_descriptor: Mapping[str, Any],
) -> list[tuple[str, Any, Any]]:
"""Return meaningful descriptor differences, ignoring implicit defaults."""
input_descriptor_cmp: Mapping[str, Any] = input_descriptor
pretrained_descriptor_cmp: Mapping[str, Any] = pretrained_descriptor
try:
input_descriptor_cmp = _normalize_descriptor_for_compare(input_descriptor)
pretrained_descriptor_cmp = _normalize_descriptor_for_compare(
pretrained_descriptor
)
except Exception:
# Some in-flight or legacy descriptor schemas may not be normalizable with
# the minimal synthetic config above. If either side fails, compare raw
# descriptor against raw descriptor; mixing normalized and raw values would
# report implicit defaults as spurious differences.
input_descriptor_cmp = input_descriptor
pretrained_descriptor_cmp = pretrained_descriptor
return _iter_descriptor_config_differences(
input_descriptor_cmp, pretrained_descriptor_cmp
)


def _format_descriptor_differences(
differences: list[tuple[str, Any, Any]],
*,
overwrite: bool,
) -> str:
lines = []
shown = differences[:_MAX_DESCRIPTOR_CONFIG_DIFFS]
for key, input_value, pretrained_value in shown:
input_text = (
"(missing)"
if input_value is _MISSING
else _format_config_value(input_value)
)
pretrained_text = (
"(missing)"
if pretrained_value is _MISSING
else _format_config_value(pretrained_value)
)
if overwrite:
lines.append(f" {key}: {input_text} -> {pretrained_text}")
else:
lines.append(f" {key}: input={input_text}, pretrained={pretrained_text}")
remaining = len(differences) - len(shown)
if remaining > 0:
lines.append(f" ... and {remaining} more difference(s)")
return "\n".join(lines)


def warn_descriptor_config_differences(
input_descriptor: Mapping[str, Any],
pretrained_descriptor: Mapping[str, Any],
model_branch: str = "Default",
) -> None:
"""Warn when ``--use-pretrain-script`` overwrites descriptor config."""
differences = _descriptor_config_differences(
input_descriptor, pretrained_descriptor
)
if not differences:
return
log.warning(
"Descriptor configuration in input.json differs from pretrained model "
f"(branch '{model_branch}'). The input descriptor configuration will be "
"overwritten with the pretrained model's descriptor configuration "
"except for the trainable flag:\n"
+ _format_descriptor_differences(differences, overwrite=True)
)


def warn_configuration_mismatch_during_finetune(
input_descriptor: Mapping[str, Any],
pretrained_descriptor: Mapping[str, Any],
model_branch: str = "Default",
) -> None:
"""Warn when fine-tuning loads only compatible descriptor parameters."""
differences = _descriptor_config_differences(
input_descriptor, pretrained_descriptor
)
if not differences:
return
log.warning(
"Descriptor configuration mismatch detected between input.json and "
f"pretrained model (branch '{model_branch}'). Only descriptor parameters "
"that are compatible with the pretrained model can be reused; "
"incompatible parameters may be reinitialized, skipped, or rejected by "
"backend-specific loading:\n"
+ _format_descriptor_differences(differences, overwrite=False)
)


class FinetuneRuleItem:
def __init__(
Expand Down
Loading
Loading