python-3.x 如何在气流中将变量从一个任务传递到另一个任务

z9smfwbn  于 2023-01-10  发布在  Python
关注(0)|答案(2)|浏览(136)

下面的代码可以工作,但我的要求是将totalbuckets作为输入传递给函数,而不是全局变量。我在将它作为变量传递并在下一个任务中执行xcom_pull时遇到了麻烦。这个dag基本上是基于输入的数量创建buckets的,totalbuckets是一个常量。感谢您的帮助。

from airflow import DAG    
from airflow.operators.python import PythonOperator, BranchPythonOperator   
with DAG('test-live', catchup=False, schedule_interval=None, default_args=args) as test_live:

totalbuckets = 3

# branches based on number of buckets
    def branch_buckets(**context):

        buckets = defaultdict(list)
        for i in range(len(inputs_to_process)):
            buckets[f'bucket_{(1+i % totalbuckets)}'].append(inputs_to_process[i])
      
        for bucket_name, input_sublist in buckets.items():
            context['ti'].xcom_push(key = bucket_name, value = input_sublist)
        return list(buckets.keys())
    
    # BranchPythonOperator will launch the buckets and distributes inputs among the buckets
    branch_buckets = BranchPythonOperator(
        task_id='branch_buckets',
        python_callable=branch_buckets,
        trigger_rule=TriggerRule.NONE_FAILED,
        provide_context=True,
        dag=test_live
    )  
# update provider tables with merge sql
    def update_inputs(sf_conn_id, bucket_name, **context):
        input_sublist = context['ti'].xcom_pull(task_ids='branch_buckets', key=bucket_name)
        print(f"Processing inputs {input_sublist} in {bucket_name}")

        from custom.hooks.snowflake_hook import SnowflakeHook
        for p in input_sublist:
            merge_sql=f"""
            merge into ......"""

bucket_tasks = []
        for i in range(totalbuckets):
            task= PythonOperator(
                task_id=f'bucket_{i+1}',
                python_callable=update_inputs,
                provide_context=True,
                op_kwargs={'bucket_name':f'bucket_{i+1}','sf_conn_id': SF_CONN_ID},
                dag=test_live
            )
            bucket_tasks.append(task)
ohfgkhjo

ohfgkhjo1#

totalbuckets应该是一个run conf变量,您可以为从UI、CLI、Airflow REST API甚至python API创建的每个运行提供它。

from airflow import DAG
from airflow.operators.python import PythonOperator, BranchPythonOperator
from airflow.models.param import Param
with DAG(
    'test-live',
    catchup=False,
    schedule_interval=None,
    default_args=args,
    params={"totalbuckets": Param(default=3, type="integer")},
) as test_live:
    # branches based on number of buckets
    def branch_buckets(**context):

        buckets = defaultdict(list)
        for i in range(len(inputs_to_process)):
            buckets[f'bucket_{(1+i % int("{{ params.totalbuckets }}"))}'].append(inputs_to_process[i])

        for bucket_name, input_sublist in buckets.items():
            context['ti'].xcom_push(key = bucket_name, value = input_sublist)
        return list(buckets.keys())

    # BranchPythonOperator will launch the buckets and distributes inputs among the buckets
    branch_buckets = BranchPythonOperator(
        task_id='branch_buckets',
        python_callable=branch_buckets,
        trigger_rule=TriggerRule.NONE_FAILED,
        provide_context=True,
        dag=test_live
    )
    # update provider tables with merge sql
    def update_inputs(sf_conn_id, bucket_name, **context):
        input_sublist = context['ti'].xcom_pull(task_ids='branch_buckets', key=bucket_name)
        print(f"Processing inputs {input_sublist} in {bucket_name}")

        from custom.hooks.snowflake_hook import SnowflakeHook
        for p in input_sublist:
            merge_sql=f"""
                merge into ......"""

    bucket_tasks = []
    for i in range(int("{{ params.totalbuckets }}")):
        task= PythonOperator(
            task_id=f'bucket_{i+1}',
            python_callable=update_inputs,
            provide_context=True,
            op_kwargs={'bucket_name':f'bucket_{i+1}','sf_conn_id': SF_CONN_ID},
            dag=test_live
        )
        bucket_tasks.append(task)

运行它的示例:

airflow dags trigger --conf '{"totalbuckets": 10}' test-live

或者通过用户界面。

1tu0hz3e

1tu0hz3e2#

@hussein awala我正在执行类似下面的操作,但无法解析bucket_tasks中的totalbuckets

from airflow.operators.python import PythonOperator, BranchPythonOperator   
with DAG('test-live', catchup=False, schedule_interval=None, default_args=args) as test_live:

#totalbuckets = 3

    def branch_buckets(totalbuckets, **context):

        buckets = defaultdict(list)
        for i in range(len(inputs_to_process)):
            buckets[f'bucket_{(1+i % totalbuckets)}'].append(inputs_to_process[i])
      
        for bucket_name, input_sublist in buckets.items():
            context['ti'].xcom_push(key = bucket_name, value = input_sublist)
        return list(buckets.keys())
    
    # BranchPythonOperator will launch the buckets and distributes inputs among the buckets
    branch_buckets = BranchPythonOperator(
        task_id='branch_buckets',
        python_callable=branch_buckets,
        trigger_rule=TriggerRule.NONE_FAILED,
        provide_context=True, op_kwargs={'totalbuckets':3},
        dag=test_live
    )  
# update provider tables with merge sql
    def update_inputs(sf_conn_id, bucket_name, **context):
        input_sublist = context['ti'].xcom_pull(task_ids='branch_buckets', key=bucket_name)
        print(f"Processing inputs {input_sublist} in {bucket_name}")

        from custom.hooks.snowflake_hook import SnowflakeHook
        for p in input_sublist:
            merge_sql=f"""
            merge into ......"""

bucket_tasks = []
        for i in range(totalbuckets):
            task= PythonOperator(
                task_id=f'bucket_{i+1}',
                python_callable=update_inputs,
                provide_context=True,
                op_kwargs={'bucket_name':f'bucket_{i+1}','sf_conn_id': SF_CONN_ID},
                dag=test_live
            )
            bucket_tasks.append(task)```

相关问题