[go: up one dir, main page]

CN110288081A - A recurrent network model and learning method based on FW mechanism and LSTM - Google Patents

A recurrent network model and learning method based on FW mechanism and LSTM Download PDF

Info

Publication number
CN110288081A
CN110288081A CN201910476156.XA CN201910476156A CN110288081A CN 110288081 A CN110288081 A CN 110288081A CN 201910476156 A CN201910476156 A CN 201910476156A CN 110288081 A CN110288081 A CN 110288081A
Authority
CN
China
Prior art keywords
module
data
unit
training
evaluation
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN201910476156.XA
Other languages
Chinese (zh)
Inventor
王军茹
卢继华
易军凯
徐懿
李梦泽
何天恺
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Institute of Technology BIT
Beijing Information Science and Technology University
Original Assignee
Beijing Institute of Technology BIT
Beijing Information Science and Technology University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Institute of Technology BIT, Beijing Information Science and Technology University filed Critical Beijing Institute of Technology BIT
Priority to CN201910476156.XA priority Critical patent/CN110288081A/en
Publication of CN110288081A publication Critical patent/CN110288081A/en
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/048Activation functions
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/049Temporal neural networks, e.g. delay elements, oscillating neurons or pulsed inputs
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Machine Translation (AREA)

Abstract

The present invention relates to a kind of Recursive Networks model and learning method based on FW mechanism and LSTM, belongs to recurrent neural network and natural language processing technique field.Learning method including Recursive Networks model and support based on FW mechanism and LSTM;The former includes data import modul, data generation module, load and iteration module, parameter setting module, definition module, Recursive Networks training, assessment and test module;Learning method includes: 1 importing data;2, which will import data, is split to obtain training data, assessment data and test data;3, according to data are imported, obtain pre-set configuration parameter;4 complete the initialization of weight parameter;Training, assessment and test data are sent into LSTM unit by 5 calculates output vector;6 calculate loss function, optimize to network parameter, export complexity.The network model and learning method further improve the accuracy and convergence rate of LSTM model treatment.

Description

一种基于FW机制及LSTM的递归网络模型及学习方法A recurrent network model and learning method based on FW mechanism and LSTM

技术领域technical field

本发明涉及一种基于FW机制及LSTM的递归网络模型及学习方法,属于递归神经网络以及自然语言处理技术领域。The invention relates to a recursive network model and a learning method based on a FW mechanism and an LSTM, and belongs to the technical field of recursive neural networks and natural language processing.

背景技术Background technique

自然语言处理模型通常采用递归神经网络(Recurrent Neural Network,RNN)结构。RNN由隐藏层状态以及权重这两种时间规模的变量组成。隐藏层状态在每个时间步进都会更新一次;而权重则在序列所有信息全部输入网络之后再进行更新。因此,代表着网络中各层间连接关系的权重往往对应着网络的“长期记忆”。但是,实际网络各层关系变化与递进,往往与输入序列长度相关,可能是3,5个时间步进,也可能是30,50个时间步进需要更新。Natural language processing models usually adopt a recurrent neural network (Recurrent Neural Network, RNN) structure. RNN consists of two time-scale variables, hidden layer states and weights. The hidden layer state is updated at each time step; the weights are updated after all information about the sequence has been fed into the network. Therefore, the weight representing the connection relationship between layers in the network often corresponds to the "long-term memory" of the network. However, the change and progress of the relationship between each layer of the actual network is often related to the length of the input sequence. It may be 3 or 5 time steps, or it may be 30 or 50 time steps that need to be updated.

基于LSTM单元的语言模型是RNN应用较为广泛的改进网络之一,该模型通过对文本数据的训练,根据输入的文本预测该段文本中即将出现的下一个单词。网络参数初始状态用零向量进行初始化,并在读取每个单词后得到更新。模型在处理输入数据时使用反向传播法进行网络参数的优化。把输入数据,即若干句子组成的段落划分为固定长度的输入块,每个输入块还有固定长度个单词,每当处理完一个输入块后执行反向传播对网络参数进行更新。The language model based on the LSTM unit is one of the widely used improved networks of RNN. The model predicts the next word that will appear in the text according to the input text through the training of the text data. The initial state of network parameters is initialized with a zero vector and updated after each word is read. The model uses the backpropagation method to optimize the network parameters when processing the input data. Divide the input data, that is, paragraphs composed of several sentences, into fixed-length input blocks, and each input block has fixed-length words. After each input block is processed, backpropagation is performed to update the network parameters.

Jimmy Ba等人提出了快速权重(Fast Weights,FW)机制,即引入更新周期处于隐藏层状态以及权重这两种时间规模之间的新变量来存储快速更新的隐藏层状态,对于序列到序列模型的学习已被证明十分有效。出于上述考虑,在保留现有隐藏层状态和标准权重的同时引入新的变量,这种变量的更新周期要比隐藏层更新周期更长,但是比标准权重更新周期更短,也称为快速权重。Jimmy Ba et al. proposed the Fast Weights (FW) mechanism, which introduces a new variable between the hidden layer state and the weight of the update period to store the fast updated hidden layer state. For sequence-to-sequence models learning has been proven to be very effective. For the above considerations, new variables are introduced while retaining the existing hidden layer state and standard weights. The update cycle of this variable is longer than the hidden layer update cycle, but shorter than the standard weight update cycle, also known as fast Weights.

在神经网络训练方面,一般需经过复杂而耗时的处理,才能获得较好的学习性能,即需要较高的时间和计算成本。因此,研究者们为降低此时间及计算成本,往往选择批量处理。In terms of neural network training, generally complex and time-consuming processing is required to obtain better learning performance, that is, higher time and computing costs are required. Therefore, researchers often choose batch processing in order to reduce this time and computational cost.

其中,批量正则化为其中的一个很典型的技术,然而其在递归神经网络的作用并不明显。因此,G.Hinton等人提出了层正则化(layer normalization,LN),具体实现为计算递归神经网络中某一个训练样本中在一个隐藏层上所有隐藏单元的状态的均值和标准差。LN用于解决快速权重机制中随着训练变多,解决隐藏层更新值期间的溢出问题。Among them, batch regularization is a typical technique, but its role in recurrent neural networks is not obvious. Therefore, G. Hinton and others proposed layer normalization (layer normalization, LN), which is specifically implemented to calculate the mean and standard deviation of the states of all hidden units on a hidden layer in a training sample in a recurrent neural network. LN is used to solve the overflow problem during the update value of the hidden layer as the training increases in the fast weight mechanism.

衡量语言模型性能的评价指标参数是复杂度perplexity和loss。其中,perplexity表示语言模型在学习文本数据后根据句子前面的单词预测下一单词的平均可选数量。例如,一个序列是由A、B、C、D、E五个字母无规律随机组成,那么预测下一个字母时,有5个等概率的选项,那么复杂度的值即为5。因此,若一个语言模型的复杂度为K,就说明语言模型在预测即将出现的单词时,平均有K个单词拥有相同的概率作为合理预测的选择。其中,K为整数,为目标单词的总数。以PTB模型为例,评价PTB模型性能指标的复杂度perplexity值的计算公式为(1):The evaluation index parameters to measure the performance of the language model are complexity perplexity and loss. Among them, perplexity represents the average number of options for the language model to predict the next word based on the words in front of the sentence after learning the text data. For example, if a sequence is randomly composed of five letters A, B, C, D, and E, then when predicting the next letter, there are 5 options with equal probability, and the complexity value is 5. Therefore, if the complexity of a language model is K, it means that when the language model predicts upcoming words, on average, K words have the same probability as reasonable prediction choices. Among them, K is an integer, which is the total number of target words. Taking the PTB model as an example, the calculation formula for evaluating the complexity perplexity value of the PTB model performance index is (1):

其中,Ptargeti表示第i个目标单词,ln为对数函数;Among them, Ptarget i represents the i-th target word, and ln is a logarithmic function;

另一衡量语言模型性能的评价指标参数loss定义为目标单词出现概率的平均负对数,表达式如(2):Another evaluation index parameter loss to measure the performance of the language model is defined as the average negative logarithm of the target word occurrence probability, the expression is as (2):

perplexity值与loss的关系为(3):The relationship between the perplexity value and loss is (3):

perplexity=eloss (3)perplexity=e loss (3)

当语言模型学习句子中单词与单词之间相互关联的逻辑关系时,模型的学习能力越强,根据之前出现的单词预测下一个单词时,备选的单词数量就越少,相应的复杂度perplexity就越低。所以复杂度perplexity能够很好地反映网络的学习性能。复杂度perplexity越低,代表网络预测句子中下一个单词的能力越强,效果也就越好。When the language model learns the logical relationship between words and words in a sentence, the stronger the learning ability of the model is, when predicting the next word based on the previous word, the fewer the number of alternative words, the corresponding complexity perplexity lower. So the complexity perplexity can well reflect the learning performance of the network. The lower the complexity perplexity, the stronger the ability of the network to predict the next word in the sentence, and the better the effect.

发明内容Contents of the invention

本发明的目的在于进一步提升现有基于LSTM递归神经网络在处理时间关联度强的自然语言时存在复杂度性能有待进一步提升的技术现状提出了一种基于FW机制及LSTM的递归网络模型及学习方法。The purpose of the present invention is to further improve the technical status of the existing LSTM-based recursive neural network when processing natural language with strong temporal correlation, and the complexity performance needs to be further improved. A recursive network model and learning method based on FW mechanism and LSTM is proposed. .

所述基于FW机制及LSTM的递归网络模型及学习方法包括基于FW机制及LSTM的递归网络模型以及所依托的学习方法;The recursive network model and learning method based on the FW mechanism and LSTM include the recurrent network model based on the FW mechanism and LSTM and the learning method it relies on;

其中,所述基于FW机制及LSTM的递归网络模型包括数据导入模块、数据生成模块、加载与迭代模块、参数设定模块、模型定义模块、递归网络训练模块、递归网络评估模块以及递归网络测试模块;Wherein, the recursive network model based on the FW mechanism and LSTM includes a data import module, a data generation module, a loading and iteration module, a parameter setting module, a model definition module, a recursive network training module, a recursive network evaluation module and a recursive network testing module ;

其中,数据生成模块又包括数据拆分单元;加载与迭代模块包括数据加载单元和迭代单元;Among them, the data generation module includes a data splitting unit; the loading and iteration module includes a data loading unit and an iteration unit;

数据拆分单元包括训练数据生成单元、评估数据生成单元以及测试数据生成单元;The data splitting unit includes a training data generating unit, an evaluation data generating unit and a test data generating unit;

递归网络训练模块包括dropout单元、更新单元和结果储存单元;递归网络评估模块以及递归网络测试模块仅包括更新单元和结果储存单元;The recursive network training module includes a dropout unit, an update unit and a result storage unit; the recursive network evaluation module and the recursive network test module only include an update unit and a result storage unit;

其中,更新单元包括长短时记忆单元和快速权重单元;Wherein, the update unit includes a long short-term memory unit and a fast weight unit;

所述基于FW机制及LSTM的递归网络模型中各模块的连接关系如下:The connection relationship of each module in the recursive network model based on FW mechanism and LSTM is as follows:

数据导入模块与数据生成模块相连,数据生成模块和加载与迭代模块相连,参数设定模块和加载与迭代模块以及模型定义模块相连,递归网络训练模块和加载与迭代模块、递归网络评估模块以及模型定义模块相连,递归网络评估模块与加载与迭代模块、递归网络训练模块、递归网络测试模块以及模型定义模块相连;递归网络测试模块与加载与迭代模块、递归网络评估模块和模型定义模块相连;The data import module is connected to the data generation module, the data generation module is connected to the loading and iteration module, the parameter setting module is connected to the loading and iteration module and the model definition module, the recursive network training module is connected to the loading and iteration module, the recursive network evaluation module and the model The definition module is connected, the recursive network evaluation module is connected with the loading and iteration module, the recursive network training module, the recursive network testing module and the model definition module; the recursive network testing module is connected with the loading and iteration module, the recursive network evaluation module and the model definition module;

数据生成模块中各单元的连接关系如下:数据拆分单元中训练数据、评估数据和测试数据分别与训练标签生成单元、评估标签生成单元以及测试标签生成单元相连;The connection relationship of each unit in the data generation module is as follows: the training data, evaluation data and test data in the data splitting unit are respectively connected to the training label generation unit, evaluation label generation unit and test label generation unit;

加载与迭代模块中各单元的连接关系如下:数据加载单元和迭代单元相连;Loading and the connection relation of each unit in the iterative module are as follows: the data loading unit is connected with the iterative unit;

所述基于FW机制及LSTM的递归网络模型中各模块的信号产生及输出关系如下:The signal generation and output relationship of each module in the recursive network model based on the FW mechanism and LSTM is as follows:

数据导入模块的输出接入数据生成模块;数据生成模块处理后接入加载与迭代模块;参数设定模块为加载与迭代模块和模型定义模块提供输入参数及FW模型参数;加载与迭代模块分别为递归网络训练模块、递归网络评估模块和递归网络测试模块提供训练数据和训练标签、评估数据和评估标签以及测试数据和测试标签;模型定义模块将FW模型参数分别输入递归网络训练模块、递归网络评估模块以及递归网络测试模块;递归网络训练模块将训练好的网络参数送入递归网络评估模块;递归网络评估模块将评估后的网络参数送入递归网络测试模块;The output of the data import module is connected to the data generation module; the data generation module is connected to the loading and iteration module after processing; the parameter setting module provides input parameters and FW model parameters for the loading and iteration module and the model definition module; the loading and iteration modules are respectively The recursive network training module, the recursive network evaluation module and the recursive network testing module provide training data and training labels, evaluation data and evaluation labels, and test data and test labels; the model definition module inputs the FW model parameters into the recurrent network training module and the recursive network evaluation module respectively. module and the recursive network testing module; the recursive network training module sends the trained network parameters into the recursive network evaluation module; the recursive network evaluation module sends the evaluated network parameters into the recursive network testing module;

递归网络训练模块、评估模块以及测试模块中的各单元连接关系如下:The connection relationship of each unit in the recurrent network training module, evaluation module and test module is as follows:

dropout单元接收数据并与长短时记忆单元相连,长短时记忆单元与数据输入和快速权重单元相连,结果储存单元与快速权重单元和结果相连。The dropout unit receives data and is connected with the long-short-term memory unit, the long-short-term memory unit is connected with the data input and the fast weight unit, and the result storage unit is connected with the fast weight unit and the result.

所述基于FW机制及LSTM的递归网络模型以及所依托的学习方法,包括如下步骤:The recursive network model based on the FW mechanism and LSTM and the learning method it relies on include the following steps:

步骤一、待训练和测试的数据经数据导入模块导入,具体为:Step 1. The data to be trained and tested is imported through the data import module, specifically:

通过读取文本路径,获取文本数据;Obtain text data by reading the text path;

步骤二、数据生成模块对经数据导入模块导入的数据经数据拆分单元进行拆分,分别得到训练数据、评估数据和测试数据;Step 2, the data generation module splits the data imported by the data import module through the data splitting unit to obtain training data, evaluation data and test data respectively;

其中,拆分具体为:将步骤一导入的文本数据按照每j个字符为一句话进行拆分;Wherein, the splitting is specifically: splitting the text data imported in step 1 according to every j characters into a sentence;

其中,j的取值范围为5到50;Wherein, the value range of j is from 5 to 50;

步骤三、训练数据生成单元随机选取x%比例经数据拆分单元拆分后的数据生成训练集;评估数据生成单元随机选取y%比例经数据拆分单元拆分后的数据生成评估集;测试数据生成单元随机选取z%比例经数据拆分单元拆分后的数据生成测试集;Step 3, the training data generation unit randomly selects the x% ratio of the data split by the data splitting unit to generate a training set; the evaluation data generation unit randomly selects the y% ratio of the data split by the data splitting unit to generate an evaluation set; test The data generating unit randomly selects the z% ratio to generate a test set from the data split by the data splitting unit;

其中,x%+y%+z%=1;Wherein, x%+y%+z%=1;

步骤四、训练标签生成单元将训练数据生成单元生成的训练集中每一个数据后移一位得到训练标签;评估标签生成单元将评估数据生成单元生成的评估集中每一个数据后移一位得到评估标签;测试标签生成单元将测试集中每一个数据后移一位得到测试标签;Step 4, the training label generating unit shifts each data in the training set generated by the training data generating unit by one bit to obtain the training label; the evaluation label generating unit shifts each data in the evaluation set generated by the evaluation data generating unit by one bit to obtain the evaluation label ; The test label generation unit shifts each data in the test set by one bit to obtain the test label;

步骤五、参数设定模块根据数据导入模块导入文本的模型规模,获取配置参数,,再将获取的配置参数输入参数设定模块;Step 5, the parameter setting module obtains configuration parameters according to the model scale of the text imported by the data import module, and then inputs the obtained configuration parameters into the parameter setting module;

其中,配置参数包括初始规模、学习率、最大梯度正则值、层数、步数、隐藏层大小、最大epoch数、极大epoch值、dropout率、衰减率、批大小以及vocab大小;Among them, the configuration parameters include initial scale, learning rate, maximum gradient regular value, number of layers, number of steps, hidden layer size, maximum epoch number, maximum epoch value, dropout rate, decay rate, batch size, and vocab size;

步骤六、加载与迭代模块中的数据加载单元按照参数设定模块中获取的配置参数加载训练集、评估集以及测试集中的数据,并设定初始化数据序号i为1;Step 6. The data loading unit in the loading and iteration module loads the data in the training set, evaluation set and test set according to the configuration parameters obtained in the parameter setting module, and sets the initialization data sequence number i to 1;

步骤七、模型定义模块根据参数设定模块中的配置参数,使用伪随机函数在配置范围内生成随机值作为权重矩阵参数,完成权重参数的初始化;Step 7, the model definition module uses the pseudo-random function to generate random values within the configuration range as the weight matrix parameters according to the configuration parameters in the parameter setting module, and completes the initialization of the weight parameters;

步骤八、加载与迭代模块中迭代模块判断当前数据集中的数据是否发送完毕,并依据判断结果进行操作,具体为:Step 8. The iteration module in the loading and iteration module judges whether the data in the current data set has been sent, and operates according to the judgment result, specifically:

若当前数据集中的数据没发送完,则发送第i组数据,判断进行训练、评估还是测试,并跳至步骤九,跳至步骤八;否则停止迭代;If the data in the current data set has not been sent, send the i-th group of data, judge whether to train, evaluate or test, and skip to step 9, and skip to step 8; otherwise, stop the iteration;

步骤九、判断当前数据是否是训练数据,若是则依据dropout率对输入数据进行抽取,抽取后数据,跳至步骤十;否则,跳至步骤十;Step 9. Determine whether the current data is training data. If so, extract the input data according to the dropout rate. After the extracted data, skip to step 10; otherwise, skip to step 10;

步骤十、将步骤九输入的数据送入更新单元中的长短时记忆单元以及快速权重单元计算得到输出向量,同时利用梯度下降法对网络进行优化,具体为:Step 10. Send the data input in step 9 to the long-short-term memory unit and the fast weight unit in the update unit to calculate the output vector, and at the same time use the gradient descent method to optimize the network, specifically:

步骤10.1更新单元基于输入层权重Wx、标准权重Wh计算起始隐藏层状态,通过公式(4)计算当前t时刻的初始隐藏状态:Step 10.1 The update unit calculates the initial hidden layer state based on the input layer weight Wx and the standard weight Wh, and calculates the initial hidden state at the current time t by formula (4):

h0 t=f(LN(Wx*xt+Wh*ht-1)) (4)h 0 t =f(LN(W x *x t +W h *h t-1 )) (4)

其中,输入层权重记为Wx、标准权重记为Wh;h0为起始隐藏层状态,LN为层正则化函数;f为激活函数;xt为当前t时刻的输入层数据;ht-1为当前时刻的前一时刻,即t-1时刻,隐藏层状态对应的数据,简称隐藏层状态;Among them, the input layer weight is recorded as Wx, and the standard weight is recorded as Wh; h 0 is the initial hidden layer state, LN is the layer regularization function; f is the activation function; x t is the input layer data at the current time t; h t- 1 is the data corresponding to the state of the hidden layer at the moment before the current moment, that is, time t-1, referred to as the state of the hidden layer;

优选的,激活函数f为SeLU函数、Leaky Relu函数以及Swish函数中的一种;Preferably, the activation function f is one of SeLU function, Leaky Relu function and Swish function;

标准权重Wh是RNN网络中隐藏层向下一个时间步进传播的权重;输入层权重Wx是输入层到隐藏层传播的权重;The standard weight Wh is the weight propagated from the hidden layer to the next time step in the RNN network; the input layer weight Wx is the weight propagated from the input layer to the hidden layer;

步骤10.2、快速权重单元计算快速权重,具体通过公式(5)计算:Step 10.2, the fast weight unit calculates the fast weight, specifically calculated by formula (5):

WA(t)=λWA(t-1)+ηht-1hT t-1 (5)W A (t)=λW A (t-1)+ηh t-1 h T t-1 (5)

其中,WA(t)是第t时刻的快速权重,是仅作用在隐藏层每个时间步进内的权重;一个时间步进更新的总次数,记为s+1次;λ是衰减率、η是学习率、ht-1为t-1时刻对应的隐藏层状态;hT t-1是ht-1即t-1时刻对应隐藏层状态的转置;Among them, W A (t) is the fast weight at the tth moment, which is the weight that only acts on the hidden layer in each time step; the total number of updates in a time step is recorded as s+1 times; λ is the decay rate , η is the learning rate, h t-1 is the hidden layer state corresponding to the t-1 moment; h T t-1 is h t-1 , that is, the transposition of the hidden layer state corresponding to the t-1 moment;

其中,时间步进更新的总次数s+1中的s即步数;Among them, s in the total number of times of time step update s+1 is the number of steps;

其中,衰减率的取值范围为0.9到0.995,学习率的取值范围为0.3到0.8;Among them, the decay rate ranges from 0.9 to 0.995, and the learning rate ranges from 0.3 to 0.8;

步骤10.3、快速权重单元计算隐藏层状态并更新s次隐藏层状态;Step 10.3, the fast weight unit calculates the state of the hidden layer and updates the state of the hidden layer for s times;

步骤10.4、慢速权重单元计算归一化输出;Step 10.4, the slow weight unit calculates the normalized output;

其中,网络的归一化输出通过Softmax或sigmoid函数两者之一实现;Among them, the normalized output of the network is realized by either Softmax or sigmoid function;

步骤10.5、结果存储单元计算基于步骤10.4计算出的归一化输出计算损失loss和复杂度perplexity;Step 10.5, the result storage unit calculation is based on the normalized output calculated in step 10.4 to calculate loss and complexity perplexity;

步骤10.6、慢速权重单元判断是否达到最后一个Epoch,如果没有达到,则更新单元则更新隐藏层状态以及训练参数或测试参数,将当前i加1,跳至步骤八。Step 10.6. The slow weight unit judges whether the last Epoch has been reached. If not, the update unit updates the state of the hidden layer and the training parameters or test parameters, adds 1 to the current i, and skips to step 8.

有益效果Beneficial effect

本发明一种基于FW机制及LSTM的递归网络模型及学习方法,与现有技术相比,具有如下有益效果:A recursive network model and learning method based on the FW mechanism and LSTM of the present invention, compared with the prior art, has the following beneficial effects:

1.所述递归网络模型引入快速权重以及LSTM机制,通过衰减系数及学习率的参数优化,使得以储存短期记忆信息的网络模型的学习准确度得到了大幅度提高;1. The recursive network model introduces fast weight and LSTM mechanism, and through parameter optimization of attenuation coefficient and learning rate, the learning accuracy of the network model with stored short-term memory information is greatly improved;

2.本发明所述方法与现有LSTM模型以及引入快速权重的RNN模型相比,模型的训练所述方法采用LSTM结合SeLU激活函数以及层正则化使得训练、评估及测试的收敛速度大大提高。2. Compared with the existing LSTM model and the RNN model that introduces fast weights, the method of the present invention adopts LSTM in combination with SeLU activation function and layer regularization to greatly improve the convergence speed of training, evaluation and testing.

附图说明Description of drawings

图1是本发明基于FW机制及LSTM的递归网络模型的组成及各模块的连接示意图;Fig. 1 is the composition and the connection schematic diagram of each module of the recursive network model based on FW mechanism and LSTM of the present invention;

图2是本发明基于FW机制及LSTM的递归网络模型中数据生成模块的组成及连接示意图;Fig. 2 is the composition and connection schematic diagram of the data generation module in the recursive network model based on FW mechanism and LSTM of the present invention;

图3是本发明基于FW机制及LSTM的递归网络模型中加载与迭代模块的组成示意图以及与数据生成模块、参数设定模块、模型定义模块、递归网络训练模块、递归网络评估模块和递归网络测试模块的连接关系;Fig. 3 is a composition schematic diagram of loading and iteration modules in the recursive network model based on FW mechanism and LSTM of the present invention and data generation module, parameter setting module, model definition module, recursive network training module, recursive network evaluation module and recursive network test The connection relationship of the modules;

图4是本发明基于FW机制及LSTM的递归网络模型中递归网络训练模块、递归网络评估模块以及递归网络测试模块三者的关系与组成示意图;4 is a schematic diagram of the relationship and composition of the recursive network training module, the recursive network evaluation module and the recursive network testing module in the recursive network model based on the FW mechanism and LSTM of the present invention;

图5是本发明基于FW机制及LSTM的递归网络模型中长短时记忆单元和快速权重单元的组成示意图;Fig. 5 is a schematic diagram of the composition of the long-short-term memory unit and the fast weight unit in the recursive network model based on the FW mechanism and LSTM of the present invention;

图6是本发明基于FW机制及LSTM的递归网络模型依托的方法处理关联度大的短句文本数据集不同batch size的学习效果对比;Fig. 6 is a comparison of the learning effects of different batch sizes of short sentence text data sets with a large correlation degree based on the method of the present invention based on the FW mechanism and the recursive network model of LSTM;

图7是本发明基于FW机制及LSTM的递归网络模型依托的方法处理关联度大的短句文本数据集不同模型的log(perplexity)对比。Fig. 7 is a log (perplexity) comparison of different models of short sentence text data sets with high relevance by the method based on the FW mechanism and the recursive network model of LSTM in the present invention.

具体实施方式Detailed ways

下面结合附图和实施例对本发明基于FW机制及LSTM的递归网络模型及学习方法做进一步说明和详细描述。The recurrent network model and learning method based on the FW mechanism and LSTM of the present invention will be further explained and described in detail below in conjunction with the drawings and embodiments.

实施例1Example 1

本实施例阐述了基于本发明所述的基于FW机制及LSTM的递归网络模型的组成及工作流程。This embodiment describes the composition and workflow of the recurrent network model based on the FW mechanism and LSTM of the present invention.

具体实施时,语料采用流行应用于自然语言处理的NLTK文本语料库中的富有代表性的短句库——欧盟国家会议语料europarl_raw进行试验。In the specific implementation, the corpus uses the representative short sentence library in the NLTK text corpus popularly used in natural language processing - the EU national conference corpus europarl_raw for experiments.

europarl_raw语料库文本数据来源于会议对话,句子大多数为中短句,长度大约为十个单词左右,句式较为简单,大多为主谓宾结构。具体到本实施例,采用图1中各模块对该数据集进行处理。The text data of the europarl_raw corpus comes from conference conversations. Most of the sentences are short and medium sentences, about ten words in length. The sentence structure is relatively simple, and most of them have a subject-predicate-object structure. Specifically in this embodiment, the data set is processed by using each module in FIG. 1 .

图1示意了基于FW机制及LSTM的递归网络模型的组成及各模块的连接,从图1中可以看出,数据导入模块导入的数据送入数据生成模块中;数据生成模块生成训练数据、评估数据以及测试数据及其标签,输入加载与迭代模块中;加载与迭代模块与模型定义模块接收参数设定模块的参数,并分别连入递归网络训练模块、评估模块以及测试模块,进行训练、评估及测试。Figure 1 schematically shows the composition of the recursive network model based on the FW mechanism and LSTM and the connection of each module. It can be seen from Figure 1 that the data imported by the data import module is sent to the data generation module; the data generation module generates training data, evaluates The data and test data and their labels are input into the loading and iteration module; the loading and iteration module and the model definition module receive the parameters of the parameter setting module, and are respectively connected to the recurrent network training module, evaluation module and test module for training and evaluation and testing.

。首先通过数据导入模块通过读取文本路径将文本数据导入;导入后输出至数据生成模块,数据生成模块进一步将原始数据拆分成训练数据、评估数据和测试数据,再经过训练标签生成单元、评估标签生成单元以及测试标签生成单元生成各数据集的标签,其结构如图2中数据生成模块的连接示意图所示。. First, the text data is imported through the data import module by reading the text path; after import, it is output to the data generation module, and the data generation module further splits the original data into training data, evaluation data and test data, and then passes through the training label generation unit, evaluation The label generation unit and the test label generation unit generate labels for each data set, and their structure is shown in the connection schematic diagram of the data generation module in FIG. 2 .

其中,参数设定模块中预先设置的配置参数有如下表1中所述的4种:Among them, the configuration parameters preset in the parameter setting module include four types as described in Table 1 below:

表1各配置具体参数设置Table 1 Specific parameter settings for each configuration

参数设定模块依据数据导入模块导入的文本的模型规模,获取由表1所示的合适配置参数,将其输入参数设定模块,尔后送往模型定义模块和加载与迭代模块。模型定义模块根据参数设定模块中的配置参数,使用伪随机函数在配置范围内生成随机值作为权重矩阵参数,完成权重参数的初始化。According to the model size of the text imported by the data import module, the parameter setting module obtains the appropriate configuration parameters shown in Table 1, and inputs them into the parameter setting module, and then sends them to the model definition module and the loading and iteration module. According to the configuration parameters in the parameter setting module, the model definition module uses the pseudo-random function to generate random values within the configuration range as the parameters of the weight matrix to complete the initialization of the weight parameters.

加载与迭代模块中的数据加载单元按照参数设定模块中获取的配置参数加载训练集、评估集以及测试集中的数据;迭代模块判断当前数据集中的数据是否发送完毕,并依据判断结果进行操作。若当前数据集为训练数据,则输出至递归网络训练模块;若为评估数据,则输出至递归网络评估模块;若为测试数据,则输出至递归网络测试模块。由图3可看出加载与迭代模块的工作示意图以及与数据生成模块、参数设定模块、模型定义模块、递归网络训练模块、递归网络评估模块和递归网络测试模块的连接关系。The data loading unit in the loading and iteration module loads the data in the training set, evaluation set, and test set according to the configuration parameters obtained in the parameter setting module; the iteration module judges whether the data in the current data set has been sent, and operates according to the judgment result. If the current data set is training data, it is output to the recursive network training module; if it is evaluation data, it is output to the recurrent network evaluation module; if it is test data, it is output to the recurrent network testing module. From Figure 3, we can see the working diagram of the loading and iteration module and the connection relationship with the data generation module, parameter setting module, model definition module, recursive network training module, recursive network evaluation module and recursive network testing module.

图4示意了基于FW机制及LSTM的递归网络模型中递归网络训练模块、递归网络评估模块以及递归网络测试模块三者的关系与组成。递归网络评估模块以及递归网络测试模块与递归网络训练模块差异是不包括dropout单元而仅包含更新单元和结果储存单元。递归网络训练模块将训练好的网络参数送入递归网络评估模块;递归网络评估模块将评估后的网络参数送入递归网络测试模块。Figure 4 illustrates the relationship and composition of the recurrent network training module, recurrent network evaluation module and recurrent network testing module in the recurrent network model based on the FW mechanism and LSTM. The difference between the recurrent network evaluation module and the recurrent network test module and the recurrent network training module is that the dropout unit is not included but only the update unit and the result storage unit are included. The recursive network training module sends the trained network parameters to the recursive network evaluation module; the recursive network evaluation module sends the evaluated network parameters to the recursive network testing module.

从图4可以看出,更新单元包括长短时记忆单元和快速权重单元;图4可以看出,递归网络评估模块以及递归网络测试模块与递归网络训练模块差异是不包括dropout单元;仅包含更新单元和结果储存单元。As can be seen from Figure 4, the update unit includes a long-short-term memory unit and a fast weight unit; as can be seen in Figure 4, the difference between the recursive network evaluation module and the recursive network test module and the recursive network training module is that the dropout unit is not included; only the update unit is included and result storage unit.

图5示意了本模型中长短时记忆单元和快速权重单元的组成。图5中,Xt对应t时刻的输入层数据;C(t-1)以及C’(t)分别对应LSTM在t时刻记忆单元C的输入和输出;C’(t)再经过快速权重进行更新,生成C(t);作为下一时刻LSTM记忆单元C的输入;ht-1以及ht分别为t-1时刻以及t时刻的LSTMcell的输出。图5中的σ为激活函数sigmoid;tanh为tanh激活函数。Figure 5 illustrates the composition of the long short-term memory unit and the fast weight unit in this model. In Figure 5, X t corresponds to the input layer data at time t; C(t-1) and C'(t) correspond to the input and output of LSTM memory unit C at time t; Update and generate C(t); as the input of the LSTM memory unit C at the next moment; h t-1 and h t are the outputs of the LSTM cell at time t-1 and time t respectively. σ in Figure 5 is the activation function sigmoid; tanh is the tanh activation function.

图5中,C’(t)=h0(t)以及C(t)hs(t)分别对应初始快速权重更新前的t时刻记忆单元的输入,以及更新后t时刻记忆单元的输入。In FIG. 5 , C'(t)=h 0 (t) and C(t)h s (t) respectively correspond to the input of the memory unit at time t before the initial fast weight update and the input of the memory unit at time t after the update.

实施例2Example 2

本实例阐述了基于本发明递归网络模型所依托的方法,处理关联度大的断句文本数据集的学习效果对比。This example illustrates the comparison of the learning effects of processing the sentence sentence data sets with high relevance based on the method relied on by the recursive network model of the present invention.

我们将目光转移到对由句子之间关联性较强,且句子长度较短的文本数据的处理,由于句子较短,更加注重短期内输入单词与单词之间的联系。我们使用流行应用于自然语言处理的NLTK文本语料库中的富有代表性的短句库——欧盟国家会议语料europarl_raw进行试验。We turn our attention to the processing of text data with strong correlation between sentences and short sentences. Due to the short sentences, we pay more attention to the connection between input words and words in a short period of time. We conduct experiments using the European Union National Assembly corpus europarl_raw, a representative short sentence corpus from the NLTK text corpus popularly used in natural language processing.

在使用europarl_raw语料库时,将num_steps统一设置为10,代表网络按照每输入十个单词为一句完整的句子处理。When using the europarl_raw corpus, set num_steps to 10, which means that the network processes every ten words as a complete sentence.

首先需确定合适的更新次数s。Firstly, it is necessary to determine the appropriate number of updates s.

当快速权重在当前的时刻得到更新后,将对隐藏状态进行循环s次的更新,相比于toy game场景的样本数据,文本数据的前后单词关联性较为复杂,我们需要加快更新频率,即加大s的数值以更大的发挥快速权重处理短期记忆的功能。我们调整一个时间步骤内隐藏状态的更新次数,固定隐藏单元数为50,batch_size为20,改变S=5,6,7,8,记录模型训练效果,如下表2所示:When the fast weight is updated at the current moment, the hidden state will be updated s times. Compared with the sample data of the toy game scene, the correlation between the words before and after the text data is more complicated. We need to speed up the update frequency, that is, add Larger values of s give greater play to the function of fast weight processing short-term memory. We adjust the number of updates of the hidden state within a time step, fix the number of hidden units to 50, batch_size to 20, change S=5, 6, 7, 8, and record the model training effect, as shown in Table 2 below:

表2不同更新次数下模型分别在训练到第5,10,13个epoch时的复杂度对比Table 2 Comparison of the complexity of the model when it is trained to the 5th, 10th, and 13th epochs under different update times

更新次数sUpdate times s 复杂度-5Complexity -5 复杂度-10Complexity -10 复杂度-13Complexity - 13 55 189.380189.380 108.083108.083 105.231105.231 66 145.939145.939 73.87573.875 71.33171.331 77 138.889138.889 68.32368.323 65.94665.946 88 139.400139.400 70.04970.049 67.64267.642

如表2所示,当更新次数s=7时,快速权重模型在训练到第5个epoch时复杂度为138.889,第10个epoch下降为68.323,第13个epoch时收敛于65.946。As shown in Table 2, when the number of updates s=7, the complexity of the fast weight model is 138.889 when training to the fifth epoch, drops to 68.323 at the 10th epoch, and converges to 65.946 at the 13th epoch.

此后我们将确定合适的batch size。Thereafter we will determine the appropriate batch size.

合适的batch size对于一个网络的学习性能至关重要,batch size过大,会导致模型在进行梯度下降法寻找最优解时找到的是局部最小值而不是全局最小值,而batchsize过小则会导致收敛速度慢,模型学习效果差。所以为了提升引入了快速权重的新模型的性能,我们固定隐藏单元数量为50,更新次数s设置为前文中验证过的最优值7,组成句子的单词数num_steps=10,改变batch size等于10,20,30,50,记录模型训练效果,如下表3所示:An appropriate batch size is crucial to the learning performance of a network. If the batch size is too large, it will cause the model to find the local minimum instead of the global minimum when the gradient descent method is used to find the optimal solution. If the batch size is too small, it will As a result, the convergence speed is slow and the model learning effect is poor. Therefore, in order to improve the performance of the new model that introduces fast weights, we fix the number of hidden units to 50, set the number of updates s to the optimal value 7 verified in the previous article, the number of words that make up the sentence num_steps=10, and change the batch size to 10 ,20,30,50, record the model training effect, as shown in Table 3 below:

表3不同batch size下模型在第10个epoch时的复杂度对比Table 3 Comparison of the complexity of the model at the 10th epoch under different batch sizes

从表3可以看到,batch_size=20时模型收敛后的复杂度最低,在训练到第10个epoch时复杂度为45.139,在训练到第13个epoch时复杂度低至43.344,batch size等于10和30时,模型训练到第13个epoch的复杂度约为51。为了更直观的表示出不同batch size下复杂度的差异,我们对复杂度取以10为底的对数log(perplexity),对比不同batch size下快速权重模型的log(perplexity)差异,如图6所示。As can be seen from Table 3, when batch_size=20, the complexity of the model after convergence is the lowest. When training to the 10th epoch, the complexity is 45.139. When training to the 13th epoch, the complexity is as low as 43.344. The batch size is equal to 10 and 30, the complexity of model training to the 13th epoch is about 51. In order to more intuitively show the difference in complexity under different batch sizes, we take the logarithm log(perplexity) with the base 10 as the complexity, and compare the log(perplexity) difference of the fast weight model under different batch sizes, as shown in Figure 6 shown.

图6中,横坐标为训练epoch数,纵坐标为以10为底,复杂度的对数log(perplexity)。可以看到以每20个epoch为一批数据进行训练时模型的复杂度最低,学习效果最佳,并进行语言模型的对比。In Figure 6, the abscissa is the number of training epochs, and the ordinate is the logarithm log(perplexity) of the complexity based on 10. It can be seen that when every 20 epochs are used as a batch of data for training, the complexity of the model is the lowest, and the learning effect is the best, and the language model is compared.

固定隐藏单元数量为50,组成句子的单词数num_steps=10,使用SeLU函数作为激活函数。对比LSTM模型,RNN模型,快速权重与LSTM网络结合的模型和快速权重与RNN结合的模型共四个模型的训练效果。模型训练复杂度如表4所示:The number of fixed hidden units is 50, the number of words that make up a sentence is num_steps=10, and the SeLU function is used as the activation function. Compare the training effects of the LSTM model, the RNN model, the model combining fast weight and LSTM network, and the model combining fast weight and RNN. The model training complexity is shown in Table 4:

表4:不同模型基于europarl_raw数据库的训练复杂度对比Table 4: Comparison of training complexity of different models based on europarl_raw database

模型名称model name 复杂度-5Complexity -5 复杂度-10Complexity -10 复杂度-15Complexity -15 复杂度-20Complexity - 20 LSTMLSTMs 267.602267.602 178.175178.175 174.935174.935 174.824174.824 LSTM+FWLSTM+FW 90.94590.945 45.13945.139 43.28043.280 43.20843.208 RNNRNN 1037.7191037.719 421.531421.531 412.841412.841 412.510412.510 RNN+FWRNN+FW 533.806533.806 378.564378.564 369.842369.842 369.474369.474

从表4能够看到,引入快速权重的LSTM模型在训练到第5个epoch时复杂度为90.945,在第10个epoch时进一步降低至45.139.训练到第15个epoch时模型复杂度为43.280,模型达到收敛。同样训练到第15个epoch时,LSTM模型的复杂度收敛至174.824,比引入快速权重的LSTM模型高出131,其次是引入快速权重的RNN网络,复杂度收敛于369.474,效果最差的是RNN模型,复杂度收敛于412.510。As can be seen from Table 4, the complexity of the LSTM model that introduces fast weights is 90.945 when it is trained to the 5th epoch, and further reduces to 45.139 at the 10th epoch. When it is trained to the 15th epoch, the model complexity is 43.280, The model reaches convergence. Similarly, when training to the 15th epoch, the complexity of the LSTM model converges to 174.824, which is 131 higher than the LSTM model with fast weights, followed by the RNN network with fast weights, the complexity converges to 369.474, and the worst effect is RNN model, the complexity converges to 412.510.

为了更直观的表示不同模型的复杂度差异,将复杂度取以10为底的对数,对比不同模型的log(perplexity)差异,如图7所示。In order to express the complexity difference of different models more intuitively, the complexity is taken as the logarithm to the base 10, and the log (perplexity) difference of different models is compared, as shown in Figure 7.

从图7能够看出,引入了快速权重的LSTM模型收敛后的复杂度最低,模型学习效果最好,且与未引入快速权重的LSTM模型差异很大,说明在LSTM网络中引入快速权重,模型训练效果提升明显。RNN模型收敛后的复杂度最高,加入快速权重后的RNN模型训练效果略有提升,但效果不明显。It can be seen from Figure 7 that the LSTM model with fast weights introduced has the lowest complexity after convergence, and the model learning effect is the best, and it is very different from the LSTM model without fast weights, which shows that the introduction of fast weights in the LSTM network, the model The training effect is significantly improved. The complexity of the RNN model after convergence is the highest, and the training effect of the RNN model after adding fast weights is slightly improved, but the effect is not obvious.

以上所述为本发明的较佳实施例而已,本发明不应该局限于该实施例和附图所公开的内容。凡是不脱离本发明所公开的精神下完成的等效或修改,都落入本发明保护的范围。The above description is only a preferred embodiment of the present invention, and the present invention should not be limited to the content disclosed in this embodiment and the accompanying drawings. All equivalents or modifications accomplished without departing from the disclosed spirit of the present invention fall within the protection scope of the present invention.

Claims (9)

1.基于FW机制及LSTM的递归网络模型,其特征在于:包括数据导入模块、数据生成模块、加载与迭代模块、参数设定模块、模型定义模块、递归网络训练模块、递归网络评估模块以及递归网络测试模块;1. The recursive network model based on the FW mechanism and LSTM is characterized in that it includes a data import module, a data generation module, a loading and iteration module, a parameter setting module, a model definition module, a recursive network training module, a recursive network evaluation module and a recursive network model. Network test module; 其中,数据生成模块又包括数据拆分单元;加载与迭代模块包括数据加载单元和迭代单元;Among them, the data generation module includes a data splitting unit; the loading and iteration module includes a data loading unit and an iteration unit; 数据拆分单元包括训练数据生成单元、评估数据生成单元以及测试数据生成单元;The data splitting unit includes a training data generating unit, an evaluation data generating unit and a test data generating unit; 递归网络训练模块包括dropout单元、更新单元和结果储存单元;递归网络评估模块以及递归网络测试模块仅包括更新单元和结果储存单元;The recursive network training module includes a dropout unit, an update unit and a result storage unit; the recursive network evaluation module and the recursive network test module only include an update unit and a result storage unit; 其中,更新单元包括长短时记忆单元和快速权重单元;Wherein, the update unit includes a long short-term memory unit and a fast weight unit; 所述基于FW机制及LSTM的递归网络模型中各模块的连接关系如下:The connection relationship of each module in the recursive network model based on FW mechanism and LSTM is as follows: 数据导入模块与数据生成模块相连,数据生成模块和加载与迭代模块相连,参数设定模块和加载与迭代模块以及模型定义模块相连,递归网络训练模块和加载与迭代模块、递归网络评估模块以及模型定义模块相连,递归网络评估模块与加载与迭代模块、递归网络训练模块、递归网络测试模块以及模型定义模块相连;递归网络测试模块与加载与迭代模块、递归网络评估模块和模型定义模块相连;The data import module is connected to the data generation module, the data generation module is connected to the loading and iteration module, the parameter setting module is connected to the loading and iteration module and the model definition module, the recursive network training module is connected to the loading and iteration module, the recursive network evaluation module and the model The definition module is connected, the recursive network evaluation module is connected with the loading and iteration module, the recursive network training module, the recursive network testing module and the model definition module; the recursive network testing module is connected with the loading and iteration module, the recursive network evaluation module and the model definition module; 数据生成模块中各单元的连接关系如下:数据拆分单元中训练数据、评估数据和测试数据分别与训练标签生成单元、评估标签生成单元以及测试标签生成单元相连;The connection relationship of each unit in the data generation module is as follows: the training data, evaluation data and test data in the data splitting unit are respectively connected to the training label generation unit, evaluation label generation unit and test label generation unit; 加载与迭代模块中各单元的连接关系如下:数据加载单元和迭代单元相连;Loading and the connection relation of each unit in the iterative module are as follows: the data loading unit is connected with the iterative unit; 所述基于FW机制及LSTM的递归网络模型中各模块的信号产生及输出关系如下:The signal generation and output relationship of each module in the recursive network model based on the FW mechanism and LSTM is as follows: 数据导入模块的输出接入数据生成模块;数据生成模块处理后接入加载与迭代模块;参数设定模块为加载与迭代模块和模型定义模块提供输入参数及FW模型参数;加载与迭代模块分别为递归网络训练模块、递归网络评估模块和递归网络测试模块提供训练数据和训练标签、评估数据和评估标签以及测试数据和测试标签;模型定义模块将FW模型参数分别输入递归网络训练模块、递归网络评估模块以及递归网络测试模块;递归网络训练模块将训练好的网络参数送入递归网络评估模块;递归网络评估模块将评估后的网络参数送入递归网络测试模块;The output of the data import module is connected to the data generation module; the data generation module is connected to the loading and iteration module after processing; the parameter setting module provides input parameters and FW model parameters for the loading and iteration module and the model definition module; the loading and iteration modules are respectively The recursive network training module, the recursive network evaluation module and the recursive network testing module provide training data and training labels, evaluation data and evaluation labels, and test data and test labels; the model definition module inputs the FW model parameters into the recurrent network training module and the recursive network evaluation module respectively. module and the recursive network testing module; the recursive network training module sends the trained network parameters into the recursive network evaluation module; the recursive network evaluation module sends the evaluated network parameters into the recursive network testing module; 递归网络训练模块、评估模块以及测试模块中的各单元连接关系如下:The connection relationship of each unit in the recurrent network training module, evaluation module and test module is as follows: dropout单元接收数据并与长短时记忆单元相连,长短时记忆单元与数据输入和快速权重单元相连,结果储存单元与快速权重单元和结果相连。The dropout unit receives data and is connected with the long-short-term memory unit, the long-short-term memory unit is connected with the data input and the fast weight unit, and the result storage unit is connected with the fast weight unit and the result. 2.如权利要求1所述的基于FW机制及LSTM的递归网络模型依托的学习方法,其特征在于:包括如下步骤:2. the learning method based on the recursive network model of FW mechanism and LSTM as claimed in claim 1, is characterized in that: comprise the steps: 步骤一、待训练和测试的数据经数据导入模块导入,具体为:Step 1. The data to be trained and tested is imported through the data import module, specifically: 步骤二、数据生成模块对经数据导入模块导入的数据经数据拆分单元进行拆分,分别得到训练数据、评估数据和测试数据;Step 2, the data generation module splits the data imported by the data import module through the data splitting unit to obtain training data, evaluation data and test data respectively; 步骤三、训练数据生成单元随机选取x%比例经数据拆分单元拆分后的数据生成训练集;评估数据生成单元随机选取y%比例经数据拆分单元拆分后的数据生成评估集;测试数据生成单元随机选取z%比例经数据拆分单元拆分后的数据生成测试集;Step 3, the training data generation unit randomly selects the x% ratio of the data split by the data splitting unit to generate a training set; the evaluation data generation unit randomly selects the y% ratio of the data split by the data splitting unit to generate an evaluation set; test The data generating unit randomly selects the z% ratio to generate a test set from the data split by the data splitting unit; 步骤四、训练标签生成单元将训练数据生成单元生成的训练集中每一个数据后移一位得到训练标签;评估标签生成单元将评估数据生成单元生成的评估集中每一个数据后移一位得到评估标签;测试标签生成单元将测试集中每一个数据后移一位得到测试标签;Step 4, the training label generating unit shifts each data in the training set generated by the training data generating unit by one bit to obtain the training label; the evaluation label generating unit shifts each data in the evaluation set generated by the evaluation data generating unit by one bit to obtain the evaluation label ; The test label generation unit shifts each data in the test set by one bit to obtain the test label; 步骤五、参数设定模块根据数据导入模块导入文本的模型规模,获取配置参数,再将获取的配置参数输入参数设定模块;Step 5, the parameter setting module obtains configuration parameters according to the model scale of the text imported by the data import module, and then inputs the obtained configuration parameters into the parameter setting module; 其中,配置参数包括初始规模、学习率、最大梯度正则值、层数、步数、隐藏层大小、最大epoch数、极大epoch值、dropout率、衰减率、批大小以及vocab大小;Among them, the configuration parameters include initial scale, learning rate, maximum gradient regular value, number of layers, number of steps, hidden layer size, maximum epoch number, maximum epoch value, dropout rate, decay rate, batch size, and vocab size; 步骤六、加载与迭代模块中的数据加载单元按照参数设定模块中获取的配置参数加载训练集、评估集以及测试集中的数据,并设定初始化数据序号i为1;Step 6. The data loading unit in the loading and iteration module loads the data in the training set, evaluation set and test set according to the configuration parameters obtained in the parameter setting module, and sets the initialization data sequence number i to 1; 步骤七、模型定义模块根据参数设定模块中的配置参数,使用伪随机函数在配置范围内生成随机值作为权重矩阵参数,完成权重参数的初始化;Step 7, the model definition module uses the pseudo-random function to generate random values within the configuration range as the weight matrix parameters according to the configuration parameters in the parameter setting module, and completes the initialization of the weight parameters; 步骤八、加载与迭代模块中迭代模块判断当前数据集中的数据是否发送完毕,并依据判断结果进行操作,具体为:Step 8. The iteration module in the loading and iteration module judges whether the data in the current data set has been sent, and operates according to the judgment result, specifically: 若当前数据集中的数据没发送完,则发送第i组数据,判断进行训练、评估还是测试,并跳至步骤九,跳至步骤八;否则停止迭代;If the data in the current data set has not been sent, send the i-th group of data, judge whether to train, evaluate or test, and skip to step 9, and skip to step 8; otherwise, stop the iteration; 步骤九、判断当前数据是否是训练数据,若是则依据dropout率对输入数据进行抽取,抽取后数据,跳至步骤十;否则,跳至步骤十;Step 9. Determine whether the current data is training data. If so, extract the input data according to the dropout rate. After the extracted data, skip to step 10; otherwise, skip to step 10; 步骤十、将步骤九输入的数据送入更新单元中的长短时记忆单元以及快速权重单元计算得到输出向量,同时利用梯度下降法对网络进行优化,具体为:Step 10. Send the data input in step 9 to the long-short-term memory unit and the fast weight unit in the update unit to calculate the output vector, and at the same time use the gradient descent method to optimize the network, specifically: 步骤10.1更新单元基于输入层权重Wx、标准权重Wh计算起始隐藏层状态,通过公式(4)计算当前t时刻的初始隐藏状态:Step 10.1 The update unit calculates the initial hidden layer state based on the input layer weight Wx and the standard weight Wh, and calculates the initial hidden state at the current time t by formula (4): h0 t=f(LN(Wx*xt+ Wh*ht-1)) (4)h 0 t =f(LN(W x *x t + W h *h t-1 )) (4) 其中,输入层权重记为Wx、标准权重记为Wh;h0为起始隐藏层状态,LN为层正则化函数;f为激活函数;xt为当前t时刻的输入层数据;ht-1为当前时刻的前一时刻,即t-1时刻,隐藏层状态对应的数据,简称隐藏层状态;Among them, the input layer weight is recorded as Wx, and the standard weight is recorded as Wh; h 0 is the initial hidden layer state, LN is the layer regularization function; f is the activation function; x t is the input layer data at the current time t; h t- 1 is the data corresponding to the state of the hidden layer at the moment before the current moment, that is, time t-1, referred to as the state of the hidden layer; 标准权重Wh是RNN网络中隐藏层向下一个时间步进传播的权重;输入层权重Wx是输入层到隐藏层传播的权重;The standard weight Wh is the weight propagated from the hidden layer to the next time step in the RNN network; the input layer weight Wx is the weight propagated from the input layer to the hidden layer; 步骤10.2、快速权重单元计算快速权重,具体通过公式(5)计算:Step 10.2, the fast weight unit calculates the fast weight, specifically calculated by formula (5): WA(t)=λWA(t-1)+ηht-1hT t-1 (5)W A (t)=λW A (t-1)+ηh t-1 h T t-1 (5) 其中,WA(t)是第t时刻的快速权重,是仅作用在隐藏层每个时间步进内的权重;一个时间步进更新的总次数,记为s+1次;λ是衰减率、η是学习率、ht-1为t-1时刻对应的隐藏层状态;hT t-1是ht-1即t-1时刻对应隐藏层状态的转置;Among them, W A (t) is the fast weight at the tth moment, which is the weight that only acts on the hidden layer in each time step; the total number of updates in a time step is recorded as s+1 times; λ is the decay rate , η is the learning rate, h t-1 is the hidden layer state corresponding to the t-1 moment; h T t-1 is h t-1 , that is, the transposition of the hidden layer state corresponding to the t-1 moment; 其中,时间步进更新的总次数s+1中的s即步数;Among them, s in the total number of times of time step update s+1 is the number of steps; 步骤10.3、快速权重单元计算隐藏层状态并更新s次隐藏层状态;Step 10.3, the fast weight unit calculates the state of the hidden layer and updates the state of the hidden layer for s times; 步骤10.4、慢速权重单元计算归一化输出;Step 10.4, the slow weight unit calculates the normalized output; 其中,网络的归一化输出通过Softmax或sigmoid函数两者之一实现;Among them, the normalized output of the network is realized by either Softmax or sigmoid function; 步骤10.5、结果存储单元计算基于步骤10.4计算出的归一化输出计算损失loss和复杂度perplexity;Step 10.5, the calculation of the result storage unit is based on the normalized output calculated in step 10.4 to calculate loss and complexity perplexity; 步骤10.6、慢速权重单元判断是否达到最后一个Epoch,如果没有达到,则更新单元则更新隐藏层状态以及训练参数或测试参数,将当前i加1,跳至步骤八。Step 10.6. The slow weight unit judges whether the last Epoch has been reached. If not, the update unit updates the state of the hidden layer and the training parameters or test parameters, adds 1 to the current i, and skips to step 8. 3.如权利要求2所述的基于FW机制及LSTM的递归网络模型依托的学习方法,其特征在于:步骤一通过读取文本路径,获取文本数据。3. The learning method based on the recursive network model of FW mechanism and LSTM as claimed in claim 2, characterized in that: Step 1 obtains the text data by reading the text path. 4.如权利要求2所述的基于FW机制及LSTM的递归网络模型依托的学习方法,其特征在于:步骤二中,拆分具体为:将步骤一导入的文本数据按照每j个字符为一句话进行拆分。4. the learning method based on the recursive network model of FW mechanism and LSTM as claimed in claim 2 is characterized in that: in step 2, splitting is specifically: the text data that step 1 imports is according to every j character is a sentence Words are split. 5.如权利要求4所述的基于FW机制及LSTM的递归网络模型依托的学习方法,其特征在于:j的取值范围为5到50。5. The learning method based on FW mechanism and LSTM recursive network model as claimed in claim 4, characterized in that: the value range of j is 5 to 50. 6.如权利要求2所述的基于FW机制及LSTM的递归网络模型依托的学习方法,其特征在于:步骤三中,x%+y%+z%=1。6. The learning method based on FW mechanism and LSTM recursive network model as claimed in claim 2, characterized in that: in step 3, x%+y%+z%=1. 7.如权利要求2所述的基于FW机制及LSTM的递归网络模型依托的学习方法,其特征在于:步骤10.1中激活函数f为SeLU函数、Leaky Relu函数以及Swish函数中的一种。7. The learning method based on the recursive network model of FW mechanism and LSTM as claimed in claim 2, characterized in that: in step 10.1, the activation function f is one of SeLU function, Leaky Relu function and Swish function. 8.如权利要求2所述的基于FW机制及LSTM的递归网络模型依托的学习方法,其特征在于:步骤10.2中,衰减率的取值范围为0.9到0.995。8. The learning method based on FW mechanism and LSTM recursive network model as claimed in claim 2, characterized in that: in step 10.2, the decay rate ranges from 0.9 to 0.995. 9.如权利要求2所述的基于FW机制及LSTM的递归网络模型依托的学习方法,其特征在于:步骤10.2中,学习率的取值范围为0.3到0.8。9. The learning method based on FW mechanism and LSTM recursive network model as claimed in claim 2, characterized in that: in step 10.2, the learning rate ranges from 0.3 to 0.8.
CN201910476156.XA 2019-06-03 2019-06-03 A recurrent network model and learning method based on FW mechanism and LSTM Pending CN110288081A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201910476156.XA CN110288081A (en) 2019-06-03 2019-06-03 A recurrent network model and learning method based on FW mechanism and LSTM

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201910476156.XA CN110288081A (en) 2019-06-03 2019-06-03 A recurrent network model and learning method based on FW mechanism and LSTM

Publications (1)

Publication Number Publication Date
CN110288081A true CN110288081A (en) 2019-09-27

Family

ID=68003232

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201910476156.XA Pending CN110288081A (en) 2019-06-03 2019-06-03 A recurrent network model and learning method based on FW mechanism and LSTM

Country Status (1)

Country Link
CN (1) CN110288081A (en)

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2018151125A1 (en) * 2017-02-15 2018-08-23 日本電信電話株式会社 Word vectorization model learning device, word vectorization device, speech synthesis device, method for said devices, and program
CN109214452A (en) * 2018-08-29 2019-01-15 杭州电子科技大学 Based on the HRRP target identification method for paying attention to depth bidirectional circulating neural network
US20190087709A1 (en) * 2016-04-29 2019-03-21 Cambricon Technologies Corporation Limited Apparatus and method for executing recurrent neural network and lstm computations
CN109508377A (en) * 2018-11-26 2019-03-22 南京云思创智信息科技有限公司 Text feature, device, chat robots and storage medium based on Fusion Model
US20190114544A1 (en) * 2017-10-16 2019-04-18 Illumina, Inc. Semi-Supervised Learning for Training an Ensemble of Deep Convolutional Neural Networks

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20190087709A1 (en) * 2016-04-29 2019-03-21 Cambricon Technologies Corporation Limited Apparatus and method for executing recurrent neural network and lstm computations
WO2018151125A1 (en) * 2017-02-15 2018-08-23 日本電信電話株式会社 Word vectorization model learning device, word vectorization device, speech synthesis device, method for said devices, and program
US20190114544A1 (en) * 2017-10-16 2019-04-18 Illumina, Inc. Semi-Supervised Learning for Training an Ensemble of Deep Convolutional Neural Networks
CN109214452A (en) * 2018-08-29 2019-01-15 杭州电子科技大学 Based on the HRRP target identification method for paying attention to depth bidirectional circulating neural network
CN109508377A (en) * 2018-11-26 2019-03-22 南京云思创智信息科技有限公司 Text feature, device, chat robots and storage medium based on Fusion Model

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
T. ANDERSON KELLER ET AL.: "FAST WEIGHT LONG SHORT-TERM MEMORY", 《ICLR 2018》 *

Similar Documents

Publication Publication Date Title
CN107358948B (en) An attention model-based approach to language input relevance detection
US11170158B2 (en) Abstractive summarization of long documents using deep learning
Merity et al. An analysis of neural language modeling at multiple scales
CN109635109B (en) Sentence classification method based on LSTM combined with part of speech and multi-attention mechanism
CN108268643A (en) A kind of Deep Semantics matching entities link method based on more granularity LSTM networks
CN111400470A (en) Question processing method and device, computer equipment and storage medium
CN109285562A (en) Speech emotion recognition method based on attention mechanism
CN111125333B (en) A Generative Question Answering Method Based on Representation Learning and Multilayer Covering Mechanism
CN110288029B (en) Image description method based on Tri-LSTMs model
CN108520298A (en) A Semantic Consistency Verification Method for Land and Air Conversation Based on Improved LSTM-RNN
CN108879732B (en) Power system transient stability assessment method and device
Shi et al. The prediction of character based on recurrent neural network language model
CN111353040A (en) GRU-based attribute level emotion analysis method
CN111782799A (en) An Enhanced Text Summarization Generation Method Based on Replication Mechanism and Variational Neural Inference
CN115600602A (en) Method, system and terminal device for extracting key elements of long text
CN113780346B (en) Priori constraint classifier adjustment method, system and readable storage medium
Zhang et al. Evaluation of judicial imprisonment term prediction model based on text mutation
CN116579350B (en) Robustness analysis method and device for dialogue understanding model and computer equipment
CN114036938A (en) News classification method for extracting text features by fusing topic information and word vectors
CN110288081A (en) A recurrent network model and learning method based on FW mechanism and LSTM
EP4293956A1 (en) Method for predicting malicious domains
CN116318845A (en) DGA domain name detection method under unbalanced proportion condition of positive and negative samples
CN112884019B (en) An image-to-language method based on the fusion gate recurrent network model
CN116049405A (en) Language characterization model pre-training method based on generator-discriminant architecture
CN118132710B (en) Dialogue-level sentiment analysis method based on multi-scale sliding window and dynamic aggregation

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
WD01 Invention patent application deemed withdrawn after publication
WD01 Invention patent application deemed withdrawn after publication

Application publication date: 20190927