在學習keras深度學習框架的過程中我們可能會遇到Keras運行變慢,內(nèi)存消耗變大的問題,這些問題其實是有g(shù)et_value函數(shù)運行越來越慢導致的,那么怎么解決這些問題呢?接下來小編就帶你來了解。
問題描述
如上圖所示,經(jīng)過時間和內(nèi)存消耗跟蹤測試,發(fā)現(xiàn)是keras.backend.get_value() 函數(shù)導致的程序越來越慢,而且嚴重的造成內(nèi)存泄露;
查看該函數(shù)內(nèi)部實現(xiàn),發(fā)現(xiàn)一個主要核心是x.eval(session=get_session()),該語句可能是導致內(nèi)存泄露和運行慢的核心語句; 根據(jù)查看一些博文得到了運行得越來越慢的
原因:該x.eval函數(shù)會添加新的節(jié)點到tf的圖中;而這也導致了tf的圖越來越大,內(nèi)存泄露;
解決方法
import tensorflow.keras.backend as K
def get_my_session(gpu_fraction=0.1):
'''Assume that you have 6GB of GPU memory and want to allocate ~2GB'''
num_threads = os.environ.get('OMP_NUM_THREADS')
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
if num_threads:
return tf.Session(config=tf.ConfigProto(
gpu_options=gpu_options, intra_op_parallelism_threads=num_threads))
else:
return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
K.set_session(get_my_session())
如上圖所示, 我在使用tensorflow之前(也就是該工程文件前面),對session進行自定義,然后用自定義的session設定keras.backend.set_session();
然后刪除get_value() 函數(shù),直接用get_value()中所使用的執(zhí)行語句x.eval(session=get_my_session());這樣這個添加節(jié)點導致內(nèi)存泄露的核心語句x.eval()就使用的是該工程統(tǒng)一自定義session,然后用tf.reset_default_graph() 對圖重置就可以了
即上圖問題代碼修改為:
output = ctc_decode(y_pred,input_length=input_length,)
output = output[0][0]
out = output.eval(session=get_my_session())
# 刪除 K.get_value(out[0][0])
tf.reset_default_graph() # 然后重置tf圖,這句很關(guān)鍵
這樣就解決了get_value()導致的越來越慢的問題;
個人認為:這樣可能就不會總是添加新的節(jié)點,導致tf圖不斷地無限變大;而是重復使用這一個自定義的節(jié)點。
補充:tensorflow與keras之間版本問題引起get_session問題解決辦法
1.產(chǎn)生報錯原因
import tensorflow.keras.backend as K
def __init__(self, **kwargs):
self.__dict__.update(self._defaults) # set up default values
self.__dict__.update(kwargs) # and update with user overrides
self.class_names = self._get_class()
self.anchors = self._get_anchors()
self.sess = K.get_session()
報錯如下:
get_session is not available when using TensorFlow 2.0.
意思是 tf2.0 沒有 get_session
2.解決方案1
import tensorflow.python.keras.backend as K
sess = K.get_session()
3. 解決方案2
import tensorflow as tf
sess = tf.compat.v1.keras.backend.get_session()
之前一直采用方案1 解決,感覺比較方便;但是解決方案1 有其它屬性會丟失問題
比如AttributeError: module ‘keras.backend' has no attribute image_dim_ordering
所以建議大家采用方案2
以上就是Keras內(nèi)存消耗變大和keras運行變慢的解決方案,希望能給大家一個參考,也希望大家多多支持W3Cschool。