tensorflow 如何将Union类型转换为tensor类型?

xpcnnkqh  于 2023-11-21  发布在  其他
关注(0)|答案(1)|浏览(100)

我想将制表符分隔的文本转换为2DTensor对象,以便将数据馈送到CNN中。
什么是正确的方法来做到这一点?
我写了以下内容:

from typing import List, Union, cast
import tensorflow as tf

CellType = Union[str, float, int, bool]
RowType = List[CellType]

# Mapping Python types to TensorFlow data types
TF_DATA_TYPES = {
    str: tf.string,
    float: tf.float32,
    int: tf.int32,
    bool: tf.bool
}

def convert_string_to_tensorflow_object(data_string):
    # Split the string into lines
    linesStringList1d: List[str] = data_string.strip().split('\n')

    # Split each line into columns
    dataStringList2d: List[List[str]] = []
    for line in linesStringList1d:
        rowItem: List[str] = line.split(' ')
        dataStringList2d.append(rowItem)

    # Convert the data to TensorFlow tensors
    listOfRows: List[RowType] = []
    for rowItem in dataStringList2d:
        oneRow: RowType = []
        for stringItem in rowItem:
            oneRow.append(cast(CellType, stringItem))
        listOfRows.append(oneRow)

    # Get the TensorFlow data type based on the Python type of CellType
    tf_data_type = TF_DATA_TYPES[type(CellType)]

    listOfRows = tf.constant(listOfRows, dtype=tf_data_type)

    # Create a TensorFlow dataset
    return listOfRows

if __name__ == "__main__":
    # Example usage
    data_string: str = """
    1 ASN C  7.042   9.118  0.000 1 1 1 1  1  0
    2 LEU H  5.781   5.488  7.470 0 0 0 0  1  0
    3 THR H  5.399   5.166  6.452 0 0 0 0  0  0
    4 GLU H  5.373   4.852  6.069 0 0 0 0  1  0
    5 LEU H  5.423   5.164  6.197 0 0 0 0  2  0
    """

    tensorflow_dataset = convert_string_to_tensorflow_object(data_string)

    print(tensorflow_dataset)

字符串
输出量:

C:\Users\pc\AppData\Local\Programs\Python\Python311\python.exe C:/git/heca_v2~~2/src/cnn_lib/convert_string_to_tensorflow_object.py
Traceback (most recent call last):
  File "C:\git\heca_v2~~2\src\cnn_lib\convert_string_to_tensorflow_object.py", line 51, in <module>
    tensorflow_dataset = convert_string_to_tensorflow_object(data_string)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\git\heca_v2~~2\src\cnn_lib\convert_string_to_tensorflow_object.py", line 34, in convert_string_to_tensorflow_object
    tf_data_type = TF_DATA_TYPES[type(CellType)]
                   ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
KeyError: <class 'typing._UnionGenericAlias'>

Process finished with exit code 1


我可以解决这个错误吗?

olmpazwi

olmpazwi1#

你得到的错误是因为type(CellType)没有返回TF_DATA_TYPES字典中的一个键。CellType是一个联合类型,在它上面调用type()将 * 不 * 返回strfloatintbool
不要试图从CellType中查找数据类型,而是检查实际的数据项并将其转换为适当的数据类型。
你期望的最终结果是什么?
二维Tensor
您还可以将行列表转换为TensorFlowTensor。
这将要求所有数据都是相同的数据类型,因此您可能需要决定一种通用的数据类型,该类型可以表示所有数据而不会丢失信息。
由于TensorFlow CNN (Convolutional Neural Network)通常处理数字数据,因此请尝试float
最后,您的代码尝试用空格(' ')分隔每一行,但您提到数据是制表符分隔的。您应该将line.split(' ')更改为line.split('\t')

from typing import List
import tensorflow as tf

def convert_string_to_tensorflow_object(data_string):
    # Split the string into lines
    linesStringList1d: List[str] = data_string.strip().split('\n')

    # Split each line into columns
    dataStringList2d: List[List[str]] = [line.split('\t') for line in linesStringList1d]

    # Convert the string items to float, as CNNs typically work with numeric data
    dataFloatList2d: List[List[float]] = [[float(item) for item in row] for row in dataStringList2d]

    # Convert the data to a TensorFlow tensor
    tensor = tf.constant(dataFloatList2d, dtype=tf.float32)

    return tensor

if __name__ == "__main__":
    # Example usage
    data_string: str = """
    1\tASN\tC\t7.042\t9.118\t0.000\t1\t1\t1\t1\t1\t0
    2\tLEU\tH\t5.781\t5.488\t7.470\t0\t0\t0\t0\t1\t0
    3\tTHR\tH\t5.399\t5.166\t6.452\t0\t0\t0\t0\t0\t0
    4\tGLU\tH\t5.373\t4.852\t6.069\t0\t0\t0\t0\t1\t0
    5\tLEU\tH\t5.423\t5.164\t6.197\t0\t0\t0\t0\t2\t0
    """

    tensorflow_tensor = convert_string_to_tensorflow_object(data_string)
    print(tensorflow_tensor)

字符串
这样,您可以通过制表符拆分每行,将数字的字符串表示转换为float(假设您的CNN可以处理浮点数据),并从2D浮点数列表中创建tf.Tensor

| 2D List of mixed dtypes (as per CellType) |
\------+------------------------------------/
       |
       | Convert all elements to float
       v
| 2D List of float                          |
\------+------------------------------------/
       |
       | Convert to TensorFlow Tensor
       v
| 2D TensorFlow Tensor                      |
\------+------------------------------------/

相关问题