'Airflow: get xcom from previous dag run

I am writing a sensor which scan s3 files for fix period of time and add the list of new files arrived at that period to xcom for next task. For that, I am trying to access list of files passed in xcom from previous run. I can do that using below snippet.

context['task_instance'].get_previous_ti(state=State.SUCCESS).xcom_pull(key='new_files',task_ids=self.task_id,dag_id=self.dag_id)

However, context object is passed in poke method and I was to access it in init. Is there another way to do it without using context.

Note - I do not want to directly access underlying database for xcom.

Thanks



Solution 1:[1]

I found this solution which (kinda) uses the underlying database but you dont have to create a sqlalchemy connection directly to use it.

The trick is using the airflow.models.DagRun object and specifically the find() function which allows you to grab all dags by id between two dates, then pull out the task instances and from there, access the xcoms.

default_args = {
    "start_date": days_ago(0),
    "retries": 0,
    "max_active_runs": 1,
}
with models.DAG(
    f"prev_xcom_tester",
    catchup=False,
    default_args=default_args,
    schedule_interval="@hourly",
    tags=["testing"],
) as dag:

    def get_new_value(**context):
        num = randint(1, 100)
        logging.info(f"building new value: {num}")
        return num

    def get_prev_xcom(**context):

        try:
            dag_runs = models.DagRun.find(
                dag_id="prev_xcom_tester",
                execution_start_date=(datetime.now(timezone.utc) - timedelta(days=1)),
                execution_end_date=datetime.now(timezone.utc),
            )

            this_val = context["ti"].xcom_pull(task_ids="get_new_value")

            for dr in dag_runs[:-1]:
                prev_val = dr.get_task_instance("get_new_value").xcom_pull(
                    "get_new_value"
                )
                logging.info(f"Checking dag run: {dr}, xcom was: {prev_val}")
                if this_val == prev_val:
                    logging.info(f"we already processed {this_val} in {dr}")

            return (
                dag_runs[-2]
                .get_task_instance("get_new_value")
                .xcom_pull("get_new_value")
            )
        except Exception as e:
            logging.info(e)
            return 0

    def check_vals_match(**context):
        ti = context["ti"]
        prev_run_val = ti.xcom_pull(task_ids="get_prev_xcoms")
        current_run_val = ti.xcom_pull(task_ids="get_new_value")
        logging.info(
            f"Prev Run Val: {prev_run_val}\nCurrent Run Val: {current_run_val}"
        )
        return prev_run_val == current_run_val

    xcom_setter = PythonOperator(task_id="get_new_value", python_callable=get_new_value)

    xcom_getter = PythonOperator(
        task_id="get_prev_xcoms",
        python_callable=get_prev_xcom,
    )

    xcom_checker = PythonOperator(
        task_id="check_xcoms_match", python_callable=check_vals_match
    )

    xcom_setter >> xcom_getter >> xcom_checker

This dag demonstrates how to:

  • Set a random int between 1 and 100 and passing it through xcom
  • Find all dagruns by dag_id and time span -> check if we have processed this value in the past
  • Return True if current value matches value from previous run.

Hope this helps!

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 niallsc