Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions config/agent/agent.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ session:
# https://asyncssh.readthedocs.io/en/latest/api.html#supported-algorithms
# enabled all algorithms by default, add extra algorithms if needed
encryption_algs: []
# Set terminal size, optional item
term_size:
width: 1000
height: 100
profiles:
# global profile, used when no other profile matche
global:
Expand All @@ -38,6 +42,9 @@ session:
expiration_time: "16:00"
# If the idle time exceeds the maximum, close the session. Unit: Seconds
max_idle_time: 600.0
# https://asyncssh.readthedocs.io/en/stable/api.html#publickeyalgs
# A list of server host key algorithms to allow during the SSH handshake
server_host_key_algs: []
# profile based on vendor/type/version, priority is higher than default
vendor:
cisco:
Expand All @@ -46,13 +53,16 @@ session:
read_timeout: 60.0
expiration_time: "16:00"
max_idle_time: 600.0
server_host_key_algs: []
# 9.8:
# read_timeout: 12.0
# expiration_time: "16:00"
# max_idle_time: 600.0
# server_host_key_algs: []
# profile based on IP address, priority is higher than vendor/type/version
# ip:
# 192.168.60.198:
# read_timeout: 15.0
# expiration_time: "16:00"
# max_idle_time: 600.0
# server_host_key_algs: []
103 changes: 84 additions & 19 deletions packages/agent/src/netdriver_agent/client/channel.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from abc import abstractmethod
from dependency_injector.providers import Configuration
from netdriver_core.utils.terminal import simulate_output, simulate_output_oct_to_chinese
from pydantic import IPvAnyAddress
from re import Match, Pattern
from typing import Optional, Tuple, List
import asyncssh
import re

from netdriver_core.exception.errors import ChannelError, ChannelReadTimeout
from netdriver_core.exception.errors import ChannelError
from netdriver_core.log import logman
from netdriver_core.utils.asyncu import async_timeout

Expand Down Expand Up @@ -95,19 +98,26 @@
_DEFAULT_READ_BUFFER_SIZE = 8192


def update_ssh_config(kwargs: dict, config: Configuration) -> dict:
def update_ssh_config(kwargs: dict, profile: dict, config: Configuration) -> dict:
""" Update SSH configuration with defaults and provided parameters """
extra_kex_algs = set(config.session.ssh.kex_algs() or [])
extra_encryption_algs = (config.session.ssh.encryption_algs() or [])
ssh = config.session.ssh
extra_kex_algs = set(ssh.kex_algs() or [])
extra_encryption_algs = (ssh.encryption_algs() or [])
ssh_config = _DEFAUTL_SSH_CONFIG.copy()
ssh_config["kex_algs"] = list(ssh_config["kex_algs"].union(extra_kex_algs))
ssh_config["encryption_algs"] = list(ssh_config["encryption_algs"].union(extra_encryption_algs))
ssh_config["login_timeout"] = config.session.ssh.login_timeout() or ssh_config["login_timeout"]
ssh_config["connect_timeout"] = config.session.ssh.connect_timeout() or ssh_config["connect_timeout"]
ssh_config["keepalive_interval"] = config.session.ssh.keepalive_interval() or ssh_config["keepalive_interval"]
ssh_config["keepalive_count_max"] = config.session.ssh.keepalive_count_max() or ssh_config["keepalive_count_max"]
ssh_config["login_timeout"] = ssh.login_timeout() or ssh_config["login_timeout"]
ssh_config["connect_timeout"] = ssh.connect_timeout() or ssh_config["connect_timeout"]
ssh_config["keepalive_interval"] = ssh.keepalive_interval() or ssh_config["keepalive_interval"]
ssh_config["keepalive_count_max"] = ssh.keepalive_count_max() or ssh_config["keepalive_count_max"]
server_host_key_algs = profile.get("server_host_key_algs", [])
if server_host_key_algs:
ssh_config["server_host_key_algs"] = server_host_key_algs
term_size = ()
kwargs.update(ssh_config)
return kwargs
if ssh.term_size() and ssh.term_size.width() and ssh.term_size.height():
term_size = (ssh.term_size.width(), ssh.term_size.height())
return kwargs, term_size


class Channel:
Expand All @@ -125,7 +135,6 @@ async def create(cls,
username: Optional[str] = None,
password: Optional[str] = None,
encode: str = "utf-8",
term_size: Tuple = None,
logger: object = None,
profile: dict = {},
config: Configuration = None,
Expand All @@ -136,13 +145,13 @@ async def create(cls,
cls._read_channel_until_timeout = profile.get("read_timeout", DEFAULT_SESSION_PROFILE.get("read_timeout", 10))

if protocol == "ssh":
kwargs = update_ssh_config(kwargs, config)
kwargs, term_size = update_ssh_config(kwargs, profile, config)
conn = await asyncssh.connect(
host=str(ip), port=port, username=username, password=password,
encoding=encode, **kwargs)
terminal = await conn.create_process(term_type="ansi", term_size=term_size)
terminal.stdout.channel.set_encoding(encoding=encode, errors='replace')
return SSHChannel(conn, terminal, logger=logger)
return SSHChannel(conn, terminal, logger=logger, encode=encode)
else:
raise ValueError(f"protocol {protocol} not supported.")

Expand Down Expand Up @@ -195,11 +204,13 @@ class SSHChannel(Channel):

def __init__(self, conn: asyncssh.SSHClientConnection,
terminal: asyncssh.SSHClientProcess,
logger: object = None) -> None:
logger: object = None,
encode: str = "utf-8") -> None:
""" SSH Channel """
self._conn = conn
self._terminal = terminal
self._logger = logger
self._encode = encode

def _check_channel(self):
if not self._conn:
Expand Down Expand Up @@ -235,7 +246,7 @@ async def read_channel_until(
:return: str, the data read from the channel
"""
self._check_channel()
output = ReadBuffer(cmd=cmd)
output = ReadBuffer(cmd=cmd, encode=self._encode)
while not self.read_at_eof():
chunk = await self.read_channel(self._read_buffer_size)
output.append(chunk)
Expand Down Expand Up @@ -276,20 +287,74 @@ class ReadBuffer:
_cmd: str
_is_cmd_displayed: bool = False

def __init__(self, cmd: str = '', line_break: str = '\n') -> None:
def __init__(self, cmd: str = '', line_break: str = '\n', encode: str = None) -> None:
""" Initialize read buffer """
self._buffer = []
self._last_line_pos = (0, 0)
self._line_break = line_break
self._cmd = cmd
self._is_cmd_displayed = False
self._encode = encode

def _check_cmd_displayed(self, line: str = '') -> bool:
def _check_cmd_displayed(self, pattern: Pattern, line: str = '') -> bool:
if not self._is_cmd_displayed and self._cmd and line:
# check if the command is displayed in the line
if self._cmd in line:
if self._cmd_in_line(pattern, line):
self._is_cmd_displayed = True
log.trace(f"Command '{self._cmd}' is displayed in the line: {line}")

def _cmd_in_line(self, pattern: Pattern, line: str = '') -> bool:
"""check if the line contains cmd"""

log.trace(f"Line repr = {repr(line)}")
log.trace(f"Line escape = {line.encode('unicode_escape').decode('ascii')}")

# Topsec output extra ' \r' char
# Fortinet output extra ' \x08' char
if self._cmd in re.sub(r'\s[\r\x08]', '', line):
return True

# Juniper input extra spaces, and the output will remove the extra spaces
if self._cmd.replace(' ', '') in re.sub(r'[\x07\s]', '', line):
return True

# Fortinet input Chinese and output octal char
chinese = simulate_output_oct_to_chinese(output=line, encoding=self._encode)
log.trace(f"Line oct to chinese: {repr(chinese)}")
if self._cmd in chinese:
return True

# Line Remove prompt
for index in range(1, len(line)):
if re.match(pattern, line[:index]):
line = line[index:].lstrip()
break
log.trace(f"Line remove prompt = {repr(line)}")

# Topsec chinese escape failed, character ignored
if '\ufffd' in line:
line_splits = re.sub(r"(\s\r|\r\n)", '', line).split('\ufffd')
log.trace(f"Line split by \\ufffd = {line_splits}")
for line_split in line_splits:
if line_split not in self._cmd:
return False
return True

line = simulate_output(line)
log.trace(f"Line simulate output = {repr(line)}")

# Line simulate output remove prompt
for index in range(1, len(line)):
if re.match(pattern, line[:index]):
line = line[index:].lstrip()
break
log.trace(f"Line simulate output remove prompt = {repr(line)}")

# Array or Cisco display ultra wide processing
if '$' in line and line.replace('\x08', '').split('$')[0] in self._cmd:
return True

return False

def _is_real_prompt(self) -> bool:
if self._cmd:
Expand Down Expand Up @@ -332,7 +397,7 @@ def check_pattern(self, pattern: Pattern, is_update_checkpos: bool = True) -> Ma
while lb_pos != -1:
# found a line break, concat the line
line = ''.join([line, self._buffer[i][line_start_pos:lb_pos], self._line_break])
self._check_cmd_displayed(line)
self._check_cmd_displayed(pattern, line)
line_start_pos = lb_pos + len(self._line_break)
log.trace(f"Checking buffer[{i}][:{line_start_pos}]: {line}")
matched = pattern.search(line)
Expand All @@ -350,7 +415,7 @@ def check_pattern(self, pattern: Pattern, is_update_checkpos: bool = True) -> Ma

# no line break found, check the rest of buffer item
line = ''.join([line, self._buffer[i][line_start_pos:]])
self._check_cmd_displayed(line)
self._check_cmd_displayed(pattern, line)
line_start_pos += len(line)
if i == buffer_size - 1:
# if no line break found and no more buffer, check the last line
Expand Down
66 changes: 0 additions & 66 deletions packages/agent/src/netdriver_agent/client/merger.py

This file was deleted.

Loading
Loading