PyTorch torch稀疏

2025-07-02 11:33 更新

PyTorch 稀疏張量詳解:從基礎(chǔ)到應(yīng)用

一、稀疏張量是什么?

稀疏張量是指那些包含大量零值的張量。在處理高維數(shù)據(jù)或大規(guī)模模型時(shí),稀疏張量能夠顯著節(jié)省內(nèi)存和計(jì)算資源。PyTorch 支持 COO(Coordinate Format)格式的稀疏張量,這種格式通過兩個(gè)密集張量來表示:一個(gè)值張量和一個(gè) 2D 索引張量。

二、稀疏張量的創(chuàng)建和操作

(一)創(chuàng)建稀疏張量

  1. 基本創(chuàng)建方法
    • 使用索引和值張量構(gòu)造稀疏張量。
    • 示例代碼:
      
      import torch

定義索引張量 (LongTensor) 和值張量 (FloatTensor)

 indices = torch.LongTensor([[0, 1, 1], [2, 0, 2]])
 values = torch.FloatTensor([3, 4, 5])

創(chuàng)建稀疏張量,指定大小

 sparse_tensor = torch.sparse.FloatTensor(indices, values, torch.Size([2, 3]))

轉(zhuǎn)換為密集張量查看結(jié)果

 dense_tensor = sparse_tensor.to_dense()
 print(dense_tensor)
 ```

  1. 混合稀疏張量
    • 僅前 n 個(gè)維度是稀疏的,其余維度是密集的。
    • 示例代碼:
      indices = torch.LongTensor([[2, 4]])
      values = torch.FloatTensor([[1, 3], [5, 7]])
      mixed_sparse_tensor = torch.sparse.FloatTensor(indices, values)
      dense_mixed = mixed_sparse_tensor.to_dense()
      print(dense_mixed)

  1. 空稀疏張量
    • 指定大小構(gòu)造空的稀疏張量。
    • 示例代碼:
      empty_sparse = torch.sparse.FloatTensor(2, 3)
      print(empty_sparse)

(二)稀疏張量的基本操作

  1. 加法操作
    • 對(duì)兩個(gè)稀疏張量進(jìn)行加法操作。
    • 示例代碼:
      
      indices1 = torch.LongTensor([[0, 1], [0, 1]])
      values1 = torch.FloatTensor([2, 3])
      sparse1 = torch.sparse.FloatTensor(indices1, values1, torch.Size([2, 2]))

indices2 = torch.LongTensor([[0, 1], [1, 0]]) values2 = torch.FloatTensor([4, 5]) sparse2 = torch.sparse.FloatTensor(indices2, values2, torch.Size([2, 2]))

result_sparse = sparse1 + sparse2 print(result_sparse.to_dense())



2. **矩陣乘法**
   - 稀疏矩陣與密集矩陣的乘法。
   - 示例代碼:
     ```python
     sparse_mat = torch.sparse.FloatTensor(indices, values, torch.Size([2, 3]))
     dense_mat = torch.randn(3, 2)
     product = torch.sparse.mm(sparse_mat, dense_mat)
     print(product)

  1. 求和操作
    • 對(duì)稀疏張量指定維度求和。
    • 示例代碼:
      sum_result = torch.sparse.sum(sparse_tensor, dim=1)
      print(sum_result.to_dense())

(三)稀疏張量的屬性和方法

  1. coalesce()
    • 合并稀疏張量中重復(fù)的索引項(xiàng)。
    • 示例代碼:
      
      # 創(chuàng)建一個(gè)包含重復(fù)索引的稀疏張量
      indices_repeat = torch.LongTensor([[0, 0, 1], [1, 1, 0]])
      values_repeat = torch.FloatTensor([1, 1, 2])
      sparse_repeat = torch.sparse.FloatTensor(indices_repeat, values_repeat, torch.Size([2, 2]))

合并重復(fù)索引

 coalesced_sparse = sparse_repeat.coalesce()
 print(coalesced_sparse.indices())
 print(coalesced_sparse.values())
 ```

  1. is_coalesced()
    • 檢查稀疏張量是否已合并。
    • 示例代碼:
      print(sparse_tensor.is_coalesced())
      print(coalesced_sparse.is_coalesced())

  1. indices()values()
    • 獲取稀疏張量的索引和值張量。
    • 示例代碼:
      print(sparse_tensor.indices())
      print(sparse_tensor.values())

  1. to_dense()
    • 將稀疏張量轉(zhuǎn)換為密集張量。
    • 示例代碼:
      dense_tensor = sparse_tensor.to_dense()
      print(dense_tensor)

三、稀疏張量的應(yīng)用場(chǎng)景

(一)自然語言處理中的稀疏嵌入

在自然語言處理任務(wù)中,詞嵌入矩陣通常是稀疏的。使用稀疏張量可以有效減少內(nèi)存占用并加速計(jì)算。

## 假設(shè)有一個(gè)稀疏的詞嵌入矩陣
word_indices = torch.LongTensor([[0, 2], [1, 3]])
word_values = torch.FloatTensor([[0.1, 0.2], [0.3, 0.4]])
embedding_sparse = torch.sparse.FloatTensor(word_indices, word_values, torch.Size([10000, 128]))


## 在模型中使用稀疏嵌入
dense_output = embedding_sparse.mm(input_vector)

(二)推薦系統(tǒng)中的稀疏用戶-項(xiàng)目交互矩陣

推薦系統(tǒng)中用戶與項(xiàng)目的交互數(shù)據(jù)通常是稀疏的。利用稀疏張量可以高效地存儲(chǔ)和處理這些數(shù)據(jù)。

## 創(chuàng)建用戶-項(xiàng)目交互稀疏矩陣
user_indices = torch.LongTensor([[0, 1, 2], [3, 5, 7]])
interaction_values = torch.FloatTensor([1, 1, 1])
user_item_matrix = torch.sparse.FloatTensor(user_indices, interaction_values, torch.Size([1000, 1000]))


## 使用稀疏矩陣進(jìn)行模型訓(xùn)練
model_output = user_item_matrix.mm(item_embeddings)

四、總結(jié)

通過本教程,我們?cè)敿?xì)介紹了 PyTorch 中稀疏張量的創(chuàng)建、操作及應(yīng)用場(chǎng)景。稀疏張量在處理大規(guī)模數(shù)據(jù)時(shí)能夠顯著節(jié)省內(nèi)存和計(jì)算資源,適用于自然語言處理、推薦系統(tǒng)等多個(gè)領(lǐng)域。掌握稀疏張量的使用方法,可以幫助我們更高效地構(gòu)建和優(yōu)化深度學(xué)習(xí)模型。

以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)