From 1b76656358288477872fb7d7c44969c863c3d7c9 Mon Sep 17 00:00:00 2001 From: truthless11 Date: Wed, 18 Sep 2019 20:16:50 +0800 Subject: [PATCH 1/3] add transformer --- .../nlg/multiwoz/transformer/dataloader.py | 99 ++++ .../nlg/multiwoz/transformer/evaluator.py | 222 ++++++++ .../modules/nlg/multiwoz/transformer/tools.py | 461 ++++++++++++++++ .../modules/nlg/multiwoz/transformer/train.py | 191 +++++++ .../multiwoz/transformer/transformer/Beam.py | 116 ++++ .../transformer/transformer/Constants.py | 57 ++ .../transformer/transformer/Transformer.py | 505 ++++++++++++++++++ .../multiwoz/transformer/transformer_nlg.py | 270 ++++++++++ 8 files changed, 1921 insertions(+) create mode 100644 convlab/modules/nlg/multiwoz/transformer/dataloader.py create mode 100644 convlab/modules/nlg/multiwoz/transformer/evaluator.py create mode 100644 convlab/modules/nlg/multiwoz/transformer/tools.py create mode 100644 convlab/modules/nlg/multiwoz/transformer/train.py create mode 100644 convlab/modules/nlg/multiwoz/transformer/transformer/Beam.py create mode 100644 convlab/modules/nlg/multiwoz/transformer/transformer/Constants.py create mode 100644 convlab/modules/nlg/multiwoz/transformer/transformer/Transformer.py create mode 100644 convlab/modules/nlg/multiwoz/transformer/transformer_nlg.py diff --git a/convlab/modules/nlg/multiwoz/transformer/dataloader.py b/convlab/modules/nlg/multiwoz/transformer/dataloader.py new file mode 100644 index 0000000..d55b664 --- /dev/null +++ b/convlab/modules/nlg/multiwoz/transformer/dataloader.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Sep 14 16:29:52 2019 + +@author: truthless +""" +import os +import json +from convlab.modules.nlg.multiwoz.transformer.transformer import Constants +import logging +import torch + +# input_ids, rep_in, resp_out, act_vecs + +logger = logging.getLogger(__name__) + +def get_batch(part, tokenizer, max_seq_length=50): + examples = [] + data_dir = os.path.dirname(os.path.abspath(__file__)) + + if part == 'train': + with open('{}/data/train.json'.format(data_dir)) as f: + source = json.load(f) + elif part == 'val': + with open('{}/data/val.json'.format(data_dir)) as f: + source = json.load(f) + elif part == 'test': + with open('{}/data/test.json'.format(data_dir)) as f: + source = json.load(f) + else: + raise ValueError(f'Unknown option {part}') + + logger.info("Loading total {} dialogs".format(len(source))) + for num_dial, dialog_info in enumerate(source): + hist = [] + dialog_file = dialog_info['file'] + dialog = dialog_info['info'] + + for turn_num, turn in enumerate(dialog): + + tokens = tokenizer.tokenize(turn['user']) + if len(hist) == 0: + if len(tokens) > max_seq_length - 2: + tokens = tokens[:max_seq_length - 2] + else: + tokens = hist + [Constants.SEP_WORD] + tokens + if len(tokens) > max_seq_length - 2: + tokens = tokens[-(max_seq_length - 2):] + + tokens = [Constants.CLS_WORD] + tokens + [Constants.SEP_WORD] + input_ids = tokenizer.convert_tokens_to_ids(tokens) + + padding = [Constants.PAD] * (max_seq_length - len(input_ids)) + input_ids += padding + + resp = [Constants.SOS_WORD] + tokenizer.tokenize(turn['sys']) + [Constants.EOS_WORD] + + if len(resp) > Constants.RESP_MAX_LEN: + resp = resp[:Constants.RESP_MAX_LEN-1] + [Constants.EOS_WORD] + else: + resp = resp + [Constants.PAD_WORD] * (Constants.RESP_MAX_LEN - len(resp)) + + resp_inp_ids = tokenizer.convert_tokens_to_ids(resp[:-1]) + resp_out_ids = tokenizer.convert_tokens_to_ids(resp[1:]) + + act_vecs = [0] * len(Constants.act_ontology) + for intent in turn['act']: + for values in turn['act'][intent]: + w = intent + '-' + values[0] + '-' + values[1] + act_vecs[Constants.act_ontology.index(w)] = 1 + + examples.append([input_ids, resp_inp_ids, resp_out_ids, act_vecs, dialog_file]) + + sys = tokenizer.tokenize(turn['sys']) + if turn_num == 0: + hist = tokens[1:-1] + [Constants.SEP_WORD] + sys + else: + hist = hist + [Constants.SEP_WORD] + tokens[1:-1] + [Constants.SEP_WORD] + sys + + all_input_ids = torch.tensor([f[0] for f in examples], dtype=torch.long) + all_response_in = torch.tensor([f[1] for f in examples], dtype=torch.long) + all_response_out = torch.tensor([f[2] for f in examples], dtype=torch.long) + all_act_vecs = torch.tensor([f[3] for f in examples], dtype=torch.float32) + all_files = [f[4] for f in examples] + + return all_input_ids, all_response_in, all_response_out, all_act_vecs, all_files + +def get_info(source, part): + result = {} + + for num_dial, dialog_info in enumerate(source): + dialog_file = dialog_info['file'] + dialog = dialog_info['info'] + result[dialog_file] = [] + + for turn_num, turn in enumerate(dialog): + result[dialog_file].append(turn[part]) + + return result diff --git a/convlab/modules/nlg/multiwoz/transformer/evaluator.py b/convlab/modules/nlg/multiwoz/transformer/evaluator.py new file mode 100644 index 0000000..9590a86 --- /dev/null +++ b/convlab/modules/nlg/multiwoz/transformer/evaluator.py @@ -0,0 +1,222 @@ +from convlab.modules.util.multiwoz.dbquery import query +import json +import os + +domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police'] +requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + + +def queryResultVenues(domain, turn, real_belief=False): + if real_belief: + constraints = turn.items() + else: + constraints = turn['metadata'][domain]['semi'].items() + result = query(domain, constraints) + return result + +def issubset(A, B): + A = set(A) + B = set(B) + return A.issubset(B) + +def parseGoal(goal, d, domain): + """Parses user goal into dictionary format.""" + goal[domain] = {} + goal[domain] = {'informable': [], 'requestable': [], 'booking': []} + if 'info' in d['goal'][domain]: + if domain == 'train': + # we consider dialogues only where train had to be booked! + if 'book' in d['goal'][domain]: + goal[domain]['requestable'].append('reference') + if 'reqt' in d['goal'][domain]: + if 'trainID' in d['goal'][domain]['reqt']: + goal[domain]['requestable'].append('id') + else: + if 'reqt' in d['goal'][domain]: + for s in d['goal'][domain]['reqt']: # addtional requests: + if s in ['phone', 'address', 'postcode', 'reference', 'id']: + # ones that can be easily delexicalized + goal[domain]['requestable'].append(s) + if 'book' in d['goal'][domain]: + goal[domain]['requestable'].append("reference") + + goal[domain]["informable"] = d['goal'][domain]['info'] + if 'book' in d['goal'][domain]: + goal[domain]["booking"] = d['goal'][domain]['book'] + + return goal + + +def evaluateModel(dialogues, mode='valid'): + """Gathers statistics for the whole sets.""" + data_dir = os.path.dirname(os.path.abspath(__file__)) + with open(os.path.join(data_dir, 'data/delex.json')) as fin: + delex_dialogues = json.load(fin) + successes, matches = 0, 0 + total = 0 + + for filename, dial in dialogues.items(): + if filename not in delex_dialogues: + filename += ".json" + + data = delex_dialogues[filename] + + success, match, _ = evaluateDialogue(dial, data) + + successes += success + matches += match + total += 1 + + # Print results + matches = matches / float(total) * 100 + successes = successes / float(total) * 100 + + print('Corpus Entity Matches : %2.2f%%' % (matches)) + print('Corpus Requestable Success : %2.2f%%' % (successes)) + + +def evaluateDialogue(dialog, realDialogue): + # get the list of domains in the goal + goal = {} + for domain in domains: + if realDialogue['goal'][domain]: + goal = parseGoal(goal, realDialogue, domain) + + real_requestables = {} + for domain in goal.keys(): + real_requestables[domain] = goal[domain]['requestable'] + + # CHECK IF MATCH HAPPENED + provided_requestables = {} + venue_offered = {} + + for domain in goal.keys(): + venue_offered[domain] = [] + provided_requestables[domain] = [] + + for t, sent_t in enumerate(dialog): + #sent_t = sent_t.replace("colleges", "[attaraction_name]") + #sent_t = sent_t.replace("college", "[attaraction_name]") + for domain in goal.keys(): + # Search for the only restaurant, hotel, attraction or train with an ID + if '[' + domain + '_name]' in sent_t or 'trainid]' in sent_t: + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION + venues = queryResultVenues(domain, realDialogue['log'][t*2 + 1]) + # if venue has changed + if len(venue_offered[domain]) == 0 and venues: + venue_offered[domain] = venues#random.sample(venues, 1) + else: + flag = True + for ven in venue_offered[domain]: + if ven not in venues: + flag = False + break + if not flag and venues: # sometimes there are no results so sample won't work + venue_offered[domain] = venues + else: + venue_offered[domain] = '[' + domain + '_name]' + + # ATTENTION: assumption here - we didn't provide phone or address twice! etc + for requestable in requestables: + if requestable == 'reference': + if domain + '_reference' in sent_t: + + if 'restaurant_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][-5] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'hotel_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][-3] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + elif 'train_reference' in sent_t: + if realDialogue['log'][t * 2]['db_pointer'][-1] == 1: # if pointer was allowing for that? + provided_requestables[domain].append('reference') + + else: + provided_requestables[domain].append('reference') + else: + if domain + '_' + requestable + ']' in sent_t: + provided_requestables[domain].append(requestable) + + # if name was given in the task + for domain in goal.keys(): + # if name was provided for the user, the match is being done automatically + if 'name' in goal[domain]['informable']: + venue_offered[domain] = '[' + domain + '_name]' + + # special domains - entity does not need to be provided + if domain in ['taxi', 'police', 'hospital']: + venue_offered[domain] = '[' + domain + '_name]' + + if domain == 'train': + if not venue_offered[domain]: + if goal[domain]['requestable'] and 'id' not in goal[domain]['requestable']: + venue_offered[domain] = '[' + domain + '_name]' + """ + Given all inform and requestable slots + we go through each domain from the user goal + and check whether right entity was provided and + all requestable slots were given to the user. + The dialogue is successful if that's the case for all domains. + """ + # HARD EVAL + stats = {'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0,0], 'taxi': [0, 0, 0], + 'hospital': [0, 0, 0], 'police': [0, 0, 0]} + + match = 0 + success = 0 + # MATCH + for domain in goal.keys(): + match_stat = 0 + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + if type(venue_offered[domain]) is str and '_name' in venue_offered[domain]: + match += 1 + match_stat = 1 + elif venue_offered[domain]: + groundtruth = queryResultVenues(domain, goal[domain]['informable'], real_belief=True) + if issubset(venue_offered[domain], groundtruth): + match += 1 + match_stat = 1 + else: + if '[' + domain + '_name]' in venue_offered[domain]: + match += 1 + match_stat = 1 + + stats[domain][0] = match_stat + stats[domain][2] = 1 + + if match == len(goal): + match = 1 + else: + match = 0 + + # SUCCESS + if match: + for domain in goal.keys(): + success_stat = 0 + domain_success = 0 + if len(real_requestables[domain]) == 0: + success += 1 + success_stat = 1 + stats[domain][1] = success_stat + continue + # if values in sentences are super set of requestables + for request in set(provided_requestables[domain]): + if request in real_requestables[domain]: + domain_success += 1 + + if domain_success >= len(real_requestables[domain]): + success += 1 + success_stat = 1 + + stats[domain][1] = success_stat + + if success >= len(real_requestables): + success = 1 + else: + success = 0 + + #print requests, 'DIFF', requests_real, 'SUCC', success + return success, match, stats diff --git a/convlab/modules/nlg/multiwoz/transformer/tools.py b/convlab/modules/nlg/multiwoz/transformer/tools.py new file mode 100644 index 0000000..c5adbe2 --- /dev/null +++ b/convlab/modules/nlg/multiwoz/transformer/tools.py @@ -0,0 +1,461 @@ +from convlab.modules.nlg.multiwoz.transformer.transformer import Constants +import json +import math +from collections import Counter +from nltk.util import ngrams +import numpy +import torch +import os + +def get_n_params(*params_list): + pp=0 + for params in params_list: + for p in params: + nn=1 + for s in list(p.size()): + nn = nn*s + pp += nn + return pp + +def filter_sents(sents, END): + hyps = [] + for batch_id in range(len(sents)): + done = False + for beam_id in range(len(sents[batch_id])): + sent = sents[batch_id][beam_id] + for s in sent[::-1]: + if s in [Constants.PAD, Constants.EOS]: + pass + elif s in END: + done = True + break + elif s not in END: + done = False + break + if done: + hyps.append(sent) + break + if len(hyps) < batch_id + 1: + hyps.append(sents[batch_id][0]) + return hyps + +def obtain_TP_TN_FN_FP(pred, act, TP, TN, FN, FP, elem_wise=False): + if isinstance(pred, torch.Tensor): + if elem_wise: + TP += ((pred.data == 1) & (act.data == 1)).sum(0) + TN += ((pred.data == 0) & (act.data == 0)).sum(0) + FN += ((pred.data == 0) & (act.data == 1)).sum(0) + FP += ((pred.data == 1) & (act.data == 0)).sum(0) + else: + TP += ((pred.data == 1) & (act.data == 1)).cpu().sum().item() + TN += ((pred.data == 0) & (act.data == 0)).cpu().sum().item() + FN += ((pred.data == 0) & (act.data == 1)).cpu().sum().item() + FP += ((pred.data == 1) & (act.data == 0)).cpu().sum().item() + return TP, TN, FN, FP + else: + TP += ((pred > 0).astype('long') & (act > 0).astype('long')).sum() + TN += ((pred == 0).astype('long') & (act == 0).astype('long')).sum() + FN += ((pred == 0).astype('long') & (act > 0).astype('long')).sum() + FP += ((pred > 0).astype('long') & (act == 0).astype('long')).sum() + return TP, TN, FN, FP + +class F1Scorer(object): + ## BLEU score calculator via GentScorer interface + ## it calculates the BLEU-4 by taking the entire corpus in + ## Calulate based multiple candidates against multiple references + def __init__(self): + pass + + def score(self, hypothesis, corpus, n=1): + # containers + data_dir = os.path.dirname(os.path.abspath(__file__)) + with open(os.path.join(data_dir, 'data/placeholder.json')) as f: + placeholder = json.load(f)['placeholder'] + + TP, TN, FN, FP = 0, 0, 0, 0 + # accumulate ngram statistics + files = hypothesis.keys() + for f in files: + hyps = hypothesis[f] + refs = corpus[f] + + hyps = [hyp.split() for hyp in hyps] + refs = [ref.split() for ref in refs] + # Shawn's evaluation + #refs[0] = [u'GO_'] + refs[0] + [u'EOS_'] + #hyps[0] = [u'GO_'] + hyps[0] + [u'EOS_'] + for hyp, ref in zip(hyps, refs): + pred = numpy.zeros((len(placeholder), ), 'float32') + gt = numpy.zeros((len(placeholder), ), 'float32') + for h in hyp: + if h in placeholder: + pred[placeholder.index(h)] += 1 + for r in ref: + if r in placeholder: + gt[placeholder.index(r)] += 1 + TP, TN, FN, FP = obtain_TP_TN_FN_FP(pred, gt, TP, TN, FN, FP) + + precision = TP / (TP + FP + 0.001) + recall = TP / (TP + FN + 0.001) + F1 = 2 * precision * recall / (precision + recall + 0.001) + return F1 + +def sentenceBLEU(hyps, refs, n=1): + count = [0, 0, 0, 0] + clip_count = [0, 0, 0, 0] + r = 0 + c = 0 + weights = [0.25, 0.25, 0.25, 0.25] + hyps = [hyp.split() for hyp in hyps] + refs = [ref.split() for ref in refs] + # Shawn's evaluation + refs[0] = [u'GO_'] + refs[0] + [u'EOS_'] + hyps[0] = [u'GO_'] + hyps[0] + [u'EOS_'] + for idx, hyp in enumerate(hyps): + for i in range(4): + # accumulate ngram counts + hypcnts = Counter(ngrams(hyp, i + 1)) + cnt = sum(hypcnts.values()) + count[i] += cnt + + # compute clipped counts + max_counts = {} + for ref in refs: + refcnts = Counter(ngrams(ref, i + 1)) + for ng in hypcnts: + max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) + clipcnt = dict((ng, min(count, max_counts[ng])) \ + for ng, count in hypcnts.items()) + clip_count[i] += sum(clipcnt.values()) + + # accumulate r & c + bestmatch = [1000, 1000] + for ref in refs: + if bestmatch[0] == 0: break + diff = abs(len(ref) - len(hyp)) + if diff < bestmatch[0]: + bestmatch[0] = diff + bestmatch[1] = len(ref) + r += bestmatch[1] + c += len(hyp) + if n == 1: + break + p0 = 1e-7 + bp = 1 if c > r else math.exp(1 - float(r) / float(c)) + p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ + for i in range(4)] + s = math.fsum(w * math.log(p_n) \ + for w, p_n in zip(weights, p_ns) if p_n) + bleu = bp * math.exp(s) + return bleu + +class BLEUScorer(object): + ## BLEU score calculator via GentScorer interface + ## it calculates the BLEU-4 by taking the entire corpus in + ## Calulate based multiple candidates against multiple references + def __init__(self): + pass + + def score(self, old_hypothesis, old_corpus, n=1): + file_names = old_hypothesis.keys() + hypothesis = [] + corpus = [] + for f in file_names: + old_h = old_hypothesis[f] + old_c = old_corpus[f] + for h, c in zip(old_h, old_c): + hypothesis.append([h]) + corpus.append([c]) + # containers + count = [0, 0, 0, 0] + clip_count = [0, 0, 0, 0] + r = 0 + c = 0 + weights = [0.25, 0.25, 0.25, 0.25] + # accumulate ngram statistics + for hyps, refs in zip(hypothesis, corpus): + hyps = [hyp.split() for hyp in hyps] + refs = [ref.split() for ref in refs] + # Shawn's evaluation + refs[0] = [u'GO_'] + refs[0] + [u'EOS_'] + hyps[0] = [u'GO_'] + hyps[0] + [u'EOS_'] + for idx, hyp in enumerate(hyps): + for i in range(4): + # accumulate ngram counts + hypcnts = Counter(ngrams(hyp, i + 1)) + cnt = sum(hypcnts.values()) + count[i] += cnt + + # compute clipped counts + max_counts = {} + for ref in refs: + refcnts = Counter(ngrams(ref, i + 1)) + for ng in hypcnts: + max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) + clipcnt = dict((ng, min(count, max_counts[ng])) \ + for ng, count in hypcnts.items()) + clip_count[i] += sum(clipcnt.values()) + + # accumulate r & c + bestmatch = [1000, 1000] + for ref in refs: + if bestmatch[0] == 0: break + diff = abs(len(ref) - len(hyp)) + if diff < bestmatch[0]: + bestmatch[0] = diff + bestmatch[1] = len(ref) + r += bestmatch[1] + c += len(hyp) + if n == 1: + break + # computing bleu score + p0 = 1e-7 + bp = 1 if c > r else math.exp(1 - float(r) / float(c)) + p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ + for i in range(4)] + s = math.fsum(w * math.log(p_n) \ + for w, p_n in zip(weights, p_ns) if p_n) + bleu = bp * math.exp(s) + return bleu + +class Tokenizer(object): + def __init__(self, vocab, ivocab, lower_case=True): + super(Tokenizer, self).__init__() + self.lower_case = lower_case + self.ivocab = ivocab + self.vocab = vocab + + self.vocab_len = len(self.vocab) + + def tokenize(self, sent): + if self.lower_case: + return sent.lower().split() + else: + return sent.split() + + def get_word_id(self, w, template=None): + if w in self.vocab: + return self.vocab[w] + else: + return self.vocab[Constants.UNK_WORD] + + + def get_word(self, k, template=None): + k = str(k) + return self.ivocab[k] + + def convert_tokens_to_ids(self, sent, template=None): + return [self.get_word_id(w, template) for w in sent] + + def convert_id_to_tokens(self, word_ids, template_ids=None, remain_eos=False): + if isinstance(word_ids, list): + if remain_eos: + return " ".join([self.get_word(wid, None) for wid in word_ids + if wid != Constants.PAD]) + else: + return " ".join([self.get_word(wid, None) for wid in word_ids + if wid not in [Constants.PAD, Constants.EOS] ]) + else: + if remain_eos: + return " ".join([self.get_word(wid.item(), None) for wid in word_ids + if wid != Constants.PAD]) + else: + return " ".join([self.get_word(wid.item(), None) for wid in word_ids + if wid not in [Constants.PAD, Constants.EOS]]) + + def convert_template(self, template_ids): + return [self.get_word(wid) for wid in template_ids if wid != Constants.PAD] +"""" +def nondetokenize(d_p, d_r): + UNK = "xxxxxxx" + placeholder = json.load(open('data/placeholder.json')) + dialog_id = 0 + for dialog, gt_dialog in zip(d_p, d_r): + turn_id = 0 + for turn, gt_turn in zip(dialog, gt_dialog): + kb = gt_turn['KB'] + bs = gt_turn['BS'] + acts = gt_turn['act'] + ref = gt_turn['sys_orig'] + + def change_words(domain, word, acts, keys, act_keys, kb_cols): + for key, act_key, kb_col in zip(keys, act_keys, kb_cols): + if key in word: + + if "reference" in word: + for act_name in acts: + if domain in act_name and act_key in act_name and acts[act_name] != "?": + new = acts[act_name].lower() + return new + + if kb != "None" and kb_col in kb[0]: + new = kb[1][kb[0].index(kb_col)].lower() + return new + return None + + words = turn.split(' ') + for i in range(len(words)): + word = words[i] + if word in placeholder: + if "reference" in word: + done = False + for act_name in acts: + if "ref" in act_name: + words[i] = acts[act_name].lower() + done = True + break + if not done: + words[i]= UNK + else: + if "attraction" in word: + new = change_words("attraction", words[i], acts, ["address", "area", "name", "phone", "postcode", "pricerange"], + ["addr", "area", "name", "phone", "post", "price"], + ["address", "area", "name", "phone", "postcode", "pricerange"]) + if new: + words[i] = new + elif "hotel" in word: + new = change_words("hotel", words[i], acts, ["name", "phone", "address", "postcode", "pricerange", "area"], + ["name", "phone", "addr", "post", "price", "area"], + ["name", "phone", "address", "postcode", "pricerange", "area"]) + if new: + words[i] = new + elif "restaurant" in word: + new = change_words("restaurant", words[i], acts, ["name", "phone", "address", "postcode", "food", "pricerange", "area"], + ["name", "phone", "addr", "post", "food", "price", "area"], + ["name", "phone", "address", "postcode", "food", "pricerange", "area"]) + if new: + words[i] = new + elif "train" in word: + new = change_words("train", words[i], acts, ["trainid", "price"], ["id", "ticket"], ["trainID", "price"]) + if new: + words[i] = new + elif "police" in word: + new = change_words("police", words[i], acts, ["name", "phone", "address", "postcode"], + ["name", "phone", "addr", "post"], + ["name", "phone", "address", "postcode"]) + if new: + words[i] = new + elif "hospital" in word: + new = change_words("hospital", words[i], acts, ["name", "phone", "address", "postcode", "department", "name"], + ["name", "phone", "address", "postcode", "department", "name"], + ["name", "phone", "address", "postcode", "department", "name"]) + if new: + words[i] = new + elif "taxi" in word: + new = change_words("taxi", words[i], acts, ["phone", "type"], ["phone", "car"], ["phone", "type"]) + if new: + words[i] = new + elif "value_count" in word: + words[i] = "1" + + elif "value_time" in word: + words[i] = "1:00" + + elif "value_day" in word: + words[i] = "monday" + + elif "value_place" in word: + words[i] = "cambridge" + + new_words = " ".join(words) + d_p[dialog_id][turn_id] = new_words + turn_id += 1 + dialog_id += 1 +""" +def nondetokenize(d_p, d_r): + need_replace = 0 + success = 0 + for gt_dialog_info in d_r: + file_name = gt_dialog_info['file'] + gt_dialog = gt_dialog_info['info'] + for turn_id in range(len(d_p[file_name])): + act = gt_dialog[turn_id]['act'] + words = d_p[file_name][turn_id].split(' ') + counter = {} + for i in range(len(words)): + if "[" in words[i] and "]" in words[i]: + need_replace += 1. + domain, slot = words[i].split('_') + domain = domain[1:].capitalize() + slot = slot[:-1].capitalize() + key = '-'.join((domain, slot)) + flag = False + for intent in act: + _domain, _intent = intent.split('-') + if domain == _domain and _intent in ['Inform', 'Recommend', 'Offerbook']: + for values in act[intent]: + if (slot == values[0]) and ('none' != values[-1]) and ((key not in counter) or (counter[key] == int(values[1])-1)): + words[i] = values[-1] + counter[key] = int(values[1]) + flag = True + success += 1. + break + if flag: + break + d_p[file_name][turn_id] = " ".join(words) + success_rate = success / need_replace + return success_rate +""" +class Templator(object): + with open('data/placeholder.json') as f: + fields = json.load(f)['field'] + templates = {} + for f in fields: + if 'pricerange' in f: + templates[f] = "its price is {}".format(f) + elif 'type' in f: + templates[f] = "it is of {} type".format(f) + elif "address" in f: + templates[f] = "its address is {}".format(f) + elif "name" in f: + templates[f] = "its name is {}".format(f) + elif "postcode" in f: + templates[f] = "its postcode is {}".format(f) + elif "phone" in f: + templates[f] = "its phone number is {}".format(f) + elif "reference" in f: + templates[f] = "its reference is {}".format(f) + elif "area" in f: + templates[f] = "it is located in {}".format(f) + elif "arriveby" in f: + templates[f] = "it arrives by {}".format(f) + elif "departure" in f: + templates[f] = "it departs at {}".format(f) + elif "destination" in f: + templates[f] = "its destination is at {}".format(f) + elif "day" in f: + templates[f] = "it is at the time of {}".format(f) + elif "stars" in f: + templates[f] = "it has {} stars".format(f) + elif "department" in f: + templates[f] = "its department is {}".format(f) + elif "food" in f: + templates[f] = "it provides {} food".format(f) + elif "duration" in f: + templates[f] = "it takes {} long".format(f) + elif "leaveat" in f: + templates[f] = "it leaves at {}".format(f) + elif "trainid" in f: + templates[f] = "its train id is {}".format(f) + elif "price" in f: + templates[f] = "its price is {}".format(f) + elif "entrance" in f: + templates[f] = "its fee is {}".format(f) + elif "parking": + templates[f] = {"yes":"it has parking", "no":"it does not have parking"} + elif "internet": + templates[f] = {"no":"it has internet", "no":"it does not have internet"} + + @staticmethod + def source2tempalte(source): + string = "" + for k, v in source.items(): + if "_id]" not in k: + if k in Templator.templates: + if isinstance(Templator.templates[k], str): + string += Templator.templates[k] + " . " + else: + string += Templator.templates[k][v] + " . " + return string +""" + \ No newline at end of file diff --git a/convlab/modules/nlg/multiwoz/transformer/train.py b/convlab/modules/nlg/multiwoz/transformer/train.py new file mode 100644 index 0000000..55bab1b --- /dev/null +++ b/convlab/modules/nlg/multiwoz/transformer/train.py @@ -0,0 +1,191 @@ +import json +import torch +import logging +import os +import argparse +import time +from convlab.modules.nlg.multiwoz.transformer.transformer.Transformer import TransformerDecoder +from torch.optim.lr_scheduler import MultiStepLR +from convlab.modules.nlg.multiwoz.transformer.transformer import Constants +from convlab.modules.nlg.multiwoz.transformer.dataloader import get_batch, get_info +from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler +from convlab.modules.nlg.multiwoz.transformer.tools import Tokenizer, BLEUScorer, F1Scorer, nondetokenize +from collections import OrderedDict +from convlab.modules.nlg.multiwoz.transformer.evaluator import evaluateModel + +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO) +logger = logging.getLogger(__name__) + +def parse_opt(): + parser = argparse.ArgumentParser() + parser.add_argument('--option', type=str, default="train", help="whether to train or test the model", choices=['train', 'test', 'postprocess']) + parser.add_argument('--emb_dim', type=int, default=128, help="the embedding dimension") + parser.add_argument('--dropout', type=float, default=0.2, help="the embedding dimension") + parser.add_argument('--resume', action='store_true', default=False, help="whether to resume previous run") + parser.add_argument('--batch_size', type=int, default=256, help="the embedding dimension") + parser.add_argument('--beam_size', type=int, default=2, help="the embedding dimension") + parser.add_argument('--layer_num', type=int, default=3, help="the embedding dimension") + parser.add_argument('--evaluate_every', type=int, default=5, help="the embedding dimension") + parser.add_argument('--head', type=int, default=4, help="the embedding dimension") + parser.add_argument("--output_dir", default="checkpoints/", type=str, \ + help="The output directory where the model predictions and checkpoints will be written.") + args = parser.parse_args() + return args + +args = parse_opt() +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +act_ontology = Constants.act_ontology +with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/vocab.json"), 'r') as f: + vocabulary = json.load(f) +vocab, ivocab = vocabulary['vocab'], vocabulary['rev'] +tokenizer = Tokenizer(vocab, ivocab) + +logger.info("Loading Vocabulary of {} size".format(tokenizer.vocab_len)) +# Loading the dataset + +os.makedirs(args.output_dir, exist_ok=True) +checkpoint_file = os.path.join(args.output_dir, 'transformer') + + +if 'train' in args.option: + *train_examples, _ = get_batch('train', tokenizer) + train_data = TensorDataset(*train_examples) + train_sampler = RandomSampler(train_data) + train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size) + *val_examples, val_id = get_batch('val', tokenizer) + with open('data/val.json') as f: + dialogs = json.load(f) + gt_turns = get_info(dialogs, 'sys') +elif 'test' in args.option or 'postprocess' in args.option: + *val_examples, val_id = get_batch('test', tokenizer) + with open('data/test.json') as f: + dialogs = json.load(f) + gt_turns = get_info(dialogs, 'sys') + gt_turns_nondelex = get_info(dialogs, 'sys_orig') +eval_data = TensorDataset(*val_examples) +eval_sampler = SequentialSampler(eval_data) +eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size) + +BLEU_calc = BLEUScorer() +F1_calc = F1Scorer() + +decoder = TransformerDecoder(vocab_size=tokenizer.vocab_len, d_word_vec=args.emb_dim, act_dim=len(Constants.act_ontology), + n_layers=args.layer_num, d_model=args.emb_dim, n_head=args.head, dropout=args.dropout) + +decoder.to(device) +loss_func = torch.nn.BCELoss() +loss_func.to(device) + +ce_loss_func = torch.nn.CrossEntropyLoss(ignore_index=Constants.PAD) +ce_loss_func.to(device) + +if args.option == 'train': + decoder.train() + if args.resume: + decoder.load_state_dict(torch.load(checkpoint_file)) + logger.info("Reloaing the encoder and decoder from {}".format(checkpoint_file)) + + logger.info("Start Training with {} batches".format(len(train_dataloader))) + + optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad, decoder.parameters()), betas=(0.9, 0.98), eps=1e-09) + scheduler = MultiStepLR(optimizer, milestones=[50, 100, 150, 200], gamma=0.5) + + best_BLEU = 0 + for epoch in range(360): + for step, batch in enumerate(train_dataloader): + batch = tuple(t.to(device) for t in batch) + input_ids, rep_in, resp_out, act_vecs = batch + + decoder.zero_grad() + optimizer.zero_grad() + + logits = decoder(tgt_seq=rep_in, src_seq=input_ids, act_vecs=act_vecs) + + loss = ce_loss_func(logits.contiguous().view(logits.size(0) * logits.size(1), -1).contiguous(), \ + resp_out.contiguous().view(-1)) + + loss.backward() + optimizer.step() + + if step % 100 == 0: + logger.info("epoch {} step {} training loss {}".format(epoch, step, loss.item())) + + scheduler.step() + if loss.item() < 3.0 and epoch > 0 and epoch % args.evaluate_every == 0: + logger.info("start evaluating BLEU on validation set") + decoder.eval() + # Start Evaluating after each epoch + model_turns = {} + TP, TN, FN, FP = 0, 0, 0, 0 + for batch_step, batch in enumerate(eval_dataloader): + batch = tuple(t.to(device) for t in batch) + input_ids, rep_in, resp_out, act_vecs = batch + + hyps = decoder.translate_batch(act_vecs=act_vecs, \ + src_seq=input_ids, n_bm=args.beam_size, + max_token_seq_len=50) + + for hyp_step, hyp in enumerate(hyps): + pred = tokenizer.convert_id_to_tokens(hyp) + file_name = val_id[batch_step * args.batch_size + hyp_step] + if file_name not in model_turns: + model_turns[file_name] = [pred] + else: + model_turns[file_name].append(pred) + BLEU = BLEU_calc.score(model_turns, gt_turns) + + logger.info("{} epoch, Validation BLEU {} ".format(epoch, BLEU)) + if BLEU > best_BLEU: + torch.save(decoder.state_dict(), checkpoint_file) + best_BLEU = BLEU + decoder.train() + +elif args.option == "test": + decoder.load_state_dict(torch.load(checkpoint_file)) + logger.info("Loading model from {}".format(checkpoint_file)) + decoder.eval() + logger.info("Start Testing with {} batches".format(len(eval_dataloader))) + + model_turns = {} + act_turns = {} + step = 0 + start_time = time.time() + TP, TN, FN, FP = 0, 0, 0, 0 + for batch_step, batch in enumerate(eval_dataloader): + batch = tuple(t.to(device) for t in batch) + input_ids, rep_in, resp_out, act_vecs = batch + + hyps = decoder.translate_batch(act_vecs=act_vecs, src_seq=input_ids, + n_bm=args.beam_size, max_token_seq_len=50) + for hyp_step, hyp in enumerate(hyps): + pred = tokenizer.convert_id_to_tokens(hyp) + file_name = val_id[batch_step * args.batch_size + hyp_step] + if file_name not in model_turns: + model_turns[file_name] = [pred] + else: + model_turns[file_name].append(pred) + + logger.info("finished {}/{} used {} sec/per-sent".format(batch_step, len(eval_dataloader), \ + (time.time() - start_time) / args.batch_size)) + start_time = time.time() + + model_turns = OrderedDict(sorted(model_turns.items())) + + BLEU = BLEU_calc.score(model_turns, gt_turns) + entity_F1 = F1_calc.score(model_turns, gt_turns) + + logger.info("BLEU = {} EntityF1 = {}".format(BLEU, entity_F1)) + + # non delex + evaluateModel(model_turns) + + success_rate = nondetokenize(model_turns, dialogs) + BLEU = BLEU_calc.score(model_turns, gt_turns_nondelex) + + logger.info("Non delex BLEU {}, Success Rate {}".format(BLEU, success_rate)) + +else: + raise ValueError("No such option") diff --git a/convlab/modules/nlg/multiwoz/transformer/transformer/Beam.py b/convlab/modules/nlg/multiwoz/transformer/transformer/Beam.py new file mode 100644 index 0000000..b92d476 --- /dev/null +++ b/convlab/modules/nlg/multiwoz/transformer/transformer/Beam.py @@ -0,0 +1,116 @@ +""" Manage beam search info structure. + + Heavily borrowed from OpenNMT-py. + For code in OpenNMT-py, please check the following link: + https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py +""" + +import torch +from convlab.modules.nlg.multiwoz.transformer.transformer import Constants + +class Beam(object): + ''' Beam search ''' + + def __init__(self, size, device=False): + + self.size = size + self._done = False + + # The score for each translation on the beam. + self.scores = torch.zeros((size,), dtype=torch.float, device=device) + self.all_scores = [] + + # The backpointers at each time-step. + self.prev_ks = [] + + # The outputs at each time-step. + self.next_ys = [torch.full((size,), Constants.PAD, dtype=torch.long, device=device)] + self.next_ys[0][0] = Constants.SOS + self.finished = [False for _ in range(size)] + + def get_current_state(self): + "Get the outputs for the current timestep." + return self.get_tentative_hypothesis() + + def get_current_origin(self): + "Get the backpointers for the current timestep." + return self.prev_ks[-1] + + @property + def done(self): + return self._done + + def advance(self, word_prob): + "Update beam status and check if finished or not." + num_words = word_prob.size(1) + + for i in range(self.size): + if self.finished[i]: + word_prob[i, :].fill_(-1000) + word_prob[i, Constants.PAD].fill_(0) + #import pdb + #pdb.set_trace() + # Sum the previous scores. + if len(self.prev_ks) > 0: + beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob) + else: + beam_lk = word_prob[0] + + flat_beam_lk = beam_lk.view(-1) + + #best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort + best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort + + self.all_scores.append(self.scores) + self.scores = best_scores + + # bestScoresId is flattened as a (beam x word) array, + # so we need to calculate which word and beam each score came from + prev_k = best_scores_id / num_words + self.prev_ks.append(prev_k) + self.next_ys.append(best_scores_id - prev_k * num_words) + + # End condition is when top-of-beam is EOS. + #if self.next_ys[-1][0].item() == Constants.EOS: + # self._done = True + # self.all_scores.append(self.scores) + self.finished = [] + for i in range(self.size): + self.finished.append(self.next_ys[-1][i].item() in [Constants.EOS, Constants.PAD]) + + if all(self.finished): + self._done = True + #self._done = self.finished[0] + + return self._done + + def sort_scores(self): + "Sort the scores." + return torch.sort(self.scores, 0, True) + + def get_the_best_score_and_idx(self): + "Get the score of the best in the beam." + scores, ids = self.sort_scores() + return scores[1], ids[1] + + def get_tentative_hypothesis(self): + "Get the decoded sequence for the current timestep." + + if len(self.next_ys) == 1: + dec_seq = self.next_ys[0].unsqueeze(1) + else: + _, keys = self.sort_scores() + hyps = [self.get_hypothesis(k) for k in keys] + hyps = [[Constants.SOS] + h for h in hyps] + dec_seq = torch.LongTensor(hyps) + + return dec_seq + + def get_hypothesis(self, k): + """ Walk back to construct the full hypothesis. """ + hyp = [] + for j in range(len(self.prev_ks) - 1, -1, -1): + hyp.append(self.next_ys[j+1][k]) + k = self.prev_ks[j][k] + + return list(map(lambda x: x.item(), hyp[::-1])) diff --git a/convlab/modules/nlg/multiwoz/transformer/transformer/Constants.py b/convlab/modules/nlg/multiwoz/transformer/transformer/Constants.py new file mode 100644 index 0000000..4d02a36 --- /dev/null +++ b/convlab/modules/nlg/multiwoz/transformer/transformer/Constants.py @@ -0,0 +1,57 @@ +import os +import numpy + +PAD = 0 +EOS = 1 +SOS = 2 +UNK = 3 +CLS = 4 +SEP = 5 + +PAD_WORD = '[PAD]' +EOS_WORD = '[EOS]' +SOS_WORD = '[SOS]' +UNK_WORD = '[UNK]' +CLS_WORD = '[CLS]' +SEP_WORD = '[SEP]' + +TEMPLATE_MAX_LEN = 50 +RESP_MAX_LEN = 40 +MAX_SEGMENT = 20 +T = 1 + +def append_or_add(dictionary, name, key): + if name in dictionary: + dictionary[name].append(key) + else: + dictionary[name] = [key] + +domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police', 'bus', 'booking', 'general'] +functions = ['inform', 'request', 'recommend', 'book', 'select', 'sorry', 'none'] +arguments = ['pricerange', 'id', 'address', 'postcode', 'type', 'food', 'phone', 'name', 'area', 'choice', + 'price', 'time', 'reference', 'none', 'parking', 'stars', 'internet', 'day', 'arriveby', 'departure', + 'destination', 'leaveat', 'duration', 'trainid', 'people', 'department', 'stay'] + +used_levels = domains + functions + arguments +#used_levels = functions + arguments +act_len = len(used_levels) +def act_to_vectors(acts): + r = numpy.zeros((act_len, ), 'float32') + for act in acts: + p1, p2, p3 = act.split('-') + if len(used_levels) == len(domains + functions + arguments): + r[domains.index(p1)] + r[len(domains) + functions.index(p2)] += 1 + r[len(domains) + len(functions) + arguments.index(p3)] += 1 + else: + r[functions.index(p2)] += 1 + r[len(functions) + arguments.index(p3)] += 1 + return (r > 0).astype('float32') + +id_to_acts = {} +for i, name in enumerate(used_levels): + id_to_acts[i] = name + +with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data/template.txt'), 'r') as f: + act_ontology = f.readlines() + act_ontology = list(map(lambda x:x.strip(), act_ontology)) diff --git a/convlab/modules/nlg/multiwoz/transformer/transformer/Transformer.py b/convlab/modules/nlg/multiwoz/transformer/transformer/Transformer.py new file mode 100644 index 0000000..a9a4352 --- /dev/null +++ b/convlab/modules/nlg/multiwoz/transformer/transformer/Transformer.py @@ -0,0 +1,505 @@ +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +import math +from convlab.modules.nlg.multiwoz.transformer.transformer.Beam import Beam +from convlab.modules.nlg.multiwoz.transformer.transformer import Constants + +class PositionalEmbedding(nn.Module): + + def __init__(self, d_model, max_len=512): + super(PositionalEmbedding, self).__init__() + + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + pe.require_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + return self.pe[:, :x.size(1)] + +class EncoderLayer(nn.Module): + ''' Compose with two layers ''' + + def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): + super(EncoderLayer, self).__init__() + self.slf_attn = MultiHeadAttention( + n_head, d_model, d_k, d_v, dropout=dropout) + self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) + + def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): + enc_output, enc_slf_attn = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask) + enc_output *= non_pad_mask + + enc_output = self.pos_ffn(enc_output) + enc_output *= non_pad_mask + + return enc_output, enc_slf_attn + +class AverageHeadAttention(nn.Module): + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + super(AverageHeadAttention, self).__init__() + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.w_qs = nn.Linear(d_model, n_head * d_k) + self.w_ks = nn.Linear(d_model, n_head * d_k) + self.w_vs = nn.Linear(d_model, n_head * d_v) + nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) + + self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) + self.layer_norm = nn.LayerNorm(d_model) + + self.fc = nn.Linear(d_v, d_model) + nn.init.xavier_normal_(self.fc.weight) + + self.dropout = nn.Dropout(dropout) + + + def forward(self, a, q, k, v, mask=None): + + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + + sz_b, len_q, _ = q.size() + sz_b, len_k, _ = k.size() + sz_b, len_v, _ = v.size() + residual = q + + q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) + k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) + + q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk + k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk + v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv + + mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. + output, attn = self.attention(q, k, v, mask=mask) + + output = output.view(n_head, sz_b, len_q, d_v) + a = a.permute(1, 0).contiguous()[:, :, None, None] + + #output = output * a + output = torch.sum(output * a, 0) + output = output.view(sz_b, len_q, -1) + #output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) + + output = self.dropout(self.fc(output)) + output = self.layer_norm(output + residual) + + return output, attn + +class MultiHeadAttention(nn.Module): + ''' Multi-Head Attention module ''' + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + super(MultiHeadAttention, self).__init__() + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.w_qs = nn.Linear(d_model, n_head * d_k) + self.w_ks = nn.Linear(d_model, n_head * d_k) + self.w_vs = nn.Linear(d_model, n_head * d_v) + nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) + + self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) + self.layer_norm = nn.LayerNorm(d_model) + + self.fc = nn.Linear(n_head * d_v, d_model) + nn.init.xavier_normal_(self.fc.weight) + + self.dropout = nn.Dropout(dropout) + + + def forward(self, q, k, v, mask=None): + + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + + sz_b, len_q, _ = q.size() + sz_b, len_k, _ = k.size() + sz_b, len_v, _ = v.size() + + residual = q + + q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) + k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) + + q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk + k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk + v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv + + mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. + output, attn = self.attention(q, k, v, mask=mask) + + output = output.view(n_head, sz_b, len_q, d_v) + output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) + + output = self.dropout(self.fc(output)) + output = self.layer_norm(output + residual) + + return output, attn + +class PositionwiseFeedForward(nn.Module): + ''' A two-feed-forward-layer module ''' + + def __init__(self, d_in, d_hid, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise + self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise + self.layer_norm = nn.LayerNorm(d_in) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + residual = x + output = x.transpose(1, 2) + output = self.w_2(F.relu(self.w_1(output))) + output = output.transpose(1, 2) + output = self.dropout(output) + output = self.layer_norm(output + residual) + return output + +class ScaledDotProductAttention(nn.Module): + ''' Scaled Dot-Product Attention ''' + + def __init__(self, temperature, attn_dropout=0.1): + super(ScaledDotProductAttention, self).__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + self.softmax = nn.Softmax(dim=2) + + def forward(self, q, k, v, mask=None): + + attn = torch.bmm(q, k.transpose(1, 2)) + attn = attn / self.temperature + + if mask is not None: + attn = attn.masked_fill(mask, -np.inf) + + attn = self.softmax(attn) + attn = self.dropout(attn) + output = torch.bmm(attn, v) + + return output, attn + +def get_non_pad_mask(seq): + assert seq.dim() == 2 + return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) + +def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): + ''' Sinusoid position encoding table ''' + + def cal_angle(position, hid_idx): + return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) + + def get_posi_angle_vec(position): + return [cal_angle(position, hid_j) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) + + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + if padding_idx is not None: + # zero vector for padding dimension + sinusoid_table[padding_idx] = 0. + + return torch.FloatTensor(sinusoid_table) + +def get_attn_key_pad_mask(seq_k, seq_q): + ''' For masking out the padding part of key sequence. ''' + + # Expand to fit the shape of key query attention matrix. + len_q = seq_q.size(1) + padding_mask = seq_k.eq(Constants.PAD) + padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk + + return padding_mask + +def get_subsequent_mask(seq): + ''' For masking out the subsequent info. ''' + + sz_b, len_s = seq.size() + subsequent_mask = torch.triu( + torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) + subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls + + return subsequent_mask + +class Transformer(nn.Module): + ''' A encoder model with self attention mechanism. ''' + + def __init__(self, n_src_vocab, len_max_seq, d_word_vec, + n_layers, n_head, d_k, d_v, + d_model, d_inner, embedding, dropout=0.1): + + super(Transformer, self).__init__() + + n_position = len_max_seq + 1 + + self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=Constants.PAD) + #self.src_word_emb = nn.Embedding.from_pretrained(embedding, freeze=False) + + self.position_enc = nn.Embedding.from_pretrained( + get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), + freeze=True) + + self.layer_stack = nn.ModuleList([ + EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers)]) + + def forward(self, src_seq, src_pos, act_vocab_id): + # -- Prepare masks + slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) + non_pad_mask = get_non_pad_mask(src_seq) + + # -- Forward Word Embedding + enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos) + # -- Forward Ontology Embedding + ontology_embedding = self.src_word_emb(act_vocab_id) + + for enc_layer in self.layer_stack: + enc_output, enc_slf_attn = enc_layer( + enc_output, + non_pad_mask=non_pad_mask, + slf_attn_mask=slf_attn_mask) + + dot_prod = torch.sum(enc_output[:, :, None, :] * ontology_embedding[None, None, :, :], -1) + #index = length[:, None, None].repeat(1, 1, dot_prod.size(-1)) + #pooled_dot_prod = dot_prod.gather(1, index).squeeze() + pooled_dot_prod = dot_prod[:, 0, :] + pooling_likelihood = torch.sigmoid(pooled_dot_prod) + return pooling_likelihood, enc_output + +class AvgDecoderLayer(nn.Module): + ''' Compose with three layers ''' + + def __init__(self, d_model, d_inner, n_head, d_k, d_v, n_head_enc, dropout=0.1): + super(AvgDecoderLayer, self).__init__() + self.slf_attn = AverageHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) + self.enc_attn = MultiHeadAttention(n_head_enc, d_model, d_model // n_head_enc, d_model // n_head_enc, dropout=dropout) + self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) + + def forward(self, act_vecs, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None): + dec_output, dec_slf_attn = self.slf_attn(act_vecs, dec_input, dec_input, dec_input, mask=slf_attn_mask) + dec_output *= non_pad_mask + dec_output, dec_enc_attn = self.enc_attn(dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) + dec_output *= non_pad_mask + + dec_output = self.pos_ffn(dec_output) + dec_output *= non_pad_mask + + return dec_output, dec_slf_attn, None + +class DecoderLayer(nn.Module): + ''' Compose with three layers ''' + + def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): + super(DecoderLayer, self).__init__() + self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) + self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) + self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) + + def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None): + dec_output, dec_slf_attn = self.slf_attn( + dec_input, dec_input, dec_input, mask=slf_attn_mask) + dec_output *= non_pad_mask + + dec_output, dec_enc_attn = self.enc_attn( + dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) + dec_output *= non_pad_mask + + dec_output = self.pos_ffn(dec_output) + dec_output *= non_pad_mask + + return dec_output, dec_slf_attn, dec_enc_attn + +class TransformerDecoder(nn.Module): + ''' A decoder model with self attention mechanism. ''' + + def __init__(self, vocab_size, d_word_vec, n_layers, d_model, n_head, act_dim, dropout=0.1): + + super(TransformerDecoder, self).__init__() + d_k = d_model // n_head + d_v = d_model // n_head + d_inner = d_model * 4 + + self.tgt_word_emb = nn.Embedding(vocab_size, d_word_vec, padding_idx=Constants.PAD) + self.act_word_emb = nn.Linear(act_dim, d_word_vec, bias=False) + + self.post_word_emb = PositionalEmbedding(d_model=d_word_vec) + + self.enc_layer_stack = nn.ModuleList([ + EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers)]) + + self.layer_stack = nn.ModuleList([ + DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers)]) + + self.tgt_word_prj = nn.Linear(d_model, vocab_size, bias=False) + + + def forward(self, tgt_seq, src_seq, act_vecs): + # -- Encode source + non_pad_mask = get_non_pad_mask(src_seq) + slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) + enc_inp = self.tgt_word_emb(src_seq) + self.post_word_emb(src_seq) + + for layer in self.enc_layer_stack: + enc_inp, _ = layer(enc_inp, non_pad_mask, slf_attn_mask) + enc_output = enc_inp + + # -- Prepare masks + non_pad_mask = get_non_pad_mask(tgt_seq) + + slf_attn_mask_subseq = get_subsequent_mask(tgt_seq) + slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq) + slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) + dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq) + + # -- Forward + dec_output = self.tgt_word_emb(tgt_seq) + self.post_word_emb(tgt_seq) + self.act_word_emb(act_vecs)[:, None, :] + + for dec_layer in self.layer_stack: + dec_output, dec_slf_attn, dec_enc_attn = dec_layer( + dec_output, enc_output, + non_pad_mask=non_pad_mask, + slf_attn_mask=slf_attn_mask, + dec_enc_attn_mask=dec_enc_attn_mask) + + logits = self.tgt_word_prj(dec_output) + return logits + + def translate_batch(self, act_vecs, src_seq, n_bm, max_token_seq_len=30): + ''' Translation work in one batch ''' + device = src_seq.device + def collate_active_info(act_vecs, src_seq, inst_idx_to_position_map, active_inst_idx_list): + # Sentences which are still active are collected, + # so the decoder will not run on completed sentences. + n_prev_active_inst = len(inst_idx_to_position_map) + active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] + active_inst_idx = torch.LongTensor(active_inst_idx).to(device) + + active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) + active_act_vecs = collect_active_part(act_vecs, active_inst_idx, n_prev_active_inst, n_bm) + #active_template_output = collect_active_part(template_output, active_inst_idx, n_prev_active_inst, n_bm) + + active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) + + return active_act_vecs, active_src_seq, active_inst_idx_to_position_map + + def beam_decode_step(inst_dec_beams, len_dec_seq, active_inst_idx_list, act_vecs, src_seq, \ + inst_idx_to_position_map, n_bm): + ''' Decode and update beam status, and then return active beam idx ''' + n_active_inst = len(inst_idx_to_position_map) + + #dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done] + dec_partial_seq = [inst_dec_beams[idx].get_current_state() + for idx in active_inst_idx_list if not inst_dec_beams[idx].done] + dec_partial_seq = torch.stack(dec_partial_seq).to(device) + dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) + + logits = self.forward(dec_partial_seq, src_seq, act_vecs)[:, -1, :] / Constants.T + word_prob = F.log_softmax(logits, dim=1) + word_prob = word_prob.view(n_active_inst, n_bm, -1) + + # Update the beam with predicted word prob information and collect incomplete instances + active_inst_idx_list = [] + for inst_idx, inst_position in inst_idx_to_position_map.items(): + is_inst_complete = inst_dec_beams[inst_idx].advance(word_prob[inst_position]) + if not is_inst_complete: + active_inst_idx_list += [inst_idx] + + return active_inst_idx_list + with torch.no_grad(): + #-- Repeat data for beam search + n_inst, len_s = src_seq.size() + src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) + act_vecs = act_vecs.repeat(1, n_bm).view(n_inst * n_bm, -1) + + #-- Prepare beams + inst_dec_beams = [Beam(n_bm, device=device) for _ in range(n_inst)] + + #-- Bookkeeping for active or not + active_inst_idx_list = list(range(n_inst)) + inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) + + #-- Decode + for len_dec_seq in range(1, max_token_seq_len + 1): + active_inst_idx_list = beam_decode_step(inst_dec_beams, len_dec_seq, active_inst_idx_list, + act_vecs, src_seq, inst_idx_to_position_map, n_bm) + + if not active_inst_idx_list: + break # all instances have finished their path to + + act_vecs, src_seq, inst_idx_to_position_map = collate_active_info( + act_vecs, src_seq, inst_idx_to_position_map, active_inst_idx_list) + + def collect_hypothesis_and_scores(inst_dec_beams, n_best): + all_hyp, all_scores = [], [] + for beam in inst_dec_beams: + scores = beam.scores + hyps = np.array([beam.get_hypothesis(i) for i in range(beam.size)], 'long') + lengths = (hyps != Constants.PAD).sum(-1) + normed_scores = [scores[i].item()/lengths[i] for i, hyp in enumerate(hyps)] + idxs = np.argsort(normed_scores)[::-1] + + all_hyp.append([hyps[idx] for idx in idxs]) + all_scores.append([normed_scores[idx] for idx in idxs]) + """ + for inst_idx in range(len(inst_dec_beams)): + scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() + all_scores += [scores[:n_best]] + + hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]] + all_hyp += [hyps] + """ + return all_hyp, all_scores + + batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, n_bm) + + result = [] + for _ in batch_hyp: + finished = False + for r in _: + if len(r) >= 8 and len(r) < 40: + result.append(r) + finished = True + break + if not finished: + result.append(_[0]) + return result + +def get_inst_idx_to_tensor_position_map(inst_idx_list): + ''' Indicate the position of an instance in a tensor. ''' + return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} + +def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): + ''' Collect tensor parts associated to active instances. ''' + _, *d_hs = beamed_tensor.size() + n_curr_active_inst = len(curr_active_inst_idx) + new_shape = (n_curr_active_inst * n_bm, *d_hs) + + beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) + beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) + beamed_tensor = beamed_tensor.view(*new_shape) + + return beamed_tensor diff --git a/convlab/modules/nlg/multiwoz/transformer/transformer_nlg.py b/convlab/modules/nlg/multiwoz/transformer/transformer_nlg.py new file mode 100644 index 0000000..c6dfd92 --- /dev/null +++ b/convlab/modules/nlg/multiwoz/transformer/transformer_nlg.py @@ -0,0 +1,270 @@ +# -*- coding: utf-8 -*- + +import re +import os +import zipfile +import json +import torch +import pickle +from copy import deepcopy +from convlab.lib.file_util import cached_path +from convlab.modules.nlg.nlg import NLG +from convlab.modules.word_policy.multiwoz.hdsa.tools import Tokenizer +from convlab.modules.word_policy.multiwoz.hdsa.transformer import Constants +from convlab.modules.word_policy.multiwoz.hdsa.transformer.Transformer import TransformerDecoder + +timepat = re.compile("\d{1,2}[:]\d{1,2}") +pricepat = re.compile("\d{1,3}[.]\d{1,2}") + +DEFAULT_DIRECTORY = "models" +DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "transformer.zip") + +def insertSpace(token, text): + sidx = 0 + while True: + sidx = text.find(token, sidx) + if sidx == -1: + break + if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \ + re.match('[0-9]', text[sidx + 1]): + sidx += 1 + continue + if text[sidx - 1] != ' ': + text = text[:sidx] + ' ' + text[sidx:] + sidx += 1 + if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ': + text = text[:sidx + 1] + ' ' + text[sidx + 1:] + sidx += 1 + return text + +def normalize(text, sub=True): + # lower case every word + text = text.lower() + + # replace white spaces in front and end + text = re.sub(r'^\s*|\s*$', '', text) + + # hotel domain pfb30 + text = re.sub(r"b&b", "bed and breakfast", text) + text = re.sub(r"b and b", "bed and breakfast", text) + + # normalize phone number + ms = re.findall('\(?(\d{3})\)?[-.\s]?(\d{3})[-.\s]?(\d{4,5})', text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m[0], sidx) + if text[sidx - 1] == '(': + sidx -= 1 + eidx = text.find(m[-1], sidx) + len(m[-1]) + text = text.replace(text[sidx:eidx], ''.join(m)) + + # normalize postcode + ms = re.findall('([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})', + text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m, sidx) + eidx = sidx + len(m) + text = text[:sidx] + re.sub('[,\. ]', '', m) + text[eidx:] + + # weird unicode bug + text = re.sub(u"(\u2018|\u2019)", "'", text) + + # replace time and and price + if sub: + text = re.sub(timepat, ' [value_time] ', text) + text = re.sub(pricepat, ' [train_price] ', text) + #text = re.sub(pricepat2, '[value_price]', text) + + # replace st. + text = text.replace(';', ',') + text = re.sub('$\/', '', text) + text = text.replace('/', ' and ') + + # replace other special characters + text = text.replace('-', ' ') + text = re.sub('[\":\<>@\(\)]', '', text) + + # insert white space before and after tokens: + for token in ['?', '.', ',', '!']: + text = insertSpace(token, text) + + # insert white space for 's + text = insertSpace('\'s', text) + + # replace it's, does't, you'd ... etc + text = re.sub('^\'', '', text) + text = re.sub('\'$', '', text) + text = re.sub('\'\s', ' ', text) + text = re.sub('\s\'', ' ', text) + + # remove multiple spaces + text = re.sub(' +', ' ', text) + + # concatenate numbers + tokens = text.split() + i = 1 + while i < len(tokens): + if re.match(u'^\d+$', tokens[i]) and \ + re.match(u'\d+$', tokens[i - 1]): + tokens[i - 1] += tokens[i] + del tokens[i] + else: + i += 1 + text = ' '.join(tokens) + + return text + + +def delexicalise(utt, dictionary): + for key, val in dictionary: + utt = (' ' + utt + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + utt = utt[1:-1] # why this? + + return utt + +def delexicaliseReferenceNumber(sent, turn): + """Based on the belief state, we can find reference number that + during data gathering was created randomly.""" + for domain in turn: + if turn[domain]['book']['booked']: + for slot in turn[domain]['book']['booked'][0]: + if slot == 'reference': + val = '[' + domain + '_' + slot + ']' + else: + val = '[' + domain + '_' + slot + ']' + key = normalize(turn[domain]['book']['booked'][0][slot]) + sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + + # try reference with hashtag + key = normalize("#" + turn[domain]['book']['booked'][0][slot]) + sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + + # try reference with ref# + key = normalize("ref#" + turn[domain]['book']['booked'][0][slot]) + sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + return sent + +class Transformer(NLG): + + def __init__(self, + archive_file=DEFAULT_ARCHIVE_FILE, + use_cuda=False, + model_file=None): + if not os.path.isfile(archive_file): + if not model_file: + raise Exception("No model for Transformer is specified!") + archive_file = cached_path(model_file) + model_dir = os.path.dirname(os.path.abspath(__file__)) + if not os.path.exists(os.path.join(model_dir, 'checkpoints')): + archive = zipfile.ZipFile(archive_file, 'r') + archive.extractall(model_dir) + + with open(os.path.join(model_dir, "data/vocab.json"), 'r') as f: + vocabulary = json.load(f) + + vocab, ivocab = vocabulary['vocab'], vocabulary['rev'] + self.tokenizer = Tokenizer(vocab, ivocab) + self.max_seq_length = 50 + + self.decoder = TransformerDecoder(vocab_size=self.tokenizer.vocab_len, d_word_vec=128, act_dim=len(Constants.act_ontology), + n_layers=3, d_model=128, n_head=4, dropout=0.2) + self.device = 'cuda' if use_cuda else 'cpu' + self.decoder.to(self.device) + checkpoint_file = os.path.join(model_dir, "checkpoints/transformer") + self.decoder.load_state_dict(torch.load(checkpoint_file)) + + with open(os.path.join(model_dir, 'data/svdic.pkl'), 'rb') as f: + self.dic = pickle.load(f) + + def generate(self, meta, state): + """ + meta = {"Attraction-Inform": [["Choice","many"],["Area","centre of town"]], + "Attraction-Select": [["Type","church"],["Type"," swimming"],["Type"," park"]]} + """ + usr_post = state['history'][-1][-1] + usr = delexicalise(' '.join(usr_post.split()), self.dic) + + # parsing reference number GIVEN belief state + usr = delexicaliseReferenceNumber(usr, state['belief_state']) + + # changes to numbers only here + digitpat = re.compile('\d+') + usr = re.sub(digitpat, '[value_count]', usr) + + tokens = self.tokenizer.tokenize(usr) + if self.history: + tokens = self.history + [Constants.SEP_WORD] + tokens + if len(tokens) > self.max_seq_length - 2: + tokens = tokens[-(self.max_seq_length - 2):] + tokens = [Constants.CLS_WORD] + tokens + [Constants.SEP_WORD] + input_ids = self.tokenizer.convert_tokens_to_ids(tokens) + input_ids = torch.tensor([input_ids], dtype=torch.long).to(self.device) + + # add placeholder value + meta = deepcopy(meta) + for k, v in meta.items(): + domain, intent = k.split('-') + if intent == "Request": + for pair in v: + if not isinstance(pair[1], str): + pair[1] = str(pair[1]) + pair.insert(1, '?') + else: + counter = {} + for pair in v: + if not isinstance(pair[1], str): + pair[1] = str(pair[1]) + if pair[0] == 'Internet' or pair[0] == 'Parking': + pair.insert(1, 'yes') + elif pair[0] == 'none': + pair.insert(1, 'none') + else: + if pair[0] in counter: + counter[pair[0]] += 1 + else: + counter[pair[0]] = 1 + pair.insert(1, str(counter[pair[0]])) + + act_vecs = [0] * len(Constants.act_ontology) + for intent in meta: + for values in meta[intent]: + w = intent + '-' + values[0] + '-' + values[1] + if w in Constants.act_ontology: + act_vecs[Constants.act_ontology.index(w)] = 1 + + act_vecs = torch.tensor([act_vecs], dtype=torch.long).to(self.device) + + hyps = self.decoder.translate_batch(act_vecs=act_vecs, src_seq=input_ids, + n_bm=2, max_token_seq_len=40) + pred = self.tokenizer.convert_id_to_tokens(hyps[0]) + + if not self.history: + self.history = tokens[1:-1] + [Constants.SEP_WORD] + self.tokenizer.tokenize(pred) + else: + self.history = self.history + [Constants.SEP_WORD] + tokens[1:-1] + [Constants.SEP_WORD] + self.tokenizer.tokenize(pred) + + # replace the placeholder with entities + words = pred.split(' ') + counter = {} + for i in range(len(words)): + if "[" in words[i] and "]" in words[i]: + domain, slot = words[i].split('_') + domain = domain[1:].capitalize() + slot = slot[:-1].capitalize() + key = '-'.join((domain, slot)) + flag = False + for intent in meta: + _domain, _intent = intent.split('-') + if domain == _domain and _intent in ['Inform', 'Recommend', 'Offerbook']: + for values in meta[intent]: + if (slot == values[0]) and ('none' != values[-1]) and ((key not in counter) or (counter[key] == int(values[1])-1)): + words[i] = values[-1] + counter[key] = int(values[1]) + flag = True + break + if flag: + break + return " ".join(words) From d09b5e3511575ab5452cedc974c5916fcd7f0523 Mon Sep 17 00:00:00 2001 From: truthless11 Date: Thu, 19 Sep 2019 10:02:51 +0800 Subject: [PATCH 2/3] add transformer --- .gitignore | 4 +++ .../nlg/multiwoz/transformer/README.md | 33 +++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 convlab/modules/nlg/multiwoz/transformer/README.md diff --git a/.gitignore b/.gitignore index 174f6c7..fd2b8c1 100644 --- a/.gitignore +++ b/.gitignore @@ -47,6 +47,10 @@ convlab/modules/nlg/multiwoz/sc_lstm/resource_usr convlab/modules/nlg/multiwoz/sc_lstm/sclstm_usr.pt convlab/modules/nlg/multiwoz/sc_lstm/sclstm_usr.res +# transformer +convlab/modules/nlg/multiwoz/transformer/data/ +convlab/modules/nlg/multiwoz/transformer/checkpoints/ + # svm convlab/modules/nlu/multiwoz/svm/model/ diff --git a/convlab/modules/nlg/multiwoz/transformer/README.md b/convlab/modules/nlg/multiwoz/transformer/README.md new file mode 100644 index 0000000..b39759b --- /dev/null +++ b/convlab/modules/nlg/multiwoz/transformer/README.md @@ -0,0 +1,33 @@ +# Transformer + +Transformer encodes the user utterance and decodes the system utterance with a stacked Transformer architecture where self-attention and multi-head attention mechanism is used. It's first proposed at [NIPS 2017](https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf) . Here we appends the dialog act vector to the input word embedding to feed the semantic information into Transformer. + +# Run the code + +TRAIN + +```sh +$ PYTHONPATH=../../../.. python train.py +``` + +TEST + +```sh +$ PYTHONPATH=../../../.. python train.py --option test +``` + +# Data + +We use the multiwoz data under the `data` directory, the trained model is saved at `checkpoints` directory . + +# Reference + +``` +@inproceedings{vaswani2017attention, + title={Attention is all you need}, + author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and Kaiser, Lukasz and Polosukhin, Illia}, + booktitle={Advances in neural information processing systems}, + pages={5998--6008}, + year={2017} +} +``` \ No newline at end of file From d01ca974a658fbc7799f86f4a92e7e0f540281da Mon Sep 17 00:00:00 2001 From: truthless11 Date: Thu, 19 Sep 2019 14:36:01 +0800 Subject: [PATCH 3/3] fix evaluator --- convlab/modules/nlg/multiwoz/transformer/evaluator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convlab/modules/nlg/multiwoz/transformer/evaluator.py b/convlab/modules/nlg/multiwoz/transformer/evaluator.py index 9590a86..3e39e03 100644 --- a/convlab/modules/nlg/multiwoz/transformer/evaluator.py +++ b/convlab/modules/nlg/multiwoz/transformer/evaluator.py @@ -176,7 +176,7 @@ def evaluateDialogue(dialog, realDialogue): match_stat = 1 elif venue_offered[domain]: groundtruth = queryResultVenues(domain, goal[domain]['informable'], real_belief=True) - if issubset(venue_offered[domain], groundtruth): + if len(venue_offered[domain]) > 0 and venue_offered[domain][0] in groundtruth: match += 1 match_stat = 1 else: