keras 为什么多输入nn损耗被卡住

deikduxw  于 2022-11-13  发布在  其他
关注(0)|答案(1)|浏览(179)

我需要创建这样的结构:structure
我尝试:

from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers import Flatten
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPooling2D
from keras.layers.merge import Add

input_array = []
output_array = []
for i in range(14):
  input_layer = Input(shape=(1,))
  hidden1 = Dense(128, activation='relu')(input_layer)
  hidden2 = Dense(128, activation='relu')(hidden1)
  output_layer = Dense(1, activation='relu')(hidden2)

  input_array.append(input_layer)
  output_array.append(output_layer)

# merge input models
summation = Add()(output_array)
# interpretation model

model = Model(inputs=input_array, outputs=summation)
model.compile(loss='mse', optimizer='Adam')

此外,我还在虚拟数据上进行了测试:

test_X= np.array([i  for i in range(1,21)])
ty = np.array([np.exp(i/10)  for i in range(1,21)])
test_X = test_X.reshape(-1, 1, 1)
model.fit([test_X]*14, ty, epochs=100, batch_size=14)

而且它工作得很好,但如果试着适合我的真实的火车数据:

transformed_G = [array([[[1.91459711e+00]],[[1.90613065e+00]],[[1.78386092e+00]],[[1.61354920e+00]],[[1.53859274e+00]],[[1.50765169e+00]],[[1.47722348e+00]],[[1.44736809e+00]],[[1.41812393e+00]],[[1.38951279e+00]],[[1.36154440e+00]],[[1.33422060e+00]],[[1.30753901e+00]],[[1.28149621e+00]],[[1.15764704e+00]],[[9.36670929e-01]],[[8.97971224e-01]],[[8.17983422e-01]],[[5.51106504e-01]],[[8.95766049e-04]]]),array([[[1.36093816]],[[1.37581026]],[[1.40231904]],[[1.30026948]],[[1.23904628]],[[1.21334898]],[[1.18858706]],[[1.16531144]],[[1.14402097]],[[1.12515567]],[[1.10909199]],[[1.09614006]],[[1.08654289]],[[1.08047722]],[[1.10716623]],[[1.5314739]],[[1.97420908]],[[1.90175807]],[[1.28591518]],[[0.00209013]]]),array([[[1.36093816]],[[1.37581026]],[[1.40231904]],[[1.30026948]],[[1.23904628]],[[1.21334898]],[[1.18858706]],[[1.16531144]],[[1.14402097]],[[1.12515567]],[[1.10909199]],[[1.09614006]],[[1.08654289]],[[1.08047722]],[[1.10716623]],[[1.5314739]],[[1.97420908]],[[1.90175807]],[[1.28591518]],[[0.00209013]]]),array([[[1.36093816]],[[1.37581026]],[[1.40231904]],[[1.30026948]],[[1.23904628]],[[1.21334898]],[[1.18858706]],[[1.16531144]],[[1.14402097]],[[1.12515567]],[[1.10909199]],[[1.09614006]],[[1.08654289]],[[1.08047722]],[[1.10716623]],[[1.5314739]],[[1.97420908]],[[1.90175807]],[[1.28591518]],[[0.00209013]]]),array([[[1.08327839]],[[1.09824762]],[[1.15427649]],[[1.12407517]],[[1.10129741]],[[1.09234529]],[[1.08443805]],[[1.07799145]],[[1.07338263]],[[1.07094503]],[[1.07096466]],[[1.07367788]],[[1.07927056]],[[1.08787846]],[[1.18146992]],[[1.73310742]],[[2.25335663]],[[2.17327423]],[[1.46961735]],[[0.00238871]]]),array([[[1.08327839]],[[1.09824762]],[[1.15427649]],[[1.12407517]],[[1.10129741]],[[1.09234529]],[[1.08443805]],[[1.07799145]],[[1.07338263]],[[1.07094503]],[[1.07096466]],[[1.07367788]],[[1.07927056]],[[1.08787846]],[[1.18146992]],[[1.73310742]],[[2.25335663]],[[2.17327423]],[[1.46961735]],[[0.00238871]]]),array([[[1.08327839]],[[1.09824762]],[[1.15427649]],[[1.12407517]],[[1.10129741]],[[1.09234529]],[[1.08443805]],[[1.07799145]],[[1.07338263]],[[1.07094503]],[[1.07096466]],[[1.07367788]],[[1.07927056]],[[1.08787846]],[[1.18146992]],[[1.73310742]],[[2.25335663]],[[2.17327423]],[[1.46961735]],[[0.00238871]]]),array([[[1.77102136e+00]],[[1.75553355e+00]],[[1.57613838e+00]],[[1.35643362e+00]],[[1.26607554e+00]],[[1.23008574e+00]],[[1.19554821e+00]],[[1.16257070e+00]],[[1.13122940e+00]],[[1.10157316e+00]],[[1.07362746e+00]],[[1.04739830e+00]],[[1.02287566e+00]],[[1.00003676e+00]],[[9.07647047e-01]],[[8.28180617e-01]],[[8.77795042e-01]],[[8.16838440e-01]],[[5.51106508e-01]],[[8.95770636e-04]]]),array([[[1.77102136e+00]],[[1.75553355e+00]],[[1.57613838e+00]],[[1.35643362e+00]],[[1.26607554e+00]],[[1.23008574e+00]],[[1.19554821e+00]],[[1.16257070e+00]],[[1.13122940e+00]],[[1.10157316e+00]],[[1.07362746e+00]],[[1.04739830e+00]],[[1.02287566e+00]],[[1.00003676e+00]],[[9.07647047e-01]],[[8.28180617e-01]],[[8.77795042e-01]],[[8.16838440e-01]],[[5.51106508e-01]],[[8.95770636e-04]]]),array([[[1.77102136e+00]],[[1.75553355e+00]],[[1.57613838e+00]],[[1.35643362e+00]],[[1.26607554e+00]],[[1.23008574e+00]],[[1.19554821e+00]],[[1.16257070e+00]],[[1.13122940e+00]],[[1.10157316e+00]],[[1.07362746e+00]],[[1.04739830e+00]],[[1.02287566e+00]],[[1.00003676e+00]],[[9.07647047e-01]],[[8.28180617e-01]],[[8.77795042e-01]],[[8.16838440e-01]],[[5.51106508e-01]],[[8.95770636e-04]]]),array([[[1.72031091e+00]],[[1.72242849e+00]],[[1.68208421e+00]],[[1.57053628e+00]],[[1.51154613e+00]],[[1.48581045e+00]],[[1.45978826e+00]],[[1.43361747e+00]],[[1.40741667e+00]],[[1.38128732e+00]],[[1.35531613e+00]],[[1.32957744e+00]],[[1.30413568e+00]],[[1.27904771e+00]],[[1.15734846e+00]],[[9.36670934e-01]],[[8.97971229e-01]],[[8.17983426e-01]],[[5.51106508e-01]],[[8.95770636e-04]]]),array([[[1.72031091e+00]],[[1.72242849e+00]],[[1.68208421e+00]],[[1.57053628e+00]],[[1.51154613e+00]],[[1.48581045e+00]],[[1.45978826e+00]],[[1.43361747e+00]],[[1.40741667e+00]],[[1.38128732e+00]],[[1.35531613e+00]],[[1.32957744e+00]],[[1.30413568e+00]],[[1.27904771e+00]],[[1.15734846e+00]],[[9.36670934e-01]],[[8.97971229e-01]],[[8.17983426e-01]],[[5.51106508e-01]],[[8.95770636e-04]]]),array([[[1.72031091e+00]],[[1.72242849e+00]],[[1.68208421e+00]],[[1.57053628e+00]],[[1.51154613e+00]],[[1.48581045e+00]],[[1.45978826e+00]],[[1.43361747e+00]],[[1.40741667e+00]],[[1.38128732e+00]],[[1.35531613e+00]],[[1.32957744e+00]],[[1.30413568e+00]],[[1.27904771e+00]],[[1.15734846e+00]],[[9.36670934e-01]],[[8.97971229e-01]],[[8.17983426e-01]],[[5.51106508e-01]],[[8.95770636e-04]]]),array([[[1.90477789e+00]],[[1.89829261e+00]],[[1.78303974e+00]],[[1.61353753e+00]],[[1.53859274e+00]],[[1.50765169e+00]],[[1.47722348e+00]],[[1.44736810e+00]],[[1.41812394e+00]],[[1.38951280e+00]],[[1.36154441e+00]],[[1.33422061e+00]],[[1.30753902e+00]],[[1.28149622e+00]],[[1.15764705e+00]],[[9.36670934e-01]],[[8.97971229e-01]],[[8.17983426e-01]],[[5.51106508e-01]],[[8.95770636e-04]]])]

data_E = np.array([  4.4612765 ,  -2.3341443 , -10.378765, -13.874788,
       -14.534859  , -14.705036  , -14.821358  , -14.8896    ,
       -14.914813  , -14.90294   , -14.857485  , -14.782236  ,
       -14.681216  , -14.557151  , -13.653582  , -10.439137  ,
        -7.4652775 ,  -5.1739723 ,  -3.5250227 ,  -0.78888653])

model.fit(transformed_G, data_E, epochs=100, batch_size=14)

"我的损失在几个世纪后被卡住“

Epoch 1/100
2/2 [==============================] - 3s 13ms/step - loss: 11.7047
Epoch 2/100
2/2 [==============================] - 0s 10ms/step - loss: 11.7047
Epoch 3/100
2/2 [==============================] - 0s 12ms/step - loss: 11.7047
Epoch 4/100
2/2 [==============================] - 0s 11ms/step - loss: 11.7047
Epoch 5/100
2/2 [==============================] - 0s 12ms/step - loss: 11.7047
Epoch 6/100
2/2 [==============================] - 0s 12ms/step - loss: 11.7047
Epoch 7/100
2/2 [==============================] - 0s 11ms/step - loss: 11.7046
Epoch 8/100
2/2 [==============================] - 0s 20ms/step - loss: 11.7046

在这种情况下,它也总是预测为0

6pp0gazn

6pp0gazn1#

解的定义域为(-∞; ∞),所以输出层激活函数应取值从-∞到∞变化如下:

output_layer = Dense(1, activation='relu')(hidden2)

output_layer = Dense(1, activation='linear')(hidden2)

或泄漏ReLU

相关问题