'Passing a trained model to another function in airflow
So I want to use airflow to display my model training. I created a model in a python function and now I want to pass it to another function which will train it. I am currently trying to pass it via xcom, but then I get the error:
INFO - Done. Returned value was: <keras.wrappers.scikit_learn.KerasClassifier object at 0x7fce52713940>
[2022-04-27, 10:18:37 CEST] {xcom.py:447}
ERROR - Could not serialize the XCom value into JSON.
How do I pass the model to the other function?
This is a part of my code:
def create_model():
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model = KerasClassifier(build_fn=create_model, verbose=0)
return model
def trainModel(ti):
model=ti.xcom_pull(task_ids='create_model')
X,Y= ti.xcom_pull(task_ids='loadData')
batch_size = 10
epochs = 10
model.fit(X,Y,
steps_per_epoch=batch_size,
epochs=epochs,
)
#model.save(model_file_name)
return model
create_model_task=PythonOperator(
task_id='create_model',
python_callable=create_model,
provide_context= True,
dag=MODEL_DAG
)
trainModel_Task=PythonOperator(
task_id='trainModel',
python_callable=trainModel,
provide_context= True,
dag=MODEL_DAG
)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|