在Tensorflow中生成曲面

7vhp5slm  于 2022-11-25  发布在  其他
关注(0)|答案(1)|浏览(146)

我有一个形状为[sampling_size* sampling_size, 2]的2D网格。我使用它在Tensorflow中生成3D表面,如下所示:

def cube(G):

    res = []

    for (X, Y) in G:
        if X >= -1 and X < 1 and Y >= -1 and Y < 1:
            res.append(1.)
        else:
            res.append(0.)

    return tf.convert_to_tensor(res)
    
Z_cube = cube(grid)

cube_2d = tf.reshape(Z_cube, [sampling_size, sampling_size])
plot_surface(X, Y, cube_2d)

下面是另一个例子:

def prism(G):

    res = []

    for (X, Y) in G:
        if X >= -1 and X < 1 and Y >= -1 and Y < 1:
            res.append(X + 1.)
        else:
            res.append(0.)

    return tf.convert_to_tensor(res)

Z_prism = prism(grid)

prism_2d = tf.reshape(Z_prism, [sampling_size, sampling_size])

我的问题是:由于这使用循环,所以这种方法效率不高,生成一个立方体需要10秒。
我想知道是否有人知道一种更有效的矢量化方法来生成这些表面。
编辑:我使用以下代码生成网格

sampling_size = 100
limit = math.pi

def generate_grid(_from, _to, _step):

    range_ = tf.range(_from, _to, _step, dtype=float)

    x, y = tf.meshgrid(range_, range_)

    _x = tf.reshape(x, (-1,1))
    _y = tf.reshape(y, (-1,1))

    return tf.squeeze(tf.stack([_x, _y], axis=-1)), x, y

grid, X, Y = generate_grid(-limit, limit, 2*limit / sampling_size)

而对于标图:

import matplotlib.pyplot as plt
from matplotlib import cm

def plot_surface(X, Y, Z, a = 30, b = 15):
    fig = plt.figure()
    ax = plt.axes(projection='3d')

    ax.plot_surface(X, Y, Z, rstride=3, cstride=3, linewidth=1, antialiased=True,
                    cmap=cm.viridis)
    ax.view_init(a, b)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')

    plt.show()
jdgnovmf

jdgnovmf1#

您要寻找的是tf.where的多路复用模式。根据条件,选择应该从TensorA还是TensorB中获取元素。
然后,您可以按照以下方式重写棱镜函数:

def tf_prism(G):
    X,Y = tf.unstack(G, axis=-1)
    # Here, the operator '&' replaces 'tf.math.logical_and'
    # Do not use the keyword 'and' it will not work
    return tf.where(
        (X >= -1) & (X < 1) & (Y >= -1) & (Y < 1),
        X + 1,
        0
    )

比较执行速度与timeit:

[1]: %timeit tf_prism(grid)
373 µs ± 3.67 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
[2]: %timeit prism(grid) 
6.47 s ± 127 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

相关问题