<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://kitewatermelon.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://kitewatermelon.github.io/" rel="alternate" type="text/html" /><updated>2026-06-11T23:35:23+09:00</updated><id>https://kitewatermelon.github.io/feed.xml</id><title type="html">개발새발 01한 인생</title><subtitle>Computer Vision, Medical AI 논문 리뷰 및 개발 이야기를 다루는 기술 블로그입니다.</subtitle><author><name>YSPARK</name></author><entry><title type="html">[논문리뷰] LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics (1)</title><link href="https://kitewatermelon.github.io/paper-review/lejepa-1/" rel="alternate" type="text/html" title="[논문리뷰] LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics (1)" /><published>2026-06-08T00:00:00+09:00</published><updated>2026-06-08T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/paper-review/lejepa-1</id><content type="html" xml:base="https://kitewatermelon.github.io/paper-review/lejepa-1/"><![CDATA[<blockquote>
  <p>NeurIPS 2022 [<a href="https://arxiv.org/pdf/2511.08544">Paper</a>] [<a href="https://github.com/rbalestr-lab/lejepa">GitHub</a>]<br />
 Randall Balestriero, Yann LeCun
 14 Nov 2025</p>
</blockquote>

<h2 id="들어가며">들어가며</h2>
<p>LeJEPA는 Brown Univ.의 Randall Balestriero 교수와 AI의 대부인 Yann LeCun 교수가 다양한 downstream task에도 강건한 모델을 만들기 위해 모델 출력의 임베딩이 따라야하는 분포가 Isotropic Gaussian임을 이론적으로 증명하고, 그 분포를 따르도록 하는 SigReg라는 regulaization을 제안한 논문이다.</p>

<p>필자는 이를 보다 높은 해상도로 이해하고 싶은 욕심이 있기에 다음과 같은 구성을 결심했다.</p>

<p>(1)편에서 <code class="language-plaintext highlighter-rouge">1 Introduction ~ 2 Background and Notations</code> 을 통해 배경 지식을 정리할 것이다.</p>

<p>(2)편에서 <code class="language-plaintext highlighter-rouge">3 Latent Euclidean: Embeddings  Should be Isotropic Gaussian</code> 을 깊게 공부하여 왜 임베딩이 Isotropic Gaussian이 되어야 하는지 증명할 것이다.</p>

<p>(3)편에서 <code class="language-plaintext highlighter-rouge">4 SIGReg: Reliable Isotropic  Gaussian Regularization in High-Dimension</code>을 통해 어떻게 LeJEPA가 model의 출력을 Isotropic Gaussian으로 강제하는지 알아볼 것이다.</p>

<p>(4)편에서 <code class="language-plaintext highlighter-rouge">5 LeJEPA: Stable and Scalable  Implementation</code>가 어떻게 구현되는지 확인 할 것이다.</p>

<p>(5)편에서 <code class="language-plaintext highlighter-rouge">6 LeJEPA: Empirical Validation ~ 7 Conclusion</code>을 통해 LeJEPA가 실제로 의미가 있는지 알아볼 것이다.</p>

<h2 id="1-introduction">1 Introduction</h2>
<p>세계와 그 역학의 조작 가능한 representation을 학습하는 것은 AI 분야에서 오랜 질문으로, 그 기원은 수세기 전으로 거슬러 올라간다. 이미지 인식, 로봇 공학, 물리학, 우주 탐사와 같은 여러 분야에서 공통된 질문은 <code class="language-plaintext highlighter-rouge">관찰을 통해 조직적이고 실행 가능한 고차원 임베딩 공간을 어떻게 학습할 것인가? </code>이다.</p>

<p>파라미터화된 비선형 오퍼레이터인 $f_\theta$를 사용하여 관측치를 임베딩으로 매핑하는 것은 이 퍼즐의 표준적인 첫 조각이다. 그 다음 퍼즐은 아직 표준화가 덜 된 어떻게 $f_\theta$를 학습할 것인가? 이다.</p>

<p>Joint-Embedding Predictive Architectures (JEPAs)는 $f_\theta$를 의미론적으로 관련있는 두 <code class="language-plaintext highlighter-rouge">views</code>간 predictive agreement를 최대화하여 $f_\theta$를 학습시키도록 제안한다.</p>

<blockquote>
  <p>여기서 <code class="language-plaintext highlighter-rouge">views</code>란 transformations과 corruptions라는 두가지 폼이 존재한다.<br />
결국엔 입력 이미지를 masking, cropping, blurring, temporal or spatial translations, geometric or photometric transformations, viewpoint changes, views from different sensor modalities, etc. 하는 것이다.</p>
</blockquote>

<p>어떤 경우든간에 <code class="language-plaintext highlighter-rouge">views</code>는 두 개가 의미적으로 어느 정도 연관되어 있어야 prediction task가 인코더 $f_\theta$의 임베딩을 데이터 속에 숨겨진 진짜 지식 쪽으로 정렬시킬 수 있다.</p>

<p>그런데 JEPA의 prediction task는 잘못하면 모델이 다 똑같은 벡터를 뱉어버리는 collapse를 일으키게 된다. identical embeddings (complete collapse)나 low dimensional subspace (dimensional collapse)</p>

<blockquote>
  <p>이를 출력하면 모델 입장에서 가장 쉬운 해답이 되기 때문인데 이는 뭐가 들어오든 같은 벡터를 뱉으면 prediction loss가 0이 되기 때문이다.</p>
</blockquote>

<p>이를 해결하기 위해 SOTA 방법은 휴리스틱에 의존하게 된다. 대표적으로 stop-gradient, 비대칭 뷰 생성 , EMA 스케줄링을 통한 teascher-student network, explicit한 normalization, whitening layers 도입 등이 있으며 이들은 모두 하이퍼파라미터를 조금만 바꿔도 collapse가 나거나 성능이 떨어진다. 저자들은 이론은 안 파고 스케일만 키우고 있음을 지적하며 이론적 뒷받침에 대한 필요를 설명한다.</p>

<p>저자들은 묻고 답한다.</p>
<blockquote>
  <p>JEPA가 준수해야 할 필요 조건이 무엇일까? <br />
(i) prediction task를 풀되, (기존의 JEPA) <br />
(2) 임베딩의 분포를 등방성 가우스 분포를 강제하는 것이다. (SIGReg)</p>
</blockquote>

<p>저자들은 임베딩의 분포를 등방성 가우스 분포를 강제하기 위해 SIGReg(Sktched Isotropic Gaussian Regularization)를 제안한다.</p>

<p>SIGReg는 (1) 통계적으로 보장되었으며 (2) 기존의 휴리스틱들을 사용하지 않아도 되고 (3) 메모리와 계산 복잡성이 선형이며 (4) collapse 방지를 위한 하이퍼파라미터는 오직 하나만 있어도 되는 장점이 있다.</p>]]></content><author><name>YSPARK</name></author><category term="Paper-Review" /><category term="Computer-Vision" /><category term="Self-Supervised-Learning" /><category term="JEPA" /><summary type="html"><![CDATA[LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics (1)]]></summary></entry><entry><title type="html">[코드 리뷰] mlx로 Vision Transformer 만들어보기</title><link href="https://kitewatermelon.github.io/code-review/mlx-vit/" rel="alternate" type="text/html" title="[코드 리뷰] mlx로 Vision Transformer 만들어보기" /><published>2026-06-07T00:00:00+09:00</published><updated>2026-06-07T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/code-review/mlx-vit</id><content type="html" xml:base="https://kitewatermelon.github.io/code-review/mlx-vit/"><![CDATA[<p>이 글이 맘에 들거나 도움이 드셨다면 아래 레포에 스타 하나 부탁드립니다!!</p>

<p>code repo: <a href="https://github.com/kitewatermelon/vit-mlx">📥 vit-mlx repo</a></p>

<h1 id="서론">서론</h1>

<p>6월 3일 선거가 있던 날 나의 구현 능력이 궁금해서 뭐라도 손으로 구현해보고 싶었다. 그러던 중 “나는 Vision Transformer를 얼마나 이해하고 있을까?” 라는 질문이 문득 떠올랐다.</p>

<p>Q1: 왜 <code class="language-plaintext highlighter-rouge">ViT</code>?  <br />
A1: 연구에서 가장 자주 사용한게 Vision Transformer(ViT)이기도 하고 Sequential과 ModuleList라는 개념을 최근에 공부했는데 둘 다 사용해보고 싶어서 구현하고 싶었다. 본 구현에선 <a href="https://arxiv.org/pdf/2010.11929">ViT 논문</a>만을 참고해서 구현하는 것을 목표로 했다.</p>

<p>Q2: 왜 <code class="language-plaintext highlighter-rouge">mlx</code>? <br />
A2: macbook과 mac mini를 사용하는 한명의 apple의 팬으로써, Mac slicon에서 동작하는 <code class="language-plaintext highlighter-rouge">mlx</code>라는 프레임워크에 관심이 있었다. <code class="language-plaintext highlighter-rouge">mlx</code>는 <code class="language-plaintext highlighter-rouge">numpy-like,  torch-like</code>라 별로 어렵지 않겠지… 라는 생각과 함께 구현을 했다.</p>

<p>Q3: 참고한 사이트? <br />
A3: <code class="language-plaintext highlighter-rouge">timm</code> 라이브러리를 주로 참고하였다. 코딩 에이전트/LLM은 cls 토큰을 배치 단위로 expand하는 법을 몰라서 찾아본 것 외에는 사용하지 않았다.</p>

<p>아래는 ViT 논문에서 제공하는 ViT의 구조도이다.</p>

<p><img src="/assets/img/code-review/mlx-vit/fig1.webp" alt="png" /></p>

<p>ViT를 구현하기 위해 고려해야 하는 사항은 크게 3가지가 있었다.</p>
<ol>
  <li>patch embedding을 어떻게 구현할지?</li>
  <li>Multi head Attention을 어떻게 구현할지?</li>
  <li><code class="language-plaintext highlighter-rouge">forward()</code>를 어떻게 구현할지?</li>
</ol>

<p>이제 차근차근 고려사항을 어떻게 구현했는지 알아보자.</p>

<h1 id="구현">구현</h1>
<h2 id="1-patch-embedding">1. Patch Embedding</h2>
<p><img src="/assets/img/code-review/mlx-vit/fig2.webp" alt="png" />
ViT는 Transformer 기반의 입력을 받기 위해 이미지를 고정된 크기의 패치로 잘라서 flatten한 후 Linear Projection을 한다. 이 과정이 굉장히 복잡할 것 같지만 실제로는 <code class="language-plaintext highlighter-rouge">nn.Conv2d</code>를 통해 간단하게 구현이 가능하다.</p>

<ol>
  <li>patch size와 stride를 같게 주면 겹치는 부분없이 patchfy 하는 꼴이 되이 된다.</li>
  <li>convolution 연산은 결국 패치들에 대하여 Linear projection을 하는 꼴이 된다.</li>
  <li>해당 피처는 NHWC의 모양의 피처맵이 나오게 된다. H와 W에 대하여 flatten 해주면 1D의 토큰 나열이 나오게 된다.</li>
</ol>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">PatchEmbedding</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">is_rgb</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">patch_size</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">embed_dim</span><span class="o">=</span><span class="mi">768</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">patch_size</span> <span class="o">=</span> <span class="n">patch_size</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">in_channels</span> <span class="o">=</span> <span class="mi">3</span> <span class="k">if</span> <span class="n">is_rgb</span> <span class="k">else</span> <span class="mi">1</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Conv2d</span><span class="p">(</span>
                <span class="n">in_channels</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">in_channels</span><span class="p">,</span>
                <span class="n">out_channels</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span>
                <span class="n">kernel_size</span><span class="o">=</span><span class="n">patch_size</span><span class="p">,</span>
                <span class="n">stride</span><span class="o">=</span><span class="n">patch_size</span>
            <span class="p">)</span>

    <span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># B, C, H, W = x.shape
</span>        <span class="c1"># print(f"Input Shape: {x.shape}")
</span>        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">).</span><span class="n">flatten</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="c1"># 두번째 차원인 H/patch_size와 W/patch_size를 하나의 차원으로 합침
</span>        <span class="k">return</span> <span class="n">x</span>
</code></pre></div></div>

<h2 id="2-block---multi-head-self-attentionmhsa">2. Block - Multi Head Self-Attention(MHSA)</h2>

<p><img src="/assets/img/code-review/mlx-vit/fig3.webp" alt="png" /></p>

<p>MHSA는 Transformer block(layer)에서 필수적으로 구현을 해야한다. 비록 요즘에는 flash attention이라는게 나와서 단순히 호출을 하면 되지만, 본 글에서는 직접 수식을 구현하는 형태로 진행한다.</p>

<p>우선 self-attention의 아주 기본적인 수식은 다음과 같다.</p>

<h3 id="2-1-qkv-계산">2-1. QKV 계산</h3>
<p>\(Q = XW^Q, \quad K = XW^K, \quad V = XW^V\)</p>

<h3 id="2-2-single-head-scaled-dot-product-attention">2-2. Single-Head Scaled Dot-Product Attention</h3>
<p>\(\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V\)</p>

<h3 id="2-3-multi-head">2-3. Multi-Head</h3>
<p>\(\text{head}_i = \text{Attention}(QW_i^Q,\ KW_i^K,\ VW_i^V)\)</p>

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O\]

<p>위 식을 코드로 구현하기 위해</p>
<ol>
  <li>우리는 projection 행렬을 만들어 Q, K, V를 얻어야 한다.</li>
  <li>self-attention 수식을 구현해야 하고</li>
  <li>그 와중에 head를 고려해서 구현해야 한다.</li>
</ol>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">MHSA</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="o">=</span><span class="mi">768</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="mi">12</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span>

        <span class="k">assert</span> <span class="n">embed_dim</span> <span class="o">%</span> <span class="n">num_heads</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s">"embed_dim % num_heads != 0 !!!!"</span> 

        <span class="bp">self</span><span class="p">.</span><span class="n">head_dim</span> <span class="o">=</span> <span class="n">embed_dim</span> <span class="o">//</span> <span class="n">num_heads</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">scale</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">head_dim</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">qkv</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">embed_dim</span> <span class="o">*</span> <span class="mi">3</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">)</span>


    <span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># BND -&gt; BN(3*D)
</span>        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">qkv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="c1"># Q, K, V 3개로 chunk -&gt; [BND, BND, BND]
</span>        <span class="n">qkv</span> <span class="o">=</span> <span class="n">mx</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">indices_or_sections</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> 
        <span class="c1"># print(qkv[0].shape, qkv[1].shape, qkv[2].shape)
</span>        
        <span class="c1"># 1. Q,K,V 각각 BHND로 reshape (관례)
</span>        <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="nb">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">t</span><span class="p">:</span> <span class="n">rearrange</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="s">'b n (h d) -&gt; b h n d'</span><span class="p">,</span> <span class="n">h</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span><span class="p">),</span> <span class="n">qkv</span><span class="p">)</span> 
        <span class="c1"># 2. Mat. Mul. -&gt; BHN(Q)N(K)
</span>        <span class="n">attn</span> <span class="o">=</span> <span class="n">q</span> <span class="o">@</span> <span class="n">k</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> 
        <span class="c1"># 3. Scale: sqrt(dk)
</span>        <span class="n">attn_score</span> <span class="o">=</span> <span class="n">attn</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">scale</span> 
        <span class="c1"># 4. attn_w: K에 대하여 softmax하기 위해서 axis=-1로 설정
</span>        <span class="n">attn_weight</span> <span class="o">=</span> <span class="n">mx</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">attn_score</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> 
        <span class="c1"># 5. Mat. Mul. -&gt; BHND
</span>        <span class="n">out</span> <span class="o">=</span> <span class="n">attn_weight</span> <span class="o">@</span> <span class="n">v</span> 
        <span class="c1"># 6. 각 헤드 concat
</span>        <span class="n">out</span> <span class="o">=</span> <span class="n">rearrange</span><span class="p">(</span><span class="n">out</span><span class="p">,</span> <span class="s">'b h n d -&gt; b n (h d)'</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span><span class="p">)</span> <span class="c1"># MHSA concat
</span>        <span class="c1"># 7. concat 후 정보 섞어주기 위해 같은 차원으로 projection
</span>        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">proj</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">out</span><span class="p">,</span> <span class="n">attn_weight</span> <span class="c1"># 최종 출력과 attn_weight 같이 출력
</span>
</code></pre></div></div>

<p>구현체의 포인트는 다음과 같다.</p>

<p>projection 행렬은 아래처럼 <code class="language-plaintext highlighter-rouge">nn.Linear</code>로 QKV를 한번에 projection한 후</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="bp">self</span><span class="p">.</span><span class="n">qkv</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">embed_dim</span> <span class="o">*</span> <span class="mi">3</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
</code></pre></div></div>

<p>forward에서 각각 QKV로 2번째 차원인 D를 3개로 split해서 각각을 Q, K, V로 사용한다.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">qkv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="c1"># Q, K, V 3개로 chunk -&gt; [BND, BND, BND]
</span>        <span class="n">qkv</span> <span class="o">=</span> <span class="n">mx</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">indices_or_sections</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> 
</code></pre></div></div>

<p>그 후에 Q,K,V 각각 BHND로 reshape을 한 후 Q, K를 행렬곱한다. 이때 K.transpose를(0,1,3,2)를 통해 K의 N과 D차원을 바꾸어 N $\times$ N 모양을 만들어준 후 scaling을 해주어 attention score가 만들어지도록 한다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>        <span class="c1"># 1. Q,K,V 각각 BHND로 reshape (관례)
</span>        <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="nb">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">t</span><span class="p">:</span> <span class="n">rearrange</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="s">'b n (h d) -&gt; b h n d'</span><span class="p">,</span> <span class="n">h</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span><span class="p">),</span> <span class="n">qkv</span><span class="p">)</span> 
        <span class="c1"># 2. Mat. Mul. -&gt; BHN(Q)N(K)
</span>        <span class="n">attn</span> <span class="o">=</span> <span class="n">q</span> <span class="o">@</span> <span class="n">k</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> 
        <span class="n">attn_score</span> <span class="o">=</span> <span class="n">attn</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">scale</span> 
</code></pre></div></div>

<p>이후에 K에 대하여 softmax를 하는데 이때 Q에 대하여 해야할지 K에 대하여 해야할지 헷갈렸다.
어쨋든 Query에 대한 Key값의 확률을 구하는 과정이므로 axis=-1로 하여 K에 대하여 softmax를 하도록 했다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>         <span class="c1"># 4. attn_w: K에 대하여 softmax하기 위해서 axis=-1로 설정
</span>        <span class="n">attn_weight</span> <span class="o">=</span> <span class="n">mx</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">attn_score</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> 
</code></pre></div></div>
<p>그 이후에 V를 곱하고 head를 concat한 후 head 별 정보를 섞어주기 위해 동일한 차원으로 projection한다. forward 과정에서는 <code class="language-plaintext highlighter-rouge">out</code>만 사용하겠지만, 여러 분석 과정에서 <code class="language-plaintext highlighter-rouge">attn_weight</code>가 필요할 수 있으므로 동시에 반환하도록 설정했다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>        <span class="n">out</span> <span class="o">=</span> <span class="n">attn_weight</span> <span class="o">@</span> <span class="n">v</span> 
        <span class="c1"># 6. 각 헤드 concat
</span>        <span class="n">out</span> <span class="o">=</span> <span class="n">rearrange</span><span class="p">(</span><span class="n">out</span><span class="p">,</span> <span class="s">'b h n d -&gt; b n (h d)'</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span><span class="p">)</span> <span class="c1"># MHSA concat
</span>        <span class="c1"># 7. concat 후 정보 섞어주기 위해 같은 차원으로 projection
</span>        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">proj</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">out</span><span class="p">,</span> <span class="n">attn_weight</span> <span class="c1"># 최종 출력과 attn_weight 같이 반환
</span></code></pre></div></div>

<h2 id="3-block---mlp">3. Block - MLP</h2>
<h3 id="3-1-mlp">3-1. MLP</h3>
<p>Transformer block의 또 다른 요소는 MLP인데 논문에 그렇게 자세히 나와있지 않아서 MLP는 <code class="language-plaintext highlighter-rouge">timm</code>을 참고하여 다음과 같이 간단하게 구현했다. 이 글에서는 짧게 작성하지만 실제로는 비선형성이 추가가 되는 부분이라 굉장히 중요한 부분이다. 절대 무시해선 안된다!</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">MLP</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="o">=</span><span class="mi">768</span><span class="p">,</span> <span class="n">mlp_ratio</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span> 
        <span class="c1"># 아래 timm-like MLP 참조함
</span>        <span class="c1"># https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/mlp.py
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
                <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">embed_dim</span> <span class="o">*</span> <span class="n">mlp_ratio</span><span class="p">),</span>
                <span class="n">nn</span><span class="p">.</span><span class="n">GELU</span><span class="p">(),</span>
                <span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_rate</span><span class="p">),</span>
                <span class="n">nn</span><span class="p">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">embed_dim</span> <span class="o">*</span> <span class="n">mlp_ratio</span><span class="p">),</span>
                <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">embed_dim</span> <span class="o">*</span> <span class="n">mlp_ratio</span><span class="p">,</span> <span class="n">embed_dim</span> <span class="p">),</span>
                <span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_rate</span><span class="p">),</span>
            <span class="p">)</span>
        
    <span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">net</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="3-2-block">3-2. Block</h3>
<p>Block도 여기서는 짧게 넘어가고자 한다. x를 normalization 하고 각 모듈을 통과한 후에 그 값에 x를 더하는 식의 Residual connection을 적용한다. 이는 논문 figure에 자세히 설명이 되어 있다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">Block</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="o">=</span><span class="mi">768</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="mi">12</span><span class="p">,</span> <span class="n">mlp_ratio</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">norm1</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">dims</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span> <span class="c1"># 별도의 norm1
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">norm2</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">dims</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span> <span class="c1"># 별도의 norm1
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">mhsa</span> <span class="o">=</span> <span class="n">MHSA</span><span class="p">(</span><span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">mlp</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">(</span><span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">mlp_ratio</span><span class="o">=</span><span class="n">mlp_ratio</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="o">=</span><span class="n">dropout_rate</span><span class="p">)</span>
    
    <span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># 1단계 - attention: 뭐가 더 중요한지 확인
</span>        <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Layer normalization 1
</span>        <span class="n">x_attn</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">mhsa</span><span class="p">(</span><span class="n">x_norm</span><span class="p">)</span> <span class="c1"># MHSA
</span>        <span class="n">x</span> <span class="o">=</span> <span class="n">x_attn</span> <span class="o">+</span> <span class="n">x</span> <span class="c1"># Residual connection
</span>        
        <span class="c1"># 2단계 - MLP: 비선형성 증가
</span>        <span class="n">x_norm</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">norm2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Layer normalization 2
</span>        <span class="n">x_mlp</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">mlp</span><span class="p">(</span><span class="n">x_norm</span><span class="p">)</span> <span class="c1"># 비선형성 증가를 위한 MLP
</span>        <span class="n">x</span> <span class="o">=</span> <span class="n">x_mlp</span> <span class="o">+</span> <span class="n">x</span>

        <span class="k">return</span> <span class="n">x</span>
</code></pre></div></div>

<h2 id="4-vit">4. ViT</h2>
<p>이제 입력을 받는 부분과 Transformer block을 전부 완성했다. 그럼에도 불구하고 아직 세가지가 더 완성이 되어야하는데 cls token과 positional embedding을 처리하는 부분과 간단한 classification head이다. 그리고 이를 한번에 실행시킬 forward 부분이 필요하다.</p>

<h3 id="4-1-cls-token과-positional-embedding-그리고-classification-head">4-1. cls token과 positional embedding 그리고 classification head</h3>
<p>cls token과 positional embedding은 결국 하나의 파라미터다. 따라서 <code class="language-plaintext highlighter-rouge">mx.random.normal()</code>을 사용한다. <code class="language-plaintext highlighter-rouge">mlx</code>는 명시적으로 parameter와 tensor를 구분하지 않는데 이는 model.parameters()를 호출하면 모듈 안의 모든 <code class="language-plaintext highlighter-rouge">mx.array</code>를 자동으로 <code class="language-plaintext highlighter-rouge">pytree</code>로 수집해버리기 때문이라고 한다.</p>

<p>classification head는 매우 간단하게 <code class="language-plaintext highlighter-rouge">embed_dim</code>에서 <code class="language-plaintext highlighter-rouge">num_classes</code>개로 projection하도록 만들었다.</p>

<p><code class="language-plaintext highlighter-rouge">mlx</code>와 <code class="language-plaintext highlighter-rouge">torch</code>랑 다른 점은 <code class="language-plaintext highlighter-rouge">nn.ModuleList</code>를 사용하는 것이 아니라 그냥 리스트에 <code class="language-plaintext highlighter-rouge">[]</code>에 각 모듈을 집어 넣어서 사용한다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">ViT</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="s">"""
        cls token based ViT 
    """</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img_size</span><span class="o">=</span><span class="mi">224</span><span class="p">,</span> <span class="n">patch_size</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">embed_dim</span><span class="o">=</span><span class="mi">768</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="mi">12</span><span class="p">,</span> <span class="n">mlp_ratio</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">depth</span><span class="o">=</span><span class="mi">12</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">1000</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">patch_size</span> <span class="o">=</span> <span class="n">patch_size</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">mlp_ratio</span> <span class="o">=</span> <span class="n">mlp_ratio</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">dropout_rate</span> <span class="o">=</span> <span class="n">dropout_rate</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">depth</span> <span class="o">=</span> <span class="n">depth</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">num_patches</span> <span class="o">=</span> <span class="nb">int</span><span class="p">((</span><span class="n">img_size</span> <span class="o">//</span> <span class="n">patch_size</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
        
        <span class="bp">self</span><span class="p">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">dims</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span> <span class="c1"># final norm
</span>

        <span class="c1"># !주의! self.cls_token은 single batch 기준으로 만들어졌기 때문에 (0번째 차원이 1), _pos_embed 매서드에서 동적으로 배치 차원을 늘려줘야 함. 
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">pos_embed</span> <span class="o">=</span> <span class="n">mx</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_patches</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">),</span> <span class="n">scale</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span>  <span class="c1"># 학습 가능한 파라미터로 position embedding 학습
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">cls_token</span> <span class="o">=</span> <span class="n">mx</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">),</span> <span class="n">scale</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span> <span class="c1"># 학습 가능한 파라미터로 cls 토큰 학습 
</span>
        <span class="c1"># FOR TEST 실제 실험 시 동작 변경!
</span>        <span class="c1"># self.cls_token = mx.zeros((1, 1, embed_dim))  # 동작 확인 용 cls token
</span>        <span class="c1"># self.pos_embed = mx.ones((1, self.num_patches + 1, embed_dim)) # 동작 확인 용 position embedding
</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">patch_embed</span> <span class="o">=</span> <span class="n">PatchEmbedding</span><span class="p">(</span>
            <span class="n">patch_size</span><span class="o">=</span><span class="n">patch_size</span><span class="p">,</span> 
            <span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span>
            <span class="p">)</span>
        
        <span class="bp">self</span><span class="p">.</span><span class="n">blocks</span> <span class="o">=</span> <span class="p">[</span>
            <span class="n">Block</span><span class="p">(</span>
                <span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> 
                <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> 
                <span class="n">mlp_ratio</span><span class="o">=</span><span class="n">mlp_ratio</span><span class="p">,</span> 
                <span class="n">dropout_rate</span><span class="o">=</span><span class="n">dropout_rate</span>
                <span class="p">)</span>  <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">depth</span><span class="p">)</span>
            <span class="p">]</span>
        
        <span class="bp">self</span><span class="p">.</span><span class="n">head</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="4-2-forward">4-2. forward</h3>
<p>forward는 단순하게 구현된다.</p>
<ol>
  <li>patch embedding을 한 후 <code class="language-plaintext highlighter-rouge">_pos_embed</code> 매서드를 이용하여 positional embedding과 cls token을 추가한다.</li>
  <li><code class="language-plaintext highlighter-rouge">[]</code>에 넣어놓은 Transformer block을 차례로 통과시켜 최종 출력을 얻는다.</li>
  <li>최종 출력을 normalization한 후 cls token을 이용하여 head를 통과시켜 최종 결과를 얻는다.
mlx는 torch와 다르게 <code class="language-plaintext highlighter-rouge">forward()</code>대신 <code class="language-plaintext highlighter-rouge">__call__()</code>를 이용하여 forward 기능을 제공한다. 이는 jax의 영향을 받은 것으로 생각한다.</li>
</ol>

<p><code class="language-plaintext highlighter-rouge">_pos_embed</code>를 만들면서 가장 신경 썼던 부분은 <code class="language-plaintext highlighter-rouge">self.cls_token</code>의 B 차원을 1로 정의했기 때문에 배치 단위로의 확장하는 것 이었다. torch에서는 보통 <code class="language-plaintext highlighter-rouge">expand</code>로 구현이 되는데 mlx에서는 <code class="language-plaintext highlighter-rouge">broadcast_to</code>라는 함수를 통해 구현을 해야했다.</p>

<p>또한 <code class="language-plaintext highlighter-rouge">self.cls_token</code>과 <code class="language-plaintext highlighter-rouge">self.pos_embed</code>이 정확히 동작하는지 확인하기 위해서 각각 zeros와 ones로 만들어 의도대로 정확히 동작하는지 확인했다.</p>

<ul>
  <li>그 결과 N차원이 1 늘어났으며 0번째에 0으로 이루어진 벡터가 추가된 것을 확인하여 <code class="language-plaintext highlighter-rouge">self.cls_token</code>이 제대로 추가되고 있는 것을 확인했다.</li>
  <li>그 이후에 <code class="language-plaintext highlighter-rouge">self.pos_embed</code>이 1씩 더해지는 것을 확인하여 positional embedding이 잘 더해지는 것을 확인했다.</li>
</ul>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>    <span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">patch_embed</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">_pos_embed</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">block</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">blocks</span><span class="p">):</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">block</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
            <span class="c1"># print(f"{i} 번째 layer")
</span>        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">cls</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span>       <span class="c1"># CLS token만 추출
</span>        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">head</span><span class="p">(</span><span class="n">cls</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">_pos_embed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">B</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">shape</span>
        <span class="c1"># print(B, N, C)
</span>        <span class="c1"># broadcast_to 함수로 cls token B 만큼 복제함. 
</span>        <span class="c1"># numpy-like 이므로 자세한 동작은 다음 문서 참고. https://numpy.org/doc/2.2/reference/generated/numpy.broadcast_to.html
</span>        <span class="n">cls</span> <span class="o">=</span> <span class="n">mx</span><span class="p">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">cls_token</span><span class="p">,</span> <span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">C</span><span class="p">))</span>   
        <span class="n">x</span> <span class="o">=</span> <span class="n">mx</span><span class="p">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">cls</span><span class="p">,</span> <span class="n">x</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># [cls] token N차원의 제일 앞에 concat, axis=1로 해줘야 N+1 됨. (BNC)
</span>        <span class="c1"># print(x)
</span>        <span class="c1"># array([[[0, 0, 0, ..., 0, 0, 0], &gt; zeros로 설정 해놓고 제대로 추가 됐는지 테스트 완료
</span>        <span class="c1">#         [-0.625268, 0.0861489, 0.387924, ..., 0.565768, 1.14554, 0.420791],
</span>        <span class="c1">#         [1.4328, 0.820314, -0.0266257, ..., 0.0697592, -1.02677, 0.830738],
</span>        <span class="c1">#         ...,
</span>        <span class="c1">#         [0.164638, 0.335743, 0.710777, ..., 0.172818, -0.326656, 0.0479117],
</span>        <span class="c1">#         [-0.466198, 0.0355091, -0.264295, ..., -0.378135, 0.381905, -0.481186],
</span>        <span class="c1">#         [0.474137, 1.21557, -0.281954, ..., 0.562486, -0.0671904, 0.0877942]]], dtype=float32)
</span>        
        <span class="n">x</span> <span class="o">+=</span> <span class="bp">self</span><span class="p">.</span><span class="n">pos_embed</span> <span class="c1"># position 정보 postion wise 하게 추가
</span>        <span class="c1"># print(x)
</span>        <span class="c1"># array([[[1, 1, 1, ..., 1, 1, 1], &gt; ones로 설정 해놓고 제대로 추가 됐는지 테스트 완료
</span>        <span class="c1">#         [0.374732, 1.08615, 1.38792, ..., 1.56577, 2.14554, 1.42079],
</span>        <span class="c1">#         [2.4328, 1.82031, 0.973374, ..., 1.06976, -0.0267704, 1.83074],
</span>        <span class="c1">#         ...,
</span>        <span class="c1">#         [1.16464, 1.33574, 1.71078, ..., 1.17282, 0.673344, 1.04791],
</span>        <span class="c1">#         [0.533802, 1.03551, 0.735705, ..., 0.621866, 1.38191, 0.518814],
</span>        <span class="c1">#         [1.47414, 2.21557, 0.718046, ..., 1.56249, 0.93281, 1.08779]]], dtype=float32)
</span>        <span class="k">return</span> <span class="n">x</span>
    
    <span class="k">def</span> <span class="nf">get_params_info</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="n">params</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">trainable_parameters</span><span class="p">()</span>
        <span class="n">flat</span> <span class="o">=</span> <span class="n">mlx</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">tree_flatten</span><span class="p">(</span><span class="n">params</span><span class="p">)</span>
        <span class="c1"># flat = [("layer.weight", array), ("layer.bias", array), ...]
</span>        
        <span class="n">total</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">v</span><span class="p">.</span><span class="n">size</span> <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">flat</span><span class="p">)</span>  <span class="c1"># 언패킹 필요
</span>        <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Total trainable parameters: </span><span class="si">{</span><span class="n">total</span><span class="si">:</span><span class="p">,</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">total</span>
</code></pre></div></div>

<h1 id="글을-마무리하며">글을 마무리하며…</h1>
<p>이번 구현을 통해 얻은 것은 다음과 같다.</p>
<ul>
  <li>softmax, MHSA 등에서 얻은 차원에 대한 이해</li>
  <li>모델 설계 과정에서 동작 테스트하는 법</li>
  <li>모델 중간 상태의 shape을 예상하는 능력</li>
  <li>개발의 주도권을 AI에 뺏기지 않고 문서를 찾아보고 이해하는 능력</li>
</ul>

<p>사전에 ViT와 torch에 대한 전반적인 지식이 있어서 구현에 큰 어려움을 겪지는 않았으나, 어려웠던 점은 다음과 같다.</p>
<ul>
  <li>cls token을 배치 단위로 확장하는 과정</li>
  <li>MHSA에서 multi-head를 확장하는 과정</li>
  <li>글에 적지는 않았지만 Block에서 MHSA 이전과 MLP 이전에 normalization 할 때 같은 norm을 사용했던 버그</li>
</ul>

<p>앞으로도 AI를 사용하지 않고 문서 찾아보고 논문 찾아보며 직접 구현하여 AI modeling에 대한 이해를 높혀야겠다. 잘 알아야 AI도 잘쓰니까…</p>]]></content><author><name>YSPARK</name></author><category term="Code-Review" /><category term="ViT" /><category term="MLX-basic" /><summary type="html"><![CDATA[mlx로 Vision Transformer를 구현하며 알게 된 실무적인 경험을 나눈다.]]></summary></entry><entry><title type="html">[코드 리뷰] nn.Conv2d는 사실 convolution을 하지 않는다.</title><link href="https://kitewatermelon.github.io/code-review/conv2d/" rel="alternate" type="text/html" title="[코드 리뷰] nn.Conv2d는 사실 convolution을 하지 않는다." /><published>2026-06-04T00:00:00+09:00</published><updated>2026-06-04T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/code-review/conv2d</id><content type="html" xml:base="https://kitewatermelon.github.io/code-review/conv2d/"><![CDATA[<p>code 다운로드: <a href="assets/code/code-review/conv2d.ipynb">📥 conv2d.ipynb 다운로드</a></p>

<p>어느 날 PyTorch 문서를 둘러보던 중 재밌는 사실을 알게 되었다.
<a href="https://docs.pytorch.org/docs/2.12/generated/torch.nn.Conv2d.html">nn.Conv2d</a> 문서에 따르면,
<strong>“<code class="language-plaintext highlighter-rouge">nn.Conv2d</code>는 사실 convolution을 하지 않는다”</strong>
라는 것이다. 문서에 의하면 “where ⋆ is the valid 2D cross-correlation operator” 라고 하는데, cross-correlation와 convolution의 차이는 무엇일까?</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
</code></pre></div></div>

<h2 id="nnconv2d는-무엇일까"><code class="language-plaintext highlighter-rouge">nn.Conv2d</code>는 무엇일까?</h2>

<p>우선 이 포스트를 통해 <code class="language-plaintext highlighter-rouge">nn.Conv2d</code>가 무엇인지 고민해본다. 결국 <code class="language-plaintext highlighter-rouge">nn.Conv2d</code>는 파라미터(<code class="language-plaintext highlighter-rouge">nn.Parameter</code>)이다. 파라미터는 텐서지만, 텐서는 파라미터가 아니다. 둘의 차이는 무엇일까?</p>

<p>파라미터는 <code class="language-plaintext highlighter-rouge">.requires_grad=True</code>이며 <code class="language-plaintext highlighter-rouge">model.parameters()</code>에 자동으로 등록되는 텐서라고 생각하면 편하다. 아래 예제를 통해 알아보자.
즉 파라미터는 학습 가능한 텐서이다.</p>

<p>아래 코드에서는 <code class="language-plaintext highlighter-rouge">rnd</code>라는 변수를 (5,5) 사이즈의 표준 정규 분포를 따르는 텐서로 초기화한다.</p>

<p><code class="language-plaintext highlighter-rouge">torch.randn</code>의 output이 애초에 tensor긴 하지만 본 예제에서는 명시적으로 표현하기 위해 <code class="language-plaintext highlighter-rouge">torch.tensor</code>로 감싼다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">rnd</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">rnd</span><span class="p">)</span>
<span class="n">t</span>
</code></pre></div></div>
<p>출력값은 다음과 같다.</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>/tmp/ipykernel_723/1694594144.py:2: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  t = torch.tensor(rnd)





tensor([[-0.7679,  0.3281,  0.8054,  1.1955, -0.6240],
        [ 0.5000, -0.1930, -0.0020, -1.0173, -0.0083],
        [-0.7387,  0.9879,  0.8556,  0.8847,  0.2356],
        [ 0.4543,  1.8703,  1.0567,  0.2428,  1.7751],
        [ 0.5252,  0.1668, -0.9017,  0.5592, -0.6604]])
</code></pre></div></div>

<p>이번에는 <code class="language-plaintext highlighter-rouge">nn.Parameter</code>로 <code class="language-plaintext highlighter-rouge">rnd</code> 변수를 감싼 뒤 결과를 출력했다.당연하게도 두 값은 모두 같다.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>

<span class="n">param</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">rnd</span><span class="p">)</span>
<span class="n">param</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Parameter containing:
tensor([[-0.7679,  0.3281,  0.8054,  1.1955, -0.6240],
        [ 0.5000, -0.1930, -0.0020, -1.0173, -0.0083],
        [-0.7387,  0.9879,  0.8556,  0.8847,  0.2356],
        [ 0.4543,  1.8703,  1.0567,  0.2428,  1.7751],
        [ 0.5252,  0.1668, -0.9017,  0.5592, -0.6604]], requires_grad=True)
</code></pre></div></div>

<p>텐서와 파라미터의 차이는 <code class="language-plaintext highlighter-rouge">.requires_grad</code>가 True인지 아닌지와 <code class="language-plaintext highlighter-rouge">nn.Module</code>에서 자동으로 학습 가능한 파라미터로 포함하는지 아닌지의 차이가 있다.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">.</span><span class="n">requires_grad</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="p">.</span><span class="n">requires_grad</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>False
True
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">w</span> <span class="o">=</span> <span class="n">param</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">t</span> <span class="o">=</span> <span class="n">t</span>

<span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">()</span>
<span class="nb">list</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">())</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[Parameter containing:
 tensor([[-0.7679,  0.3281,  0.8054,  1.1955, -0.6240],
         [ 0.5000, -0.1930, -0.0020, -1.0173, -0.0083],
         [-0.7387,  0.9879,  0.8556,  0.8847,  0.2356],
         [ 0.4543,  1.8703,  1.0567,  0.2428,  1.7751],
         [ 0.5252,  0.1668, -0.9017,  0.5592, -0.6604]], requires_grad=True)]
</code></pre></div></div>

<h2 id="linear와-conv2d">Linear와 Conv2d</h2>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">fc</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">fc</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">x</span>

<span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">()</span>
<span class="k">print</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">conv</span><span class="p">.</span><span class="n">weight</span><span class="p">),</span> <span class="n">model</span><span class="p">.</span><span class="n">conv</span><span class="p">.</span><span class="n">weight</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">conv</span><span class="p">.</span><span class="n">bias</span><span class="p">),</span> <span class="n">model</span><span class="p">.</span><span class="n">conv</span><span class="p">.</span><span class="n">bias</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>

<span class="k">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">fc</span><span class="p">.</span><span class="n">weight</span><span class="p">),</span> <span class="n">model</span><span class="p">.</span><span class="n">fc</span><span class="p">.</span><span class="n">weight</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">fc</span><span class="p">.</span><span class="n">bias</span><span class="p">),</span> <span class="n">model</span><span class="p">.</span><span class="n">fc</span><span class="p">.</span><span class="n">bias</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Model(
  (conv): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1))
  (fc): Linear(in_features=128, out_features=10, bias=True)
)
&lt;class 'torch.nn.parameter.Parameter'&gt; torch.Size([128, 3, 3, 3])
&lt;class 'torch.nn.parameter.Parameter'&gt; torch.Size([128])
&lt;class 'torch.nn.parameter.Parameter'&gt; torch.Size([10, 128])
&lt;class 'torch.nn.parameter.Parameter'&gt; torch.Size([10])
</code></pre></div></div>

<p>결국 <code class="language-plaintext highlighter-rouge">nn.Conv2d</code>의 weight는 학습 가능한 파라미터이고, 이 파라미터가 cross-correlation 연산에 사용된다.</p>

<h2 id="nnconv2d는-사실-convolution을-하지-않는다"><code class="language-plaintext highlighter-rouge">nn.Conv2d</code>는 사실 convolution을 하지 않는다.</h2>
<p>잠시 신호처리의 관점에서 CNN을 바라보자 그러면 CNN이 왜 필터 혹은 커널로 불리는지 이해할 수 있다.</p>

<p>convolution 이란 어떤 함수 $f$와 커널이라는 함수 $g$중 한 함수를 뒤집고(flip) → 슬라이딩하면서 → 원소별 곱의 합(내적)을 하는 것인데, 이를 하게 되면 어떤 $f$와 $g$가 비슷한 방향을 가질 때 값이 커지게 된다. 이게 신호처리에서 필터링의 원리이고 이 때문에 CNN의 파라미터를 커널, 필터 등으로 부르는 것이다.</p>

\[(f * g)(t) = \int f(\tau) \cdot g(t - \tau) \, d\tau\]

<p>근데 torch의 문서에서 볼 수 있듯이 CNN은 사실 convolution을 하지 않는다. cross-correlation과 convolition의 차이는 무엇일까?
우선, cross-correlation의 수식은 다음과 같다.</p>

\[(f \star g)(t) = \int f(\tau) \cdot g(t + \tau) \, d\tau\]

<p>cross-correlation을 convolution 대신 사용하면 (1) 두 함수 중 하나를 굳이 뒤집을 필요가 없어서 연산이 단순해지고, (2) 커널이 입력과 같은 방향으로 슬라이딩하기 때문에 좀 더 직관적이다. CNN은 어차피 커널 값을 학습으로 찾으므로 flip 여부가 결과에 영향을 주지 않아 cross-correlation을 사용해도 무방하다.</p>

<p>아래는 어떤 필터들에 대하여 convolition과 cross-correlation의 결과 차이이다.</p>

<p><img src="/assets/img/code-review/conv2d/fig1.webp" alt="png" /></p>

<p>결과에서 볼 수 있듯이 결과는 크게 변하지 않고 한번 뒤집고 말고의 차이이다. 역사적으로 신호처리에서 convolution을 먼저 사용했고, CNN도 그 이름을 그대로 차용했기 때문에 관습적으로 convolution이라 부른다.</p>]]></content><author><name>YSPARK</name></author><category term="Code-Review" /><category term="PyTorch-basic" /><summary type="html"><![CDATA[nn.Conv2d가 뭐고, 왜 convolution을 하지 않는지 알아본다.]]></summary></entry><entry><title type="html">[코드 리뷰] Dataloader의 동작과 역할</title><link href="https://kitewatermelon.github.io/code-review/dataloader/" rel="alternate" type="text/html" title="[코드 리뷰] Dataloader의 동작과 역할" /><published>2026-05-13T00:00:00+09:00</published><updated>2026-05-13T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/code-review/dataloader</id><content type="html" xml:base="https://kitewatermelon.github.io/code-review/dataloader/"><![CDATA[<!-- code 다운로드: [📥 eval-no_grad-inference_mode.ipynb 다운로드](assets/code/code-review/eval-no_grad-inference_mode.ipynb) -->

<p>PyTorch 라이브러리를 사용하여 딥러닝 코드를 작성하다보면 필수적으로 Dataset과 Dataloader 클래스를 만나게 된다. 보통의 경우 두 클래스의 정확한 동작과 역할을 잘 알지 못하며, 두 클래스의 역할을 혼용하기도 한다. 본 글을 통해 독자들이 Dataset과 Dataloader의 동작과 역할을 정확히 이해할 수 있길 바란다. 특히 본 글은 Dataloader 클래스에 초점을 둔 채로 작성되었으니, 이 부분 참고하여 읽으면 도움이 될 것 같다.</p>

<h2 id="dataset과-dataloader의-차이">Dataset과 Dataloader의 차이</h2>
<p>Dataset과 Dataloader는 PyTorch 데이터 파이프라인의 핵심 두 축이지만, 각자의 역할은 명확히 구분된다.Dataset은 “무엇을 읽을 것인가”를 정의한다. 디스크에서 이미지를 읽거나, 레이블을 매핑하거나, 전처리(transform)를 적용하는 등 개별 샘플 하나를 어떻게 가져올지를 담당한다. <strong>getitem</strong>(index)를 구현하면 dataset[i]와 같이 특정 인덱스의 샘플을 꺼낼 수 있다.Dataloader는 “어떻게 꺼낼 것인가”를 정의한다. Dataset이 만들어둔 샘플을 어떤 순서로, 몇 개씩, 몇 개의 프로세스로 꺼낼지를 담당한다. 즉 배치 구성·셔플·병렬 로딩 등의 로직은 모두 Dataloader의 몫이다.
두 클래스를 혼용하는 흔한 실수는, 배치 처리나 셔플 로직을 Dataset 내부에 직접 구현하는 것이다. Dataset은 단일 샘플 반환에만 집중하고, 나머지는 Dataloader에 위임하는 것이 올바른 설계다.</p>

<p>이제 Dataloader에 대하여 알아보자.</p>

<p><a href="https://docs.pytorch.org/docs/2.11/data.html#module-torch.utils.data">torch.utils.data</a>의 공식 문서는 다음과 같은 문장과 함께 시작된다.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Eng: 
At the heart of PyTorch data loading utility is the torch.utils.data.DataLoader class. It represents a Python iterable over a dataset, with support for
- map-style and iterable-style datasets,
- customizing data loading order,
- automatic batching,
- single- and multi-process data loading,
- automatic memory pinning.

Kor:
PyTorch 데이터 로딩 유틸리티의 핵심은 torch.utils.data.DataLoader 클래스입니다. 이 클래스는 데이터셋을 순회할 수 있는 Python 이터러블로, 다음 기능들을 지원한다.
- map-style 및 iterable-style 데이터셋
- 데이터 로딩 순서 커스터마이징
- 자동 배치 처리
- 단일/다중 프로세스 데이터 로딩
- 자동 메모리 고정(pinning)
</code></pre></div></div>

<p>이 옵션들은 DataLoader의 생성자 인수를 통해 설정되며, 생성자의 시그니처는 다음과 같다:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)
</code></pre></div></div>

<h2 id="map-style-and-iterable-style-datasets">map-style and iterable-style datasets</h2>
<p>파이썬의 Iterable(이터러블)은 for 문이나 list(), tuple() 등에서 내부의 요소를 한 번에 하나씩 차례대로 반환(순회)할 수 있는 반복 가능한 객체를 말한다. 이들의 특징으로는 내부적으로 <code class="language-plaintext highlighter-rouge">__iter__()</code> 매서드나 <code class="language-plaintext highlighter-rouge">__getitem__()</code> 메서드를 구현하여 인덱싱이 가능한 객체라는 점이다.</p>

<p>Dataset 클래스는 데이터를 어떻게 읽을지 정의하는 클래스로 아래와 같이 두가지로 나뉜다.</p>

<ul>
  <li>Map-style: <code class="language-plaintext highlighter-rouge">__getitem__(), __len__()</code> 인덱스로 접근하며 기본적인 Dataset 클래스가 이해 해당한다. 인덱스가 꼭 정수일 필요는 없으며, 문자열로도 사용 가능하다.</li>
  <li>Iterable-style: <code class="language-plaintext highlighter-rouge">__iter__()</code> 스트림 방식으로 접근하며, IterableDataset 라는 클래스로 정의한다. 흔히 사용되는 방식은 아니고, 랜덤 접근이 비싸거나 불가능한 경우 (DB 스트림, 원격 서버, 실시간 로그 등)에 사용한다.</li>
</ul>

<p>보통 대부분의 데이터셋은 Map-style Dataset을 사용하므로, 해당 클래스에 집중하여 작성하겠다. Iterable-style Dataset을 다룬다면 공식문서를 읽어보는 것을 추천한다.</p>

<h2 id="customizing-data-loading-order">customizing data loading order</h2>
<p>이 섹션은 “데이터를 어떤 순서로 꺼낼 것인가” 를 결정하는 메커니즘을 설명한다. Map-style은 Sampler가 순서를 결정하게 되는데, shuffle=True로 지정을 하게 되면 내부적으로 RandomSampler를 만들어서 사용하게 된다.</p>

<p>아래 표는 PyTorch에서 지원하는 Sampler 클래스들이니 필요한 곳에 사용하면 된다.</p>

<table>
  <thead>
    <tr>
      <th>Sampler 종류</th>
      <th>동작</th>
      <th>주요 파라미터</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">SequentialSampler</code></td>
      <td>항상 같은 순서로 순차 샘플링</td>
      <td><code class="language-plaintext highlighter-rouge">data_source</code></td>
    </tr>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">RandomSampler</code></td>
      <td>랜덤 샘플링 (복원/비복원 선택)</td>
      <td><code class="language-plaintext highlighter-rouge">replacement</code>, <code class="language-plaintext highlighter-rouge">num_samples</code>, <code class="language-plaintext highlighter-rouge">generator</code></td>
    </tr>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">SubsetRandomSampler</code></td>
      <td>주어진 인덱스 목록 내에서 랜덤 샘플링 (비복원)</td>
      <td><code class="language-plaintext highlighter-rouge">indices</code>, <code class="language-plaintext highlighter-rouge">generator</code></td>
    </tr>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">WeightedRandomSampler</code></td>
      <td>가중치 확률 기반 샘플링</td>
      <td><code class="language-plaintext highlighter-rouge">weights</code>, <code class="language-plaintext highlighter-rouge">num_samples</code>, <code class="language-plaintext highlighter-rouge">replacement</code></td>
    </tr>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">BatchSampler</code></td>
      <td>다른 Sampler를 감싸서 배치 단위 인덱스 반환</td>
      <td><code class="language-plaintext highlighter-rouge">sampler</code>, <code class="language-plaintext highlighter-rouge">batch_size</code>, <code class="language-plaintext highlighter-rouge">drop_last</code></td>
    </tr>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">DistributedSampler</code></td>
      <td>DDP 환경에서 프로세스별 데이터 구간 분할</td>
      <td><code class="language-plaintext highlighter-rouge">num_replicas</code>, <code class="language-plaintext highlighter-rouge">rank</code>, <code class="language-plaintext highlighter-rouge">shuffle</code>, <code class="language-plaintext highlighter-rouge">seed</code>, <code class="language-plaintext highlighter-rouge">drop_last</code></td>
    </tr>
  </tbody>
</table>

<h2 id="automatic-batching">automatic batching</h2>
<p>DataLoader는 <code class="language-plaintext highlighter-rouge">batch_size</code>, <code class="language-plaintext highlighter-rouge">drop_last</code>, <code class="language-plaintext highlighter-rouge">batch_sampler</code>, <code class="language-plaintext highlighter-rouge">collate_fn</code> 인수를 통해 개별로 가져온 데이터 샘플을 자동으로 배치로 묶는 기능을 지원한다.</p>

<h3 id="automatic-batching-default">Automatic batching (default)</h3>
<p>가장 일반적인 방식으로, 미니배치 단위로 데이터를 가져와 배치 샘플로 묶는다. 이때, <code class="language-plaintext highlighter-rouge">batch_size</code>와 <code class="language-plaintext highlighter-rouge">drop_last</code>는 본질적으로 sampler로부터 batch_sampler를 구성하는 데 사용된다. (<code class="language-plaintext highlighter-rouge">batch_size</code>와 <code class="language-plaintext highlighter-rouge">drop_last</code>는 사용자 편의를 위한 shortcut이고, 실제로는 batch_sampler로 변환되어 동작한다는 뜻)</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="n">indices</span> <span class="ow">in</span> <span class="n">batch_sampler</span><span class="p">:</span>
    <span class="k">yield</span> <span class="n">collate_fn</span><span class="p">([</span><span class="n">dataset</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">indices</span><span class="p">])</span>
</code></pre></div></div>
<ol>
  <li>sampler가 인덱스를 하나씩 생성</li>
  <li>batch_sampler가 그 인덱스를 batch_size개씩 묶음(drop_last도 여기서 처리)</li>
  <li>dataset[i]로 각 샘플을 가져옴</li>
  <li>collate_fn이 샘플 리스트를 하나의 배치 텐서로 변환</li>
</ol>

<h3 id="collate_fn-다루기">collate_fn 다루기</h3>
<p>Automatic batching이 활성화되면 <code class="language-plaintext highlighter-rouge">collate_fn</code>은 샘플 목록을 받아 배치로 묶어 반환한다. 기본 collate_fn(default_collate())의 동작은 다음과 같다:</p>

<ol>
  <li>배치 차원 추가: 항상 새로운 첫 번째 차원을 배치 차원으로 추가</li>
  <li>자동 타입 변환: NumPy 배열, Python 수치값 → PyTorch 텐서로 자동 변환</li>
  <li>데이터 구조 보존: dict면 dict, list면 list, tuple이면 tuple 구조를 그대로 유지하되 값은 배치 텐서로</li>
</ol>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># 개별 샘플 3개
</span><span class="p">[(</span><span class="n">img1</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="n">img2</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="p">(</span><span class="n">img3</span><span class="p">,</span> <span class="mi">0</span><span class="p">)]</span>

<span class="c1"># collate_fn 적용 후
</span><span class="p">(</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">stack</span><span class="p">([</span><span class="n">img1</span><span class="p">,</span> <span class="n">img2</span><span class="p">,</span> <span class="n">img3</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span>  <span class="c1"># shape: (3, C, H, W)
</span>    <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span>           <span class="c1"># shape: (3,)
</span><span class="p">)</span>
</code></pre></div></div>

<p>Q1. 데이터셋 크기가 64개이고 batch_size=8일 때, DataLoader를 순회하면 몇 번 iterate되며, 각 iteration에서 반환되는 튜플의 구조는?</p>

<h2 id="single--and-multi-process-data-loading">single- and multi-process data loading</h2>
<p>Python은 GIL 정책으로 인해 스레드간 완전한 병렬화가 불가능하다. 데이터 로딩이 연산 코드를 블로킹하는 것을 방지하기 위해, PyTorch는 <code class="language-plaintext highlighter-rouge">num_workers</code>를 양의 정수로 설정하는 것만으로 다중 프로세스 데이터 로딩으로 간단히 전환할 수 있다.</p>

<p>Map-style 데이터셋의 경우, 메인 프로세스가 sampler로 인덱스를 생성해 워커에 전달한다. 따라서 셔플 랜덤화는 메인 프로세스에서 처리되며, 워커는 할당받은 인덱스에 따라 데이터를 로딩하게 된다.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>메인 프로세스
  └─ Sampler → [3, 1, 4, 2, 0, ...] 인덱스 생성
       └─ 워커 0 → dataset[3], dataset[1]
       └─ 워커 1 → dataset[4], dataset[2]
       └─ 워커 2 → dataset[0], ...
</code></pre></div></div>
<ul>
  <li>해당 작업은 <code class="language-plaintext highlighter-rouge">multiprocessing</code> 라이브러리에 의존하므로, 윈도우와 UNIX의 동작이 다르다.</li>
  <li>기본적으로 각 워커의 PyTorch 시드는 base_seed + worker_id로 설정된다.</li>
  <li>Q2. 왜 워커마다 서로 다른 난수를 사용할까?</li>
</ul>

<h2 id="automatic-memory-pinning">automatic memory pinning</h2>
<p>Host → GPU 복사는 pinned(page-locked) 메모리에서 시작할 때 훨씬 빠르다.</p>

<p>데이터 로딩 시, DataLoader에 <code class="language-plaintext highlighter-rouge">pin_memory=True</code>를 전달하면 가져온 데이터 텐서를 자동으로 pinned 메모리에 올려 CUDA GPU로의 데이터 전송 속도를 높일 수 있다.
기본 메모리 고정 로직은 텐서, 그리고 텐서를 담은 map/iterable만 인식한다.
텐서나 스토리지를 고정하면 비동기 GPU 복사도 사용할 수 있습니다. to()나 cuda() 호출 시 <code class="language-plaintext highlighter-rouge">non_blocking=True</code> 인수를 추가하면 되며, 이를 통해 데이터 전송과 연산을 오버랩시킬 수 있다.
DataLoader 생성자에 pin_memory=True를 전달하면 DataLoader가 pinned 메모리에 배치를 올려서 반환한다.</p>

<ul>
  <li>Pinned memory는 RAM을 고정 점유하므로 메모리 부족 시 심각한 문제가 발생한다.</li>
  <li>Pinning 자체가 비싼 연산이라 데이터가 작거나 CPU 병목이 없으면 오히려 손해가 될 수 있다.</li>
  <li>다중 프로세스 로딩에서 CUDA 텐서 직접 반환보다 <code class="language-plaintext highlighter-rouge">pin_memory=True</code> 사용이 권장된다.</li>
</ul>

<p>오늘은 PyTorch Dataloader의 핵심 동작 방식을 살펴보았다. Dataset이 개별 샘플을 정의하는 역할이라면, Dataloader는 그 샘플을 어떤 순서로, 몇 개씩, 몇 개의 프로세스로 꺼낼지를 결정한다. Sampler로 순서를 제어하고, collate_fn으로 배치를 구성하며, num_workers와 pin_memory로 로딩 성능을 끌어올리는 구조를 이해하면 데이터 파이프라인 병목을 진단하고 최적화하는 데 큰 도움이 된다.</p>

<p>A1.</p>
<ul>
  <li>iterate 횟수: 8번 (64 / 8)</li>
  <li>각 iteration의 반환값:</li>
</ul>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="p">(</span>
    <span class="n">tensor</span><span class="p">(...),</span>  <span class="c1"># shape: (8, C, H, W)  ← 이미지 배치
</span>    <span class="n">tensor</span><span class="p">(...)</span>   <span class="c1"># shape: (8,)           ← 레이블 배치
</span><span class="p">)</span>
</code></pre></div></div>

<p>튜플 1개 안에 배치 텐서 2개.</p>

<p>A2. 데이터 augmentation 시 다양성을 확보하기 위해, 같은 시드를 쓰면 모든 워커가 동일한 augmentation을 적용하게 되어 배치 내 다양성이 사라진다.</p>]]></content><author><name>YSPARK</name></author><category term="Code-Review" /><category term="PyTorch-basic" /><summary type="html"><![CDATA[Dataloader의 동작과 역할을 알아본다.]]></summary></entry><entry><title type="html">[코드 리뷰] eval() vs no_grad() vs inference_mode()</title><link href="https://kitewatermelon.github.io/code-review/eval-no_grad-inference_mode/" rel="alternate" type="text/html" title="[코드 리뷰] eval() vs no_grad() vs inference_mode()" /><published>2026-04-09T00:00:00+09:00</published><updated>2026-04-10T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/code-review/eval-no_grad-inference_mode</id><content type="html" xml:base="https://kitewatermelon.github.io/code-review/eval-no_grad-inference_mode/"><![CDATA[<p>code 다운로드: <a href="assets/code/code-review/eval-no_grad-inference_mode.ipynb">📥 eval-no_grad-inference_mode.ipynb 다운로드</a></p>

<p>딥러닝 코드를 보다보면 model.eval(), with torch.no_grad() 그리고 with torch.inference_mode()를 많이 보곤한다. 오늘은 이 세 함수의 역할에 대하여 알아본다. 또한 with torch.inference_mode()가 왜</p>

<h2 id="eval">eval()</h2>
<p>nn.Dropout()과 nn.BatchNormNd() 클래스는 대표적으로 train과 eval일때 역할이 다른 함수이다.</p>

<p>nn.Dropout의 <a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html">공식 문서</a>에 따르면</p>
<ul>
  <li>train 시 입력의 일부 요소를 0으로 바꾸고 $\frac{1}{1-p}$로 나머지 값들을 scaling한다.</li>
  <li>eval 시에는 모든 입력을 그대로 사용한다.</li>
  <li>이 방법을 사용하면 eval()시에 스케일링을 따로 하지 않아도 되어 불필요한 오버헤드를 줄이는 데 도와준다.</li>
</ul>

<p>nn.BatchNormNd의 <a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html">공식 문서</a>에 따르면</p>
<ul>
  <li>train 시 var를 두번 계산하는데, forward 계산용: biased (N으로 나눔, correction=0)과 running_var 저장용: unbiased (N-1로 나눔, correction=1)</li>
  <li>eval 시에는 train때 쌓아둔 running mean/var로 정규화하여 사용한다.</li>
</ul>

<p>위와 같이 train과 eval의 동작이 다른 모듈을 효과적으로 train과 eval로 제어하기 위해 .train()과 .eval()을 사용하게 된다. 아래 예제 코드가 도움이 되길 바란다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>

<span class="c1"># ============================================================
# 1. model.eval() — BN/Dropout 동작 모드 전환 (gradient와 무관)
# ============================================================
</span>
<span class="k">print</span><span class="p">(</span><span class="s">"="</span> <span class="o">*</span> <span class="mi">60</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"1. model.eval()은 BN/Dropout 모드만 바꾼다"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"="</span> <span class="o">*</span> <span class="mi">60</span><span class="p">)</span>

<span class="c1"># --- Dropout ---
</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.8</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">8</span><span class="p">)</span>

<span class="n">dropout</span><span class="p">.</span><span class="n">train</span><span class="p">()</span>
<span class="k">print</span><span class="p">(</span><span class="s">"train:"</span><span class="p">,</span> <span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>

<span class="n">dropout</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
<span class="k">print</span><span class="p">(</span><span class="s">"eval: "</span><span class="p">,</span> <span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>

<span class="c1"># --- BatchNorm ---
</span><span class="n">bn</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">BatchNorm1d</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span>
    <span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> <span class="o">+</span> <span class="n">i</span>
    <span class="n">bn</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"  step </span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s">: running_mean = </span><span class="si">{</span><span class="n">bn</span><span class="p">.</span><span class="n">running_mean</span><span class="p">.</span><span class="n">tolist</span><span class="p">()</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>

<span class="n">bn</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
<span class="k">print</span><span class="p">(</span><span class="s">"eval output:"</span><span class="p">,</span> <span class="n">bn</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">)))</span>

<span class="c1"># --- eval이어도 gradient는 살아있다 ---
</span><span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">model</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
<span class="n">inp</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">inp</span><span class="p">)</span>
<span class="n">out</span><span class="p">.</span><span class="nb">sum</span><span class="p">().</span><span class="n">backward</span><span class="p">()</span>  <span class="c1"># 정상 동작!
</span><span class="k">print</span><span class="p">(</span><span class="s">"eval 모드에서 backward:"</span><span class="p">,</span> <span class="n">inp</span><span class="p">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">)</span>  <span class="c1"># True
</span></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>============================================================
1. model.eval()은 BN/Dropout 모드만 바꾼다
============================================================
train: tensor([[5., 0., 0., 5., 0., 0., 5., 0.]])
eval:  tensor([[1., 1., 1., 1., 1., 1., 1., 1.]])
  step 0: running_mean = [0.02741658128798008, 0.01935068890452385, -0.012505004182457924, 0.10418272018432617]
  step 1: running_mean = [0.10060923546552658, 0.13381418585777283, 0.021229349076747894, 0.23009954392910004]
  step 2: running_mean = [0.3041425943374634, 0.3139670193195343, 0.279056578874588, 0.3843126893043518]
  step 3: running_mean = [0.5296856760978699, 0.5720039010047913, 0.5548611283302307, 0.650330126285553]
  step 4: running_mean = [0.8712624907493591, 0.9221630096435547, 0.8863516449928284, 0.9963715672492981]
eval output: tensor([[-0.9029, -0.9566, -0.9251, -1.0156],
        [-0.9029, -0.9566, -0.9251, -1.0156]],
       grad_fn=&lt;NativeBatchNormBackward0&gt;)
eval 모드에서 backward: True
</code></pre></div></div>

<h2 id="torchno_grad">torch.no_grad()</h2>
<p>torch.no_grad()의 <a href="https://docs.pytorch.org/docs/stable/generated/torch.no_grad.html">공식 문서</a>에 따르면 no_grad()는 gradient calculation을 비활성화하는 Context-manager이다. 이를 통해 memory consumption을 줄이기 때문에 추론 단계에서 사용하는 것이 권장된다.</p>

<p>따라서 .eval()로 각 모듈을 eval 모드로 전환하고, with torch.no_grad() 안에서 추론을 함으로 계산 효율을 높혀준다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># ============================================================
# 2. torch.no_grad() — grad_fn 생략, version tracking은 유지
# ============================================================
</span>
<span class="k">print</span><span class="p">(</span><span class="s">"="</span> <span class="o">*</span> <span class="mi">60</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"2. torch.no_grad()"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"="</span> <span class="o">*</span> <span class="mi">60</span><span class="p">)</span>

<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
    <span class="n">a</span> <span class="o">=</span> <span class="n">x</span> <span class="o">*</span> <span class="mi">2</span>

<span class="k">print</span><span class="p">(</span><span class="s">"grad_fn:"</span><span class="p">,</span> <span class="n">a</span><span class="p">.</span><span class="n">grad_fn</span><span class="p">)</span>           <span class="c1"># None (autograd 그래프 안 만듦)
</span><span class="k">print</span><span class="p">(</span><span class="s">"requires_grad:"</span><span class="p">,</span> <span class="n">a</span><span class="p">.</span><span class="n">requires_grad</span><span class="p">)</span> <span class="c1"># False
</span><span class="k">print</span><span class="p">(</span><span class="s">"_version:"</span><span class="p">,</span> <span class="n">a</span><span class="p">.</span><span class="n">_version</span><span class="p">)</span>           <span class="c1"># 0 (version tracking은 살아있음)
</span>
<span class="c1"># 블록 밖에서 다시 autograd 텐서와 연산하면 grad 붙음
</span><span class="n">b</span> <span class="o">=</span> <span class="n">a</span> <span class="o">+</span> <span class="n">x</span>
<span class="k">print</span><span class="p">(</span><span class="s">"밖에서 a+x requires_grad:"</span><span class="p">,</span> <span class="n">b</span><span class="p">.</span><span class="n">requires_grad</span><span class="p">)</span>  <span class="c1"># True
</span></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>============================================================
2. torch.no_grad()
============================================================
grad_fn: None
requires_grad: False
_version: 0
밖에서 a+x requires_grad: True
</code></pre></div></div>

<h2 id="inference_mode">inference_mode()</h2>
<p>torch.inference_mode()는 no_grad()와 비슷한데, 추가적인 오버헤드를 줄여준다. <a href="https://docs.pytorch.org/docs/stable/generated/torch.autograd.grad_mode.inference_mode.html">공식 문서</a>에서 말하는 추가적인 오버헤드란 view tracking과 version counter bumps을 비활성화 하는 것이다.</p>

<p><a href="https://docs.pytorch.org/serve/performance_checklist.html">이 글도</a> 한번 쯤 읽는 것이 좋아보인다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># ============================================================
# 3. torch.inference_mode() — version tracking까지 제거
# ============================================================
</span>
<span class="k">print</span><span class="p">(</span><span class="s">"="</span> <span class="o">*</span> <span class="mi">60</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"3. torch.inference_mode()"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"="</span> <span class="o">*</span> <span class="mi">60</span><span class="p">)</span>

<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">inference_mode</span><span class="p">():</span>
    <span class="n">c</span> <span class="o">=</span> <span class="n">x</span> <span class="o">*</span> <span class="mi">2</span>

<span class="k">print</span><span class="p">(</span><span class="s">"grad_fn:"</span><span class="p">,</span> <span class="n">c</span><span class="p">.</span><span class="n">grad_fn</span><span class="p">)</span>             <span class="c1"># None
</span><span class="k">print</span><span class="p">(</span><span class="s">"requires_grad:"</span><span class="p">,</span> <span class="n">c</span><span class="p">.</span><span class="n">requires_grad</span><span class="p">)</span>   <span class="c1"># False
</span><span class="k">print</span><span class="p">(</span><span class="s">"is_inference:"</span><span class="p">,</span> <span class="n">c</span><span class="p">.</span><span class="n">is_inference</span><span class="p">())</span>   <span class="c1"># True
</span>
<span class="k">try</span><span class="p">:</span>
    <span class="k">print</span><span class="p">(</span><span class="n">c</span><span class="p">.</span><span class="n">_version</span><span class="p">)</span>
<span class="k">except</span> <span class="nb">RuntimeError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"_version 접근 에러: </span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
    <span class="c1"># "Inference tensors do not track version counter."
</span></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>============================================================
3. torch.inference_mode()
============================================================
grad_fn: None
requires_grad: False
is_inference: True
_version 접근 에러: Inference tensors do not track version counter.
</code></pre></div></div>

<h2 id="version-counter">version counter?</h2>
<p>version counter의 역할은 다음과 같다. tensor의 내부 구현으로 tensor._version으로 read-only로 접근 가능하다. <a href="https://discuss.pytorch.org/t/how-to-get-the-version-numbers-of-a-modules-parameters/90726/2">참고</a></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># ============================================================
# 4. version counter가 왜 필요한지
# ============================================================
</span>
<span class="k">print</span><span class="p">(</span><span class="s">"="</span> <span class="o">*</span> <span class="mi">60</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"4. version counter의 역할"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"="</span> <span class="o">*</span> <span class="mi">60</span><span class="p">)</span>

<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">x</span> <span class="o">*</span> <span class="mi">2</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">y</span> <span class="o">**</span> <span class="mi">2</span>   <span class="c1"># dz/dy = 2y → backward 때 y의 값이 필요
</span><span class="k">print</span><span class="p">(</span><span class="s">"y._version:"</span><span class="p">,</span> <span class="n">y</span><span class="p">.</span><span class="n">_version</span><span class="p">)</span>  <span class="c1"># 0
</span>
<span class="n">y</span><span class="p">.</span><span class="n">mul_</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"y._version:"</span><span class="p">,</span> <span class="n">y</span><span class="p">.</span><span class="n">_version</span><span class="p">)</span>  <span class="c1"># 1
</span>
<span class="k">try</span><span class="p">:</span>
    <span class="n">z</span><span class="p">.</span><span class="n">backward</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">3</span><span class="p">))</span>
<span class="k">except</span> <span class="nb">RuntimeError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"backward 에러: </span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>============================================================
4. version counter의 역할
============================================================
y._version: 0
y._version: 1
backward 에러: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3]], which is output 0 of MulBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True, check_nan=False).
</code></pre></div></div>]]></content><author><name>YSPARK</name></author><category term="Code-Review" /><category term="PyTorch-basic" /><summary type="html"><![CDATA[eval() vs no_grad() vs inference_mode()의 차이를 알아본다.]]></summary></entry><entry><title type="html">[코드 리뷰] I-JEPA-(1) overall train code</title><link href="https://kitewatermelon.github.io/code-review/ijepa-1/" rel="alternate" type="text/html" title="[코드 리뷰] I-JEPA-(1) overall train code" /><published>2026-03-20T00:00:00+09:00</published><updated>2026-03-20T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/code-review/ijepa-1</id><content type="html" xml:base="https://kitewatermelon.github.io/code-review/ijepa-1/"><![CDATA[<p>실습 예제: <a href="https://github.com/facebookresearch/ijepa/tree/main">I-JEPA official repository</a>
논문 리뷰: <a href="/paper-review/ijepa/">I-JEPA 논문 리뷰</a></p>

<h3 id="0-i-jepa">0. I-JEPA</h3>

<p>I-JEPA는 the target block representations를 single context block으로 부터 predict 하는 것을 목표로 하는 self-supervised learning architecture이다.</p>

<p>따라서 target block representation을 만들기 위한 target encoder와 single context block을 만들기 위한 context encoder, 그리고 predict를 하기 위한 predictor가 필요하다.</p>

<p>이번 코드 리뷰는 공식 repo의 src/train.py 아래 5가지로 나뉘며 필요한 부분을 주석으로 설명한다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">Epoch</span> <span class="n">loop</span>    
  <span class="err">└─</span> <span class="n">Iteration</span> <span class="n">loop</span> <span class="p">(</span><span class="n">배치마다</span><span class="p">)</span>   
    <span class="err">├─</span> <span class="mf">1.</span> <span class="n">데이터</span> <span class="n">로드</span> <span class="p">(</span><span class="n">이미지</span> <span class="o">+</span> <span class="n">마스크</span><span class="p">)</span>   
    <span class="err">├─</span> <span class="mf">2.</span> <span class="n">Forward</span> <span class="p">(</span><span class="n">Target</span> <span class="o">/</span> <span class="n">Context</span><span class="p">)</span>   
    <span class="err">├─</span> <span class="mf">3.</span> <span class="n">Loss</span> <span class="n">계산</span>   
    <span class="err">├─</span> <span class="mf">4.</span> <span class="n">Backward</span> <span class="o">&amp;</span> <span class="n">Optimizer</span> <span class="n">step</span>   
    <span class="err">└─</span> <span class="mf">5.</span> <span class="n">Target</span> <span class="n">encoder</span> <span class="n">momentum</span> <span class="n">update</span> <span class="p">(</span><span class="n">EMA</span><span class="p">)</span>   
</code></pre></div></div>

<h3 id="1-데이터-로드-이미지--마스크">1. 데이터 로드 (이미지 + 마스크)</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># -- TRAINING LOOP
</span><span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">start_epoch</span><span class="p">,</span> <span class="n">num_epochs</span><span class="p">):</span>
    <span class="n">logger</span><span class="p">.</span><span class="n">info</span><span class="p">(</span><span class="s">'Epoch %d'</span> <span class="o">%</span> <span class="p">(</span><span class="n">epoch</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>

    <span class="c1"># -- update distributed-data-loader epoch
</span>    <span class="n">unsupervised_sampler</span><span class="p">.</span><span class="n">set_epoch</span><span class="p">(</span><span class="n">epoch</span><span class="p">)</span>

    <span class="n">loss_meter</span> <span class="o">=</span> <span class="n">AverageMeter</span><span class="p">()</span>
    <span class="n">maskA_meter</span> <span class="o">=</span> <span class="n">AverageMeter</span><span class="p">()</span>
    <span class="n">maskB_meter</span> <span class="o">=</span> <span class="n">AverageMeter</span><span class="p">()</span>
    <span class="n">time_meter</span> <span class="o">=</span> <span class="n">AverageMeter</span><span class="p">()</span>

    <span class="k">for</span> <span class="n">itr</span><span class="p">,</span> <span class="p">(</span><span class="n">udata</span><span class="p">,</span> <span class="n">masks_enc</span><span class="p">,</span> <span class="n">masks_pred</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">unsupervised_loader</span><span class="p">):</span>

        <span class="k">def</span> <span class="nf">load_imgs</span><span class="p">():</span>
            <span class="c1"># -- unsupervised imgs
</span>            <span class="n">imgs</span> <span class="o">=</span> <span class="n">udata</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
            <span class="n">masks_1</span> <span class="o">=</span> <span class="p">[</span><span class="n">u</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="k">for</span> <span class="n">u</span> <span class="ow">in</span> <span class="n">masks_enc</span><span class="p">]</span>
            <span class="n">masks_2</span> <span class="o">=</span> <span class="p">[</span><span class="n">u</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="k">for</span> <span class="n">u</span> <span class="ow">in</span> <span class="n">masks_pred</span><span class="p">]</span>
            <span class="k">return</span> <span class="p">(</span><span class="n">imgs</span><span class="p">,</span> <span class="n">masks_1</span><span class="p">,</span> <span class="n">masks_2</span><span class="p">)</span>

        <span class="c1"># 원본 이미지, context 위치 정보, target 위치 정보
</span>        <span class="n">imgs</span><span class="p">,</span> <span class="n">masks_enc</span><span class="p">,</span> <span class="n">masks_pred</span> <span class="o">=</span> <span class="n">load_imgs</span><span class="p">()</span> 
        <span class="n">maskA_meter</span><span class="p">.</span><span class="n">update</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">masks_enc</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]))</span>
        <span class="n">maskB_meter</span><span class="p">.</span><span class="n">update</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">masks_pred</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]))</span>

        <span class="k">def</span> <span class="nf">train_step</span><span class="p">():</span>
            <span class="n">_new_lr</span> <span class="o">=</span> <span class="n">scheduler</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
            <span class="n">_new_wd</span> <span class="o">=</span> <span class="n">wd_scheduler</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
</code></pre></div></div>

<h3 id="2-forward-target--context">2. Forward (Target / Context)</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>            <span class="c1"># --
</span>            <span class="k">def</span> <span class="nf">forward_target</span><span class="p">():</span>
                <span class="c1"># target encoder는 EMA based optimization이라 grad 계산 필요 없음.
</span>                <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
                    <span class="n">h</span> <span class="o">=</span> <span class="n">target_encoder</span><span class="p">(</span><span class="n">imgs</span><span class="p">)</span>
                    <span class="c1"># normalize over feature-dim
</span>                    <span class="c1"># 각 토큰 벡터를 독립적으로 평균 0, 분산 1로 만들어 예측 타깃의 스케일을 고정
</span>                    <span class="n">h</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">layer_norm</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="p">(</span><span class="n">h</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),))</span> 
                    <span class="n">B</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
                    <span class="c1"># -- create targets (masked regions of h)
</span>                    <span class="c1"># target embedding에서 예측 대상 위치(masks_pred)만 추출
</span>                    <span class="n">h</span> <span class="o">=</span> <span class="n">apply_masks</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">masks_pred</span><span class="p">)</span>
                    
                    <span class="c1"># context block 수(len(masks_enc))만큼 target을 복제하여 shape 맞춤.
</span>                    <span class="c1"># 현재 공식 구현에서는 len(masks_enc)==1이지만,
</span>                    <span class="c1"># multi-context 확장을 고려한 일반화 코드임.
</span>                    <span class="n">h</span> <span class="o">=</span> <span class="n">repeat_interleave_batch</span><span class="p">(</span>
                        <span class="n">h</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> 
                        <span class="n">repeat</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">masks_enc</span><span class="p">)</span>
                      <span class="p">)</span>
                    <span class="k">return</span> <span class="n">h</span>

            <span class="k">def</span> <span class="nf">forward_context</span><span class="p">():</span>
                <span class="c1"># 인코딩 후
</span>                <span class="n">z</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">(</span><span class="n">imgs</span><span class="p">,</span> <span class="n">masks_enc</span><span class="p">)</span>
                <span class="c1"># context embedding, context 위치 정보, target 위치 정보를 이용하여 예측
</span>                <span class="n">z</span> <span class="o">=</span> <span class="n">predictor</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">masks_enc</span><span class="p">,</span> <span class="n">masks_pred</span><span class="p">)</span>
                <span class="k">return</span> <span class="n">z</span>
</code></pre></div></div>

<h3 id="3-loss-계산">3. Loss 계산</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>            <span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">h</span><span class="p">):</span>
                <span class="c1"># 논문은 L2이나 실제 구현으로는 L1
</span>                <span class="c1"># gradient 안정성을 위한 것으로 예측됨. 
</span>                <span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">smooth_l1_loss</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
                <span class="n">loss</span> <span class="o">=</span> <span class="n">AllReduce</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span>
                <span class="k">return</span> <span class="n">loss</span>
</code></pre></div></div>

<h3 id="4-backward--optimizer-step">4. Backward &amp; Optimizer step</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>            <span class="c1"># Step 1. Forward
</span>            <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">amp</span><span class="p">.</span><span class="n">autocast</span><span class="p">(</span>
                <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">bfloat16</span><span class="p">,</span> 
                <span class="n">enabled</span><span class="o">=</span><span class="n">use_bfloat16</span>
              <span class="p">):</span>
                <span class="n">h</span> <span class="o">=</span> <span class="n">forward_target</span><span class="p">()</span>
                <span class="n">z</span> <span class="o">=</span> <span class="n">forward_context</span><span class="p">()</span>
                <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>

            <span class="c1">#  Step 2. Backward &amp; step (context encoder와 prediction encoder만 gradient based optimization)
</span>            <span class="k">if</span> <span class="n">use_bfloat16</span><span class="p">:</span>
                <span class="n">scaler</span><span class="p">.</span><span class="n">scale</span><span class="p">(</span><span class="n">loss</span><span class="p">).</span><span class="n">backward</span><span class="p">()</span>
                <span class="n">scaler</span><span class="p">.</span><span class="n">step</span><span class="p">(</span><span class="n">optimizer</span><span class="p">)</span>
                <span class="n">scaler</span><span class="p">.</span><span class="n">update</span><span class="p">()</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
                <span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
            <span class="n">grad_stats</span> <span class="o">=</span> <span class="n">grad_logger</span><span class="p">(</span>
                <span class="n">encoder</span><span class="p">.</span><span class="n">named_parameters</span><span class="p">()</span>
              <span class="p">)</span>
            <span class="n">optimizer</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
</code></pre></div></div>

<h3 id="5-target-encoder-momentum-update-ema">5. Target encoder momentum update (EMA)</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>            <span class="c1"># Step 3. momentum update of target encoder는 EMA based optimization
</span>            <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
                <span class="n">m</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="n">momentum_scheduler</span><span class="p">)</span>
                <span class="k">for</span> <span class="n">param_q</span><span class="p">,</span> <span class="n">param_k</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span>
                    <span class="n">encoder</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="c1"># 파라미터를 들고 와서
</span>                    <span class="n">target_encoder</span><span class="p">.</span><span class="n">parameters</span><span class="p">()</span>
                  <span class="p">):</span>
                    <span class="c1"># θk​←m⋅θk​+(1−m)⋅θq​
</span>                    <span class="n">param_k</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">mul_</span><span class="p">(</span><span class="n">m</span><span class="p">).</span><span class="n">add_</span><span class="p">(</span>
                        <span class="p">(</span><span class="mf">1.</span><span class="o">-</span><span class="n">m</span><span class="p">)</span> <span class="o">*</span> <span class="n">param_q</span><span class="p">.</span><span class="n">detach</span><span class="p">().</span><span class="n">data</span> 
                      <span class="p">)</span>

            <span class="k">return</span> <span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="n">loss</span><span class="p">),</span> <span class="n">_new_lr</span><span class="p">,</span> <span class="n">_new_wd</span><span class="p">,</span> <span class="n">grad_stats</span><span class="p">)</span>
</code></pre></div></div>]]></content><author><name>YSPARK</name></author><category term="Code-Review" /><category term="JEPA" /><summary type="html"><![CDATA[I-JEPA의 전반적인 학습 코드를 본다.]]></summary></entry><entry><title type="html">[논문리뷰] Brain Network Transformer</title><link href="https://kitewatermelon.github.io/paper-review/bnt/" rel="alternate" type="text/html" title="[논문리뷰] Brain Network Transformer" /><published>2026-03-16T00:00:00+09:00</published><updated>2026-03-16T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/paper-review/bnt</id><content type="html" xml:base="https://kitewatermelon.github.io/paper-review/bnt/"><![CDATA[<blockquote>
  <p>NeurIPS 2022 [<a href="https://arxiv.org/pdf/2210.06681">Paper</a>] [<a href="https://github.com/Wayfear/BrainNetworkTransformer">GitHub</a>]<br />
 Xuan Kan, Wei Dai, Hejie Cui, Zilong Zhang, Ying Guo, Carl Yang
 15 Oct 2022</p>
</blockquote>

<h2 id="1-introduction">1. Introduction</h2>
<p>Brain network analysis는 신경 과학자들에게 사람의 뇌 구조 이해와 임상 결과 예측을 위한 흥미로운 연구이다. 다양한 모달 중 fMRI가 뇌 네트워크 구조를 위해 주로 사용된다. fMRI의 노드는 atlas 기반의 ROIs로 정의되고 각 노드간 BOLD 신호의 상관관계로 엣지를 정의한다. 연구자들은 action, language, and vision 같은 cognitive-related tasks가 일어날때 특정 영역이 동시에 활성화되거나 비활성화 되는 것을 관측해왔다. 이런 패턴을 바탕으로 brain regions을 다양한 기능적 모듈로 분류하여 질병을 분석하여 진단, 진행 이해 및 치료에 사용할 수 있다.</p>

<p>Transfer의 여러 분야에서의 성공을 해왔고 그 중 GAT라는 모델이 처음으로 GNNs의 영역에 적용되었다. 그러나 이는 이웃 노드의 local 구조만 고려하였다. Graph Transformer는 edge information을 attention mechanism에 집어넣었고 각 노드의 eigenvectors를 position embedding(PE)로 활용하였다. SAN은 eigenvalue와 eigenvectors를 동시에 고려하여 PE를 더 강화하고 attention을 local 구조에서 global 구조로 확장하였다. Graphomer는 독창적인 메커니즘으로 OGB Large-Scale Challenge에서 우승하였다.</p>

<p>그러나 brain networks는 기존 Graph Transformer 모델이 실용적이지 않은 독특한 특성이 여럿 있다.</p>

<ol>
  <li>주로 사용되는 방법이 ROIs 간의 BOLD 신호의 correlation이다. 이는 centrality, spatial 그리고 edge encoding 같은 디자인을 방해한다. 왜냐하면 각 노드들은 같은 차수를 가지고 모든 노드간 single hop으로 연결되어 있기 때문이다.</li>
  <li>기존 Graph transformer models는 eigenvalue와 eigenvectors를 주로 PE에 사용한다. 왜냐하면 그들은 각 노드의 identity와 positional information을 제공하기 때문이다. 그러나 brain network에서는, brain network adjacency matrix의 각 노드의 해당 행으로 정의되는 connection profile이 가장 효과적인 node feature로 인식된다. 이 node feature는 구조적, 위치적 정보 둘 다 자연스럽게 인코딩한다. 이는 앞서 설명한 PE 디자인이 중복으로 여겨 진다.</li>
  <li>Scalability도 중요하다. 보통은 노드와 엣지의 수가 50과 2500개 미만인데, brain network는 atlas에 따라 100~400 개의 노드를 가지고 이는 최대 160k개의 엣지가 생긴다. 그러므로 현존하는 Graph transformer model으로 모든 엣지의 기능 생성과 같은 작업이 불가능하지 않더라도 시간이 많이 걸릴 수 있다.</li>
</ol>

<p>본 논문에서는 brain network analysis를 위해 brain network의 독특한 특성을 transformer-based moedl의 힘으로 완전히 해방하는 BRAIN NETWORK TRANSFORMER (BRAINNETTF, 다른 논무에서는 주로 BNT로 불림)을 제안한다. 특히, 기존 GNN에서의 발견을 통해 connection profiles의 초기 node feature를 효과적으로 초기화하는 방법을 제안한다.</p>

<p>한 단계 나아가서 brain network analysis에 GNNs를 사용할 때는 학습된 node embedding을 기반으로 readout function을 통해 graph-level embedding을 생성해야 한다. brain network의 특성상 동일한 기능적 모듈에 속하는 노드들은 다양한 자극에 대한 활성화 및 비활성화 반응에서 유사한 행동 양식을 공유하는 경우가 많다. 이를 위해 ORTHONORMAL CLUSTERING READOUT을 설계하여 노드들의 cluster에서 graph-level embedding을 pooling한다.</p>

<p>마지막으로 공개 데이터셋의 부족은 brain network analysis에서 무시할 수 없는 과제이다. 예를 들어 ABIDE는 별도의 접근 허가 없이 fMRI를 완전히 사용할 수 있지만 17개의 기관에서 서로 다른 스캐너와 파라미터를 사용하여 획득된다. 이러한 inter-site variability는 실질적으로 의미 있는 집단 간 차이를 가려버리며 이로 인해 학습시 불안정성 증가 및 검증/테스트 셋과의 유의미한 격차로 나타난다. 이를 해결하기 위해 stratified sampling을 제안하고 표준화 할 것을 제안한다.</p>

<h2 id="2-background-and-related-work">2 Background and Related Work</h2>
<p>해당 섹션에서는 배경과 관련 연구를 설명한다. 본문에서는 다루지 않는다.</p>

<h2 id="3-brain-network-transformer">3 BRAIN NETWORK TRANSFORMER</h2>
<h3 id="31-problem-definition">3.1 Problem Definition</h3>
<p>brain network analysis에서 brain network: $X \in \mathbb{R}^{V \times V} $ where $V$는 node (ROIs)의 수 이며 주로 성, 질병의 유무, brain subject의 특징을 예측하는 것이 모델의 주 목표이다. BNT의 전반적인 프레임워크는 아래 그림과 같다. $L$개의 MHSA layer와 graph pooling operator OCREAD로 두개의 main component로 이루어져있다.</p>

<center>
<img src="/assets/img/paper-review/bnt/fig2.webp" width="80%" />
</center>
<p><br /></p>

<p>MHSA에서 non-linear mapping $X \rarr Z^L \in \mathbb{R}^{V \times V}$을 통해 attention-enhanced node features $Z^L$을 학습한다. OCREAD은 enhanced node embeddings $Z^L$을 graph-level embeddings $Z_G \in \mathbb{R}^{K \times V}$ 로 압축한다. $K$는 hyperparam인 number of clusters이다. $Z_G$는 flatten 후 MLP를 통과하여 graph-level prediction을 한다. 모든 학습 과정은 CE를 통한 supervised learning을 통해 이루어진다.</p>

<h3 id="32-multi-head-self-attention-module-mhsa">3.2 Multi-Head Self-Attention Module (MHSA)</h3>
<p>brain network에 적합한 Transformer-based model을 개발하기 위해 PE와 attention mechanism이라는 두가지 기초 디자인에 대한 재고가 필요하다. 현존하는 모델은 eigendecomposition을 통해 주로 위치 정보를 인코딩하지만, 이는 dense한 brain network에서 비용이 많이 들고 edge의 존재가 유익하지 않다.</p>

<p>brain networks의 ROI 노드는 이미 필요한 위치 정보를 가지고 있으므로 eigendecomposition을 통한 position encoding은 중복이다. 기존 연구에서 connection profile $X_i$을 통한 분석이 항상 eigenvectors를 이용한 방법보다 좋은 성능을 보였다. 또한 이전 연구에서 edge weight를 attention score에 통합하면 완전 그래프에서 attention의 효과를 크게 저하시킬 수 있음을 경험적으로 입증하였다. 또한 brain network는 edge가 매우 많기 때문에 edge-wise embedding을 생성하는 것이 계산적으로 감당하기 어렵다. 또한 이 케이스에서 모든 엣지가 단순히 존재하기 때문에 존재 여부 자체도 attention score 계산에 유용한 정보를 제공하지 않는다.</p>

<p>이러한 관점에서 BNT는</p>
<ol>
  <li>connection profile을 초기 node feature로 삼고 PE를 없앤다.</li>
  <li>edge weight나 상대 위치 정보를를 사용하지 않는 바닐라 pair-wise mechanism을 사용한다.</li>
</ol>

<p>$Z^L = \text{MHSA}(X) \in \mathbb{R}^{V \times V}$를 생성하는 $L$-Layer non-linear mapping module인 MHSA 수식은 다음과 같다. 각 레이어 $l$에 대하여 출력 $Z^l$은 다음과 같이 얻어진다.</p>

\[\begin{equation}
Z^l = \left( \Big\|_{m=1}^{M} h^{l,m} \right) W^l_Oh^{l,m} = \text{Softmax} \left( 
    \frac{W^{l,m}_Q Z^{l-1} 
    \left( W^{l,m}_K Z^{l-1} \right)^\top}
    {\sqrt{d^{l,m}_K}} 
\right) W^{l,m}_V Z^{l-1}
\end{equation}\]

<p>여기서 $Z^0 = X$, ∥는 concatenation 연산자, M은 헤드 수, l은 레이어 인덱스, $W^l_O,\ W^{l,m}_Q,\ W^{l,m}_K,\ W^{l,m}_V$는 학습 가능한 파라미터, $d^{l,m}_K$ 는 $W^{l,m}_K$ ​의 첫 번째 차원이다.</p>

<h3 id="33-orthonormal-clustering-readout-ocread">3.3 ORTHONORMAL CLUSTERING READOUT (OCREAD)</h3>
<p>graph-level representation을 학습하는 readout function은 brain network analysis를 위해 필수적이다. Mean, Sum, Max가 주로 사용된다. 그러나 현존하는 방법 중 어느 것도 fig1(a)에 표시된 것 처럼 brain network의 동일한 기능 모듈의 노드가 유사한 동작과 clustering representation을 갖는 경향이 있는 속성을 사용하지 않는다.</p>

<center>
<img src="/assets/img/paper-review/bnt/fig1.webp" width="80%" />
</center>
<p><br /></p>

<p>이를 해결하기 위해 ROIs긴 modular-level similarities의 이점을 활용하는 novel readout function을 제안한다.</p>

<p>$V$ 차원을 가진 $K$ 개의 cluster center $E \in \mathbb{R}^{K \times V}$ 가 주어졌을때 Softmax projection operator가 노드 $i$를 cluster $k$로 할당하는 probability $P_{ik}$를 계산하는 함수로 사용된다.</p>

\[\begin{equation}
P_{ik} = \frac{e^{\langle Z^L_{i\cdot},\, E_{k\cdot} \rangle}}{\sum_{k'}^{K} e^{\langle Z^L_{i\cdot},\, E_{k'\cdot} \rangle}},
\end{equation}\]

<p>soft assignment가 계산된 이후로 $Z^L$은 soft cluster information의 가이드를 받아 graph-level embedding $Z_G$로 집약된다. $Z_G = P^TZ^L$
그러나 GT 없이 node embedding과 cluster를 학습하는 것은 어렵다. 따라서 클러스터 센터 초기화가 매우 중요하다. 이를 해결하기 위해 Fig1(b)에서의 관측을 활용한다. 이는 orthonormal(직교) embedding이 brain network내에서 node clustering을 향상시키는 것이다.</p>

<blockquote>
  <p>이 부분은 완전히 이해하지 못했다. 논문 본문에 이론적 정의가 잘 나와 있으니 참고하면 될 것 같다.</p>
</blockquote>

<h3 id="34-generalizing-ocread-to-other-graph-tasks-and-domains">3.4 Generalizing OCREAD to Other Graph Tasks and Domains</h3>
<p>본 논문에서 OCREAD는 FC based brain network를 이용하였다. 그러나 이에 국한되지 않고 Structural connectivities(SC)등에도 사용가능하며 protein-protein interaction networks나 유전자 발현 network에서도 사용가능하다.</p>

<h2 id="4-experiments">4 Experiments</h2>
<p>저자들은 다음 세가지 RQ에 대한 검증을 위주로 실험했다.</p>

<p>RQ1. How does BRAINNETTF perform compared with state-of-the-art models of various types? (SOTA 급인지?)</p>

<p>RQ2. How does our proposed OCREAD module perform with different model choices? (OCREAD 모듈은 다양한 모델 선택에서 어떻게 작동하는지?)</p>

<p>RQ3. Does the learned model of BRAINNETTF exhibit consistency with existing neuroscience knowledge and suggest reasonable explainability? (BNT가 현존하는 neuroscience knowledge와 일관성있고 함리적인 설명 가능성을 제안하는지?)</p>

<h3 id="41-experimental-settings">4.1 Experimental Settings</h3>
<p>Dataset:</p>
<ul>
  <li>ABIDE
    <ul>
      <li>#: 1009 subjects</li>
      <li>자폐: 516 subjects</li>
      <li>atals:Craddock 200</li>
      <li>network 바로 다운 가능</li>
      <li>multi-site problem을 stratified sampling로 해결</li>
    </ul>
  </li>
  <li>ABCD
    <ul>
      <li>#: 7901 subjects</li>
      <li>여자: 3961 subjects</li>
      <li>atals: HCP 360 ROI</li>
    </ul>
  </li>
</ul>

<p>Metrics는 두 데이터셋 모두 binary classification이므로 AUROC를 사용했으며 임상 적용성을 위해 Sensitivity와 Specificity까지 3가지를 5 random seed에 평균과 표준편차를 보고한다. Model의 헤드는 4개 레이어는 2개를 사용했으면 7:1:2 split을 활용한다. Adam으로 1e-4의 초기 learning rate를 사용한다. weight decay는 1e-4를 사용하며, batch size는 64로 설정됐다. 200 epoch동안 AUROC가 가장 좋은 모델을 테스트에 사용했다.</p>

<h3 id="42-performance-analysis-rq1">4.2 Performance Analysis (RQ1)</h3>

<center>
<img src="/assets/img/paper-review/bnt/tab1.webp" width="80%" />
</center>
<p><br /></p>

<p>(a) BNT vs other graph transformers <br />
위 표를 보면 알 수 있지만 VanillaTF가 다른 비교군을 AUROC 측면에서 이겼다. 저자들은 이를 brain network의 특성 때문이라고 분석했다. 이는</p>

<p>(b) BRAINNETTF vs. neural network models on fixed brain networks <br />
BrainGNN, BrainGB 및 BrainnetCNN 등의 NN 기반 모델보다 더 좋은 성능을 보였다.</p>

<p>(c) BRAINNETTF vs. neural network models on learnable brain networks <br />
learnable graph를 사용한 경우에도 더 좋은 성능을 보였다.</p>

<h3 id="43-ablation-studies-on-the-ocread-module-rq2">4.3 Ablation Studies on the OCREAD Module (RQ2)</h3>
<h4 id="431-ocread-with-varying-readout-functions">4.3.1 OCREAD with varying readout functions</h4>
<p>아래 표는 SAN, Graphormer and VanillaTF에 대하여 다양한 readout function과 본 논문에서 제안된 OCREAD를 비교한 표이다. 전반적으로 OCREAD가 가장 효과적인 readout function이었으며 다양한 transformer 아키텍처의 예측 성능을 높혀주었다.</p>

<center>
<img src="/assets/img/paper-review/bnt/tab2.webp" width="80%" />
</center>
<p><br /></p>

<h4 id="432-ocread-with-varying-cluster-initializations">4.3.2 OCREAD with varying cluster initializations</h4>
<p>OCREAD의 설계가 BNT의 성능에 어떻게 영향을 미치는지 추가로 입증하기 위해 cluster center 초기화 방법과 cluster $K$의 수를 어떻게 선택하는지 이 섹션에서 논의한다.
Random, Learnable, Orthonormal 세가지로 초기화 방법을 비교했으며 2, 3, 4, 5, 10, 50, 100로 $K$를 비교한다. fig3(a)가 이의 결과이다.</p>

<center>
<img src="/assets/img/paper-review/bnt/fig3.webp" width="80%" />
</center>
<p><br /></p>

<p>결과에 따르면 최적의 $K$의 수는 상대적으로 작다. 이는 적은 계산량으로 이끌며 일반적인 functional module의 수가 25개 미만인 것과 일치한다. 충분히 큰 $K$에서는 세가지 방법이 비슷한 성능을 보이나 적은 $K$에서 제안된 orthonormal 방법이 가장 안정적으로 성능을 낸다. OCREAD의 std가 가장 작은 것도 알아두면 좋다.</p>

<h3 id="44-in-depth-analysis-of-attention-scores-and-cluster-assignments-rq3">4.4 In-depth Analysis of Attention Scores and Cluster Assignments (RQ3)</h3>
<p>fig3(b)는 ABCD 테스트셋에서 MHSA 첫 레이어의 평균 self-attention score이다. 이는 학습된 attention score가 available labels 기반의 functional modules의 구분과 잘 일치하여 transformer 모델의 효율성과 설명 가능성을 보여준다. ABIDE는 label을 제공하지 않기 때문에 시각화하지 못했다.</p>

<p>아래 그림은 두 가지 초기화 방법을 사용하여 OCREAD의 노드에 대한 cluster soft assignment 결과 $P$를 보여준다. orthonormal 초기화는 random 초기화보다 좀 더 구분 가능한 $P$를 보여준다. 각 클래스 내에서 orthonormal 초기화는 노드가 그룹을 형성하도록 장려한다.</p>

<center>
<img src="/assets/img/paper-review/bnt/fig4.webp" width="80%" />
</center>
<p><br /></p>

<h2 id="5-discussion-and-conclusion">5 Discussion and Conclusion</h2>
<p>본 논문에서는 brain network analysis를 위한 OCREAD를 갖춘 특화된 graph transformer를 제시한다. 두개의 대규모 brain network 데이터셋에 대한 광범위한 실험으로 BNT가 SOTA인 것을 확인했으며 brain network의 잠재적인 노드 기능 유사성을 모델링하기 위해 OCREAD를 설계하고 이론적, 경험적으로 그 효과를 입증한다. 마지막으로 ABIDE에 대한 재표준화된 데이터셋 분할은 커뮤니티의 새로운 방법에 대한 공정한 평가를 제공할 수 있다. 향후 작업을 위해 BNT는 explicit explanation modules을 통해 개선될 수 있으며 정신 장애에 대한 필수 신경 회로 발굴 및 청소년의 인지 발달 이해와 같은 추가 뇌 네트워크 분석을 위한 백본으로 사용될 수 있다.</p>

<h2 id="개인적인-생각">개인적인 생각</h2>
<ul>
  <li>본 논문은 brain network 관련 연구에서 베이스라인으로 사용되는 모델이라 자세히 리뷰를 하고 싶었는데, 일단 OCREAD라는 것이 나에게는 너무 어려워서 제대로 이해하지 못했다. 추후에 BNT에 대하여, 또 OCREAD에 대하여 자세히 공부할 일이 있으면 이론까지 공부를 하여 해당 본문을 수정할 의향이 있다.</li>
  <li>기존 모델들이 왜 brain network에서 약한 모습을 보였는지 이해할 수 있었다. brain network의 고유한 특성에 대하여 이해할 수 있었다. 저자들의 뇌 기능에 관한 깊은 insight를 얻을 수 있어서 좋았다.</li>
  <li>brain network가 complete graph 라서 edge 정보가 불필요한 것을 이해하였으며 connection profile이 구조적 위치적 정보를 가지고 있는 것을 알게 되었다.</li>
  <li>다시 원문을 보니 stratified sampling에 대하여 자세히 설명하지 않은 것 같다.</li>
  <li>글을 보면 자연스럽지 못한 부분이 참 많은 것 같다. 이는 내가 이 논문을 완전히 이해하고 있지 못하다는 뜻이다. 여러모로 아쉬운 리뷰였지만, 이번 리뷰를 발판 삼아 더 나은 리뷰를 하는 블로거가 되어야겠다.</li>
</ul>]]></content><author><name>YSPARK</name></author><category term="Paper-Review" /><category term="Medical-AI" /><category term="Brain-Network" /><category term="NeurIPS" /><summary type="html"><![CDATA[Brain Network Transformer (NeurIPS)]]></summary></entry><entry><title type="html">[논문리뷰] Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture</title><link href="https://kitewatermelon.github.io/paper-review/ijepa/" rel="alternate" type="text/html" title="[논문리뷰] Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture" /><published>2026-03-14T00:00:00+09:00</published><updated>2026-03-12T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/paper-review/ijepa</id><content type="html" xml:base="https://kitewatermelon.github.io/paper-review/ijepa/"><![CDATA[<blockquote>
  <p>CVPR 2023 [<a href="https://arxiv.org/pdf/2301.08243">Paper</a>] [<a href="https://github.com/facebookresearch/ijepa">GitHub</a>]<br />
 Mahmoud Assran, Quentin Duval, Ishan Misra, Piotr Bojanowski, Pascal Vincent, Michael Rabbat, Yann LeCun, Nicolas Ballas
 13 Apr 2023</p>
</blockquote>

<h2 id="1-introduction">1. Introduction</h2>
<p>CV 분야에서는 invariance-based method와 generative method라는 두가지 SSL 기법이 있다. invariance-based pretraining method는 같은 이미지의 서로 다른 view에서 비슷한 임베딩을 얻으며 최적화한다. 이때 서로 다른 view는 hand-crafted augmentations를 통해 주로 만든다. 이 pretraining 기법은 high semantic level의 representations을 얻을 수 있지만, 특정 task나 다른 데이터 분포에서 강한 편향을 주입하기도 한다. 다양한 수준의 추상화가 필요한 task에 대해 이런 편향을 일반화하는 것은 아직 불분명한 경우가 많다. 예를 들어 image classification과 instance segmentation은 같은 invariance를 요구하지 않는다. 추가적으로 image-specific augmentation을 audio 같은 다른 모달에 일반화하는 것은 간단하지 않다.</p>

<p>Cognitive learning theories는 생물학적 시스템에서 representations learning 이면에 있는 구동 메커니즘은 감각 입력 반응을 예측하기 위한 내부 모델의 adaptation이라고 제안한다. 이 아이디어는 self-supervised generative methods의 핵심이다. 이는 입력의 일부를 제거하거나 오염시키고 해당 부분을 예측하는 방식이다. Masked pretraining task는 view-invariance method 보다 사전 지식(hand-crafted transformers을 의미하는 것 같음.)이 덜 필요하고 다른 모달에 일반화 성능이 좋다. 그러나 invariance-based method보다 낮은 semantic 수준을 보이며ㅡ off-the-shelf evaluation에서 성능이 낮다. 결과적으로 end-to-end fine-tunning 같은 복잡한 adaptation 메커니즘을 활용해야 이 방법의 완전한 이점을 누릴 수 있다.</p>

<p>본 논문에서는, I-JEPA라는 추가 사전 지식 없이 self-supervised representations의 semantic 수준을 높이는 method를 도입한다. I-JEPA의 핵심 아이디어는 abstract representation space(추상 표현 공간)에서 누락된 정보를 예측하는 것이다. 예를 들어 같은 이미지의 단일 context block를 주고, 여러 target block의 representation을 예측하는 것이다.</p>

<center>
<img src="/assets/img/paper-review/ijepa/fig3.webp" width="80%" />
</center>
<p><br /></p>

<p>pixel/token space에서 예측하는 generative methods에 비해 I-JEPA는 불필요한 pixel-level 디테일이 잠재적으로 지워진 target의 abstract를 예측함으로 model이 더 의미있는 특징을 학습하도록 이끌어낸다.</p>

<p>I-JEPA를 더 semantic representations을 생산하도록 선택된 또 다른 핵심 디자인은 multi-block masking strategy이다. 특히 이미지에서 충분히 큰 target blocks를 예측하도록 하는 것의 중요성을 입증한다.</p>

<p>저자들은 방대한 양의 실험을 통해 다음을 입증한다.</p>
<ul>
  <li>strong off-the-shelf representation을 hand-crafted view augmentation없이 학습한다.</li>
  <li>I-JEPA는 view-invatiant pretraining approaches와 비등한 semantic task 결과를 보였고, low-level visions tasks에서는 더 나은 결과를 보였다.</li>
  <li>I-JEPA는 scalable하고 효율적인다.</li>
</ul>

<h2 id="2-background">2. Background</h2>
<p>SSL은 system이 입력간의 관계를 포착하도록 하는 representation learning 기법이다. 이 목표는 incompatible inputs 끼리는 높은 에너지를, compatible inputs끼리는 낮은 에너지를 할당하는 Energy-Based Models(EBMs)의 프레임워크를 이용하여 쉽게 설명 가능하다. 현존하는 많은 SSL 방법들이 이 프레임워크로 설명가능하다. 다음 그림을 보면 이해가 쉽다.</p>

<center>
<img src="/assets/img/paper-review/ijepa/fig2.webp" width="80%" />
</center>
<p><br /></p>

<h3 id="joint-embedding-architectures">Joint-Embedding Architectures</h3>
<p>Invariance-based pretraining methods는 compatible inputs인 $x,y$ 간에 비슷한 embedding을 산출하고,incompatible inputs에는 다른 embedding을 산출하느 Joint-Embedding Architectures을 사용하는 EBMs로 설명가능하다(Figure 2a.). image-based pretraining의 관점에서 compatible inputs인 $x,y$는 주로 같은 입력 이미지에 랜덤하게 hand-crafted augmentation을 적용하여 만든다. JEAs의 주된 과제는 energy landscapes가 평평해 지는 representation collapse이다(입력에 관계없이 완전히 같은 출력을 내보냄.). 지난 몇년간 representation collapse을 방지하기 위한 다양한 방법이 연구되었다. contrastive loss를 사용하거나 non-contrastive loss를 사용하여 embedding간 정보 중복을 최소화 하거나 평균 임베딩의 엔트로피를 극대화하는 clustering-based 등의 방법이 있다. 또 서로 다른 인코더를 사용해서 collapse를 방지하는 방법도 있다.</p>

<h3 id="generative-architectures">Generative Architectures</h3>
<p>Reconstruction-based method 역시 Generative Architectures를 사용하는 EBMs 프레임워크로 설명할 수 있다(Figure 2b.). Generative Architectures는 compatible signal $x$에서 바로 $y$를 reconstruction한다. 이때 디코더는 이를 촉진하기 위해서 $z$를 추가적으로 조건으로 받는다. image-based pretraining의 관점에서 가장 흔한 compatible inputs $x,y$를 만드는 방법은 masking이다. $z$는 mask와 position token이다. 이는 어느 이미지 패치를 디코더가 reconstruction할지 명시해준다. $y$보다 $z$의 정보 용량이 적은한 representation collapse는 문제가 되지 않는다.</p>

<h3 id="joint-embedding-predictive-architectures">Joint-Embedding Predictive Architectures</h3>
<p>Figure 2c.에서 볼 수 있듯이 Joint-Embedding Predictive Architectures는 Generative Architectures와 비슷하다. 이와 가장 큰 차이는 loss 계산이 input space가 아닌 embedding space에서 일어나는 것이다. JEPAs는 예측을 용이하기 위한 변수 $z$를 조건으로 받아 compatible signal $x$로 부터 signal $y$의 임베딩을 예측 네트워크를 통해 예측하도록 학습한다. 자세한 그림은 Figure 3.을 참조하면된다.</p>

<p>JEA와 다르게 JEPAs는 representation invatiant를 hand-crafted augmentation을 이용하여 representation invatiant를 찾지 않고 대신 추가 정보 $z$를 조건으로 할 때 서로를 예측하는 representation을 찾는다. 그러나 JEA와 마찬가지로 representation collapse는 JEPAs에서도 문제가 되는데, 이를 방지하기 위해 $x$와 $y$ 사이에 비대칭 아키텍처를 사용한다.</p>

<h2 id="3-method">3. Method</h2>
<p>Figure 3.을 다시 한번 보자.</p>

<center>
<img src="/assets/img/paper-review/ijepa/fig3.webp" width="80%" />
</center>
<p><br /></p>

<p>I-JEPA의 전반적인 목적은 같은 이미지에서 context block이 주어졌을때 다양한 target block의 representation을 예측하는 것이다. context-encoder와 target-encoder 모두 ViT를 사용하였으며 decoder구조는 MAE에서 따왔다. 둘의 차이는 I-JEPA는 non-generative method이며 prediction이 representation space에서 일어나는 것이다.</p>

<h3 id="targets">Targets</h3>
<p>주어진 입력 이미지 $y$를 겹치지 않는 N개의 패치를 만든다. target-encoder $f_{\bar \theta}$에 이걸 넣어서 대응하는 patch-level representation $s_y=\lbrace s_{y1},…, s_{yN} \rbrace $을 만든다. loss를 위한 targets를 얻기 위해 $s_y$에서 M개의 랜덤한 sample block을 뽑는다. i번째 블록에 대응하는 마스크를 $B_i$로 표시하고 $s_y(i) = \lbrace s_{yj} \rbrace _{j \in B_i}$로 표시한다. 저자들은 실험에서 M=4로 셋하고 0.75:1의 종횡비와 0.15~0.2의 스케일로 block을 샘플링했다.</p>

<h3 id="context">Context</h3>
<p>I-JEPA의 목표는 single context block으로 부터 target block의 representation을 예측하는 것이다. 이를 위해 이미지의 0.85~1의 스케일로 $x$를 샘플링하고 $B_x$를 이용해 context blocks을 할당한다. 이후 target block과 겹치는 부분을 없애준다. 아래 그림이 target blocks와 context block을 이해하는데 도움을 준다.</p>

<center>
<img src="/assets/img/paper-review/ijepa/fig4.webp" width="80%" />
</center>
<p><br /></p>

<h3 id="prediction">Prediction</h3>
<p>context encoder의 출력 $s_x$가 주어졌을 때 우리는 $M$ 개의 target block의 
representation $s_y(1), \ldots, s_y(M)$ 을 예측하기를 바란다. 이를 위해, 
대상 마스크 $B_i$ 에 해당하는 주어진 target block $s_y(i)$ 에 대해 예측기 
$g_\phi(\cdot, \cdot)$ 는 context encoder의 출력 $s_x$와 예측하려는 각 패치에 
대한 마스크 토큰 $\lbrace m_j \rbrace_{j \in B_i}$ 를 입력으로 취하고 패치 수준 예측 \(\hat{s}_{y}(i) = \lbrace \hat{s}_{yj} \rbrace_{j \in B_i} = g_\phi(s_x, \lbrace m_j \rbrace_{j \in B_i})\) 을 출력한다. 
mask token은 positional embedding이 추가된 shared learnable vector이다.</p>

<h3 id="loss">Loss</h3>
<p>loss는 predicted patch-level representations $\hat{s}_y(i)$과 the target patch-level representation $s_y(i)$간의 평균 $L_2$ distance이다. 
predictor, $\phi$와 context encoder, $\theta$는 gradient based optimization을 target encoder $\bar \theta$는 EMA 방식으로 학습한다.</p>

<h2 id="4-related-work">4. Related Work</h2>
<p>본문에서 다양한 SSL 기법에 대한 설명을 하고 있으나 이 글에서는 다루지 않겠다.</p>

<h2 id="5-image-classification">5. Image Classification</h2>

<h3 id="imagenet-1k">ImageNet-1K</h3>
<center>
<img src="/assets/img/paper-review/ijepa/tab1.webp" width="80%" />
</center>
<p><br /></p>

<p>hand-crafted augmentation을 사용하지 않는 다른 유명한 방법인 MAE, CAE 그리고 data2vec과 비교했을때 I-JEPA는 더 적은 연산으로 linear probing 성능을 향상시켰다.</p>

<h3 id="low-shot-imagenet-1k">Low-Shot ImageNet-1K</h3>
<center>
<img src="/assets/img/paper-review/ijepa/tab2.webp" width="80%" />
</center>
<p><br /></p>

<p>IN1k의 1%(각 클래스 별로 12~13 장)으로 학습한 결과이다. 적은 에폭으로도 비슷한 구조의 MAE보다 나은 성능을 보였다. 이미지 해상도가 높아져도 JEAs보다 더 나은 성능을 보인다.</p>

<h3 id="transfer-learning">Transfer learning</h3>
<p>기존 모델들 보다 더 좋은 성능을 보였으며 view-invariance-based와의 간격도 줄었다.</p>
<center>
<img src="/assets/img/paper-review/ijepa/tab3.webp" width="80%" />
</center>
<p><br /></p>

<h2 id="6-local-prediction-tasks">6. Local Prediction Tasks</h2>
<ol>
  <li>에서 I-JEPA의 강력함을 엿볼 수 있었는데 이 섹션에서는 I-JEPArk local image feature를 학습하고 low-level이고 dense prediction task에서 view-invariance based method보다 더 나은 결과를 보임을 입증한다.</li>
</ol>
<center>
<img src="/assets/img/paper-review/ijepa/tab4.webp" width="80%" />
</center>
<p><br /></p>

<h2 id="7-scalability">7. Scalability</h2>
<h3 id="model-efficiency">Model Efficiency</h3>
<p>I-JEPA는 기존 방법들보다 더 높은 확장성을 제공한다. MAE 같은 reconstructionbased methods는 픽셀을 target으로 삼는 반면, I-JEPA는 representation space에서 계산을 하기 때문에 약간의 오버헤드가 있다. 그러나 5배 더 빠른 수렴을 보여준다. 또한 I-JEPA로 ViT-H/14를 학습하는 것 보다 ViT-S/16으로 iBOT을 학습하는 것이 더 적은 연산을 필요로 한다.</p>

<center>
<img src="/assets/img/paper-review/ijepa/fig5.webp" width="80%" />
</center>
<p><br /></p>

<h3 id="scaling-data-size">Scaling data size</h3>
<p>I-JEPA는 더 큰 데이터셋에서 pretraining할 때 효과적임을 아래 표를 통해 확인할 수 있다.</p>
<center>
<img src="/assets/img/paper-review/ijepa/tab5.webp" width="80%" />
</center>
<p><br /></p>

<h3 id="scaling-model-size">Scaling model size</h3>
<p>위 표는 I-JEPA가 더 큰 모델에서 pretraining을 할때 더 효과적임을 입증한다. 그러나 ViT-G/16은 입력 팿치가 더 커서 local prediction task 성능이 안좋다.</p>

<h2 id="8-predictor-visualizations">8. Predictor Visualizations</h2>
<p>I-JEPA로 학습한 모델을 RCDM framework로 생성을 시킨 결과이다. I-JEPA의 predictor는 고수준 object의 부분을 정확한 Pose로 잘잡아낸다.</p>

<center>
<img src="/assets/img/paper-review/ijepa/fig6.webp" width="80%" />
</center>
<p><br /></p>

<h2 id="9-ablations">9. Ablations</h2>
<h3 id="predicting-in-representation-space">Predicting in representation space</h3>
<p>pixel space vs representation space에서 loss를 계산할 때의 성능 차이를 ImageNet-1K 1% linear probe로 비교한 실험이다.</p>

<center>
<img src="/assets/img/paper-review/ijepa/tab7.webp" width="80%" />
</center>
<p><br /></p>

<p>pixel space에서 예측하게 되면 모델이 픽셀 수준의 세부 정보(텍스처, 조명, 노이즈 등)까지 다 맞춰야 해서 representation이 low-level detail에 오염된다. 그러나 
representation space에서 예측하게 되면 target encoder가 추상적인 예측 타겟을 만들 수 있으므로 의미없는 픽셀 수준 디테일이 제거된 상태로 학습된다.</p>

<h3 id="masking-strategy">Masking strategy</h3>
<p>다양한 마스킹 전략을 ablation한 결과이다. multi-block masking이 I-JEPA가 semantic representation을 학습하는 데에 도움이 되는 가이드를 하는 것을 알아냈다.</p>

<center>
<img src="/assets/img/paper-review/ijepa/tab6.webp" width="80%" />
</center>
<p><br /></p>

<h2 id="10-conclusion">10. Conclusion</h2>
<p>본 논문에서 저자들은 I-JEPA를 소개한다. I-JEPA는 hand-crafted augmentation에 의존하지 않으며 semantic image representation을 학습하는 간단하고 효율적인 방법이다. I-JEPA는 다른 pixel-level의 방법보다 빠르게 수렴하며 높은 수준의 semantic representation을 학습한다. view-invariance based method와 달리 I-JEPA는 and-crafted augmentation에 의존하지 않고 JEA를 사용하여 general representation을 학습할 수 있는 경로를 강조한다.</p>

<h2 id="개인적인-생각">개인적인 생각</h2>
<ul>
  <li>오랜만에 CV 이론 논문을 리뷰해서 introduction에 힘을 줘버려서 Experiments 부분을 제대로 리뷰하지 못한 것 같아서 아쉬웠다.</li>
  <li>얀 르쿤이 심열을 기울인 방법으로 디테일이 크게 돋보인 논문이었다.</li>
  <li>아직 익숙하지 않은 개념이라, 따로 드는 생각은 없는 것 같다. 코드를 한번 뜯어봐야겠다.</li>
</ul>]]></content><author><name>YSPARK</name></author><category term="Paper-Review" /><category term="Computer-Vision" /><category term="Self-Supervised-Learning" /><category term="JEPA" /><category term="CVPR" /><summary type="html"><![CDATA[Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture (CVPR)]]></summary></entry><entry><title type="html">[코드 리뷰] Tensor Slicing</title><link href="https://kitewatermelon.github.io/code-review/tensor-slice/" rel="alternate" type="text/html" title="[코드 리뷰] Tensor Slicing" /><published>2026-03-12T00:00:00+09:00</published><updated>2026-03-12T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/code-review/tensor-slice</id><content type="html" xml:base="https://kitewatermelon.github.io/code-review/tensor-slice/"><![CDATA[<p>실습 예제: <a href="https://colab.research.google.com/drive/1btkCLqW3QAqOZZ6ymbmL5trOHW4ri4IS#scrollTo=9_nIJ9P9H5Np">Colab</a></p>

<h2 id="1-introduction">1. Introduction</h2>
<p>딥러닝 관련 논문을 읽으며 공부하다 보면 대부분의 논문 구현이 PyTorch(torch) 기반으로 되어 있음을 알 수 있다. 아래 그래프에서 볼 수 있듯이 2024년 기준 PyTorch, TensorFlow, JAX 중 PyTorch를 사용한 프로젝트가 제일 많은 비중을 차지하는 것을 볼 수 있다 <a href="https://softwaremill.com/ml-engineer-comparison-of-pytorch-tensorflow-jax-and-flax/">[1]</a>.</p>

<p><img src="https://softwaremill.com/user/pages/blog/229.ml-engineer-comparison-of-pytorch-tensorflow-jax-and-flax/image2.png?g-1efd1e18" alt="image.png" /></p>

<h3 id="11-tensor-slicing이란">1.1. Tensor Slicing이란?</h3>
<p>텐서 슬라이싱은 다차원 배열에서 원하는 부분만 선택적으로 추출하는 연산으로 NumPy의 배열 인덱싱에서 유래했고 PyTorch도 동일한 문법을 사용한다 (NumPy-like).</p>

<h3 id="12-왜-필요한가">1.2. 왜 필요한가?</h3>
<p>딥러닝에서 텐서는 보통 (B, H, W, C) 같은 고차원 구조를 가진다. 모델 내부에서 특정 배치만, 특정 채널만, 특정 공간 위치만 꺼내서 연산해야 할 일이 매우 많다. 슬라이싱 없이는 불필요한 데이터까지 복사하거나 반복문으로 순회해야 하는데, 슬라이싱은 이걸 뷰(view) 방식으로 해결한다.</p>

<h3 id="13-원리">1.3. 원리</h3>
<p>핵심은 “메모리를 복사하지 않는다”는 것이다.</p>

<p>텐서는 내부적으로 두 가지로 구성되는데:</p>
<ul>
  <li>storage: 실제 데이터가 1D로 연속 저장된 메모리</li>
  <li>stride + offset: “몇 칸 건너뛰면 다음 원소인지”를 기술하는 메타데이터</li>
</ul>

<p>슬라이싱을 하면 storage는 그대로 두고 stride와 offset만 바꾼 새 텐서 객체를 반환하기 때문에 빠르고 메모리 효율적이다 <a href="https://docs.pytorch.org/docs/stable/tensor_view.html">[2]</a>.</p>

<p>메모리를 복사하지 않는 것으로 문제가 생길 수 있는데, 이 부분은 3장에서 다룬다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>  <span class="c1"># (B, H, W, C)
</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">shape</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:].</span><span class="n">shape</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">[...,</span> <span class="mi">0</span><span class="p">].</span><span class="n">shape</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">].</span><span class="n">shape</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>torch.Size([3, 3, 3])
torch.Size([3, 3, 3])
torch.Size([3, 3, 3])
torch.Size([2, 3, 3, 3])
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">.</span><span class="n">stride</span><span class="p">())</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">stride</span><span class="p">())</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">].</span><span class="n">stride</span><span class="p">())</span>
<span class="n">t_slice</span> <span class="o">=</span> <span class="n">t</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">print</span><span class="p">(</span><span class="n">t_slice</span><span class="p">.</span><span class="n">data_ptr</span><span class="p">()</span> <span class="o">==</span> <span class="n">t</span><span class="p">.</span><span class="n">data_ptr</span><span class="p">())</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(27, 9, 3, 1)
(9, 3, 1)
(27, 3, 1)
True
</code></pre></div></div>

<h2 id="2-시각화로-텐스-슬라이싱-제대로-보기">2. 시각화로 텐스 슬라이싱 제대로 보기</h2>
<p>본 글에서는 사람들의 이해를 돕기 위해 matplotlib로 시각화한다. Computer Vision 영역에서 제일 많이 사용되는 4D [B, H, W, C] 형태를 시각화 할 것이며, [3,3,3,3] 사이즈의 boolean 자료형의 입력을 사용한다. 이때 B는 batch 사이즈이고, H는 이미지의 높이, W는 이미지의 너비, C는 RGB로 판단한다. 따라서 $3^2$ 크기의 컬러 이미지 3개가 있는 상황이다.</p>

<p>기본적으로 <code class="language-plaintext highlighter-rouge">tensor.zeros().dtype(bool)</code> 로 4차원 False tensor를 생성하여 슬라이싱 되는 부분만 True로 변환하여 어떤 부분이 슬라이싱 되는지 시각화한다.</p>

<p>우리들의 천하무적 클로드가 <code class="language-plaintext highlighter-rouge">visualize_tensor()</code>라는 시각화 코드를 만들어줬다:</p>

<details>
<summary>코드 정보</summary>

<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">matplotlib.patches</span> <span class="k">as</span> <span class="n">mpatches</span>
<span class="kn">from</span> <span class="nn">mpl_toolkits.mplot3d.art3d</span> <span class="kn">import</span> <span class="n">Poly3DCollection</span>

<span class="n">RGB_COLORS</span> <span class="o">=</span> <span class="p">[</span><span class="s">'red'</span><span class="p">,</span> <span class="s">'green'</span><span class="p">,</span> <span class="s">'blue'</span><span class="p">]</span>
<span class="n">RGB_LABELS</span> <span class="o">=</span> <span class="p">[</span><span class="s">'R'</span><span class="p">,</span> <span class="s">'G'</span><span class="p">,</span> <span class="s">'B'</span><span class="p">]</span>

<span class="k">def</span> <span class="nf">draw_cube</span><span class="p">(</span><span class="n">ax</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">filled</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'steelblue'</span><span class="p">):</span>
    <span class="n">vertices</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span>
        <span class="p">[</span><span class="n">x</span><span class="p">,</span>   <span class="n">y</span><span class="p">,</span>   <span class="n">z</span><span class="p">],</span>   <span class="p">[</span><span class="n">x</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span>   <span class="n">z</span><span class="p">],</span>   <span class="p">[</span><span class="n">x</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">y</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">z</span><span class="p">],</span>   <span class="p">[</span><span class="n">x</span><span class="p">,</span>   <span class="n">y</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">z</span><span class="p">],</span>
        <span class="p">[</span><span class="n">x</span><span class="p">,</span>   <span class="n">y</span><span class="p">,</span>   <span class="n">z</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="n">x</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span>   <span class="n">z</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="n">x</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">y</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">z</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="n">x</span><span class="p">,</span>   <span class="n">y</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">z</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span>
    <span class="p">])</span>
    <span class="n">faces</span> <span class="o">=</span> <span class="p">[</span>
        <span class="p">[</span><span class="n">vertices</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">3</span><span class="p">]],</span>
        <span class="p">[</span><span class="n">vertices</span><span class="p">[</span><span class="mi">4</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">5</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">6</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">7</span><span class="p">]],</span>
        <span class="p">[</span><span class="n">vertices</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">5</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">4</span><span class="p">]],</span>
        <span class="p">[</span><span class="n">vertices</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">7</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">6</span><span class="p">]],</span>
        <span class="p">[</span><span class="n">vertices</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">7</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">4</span><span class="p">]],</span>
        <span class="p">[</span><span class="n">vertices</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">6</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">5</span><span class="p">]],</span>
    <span class="p">]</span>
    <span class="k">if</span> <span class="n">filled</span><span class="p">:</span>
        <span class="n">poly</span> <span class="o">=</span> <span class="n">Poly3DCollection</span><span class="p">(</span><span class="n">faces</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span>
                                <span class="n">facecolor</span><span class="o">=</span><span class="n">color</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s">'black'</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">poly</span> <span class="o">=</span> <span class="n">Poly3DCollection</span><span class="p">(</span><span class="n">faces</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.03</span><span class="p">,</span>
                                <span class="n">facecolor</span><span class="o">=</span><span class="s">'white'</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s">'gray'</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="s">'--'</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">add_collection3d</span><span class="p">(</span><span class="n">poly</span><span class="p">)</span>


<span class="k">def</span> <span class="nf">visualize_tensor</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
    <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="s">'numpy'</span><span class="p">):</span>
        <span class="n">arr</span> <span class="o">=</span> <span class="n">tensor</span><span class="p">.</span><span class="n">numpy</span><span class="p">().</span><span class="n">astype</span><span class="p">(</span><span class="nb">bool</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">arr</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>

    <span class="k">assert</span> <span class="n">arr</span><span class="p">.</span><span class="n">shape</span> <span class="o">==</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="s">"Input must be [3,3,3,3]"</span>
    <span class="n">B</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">arr</span><span class="p">.</span><span class="n">shape</span>

    <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span> <span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span>

    <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">B</span><span class="p">):</span>
        <span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="p">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">b</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">projection</span><span class="o">=</span><span class="s">'3d'</span><span class="p">)</span>

        <span class="k">for</span> <span class="n">h</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">H</span><span class="p">):</span>
            <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">W</span><span class="p">):</span>
                <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">C</span><span class="p">):</span>
                    <span class="n">draw_cube</span><span class="p">(</span><span class="n">ax</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span>
                              <span class="n">filled</span><span class="o">=</span><span class="nb">bool</span><span class="p">(</span><span class="n">arr</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">c</span><span class="p">]),</span>
                              <span class="n">color</span><span class="o">=</span><span class="n">RGB_COLORS</span><span class="p">[</span><span class="n">c</span><span class="p">])</span>  <span class="c1"># c=0→R, c=1→G, c=2→B
</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">'C'</span><span class="p">,</span> <span class="n">labelpad</span><span class="o">=</span><span class="mi">6</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">'W'</span><span class="p">,</span> <span class="n">labelpad</span><span class="o">=</span><span class="mi">6</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_zlabel</span><span class="p">(</span><span class="s">'H'</span><span class="p">,</span> <span class="n">labelpad</span><span class="o">=</span><span class="mi">6</span><span class="p">)</span>

        <span class="n">ticks</span>  <span class="o">=</span> <span class="p">[</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">1.5</span><span class="p">,</span> <span class="mf">2.5</span><span class="p">]</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_xticks</span><span class="p">(</span><span class="n">ticks</span><span class="p">);</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_xticklabels</span><span class="p">(</span><span class="n">RGB_LABELS</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">7</span><span class="p">)</span>  <span class="c1"># R/G/B 표기
</span>        <span class="n">ax</span><span class="p">.</span><span class="n">set_yticks</span><span class="p">(</span><span class="n">ticks</span><span class="p">);</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_yticklabels</span><span class="p">([</span><span class="s">'1'</span><span class="p">,</span> <span class="s">'2'</span><span class="p">,</span> <span class="s">'3'</span><span class="p">],</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">7</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_zticks</span><span class="p">(</span><span class="n">ticks</span><span class="p">);</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_zticklabels</span><span class="p">([</span><span class="s">'1'</span><span class="p">,</span> <span class="s">'2'</span><span class="p">,</span> <span class="s">'3'</span><span class="p">],</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">7</span><span class="p">)</span>

        <span class="n">ax</span><span class="p">.</span><span class="n">set_xlim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">C</span><span class="p">);</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_ylim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">W</span><span class="p">);</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_zlim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">H</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s">'B=</span><span class="si">{</span><span class="n">b</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s">'</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">11</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">view_init</span><span class="p">(</span><span class="n">elev</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">azim</span><span class="o">=-</span><span class="mi">60</span><span class="p">)</span>

    <span class="c1"># Legend: R/G/B + False
</span>    <span class="n">patches</span> <span class="o">=</span> <span class="p">[</span><span class="n">mpatches</span><span class="p">.</span><span class="n">Patch</span><span class="p">(</span><span class="n">facecolor</span><span class="o">=</span><span class="n">c</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s">'black'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">'True (</span><span class="si">{</span><span class="n">l</span><span class="si">}</span><span class="s">)'</span><span class="p">)</span>
               <span class="k">for</span> <span class="n">c</span><span class="p">,</span> <span class="n">l</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">RGB_COLORS</span><span class="p">,</span> <span class="n">RGB_LABELS</span><span class="p">)]</span>
    <span class="n">patches</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">mpatches</span><span class="p">.</span><span class="n">Patch</span><span class="p">(</span><span class="n">facecolor</span><span class="o">=</span><span class="s">'white'</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s">'gray'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'False'</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="s">'--'</span><span class="p">))</span>
    <span class="n">fig</span><span class="p">.</span><span class="n">legend</span><span class="p">(</span><span class="n">handles</span><span class="o">=</span><span class="n">patches</span><span class="p">,</span> <span class="n">loc</span><span class="o">=</span><span class="s">'lower center'</span><span class="p">,</span>
               <span class="n">ncol</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">9</span><span class="p">,</span> <span class="n">frameon</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">bbox_to_anchor</span><span class="o">=</span><span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">))</span>

    <span class="n">plt</span><span class="p">.</span><span class="n">suptitle</span><span class="p">(</span><span class="s">'[B, H, W, C] Tensor Slice Visualization'</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">13</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="mf">1.01</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>


</code></pre></div>    </div>
  </div>
</details>

<p>전체 체크</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t1</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t1</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_7_1.webp" alt="png" /></p>

<p>두번째 이미지만 슬라이싱</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
<span class="n">t</span><span class="p">[</span><span class="mi">1</span><span class="p">,:,:,:]</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_9_1.webp" alt="png" /></p>

<p>RGB 중 G만 시각화<br />
<code class="language-plaintext highlighter-rouge">...</code>(Ellipsis)는 “나머지 차원은 전부 : 로 채워줘” 라는 뜻이다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
<span class="n">t</span><span class="p">[:,:,:,</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>

<span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
<span class="n">t</span><span class="p">[...,</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_11_1.webp" alt="png" /></p>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_11_3.webp" alt="png" /></p>

<p>십자가 모양으로 슬라이싱</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>

<span class="n">t</span><span class="p">[:,</span><span class="mi">1</span><span class="p">,...]</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">t</span><span class="p">[...,</span><span class="mi">1</span><span class="p">,:]</span> <span class="o">=</span> <span class="bp">True</span>

<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_13_1.webp" alt="png" /></p>

<p>중심 부분만 슬라이싱</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>

<span class="n">t</span><span class="p">[:,</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">,:]</span> <span class="o">=</span> <span class="bp">True</span>

<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_15_1.webp" alt="png" /></p>

<p>멋지게 인덱싱 해보기</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># 안 멋진 방법
</span><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>

<span class="n">t</span><span class="p">[:,</span><span class="mi">0</span><span class="p">,</span><span class="mi">0</span><span class="p">,</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">t</span><span class="p">[:,</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">t</span><span class="p">[:,</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>

<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>

<span class="c1"># 멋진 방법
</span><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>

<span class="n">idx</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
<span class="n">t</span><span class="p">[:,</span> <span class="n">idx</span><span class="p">,</span> <span class="n">idx</span><span class="p">,</span> <span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>  <span class="c1"># H=W=C 인 대각선
</span>
<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_17_1.webp" alt="png" /></p>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_17_3.webp" alt="png" /></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>

<span class="n">t</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span> <span class="c1"># R채널만, shape [B, H, W]
</span><span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>

<span class="n">t</span> <span class="o">=</span> <span class="n">t</span><span class="p">.</span><span class="n">T</span>
<span class="n">t</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span> <span class="c1"># R채널만, shape [B, H, W]
</span>
<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_18_1.webp" alt="png" /></p>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_18_3.webp" alt="png" /></p>

<h2 id="3-contiguous에-대하여">3. <code class="language-plaintext highlighter-rouge">contiguous()</code>에 대하여</h2>

<h3 id="31-메모리-레이아웃부터-이해하기">3.1. 메모리 레이아웃부터 이해하기</h3>

<p>PyTorch 텐서는 내부적으로 <strong>1D 메모리(storage)</strong> 위에 존재한다. 예를 들어 shape <code class="language-plaintext highlighter-rouge">[2, 3]</code> 텐서는 실제로 메모리에 이렇게 저장된다:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>메모리: [a, b, c, d, e, f]
         ↕
tensor([[a, b, c],
        [d, e, f]])
</code></pre></div></div>

<p>이때 “다음 원소로 가려면 몇 칸 건너뛰어야 하는가”를 <strong>stride</strong>라고 한다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>

<span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">],[</span><span class="mi">4</span><span class="p">,</span><span class="mi">5</span><span class="p">,</span><span class="mi">6</span><span class="p">]])</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">.</span><span class="n">stride</span><span class="p">())</span>  <span class="c1"># (3, 1) → 행 이동시 3칸, 열 이동시 1칸
</span></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(3, 1)
</code></pre></div></div>

<h3 id="32-슬라이싱-후-stride가-꼬이는-상황">3.2. 슬라이싱 후 stride가 꼬이는 상황</h3>

<p>이 슬라이싱은 <strong>메모리를 복사하지 않고</strong> stride/offset만 바꿔서 반환한다.
그 결과 메모리 상에서 원소들이 <strong>띄엄띄엄</strong> 놓이게 된다.</p>

<p>stride의 마지막 값이 1이 아니라는 건, 메모리에서 원소들이 연속적으로 붙어있지 않다는 뜻이다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">t_slice</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">]</span>   <span class="c1"># C 채널 중 G만 추출 → shape [3,3,3]
</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">.</span><span class="n">stride</span><span class="p">())</span>               <span class="c1"># (27, 9, 3, 1)
</span><span class="k">print</span><span class="p">(</span><span class="n">t_slice</span><span class="p">.</span><span class="n">stride</span><span class="p">())</span>         <span class="c1"># (27, 9, 3)  ← 마지막이 1이 아님!
</span><span class="k">print</span><span class="p">(</span><span class="n">t_slice</span><span class="p">.</span><span class="n">is_contiguous</span><span class="p">())</span>  <span class="c1"># False
</span></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(27, 9, 3, 1)
(27, 9, 3)
False
</code></pre></div></div>

<h3 id="33-언제-문제가-터지나">3.3. 언제 문제가 터지나?</h3>

<p><code class="language-plaintext highlighter-rouge">view()</code>는 메모리가 연속적으로 배치되어 있다고 가정한다. 그래서 비연속 텐서에 <code class="language-plaintext highlighter-rouge">.view()</code>를 쓰면 에러가 발생한다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">rand</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">])</span>

<span class="c1"># transpose/permute는 stride 관계가 틀어져서 진짜 에러 발생
</span><span class="n">t_transposed</span> <span class="o">=</span> <span class="n">t</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>  <span class="c1"># stride 관계가 깨짐
</span><span class="k">print</span><span class="p">(</span><span class="n">t_transposed</span><span class="p">.</span><span class="n">is_contiguous</span><span class="p">())</span>  <span class="c1"># False
</span>
<span class="k">try</span><span class="p">:</span>
    <span class="n">t_transposed</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">except</span> <span class="nb">RuntimeError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
    <span class="k">print</span><span class="p">(</span><span class="n">e</span><span class="p">)</span>

<span class="c1"># 해결
</span><span class="k">print</span><span class="p">(</span><span class="n">t_transposed</span><span class="p">.</span><span class="n">contiguous</span><span class="p">().</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="n">t_transposed</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>False
view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
tensor([0.0949, 0.1639, 0.6846, 0.3884, 0.6910, 0.5094, 0.1464, 0.4296])
tensor([0.0949, 0.1639, 0.6846, 0.3884, 0.6910, 0.5094, 0.1464, 0.4296])
</code></pre></div></div>

<h3 id="34-contiguous의-역할">3.4. <code class="language-plaintext highlighter-rouge">contiguous()</code>의 역할</h3>

<p><code class="language-plaintext highlighter-rouge">.contiguous()</code>는 <strong>메모리를 새로 할당하고 데이터를 연속된 형태로 복사</strong>한다.
이때 <strong>실제 copy가 발생</strong>하기 때문에 주의가 필요하다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t_cont</span> <span class="o">=</span> <span class="n">t_slice</span><span class="p">.</span><span class="n">contiguous</span><span class="p">()</span>

<span class="k">print</span><span class="p">(</span><span class="n">t_cont</span><span class="p">.</span><span class="n">is_contiguous</span><span class="p">())</span>                      <span class="c1"># True
</span><span class="k">print</span><span class="p">(</span><span class="n">t_slice</span><span class="p">.</span><span class="n">data_ptr</span><span class="p">()</span> <span class="o">==</span> <span class="n">t_cont</span><span class="p">.</span><span class="n">data_ptr</span><span class="p">())</span>     <span class="c1"># False → 다른 메모리
</span><span class="n">t_cont</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>                                    <span class="c1"># 정상 작동
</span></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>True
False

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])
</code></pre></div></div>

<h3 id="35-view-vs-reshape-정리">3.5. <code class="language-plaintext highlighter-rouge">view()</code> vs <code class="language-plaintext highlighter-rouge">reshape()</code> 정리</h3>

<table>
  <thead>
    <tr>
      <th> </th>
      <th><code class="language-plaintext highlighter-rouge">view()</code></th>
      <th><code class="language-plaintext highlighter-rouge">reshape()</code></th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>contiguous 필요</td>
      <td>✅ 반드시</td>
      <td>❌ 아니어도 됨</td>
    </tr>
    <tr>
      <td>동작 방식</td>
      <td>항상 view (zero-copy)</td>
      <td>contiguous면 view, 아니면 내부적으로 copy</td>
    </tr>
    <tr>
      <td>에러 발생</td>
      <td>비연속이면 RuntimeError</td>
      <td>없음</td>
    </tr>
  </tbody>
</table>

<p>실무에서는 보통 아래 두 패턴 중 하나를 쓴다:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># 패턴 1: 명시적으로 contiguous 보장 후 view
</span><span class="n">t_slice</span><span class="p">.</span><span class="n">contiguous</span><span class="p">().</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>

<span class="c1"># 패턴 2: reshape에 맡기기 (더 간편)
</span><span class="n">t_slice</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])
</code></pre></div></div>

<h3 id="36-실무에서-자주-만나는-케이스">3.6. 실무에서 자주 만나는 케이스</h3>

<p>Vision Transformer나 멀티헤드 어텐션 구현에서 특히 자주 나온다.
<code class="language-plaintext highlighter-rouge">transpose()</code>와 <code class="language-plaintext highlighter-rouge">permute()</code>는 <strong>항상 비연속 텐서를 반환</strong>하기 때문에, 이후에 <code class="language-plaintext highlighter-rouge">view()</code>를 쓸 계획이라면 <code class="language-plaintext highlighter-rouge">.contiguous()</code>를 습관적으로 붙여주는 것이 좋다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">B</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span>
<span class="n">feature_map</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span><span class="p">)</span>

<span class="c1"># [B, H, W, C] → 특정 채널 추 후 reshape
</span><span class="n">x</span> <span class="o">=</span> <span class="n">feature_map</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span>     <span class="c1"># shape [B, H, W], 비연속 가능성 있음
</span><span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">contiguous</span><span class="p">().</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># 안전하게 flatten출
</span>
<span class="c1"># transpose 후 reshape할 때
</span><span class="n">x</span> <span class="o">=</span> <span class="n">feature_map</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>  <span class="c1"># transpose는 항상 비연속!
</span><span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">contiguous</span><span class="p">().</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>]]></content><author><name>YSPARK</name></author><category term="Code-Review" /><category term="PyTorch-basic" /><summary type="html"><![CDATA[tensor slicing이란 무엇인지 깨닫고, 메모리를 고려하며 코딩하는 법]]></summary></entry><entry><title type="html">[논문리뷰] Interpretable fMRI Captioning via Contrastive Learning</title><link href="https://kitewatermelon.github.io/paper-review/brain-decoding-with-blip2/" rel="alternate" type="text/html" title="[논문리뷰] Interpretable fMRI Captioning via Contrastive Learning" /><published>2026-03-10T00:00:00+09:00</published><updated>2026-03-10T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/paper-review/brain-decoding-with-blip2</id><content type="html" xml:base="https://kitewatermelon.github.io/paper-review/brain-decoding-with-blip2/"><![CDATA[<blockquote>
  <p>MICCAI 2025 [<a href="https://papers.miccai.org/miccai-2025/paper/2049_paper.pdf">Paper</a>] [<a href="https://github.com/slavaheroes/brain-decoding-with-blip2">GitHub</a>]<br />
 Vyacheslav Shen, Kassymzhomart Kunanbayev, Donggon Jang, Daeshik Kim
 20 Sep 2025</p>
</blockquote>

<h2 id="1-introduction">1. Introduction</h2>
<p>뇌의 계층적 이미지 처리는 CNN 개발에 영감을 주었다. CNN 레이어 전체에 걸쳐 특징과 돌출 맵을 시각화하면 초기 레이어에서 엣지를 감지하고 깊어질수록 클래스 특화된 특징을 감지하는 것은 시각 피질의 기능과 유사하다. 더욱이 CNN-learned representations는 원숭이와 사람의 neural activity와 강한 상관관계가 있다. 이런 유사성 덕분에 neural activity로 DNN feature를 역으로 예측하는 방식으로 DNNs는 visual representations를 디코딩하는데 많이 사용된다.</p>

<p>Huthet al.은 fMRI 데이터를 단어 임베딩에 매핑하여 몇 시간 분량의 서술된 이야기를 디코딩할 수 있음을 보여주었으며, 최근에는 LDM을 이용하여 fMRI 데이터로부터 고해상도 자극 이미지를 reconstruction 하는 연구도 있었다. 한편, Transformer 아키텍처와 GPT-2는 neural activity로부터 자연어 재구성을 크게 향상시켰다. 그러나 생성된 출력물의 품질과 의미론적 일관성을 위해 추가적인 개성과 대안이 필요하다. 기존에는 brain activity의 시각 자극에서 이미지를 재구성하는 방식으로 접근했으나, 최근에는 multimodal deep learning이 대안을 제공한다. 신경 반응을 바로 textual descriptions로 디코딩하는 것인데 이를 fMRI captioning이라고 한다. 이런 관점에서 multimodal retrieval은 brain activity로부터 무엇이 보였고 근본적으로 의미론적인 내용을 유연하게 디코딩할 수 있다. fMRI-based decoding의 발전에도 불구하고 효율적으로 brain activity와 의미 있는 textual descriptions을 align하는 것은 아직 여러 문제가 있는데, 연산 효율, 의미론적 일관성 그리고 retrieval capabilities이다. 본 논문에선 contrastive learning을 통해 이 문제를 해결한다.</p>

<p>본 논문의 contribution은 다음과 같다.</p>
<ul>
  <li>연산 효율이 좋은 two-stage training을 도입하여 fMRI 데이터와 VL model(BLIP-2)을 align한다.</li>
  <li>synthetic fMRI patterns을 이용하여 interpretability decoding analysis를 제안한다.</li>
</ul>

<h3 id="11-related-work">1.1 Related Work</h3>
<p>CLIP (Contrastive Language-Image Pre-training)은 image 인코더와 text 인코더로 구성되며 multimodal model의 진보에 크게 기여했다. LDM의 reverse diffusion process에서 가이드를 하는 역할도 하고 VLMs에 LLMs과 visual data를 align 할 때도 사용한다.</p>

<p>이런 유능함에 힘입어 fMRI 신호로 CLIP의 image embedding을 예측하도록 하여 시각 자극을 재건하는 곳에 쓰인다. 그러나 breain decoding 연구에는 fMRI의 차원이 15,724로 충분히 고차원인데 conditional embedding 역시 257 × 768이나 257 × 1024 같은 고차원으로, 높은 연산량을 요구받는 어려움이 있다.</p>

<p>본 논문에서는 BLIP-2를 이용하여 visual embedding의 차원을 32 × 768로 compact하게 만든다. BLIP-2는 Q-Former(Querying Transformer)를 사용하여 이미지 인코더 기능을 LLM 임베딩 공간에 매핑한다. 압축 네트워크 역할을 하는 Q-Former는 대규모 frozen image features(257 × 1024)를 compact query tokens(32 × 768)으로 인코딩하여 뇌 디코딩에 적합한 텍스트 관련 및 의미론적으로 풍부한 이미지 표현을 보존한다.</p>

<h2 id="2-methodology">2 Methodology</h2>
<h3 id="21-dataset">2.1 Dataset</h3>
<p>Natural Scenes Dataset (NSD) 데이터셋을 사용한다. 이는 COCO dataset의 image를 각각 3초간 본 8명의 피험자의 7T fMRI 데이터셋이다. 기존 연구와 일관되도록 subj1의 데이터에서만 정량 분석을 한다. subj1은 모든 실험 시험을 완료하여 24,980개의 fMRI 시험(이미지당 최대 3회 반복)에 해당하는 8,859개의 훈련 이미지와 2,770개의 fMRI 시험이 포함된 982개의 테스트 이미지의 데이터 세트를 얻었다. 여러번 보여진 이미지에 대해서는 대응하는 fMRI trials에 대하여 평균을 취했다.</p>

<p>Ozcelik et al.을 따라 ridge regression을 사용한 GLM에서 억은 단일 실험 베타 가중치를 사용하여 fMRI를 처리했다. 시각 축을 따라 z-정규화했으며 NSDGeneral Regions-of-Interest (ROI) 마스크를 사용하여 15,764 복셀 벡터를 추출했다.</p>

<h3 id="22-fmri-captioning-with-blip-2">2.2 fMRI Captioning with BLIP-2</h3>
<p>본 논문에서는 textual descriptions from fMRI activity를 생성하기 위해 pre-trained BLIP-2를 사용했다. 이는 compact language-aligned image representations (32 × 768)을 제공하기 때문이다.</p>

<center>
<img src="/assets/img/paper-review/brain-decoding-with-blip2/fig1.webp" width="80%" />
</center>
<p><br /></p>

<p>위 그림에서 볼 수 있듯이 feature extraction and Brain Model training으로 2 단계 프레임워크가 시작된다.</p>

<p>첫번째 단계에서는 stimulus image이 BLIP-2 이미지 인코더로 처리되고 BLIP-2 Q-Former안의 learned query vectors와 cross attention 하여 32 × 768의 최종 representation을 뽑는다. Brain Model은 ridge regression을 이용하여 fMRI activity(15,764 voxel)을 32 × 768의 최종 representation의 임베딩과 매핑한다.</p>

<p>두번째 단계에서는 retrieval을 위해 Brain Model의 출력과 text embeddings를 contrastive learning을 통해 align한다. GT caption은 BLIP-2 Q-Former’s self-attention 레이어를 통해 text embedding을 생성한다. image-text space와 Brain Model의 출력을 align하기 위해 linear projection layer를 도입한다. (fig 2 참고) 최종 loss는 다음과 같다.</p>

\[\begin{equation}
\mathcal L = \lambda_1\mathcal L MSE(b,i) + \lambda_2\mathcal L CLIP(b,t) + \lambda_3\mathcal L CLIP(i,t)
\end{equation}\]

<p>역할은 다음과 같다.</p>
<ol>
  <li>Mean Squared Error (MSE) loss: Brain Model’s predicted embeddings b 와 the GT image embeddings i의 alignment를 보존</li>
  <li>Brain-text contrastive loss: Brain Model’s outputs b 와 text embeddings t를 align해서 text retrieval 성능 향상</li>
  <li>Image-text contrastive loss: catastrophic forgetting 방지 및 t와 i의 일관성을 강화하며 robust image-text를 align</li>
</ol>

<h2 id="3-results--discussion">3 Results &amp; Discussion</h2>
<h3 id="31-retrieval">3.1 Retrieval</h3>
<blockquote>
  <p>Multimodal Retrieval이란?</p>
  <ul>
    <li>여러 종류의 데이터(뇌 신호, 이미지, 텍스트)를 서로 검색할 수 있는 능력</li>
    <li>예시</li>
  </ul>

  <table>
    <thead>
      <tr>
        <th>입력 (Query)</th>
        <th>검색 대상 (Retrieved)</th>
        <th>의미</th>
      </tr>
    </thead>
    <tbody>
      <tr>
        <td>fMRI 뇌 신호</td>
        <td>이미지 (B→I)</td>
        <td>“이 뇌 활동을 봤을 때 어떤 이미지를 본 거지?”</td>
      </tr>
      <tr>
        <td>이미지</td>
        <td>fMRI 뇌 신호 (I→B)</td>
        <td>“이 이미지를 봤을 때의 뇌 신호는 어느 것이지”</td>
      </tr>
      <tr>
        <td>fMRI 뇌 신호</td>
        <td>텍스트 (B→T)</td>
        <td>“이 뇌 활동을 설명하는 문장은 무엇이지?”</td>
      </tr>
      <tr>
        <td>텍스트</td>
        <td>fMRI 뇌 신호 (T→B)</td>
        <td>“이 문장에 해당하는 뇌 신호는 어느 것이지?”</td>
      </tr>
    </tbody>
  </table>
</blockquote>

<h4 id="image-and-brain-retrieval">image and brain retrieval</h4>
<p>이미지를 BLIP-2 Q-Former representation으로 만들고 fMRI-derived representation과 image embedding의 cosine similarity를 계산한다. MindEye-2의 eval protocol을 따라 300 sample의 top-1 retrieval accuracy를 측정한다. 보고된 결과는 30번의 시도에 대한 평균 정확도를 반영한다.</p>

<h4 id="textbrain-retrieval">text/brain retrieval</h4>
<p>text-aligned image embedding을 stage 2의 Brain Model을 이용하여 예측한다. caption embedding을 BLIP-2 Q-Former를 이용하여 얻으며 올바른지 확인하기 위해 cosine similarity를 계산한다. 50번의 시도에 대한 평균 정확도를 보고 한다.</p>

<p>성능은 다음 표와 같으며 T → B와 B → T가 가능 한 모델임을 보여준다.</p>

<center>
<img src="/assets/img/paper-review/brain-decoding-with-blip2/tab1.webp" width="80%" />
</center>
<p><br /></p>

<h3 id="32-fmri-captioning">3.2 fMRI Captioning</h3>
<p>BLIP-2에 구현되어 있는 OPT-2.7B decoder-only language model를 이용하여 textual descriptions을 생성한다. 6개 중 5개에서 다른 모델들을 stage 1에서도 이미 넘어섰으며 stage 2는 압도적인 성능을 보인다.</p>

<center>
<img src="/assets/img/paper-review/brain-decoding-with-blip2/tab2.webp" width="80%" />
</center>
<p><br /></p>

<p>아래의 Figure 4는 정성적인 성능을 보여준다. Stage 1보다 Stage 2에서 구체적인 caption이 나왔다. (beach 보다 wave, horses보다 zebra 등…)</p>

<center>
<img src="/assets/img/paper-review/brain-decoding-with-blip2/fig4.webp" width="80%" />
</center>
<p><br /></p>

<h3 id="33-interpretability-analysis-of-roi-specific-fmri-signals">3.3 Interpretability Analysis of ROI-Specific fMRI Signals</h3>
<p>서로 다른 뇌 영역의 역할을 분석하기 위해 ROI-based interpretability analysis를 Brain Diffuser를 따라 한다. ROI의 voxel의 값을 1로 하고, 나머지를 0으로 만들어 synthetic fMRI 신호를 생성한다. Brain Model을 통해 처리 되고 정규화 후 11로 스케일되고 나서 caption 생성을 위해 language model을 통과한다. 아래 표는 그 결과이다.</p>

<center>
<img src="/assets/img/paper-review/brain-decoding-with-blip2/tab3.webp" width="80%" />
</center>
<p><br /></p>

<p>이 결과는 인간의 계층적, 모듈적 특성을 반영하는 시각 처리의 신경과학적 연구 결과와 일치한다. 예를 하나만 들자면 V1은 basic black-and-white features를 highlight한다. floc-words 영역은 텍스트 및 기호와 관련된 caption을 생성한다. 이런 결과는 Brain Diffuser의 결과와 일관되게 같다.</p>

<h2 id="4-conclusion">4 Conclusion</h2>
<p>본 논문에서는 연산 효율이 좋은 2 단계의 학습 프레임워크를 제안한다. contrastive learning을 도입하여 fMRI로 부터 정확한 captions을 생성하도록 하였으며, Vision-Language model representations과 brain activity를 align하여 multimnodal retrieval의 성능을 향상시켰다. ROI-optimal stimuli analysis는 decoding 과정에서 특정 뇌 영역의 contribuution을 식별하며 interpretability를 향상시켰다. 일반화 능력을 향상시키기 위해 cross-subject decoding에 초점을 두고, 적용 가능성을 향상시키기 위하여 multimodal generarion을 더 탐구하는 것을 future work로 두며 저자들은 글을 마무리 짓는다.</p>

<h2 id="개인적인-생각">개인적인 생각</h2>
<ul>
  <li>본 논문은 BLIP-2 Q-Former를 이용하여 연산 효율을 높이며 multimodal retrieval, fMRI captioning, Interpretability Analysis의 3가지 실험을 통해 우수성을 입증했다.</li>
  <li>새로운 데이터셋을 통해 fMRI가 질환 연구에만 사용되는 것이 아닌 신경과학 분야에서 뇌를 이해하기 위해 사용되는 것을 확인하며 fMRI의 범용성을 알 수 있었다.</li>
  <li>b, i에서는 왜 MSE를 사용하고, 나머지는 왜 CLIP loss를 사용하는지 이해하지 못했는데 이유는 다음과 같다.
    <blockquote>
      <p>MSE: Brain Model의 출력 b가 Image Embedding i와 “최대한 똑같은 벡터값”이 되길 원함 <br />
CLIP: Brain Model의 출력 b가 Text Embedding t와 “의미적으로 가까운 공간”에 있길 원함</p>
    </blockquote>
  </li>
  <li>이 논문 역시 작년에 직접 설명을 들었었는데, 배경지식의 부족으로 그저 지나친 논문중에 하나였다. 이제 공부를 해서 어느정도 이해를 할 수 있어서 기쁘다. 저자분은 한국말을 잘하셨다.</li>
</ul>]]></content><author><name>YSPARK</name></author><category term="Paper-Review" /><category term="Medical-AI" /><category term="Brain-Decoding" /><category term="Contrastive-Learning" /><category term="MICCAI" /><summary type="html"><![CDATA[Interpretable fMRI Captioning via Contrastive Learning (MICCAI 2025)]]></summary></entry></feed>