PyTorch 命名為 Tensors 操作員范圍

2020-09-15 14:00 更新

原文:PyTorch 命名為 Tensors 操作員范圍

請首先閱讀命名張量,以了解命名張量。

本文檔是名稱推斷的參考,HTH1 是一個定義張量命名方式的過程:

  1. 使用名稱提供其他自動運行時正確性檢查
  2. 將名稱從輸入張量傳播到輸出張量

以下是命名張量及其關(guān)聯(lián)的名稱推斷規(guī)則支持的所有操作的列表。

如果此處未列出操作,但對您的用例有幫助,請搜索問題是否已提交,否則請提交一個問題。

警告

命名的張量 API 是實驗性的,隨時可能更改。

Supported Operations

API 名稱推斷規(guī)則
Tensor.abs() , torch.abs() 保留輸入名稱
Tensor.abs_() Keeps input names
Tensor.acos() , torch.acos() Keeps input names
Tensor.acos_() Keeps input names
Tensor.add() , torch.add() 統(tǒng)一輸入的名稱
Tensor.add_() Unifies names from inputs
Tensor.addmm() , torch.addmm() 縮小暗淡
Tensor.addmm_() Contracts away dims
Tensor.addmv() , torch.addmv() Contracts away dims
Tensor.addmv_() Contracts away dims
Tensor.align_as() 查看文件
Tensor.align_to() See documentation
Tensor.all(),torch.all() 沒有
Tensor.any()torch.any() None
Tensor.asin() , torch.asin() Keeps input names
Tensor.asin_() Keeps input names
Tensor.atan() , torch.atan() Keeps input names
Tensor.atan2() , torch.atan2() Unifies names from inputs
Tensor.atan2_() Unifies names from inputs
Tensor.atan_() Keeps input names
Tensor.bernoulli() , torch.bernoulli() Keeps input names
Tensor.bernoulli_() None
Tensor.bfloat16() Keeps input names
Tensor.bitwise_not() , torch.bitwise_not() Keeps input names
Tensor.bitwise_not_() None
Tensor.bmm() , torch.bmm() Contracts away dims
Tensor.bool() Keeps input names
Tensor.byte() Keeps input names
torch.cat() Unifies names from inputs
Tensor.cauchy_() None
Tensor.ceil() , torch.ceil() Keeps input names
Tensor.ceil_() None
Tensor.char() Keeps input names
Tensor.chunk() , torch.chunk() Keeps input names
Tensor.clamp() , torch.clamp() Keeps input names
Tensor.clamp_() None
Tensor.copy_() 輸出功能和就地變體
Tensor.cos() , torch.cos() Keeps input names
Tensor.cos_() None
Tensor.cosh() , torch.cosh() Keeps input names
Tensor.cosh_() None
Tensor.cpu() Keeps input names
Tensor.cuda() Keeps input names
Tensor.cumprod() , torch.cumprod() Keeps input names
Tensor.cumsum() , torch.cumsum() Keeps input names
Tensor.data_ptr() None
Tensor.detach() ,torch.detach() Keeps input names
Tensor.detach_() None
Tensor.device , torch.device() None
Tensor.digamma() , torch.digamma() Keeps input names
Tensor.digamma_() None
Tensor.dim() None
Tensor.div() , torch.div() Unifies names from inputs
Tensor.div_() Unifies names from inputs
Tensor.dot() , torch.dot() None
Tensor.double() Keeps input names
Tensor.element_size() None
torch.empty() 工廠功能
torch.empty_like() Factory functions
Tensor.eq() , torch.eq() Unifies names from inputs
Tensor.erf() , torch.erf() Keeps input names
Tensor.erf_() None
Tensor.erfc() , torch.erfc() Keeps input names
Tensor.erfc_() None
Tensor.erfinv() , torch.erfinv() Keeps input names
Tensor.erfinv_() None
Tensor.exp() , torch.exp() Keeps input names
Tensor.exp_() None
Tensor.expand() Keeps input names
Tensor.expm1() , torch.expm1() Keeps input names
Tensor.expm1_() None
Tensor.exponential_() None
Tensor.fill_() None
Tensor.flatten() , torch.flatten() See documentation
Tensor.float() Keeps input names
Tensor.floor() , torch.floor() Keeps input names
Tensor.floor_() None
Tensor.frac() , torch.frac() Keeps input names
Tensor.frac_() None
Tensor.ge() , torch.ge() Unifies names from inputs
Tensor.get_device() ,torch.get_device() None
Tensor.grad None
Tensor.gt() , torch.gt() Unifies names from inputs
Tensor.half() Keeps input names
Tensor.has_names() See documentation
Tensor.index_fill() ,torch.index_fill() Keeps input names
Tensor.index_fill_() None
Tensor.int() Keeps input names
Tensor.is_contiguous() None
Tensor.is_cuda None
Tensor.is_floating_point() , torch.is_floating_point() None
Tensor.is_leaf None
Tensor.is_pinned() None
Tensor.is_shared() None
Tensor.is_signed() ,torch.is_signed() None
Tensor.is_sparse None
torch.is_tensor() None
Tensor.item() None
Tensor.kthvalue() , torch.kthvalue() 移除尺寸
Tensor.le() , torch.le() Unifies names from inputs
Tensor.log() , torch.log() Keeps input names
Tensor.log10() , torch.log10() Keeps input names
Tensor.log10_() None
Tensor.log1p() , torch.log1p() Keeps input names
Tensor.log1p_() None
Tensor.log2() , torch.log2() Keeps input names
Tensor.log2_() None
Tensor.log_() None
Tensor.log_normal_() None
Tensor.logical_not() , torch.logical_not() Keeps input names
Tensor.logical_not_() None
Tensor.logsumexp() , torch.logsumexp() Removes dimensions
Tensor.long() Keeps input names
Tensor.lt() , torch.lt() Unifies names from inputs
torch.manual_seed() None
Tensor.masked_fill() ,torch.masked_fill() Keeps input names
Tensor.masked_fill_() None
Tensor.masked_select() , torch.masked_select() 將遮罩對齊到輸入,然后 unified_names_from_input_tensors
Tensor.matmul() , torch.matmul() Contracts away dims
Tensor.mean() , torch.mean() Removes dimensions
Tensor.median() , torch.median() Removes dimensions
Tensor.mm() , torch.mm() Contracts away dims
Tensor.mode() , torch.mode() Removes dimensions
Tensor.mul() , torch.mul() Unifies names from inputs
Tensor.mul_() Unifies names from inputs
Tensor.mv() , torch.mv() Contracts away dims
Tensor.names See documentation
Tensor.narrow() , torch.narrow() Keeps input names
Tensor.ndim None
Tensor.ndimension() None
Tensor.ne() , torch.ne() Unifies names from inputs
Tensor.neg() , torch.neg() Keeps input names
Tensor.neg_() None
torch.normal() Keeps input names
Tensor.normal_() None
Tensor.numel() , torch.numel() None
torch.ones() Factory functions
Tensor.pow() , torch.pow() Unifies names from inputs
Tensor.pow_() None
Tensor.prod() , torch.prod() Removes dimensions
torch.rand() Factory functions
torch.rand() Factory functions
torch.randn() Factory functions
torch.randn() Factory functions
Tensor.random_() None
Tensor.reciprocal() , torch.reciprocal() Keeps input names
Tensor.reciprocal_() None
Tensor.refine_names() See documentation
Tensor.register_hook() None
Tensor.rename() See documentation
Tensor.rename_() See documentation
Tensor.requires_grad None
Tensor.requires_grad_() None
Tensor.resize_() 只允許不改變形狀的調(diào)整大小
Tensor.resize_as_() Only allow resizes that do not change shape
Tensor.round() , torch.round() Keeps input names
Tensor.round_() None
Tensor.rsqrt() , torch.rsqrt() Keeps input names
Tensor.rsqrt_() None
Tensor.select() ,torch.select() Removes dimensions
Tensor.short() Keeps input names
Tensor.sigmoid() , torch.sigmoid() Keeps input names
Tensor.sigmoid_() None
Tensor.sign() , torch.sign() Keeps input names
Tensor.sign_() None
Tensor.sin() , torch.sin() Keeps input names
Tensor.sin_() None
Tensor.sinh() , torch.sinh() Keeps input names
Tensor.sinh_() None
Tensor.size() None
Tensor.split() , torch.split() Keeps input names
Tensor.sqrt() , torch.sqrt() Keeps input names
Tensor.sqrt_() None
Tensor.squeeze() , torch.squeeze() Removes dimensions
Tensor.std() , torch.std() Removes dimensions
torch.std_mean() Removes dimensions
Tensor.stride() None
Tensor.sub() ,torch.sub() Unifies names from inputs
Tensor.sub_() Unifies names from inputs
Tensor.sum() , torch.sum() Removes dimensions
Tensor.tan() , torch.tan() Keeps input names
Tensor.tan_() None
Tensor.tanh() , torch.tanh() Keeps input names
Tensor.tanh_() None
torch.tensor() Factory functions
Tensor.to() Keeps input names
Tensor.topk() , torch.topk() Removes dimensions
Tensor.transpose() , torch.transpose() 排列尺寸
Tensor.trunc() , torch.trunc() Keeps input names
Tensor.trunc_() None
Tensor.type() None
Tensor.type_as() Keeps input names
Tensor.unbind() , torch.unbind() Removes dimensions
Tensor.unflatten() See documentation
Tensor.uniform_() None
Tensor.var() , torch.var() Removes dimensions
torch.var_mean() Removes dimensions
Tensor.zero_() None
torch.zeros() Factory functions

保留輸入名稱

所有逐點一元函數(shù)以及其他一些一元函數(shù)都遵循此規(guī)則。

  • 檢查姓名:無
  • 傳播名稱:輸入張量的名稱會傳播到輸出。

>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.abs().names
('N', 'C')

移除尺寸

所有縮小操作,例如 sum() ,都會通過縮小所需尺寸來刪除尺寸。 select()squeeze() 等其他操作會刪除尺寸。

只要有人可以將整數(shù)維度索引傳遞給運算符,就可以傳遞維度名稱。 包含維索引列表的函數(shù)也可以包含維名稱列表。

  • 檢查名稱:如果dimdims作為名稱列表傳入,請檢查self中是否存在這些名稱。
  • 傳播名稱:如果在輸出張量中不存在dimdims指定的輸入張量的尺寸,則這些尺寸的相應(yīng)名稱不會出現(xiàn)在output.names中。

>>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.squeeze('N').names
('C', 'H', 'W')


>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C']).names
('H', 'W')


## Reduction ops with keepdim=True don't actually remove dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C'], keepdim=True).names
('N', 'C', 'H', 'W')

統(tǒng)一輸入中的名稱

所有二進制算術(shù)運算都遵循此規(guī)則。 廣播操作仍然從右側(cè)進行位置廣播,以保持與未命名張量的兼容性。 要通過名稱執(zhí)行顯式廣播,請使用 Tensor.align_as() 。

  • 檢查名稱:所有名稱都必須從右側(cè)位置匹配。 即,在tensor + other中,對于(-min(tensor.dim(), other.dim()) + 1, -1]中的所有i,match(tensor.names[i], other.names[i])必須為 true。
  • 檢查名稱:此外,所有命名的尺寸必須從右對齊。 在匹配期間,如果我們將命名尺寸A與未命名尺寸None匹配,則A不得出現(xiàn)在具有未命名尺寸的張量中。
  • 傳播名稱:從兩個張量的右邊開始統(tǒng)一名稱對,以產(chǎn)生輸出名稱。

例如,

## tensor: Tensor[   N, None]
## other:  Tensor[None,    C]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, 3, names=(None, 'C'))
>>> (tensor + other).names
('N', 'C')

檢查姓名:

  • match(tensor.names[-1], other.names[-1])True
  • match(tensor.names[-2], tensor.names[-2])True
  • 由于我們將 tensor 中的None'C'匹配,因此請確保 tensor 中不存在'C'
  • 檢查以確保other中不存在'N'(不存在)。

最后,使用[unify('N', None), unify(None, 'C')] = ['N', 'C']計算輸出名稱

更多示例:

## Dimensions don't match from the right:
## tensor: Tensor[N, C]
## other:  Tensor[   N]
>>> tensor = torch.randn(3, 3, names=('N', 'C'))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Error when attempting to broadcast dims ['N', 'C'] and dims
['N']: dim 'C' and dim 'N' are at the same position from the right but do
not match.


## Dimensions aren't aligned when matching tensor.names[-1] and other.names[-1]:
## tensor: Tensor[N, None]
## other:  Tensor[      N]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Misaligned dims when attempting to broadcast dims ['N'] and
dims ['N', None]: dim 'N' appears in a different position from the right
across both lists.

注意

在最后兩個示例中,可以通過名稱對齊張量,然后執(zhí)行加法。 使用 Tensor.align_as() 按名稱對齊張量,或使用 Tensor.align_to() 將張量對齊到自定義尺寸順序。

排列尺寸

某些操作(例如 Tensor.t())會置換尺寸順序。 維度名稱附加到各個維度,因此也可以排列。

如果操作員輸入位置索引dim,它也可以采用尺寸名稱作為dim

  • 檢查名稱:如果將dim作為名稱傳遞,請檢查其是否在張量中存在。
  • 傳播名稱:以與要排列的維相同的方式排列維名稱。

>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.transpose('N', 'C').names
('C', 'N')

收縮消失

矩陣乘法函數(shù)遵循此方法的某些變體。 讓我們先通過 torch.mm() ,然后概括一下批矩陣乘法的規(guī)則。

對于torch.mm(tensor, other)

  • Check names: None
  • 傳播名稱:結(jié)果名稱為(tensor.names[-2], other.names[-1])。

>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, 3, names=('in', 'out'))
>>> x.mm(y).names
('N', 'out')

本質(zhì)上,矩陣乘法在二維上執(zhí)行點積運算,使它們折疊。 當(dāng)兩個張量矩陣相乘時,收縮尺寸消失,并且不出現(xiàn)在輸出張量中。

torch.mv() , torch.dot() 的工作方式類似:名稱推斷不會檢查輸入名稱,并且會刪除點積所涉及的尺寸:

>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, names=('something',))
>>> x.mv(y).names
('N',)

現(xiàn)在,讓我們看一下torch.matmul(tensor, other)。 假設(shè)tensor.dim() >= 2other.dim() >= 2。

  • 檢查名稱:檢查輸入的批次尺寸是否對齊并可以廣播。 請參見統(tǒng)一輸入的名稱,以了解對齊輸入的含義。
  • 傳播名稱:結(jié)果名稱是通過統(tǒng)一批次尺寸并刪除合同規(guī)定的尺寸獲得的:unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])。

例子:

## Batch matrix multiply of matrices Tensor['C', 'D'] and Tensor['E', 'F'].
## 'A', 'B' are batch dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D))
>>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F))
>>> torch.matmul(x, y).names
('A', 'B', 'C', 'F')

最后,還有許多功能的融合add版本。 即 addmm()addmv() 。 這些被視為構(gòu)成 mm() 的名稱推斷和 add() 的命名推斷。

工廠功能

現(xiàn)在,工廠函數(shù)采用新的names參數(shù),該參數(shù)將名稱與每個維度相關(guān)聯(lián)。

>>> torch.zeros(2, 3, names=('N', 'C'))
tensor([[0., 0., 0.],
        [0., 0., 0.]], names=('N', 'C'))

輸出功能和就地變型

指定為out=張量的張量具有以下行為:

  • 如果沒有命名維,則將從操作中計算出的名稱傳播到其中。
  • 如果它具有任何命名維,則從該操作計算出的名稱必須與現(xiàn)有名稱完全相同。 否則,操作錯誤。

所有就地方法都會將輸入修改為具有與根據(jù)名稱推斷計算出的名稱相同的名稱。 例如,

>>> x = torch.randn(3, 3)
>>> y = torch.randn(3, 3, names=('N', 'C'))
>>> x.names
(None, None)


>>> x += y
>>> x.names
('N', 'C')
以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號