Model training and recognition method and device, electronic equipment and storage medium

文档序号:1043298 发布日期:2020-10-09 浏览:16次 中文

阅读说明:本技术 一种模型训练和识别方法、装置、电子设备及存储介质 (Model training and recognition method and device, electronic equipment and storage medium ) 是由 秦海宁 单培 张瑞飞 于 2020-06-29 设计创作,主要内容包括:本申请提供一种模型训练和识别方法、装置、电子设备及存储介质,用于改善对生成对抗网络进行训练时出现梯度爆炸或者梯度消失的问题。该方法包括:获得文本数据和文本数据对应的文本类别;以文本数据为训练数据,以文本类别为训练标签,对生成对抗网络进行训练,获得生成对抗网络模型,生成对抗网络模型包括:生成器和判别器;其中,对生成对抗网络进行训练,包括:获得状态推测矩阵,状态推测矩阵表征生成器的重要程度;对状态观测矩阵和状态推测矩阵进行卡尔曼滤波运算,获得生成对抗网络模型的损失值,状态观测矩阵表征判别器的重要程度;根据生成对抗网络模型的准确率调整生成对抗网络模型的损失值。(The application provides a model training and recognition method, a model training and recognition device, an electronic device and a storage medium, which are used for solving the problem that gradient explosion or gradient disappearance occurs when a generation countermeasure network is trained. The method comprises the following steps: acquiring text data and a text category corresponding to the text data; training a countermeasure network by taking text data as training data and a text category as a training label to obtain a generated countermeasure network model, wherein the generation of the countermeasure network model comprises the following steps: a generator and a discriminator; wherein generating the antagonistic network comprises: obtaining a state inference matrix, wherein the state inference matrix represents the importance degree of a generator; performing Kalman filtering operation on the state observation matrix and the state inference matrix to obtain a loss value for generating a countermeasure network model, wherein the state observation matrix represents the importance degree of the discriminator; and adjusting the loss value of the generated countermeasure network model according to the accuracy of the generated countermeasure network model.)

1. A method of model training, comprising:

acquiring text data and a text category corresponding to the text data;

training a countermeasure network by taking the text data as training data and the text category as a training label to obtain a generated countermeasure network model, wherein the generation of the countermeasure network model comprises the following steps: a generator and a discriminator;

wherein the training of the generation of the countermeasure network comprises:

obtaining a state inference matrix, wherein the state inference matrix represents the importance degree of the discriminator;

performing Kalman filtering operation on a state observation matrix and the state inference matrix to obtain a loss value of the generated countermeasure network model, wherein the state observation matrix represents the importance degree of the generator;

and adjusting the loss value of the generated countermeasure network model according to the accuracy rate of the generated countermeasure network model.

2. The method of claim 1, wherein adjusting the loss value of the generative countermeasure network model according to the accuracy of the generative countermeasure network model comprises:

judging whether the accuracy of the generation of the confrontation network model gradually converges or not;

if so, resetting the loss value of the generated countermeasure network model to a first direction;

if not, resetting the loss value of the generated countermeasure network model to a second direction, wherein the first direction is opposite to the second direction.

3. The method of claim 1, wherein obtaining the state inference matrix comprises:

obtaining an accuracy of the generation of the antagonistic network model;

and calculating the state inference matrix according to the accuracy of the generation of the confrontation network model.

4. The method of claim 3, wherein obtaining the accuracy of the generation of the antagonistic network model comprises:

predicting the text data by using the generated confrontation network model to obtain a prediction label;

and calculating the accuracy of the generation of the confrontation network model according to the prediction label and the training label.

5. The method according to any of claims 1-4, wherein the generative confrontation network model is the WGAN-GP model.

6. An identification method, comprising:

obtaining text content;

using the confrontation network model generated according to any one of claims 1 to 5 to identify the category of the text content, and obtaining the corresponding category of the text content.

7. A model training apparatus, comprising:

the data category obtaining module is used for obtaining the text data and the text category corresponding to the text data;

a network model training module, configured to train a generative confrontation network by using the text data as training data and the text category as a training label, to obtain a generative confrontation network model, where the generative confrontation network model includes: a generator and a discriminator;

wherein, the network model training module comprises:

a presumption matrix obtaining module, configured to obtain a state presumption matrix, where the state presumption matrix represents an importance degree of the discriminator;

the Kalman filtering module is used for carrying out Kalman filtering operation on a state observation matrix and the state speculation matrix to obtain a loss value of the generated countermeasure network model, and the state observation matrix represents the importance degree of the generator;

and the loss value adjusting module is used for adjusting the loss value of the generated countermeasure network model according to the accuracy of the generated countermeasure network model.

8. The apparatus of claim 7, wherein the loss value adjustment module comprises:

a gradual convergence judging module, configured to judge whether the accuracy of generating the confrontation network model gradually converges;

the first direction resetting module is used for resetting the loss value of the generated countermeasure network model to the first direction if the accuracy rate of the generated countermeasure network model gradually converges;

and the second direction resetting module is used for resetting the loss value of the generated countermeasure network model to a second direction if the accuracy rate of the generated countermeasure network model does not gradually converge, and the first direction is opposite to the second direction.

9. An electronic device, comprising: a processor and a memory, the memory storing machine-readable instructions executable by the processor, the machine-readable instructions, when executed by the processor, performing the method of any of claims 1-6.

10. A storage medium, characterized in that the storage medium has stored thereon a computer program which, when being executed by a processor, carries out the method according to any one of claims 1-6.

Technical Field

The application relates to the technical field of artificial intelligence and machine learning, in particular to a model training and recognition method and device, an electronic device and a storage medium.

Background

Model training, which refers to training a target model according to training data, and the specific training mode may include, according to the condition of the training data: supervised learning, unsupervised learning and the like.

Supervised learning (also called Supervised training) is a method of machine learning, which can learn or establish a learning mode (learning model) or learning function from training data, and infer new examples according to the mode.

Unsupervised learning (unsupervised learning), also known as unsupervised training, refers to a method of machine learning, which automatically classifies or groups input data without giving a training example labeled in advance; the unsupervised learning mainly comprises the following steps: cluster analysis (cluster analysis), relationship rule (association rule), dimension reduction (dimensional reduction), and counterlearning (adaptive learning), and the like.

At present, when training the generation countermeasure network, in order to improve the stability of training the generation countermeasure network, the loss value between the prediction tag and the training tag is generally directly cut off in a constant interval range. The method greatly limits the expressive ability of generating the countermeasure network model, the generation of the countermeasure network is difficult to simulate a complex function, and the problem of gradient explosion or gradient disappearance is found in the specific practical process when the method is used for training the generation of the countermeasure network.

Disclosure of Invention

An object of the embodiments of the present application is to provide a model training and recognition method, device, electronic device, and storage medium, which are used to solve the problem that a gradient explosion or a gradient disappearance occurs when an generative countermeasure network is trained.

The embodiment of the application provides a model training method, which comprises the following steps: acquiring text data and a text category corresponding to the text data; training a countermeasure network by taking text data as training data and a text category as a training label to obtain a generated countermeasure network model, wherein the generation of the countermeasure network model comprises the following steps: a generator and a discriminator; wherein generating the antagonistic network comprises: obtaining a state inference matrix, wherein the state inference matrix represents the importance degree of a generator; performing Kalman filtering operation on the state observation matrix and the state inference matrix to obtain a loss value for generating a countermeasure network model, wherein the state observation matrix represents the importance degree of the discriminator; and adjusting the loss value of the generated countermeasure network model according to the accuracy of the generated countermeasure network model. In the implementation process, when the countermeasure network model is trained, Kalman filtering operation is carried out on the state observation matrix and the obtained state speculation matrix to obtain a loss value of the generated countermeasure network model, and the loss value of the generated countermeasure network model is adjusted according to the accuracy of the generated countermeasure network model; that is to say, in the process of training the generation countermeasure network, the loss value of the generation countermeasure network model is dynamically adjusted according to the accuracy of the generation countermeasure network model, so that the threshold range of the loss value is found as soon as possible and is cut off, the generation countermeasure network model is made to converge faster and more stably, and the problem of gradient explosion or gradient disappearance when the generation countermeasure network is trained is effectively improved.

Optionally, in this embodiment of the present application, adjusting the loss value of the generation countermeasure network model according to the accuracy of the generation countermeasure network model includes: judging whether the accuracy rate of generating the confrontation network model gradually converges or not; if yes, resetting the loss value of the generated countermeasure network model to the first direction; if not, resetting the loss value of the generated countermeasure network model to a second direction, wherein the first direction is opposite to the second direction.

In the implementation process, if the accuracy of generating the confrontation network model gradually converges, resetting the loss value of the generated confrontation network model to the first direction; if the accuracy of the generation of the confrontation network model does not gradually converge, resetting the loss value of the generation of the confrontation network model to a second direction opposite to the first direction; that is, the loss value of the generated countermeasure network model is dynamically adjusted according to the accuracy of the generated countermeasure network model, so that the threshold range of the loss value is found as soon as possible and is truncated, so that the generated countermeasure network model converges faster and more stably.

Optionally, in an embodiment of the present application, obtaining a state inference matrix includes: obtaining the accuracy of generating the confrontation network model; and calculating a state inference matrix according to the accuracy of the generated countermeasure network model. In the implementation process, the accuracy of generating the confrontation network model is obtained; calculating a state inference matrix according to the accuracy of the generated countermeasure network model; thereby effectively increasing the speed of obtaining the state inference matrix.

Optionally, in this embodiment of the present application, obtaining an accuracy of generating the confrontation network model includes: predicting the text data by using the generated confrontation network model to obtain a prediction label; and calculating the accuracy of generating the confrontation network model according to the prediction label and the training label. In the implementation process, the text data is predicted by using the generated confrontation network model to obtain a prediction label; calculating the accuracy of generating the confrontation network model according to the prediction label and the training label; thereby effectively improving the speed of obtaining the accuracy of generating the confrontation network model.

The embodiment of the application further provides an identification method, which comprises the following steps: obtaining text content; and identifying the category of the text content by using the generated confrontation network model to obtain the category corresponding to the text content. In the implementation process, the text content is obtained; identifying the category of the text content by using the trained generation confrontation network model to obtain the category corresponding to the text content; therefore, the speed of obtaining the category corresponding to the text content is effectively improved.

The embodiment of the present application further provides a model training device, including: the data category obtaining module is used for obtaining the text data and the text category corresponding to the text data; the network model training module is used for training the generation countermeasure network by taking the text data as training data and the text category as a training label to obtain a generation countermeasure network model, and the generation countermeasure network model comprises the following steps: a generator and a discriminator; wherein, the network model training module includes: the device comprises a presumption matrix obtaining module, a state presumption matrix generating module and a state presumption matrix generating module, wherein the presumption matrix obtaining module is used for obtaining a state presumption matrix which represents the importance degree of a generator; the Kalman filtering module is used for carrying out Kalman filtering operation on the state observation matrix and the state speculation matrix to obtain a loss value for generating the countermeasure network model, and the state observation matrix represents the importance degree of the discriminator; and the loss value adjusting module is used for adjusting the loss value of the generated countermeasure network model according to the accuracy rate of the generated countermeasure network model.

Optionally, in an embodiment of the present application, the loss value adjusting module includes: the gradual convergence judging module is used for judging whether the accuracy of the generation of the confrontation network model is gradually converged; the first direction resetting module is used for resetting the loss value of the generated countermeasure network model to the first direction if the accuracy rate of the generated countermeasure network model gradually converges; and the second direction resetting module is used for resetting the loss value of the generated countermeasure network model to a second direction if the accuracy rate of the generated countermeasure network model does not gradually converge, and the first direction is opposite to the second direction.

Optionally, in an embodiment of the present application, the speculation matrix obtaining module includes: the accuracy rate obtaining module is used for obtaining the accuracy rate of the generation of the confrontation network model; and the guess matrix calculation module is used for calculating the state guess matrix according to the accuracy of the generation of the confrontation network model.

Optionally, in an embodiment of the present application, the accuracy obtaining module includes: the predicted label obtaining module is used for predicting the text data by using the generated confrontation network model to obtain a predicted label; and the accuracy calculation module is used for calculating the accuracy of the generation of the confrontation network model according to the prediction label and the training label.

An embodiment of the present application further provides an identification apparatus, including: the text content obtaining module is used for obtaining text content; and the identification category obtaining module is used for identifying the category of the text content by using the generated confrontation network model and obtaining the category corresponding to the text content.

An embodiment of the present application further provides an electronic device, including: a processor and a memory, the memory storing processor-executable machine-readable instructions, the machine-readable instructions when executed by the processor performing the method as described above.

Embodiments of the present application also provide a storage medium having a computer program stored thereon, where the computer program is executed by a processor to perform the method as described above.

Drawings

In order to more clearly illustrate the technical solutions of the embodiments of the present application, the drawings that are required to be used in the embodiments of the present application will be briefly described below, it should be understood that the following drawings only illustrate some embodiments of the present application and therefore should not be considered as limiting the scope, and for those skilled in the art, other related drawings can be obtained according to the drawings without inventive efforts.

FIG. 1 is a schematic diagram of a model training method provided by an embodiment of the present application;

FIG. 2 is a schematic diagram of an identification method provided by an embodiment of the present application;

FIG. 3 is a schematic structural diagram of a model training apparatus provided in an embodiment of the present application;

fig. 4 is a schematic structural diagram of an identification device provided in an embodiment of the present application;

fig. 5 is a schematic structural diagram of an electronic device provided in an embodiment of the present application.

Detailed Description

The technical solution in the embodiments of the present application will be clearly and completely described below with reference to the drawings in the embodiments of the present application.

Before describing the model training and recognition method provided by the embodiment of the present application, some concepts related to the embodiment of the present application are described, and some concepts related to the embodiment of the present application are as follows:

artificial Intelligence (AI), a new technical science that refers to the study and development of theories, methods, techniques and application systems for simulating, extending and expanding human Intelligence; artificial intelligence is a branch of computer science that attempts to understand the essence of intelligence and produces a new intelligent machine that can react in a manner similar to human intelligence, and research in this field includes robotics, language recognition, image recognition, natural language processing, and expert systems, among others.

Machine learning refers to a branch of studying human learning behavior in the field of artificial intelligence. By referring to the scientific or theoretical viewpoints of cognition science, biology, philosophy, statistics, information theory, control theory, computational complexity and the like, the method explores the human cognition rule and the learning process through the basic methods of induction, generalization, specialization, analogy and the like, and establishes various algorithms capable of being automatically improved through experience, so that the computer system has the capability of automatically learning specific knowledge and skill; the main method of machine learning comprises: decision trees, bayesian learning, instance-based learning, genetic algorithms, rule learning, interpretation-based learning, etc.

A Generative Adaptive Network (GAN), also called a Generative adaptive Network, is a method of unsupervised learning in machine learning, in which two neural networks play games with each other.

Gradient disappearance means that in a neural network, the learning rate of the current hidden layer is lower than that of the subsequent hidden layer, that is, the classification accuracy rate is reduced as the number of hidden layers is increased; this phenomenon is called disappearance of the gradient.

Gradient explosion means that in a neural network, the learning rate of the current hidden layer is lower than that of the subsequent hidden layer, namely, the classification accuracy rate is reduced as the number of the hidden layers is increased; this phenomenon is called gradient explosion.

Loss function (loss function), also known as cost function, refers to a function that maps an event (an element in a sample space) to a real number expressing the economic or opportunity cost associated with its event, whereby some "cost" of visual representation is associated with the event; the loss function may determine how the training process "penalizes" the differences between the predicted and true results of the network, and various different loss functions are applicable to different types of tasks, such as: softmax cross-entropy loss functions are often used to select one out of multiple classes, while Sigmoid cross-entropy loss functions are often used for multiple independent binary problems, and euclidean loss functions are often used for problems with a result that takes on a range of arbitrary real numbers.

A server refers to a device that provides computing services over a network, such as: x86 server and non-x 86 server, non-x 86 server includes: mainframe, minicomputer, and UNIX server. Certainly, in a specific implementation process, the server may specifically select a mainframe or a minicomputer, where the mainframe refers to a dedicated processor that mainly supports a closed and dedicated device for providing Computing service of a UNIX operating system, and that uses Reduced Instruction Set Computing (RISC), single-length fixed-point instruction average execution speed (MIPS), and the like; a mainframe, also known as a mainframe, refers to a device that provides computing services using a dedicated set of processor instructions, an operating system, and application software.

Before introducing the model training and recognition method provided by the embodiment of the present application, the cause of gradient explosion or gradient disappearance in the training of the control embodiment when the countermeasure network is generated is analyzed, specifically, the control embodiment directly truncates the loss value between the prediction tag and the training tag in a constant interval range when the countermeasure network is generated, where the constant interval range specifically includes: the range interval of [ -0.01, 0.01], where truncation may also be understood as weight clipping (weight clipping), truncates or clips the loss value in a constant range interval, greatly limits the expressive ability of generating the countermeasure network model, causes that it is difficult to simulate a complex function for generating the countermeasure network, and after passing through a multi-layer network, the problem of gradient disappearance or gradient explosion easily occurs. The reason why the problem of gradient disappearance or gradient explosion occurs here is that the discriminator of GAN is a multilayer network; if the clipping threshold (clipping threshold) is set to be slightly smaller, the gradient becomes smaller every time the network passes through one layer, and the phenomenon of gradient disappearance appears after the network passes through multiple layers, wherein the gradient disappearance is also called exponential decay; on the contrary, if the clipping threshold is set slightly larger, the gradient becomes larger every time the network passes through one layer, and a gradient explosion phenomenon occurs after the network passes through multiple layers, which is also referred to as an index explosion. That is, the problem of gradient disappearance or gradient explosion occurs because the loss value that generates the countermeasure network model is truncated to a static and constant interval range.

In the method for training the generation countermeasure network provided by the embodiment of the application, the loss value of the generation countermeasure network model is dynamically adjusted according to the accuracy of the generation countermeasure network model, so that the threshold range of the loss value is found as soon as possible to be truncated, the generation countermeasure network model is enabled to converge faster and more stably, and the problem of gradient explosion or gradient disappearance during training the generation countermeasure network is effectively solved.

It should be noted that the model training and recognition method provided in the embodiments of the present application may be executed by an electronic device, where the electronic device refers to a device terminal having a function of executing a computer program or the server described above, and the device terminal includes, for example: a smart phone, a Personal Computer (PC), a tablet computer, a Personal Digital Assistant (PDA), a Mobile Internet Device (MID), a network switch or a network router, and the like.

Before describing the model training and recognition method provided in the embodiments of the present application, an application scenario to which the model training and recognition method is applicable is described, where the application scenario includes, but is not limited to: classifying text data in machine learning to obtain categories of the text data, wherein the categories specifically include: emotional tendency of the text, subject classification, main idea and the like.

Please refer to fig. 1, which is a schematic diagram of a model training method provided in the embodiment of the present application; the model training method can comprise the following steps:

step S100: and obtaining the text data and the text category corresponding to the text data.

The text data refers to language materials which are stored in the corpus and actually appear in the practical use of the language; the corpus is basic resources which take an electronic computer as a carrier to bear language knowledge; the text data here specifically includes, for example: text information such as texts, patent documents or patent documents in articles and textbooks on the network is the most common unstructured data, and contains a large amount of potential information.

The text type refers to a specific type of the text data, and the type of the text data is different according to different classification methods, specifically, for example: the text data may be classified according to emotion category of the text data, or may be classified according to topic or topic to which the text data belongs, and the list of topic classifications may include: law, politics, or society, etc.

The text data in step S100 and the text type corresponding to the text data may be separately obtained, specifically for example: manually collecting text data, and manually identifying text types corresponding to the text data; the text data and the text type may also be obtained together, for example, a training data packet formed by packing text types corresponding to the text data and the text data is obtained in a manner including: the first mode is to acquire a pre-stored training data packet, acquire the training data packet from a file system, or acquire the training data packet from a database; in a second mode, training data packets are received and obtained from other terminal equipment; in the third mode, a software such as a browser is used for obtaining the training data packet on the internet, or other application programs are used for accessing the internet to obtain the training data packet.

Step S200: and training the generation countermeasure network by taking the text data as training data and the text type as a training label to obtain a generation countermeasure network model.

It should be noted that, in the embodiment of the present application, for convenience of distinction, a trained neural network is referred to as a neural network model, and the neural network model, for example, a generation countermeasure network model, and the neural networks that are not trained are referred to as a certain network, for example, a generation countermeasure network; in fact, the neural network before training and the neural network model after training both have the same network structure, that is, the generation of the countermeasure network and the generation of the countermeasure network model both have the same network structure; the following describes the network architecture for generating the antagonistic network model:

the above-mentioned generation confrontation network model is made up of a generator (generator) and a discriminator (discriminator) formed by multilayer network; the generator randomly samples from a potential space (latency) as input data, and the output result of the generator needs to imitate a real sample in a training set as much as possible; the input data of the discriminator is the real sample or the output data of the generator (namely the output result of the generator), and the aim is to separate the output data of the generator and the real sample as far as possible; the generator should cheat the discriminator as much as possible, that is, the discriminator should distinguish the output data and the real sample of the generator as much as possible, the generator and the discriminator compete with each other to continuously adjust the parameters, and finally the purpose is to make the discriminator unable to judge whether the output result of the generator is real or not.

When the generation of the countermeasure network in step S200 is trained, the training data and the training labels may be divided into multiple batches for training, where the training labels include the text categories, the number of the training data and the training labels in each batch may also be adjusted according to specific situations, and the implementation of training the countermeasure network using the training data and the training labels in each batch may include the following steps:

step S210: a state inference matrix is obtained.

The state inference matrix refers to a matrix characterizing the importance of a generator, and may be represented by P in a formula, where different expressions are possible, such as: pkRepresents the a posteriori state inference matrix at the kth training where the training data and training labels are divided into multiple batches, and

Figure BDA0002561129640000091

representing a prior state inference matrix at the kth training of the split of training data and training labels into batches, where the difference between prior and posterior is that the prior state inference matrix is calculated without knowing the accuracy of the generation of the challenge network model for the current batch (e.g., the kth time)The posterior state inference matrix is obtained by calculating the accuracy of the generated confrontation network model of the batch (e.g. k-th time) under the condition that the accuracy of the generated confrontation network model of the batch (e.g. k-th time) is known, and the specific calculation process is described in detail below.

The above-mentioned embodiment of obtaining the state inference matrix in step S210 may include:

step S211: the accuracy of generating the confrontation network model is obtained.

The accuracy (accuraterate) of generating the confrontation network model is to input training data into the generated confrontation network in the process of generating the confrontation network for training, and then obtain a prediction label, wherein the prediction label is the correct probability of the training label.

The above-mentioned embodiment of obtaining the accuracy of generating the countermeasure network model in step S211 is, for example: predicting the text data by using the generated confrontation network model to obtain a prediction label; calculating the accuracy of generating the confrontation network model according to the prediction label and the training label; specific examples thereof include: in the process of dividing training data and training labels into multiple batches for training, 10 text data and 10 text labels are uniformly distributed in each batch, wherein the text labels are also called as training labels, the 10 text data are input as the training data to generate an antagonistic network, the antagonistic network generates 10 predicted labels, and if the specific values of the 10 predicted labels are equal to the specific values of the 10 text labels, the accuracy of the generated antagonistic network model is 100%; if the specific values of the 10 predicted labels are only 5 and the specific values of the 10 text labels are equal, the accuracy of generating the confrontation network model is 50%. In the implementation process, the text data is predicted by using the generated confrontation network model to obtain a prediction label; calculating the accuracy of generating the confrontation network model according to the prediction label and the training label; thereby effectively improving the speed of obtaining the accuracy of generating the confrontation network model.

Step S212: and calculating a state inference matrix according to the accuracy of the generated countermeasure network model.

The embodiment of calculating the state estimation matrix according to the accuracy of generating the countermeasure network model in step S212 includes: in an initial state, that is, when the confrontation network model is trained in a first batch divided into multiple batches by using the training data and the training labels, the state inference matrix may be determined directly according to the accuracy of generating the confrontation network model, specifically, for example, each value in the state inference matrix is set to the accuracy of generating the confrontation network model; when the confrontation network model is trained by using the batch after the training data and the training labels are divided into the first batch of the multiple batches, namely the second batch and the third batch of the multiple batches are used until the training data and the training labels of all the batches are used up, the state inference matrix is obtained by performing the Kalman filtering operation by using the accuracy of the generated confrontation network model in combination with the accuracy convergence condition of the generated confrontation network model, wherein the Kalman filtering operation is an iterative operation process, and therefore, the specific Kalman filtering operation process will be described in detail below.

Step S220: and performing Kalman filtering operation on the state observation matrix and the state speculation matrix to obtain a loss value for generating the countermeasure network model.

The state observation matrix refers to a matrix representing the importance degree of the discriminator, and may be represented by R in a formula, where the state observation matrix is obtained in a specific manner, for example: determining a state observation matrix according to the accuracy of the generated countermeasure network model, specifically for example: in the process of dividing training data and training labels into multiple batches for training, a discriminator in a generated confrontation network model is used for predicting each batch to obtain multiple prediction labels, the multiple prediction labels are multiplied by the accuracy of each batch of confrontation network model to obtain a one-dimensional label vector, and the one-dimensional label vector is converted into a state observation matrix according to the matrix format of the state observation matrix.

Of course, in the specific implementation process, the state observation matrix may also be determined according to the variation value of the accuracy of generating the countermeasure network model, specifically, for example: the change condition of the accuracy can be obtained according to the accuracy change value; if the change value of the accuracy rate is a negative value, that is, the accuracy rate is decreased, the last state observation matrix is adjusted downwards, namely the last state observation matrix is decreased according to the ratio of the change value of the accuracy rate to the accuracy rate (for example, the change value of the accuracy rate is divided by the accuracy rate to obtain the change ratio; if the accuracy change value is a positive value, that is, the accuracy is indicated to be increased, the last state observation matrix is adjusted upwards, that is, the last state observation matrix is increased according to the ratio of the accuracy change value to the accuracy (the specific increasing mode is similar to the decreasing mode), and the current state observation matrix is obtained.

Kalman filtering (Kalman filter) is a highly efficient recursive filter (autoregressive filter) that can estimate the state of a dynamic system from a series of incomplete and noisy measurements. In this embodiment, the measurement may be understood as the process of calculating the accuracy of the confrontation-generating network model, the dynamic system may be understood as the process of training the confrontation-generating network model, and the estimation of the state of the dynamic system may be understood as predicting the loss value of the confrontation-generating network model, where the whole process of training the confrontation-generating network model is to minimize the loss value, but the change of the loss value is not known in the training process of each batch; during the training process, there are many interference noises, such as: the error training label, the method for obtaining the loss value are unreasonable, or the set training hyper-parameter is unreasonable, and the like, and the interference noise can influence the loss value of the countermeasure generation network model; the Kalman filtering considers the joint distribution of each measurement at different time according to the value of each measurement at different time, and then generates the estimation of an unknown variable, so that the estimation method is more accurate than the estimation method based on single measurement; that is, according to the accuracy value calculated during training of each batch, the distribution situation that prediction errors exist in the predicted accuracy of each batch is considered, and the loss value of the antagonistic generation network model is dynamically predicted in combination with the distribution that interference errors exist in the actually calculated accuracy, so that the method for dynamically predicting the loss value is more accurate than the method for predicting the loss value only in consideration of the actually calculated accuracy.

The Loss value (Loss) for generating the countermeasure network model refers to a value that determines how to "penalize" the difference between the predicted result and the true result of the network in the training process for generating the countermeasure network model, and may also be understood as a value of the difference between the predicted label and the training label calculated from a Loss function for generating the countermeasure network model.

The above-mentioned many ways to calculate the difference between the predicted label and the training label include:

in the first way, KL divergence (KLD) is used to characterize the difference between the predicted tag and the training tag, where KL divergence is called relative entropy (relative entropy) in the information system, randomness (randomness) in a continuous time sequence, and information gain (information) in the statistical model inference, and information divergence (information divergence).

In the second mode, a JS divergence (JSenn Shannon divergence, JSD) is used for representing a difference value between the prediction tag and the training tag, wherein the JS divergence refers to the similarity of two probability distributions, and the problem of asymmetric KL divergence is solved based on a variation of the KL divergence.

In a third way, the difference value between the predicted label and the training label is characterized by using the Wasserstein distance, which is a distance between two probability distributions.

The implementation of the kalman filtering operation performed on the state observation matrix and the state inference matrix in step S220 may include:

according toPerforming Kalman filtering operation on the state observation matrix and the state conjecture matrix;

wherein k represents the training data and the training labels divided into the first batchk batches of the raw material to be treated,and

Figure BDA0002561129640000133

respectively representing the prior correct rate at the k-th training time and the k-1 st training time,

Figure BDA0002561129640000134

representing the posterior accuracy at the k-th training, where the difference between the prior and posterior is mentioned in the above description, where a represents the degree of correlation between the accuracy of the k-1 th batch and the accuracy of the k-th batch without noise interference, where a is changeable during each batch training, and B represents the degree of correlation between the control input parameters and the accuracy of the antagonistic generation network model;representing a prior state inference matrix, P, at the time of the kth trainingkAnd Pk-1Respectively representing a posteriori state estimation matrix at the K-th training time and the K-1-th training time, Q representing an interference noise covariance matrix, KkA coefficient representing Kalman filtering at the k-th training, H represents the degree of correlation between the loss value of the challenge-generating network model and the accuracy of the challenge-generating network model, R represents a state observation matrix, zk=Hxk+vkRepresenting a loss value against the generative network model, v herekRepresenting the interference noise in the process of obtaining the loss value against the generated network model, and I represents the identity matrix.

It will be understood that the coefficients of the kalman filtering described above are also referred to as kalman coefficients, the function of which includes: the sizes of the state estimation matrix P and the state observation matrix R are balanced to determine whether the generator is believed to be more than one point or the discriminator is believed to be more than one point, specifically, in the formulaIn the middle, if the state observation matrix R approachesAt 0, the greater the residual weight obtained by the kalman coefficient K, and correspondingly, the prior state inference matrix if in the kth trainingApproaching to 0, the smaller the residual weight obtained by the Kalman coefficient K is; the residual error weight refers to the importance degree of a Kalman coefficient in a formula, namely, the weight influencing whether the prior accuracy of the GAN model in the k-th training is close to the posterior accuracy is influenced, and the difference value between the prior accuracy and the posterior accuracy can be understood as the residual error; this residual weight that is generated against the network model is smaller if the prediction model is believed to be a bit more, and larger if the observation model is believed to be a bit more.

Step S230: and adjusting the loss value of the generated countermeasure network model according to the accuracy of the generated countermeasure network model.

The above-mentioned embodiment of step S230 may include:

step S231: and judging whether the accuracy rate of generating the confrontation network model gradually converges or not.

The embodiment of determining whether the accuracy of generating the countermeasure network model gradually converges in step S231 is, for example: judging whether the accuracy of generating the countermeasure network model gradually converges according to the historical data of the accuracy of generating the countermeasure network model, which may specifically include the following ways:

in the first mode, whether the accuracy of the generation of the confrontation network model gradually converges can be judged according to the slope of the historical data of the accuracy; specific examples thereof include: if the slope of the historical data of the accuracy rate is smaller than a preset threshold value, determining that the accuracy rate of the generated confrontation network model gradually converges; if the slope of the historical data of the accuracy is greater than or equal to a preset threshold, it is determined that the accuracy of the generation of the confrontation network model does not gradually converge, and the preset threshold can be set according to specific situations.

In the second mode, whether the accuracy of the generated confrontation network model gradually converges can be judged according to the change condition of the historical data of the accuracy in a preset period; specific examples thereof include: if the change rate of the historical data of the accuracy rate in a preset period is larger than a preset proportion, determining that the accuracy rate of the generated confrontation network model gradually converges; if the change rate of the historical data of the accuracy in the preset period is smaller than or equal to the preset proportion, determining that the accuracy of the generation of the confrontation network model does not gradually converge, wherein the preset proportion can be set according to specific conditions.

Step S232: if the accuracy of generating the confrontation network model gradually converges, the loss value of the confrontation network model is reset towards the first direction.

The embodiment of step S232 described above is, for example: if the accuracy of generating the confrontation network model gradually converges steadily, the loss value of generating the confrontation network model is increased as much as possible, for example: if the accuracy of the generated countermeasure network model is gradually stable and converged, the loss value of the generated countermeasure network model is set to be multiplied by 1.1 or multiplied by 1.01; of course, in a specific implementation, the loss value of the generated countermeasure network model may be multiplied by other numbers greater than 1.

Step S233: if the accuracy of the generation of the confrontation network model does not gradually converge, the loss value of the generation of the confrontation network model is reset to a second direction, and the first direction is opposite to the second direction.

The embodiment in step S233 described above is, for example: if the accuracy of generating the confrontation network model does not gradually converge, the loss value of generating the confrontation network model is reduced as much as possible, for example: if the accuracy of generating the confrontation network model does not gradually converge, multiplying the loss value of generating the confrontation network model by 0.99 or 0.999; of course, in a specific implementation process, the loss value of the generated countermeasure network model can be multiplied by other numbers smaller than 1. In the implementation process, if the accuracy of generating the confrontation network model gradually converges, resetting the loss value of the generated confrontation network model to the first direction; if the accuracy of the generation of the confrontation network model does not gradually converge, resetting the loss value of the generation of the confrontation network model to a second direction opposite to the first direction; that is, the loss value of the generated countermeasure network model is dynamically adjusted according to the accuracy of the generated countermeasure network model, so that the threshold range of the loss value is found as soon as possible and is truncated, so that the generated countermeasure network model converges faster and more stably.

In the implementation process, when the countermeasure network model is trained, Kalman filtering operation is carried out on the state observation matrix and the obtained state speculation matrix to obtain a loss value of the generated countermeasure network model, and the loss value of the generated countermeasure network model is adjusted according to the accuracy of the generated countermeasure network model; that is to say, in the process of training the generation countermeasure network, the loss value of the generation countermeasure network model is dynamically adjusted according to the accuracy of the generation countermeasure network model, so that the threshold range of the loss value is found as soon as possible and is cut off, the generation countermeasure network model is made to converge faster and more stably, and the problem of gradient explosion or gradient disappearance when the generation countermeasure network is trained is effectively improved.

Referring to fig. 2, the embodiment of the present application further provides a recognition method, after training generation of a confrontation network model, the generation of the confrontation network model may be applied to recognize the category of the text content, that is, after step S200, the method may further include the following steps:

step S300: text content is obtained.

The text content refers to information content stored in a text manner, and the text content specifically includes: text information such as texts, patent documents or patent documents in articles and textbooks on the network is the most common unstructured data, and contains a large amount of potential information.

The embodiment of obtaining the text content in step S300 includes: in the first mode, pre-stored text content is acquired, for example, the text content is acquired from a file system or the text content is acquired from a database; in a second mode, text content is received and obtained from other terminal equipment; in the third mode, a software such as a browser is used to obtain text content on the internet, or other application programs are used to access the internet to obtain text content.

Step S400: and identifying the category of the text content by using the generated confrontation network model to obtain the category corresponding to the text content.

The embodiment of the step S400 is, for example: identifying the category of the text content by using the generated confrontation network model to obtain the category corresponding to the text content; the generating of the countermeasure network model herein may specifically include: a GAN model, a WGAN (Wasserstein GAN) model, or a WGAN-GP (Wasserstein GAN-gradient dependency) model. In the implementation process, the text content is obtained; identifying the category of the text content by using the trained generation confrontation network model to obtain the category corresponding to the text content; therefore, the speed of obtaining the category corresponding to the text content is effectively improved.

Please refer to fig. 3, which is a schematic structural diagram of a model training apparatus provided in the embodiment of the present application; the embodiment of the present application provides a model training device 500, including:

a data category obtaining module 510, configured to obtain the text data and a text category corresponding to the text data.

A network model training module 520, configured to train the countermeasure network by using the text data as training data and using the text category as a training label, to obtain a generated countermeasure network model, where generating the countermeasure network model includes: a generator and a discriminator.

The network model training module 520 includes:

the guess matrix obtaining module 521 uses the obtained state guess matrix to characterize the importance of the generator.

And a kalman filtering module 522, configured to perform kalman filtering operation on the state observation matrix and the state inference matrix to obtain a loss value for generating the countermeasure network model, where the state observation matrix represents an importance degree of the discriminator.

A loss value adjusting module 523 configured to adjust a loss value of the generated countermeasure network model according to an accuracy of the generated countermeasure network model.

Optionally, in an embodiment of the present application, the loss value adjusting module includes:

and the gradual convergence judging module is used for judging whether the accuracy of the generation of the confrontation network model is gradually converged or not.

And the first direction resetting module is used for resetting the loss value of the generated countermeasure network model to the first direction if the accuracy rate of the generated countermeasure network model gradually converges.

And the second direction resetting module is used for resetting the loss value of the generated countermeasure network model to a second direction if the accuracy rate of the generated countermeasure network model does not gradually converge, and the first direction is opposite to the second direction.

Optionally, in an embodiment of the present application, the speculation matrix obtaining module includes:

and the accuracy rate obtaining module is used for obtaining the accuracy rate of the generation of the confrontation network model.

And the guess matrix calculation module is used for calculating the state guess matrix according to the accuracy of the generation of the confrontation network model.

Optionally, in an embodiment of the present application, the accuracy obtaining module includes:

and the prediction label obtaining module is used for predicting the text data by using the generated confrontation network model to obtain a prediction label.

And the accuracy calculation module is used for calculating the accuracy of the generation of the confrontation network model according to the prediction label and the training label.

Please refer to fig. 4, which illustrates a schematic structural diagram of an identification apparatus provided in the embodiment of the present application; an embodiment of the present application provides an identification apparatus 600, including:

a text content obtaining module 610, configured to obtain text content.

And the identification category obtaining module 620 is configured to identify the category of the text content by using the generated confrontation network model, and obtain a category corresponding to the text content.

It should be understood that the apparatus corresponds to the above-mentioned embodiment of the model training and recognition method, and is capable of executing the steps related to the above-mentioned embodiment of the method, and the specific functions of the apparatus can be referred to the above description, and the detailed description is appropriately omitted here to avoid repetition. The device includes at least one software function that can be stored in memory in the form of software or firmware (firmware) or solidified in the Operating System (OS) of the device.

Please refer to fig. 5, which illustrates a schematic structural diagram of an electronic device according to an embodiment of the present application. An electronic device 700 provided in an embodiment of the present application includes: a processor 710 and a memory 720, the memory 720 storing machine readable instructions executable by the processor 710, the machine readable instructions when executed by the processor 710 performing the method as above.

The embodiment of the present application further provides a storage medium 730, where the storage medium 730 stores a computer program, and the computer program is executed by the processor 710 to perform the above model training and recognition method.

The storage medium 730 may be implemented by any type of volatile or nonvolatile storage device or combination thereof, such as a Static Random Access Memory (SRAM), an Electrically Erasable Programmable Read-Only Memory (EEPROM), an Erasable Programmable Read-Only Memory (EPROM), a Programmable Read-Only Memory (PROM), a Read-Only Memory (ROM), a magnetic Memory, a flash Memory, a magnetic disk, or an optical disk.

In the embodiments provided in the present application, it should be understood that the disclosed apparatus and method may be implemented in other ways. The apparatus embodiments described above are merely illustrative, and for example, the flowchart and block diagrams in the figures illustrate the architecture, functionality, and operation of possible implementations of apparatus, methods and computer program products according to various embodiments of the present application. In this regard, each block in the flowchart or block diagrams may represent a module, segment, or portion of code, which comprises one or more executable instructions for implementing the specified logical function(s). It should also be noted that, in some alternative implementations, the functions noted in the block may occur out of the order noted in the figures. For example, two blocks shown in succession may, in fact, be executed substantially concurrently, or the blocks may sometimes be executed in the reverse order, depending upon the functionality involved. It will also be noted that each block of the block diagrams and/or flowchart illustration, and combinations of blocks in the block diagrams and/or flowchart illustration, can be implemented by special purpose hardware-based systems which perform the specified functions or acts, or combinations of special purpose hardware and computer instructions.

In addition, functional modules in the embodiments of the present application may be integrated together to form an independent part, or each module may exist separately, or two or more modules may be integrated to form an independent part.

In this document, relational terms such as first and second, and the like may be used solely to distinguish one entity or action from another entity or action without necessarily requiring or implying any actual such relationship or order between such entities or actions.

The above description is only an alternative embodiment of the embodiments of the present application, but the scope of the embodiments of the present application is not limited thereto, and any person skilled in the art can easily conceive of changes or substitutions within the technical scope of the embodiments of the present application, and all the changes or substitutions should be covered by the scope of the embodiments of the present application.

17页详细技术资料下载
上一篇:一种医用注射器针头装配设备
下一篇:一种风险预测方法、装置、电子设备及存储介质

网友询问留言

已有0条留言

还没有人留言评论。精彩留言会获得点赞!

精彩留言,会给你点赞!