张量的索引和切片
索引
- 索引是指张量中元素所在的位置,我们可以通过索引来访问或修改张量中的元素。
- 在正向索引中,张量中第一个元素的索引值为
0
(注意索引值为0
才表示第一个元素,这与我们日常数学中的概念有所区别);第二个元素的索引为1
,以此类推。 - 在反向索引中,倒数第一个元素的索引值为
-1
;倒数第二个元素的索引值为-2
,以此类推。
假设有一个图片张量,图片总数为 6
张,图像宽和高分别为 28
和 28
,且为彩色图像(3
通道),那么该图片张量的形状为 (6, 28, 28, 3)
。我们用正态分布来模拟该张量:
import tensorflow as tf
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
image_tensor = tf.random.normal([6, 28, 28, 3]) # 创建尺寸为 (6, 28, 28, 3) 的图片张量
接下来,我们利用索引来访问图片张量中的元素:
print(image_tensor[0]) # 输出第一张图片
tf.Tensor(
[[[-0.99267304 1.3366516 -0.8440698 ]
[ 1.0535443 0.01591775 -0.20434138]
[-0.27021262 -0.30959526 1.0631086 ]
...
[-1.008951 -0.95502394 1.4857111 ]
[ 0.2602379 -1.2565975 -1.5890635 ]
[ 1.2195483 -0.39602706 0.01813171]]], shape=(28, 28, 3), dtype=float32)
我们也可以通过反向索引来访问最后一张图片,并对其进行验证:
由于总共只有
6
张图片,那么image_tensor[5]
表示索引值为5
的图片数据,对应第6
张图片;根据反向索引规则,image_tensor[-1]
也表示最后一张图片(第6
张图片)。
print(image_tensor[5] == image_tensor[-1]) # 判断两个张量是否相等
tf.Tensor(
[[[ True True True]
[ True True True]
[ True True True]
...
[ True True True]
[ True True True]
[ True True True]]], shape=(28, 28, 3), dtype=bool)
访问二维数组时,我们需要两个索引:一个表示行(row
),一个表示列(column
)。在 TensorFlow
中,访问二维数组的方式有两种,一种类似于 C/C++
中的索引格式,另一种类似于 Numpy
。例如,我们可以通过以下方式,访问第 1
张图片的第 2
行数据:
print(image_tensor[0][1]) # 输出第一张图片的第二行
tf.Tensor(
[[-1.2002944 0.12152377 0.2234562 ]
[-2.709574 1.2775909 1.4541649 ]
...
[-0.33908963 1.0630795 1.558053 ]
[-0.17027545 0.8788766 0.87936914]], shape=(28, 3), dtype=float32)
我们也可以将上述代码中的两个维度合并到一个括号中,并用逗号隔开,来实现二维数组的访问,其中 image_tensor[0][1]
和 image_tensor[0, 1]
是等价的:
print(image_tensor[0][1] == image_tensor[0, 1]) # 判断两个张量是否相等
tf.Tensor(
[[ True True True]
[ True True True]
...
[ True True True]
[ True True True]], shape=(28, 3), dtype=bool)
切片
切片表示以给定的规则获取张量中的部分数据,我们可以利用切片来访问或者修改张量中的元素。
切片的语法为 [start:stop:step]
,start
表示开始的索引值,end
表示结束的索引值(不包含 end
位置),step
表示步长(即每次访问的间隔)
例如,比如我们要访问第 2
张到第 4
张图片,我们可以这么写:
print(image_tensor[1:4]) # 输出第二到第四张图片
tf.Tensor(
[[[[-6.1072814e-01 -2.4276777e-01 -1.2022603e+00]
[-3.9724684e-01 -1.7809823e-02 1.0459790e+00]
[-4.6726876e-01 -1.9859287e+00 -1.4355340e+00]
...
[-5.0956316e-02 7.6612139e-01 8.5568899e-01]
[-1.1918964e+00 -1.2277472e+00 -4.8515704e-01]
[ 2.0091980e+00 -3.1338432e-01 -1.2111558e+00]]]], shape=(3, 28, 28, 3), dtype=float32)
上述代码中省略了 step
,则默认 step
的值为 1
,这是切片的一种省略方式。
我们也可以省略 start
,end
和 step
,这样表示提取该维度的所有元素。
例如,image_tensor[1, ::]
表示第 2
张图片的所有行数据,::
表示提取行维度上的所有元素,等价于 image_tensor[1]
:
print(image_tensor[1, ::] == image_tensor[1]) # 判断两个张量是否相等
tf.Tensor(
[[[ True True True]
[ True True True]
[ True True True]
...
[ True True True]
[ True True True]
[ True True True]]], shape=(28, 28, 3), dtype=bool)
切片参数省略方式总结如下:
切片参数方式 | 含义描述 |
---|---|
start:end:step |
从 start 开始,到 end 结束(不包括 end ),步长为 step |
start:end |
从 start 开始,到 end 结束(不包括 end ),步长为 1 |
start: |
从 start 开始到后续所有元素,步长为 1 |
start::step |
从 start 开始到后续所有元素,步长为 step |
:end:step |
从 0 开始,到 end 结束(不包括 end ),步长为 step |
:end |
从 0 开始,到 end 结束(不包括 end ),步长为 1 |
::step |
从 0 开始到后续所有元素,步长为 step |
:: |
所有元素 |
: |
所有元素 |
若张量的维度较多,可以采用单个冒号 :
来提取该维度的所有元素。例如,我们想提取图片张量中的蓝色通道,可以这么写:
print(image_tensor[:, :, :, 2]) # 输出第三个通道的所有像素
tf.Tensor(
[[[-0.8440698 -0.20434138 1.0631086 ... 1.4857111 -1.5890635
0.01813171]
[ 0.2234562 1.4541649 1.558053 ... 0.2963377 0.8863706
0.87936914]
[-1.1262956 -1.4410621 -0.64045465 ... -0.6537862 0.47904253
1.2033157 ]
...
[-0.5564168 -0.3602578 -0.07460994 ... -0.596759 -0.4439934
-1.2628192 ]
[ 0.04733002 1.7597612 0.7633362 ... -0.03441245 0.799629
-0.20505095]
[ 0.9229313 -1.3190516 1.7249115 ... -1.8670495 -0.5548941
-1.2218807 ]]], shape=(6, 28, 28), dtype=float32)
为了避免出现冒号过多的情况,我们也可以采用省略号 ...
的方式来读取相邻多个维度上的张量,例如,我们想读取前两张图片中的绿通道和蓝通道,可以这么写:
print(image_tensor[:2, ..., 1:]) # 输出前两张图片中的绿色和蓝色通道
tf.Tensor(
[[[[ 1.33665156e+00 -8.44069779e-01]
[ 1.59177464e-02 -2.04341382e-01]
[-3.09595257e-01 1.06310856e+00]
...
[-9.55023944e-01 1.48571110e+00]
[-1.25659752e+00 -1.58906353e+00]
[-3.96027058e-01 1.81317069e-02]]]], shape=(2, 28, 28, 2), dtype=float32)
省略号切片方式总结如下:
切片方式 | 切片含义描述 |
---|---|
m, ..., n |
省略号 ... 表示:最左边以 m 为界,最右边以 n 为界,中间维度全部包含,其他维度按照 m 和 n 的格式进行读取 |
m, ... |
省略号 ... 表示:最左边以 m 为界,m 之后维度全部包含,其他维度按照 m 的格式进行读取 |
..., n |
省略号 ... 表示:最右边以 n 为界,n 之前维度全部包含,其他维度按照 n 的格式进行读取 |
... |
省略号 ... 表示:读取所有维度数据 |