Skip to content

Commit 9090c27

Browse files
Merge pull request #2315 from ds-wook/refactor/migrate-lightgcn-pytorch
refactor: migrate lightgcn pytorch
2 parents f7a10b0 + 77ce705 commit 9090c27

7 files changed

Lines changed: 2152 additions & 1580 deletions

File tree

examples/02_model_collaborative_filtering/lightgcn_deep_dive.ipynb

Lines changed: 1267 additions & 832 deletions
Large diffs are not rendered by default.

examples/06_benchmarks/benchmark_utils.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
pass # skip this import if we are not in a Spark environment
5555

5656
try:
57-
from recommenders.models.deeprec.deeprec_utils import prepare_hparams
5857
from recommenders.models.deeprec.models.graphrec.lightgcn import LightGCN
5958
from recommenders.models.deeprec.DataModel.ImplicitCF import ImplicitCF
6059
from recommenders.models.ncf.ncf_singlenode import NCF
@@ -385,10 +384,32 @@ def prepare_training_lightgcn(train, test):
385384

386385

387386
def train_lightgcn(params, data):
388-
hparams = prepare_hparams(**params)
389-
model = LightGCN(hparams, data)
387+
ctor_keys = {"embed_size", "n_layers", "seed"}
388+
fit_keys = {
389+
"epochs",
390+
"learning_rate",
391+
"batch_size",
392+
"decay",
393+
"eval_epoch",
394+
"top_k",
395+
"metrics",
396+
"save_model",
397+
"save_epoch",
398+
}
399+
400+
ctor_kwargs = {k: params[k] for k in ctor_keys if k in params}
401+
fit_kwargs = {k: params[k] for k in fit_keys if k in params}
402+
if "MODEL_DIR" in params:
403+
fit_kwargs["model_dir"] = params["MODEL_DIR"]
404+
405+
model = LightGCN(
406+
n_users=data.n_users,
407+
n_items=data.n_items,
408+
norm_adj=data.get_norm_adj_mat(),
409+
**ctor_kwargs,
410+
)
390411
with Timer() as t:
391-
model.fit()
412+
model.fit(data, **fit_kwargs)
392413
return model, t
393414

394415

examples/07_tutorials/KDD2020-tutorial/step5_run_lightgcn.ipynb

Lines changed: 453 additions & 472 deletions
Large diffs are not rendered by default.

recommenders/models/deeprec/DataModel/ImplicitCF.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Copyright (c) Recommenders contributors.
22
# Licensed under the MIT License.
33

4+
import logging
45
import random
6+
57
import numpy as np
68
import pandas as pd
79
import scipy.sparse as sp
@@ -13,6 +15,8 @@
1315
DEFAULT_PREDICTION_COL,
1416
)
1517

18+
logger = logging.getLogger(__name__)
19+
1620

1721
class ImplicitCF(object):
1822
"""Data processing class for GCN models which use implicit feedback.
@@ -156,7 +160,7 @@ def get_norm_adj_mat(self):
156160
if self.adj_dir is None:
157161
raise FileNotFoundError
158162
norm_adj_mat = sp.load_npz(self.adj_dir + "/norm_adj_mat.npz")
159-
print("Already load norm adj matrix.")
163+
logger.info("Already load norm adj matrix.")
160164

161165
except FileNotFoundError:
162166
norm_adj_mat = self.create_norm_adj_mat()
@@ -180,15 +184,15 @@ def create_norm_adj_mat(self):
180184
adj_mat[: self.n_users, self.n_users :] = R
181185
adj_mat[self.n_users :, : self.n_users] = R.T
182186
adj_mat = adj_mat.todok()
183-
print("Already create adjacency matrix.")
187+
logger.info("Already create adjacency matrix.")
184188

185189
rowsum = np.array(adj_mat.sum(1))
186190
d_inv = np.power(rowsum + 1e-9, -0.5).flatten()
187191
d_inv[np.isinf(d_inv)] = 0.0
188192
d_mat_inv = sp.diags(d_inv)
189193
norm_adj_mat = d_mat_inv.dot(adj_mat)
190194
norm_adj_mat = norm_adj_mat.dot(d_mat_inv)
191-
print("Already normalize adjacency matrix.")
195+
logger.info("Already normalize adjacency matrix.")
192196

193197
return norm_adj_mat.tocsr()
194198

0 commit comments

Comments
 (0)