keras Tensorflow Recommender在加载排名模型预测时出错

laximzn5  于 2023-01-17  发布在  其他
关注(0)|答案(1)|浏览(179)

我有一个用TensorFlow推荐器做的排名模型。当我和他一起做预测的时候,效果还可以。
如果我保存并加载它,当使用加载的模型进行预测时,我会得到一个错误。请参见下面我的模型和我得到的错误
排序模型:

class HMRankingModel(tfrs.Model):

    def __init__(self):
        super().__init__()
        
        # Customer model
        self.customer_input = tf.keras.Input(shape=(1,), dtype=tf.string, name='customer_input')
        self.customer_sl = tf.keras.layers.StringLookup(vocabulary=unique_customer_ids, mask_token=None, name='customer_string_lookup')(self.customer_input)
        self.customer_embedding = tf.squeeze(tf.keras.layers.Embedding(len(unique_customer_ids) + 1, embedding_dimension, name='customer_emb')(self.customer_sl), axis=1)

        self.age_input = tf.keras.Input(shape=(1,), name='age_input')
        self.age_discretization = tf.keras.layers.Discretization(age_buckets.tolist(), name='age_discretization')(self.age_input)
        self.age_embedding = tf.squeeze(tf.keras.layers.Embedding(len(age_buckets) + 1, embedding_dimension, name='age_embedding')(self.age_discretization), axis=1)
        
        self.customer_merged = tf.keras.layers.concatenate([self.customer_embedding, self.age_embedding], axis=-1, name='customer_merged')
        self.customer_dense = tf.keras.layers.Dense(embedding_dimension, activation=activation, name='customer_dense')(self.customer_merged)
        
        
        # Article model
        self.article_input = tf.keras.Input(shape=(1,), dtype=tf.string, name='article_input')
        self.article_sl = tf.keras.layers.StringLookup(vocabulary=unique_article_ids, name='article_string_lookup')(self.article_input)
        self.article_final = tf.squeeze(tf.keras.layers.Embedding(len(unique_article_ids)+1, embedding_dimension, name='article_emb')(self.article_sl), axis=1)

        self.article_dense = tf.keras.layers.Dense(embedding_dimension, activation=activation, name='article_dense')(self.article_final)        
        

        # Multiply model
        self.towers_multiplied = tf.keras.layers.Multiply(name='towers_multiplied')([self.customer_dense, self.article_dense])
        self.towers_dense = tf.keras.layers.Dense(dense_size, activation=activation, name='towers_dense1')(self.towers_multiplied)
        self.output_node = tf.keras.layers.Dense(1, name='output_node')(self.towers_dense)
        
        
        # Model definition
        self.model = tf.keras.Model(inputs={'customer_id': self.customer_input, 
                                            'article_id': self.article_input,
                                            'age': self.age_input,
                                            }, 
                                    outputs=self.output_node)
        
        self.task = tfrs.tasks.Ranking(
            loss = tf.keras.losses.MeanSquaredError(),
            metrics=[tf.keras.metrics.RootMeanSquaredError()]
        )
        
    def call(self, features):
        return self.model({'customer_id': features["customer_id"], 
                           'article_id': features["article_id"], 
                           'age': features["age"], 
                           })   
        
    def compute_loss(self, features_dict, training=False):
        labels = features_dict.pop("count")
        predictions = self(features_dict)
        return self.task(labels=labels, predictions=predictions)
ranking_model = HMRankingModel()
ranking_model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.1))
ranking_model.fit(cached_train, validation_data=cached_validation, epochs=epochs)

原始模型预测

ranking_model({
    'customer_id': np.array(["18b3a4767533c8f1f6ff274b57ca200939c9fda3992c5bb3b50b31dc6d6b1ee5"]), 
    'age': np.array([29]), 
    'article_id': np.array(['562245059'])
})

产出:

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.3872527]], dtype=float32)>

我保存模型

tf.saved_model.save(ranking_model, ranking_model_path)

在加载模型并进行预测时,出现错误

saved_ranking_model = tf.saved_model.load(ranking_model_path)
predictions = saved_ranking_model({
    'customer_id': np.array(["18b3a4767533c8f1f6ff274b57ca200939c9fda3992c5bb3b50b31dc6d6b1ee5"]), 
    'age': np.array([29]), 
    'article_id': np.array(['141661025'])
})

输出:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [31], in <cell line: 1>()
----> 1 predictions = saved_ranking_model({
      2     'customer_id': np.array(["18b3a4767533c8f1f6ff274b57ca200939c9fda3992c5bb3b50b31dc6d6b1ee5"]), 
      3     'age': np.array([29]), 
      4     'article_id': np.array(['141661025'])
      5 })

File ~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/saved_model/load.py:686, in _call_attribute(instance, *args, **kwargs)
    685 def _call_attribute(instance, *args, **kwargs):
--> 686   return instance.__call__(*args, **kwargs)

File ~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153   raise e.with_traceback(filtered_tb) from None
    154 finally:
    155   del filtered_tb

File ~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/saved_model/function_deserialization.py:286, in recreate_function.<locals>.restored_function_body(*args, **kwargs)
    282   positional, keyword = concrete_function.structured_input_signature
    283   signature_descriptions.append(
    284       "Option {}:\n  {}\n  Keyword arguments: {}"
    285       .format(index + 1, _pretty_format_positional(positional), keyword))
--> 286 raise ValueError(
    287     "Could not find matching concrete function to call loaded from the "
    288     f"SavedModel. Got:\n  {_pretty_format_positional(args)}\n  Keyword "
    289     f"arguments: {kwargs}\n\n Expected these arguments to match one of the "
    290     f"following {len(saved_function.concrete_functions)} option(s):\n\n"
    291     f"{(chr(10)+chr(10)).join(signature_descriptions)}")

ValueError: Could not find matching concrete function to call loaded from the SavedModel. Got:
  Positional arguments (2 total):
    * {'age': <tf.Tensor 'features:0' shape=(1,) dtype=int64>,
 'article_id': <tf.Tensor 'features_1:0' shape=(1,) dtype=string>,
 'customer_id': <tf.Tensor 'features_2:0' shape=(1,) dtype=string>}
    * False
  Keyword arguments: {}

 Expected these arguments to match one of the following 4 option(s):

Option 1:
  Positional arguments (2 total):
    * {'age': TensorSpec(shape=(None,), dtype=tf.float32, name='age'),
 'article_id': TensorSpec(shape=(None,), dtype=tf.string, name='article_id'),
 'customer_id': TensorSpec(shape=(None,), dtype=tf.string, name='customer_id')}
    * False
  Keyword arguments: {}

Option 2:
  Positional arguments (2 total):
    * {'age': TensorSpec(shape=(None,), dtype=tf.float32, name='features/age'),
 'article_id': TensorSpec(shape=(None,), dtype=tf.string, name='features/article_id'),
 'customer_id': TensorSpec(shape=(None,), dtype=tf.string, name='features/customer_id')}
    * False
  Keyword arguments: {}

Option 3:
  Positional arguments (2 total):
    * {'age': TensorSpec(shape=(None,), dtype=tf.float32, name='features/age'),
 'article_id': TensorSpec(shape=(None,), dtype=tf.string, name='features/article_id'),
 'customer_id': TensorSpec(shape=(None,), dtype=tf.string, name='features/customer_id')}
    * True
  Keyword arguments: {}

Option 4:
  Positional arguments (2 total):
    * {'age': TensorSpec(shape=(None,), dtype=tf.float32, name='age'),
 'article_id': TensorSpec(shape=(None,), dtype=tf.string, name='article_id'),
 'customer_id': TensorSpec(shape=(None,), dtype=tf.string, name='customer_id')}
    * True
  Keyword arguments: {}

如果我从模型中移除年龄数据和所有年龄相关图层,它可以正常工作。我猜这可能是年龄图层的问题,但我无法找出原因

fquxozlt

fquxozlt1#

saved_model的数据有问题。需要将“age”的值转换为float32。

predictions = saved_ranking_model({
    'customer_id': np.array(["18b3a4767533c8f1f6ff274b57ca200939c9fda3992c5bb3b50b31dc6d6b1ee5"]), 
    'age': np.array([np.float32(29.0)]), 
    'article_id': np.array(['141661025'])
})

相关问题