import io import threading from typing import Optional, Callable from dataclasses import dataclass import paramiko from config import get_config @dataclass class CommandResult: """Result of an SSH command execution.""" stdout: str stderr: str exit_code: int success: bool class SSHManager: """Manages SSH connections to target hosts.""" def __init__(self): self._connections: dict[str, paramiko.SSHClient] = {} self._lock = threading.Lock() def _get_private_key(self) -> paramiko.PKey: """Parse the private key from config.""" config = get_config() key_data = config.ssh_private_key # Try different key formats key_file = io.StringIO(key_data) for key_class in [paramiko.RSAKey, paramiko.Ed25519Key, paramiko.ECDSAKey]: try: key_file.seek(0) return key_class.from_private_key(key_file) except Exception: continue raise ValueError("Unable to parse SSH private key. Supported formats: RSA, Ed25519, ECDSA") def connect(self, hostname: str, username: str = "root", port: int = 22) -> bool: """Establish SSH connection to a host.""" connection_key = f"{username}@{hostname}:{port}" with self._lock: if connection_key in self._connections: # Test if connection is still alive try: transport = self._connections[connection_key].get_transport() if transport and transport.is_active(): return True except Exception: pass # Remove dead connection self._connections.pop(connection_key, None) try: client = paramiko.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) private_key = self._get_private_key() client.connect( hostname=hostname, port=port, username=username, pkey=private_key, timeout=30, allow_agent=False, look_for_keys=False, ) self._connections[connection_key] = client return True except Exception as e: raise ConnectionError(f"Failed to connect to {connection_key}: {str(e)}") def execute( self, hostname: str, command: str, username: str = "root", port: int = 22, timeout: int = 60, on_output: Optional[Callable[[str], None]] = None, ) -> CommandResult: """Execute a command on a remote host.""" connection_key = f"{username}@{hostname}:{port}" with self._lock: client = self._connections.get(connection_key) if not client: raise ConnectionError(f"Not connected to {connection_key}. Call connect() first.") try: stdin, stdout, stderr = client.exec_command(command, timeout=timeout) # Read output stdout_data = "" stderr_data = "" # Stream stdout if callback provided if on_output: for line in stdout: stdout_data += line on_output(line.rstrip("\n")) else: stdout_data = stdout.read().decode("utf-8", errors="replace") stderr_data = stderr.read().decode("utf-8", errors="replace") exit_code = stdout.channel.recv_exit_status() return CommandResult( stdout=stdout_data, stderr=stderr_data, exit_code=exit_code, success=exit_code == 0, ) except Exception as e: return CommandResult( stdout="", stderr=str(e), exit_code=-1, success=False, ) def disconnect(self, hostname: str, username: str = "root", port: int = 22) -> None: """Close SSH connection to a host.""" connection_key = f"{username}@{hostname}:{port}" with self._lock: client = self._connections.pop(connection_key, None) if client: try: client.close() except Exception: pass def disconnect_all(self) -> None: """Close all SSH connections.""" with self._lock: for client in self._connections.values(): try: client.close() except Exception: pass self._connections.clear() def is_connected(self, hostname: str, username: str = "root", port: int = 22) -> bool: """Check if connected to a host.""" connection_key = f"{username}@{hostname}:{port}" with self._lock: client = self._connections.get(connection_key) if not client: return False try: transport = client.get_transport() return transport is not None and transport.is_active() except Exception: return False # Global SSH manager instance _ssh_manager: Optional[SSHManager] = None def get_ssh_manager() -> SSHManager: """Get the global SSH manager instance.""" global _ssh_manager if _ssh_manager is None: _ssh_manager = SSHManager() return _ssh_manager