TensorFlow Tensor变换API

tf.stack()

1
2
3
4
5
stack(
values,
axis=0,
name="stack"
)

官方文档:

https://www.tensorflow.org/api_docs/python/tf/stack

定义:

tensorflow/python/ops/array_ops.py

功能:

将由 R 维的 tensor 堆成 R+1 维的 tensor .

说明:

通过沿着 axis 维,将 values 中的 tensor 列表填充到一个比values 中的 tensor 高一维的 tensor中。

给定一个长度为 N ,由 shape(A,B,C)tensor 构成的列表;

如果 axis == 0,输出的 tensorshape(N, A, B, C)

如果 axis == 1,输出的 tensorshape(A, N, B, C)

示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import tensorflow as tf
x = [1, 4]
y = [2, 5]
z = [3, 6]
stack_0 = tf.stack([x, y, z])
print (stack_0.get_shape()) # (3, 2)
stack_1 = tf.stack([x, y, z], axis=1)
print (stack_1.get_shape()) # (2, 3)
with tf.Session() as sess:
stack_0_val, stack_1_val = sess.run([stack_0, stack_1])
print ('stack_0_val:')
print (stack_0_val) # [[1, 4], [2, 5], [3, 6]]
print ('stack_0_val.shape: %s' % str(stack_0_val.shape)) # (3, 2)
print ('stack_1_val:')
print (stack_1_val) # [[1, 2, 3], [4, 5, 6]]
print ('stack_1_val.shape: %s' % str(stack_1_val.shape)) # (2, 3)

输出结果:

1
2
3
4
5
6
7
8
9
10
11
stack_0.get_shape(): (3, 2)
stack_1.get_shape(): (2, 3)
stack_0_val:
[[1 4]
[2 5]
[3 6]]
stack_0_val.shape: (3, 2)
stack_1_val:
[[1 2 3]
[4 5 6]]
stack_1_val.shape: (2, 3)

-------------本文结束感谢您的阅读-------------
坚持整理学习笔记,您的支持将鼓励我继续整理下去!