'Airflow SSHOperator: How To Securely Access Pem File Across Tasks?
We are running Airflow via AWS's managed MWAA Offering. As part of their offering they include a tutorial on securely using the SSH Operator in conjunction with AWS Secrets Manager. The gist of how their solution works is described below:
- Run a Task that fetches the pem file from a Secrets Manager location and store it on the filesystem at /tmp/mypem.pem.
- In the SSH Connection include the extra information that specifies the file location
{"key_file":"/tmp/mypem.pem"}
- Use the SSH Connection in the SSHOperator.
In short the workflow is supposed to be:
Task1 gets the pem -> Task2 uses the pem via the SSHOperator
All of this is great in theory, but it doesn't actually work. It doesn't work because Task1 may run on a different node from Task2, which means Task2 can't access the /tmp/mypem.pem file location that Task1 wrote the file to. AWS is aware of this limitation according to AWS Support, but now we need to understand another way to do this.
Question
How can we securely store and access a pem file that can then be used by Tasks running on different nodes via the SSHOperator?
Solution 1:[1]
I ran into the same problem. I extended the SSHOperator to do both steps in one call.
In AWS Secrets Manager, two keys are added for airflow to retrieve on execution.
{variables_prefix}/airflow-user-ssh-key : the value of the private key
{connections_prefix}/ssh_airflow_user : ssh://[email protected]?key_file=%2Ftmp%2Fairflow-user-ssh-key
from typing import Optional, Sequence
from os.path import basename, splitext
from airflow.models import Variable
from airflow.providers.ssh.operators.ssh import SSHOperator
from airflow.providers.ssh.hooks.ssh import SSHHook
class SSHOperator(SSHOperator):
"""
SSHOperator to execute commands on given remote host using the ssh_hook.
:param ssh_conn_id: :ref:`ssh connection id<howto/connection:ssh>`
from airflow Connections.
:param ssh_key_var: name of Variable holding private key.
Creates "/tmp/{variable_name}.pem" to use in SSH connection.
May also be inferred from "key_file" in "extras" in "ssh_conn_id".
:param remote_host: remote host to connect (templated)
Nullable. If provided, it will replace the `remote_host` which was
defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
:param command: command to execute on remote host. (templated)
:param timeout: (deprecated) timeout (in seconds) for executing the command. The default is 10 seconds.
Use conn_timeout and cmd_timeout parameters instead.
:param environment: a dict of shell environment variables. Note that the
server will reject them silently if `AcceptEnv` is not set in SSH config.
:param get_pty: request a pseudo-terminal from the server. Set to ``True``
to have the remote process killed upon task timeout.
The default is ``False`` but note that `get_pty` is forced to ``True``
when the `command` starts with ``sudo``.
"""
template_fields: Sequence[str] = ("command", "remote_host")
template_ext: Sequence[str] = (".sh",)
template_fields_renderers = {"command": "bash"}
def __init__(
self,
*,
ssh_conn_id: Optional[str] = None,
ssh_key_var: Optional[str] = None,
remote_host: Optional[str] = None,
command: Optional[str] = None,
timeout: Optional[int] = None,
environment: Optional[dict] = None,
get_pty: bool = False,
**kwargs,
) -> None:
super().__init__(
ssh_conn_id=ssh_conn_id,
remote_host=remote_host,
command=command,
timeout=timeout,
environment=environment,
get_pty=get_pty,
**kwargs,
)
if ssh_key_var is None:
key_file = SSHHook(ssh_conn_id=self.ssh_conn_id).key_file
key_filename = basename(key_file)
key_filename_no_extension = splitext(key_filename)[0]
self.ssh_key_var = key_filename_no_extension
else:
self.ssh_key_var = ssh_key_var
def import_ssh_key(self):
with open(f"/tmp/{self.ssh_key_var}", "w") as file:
file.write(Variable.get(self.ssh_key_var))
def execute(self, context):
self.import_ssh_key()
super().execute(context)
Solution 2:[2]
The answer by holly is good. I am sharing a different way I solved this problem. I used the strategy of converting the SSH Connection into a URI and then input that into Secrets Manager under the expected connections path, and everything worked great via the SSH Operator. Below are the general steps I took.
- Generate an encoded URI
import json
from airflow.models.connection import Connection
from pathlib import Path
pem = Path(“/my/pem/file”/pem).read_text()
myconn= Connection(
conn_id="connX”,
conn_type="ssh",
host="10.x.y.z,
login=“mylogin”,
extra=json.dumps(dict(private_key=pem)),
print(myconn.get_uri())
- Input that URI under the environment's configured path in Secrets Manager. The important note here is to input the value in the plaintext field without including a key. Example:
airflow/connections/connX and under Plaintext only include the URI value
- Now in the SSHOperator you can reference this connection Id like any other.
remote_task = SSHOperator(
task_id="ssh_and_execute_command",
ssh_conn_id="connX"
command="whoami",
)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | holly.evans |
| Solution 2 | Howard_Roark |
