diff --git a/README.md b/README.md index 0f11bd35..a866454f 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ This is the PyTorch implementation of the [RotatE](https://openreview.net/forum?id=HkgEQnRqYQ) model for knowledge graph embedding (KGE). We provide a toolkit that gives state-of-the-art performance of several popular KGE models. The toolkit is quite efficient, which is able to train a large KGE model within a few hours on a single GPU. +A faster multi-GPU implementation of RotatE and other KGE models is available in [GraphVite](https://github.com/DeepGraphLearning/graphvite). + **Implemented features** Models: diff --git a/codes/dataloader.py b/codes/dataloader.py index 70d43a25..ed3f3492 100644 --- a/codes/dataloader.py +++ b/codes/dataloader.py @@ -59,8 +59,8 @@ def __getitem__(self, idx): negative_sample = np.concatenate(negative_sample_list)[:self.negative_sample_size] - negative_sample = torch.from_numpy(negative_sample) - + negative_sample = torch.LongTensor(negative_sample) + positive_sample = torch.LongTensor(positive_sample) return positive_sample, negative_sample, subsampling_weight, self.mode @@ -181,4 +181,4 @@ def one_shot_iterator(dataloader): ''' while True: for data in dataloader: - yield data \ No newline at end of file + yield data diff --git a/codes/run.py b/codes/run.py index 9cc7d2e9..457c6fdf 100644 --- a/codes/run.py +++ b/codes/run.py @@ -284,7 +284,6 @@ def main(args): logging.info('Start Training...') logging.info('init_step = %d' % init_step) - logging.info('learning_rate = %d' % current_learning_rate) logging.info('batch_size = %d' % args.batch_size) logging.info('negative_adversarial_sampling = %d' % args.negative_adversarial_sampling) logging.info('hidden_dim = %d' % args.hidden_dim) @@ -296,6 +295,8 @@ def main(args): # Set valid dataloader as it would be evaluated during training if args.do_train: + logging.info('learning_rate = %d' % current_learning_rate) + training_logs = [] #Training Loop @@ -357,4 +358,4 @@ def main(args): log_metrics('Test', step, metrics) if __name__ == '__main__': - main(parse_args()) \ No newline at end of file + main(parse_args())