Disclosure of Invention
In order to solve the defects in the prior art, the invention provides a 3D medical image segmentation method, a system, electronic equipment and a computer readable storage medium based on paired attention, wherein a 3D medical image segmentation model is designed based on a paired attention transducer (Paired Attention Transformer, PAT) module, the dimension of a space dimension is reduced, and channels and space information are effectively learned in a 3D feature map, so that model segmentation performance is improved while model parameter quantity is reduced and model calculation speed is increased.
In a first aspect, the present invention provides a paired attention based 3D medical image segmentation method;
A paired attention based 3D medical image segmentation method comprising:
Acquiring a 3D medical image to be segmented;
inputting the 3D medical image to be segmented into a trained 3D medical image segmentation model for processing so as to obtain a segmentation result;
the 3D medical image segmentation model comprises an encoder and a decoder, wherein the encoder is connected with the decoder, the encoder comprises a first encoding module and a plurality of second encoding modules which are sequentially connected, the first encoding module comprises a patch embedding layer and a paired attention transformer module, any second encoding module comprises a paired attention transformer module and a downsampling layer, the decoder comprises a plurality of decoding modules which are sequentially connected, and any decoding module comprises a jump connection module, a paired attention transformer module and an upsampling module.
Further, the paired attention transformer module is composed of a normalization layer, a multi-layer perceptron and a multi-head paired attention module, wherein data input into the paired attention transformer module sequentially passes through the normalization layer, the multi-head paired attention module and the multi-layer perceptron.
Preferably, the multi-head pairing attention module is used for capturing the channel dependency relationship of the input data through the channel attention, obtaining the channel attention output characteristic diagram, obtaining the space attention output characteristic diagram through the space dependency relationship of the space attention capturing input data, merging the channel attention output characteristic diagram and the space attention output characteristic diagram with the original 3D voxel characteristic of the input data, and carrying out 3D convolution to obtain the deep characteristic representation of the input data.
Preferably, the channel attention formula in the multi-head pairing attention module is expressed as:
Wherein X C represents the output obtained through channel attention, Q channel is a channel query vector, K channel is a channel key vector, V channel is a channel value, and d is the size of each vector;
The spatial attention formula in the multi-head paired attention module is expressed as:
Wherein X s is the output obtained by spatial attention, Q spatial is the projection of the spatial query vector, K spatial_proj is the projection of the spatial key vector, V spatial_proj is the projection of the spatial value, and d is the size of each vector.
Further, the first encoding module is used for carrying out embedding processing and segmentation on the 3D medical image to be segmented, obtaining a 3D voxel feature image and adding position codes, the second encoding modules are used for carrying out pairing attention transformation and downsampling operation on the 3D voxel feature image so as to realize sequential dimension reduction of the 3D voxel feature image, the decoding modules are used for carrying out upsampling processing on the dimension-reduced 3D voxel feature image and splicing the dimension-reduced 3D voxel feature image with 3D voxel feature images of different dimensions, then carrying out pairing attention transformation processing so as to realize sequential dimension increase of the spliced 3D voxel feature image, and outputting a predicted final segmentation result through convolution operation.
Further, the inputting the 3D medical image to be segmented into the trained 3D medical image segmentation model for processing includes:
Carrying out embedding processing and segmentation on the 3D medical image to be segmented, obtaining a 3D voxel characteristic map and adding a position code;
performing pairing attention transformation and downsampling operation on the 3D voxel feature map so as to realize sequential dimension reduction of the 3D voxel feature map;
and carrying out up-sampling processing on the 3D voxel feature map after the dimension reduction, splicing the 3D voxel feature map with different dimensions, and carrying out pairing attention transformation processing to realize sequential dimension increase of the spliced 3D voxel feature map, and outputting a predicted final segmentation result through convolution operation.
Further, the training mode for the 3D medical image segmentation model includes:
acquiring training data;
setting AdamW an optimizer, and adaptively adjusting the learning rate;
And training the 3D medical image segmentation model according to training data, the learning rate and a preset loss function.
In a second aspect, the present invention provides a paired attention based 3D medical image segmentation system;
A paired attention based 3D medical image segmentation system comprising:
the acquisition module is used for acquiring the 3D medical image to be segmented;
The 3D medical image segmentation module is used for inputting the 3D medical image to be segmented into a trained 3D medical image segmentation model for processing so as to obtain a segmentation result;
the 3D medical image segmentation model comprises an encoder and a decoder, wherein the encoder is connected with the decoder, the encoder comprises a first encoding module and a plurality of second encoding modules which are sequentially connected, the first encoding module comprises a patch embedding layer and a paired attention transformer module, any second encoding module comprises a paired attention transformer module and a downsampling layer, the decoder comprises a plurality of decoding modules which are sequentially connected, and any decoding module comprises a jump connection module, a paired attention transformer module and an upsampling module.
In a third aspect, the present invention provides an electronic device;
an electronic device comprising a memory and a processor and computer instructions stored on the memory and running on the processor, which when executed by the processor, perform the steps of the paired attention-based 3D medical image segmentation method described above.
In a fourth aspect, the present invention provides a computer-readable storage medium;
A computer readable storage medium storing computer instructions which, when executed by a processor, perform the steps of the paired attention based 3D medical image segmentation method described above.
Compared with the prior art, the invention has the beneficial effects that:
according to the technical scheme provided by the invention, the 3D medical image segmentation model PAT-Unet is designed based on Paired Attention Transformer modules, the dependency information among channels and the rich information on the space dimension are effectively combined by Paired Attention Transformer modules, the segmentation effect is improved, the parameter quantity of the model is reduced, and the model reasoning speed is accelerated.
Compared with the existing method, the technical scheme provided by the invention can capture the detail texture information in the image, and the model is reduced by more than 67% in terms of the parameter number and the operation times while obtaining a more accurate segmentation map.
Detailed Description
It should be noted that the following detailed description is exemplary and is intended to provide further explanation of the invention. Unless defined otherwise, all technical and scientific terms used herein have the same meaning as commonly understood by one of ordinary skill in the art to which this invention belongs.
It is noted that the terminology used herein is for the purpose of describing particular embodiments only and is not intended to be limiting of exemplary embodiments according to the present invention. As used herein, unless the context clearly indicates otherwise, the singular forms also are intended to include the plural forms, and furthermore, it is to be understood that the terms "comprises" and "comprising" and any variations thereof are intended to cover non-exclusive inclusions, such as, for example, processes, methods, systems, products or devices that comprise a series of steps or units, are not necessarily limited to those steps or units that are expressly listed, but may include other steps or units that are not expressly listed or inherent to such processes, methods, products or devices.
Embodiments of the invention and features of the embodiments may be combined with each other without conflict.
Example 1
In the 3D medical image segmentation method in the prior art, the segmentation precision of the model is improved, but the robustness of the model is poor and the requirement on computing resources is high at the cost of increasing the parameter quantity and the computing quantity of the model, so that the segmentation efficiency and the segmentation precision of the 3D medical image are affected. Accordingly, the present invention provides a 3D medical image segmentation method based on paired attention.
Next, a detailed description will be given of the paired attention-based 3D medical image segmentation method disclosed in this embodiment with reference to fig. 1 to 5. The 3D medical image segmentation method based on the paired attention comprises the following steps of:
S1, acquiring a 3D medical image to be segmented.
S2, inputting the 3D medical image to be segmented into a trained 3D medical image segmentation model for processing so as to obtain a segmentation result. The 3D medical image segmentation model comprises an encoder and a decoder, the encoder is connected with the decoder, the encoder comprises a first encoding module and 3 second encoding modules which are sequentially connected, the first encoding module comprises a patch embedding (Patch Embedding) layer and a pairing attention transformer (Paired Attention Transformer, PAT) module, any second encoding module comprises a pairing attention transformer (Paired Attention Transformer, PAT) module and a downsampling layer, the decoder comprises 4 decoding modules which are sequentially connected, and any decoding module comprises a jump connection module, a pairing attention transformer (Paired Attention Transformer, PAT) module and an upsampling module.
The specific flow of inputting the 3D medical image to be segmented into the trained 3D medical image segmentation model for processing is as follows:
inputting the 3D medical image to be segmented into a decoder, and processing the 3D medical image to be segmented by the decoder, wherein the method comprises the following steps of:
In this embodiment, the 3D medical image to be segmented is a first 3D voxel feature map with a size of 128×128×64×1, i.e. height×width×depth×channel format, where 64 is the Depth of the data volume, 128×128 represents the Height and Width of the data volume, and 1 is the Channel number of the feature volume image.
In the first stage of the encoder, i.e. the first encoding module, first a first 3D voxel feature map is subjected to Patch Embedding (patch embedding) processing by a patch embedding layer, a three-dimensional data volume is split into a number of small data blocks of a low-dimensional representation, and position encoding is added to these data blocks. And secondly, inputting the encoded first 3D voxel characteristic map into a Paired Attention Transformer module for processing, and carrying out focusing segmentation on the medical image characteristic region to obtain a second 3D voxel characteristic map with the size of 32 multiplied by 16 multiplied by 32.
In the second stage of the encoder, i.e. the first second encoder module, the second 3D voxel feature map is first subjected to a step size of 2, and a downsampling layer formed by 3D convolution with the convolution kernel size of 3 multiplied by 3 and normalization operation is processed by Paired Attention Transformer modules to obtain a third 3D voxel characteristic diagram with the size of 16 multiplied by 8 multiplied by 64.
In a third phase of the encoder, i.e. the second encoder module, the third 3D voxel feature map is first subjected to a step size of 2, and a downsampling layer formed by 3D convolution with the convolution kernel size of 3 multiplied by 3 and normalization operation is processed by Paired Attention Transformer modules to obtain a fourth 3D voxel characteristic diagram with the size of 8 multiplied by 4 multiplied by 128.
In a fourth phase of the encoder, i.e. the third second encoder module, the fourth 3D voxel feature map is first subjected to a step size of 2, and a downsampling layer formed by 3D convolution with the convolution kernel size of 3 multiplied by 3 and normalization operation is processed by Paired Attention Transformer modules to obtain a fifth 3D voxel characteristic diagram with the size of 4 multiplied by 2 multiplied by 256.
The processing of the 3D voxel feature map by the Paired Attention Transformer module is the same as the processing of the Paired Attention Transformer module in the decoder section described below, and is described in detail below, and is not repeated here.
The first 3D voxel feature map, the second 3D voxel feature map, the third 3D voxel feature map, the fourth 3D voxel feature map and the fifth 3D voxel feature map which are obtained by processing of the decoder are input into the decoder for processing, and the processing steps are that in the first stage (a first decoding module) of the decoder, the fifth 3D voxel feature map is up-sampled to the 3D voxel feature map with the size of 8 multiplied by 4 multiplied by 128 through an up-sampling layer, then the 3D voxel feature map is spliced with the fourth 3D voxel feature map through a jump connection module, and then the sixth 3D voxel feature map with the size of 8 multiplied by 4 multiplied by 128 is obtained by processing through a Paired Attention Transformer module.
In the second stage (second decoding module) of the decoder, the sixth 3D voxel feature map is up-sampled to a 3D voxel feature map with the size of 16×16×8×64 through an up-sampling layer, then the 3D voxel feature map up-sampled to the size of 16×16×8×64 is spliced with the third 3D voxel feature map through a jump connection module, and then a seventh 3D voxel feature map with the size of 16×16×8×64 is obtained through processing of Paired Attention Transformer modules.
In a third stage of the decoder, the third decoding module, the seventh 3D voxel feature map is first upsampled via an upsampling layer to a 3D voxel feature map of size 32 x 16 x 32, and splicing the 3D voxel characteristic map up-sampled to the size of 32 multiplied by 16 multiplied by 32 with the second 3D voxel characteristic map through a jump connection module, and processing the spliced 3D voxel characteristic map through a Paired Attention Transformer module to obtain an eighth 3D voxel characteristic map with the size of 32 multiplied by 16 multiplied by 32.
In a fourth stage of the decoder, the fourth decoding module, the eighth 3D voxel feature map is first upsampled via an upsampling layer to a 3D voxel feature map of size 128 x 64 x 1, and splicing the 3D voxel feature map up-sampled to 128 multiplied by 64 multiplied by 1 with the result of the 3D convolution processing of the first 3D voxel feature map with the convolution kernel size of 3 multiplied by 3 by 1 through a jump connection module, and obtaining the final prediction output of the model by the 3D convolution processing of the spliced result with the convolution kernel size of 3 multiplied by 3, namely the final segmentation result of the medical feature region: and a ninth 3D voxel feature map.
The Paired Attention Transformer module in the above operation is shown in fig. 3, and mainly consists of a normalization Layer (Layer Norm), a multi-Layer perceptron (MLP), and a multi-head pairing attention (MPA) module. The multi-head pairing attention module is shown in fig. 4, and consists of two parts, namely channel attention and space attention, which capture channel dependency and space dependency respectively.
The channel attention operation in the multi-head paired attention first transposes the vector Q channel and then performs scaling dot product operation with the vector K channel, and uses Softmax to measure the similarity between each feature and the rest of the channel features, so as to obtain the channel attention map. And performing dot product operation on the channel attention map and the vector V channel to capture the dependency relationship among the channels in the feature map so as to obtain the output of the channel attention. The channel attention formula in the multi-head paired attention is as (1):
(1) Where X C represents the output obtained through channel attention. Q channel,Kchannel and V channel represent a channel query vector, a channel key vector, and a channel value, respectively, and d is the size of each vector.
The spatial attention operation in the multi-head paired attention projects V spatial and K spatial with dimensions hwd×c onto spatial dimensions p×c, respectively, to obtain V spatial_proj and K spatial_pro, respectively. And performing scaling dot product operation on the transposed K spatial_pro with the dimension of PxC and Q spatial, and processing by using Softmax to obtain a space attention feature map with the dimension of HWDxP. And finally, performing dot product operation on the spatial attention characteristic map and the projected V spatial_proj to generate a spatial attention characteristic map with the dimension HWD multiplied by C. The spatial attention formula in the multi-head paired attention is expressed as (2):
in equation (2), Q spatial、Kspatial_proj、Vspatial_proj represents a spatial query vector, a projection of a spatial key vector, and a projection of a spatial value, respectively, and d is a size of each vector.
And (3) integrating the output after the channel and the spatial attention with the original 3D voxel characteristics, and performing 3D convolution operation on the integrated result to further extract deeper characteristic representation. The final output of the multi-head paired attention module is shown in equation (3):
X=Conv1( Conv3( ( Xs+Xc ) ) ) (3)
Wherein X C and X S represent output feature maps of channel and spatial attention, respectively, conv1 represents a 3D convolution with a convolution kernel size of 1X 1, conv3 represents a 3D convolution with a convolution kernel size of 3X 3.
Further, training the 3D medical image segmentation model includes:
step 1, training data are acquired.
Two disclosed 3D medical image segmentation datasets, synapse and ACDC, were selected as training data.
Wherein the Synapse data set consists of CT scans of 30 patient abdominal organs. Referring to the partitioning of the dataset by the TransUnet model, the 18 sets of data are partitioned into training sets and the remaining 12 sets of data are partitioned into test sets. The model is given in the experimental results section with a Dice similarity score (DSC) and 95% Hausdorff distance (HD 95) data for 8 abdominal organs, spleen, left kidney, pancreas, stomach, aorta, liver, gall bladder and right kidney. An automated cardiac diagnostic challenge data set (ACDC) was split into 70 training samples, 10 validation samples, and 20 test samples.
And 2, preprocessing and enhancing data.
First, two data sets of Synapse and ACDC are acquired and the input three-dimensional data volume is cut into 128×128×64 sizes.
And secondly, carrying out random rotation and random overturning operation with 50% probability on the cut training image and the corresponding real segmentation image. The data preprocessing and enhancing operation can effectively make up the defect of small number of training images in the original data set, so that the capability of the model for resisting over fitting is enhanced, and the robustness of the model is improved.
And step 3, inputting the data after data enhancement into a 3D medical image segmentation model for processing to obtain a final model prediction graph.
Wherein the loss is calculated using a combination of a plurality of loss functions, the loss functions being used to calculate the error between the predicted value of the model and the true segmented image.
The sum of the cross entropy loss and the soft position loss is used in this embodiment to calculate the loss between the model predicted 3D voxel result and the true value, so that the advantages of both loss functions can be integrated. The loss function is shown in equation (4):
(4) Where V is the total number of 3D voxel feature maps, N is the number of predicted classes, Z v,j is the true value of the jth class at voxel V, and P v,j is the predicted probability of the model output of the jth class at voxel V.
Setting AdamW optimizers, setting initial learning rate to 0.001, setting weight attenuation to 3e-5, setting weight attenuation coefficient to prevent model from over fitting, and self-adapting learning rate adjustment to speed up model convergence.
And calculating the loss between the model segmentation prediction graph and the real result according to the combined loss function, and performing gradient updating and learning rate self-adaptive adjustment by using AdamW optimizers, wherein 8 samples are trained at a time, and the total training is 1000 rounds. The average Dice score and the evaluation index result of 95% Hausdorff distance are given for the Synase data set, and the ACDC data set only uses the average Dice score as the evaluation index.
The comparative models used for the experiments were the current up-to-date medical image segmentation models UNETR and nnFormer, with experimental comparative data for these models as shown in table 1, and visual comparisons with other models on ACDC datasets as shown in fig. 5.
Table 1 experimental comparison of the method with other models
Example two
The embodiment discloses a 3D medical image segmentation system based on paired attention, comprising:
the acquisition module is used for acquiring the 3D medical image to be segmented;
The 3D medical image segmentation module is used for inputting the 3D medical image to be segmented into a trained 3D medical image segmentation model for processing so as to obtain a segmentation result;
the 3D medical image segmentation model comprises an encoder and a decoder, wherein the encoder is connected with the decoder, the encoder comprises a first encoding module and a plurality of second encoding modules which are sequentially connected, the first encoding module comprises a patch embedding layer and a paired attention transformer module, any second encoding module comprises a paired attention transformer module and a downsampling layer, the decoder comprises a plurality of decoding modules which are sequentially connected, and any decoding module comprises a jump connection module, a paired attention transformer module and an upsampling module.
It should be noted that, the acquiring module and the 3D medical image segmentation module correspond to the steps in the first embodiment, and the modules are the same as the examples and application scenarios implemented by the corresponding steps, but are not limited to the disclosure in the first embodiment. It should be noted that the modules described above may be implemented as part of a system in a computer system, such as a set of computer-executable instructions.
Example III
An electronic device according to a third embodiment of the present invention includes a memory, a processor, and computer instructions stored in the memory and running on the processor, where the computer instructions, when executed by the processor, complete the steps of the 3D medical image segmentation method based on paired attention.
Example IV
A fourth embodiment of the present invention provides a computer readable storage medium storing computer instructions that, when executed by a processor, perform the steps of the paired attention-based 3D medical image segmentation method described above.
The present invention is described with reference to flowchart illustrations and/or block diagrams of methods, apparatus (systems) and computer program products according to embodiments of the invention. It will be understood that each flow and/or block of the flowchart illustrations and/or block diagrams, and combinations of flows and/or blocks in the flowchart illustrations and/or block diagrams, can be implemented by computer program instructions. These computer program instructions may be provided to a processor of a general purpose computer, special purpose computer, embedded processor, or other programmable data processing apparatus to produce a machine, such that the instructions, which execute via the processor of the computer or other programmable data processing apparatus, create means for implementing the functions specified in the flowchart flow or flows and/or block diagram block or blocks.
These computer program instructions may also be stored in a computer-readable memory that can direct a computer or other programmable data processing apparatus to function in a particular manner, such that the instructions stored in the computer-readable memory produce an article of manufacture including instruction means which implement the function specified in the flowchart flow or flows and/or block diagram block or blocks.
These computer program instructions may also be loaded onto a computer or other programmable data processing apparatus to cause a series of operational steps to be performed on the computer or other programmable apparatus to produce a computer implemented process such that the instructions which execute on the computer or other programmable apparatus provide steps for implementing the functions specified in the flowchart flow or flows and/or block diagram block or blocks.
The foregoing embodiments are directed to various embodiments, and details of one embodiment may be found in the related description of another embodiment.
The above description is only of the preferred embodiments of the present invention and is not intended to limit the present invention, but various modifications and variations can be made to the present invention by those skilled in the art. Any modification, equivalent replacement, improvement, etc. made within the spirit and principle of the present invention should be included in the protection scope of the present invention.