From a14d19556cad7966599ac9c0bc3cf35900796568 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adria=CC=81n=20Gonza=CC=81lez=20Marti=CC=81n?= Date: Sun, 17 Mar 2019 16:18:36 +0000 Subject: [PATCH 01/15] add VanillaVocab interface --- bin/test.py | 6 +- bin/train.py | 6 +- bin/vocab.py | 6 +- src/data.py | 3 + src/evaluation.py | 10 +-- src/vocab.py | 218 ++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 234 insertions(+), 15 deletions(-) create mode 100644 src/vocab.py diff --git a/bin/test.py b/bin/test.py index 4a3271d..6f5ebf2 100644 --- a/bin/test.py +++ b/bin/test.py @@ -10,9 +10,9 @@ from torch.utils.data import DataLoader +from src.vocab import VanillaVocab from src.matching_network import MatchingNetwork from src.evaluation import (predict, save_predictions, generate_episode_data) -from src.data import read_vocab, read_data_set from src.datasets import EpisodesSampler, EpisodesDataset from src.utils import extract_model_parameters, get_model_name @@ -66,8 +66,8 @@ def main(args): model, _, N, k = _load_model(args.model) print("Loading dataset...") - vocab = read_vocab(args.vocab) - X_test, y_test = read_data_set(args.test_set, vocab) + vocab = VanillaVocab(args.vocab) + X_test, y_test = vocab.to_tensors(args.test_set) test_set = EpisodesDataset(X_test, y_test, k=k) sampler = EpisodesSampler(test_set, N=N, episodes_multiplier=30) test_loader = DataLoader(test_set, sampler=sampler, batch_size=BATCH_SIZE) diff --git a/bin/train.py b/bin/train.py index 17942ee..8a0f82b 100644 --- a/bin/train.py +++ b/bin/train.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader -from src.data import read_vocab, read_data_set +from src.vocab import VanillaVocab from src.datasets import EpisodesSampler, EpisodesDataset from src.matching_network import MatchingNetwork from src.training import train @@ -65,8 +65,8 @@ def _get_loader(data_set, N, episodes_multiplier=1): def main(args): print("Loading dataset...") - vocab = read_vocab(args.vocab) - X_train, y_train = read_data_set(args.training_set, vocab) + vocab = VanillaVocab(args.vocab) + X_train, y_train = vocab.to_tensors(args.training_set) # Split training further into train and valid X_train, X_valid, y_train, y_valid = train_test_split_tensors( diff --git a/bin/vocab.py b/bin/vocab.py index ce4a309..e625130 100644 --- a/bin/vocab.py +++ b/bin/vocab.py @@ -4,7 +4,7 @@ from argparse import ArgumentParser -from src.data import generate_vocab, store_vocab +from src.vocab import VanillaVocab parser = ArgumentParser() parser.add_argument("input", help="Path to the input CSV data set") @@ -13,10 +13,10 @@ def main(args): print("Generating vocab...") - vocab = generate_vocab(args.input) + vocab = VanillaVocab.generate_vocab(args.input) print("Storing vocab...") - store_vocab(vocab, args.output) + vocab.save(args.output) print(f"Stored vocab of size {len(vocab)} at {args.output}") diff --git a/src/data.py b/src/data.py index f1e88f1..b1d1be5 100644 --- a/src/data.py +++ b/src/data.py @@ -9,6 +9,9 @@ from torchtext.data import Field, TabularDataset from torchtext.vocab import Vocab +print( + "[WARNING] Don't use src.data anymore. Use the Vocab interfaces instead.") + VOCAB_SIZE = 27443 PADDING_TOKEN_INDEX = 1 diff --git a/src/evaluation.py b/src/evaluation.py index c61f551..0329e49 100644 --- a/src/evaluation.py +++ b/src/evaluation.py @@ -6,8 +6,6 @@ import torch import numpy as np -from .data import reverse_tensor - RESULTS_PATH = os.path.join( os.path.dirname(os.path.dirname(__file__)), "results") @@ -174,11 +172,11 @@ def _episode_to_text(support_set, targets, labels, target_labels, vocab): # First, we need to flatten these... N, k, _ = support_set.shape flat_support_set = support_set.view(N * k, -1) - flat_support_set = reverse_tensor(flat_support_set, vocab) + flat_support_set = vocab.to_text(flat_support_set) support_set = flat_support_set.reshape(N, k) - targets = reverse_tensor(targets, vocab) - labels = reverse_tensor(labels, vocab) - target_labels = reverse_tensor(target_labels, vocab) + targets = vocab.to_text(targets) + labels = vocab.to_text(labels) + target_labels = vocab.to_text(target_labels) return support_set, targets, labels, target_labels diff --git a/src/vocab.py b/src/vocab.py new file mode 100644 index 0000000..9f10053 --- /dev/null +++ b/src/vocab.py @@ -0,0 +1,218 @@ +import json +import numpy as np + +from collections import defaultdict, Counter + +from torchtext.vocab import Vocab +from torchtext.data import Field, TabularDataset + + +class AbstractVocab(object): + """ + Abstract interface for the Vocab classes which allows to map between text + and numbers. + """ + + def __len__(self): + raise NotImplementedError() + + def to_tensors(self, file_path): + raise NotImplementedError() + + def to_text(self, X): + raise NotImplementedError() + + +class VanillaVocab(AbstractVocab): + """ + Allows to map between text and numbers using a simple tokenizer. + """ + + def __init__(self, file_path): + """ + Initialise the vocabulary by reading it from a file path. + + Parameters + --- + file_path : str + Path to the vocab file. + """ + super().__init__() + + self.vocab = self._read_vocab(file_path) + self.padding_token_index = 1 + + def _read_vocab(self, file_path): + """ + Reads a vocab from its previously stored state. + + Inspired by https://github.com/pytorch/text/issues/253#issue-305929871 + + Parameters + --- + file_path : str + Path to the JSON file with the vocab info. + + Returns + --- + vocab : torchtext.Vocab + Vocabulary created. + """ + vocab_state = {} + with open(file_path) as file: + vocab_state = json.load(file) + + vocab = Vocab(Counter()) + vocab.__dict__.update(vocab_state) + vocab.stoi = defaultdict(lambda: 0, vocab.stoi) + return vocab + + def __len__(self): + """ + Returns the size of the vocabulary. + + Returns + --- + int + Number of tokens in the vocabulary. + """ + return len(self.vocab) + + def save(self, file_path): + """ + Stores a vocab in a JSON file. + + Inspired by https://github.com/pytorch/text/issues/253#issue-305929871 + + Parameters + --- + vocab : torchtext.Vocab + Vocabulary with our corpus. + file_path : str + Path to the vocab state to write. + + """ + vocab_state = dict(self.vocab.__dict__, stoi=dict(self.vocab.stoi)) + with open(file_path, 'w') as file: + json.dump(vocab_state, file) + + def to_tensors(self, file_path): + """ + Reads the data set from one of the pre-processed CSVs composed + of columns `label` and `sentence`. + + Parameters + --- + file_path : str + Path to the CSV file. + vocab : torchtext.Vocab + Vocabulary to use. + + Returns + --- + X : torch.Tensor[num_labels x num_examples x sen_length] + Sentences on the dataset grouped by labels. + y : torch.Tensor[num_labels] + Labels for each group of sentences. + """ + sentence = Field( + batch_first=True, sequential=True, tokenize=self._tokenizer) + sentence.vocab = self.vocab + + label = Field(is_target=True) + label.vocab = self.vocab + + data_set = TabularDataset( + path=file_path, + format='csv', + skip_header=True, + fields=[('label', label), ('sentence', sentence)]) + + sentences_tensor = sentence.process(data_set.sentence) + labels_tensor = label.process(data_set.label).squeeze() + + # Infer num_labels and group sentences by label + num_labels = labels_tensor.unique().shape[0] + num_examples = labels_tensor.shape[0] // num_labels + y = labels_tensor[::num_examples] + sen_length = sentences_tensor.shape[-1] + X = sentences_tensor.view(num_labels, num_examples, sen_length) + + return X, y + + @classmethod + def _tokenizer(cls, text): + """ + Simple tokenizer which splits the token by the space + character. The CSVs have already been pre-processed with + spaCy, therefore this should be enough. + + Parameters + --- + text : str + Input text to tokenize. + + Returns + --- + iterator + Iterator over token text. + """ + return text.split(' ') + + def to_text(self, X): + """ + Reverses some numericalised tensor into text. + + Parameters + ---- + X : torch.Tensor[num_elements x sen_length] + Sentences on the tensor. + vocab : torchtext.Vocab + Vocabulary to use. + + Returns + ---- + sentences : np.array[num_elements] + Array of strings. + """ + sentences = [] + for sentence_tensor in X: + if len(sentence_tensor.shape) == 0: + # 0-D tensor + sentences.append(self.vocab.itos[sentence_tensor]) + continue + + sentence = [ + self.vocab.itos[token] for token in sentence_tensor + if token != self.padding_token_index + ] + sentences.append(' '.join(sentence)) + + return np.array(sentences) + + @classmethod + def generate_vocab(cls, file_path): + """ + Generate the vocabulary from one of the pre-processed CSVs composed + of columns `label` and `sentence`. + + Parameters + --- + file_path : str + Path to the CSV file. + + Returns + --- + vocab : torchtext.Vocab + Vocabulary generated from the file. + """ + text = Field(sequential=True, tokenize=cls._tokenizer) + + data_set = TabularDataset( + path=file_path, + format='csv', + skip_header=True, + fields=[('label', text), ('sentence', text)]) + + text.build_vocab(data_set.label, data_set.sentence) + return text.vocab From e3c6c23d353646535fcb89f74953bd3625cfdf13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adria=CC=81n=20Gonza=CC=81lez=20Marti=CC=81n?= Date: Sun, 17 Mar 2019 16:36:57 +0000 Subject: [PATCH 02/15] start to work on bert's interface --- Pipfile | 2 ++ Pipfile.lock | 50 ++++++++++++++++++++++++++++++++++++++++++- bin/train.py | 14 +++++++++--- src/vocab.py | 60 ++++++++++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 118 insertions(+), 8 deletions(-) diff --git a/Pipfile b/Pipfile index 9860653..5d968eb 100644 --- a/Pipfile +++ b/Pipfile @@ -18,6 +18,8 @@ scikit-learn = "*" torch = "*" pytorch-ignite = "*" torchtext = "*" +pytorch-pretrained-bert = "*" +pandas = "*" [requires] python_version = "3.6" diff --git a/Pipfile.lock b/Pipfile.lock index 6ada896..308cfa6 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "9c3ebbd53468ba536a91e2972c0f7affb2004c1787485c72b4bda1979d84d074" + "sha256": "5a112ee07d4119f2a059b1d324542eec71ab5941f427813b3e0fd055c1c2f640" }, "pipfile-spec": 6, "requires": { @@ -52,6 +52,20 @@ ], "version": "==3.1.0" }, + "boto3": { + "hashes": [ + "sha256:1b4a86e1167ba7cbb9dbf2a0a0b86447b35a2b901ae5aace75b8196631680957", + "sha256:f5b12367c530dac45782251b672f1e911da5c74285f89850b0f4f5694b8c388c" + ], + "version": "==1.9.115" + }, + "botocore": { + "hashes": [ + "sha256:7c8ec120bc5bcc4076aebd7dac3a679777ff3a3ce3263c64d7342ea7982b578c", + "sha256:f4607f8800f87fd8eacd450699666f92d7fbc48fbb757903ad56825ce08e072a" + ], + "version": "==1.12.115" + }, "certifi": { "hashes": [ "sha256:59b7658e26ca9c7339e00f8f4636cdfe59d34fa37b9b04f6f9e9926b3cece1a5", @@ -127,6 +141,14 @@ ], "version": "==0.2.9" }, + "docutils": { + "hashes": [ + "sha256:02aec4bd92ab067f6ff27a38a38a41173bf01bed8f89157768c1573f53e474a6", + "sha256:51e64ef2ebfb29cae1faa133b3710143496eca21c530f3f71424d77687764274", + "sha256:7a4bd47eaf6596e1295ecb11361139febe29b084a87bf005bf899f9a42edc3c6" + ], + "version": "==0.14" + }, "entrypoints": { "hashes": [ "sha256:589f874b313739ad35be6e0cd7efde2a4e9b6fea91edcc34e58ecbb8dbe56d19", @@ -198,6 +220,13 @@ ], "version": "==2.10" }, + "jmespath": { + "hashes": [ + "sha256:3720a4b1bd659dd2eecad0666459b9788813e032b83e7ba58578e48254e0a0e6", + "sha256:bde2aef6f44302dfb30320115b17d030798de8c4110e28d5cf6cf91a7a31074c" + ], + "version": "==0.9.4" + }, "jsonschema": { "hashes": [ "sha256:0c0a81564f181de3212efa2d17de1910f8732fa1b71c42266d983cd74304e20d", @@ -453,6 +482,7 @@ "sha256:cc8fc0c7a8d5951dc738f1c1447f71c43734244453616f32b8aa0ef6013a5dfb", "sha256:d7b460bc316064540ce0c41c1438c416a40746fd8a4fb2999668bf18f3c4acf1" ], + "index": "pypi", "version": "==0.24.2" }, "pandocfilters": { @@ -587,6 +617,7 @@ "sha256:7e6584c74aeed623791615e26efd690f29817a27c73085b78e4bad02493df2fb", "sha256:c89805f6f4d64db21ed966fda138f8a5ed7a4fdbc1a8ee329ce1b74e3c74da9e" ], + "markers": "python_version >= '2.7'", "version": "==2.8.0" }, "pytorch-ignite": { @@ -597,6 +628,15 @@ "index": "pypi", "version": "==0.1.2" }, + "pytorch-pretrained-bert": { + "hashes": [ + "sha256:138c9702cc8da0c949a3b266a0c6e436aee4ae1c722b5d3eb1e47fb4b2b0f197", + "sha256:9ec5998f501381d86d6e0b4c4d92c1c2888f3f093e3a13177b3b94494b1bf7d7", + "sha256:f30ae5d19a95b64bd7068170640608cc457488948aa7643855aa261c2c8ab8b7" + ], + "index": "pypi", + "version": "==0.6.1" + }, "pytz": { "hashes": [ "sha256:32b0891edff07e28efe91284ed9c31e123d84bea3fd98e1f72be2508f43ef8d9", @@ -668,6 +708,13 @@ ], "version": "==2.21.0" }, + "s3transfer": { + "hashes": [ + "sha256:7b9ad3213bff7d357f888e0fab5101b56fa1a0548ee77d121c3a3dbfbef4cb2e", + "sha256:f23d5cb7d862b104401d9021fc82e5fa0e0cf57b7660a1331425aab0c691d021" + ], + "version": "==0.2.0" + }, "scikit-learn": { "hashes": [ "sha256:018f470a7e685767d84ce6fac87af59e064e87ec3cea71eaf12646f9538e293d", @@ -875,6 +922,7 @@ "sha256:61bf29cada3fc2fbefad4fdf059ea4bd1b4a86d2b6d15e1c7c0b582b9752fe39", "sha256:de9529817c93f27c8ccbfead6985011db27bd0ddfcdb2d86f3f663385c6a9c22" ], + "markers": "python_version >= '3.4'", "version": "==1.24.1" }, "wcwidth": { diff --git a/bin/train.py b/bin/train.py index 8a0f82b..6e6fe59 100644 --- a/bin/train.py +++ b/bin/train.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader -from src.vocab import VanillaVocab +from src.vocab import get_vocab from src.datasets import EpisodesSampler, EpisodesDataset from src.matching_network import MatchingNetwork from src.training import train @@ -44,6 +44,14 @@ type=str, default='cosine', help="Distance metric to be used") +parser.add_argument( + "-e", + "--embeddings", + action="store", + dest="embeddings", + type=str, + default='vanilla', + help="Type of embedding") parser.add_argument( "-p", "--processing-steps", @@ -65,7 +73,7 @@ def _get_loader(data_set, N, episodes_multiplier=1): def main(args): print("Loading dataset...") - vocab = VanillaVocab(args.vocab) + vocab = get_vocab(args.embeddings, args.vocab) X_train, y_train = vocab.to_tensors(args.training_set) # Split training further into train and valid @@ -77,7 +85,7 @@ def main(args): print("Initialising model...") model_name = get_model_name( distance=args.distance_metric, - embeddings='vanilla', + embeddings=args.embeddings, N=args.N, k=args.k) model = MatchingNetwork( diff --git a/src/vocab.py b/src/vocab.py index 9f10053..20bf048 100644 --- a/src/vocab.py +++ b/src/vocab.py @@ -1,11 +1,14 @@ import json import numpy as np +import pandas as pd from collections import defaultdict, Counter from torchtext.vocab import Vocab from torchtext.data import Field, TabularDataset +from pytorch_pretrained_bert import BertTokenizer + class AbstractVocab(object): """ @@ -86,8 +89,6 @@ def save(self, file_path): Parameters --- - vocab : torchtext.Vocab - Vocabulary with our corpus. file_path : str Path to the vocab state to write. @@ -105,8 +106,6 @@ def to_tensors(self, file_path): --- file_path : str Path to the CSV file. - vocab : torchtext.Vocab - Vocabulary to use. Returns --- @@ -216,3 +215,56 @@ def generate_vocab(cls, file_path): text.build_vocab(data_set.label, data_set.sentence) return text.vocab + + +class BertVocab(AbstractVocab): + """ + Implementation of mappings between text and tensors using Bert. + """ + + def __init__(self): + """ + Initialise Bert's tokenizer. + """ + self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + + def to_tensors(self, file_path): + """ + Reads the data set from one of the pre-processed CSVs composed + of columns `label` and `sentence`. + + Parameters + --- + file_path : str + Path to the CSV file. + + Returns + --- + X : torch.Tensor[num_labels x num_examples x sen_length] + Sentences on the dataset grouped by labels. + y : torch.Tensor[num_labels] + Labels for each group of sentences. + """ + data_set = pd.read_csv(file_path) + sentences_tokens = self.tokenizer.tokenize(data_set['sentences']) + labels_tokens = self.tokenizer.tokenize(data_set['labels']) + raise NotImplementedError() + + +VOCABS = {'vanilla': VanillaVocab, 'bert': BertVocab} + + +def get_vocab(embeddings, *args, **kwargs): + """ + Returns an initialised vocab, forwarding the extra args and kwargs. + + Parameters + --- + embeddings : str + Embeddings to use. Can be one of the VOCABS keys. + + Returns + --- + AbstractVocab + """ + return VOCABS[embeddings](*args, **kwargs) From a4097b82844ca103a46e3eab06bfa8433f805aa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adria=CC=81n=20Gonza=CC=81lez=20Marti=CC=81n?= Date: Sun, 17 Mar 2019 18:25:55 +0000 Subject: [PATCH 03/15] add method to convert to tensor --- src/vocab.py | 139 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 126 insertions(+), 13 deletions(-) diff --git a/src/vocab.py b/src/vocab.py index 20bf048..3a2e029 100644 --- a/src/vocab.py +++ b/src/vocab.py @@ -1,4 +1,5 @@ import json +import torch import numpy as np import pandas as pd @@ -16,6 +17,8 @@ class AbstractVocab(object): and numbers. """ + padding_token_index = 0 + def __len__(self): raise NotImplementedError() @@ -31,6 +34,8 @@ class VanillaVocab(AbstractVocab): Allows to map between text and numbers using a simple tokenizer. """ + padding_token_index = 1 + def __init__(self, file_path): """ Initialise the vocabulary by reading it from a file path. @@ -43,7 +48,6 @@ def __init__(self, file_path): super().__init__() self.vocab = self._read_vocab(file_path) - self.padding_token_index = 1 def _read_vocab(self, file_path): """ @@ -130,14 +134,7 @@ def to_tensors(self, file_path): sentences_tensor = sentence.process(data_set.sentence) labels_tensor = label.process(data_set.label).squeeze() - # Infer num_labels and group sentences by label - num_labels = labels_tensor.unique().shape[0] - num_examples = labels_tensor.shape[0] // num_labels - y = labels_tensor[::num_examples] - sen_length = sentences_tensor.shape[-1] - X = sentences_tensor.view(num_labels, num_examples, sen_length) - - return X, y + return _reshape_tensors(sentences_tensor, labels_tensor) @classmethod def _tokenizer(cls, text): @@ -222,10 +219,14 @@ class BertVocab(AbstractVocab): Implementation of mappings between text and tensors using Bert. """ - def __init__(self): + padding_token_index = 0 + + def __init__(self, *args, **kwargs): """ Initialise Bert's tokenizer. """ + super().__init__() + self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') def to_tensors(self, file_path): @@ -246,9 +247,121 @@ def to_tensors(self, file_path): Labels for each group of sentences. """ data_set = pd.read_csv(file_path) - sentences_tokens = self.tokenizer.tokenize(data_set['sentences']) - labels_tokens = self.tokenizer.tokenize(data_set['labels']) - raise NotImplementedError() + + # Convert into tokens and find max sen length + sentences_tokens, labels_tokens, sen_length = self._to_tokens(data_set) + + # Convert into tensors + num_elems = len(sentences_tokens) + sentences_tensor = torch.zeros((num_elems, sen_length)) + labels_tensor = torch.zeros(num_elems) + + for idx in range(num_elems): + tensor_sentence = self.tokenizer.convert_tokens_to_ids( + sentences_tokens[idx]) + tensor_label = self.tokenizer.convert_tokens_to_ids( + labels_tokens[idx]) + + sentences_tensor[idx, :len(tensor_sentence)] = torch.Tensor( + tensor_sentence) + labels_tensor[idx] = tensor_label[0] + + return _reshape_tensors(sentences_tensor, labels_tensor) + + def _to_tokens(self, data_set): + """ + Tokenize the dataset. + + Parameters + --- + data_set : pd.DataFrame[label, sentence] + Dataset with two columns. + + Returns + --- + sentences_tokens : list + List of tokenized sentences. + labels_tokens : list + List of tokenized labels. + sen_length : int + Maximum sentence length. + """ + sentences_tokens = [] + labels_tokens = [] + sen_length = 0 + for idx, row in data_set.iterrows(): + token_sentence = self._tokenize(row['sentence']) + token_label = self._tokenize(row['label']) + + if len(token_label) > 1: + continue + # raise ValueError(f"Label '{row['label']}' was split " + # f"into more than one tokens: " + # f"{token_label}") + + length = len(token_sentence) + if length > sen_length: + sen_length = length + + sentences_tokens.append(token_sentence) + labels_tokens.append(token_label) + + return sentences_tokens, labels_tokens, sen_length + + def _tokenize(self, text): + """ + Tokenize a text using Bert's tokenizer but processing it first to + replace: + + - => [UNK] + - => [MASK] + + Parameters + --- + text : str + Input string. + + Returns + --- + list + List of tokens. + """ + with_unk = text.replace('', '[UNK]') + with_mask = with_unk.replace('', '[MASK]') + + return self.tokenizer.tokenize(with_mask) + + +def _reshape_tensors(sentences_tensor, labels_tensor): + """ + Reshape tensors to the [N x k x sen_lenth] structure. + + Parameters + --- + sentences_tensor : torch.Tensor[num_elems x sen_length] + Flat tensor with all the sentences. + labels_tensor : torch.Tensor[num_elems] + Flat tensor with all the labels. + + Returns + --- + X : torch.Tensor[num_labels x num_examples x sen_length] + Sentences on the dataset grouped by labels. + y : torch.Tensor[num_labels] + Labels for each group of sentences. + """ + # Infer num_labels and num_examples by label + num_labels = labels_tensor.unique().shape[0] + num_examples = labels_tensor.shape[0] // num_labels + y = labels_tensor[::num_examples] + + # More robust to potentially duplicated labels + num_labels = y.shape[0] + + sen_length = sentences_tensor.shape[-1] + X = sentences_tensor.view(num_labels, num_examples, sen_length) + + return X, y VOCABS = {'vanilla': VanillaVocab, 'bert': BertVocab} From d19dbe1297f28f3b6853387b63ad48f2a6660889 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adria=CC=81n=20Gonza=CC=81lez=20Marti=CC=81n?= Date: Sun, 17 Mar 2019 18:30:49 +0000 Subject: [PATCH 04/15] add remaining methods --- src/vocab.py | 43 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/src/vocab.py b/src/vocab.py index 3a2e029..4e04647 100644 --- a/src/vocab.py +++ b/src/vocab.py @@ -163,8 +163,6 @@ def to_text(self, X): ---- X : torch.Tensor[num_elements x sen_length] Sentences on the tensor. - vocab : torchtext.Vocab - Vocabulary to use. Returns ---- @@ -229,6 +227,17 @@ def __init__(self, *args, **kwargs): self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + def __len__(self): + """ + Returns the length of Bert's vocabulary. + + Returns + --- + int + Length of the vocabulary. + """ + return len(self.tokenizer.vocab) + def to_tensors(self, file_path): """ Reads the data set from one of the pre-processed CSVs composed @@ -331,6 +340,36 @@ def _tokenize(self, text): return self.tokenizer.tokenize(with_mask) + def to_text(self, X): + """ + Reverses some numericalised tensor into text. + + Parameters + ---- + X : torch.Tensor[num_elements x sen_length] + Sentences on the tensor. + + Returns + ---- + sentences : np.array[num_elements] + Array of strings. + """ + sentences = [] + for sentence_tensor in X: + if len(sentence_tensor.shape) == 0: + # 0-D tensor + sentences.append(self.tokenizer.ids_to_tokens[sentence_tensor]) + continue + + sentence = [ + self.tokenizer.ids_to_tokens[token_id] + for token_id in sentence_tensor + if token_id != self.padding_token_index + ] + sentences.append(' '.join(sentence)) + + return np.array(sentences) + def _reshape_tensors(sentences_tensor, labels_tensor): """ From 0deb44619b9d4572b9d870d874df71979f1a6e09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adria=CC=81n=20Gonza=CC=81lez=20Marti=CC=81n?= Date: Sun, 17 Mar 2019 18:32:13 +0000 Subject: [PATCH 05/15] add comment on hack --- src/vocab.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/vocab.py b/src/vocab.py index 4e04647..6cca6e4 100644 --- a/src/vocab.py +++ b/src/vocab.py @@ -302,6 +302,10 @@ def _to_tokens(self, data_set): token_sentence = self._tokenize(row['sentence']) token_label = self._tokenize(row['label']) + # TODO: This is a shortcut to avoid dealing with + # situations where Bert's word-piece tokenizer + # splits a label into multiple tokens and thus + # multiple token ids to predict for a single input. if len(token_label) > 1: continue # raise ValueError(f"Label '{row['label']}' was split " From 6a4cf6430e61fdb77813e3a4412e4d72c37b589e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adria=CC=81n=20Gonza=CC=81lez=20Marti=CC=81n?= Date: Sun, 17 Mar 2019 19:43:58 +0000 Subject: [PATCH 06/15] add Matt's encoding layer --- bin/test.py | 10 ++++---- bin/train.py | 1 + src/matching_network.py | 51 ++++++++++++++++++++++++++++++----------- src/vocab.py | 3 +++ 4 files changed, 46 insertions(+), 19 deletions(-) diff --git a/bin/test.py b/bin/test.py index 6f5ebf2..0b56a1b 100644 --- a/bin/test.py +++ b/bin/test.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader -from src.vocab import VanillaVocab +from src.vocab import get_vocab from src.matching_network import MatchingNetwork from src.evaluation import (predict, save_predictions, generate_episode_data) from src.datasets import EpisodesSampler, EpisodesDataset @@ -50,11 +50,11 @@ parser.add_argument("test_set", help="Path to the test CSV file") -def _load_model(model_path): +def _load_model(model_path, vocab): model_file_name = os.path.basename(args.model) distance, embeddings, N, k = extract_model_parameters(model_file_name) model_name = get_model_name(distance, embeddings, N, k) - model = MatchingNetwork(model_name, distance_metric=distance) + model = MatchingNetwork(model_name, vocab, distance_metric=distance) model_state_dict = torch.load(model_path) model.load_state_dict(model_state_dict) @@ -63,10 +63,10 @@ def _load_model(model_path): def main(args): print("Loading model...") - model, _, N, k = _load_model(args.model) + vocab = get_vocab(args.embeddings, args.vocab) + model, embeddings, N, k = _load_model(args.model, vocab) print("Loading dataset...") - vocab = VanillaVocab(args.vocab) X_test, y_test = vocab.to_tensors(args.test_set) test_set = EpisodesDataset(X_test, y_test, k=k) sampler = EpisodesSampler(test_set, N=N, episodes_multiplier=30) diff --git a/bin/train.py b/bin/train.py index 6e6fe59..9eabfa8 100644 --- a/bin/train.py +++ b/bin/train.py @@ -90,6 +90,7 @@ def main(args): k=args.k) model = MatchingNetwork( model_name, + vocab, fce=True, processing_steps=args.processing_steps, distance_metric=args.distance_metric) diff --git a/src/matching_network.py b/src/matching_network.py index 367f07d..bb29cad 100644 --- a/src/matching_network.py +++ b/src/matching_network.py @@ -1,8 +1,11 @@ import torch + from torch import nn from torch.nn import functional as F + +from pytorch_pretrained_bert import BertModel + from .similarity import get_similarity_func -from .data import VOCAB_SIZE, PADDING_TOKEN_INDEX class EncodingLayer(nn.Module): @@ -11,23 +14,31 @@ class EncodingLayer(nn.Module): embedding. """ - def __init__(self, vocab_size, encoding_size): + def __init__(self, encoding_size, vocab): """ Initialises the encoding layer. Parameters --- - vocab_size : int - Size of the vocabulary to do one-hot encodings. encoding_size : int Target size of the encoding. + vocab : AbstractVocab + Vocabulary used for the encodings. """ super().__init__() - self.encoding_layer = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=encoding_size, - padding_idx=PADDING_TOKEN_INDEX) + self.vocab_size = len(vocab) + self.padding_token_index = vocab.padding_token_index + self.embeddings = vocab.name + + if self.embeddings == "bert": + self.encoding_layer = BertModel.from_pretrained( + 'bert-base-uncased') + else: + self.encoding_layer = nn.Embedding( + num_embeddings=self.vocab_size, + embedding_dim=encoding_size, + padding_idx=self.padding_token_index) def forward(self, sentences): """ @@ -58,8 +69,20 @@ def forward(self, sentences): sen_length = sentences.shape[2] flattened = reshaped.reshape(-1, sen_length) - encoded_flat = self.encoding_layer(flattened) - pooled_flat = encoded_flat.sum(dim=1) + + if self.embeddings == "bert": + # We don't want to fine-tune BERT! + with torch.no_grad(): + encoded_layers, _ = self.encoding_layer(flattened) + + # We have a hidden states for each of the 12 layers + # in model bert-base-uncased + + # Remove useless dimension + pooled_flat = torch.squeeze(encoded_layers[11]) + else: + encoded_flat = self.encoding_layer(flattened) + pooled_flat = encoded_flat.sum(dim=1) # Re-shape into original form (4D or 3D tensor) enc_size = pooled_flat.shape[1] @@ -249,8 +272,8 @@ class MatchingNetwork(nn.Module): def __init__(self, name, + vocab, fce=True, - vocab_size=VOCAB_SIZE, processing_steps=5, distance_metric="cosine"): """ @@ -260,10 +283,10 @@ def __init__(self, --- name : str Name of the model. Used for storing checkpoints. + vocab : AbstractVocab + AbstractVocab object. fce : bool Flag to decide if we should use Full Context Embeddings. - vocab_size : int - Size of the vocabulary to do one-hot encodings. processing_steps : int How many processing steps to take when embedding the target query. @@ -275,7 +298,7 @@ def __init__(self, self.name = name self.encoding_size = 64 - self.vocab_size = vocab_size + self.vocab_size = len(vocab) self.encode = EncodingLayer(self.vocab_size, self.encoding_size) self.g = GLayer(self.encoding_size, fce=fce) diff --git a/src/vocab.py b/src/vocab.py index 6cca6e4..424dd12 100644 --- a/src/vocab.py +++ b/src/vocab.py @@ -17,6 +17,7 @@ class AbstractVocab(object): and numbers. """ + name = "" padding_token_index = 0 def __len__(self): @@ -34,6 +35,7 @@ class VanillaVocab(AbstractVocab): Allows to map between text and numbers using a simple tokenizer. """ + name = "vanilla" padding_token_index = 1 def __init__(self, file_path): @@ -217,6 +219,7 @@ class BertVocab(AbstractVocab): Implementation of mappings between text and tensors using Bert. """ + name = "bert" padding_token_index = 0 def __init__(self, *args, **kwargs): From a1a322b7a365e21e7a18b355760ca7ff189cf87a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adria=CC=81n=20Gonza=CC=81lez=20Marti=CC=81n?= Date: Sun, 17 Mar 2019 19:46:41 +0000 Subject: [PATCH 07/15] fix type mismatch --- src/matching_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/matching_network.py b/src/matching_network.py index bb29cad..e32a35e 100644 --- a/src/matching_network.py +++ b/src/matching_network.py @@ -300,7 +300,7 @@ def __init__(self, self.encoding_size = 64 self.vocab_size = len(vocab) - self.encode = EncodingLayer(self.vocab_size, self.encoding_size) + self.encode = EncodingLayer(self.encoding_size, vocab) self.g = GLayer(self.encoding_size, fce=fce) self.f = FLayer(self.encoding_size, processing_steps=processing_steps) From c5e915a12609bda3c4d8e8c42e566613f72eed4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adria=CC=81n=20Gonza=CC=81lez=20Marti=CC=81n?= Date: Sun, 17 Mar 2019 20:35:32 +0000 Subject: [PATCH 08/15] fixes for variable sentence encoding --- src/matching_network.py | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/src/matching_network.py b/src/matching_network.py index e32a35e..e00b396 100644 --- a/src/matching_network.py +++ b/src/matching_network.py @@ -102,21 +102,24 @@ class FLayer(nn.Module): Referred to in the paper as `f(x)` on the paper. """ - def __init__(self, encoding_size, processing_steps): + def __init__(self, sentence_encoding_size, encoding_size, + processing_steps): """ Initialise the `f()` layer. Parameters ---- - encoding_size : int + sentence_encoding_size : int Size of the sentence encodings. + encoding_size : int + Size of the embedded sentence through f(). processing_steps : int Number of processing steps for the LSTM. Referred to as `K` on the paper. """ super().__init__() self.lstm_cell = nn.LSTMCell( - input_size=encoding_size, hidden_size=encoding_size) + input_size=sentence_encoding_size, hidden_size=encoding_size) self.processing_steps = processing_steps def forward(self, targets, support_embeddings): @@ -125,7 +128,7 @@ def forward(self, targets, support_embeddings): Parameters ---- - targets : torch.Tensor[batch_size x T x encoding_size] + targets : torch.Tensor[batch_size x T x sentence_encoding_size] List of targets to predict. support_embeddings : torch.Tensor[batch_size x N x k x encoding_size] Embeddings of the support set. @@ -137,8 +140,10 @@ def forward(self, targets, support_embeddings): # Flatten so that targets are 2D # (i.e. [(batch_size * T) x encoding_size]) T = targets.shape[1] - encoding_size = targets.shape[2] - flattened_targets = targets.view(-1, encoding_size) + encoding_size = support_embeddings.shape[3] + sentence_encoding_size = targets.shape[2] + + flattened_targets = targets.view(-1, sentence_encoding_size) h_prev = torch.zeros_like(flattened_targets) c_prev = torch.zeros_like(flattened_targets) @@ -206,14 +211,16 @@ class GLayer(nn.Module): Referred to in the paper as `g()`. """ - def __init__(self, encoding_size, fce): + def __init__(self, sentence_encoding_size, encoding_size, fce): """ Initialise the g()-layer. Parameters --- - encoding_size : int + sentence_encoding_size : int Size of the sentence encodings. + encoding_size : int + Size of the embedded sentence through g(). fce : bool Flag to decide if we should use Full Context Embeddings. """ @@ -222,7 +229,7 @@ def __init__(self, encoding_size, fce): self.fce_layer = None if fce: self.fce_layer = nn.LSTM( - input_size=encoding_size, + input_size=sentence_encoding_size, hidden_size=encoding_size, bidirectional=True, batch_first=True) @@ -233,7 +240,7 @@ def forward(self, support_encodings): Parameters --- - support_set : torch.Tensor[batch_size x N x k x encoding_size] + support_set : torch.Tensor[batch_size x N x k x sentence_encoding_size] Support set containing [batch_size] episodes of [N] labels with [k] examples each. The last dimension represents the list of tokens in each sentence. @@ -297,12 +304,20 @@ def __init__(self, self.name = name + self.sentence_encoding_size = 64 + if vocab.name == 'bert': + self.sentence_encoding_size = 210 + self.encoding_size = 64 self.vocab_size = len(vocab) - self.encode = EncodingLayer(self.encoding_size, vocab) - self.g = GLayer(self.encoding_size, fce=fce) - self.f = FLayer(self.encoding_size, processing_steps=processing_steps) + self.encode = EncodingLayer(self.sentence_encoding_size, vocab) + self.g = GLayer( + self.sentence_encoding_size, self.encoding_size, fce=fce) + self.f = FLayer( + self.sentence_encoding_size, + self.encoding_size, + processing_steps=processing_steps) self.distance_metric = distance_metric From 43bc4ef99879eddb4109d2ede6b8b55f12c93e21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adria=CC=81n=20Gonza=CC=81lez=20Marti=CC=81n?= Date: Sun, 17 Mar 2019 20:50:24 +0000 Subject: [PATCH 09/15] change entire encoding size to 210 --- src/matching_network.py | 45 +++++++++++++++-------------------------- 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/src/matching_network.py b/src/matching_network.py index e00b396..ba8fecb 100644 --- a/src/matching_network.py +++ b/src/matching_network.py @@ -32,8 +32,7 @@ def __init__(self, encoding_size, vocab): self.embeddings = vocab.name if self.embeddings == "bert": - self.encoding_layer = BertModel.from_pretrained( - 'bert-base-uncased') + self.bert_layer = BertModel.from_pretrained('bert-base-uncased') else: self.encoding_layer = nn.Embedding( num_embeddings=self.vocab_size, @@ -102,24 +101,21 @@ class FLayer(nn.Module): Referred to in the paper as `f(x)` on the paper. """ - def __init__(self, sentence_encoding_size, encoding_size, - processing_steps): + def __init__(self, encoding_size, processing_steps): """ Initialise the `f()` layer. Parameters ---- - sentence_encoding_size : int - Size of the sentence encodings. encoding_size : int - Size of the embedded sentence through f(). + Size of the sentence encodings. processing_steps : int Number of processing steps for the LSTM. Referred to as `K` on the paper. """ super().__init__() self.lstm_cell = nn.LSTMCell( - input_size=sentence_encoding_size, hidden_size=encoding_size) + input_size=encoding_size, hidden_size=encoding_size) self.processing_steps = processing_steps def forward(self, targets, support_embeddings): @@ -128,7 +124,7 @@ def forward(self, targets, support_embeddings): Parameters ---- - targets : torch.Tensor[batch_size x T x sentence_encoding_size] + targets : torch.Tensor[batch_size x T x encoding_size] List of targets to predict. support_embeddings : torch.Tensor[batch_size x N x k x encoding_size] Embeddings of the support set. @@ -140,10 +136,8 @@ def forward(self, targets, support_embeddings): # Flatten so that targets are 2D # (i.e. [(batch_size * T) x encoding_size]) T = targets.shape[1] - encoding_size = support_embeddings.shape[3] - sentence_encoding_size = targets.shape[2] - - flattened_targets = targets.view(-1, sentence_encoding_size) + encoding_size = targets.shape[2] + flattened_targets = targets.view(-1, encoding_size) h_prev = torch.zeros_like(flattened_targets) c_prev = torch.zeros_like(flattened_targets) @@ -211,16 +205,14 @@ class GLayer(nn.Module): Referred to in the paper as `g()`. """ - def __init__(self, sentence_encoding_size, encoding_size, fce): + def __init__(self, encoding_size, fce): """ Initialise the g()-layer. Parameters --- - sentence_encoding_size : int - Size of the sentence encodings. encoding_size : int - Size of the embedded sentence through g(). + Size of the sentence encodings. fce : bool Flag to decide if we should use Full Context Embeddings. """ @@ -229,7 +221,7 @@ def __init__(self, sentence_encoding_size, encoding_size, fce): self.fce_layer = None if fce: self.fce_layer = nn.LSTM( - input_size=sentence_encoding_size, + input_size=encoding_size, hidden_size=encoding_size, bidirectional=True, batch_first=True) @@ -240,7 +232,7 @@ def forward(self, support_encodings): Parameters --- - support_set : torch.Tensor[batch_size x N x k x sentence_encoding_size] + support_set : torch.Tensor[batch_size x N x k x encoding_size] Support set containing [batch_size] episodes of [N] labels with [k] examples each. The last dimension represents the list of tokens in each sentence. @@ -304,20 +296,15 @@ def __init__(self, self.name = name - self.sentence_encoding_size = 64 + self.encoding_size = 64 if vocab.name == 'bert': - self.sentence_encoding_size = 210 + self.encoding_size = 210 - self.encoding_size = 64 self.vocab_size = len(vocab) - self.encode = EncodingLayer(self.sentence_encoding_size, vocab) - self.g = GLayer( - self.sentence_encoding_size, self.encoding_size, fce=fce) - self.f = FLayer( - self.sentence_encoding_size, - self.encoding_size, - processing_steps=processing_steps) + self.encode = EncodingLayer(self.encoding_size, vocab) + self.g = GLayer(self.encoding_size, fce=fce) + self.f = FLayer(self.encoding_size, processing_steps=processing_steps) self.distance_metric = distance_metric From bd95af3b3d45447028827bbf262e4a48e116553b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adria=CC=81n=20Gonza=CC=81lez=20Marti=CC=81n?= Date: Sun, 17 Mar 2019 20:54:07 +0000 Subject: [PATCH 10/15] fix --- src/matching_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/matching_network.py b/src/matching_network.py index ba8fecb..8647a7d 100644 --- a/src/matching_network.py +++ b/src/matching_network.py @@ -32,7 +32,8 @@ def __init__(self, encoding_size, vocab): self.embeddings = vocab.name if self.embeddings == "bert": - self.bert_layer = BertModel.from_pretrained('bert-base-uncased') + self.encoding_layer = BertModel.from_pretrained( + 'bert-base-uncased') else: self.encoding_layer = nn.Embedding( num_embeddings=self.vocab_size, From 2df69ac3849f446749a31351a70fffbc26c48516 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adria=CC=81n=20Gonza=CC=81lez=20Marti=CC=81n?= Date: Sun, 17 Mar 2019 20:59:14 +0000 Subject: [PATCH 11/15] test --- src/matching_network.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/matching_network.py b/src/matching_network.py index 8647a7d..56557d9 100644 --- a/src/matching_network.py +++ b/src/matching_network.py @@ -394,6 +394,8 @@ def _to_logits(self, attention, labels): """ # Sum across labels attention = attention.sum(dim=3) + import ipdb + ipdb.set_trace() batch_size, T, N = attention.shape logits = torch.zeros((batch_size, T, self.vocab_size)) From cfcaff1367b14f8e01ea2304bb7b0ef1b1b27887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adria=CC=81n=20Gonza=CC=81lez=20Marti=CC=81n?= Date: Sun, 17 Mar 2019 21:02:41 +0000 Subject: [PATCH 12/15] test2 --- src/matching_network.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/matching_network.py b/src/matching_network.py index 56557d9..548d5e0 100644 --- a/src/matching_network.py +++ b/src/matching_network.py @@ -80,6 +80,9 @@ def forward(self, sentences): # Remove useless dimension pooled_flat = torch.squeeze(encoded_layers[11]) + + import ipdb + ipdb.set_trace() else: encoded_flat = self.encoding_layer(flattened) pooled_flat = encoded_flat.sum(dim=1) @@ -394,8 +397,6 @@ def _to_logits(self, attention, labels): """ # Sum across labels attention = attention.sum(dim=3) - import ipdb - ipdb.set_trace() batch_size, T, N = attention.shape logits = torch.zeros((batch_size, T, self.vocab_size)) From 2dc0328af203bb84939d0307408ba1eb6a3378ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adria=CC=81n=20Gonza=CC=81lez=20Marti=CC=81n?= Date: Sun, 17 Mar 2019 21:09:17 +0000 Subject: [PATCH 13/15] Use sum-pooling. --- src/matching_network.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/matching_network.py b/src/matching_network.py index 548d5e0..e5d49bf 100644 --- a/src/matching_network.py +++ b/src/matching_network.py @@ -79,10 +79,8 @@ def forward(self, sentences): # in model bert-base-uncased # Remove useless dimension - pooled_flat = torch.squeeze(encoded_layers[11]) - - import ipdb - ipdb.set_trace() + encoded_flat = torch.squeeze(encoded_layers[11]) + pooled_flat = encoded_flat.sum(dim=1) else: encoded_flat = self.encoding_layer(flattened) pooled_flat = encoded_flat.sum(dim=1) From 21e9422b1d9821312cfdf01e37871595c6778795 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adria=CC=81n=20Gonza=CC=81lez=20Marti=CC=81n?= Date: Sun, 17 Mar 2019 21:15:32 +0000 Subject: [PATCH 14/15] reduce dim down to 64 --- src/matching_network.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/matching_network.py b/src/matching_network.py index e5d49bf..bc8589c 100644 --- a/src/matching_network.py +++ b/src/matching_network.py @@ -32,8 +32,10 @@ def __init__(self, encoding_size, vocab): self.embeddings = vocab.name if self.embeddings == "bert": - self.encoding_layer = BertModel.from_pretrained( - 'bert-base-uncased') + bert_encoding_size = 768 + self.bert_layer = BertModel.from_pretrained('bert-base-uncased') + self.encoding_layer = nn.Linear( + in_features=bert_encoding_size, out_features=encoding_size) else: self.encoding_layer = nn.Embedding( num_embeddings=self.vocab_size, @@ -73,7 +75,7 @@ def forward(self, sentences): if self.embeddings == "bert": # We don't want to fine-tune BERT! with torch.no_grad(): - encoded_layers, _ = self.encoding_layer(flattened) + encoded_layers, _ = self.bert_layer(flattened) # We have a hidden states for each of the 12 layers # in model bert-base-uncased @@ -81,6 +83,9 @@ def forward(self, sentences): # Remove useless dimension encoded_flat = torch.squeeze(encoded_layers[11]) pooled_flat = encoded_flat.sum(dim=1) + + # Reduce dimensionality to 64 + pooled_flat = self.encoding_layer(pooled_flat) else: encoded_flat = self.encoding_layer(flattened) pooled_flat = encoded_flat.sum(dim=1) @@ -299,9 +304,6 @@ def __init__(self, self.name = name self.encoding_size = 64 - if vocab.name == 'bert': - self.encoding_size = 210 - self.vocab_size = len(vocab) self.encode = EncodingLayer(self.encoding_size, vocab) From 1ee59bf24d673cc7c737c6538ca5aef544991e10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adria=CC=81n=20Gonza=CC=81lez=20Marti=CC=81n?= Date: Sun, 17 Mar 2019 21:25:00 +0000 Subject: [PATCH 15/15] fix device --- src/matching_network.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/matching_network.py b/src/matching_network.py index bc8589c..95dc881 100644 --- a/src/matching_network.py +++ b/src/matching_network.py @@ -330,7 +330,8 @@ def _similarity(self, support_embeddings, target_embeddings): """ batch_size, N, k, _ = support_embeddings.shape T = target_embeddings.shape[1] - similarities = torch.zeros(batch_size, T, N, k) + similarities = torch.zeros((batch_size, T, N, k), + device=support_embeddings.device) similarity_func = get_similarity_func(self.distance_metric) # TODO: Would be good to optimise this so that it's vectorised. @@ -398,7 +399,8 @@ def _to_logits(self, attention, labels): # Sum across labels attention = attention.sum(dim=3) batch_size, T, N = attention.shape - logits = torch.zeros((batch_size, T, self.vocab_size)) + logits = torch.zeros((batch_size, T, self.vocab_size), + device=attention.device) # TODO: Would be good to optimise this so that it's vectorised. for batch_idx in range(batch_size):