[go: up one dir, main page]

CN118072099A - Class increment learning method based on joint distillation playback strategy - Google Patents

Class increment learning method based on joint distillation playback strategy Download PDF

Info

Publication number
CN118072099A
CN118072099A CN202410269441.5A CN202410269441A CN118072099A CN 118072099 A CN118072099 A CN 118072099A CN 202410269441 A CN202410269441 A CN 202410269441A CN 118072099 A CN118072099 A CN 118072099A
Authority
CN
China
Prior art keywords
distillation
data
task
joint
distilled
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202410269441.5A
Other languages
Chinese (zh)
Inventor
莫建文
左丽芳
欧阳宁
林乐平
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Guilin University of Electronic Technology
Original Assignee
Guilin University of Electronic Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Guilin University of Electronic Technology filed Critical Guilin University of Electronic Technology
Priority to CN202410269441.5A priority Critical patent/CN118072099A/en
Publication of CN118072099A publication Critical patent/CN118072099A/en
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/74Image or video pattern matching; Proximity measures in feature spaces
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Computation (AREA)
  • Databases & Information Systems (AREA)
  • Computing Systems (AREA)
  • Artificial Intelligence (AREA)
  • General Health & Medical Sciences (AREA)
  • Medical Informatics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Multimedia (AREA)
  • Image Analysis (AREA)

Abstract

The invention discloses a class increment learning method based on a joint distillation playback strategy, which comprises the following steps: 1) Training a distillation model of an initial task; 2) Setting a buffer area to store distilled data of an initial task; 3) Establishing a combined distillation model based on distillation data of an old task and original data of a new task; 4) The co-distilled data is played back during the next task training. According to the technical scheme, the distillation data of the old task and the original data of the new task are utilized to carry out combined distillation, and category characteristic information of the old task is fully considered, so that characteristic ambiguity is avoided. On the basis, the same learning rate is used in the distillation process of new and old tasks, and an optimal buffer area size is searched for storing distilled data for playback, so that catastrophic forgetting in a class increment learning scene is effectively relieved, and the classification effect of images is improved.

Description

基于联合蒸馏回放策略的类增量学习方法Class-incremental learning method based on joint distillation and replay strategy

技术领域Technical Field

本发明涉及图像分类技术领域,具体涉及一种基于联合蒸馏回放策略的类增量学习方法。The present invention relates to the technical field of image classification, and in particular to a class incremental learning method based on a joint distillation replay strategy.

背景技术Background technique

近年来,类增量学习逐渐成为深度学习领域内的研究热点之一。增量学习有三种基本场景:任务增量学习、域增量学习和类增量学习。在任务增量学习场景中,模型需要逐步学习一组可明确区分的任务;在域增量学习场景中,模型需要在不同环境中学习同一问题;在类增量学习场景中,模型需要区分所有类。由于现实世界的场景往往会随着时间的推移产生新的类,类增量学习相比任务增量学习和域增量学习更接近现实世界的应用场景,也更具有挑战性。然而,随着新类别的不断增加,模型必须不断更新以纳入新任务的知识,而这一学习过程可能会破坏甚至完全覆盖旧任务的知识,从而导致灾难性遗忘问题。因此,解决类增量学习中的灾难性遗忘问题尤为重要。In recent years, incremental learning has gradually become one of the research hotspots in the field of deep learning. There are three basic scenarios for incremental learning: task incremental learning, domain incremental learning, and class incremental learning. In the task incremental learning scenario, the model needs to gradually learn a set of clearly distinguishable tasks; in the domain incremental learning scenario, the model needs to learn the same problem in different environments; in the class incremental learning scenario, the model needs to distinguish all classes. Since real-world scenarios often generate new classes over time, class incremental learning is closer to real-world application scenarios and more challenging than task incremental learning and domain incremental learning. However, as new categories continue to increase, the model must be continuously updated to incorporate the knowledge of new tasks, and this learning process may destroy or even completely overwrite the knowledge of old tasks, leading to catastrophic forgetting problems. Therefore, it is particularly important to solve the catastrophic forgetting problem in incremental learning.

目前的研究工作中,减轻类增量学习中灾难性遗忘的方法大致可分为三大类:基于正则化的方法、基于偏差校正的方法以及基于回放的方法。基于正则化的方法通过在新任务的学习中增加正则化项对旧任务中重要的参数施加约束,平衡新旧任务并且限制旧知识的遗忘。然而,实际应用中很难设计一个合理的度量来估计神经网络中每个参数的重要性。基于偏差校正的方法的出发点是解决增量学习中的偏差问题,避免增量学习网络偏向于新任务的类别,但校正过程中可能会引入新的偏差。In current research, methods to mitigate catastrophic forgetting in incremental learning can be roughly divided into three categories: regularization-based methods, bias correction-based methods, and replay-based methods. Regularization-based methods impose constraints on important parameters in old tasks by adding regularization terms in the learning of new tasks, balancing new and old tasks and limiting the forgetting of old knowledge. However, in practical applications, it is difficult to design a reasonable metric to estimate the importance of each parameter in a neural network. The starting point of the bias correction-based method is to solve the bias problem in incremental learning and prevent the incremental learning network from being biased towards the category of the new task, but new biases may be introduced during the correction process.

基于回放的方法根据回放内容的不同可分为两类:特征回放和数据回放。基于特征回放的类增量学习方法通过保存和回放深度特征空间的旧类别样本特征或者旧类别原型向量来减轻遗忘。基于数据回放的类增量学习方法通过保存和回放旧任务的范例、生成数据来减轻遗忘。随着数据集蒸馏技术的发展,数据集蒸馏技术可以与类增量学习场景相结合。数据集蒸馏技术的核心思想是优化合成数据集,将旧任务训练数据集的知识压缩成一个小的合成蒸馏数据集,在学习新任务时回放旧任务的合成蒸馏数据集。然而,当数据集蒸馏技术单独应用于每个任务时,合成图像没有考虑到旧任务数据集的类别特征,这可能会出现特征模糊的现象,尤其是当新任务的样本与旧任务的样本类别特征相似时,会加重特征模糊的情况,从而增加对这些类别的遗忘。Replay-based methods can be divided into two categories according to the different replay contents: feature replay and data replay. Feature replay-based incremental learning methods mitigate forgetting by saving and replaying old category sample features or old category prototype vectors in the deep feature space. Data replay-based incremental learning methods mitigate forgetting by saving and replaying examples of old tasks and generating data. With the development of dataset distillation technology, dataset distillation technology can be combined with incremental learning scenarios. The core idea of dataset distillation technology is to optimize the synthetic dataset, compress the knowledge of the old task training dataset into a small synthetic distilled dataset, and replay the synthetic distilled dataset of the old task when learning the new task. However, when the dataset distillation technology is applied to each task separately, the synthetic image does not take into account the category characteristics of the old task dataset, which may cause feature ambiguity, especially when the sample category characteristics of the new task are similar to those of the old task, which will aggravate the feature ambiguity and increase the forgetfulness of these categories.

所以,需要一种新的技术解决上述出现的问题。Therefore, a new technology is needed to solve the above problems.

发明内容Summary of the invention

本发明的目的是针对现有技术不足,而提供一种基于联合蒸馏回放策略的类增量学习方法。这种方法首先在新任务原始数据集的数据蒸馏过程中加入旧任务的蒸馏数据,这个联合蒸馏过程使得新旧任务样本之间的相似类别特征得到考虑。其次,保持新旧任务蒸馏过程中使用相同的学习率,并寻找一个最优的缓冲区规模以存储蒸馏数据。最后,在下一个任务的学习过程中回放存储在缓冲区中的联合蒸馏数据,实现类增量学习。The purpose of the present invention is to provide a class incremental learning method based on a joint distillation replay strategy to address the deficiencies of the prior art. This method first adds the distilled data of the old task to the data distillation process of the original data set of the new task. This joint distillation process allows similar category features between the new and old task samples to be considered. Secondly, the same learning rate is used in the distillation process of the new and old tasks, and an optimal buffer size is found to store the distilled data. Finally, the joint distilled data stored in the buffer is replayed during the learning process of the next task to achieve class incremental learning.

实现本发明目的的技术方案是:The technical solution for achieving the purpose of the present invention is:

基于联合蒸馏回放策略的类增量学习方法,与现有的技术不同的是,包括如下步骤:The incremental learning method based on the joint distillation replay strategy is different from the existing technology in that it includes the following steps:

1)训练初始任务的蒸馏模型:给定初始任务的训练数据集初始任务的蒸馏模型参数θ,l(xi,θ)为数据xi在蒸馏模型上的损失函数,蒸馏模型优化的目标是通过训练获得一个θ*使得模型在整个数据集上的损失最小,定义如下:1) Train the distillation model for the initial task: Given the training dataset for the initial task The distillation model parameter θ of the initial task, l( xi , θ) is the loss function of the data xi on the distillation model. The goal of distillation model optimization is to obtain a θ * through training so that the loss of the model on the entire data set is minimized, which is defined as follows:

假设蒸馏模型随机初始化参数为θ0,θ0满足p(θ0)的分布,使用标准的随机梯度下降来优化更新蒸馏数据和和学习率/>从而更新初始任务的蒸馏模型参数,假设现在进行第k次蒸馏模型参数更新,则:Assume that the distillation model is randomly initialized with parameters θ 0 , θ 0 satisfies the distribution of p(θ 0 ), and use standard stochastic gradient descent to optimize and update the distillation data. and learning rate/> Thereby updating the distillation model parameters of the initial task. Assuming that the k-th distillation model parameter update is now performed, then:

此时,初始任务的蒸馏模型的优化目标变为最小化损失函数从而得到初始任务的蒸馏数据:At this point, the optimization objective of the distillation model of the initial task becomes to minimize the loss function Thus, the distilled data of the initial task is obtained:

2)设置一个缓冲区存储初始任务的蒸馏数据:初始化缓冲区将步骤1)中的所述初始任务的蒸馏数据/>存储在缓冲区中,则:2) Set up a buffer to store the distillation data of the initial task: Initialize the buffer The distillation data of the initial task in step 1) Stored in a buffer, then:

3)建立基于旧任务的蒸馏数据与新任务的原始数据的联合蒸馏模型:假设在类增量学习场景中,神经网络需要持续学习T个任务,除了对初始任务进行简单的数据集蒸馏之外,对剩余的每个任务都进行联合蒸馏,并将获得的联合蒸馏数据存储在步骤2)中的所述缓冲区中,为了使模型在类增量学习场景中能持续学习,联合蒸馏过程中需要使用固定的学习率η,过程如下公式所示:3) Establish a joint distillation model based on the distilled data of the old task and the original data of the new task: Assuming that in a class incremental learning scenario, the neural network needs to continuously learn T tasks. In addition to a simple data set distillation for the initial task, each of the remaining tasks is jointly distilled, and the obtained joint distillation data is stored in the buffer in step 2). In order to enable the model to continue learning in the class incremental learning scenario, a fixed learning rate η needs to be used in the joint distillation process. The process is shown in the following formula:

其中,和/>分别表示前一任务的蒸馏数据和当前任务的蒸馏数据,/>表示当前任务的原始训练数据集,/>表示前一任务的蒸馏数据和当前任务的原始训练数据集在蒸馏模型参数下的损失;in, and/> Respectively represent the distilled data of the previous task and the distilled data of the current task, /> Represents the original training dataset of the current task, /> Represents the loss of the distilled data of the previous task and the original training data set of the current task under the parameters of the distilled model;

4)在下一个任务训练过程中回放联合蒸馏数据:随着下一个任务数据流的到来,为了不遗忘旧任务的知识,从缓冲区取出步骤3)中的所述旧任务的联合蒸馏数据,将联合蒸馏数据和新任务的原始训练数据一起输入到分类网络中,训练的过程可以表示为:4) Replay the joint distillation data during the next task training process: With the arrival of the next task data stream, in order not to forget the knowledge of the old task, the joint distillation data of the old task in step 3) is taken out from the buffer, and the joint distillation data and the original training data of the new task are input into the classification network together. The training process can be expressed as:

其中,表示每一步s从训练数据集/>中选取的一个批量数据,/>表示分类网络的参数,/>是从缓冲区取出的旧任务的联合蒸馏数据集,包括分类网络已学习到的所有类别的蒸馏数据。in, Indicates that each step s is from the training data set/> A batch of data selected from / > Represents the parameters of the classification network,/> It is a joint distilled dataset of old tasks taken from the buffer, including the distilled data of all categories that the classification network has learned.

步骤2)中所述缓冲区存储过程为:在联合蒸馏过程结束后,将旧任务的联合蒸馏数据存储在所设置好的缓冲区中,并且为了使模型能更充分地学习到旧类别的特征信息,缓冲区规模大小设置为100,即联合蒸馏后每个旧类保存100张蒸馏图像,减少对旧类的遗忘,缓冲区中存储的蒸馏数据可以表示为:The buffer storage process in step 2) is as follows: after the joint distillation process is completed, the joint distillation data of the old task is stored in the set buffer, and in order to enable the model to learn the feature information of the old category more fully, the buffer size is set to 100, that is, after the joint distillation, 100 distilled images are saved for each old category to reduce the forgetting of the old category. The distillation data stored in the buffer can be expressed as:

其中,表示对初始任务进行简单的数据集蒸馏后得到的蒸馏数据,表示对旧任务的蒸馏数据与新任务的原始数据联合蒸馏后得到的联合蒸馏数据。in, Represents the distilled data obtained after a simple data set distillation of the initial task. Represents the joint distilled data obtained by jointly distilling the distilled data of the old task and the original data of the new task.

步骤3)中所述联合蒸馏过程具体为:初始化当前任务的蒸馏数和蒸馏模型参数,在优化器每次迭代时选取一批数据/>接着用蒸馏数据/>对联合蒸馏模型参数进行K=1,2,…,k次梯度下降更新,然后计算批数据xq在蒸馏模型参数下的损失并更新蒸馏数据:The joint distillation process in step 3) is as follows: Initialize the distillation number of the current task and distillation model parameters, selecting a batch of data at each iteration of the optimizer/> Then use the distillation data/> Perform K=1, 2, …, k gradient descent updates on the joint distillation model parameters, then calculate the loss of the batch data xq under the distillation model parameters and update the distillation data:

其中,p(θ0)为初始化权重的分布,批数据xq的批次大小为n,优化器迭代Q次。Where p(θ 0 ) is the distribution of the initialization weights, the batch size of the batch data x q is n, and the optimizer iterates Q times.

步骤4)中所述联合蒸馏回放过程为:当分类网络学习新任务时,回放旧任务的联合蒸馏数据,即当前输入到分类网络中的训练数据集为则基于联合蒸馏回放策略的类增量学习过程可以表示为:The joint distillation playback process in step 4) is as follows: when the classification network learns a new task, the joint distillation data of the old task is played back, that is, the training data set currently input to the classification network is Then the class incremental learning process based on the joint distillation replay strategy can be expressed as:

因此,在所有任务学习结束后,分类网络参数更新为缓冲区更新为/> Therefore, after all tasks are learned, the classification network parameters are updated to The buffer is updated as />

本技术方案使用联合蒸馏的策略更新旧任务的蒸馏数据,可以避免蒸馏过程中的特征模糊;在联合蒸馏的过程中保持新旧任务使用相同的学习率,从而保证类增量学习的有效性。This technical solution uses a joint distillation strategy to update the distillation data of the old task, which can avoid feature ambiguity in the distillation process; in the process of joint distillation, the new and old tasks are kept using the same learning rate, thereby ensuring the effectiveness of incremental learning.

这种方法不再是单独处理每个学习任务的蒸馏数据,而是在新任务原始数据集的数据蒸馏过程中加入旧任务的蒸馏数据。在此基础上,寻找一个最优的缓冲区规模以存储联合蒸馏数据,从而维持对历史任务的记忆。最后在新任务到来时,回放旧任务的联合蒸馏数据。所提的联合蒸馏回放策略能有效地缓解灾难性遗忘,提升在类增量学习场景的分类性能。This method no longer processes the distilled data of each learning task separately, but adds the distilled data of the old task to the data distillation process of the original data set of the new task. On this basis, an optimal buffer size is found to store the joint distilled data to maintain the memory of historical tasks. Finally, when a new task arrives, the joint distilled data of the old task is replayed. The proposed joint distillation replay strategy can effectively alleviate catastrophic forgetting and improve the classification performance in the incremental learning scenario.

附图说明BRIEF DESCRIPTION OF THE DRAWINGS

图1为实施例提供的一种基于联合蒸馏回放策略的类增量学习方法的流程示意图。FIG1 is a flow chart of a class incremental learning method based on a joint distillation replay strategy provided in an embodiment.

图2为实施例提供的一种联合蒸馏方法的流程示意图。FIG2 is a schematic flow diagram of a combined distillation method provided in an embodiment.

具体实施方式Detailed ways

下面结合附图和实施例对本发明的内容作进一步的阐述,但不是对本发明的限定。The content of the present invention is further described below in conjunction with the drawings and embodiments, but the present invention is not limited thereto.

实施例:Example:

参照图1.基于联合蒸馏回放策略的类增量学习方法,与现有的技术不同的是,包括如下步骤:Referring to Figure 1, the incremental learning method based on the joint distillation replay strategy is different from the existing technology in that it includes the following steps:

1)训练初始任务的蒸馏模型:给定初始任务的训练数据集初始任务的蒸馏模型参数θ,l(xi,θ)为数据xi在蒸馏模型上的损失函数,蒸馏模型优化的目标是通过训练获得一个θ*使得模型在整个数据集上的损失最小,定义如下:1) Train the distillation model for the initial task: Given the training dataset for the initial task The distillation model parameter θ of the initial task, l( xi , θ) is the loss function of the data xi on the distillation model. The goal of distillation model optimization is to obtain a θ * through training so that the loss of the model on the entire data set is minimized, which is defined as follows:

假设蒸馏模型随机初始化参数为θ0,θ0满足p(θ0)的分布,使用标准的随机梯度下降来优化更新蒸馏数据和和学习率/>从而更新初始任务的蒸馏模型参数,假设现在进行第k次蒸馏模型参数更新,则:Assume that the distillation model is randomly initialized with parameters θ 0 , θ 0 satisfies the distribution of p(θ 0 ), and use standard stochastic gradient descent to optimize and update the distillation data. and learning rate/> Thus, the distillation model parameters of the initial task are updated. Assuming that the k-th distillation model parameter update is now performed, then:

此时,初始任务的蒸馏模型的优化目标变为最小化损失函数从而得到初始任务的蒸馏数据:At this point, the optimization objective of the distillation model of the initial task becomes to minimize the loss function Thus, the distilled data of the initial task is obtained:

2)设置一个缓冲区存储初始任务的蒸馏数据:初始化缓冲区将步骤1)中的所述初始任务的蒸馏数据/>存储在缓冲区中,则:2) Set up a buffer to store the distillation data of the initial task: Initialize the buffer The distillation data of the initial task in step 1) Stored in a buffer, then:

3)建立基于旧任务的蒸馏数据与新任务的原始数据的联合蒸馏模型:参照图2,假设在类增量学习场景中,神经网络需要持续学习T个任务,除了对初始任务进行简单的数据集蒸馏之外,对剩余的每个任务t∈{2,...,T}都进行联合蒸馏,并将获得的联合蒸馏数据存储在步骤2)中的所述缓冲区中,为了使模型在类增量学习场景中能持续学习,联合蒸馏过程中需要使用固定的学习率η,过程如下公式所示:3) Establish a joint distillation model based on the distilled data of the old task and the original data of the new task: Referring to Figure 2, assuming that in a class incremental learning scenario, the neural network needs to continuously learn T tasks. In addition to a simple data set distillation for the initial task, each of the remaining tasks t∈{2,...,T} is jointly distilled, and the obtained joint distillation data is stored in the buffer in step 2). In order to enable the model to continue learning in a class incremental learning scenario, a fixed learning rate η needs to be used in the joint distillation process. The process is shown in the following formula:

其中,和/>分别表示前一任务的蒸馏数据和当前任务的蒸馏数据,/>表示当前任务的原始训练数据集,/>表示前一任务的蒸馏数据和当前任务的原始训练数据集在蒸馏模型参数下的损失;in, and/> Respectively represent the distilled data of the previous task and the distilled data of the current task, /> Represents the original training dataset of the current task, /> Represents the loss of the distilled data of the previous task and the original training data set of the current task under the parameters of the distilled model;

4)在下一个任务训练过程中回放联合蒸馏数据:随着下一个任务数据流的到来,为了不遗忘旧任务的知识,从缓冲区取出步骤3)中的所述旧任务的联合蒸馏数据,将联合蒸馏数据和新任务的原始训练数据一起输入到分类网络中,训练的过程可以表示为:4) Replay the joint distillation data during the next task training process: With the arrival of the next task data stream, in order not to forget the knowledge of the old task, the joint distillation data of the old task in step 3) is taken out from the buffer, and the joint distillation data and the original training data of the new task are input into the classification network together. The training process can be expressed as:

其中,表示每一步s从训练数据集/>中选取的一个批量数据,/>表示分类网络的参数,/>是从缓冲区取出的旧任务的联合蒸馏数据集,包括分类网络已学习到的所有类别的蒸馏数据。in, Indicates that each step s is from the training data set/> A batch of data selected from / > Represents the parameters of the classification network,/> It is a joint distilled dataset of old tasks taken from the buffer, including the distilled data of all categories that the classification network has learned.

步骤2)中所述缓冲区存储过程为:在联合蒸馏过程结束后,将旧任务的联合蒸馏数据存储在所设置好的缓冲区中,并且为了使模型能更充分地学习到旧类别的特征信息,缓冲区规模大小设置为100,即联合蒸馏后每个旧类保存100张蒸馏图像,减少对旧类的遗忘,缓冲区中存储的蒸馏数据可以表示为:The buffer storage process in step 2) is as follows: after the joint distillation process is completed, the joint distillation data of the old task is stored in the set buffer, and in order to enable the model to learn the feature information of the old category more fully, the buffer size is set to 100, that is, after the joint distillation, 100 distilled images are saved for each old category to reduce the forgetting of the old category. The distillation data stored in the buffer can be expressed as:

其中,表示对初始任务进行简单的数据集蒸馏后得到的蒸馏数据,表示对旧任务的蒸馏数据与新任务的原始数据联合蒸馏后得到的联合蒸馏数据。in, Represents the distilled data obtained after a simple data set distillation of the initial task. Represents the joint distilled data obtained by jointly distilling the distilled data of the old task and the original data of the new task.

参照图2,步骤3)中所述联合蒸馏过程具体为:初始化当前任务的蒸馏数和蒸馏模型参数,在优化器每次迭代时选取一批数据/>接着用蒸馏数据/>对联合蒸馏模型参数进行K=1,2,…,k次梯度下降更新,然后计算批数据xq在蒸馏模型参数下的损失并更新蒸馏数据:Referring to FIG. 2 , the joint distillation process in step 3) is specifically as follows: Initialize the distillation number of the current task and distillation model parameters, selecting a batch of data at each iteration of the optimizer/> Then use the distillation data/> Perform K=1, 2, …, k gradient descent updates on the joint distillation model parameters, then calculate the loss of the batch data xq under the distillation model parameters and update the distillation data:

其中,p(θ0)为初始化权重的分布,批数据xq的批次大小为n,优化器迭代Q次。Where p(θ 0 ) is the distribution of the initialization weights, the batch size of the batch data x q is n, and the optimizer iterates Q times.

步骤4)中所述联合蒸馏回放过程为:当分类网络学习新任务时,回放旧任务的联合蒸馏数据,即当前输入到分类网络中的训练数据集为则基于联合蒸馏回放策略的类增量学习过程可以表示为:The joint distillation playback process in step 4) is as follows: when the classification network learns a new task, the joint distillation data of the old task is played back, that is, the training data set currently input to the classification network is Then the class incremental learning process based on the joint distillation replay strategy can be expressed as:

因此,在所有任务学习结束后,分类网络参数更新为缓冲区更新为/> Therefore, after all tasks are learned, the classification network parameters are updated to The buffer is updated as />

Claims (4)

1.一种基于联合蒸馏回放策略的类增量学习方法,其特征在于,包括如下步骤:1. A class incremental learning method based on a joint distillation replay strategy, characterized by comprising the following steps: 1)训练初始任务的蒸馏模型:给定初始任务的训练数据集初始任务的蒸馏模型参数θ,l(xi,θ)为数据xi在蒸馏模型上的损失函数,蒸馏模型优化的目标是通过训练获得一个θ*使得模型在整个数据集上的损失最小,定义如下:1) Train the distillation model for the initial task: Given the training dataset for the initial task The distillation model parameter θ of the initial task, l( xi , θ) is the loss function of the data xi on the distillation model. The goal of distillation model optimization is to obtain a θ * through training so that the loss of the model on the entire data set is minimized, which is defined as follows: 假设蒸馏模型随机初始化参数为θ0,θ0满足p(θ0)的分布,使用标准的随机梯度下降来优化更新蒸馏数据和和学习率/>从而更新初始任务的蒸馏模型参数,假设现在进行第k次蒸馏模型参数更新,则:Assume that the distillation model is randomly initialized with parameters θ 0 , θ 0 satisfies the distribution of p(θ 0 ), and use standard stochastic gradient descent to optimize and update the distillation data. and learning rate/> Thereby updating the distillation model parameters of the initial task. Assuming that the k-th distillation model parameter update is now performed, then: 此时,初始任务的蒸馏模型的优化目标变为最小化损失函数从而得到初始任务的蒸馏数据:At this point, the optimization objective of the distillation model of the initial task becomes to minimize the loss function Thus, the distilled data of the initial task is obtained: 2)设置一个缓冲区存储初始任务的蒸馏数据:初始化缓冲区将步骤1)中的所述初始任务的蒸馏数据/>存储在缓冲区中,则:2) Set up a buffer to store the distillation data of the initial task: Initialize the buffer The distillation data of the initial task in step 1) Stored in a buffer, then: 3)建立基于旧任务的蒸馏数据与新任务的原始数据的联合蒸馏模型:假设在类增量学习场景中,神经网络需要持续学习T个任务,除了对初始任务进行简单的数据集蒸馏之外,对剩余的每个任务t∈{2,...,T}都进行联合蒸馏,并将获得的联合蒸馏数据存储在步骤2)中的所述缓冲区中,为了使模型在类增量学习场景中能持续学习,联合蒸馏过程中需要使用固定的学习率η,过程如下公式所示:3) Establish a joint distillation model based on the distilled data of the old task and the original data of the new task: Assuming that in a class incremental learning scenario, the neural network needs to continuously learn T tasks. In addition to a simple data set distillation for the initial task, each of the remaining tasks t∈{2,...,T} is jointly distilled, and the obtained joint distillation data is stored in the buffer in step 2). In order to enable the model to continue learning in a class incremental learning scenario, a fixed learning rate η needs to be used in the joint distillation process. The process is shown in the following formula: 其中,和/>分别表示前一任务的蒸馏数据和当前任务的蒸馏数据,/>表示当前任务的原始训练数据集,/>表示前一任务的蒸馏数据和当前任务的原始训练数据集在蒸馏模型参数下的损失;in, and/> Respectively represent the distilled data of the previous task and the distilled data of the current task, /> Represents the original training dataset of the current task, /> Represents the loss of the distilled data of the previous task and the original training data set of the current task under the parameters of the distilled model; 4)在下一个任务训练过程中回放联合蒸馏数据:随着下一个任务数据流的到来,为了不遗忘旧任务的知识,从缓冲区取出步骤3)中的所述旧任务的联合蒸馏数据,将联合蒸馏数据和新任务的原始训练数据一起输入到分类网络中,训练的过程可以表示为:4) Replay the joint distillation data during the next task training process: With the arrival of the next task data stream, in order not to forget the knowledge of the old task, the joint distillation data of the old task in step 3) is taken out from the buffer, and the joint distillation data and the original training data of the new task are input into the classification network together. The training process can be expressed as: 其中,表示每一步s从训练数据集/>中选取的一个批量数据,/>表示分类网络的参数,/>是从缓冲区取出的旧任务的联合蒸馏数据集,包括分类网络已学习到的所有类别的蒸馏数据。in, Indicates that each step s is from the training data set/> A batch of data selected from / > Represents the parameters of the classification network,/> It is a joint distilled dataset of old tasks taken from the buffer, including the distilled data of all categories that the classification network has learned. 2.根据权利要求1所述的基于联合蒸馏回放策略的类增量学习方法,其特征在于,步骤2)中所述缓冲区存储过程为:在联合蒸馏过程结束后,将旧任务的联合蒸馏数据存储在所设置好的缓冲区中,并且为了使模型能更充分地学习到旧类别的特征信息,缓冲区规模大小设置为100,即联合蒸馏后每个旧类保存100张蒸馏图像,减少对旧类的遗忘,缓冲区中存储的蒸馏数据可以表示为:2. The class incremental learning method based on the joint distillation playback strategy according to claim 1 is characterized in that the buffer storage process in step 2) is: after the joint distillation process is completed, the joint distillation data of the old task is stored in the set buffer, and in order to enable the model to learn the feature information of the old category more fully, the buffer size is set to 100, that is, after the joint distillation, 100 distilled images are saved for each old class to reduce the forgetting of the old class. The distillation data stored in the buffer can be expressed as: 其中,表示对初始任务进行简单的数据集蒸馏后得到的蒸馏数据,/>表示对旧任务的蒸馏数据与新任务的原始数据联合蒸馏后得到的联合蒸馏数据。in, Represents the distilled data obtained after a simple data set distillation of the initial task, /> Represents the joint distilled data obtained by jointly distilling the distilled data of the old task and the original data of the new task. 3.根据权利要求1所述的基于联合蒸馏回放策略的类增量学习方法,其特征在于,步骤3)中所述联合蒸馏过程具体为:初始化当前任务的蒸馏数和蒸馏模型参数,在优化器每次迭代时选取一批数据/>接着用蒸馏数据/>对联合蒸馏模型参数进行K=1,2,…,k次梯度下降更新,然后计算批数据xq在蒸馏模型参数下的损失并更新蒸馏数据:3. The incremental learning method based on the joint distillation playback strategy according to claim 1 is characterized in that the joint distillation process in step 3) is specifically: initializing the distillation data of the current task and distillation model parameters, selecting a batch of data at each iteration of the optimizer/> Then use the distillation data/> Perform K=1, 2, …, k gradient descent updates on the joint distillation model parameters, then calculate the loss of the batch data xq under the distillation model parameters and update the distillation data: 其中,p(θ0)为初始化权重的分布,批数据xq的批次大小为n,优化器迭代Q次。Where p(θ 0 ) is the distribution of the initialization weights, the batch size of the batch data x q is n, and the optimizer iterates Q times. 4.根据权利要求1所述的基于联合蒸馏回放策略的类增量学习方法,其特征在于,步骤4)中所述联合蒸馏回放过程为:当分类网络学习新任务时,回放旧任务的联合蒸馏数据,即当前输入到分类网络中的训练数据集为则基于联合蒸馏回放策略的类增量学习过程可以表示为:4. The incremental learning method based on the joint distillation replay strategy according to claim 1 is characterized in that the joint distillation replay process in step 4) is: when the classification network learns a new task, the joint distillation data of the old task is replayed, that is, the training data set currently input to the classification network is Then the class incremental learning process based on the joint distillation replay strategy can be expressed as: 因此,在所有任务学习结束后,分类网络参数更新为缓冲区更新为/> Therefore, after all tasks are learned, the classification network parameters are updated to The buffer is updated as />
CN202410269441.5A 2024-03-11 2024-03-11 Class increment learning method based on joint distillation playback strategy Pending CN118072099A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202410269441.5A CN118072099A (en) 2024-03-11 2024-03-11 Class increment learning method based on joint distillation playback strategy

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202410269441.5A CN118072099A (en) 2024-03-11 2024-03-11 Class increment learning method based on joint distillation playback strategy

Publications (1)

Publication Number Publication Date
CN118072099A true CN118072099A (en) 2024-05-24

Family

ID=91101609

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202410269441.5A Pending CN118072099A (en) 2024-03-11 2024-03-11 Class increment learning method based on joint distillation playback strategy

Country Status (1)

Country Link
CN (1) CN118072099A (en)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN118586475A (en) * 2024-08-05 2024-09-03 浙江浙能电力股份有限公司萧山发电厂 A federated category incremental learning modeling method based on stable feature prototypes

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN118586475A (en) * 2024-08-05 2024-09-03 浙江浙能电力股份有限公司萧山发电厂 A federated category incremental learning modeling method based on stable feature prototypes

Similar Documents

Publication Publication Date Title
CN109948029B (en) Neural network self-adaptive depth Hash image searching method
Prabhu et al. Online continual learning without the storage constraint
CN110704636B (en) An Improved Vector Representation Method of Knowledge Graph Based on Node2vec
CN111898728A (en) Team robot decision-making method based on multi-Agent reinforcement learning
CN117523295B (en) Passive domain adaptive image classification method based on class guide element learning
CN118072099A (en) Class increment learning method based on joint distillation playback strategy
CN112131403B (en) Knowledge graph representation learning method in dynamic environment
Fu et al. Me-d2n: Multi-expert domain decompositional network for cross-domain few-shot learning
CN111563590A (en) An Active Learning Method Based on Generative Adversarial Models
CN113112397A (en) Image style migration method based on style and content decoupling
CN112434552B (en) Neural network model adjustment method, device, equipment and storage medium
KR102579686B1 (en) Method for transforming an image step by step taking into account angle changes
CN113626610A (en) Knowledge graph embedding method and device, computer equipment and storage medium
CN116091867A (en) A model training, image recognition method, device, equipment and storage medium
CN116401377B (en) A temporal knowledge graph reasoning method based on diffusion probability distribution
CN117494790A (en) Incremental learning method based on multi-level knowledge distillation
CN117058436A (en) Class increment image classification method based on dual attention vision transducer network
WO2023240583A1 (en) Cross-media corresponding knowledge generating method and apparatus
CN115496174A (en) Method for optimizing network representation learning, model training method and system
CN116051591B (en) Strip steel image threshold segmentation method based on improved cuckoo search algorithm
CN116563635B (en) Image classification system based on category attribute modeling
KR102746256B1 (en) Method for transferring style of visual data
KR102822628B1 (en) Method for obtaining environment containing virtual objects by using a neural network model
CN114781642B (en) A method and device for generating cross-media correspondence knowledge
CN117829254A (en) A reinforcement learning method and device for transferring offline strategy to online learning

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination