Our proposal focuses on enhancing the partial volumetric representation, which is inherently limited compared to the original volumetric information, by utilizing the relational information among data. In contrast to existing works, which simply expose partial 3D spatial information to the model through basic slice extraction16,17 and feature19,20combination without addressing how these are utilized in lesion analysis, our method provides relationships in the data created in terms of 3D spatial information as a prior. This approach of determining the input for the 2D student model is termed 'partial input restriction'. Figure 2 provides a comprehensive schematic illustration of the experiment and the model's structure. In the following section, we will discuss in detail the 2D projection we adopted and the structure for 3D-to-2D KD.
2D projection alteration for 3D-to-2D KD
To demonstrate the effectiveness of 3D-to-2D KD, we prepared the model's input information through the following process. From the given 3D volumetric imaging, we re-sample the striatal region at its median level as pre-defined by slice indices and along three axes (axial, coronal, sagittal). When including adjacent slices, we acquired one adjacent slice19. We refer to this process as 'slice extraction'. The volumetric features contained in the resulting images from the previous process are handled in three major ways:
Single slice as input: The extracted slices are used directly as non-i.i.d. input data for training the diagnostic model. In this case, to prevent data leakage caused by multiple slices extracted from one sample being distributed across the training, validation, and test sets, a subject-level data split is performed before conducting slice extraction36.
Aggregated slices with early fusion (EF): The extracted slices are assigned to each RGB-band (i.e., channel-level concatenation), thereby conveying relatively thicker volume information to the model compared to a single slice. This method allows for a more comprehensive representation of the volumetric features within the model19.
Aggregated feature with joint fusion (JF): Among the extracted slices, adjacent slices are assigned to the RGB band. They are then individually fed into a 2D CNN network for each plane to encode plane-wise features. These plane-wise features are immediately concatenated along the channel dimension and then passed through a Feed Forward Network (FFN) to encode a comprehensive volumetric feature. This process facilitates the integration of detailed spatial information from different planes, resulting in a more robust representation of the volumetric characteristics in the model20.
In previous studies, different network parameters are used for each plane to learn plane-wise features. However, in our approach, we designed the model to share parameters across planes, recognizing that the striatal instances appearing in each plane are homogeneous for the prediction model to capture. The choice of the design was also driven by the need to minimize model parameters in our experimental setup. A detailed explanation of the approach is provided in the '2D Student Network for Diagnosis System' section. The shared parameter strategy not only reduces the model's complexity but also ensures a more consistent learning process across different planes.
Automatic projection via rank-pooling-based projection
In previous studies applying the 2D projection method, particularly for AD diagnosis, researchers heuristically determined the representative slice based on observations needed, such as cortical atrophy in structural MRI or amyloid plaque load in beta-amyloid PET. These findings are expected to appear across multiple regions, making it easier for researchers to arbitrarily select them. However, in cases like detecting cerebral malignancies or PD, where the lesion areas to be observed are localized, manually extracting diagnostic slices involves significant time and effort, and the selection of representative slices can be subjectively influenced by the researcher. To overcome it, we propose a method that considers a 3D volumetric image as a dynamic image that changes scenes according to an arbitrary axis, thereby summarizing 3D spatial information without the need for practitioner intervention. The approach aims to achieve automatic 2D projection, effectively reducing the subjective bias and manual workload in selecting representative slices.
Bilen et al.37 initially proposed the concept of a 'dynamic image' as a compact representation for video analysis. A dynamic image is obtained by applying rank pooling to each frame of a video. The process effectively turns standard 2D CNNs into dynamic-aware models through fine-tuning video data. Prior research compares and discusses the construction operations of the rank-pooler, contrasting approximated rank pooling with modified rank-pooling that directly ranks feature frames. The studies report that while modified rank-pooling is about 45 times slower than approximated rank-pooling, it yields approximately 3% higher accuracy. Therefore, in our work, we adopted modified rank-pooling as an automatic projection technique to better encapsulate the rich volumetric information in the condensation process. The approach aims to provide a more accurate and dynamic representation of the volumetric data for effective analysis. First, when provided with \(N\) volumetric imaging data \({x}_{T}\in {\mathbb{R}}^{D\times H\times W}, {X}_{T}=\{{x}_{T1},\dots ,{x}_{TN}\}\), each containing \(D\) slices, we enumerate the \({j}^{th}\) slice \({I}_{j}=\left\{ {I}_{1},{I}_{2}, \cdots , {I}_{D} \right\}\) from each data point \({X}_{Tij}(i=1, \dots , N)\).
$$\widehat{\rho }\left( {I}_{i1},{I}_{i2}, \cdots , {I}_{ij} \right)= \frac{1}{n} \sum _{j=1}^{N}{\alpha }_{n}{I}_{n}$$
1
$${\alpha }_{d}=2d-D-1$$
2
Rank-pooling operations, as described in Eq. (1), produce a dynamic image by multiplying coefficients by the slices and then computing their cumulative sum. In modified rank-pooling, \(\rho\) is a function that generates a score by reflecting the rank of a sequence of \(D\) slices and maps it to a single value. In other words, the optimized rank in modified rank-pooling is derived from a weighted sum determined by a linear weighting function \({\alpha }_{d}\), corresponding to depth \(d\) along an arbitrary axis. This approach allows for a more nuanced and dynamic representation of the volumetric data by incorporating the sequential and spatial information contained within the slices. We obtain plane-wise dynamic images by applying rank-pooling directly to each axis of the 3D volumetric imaging. Subsequently, we compare the results of applying EF or JF to these plane-wise dynamic images, along with channel-level concatenation of representative slices for each axis as per the 2D + e approach. The method allows us to comprehensively evaluate the effectiveness of different fusion techniques in capturing the complex spatial information inherent in 3D volumetric imaging.
Building Volumetric Prior Knowledge Through Training 3D Teacher Network
To enable the 2D student network, which will be used in the diagnostic system, to understand the original 3D space from the partial volumetric features it observes, we employ volumetric prior knowledge. To form this volumetric prior knowledge, we begin by using a 3D teacher network to encode the original 3D space. The architecture of the 3D teacher network fundamentally follows that of ResNet1838 but with all 2D convolutional layers replaced by 3D convolutional layers. The ResNet's residual block structure involves repeating a sequence of a convolutional layer, a normalization layer, and a ReLU activation twice but adds the input features back to the features through identity mapping before the final ReLU activation. Our structural adaptation differs from the conventional ResNet18, which stacks two residual blocks in a group, repeated four times. Instead, we simply stack four individual residual blocks. This modification is driven by the need for a 3D CNN model that incorporates the validated structure of 2D CNNs, while also being lightweight enough for training with the typically smaller datasets characteristic of medical data. This approach aims to balance the complexity and depth of the model to suit the specific requirements of medical imaging analysis. Our internal experiments revealed that reducing the number of blocks in a ResNet18's residual block led to higher performance on both validation and test sets compared to the standard ResNet18. After the sequence of residual blocks, we apply Adaptive Average Pooling and encode the output into a 1024-dimensional vector. Finally, we use a single-layer neural network to map this vector to a 2-dimensional node representing Healthy Control (HC) and PD. This network is used to calculate the probability distribution \(P\left(k|{x}^{T}\right)={\widehat{y}}^{T}\)for the diagnostic label. In this context, \(k\) represents the random variable associated with the diagnostic label.
We train the 3D teacher network from scratch using the given dataset \({D}_{T}\). In our experiment, the dataset \({D}_{T}=\{{X}_{T}, {Y}_{T}\}\) consists of \({X}_{T}\), which are 3D volumetric images \({x}_{T}\in {\mathbb{R}}^{D\times H\times W}, {X}_{T}=\{{x}_{T1},\dots ,{x}_{TN}\}\), and \({Y}_{T}\) are the corresponding diagnostic labels evaluated according to the criteria of PPMI for each volumetric image. Here, \(H, W\), and \(D\) don't necessarily have to align with the standard neuroimaging coordinate system, let's define \(H\) as the axis running from left to right of the brain, \(W\) as the axis from the front to the back of the brain, and \(D\) as the brain's longitudinal axis. The 3D teacher network is trained through the cross-entropy loss between the diagnostic label and the predicted probability distribution \({\widehat{y}}_{T}.\)
$${L}_{CL{S}^{T}}\left({\widehat{y}}_{T},{y}_{T} \right)={L}_{CE}\left({\widehat{y}}_{T},{y}_{T}\right)=-\frac{1}{B}\sum _{i}^{B}{y}_{Ti}log\left({\widehat{y}}_{T}\right)$$
3
Once the 3D teacher network’s model parameters are trained, they are frozen during the training of the student model.
2D Student Network for diagnostic system
To distill the 3D volumetric prior knowledge embedded in the 3D teacher network, a 2D student network, using standard ResNet18 as its backbone, learns from partially restricted input from scratch while simultaneously mimicking the knowledge representation of the 3D teacher network. This process is akin to the 2D student network piecing together fragmented volumetric information to reconstruct a complete 3D volumetric knowledge. Consequently, the effectiveness of the 3D-to-2D KD can vary depending on how rich the input information is in volumetric knowledge. Therefore, the structure of the 2D student network is designed to adaptively respond to alterations in partial input restriction. For example, in cases where the 2D + e or rank-pooling-based projection method is combined through EF, the channel size is 3. In the JF setup, as illustrated in Fig. 2, we encode the projected inputs for each plane using a shared backbone model parameter. These encoded features are then immediately concatenated, and a single-layer neural network is used to map them into a 2-dimensional vector. This process is employed to predict the probability distribution \(P\left(k|{x}_{S}\right)={\widehat{y}}_{S}\) for HC and PD.
We train the 2D student network from scratch using a dataset \({D}_{S}=\{{X}_{S}, {Y}_{S}\}\) created by applying the 2D projection method to the dataset \({D}_{S}\). Here, \({X}_{S}\) represents the projected 2D images \({x}_{S}\in {\mathbb{R}}^{C\times H\times W},\) with \({X}_{S}=\{{x}_{S1},\dots ,{x}_{SN}\}\), generated by a predefined partial input restriction function \(R\) (i.e., \({x}_{S}=R\left({x}_{T}\right)\)). In our experimental setup, \({Y}_{S}\) is used as the Ground Truth (GT) label, which is identical to \({Y}_{T}\), and we ensure that \({D}_{T}\) and \({D}_{S}\) have the same subject IDs. Here, \(C\) represents the channel size of the model’s input images, which varies depending on the 2D projection method used. The objective function for the 2D student network is formulated as follows:
$${L}_{CL{S}^{S}}\left({\widehat{y}}_{S},{y}_{S} \right)={L}_{CE}\left({\widehat{y}}_{S},{y}_{S}\right)=-\frac{1}{B}\sum _{i}^{B}{y}_{Si}log\left({\widehat{y}}_{S}\right)$$
4
Aligning volumetric feature representation with 3D-to-2D Knowledge Distillation
As mentioned earlier, although 2D projection methods provide partial volumetric information to the 2D CNN model in their own distinct ways, they do not necessarily convey the context in which these information fragments are used in the original 3D space. Furthermore, if the dataset is limited in a data-driven feature learning supervision manner as directed by the given objective function, it is not easy for the visual representation created by the 2D student network to coincidentally align with the original 3D volumetric representation produced by the 3D teacher network. To minimize the modality gap between the original 3D data and the projected 2D data, we rectify the graph-level representations created by the teacher and student networks for minibatch data. By having the 2D student network closely mimic the similarity matrix between data samples, which is calculated based on the 3D volumetric information discovered by the 3D teacher network, the ability of the 2D network to handle fragmented volumetric features is significantly enhanced.
In the initial KD approach, the knowledge representation used is the soft prediction, which is obtained by applying a temperature softmax function to the output logits generated by the model. This soft prediction is also referred to as a soft target. The KD loss using soft targets, denoted as \({L}_{s.t.}\), is defined as follows:
$${L}_{s.t.}={L}_{CE}(\sigma \left(\frac{{z}_{S}}{T}\right), \sigma \left(\frac{{z}_{T}}{T}\right))$$
5
In this context, \({z}_{S}\) and \({z}_{T}\) represent the logits produced by the student network and the teacher network, respectively, while \(\sigma\) denotes the softmax function. The parameter T (temperature) in the softmax function plays a role in smoothing the computed class probabilities. It adjusts the sharpness of the probability distribution, amplifying either the strong or weak probabilities to prevent the softmax function's output from becoming too extreme. In the setup, the soft targets generated by the fixed (non-updating) teacher network act as additional pseudo labels for training the student model.
We adopt the values of the flattened features in the penultimate layer, corresponding to the volumetric representation encoded by the neural networks involved in 3D-to-2D KD across the 3D and 2D modalities, as the distilled representation for 3D-to-2D KD. \({f}_{T}\in {\mathbb{R}}^{B\times {C}_{T}}\) and \({f}_{S}\in {\mathbb{R}}^{B\times {C}_{S}}\) represent the feature vector from the penultimate layer of each network. Here, \({C}_{T}\) and \({C}_{S}\) are the dimensional sizes of the features created by the teacher network and the student network, respectively. For a given minibatch of data, the interrelationships of data points in the embedding space formed by the feature vector of the 3D student network are expressed through a similarity matrix. We calculate the similarity matrix as follows:
$$\stackrel{\sim}{{f}_{T}}=\frac{{f}_{T}}{\left|\right|{f}_{T}|{\left.\right|}_{2}}; {S}_{T}=\stackrel{\sim}{{f}_{T}} \cdot {\stackrel{\sim}{{f}_{T}}}^{T}$$
6
For the distilled representations, we first apply the l2-norm and the similarity matrices \({S}_{T}, {S}_{S}\in {\mathbb{R}}^{B\times B}\) are computed through linear affinity. In our internal experiments, we considered various similarity measures such as linear affinity (i.e., simple matrix multiplication), Radial Basis Function (RBF), and k-nearest neighbors (kNN)-based affinity. However, since no significant differences were observed, we opted for linear affinity for clarity. The similarity matrix for the feature representation from the penultimate layer of the 2D student network is calculated similarly:
$$\stackrel{\sim}{{f}_{S}}=\frac{{f}_{S}}{\left|\right|{f}_{S}|{\left.\right|}_{2}}; {S}_{S}=\stackrel{\sim}{{f}_{S}} \cdot {\stackrel{\sim}{{f}_{S}}}^{T}$$
7
We distill the volumetric prior knowledge of the teacher network by directly reducing the difference between the similarity matrices created from the encoded features produced by the teacher and student networks. We define the 3D-to-2D KD loss based on volumetric features as follows:
$${L}_{fg}\left({S}_{T},{ S}_{S}\right)=\frac{1}{{B}^{2}}\sum _{\left(i, j\right)\in I}|\left|{S}_{T}-{S}_{S}\right|{\left.\right|}_{2}^{2}$$
8
\(I\) represents the set containing all pairs of data points included in the minibatch input.
Finally, the total loss \({L}_{total}\) used for training the 2D student network is defined as follows:
$${L}_{total}={L}_{CL{S}^{S}}\left({\widehat{y}}_{S}, {y}_{S}\right)+{L}_{fg}({S}_{T}, {S}_{S})$$
9
Figure 2. Illustration of 3D to 2D Knowledge Distillation. \({L}_{fg}\)is 3D-to-2D KD loss based on volumetric features, \({L}_{{CLS}^{S}}\) is the cross-entropy loss between the diagnostic label \({y}_{S}\) and probability distribution \({\widehat{y}}_{S}\), \({\Phi }\) is the similarity measure, FFN is the Feed-Forward Network, \({f}_{T}\) is a feature vector in the penultimate layer of the 3D teacher network, \({f}_{S}\) is a feature vector in the penultimate layer of the 2D student network