TensorFlow张量索引和切片

TensorFlow张量索引和切片

张量的索引和切片

索引

  • 索引是指张量中元素所在的位置,我们可以通过索引来访问或修改张量中的元素。
  • 正向索引中,张量中第一个元素的索引值为 0(注意索引值为 0 才表示第一个元素,这与我们日常数学中的概念有所区别);第二个元素的索引为 1,以此类推。
  • 反向索引中,倒数第一个元素的索引值为 -1;倒数第二个元素的索引值为 -2,以此类推。

假设有一个图片张量,图片总数为 6 张,图像宽和高分别为 2828,且为彩色图像(3 通道),那么该图片张量的形状为 (6, 28, 28, 3)。我们用正态分布来模拟该张量:

import tensorflow as tf
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

image_tensor = tf.random.normal([6, 28, 28, 3])  # 创建尺寸为 (6, 28, 28, 3) 的图片张量

接下来,我们利用索引来访问图片张量中的元素:

print(image_tensor[0])  # 输出第一张图片
tf.Tensor(
    [[[-0.99267304  1.3366516  -0.8440698 ]
      [ 1.0535443   0.01591775 -0.20434138]
      [-0.27021262 -0.30959526  1.0631086 ]
      ...
      [-1.008951   -0.95502394  1.4857111 ]
      [ 0.2602379  -1.2565975  -1.5890635 ]
      [ 1.2195483  -0.39602706  0.01813171]]], shape=(28, 28, 3), dtype=float32)

我们也可以通过反向索引来访问最后一张图片,并对其进行验证:

由于总共只有 6 张图片,那么 image_tensor[5] 表示索引值为 5 的图片数据,对应第 6 张图片;根据反向索引规则,image_tensor[-1] 也表示最后一张图片(第 6 张图片)。

print(image_tensor[5] == image_tensor[-1])  # 判断两个张量是否相等
tf.Tensor(
  [[[ True  True  True]
      [ True  True  True]
      [ True  True  True]
      ...
      [ True  True  True]
      [ True  True  True]
      [ True  True  True]]], shape=(28, 28, 3), dtype=bool)

访问二维数组时,我们需要两个索引:一个表示行(row),一个表示列(column)。在 TensorFlow 中,访问二维数组的方式有两种,一种类似于 C/C++ 中的索引格式,另一种类似于 Numpy。例如,我们可以通过以下方式,访问第 1 张图片的第 2 行数据:

print(image_tensor[0][1])  # 输出第一张图片的第二行
tf.Tensor(
    [[-1.2002944   0.12152377  0.2234562 ]
     [-2.709574    1.2775909   1.4541649 ]
     ...
     [-0.33908963  1.0630795   1.558053  ]
     [-0.17027545  0.8788766   0.87936914]], shape=(28, 3), dtype=float32)

我们也可以将上述代码中的两个维度合并到一个括号中,并用逗号隔开,来实现二维数组的访问,其中 image_tensor[0][1]image_tensor[0, 1] 是等价的:

print(image_tensor[0][1] == image_tensor[0, 1])  # 判断两个张量是否相等
tf.Tensor(
    [[ True  True  True]
     [ True  True  True]
     ...
     [ True  True  True]
     [ True  True  True]], shape=(28, 3), dtype=bool)

切片

切片表示以给定的规则获取张量中的部分数据,我们可以利用切片来访问或者修改张量中的元素。

切片的语法为 [start:stop:step]start 表示开始的索引值,end 表示结束的索引值(不包含 end 位置),step 表示步长(即每次访问的间隔)

例如,比如我们要访问第 2 张到第 4 张图片,我们可以这么写:

print(image_tensor[1:4]) # 输出第二到第四张图片
tf.Tensor(
    [[[[-6.1072814e-01 -2.4276777e-01 -1.2022603e+00]
       [-3.9724684e-01 -1.7809823e-02  1.0459790e+00]
       [-4.6726876e-01 -1.9859287e+00 -1.4355340e+00]
       ...
       [-5.0956316e-02  7.6612139e-01  8.5568899e-01]
       [-1.1918964e+00 -1.2277472e+00 -4.8515704e-01]
       [ 2.0091980e+00 -3.1338432e-01 -1.2111558e+00]]]], shape=(3, 28, 28, 3), dtype=float32)

上述代码中省略了 step,则默认 step 的值为 1,这是切片的一种省略方式。

我们也可以省略 startendstep,这样表示提取该维度的所有元素。

例如,image_tensor[1, ::] 表示第 2 张图片的所有行数据,:: 表示提取行维度上的所有元素,等价于 image_tensor[1]

print(image_tensor[1, ::] == image_tensor[1])  # 判断两个张量是否相等
tf.Tensor(
    [[[ True  True  True]
      [ True  True  True]
      [ True  True  True]
      ...
      [ True  True  True]
      [ True  True  True]
      [ True  True  True]]], shape=(28, 28, 3), dtype=bool)

切片参数省略方式总结如下:

切片参数方式 含义描述
start:end:step start 开始,到 end 结束(不包括 end),步长为 step
start:end start 开始,到 end 结束(不包括 end),步长为 1
start: start 开始到后续所有元素,步长为 1
start::step start 开始到后续所有元素,步长为 step
:end:step 0 开始,到 end 结束(不包括 end),步长为 step
:end 0 开始,到 end 结束(不包括 end),步长为 1
::step 0 开始到后续所有元素,步长为 step
:: 所有元素
: 所有元素

若张量的维度较多,可以采用单个冒号 : 来提取该维度的所有元素。例如,我们想提取图片张量中的蓝色通道,可以这么写:

print(image_tensor[:, :, :, 2])  # 输出第三个通道的所有像素
tf.Tensor(
    [[[-0.8440698  -0.20434138  1.0631086  ...  1.4857111  -1.5890635
        0.01813171]
      [ 0.2234562   1.4541649   1.558053   ...  0.2963377   0.8863706
        0.87936914]
      [-1.1262956  -1.4410621  -0.64045465 ... -0.6537862   0.47904253
        1.2033157 ]
      ...
      [-0.5564168  -0.3602578  -0.07460994 ... -0.596759   -0.4439934
       -1.2628192 ]
      [ 0.04733002  1.7597612   0.7633362  ... -0.03441245  0.799629
       -0.20505095]
      [ 0.9229313  -1.3190516   1.7249115  ... -1.8670495  -0.5548941
       -1.2218807 ]]], shape=(6, 28, 28), dtype=float32)

为了避免出现冒号过多的情况,我们也可以采用省略号 ... 的方式来读取相邻多个维度上的张量,例如,我们想读取前两张图片中的绿通道和蓝通道,可以这么写:

print(image_tensor[:2, ..., 1:])  # 输出前两张图片中的绿色和蓝色通道
tf.Tensor(
    [[[[ 1.33665156e+00 -8.44069779e-01]
       [ 1.59177464e-02 -2.04341382e-01]
       [-3.09595257e-01  1.06310856e+00]
       ...
       [-9.55023944e-01  1.48571110e+00]
       [-1.25659752e+00 -1.58906353e+00]
       [-3.96027058e-01  1.81317069e-02]]]], shape=(2, 28, 28, 2), dtype=float32)

省略号切片方式总结如下:

切片方式 切片含义描述
m, ..., n 省略号 ... 表示:最左边以 m 为界,最右边以 n 为界,中间维度全部包含,其他维度按照 mn 的格式进行读取
m, ... 省略号 ... 表示:最左边以 m 为界,m 之后维度全部包含,其他维度按照 m 的格式进行读取
..., n 省略号 ... 表示:最右边以 n 为界,n 之前维度全部包含,其他维度按照 n 的格式进行读取
... 省略号 ... 表示:读取所有维度数据

Copyright: 采用 知识共享署名4.0 国际许可协议进行许可

Links: https://cangmang.xyz/articles/1648196494511