Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ scikit-learn = "*"
torch = "*"
pytorch-ignite = "*"
torchtext = "*"
pytorch-pretrained-bert = "*"
pandas = "*"

[requires]
python_version = "3.6"
50 changes: 49 additions & 1 deletion Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions bin/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

from torch.utils.data import DataLoader

from src.vocab import get_vocab
from src.matching_network import MatchingNetwork
from src.evaluation import (predict, save_predictions, generate_episode_data)
from src.data import read_vocab, read_data_set
from src.datasets import EpisodesSampler, EpisodesDataset
from src.utils import extract_model_parameters, get_model_name

Expand Down Expand Up @@ -50,11 +50,11 @@
parser.add_argument("test_set", help="Path to the test CSV file")


def _load_model(model_path):
def _load_model(model_path, vocab):
model_file_name = os.path.basename(args.model)
distance, embeddings, N, k = extract_model_parameters(model_file_name)
model_name = get_model_name(distance, embeddings, N, k)
model = MatchingNetwork(model_name, distance_metric=distance)
model = MatchingNetwork(model_name, vocab, distance_metric=distance)
model_state_dict = torch.load(model_path)
model.load_state_dict(model_state_dict)

Expand All @@ -63,11 +63,11 @@ def _load_model(model_path):

def main(args):
print("Loading model...")
model, _, N, k = _load_model(args.model)
vocab = get_vocab(args.embeddings, args.vocab)
model, embeddings, N, k = _load_model(args.model, vocab)

print("Loading dataset...")
vocab = read_vocab(args.vocab)
X_test, y_test = read_data_set(args.test_set, vocab)
X_test, y_test = vocab.to_tensors(args.test_set)
test_set = EpisodesDataset(X_test, y_test, k=k)
sampler = EpisodesSampler(test_set, N=N, episodes_multiplier=30)
test_loader = DataLoader(test_set, sampler=sampler, batch_size=BATCH_SIZE)
Expand Down
17 changes: 13 additions & 4 deletions bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from torch.utils.data import DataLoader

from src.data import read_vocab, read_data_set
from src.vocab import get_vocab
from src.datasets import EpisodesSampler, EpisodesDataset
from src.matching_network import MatchingNetwork
from src.training import train
Expand Down Expand Up @@ -44,6 +44,14 @@
type=str,
default='cosine',
help="Distance metric to be used")
parser.add_argument(
"-e",
"--embeddings",
action="store",
dest="embeddings",
type=str,
default='vanilla',
help="Type of embedding")
parser.add_argument(
"-p",
"--processing-steps",
Expand All @@ -65,8 +73,8 @@ def _get_loader(data_set, N, episodes_multiplier=1):

def main(args):
print("Loading dataset...")
vocab = read_vocab(args.vocab)
X_train, y_train = read_data_set(args.training_set, vocab)
vocab = get_vocab(args.embeddings, args.vocab)
X_train, y_train = vocab.to_tensors(args.training_set)

# Split training further into train and valid
X_train, X_valid, y_train, y_valid = train_test_split_tensors(
Expand All @@ -77,11 +85,12 @@ def main(args):
print("Initialising model...")
model_name = get_model_name(
distance=args.distance_metric,
embeddings='vanilla',
embeddings=args.embeddings,
N=args.N,
k=args.k)
model = MatchingNetwork(
model_name,
vocab,
fce=True,
processing_steps=args.processing_steps,
distance_metric=args.distance_metric)
Expand Down
6 changes: 3 additions & 3 deletions bin/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from argparse import ArgumentParser

from src.data import generate_vocab, store_vocab
from src.vocab import VanillaVocab

parser = ArgumentParser()
parser.add_argument("input", help="Path to the input CSV data set")
Expand All @@ -13,10 +13,10 @@

def main(args):
print("Generating vocab...")
vocab = generate_vocab(args.input)
vocab = VanillaVocab.generate_vocab(args.input)

print("Storing vocab...")
store_vocab(vocab, args.output)
vocab.save(args.output)

print(f"Stored vocab of size {len(vocab)} at {args.output}")

Expand Down
3 changes: 3 additions & 0 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from torchtext.data import Field, TabularDataset
from torchtext.vocab import Vocab

print(
"[WARNING] Don't use src.data anymore. Use the Vocab interfaces instead.")

VOCAB_SIZE = 27443

PADDING_TOKEN_INDEX = 1
Expand Down
10 changes: 4 additions & 6 deletions src/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import torch
import numpy as np

from .data import reverse_tensor

RESULTS_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "results")

Expand Down Expand Up @@ -174,11 +172,11 @@ def _episode_to_text(support_set, targets, labels, target_labels, vocab):
# First, we need to flatten these...
N, k, _ = support_set.shape
flat_support_set = support_set.view(N * k, -1)
flat_support_set = reverse_tensor(flat_support_set, vocab)
flat_support_set = vocab.to_text(flat_support_set)
support_set = flat_support_set.reshape(N, k)

targets = reverse_tensor(targets, vocab)
labels = reverse_tensor(labels, vocab)
target_labels = reverse_tensor(target_labels, vocab)
targets = vocab.to_text(targets)
labels = vocab.to_text(labels)
target_labels = vocab.to_text(target_labels)

return support_set, targets, labels, target_labels
Loading